示例#1
0
 def train(self, np_batch, np_demo_batch=None):
     self._num_train_steps += 1
     batch = np_to_pytorch_batch(np_batch)
     if not np_demo_batch == None:
         demo_batch = np_to_pytorch_batch(np_demo_batch)
         self.train_from_torch(batch, demo_batch)
     else:
         self.train_from_torch(batch)
示例#2
0
 def train(self, np_batch, np_batch_dead):
     self._num_train_steps += 1
     # for k in np_batch.keys():
     #     print(k)
     np_full_batch = {
         k: np.concatenate((np_batch[k], np_batch_dead[k]))
         for k in np_batch.keys()
     }
     full_batch = np_to_pytorch_batch(np_full_batch)
     batch_dead = np_to_pytorch_batch(np_batch_dead)
     self.train_from_torch(full_batch, batch_dead)
示例#3
0
    def random_batch(self, batch_size):
        env_i = np.random.choice(self.num_envs, batch_size)
        trans_i = np.random.choice(self.sample_size, batch_size)

        match_i = np.random.choice(self.num_envs, batch_size)
        trans_x = np.random.choice(self.sample_size, batch_size)
        trans_y = np.random.choice(self.sample_size, batch_size)

        rand_a = np.random.choice(self.num_envs - 1, batch_size // 2)
        rand_b = np.add(rand_a, np.ones(batch_size // 2)).astype(int)

        trans_m = np.random.choice(self.sample_size, batch_size // 2)
        trans_n = np.random.choice(self.sample_size, batch_size // 2)

        matches = np.random.uniform(0, 0.1, batch_size // 2)
        nonmatches = np.random.uniform(0.9, 1, batch_size // 2)
        swap_count = int(batch_size * 0.05)

        matches[:swap_count], nonmatches[:swap_count] = nonmatches[:swap_count], matches[:swap_count]
        labels = np.concatenate([matches, nonmatches])

        data_dict = {
            'observations': self.data['observations'][env_i, trans_i, :],
            'env_set_1': self.data['observations'][match_i, trans_x, :],
            'env_set_2': self.data['observations'][match_i, trans_y, :],
        }

        return np_to_pytorch_batch(data_dict)
示例#4
0
 def train(self, np_batch):
     # 1 训练步数,递增
     self._num_train_steps += 1
     # 2 转换数据为torch格式
     batch = np_to_pytorch_batch(np_batch)
     # 3 训练
     self.train_from_torch(batch)
示例#5
0
 def get_batch(self, batch_size, from_expert, keys=None):
     if from_expert:
         buffer = self.expert_replay_buffer
     else:
         buffer = self.replay_buffer
     batch = buffer.random_batch(batch_size, keys=keys)
     batch = np_to_pytorch_batch(batch)
     return batch
示例#6
0
 def get_batch(self, batch_size, keys=None, use_expert_buffer=True):
     if use_expert_buffer:
         rb = self.expert_replay_buffer
     else:
         rb = self.replay_buffer
     batch = rb.random_batch(batch_size, keys=keys)
     batch = np_to_pytorch_batch(batch)
     return batch
示例#7
0
 def random_batch(self, batch_size):
     i = np.random.choice(self.size, batch_size, replace=(self.size < batch_size))
     obs = self.data[i, :]
     if self.normalize:
         obs = normalize_image(obs)
     data_dict = {
         'observations': obs,
     }
     return np_to_pytorch_batch(data_dict)
示例#8
0
 def get_test_batch(self):
     batch = self.test_replay_buffer.random_batch(self.batch_size)
     batch = np_to_pytorch_batch(batch)
     obs = batch['observations']
     next_obs = batch['next_observations']
     goals = batch['resampled_goals']
     batch['observations'] = torch.cat((obs, goals), dim=1)
     batch['next_observations'] = torch.cat((next_obs, goals), dim=1)
     return batch
示例#9
0
    def train(self, np_batch_dict):
        """
        :param np_batch_dict: dict with 'safe' and 'danger' keys with batches
        :return:
        """
        self._num_train_steps += 1

        # concat data for danger and safe transitions
        np_full_batch = {k: np.concatenate((np_batch_dict['safe'][k], np_batch_dict['danger'][k])) for k in
                         np_batch_dict['safe'].keys()}

        # convert to pytorch
        torch_batch_full = np_to_pytorch_batch(np_full_batch)
        torch_batch_danger = np_to_pytorch_batch(np_batch_dict['safe'])

        # pack it to dict
        torch_batch_dict = {'full': torch_batch_full, 'danger': torch_batch_danger}
        self.train_from_torch(torch_batch_dict)
示例#10
0
    def random_batch(self, batch_size):
        traj_i = np.random.choice(np.arange(self.size), batch_size)
        trans_i = np.random.choice(np.arange(self.traj_length - 1), batch_size)
        data_dict = {
            'observations': self.data['observations'][traj_i, trans_i, :],
            'next_observations': self.data['observations'][traj_i, trans_i + 1, :],
            'actions': self.data['actions'][traj_i, trans_i, :],
        }

        return np_to_pytorch_batch(data_dict)
示例#11
0
    def get_batch_from_buffer(replay_buffer, batch_size):
        """

        :param replay_buffer:
        :param batch_size:
        :return:
        """
        batch = replay_buffer.random_batch(batch_size)
        batch = np_to_pytorch_batch(batch)
        return batch
示例#12
0
 def get_batch(self, batch_size, from_target_state_buffer, keys=None):
     if from_target_state_buffer:
         buffer = self.target_state_buffer
         batch = {
             'observations':
             buffer[np.random.choice(buffer.shape[0], size=batch_size)]
         }
     else:
         buffer = self.replay_buffer
         batch = buffer.random_batch(batch_size, keys=keys)
     batch = np_to_pytorch_batch(batch)
     return batch
示例#13
0
    def random_batch(self, batch_size):
        num_traj = self.replay_buffer._size // self.horizon
        traj_i = np.random.choice(num_traj, batch_size)
        trans_i = np.random.choice(self.horizon - 2, batch_size)

        indices = traj_i * self.horizon + trans_i
        batch = dict(
            x0=self.replay_buffer._obs["image_observation"][indices],
            x1=self.replay_buffer._obs["image_observation"][indices+1],
            x2=self.replay_buffer._obs["image_observation"][indices+2],
        )
        return np_to_pytorch_batch(batch)
示例#14
0
 def get_batch(self):
     sample_size = self.batch_size // 2
     batch1 = self.replay_buffer1().random_batch(sample_size)
     batch2 = self.replay_buffer2().random_batch(sample_size)
     new_batch = {}
     for k, v in batch1.items():
         new_batch[k] = np.concatenate(
             (
                 v,
                 batch2[k]
             ),
             axis=0,
         )
     return np_to_pytorch_batch(new_batch)
示例#15
0
 def _statistics_from_paths(self, paths, stat_prefix):
     rewards, terminals, obs, actions, next_obs = split_paths(paths)
     np_batch = dict(
         rewards=rewards,
         terminals=terminals,
         observations=obs,
         actions=actions,
         next_observations=next_obs,
     )
     batch = np_to_pytorch_batch(np_batch)
     statistics = self._statistics_from_batch(batch, stat_prefix)
     statistics.update(create_stats_ordered_dict(
         'Num Paths', len(paths), stat_prefix=stat_prefix
     ))
     return statistics
示例#16
0
 def get_batch_from_buffer(self, replay_buffer):
     batch = replay_buffer.random_batch(self.bc_batch_size)
     batch = np_to_pytorch_batch(batch)
     # obs = batch['observations']
     # next_obs = batch['next_observations']
     # goals = batch['resampled_goals']
     # import ipdb; ipdb.set_trace()
     # batch['observations'] = torch.cat((
     #     obs,
     #     goals
     # ), dim=1)
     # batch['next_observations'] = torch.cat((
     #     next_obs,
     #     goals
     # ), dim=1)
     return batch
示例#17
0
    def pretrain_q_with_bc_data(self, batch_size):
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_q.csv',
                                  relative_to_snapshot_dir=True)

        self.update_policy = True
        # then train policy and Q function together
        prev_time = time.time()
        for i in range(self.num_pretrain_steps):
            self.eval_statistics = dict()
            if i % self.pretraining_logging_period == 0:
                self._need_to_update_eval_statistics = True
            train_data = self.replay_buffer.random_batch(batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data)
            if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0:
                total_ret = self.do_rollouts()
                print("Return at step {} : {}".format(i, total_ret / 20))

            if i % self.pretraining_logging_period == 0:
                if self.do_pretrain_rollouts:
                    self.eval_statistics[
                        "pretrain_bc/avg_return"] = total_ret / 20
                self.eval_statistics["batch"] = i
                self.eval_statistics["epoch_time"] = time.time() - prev_time
                logger.record_dict(self.eval_statistics)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        self._need_to_update_eval_statistics = True
        self.eval_statistics = dict()
示例#18
0
    def get_batch(self):
        batch = self.replay_buffer.random_batch_random_tau(
            self.batch_size, self.max_tau)
        """
        Update the goal states/rewards
        """
        num_steps_left = self._sample_taus_for_training(batch)
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        goals = batch['training_goals']
        rewards = self._compute_rewards_np(batch, obs, actions, next_obs,
                                           goals)
        terminals = batch['terminals']

        #not too sure what this code does
        if self.tau_sample_strategy == 'all_valid':
            obs = np.repeat(obs, self.max_tau + 1, 0)
            actions = np.repeat(actions, self.max_tau + 1, 0)
            next_obs = np.repeat(next_obs, self.max_tau + 1, 0)
            goals = np.repeat(goals, self.max_tau + 1, 0)
            rewards = np.repeat(rewards, self.max_tau + 1, 0)
            terminals = np.repeat(terminals, self.max_tau + 1, 0)

        if self.finite_horizon:
            terminals = 1 - (1 - terminals) * (num_steps_left != 0)
        if self.terminate_when_goal_reached:
            diff = self.env.convert_obs_to_goals(next_obs) - goals
            goal_not_reached = (np.linalg.norm(diff, axis=1, keepdims=True) >
                                self.goal_reached_epsilon)
            terminals = 1 - (1 - terminals) * goal_not_reached

        if not self.dense_rewards:
            rewards = rewards * terminals
        """
        Update the batch
        """
        batch['rewards'] = rewards
        batch['terminals'] = terminals
        batch['actions'] = actions
        batch['num_steps_left'] = num_steps_left
        batch['goals'] = goals
        batch['observations'] = obs
        batch['next_observations'] = next_obs

        return np_to_pytorch_batch(batch)
示例#19
0
    def get_batch(self, training=True):
        if self.replay_buffer_is_split:
            replay_buffer = self.replay_buffer.get_replay_buffer(training)
        else:
            replay_buffer = self.replay_buffer
        batch = replay_buffer.random_batch(self.batch_size)

        """
        Update the goal states/rewards
        """
        num_steps_left = np.random.randint(
            0, self.max_tau + 1, (self.batch_size, 1)
        )
        terminals = 1 - (1 - batch['terminals']) * (num_steps_left != 0)
        batch['terminals'] = terminals

        obs = batch['observations']
        next_obs = batch['next_observations']
        if self.sample_train_goals_from == 'her':
            goals = batch['goals']
        else:
            goals = self._sample_goals_for_training()
        goal_differences = np.abs(
            self.env.convert_obs_to_goals(next_obs)
            # - self.env.convert_obs_to_goals(obs)
            - goals
        )
        batch['goal_differences'] = goal_differences * self.reward_scale
        batch['goals'] = goals

        """
        Update the observations
        """
        batch['observations'] = merge_into_flat_obs(
            obs=batch['observations'],
            goals=batch['goals'],
            num_steps_left=num_steps_left,
        )
        batch['next_observations'] = merge_into_flat_obs(
            obs=batch['next_observations'],
            goals=batch['goals'],
            num_steps_left=num_steps_left-1,
        )

        return np_to_pytorch_batch(batch)
示例#20
0
 def train(self):
     # first train only the Q function
     iteration = 0
     for i in range(self.num_batches):
         train_data = self.replay_buffer.random_batch(self.batch_size)
         train_data = np_to_pytorch_batch(train_data)
         obs = train_data['observations']
         next_obs = train_data['next_observations']
         train_data['observations'] = obs
         train_data['next_observations'] = next_obs
         self.trainer.train_from_torch(train_data)
         if i % self.logging_period == 0:
             stats_with_prefix = add_prefix(
                 self.trainer.eval_statistics, prefix="trainer/")
             self.trainer.end_epoch(iteration)
             iteration += 1
             logger.record_dict(stats_with_prefix)
             logger.dump_tabular(with_prefix=True, with_timestamp=False)
示例#21
0
    def train(self, np_batch):
        batch = np_to_pytorch_batch(np_batch)
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Compute loss
        """

        best_action_idxs = self.qf(next_obs).max(1, keepdim=True)[1]
        target_q_values = self.target_qf(next_obs).gather(
            1, best_action_idxs).detach()
        y_target = rewards + (1. - terminals) * self.discount * target_q_values
        y_target = y_target.detach()
        # actions is a one-hot vector
        y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True)
        qf_loss = self.qf_criterion(y_pred, y_target)
        """
        Update networks
        """
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()
        """
        Soft target network updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf, self.target_qf,
                                    self.soft_target_tau)
        """
        Save some statistics for eval using just one batch.
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Y Predictions',
                    ptu.get_numpy(y_pred),
                ))
示例#22
0
    def pretrain_q_with_bc_data(self, batch_size):
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_q.csv',
                                  relative_to_snapshot_dir=True)

        prev_time = time.time()
        for i in range(self.num_pretrain_steps):
            self.eval_statistics = dict()
            if i % self.pretraining_logging_period == 0:
                self._need_to_update_eval_statistics = True
            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data, pretrain=True)

            if i % self.pretraining_logging_period == 0:
                self.eval_statistics["batch"] = i
                self.eval_statistics["epoch_time"] = time.time() - prev_time
                stats_with_prefix = add_prefix(self.eval_statistics,
                                               prefix="trainer/")
                logger.record_dict(stats_with_prefix)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        self._need_to_update_eval_statistics = True
        self.eval_statistics = dict()
示例#23
0
    def random_batch(self, batch_size):
        traj_i = np.random.choice(self.size, batch_size)
        trans_i = np.random.choice(self.traj_length, batch_size)
        # conditioning = np.random.choice(self.traj_length, batch_size)
        # env = normalize_image(self.data['observations'][traj_i, conditioning, :])
        try:
            env = normalize_image(self.data['env'][traj_i, :])
        except:
            env = normalize_image(self.data['observations'][traj_i, 0, :])
        x_t = normalize_image(self.data['observations'][traj_i, trans_i, :])

        episode_num = np.random.randint(0, self.size)
        episode_obs = normalize_image(self.data['observations'][episode_num, :8, :])


        data_dict = {
            'x_t': x_t,
            'env': env,
            'episode_obs': episode_obs,
        }
        return np_to_pytorch_batch(data_dict)
示例#24
0
 def fix_data_set(self):
     for training in [True, False]:
         replay_buffer = self.replay_buffer.get_replay_buffer(training)
         batch_dict = {}
         for i in range(self.num_unique_batches):
             batch_size = min(
                 replay_buffer.num_steps_can_sample(),
                 self.batch_size
             )
             batch = replay_buffer.random_batch(batch_size)
             goal_states = self.sample_goal_states(batch_size, training)
             new_rewards = self.env.compute_rewards(
                 batch['observations'],
                 batch['actions'],
                 batch['next_observations'],
                 goal_states,
             )
             batch['goal_states'] = goal_states
             batch['rewards'] = new_rewards
             torch_batch = np_to_pytorch_batch(batch)
             batch_dict[i] = torch_batch
         self.mode_to_batch_iterator[training] = create_batch_iterator(
             self.num_unique_batches, batch_dict
         )
示例#25
0
    def random_batch(self, batch_size):
        traj_i = np.random.choice(self.size, batch_size)
        trans_i = np.random.choice(self.traj_length - 1, batch_size)

        try:
            env = normalize_image(self.data['env'][traj_i, :])
        except:
            env = normalize_image(self.data['observations'][traj_i, 0, :])

        x_t = normalize_image(self.data['observations'][traj_i, trans_i, :])
        x_next = normalize_image(self.data['observations'][traj_i, trans_i + 1, :])

        episode_num = np.random.randint(0, self.size)
        episode_obs = normalize_image(self.data['observations'][episode_num, :8, :])

        data_dict = {
            'x_t': x_t,
            'x_next': x_next,
            'env': env,
            'actions': self.data['actions'][traj_i, trans_i, :],
            'episode_obs': episode_obs,
            'episode_acts': self.data['actions'][episode_num, :7, :],
        }
        return np_to_pytorch_batch(data_dict)
示例#26
0
 def train(self, np_batch):
     self._num_train_steps += 1
     batch = np_to_pytorch_batch(np_batch)
     self.train_from_torch(batch)
示例#27
0
    def train_from_torch(
        self,
        batch,
        train=True,
        pretrain=False,
    ):
        """

        :param batch:
        :param train:
        :param pretrain:
        :return:
        """
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        weights = batch.get('weights', None)
        if self.reward_transform:
            rewards = self.reward_transform(rewards)

        if self.terminal_transform:
            terminals = self.terminal_transform(terminals)
        """
        Policy and Alpha Loss
        """
        dist = self.policy(obs)
        new_obs_actions, log_pi = dist.rsample_and_logprob()
        policy_mle = dist.mle_estimate()

        if self.brac:
            buf_dist = self.buffer_policy(obs)
            buf_log_pi = buf_dist.log_prob(actions)
            rewards = rewards + buf_log_pi

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = self.alpha
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        # Make sure policy accounts for squashing functions like tanh correctly!
        next_dist = self.policy(next_obs)
        new_next_actions, new_log_pi = next_dist.rsample_and_logprob()
        target_q_values = torch.min(
            self.target_qf1(next_obs, new_next_actions),
            self.target_qf2(next_obs, new_next_actions),
        ) - alpha * new_log_pi

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
        qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
        """
        Policy Loss
        """
        qf1_new_actions = self.qf1(obs, new_obs_actions)
        qf2_new_actions = self.qf2(obs, new_obs_actions)
        q_new_actions = torch.min(
            qf1_new_actions,
            qf2_new_actions,
        )

        # Advantage-weighted regression
        if self.awr_use_mle_for_vf:
            v1_pi = self.qf1(obs, policy_mle)
            v2_pi = self.qf2(obs, policy_mle)
            v_pi = torch.min(v1_pi, v2_pi)
        else:
            if self.vf_K > 1:
                vs = []
                for i in range(self.vf_K):
                    u = dist.sample()
                    q1 = self.qf1(obs, u)
                    q2 = self.qf2(obs, u)
                    v = torch.min(q1, q2)
                    # v = q1
                    vs.append(v)
                v_pi = torch.cat(vs, 1).mean(dim=1)
            else:
                # v_pi = self.qf1(obs, new_obs_actions)
                v1_pi = self.qf1(obs, new_obs_actions)
                v2_pi = self.qf2(obs, new_obs_actions)
                v_pi = torch.min(v1_pi, v2_pi)

        if self.awr_sample_actions:
            u = new_obs_actions
            if self.awr_min_q:
                q_adv = q_new_actions
            else:
                q_adv = qf1_new_actions
        elif self.buffer_policy_sample_actions:
            buf_dist = self.buffer_policy(obs)
            u, _ = buf_dist.rsample_and_logprob()
            qf1_buffer_actions = self.qf1(obs, u)
            qf2_buffer_actions = self.qf2(obs, u)
            q_buffer_actions = torch.min(
                qf1_buffer_actions,
                qf2_buffer_actions,
            )
            if self.awr_min_q:
                q_adv = q_buffer_actions
            else:
                q_adv = qf1_buffer_actions
        else:
            u = actions
            if self.awr_min_q:
                q_adv = torch.min(q1_pred, q2_pred)
            else:
                q_adv = q1_pred

        policy_logpp = dist.log_prob(u)

        if self.use_automatic_beta_tuning:
            buffer_dist = self.buffer_policy(obs)
            beta = self.log_beta.exp()
            kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)
            beta_loss = -1 * (beta *
                              (kldiv - self.beta_epsilon).detach()).mean()

            self.beta_optimizer.zero_grad()
            beta_loss.backward()
            self.beta_optimizer.step()
        else:
            beta = self.beta_schedule.get_value(self._n_train_steps_total)

        if self.normalize_over_state == "advantage":
            score = q_adv - v_pi
            if self.mask_positive_advantage:
                score = torch.sign(score)
        elif self.normalize_over_state == "Z":
            buffer_dist = self.buffer_policy(obs)
            K = self.Z_K
            buffer_obs = []
            buffer_actions = []
            log_bs = []
            log_pis = []
            for i in range(K):
                u = buffer_dist.sample()
                log_b = buffer_dist.log_prob(u)
                log_pi = dist.log_prob(u)
                buffer_obs.append(obs)
                buffer_actions.append(u)
                log_bs.append(log_b)
                log_pis.append(log_pi)
            buffer_obs = torch.cat(buffer_obs, 0)
            buffer_actions = torch.cat(buffer_actions, 0)
            p_buffer = torch.exp(torch.cat(log_bs, 0).sum(dim=1, ))
            log_pi = torch.cat(log_pis, 0)
            log_pi = log_pi.sum(dim=1, )
            q1_b = self.qf1(buffer_obs, buffer_actions)
            q2_b = self.qf2(buffer_obs, buffer_actions)
            q_b = torch.min(q1_b, q2_b)
            q_b = torch.reshape(q_b, (-1, K))
            adv_b = q_b - v_pi
            # if self._n_train_steps_total % 100 == 0:
            #     import ipdb; ipdb.set_trace()
            # Z = torch.exp(adv_b / beta).mean(dim=1, keepdim=True)
            # score = torch.exp((q_adv - v_pi) / beta) / Z
            # score = score / sum(score)
            logK = torch.log(ptu.tensor(float(K)))
            logZ = torch.logsumexp(adv_b / beta - logK, dim=1, keepdim=True)
            logS = (q_adv - v_pi) / beta - logZ
            # logZ = torch.logsumexp(q_b/beta - logK, dim=1, keepdim=True)
            # logS = q_adv/beta - logZ
            score = F.softmax(logS, dim=0)  # score / sum(score)
        else:
            error

        if self.clip_score is not None:
            score = torch.clamp(score, max=self.clip_score)

        if self.weight_loss and weights is None:
            if self.normalize_over_batch:
                weights = F.softmax(score / beta, dim=0)
            elif self.normalize_over_batch == "whiten":
                adv_mean = torch.mean(score)
                adv_std = torch.std(score) + 1e-5
                normalized_score = (score - adv_mean) / adv_std
                weights = torch.exp(normalized_score / beta)
            elif self.normalize_over_batch == "exp":
                weights = torch.exp(score / beta)
            elif self.normalize_over_batch == "step_fn":
                weights = (score > 0).float()
            elif not self.normalize_over_batch:
                weights = score
            else:
                error
        weights = weights[:, 0]

        policy_loss = alpha * log_pi.mean()

        if self.use_awr_update and self.weight_loss:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp * len(weights) * weights.detach()).mean()
        elif self.use_awr_update:
            policy_loss = policy_loss + self.awr_weight * (
                -policy_logpp).mean()

        if self.use_reparam_update:
            policy_loss = policy_loss + self.reparam_weight * (
                -q_new_actions).mean()

        policy_loss = self.rl_weight * policy_loss
        if self.compute_bc:
            train_policy_loss, train_logp_loss, train_mse_loss, _ = self.run_bc_batch(
                self.demo_train_buffer, self.policy)
            policy_loss = policy_loss + self.bc_weight * train_policy_loss

        if not pretrain and self.buffer_policy_reset_period > 0 and self._n_train_steps_total % self.buffer_policy_reset_period == 0:
            del self.buffer_policy_optimizer
            self.buffer_policy_optimizer = self.optimizer_class(
                self.buffer_policy.parameters(),
                weight_decay=self.policy_weight_decay,
                lr=self.policy_lr,
            )
            self.optimizers[self.buffer_policy] = self.buffer_policy_optimizer
            for i in range(self.num_buffer_policy_train_steps_on_reset):
                if self.train_bc_on_rl_buffer:
                    if self.advantage_weighted_buffer_loss:
                        buffer_dist = self.buffer_policy(obs)
                        buffer_u = actions
                        buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob(
                        )
                        buffer_policy_logpp = buffer_dist.log_prob(buffer_u)
                        buffer_policy_logpp = buffer_policy_logpp[:, None]

                        buffer_q1_pred = self.qf1(obs, buffer_u)
                        buffer_q2_pred = self.qf2(obs, buffer_u)
                        buffer_q_adv = torch.min(buffer_q1_pred,
                                                 buffer_q2_pred)

                        buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions)
                        buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions)
                        buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi)

                        buffer_score = buffer_q_adv - buffer_v_pi
                        buffer_weights = F.softmax(buffer_score / beta, dim=0)
                        buffer_policy_loss = self.awr_weight * (
                            -buffer_policy_logpp * len(buffer_weights) *
                            buffer_weights.detach()).mean()
                    else:
                        buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch(
                            self.replay_buffer.train_replay_buffer,
                            self.buffer_policy)

                    self.buffer_policy_optimizer.zero_grad()
                    buffer_policy_loss.backward(retain_graph=True)
                    self.buffer_policy_optimizer.step()

        if self.train_bc_on_rl_buffer:
            if self.advantage_weighted_buffer_loss:
                buffer_dist = self.buffer_policy(obs)
                buffer_u = actions
                buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob()
                buffer_policy_logpp = buffer_dist.log_prob(buffer_u)
                buffer_policy_logpp = buffer_policy_logpp[:, None]

                buffer_q1_pred = self.qf1(obs, buffer_u)
                buffer_q2_pred = self.qf2(obs, buffer_u)
                buffer_q_adv = torch.min(buffer_q1_pred, buffer_q2_pred)

                buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions)
                buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions)
                buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi)

                buffer_score = buffer_q_adv - buffer_v_pi
                buffer_weights = F.softmax(buffer_score / beta, dim=0)
                buffer_policy_loss = self.awr_weight * (
                    -buffer_policy_logpp * len(buffer_weights) *
                    buffer_weights.detach()).mean()
            else:
                buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch(
                    self.replay_buffer.train_replay_buffer, self.buffer_policy)
        """
        Update networks
        """
        if self._n_train_steps_total % self.q_update_period == 0:
            self.qf1_optimizer.zero_grad()
            qf1_loss.backward()
            self.qf1_optimizer.step()

            self.qf2_optimizer.zero_grad()
            qf2_loss.backward()
            self.qf2_optimizer.step()

        if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()

        if self.train_bc_on_rl_buffer and self._n_train_steps_total % self.policy_update_period == 0:
            self.buffer_policy_optimizer.zero_grad()
            buffer_policy_loss.backward()
            self.buffer_policy_optimizer.step()
        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                    self.soft_target_tau)
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'rewards',
                    ptu.get_numpy(rewards),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'terminals',
                    ptu.get_numpy(terminals),
                ))
            policy_statistics = add_prefix(dist.get_diagnostics(), "policy/")
            self.eval_statistics.update(policy_statistics)
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Weights',
                    ptu.get_numpy(weights),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Advantage Score',
                    ptu.get_numpy(score),
                ))

            if self.normalize_over_state == "Z":
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'logZ',
                        ptu.get_numpy(logZ),
                    ))

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            if self.compute_bc:
                test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch(
                    self.demo_test_buffer, self.policy)
                self.eval_statistics.update({
                    "bc/Train Logprob Loss":
                    ptu.get_numpy(train_logp_loss),
                    "bc/Test Logprob Loss":
                    ptu.get_numpy(test_logp_loss),
                    "bc/Train MSE":
                    ptu.get_numpy(train_mse_loss),
                    "bc/Test MSE":
                    ptu.get_numpy(test_mse_loss),
                    "bc/train_policy_loss":
                    ptu.get_numpy(train_policy_loss),
                    "bc/test_policy_loss":
                    ptu.get_numpy(test_policy_loss),
                })
            if self.train_bc_on_rl_buffer:
                _, buffer_train_logp_loss, _, _ = self.run_bc_batch(
                    self.replay_buffer.train_replay_buffer, self.buffer_policy)

                _, buffer_test_logp_loss, _, _ = self.run_bc_batch(
                    self.replay_buffer.validation_replay_buffer,
                    self.buffer_policy)
                buffer_dist = self.buffer_policy(obs)
                kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist)

                _, train_offline_logp_loss, _, _ = self.run_bc_batch(
                    self.demo_train_buffer, self.buffer_policy)

                _, test_offline_logp_loss, _, _ = self.run_bc_batch(
                    self.demo_test_buffer, self.buffer_policy)

                self.eval_statistics.update({
                    "buffer_policy/Train Online Logprob":
                    -1 * ptu.get_numpy(buffer_train_logp_loss),
                    "buffer_policy/Test Online Logprob":
                    -1 * ptu.get_numpy(buffer_test_logp_loss),
                    "buffer_policy/Train Offline Logprob":
                    -1 * ptu.get_numpy(train_offline_logp_loss),
                    "buffer_policy/Test Offline Logprob":
                    -1 * ptu.get_numpy(test_offline_logp_loss),
                    "buffer_policy/train_policy_loss":
                    ptu.get_numpy(buffer_policy_loss),
                    # "buffer_policy/test_policy_loss": ptu.get_numpy(buffer_test_policy_loss),
                    "buffer_policy/kl_div":
                    ptu.get_numpy(kldiv.mean()),
                })
            if self.use_automatic_beta_tuning:
                self.eval_statistics.update({
                    "adaptive_beta/beta":
                    ptu.get_numpy(beta.mean()),
                    "adaptive_beta/beta loss":
                    ptu.get_numpy(beta_loss.mean()),
                })

            if self.validation_qlearning:
                train_data = self.replay_buffer.validation_replay_buffer.random_batch(
                    self.bc_batch_size)
                train_data = np_to_pytorch_batch(train_data)
                obs = train_data['observations']
                next_obs = train_data['next_observations']
                # goals = train_data['resampled_goals']
                train_data[
                    'observations'] = obs  # torch.cat((obs, goals), dim=1)
                train_data[
                    'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
                self.test_from_torch(train_data)

        self._n_train_steps_total += 1
示例#28
0
    def pretrain_q_with_bc_data(self):
        """

        :return:
        """
        logger.remove_tabular_output('progress.csv',
                                     relative_to_snapshot_dir=True)
        logger.add_tabular_output('pretrain_q.csv',
                                  relative_to_snapshot_dir=True)

        self.update_policy = False
        # first train only the Q function
        for i in range(self.q_num_pretrain1_steps):
            self.eval_statistics = dict()

            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data, pretrain=True)
            if i % self.pretraining_logging_period == 0:
                stats_with_prefix = add_prefix(self.eval_statistics,
                                               prefix="trainer/")
                logger.record_dict(stats_with_prefix)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)

        self.update_policy = True
        # then train policy and Q function together
        prev_time = time.time()
        for i in range(self.q_num_pretrain2_steps):
            self.eval_statistics = dict()
            if i % self.pretraining_logging_period == 0:
                self._need_to_update_eval_statistics = True
            train_data = self.replay_buffer.random_batch(self.bc_batch_size)
            train_data = np_to_pytorch_batch(train_data)
            obs = train_data['observations']
            next_obs = train_data['next_observations']
            # goals = train_data['resampled_goals']
            train_data['observations'] = obs  # torch.cat((obs, goals), dim=1)
            train_data[
                'next_observations'] = next_obs  # torch.cat((next_obs, goals), dim=1)
            self.train_from_torch(train_data, pretrain=True)

            if i % self.pretraining_logging_period == 0:
                self.eval_statistics["batch"] = i
                self.eval_statistics["epoch_time"] = time.time() - prev_time
                stats_with_prefix = add_prefix(self.eval_statistics,
                                               prefix="trainer/")
                logger.record_dict(stats_with_prefix)
                logger.dump_tabular(with_prefix=True, with_timestamp=False)
                prev_time = time.time()

        logger.remove_tabular_output(
            'pretrain_q.csv',
            relative_to_snapshot_dir=True,
        )
        logger.add_tabular_output(
            'progress.csv',
            relative_to_snapshot_dir=True,
        )

        self._need_to_update_eval_statistics = True
        self.eval_statistics = dict()

        if self.post_pretrain_hyperparams:
            self.set_algorithm_weights(**self.post_pretrain_hyperparams)
示例#29
0
 def get_batch_from_buffer(self, replay_buffer):
     batch = replay_buffer.random_batch(self.bc_batch_size)
     batch = np_to_pytorch_batch(batch)
     return batch
示例#30
0
 def random_batch(self, batch_size):
     i = np.random.choice(self.size, batch_size, replace=(self.size < batch_size))
     data_dict = {
         'observations': self.data[i, :],
     }
     return np_to_pytorch_batch(data_dict)