def test_build_one_hot(fake_geo_data): dataroot, DIV_datas, Vx_datas, Vy_datas = fake_geo_data p = GeoPickler(dataroot) p.collect_all() p.group_by_series() data_dict = p.get_data_dict(0, 0) DIV = (np.random.randn(*data_dict['A_DIV'].shape) * 20000) data_dict['A_DIV'] = DIV p.create_one_hot(data_dict, 1000) one_hot = data_dict['A'] assert ([i in np.where(DIV > 1000) for i in np.where(one_hot[:, :, 0])]) assert ([ i in np.where(np.logical_and(DIV < 1000, DIV < -1000)) for i in np.where(one_hot[:, :, 1]) ]) assert ([i in np.where(DIV < -1000) for i in np.where(one_hot[:, :, 2])])
def test_resized(fake_geo_data): dataroot, DIV_datas, Vx_datas, Vy_datas = fake_geo_data p = GeoPickler(dataroot, row_height=18) p.collect_all() p.group_by_series() data_dict = p.get_data_dict(0, 0) p.create_one_hot(data_dict, 1000) assert ((data_dict['A_DIV'].shape == (18, 36))) assert ((data_dict['A_Vx'].shape == (18, 36))) assert ((data_dict['A_Vy'].shape == (18, 36))) assert ((data_dict['A'].shape == (18, 36, 3)))
def test_mask_params_stored_in_dict(fake_geo_data): dataroot, DIV_datas, Vx_datas, Vy_datas = fake_geo_data p = GeoPickler(dataroot) p.collect_all() p.group_by_series() data_dict = p.get_data_dict(0, 0) DIV = (np.random.randn(*data_dict['A_DIV'].shape) * 20000) data_dict['A_DIV'] = DIV p.create_one_hot(data_dict, 1000) p.get_mask_loc(data_dict, 4, 6) assert (data_dict['mask_size'] == 4) assert (data_dict['min_pix_in_mask'] == 6)
def test_mask_location(fake_geo_data): dataroot, DIV_datas, Vx_datas, Vy_datas = fake_geo_data p = GeoPickler(dataroot) p.collect_all() p.group_by_series() data_dict = p.get_data_dict(0, 0) DIV = (np.random.randn(*data_dict['A_DIV'].shape) * 20000) data_dict['A_DIV'] = DIV p.create_one_hot(data_dict, 1000) p.get_mask_loc(data_dict, 4, 6) one_hot = data_dict['A'] mask_loc = data_dict['mask_locs'] assert (len(mask_loc) > 0) for x in range(one_hot.shape[1] - 4): for y in range(one_hot.shape[0] - 4): sum1 = np.sum(one_hot[y:y + 4, x:x + 4, 0]) sum2 = np.sum(one_hot[y:y + 4, x:x + 4, 2]) if (y, x) in mask_loc: assert (np.sum(one_hot[y:y + 4, x:x + 4, 0]) >= 6) assert (np.sum(one_hot[y:y + 4, x:x + 4, 2]) >= 6) else: assert (np.sum(one_hot[y:y + 4, x:x + 4, 0]) < 6 or np.sum(one_hot[y:y + 4, x:x + 4, 2]) < 6)
class DivInlineModel(BaseModel): def name(self): return 'DivInlineModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.isTrain = opt.isTrain self.D_has_run = False # Dummy pickler that we use to create discrete images from divergence self.p = GeoPickler('') # load/define networks # Input channels = 3 channels for input one-hot map + mask (optional) + continents (optional) input_channels = opt.input_nc if self.opt.mask_to_G: input_channels += 1 if self.opt.continent_data: input_channels += 1 G_output_channels = opt.output_nc if opt.with_BCE: G_output_channels += 1 self.netG = networks.define_G(input_channels, G_output_channels, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) # Filters to create gradient images, if we are using gradient loss if self.opt.grad_loss: self.sobel_layer_y = torch.nn.Conv2d(1, 1, 3, padding=1) self.sobel_layer_y.weight.data = torch.FloatTensor( [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).unsqueeze(0).unsqueeze(0) self.sobel_layer_y.weight.requires_grad = False self.sobel_layer_x = torch.nn.Conv2d(1, 1, 3, padding=1) self.sobel_layer_x.weight.data = torch.FloatTensor( [[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).unsqueeze(0).unsqueeze(0) self.sobel_layer_x.weight.requires_grad = False if len(self.gpu_ids) > 0: self.sobel_layer_y.cuda(self.gpu_ids[0]) self.sobel_layer_x.cuda(self.gpu_ids[0]) if self.isTrain: use_sigmoid = opt.no_lsgan # Inputs: 1 channel of divergence output data + mask (optional) discrim_input_channels = opt.output_nc if not opt.no_mask_to_critic: discrim_input_channels += 1 # If we are only looking at the missing region if opt.local_critic: self.critic_im_size = (64, 64) else: self.critic_im_size = (256, opt.x_size) if self.opt.continent_data: discrim_input_channels += 1 # Create discriminator self.netD = networks.define_D(discrim_input_channels, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, critic_im_size=self.critic_im_size) if len(self.gpu_ids) > 0: self.netD.cuda() if not self.isTrain or opt.continue_train or opt.restart_G: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain and not opt.restart_G: self.load_network(self.netD, 'D', opt.which_epoch) if self.opt.local_loss: self.im_dims = self.opt.mask_size, self.opt.mask_size else: self.im_dims = (256, self.opt.x_size) if self.isTrain: # define loss functions if self.opt.int_vars: self.criterionR = torch.nn.MSELoss( size_average=True, reduce=(not self.opt.weighted_reconstruction)) else: ce_fun = torch.nn.CrossEntropyLoss( size_average=True, reduce=(not self.opt.weighted_reconstruction)) self.criterionR = lambda test, target: ce_fun( test.view(self.opt.batchSize, self.opt.output_nc, -1), target.max(dim=1)[1].view(self.opt.batchSize, -1).long()) self.criterionBCE = torch.nn.BCELoss( size_average=True, reduce=(not self.opt.weighted_CE)) # Choose post-processing function for reconstruction losses if self.opt.log_L2: self.processL2 = torch.log else: self.processL2 = identity if self.opt.log_BCE: self.processBCE = torch.log else: self.processBCE = identity # Choose post-processing function for discriminator output if self.opt.use_hinge: self.criterionGAN = hinge_criterionGAN elif self.opt.which_model_netD == 'wgan-gp' or self.opt.which_model_netD == 'self-attn': self.criterionGAN = wgan_criterionGAN else: self.criterionGAN = networks.GANLoss( use_lsgan=not opt.no_lsgan, tensor=self.Tensor) # initialize optimizers self.schedulers = [] self.optimizers = [] if opt.optim_type == 'adam': optim = torch.optim.Adam G_optim_kwargs = { 'lr': opt.g_lr, 'betas': (opt.g_beta1, 0.999) } D_optim_kwargs = { 'lr': opt.d_lr, 'betas': (opt.d_beta1, 0.999) } elif opt.optim_type == 'rmsprop': optim = torch.optim.RMSprop G_optim_kwargs = {'lr': opt.g_lr, 'alpha': opt.alpha} D_optim_kwargs = {'lr': opt.d_lr, 'alpha': opt.alpha} self.optimizer_G = optim( filter(lambda p: p.requires_grad, self.netG.parameters()), **G_optim_kwargs) self.optimizers.append(self.optimizer_G) self.optimizer_D = optim( filter(lambda p: p.requires_grad, self.netD.parameters()), **D_optim_kwargs) self.optimizers.append(self.optimizer_D) # Just a linear decay over the last 100 iterations, by default for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG) if self.isTrain: if self.opt.num_discrims > 0: networks.print_network(self.netD) print('-----------------------------------------------') def set_input(self, input): # This model is B to A by default AtoB = self.opt.which_direction == 'AtoB' # This is a confusing leftover of the pix2pix code # We process the images in the geo dataset A to B, that is # full image to masked out # So we want to switch direction, as we're trying to predict the # full image from the masked out input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] input_A_DIV = input['A_DIV' if AtoB else 'B_DIV'] input_B_DIV = input['B_DIV' if AtoB else 'A_DIV'] mask = input['mask'] if self.opt.continent_data: continents = input['cont'] if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0], async=True) input_B = input_B.cuda(self.gpu_ids[0], async=True) input_A_DIV = input_A_DIV.cuda(self.gpu_ids[0], async=True) input_B_DIV = input_B_DIV.cuda(self.gpu_ids[0], async=True) mask = mask.cuda(self.gpu_ids[0], async=True) if self.opt.continent_data: continents = continents.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.input_B = input_B self.input_A_DIV = input_A_DIV self.input_B_DIV = input_B_DIV self.mask = mask if 'A_paths' in input.keys(): self.A_path = input['A_paths'] elif 'folder_id' in input.keys(): self.A_path = [ 'serie_{}_{:05}'.format(input['folder_name'][0], input['series_number'][0]) ] if self.opt.continent_data: self.continent_img = continents self.batch_size = input_A.shape[0] self.mask_size = input['mask_size'].numpy()[0] self.div_thresh = input['DIV_thresh'] self.div_min = input['DIV_min'] self.div_max = input['DIV_max'] if self.isTrain and self.opt.num_discrims > 0: if self.opt.local_critic: assert ( self.mask_size, self.mask_size ) == self.critic_im_size, "Fix im dimensions in critic {} -> {}".format( self.critic_im_size, (self.mask_size, self.mask_size)) else: assert input_A.shape[ 2:] == self.critic_im_size, "Fix im dimensions in critic {} -> {}".format( self.critic_im_size, input_A.shape[2:]) if self.opt.local_loss: self.loss_mask = self.mask.byte() else: # if we aren't taking local loss, use entire image loss_mask = torch.ones(self.mask.shape).byte() loss_mask = loss_mask.cuda() if len( self.gpu_ids) > 0 else loss_mask self.loss_mask = torch.autograd.Variable(loss_mask) def forward(self): # Thresholded, one-hot divergence map with chunk missing self.real_A_discrete = torch.autograd.Variable(self.input_A) # Complete thresholded, one-hot divergence map self.real_B_discrete = torch.autograd.Variable(self.input_B) self.real_B_fg = torch.max(self.real_B_discrete[:, [0, 2], :, :], dim=1)[0].unsqueeze(1) # Continuous divergence map with chunk missing self.real_A_DIV = torch.autograd.Variable(self.input_A_DIV) # Complete continuous divergence map self.real_B_DIV = torch.autograd.Variable(self.input_B_DIV) # Mask of inpainted region self.mask = torch.autograd.Variable(self.mask) if self.opt.continent_data: self.continents = torch.autograd.Variable(self.continent_img) # Produces three channel output with class probability assignments # Input is one-hot image with chunk missing, conditional data is mask self.G_input = self.real_A_discrete if self.opt.continent_data: self.G_input = torch.cat((self.G_input, self.continents.float()), dim=1) self.G_out = self.netG(self.G_input) self.fake_B_out = self.G_out[:, :self.opt.output_nc, :, :] self.fake_B_out_ROI = self.fake_B_out.masked_select( self.loss_mask).view(self.batch_size, self.fake_B_out.shape[1], *self.im_dims) # If we're creating the foreground image, just use that as discrete if self.opt.with_BCE: self.fake_B_fg = torch.nn.Sigmoid()( self.G_out[:, -1, :, :].unsqueeze(1)) self.fake_fg_discrete = self.fake_B_fg > 0.5 if self.opt.grad_loss: self.real_B_DIV_grad_x = self.sobel_layer_x(self.real_B_DIV) self.real_B_DIV_grad_y = self.sobel_layer_y(self.real_B_DIV) self.fake_B_DIV_grad_x = self.sobel_layer_x(self.fake_B_DIV) self.fake_B_DIV_grad_y = self.sobel_layer_y(self.fake_B_DIV) if self.opt.int_vars: self.fake_B_DIV = self.fake_B_out self.fake_B_DIV_ROI = self.fake_B_out_ROI # One hot was created by thresholding unscaled image, so rescale the threshold to # apply to scaled image scaled_thresh = self.div_thresh.repeat(1, 3) / torch.cat( (self.div_max, torch.ones(self.div_max.shape), -self.div_min), dim=1) scaled_thresh = scaled_thresh.view(self.fake_B_DIV.shape[0], 3, 1, 1) scaled_thresh = scaled_thresh.cuda() if len( self.gpu_ids) > 0 else scaled_thresh # Apply threshold to divergence image self.fake_B_discrete = (torch.cat( (self.fake_B_DIV * (-1 if self.opt.invert_ridge else 1), torch.zeros(self.fake_B_DIV.shape, device=self.fake_B_DIV.device.type), self.fake_B_DIV * (1 if self.opt.invert_ridge else -1)), dim=1) > scaled_thresh) plate = 1 - torch.max(self.fake_B_discrete, dim=1)[0] self.fake_B_discrete[:, 1, :, :].copy_(plate.detach()) else: self.fake_B_discrete = self.fake_B_out if self.opt.with_BCE: self.real_B_fg_ROI = self.real_B_fg.masked_select( self.loss_mask).view(self.batch_size, 1, *self.im_dims) self.fake_fg_discrete_ROI = self.fake_fg_discrete.masked_select( self.loss_mask).view(self.batch_size, 1, *self.im_dims) self.fake_B_fg_ROI = self.fake_B_fg.masked_select( self.loss_mask).view(self.batch_size, 1, *self.im_dims) self.real_B_DIV_ROI = self.real_B_DIV.masked_select( self.loss_mask).view(self.batch_size, 1, *self.im_dims) if self.opt.grad_loss: self.real_B_DIV_grad_x = self.real_B_DIV_grad_x.masked_select( self.loss_mask).view(self.batch_size, 1, *self.im_dims) self.real_B_DIV_grad_y = self.real_B_DIV_grad_y.masked_select( self.loss_mask).view(self.batch_size, 1, *self.im_dims) self.fake_B_DIV_grad_x = self.fake_B_DIV_grad_x.masked_select( self.loss_mask).view(self.batch_size, 1, *self.im_dims) self.fake_B_DIV_grad_y = self.fake_B_DIV_grad_y.masked_select( self.loss_mask).view(self.batch_size, 1, *self.im_dims) self.fake_B_discrete_ROI = self.fake_B_discrete.masked_select( self.loss_mask.repeat(1, 3, 1, 1)).view(self.batch_size, 3, *self.im_dims) self.real_B_discrete_ROI = self.real_B_discrete.masked_select( self.loss_mask.repeat(1, 3, 1, 1)).view(self.batch_size, 3, *self.im_dims) if self.opt.int_vars: self.real_B_out_ROI = self.real_B_DIV_ROI else: self.real_B_out_ROI = self.real_B_discrete_ROI if self.opt.weighted_reconstruction or self.opt.weighted_CE: # If we are using BCE, default is to create both weight masks using the fg channel # We can alternatively specify that only BCE weighting is created using the fg channel, # and the discrete output is used to create its own weight mask if self.opt.with_BCE: self.ce_weight_mask = util.create_weight_mask( self.real_B_fg_ROI, self.fake_fg_discrete_ROI.float()) if self.opt.with_BCE and not self.opt.ce_weight_mask: self.weight_mask = self.ce_weight_mask else: self.weight_mask = util.create_weight_mask( self.real_B_discrete_ROI, self.fake_B_discrete_ROI.float(), diff_in_numerator=self.opt.diff_in_numerator, method='freq') # no backprop gradients def test(self): self.real_A_discrete = torch.autograd.Variable(self.input_A, volatile=True) self.real_B_discrete = torch.autograd.Variable(self.input_B, volatile=True) self.real_B_fg = torch.max(self.real_B_discrete[:, [0, 2], :, :], dim=1)[0].unsqueeze(1) self.real_A_DIV = torch.autograd.Variable(self.input_A_DIV) self.real_B_DIV = torch.autograd.Variable(self.input_B_DIV) self.mask = torch.autograd.Variable(self.mask) if self.opt.continent_data: self.continents = torch.autograd.Variable(self.continent_img) # mask_var = Variable(self.mask.float(), volatile=True) # self.G_input = torch.cat((self.real_A_discrete, self.mask.float()), dim=1) self.G_input = self.real_A_discrete if self.opt.mask_to_G: self.G_input = torch.cat((self.G_input, self.mask.float()), dim=1) if self.opt.continent_data: self.G_input = torch.cat((self.G_input, self.continents.float()), dim=1) self.G_out = self.netG(self.G_input) self.fake_B_out = self.G_out[:, :self.opt.output_nc, :, :] if self.opt.with_BCE: self.fake_B_fg = torch.nn.Sigmoid()(self.G_out[:, -1:, :, :]) self.fake_fg_discrete = self.fake_B_fg > 0.5 if self.opt.int_vars: self.fake_B_DIV = self.fake_B_out self.fake_B_DIV_ROI = self.fake_B_DIV.masked_select( self.loss_mask).view(self.batch_size, 1, *self.im_dims) scaled_thresh = self.div_thresh.repeat(1, 3) / torch.cat( (self.div_max, torch.ones(self.div_max.shape), -self.div_min), dim=1) scaled_thresh = scaled_thresh.view(self.fake_B_DIV.shape[0], 3, 1, 1) scaled_thresh = scaled_thresh.cuda() if len( self.gpu_ids) > 0 else scaled_thresh self.fake_B_discrete = (torch.cat( (self.fake_B_DIV * (-1 if self.opt.invert_ridge else 1), torch.zeros(self.fake_B_DIV.shape, device=self.fake_B_DIV.device.type), self.fake_B_DIV * (1 if self.opt.invert_ridge else -1)), dim=1) > scaled_thresh) plate = 1 - torch.max(self.fake_B_discrete, dim=1)[0] self.fake_B_discrete[:, 1, :, :].copy_(plate.detach()) else: self.fake_B_discrete = self.fake_B_out # Work out the threshold from quantification factor # tmp_dict = {'A_DIV': self.fake_B_DIV.data[0].numpy().squeeze()} # self.p.create_one_hot(tmp_dict, 0.5) # self.fake_B_discrete_05 = tmp_dict['A'] # self.p.create_one_hot(tmp_dict, 0.2) # self.fake_B_discrete_02 = tmp_dict['A'] # self.p.create_one_hot(tmp_dict, 0.1) self.real_B_DIV_ROI = self.real_B_DIV.masked_select( self.loss_mask).view(self.batch_size, 1, *self.im_dims) self.real_B_discrete_ROI = self.real_B_discrete.masked_select( self.loss_mask.repeat(1, 3, 1, 1)).view(self.batch_size, 3, *self.im_dims) self.fake_B_discrete_ROI = self.fake_B_discrete.masked_select( self.loss_mask.repeat(1, 3, 1, 1)).view(self.batch_size, 3, *self.im_dims) if self.opt.int_vars: self.real_B_out_ROI = self.real_B_DIV_ROI else: self.real_B_out_ROI = self.real_B_discrete_ROI if self.opt.with_BCE: self.real_B_fg_ROI = self.real_B_fg.masked_select( self.loss_mask).view(self.batch_size, 1, *self.im_dims) self.fake_B_fg_ROI = self.fake_B_fg.masked_select( self.loss_mask).view(self.batch_size, 1, *self.im_dims) self.fake_fg_discrete_ROI = self.fake_fg_discrete.masked_select( self.loss_mask).view(self.batch_size, 1, *self.im_dims) # self.fake_B_discrete_01 = tmp_dict['A'] # get image paths def get_image_paths(self): return self.A_path def calc_gradient_penalty(self, netD, real_data, fake_data): # Calculate gradient penalty of points interpolated between real and fake pairs alpha = torch.rand(real_data.shape[0], 1) alpha = alpha.expand(alpha.shape[0], real_data[0, ...].nelement()).contiguous().view( -1, *real_data.shape[1:]) alpha = alpha.cuda(self.gpu_ids[0]) if len(self.gpu_ids) > 0 else alpha interpolates = alpha * fake_data + ((1 - alpha) * real_data) if len(self.gpu_ids) > 0: interpolates = interpolates.cuda(self.gpu_ids[0]) interpolates = autograd.Variable(interpolates, requires_grad=True) disc_interpolates = netD(interpolates) # We have the [0] at the end because grad() returns a tuple with an empty second element, for some reason gradients = autograd.grad( outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size()).cuda( self.gpu_ids[0]) if len(self.gpu_ids) > 0 else torch.ones(disc_interpolates.size()), create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.view(gradients.size(0), -1) # Flattened, so we take the gradient wrt every x (each pixel in each channel) # Take mean across the batch gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean( dim=0, keepdim=True) return gradient_penalty def backward_G(self): self.loss_G_GAN = 0 self.loss_G_L2 = 0 if self.opt.num_discrims > 0: # Conditional data (input with chunk missing + mask) + fake data # Remember self.fake_B_discrete is the generator output if self.opt.local_critic: if self.opt.no_int_vars: fake_AB = self.fake_B_discrete_ROI else: fake_AB = self.fake_B_DIV_ROI else: if self.opt.no_int_vars: fake_AB = self.fake_B_discrete else: fake_AB = self.fake_B_DIV if self.opt.continent_data: fake_AB = torch.cat((fake_AB, self.continents.float()), dim=1) if not self.opt.no_mask_to_critic: fake_AB = torch.cat((fake_AB, self.mask.float()), dim=1) # Mean across batch, then across discriminators # We only optimise with respect to the fake prediction because # the first term (i.e. the real one) is independent of the generator i.e. it is just a constant term for p in self.netD.parameters(): p.requires_grad = False if self.opt.use_hinge: self.loss_G_GAN1 = -self.netD(fake_AB) else: self.loss_G_GAN1 = self.criterionGAN(self.netD(fake_AB), True) # Trying to incentivise making this big, so it's mistaken for real self.loss_G_GAN = self.loss_G_GAN1 * self.opt.lambda_D for p in self.netD.parameters(): p.requires_grad = True ##### L2 Loss self.loss_G_rec = self.criterionR(self.fake_B_out_ROI, self.real_B_out_ROI) if self.opt.weighted_reconstruction: self.loss_G_rec = ( self.weight_mask.detach() * self.loss_G_rec.reshape( self.opt.batchSize, 1, *self.im_dims)).sum(3).sum(2) self.loss_G_rec = self.processL2(self.loss_G_rec * self.opt.lambda_A + 1e-8) * self.opt.lambda_A2 # self.loss_G_rec += self.loss_G_L2_rec if self.opt.grad_loss: grad_x_L2_img = self.criterionR(self.fake_B_DIV_grad_x, self.real_B_DIV_grad_x.detach()) grad_y_L2_img = self.criterionR(self.fake_B_DIV_grad_y, self.real_B_DIV_grad_y.detach()) if self.opt.weighted_grad: grad_x_L2_img = self.weight_mask.detach() * grad_x_L2_img grad_y_L2_img = self.weight_mask.detach() * grad_y_L2_img self.loss_L2_DIV_grad_x = (grad_x_L2_img) self.loss_L2_DIV_grad_y = (grad_y_L2_img) self.loss_G_L2 += self.loss_L2_DIV_grad_x self.loss_G_L2 += self.loss_L2_DIV_grad_y self.loss_G = self.loss_G_GAN + self.loss_G_rec ##### BCE Loss if self.opt.with_BCE: self.loss_fg_CE = self.criterionBCE(self.fake_B_fg_ROI, self.real_B_fg_ROI.float()) if self.opt.weighted_CE: self.loss_fg_CE = (self.ce_weight_mask.detach() * self.loss_fg_CE).sum(3).sum(2) self.loss_fg_CE = self.processBCE(self.loss_fg_CE * self.opt.lambda_B + 1e-8) * self.opt.lambda_B2 self.loss_G += self.loss_fg_CE self.loss_G = self.loss_G.mean() self.loss_G.backward() def optimize_D(self): if self.opt.num_discrims > 0: cond_data = torch.cat((self.real_A_discrete, self.mask.float()), dim=1) if self.opt.local_critic: if self.opt.no_int_vars: fake_AB = self.fake_B_discrete_ROI real_AB = self.real_B_discrete_ROI else: fake_AB = self.fake_B_DIV_ROI real_AB = self.real_B_DIV_ROI else: if self.opt.no_int_vars: fake_AB = self.fake_B_discrete real_AB = self.real_B_discrete else: fake_AB = self.fake_B_DIV real_AB = self.real_B_DIV if self.opt.continent_data: fake_AB = torch.cat((fake_AB, self.continents.float()), dim=1) real_AB = torch.cat((real_AB, self.continents.float()), dim=1) if not self.opt.no_mask_to_critic: fake_AB = torch.cat((fake_AB, self.mask.float()), dim=1) real_AB = torch.cat((real_AB, self.mask.float()), dim=1) # stop backprop to the generator by detaching fake_B self.loss_D_fake = self.criterionGAN(self.netD(fake_AB.detach()), False) # Real self.loss_D_real = self.criterionGAN(self.netD(real_AB), True) loss = self.loss_D_fake + self.loss_D_real if self.opt.which_model_netD == 'wgan-gp' or self.opt.which_model_netD == 'self-attn': if not self.opt.use_hinge: self.grad_pen_loss = self.calc_gradient_penalty( self.netD, real_AB.data, fake_AB.data) * self.opt.lambda_C loss += self.grad_pen_loss loss = loss.mean() loss.backward() if not self.D_has_run: self.D_has_run = True def optimize_G(self): self.backward_G() def zero_optimisers(self): for optimiser in self.optimizers: optimiser.zero_grad() def step_optimisers(self): for optimiser in self.optimizers: optimiser.step() def get_current_errors(self): errors = [('G', self.loss_G.data[0]), ('G_rec', self.loss_G_rec.data[0])] if self.opt.grad_loss: errors += [('G_L2_grad_x', self.loss_L2_DIV_grad_x.data[0]), ('G_L2_grad_y', self.loss_L2_DIV_grad_y.data[0])] if self.opt.with_BCE: errors += [('G_fg_CE', self.loss_fg_CE.data[0])] if self.opt.num_discrims > 0 and self.D_has_run: errors += [('G_GAN_D', self.loss_G_GAN.data[0]), ('D_real', self.loss_D_real.data[0]), ('D_fake', self.loss_D_fake.data[0])] if self.opt.which_model_netD == 'wgan-gp' or self.opt.which_model_netD == 'self-attn': if not self.opt.use_hinge: errors += [('G_grad_pen', self.grad_pen_loss.data[0])] if self.isTrain and self.opt.num_folders > 1 and self.opt.folder_pred: errors.append(('folder_CE', self.folder_pred_CE.data[0])) return OrderedDict(errors) def get_current_visuals(self): # print(np.unique(self.real_A_discrete.data)) # print(self.fake_B_discrete.data.shape) mask_edge = roberts(self.mask.data.cpu().numpy()[0, ...].squeeze()) mask_edge_coords = np.where(mask_edge) visuals = [] real_A_discrete = util.tensor2im(self.real_A_discrete.data) real_A_discrete[mask_edge_coords] = np.max(real_A_discrete) visuals.append(('input_one_hot', real_A_discrete)) real_B_discrete = util.tensor2im(self.real_B_discrete.data) real_B_discrete[mask_edge_coords] = np.max(real_B_discrete) visuals.append(('ground_truth_one_hot', real_B_discrete)) fake_B_discrete = util.tensor2im(self.fake_B_discrete.data) fake_B_discrete[mask_edge_coords] = np.max(fake_B_discrete) visuals.append(('output_one_hot', fake_B_discrete)) real_A_DIV = util.tensor2im(self.real_A_DIV.data) real_A_DIV[mask_edge_coords] = np.max(real_A_DIV) visuals.append(('input_divergence', real_A_DIV)) real_B_DIV = util.tensor2im(self.real_B_DIV.data) real_B_DIV[mask_edge_coords] = np.max(real_B_DIV) visuals.append(('ground_truth_divergence', real_B_DIV)) if self.opt.int_vars: fake_B_DIV = util.tensor2im(self.fake_B_DIV.data) fake_B_DIV[mask_edge_coords] = np.max(fake_B_DIV) visuals.append(('output_divergence', fake_B_DIV)) if self.opt.grad_loss: real_B_DIV_grad_x = util.tensor2im(self.real_B_DIV_grad_x.data) visuals.append(('ground_truth_x_gradient', real_B_DIV_grad_x)) real_B_DIV_grad_y = util.tensor2im(self.real_B_DIV_grad_y.data) visuals.append(('ground_truth_y_gradient', real_B_DIV_grad_y)) fake_B_DIV_grad_x = util.tensor2im(self.fake_B_DIV_grad_x.data) visuals.append(('output_x_gradient', fake_B_DIV_grad_x)) fake_B_DIV_grad_y = util.tensor2im(self.fake_B_DIV_grad_y.data) visuals.append(('output_y_gradient', fake_B_DIV_grad_y)) if self.opt.with_BCE: fake_B_fg = util.tensor2im(self.fake_B_fg.data) fake_B_fg[mask_edge_coords] = np.max(fake_B_fg) visuals.append(('fake_B_fg', fake_B_fg)) fake_fg_discrete = util.tensor2im( self.fake_fg_discrete.data.float()) fake_fg_discrete[mask_edge_coords] = np.max(fake_fg_discrete) visuals.append(('fake_fg_discrete', fake_fg_discrete)) real_B_fg = util.tensor2im(self.real_B_fg.data) real_B_fg[mask_edge_coords] = np.max(real_B_fg) visuals.append(('real_foreground', real_B_fg)) elif self.opt.weighted_reconstruction: fake_B_discrete = util.tensor2im(self.fake_B_discrete.data) fake_B_discrete[mask_edge_coords] = np.max(fake_B_discrete) visuals.append(('fake_B_discrete', fake_B_discrete)) if self.opt.weighted_reconstruction or self.opt.weighted_CE: weight_mask = util.tensor2im(self.weight_mask.data) if not self.opt.local_loss: weight_mask[mask_edge_coords] = np.max(weight_mask) visuals.append(('weight_mask', weight_mask)) if self.opt.continent_data: continents = util.tensor2im(self.continents.data) continents[mask_edge_coords] = np.max(continents) visuals.append(('continents', continents)) if not self.isTrain: visuals.append(('emd_ridge_error', self.emd_ridge_error)) visuals.append(('emd_subduction_error', self.emd_subduction_error)) return OrderedDict(visuals) def get_current_metrics(self): from collections import defaultdict real_disc_local = self.real_B_discrete.masked_select( self.mask.repeat(1, 3, 1, 1)).view( 1, 3, self.mask_size, self.mask_size).data.numpy().squeeze().transpose(1, 2, 0) metrics = [] if self.opt.int_vars: # import skimage.io as io # import matplotlib.pyplot as plt real_DIV = self.real_B_DIV.data.numpy().squeeze() real_disc = self.real_B_discrete.data.numpy().squeeze().transpose( 1, 2, 0) fake_DIV = self.fake_B_DIV.data.numpy().squeeze() real_DIV_local = self.real_B_DIV.masked_select(self.mask).view( 1, 1, self.mask_size, self.mask_size).numpy().squeeze() fake_DIV_local = self.fake_B_DIV.masked_select(self.mask).view( 1, 1, self.mask_size, self.mask_size).data.numpy().squeeze() L2_error = np.mean((real_DIV - fake_DIV)**2) L2_local_error = np.mean((real_DIV_local - fake_DIV_local)**2) metrics.append(('L2_global', L2_error)) metrics.append(('L2_local', L2_local_error)) low_thresh = 2e-4 high_thresh = max(np.max(fake_DIV_local), np.abs(np.min(fake_DIV_local))) # Somehow goofed and produced inverted divergence maps for circles/ellipses, so we sometimes need to flip to compare tmp = { 'A_DIV': fake_DIV_local * (-1 if self.opt.invert_ridge else 1) } #print(np.max(tmp['A_DIV']), np.min(tmp['A_DIV'])) scores = np.ones((5, 1)) * np.inf results_cache = defaultdict(dict) print('search_iter: ') for search_iter in range(10): print('{}... '.format(search_iter), end='\r') # Threshold at equal intervals between low and high, find best score thresholds = np.linspace(low_thresh, high_thresh, 5) for thresh_idx, thresh in enumerate(thresholds): if scores[thresh_idx] != np.inf: continue self.p.create_one_hot(tmp, thresh, skel=self.opt.skel_metric) tmp_disc = tmp['A'] s = [] for i in [0, 2]: tmp_emd, pairs = get_emd(tmp_disc[:, :, i], real_disc_local[:, :, i], average=True, return_pairs=True) s.append(tmp_emd) results_cache[thresh][i] = {'pairs': pairs} results_cache[thresh][i]['score'] = tmp_emd scores[thresh_idx] = (np.mean(s)) best_idx = np.argmin(scores) DIV_thresh = thresholds[best_idx] best_score = scores.ravel()[best_idx] high_idx = best_idx + 1 low_idx = best_idx - 1 if high_idx >= len(thresholds): high_idx -= 1 if low_idx < 0: low_idx += 1 high_thresh = thresholds[high_idx] low_thresh = thresholds[low_idx] print(scores.ravel()[best_idx]) scores[0] = scores[low_idx] scores[-1] = scores[high_idx] scores[1:-1] = np.inf print('Best thresh/score : {}/{}'.format(DIV_thresh, best_score)) self.p.create_one_hot(tmp, DIV_thresh, skel=self.opt.skel_metric) print('Created new one-hot') print('Computing emd 0 ', end='') # emd_cost0, im0 = get_emd(tmp['A'][:, :, 0], real_disc_local[:, :, 0], visualise=True) results = results_cache[DIV_thresh][0] emd_cost0 = results['score'] im0 = visualise_emd(emd_cost0, *self.im_dims, **results['pairs']) print('Computing emd 1 ', end='') # emd_cost1, im1 = get_emd(tmp['A'][:, :, 2], real_disc_local[:, :, 2], visualise=True) results = results_cache[DIV_thresh][2] emd_cost1 = results['score'] im1 = visualise_emd(emd_cost1, *self.im_dims, **results['pairs']) tmp['A_DIV'] = fake_DIV * (-1 if self.opt.invert_ridge else 1) print('Creating full one hot image') self.p.create_one_hot(tmp, DIV_thresh, skel=self.opt.skel_metric) self.fake_B_discrete.data.copy_( torch.from_numpy(tmp['A'].transpose(2, 0, 1))) self.emd_ridge_error = im0 self.emd_subduction_error = im1 else: self.fake_disc_local = self.fake_B_discrete.masked_select( self.mask.repeat(1, 3, 1, 1)).view( 1, 3, self.mask_size, self.mask_size).data.numpy().squeeze().transpose(1, 2, 0) print('Computing emd 0 ', end='') emd_cost0, results = get_emd(self.fake_disc_local[:, :, 0], real_disc_local[:, :, 0], return_pairs=True) im0 = visualise_emd(emd_cost0, *self.im_dims, **results) print('Computing emd 1 ', end='') emd_cost1, results = get_emd(self.fake_disc_local[:, :, 2], real_disc_local[:, :, 2], return_pairs=True) im1 = visualise_emd(emd_cost1, *self.im_dims, **results) self.emd_ridge_error = im0 self.emd_subduction_error = im1 metrics += [('EMD_ridge', emd_cost0), ('EMD_subduction', emd_cost1), ('EMD_mean', (emd_cost0 + emd_cost1) / 2)] print('Done') return OrderedDict(metrics) def accumulate_metrics(self, metrics): a_metrics = [] if self.opt.int_vars: a_metrics.append( ('L2_global', np.mean([metric['L2_global'] for metric in metrics]))) a_metrics.append( ('L2_local', np.mean([metric['L2_local'] for metric in metrics]))) a_metrics.append( ('EMD_ridge', np.mean([metric['EMD_ridge'] for metric in metrics]))) a_metrics.append( ('EMD_subduction', np.mean([metric['EMD_subduction'] for metric in metrics]))) a_metrics.append( ('EMD_mean', np.mean([metric['EMD_mean'] for metric in metrics]))) return OrderedDict(a_metrics) def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) self.save_network(self.netD, 'D', label, self.gpu_ids)