import os
import sys
import glob
import tqdm
import json
import pickle
import varname
from objprint import objstr
from rich.console import Console
import cv2
from PIL import Image
import numpy as np
import torch
from kiui.typing import *
from kiui.env import is_imported
[docs]
def lo(*xs, verbose=0):
"""inspect array like objects and report statistics.
Args:
xs (Any): array like objects to inspect.
verbose (int, optional): level of verbosity, set to 1 to report mean and std, 2 to print the content. Defaults to 0.
"""
console = Console()
def _lo(x, name):
if isinstance(x, np.ndarray):
# general stats
text = ""
text += f"[orange1]Array {name}[/orange1] {x.shape} {x.dtype}"
if x.size > 0:
text += f" ∈ [{x.min()}, {x.max()}]"
if verbose >= 1:
text += f" μ = {x.mean()} σ = {x.std()}"
# detect abnormal values
if np.isnan(x).any():
text += "[red] NaN![/red]"
if np.isinf(x).any():
text += "[red] Inf![/red]"
console.print(text)
# show values if shape is small or verbose is high
if x.size < 50 or verbose >= 2:
# np.set_printoptions(precision=4)
print(x)
elif torch.is_tensor(x):
# general stats
text = ""
text += f"[orange1]Tensor {name}[/orange1] {x.shape} {x.dtype} {x.device}"
if x.numel() > 0:
text += f" ∈ [{x.min().item()}, {x.max().item()}]"
if verbose >= 1:
text += f" μ = {x.mean().item()} σ = {x.std().item()}"
# detect abnormal values
if torch.isnan(x).any():
text += "[red] NaN![/red]"
if torch.isinf(x).any():
text += "[red] Inf![/red]"
console.print(text)
# show values if shape is small or verbose is high
if x.numel() < 50 or verbose >= 2:
# np.set_printoptions(precision=4)
print(x)
else: # other type, just print them
console.print(f"[orange1]{type(x)} {name}[/orange1] {objstr(x)}")
# inspect names
for i, x in enumerate(xs):
try:
name = varname.argname(f"xs[{i}]", func=lo)
except:
name = f"UNKNOWN"
_lo(x, name)
[docs]
def seed_everything(seed=42, verbose=False, strict=False):
"""auto set seed for random, numpy and torch.
Args:
seed (int, optional): random seed. Defaults to 42.
verbose (bool, optional): whether to report each seed setting. Defaults to False.
strict (bool, optional): whether to use strict deterministic mode for better torch reproduction. Defaults to False.
"""
os.environ['PYTHONHASHSEED'] = str(seed)
if is_imported('random'):
import random # still need to import it here
random.seed(seed)
if verbose: print(f'[INFO] set random.seed = {seed}')
else:
if verbose: print(f'[INFO] random not imported, skip setting seed')
# assume numpy is imported as np
if is_imported('np'):
import numpy as np
np.random.seed(seed)
if verbose: print(f'[INFO] set np.random.seed = {seed}')
else:
if verbose: print(f'[INFO] numpy not imported, skip setting seed')
if is_imported('torch'):
import torch
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if verbose: print(f'[INFO] set torch.manual_seed = {seed}')
if strict:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
if verbose: print(f'[INFO] set strict deterministic mode for torch.')
else:
if verbose: print(f'[INFO] torch not imported, skip setting seed')
[docs]
def read_json(path):
"""load a json file.
Args:
path (str): path to json file.
Returns:
dict: json content.
"""
with open(path, "r") as f:
return json.load(f)
[docs]
def write_json(path, x):
"""write a json file.
Args:
path (str): path to write json file.
x (dict): dict to write.
"""
with open(path, "w") as f:
json.dump(x, f, indent=2)
[docs]
def read_pickle(path):
"""read a pickle file.
Args:
path (str): path to pickle file.
Returns:
Any: pickle content.
"""
with open(path, "rb") as f:
return pickle.load(f)
[docs]
def write_pickle(path, x):
"""write a pickle file.
Args:
path (str): path to write pickle file.
x (Any): content to write.
"""
with open(path, "wb") as f:
pickle.dump(x, f)
[docs]
def read_image(
path: str,
mode: Literal["float", "uint8", "pil", "torch", "tensor"] = "float",
order: Literal["RGB", "RGBA", "BGR", "BGRA"] = "RGB",
):
"""read an image file into various formats and color mode.
Args:
path (str): path to the image file.
mode (Literal["float", "uint8", "pil", "torch", "tensor"], optional): returned image format. Defaults to "float".
float: float32 numpy array, range [0, 1];
uint8: uint8 numpy array, range [0, 255];
pil: PIL image;
torch/tensor: float32 torch tensor, range [0, 1];
order (Literal["RGB", "RGBA", "BGR", "BGRA"], optional): channel order. Defaults to "RGB".
Note:
By default this function will convert RGBA image to white-background RGB image. Use ``order="RGBA"`` to keep the alpha channel.
Returns:
Union[np.ndarray, PIL.Image, torch.Tensor]: the image array.
"""
if mode == "pil":
return Image.open(path).convert(order)
if path.endswith('.exr'):
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
img = cv2.imread(path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
else:
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
# cvtColor
if len(img.shape) == 3: # ignore if gray scale
if order in ["RGB", "RGBA"]:
if img.shape[-1] == 4:
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)
elif img.shape[-1] == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# mix background
if img.shape[-1] == 4 and 'A' not in order:
img = img.astype(np.float32) / 255
img = img[..., :3] * img[..., 3:] + (1 - img[..., 3:])
# mode
if mode == "uint8":
if img.dtype != np.uint8:
img = (img * 255).astype(np.uint8)
return img
elif mode == "float":
if img.dtype == np.uint8:
img = img.astype(np.float32) / 255
return img
elif mode in ["tensor", "torch"]:
if img.dtype == np.uint8:
img = img.astype(np.float32) / 255
return torch.from_numpy(img)
else:
raise ValueError(f"Unknown read_image mode {mode}")
[docs]
def write_image(
path: str,
img: Union[Tensor, np.ndarray, Image.Image],
order: Literal["RGB", "BGR"] = "RGB",
):
"""write an image to various formats.
Args:
path (str): path to write the image file.
img (Union[torch.Tensor, np.ndarray, PIL.Image.Image]): image to write.
order (str, optional): channel order. Defaults to "RGB".
"""
if isinstance(img, Image.Image):
img.save(path)
return
if torch.is_tensor(img):
img = img.detach().cpu().numpy()
if img.dtype == np.float32 or img.dtype == np.float64:
img = (img * 255).astype(np.uint8)
if len(img.shape) == 4:
if img.shape[0] > 1:
raise ValueError(f'only support saving a single image! current image: {img.shape}')
img = img[0]
if len(img.shape) == 3:
# cvtColor
if order == "RGB":
if img.shape[-1] == 4:
img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA)
elif img.shape[-1] == 3:
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
dir_path = os.path.dirname(path)
if dir_path != '' and not os.path.exists(dir_path):
os.makedirs(os.path.dirname(path), exist_ok=True)
cv2.imwrite(path, img)
[docs]
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Load file form http url, will download models if necessary.
Args:
url (str): URL to be downloaded.
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
Default: None.
progress (bool): Whether to show the download progress. Default: True.
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
Returns:
str: The path to the downloaded file.
"""
from torch.hub import download_url_to_file, get_dir
from urllib.parse import urlparse
if model_dir is None: # use the pytorch hub_dir
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
print(f'[INFO] Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
[docs]
def batch_process_files(
process_fn, path, out_path,
overwrite=False,
in_format=[".jpg", ".jpeg", ".png"],
out_format=None,
image_mode='uint8',
image_color_order="RGB",
**kwargs
):
"""simple function wrapper to batch processing files.
Args:
process_fn (Callable): process function.
path (str): path to a file or a directory containing the files to process.
out_path (str): output path of a file or a directory.
overwrite (bool, optional): whether to overwrite existing results. Defaults to False.
in_format (list, optional): input file formats. Defaults to [".jpg", ".jpeg", ".png"].
out_format (str, optional): output file format. Defaults to None.
image_mode (str, optional): for images, the mode to read. Defaults to 'uint8'.
image_color_order (str, optional): for images, the color order. Defaults to "RGB".
"""
if os.path.isdir(path):
file_paths = glob.glob(os.path.join(path, "*"))
file_paths = [f for f in file_paths if is_format(f, in_format)]
else:
file_paths = [path]
if os.path.dirname(out_path) != '':
os.makedirs(os.path.dirname(out_path), exist_ok=True)
for file_path in tqdm.tqdm(file_paths):
try:
if len(file_paths) == 1:
file_out_path = out_path
else:
file_out_path = os.path.join(out_path, os.path.basename(file_path))
if out_format is not None:
file_out_path = os.path.splitext(file_out_path)[0] + out_format
if os.path.exists(file_out_path) and not overwrite:
print(f"[INFO] ignoring {file_path} --> {file_out_path}")
continue
# dispatch loader
if is_format(file_path, ['.jpg', '.jpeg', '.png']):
input = read_image(file_path, mode=image_mode, order=image_color_order)
elif is_format(file_path, ['.ply', '.obj', '.glb', '.gltf']):
from kiui.mesh import Mesh
input = Mesh.load(file_path)
else:
with open(file_path, "r") as f:
input = f.read()
# process
output = process_fn(input, **kwargs)
# dispatch writer
if is_format(file_out_path, ['.jpg', '.jpeg', '.png']):
write_image(file_out_path, output, order=image_color_order)
elif is_format(file_out_path, ['.ply', '.obj', '.glb', '.gltf']):
output.write(file_out_path)
elif is_format(file_out_path, ['.npy']):
np.save(file_out_path, output)
else:
with open(file_out_path, "w") as f:
f.write(output)
except Exception as e:
print(f"[ERROR] when processing {file_path} --> {file_out_path}")
print(e)