def _update_target_networks(self):
     if self.use_soft_update:
         ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
         ptu.soft_update_from_to(self.qf, self.target_qf, self.tau)
     else:
         if self._n_env_steps_total % self.target_hard_update_period == 0:
             ptu.copy_model_params_from_to(self.qf, self.target_qf)
             ptu.copy_model_params_from_to(self.policy, self.target_policy)
Example #2
0
 def _update_target_networks(self):
     if self.use_soft_update:
         ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
         ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)
     else:
         if self._n_train_steps_total % self.target_hard_update_period == 0:
             ptu.copy_model_params_from_to(self.qf1, self.target_qf1)
             ptu.copy_model_params_from_to(self.qf2, self.target_qf2)
Example #3
0
    def _do_training(self):
        batch = self.get_batch()
        """
        Optimize Critic/Actor.
        """
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        _, _, v_pred = self.target_policy(next_obs, None)
        y_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * v_pred
        y_target = y_target.detach()
        mu, y_pred, v = self.policy(obs, actions)
        policy_loss = self.policy_criterion(y_pred, y_target)

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        """
        Update Target Networks
        """
        if self.use_soft_update:
            ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
        else:
            if self._n_train_steps_total % self.target_hard_update_period == 0:
                ptu.copy_model_params_from_to(self.policy, self.target_policy)

        if self.need_to_update_eval_statistics:
            self.need_to_update_eval_statistics = False
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy v',
                    ptu.get_numpy(v),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(mu),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Y targets',
                    ptu.get_numpy(y_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Y predictions',
                    ptu.get_numpy(y_pred),
                ))
Example #4
0
 def _do_training(self, n_steps_total):
     raw_subtraj_batch, start_indices = (
         self.replay_buffer.train_replay_buffer.random_subtrajectories(
             self.num_subtrajs_per_batch))
     subtraj_batch = create_torch_subtraj_batch(raw_subtraj_batch)
     if self.save_memory_gradients:
         subtraj_batch['memories'].requires_grad = True
     self.train_critic(subtraj_batch)
     self.train_policy(subtraj_batch, start_indices)
     if self.use_soft_update:
         ptu.soft_update_from_to(self.policy, self.target_policy, self.tau)
         ptu.soft_update_from_to(self.qf, self.target_qf, self.tau)
     else:
         if n_steps_total % self.target_hard_update_period == 0:
             ptu.copy_model_params_from_to(self.qf, self.target_qf)
             ptu.copy_model_params_from_to(self.policy, self.target_policy)
Example #5
0
 def _update_target_networks(self):
     for cg1, target_cg1, qf1, target_qf1, cg2, target_cg2, qf2, target_qf2 in \
         zip(self.cg1_n, self.target_cg1_n, self.qf1_n, self.target_qf1_n,
             self.cg2_n, self.target_cg2_n, self.qf2_n, self.target_qf2_n):
         if self.use_soft_update:
             ptu.soft_update_from_to(cg1, target_cg1, self.tau)
             ptu.soft_update_from_to(qf1, target_qf1, self.tau)
             ptu.soft_update_from_to(cg2, target_cg2, self.tau)
             ptu.soft_update_from_to(qf2, target_qf2, self.tau)
         else:
             if self._n_train_steps_total % self.target_hard_update_period == 0:
                 ptu.copy_model_params_from_to(cg1, target_cg1)
                 ptu.copy_model_params_from_to(qf1, target_qf1)
                 ptu.copy_model_params_from_to(cg2, target_cg2)
                 ptu.copy_model_params_from_to(qf2, target_qf2)
