Exemplo n.º 1
0
    def train(self, epoch: int = -1, writer=None) -> str:
        self.put_model_on_device()
        batch = self.data_loader.get_dataset()
        assert len(batch) != 0

        values = self._net.critic(inputs=batch.observations, train=False).squeeze().detach()
        phi_weights = self._calculate_phi(batch, values).to(self._device).squeeze(-1).detach()

        critic_targets = get_reward_to_go(batch).to(self._device) if self._config.phi_key != 'gae' else \
            (values + phi_weights).detach()
        critic_loss_distribution = self._train_critic_clipped(batch, critic_targets, values)
        actor_loss_distribution = self._train_actor_ppo(batch, phi_weights, writer)

        if writer is not None:
            writer.write_distribution(critic_loss_distribution, "critic_loss")
            writer.write_distribution(Distribution(phi_weights.detach()), "phi_weights")
            writer.write_distribution(Distribution(critic_targets.detach()), "critic_targets")

        if self._config.scheduler_config is not None:
            self._actor_scheduler.step()
            self._critic_scheduler.step()
        self._net.global_step += 1
        self.put_model_back_to_original_device()
        return f" training policy loss {actor_loss_distribution.mean: 0.3e} [{actor_loss_distribution.std: 0.2e}], " \
               f"critic loss {critic_loss_distribution.mean: 0.3e} [{critic_loss_distribution.std: 0.3e}]"
    def _train_adversarial_actor_ppo(self, batch: Dataset, phi_weights: torch.Tensor, writer: TensorboardWrapper = None) \
            -> Distribution:
        original_log_probabilities = self._net.adversarial_policy_log_probabilities(inputs=batch.observations,
                                                                                    actions=batch.actions,
                                                                                    train=False).detach()
        list_batch_loss = []
        list_entropy_loss = []
        for _ in range(self._config.max_actor_training_iterations
                       if self._config.max_actor_training_iterations != -1 else 1000):
            for data in self.data_loader.split_data(np.zeros((0,)),  # provide empty array if all data can be selected
                                                    batch.observations,
                                                    batch.actions,
                                                    original_log_probabilities,
                                                    phi_weights):
                mini_batch_observations, mini_batch_actions, \
                    mini_batch_original_log_probabilities, mini_batch_phi_weights = data

                # normalize advantages (phi_weights)
                mini_batch_phi_weights = (mini_batch_phi_weights - mini_batch_phi_weights.mean()) \
                    / (mini_batch_phi_weights.std() + 1e-6)

                new_log_probabilities = self._net.adversarial_policy_log_probabilities(inputs=mini_batch_observations,
                                                                                       actions=mini_batch_actions,
                                                                                       train=True)
                ratio = torch.exp(new_log_probabilities - mini_batch_original_log_probabilities)
                unclipped_loss = ratio * mini_batch_phi_weights
                clipped_loss = ratio.clamp(1 - self._config.epsilon, 1 + self._config.epsilon) \
                    * mini_batch_phi_weights
                surrogate_loss = - torch.min(unclipped_loss, clipped_loss).mean()
                entropy_loss = - self._config.entropy_coefficient * \
                    self._net.get_adversarial_policy_entropy(mini_batch_observations, train=True).mean()

                batch_loss = surrogate_loss + entropy_loss
                kl_approximation = (mini_batch_original_log_probabilities - new_log_probabilities).abs().mean().item()
                if kl_approximation > 1.5 * self._config.kl_target and self._config.use_kl_stop:
                    break
                self._adversarial_actor_optimizer.zero_grad()
                batch_loss.backward()
                if self._config.gradient_clip_norm != -1:
                    nn.utils.clip_grad_norm_(self._net.get_adversarial_actor_parameters(),
                                             self._config.gradient_clip_norm)
                self._adversarial_actor_optimizer.step()
                assert not np.isnan(batch_loss.detach().numpy())
                list_batch_loss.append(batch_loss.detach())
                list_entropy_loss.append(entropy_loss.detach())
        actor_loss_distribution = Distribution(torch.stack(list_batch_loss))
        if writer is not None:
            writer.set_step(self._net.global_step)
            writer.write_distribution(actor_loss_distribution, "adversarial_policy_loss")
            writer.write_distribution(Distribution(torch.stack(list_entropy_loss)), "adversarial_policy_entropy_loss")
            writer.write_scalar(list_batch_loss[-1].item(), 'adversarial_final_policy_loss')
            writer.write_scalar(kl_approximation, 'adversarial_kl_difference')
        return actor_loss_distribution
