Source code for kiui.grid_put

import torch
import torch.nn.functional as F

from kiui.typing import *

def stride_from_shape(shape):
    stride = [1]
    for x in reversed(shape[1:]):
        stride.append(stride[-1] * x) 
    return list(reversed(stride))


def scatter_add_nd(input, indices, values):
    # input: [..., C], D dimension + C channel
    # indices: [N, D], long
    # values: [N, C]

    D = indices.shape[-1]
    C = input.shape[-1]
    size = input.shape[:-1]
    stride = stride_from_shape(size)

    assert len(size) == D

    input = input.view(-1, C)  # [HW, C]
    flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1)  # [N]

    input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)

    return input.view(*size, C)


def scatter_add_nd_with_count(input, count, indices, values, weights=None):
    # input: [..., C], D dimension + C channel
    # count: [..., 1], D dimension
    # indices: [N, D], long
    # values: [N, C]

    D = indices.shape[-1]
    C = input.shape[-1]
    size = input.shape[:-1]
    stride = stride_from_shape(size)

    assert len(size) == D

    input = input.view(-1, C)  # [HW, C]
    count = count.view(-1, 1)

    flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1)  # [N]

    if weights is None:
        weights = torch.ones_like(values[..., :1]) 

    input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
    count.scatter_add_(0, flatten_indices.unsqueeze(1), weights)

    return input.view(*size, C), count.view(*size, 1)

def nearest_grid_put_2d(H, W, coords, values, return_count=False):
    # coords: [N, 2], float in [-1, 1]
    # values: [N, C]

    C = values.shape[-1]

    indices = (coords * 0.5 + 0.5) * torch.tensor(
        [H - 1, W - 1], dtype=torch.float32, device=coords.device
    )
    indices = indices.round().long()  # [N, 2]

    result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype)  # [H, W, C]
    count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype)  # [H, W, 1]
    weights = torch.ones_like(values[..., :1])  # [N, 1]
    
    result, count = scatter_add_nd_with_count(result, count, indices, values, weights)

    if return_count:
        return result, count

    mask = (count.squeeze(-1) > 0)
    result[mask] = result[mask] / count[mask].repeat(1, C)

    return result


def linear_grid_put_2d(H, W, coords, values, return_count=False):
    # coords: [N, 2], float in [-1, 1]
    # values: [N, C]

    C = values.shape[-1]

    indices = (coords * 0.5 + 0.5) * torch.tensor(
        [H - 1, W - 1], dtype=torch.float32, device=coords.device
    )
    indices_00 = indices.floor().long()  # [N, 2]
    indices_00[:, 0].clamp_(0, H - 2)
    indices_00[:, 1].clamp_(0, W - 2)
    indices_01 = indices_00 + torch.tensor(
        [0, 1], dtype=torch.long, device=indices.device
    )
    indices_10 = indices_00 + torch.tensor(
        [1, 0], dtype=torch.long, device=indices.device
    )
    indices_11 = indices_00 + torch.tensor(
        [1, 1], dtype=torch.long, device=indices.device
    )

    h = indices[..., 0] - indices_00[..., 0].float()
    w = indices[..., 1] - indices_00[..., 1].float()
    w_00 = (1 - h) * (1 - w)
    w_01 = (1 - h) * w
    w_10 = h * (1 - w)
    w_11 = h * w

    result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype)  # [H, W, C]
    count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype)  # [H, W, 1]
    weights = torch.ones_like(values[..., :1])  # [N, 1]
    
    result, count = scatter_add_nd_with_count(result, count, indices_00, values * w_00.unsqueeze(1), weights* w_00.unsqueeze(1))
    result, count = scatter_add_nd_with_count(result, count, indices_01, values * w_01.unsqueeze(1), weights* w_01.unsqueeze(1))
    result, count = scatter_add_nd_with_count(result, count, indices_10, values * w_10.unsqueeze(1), weights* w_10.unsqueeze(1))
    result, count = scatter_add_nd_with_count(result, count, indices_11, values * w_11.unsqueeze(1), weights* w_11.unsqueeze(1))

    if return_count:
        return result, count

    mask = (count.squeeze(-1) > 0)
    result[mask] = result[mask] / count[mask].repeat(1, C)

    return result

