#!/usr/bin/env python3

# meshlab_finetune_openmvs.py
# 2025-10-31
# by Gernot Walzl

# This PyMeshLab script does the fine-tuning on a textured mesh
# reconstructed by OpenMVS.

# To install PyMeshLab:
# python3 -m venv pymeshlab.venv
# source pymeshlab.venv/bin/activate
# python3 -m pip install pymeshlab

import argparse
import numpy as np
import pymeshlab


def replace_vertex_color_by_neighbors(
        mesh, color_irgb_to_be_replaced=[255, 255, 255]):
    """OpenMVS colors triangles in the reconstructed mesh that are
    not visible in any input image in a specific color (empty color).
    Areas in the empty color are recolored in the colors of
    the local neighborhoods."""

    faces = mesh.face_matrix()
    vert_colors = mesh.vertex_color_matrix()
    vert_color_arr = mesh.vertex_color_array()

    color_i_to_be_replaced = (
        0xff000000 |
        (color_irgb_to_be_replaced[0] & 0xff) << 16 |
        (color_irgb_to_be_replaced[1] & 0xff) << 8 |
        (color_irgb_to_be_replaced[2] & 0xff))

    num_verts_ctbr_before = 0
    for _ in range(100):
        verts_ctbr_idx = (vert_color_arr == color_i_to_be_replaced).nonzero()[0]
        if num_verts_ctbr_before == len(verts_ctbr_idx):
            break
        for vert_ctbr_idx in verts_ctbr_idx:
            faces_adj_idx = (faces == vert_ctbr_idx).any(axis=1).nonzero()[0]
            faces_adj_colors = []
            for face_adj_idx in faces_adj_idx:
                vertices_adj_colors = []
                for vertex_adj_idx in faces[face_adj_idx]:
                    if vert_color_arr[vertex_adj_idx] != color_i_to_be_replaced:
                        vertices_adj_colors.append(vert_colors[vertex_adj_idx])
                if vertices_adj_colors:
                    faces_adj_colors.append(
                        np.mean(vertices_adj_colors, axis=0))
            if faces_adj_colors:
                color_frgba = np.mean(faces_adj_colors, axis=0)
                vert_colors[vert_ctbr_idx] = color_frgba
                # http://vcglib.net/color4_8h_source.html
                vert_color_arr[vert_ctbr_idx] = (
                    (int(color_frgba[3] * 255.0) & 0xff) << 24 |
                    (int(color_frgba[0] * 255.0) & 0xff) << 16 |
                    (int(color_frgba[1] * 255.0) & 0xff) << 8 |
                    (int(color_frgba[2] * 255.0) & 0xff))
        num_verts_ctbr_before = len(verts_ctbr_idx)

    return pymeshlab.Mesh(
        vertex_matrix=mesh.vertex_matrix(),
        face_matrix=mesh.face_matrix(),
        edge_matrix=mesh.edge_matrix(),
        v_normals_matrix=mesh.vertex_normal_matrix(),
        f_normals_matrix=mesh.face_normal_matrix(),
        v_scalar_array=(
            mesh.vertex_scalar_array() if mesh.has_vertex_scalar()
            else np.empty((0, 1))),
        f_scalar_array=(
            mesh.face_scalar_array() if mesh.has_face_scalar()
            else np.empty((0, 1))),
        v_color_matrix=vert_colors,
        f_color_matrix=(
            mesh.face_color_matrix() if mesh.has_face_color()
            else np.empty((0, 4))),
        v_tex_coords_matrix=(
            mesh.vertex_tex_coord_matrix() if mesh.has_vertex_tex_coord()
            else np.empty((0, 2))),
        w_tex_coords_matrix=(
            mesh.wedge_tex_coord_matrix() if mesh.has_wedge_tex_coord()
            else np.empty((0, 2))))


def main(inputfilename, outputfilename,
         color_to_be_replaced=[255, 255, 255], scale=1.0):
    ms = pymeshlab.MeshSet()
    ms.load_new_mesh(inputfilename)

    # Remove small unconnected objects from the scene
    # Cleaning and Repairing > Remove Isolated pieces (wrt Face num.)
    ms.meshing_remove_connected_component_by_face_number(mincomponentsize=1000)

    # In COLMAP, the y-axis is pointing downwards.
    # After rotation of -90°, the z-axis is pointing upwards.
    # Normals, Curvatures and Orientationn > Transform: Rotate
    ms.compute_matrix_from_rotation(angle=-90.0)

    # Normals, Curvatures and Orientationn > Transform: Scale, Normalize
    if scale != 1.0:
        ms.compute_matrix_from_scaling_or_normalization(
            axisx=scale, axisy=scale, axisz=scale)

    # Move the center of the mesh to the origin
    # of the coordinate system (0, 0, 0)
    # Normals, Curvatures and Orientation > Transform: Translate, Center, set Origin
    ms.compute_matrix_from_translation(traslmethod=2)

    # Make the surface of the mesh a bit more smooth
    # Smoothing, Fairing and Deformation > Laplacian Smooth
    ms.apply_coord_laplacian_smoothing(stepsmoothnum=5)

    # Texture > Transfer: Texture to Vertex Color (1 or 2 meshes)
    ms.transfer_texture_to_color_per_vertex()

    ms.add_mesh(replace_vertex_color_by_neighbors(
        ms.current_mesh(), color_to_be_replaced))

    # https://pymeshlab.readthedocs.io/en/latest/io_format_list.html#save-mesh-parameters
    ms.save_current_mesh(
        outputfilename,
        save_textures=False,
        binary=True,
        save_vertex_quality=False,
        save_vertex_flag=False,
        save_vertex_color=True,
        save_vertex_coord=False,
        save_vertex_normal=False,
        save_vertex_radius=False,
        save_face_quality=False,
        save_face_flag=False,
        save_face_color=False,
        save_wedge_color=False,
        save_wedge_texcoord=False,
        save_wedge_normal=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'inputfile', type=str)
    parser.add_argument(
        '-o', '--outputfile', type=str,
        default='output.ply')
    parser.add_argument(
        '-ctbr', '--color-to-be-replaced', type=int, nargs=3,
        default=[255, 255, 255])
    parser.add_argument(
        '-s', '--scale', type=float,
        default=1.0)

    args = parser.parse_args()
    main(args.inputfile, args.outputfile, args.color_to_be_replaced, args.scale)