Exemplo n.º 3
0
    def train(self, epoch: int = -1, writer=None) -> str:
        self.put_model_on_device()
        total_error = []
        task_error = []
        domain_error = []
        for source_batch, target_batch in zip(self.data_loader.sample_shuffled_batch(),
                                              self.target_data_loader.sample_shuffled_batch()):
            self._optimizer.zero_grad()
            targets = data_to_tensor(source_batch.actions).type(self._net.dtype).to(self._device)
            # task loss
            predictions = self._net.forward(source_batch.observations, train=True)
            task_loss = (1 - self._config.epsilon) * self._criterion(predictions, targets).mean()

            # add domain adaptation loss
            domain_loss = self._config.epsilon * self._domain_adaptation_criterion(
                self._net.get_features(source_batch.observations, train=True),
                self._net.get_features(target_batch.observations, train=True))

            loss = task_loss + domain_loss
            loss.backward()
            if self._config.gradient_clip_norm != -1:
                nn.utils.clip_grad_norm_(self._net.parameters(),
                                         self._config.gradient_clip_norm)
            self._optimizer.step()
            self._net.global_step += 1
            task_error.append(task_loss.cpu().detach())
            domain_error.append(domain_loss.cpu().detach())
            total_error.append(loss.cpu().detach())
        self.put_model_back_to_original_device()

        if self._scheduler is not None:
            self._scheduler.step()

        task_error_distribution = Distribution(task_error)
        domain_error_distribution = Distribution(domain_error)
        total_error_distribution = Distribution(total_error)
        if writer is not None:
            writer.set_step(self._net.global_step)
            writer.write_distribution(task_error_distribution, 'training/task_error')
            writer.write_distribution(domain_error_distribution, 'training/domain_error')
            writer.write_distribution(total_error_distribution, 'training/total_error')
            if self._config.store_output_on_tensorboard and epoch % 30 == 0:
                writer.write_output_image(predictions, 'source/predictions')
                writer.write_output_image(targets, 'source/targets')
                writer.write_output_image(torch.stack(source_batch.observations), 'source/inputs')
                writer.write_output_image(self._net.forward(target_batch.observations, train=True),
                                          'target/predictions')
                writer.write_output_image(torch.stack(target_batch.observations), 'target/inputs')

        return f' training task: {self._config.criterion} {task_error_distribution.mean: 0.3e} ' \
               f'[{task_error_distribution.std:0.2e}]' \
               f' domain: {self._config.domain_adaptation_criterion} {domain_error_distribution.mean: 0.3e} ' \
               f'[{domain_error_distribution.std:0.2e}]'
    def _get_result_message(self, test: bool = False, tag: str = '', lowest_return: bool = False):
        msg = f" {'' if not test else 'test'} {tag} "
        msg += f"{self._count_episodes} episodes"
        if self._count_success != 0:
            msg += f" with {self._count_success} success"
            if self._writer is not None:
                self._writer.write_scalar(self._count_success / float(self._count_episodes), "success")
        return_distribution = Distribution(self._episode_returns)
        msg += f" with return {return_distribution.mean: 0.3e} [{return_distribution.std: 0.2e}]"
        if self._writer is not None:
            self._writer.write_scalar(np.mean(self._episode_lengths).item(),
                                      f'{"" if not test else "test_"}episode_lengths{"_"+tag if tag is not "" else ""}')
            self._writer.write_distribution(return_distribution,
                                            f'{"" if not test else "test_"}episode_return{"_"+tag if tag is not "" else ""}')
            self._writer.write_gif(self._frames,
                                   f'{"" if not test else "test_"}episode{"_"+tag if tag is not "" else ""}')

        best_checkpoint = False
        if self._max_mean_return is None or return_distribution.mean > self._max_mean_return:
            self._max_mean_return = return_distribution.mean
            best_checkpoint = True
        if self._min_mean_return is None or return_distribution.mean < self._min_mean_return:
            self._min_mean_return = return_distribution.mean
            if lowest_return:
                best_checkpoint = True
        return msg, best_checkpoint
    def _train_adversarial_critic_clipped(self, batch: Dataset, targets: torch.Tensor, previous_values: torch.Tensor) \
            -> Distribution:
        critic_loss = []
        for value_train_it in range(self._config.max_critic_training_iterations):
            state_indices = np.asarray([index for index in range(len(batch)) if not batch.done[index]])
            for data in self.data_loader.split_data(state_indices,
                                                    batch.observations,
                                                    targets,
                                                    previous_values):
                self._adversarial_critic_optimizer.zero_grad()
                mini_batch_observations, mini_batch_targets, mini_batch_previous_values = data

                batch_values = self._net.adversarial_critic(inputs=mini_batch_observations, train=True).squeeze()
                unclipped_loss = self._criterion(batch_values, mini_batch_targets)
                # absolute clipping
                clipped_values = mini_batch_previous_values + \
                    (batch_values - mini_batch_previous_values).clamp(-self._config.epsilon,
                                                                      self._config.epsilon)
                clipped_loss = self._criterion(clipped_values, mini_batch_targets)
                batch_loss = torch.max(unclipped_loss, clipped_loss)
                batch_loss.mean().backward()
                if self._config.gradient_clip_norm != -1:
                    nn.utils.clip_grad_norm_(self._net.get_adversarial_critic_parameters(),
                                             self._config.gradient_clip_norm)
                self._adversarial_critic_optimizer.step()
                critic_loss.append(batch_loss.mean().detach())
        return Distribution(torch.stack(critic_loss))
 def test_distribution_list(self):
     data = [10] * 5
     distribution = Distribution(data)
     self.assertEqual(distribution.max, 10)
     self.assertEqual(distribution.min, 10)
     self.assertEqual(distribution.mean, 10)
     self.assertEqual(distribution.std, 0)
 def test_distribution_tensor(self):
     data = [10] * 5
     distribution = Distribution(torch.as_tensor(data))
     self.assertEqual(distribution.max, 10)
     self.assertEqual(distribution.min, 10)
     self.assertEqual(distribution.mean, 10)
     self.assertEqual(distribution.std, 0)