Example #6
0
 def _update_target_networks(self):
     for policy, target_policy, qf, target_qf in \
         zip(self.policy_n, self.target_policy_n, self.qf_n, self.target_qf_n):
         if self.use_soft_update:
             ptu.soft_update_from_to(policy, target_policy, self.tau)
             ptu.soft_update_from_to(qf, target_qf, self.tau)
         else:
             if self._n_train_steps_total % self.target_hard_update_period == 0:
                 ptu.copy_model_params_from_to(qf, target_qf)
                 ptu.copy_model_params_from_to(policy, target_policy)
     if self.double_q:
         for qf2, target_qf2 in zip(self.qf2_n, self.target_qf2_n):
             if self.use_soft_update:
                 ptu.soft_update_from_to(qf2, target_qf2, self.tau)
             else:
                 if self._n_train_steps_total % self.target_hard_update_period == 0:
                     ptu.copy_model_params_from_to(qf2, target_qf2)
Example #7
0
 def _update_target_network(self):
     if self.use_hard_updates:
         if self._n_train_steps_total % self.hard_update_period == 0:
             ptu.copy_model_params_from_to(self.qf, self.target_qf)
     else:
         ptu.soft_update_from_to(self.qf, self.target_qf, self.tau)
Example #8
0
 def copy(self):
     copy = Serializable.clone(self)
     ptu.copy_model_params_from_to(self, copy)
     return copy
Example #9
0
    def pretrain_policy_with_bc(self):
        if self.buffer_for_bc_training == "demos":
            self.bc_training_buffer = self.demo_train_buffer
            self.bc_test_buffer = self.demo_test_buffer
        elif self.buffer_for_bc_training == "replay_buffer":
            self.bc_training_buffer = self.replay_buffer.train_replay_buffer
            self.bc_test_buffer = self.replay_buffer.validation_replay_buffer
        else:
            self.bc_training_buffer = None
            self.bc_test_buffer = None

        if self.load_policy_path:
            self.policy = load_local_or_remote_file(self.load_policy_path)
            ptu.copy_model_params_from_to(self.policy, self.target_policy)
            return

        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_policy.csv',
                                  relative_to_snapshot_dir=True)
        if self.do_pretrain_rollouts:
            total_ret = self.do_rollouts()
            print("INITIAL RETURN", total_ret / 20)

        prev_time = time.time()
        for i in range(self.bc_num_pretrain_steps):
            train_policy_loss, train_logp_loss, train_mse_loss, train_log_std = self.run_bc_batch(
                self.demo_train_buffer, self.policy)
            train_policy_loss = train_policy_loss * self.bc_weight

            self.policy_optimizer.zero_grad()
            train_policy_loss.backward()
            self.policy_optimizer.step()

            test_policy_loss, test_logp_loss, test_mse_loss, test_log_std = self.run_bc_batch(
                self.demo_test_buffer, self.policy)
            test_policy_loss = test_policy_loss * self.bc_weight

            if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0:
                total_ret = self.do_rollouts()
                print("Return at step {} : {}".format(i, total_ret / 20))

            if i % self.pretraining_logging_period == 0:
                stats = {
                    "pretrain_bc/batch":
                    i,
                    "pretrain_bc/Train Logprob Loss":
                    ptu.get_numpy(train_logp_loss),
                    "pretrain_bc/Test Logprob Loss":
                    ptu.get_numpy(test_logp_loss),
                    "pretrain_bc/Train MSE":
                    ptu.get_numpy(train_mse_loss),
                    "pretrain_bc/Test MSE":
                    ptu.get_numpy(test_mse_loss),
                    "pretrain_bc/train_policy_loss":
                    ptu.get_numpy(train_policy_loss),
                    "pretrain_bc/test_policy_loss":
                    ptu.get_numpy(test_policy_loss),
                    "pretrain_bc/epoch_time":
                    time.time() - prev_time,
                }

                if self.do_pretrain_rollouts:
                    stats["pretrain_bc/avg_return"] = total_ret / 20

                logger.record_dict(stats)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                pickle.dump(self.policy,
                            open(logger.get_snapshot_dir() + '/bc.pkl', "wb"))
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_policy.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        ptu.copy_model_params_from_to(self.policy, self.target_policy)

        if self.post_bc_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)