예제 #1
0
 def _train_iter(self, iter_num, policy_state, time_step):
     if not self._driver_started:
         self._driver.start()
         self._driver_started = True
     if not self._config.update_counter_every_mini_batch:
         common.get_global_counter().assign_add(1)
     with record_time("time/driver_run"):
         if iter_num == 0 and self._initial_collect_steps != 0:
             steps = 0
             while steps < self._initial_collect_steps:
                 steps += self._driver.run_async()
         else:
             self._driver.run_async()
     with record_time("time/train"):
         # `train_steps` might be different from `steps`!
         train_steps = self._algorithm.train(
             num_updates=self._num_updates_per_train_step,
             mini_batch_size=self._mini_batch_size,
             mini_batch_length=self._mini_batch_length,
             whole_replay_buffer_training=self.
             _whole_replay_buffer_training,
             clear_replay_buffer=self._clear_replay_buffer,
             update_counter_every_mini_batch=self._config.
             update_counter_every_mini_batch)
     return time_step, policy_state, train_steps
예제 #2
0
 def _train_iter(self, iter_num, policy_state, time_step):
     if not self._config.update_counter_every_mini_batch:
         common.get_global_counter().assign_add(1)
     unroll_steps = self._unroll_length * self._envs[0].batch_size
     max_num_steps = unroll_steps
     if iter_num == 0 and self._initial_collect_steps != 0:
         max_num_steps = self._initial_collect_steps
     with record_time("time/driver_run"):
         for _ in range((max_num_steps + unroll_steps - 1) // unroll_steps):
             time_step, policy_state = self._driver.run(
                 max_num_steps=unroll_steps,
                 time_step=time_step,
                 policy_state=policy_state)
     with record_time("time/train"):
         # `train_steps` might be different from `max_num_steps`!
         train_steps = self._algorithm.train(
             num_updates=self._num_updates_per_train_step,
             mini_batch_size=self._mini_batch_size,
             mini_batch_length=self._mini_batch_length,
             whole_replay_buffer_training=self.
             _whole_replay_buffer_training,
             clear_replay_buffer=self._clear_replay_buffer,
             update_counter_every_mini_batch=self._config.
             update_counter_every_mini_batch)
     return time_step, policy_state, train_steps
예제 #3
0
    def _train_iter_off_policy(self):
        """User may override this for their own training procedure."""
        config: TrainerConfig = self._config

        if not config.update_counter_every_mini_batch:
            alf.summary.increment_global_counter()

        with torch.set_grad_enabled(config.unroll_with_grad):
            with record_time("time/unroll"):
                self.eval()
                experience = self.unroll(config.unroll_length)
                self.summarize_rollout(experience)
                self.summarize_metrics()

        self.train()
        steps = self.train_from_replay_buffer(update_global_counter=True)

        with record_time("time/after_train_iter"):
            train_info = experience.rollout_info
            experience = experience._replace(rollout_info=())
            if config.unroll_with_grad:
                self.after_train_iter(experience, train_info)
            else:
                self.after_train_iter(experience)  # only off-policy training

        # For now, we only return the steps of the primary algorithm's training
        return steps
예제 #4
0
    def _train_iter_on_policy(self):
        """User may override this for their own training procedure."""
        alf.summary.increment_global_counter()

        with record_time("time/unroll"):
            experience = self.unroll(self._config.unroll_length)
            self.summarize_metrics()

        with record_time("time/train"):
            train_info = experience.rollout_info
            experience = experience._replace(rollout_info=())
            steps = self.train_from_unroll(experience, train_info)

        with record_time("time/after_train_iter"):
            # Here we don't pass ``train_info`` to disable another on-policy
            # training because otherwise it will backprop on the same graph
            # twice, which is unnecessary because we could have simply merged
            # the two trainings into the parent's ``rollout_step``.
            self.after_train_iter(experience)

        return steps
예제 #5
0
    def _train(self):
        for env in self._envs:
            env.reset()
        time_step = self._driver.get_initial_time_step()
        policy_state = self._driver.get_initial_policy_state()
        iter_num = 0
        while True:
            t0 = time.time()
            with record_time("time/train_iter"):
                time_step, policy_state, train_steps = self._train_iter(
                    iter_num=iter_num,
                    policy_state=policy_state,
                    time_step=time_step)
            t = time.time() - t0
            logging.log_every_n_seconds(logging.INFO,
                                        '%s time=%.3f throughput=%0.2f' %
                                        (iter_num, t, int(train_steps) / t),
                                        n_seconds=1)
            if (iter_num + 1) % self._checkpoint_interval == 0:
                self._save_checkpoint()
            if self._evaluate and (iter_num + 1) % self._eval_interval == 0:
                self._eval()
            if iter_num == 0:
                # We need to wait for one iteration to get the operative args
                # Right just give a fixed gin file name to store operative args
                common.write_gin_configs(self._root_dir, "configured.gin")
                with tf.summary.record_if(True):

                    def _markdownify(paragraph):
                        return "    ".join(
                            (os.linesep + paragraph).splitlines(keepends=True))

                    common.summarize_gin_config()
                    tf.summary.text('commandline', ' '.join(sys.argv))
                    tf.summary.text(
                        'optimizers',
                        _markdownify(self._algorithm.get_optimizer_info()))
                    tf.summary.text('revision', git_utils.get_revision())
                    tf.summary.text('diff', _markdownify(git_utils.get_diff()))
                    tf.summary.text('seed', str(self._random_seed))

            # check termination
            env_steps_metric = self._driver.get_step_metrics()[1]
            total_time_steps = env_steps_metric.result().numpy()
            iter_num += 1
            if (self._num_iterations and iter_num >= self._num_iterations) \
                or (self._num_env_steps and total_time_steps >= self._num_env_steps):
                break
예제 #6
0
    def _train(self):
        begin_epoch_num = int(self._trainer_progress._iter_num)
        epoch_num = begin_epoch_num

        checkpoint_interval = math.ceil(self._num_epochs /
                                        self._num_checkpoints)
        time_to_checkpoint = checkpoint_interval

        logging.info("==> Begin Training")
        while True:
            logging.info("-" * 68)
            logging.info("Epoch: {}".format(epoch_num + 1))
            with record_time("time/train_iter"):
                self._algorithm.train_iter()

            if self._evaluate and (epoch_num + 1) % self._eval_interval == 0:
                self._algorithm.evaluate()

            if epoch_num == begin_epoch_num:
                self._summarize_training_setting()

            # check termination
            epoch_num += 1
            self._trainer_progress.update(epoch_num)

            if (self._num_epochs and epoch_num >= self._num_epochs):
                if self._evaluate:
                    self._algorithm.evaluate()
                break

            if self._num_epochs and epoch_num >= time_to_checkpoint:
                self._save_checkpoint()
                time_to_checkpoint += checkpoint_interval
            elif self._checkpoint_requested:
                logging.info("Saving checkpoint upon request...")
                self._save_checkpoint()
                self._checkpoint_requested = False

            if self._debug_requested:
                self._debug_requested = False
                import pdb
                pdb.set_trace()
예제 #7
0
    def train_iter(self, num_particles=None, state=None):
        """Perform one epoch (iteration) of training.

        Args:
            num_particles (int): number of sampled particles. Default is None.
            state: not used

        Return:
            mini_batch number
        """

        assert self._train_loader is not None, "Must set data_loader first."
        alf.summary.increment_global_counter()
        with record_time("time/train"):
            loss = 0.
            if self._loss_type == 'classification':
                avg_acc = []
            for batch_idx, (data, target) in enumerate(self._train_loader):
                data = data.to(alf.get_default_device())
                target = target.to(alf.get_default_device())
                alg_step = self.train_step((data, target),
                                           num_particles=num_particles,
                                           state=state)
                loss_info, params = self.update_with_gradient(alg_step.info)
                loss += loss_info.extra.generator.loss
                if self._loss_type == 'classification':
                    avg_acc.append(alg_step.info.extra.generator.extra)
        acc = None
        if self._loss_type == 'classification':
            acc = torch.as_tensor(avg_acc).mean() * 100
        if self._logging_training:
            if self._loss_type == 'classification':
                logging.info("Avg acc: {}".format(acc))
            logging.info("Cum loss: {}".format(loss))
        self.summarize_train(loss_info, params, cum_loss=loss, avg_acc=acc)

        return batch_idx + 1
예제 #8
0
    def evaluate(self, num_particles=None):
        """Evaluate on a randomly drawn ensemble. 

        Args:
            num_particles (int): number of sampled particles. Default is None.
        """

        assert self._test_loader is not None, "Must set test_loader first."
        logging.info("==> Begin testing")
        if self._use_fc_bn:
            self._generator.eval()
        params = self.sample_parameters(num_particles=num_particles)
        self._param_net.set_parameters(params)
        if self._use_fc_bn:
            self._generator.train()
        with record_time("time/test"):
            if self._loss_type == 'classification':
                test_acc = 0.
            test_loss = 0.
            for i, (data, target) in enumerate(self._test_loader):
                data = data.to(alf.get_default_device())
                target = target.to(alf.get_default_device())
                output, _ = self._param_net(data)  # [B, N, D]
                loss, extra = self._vote(output, target)
                if self._loss_type == 'classification':
                    test_acc += extra.item()
                test_loss += loss.loss.item()

        if self._loss_type == 'classification':
            test_acc /= len(self._test_loader.dataset)
            alf.summary.scalar(name='eval/test_acc', data=test_acc * 100)
        if self._logging_evaluate:
            if self._loss_type == 'classification':
                logging.info("Test acc: {}".format(test_acc * 100))
            logging.info("Test loss: {}".format(test_loss))
        alf.summary.scalar(name='eval/test_loss', data=test_loss)
예제 #9
0
    def _train(self):
        for env in self._envs:
            env.reset()
        if self._eval_env:
            self._eval_env.reset()

        begin_iter_num = int(self._trainer_progress._iter_num)
        iter_num = begin_iter_num

        checkpoint_interval = math.ceil(
            (self._num_iterations or self._num_env_steps) /
            self._num_checkpoints)

        if self._num_iterations:
            time_to_checkpoint = self._trainer_progress._iter_num + checkpoint_interval
        else:
            time_to_checkpoint = self._trainer_progress._env_steps + checkpoint_interval

        while True:
            t0 = time.time()
            with record_time("time/train_iter"):
                train_steps = self._algorithm.train_iter()
            t = time.time() - t0
            logging.log_every_n_seconds(
                logging.INFO,
                '%s -> %s: %s time=%.3f throughput=%0.2f' %
                (common.get_gin_file(), [
                    os.path.basename(self._root_dir.strip('/'))
                ], iter_num, t, int(train_steps) / t),
                n_seconds=1)

            if self._evaluate and (iter_num + 1) % self._eval_interval == 0:
                self._eval()
            if iter_num == begin_iter_num:
                self._summarize_training_setting()

            # check termination
            env_steps_metric = self._algorithm.get_step_metrics()[1]
            total_time_steps = env_steps_metric.result()
            iter_num += 1

            self._trainer_progress.update(iter_num, total_time_steps)

            if ((self._num_iterations and iter_num >= self._num_iterations)
                    or (self._num_env_steps
                        and total_time_steps >= self._num_env_steps)):
                # Evaluate before exiting so that the eval curve shown in TB
                # will align with the final iter/env_step.
                if self._evaluate:
                    self._eval()
                break

            if ((self._num_iterations and iter_num >= time_to_checkpoint)
                    or (self._num_env_steps
                        and total_time_steps >= time_to_checkpoint)):
                self._save_checkpoint()
                time_to_checkpoint += checkpoint_interval
            elif self._checkpoint_requested:
                logging.info("Saving checkpoint upon request...")
                self._save_checkpoint()
                self._checkpoint_requested = False

            if self._debug_requested:
                self._debug_requested = False
                import pdb
                pdb.set_trace()