Exemplo n.º 8
0
    def evaluate(self,
                 epoch: int = -1,
                 writer=None,
                 tag: str = 'validation') -> Tuple[str, bool]:
        self.put_model_on_device()
        total_error = []
        #        for batch in tqdm(self.data_loader.get_data_batch(), ascii=True, desc='evaluate'):
        for batch in self.data_loader.get_data_batch():
            with torch.no_grad():
                predictions = self._net.forward(batch.observations,
                                                train=False)
                targets = data_to_tensor(batch.actions).type(
                    self._net.dtype).to(self._device)
                error = self._criterion(predictions, targets).mean()
                total_error.append(error)
        error_distribution = Distribution(total_error)
        self.put_model_back_to_original_device()
        if writer is not None:
            writer.write_distribution(error_distribution, tag)
            if self._config.store_output_on_tensorboard and (epoch % 30 == 0
                                                             or tag == 'test'):
                writer.write_output_image(predictions, f'{tag}/predictions')
                writer.write_output_image(targets, f'{tag}/targets')
                writer.write_output_image(torch.stack(batch.observations),
                                          f'{tag}/inputs')

        msg = f' {tag} {self._config.criterion} {error_distribution.mean: 0.3e} [{error_distribution.std:0.2e}]'

        best_checkpoint = False
        if self._lowest_validation_loss is None or error_distribution.mean < self._lowest_validation_loss:
            self._lowest_validation_loss = error_distribution.mean
            best_checkpoint = True
        return msg, best_checkpoint
