Пример #1
0
  def validate(self, loader, num_batches_to_return=0):
    epoch_data = []
    epoch_loss_metrics = {}
    epoch_metrics = {}
    self._set_test()

    self.data_iter = iter(loader)

    for current_batch in range(len(loader)):
      loss_metrics, data = self._val_step(loader)
      if data is None:
        break

      if len(epoch_data) < num_batches_to_return:
        epoch_data.append(utils.cpuify(data))

      metrics = self._compute_test_metrics(data)

      # Remove reference to ensure that GPU memory is freed
      del data

      for name, loss_metric in loss_metrics.items():
        accumulate_metric(epoch_loss_metrics, name, loss_metric)
      for name, metric in metrics.items():
        accumulate_metric(epoch_metrics, name, metric)

    value_by_loss = {name: loss_metric.average()
                     for name, loss_metric in epoch_loss_metrics.items()}
    value_by_metric = {name: metric.average()
                       for name, metric in epoch_metrics.items()}

    return epoch_data, value_by_loss, value_by_metric
Пример #2
0
    def train_epoch(self,
                    loader,
                    epoch,
                    summary_writer=None,
                    steps_per_train_summary=1,
                    verbose=False):
        num_batches_per_epoch = len(loader)
        epoch_loss_metrics = {}
        epoch_metrics = {}
        self._set_train()

        self.data_iter = iter(loader)

        current_batch = 0
        while current_batch < num_batches_per_epoch:
            num_batches, loss_metrics, data = self._train_step(loader)
            if num_batches == 0:
                break

            current_batch += num_batches

            metrics = self._compute_train_metrics(data)

            for name, loss_metric in loss_metrics.items():
                accumulate_metric(epoch_loss_metrics, name, loss_metric)
            for name, metric in metrics.items():
                accumulate_metric(epoch_metrics, name, metric)

            global_step = num_batches_per_epoch * (epoch - 1) + current_batch
            if global_step % steps_per_train_summary == 0:
                s = '===> Epoch[{}]({}/{}): '.format(epoch, current_batch,
                                                     num_batches_per_epoch)
                s += ', '.join(('{}: {}'.format(name, loss_metric)
                                for name, loss_metric in loss_metrics.items()))
                s += '\n'
                if verbose:
                    s += '\n'.join(('     {}: {}'.format(name, metric)
                                    for name, metric in metrics.items()))
                print(s)

                if summary_writer is not None:
                    for name, metric in chain(loss_metrics.items(),
                                              metrics.items()):
                        summary_writer.add_scalar('train/{}'.format(name),
                                                  metric.value, global_step)

        value_by_loss = {
            loss: loss_value.average()
            for loss, loss_value in epoch_loss_metrics.items()
        }
        value_by_metric = {
            metric: metric_value.average()
            for metric, metric_value in epoch_metrics.items()
        }
        return value_by_loss, value_by_metric
    def _train_multiple_steps(self, loader):
        """Train generator and discriminator for multiple steps at once"""
        last_batch = None
        max_updates = max(self.disc_updates_per_step,
                          self.gen_updates_per_step)

        # Deque input data upfront (this could lead to memory problems)
        batches = []
        for _ in range(max_updates):
            batch = self._request_data(loader)
            if batch is None:
                break
            batches.append(batch)

        gen_uses_feature_matching = 'FeatureMatching' in self.gen_adv_criteria
        loss_metrics = {}

        # Train discriminator
        for idx, batch in enumerate(batches[:self.disc_updates_per_step]):
            if not self.discriminator_enabled:
                continue

            last_batch = batch

            # Propagate fake image through discriminator
            gen_inp = self.train_model_input_fn(batch)
            out_gen = self.gen(*gen_inp)
            out_disc_fake = self.disc(
                self.disc_input_fn(out_gen,
                                   gen_inp[0],
                                   out_gen,
                                   is_real_input=False,
                                   detach=True))

            # Propagate real images through discriminator
            target = batch['target']
            out_disc_real = self.disc(
                self.disc_input_fn(target,
                                   gen_inp[0],
                                   out_gen,
                                   is_real_input=True,
                                   detach=True))

            disc_losses = []
            # Compute discriminator losses
            for name, criterion in self.disc_adv_criteria.items():
                loss = criterion(out_disc_fake, out_disc_real)
                disc_losses.append(loss)
                accumulate_metric(loss_metrics, 'disc_loss_' + name,
                                  get_loss_metric(loss.data[0]))

            # Perform discriminator update
            total_disc_loss = self._update_step(self.disc_optimizer,
                                                disc_losses,
                                                self.disc_loss_weights)
            accumulate_metric(loss_metrics, 'disc_loss',
                              get_loss_metric(total_disc_loss.data[0]))

            if idx < len(batches) - 1 and idx < self.disc_updates_per_step - 1:
                del out_gen
                del out_disc_real
                del out_disc_fake
            elif self.generator_enabled:
                del out_gen
                del out_disc_fake

        # Train generator
        for idx, batch in enumerate(batches[:self.gen_updates_per_step]):
            if not self.generator_enabled:
                continue

            last_batch = batch
            gen_losses = []

            gen_inp = self.train_model_input_fn(batch)
            out_gen = self.gen(*gen_inp)

            if self.discriminator_enabled:
                # Propagate again with non-detached input to allow gradients on the
                # generator
                out_disc_fake = self.disc(
                    self.disc_input_fn(out_gen,
                                       gen_inp[0],
                                       out_gen,
                                       is_real_input=False,
                                       detach=False))
                if gen_uses_feature_matching:
                    # Only need to compute the discriminator output for real targets if we
                    # use feature matching loss
                    target = batch['target']
                    out_disc_real = self.disc(
                        self.disc_input_fn(target,
                                           gen_inp[0],
                                           out_gen,
                                           is_real_input=True,
                                           detach=True))
                else:
                    out_disc_real = None

                # Compute adversarial generator losses from discriminator output
                # Order matters: first compute adversarial losses for generator, then
                # the other generator losses. Otherwise the loss weights will not match
                for name, criterion in self.gen_adv_criteria.items():
                    loss = criterion(out_disc_fake, out_disc_real)
                    gen_losses.append(loss)
                    accumulate_metric(loss_metrics, 'gen_loss_' + name,
                                      get_loss_metric(loss.data[0]))

            # Compute generator losses on prediction and target image
            for name, criterion in self.gen_criteria.items():
                loss = criterion(out_gen, batch)
                gen_losses.append(loss)
                accumulate_metric(loss_metrics, 'gen_loss_' + name,
                                  get_loss_metric(loss.data[0]))

            # Perform generator update
            total_gen_loss = self._update_step(self.gen_optimizer, gen_losses,
                                               self.gen_loss_weights)
            accumulate_metric(loss_metrics, 'gen_loss',
                              get_loss_metric(total_gen_loss.data[0]))

            if idx < len(batch) - 1 and idx < self.gen_updates_per_step - 1:
                del out_gen
                if self.discriminator_enabled:
                    del out_disc_fake
                    if out_disc_real is not None:
                        del out_disc_real

        # For simplicity, we just return the last batch of data
        # This is a bit of a smell, as our metrics will only be on this last batch
        # of data, whereas the loss metrics are averaged over all updates
        if len(batches) > 0:
            avg_loss_metrics = {
                name: metric.average()
                for name, metric in loss_metrics.items()
            }
            if not self.discriminator_enabled:
                out_disc_fake = None
                out_disc_real = None
            data = (last_batch, out_gen, out_disc_fake, out_disc_real)
        else:
            avg_loss_metrics = None
            data = None

        return len(batches), avg_loss_metrics, data