importosimportcv2importmathimporttorchimporttorch.nnasnnimporttorch.nn.initasinitfromtorch.nnimportfunctionalasFfromtorch.nn.modules.batchnormimport_BatchNormimportnumpyasnpfromPILimportImagefromhuggingface_hubimporthf_hub_downloadfromkiui.typingimport*HF_MODELS={2:dict(repo_id='ai-forever/Real-ESRGAN',filename='RealESRGAN_x2.pth',),4:dict(repo_id='ai-forever/Real-ESRGAN',filename='RealESRGAN_x4.pth',),8:dict(repo_id='ai-forever/Real-ESRGAN',filename='RealESRGAN_x8.pth',),}@torch.no_grad()defdefault_init_weights(module_list,scale=1,bias_fill=0,**kwargs):"""Initialize network weights. Args: module_list (list[nn.Module] | nn.Module): Modules to be initialized. scale (float): Scale initialized weights, especially for residual blocks. Default: 1. bias_fill (float): The value to fill bias. Default: 0 kwargs (dict): Other arguments for initialization function. """ifnotisinstance(module_list,list):module_list=[module_list]formoduleinmodule_list:forminmodule.modules():ifisinstance(m,nn.Conv2d):init.kaiming_normal_(m.weight,**kwargs)m.weight.data*=scaleifm.biasisnotNone:m.bias.data.fill_(bias_fill)elifisinstance(m,nn.Linear):init.kaiming_normal_(m.weight,**kwargs)m.weight.data*=scaleifm.biasisnotNone:m.bias.data.fill_(bias_fill)elifisinstance(m,_BatchNorm):init.constant_(m.weight,1)ifm.biasisnotNone:m.bias.data.fill_(bias_fill)defmake_layer(basic_block,num_basic_block,**kwarg):"""Make layers by stacking the same blocks. Args: basic_block (nn.module): nn.module class for basic block. num_basic_block (int): number of blocks. Returns: nn.Sequential: Stacked blocks in nn.Sequential. """layers=[]for_inrange(num_basic_block):layers.append(basic_block(**kwarg))returnnn.Sequential(*layers)classResidualBlockNoBN(nn.Module):"""Residual block without BN. It has a style of: ---Conv-ReLU-Conv-+- |________________| Args: num_feat (int): Channel number of intermediate features. Default: 64. res_scale (float): Residual scale. Default: 1. pytorch_init (bool): If set to True, use pytorch default init, otherwise, use default_init_weights. Default: False. """def__init__(self,num_feat=64,res_scale=1,pytorch_init=False):super(ResidualBlockNoBN,self).__init__()self.res_scale=res_scaleself.conv1=nn.Conv2d(num_feat,num_feat,3,1,1,bias=True)self.conv2=nn.Conv2d(num_feat,num_feat,3,1,1,bias=True)self.relu=nn.ReLU(inplace=True)ifnotpytorch_init:default_init_weights([self.conv1,self.conv2],0.1)defforward(self,x):identity=xout=self.conv2(self.relu(self.conv1(x)))returnidentity+out*self.res_scaleclassUpsample(nn.Sequential):"""Upsample module. Args: scale (int): Scale factor. Supported scales: 2^n and 3. num_feat (int): Channel number of intermediate features. """def__init__(self,scale,num_feat):m=[]if(scale&(scale-1))==0:# scale = 2^nfor_inrange(int(math.log(scale,2))):m.append(nn.Conv2d(num_feat,4*num_feat,3,1,1))m.append(nn.PixelShuffle(2))elifscale==3:m.append(nn.Conv2d(num_feat,9*num_feat,3,1,1))m.append(nn.PixelShuffle(3))else:raiseValueError(f'scale {scale} is not supported. ''Supported scales: 2^n and 3.')super(Upsample,self).__init__(*m)defflow_warp(x,flow,interp_mode='bilinear',padding_mode='zeros',align_corners=True):"""Warp an image or feature map with optical flow. Args: x (Tensor): Tensor with size (n, c, h, w). flow (Tensor): Tensor with size (n, h, w, 2), normal value. interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. padding_mode (str): 'zeros' or 'border' or 'reflection'. Default: 'zeros'. align_corners (bool): Before pytorch 1.3, the default value is align_corners=True. After pytorch 1.3, the default value is align_corners=False. Here, we use the True as default. Returns: Tensor: Warped image or feature map. """assertx.size()[-2:]==flow.size()[1:3]_,_,h,w=x.size()# create mesh gridgrid_y,grid_x=torch.meshgrid(torch.arange(0,h).type_as(x),torch.arange(0,w).type_as(x))grid=torch.stack((grid_x,grid_y),2).float()# W(x), H(y), 2grid.requires_grad=Falsevgrid=grid+flow# scale grid to [-1,1]vgrid_x=2.0*vgrid[:,:,:,0]/max(w-1,1)-1.0vgrid_y=2.0*vgrid[:,:,:,1]/max(h-1,1)-1.0vgrid_scaled=torch.stack((vgrid_x,vgrid_y),dim=3)output=F.grid_sample(x,vgrid_scaled,mode=interp_mode,padding_mode=padding_mode,align_corners=align_corners)# TODO, what if align_corners=Falsereturnoutputdefresize_flow(flow,size_type,sizes,interp_mode='bilinear',align_corners=False):"""Resize a flow according to ratio or shape. Args: flow (Tensor): Precomputed flow. shape [N, 2, H, W]. size_type (str): 'ratio' or 'shape'. sizes (list[int | float]): the ratio for resizing or the final output shape. 1) The order of ratio should be [ratio_h, ratio_w]. For downsampling, the ratio should be smaller than 1.0 (i.e., ratio < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., ratio > 1.0). 2) The order of output_size should be [out_h, out_w]. interp_mode (str): The mode of interpolation for resizing. Default: 'bilinear'. align_corners (bool): Whether align corners. Default: False. Returns: Tensor: Resized flow. """_,_,flow_h,flow_w=flow.size()ifsize_type=='ratio':output_h,output_w=int(flow_h*sizes[0]),int(flow_w*sizes[1])elifsize_type=='shape':output_h,output_w=sizes[0],sizes[1]else:raiseValueError(f'Size type should be ratio or shape, but got type {size_type}.')input_flow=flow.clone()ratio_h=output_h/flow_hratio_w=output_w/flow_winput_flow[:,0,:,:]*=ratio_winput_flow[:,1,:,:]*=ratio_hresized_flow=F.interpolate(input=input_flow,size=(output_h,output_w),mode=interp_mode,align_corners=align_corners)returnresized_flowdefpixel_unshuffle(x,scale):""" Pixel unshuffle. Args: x (Tensor): Input feature with shape (b, c, hh, hw). scale (int): Downsample ratio. Returns: Tensor: the pixel unshuffled feature. """b,c,hh,hw=x.size()out_channel=c*(scale**2)asserthh%scale==0andhw%scale==0h=hh//scalew=hw//scalex_view=x.view(b,c,h,scale,w,scale)returnx_view.permute(0,1,3,5,2,4).reshape(b,out_channel,h,w)defpad_reflect(image,pad_size):imsize=image.shapeheight,width=imsize[:2]new_img=np.zeros([height+pad_size*2,width+pad_size*2,imsize[2]]).astype(np.uint8)new_img[pad_size:-pad_size,pad_size:-pad_size,:]=imagenew_img[0:pad_size,pad_size:-pad_size,:]=np.flip(image[0:pad_size,:,:],axis=0)#topnew_img[-pad_size:,pad_size:-pad_size,:]=np.flip(image[-pad_size:,:,:],axis=0)#bottomnew_img[:,0:pad_size,:]=np.flip(new_img[:,pad_size:pad_size*2,:],axis=1)#leftnew_img[:,-pad_size:,:]=np.flip(new_img[:,-pad_size*2:-pad_size,:],axis=1)#rightreturnnew_imgdefunpad_image(image,pad_size):returnimage[pad_size:-pad_size,pad_size:-pad_size,:]defprocess_array(image_array,expand=True):""" Process a 3-dimensional array into a scaled, 4 dimensional batch of size 1. """image_batch=image_array/255.0ifexpand:image_batch=np.expand_dims(image_batch,axis=0)returnimage_batchdefprocess_output(output_tensor):""" Transforms the 4-dimensional output tensor into a suitable image format. """sr_img=output_tensor.clip(0,1)*255sr_img=np.uint8(sr_img)returnsr_imgdefpad_patch(image_patch,padding_size,channel_last=True):""" Pads image_patch with with padding_size edge values. """ifchannel_last:returnnp.pad(image_patch,((padding_size,padding_size),(padding_size,padding_size),(0,0)),'edge',)else:returnnp.pad(image_patch,((0,0),(padding_size,padding_size),(padding_size,padding_size)),'edge',)defunpad_patches(image_patches,padding_size):returnimage_patches[:,padding_size:-padding_size,padding_size:-padding_size,:]defsplit_image_into_overlapping_patches(image_array,patch_size,padding_size=2):""" Splits the image into partially overlapping patches. The patches overlap by padding_size pixels. Pads the image twice: - first to have a size multiple of the patch size, - then to have equal padding at the borders. Args: image_array: numpy array of the input image. patch_size: size of the patches from the original image (without padding). padding_size: size of the overlapping area. """xmax,ymax,_=image_array.shapex_remainder=xmax%patch_sizey_remainder=ymax%patch_size# modulo here is to avoid extending of patch_size instead of 0x_extend=(patch_size-x_remainder)%patch_sizey_extend=(patch_size-y_remainder)%patch_size# make sure the image is divisible into regular patchesextended_image=np.pad(image_array,((0,x_extend),(0,y_extend),(0,0)),'edge')# add padding around the image to simplify computationspadded_image=pad_patch(extended_image,padding_size,channel_last=True)xmax,ymax,_=padded_image.shapepatches=[]x_lefts=range(padding_size,xmax-padding_size,patch_size)y_tops=range(padding_size,ymax-padding_size,patch_size)forxinx_lefts:foryiny_tops:x_left=x-padding_sizey_top=y-padding_sizex_right=x+patch_size+padding_sizey_bottom=y+patch_size+padding_sizepatch=padded_image[x_left:x_right,y_top:y_bottom,:]patches.append(patch)returnnp.array(patches),padded_image.shapedefstich_together(patches,padded_image_shape,target_shape,padding_size=4):""" Reconstruct the image from overlapping patches. After scaling, shapes and padding should be scaled too. Args: patches: patches obtained with split_image_into_overlapping_patches padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches target_shape: shape of the final image padding_size: size of the overlapping area. """xmax,ymax,_=padded_image_shapepatches=unpad_patches(patches,padding_size)patch_size=patches.shape[1]n_patches_per_row=ymax//patch_sizecomplete_image=np.zeros((xmax,ymax,3))row=-1col=0foriinrange(len(patches)):ifi%n_patches_per_row==0:row+=1col=0complete_image[row*patch_size:(row+1)*patch_size,col*patch_size:(col+1)*patch_size,:]=patches[i]col+=1returncomplete_image[0:target_shape[0],0:target_shape[1],:]classResidualDenseBlock(nn.Module):"""Residual Dense Block. Used in RRDB block in ESRGAN. Args: num_feat (int): Channel number of intermediate features. num_grow_ch (int): Channels for each growth. """def__init__(self,num_feat=64,num_grow_ch=32):super(ResidualDenseBlock,self).__init__()self.conv1=nn.Conv2d(num_feat,num_grow_ch,3,1,1)self.conv2=nn.Conv2d(num_feat+num_grow_ch,num_grow_ch,3,1,1)self.conv3=nn.Conv2d(num_feat+2*num_grow_ch,num_grow_ch,3,1,1)self.conv4=nn.Conv2d(num_feat+3*num_grow_ch,num_grow_ch,3,1,1)self.conv5=nn.Conv2d(num_feat+4*num_grow_ch,num_feat,3,1,1)self.lrelu=nn.LeakyReLU(negative_slope=0.2,inplace=True)# initializationdefault_init_weights([self.conv1,self.conv2,self.conv3,self.conv4,self.conv5],0.1)defforward(self,x):x1=self.lrelu(self.conv1(x))x2=self.lrelu(self.conv2(torch.cat((x,x1),1)))x3=self.lrelu(self.conv3(torch.cat((x,x1,x2),1)))x4=self.lrelu(self.conv4(torch.cat((x,x1,x2,x3),1)))x5=self.conv5(torch.cat((x,x1,x2,x3,x4),1))# Emperically, we use 0.2 to scale the residual for better performancereturnx5*0.2+xclassRRDB(nn.Module):"""Residual in Residual Dense Block. Used in RRDB-Net in ESRGAN. Args: num_feat (int): Channel number of intermediate features. num_grow_ch (int): Channels for each growth. """def__init__(self,num_feat,num_grow_ch=32):super(RRDB,self).__init__()self.rdb1=ResidualDenseBlock(num_feat,num_grow_ch)self.rdb2=ResidualDenseBlock(num_feat,num_grow_ch)self.rdb3=ResidualDenseBlock(num_feat,num_grow_ch)defforward(self,x):out=self.rdb1(x)out=self.rdb2(out)out=self.rdb3(out)# Emperically, we use 0.2 to scale the residual for better performancereturnout*0.2+xclassRRDBNet(nn.Module):"""Networks consisting of Residual in Residual Dense Block, which is used in ESRGAN. ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. We extend ESRGAN for scale x2 and scale x1. Note: This is one option for scale 1, scale 2 in RRDBNet. We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size and enlarge the channel size before feeding inputs into the main ESRGAN architecture. Args: num_in_ch (int): Channel number of inputs. num_out_ch (int): Channel number of outputs. num_feat (int): Channel number of intermediate features. Default: 64 num_block (int): Block number in the trunk network. Defaults: 23 num_grow_ch (int): Channels for each growth. Default: 32. """def__init__(self,num_in_ch,num_out_ch,scale=4,num_feat=64,num_block=23,num_grow_ch=32):super(RRDBNet,self).__init__()self.scale=scaleifscale==2:num_in_ch=num_in_ch*4elifscale==1:num_in_ch=num_in_ch*16self.conv_first=nn.Conv2d(num_in_ch,num_feat,3,1,1)self.body=make_layer(RRDB,num_block,num_feat=num_feat,num_grow_ch=num_grow_ch)self.conv_body=nn.Conv2d(num_feat,num_feat,3,1,1)# upsampleself.conv_up1=nn.Conv2d(num_feat,num_feat,3,1,1)self.conv_up2=nn.Conv2d(num_feat,num_feat,3,1,1)ifscale==8:self.conv_up3=nn.Conv2d(num_feat,num_feat,3,1,1)self.conv_hr=nn.Conv2d(num_feat,num_feat,3,1,1)self.conv_last=nn.Conv2d(num_feat,num_out_ch,3,1,1)self.lrelu=nn.LeakyReLU(negative_slope=0.2,inplace=True)defforward(self,x):ifself.scale==2:feat=pixel_unshuffle(x,scale=2)elifself.scale==1:feat=pixel_unshuffle(x,scale=4)else:feat=xfeat=self.conv_first(feat)body_feat=self.conv_body(self.body(feat))feat=feat+body_feat# upsamplefeat=self.lrelu(self.conv_up1(F.interpolate(feat,scale_factor=2,mode='nearest')))feat=self.lrelu(self.conv_up2(F.interpolate(feat,scale_factor=2,mode='nearest')))ifself.scale==8:feat=self.lrelu(self.conv_up3(F.interpolate(feat,scale_factor=2,mode='nearest')))out=self.conv_last(self.lrelu(self.conv_hr(feat)))returnoutclassRealESRGAN:def__init__(self,device,scale=4):print(f'[INFO] init RealESRGAN_{scale}x: {device}')self.device=deviceself.scale=scaleself.model=RRDBNet(num_in_ch=3,num_out_ch=3,num_feat=64,num_block=23,num_grow_ch=32,scale=scale)self.load_weights()defload_weights(self):model_path=hf_hub_download(repo_id=HF_MODELS[self.scale]['repo_id'],filename=HF_MODELS[self.scale]['filename'])checkpoint=torch.load(model_path)if'params'incheckpoint:self.model.load_state_dict(checkpoint['params'],strict=True)elif'params_ema'incheckpoint:self.model.load_state_dict(checkpoint['params_ema'],strict=True)else:self.model.load_state_dict(checkpoint,strict=True)self.model.eval()self.model.to(self.device)@torch.cuda.amp.autocast()defpredict(self,lr_image,batch_size=4,patches_size=192,padding=24,pad_size=15):# lr_image: np.ndarray, [h, w, 3], RGB uint8# return: np.ndarray, [H, W, 3], RGB uint8return_tensor=Falseiftorch.is_tensor(lr_image):# or Tensor, [1, 3, H, W], RGB float32lr_image=(lr_image.detach().permute(0,2,3,1)[0].cpu().numpy()*255).astype(np.uint8)return_tensor=Truelr_image=pad_reflect(lr_image,pad_size)patches,p_shape=split_image_into_overlapping_patches(lr_image,patch_size=patches_size,padding_size=padding)img=torch.from_numpy(patches.astype(np.float32)/255).permute((0,3,1,2)).to(self.device).detach()withtorch.no_grad():res=self.model(img[0:batch_size])foriinrange(batch_size,img.shape[0],batch_size):res=torch.cat((res,self.model(img[i:i+batch_size])),0)sr_image=res.permute((0,2,3,1)).clamp_(0,1).cpu()np_sr_image=sr_image.numpy()padded_size_scaled=tuple(np.multiply(p_shape[0:2],self.scale))+(3,)scaled_image_shape=tuple(np.multiply(lr_image.shape[0:2],self.scale))+(3,)np_sr_image=stich_together(np_sr_image,padded_image_shape=padded_size_scaled,target_shape=scaled_image_shape,padding_size=padding*self.scale)sr_img=(np_sr_image*255).astype(np.uint8)sr_img=unpad_image(sr_img,pad_size*self.scale)ifreturn_tensor:sr_img=torch.from_numpy(sr_img.astype(np.float32)/255).permute((2,0,1)).unsqueeze(0).to(self.device)returnsr_imgMODELS={}
[docs]defsr(image:ndarray,scale:Literal[2,4,8]=2,device=None):""" lazy load functional super-resolution API for convenience. Args: image (ndarray): input image, uint8/float32 [H, W, 3] scale (Literal[2, 4, 8], optional): upscale factor. Defaults to 2. device (torch.device, optional): device to put SR models, if not provided, will try to use 'cuda'. Defaults to None. Returns: ndarray: super-resolutioned image, uint8/float32 [H * scale, W * scale, 3] """globalMODELSifscalenotinMODELS:ifdeviceisNone:device='cuda'iftorch.cuda.is_available()else'cpu'MODELS[scale]=RealESRGAN(device,scale=scale)return_float=Falseifimage.dtype==np.float32:return_float=Trueimage=(image*255).astype(np.uint8)sr_image=MODELS[scale].predict(image)ifreturn_float:sr_image=sr_image.astype(np.float32)/255.0returnsr_image