Exemplo n.º 9
0
    def _train_discriminator_network(self, writer=None) -> str:
        total_error = []
        criterion = nn.BCELoss()
        for sim_batch, real_batch in zip(
                self.data_loader.sample_shuffled_batch(),
                self.target_data_loader.sample_shuffled_batch()):
            self._discriminator_optimizer.zero_grad()
            sim_predictions = torch.cat(
                self._net.forward_with_all_outputs(sim_batch.observations,
                                                   train=False))
            real_predictions = torch.cat(
                self._net.forward_with_all_outputs(real_batch.observations,
                                                   train=False))
            targets = torch.as_tensor([*[0] * len(sim_predictions), *[1] * len(real_predictions)])\
                .type(self._net.dtype).to(self._device)
            outputs = self._net.discriminate(torch.cat(
                [sim_predictions, real_predictions]).unsqueeze(dim=1),
                                             train=True).squeeze(dim=1)
            loss = criterion(outputs, targets)
            loss.mean().backward()
            if self._config.gradient_clip_norm != -1:
                nn.utils.clip_grad_norm_(self._net.discriminator_parameters(),
                                         self._config.gradient_clip_norm)
            self._discriminator_optimizer.step()
            total_error.append(loss.cpu().detach())

        error_distribution = Distribution(total_error)
        if writer is not None:
            writer.write_distribution(error_distribution, 'discriminator_loss')
        return f' train discriminator network BCE {error_distribution.mean: 0.3e}'
    def train(self, epoch: int = -1, writer=None) -> str:
        self.put_model_on_device()
        total_error = []
        for batch in self.data_loader.sample_shuffled_batch():
            self._optimizer.zero_grad()
            targets = data_to_tensor(batch.actions).type(self._net.dtype).to(
                self._device)
            probabilities = self._net.forward_with_all_outputs(
                batch.observations, train=True)
            loss = self._criterion(probabilities[-1], targets).mean()
            for index, prob in enumerate(probabilities[:-1]):
                loss += self._criterion(prob, targets).mean()
            loss.mean().backward()
            if self._config.gradient_clip_norm != -1:
                nn.utils.clip_grad_norm_(self._net.parameters(),
                                         self._config.gradient_clip_norm)
            self._optimizer.step()
            self._net.global_step += 1
            total_error.append(loss.cpu().detach())
        self.put_model_back_to_original_device()

        error_distribution = Distribution(total_error)
        if writer is not None:
            writer.set_step(self._net.global_step)
            writer.write_distribution(error_distribution, 'training')
            if self._config.store_output_on_tensorboard and epoch % 30 == 0:
                for index, prob in enumerate(probabilities):
                    writer.write_output_image(prob,
                                              f'training/predictions_{index}')
                writer.write_output_image(targets, 'training/targets')
                writer.write_output_image(torch.stack(batch.observations),
                                          'training/inputs')
        return f' training {self._config.criterion} {error_distribution.mean: 0.3e} [{error_distribution.std:0.2e}]'
    def train(self, epoch: int = -1, writer=None) -> str:
        self.put_model_on_device()
        total_error = []
        for batch in self.data_loader.sample_shuffled_batch():
            self._optimizer.zero_grad()
            targets = data_to_tensor(batch.actions).type(self._net.dtype).to(
                self._device)
            probabilities = self._net.forward_with_all_outputs(
                batch.observations, train=True)
            loss = self._criterion(probabilities[-1], targets).mean()
            for index, prob in enumerate(probabilities[:-1]):
                loss += self._criterion(prob, targets).mean()
            loss.mean().backward()
            if self._config.gradient_clip_norm != -1:
                nn.utils.clip_grad_norm_(self._net.parameters(),
                                         self._config.gradient_clip_norm)
            self._optimizer.step()
            self._net.global_step += 1
            total_error.append(loss.cpu().detach())
        self.put_model_back_to_original_device()

        error_distribution = Distribution(total_error)
        if writer is not None:
            writer.set_step(self._net.global_step)
            writer.write_distribution(error_distribution, 'training')
            if self._config.store_output_on_tensorboard and epoch % 30 == 0:
                for index, prob in enumerate(probabilities):
                    writer.write_output_image(prob,
                                              f'training/predictions_{index}')
                writer.write_output_image(targets, 'training/targets')
                writer.write_output_image(torch.stack(batch.observations),
                                          'training/inputs')
            if self._config.store_feature_maps_on_tensorboard and epoch % 30 == 0:
                outputs = self._net.forward_with_intermediate_outputs(
                    batch.observations, train=False)
                for i in range(4):  # store first 5 images of batch
                    for layer in ['x1', 'x2', 'x3', 'x4']:
                        feature_maps = outputs[layer][i].flatten(start_dim=0,
                                                                 end_dim=0)
                        title = f'feature_map/layer_{layer}/{i}'
                        # title += 'inds_' + '_'.join([str(v.item()) for v in winning_indices.indices])
                        # title += '_vals_' + '_'.join([f'{v.item():0.2f}' for v in winning_indices.values])
                        writer.write_output_image(feature_maps, title)
            writer.write_figure(tag='gradient',
                                figure=plot_gradient_flow(
                                    self._net.named_parameters()))
        return f' training {self._config.criterion} {error_distribution.mean: 0.3e} [{error_distribution.std:0.2e}]'
