Beispiel #1
0
    def backward_local_D(self, generator_outputs, ground_truths, real_labels, fake_labels):
        list_X = extract_patch_from_tensor(generator_outputs.detach(), patch_size=(128, 128))
        # todo add CUDA
        list_Y = extract_patch_from_tensor(ground_truths.unsqueeze(1).type(torch.cuda.FloatTensor),
                                           patch_size=(128, 128))

        discriminator_loss = 0
        for X, Y in zip(list_X, list_Y):
            generated_labels = self.local_discriminator(X)
            generated_labels = generated_labels.view(generated_labels.size(0), 1)
            generated_loss = self.loss["lsgan"](generated_labels, fake_labels)
            target_labels = self.local_discriminator(Y)
            target_labels = target_labels.view(target_labels.size(0), 1)
            target_loss = self.loss["lsgan"](target_labels, real_labels)
            discriminator_loss += generated_loss + target_loss

        discriminator_loss = discriminator_loss / len(list_X)
        discriminator_loss.backward()
        return discriminator_loss.item()
Beispiel #2
0
    def backward_G(self, outputs, probs, targets, locations, orig_window_length, full_image, other_full_image,
                   real_labels):
        log_probs, rewards = self.global_discriminator.reward_forward(probs, locations, orig_window_length, full_image,
                                                           other_full_image)
        generator_global_loss = self.loss["pg"](rewards, log_probs)

        # outputs = self.generator(inputs)
        list_G = extract_patch_from_tensor(probs, patch_size=(128, 128))
        generator_local_loss = 0
        for G in list_G:
            generated_labels = self.local_discriminator(G)
            generator_local_loss += self.loss["lsgan"](generated_labels, real_labels)
        generator_local_loss = generator_local_loss / len(list_G)
        generator_cce_loss = self.loss["ce"](outputs, targets)

        generator_loss = self.lambda_1 * generator_cce_loss + self.lambda_2 * generator_local_loss + self.lambda_3 * generator_global_loss
        generator_loss.backward()

        return generator_loss.item(), generator_cce_loss.item(), generator_local_loss.item(), generator_global_loss.item()
Beispiel #3
0
    def _train_discriminator_epoch(self, epoch):
        """
        Pre training logic for an epoch
        :param epoch: Current training epoch
        :return: A log that contrains all information you want to save
        Note:
        If you have additional information to record, for example:
                > additional_log = {"x": x, "y": y}
            merge it with log before return. i.e.
                > log = {**log, **additional_log}
                > return log
            The metrics in log must have the key 'metrics'.
        """
        self.discriminator.train()
        self.generator.eval()
        total_loss = 0
        total_metrics = np.zeros(len(self.metrics))
        for batch_idx, (inputs, targets, _, _, _) in enumerate(self.train_data_loader):

            if batch_idx == int(len(self.train_data_loader) / 4):
                break

            inputs, targets = inputs.to(self.device), targets.to(self.device)
            batch_size = inputs.size(0)
            true_labels = torch.ones((batch_size, 1)).to(self.device)
            fake_labels = -torch.ones((batch_size, 1)).to(self.device)

            self.discriminator_optimizer.zero_grad()
            outputs = self.generator(inputs)
            list_X = extract_patch_from_tensor(outputs[:, :1, :, :].detach(), patch_size=(128, 128))
            list_Y = extract_patch_from_tensor(targets.unsqueeze(1).type(torch.cuda.FloatTensor), patch_size=(128, 128))

            discriminator_loss = 0
            for X, Y in zip(list_X, list_Y):
                generated_labels = self.discriminator(X)
                generated_labels = generated_labels.view(batch_size, 1)
                generated_loss = self.loss["lsgan"](generated_labels, fake_labels)
                target_labels = self.discriminator(Y)
                target_labels = target_labels.view(batch_size, 1)
                target_loss = self.loss["lsgan"](target_labels, true_labels)
                discriminator_loss += generated_loss + target_loss

            discriminator_loss = discriminator_loss / len(list_X)
            discriminator_loss.backward()
            self.discriminator_optimizer.step()

            self.writer.set_step((epoch - 1) * len(self.train_data_loader) + batch_idx)
            self.writer.add_scalar('loss', discriminator_loss.item())
            total_loss += discriminator_loss.item()
            total_metrics += self._eval_metrics(outputs, targets)

            if self.verbosity >= 2 and batch_idx % self.log_step == 0:
                self.logger.info(
                    'Discriminator Pre-Train Epoch: {} [{}/{} ({:.0f}%)] Discriminator Loss: {:.6f}'.format(
                        epoch,
                        batch_idx * self.train_data_loader.batch_size,
                        self.train_data_loader.n_samples,
                        100.0 * batch_idx / len(self.train_data_loader),
                        discriminator_loss.item()))

        log = {
            'Discriminator_Loss': total_loss / len(self.train_data_loader),
            'metrics': (total_metrics / len(self.train_data_loader)).tolist()
        }

        return log