def train(self, np_batch, np_demo_batch=None): self._num_train_steps += 1 batch = np_to_pytorch_batch(np_batch) if not np_demo_batch == None: demo_batch = np_to_pytorch_batch(np_demo_batch) self.train_from_torch(batch, demo_batch) else: self.train_from_torch(batch)
def train(self, np_batch, np_batch_dead): self._num_train_steps += 1 # for k in np_batch.keys(): # print(k) np_full_batch = { k: np.concatenate((np_batch[k], np_batch_dead[k])) for k in np_batch.keys() } full_batch = np_to_pytorch_batch(np_full_batch) batch_dead = np_to_pytorch_batch(np_batch_dead) self.train_from_torch(full_batch, batch_dead)
def random_batch(self, batch_size): env_i = np.random.choice(self.num_envs, batch_size) trans_i = np.random.choice(self.sample_size, batch_size) match_i = np.random.choice(self.num_envs, batch_size) trans_x = np.random.choice(self.sample_size, batch_size) trans_y = np.random.choice(self.sample_size, batch_size) rand_a = np.random.choice(self.num_envs - 1, batch_size // 2) rand_b = np.add(rand_a, np.ones(batch_size // 2)).astype(int) trans_m = np.random.choice(self.sample_size, batch_size // 2) trans_n = np.random.choice(self.sample_size, batch_size // 2) matches = np.random.uniform(0, 0.1, batch_size // 2) nonmatches = np.random.uniform(0.9, 1, batch_size // 2) swap_count = int(batch_size * 0.05) matches[:swap_count], nonmatches[:swap_count] = nonmatches[:swap_count], matches[:swap_count] labels = np.concatenate([matches, nonmatches]) data_dict = { 'observations': self.data['observations'][env_i, trans_i, :], 'env_set_1': self.data['observations'][match_i, trans_x, :], 'env_set_2': self.data['observations'][match_i, trans_y, :], } return np_to_pytorch_batch(data_dict)
def train(self, np_batch): # 1 训练步数,递增 self._num_train_steps += 1 # 2 转换数据为torch格式 batch = np_to_pytorch_batch(np_batch) # 3 训练 self.train_from_torch(batch)
def get_batch(self, batch_size, from_expert, keys=None): if from_expert: buffer = self.expert_replay_buffer else: buffer = self.replay_buffer batch = buffer.random_batch(batch_size, keys=keys) batch = np_to_pytorch_batch(batch) return batch
def get_batch(self, batch_size, keys=None, use_expert_buffer=True): if use_expert_buffer: rb = self.expert_replay_buffer else: rb = self.replay_buffer batch = rb.random_batch(batch_size, keys=keys) batch = np_to_pytorch_batch(batch) return batch
def random_batch(self, batch_size): i = np.random.choice(self.size, batch_size, replace=(self.size < batch_size)) obs = self.data[i, :] if self.normalize: obs = normalize_image(obs) data_dict = { 'observations': obs, } return np_to_pytorch_batch(data_dict)
def get_test_batch(self): batch = self.test_replay_buffer.random_batch(self.batch_size) batch = np_to_pytorch_batch(batch) obs = batch['observations'] next_obs = batch['next_observations'] goals = batch['resampled_goals'] batch['observations'] = torch.cat((obs, goals), dim=1) batch['next_observations'] = torch.cat((next_obs, goals), dim=1) return batch
def train(self, np_batch_dict): """ :param np_batch_dict: dict with 'safe' and 'danger' keys with batches :return: """ self._num_train_steps += 1 # concat data for danger and safe transitions np_full_batch = {k: np.concatenate((np_batch_dict['safe'][k], np_batch_dict['danger'][k])) for k in np_batch_dict['safe'].keys()} # convert to pytorch torch_batch_full = np_to_pytorch_batch(np_full_batch) torch_batch_danger = np_to_pytorch_batch(np_batch_dict['safe']) # pack it to dict torch_batch_dict = {'full': torch_batch_full, 'danger': torch_batch_danger} self.train_from_torch(torch_batch_dict)
def random_batch(self, batch_size): traj_i = np.random.choice(np.arange(self.size), batch_size) trans_i = np.random.choice(np.arange(self.traj_length - 1), batch_size) data_dict = { 'observations': self.data['observations'][traj_i, trans_i, :], 'next_observations': self.data['observations'][traj_i, trans_i + 1, :], 'actions': self.data['actions'][traj_i, trans_i, :], } return np_to_pytorch_batch(data_dict)
def get_batch_from_buffer(replay_buffer, batch_size): """ :param replay_buffer: :param batch_size: :return: """ batch = replay_buffer.random_batch(batch_size) batch = np_to_pytorch_batch(batch) return batch
def get_batch(self, batch_size, from_target_state_buffer, keys=None): if from_target_state_buffer: buffer = self.target_state_buffer batch = { 'observations': buffer[np.random.choice(buffer.shape[0], size=batch_size)] } else: buffer = self.replay_buffer batch = buffer.random_batch(batch_size, keys=keys) batch = np_to_pytorch_batch(batch) return batch
def random_batch(self, batch_size): num_traj = self.replay_buffer._size // self.horizon traj_i = np.random.choice(num_traj, batch_size) trans_i = np.random.choice(self.horizon - 2, batch_size) indices = traj_i * self.horizon + trans_i batch = dict( x0=self.replay_buffer._obs["image_observation"][indices], x1=self.replay_buffer._obs["image_observation"][indices+1], x2=self.replay_buffer._obs["image_observation"][indices+2], ) return np_to_pytorch_batch(batch)
def get_batch(self): sample_size = self.batch_size // 2 batch1 = self.replay_buffer1().random_batch(sample_size) batch2 = self.replay_buffer2().random_batch(sample_size) new_batch = {} for k, v in batch1.items(): new_batch[k] = np.concatenate( ( v, batch2[k] ), axis=0, ) return np_to_pytorch_batch(new_batch)
def _statistics_from_paths(self, paths, stat_prefix): rewards, terminals, obs, actions, next_obs = split_paths(paths) np_batch = dict( rewards=rewards, terminals=terminals, observations=obs, actions=actions, next_observations=next_obs, ) batch = np_to_pytorch_batch(np_batch) statistics = self._statistics_from_batch(batch, stat_prefix) statistics.update(create_stats_ordered_dict( 'Num Paths', len(paths), stat_prefix=stat_prefix )) return statistics
def get_batch_from_buffer(self, replay_buffer): batch = replay_buffer.random_batch(self.bc_batch_size) batch = np_to_pytorch_batch(batch) # obs = batch['observations'] # next_obs = batch['next_observations'] # goals = batch['resampled_goals'] # import ipdb; ipdb.set_trace() # batch['observations'] = torch.cat(( # obs, # goals # ), dim=1) # batch['next_observations'] = torch.cat(( # next_obs, # goals # ), dim=1) return batch
def pretrain_q_with_bc_data(self, batch_size): logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_q.csv', relative_to_snapshot_dir=True) self.update_policy = True # then train policy and Q function together prev_time = time.time() for i in range(self.num_pretrain_steps): self.eval_statistics = dict() if i % self.pretraining_logging_period == 0: self._need_to_update_eval_statistics = True train_data = self.replay_buffer.random_batch(batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data) if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0: total_ret = self.do_rollouts() print("Return at step {} : {}".format(i, total_ret / 20)) if i % self.pretraining_logging_period == 0: if self.do_pretrain_rollouts: self.eval_statistics[ "pretrain_bc/avg_return"] = total_ret / 20 self.eval_statistics["batch"] = i self.eval_statistics["epoch_time"] = time.time() - prev_time logger.record_dict(self.eval_statistics) logger.dump_tabular(with_prefix=True, with_timestamp=False) prev_time = time.time() logger.remove_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) self._need_to_update_eval_statistics = True self.eval_statistics = dict()
def get_batch(self): batch = self.replay_buffer.random_batch_random_tau( self.batch_size, self.max_tau) """ Update the goal states/rewards """ num_steps_left = self._sample_taus_for_training(batch) obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] goals = batch['training_goals'] rewards = self._compute_rewards_np(batch, obs, actions, next_obs, goals) terminals = batch['terminals'] #not too sure what this code does if self.tau_sample_strategy == 'all_valid': obs = np.repeat(obs, self.max_tau + 1, 0) actions = np.repeat(actions, self.max_tau + 1, 0) next_obs = np.repeat(next_obs, self.max_tau + 1, 0) goals = np.repeat(goals, self.max_tau + 1, 0) rewards = np.repeat(rewards, self.max_tau + 1, 0) terminals = np.repeat(terminals, self.max_tau + 1, 0) if self.finite_horizon: terminals = 1 - (1 - terminals) * (num_steps_left != 0) if self.terminate_when_goal_reached: diff = self.env.convert_obs_to_goals(next_obs) - goals goal_not_reached = (np.linalg.norm(diff, axis=1, keepdims=True) > self.goal_reached_epsilon) terminals = 1 - (1 - terminals) * goal_not_reached if not self.dense_rewards: rewards = rewards * terminals """ Update the batch """ batch['rewards'] = rewards batch['terminals'] = terminals batch['actions'] = actions batch['num_steps_left'] = num_steps_left batch['goals'] = goals batch['observations'] = obs batch['next_observations'] = next_obs return np_to_pytorch_batch(batch)
def get_batch(self, training=True): if self.replay_buffer_is_split: replay_buffer = self.replay_buffer.get_replay_buffer(training) else: replay_buffer = self.replay_buffer batch = replay_buffer.random_batch(self.batch_size) """ Update the goal states/rewards """ num_steps_left = np.random.randint( 0, self.max_tau + 1, (self.batch_size, 1) ) terminals = 1 - (1 - batch['terminals']) * (num_steps_left != 0) batch['terminals'] = terminals obs = batch['observations'] next_obs = batch['next_observations'] if self.sample_train_goals_from == 'her': goals = batch['goals'] else: goals = self._sample_goals_for_training() goal_differences = np.abs( self.env.convert_obs_to_goals(next_obs) # - self.env.convert_obs_to_goals(obs) - goals ) batch['goal_differences'] = goal_differences * self.reward_scale batch['goals'] = goals """ Update the observations """ batch['observations'] = merge_into_flat_obs( obs=batch['observations'], goals=batch['goals'], num_steps_left=num_steps_left, ) batch['next_observations'] = merge_into_flat_obs( obs=batch['next_observations'], goals=batch['goals'], num_steps_left=num_steps_left-1, ) return np_to_pytorch_batch(batch)
def train(self): # first train only the Q function iteration = 0 for i in range(self.num_batches): train_data = self.replay_buffer.random_batch(self.batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] train_data['observations'] = obs train_data['next_observations'] = next_obs self.trainer.train_from_torch(train_data) if i % self.logging_period == 0: stats_with_prefix = add_prefix( self.trainer.eval_statistics, prefix="trainer/") self.trainer.end_epoch(iteration) iteration += 1 logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False)
def train(self, np_batch): batch = np_to_pytorch_batch(np_batch) rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Compute loss """ best_action_idxs = self.qf(next_obs).max(1, keepdim=True)[1] target_q_values = self.target_qf(next_obs).gather( 1, best_action_idxs).detach() y_target = rewards + (1. - terminals) * self.discount * target_q_values y_target = y_target.detach() # actions is a one-hot vector y_pred = torch.sum(self.qf(obs) * actions, dim=1, keepdim=True) qf_loss = self.qf_criterion(y_pred, y_target) """ Update networks """ self.qf_optimizer.zero_grad() qf_loss.backward() self.qf_optimizer.step() """ Soft target network updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf, self.target_qf, self.soft_target_tau) """ Save some statistics for eval using just one batch. """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Y Predictions', ptu.get_numpy(y_pred), ))
def pretrain_q_with_bc_data(self, batch_size): logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_q.csv', relative_to_snapshot_dir=True) prev_time = time.time() for i in range(self.num_pretrain_steps): self.eval_statistics = dict() if i % self.pretraining_logging_period == 0: self._need_to_update_eval_statistics = True train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data, pretrain=True) if i % self.pretraining_logging_period == 0: self.eval_statistics["batch"] = i self.eval_statistics["epoch_time"] = time.time() - prev_time stats_with_prefix = add_prefix(self.eval_statistics, prefix="trainer/") logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False) prev_time = time.time() logger.remove_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) self._need_to_update_eval_statistics = True self.eval_statistics = dict()
def random_batch(self, batch_size): traj_i = np.random.choice(self.size, batch_size) trans_i = np.random.choice(self.traj_length, batch_size) # conditioning = np.random.choice(self.traj_length, batch_size) # env = normalize_image(self.data['observations'][traj_i, conditioning, :]) try: env = normalize_image(self.data['env'][traj_i, :]) except: env = normalize_image(self.data['observations'][traj_i, 0, :]) x_t = normalize_image(self.data['observations'][traj_i, trans_i, :]) episode_num = np.random.randint(0, self.size) episode_obs = normalize_image(self.data['observations'][episode_num, :8, :]) data_dict = { 'x_t': x_t, 'env': env, 'episode_obs': episode_obs, } return np_to_pytorch_batch(data_dict)
def fix_data_set(self): for training in [True, False]: replay_buffer = self.replay_buffer.get_replay_buffer(training) batch_dict = {} for i in range(self.num_unique_batches): batch_size = min( replay_buffer.num_steps_can_sample(), self.batch_size ) batch = replay_buffer.random_batch(batch_size) goal_states = self.sample_goal_states(batch_size, training) new_rewards = self.env.compute_rewards( batch['observations'], batch['actions'], batch['next_observations'], goal_states, ) batch['goal_states'] = goal_states batch['rewards'] = new_rewards torch_batch = np_to_pytorch_batch(batch) batch_dict[i] = torch_batch self.mode_to_batch_iterator[training] = create_batch_iterator( self.num_unique_batches, batch_dict )
def random_batch(self, batch_size): traj_i = np.random.choice(self.size, batch_size) trans_i = np.random.choice(self.traj_length - 1, batch_size) try: env = normalize_image(self.data['env'][traj_i, :]) except: env = normalize_image(self.data['observations'][traj_i, 0, :]) x_t = normalize_image(self.data['observations'][traj_i, trans_i, :]) x_next = normalize_image(self.data['observations'][traj_i, trans_i + 1, :]) episode_num = np.random.randint(0, self.size) episode_obs = normalize_image(self.data['observations'][episode_num, :8, :]) data_dict = { 'x_t': x_t, 'x_next': x_next, 'env': env, 'actions': self.data['actions'][traj_i, trans_i, :], 'episode_obs': episode_obs, 'episode_acts': self.data['actions'][episode_num, :7, :], } return np_to_pytorch_batch(data_dict)
def train(self, np_batch): self._num_train_steps += 1 batch = np_to_pytorch_batch(np_batch) self.train_from_torch(batch)
def train_from_torch( self, batch, train=True, pretrain=False, ): """ :param batch: :param train: :param pretrain: :return: """ rewards = batch['rewards'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] weights = batch.get('weights', None) if self.reward_transform: rewards = self.reward_transform(rewards) if self.terminal_transform: terminals = self.terminal_transform(terminals) """ Policy and Alpha Loss """ dist = self.policy(obs) new_obs_actions, log_pi = dist.rsample_and_logprob() policy_mle = dist.mle_estimate() if self.brac: buf_dist = self.buffer_policy(obs) buf_log_pi = buf_dist.log_prob(actions) rewards = rewards + buf_log_pi if self.use_automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = self.alpha """ QF Loss """ q1_pred = self.qf1(obs, actions) q2_pred = self.qf2(obs, actions) # Make sure policy accounts for squashing functions like tanh correctly! next_dist = self.policy(next_obs) new_next_actions, new_log_pi = next_dist.rsample_and_logprob() target_q_values = torch.min( self.target_qf1(next_obs, new_next_actions), self.target_qf2(next_obs, new_next_actions), ) - alpha * new_log_pi q_target = self.reward_scale * rewards + ( 1. - terminals) * self.discount * target_q_values qf1_loss = self.qf_criterion(q1_pred, q_target.detach()) qf2_loss = self.qf_criterion(q2_pred, q_target.detach()) """ Policy Loss """ qf1_new_actions = self.qf1(obs, new_obs_actions) qf2_new_actions = self.qf2(obs, new_obs_actions) q_new_actions = torch.min( qf1_new_actions, qf2_new_actions, ) # Advantage-weighted regression if self.awr_use_mle_for_vf: v1_pi = self.qf1(obs, policy_mle) v2_pi = self.qf2(obs, policy_mle) v_pi = torch.min(v1_pi, v2_pi) else: if self.vf_K > 1: vs = [] for i in range(self.vf_K): u = dist.sample() q1 = self.qf1(obs, u) q2 = self.qf2(obs, u) v = torch.min(q1, q2) # v = q1 vs.append(v) v_pi = torch.cat(vs, 1).mean(dim=1) else: # v_pi = self.qf1(obs, new_obs_actions) v1_pi = self.qf1(obs, new_obs_actions) v2_pi = self.qf2(obs, new_obs_actions) v_pi = torch.min(v1_pi, v2_pi) if self.awr_sample_actions: u = new_obs_actions if self.awr_min_q: q_adv = q_new_actions else: q_adv = qf1_new_actions elif self.buffer_policy_sample_actions: buf_dist = self.buffer_policy(obs) u, _ = buf_dist.rsample_and_logprob() qf1_buffer_actions = self.qf1(obs, u) qf2_buffer_actions = self.qf2(obs, u) q_buffer_actions = torch.min( qf1_buffer_actions, qf2_buffer_actions, ) if self.awr_min_q: q_adv = q_buffer_actions else: q_adv = qf1_buffer_actions else: u = actions if self.awr_min_q: q_adv = torch.min(q1_pred, q2_pred) else: q_adv = q1_pred policy_logpp = dist.log_prob(u) if self.use_automatic_beta_tuning: buffer_dist = self.buffer_policy(obs) beta = self.log_beta.exp() kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist) beta_loss = -1 * (beta * (kldiv - self.beta_epsilon).detach()).mean() self.beta_optimizer.zero_grad() beta_loss.backward() self.beta_optimizer.step() else: beta = self.beta_schedule.get_value(self._n_train_steps_total) if self.normalize_over_state == "advantage": score = q_adv - v_pi if self.mask_positive_advantage: score = torch.sign(score) elif self.normalize_over_state == "Z": buffer_dist = self.buffer_policy(obs) K = self.Z_K buffer_obs = [] buffer_actions = [] log_bs = [] log_pis = [] for i in range(K): u = buffer_dist.sample() log_b = buffer_dist.log_prob(u) log_pi = dist.log_prob(u) buffer_obs.append(obs) buffer_actions.append(u) log_bs.append(log_b) log_pis.append(log_pi) buffer_obs = torch.cat(buffer_obs, 0) buffer_actions = torch.cat(buffer_actions, 0) p_buffer = torch.exp(torch.cat(log_bs, 0).sum(dim=1, )) log_pi = torch.cat(log_pis, 0) log_pi = log_pi.sum(dim=1, ) q1_b = self.qf1(buffer_obs, buffer_actions) q2_b = self.qf2(buffer_obs, buffer_actions) q_b = torch.min(q1_b, q2_b) q_b = torch.reshape(q_b, (-1, K)) adv_b = q_b - v_pi # if self._n_train_steps_total % 100 == 0: # import ipdb; ipdb.set_trace() # Z = torch.exp(adv_b / beta).mean(dim=1, keepdim=True) # score = torch.exp((q_adv - v_pi) / beta) / Z # score = score / sum(score) logK = torch.log(ptu.tensor(float(K))) logZ = torch.logsumexp(adv_b / beta - logK, dim=1, keepdim=True) logS = (q_adv - v_pi) / beta - logZ # logZ = torch.logsumexp(q_b/beta - logK, dim=1, keepdim=True) # logS = q_adv/beta - logZ score = F.softmax(logS, dim=0) # score / sum(score) else: error if self.clip_score is not None: score = torch.clamp(score, max=self.clip_score) if self.weight_loss and weights is None: if self.normalize_over_batch: weights = F.softmax(score / beta, dim=0) elif self.normalize_over_batch == "whiten": adv_mean = torch.mean(score) adv_std = torch.std(score) + 1e-5 normalized_score = (score - adv_mean) / adv_std weights = torch.exp(normalized_score / beta) elif self.normalize_over_batch == "exp": weights = torch.exp(score / beta) elif self.normalize_over_batch == "step_fn": weights = (score > 0).float() elif not self.normalize_over_batch: weights = score else: error weights = weights[:, 0] policy_loss = alpha * log_pi.mean() if self.use_awr_update and self.weight_loss: policy_loss = policy_loss + self.awr_weight * ( -policy_logpp * len(weights) * weights.detach()).mean() elif self.use_awr_update: policy_loss = policy_loss + self.awr_weight * ( -policy_logpp).mean() if self.use_reparam_update: policy_loss = policy_loss + self.reparam_weight * ( -q_new_actions).mean() policy_loss = self.rl_weight * policy_loss if self.compute_bc: train_policy_loss, train_logp_loss, train_mse_loss, _ = self.run_bc_batch( self.demo_train_buffer, self.policy) policy_loss = policy_loss + self.bc_weight * train_policy_loss if not pretrain and self.buffer_policy_reset_period > 0 and self._n_train_steps_total % self.buffer_policy_reset_period == 0: del self.buffer_policy_optimizer self.buffer_policy_optimizer = self.optimizer_class( self.buffer_policy.parameters(), weight_decay=self.policy_weight_decay, lr=self.policy_lr, ) self.optimizers[self.buffer_policy] = self.buffer_policy_optimizer for i in range(self.num_buffer_policy_train_steps_on_reset): if self.train_bc_on_rl_buffer: if self.advantage_weighted_buffer_loss: buffer_dist = self.buffer_policy(obs) buffer_u = actions buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob( ) buffer_policy_logpp = buffer_dist.log_prob(buffer_u) buffer_policy_logpp = buffer_policy_logpp[:, None] buffer_q1_pred = self.qf1(obs, buffer_u) buffer_q2_pred = self.qf2(obs, buffer_u) buffer_q_adv = torch.min(buffer_q1_pred, buffer_q2_pred) buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions) buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions) buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi) buffer_score = buffer_q_adv - buffer_v_pi buffer_weights = F.softmax(buffer_score / beta, dim=0) buffer_policy_loss = self.awr_weight * ( -buffer_policy_logpp * len(buffer_weights) * buffer_weights.detach()).mean() else: buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch( self.replay_buffer.train_replay_buffer, self.buffer_policy) self.buffer_policy_optimizer.zero_grad() buffer_policy_loss.backward(retain_graph=True) self.buffer_policy_optimizer.step() if self.train_bc_on_rl_buffer: if self.advantage_weighted_buffer_loss: buffer_dist = self.buffer_policy(obs) buffer_u = actions buffer_new_obs_actions, _ = buffer_dist.rsample_and_logprob() buffer_policy_logpp = buffer_dist.log_prob(buffer_u) buffer_policy_logpp = buffer_policy_logpp[:, None] buffer_q1_pred = self.qf1(obs, buffer_u) buffer_q2_pred = self.qf2(obs, buffer_u) buffer_q_adv = torch.min(buffer_q1_pred, buffer_q2_pred) buffer_v1_pi = self.qf1(obs, buffer_new_obs_actions) buffer_v2_pi = self.qf2(obs, buffer_new_obs_actions) buffer_v_pi = torch.min(buffer_v1_pi, buffer_v2_pi) buffer_score = buffer_q_adv - buffer_v_pi buffer_weights = F.softmax(buffer_score / beta, dim=0) buffer_policy_loss = self.awr_weight * ( -buffer_policy_logpp * len(buffer_weights) * buffer_weights.detach()).mean() else: buffer_policy_loss, buffer_train_logp_loss, buffer_train_mse_loss, _ = self.run_bc_batch( self.replay_buffer.train_replay_buffer, self.buffer_policy) """ Update networks """ if self._n_train_steps_total % self.q_update_period == 0: self.qf1_optimizer.zero_grad() qf1_loss.backward() self.qf1_optimizer.step() self.qf2_optimizer.zero_grad() qf2_loss.backward() self.qf2_optimizer.step() if self._n_train_steps_total % self.policy_update_period == 0 and self.update_policy: self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() if self.train_bc_on_rl_buffer and self._n_train_steps_total % self.policy_update_period == 0: self.buffer_policy_optimizer.zero_grad() buffer_policy_loss.backward() self.buffer_policy_optimizer.step() """ Soft Updates """ if self._n_train_steps_total % self.target_update_period == 0: ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau) ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau) """ Save some statistics for eval """ if self._need_to_update_eval_statistics: self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ policy_loss = (log_pi - q_new_actions).mean() self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss)) self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'Q1 Predictions', ptu.get_numpy(q1_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q2 Predictions', ptu.get_numpy(q2_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'Q Targets', ptu.get_numpy(q_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Log Pis', ptu.get_numpy(log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'rewards', ptu.get_numpy(rewards), )) self.eval_statistics.update( create_stats_ordered_dict( 'terminals', ptu.get_numpy(terminals), )) policy_statistics = add_prefix(dist.get_diagnostics(), "policy/") self.eval_statistics.update(policy_statistics) self.eval_statistics.update( create_stats_ordered_dict( 'Advantage Weights', ptu.get_numpy(weights), )) self.eval_statistics.update( create_stats_ordered_dict( 'Advantage Score', ptu.get_numpy(score), )) if self.normalize_over_state == "Z": self.eval_statistics.update( create_stats_ordered_dict( 'logZ', ptu.get_numpy(logZ), )) if self.use_automatic_entropy_tuning: self.eval_statistics['Alpha'] = alpha.item() self.eval_statistics['Alpha Loss'] = alpha_loss.item() if self.compute_bc: test_policy_loss, test_logp_loss, test_mse_loss, _ = self.run_bc_batch( self.demo_test_buffer, self.policy) self.eval_statistics.update({ "bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss), "bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss), "bc/Train MSE": ptu.get_numpy(train_mse_loss), "bc/Test MSE": ptu.get_numpy(test_mse_loss), "bc/train_policy_loss": ptu.get_numpy(train_policy_loss), "bc/test_policy_loss": ptu.get_numpy(test_policy_loss), }) if self.train_bc_on_rl_buffer: _, buffer_train_logp_loss, _, _ = self.run_bc_batch( self.replay_buffer.train_replay_buffer, self.buffer_policy) _, buffer_test_logp_loss, _, _ = self.run_bc_batch( self.replay_buffer.validation_replay_buffer, self.buffer_policy) buffer_dist = self.buffer_policy(obs) kldiv = torch.distributions.kl.kl_divergence(dist, buffer_dist) _, train_offline_logp_loss, _, _ = self.run_bc_batch( self.demo_train_buffer, self.buffer_policy) _, test_offline_logp_loss, _, _ = self.run_bc_batch( self.demo_test_buffer, self.buffer_policy) self.eval_statistics.update({ "buffer_policy/Train Online Logprob": -1 * ptu.get_numpy(buffer_train_logp_loss), "buffer_policy/Test Online Logprob": -1 * ptu.get_numpy(buffer_test_logp_loss), "buffer_policy/Train Offline Logprob": -1 * ptu.get_numpy(train_offline_logp_loss), "buffer_policy/Test Offline Logprob": -1 * ptu.get_numpy(test_offline_logp_loss), "buffer_policy/train_policy_loss": ptu.get_numpy(buffer_policy_loss), # "buffer_policy/test_policy_loss": ptu.get_numpy(buffer_test_policy_loss), "buffer_policy/kl_div": ptu.get_numpy(kldiv.mean()), }) if self.use_automatic_beta_tuning: self.eval_statistics.update({ "adaptive_beta/beta": ptu.get_numpy(beta.mean()), "adaptive_beta/beta loss": ptu.get_numpy(beta_loss.mean()), }) if self.validation_qlearning: train_data = self.replay_buffer.validation_replay_buffer.random_batch( self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data[ 'observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.test_from_torch(train_data) self._n_train_steps_total += 1
def pretrain_q_with_bc_data(self): """ :return: """ logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_q.csv', relative_to_snapshot_dir=True) self.update_policy = False # first train only the Q function for i in range(self.q_num_pretrain1_steps): self.eval_statistics = dict() train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data, pretrain=True) if i % self.pretraining_logging_period == 0: stats_with_prefix = add_prefix(self.eval_statistics, prefix="trainer/") logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False) self.update_policy = True # then train policy and Q function together prev_time = time.time() for i in range(self.q_num_pretrain2_steps): self.eval_statistics = dict() if i % self.pretraining_logging_period == 0: self._need_to_update_eval_statistics = True train_data = self.replay_buffer.random_batch(self.bc_batch_size) train_data = np_to_pytorch_batch(train_data) obs = train_data['observations'] next_obs = train_data['next_observations'] # goals = train_data['resampled_goals'] train_data['observations'] = obs # torch.cat((obs, goals), dim=1) train_data[ 'next_observations'] = next_obs # torch.cat((next_obs, goals), dim=1) self.train_from_torch(train_data, pretrain=True) if i % self.pretraining_logging_period == 0: self.eval_statistics["batch"] = i self.eval_statistics["epoch_time"] = time.time() - prev_time stats_with_prefix = add_prefix(self.eval_statistics, prefix="trainer/") logger.record_dict(stats_with_prefix) logger.dump_tabular(with_prefix=True, with_timestamp=False) prev_time = time.time() logger.remove_tabular_output( 'pretrain_q.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) self._need_to_update_eval_statistics = True self.eval_statistics = dict() if self.post_pretrain_hyperparams: self.set_algorithm_weights(**self.post_pretrain_hyperparams)
def get_batch_from_buffer(self, replay_buffer): batch = replay_buffer.random_batch(self.bc_batch_size) batch = np_to_pytorch_batch(batch) return batch
def random_batch(self, batch_size): i = np.random.choice(self.size, batch_size, replace=(self.size < batch_size)) data_dict = { 'observations': self.data[i, :], } return np_to_pytorch_batch(data_dict)