Exemplo n.º 12
0
    def train(self, epoch: int = -1, writer=None) -> str:
        self.put_model_on_device()
        total_error = []
        #        for batch in tqdm(self.data_loader.sample_shuffled_batch(), ascii=True, desc='train'):
        for batch in self.data_loader.sample_shuffled_batch():
            self._optimizer.zero_grad()
            targets = data_to_tensor(batch.actions).type(self._net.dtype).to(
                self._device)
            if self._config.add_KL_divergence_loss:
                predictions, mean, std = self._net.forward_with_distribution(
                    batch.observations, train=True)
            else:
                predictions = self._net.forward(batch.observations, train=True)

            loss = self._criterion(predictions, targets).mean()
            if self._config.add_KL_divergence_loss:
                # https://arxiv.org/pdf/1312.6114.pdf
                KL_loss = -0.5 * torch.sum(1 + std.pow(2).log() - mean.pow(2) -
                                           std.pow(2))
                loss += KL_loss

            loss.backward()
            if self._config.gradient_clip_norm != -1:
                nn.utils.clip_grad_norm_(self._net.parameters(),
                                         self._config.gradient_clip_norm)
            self._optimizer.step()
            self._net.global_step += 1
            total_error.append(loss.cpu().detach())
        self.put_model_back_to_original_device()

        if self._scheduler is not None:
            self._scheduler.step()

        error_distribution = Distribution(total_error)
        if writer is not None:
            writer.set_step(self._net.global_step)
            writer.write_distribution(error_distribution, 'training')
            if self._config.add_KL_divergence_loss:
                writer.write_scalar(KL_loss, 'KL_divergence')
            if self._config.store_output_on_tensorboard and epoch % 30 == 0:
                writer.write_output_image(predictions, 'training/predictions')
                writer.write_output_image(targets, 'training/targets')
                writer.write_output_image(torch.stack(batch.observations),
                                          'training/inputs')
        return f' training {self._config.criterion} {error_distribution.mean: 0.3e} [{error_distribution.std:0.2e}]'
