import torch
import numpy as np
from kiui.op import safe_normalize
import pymeshlab as pml
from importlib.metadata import version
PML_VER = version('pymeshlab')
# the code assumes the latest 2023.12 version, but we can patch older versions
if PML_VER.startswith('0.2'):
# monkey patch for 0.2 (only the used functions in this file!)
pml.MeshSet.meshing_decimation_quadric_edge_collapse = pml.MeshSet.simplification_quadric_edge_collapse_decimation
pml.MeshSet.meshing_isotropic_explicit_remeshing = pml.MeshSet.remeshing_isotropic_explicit_remeshing
pml.MeshSet.meshing_remove_unreferenced_vertices = pml.MeshSet.remove_unreferenced_vertices
pml.MeshSet.meshing_merge_close_vertices = pml.MeshSet.merge_close_vertices
pml.MeshSet.meshing_remove_duplicate_faces = pml.MeshSet.remove_duplicate_faces
pml.MeshSet.meshing_remove_null_faces = pml.MeshSet.remove_zero_area_faces
pml.MeshSet.meshing_remove_connected_component_by_diameter = pml.MeshSet.remove_isolated_pieces_wrt_diameter
pml.MeshSet.meshing_remove_connected_component_by_face_number = pml.MeshSet.remove_isolated_pieces_wrt_face_num
pml.MeshSet.meshing_repair_non_manifold_edges = pml.MeshSet.repair_non_manifold_edges_by_removing_faces
pml.MeshSet.meshing_repair_non_manifold_vertices = pml.MeshSet.repair_non_manifold_vertices_by_splitting
pml.PercentageValue = pml.Percentage
pml.PureValue = float
elif PML_VER.startswith('2022.2'):
# monkey patch for 2022.2
pml.PercentageValue = pml.Percentage
pml.PureValue = pml.AbsoluteValue
[docs]
def decimate_mesh(
verts, faces, target=5e4, backend="pymeshlab", remesh=False, optimalplacement=True, verbose=True
):
""" perform mesh decimation.
Args:
verts (np.ndarray): mesh vertices, float [N, 3]
faces (np.ndarray): mesh faces, int [M, 3]
target (int): targeted number of faces
backend (str, optional): algorithm backend, can be "pymeshlab" or "pyfqmr". Defaults to "pymeshlab".
remesh (bool, optional): whether to remesh after decimation. Defaults to False.
optimalplacement (bool, optional): For flat mesh, use False to prevent spikes. Defaults to True.
verbose (bool, optional): whether to print the decimation process. Defaults to True.
Returns:
Tuple[np.ndarray]: vertices and faces after decimation.
"""
_ori_vert_shape = verts.shape
_ori_face_shape = faces.shape
if backend == "pyfqmr":
import pyfqmr
solver = pyfqmr.Simplify()
solver.setMesh(verts, faces)
solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False)
verts, faces, normals = solver.getMesh()
else:
m = pml.Mesh(verts, faces)
ms = pml.MeshSet()
ms.add_mesh(m, "mesh") # will copy!
# filters
# ms.meshing_decimation_clustering(threshold=pml.PercentageValue(1))
ms.meshing_decimation_quadric_edge_collapse(
targetfacenum=int(target), optimalplacement=optimalplacement
)
if remesh:
# ms.apply_coord_taubin_smoothing()
ms.meshing_isotropic_explicit_remeshing(
iterations=3, targetlen=pml.PercentageValue(1)
)
# extract mesh
m = ms.current_mesh()
m.compact()
verts = m.vertex_matrix()
faces = m.face_matrix()
if verbose:
print(f"[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}")
return verts, faces
[docs]
def clean_mesh(
verts,
faces,
v_pct=1,
min_f=0,
min_d=0,
repair=False,
remesh=False,
remesh_size=0.01,
remesh_iters=3,
verbose=True,
):
""" perform mesh cleaning, including floater removal, non manifold repair, and remeshing.
Args:
verts (np.ndarray): mesh vertices, float [N, 3]
faces (np.ndarray): mesh faces, int [M, 3]
v_pct (int, optional): percentage threshold to merge close vertices. Defaults to 1.
min_f (int, optional): maximal number of faces for isolated component to remove. Defaults to 0.
min_d (int, optional): maximal diameter percentage of isolated component to remove. Defaults to 0.
repair (bool, optional): whether to repair non-manifold faces (cannot gurantee). Defaults to True.
remesh (bool, optional): whether to perform a remeshing after all cleaning. Defaults to True.
remesh_size (float, optional): the targeted edge length for remeshing. Defaults to 0.01.
remesh_iters (int, optional): the iterations of remeshing. Defaults to 3.
verbose (bool, optional): whether to print the cleaning process. Defaults to True.
Returns:
Tuple[np.ndarray]: vertices and faces after decimation.
"""
# verts: [N, 3]
# faces: [N, 3]
_ori_vert_shape = verts.shape
_ori_face_shape = faces.shape
m = pml.Mesh(verts, faces)
ms = pml.MeshSet()
ms.add_mesh(m, "mesh") # will copy!
# filters
ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces
if v_pct > 0:
ms.meshing_merge_close_vertices(
threshold=pml.PercentageValue(v_pct)
) # 1/10000 of bounding box diagonal
ms.meshing_remove_duplicate_faces() # faces defined by the same verts
ms.meshing_remove_null_faces() # faces with area == 0
if min_d > 0:
ms.meshing_remove_connected_component_by_diameter(
mincomponentdiag=pml.PercentageValue(min_d)
)
if min_f > 0:
ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f)
# be careful: may lead to strangely missing triangles...
if repair:
# ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True)
ms.meshing_repair_non_manifold_edges(method=0)
ms.meshing_repair_non_manifold_vertices(vertdispratio=0)
if remesh:
# ms.apply_coord_taubin_smoothing()
ms.meshing_isotropic_explicit_remeshing(
iterations=remesh_iters, targetlen=pml.PureValue(remesh_size)
)
# extract mesh
m = ms.current_mesh()
m.compact()
verts = m.vertex_matrix()
faces = m.face_matrix()
if verbose:
print(f"[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}")
return verts, faces
[docs]
def normalize_mesh(vertices, bound=0.95):
""" normalize the mesh vertices to a unit cube.
Args:
vertices (np.ndarray): mesh vertices, float [N, 3]
bound (float, optional): the bounding box size. Defaults to 0.95.
Returns:
np.ndarray: normalized vertices.
"""
vmin = vertices.min(0)
vmax = vertices.max(0)
ori_center = (vmax + vmin) / 2
ori_scale = 2 * bound / np.max(vmax - vmin)
vertices = (vertices - ori_center) * ori_scale
return vertices
### mesh related losses
[docs]
def laplacian_smooth_loss(verts, faces):
""" calculate laplacian smooth loss.
Args:
verts (torch.Tensor): mesh vertices, float [N, 3]
faces (torch.Tensor): mesh faces, int [M, 3]
Returns:
torch.Tensor: loss value.
"""
with torch.no_grad():
L = laplacian_uniform(verts, faces.long())
loss = L.mm(verts)
loss = loss.norm(dim=1)
loss = loss.mean()
return loss
[docs]
@torch.no_grad()
def compute_edge_to_face_mapping(faces):
""" compute edge to face mapping.
Args:
faces (torch.Tensor): mesh faces, int [M, 3]
Returns:
torch.Tensor: indices to faces for each edge, long, [N, 2]
"""
# Get unique edges
# Create all edges, packed by triangle
all_edges = torch.cat((
torch.stack((faces[:, 0], faces[:, 1]), dim=-1),
torch.stack((faces[:, 1], faces[:, 2]), dim=-1),
torch.stack((faces[:, 2], faces[:, 0]), dim=-1),
), dim=-1).view(-1, 2)
# Swap edge order so min index is always first
order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)
sorted_edges = torch.cat((
torch.gather(all_edges, 1, order),
torch.gather(all_edges, 1, 1 - order)
), dim=-1)
# Elliminate duplicates and return inverse mapping
unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True)
tris = torch.arange(faces.shape[0]).repeat_interleave(3).cuda()
tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda()
# Compute edge to face table
mask0 = order[:,0] == 0
mask1 = order[:,0] == 1
tris_per_edge[idx_map[mask0], 0] = tris[mask0]
tris_per_edge[idx_map[mask1], 1] = tris[mask1]
return tris_per_edge
[docs]
def normal_consistency(verts, faces, face_normals=None):
""" calculate normal consistency loss.
Args:
verts (torch.Tensor): mesh vertices, float [N, 3]
faces (torch.Tensor): mesh faces, int [M, 3]
face_normals (Optional[torch.Tensor]): the normal vector for each face, will be calculated if not provided, float [M, 3]
Returns:
torch.Tensor: loss value.
"""
if face_normals is None:
i0, i1, i2 = faces[:, 0].long(), faces[:, 1].long(), faces[:, 2].long()
v0, v1, v2 = verts[i0, :], verts[i1, :], verts[i2, :]
face_normals = torch.cross(v1 - v0, v2 - v0)
face_normals = safe_normalize(face_normals)
tris_per_edge = compute_edge_to_face_mapping(faces)
# Fetch normals for both faces sharind an edge
n0 = face_normals[tris_per_edge[:, 0], :]
n1 = face_normals[tris_per_edge[:, 1], :]
# Compute error metric based on normal difference
term = torch.clamp(torch.sum(n0 * n1, -1, keepdim=True), min=-1.0, max=1.0)
term = (1.0 - term)
return torch.mean(torch.abs(term))