def mipmap_linear_grid_put_2d(H, W, coords, values, min_resolution=32, return_count=False):
    # coords: [N, 2], float in [-1, 1]
    # values: [N, C]

    C = values.shape[-1]

    result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype)  # [H, W, C]
    count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype)  # [H, W, 1]

    cur_H, cur_W = H, W
    
    while min(cur_H, cur_W) > min_resolution:

        # try to fill the holes
        mask = (count.squeeze(-1) == 0)
        if not mask.any():
            break

        cur_result, cur_count = linear_grid_put_2d(cur_H, cur_W, coords, values, return_count=True)
        result[mask] = result[mask] + F.interpolate(cur_result.permute(2,0,1).unsqueeze(0).contiguous(), (H, W), mode='bilinear', align_corners=False).squeeze(0).permute(1,2,0).contiguous()[mask]
        count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W), (H, W), mode='bilinear', align_corners=False).view(H, W, 1)[mask]
        cur_H //= 2
        cur_W //= 2
    
    if return_count:
        return result, count

    mask = (count.squeeze(-1) > 0)
    result[mask] = result[mask] / count[mask].repeat(1, C)

    return result

def nearest_grid_put_3d(H, W, D, coords, values, return_count=False):
    # coords: [N, 3], float in [-1, 1]
    # values: [N, C]

    C = values.shape[-1]

    indices = (coords * 0.5 + 0.5) * torch.tensor(
        [H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device
    )
    indices = indices.round().long()  # [N, 2]

    result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype)  # [H, W, C]
    count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype)  # [H, W, 1]
    weights = torch.ones_like(values[..., :1])  # [N, 1]

    result, count = scatter_add_nd_with_count(result, count, indices, values, weights)
    
    if return_count:
        return result, count

    mask = (count.squeeze(-1) > 0)
    result[mask] = result[mask] / count[mask].repeat(1, C)

    return result


def linear_grid_put_3d(H, W, D, coords, values, return_count=False):
    # coords: [N, 3], float in [-1, 1]
    # values: [N, C]

    C = values.shape[-1]

    indices = (coords * 0.5 + 0.5) * torch.tensor(
        [H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device
    )
    indices_000 = indices.floor().long()  # [N, 3]
    indices_000[:, 0].clamp_(0, H - 2)
    indices_000[:, 1].clamp_(0, W - 2)
    indices_000[:, 2].clamp_(0, D - 2)

    indices_001 = indices_000 + torch.tensor([0, 0, 1], dtype=torch.long, device=indices.device)
    indices_010 = indices_000 + torch.tensor([0, 1, 0], dtype=torch.long, device=indices.device)
    indices_011 = indices_000 + torch.tensor([0, 1, 1], dtype=torch.long, device=indices.device)
    indices_100 = indices_000 + torch.tensor([1, 0, 0], dtype=torch.long, device=indices.device)
    indices_101 = indices_000 + torch.tensor([1, 0, 1], dtype=torch.long, device=indices.device)
    indices_110 = indices_000 + torch.tensor([1, 1, 0], dtype=torch.long, device=indices.device)
    indices_111 = indices_000 + torch.tensor([1, 1, 1], dtype=torch.long, device=indices.device)

    h = indices[..., 0] - indices_000[..., 0].float()
    w = indices[..., 1] - indices_000[..., 1].float()
    d = indices[..., 2] - indices_000[..., 2].float()
    
    w_000 = (1 - h) * (1 - w) * (1 - d)
    w_001 = (1 - h) * w * (1 - d)
    w_010 = h * (1 - w) * (1 - d)
    w_011 = h * w * (1 - d)
    w_100 = (1 - h) * (1 - w) * d
    w_101 = (1 - h) * w * d
    w_110 = h * (1 - w) * d
    w_111 = h * w * d

    result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype)  # [H, W, D, C]
    count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype)  # [H, W, D, 1]
    weights = torch.ones_like(values[..., :1])  # [N, 1]
    
    result, count = scatter_add_nd_with_count(result, count, indices_000, values * w_000.unsqueeze(1), weights * w_000.unsqueeze(1))
    result, count = scatter_add_nd_with_count(result, count, indices_001, values * w_001.unsqueeze(1), weights * w_001.unsqueeze(1))
    result, count = scatter_add_nd_with_count(result, count, indices_010, values * w_010.unsqueeze(1), weights * w_010.unsqueeze(1))
    result, count = scatter_add_nd_with_count(result, count, indices_011, values * w_011.unsqueeze(1), weights * w_011.unsqueeze(1))
    result, count = scatter_add_nd_with_count(result, count, indices_100, values * w_100.unsqueeze(1), weights * w_100.unsqueeze(1))
    result, count = scatter_add_nd_with_count(result, count, indices_101, values * w_101.unsqueeze(1), weights * w_101.unsqueeze(1))
    result, count = scatter_add_nd_with_count(result, count, indices_110, values * w_110.unsqueeze(1), weights * w_110.unsqueeze(1))
    result, count = scatter_add_nd_with_count(result, count, indices_111, values * w_111.unsqueeze(1), weights * w_111.unsqueeze(1))

    if return_count:
        return result, count

    mask = (count.squeeze(-1) > 0)
    result[mask] = result[mask] / count[mask].repeat(1, C)

    return result

