Exemplo n.º 1
0
    def _summarize_training_setting(self):
        # We need to wait for one iteration to get the operative args
        # Right just give a fixed gin file name to store operative args
        common.write_gin_configs(self._root_dir, "configured.gin")
        with alf.summary.record_if(lambda: True):

            def _markdownify(paragraph):
                return "    ".join(
                    (os.linesep + paragraph).splitlines(keepends=True))

            common.summarize_gin_config()
            alf.summary.text('commandline', ' '.join(sys.argv))
            alf.summary.text(
                'optimizers',
                _markdownify(self._algorithm.get_optimizer_info()))
            alf.summary.text('revision', git_utils.get_revision())
            alf.summary.text('diff', _markdownify(git_utils.get_diff()))
            alf.summary.text('seed', str(self._random_seed))

            if self._config.code_snapshots is not None:
                for f in self._config.code_snapshots:
                    path = os.path.join(
                        os.path.abspath(os.path.dirname(__file__)), "..", f)
                    if not os.path.isfile(path):
                        common.warning_once(
                            "The code file '%s' for summary is invalid" % path)
                        continue
                    with open(path, 'r') as fin:
                        code = fin.read()
                        # adding "<pre>" will make TB show raw text instead of MD
                        alf.summary.text('code/%s' % f,
                                         "<pre>" + code + "</pre>")
Exemplo n.º 2
0
    def _train(self):
        for env in self._envs:
            env.reset()
        time_step = self._driver.get_initial_time_step()
        policy_state = self._driver.get_initial_policy_state()
        iter_num = 0
        while True:
            t0 = time.time()
            with record_time("time/train_iter"):
                time_step, policy_state, train_steps = self._train_iter(
                    iter_num=iter_num,
                    policy_state=policy_state,
                    time_step=time_step)
            t = time.time() - t0
            logging.log_every_n_seconds(logging.INFO,
                                        '%s time=%.3f throughput=%0.2f' %
                                        (iter_num, t, int(train_steps) / t),
                                        n_seconds=1)
            if (iter_num + 1) % self._checkpoint_interval == 0:
                self._save_checkpoint()
            if self._evaluate and (iter_num + 1) % self._eval_interval == 0:
                self._eval()
            if iter_num == 0:
                # We need to wait for one iteration to get the operative args
                # Right just give a fixed gin file name to store operative args
                common.write_gin_configs(self._root_dir, "configured.gin")
                with tf.summary.record_if(True):

                    def _markdownify(paragraph):
                        return "    ".join(
                            (os.linesep + paragraph).splitlines(keepends=True))

                    common.summarize_gin_config()
                    tf.summary.text('commandline', ' '.join(sys.argv))
                    tf.summary.text(
                        'optimizers',
                        _markdownify(self._algorithm.get_optimizer_info()))
                    tf.summary.text('revision', git_utils.get_revision())
                    tf.summary.text('diff', _markdownify(git_utils.get_diff()))
                    tf.summary.text('seed', str(self._random_seed))

            # check termination
            env_steps_metric = self._driver.get_step_metrics()[1]
            total_time_steps = env_steps_metric.result().numpy()
            iter_num += 1
            if (self._num_iterations and iter_num >= self._num_iterations) \
                or (self._num_env_steps and total_time_steps >= self._num_env_steps):
                break
Exemplo n.º 3
0
    def _summarize_training_setting(self):
        # We need to wait for one iteration to get the operative args
        # Right just give a fixed gin file name to store operative args
        common.write_gin_configs(self._root_dir, "configured.gin")
        with alf.summary.record_if(lambda: True):

            def _markdownify(paragraph):
                return "    ".join(
                    (os.linesep + paragraph).splitlines(keepends=True))

            common.summarize_gin_config()
            alf.summary.text('commandline', ' '.join(sys.argv))
            alf.summary.text(
                'optimizers',
                _markdownify(self._algorithm.get_optimizer_info()))
            alf.summary.text('revision', git_utils.get_revision())
            alf.summary.text('diff', _markdownify(git_utils.get_diff()))
            alf.summary.text('seed', str(self._random_seed))
Exemplo n.º 4
0
    def _train(self):
        for env in self._envs:
            env.reset()
        time_step = self._driver.get_initial_time_step()
        policy_state = self._driver.get_initial_policy_state()
        iter_num = 0
        while True:
            t0 = time.time()
            time_step, policy_state, train_steps = self.train_iter(
                iter_num=iter_num,
                policy_state=policy_state,
                time_step=time_step)
            t = time.time() - t0
            logging.log_every_n_seconds(logging.INFO,
                                        '%s time=%.3f throughput=%0.2f' %
                                        (iter_num, t, int(train_steps) / t),
                                        n_seconds=1)
            tf.summary.scalar("time/train_iter", t)
            if (iter_num + 1) % self._checkpoint_interval == 0:
                self._save_checkpoint()
            if self._evaluate and (iter_num + 1) % self._eval_interval == 0:
                self._eval()
            if iter_num == 0:
                # We need to wait for one iteration to get the operative args
                # Right just give a fixed gin file name to store operative args
                common.write_gin_configs(self._root_dir, "configured.gin")
                with tf.summary.record_if(True):
                    common.summarize_gin_config()
                    tf.summary.text('commandline', ' '.join(sys.argv))
                    tf.summary.text('optimizers',
                                    self._algorithm.get_optimizer_info())

            # check termination
            env_steps_metric = self._driver.get_step_metrics()[1]
            total_time_steps = env_steps_metric.result().numpy()
            iter_num += 1
            if (self._num_iterations and iter_num >= self._num_iterations) \
                or (self._num_env_steps and total_time_steps >= self._num_env_steps):
                break