Exemplo n.º 13
0
    def train(self, epoch: int = -1, writer=None) -> str:
        self.put_model_on_device()
        batch = self.data_loader.get_dataset()
        assert len(batch) != 0

        values = self._net.critic(inputs=batch.observations,
                                  train=False).squeeze().detach()
        phi_weights = self._calculate_phi(batch, values).to(self._device)
        policy_loss = self._train_actor(batch, phi_weights)
        critic_loss = Distribution(
            self._train_critic(batch,
                               get_reward_to_go(batch).to(self._device)))

        if writer is not None:
            writer.set_step(self._net.global_step)
            writer.write_scalar(policy_loss.data, "policy_loss")
            writer.write_distribution(critic_loss, "critic_loss")

        self._net.global_step += 1
        self.put_model_back_to_original_device()
        if self._config.scheduler_config is not None:
            self._actor_scheduler.step()
            self._critic_scheduler.step()
        return f" training policy loss {policy_loss.data: 0.3e}, critic loss {critic_loss.mean: 0.3e}"
Exemplo n.º 14
0
    def train(self, epoch: int = -1, writer=None) -> str:
        self.put_model_on_device()
        total_error = []
        task_error = []
        domain_error = []
        for source_batch, target_batch in zip(self.data_loader.sample_shuffled_batch(),
                                              self.target_data_loader.sample_shuffled_batch()):
            self._optimizer.zero_grad()
            targets = data_to_tensor(source_batch.actions).type(self._net.dtype).to(self._device)

            # deep supervision loss
            probabilities = self._net.forward_with_all_outputs(source_batch.observations, train=True)
            task_loss = self._criterion(probabilities[-1], targets).mean()
            for index, prob in enumerate(probabilities[:-1]):
                task_loss += self._criterion(prob, targets).mean()
            task_loss *= (1 - self._config.epsilon)

            # add domain adaptation loss on distribution of output pixels at each output
            domain_loss = sum([self._domain_adaptation_criterion(sp.flatten().unsqueeze(1), tp.flatten().unsqueeze(1))
                               for sp, tp in zip(self._net.forward_with_all_outputs(source_batch.observations,
                                                                                    train=True),
                                                 self._net.forward_with_all_outputs(target_batch.observations,
                                                                                    train=True))
                               ]) * self._config.epsilon

            loss = task_loss + domain_loss
            loss.backward()
            if self._config.gradient_clip_norm != -1:
                nn.utils.clip_grad_norm_(self._net.parameters(),
                                         self._config.gradient_clip_norm)
            self._optimizer.step()
            self._net.global_step += 1
            task_error.append(task_loss.cpu().detach().numpy())
            domain_error.append(domain_loss.cpu().detach().numpy())
            total_error.append(loss.cpu().detach().numpy())

        self.put_model_back_to_original_device()

        if self._scheduler is not None:
            self._scheduler.step()

        task_error_distribution = Distribution(task_error)
        domain_error_distribution = Distribution(domain_error)
        total_error_distribution = Distribution(total_error)
        if writer is not None:
            writer.set_step(self._net.global_step)
            writer.write_distribution(task_error_distribution, 'training/task_error')
            writer.write_distribution(domain_error_distribution, 'training/domain_error')
            writer.write_distribution(total_error_distribution, 'training/total_error')
            if self._config.store_output_on_tensorboard and epoch % 30 == 0:
                writer.write_output_image(probabilities[-1], 'source/predictions')
                writer.write_output_image(targets, 'source/targets')
                writer.write_output_image(torch.stack(source_batch.observations), 'source/inputs')
                writer.write_output_image(self._net.forward(target_batch.observations, train=False),
                                          'target/predictions')
                writer.write_output_image(torch.stack(target_batch.observations), 'target/inputs')
            if self._config.store_feature_maps_on_tensorboard and epoch % 30 == 0:
                for name, batch in zip(['source', 'target'], [source_batch, target_batch]):
                    outputs = self._net.forward_with_intermediate_outputs(batch.observations, train=False)
                    for i in range(4):  # store first 5 images of batch
                        for layer in ['x1', 'x2', 'x3', 'x4']:
                            feature_maps = outputs[layer][i].flatten(start_dim=0, end_dim=0)
                            title = f'feature_map/{name}/layer_{layer}/{i}'
                            # title += 'inds_' + '_'.join([str(v.item()) for v in winning_indices.indices])
                            # title += '_vals_' + '_'.join([f'{v.item():0.2f}' for v in winning_indices.values])
                            writer.write_output_image(feature_maps, title)
        return f' task {self._config.criterion} ' \
               f'{task_error_distribution.mean: 0.3e} ' \
               f'[{task_error_distribution.std:0.2e}] ' \
               f' domain {self._config.domain_adaptation_criterion} ' \
               f'{domain_error_distribution.mean: 0.3e} ' \
               f'[{domain_error_distribution.std: 0.2e}]'
