import time
import torch
import numpy as np
from datetime import datetime
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from kiui.typing import *
from kiui.utils import lo, write_image
[docs]
def map_color(value: ndarray, cmap_name: str="viridis", vmin: float=None, vmax: float=None):
""" map a 1D array to continuous color.
Args:
value (ndarray): array of float, [N]
cmap_name (str, optional): color map name, ref: https://matplotlib.org/stable/users/explain/colors/colormaps.html#classes-of-colormaps. Defaults to "viridis".
vmin (float, optional): min value. Defaults to None.
vmax (float, optional): max value. Defaults to None.
Returns:
ndarray: array of color, [N, 3] in [0, 1]
"""
# value: [N], float
# return: RGB, [N, 3], float in [0, 1]
if vmin is None:
vmin = value.min()
if vmax is None:
vmax = value.max()
value = (value - vmin) / (vmax - vmin) # range in [0, 1]
cmap = cm.get_cmap(cmap_name)
rgb = cmap(value)[:, :3] # will return rgba, we take only first 3 so we get rgb
return rgb
[docs]
def plot_image(*xs, normalize=False, save=False, prefix='kiui_vis_plot_image'):
""" sequentially plot provided images, optionally save to current dir.
Args:
xs (Sequence[Union[torch.Tensor, numpy.ndarray]]): can be uint8 or float32.
[B, 4/3/1, H, W], [B, H, W, 4/3/1], [4/3/1, H, W], [H, W, 4/3/1], [H, W] torch.Tensor or numpy.ndarray
normalize (bool, optional): whether to renormalize the image to [0, 1]. Defaults to False.
save (bool, optional): whether to save the image to current dir (in case the plot cannot be showed, like in vscode remote). Defaults to False.
prefix (str, optional): image save name prefix if save=True.
"""
_cnt = 0
_signature = datetime.now().strftime('%Y_%m_%d_%H_%M_%S_%f')
def _plot_image(image):
nonlocal _cnt
lo(image)
if isinstance(image, torch.Tensor):
image = image.detach().cpu().numpy()
if image.dtype == np.uint8:
image = image.astype(np.float32) / 255.0
# empirially to channel-last
if len(image.shape) == 3 and image.shape[0] < image.shape[-1]:
image = image.transpose(1, 2, 0)
# normalize
if normalize:
image = (image - image.min(axis=0, keepdims=True)) / (
image.max(axis=0, keepdims=True)
- image.min(axis=0, keepdims=True)
+ 1e-8
)
if save:
_path = f'{prefix}_{_signature}_{_cnt}.png'
_cnt += 1
write_image(_path, image.astype(np.float32))
print(f'[INFO] write image to {_path}')
else:
plt.imshow(image.astype(np.float32))
plt.show()
for x in xs:
if len(x.shape) == 4:
for i in range(x.shape[0]):
_plot_image(x[i])
else: # 3 or 2
_plot_image(x)
[docs]
def plot_matrix(*xs):
""" visualize some 2D matrix, different from ``kiui.vis.plot_image``, this will keep the original range and plot channel-by-channel.
Args:
xs (Sequence[Union[torch.Tensor, numpy.ndarray]]): [B, C, H, W], [C, H, W], or [H, W] torch.Tensor or numpy.ndarray
"""
def _plot_matrix(matrix):
lo(matrix)
if isinstance(matrix, torch.Tensor):
if len(matrix.shape) == 3:
matrix = matrix.permute(1, 2, 0).squeeze()
matrix = matrix.detach().cpu().numpy()
if len(matrix.shape) == 3:
# per channel
for i in range(matrix.shape[-1]):
plt.matshow(matrix[..., i])
plt.show()
else:
plt.matshow(matrix.astype(np.float32))
plt.show()
for x in xs:
if len(x.shape) == 4:
for i in range(x.shape[0]):
_plot_matrix(x[i])
else: # 3 or 2
_plot_matrix(x)
[docs]
def plot_pointcloud(pc, color=None):
"""plot point cloud.
Args:
pc (ndarray): point cloud positions, float [N, 3].
color (ndarray, optional): point cloud colors, float/uint8 [N, 3/4]. Defaults to None.
Note:
This function requires a desktop (cannot be forwarded by ssh)!
"""
lo(pc)
if color is not None:
lo(color)
if color.dtype == np.float32:
color = (color * 255).astype(np.uint8)
if color is None or color.shape[-1] == 3:
# use o3d as it's better to control
import open3d as o3d
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(pc)
if color is not None:
pcd.colors = o3d.utility.Vector3dVector(color)
o3d.visualization.draw_geometries([pcd])
else:
import trimesh
pc = trimesh.PointCloud(pc, color)
# axis
axes = trimesh.creation.axis(axis_length=4)
# sphere
box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline()
box.colors = np.array([[128, 128, 128]] * len(box.entities))
trimesh.Scene([pc, axes, box]).show()
[docs]
def plot_poses(poses, size=0.1, bound=1, points=None, mesh=None, opengl=True):
"""plot camera poses.
Args:
poses (ndarray): camera poses, float [N, 4, 4].
size (float, optional): line width. Defaults to 0.1.
bound (int, optional): bounding box bound. Defaults to 1.
points (ndarray, optional): also draw point clouds, float [M, 3]. Defaults to None.
mesh (trimesh.Trimesh, optional): also draw mesh. Defaults to None.
opengl (bool, optional): use OpenGL camera convention. Defaults to True.
Note:
This function requires a desktop (cannot be forwarded by ssh)!
"""
lo(poses)
if torch.is_tensor(poses):
poses = poses.detach().cpu().numpy()
import trimesh
axes = trimesh.creation.axis(axis_length=4)
box = trimesh.primitives.Box(extents=[2 * bound] * 3).as_outline()
box.colors = np.array([[128, 128, 128]] * len(box.entities))
objects = [axes, box]
if bound > 1:
unit_box = trimesh.primitives.Box(extents=[2] * 3).as_outline()
unit_box.colors = np.array([[128, 128, 128]] * len(unit_box.entities))
objects.append(unit_box)
for pose in poses:
# a camera is visualized with 8 line segments.
pos = pose[:3, 3]
a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] * (-1 if opengl else 1)
b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] * (-1 if opengl else 1)
c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] * (-1 if opengl else 1)
d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] * (-1 if opengl else 1)
# construct 3D paths
frame = np.array([
[pos, a],
[pos, b],
[pos, c],
[pos, d],
[a, b],
[b, c],
[c, d],
[d, a],
[pos, pos + pose[:3, 2] * (-1 if opengl else 1) * 3], # point to target
])
frame = trimesh.load_path(frame)
objects.append(frame)
right_line = np.array([[pos, pos + pose[:3, 0] * size]])
right_line = trimesh.load_path(right_line)
right_line.colors = np.array([[255, 0, 0, 255]]).repeat(len(right_line.entities), axis=0)
objects.append(right_line)
up_line = np.array([[pos, pos + pose[:3, 1] * size]])
up_line = trimesh.load_path(up_line)
up_line.colors = np.array([[0, 255, 0, 255]]).repeat(len(up_line.entities), axis=0)
objects.append(up_line)
forward_line = np.array([[pos, pos + pose[:3, 2] * size]])
forward_line = trimesh.load_path(forward_line)
forward_line.colors = np.array([[0, 0, 255, 255]]).repeat(len(forward_line.entities), axis=0)
objects.append(forward_line)
if points is not None:
lo(points)
colors = np.zeros((points.shape[0], 4), dtype=np.uint8)
colors[:, 2] = 255 # blue
colors[:, 3] = 30 # transparent
objects.append(trimesh.PointCloud(points, colors))
if mesh is not None:
objects.append(mesh)
scene = trimesh.Scene(objects)
scene.set_camera(distance=bound, center=[0, 0, 0])
scene.show()