def validate(self, loader, num_batches_to_return=0): epoch_data = [] epoch_loss_metrics = {} epoch_metrics = {} self._set_test() self.data_iter = iter(loader) for current_batch in range(len(loader)): loss_metrics, data = self._val_step(loader) if data is None: break if len(epoch_data) < num_batches_to_return: epoch_data.append(utils.cpuify(data)) metrics = self._compute_test_metrics(data) # Remove reference to ensure that GPU memory is freed del data for name, loss_metric in loss_metrics.items(): accumulate_metric(epoch_loss_metrics, name, loss_metric) for name, metric in metrics.items(): accumulate_metric(epoch_metrics, name, metric) value_by_loss = {name: loss_metric.average() for name, loss_metric in epoch_loss_metrics.items()} value_by_metric = {name: metric.average() for name, metric in epoch_metrics.items()} return epoch_data, value_by_loss, value_by_metric
def train_epoch(self, loader, epoch, summary_writer=None, steps_per_train_summary=1, verbose=False): num_batches_per_epoch = len(loader) epoch_loss_metrics = {} epoch_metrics = {} self._set_train() self.data_iter = iter(loader) current_batch = 0 while current_batch < num_batches_per_epoch: num_batches, loss_metrics, data = self._train_step(loader) if num_batches == 0: break current_batch += num_batches metrics = self._compute_train_metrics(data) for name, loss_metric in loss_metrics.items(): accumulate_metric(epoch_loss_metrics, name, loss_metric) for name, metric in metrics.items(): accumulate_metric(epoch_metrics, name, metric) global_step = num_batches_per_epoch * (epoch - 1) + current_batch if global_step % steps_per_train_summary == 0: s = '===> Epoch[{}]({}/{}): '.format(epoch, current_batch, num_batches_per_epoch) s += ', '.join(('{}: {}'.format(name, loss_metric) for name, loss_metric in loss_metrics.items())) s += '\n' if verbose: s += '\n'.join((' {}: {}'.format(name, metric) for name, metric in metrics.items())) print(s) if summary_writer is not None: for name, metric in chain(loss_metrics.items(), metrics.items()): summary_writer.add_scalar('train/{}'.format(name), metric.value, global_step) value_by_loss = { loss: loss_value.average() for loss, loss_value in epoch_loss_metrics.items() } value_by_metric = { metric: metric_value.average() for metric, metric_value in epoch_metrics.items() } return value_by_loss, value_by_metric
def _train_multiple_steps(self, loader): """Train generator and discriminator for multiple steps at once""" last_batch = None max_updates = max(self.disc_updates_per_step, self.gen_updates_per_step) # Deque input data upfront (this could lead to memory problems) batches = [] for _ in range(max_updates): batch = self._request_data(loader) if batch is None: break batches.append(batch) gen_uses_feature_matching = 'FeatureMatching' in self.gen_adv_criteria loss_metrics = {} # Train discriminator for idx, batch in enumerate(batches[:self.disc_updates_per_step]): if not self.discriminator_enabled: continue last_batch = batch # Propagate fake image through discriminator gen_inp = self.train_model_input_fn(batch) out_gen = self.gen(*gen_inp) out_disc_fake = self.disc( self.disc_input_fn(out_gen, gen_inp[0], out_gen, is_real_input=False, detach=True)) # Propagate real images through discriminator target = batch['target'] out_disc_real = self.disc( self.disc_input_fn(target, gen_inp[0], out_gen, is_real_input=True, detach=True)) disc_losses = [] # Compute discriminator losses for name, criterion in self.disc_adv_criteria.items(): loss = criterion(out_disc_fake, out_disc_real) disc_losses.append(loss) accumulate_metric(loss_metrics, 'disc_loss_' + name, get_loss_metric(loss.data[0])) # Perform discriminator update total_disc_loss = self._update_step(self.disc_optimizer, disc_losses, self.disc_loss_weights) accumulate_metric(loss_metrics, 'disc_loss', get_loss_metric(total_disc_loss.data[0])) if idx < len(batches) - 1 and idx < self.disc_updates_per_step - 1: del out_gen del out_disc_real del out_disc_fake elif self.generator_enabled: del out_gen del out_disc_fake # Train generator for idx, batch in enumerate(batches[:self.gen_updates_per_step]): if not self.generator_enabled: continue last_batch = batch gen_losses = [] gen_inp = self.train_model_input_fn(batch) out_gen = self.gen(*gen_inp) if self.discriminator_enabled: # Propagate again with non-detached input to allow gradients on the # generator out_disc_fake = self.disc( self.disc_input_fn(out_gen, gen_inp[0], out_gen, is_real_input=False, detach=False)) if gen_uses_feature_matching: # Only need to compute the discriminator output for real targets if we # use feature matching loss target = batch['target'] out_disc_real = self.disc( self.disc_input_fn(target, gen_inp[0], out_gen, is_real_input=True, detach=True)) else: out_disc_real = None # Compute adversarial generator losses from discriminator output # Order matters: first compute adversarial losses for generator, then # the other generator losses. Otherwise the loss weights will not match for name, criterion in self.gen_adv_criteria.items(): loss = criterion(out_disc_fake, out_disc_real) gen_losses.append(loss) accumulate_metric(loss_metrics, 'gen_loss_' + name, get_loss_metric(loss.data[0])) # Compute generator losses on prediction and target image for name, criterion in self.gen_criteria.items(): loss = criterion(out_gen, batch) gen_losses.append(loss) accumulate_metric(loss_metrics, 'gen_loss_' + name, get_loss_metric(loss.data[0])) # Perform generator update total_gen_loss = self._update_step(self.gen_optimizer, gen_losses, self.gen_loss_weights) accumulate_metric(loss_metrics, 'gen_loss', get_loss_metric(total_gen_loss.data[0])) if idx < len(batch) - 1 and idx < self.gen_updates_per_step - 1: del out_gen if self.discriminator_enabled: del out_disc_fake if out_disc_real is not None: del out_disc_real # For simplicity, we just return the last batch of data # This is a bit of a smell, as our metrics will only be on this last batch # of data, whereas the loss metrics are averaged over all updates if len(batches) > 0: avg_loss_metrics = { name: metric.average() for name, metric in loss_metrics.items() } if not self.discriminator_enabled: out_disc_fake = None out_disc_real = None data = (last_batch, out_gen, out_disc_fake, out_disc_real) else: avg_loss_metrics = None data = None return len(batches), avg_loss_metrics, data