importtorchimporttorch.nn.functionalasFfromkiui.typingimport*defstride_from_shape(shape):stride=[1]forxinreversed(shape[1:]):stride.append(stride[-1]*x)returnlist(reversed(stride))defscatter_add_nd(input,indices,values):# input: [..., C], D dimension + C channel# indices: [N, D], long# values: [N, C]D=indices.shape[-1]C=input.shape[-1]size=input.shape[:-1]stride=stride_from_shape(size)assertlen(size)==Dinput=input.view(-1,C)# [HW, C]flatten_indices=(indices*torch.tensor(stride,dtype=torch.long,device=indices.device)).sum(-1)# [N]input.scatter_add_(0,flatten_indices.unsqueeze(1).repeat(1,C),values)returninput.view(*size,C)defscatter_add_nd_with_count(input,count,indices,values,weights=None):# input: [..., C], D dimension + C channel# count: [..., 1], D dimension# indices: [N, D], long# values: [N, C]D=indices.shape[-1]C=input.shape[-1]size=input.shape[:-1]stride=stride_from_shape(size)assertlen(size)==Dinput=input.view(-1,C)# [HW, C]count=count.view(-1,1)flatten_indices=(indices*torch.tensor(stride,dtype=torch.long,device=indices.device)).sum(-1)# [N]ifweightsisNone:weights=torch.ones_like(values[...,:1])input.scatter_add_(0,flatten_indices.unsqueeze(1).repeat(1,C),values)count.scatter_add_(0,flatten_indices.unsqueeze(1),weights)returninput.view(*size,C),count.view(*size,1)defnearest_grid_put_2d(H,W,coords,values,return_count=False):# coords: [N, 2], float in [-1, 1]# values: [N, C]C=values.shape[-1]indices=(coords*0.5+0.5)*torch.tensor([H-1,W-1],dtype=torch.float32,device=coords.device)indices=indices.round().long()# [N, 2]result=torch.zeros(H,W,C,device=values.device,dtype=values.dtype)# [H, W, C]count=torch.zeros(H,W,1,device=values.device,dtype=values.dtype)# [H, W, 1]weights=torch.ones_like(values[...,:1])# [N, 1]result,count=scatter_add_nd_with_count(result,count,indices,values,weights)ifreturn_count:returnresult,countmask=(count.squeeze(-1)>0)result[mask]=result[mask]/count[mask].repeat(1,C)returnresultdeflinear_grid_put_2d(H,W,coords,values,return_count=False):# coords: [N, 2], float in [-1, 1]# values: [N, C]C=values.shape[-1]indices=(coords*0.5+0.5)*torch.tensor([H-1,W-1],dtype=torch.float32,device=coords.device)indices_00=indices.floor().long()# [N, 2]indices_00[:,0].clamp_(0,H-2)indices_00[:,1].clamp_(0,W-2)indices_01=indices_00+torch.tensor([0,1],dtype=torch.long,device=indices.device)indices_10=indices_00+torch.tensor([1,0],dtype=torch.long,device=indices.device)indices_11=indices_00+torch.tensor([1,1],dtype=torch.long,device=indices.device)h=indices[...,0]-indices_00[...,0].float()w=indices[...,1]-indices_00[...,1].float()w_00=(1-h)*(1-w)w_01=(1-h)*ww_10=h*(1-w)w_11=h*wresult=torch.zeros(H,W,C,device=values.device,dtype=values.dtype)# [H, W, C]count=torch.zeros(H,W,1,device=values.device,dtype=values.dtype)# [H, W, 1]weights=torch.ones_like(values[...,:1])# [N, 1]result,count=scatter_add_nd_with_count(result,count,indices_00,values*w_00.unsqueeze(1),weights*w_00.unsqueeze(1))result,count=scatter_add_nd_with_count(result,count,indices_01,values*w_01.unsqueeze(1),weights*w_01.unsqueeze(1))result,count=scatter_add_nd_with_count(result,count,indices_10,values*w_10.unsqueeze(1),weights*w_10.unsqueeze(1))result,count=scatter_add_nd_with_count(result,count,indices_11,values*w_11.unsqueeze(1),weights*w_11.unsqueeze(1))ifreturn_count:returnresult,countmask=(count.squeeze(-1)>0)result[mask]=result[mask]/count[mask].repeat(1,C)returnresultdefmipmap_linear_grid_put_2d(H,W,coords,values,min_resolution=32,return_count=False):# coords: [N, 2], float in [-1, 1]# values: [N, C]C=values.shape[-1]result=torch.zeros(H,W,C,device=values.device,dtype=values.dtype)# [H, W, C]count=torch.zeros(H,W,1,device=values.device,dtype=values.dtype)# [H, W, 1]cur_H,cur_W=H,Wwhilemin(cur_H,cur_W)>min_resolution:# try to fill the holesmask=(count.squeeze(-1)==0)ifnotmask.any():breakcur_result,cur_count=linear_grid_put_2d(cur_H,cur_W,coords,values,return_count=True)result[mask]=result[mask]+F.interpolate(cur_result.permute(2,0,1).unsqueeze(0).contiguous(),(H,W),mode='bilinear',align_corners=False).squeeze(0).permute(1,2,0).contiguous()[mask]count[mask]=count[mask]+F.interpolate(cur_count.view(1,1,cur_H,cur_W),(H,W),mode='bilinear',align_corners=False).view(H,W,1)[mask]cur_H//=2cur_W//=2ifreturn_count:returnresult,countmask=(count.squeeze(-1)>0)result[mask]=result[mask]/count[mask].repeat(1,C)returnresultdefnearest_grid_put_3d(H,W,D,coords,values,return_count=False):# coords: [N, 3], float in [-1, 1]# values: [N, C]C=values.shape[-1]indices=(coords*0.5+0.5)*torch.tensor([H-1,W-1,D-1],dtype=torch.float32,device=coords.device)indices=indices.round().long()# [N, 2]result=torch.zeros(H,W,D,C,device=values.device,dtype=values.dtype)# [H, W, C]count=torch.zeros(H,W,D,1,device=values.device,dtype=values.dtype)# [H, W, 1]weights=torch.ones_like(values[...,:1])# [N, 1]result,count=scatter_add_nd_with_count(result,count,indices,values,weights)ifreturn_count:returnresult,countmask=(count.squeeze(-1)>0)result[mask]=result[mask]/count[mask].repeat(1,C)returnresultdeflinear_grid_put_3d(H,W,D,coords,values,return_count=False):# coords: [N, 3], float in [-1, 1]# values: [N, C]C=values.shape[-1]indices=(coords*0.5+0.5)*torch.tensor([H-1,W-1,D-1],dtype=torch.float32,device=coords.device)indices_000=indices.floor().long()# [N, 3]indices_000[:,0].clamp_(0,H-2)indices_000[:,1].clamp_(0,W-2)indices_000[:,2].clamp_(0,D-2)indices_001=indices_000+torch.tensor([0,0,1],dtype=torch.long,device=indices.device)indices_010=indices_000+torch.tensor([0,1,0],dtype=torch.long,device=indices.device)indices_011=indices_000+torch.tensor([0,1,1],dtype=torch.long,device=indices.device)indices_100=indices_000+torch.tensor([1,0,0],dtype=torch.long,device=indices.device)indices_101=indices_000+torch.tensor([1,0,1],dtype=torch.long,device=indices.device)indices_110=indices_000+torch.tensor([1,1,0],dtype=torch.long,device=indices.device)indices_111=indices_000+torch.tensor([1,1,1],dtype=torch.long,device=indices.device)h=indices[...,0]-indices_000[...,0].float()w=indices[...,1]-indices_000[...,1].float()d=indices[...,2]-indices_000[...,2].float()w_000=(1-h)*(1-w)*(1-d)w_001=(1-h)*w*(1-d)w_010=h*(1-w)*(1-d)w_011=h*w*(1-d)w_100=(1-h)*(1-w)*dw_101=(1-h)*w*dw_110=h*(1-w)*dw_111=h*w*dresult=torch.zeros(H,W,D,C,device=values.device,dtype=values.dtype)# [H, W, D, C]count=torch.zeros(H,W,D,1,device=values.device,dtype=values.dtype)# [H, W, D, 1]weights=torch.ones_like(values[...,:1])# [N, 1]result,count=scatter_add_nd_with_count(result,count,indices_000,values*w_000.unsqueeze(1),weights*w_000.unsqueeze(1))result,count=scatter_add_nd_with_count(result,count,indices_001,values*w_001.unsqueeze(1),weights*w_001.unsqueeze(1))result,count=scatter_add_nd_with_count(result,count,indices_010,values*w_010.unsqueeze(1),weights*w_010.unsqueeze(1))result,count=scatter_add_nd_with_count(result,count,indices_011,values*w_011.unsqueeze(1),weights*w_011.unsqueeze(1))result,count=scatter_add_nd_with_count(result,count,indices_100,values*w_100.unsqueeze(1),weights*w_100.unsqueeze(1))result,count=scatter_add_nd_with_count(result,count,indices_101,values*w_101.unsqueeze(1),weights*w_101.unsqueeze(1))result,count=scatter_add_nd_with_count(result,count,indices_110,values*w_110.unsqueeze(1),weights*w_110.unsqueeze(1))result,count=scatter_add_nd_with_count(result,count,indices_111,values*w_111.unsqueeze(1),weights*w_111.unsqueeze(1))ifreturn_count:returnresult,countmask=(count.squeeze(-1)>0)result[mask]=result[mask]/count[mask].repeat(1,C)returnresultdefmipmap_linear_grid_put_3d(H,W,D,coords,values,min_resolution=32,return_count=False):# coords: [N, 3], float in [-1, 1]# values: [N, C]C=values.shape[-1]result=torch.zeros(H,W,D,C,device=values.device,dtype=values.dtype)# [H, W, D, C]count=torch.zeros(H,W,D,1,device=values.device,dtype=values.dtype)# [H, W, D, 1]cur_H,cur_W,cur_D=H,W,Dwhilemin(min(cur_H,cur_W),cur_D)>min_resolution:# try to fill the holesmask=(count.squeeze(-1)==0)ifnotmask.any():breakcur_result,cur_count=linear_grid_put_3d(cur_H,cur_W,cur_D,coords,values,return_count=True)result[mask]=result[mask]+F.interpolate(cur_result.permute(3,0,1,2).unsqueeze(0).contiguous(),(H,W,D),mode='trilinear',align_corners=False).squeeze(0).permute(1,2,3,0).contiguous()[mask]count[mask]=count[mask]+F.interpolate(cur_count.view(1,1,cur_H,cur_W,cur_D),(H,W,D),mode='trilinear',align_corners=False).view(H,W,D,1)[mask]cur_H//=2cur_W//=2cur_D//=2ifreturn_count:returnresult,countmask=(count.squeeze(-1)>0)result[mask]=result[mask]/count[mask].repeat(1,C)returnresult
[docs]defgrid_put(shape:Sequence[int],coords:Tensor,values:Tensor,mode:Literal['nearest','linear','linear-mipmap']='linear-mipmap',min_resolution:int=32,return_count:bool=False)->Tensor:""" put back values to an image according to the coords. inverse operation of ``F.grid_sample``. Args: shape (Sequence[int]): shape of the image, support 2D image and 3D volume, sequence of [D] coords (Tensor): coordinates, float [N, D] in [-1, 1]. values (Tensor): values, float [N, C]. mode (str, Literal['nearest', 'linear', 'linear-mipmap']): interpolation mode, see https://github.com/ashawkey/grid_put for examples. Defaults to 'linear-mipmap'. min_resolution (int, optional): minimal resolution for mipmap. Defaults to 32. return_count (bool, optional): whether to return the summed value and weights, instead of the divided results. Defaults to False. Returns: Tensor: the restored image/volume, float [H, W, C]/[H, W, D, C]. """D=len(shape)assertDin[2,3],f'only support D == 2 or 3, but got D == {D}'ifmode=='nearest':ifD==2:returnnearest_grid_put_2d(*shape,coords,values,return_count)else:returnnearest_grid_put_3d(*shape,coords,values,return_count)elifmode=='linear':ifD==2:returnlinear_grid_put_2d(*shape,coords,values,return_count)else:returnlinear_grid_put_3d(*shape,coords,values,return_count)elifmode=='linear-mipmap':ifD==2:returnmipmap_linear_grid_put_2d(*shape,coords,values,min_resolution,return_count)else:returnmipmap_linear_grid_put_3d(*shape,coords,values,min_resolution,return_count)else:raiseNotImplementedError(f"got mode {mode}")