def mipmap_linear_grid_put_3d(H, W, D, coords, values, min_resolution=32, return_count=False):
    # coords: [N, 3], float in [-1, 1]
    # values: [N, C]

    C = values.shape[-1]

    result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype)  # [H, W, D, C]
    count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype)  # [H, W, D, 1]
    cur_H, cur_W, cur_D = H, W, D
    
    while min(min(cur_H, cur_W), cur_D) > min_resolution:

        # try to fill the holes
        mask = (count.squeeze(-1) == 0)
        if not mask.any():
            break

        cur_result, cur_count = linear_grid_put_3d(cur_H, cur_W, cur_D, coords, values, return_count=True)
        result[mask] = result[mask] + F.interpolate(cur_result.permute(3,0,1,2).unsqueeze(0).contiguous(), (H, W, D), mode='trilinear', align_corners=False).squeeze(0).permute(1,2,3,0).contiguous()[mask]
        count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W, cur_D), (H, W, D), mode='trilinear', align_corners=False).view(H, W, D, 1)[mask]
        cur_H //= 2
        cur_W //= 2
        cur_D //= 2
    
    if return_count:
        return result, count

    mask = (count.squeeze(-1) > 0)
    result[mask] = result[mask] / count[mask].repeat(1, C)

    return result


[docs] def grid_put(shape: Sequence[int], coords: Tensor, values: Tensor, mode: Literal['nearest', 'linear', 'linear-mipmap']='linear-mipmap', min_resolution: int=32, return_count: bool=False) -> Tensor: """ put back values to an image according to the coords. inverse operation of ``F.grid_sample``. Args: shape (Sequence[int]): shape of the image, support 2D image and 3D volume, sequence of [D] coords (Tensor): coordinates, float [N, D] in [-1, 1]. values (Tensor): values, float [N, C]. mode (str, Literal['nearest', 'linear', 'linear-mipmap']): interpolation mode, see https://github.com/ashawkey/grid_put for examples. Defaults to 'linear-mipmap'. min_resolution (int, optional): minimal resolution for mipmap. Defaults to 32. return_count (bool, optional): whether to return the summed value and weights, instead of the divided results. Defaults to False. Returns: Tensor: the restored image/volume, float [H, W, C]/[H, W, D, C]. """ D = len(shape) assert D in [2, 3], f'only support D == 2 or 3, but got D == {D}' if mode == 'nearest': if D == 2: return nearest_grid_put_2d(*shape, coords, values, return_count) else: return nearest_grid_put_3d(*shape, coords, values, return_count) elif mode == 'linear': if D == 2: return linear_grid_put_2d(*shape, coords, values, return_count) else: return linear_grid_put_3d(*shape, coords, values, return_count) elif mode == 'linear-mipmap': if D == 2: return mipmap_linear_grid_put_2d(*shape, coords, values, min_resolution, return_count) else: return mipmap_linear_grid_put_3d(*shape, coords, values, min_resolution, return_count) else: raise NotImplementedError(f"got mode {mode}")