def backward_local_D(self, generator_outputs, ground_truths, real_labels, fake_labels): list_X = extract_patch_from_tensor(generator_outputs.detach(), patch_size=(128, 128)) # todo add CUDA list_Y = extract_patch_from_tensor(ground_truths.unsqueeze(1).type(torch.cuda.FloatTensor), patch_size=(128, 128)) discriminator_loss = 0 for X, Y in zip(list_X, list_Y): generated_labels = self.local_discriminator(X) generated_labels = generated_labels.view(generated_labels.size(0), 1) generated_loss = self.loss["lsgan"](generated_labels, fake_labels) target_labels = self.local_discriminator(Y) target_labels = target_labels.view(target_labels.size(0), 1) target_loss = self.loss["lsgan"](target_labels, real_labels) discriminator_loss += generated_loss + target_loss discriminator_loss = discriminator_loss / len(list_X) discriminator_loss.backward() return discriminator_loss.item()
def backward_G(self, outputs, probs, targets, locations, orig_window_length, full_image, other_full_image, real_labels): log_probs, rewards = self.global_discriminator.reward_forward(probs, locations, orig_window_length, full_image, other_full_image) generator_global_loss = self.loss["pg"](rewards, log_probs) # outputs = self.generator(inputs) list_G = extract_patch_from_tensor(probs, patch_size=(128, 128)) generator_local_loss = 0 for G in list_G: generated_labels = self.local_discriminator(G) generator_local_loss += self.loss["lsgan"](generated_labels, real_labels) generator_local_loss = generator_local_loss / len(list_G) generator_cce_loss = self.loss["ce"](outputs, targets) generator_loss = self.lambda_1 * generator_cce_loss + self.lambda_2 * generator_local_loss + self.lambda_3 * generator_global_loss generator_loss.backward() return generator_loss.item(), generator_cce_loss.item(), generator_local_loss.item(), generator_global_loss.item()
def _train_discriminator_epoch(self, epoch): """ Pre training logic for an epoch :param epoch: Current training epoch :return: A log that contrains all information you want to save Note: If you have additional information to record, for example: > additional_log = {"x": x, "y": y} merge it with log before return. i.e. > log = {**log, **additional_log} > return log The metrics in log must have the key 'metrics'. """ self.discriminator.train() self.generator.eval() total_loss = 0 total_metrics = np.zeros(len(self.metrics)) for batch_idx, (inputs, targets, _, _, _) in enumerate(self.train_data_loader): if batch_idx == int(len(self.train_data_loader) / 4): break inputs, targets = inputs.to(self.device), targets.to(self.device) batch_size = inputs.size(0) true_labels = torch.ones((batch_size, 1)).to(self.device) fake_labels = -torch.ones((batch_size, 1)).to(self.device) self.discriminator_optimizer.zero_grad() outputs = self.generator(inputs) list_X = extract_patch_from_tensor(outputs[:, :1, :, :].detach(), patch_size=(128, 128)) list_Y = extract_patch_from_tensor(targets.unsqueeze(1).type(torch.cuda.FloatTensor), patch_size=(128, 128)) discriminator_loss = 0 for X, Y in zip(list_X, list_Y): generated_labels = self.discriminator(X) generated_labels = generated_labels.view(batch_size, 1) generated_loss = self.loss["lsgan"](generated_labels, fake_labels) target_labels = self.discriminator(Y) target_labels = target_labels.view(batch_size, 1) target_loss = self.loss["lsgan"](target_labels, true_labels) discriminator_loss += generated_loss + target_loss discriminator_loss = discriminator_loss / len(list_X) discriminator_loss.backward() self.discriminator_optimizer.step() self.writer.set_step((epoch - 1) * len(self.train_data_loader) + batch_idx) self.writer.add_scalar('loss', discriminator_loss.item()) total_loss += discriminator_loss.item() total_metrics += self._eval_metrics(outputs, targets) if self.verbosity >= 2 and batch_idx % self.log_step == 0: self.logger.info( 'Discriminator Pre-Train Epoch: {} [{}/{} ({:.0f}%)] Discriminator Loss: {:.6f}'.format( epoch, batch_idx * self.train_data_loader.batch_size, self.train_data_loader.n_samples, 100.0 * batch_idx / len(self.train_data_loader), discriminator_loss.item())) log = { 'Discriminator_Loss': total_loss / len(self.train_data_loader), 'metrics': (total_metrics / len(self.train_data_loader)).tolist() } return log