Trouble correctly displaying images generated by EnlightenGAN

1 day ago 1
ARTICLE AD BOX

I downloaded the pretrained EnlightenGAN model from this repository - https://github.com/VITA-Group/EnlightenGAN I'm trying to use this model in my Flask app to enhance uploaded images, but I can't figure out how to normalize the image correctly. Here are the outputs I'm getting:

The input image next to the outputs and the desired output

When the input image tensor is in the [0; 1] range I get output 1, a very gray washed out version of the input image (output tensor in the [0; 4.37] range). When I normalize the tensor to [-1; 1] before putting it through the generator I get output 2 (output tensor in the [-1.01, 5.086] range) and it's a barely enhanced version of the input. When I tested the same model using the predict.py script from the original repository I got the desired output image, which is what I'm trying to achieve. It seems the problem is with the way I preprocess/normalize/post-process the images. Here's my code:

opt = TrainOpt() # Initialize generator netG = networks_eg.define_G(input_nc=3, output_nc=3, ngf=64, which_model_netG='sid_unet_resize', norm='instance', skip=True, opt=opt) # Load your pretrained weights state_dict = torch.load('model/200_net_G_A.pth', map_location='cpu') # Remove 'module.' prefix if present new_state_dict = OrderedDict() for k, v in state_dict.items(): if k.startswith('module.'): k = k.replace('module.', '', 1) new_state_dict[k] = v netG.load_state_dict(new_state_dict) netG.eval() def tensor2im(image_tensor, imtype=np.uint8): image_numpy = image_tensor[0].cpu().float().numpy() image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 image_numpy = np.maximum(image_numpy, 0) image_numpy = np.minimum(image_numpy, 255) return image_numpy.astype(imtype) def enhance_image(input_image): transform = transforms.ToTensor() inverse_transform = transforms.ToPILImage() # Convert input image to tensor input_tensor = transform(input_image).unsqueeze(0) # If the model was trained with im in [-1, 1], normalize: input_tensor = input_tensor * 2 - 1 # Convert grayscale image to tensor gray_image = input_image.convert('L') gray_tensor = transform(gray_image).unsqueeze(0) real_A = Variable(input_tensor, volatile=True) real_A_gray = Variable(gray_tensor, volatile=True) with torch.no_grad(): fake_B, latent_real_A = netG(real_A, real_A_gray) print("Raw output min:", fake_B.min().item()) print("Raw output max:", fake_B.max().item()) fake_B = inverse_transform(tensor2im(fake_B.data)) return fake_B

The tensor2im function is from the original repository. What am I doing wrong here? How do I normalize the image?

Read Entire Article