def _train_iter(self, iter_num, policy_state, time_step): if not self._driver_started: self._driver.start() self._driver_started = True if not self._config.update_counter_every_mini_batch: common.get_global_counter().assign_add(1) with record_time("time/driver_run"): if iter_num == 0 and self._initial_collect_steps != 0: steps = 0 while steps < self._initial_collect_steps: steps += self._driver.run_async() else: self._driver.run_async() with record_time("time/train"): # `train_steps` might be different from `steps`! train_steps = self._algorithm.train( num_updates=self._num_updates_per_train_step, mini_batch_size=self._mini_batch_size, mini_batch_length=self._mini_batch_length, whole_replay_buffer_training=self. _whole_replay_buffer_training, clear_replay_buffer=self._clear_replay_buffer, update_counter_every_mini_batch=self._config. update_counter_every_mini_batch) return time_step, policy_state, train_steps
def _train_iter(self, iter_num, policy_state, time_step): if not self._config.update_counter_every_mini_batch: common.get_global_counter().assign_add(1) unroll_steps = self._unroll_length * self._envs[0].batch_size max_num_steps = unroll_steps if iter_num == 0 and self._initial_collect_steps != 0: max_num_steps = self._initial_collect_steps with record_time("time/driver_run"): for _ in range((max_num_steps + unroll_steps - 1) // unroll_steps): time_step, policy_state = self._driver.run( max_num_steps=unroll_steps, time_step=time_step, policy_state=policy_state) with record_time("time/train"): # `train_steps` might be different from `max_num_steps`! train_steps = self._algorithm.train( num_updates=self._num_updates_per_train_step, mini_batch_size=self._mini_batch_size, mini_batch_length=self._mini_batch_length, whole_replay_buffer_training=self. _whole_replay_buffer_training, clear_replay_buffer=self._clear_replay_buffer, update_counter_every_mini_batch=self._config. update_counter_every_mini_batch) return time_step, policy_state, train_steps
def _train_iter_off_policy(self): """User may override this for their own training procedure.""" config: TrainerConfig = self._config if not config.update_counter_every_mini_batch: alf.summary.increment_global_counter() with torch.set_grad_enabled(config.unroll_with_grad): with record_time("time/unroll"): self.eval() experience = self.unroll(config.unroll_length) self.summarize_rollout(experience) self.summarize_metrics() self.train() steps = self.train_from_replay_buffer(update_global_counter=True) with record_time("time/after_train_iter"): train_info = experience.rollout_info experience = experience._replace(rollout_info=()) if config.unroll_with_grad: self.after_train_iter(experience, train_info) else: self.after_train_iter(experience) # only off-policy training # For now, we only return the steps of the primary algorithm's training return steps
def _train_iter_on_policy(self): """User may override this for their own training procedure.""" alf.summary.increment_global_counter() with record_time("time/unroll"): experience = self.unroll(self._config.unroll_length) self.summarize_metrics() with record_time("time/train"): train_info = experience.rollout_info experience = experience._replace(rollout_info=()) steps = self.train_from_unroll(experience, train_info) with record_time("time/after_train_iter"): # Here we don't pass ``train_info`` to disable another on-policy # training because otherwise it will backprop on the same graph # twice, which is unnecessary because we could have simply merged # the two trainings into the parent's ``rollout_step``. self.after_train_iter(experience) return steps
def _train(self): for env in self._envs: env.reset() time_step = self._driver.get_initial_time_step() policy_state = self._driver.get_initial_policy_state() iter_num = 0 while True: t0 = time.time() with record_time("time/train_iter"): time_step, policy_state, train_steps = self._train_iter( iter_num=iter_num, policy_state=policy_state, time_step=time_step) t = time.time() - t0 logging.log_every_n_seconds(logging.INFO, '%s time=%.3f throughput=%0.2f' % (iter_num, t, int(train_steps) / t), n_seconds=1) if (iter_num + 1) % self._checkpoint_interval == 0: self._save_checkpoint() if self._evaluate and (iter_num + 1) % self._eval_interval == 0: self._eval() if iter_num == 0: # We need to wait for one iteration to get the operative args # Right just give a fixed gin file name to store operative args common.write_gin_configs(self._root_dir, "configured.gin") with tf.summary.record_if(True): def _markdownify(paragraph): return " ".join( (os.linesep + paragraph).splitlines(keepends=True)) common.summarize_gin_config() tf.summary.text('commandline', ' '.join(sys.argv)) tf.summary.text( 'optimizers', _markdownify(self._algorithm.get_optimizer_info())) tf.summary.text('revision', git_utils.get_revision()) tf.summary.text('diff', _markdownify(git_utils.get_diff())) tf.summary.text('seed', str(self._random_seed)) # check termination env_steps_metric = self._driver.get_step_metrics()[1] total_time_steps = env_steps_metric.result().numpy() iter_num += 1 if (self._num_iterations and iter_num >= self._num_iterations) \ or (self._num_env_steps and total_time_steps >= self._num_env_steps): break
def _train(self): begin_epoch_num = int(self._trainer_progress._iter_num) epoch_num = begin_epoch_num checkpoint_interval = math.ceil(self._num_epochs / self._num_checkpoints) time_to_checkpoint = checkpoint_interval logging.info("==> Begin Training") while True: logging.info("-" * 68) logging.info("Epoch: {}".format(epoch_num + 1)) with record_time("time/train_iter"): self._algorithm.train_iter() if self._evaluate and (epoch_num + 1) % self._eval_interval == 0: self._algorithm.evaluate() if epoch_num == begin_epoch_num: self._summarize_training_setting() # check termination epoch_num += 1 self._trainer_progress.update(epoch_num) if (self._num_epochs and epoch_num >= self._num_epochs): if self._evaluate: self._algorithm.evaluate() break if self._num_epochs and epoch_num >= time_to_checkpoint: self._save_checkpoint() time_to_checkpoint += checkpoint_interval elif self._checkpoint_requested: logging.info("Saving checkpoint upon request...") self._save_checkpoint() self._checkpoint_requested = False if self._debug_requested: self._debug_requested = False import pdb pdb.set_trace()
def train_iter(self, num_particles=None, state=None): """Perform one epoch (iteration) of training. Args: num_particles (int): number of sampled particles. Default is None. state: not used Return: mini_batch number """ assert self._train_loader is not None, "Must set data_loader first." alf.summary.increment_global_counter() with record_time("time/train"): loss = 0. if self._loss_type == 'classification': avg_acc = [] for batch_idx, (data, target) in enumerate(self._train_loader): data = data.to(alf.get_default_device()) target = target.to(alf.get_default_device()) alg_step = self.train_step((data, target), num_particles=num_particles, state=state) loss_info, params = self.update_with_gradient(alg_step.info) loss += loss_info.extra.generator.loss if self._loss_type == 'classification': avg_acc.append(alg_step.info.extra.generator.extra) acc = None if self._loss_type == 'classification': acc = torch.as_tensor(avg_acc).mean() * 100 if self._logging_training: if self._loss_type == 'classification': logging.info("Avg acc: {}".format(acc)) logging.info("Cum loss: {}".format(loss)) self.summarize_train(loss_info, params, cum_loss=loss, avg_acc=acc) return batch_idx + 1
def evaluate(self, num_particles=None): """Evaluate on a randomly drawn ensemble. Args: num_particles (int): number of sampled particles. Default is None. """ assert self._test_loader is not None, "Must set test_loader first." logging.info("==> Begin testing") if self._use_fc_bn: self._generator.eval() params = self.sample_parameters(num_particles=num_particles) self._param_net.set_parameters(params) if self._use_fc_bn: self._generator.train() with record_time("time/test"): if self._loss_type == 'classification': test_acc = 0. test_loss = 0. for i, (data, target) in enumerate(self._test_loader): data = data.to(alf.get_default_device()) target = target.to(alf.get_default_device()) output, _ = self._param_net(data) # [B, N, D] loss, extra = self._vote(output, target) if self._loss_type == 'classification': test_acc += extra.item() test_loss += loss.loss.item() if self._loss_type == 'classification': test_acc /= len(self._test_loader.dataset) alf.summary.scalar(name='eval/test_acc', data=test_acc * 100) if self._logging_evaluate: if self._loss_type == 'classification': logging.info("Test acc: {}".format(test_acc * 100)) logging.info("Test loss: {}".format(test_loss)) alf.summary.scalar(name='eval/test_loss', data=test_loss)
def _train(self): for env in self._envs: env.reset() if self._eval_env: self._eval_env.reset() begin_iter_num = int(self._trainer_progress._iter_num) iter_num = begin_iter_num checkpoint_interval = math.ceil( (self._num_iterations or self._num_env_steps) / self._num_checkpoints) if self._num_iterations: time_to_checkpoint = self._trainer_progress._iter_num + checkpoint_interval else: time_to_checkpoint = self._trainer_progress._env_steps + checkpoint_interval while True: t0 = time.time() with record_time("time/train_iter"): train_steps = self._algorithm.train_iter() t = time.time() - t0 logging.log_every_n_seconds( logging.INFO, '%s -> %s: %s time=%.3f throughput=%0.2f' % (common.get_gin_file(), [ os.path.basename(self._root_dir.strip('/')) ], iter_num, t, int(train_steps) / t), n_seconds=1) if self._evaluate and (iter_num + 1) % self._eval_interval == 0: self._eval() if iter_num == begin_iter_num: self._summarize_training_setting() # check termination env_steps_metric = self._algorithm.get_step_metrics()[1] total_time_steps = env_steps_metric.result() iter_num += 1 self._trainer_progress.update(iter_num, total_time_steps) if ((self._num_iterations and iter_num >= self._num_iterations) or (self._num_env_steps and total_time_steps >= self._num_env_steps)): # Evaluate before exiting so that the eval curve shown in TB # will align with the final iter/env_step. if self._evaluate: self._eval() break if ((self._num_iterations and iter_num >= time_to_checkpoint) or (self._num_env_steps and total_time_steps >= time_to_checkpoint)): self._save_checkpoint() time_to_checkpoint += checkpoint_interval elif self._checkpoint_requested: logging.info("Saving checkpoint upon request...") self._save_checkpoint() self._checkpoint_requested = False if self._debug_requested: self._debug_requested = False import pdb pdb.set_trace()