Exemplo n.º 15
0
    def _train_main_network(self, epoch: int = -1, writer=None) -> str:
        deeply_supervised_error = []
        discriminator_error = []
        for sim_batch, real_batch in zip(
                self.data_loader.sample_shuffled_batch(),
                self.target_data_loader.sample_shuffled_batch()):
            self._optimizer.zero_grad()

            # normal deep supervision loss
            targets = data_to_tensor(sim_batch.actions).type(
                self._net.dtype).to(self._device)
            probabilities = self._net.forward_with_all_outputs(
                sim_batch.observations, train=True)
            loss = self._criterion(probabilities[-1], targets).mean()
            for index, prob in enumerate(probabilities[:-1]):
                loss += self._criterion(prob, targets).mean()
            deeply_supervised_error.append(loss.mean().cpu().detach())

            # adversarial loss on discriminator data
            network_outputs = torch.cat(
                self._net.forward_with_all_outputs(
                    real_batch.observations, train=True)).unsqueeze(dim=1)
            discriminator_loss = self._net.discriminate(network_outputs,
                                                        train=False).mean()
            #results = self._net.forward_with_intermediate_outputs(real_batch.observations, train=True)
            #feature_maps =

            # combine losses with epsilon weight
            loss *= (1 - self._config.epsilon)
            loss += self._config.epsilon * discriminator_loss
            loss.mean().backward()
            discriminator_error.append(
                discriminator_loss.mean().cpu().detach())
            # clip gradients
            if self._config.gradient_clip_norm != -1:
                nn.utils.clip_grad_norm_(self._net.parameters(),
                                         self._config.gradient_clip_norm)
            self._optimizer.step()

        supervised_error_distribution = Distribution(deeply_supervised_error)
        discriminator_error_distribution = Distribution(discriminator_error)
        if writer is not None:
            writer.write_distribution(supervised_error_distribution,
                                      'training_loss_from_deep_supervision')
            writer.write_distribution(discriminator_error_distribution,
                                      'training_loss_from_discriminator')
            if self._config.store_output_on_tensorboard and epoch % 30 == 0:
                for index, prob in enumerate(probabilities):
                    writer.write_output_image(prob,
                                              f'training/predictions_{index}')
                writer.write_output_image(targets, 'training/targets')
                writer.write_output_image(torch.stack(sim_batch.observations),
                                          'training/inputs')
            for index, prob in enumerate(
                    self._net.forward_with_all_outputs(real_batch.observations,
                                                       train=False)):
                writer.write_output_image(prob, f'real/predictions_{index}')
            writer.write_output_image(torch.stack(real_batch.observations),
                                      'real/inputs')
        return f' Training: supervision {self._config.criterion} {supervised_error_distribution.mean: 0.3e} ' \
               f'[{supervised_error_distribution.std:0.2e}]' \
               f' discriminator {discriminator_error_distribution.mean: 0.3e} ' \
               f'[{discriminator_error_distribution.std:0.2e}]'