Пример #1
0
class PEARLFineTuningHelper:

    def __init__(
            self,
            env,
            agent,
            qf1,
            qf2,
            target_qf1,
            target_qf2,
            total_steps=int(1e6),
            max_path_length=200,
            num_exp_traj_eval=1,
            start_fine_tuning=10,
            fine_tuning_steps=1,
            should_freeze_z=True,
            replay_buffer_size=int(1e6),
            batch_size=256,
            discount=0.99,
            policy_lr=1e-4,
            qf_lr=1e-4,
            temp_lr=1e-4,
            target_entropy=None,
            optimizer_class=torch.optim.Adam,
            soft_target_tau=1e-2
    ):
        self.env = env
        self.agent = agent

        # Ctitic networks
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.log_alpha = torch.zeros(1, requires_grad=True, device='cuda')
        self.log_alpha.to(device)
        self.target_entropy = target_entropy

        # Experimental setting
        self.total_steps = total_steps
        self.max_path_length = max_path_length
        self.num_exp_traj_eval = num_exp_traj_eval
        self.start_fine_tuning = start_fine_tuning
        self.fine_tuning_steps = fine_tuning_steps
        self.should_freeze_z = should_freeze_z

        # Hyperparams
        self.batch_size = batch_size
        self.discount = discount
        self.soft_target_tau = soft_target_tau

        self.replay_buffer = SimpleReplayBuffer(
            max_replay_buffer_size=replay_buffer_size,
            observation_dim=int(np.prod(env.observation_space.shape)),
            action_dim=int(np.prod(env.action_space.shape)),
        )

        self.q_losses = []
        self.temp_losses = []
        self.policy_losses = []
        self.temp_vals = []

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.agent.policy.parameters(),
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            lr=qf_lr,
        )
        self.temp_optimizer = optimizer_class(
            [self.log_alpha],
            lr=temp_lr,
        )

        self.print_experiment_description()

    def get_mean(self, losses):
        if not losses:
            return None
        tot = 0
        for tensor in losses:
            tot += np.mean(tensor.to('cpu').detach().numpy())
        return tot / len(losses)

    def collect_samples(self, should_accum_context):
        path = self.rollout(should_accum_context)
        self.replay_buffer.add_path(path)
        steps = path['rewards'].shape[0]
        ret = sum(path['rewards'])[0]
        return ret, steps

    def rollout(self, should_accum_context):
        should_fine_tune = not should_accum_context
        observations = []
        actions = []
        rewards = []
        terminals = []
        agent_infos = []
        env_infos = []
        o = self.env.reset()
        next_o = None
        path_length = 0
        done = False

        while (not done):
            a, agent_info = self.agent.get_action(o)
            next_o, r, d, env_info = self.env.step(a)
            real_done = False if path_length == self.max_path_length else d
            observations.append(o)
            rewards.append(r)
            terminals.append(real_done)
            actions.append(a)
            agent_infos.append(agent_info)
            path_length += 1
            o = next_o
            env_infos.append(env_info)
            if should_accum_context:
                self.agent.update_context([o, a, r, next_o, d, env_info])
            if should_fine_tune:
                for j in range(self.fine_tuning_steps):
                    self.fine_tuning_step()
            if d or path_length >= self.max_path_length:
                done = True

        actions = np.array(actions)
        if len(actions.shape) == 1:
            actions = np.expand_dims(actions, 1)
        observations = np.array(observations)
        if len(observations.shape) == 1:
            observations = np.expand_dims(observations, 1)
            next_o = np.array([next_o])
        next_observations = np.vstack(
            (
                observations[1:, :],
                np.expand_dims(next_o, 0)
            )
        )

        if should_accum_context:
            self.agent.sample_z()

        return dict(
            observations=observations,
            actions=actions,
            rewards=np.array(rewards).reshape(-1, 1),
            next_observations=next_observations,
            terminals=np.array(terminals).reshape(-1, 1),
            agent_infos=agent_infos,
            env_infos=env_infos,
        )

    def get_samples(self):
        batch = ptu.np_to_pytorch_batch(self.replay_buffer.random_batch(self.batch_size))
        o = batch['observations'][None, ...]
        a = batch['actions'][None, ...]
        r = batch['rewards'][None, ...]
        no = batch['next_observations'][None, ...]
        t = batch['terminals'][None, ...]
        return o, a, r, no, t

    def _min_q(self, obs, actions, task_z):
        q1 = self.qf1(obs, actions, task_z.detach())
        q2 = self.qf2(obs, actions, task_z.detach())
        min_q = torch.min(q1, q2)
        return min_q

    def _update_target_networks(self):
        ptu.soft_update_from_to(self.qf1, self.target_qf1, self.soft_target_tau)
        ptu.soft_update_from_to(self.qf2, self.target_qf2, self.soft_target_tau)

    def fine_tuning_step(self):
        obs, actions, rewards, next_obs, terms = self.get_samples()

        # flattens out the task dimension
        t, b, _ = obs.size()
        obs_flat = obs.view(t * b, -1)
        actions_flat = actions.view(t * b, -1)
        next_obs_flat = next_obs.view(t * b, -1)
        rewards_flat = rewards.view(self.batch_size, -1)
        terms_flat = terms.view(self.batch_size, -1)

        """
        QF Loss
        """
        with torch.no_grad():
            next_policy_outputs, task_z = self.agent(next_obs, self.agent.context)
            next_new_actions, _, _, next_log_prob = next_policy_outputs[:4]
            t_q1_pred = self.target_qf1(next_obs_flat, next_new_actions, task_z.detach())  # TODO: Remove .detach() if redundant
            t_q2_pred = self.target_qf2(next_obs_flat, next_new_actions, task_z.detach())
            t_q_min = torch.min(t_q1_pred, t_q2_pred)
            q_target = rewards_flat + (1. - terms_flat) * self.discount * (t_q_min - self.alpha * next_log_prob)
        q1_pred = self.qf1(obs_flat, actions_flat, task_z.detach())                    # TODO: Remove .detach() if redundant
        q2_pred = self.qf2(obs_flat, actions_flat, task_z.detach())
        qf_loss = torch.mean((q1_pred - q_target.detach()) ** 2) + torch.mean((q2_pred - q_target.detach()) ** 2)

        self.qf1_optimizer.zero_grad()
        self.qf2_optimizer.zero_grad()
        qf_loss.backward()
        self.qf1_optimizer.step()
        self.qf2_optimizer.step()

        """
        Policy and Temp Loss
        """
        for p in self.qf1.parameters():
            p.requires_grad = False
        for p in self.qf2.parameters():
            p.requires_grad = False

        policy_outputs, task_z = self.agent(obs, self.agent.context)
        new_actions, policy_mean, policy_log_std, log_prob = policy_outputs[:4]
        min_q_new_actions = self._min_q(obs_flat, new_actions, task_z)

        policy_loss = (self.alpha * log_prob - min_q_new_actions).mean()
        temp_loss = -self.alpha * (log_prob.detach() + self.target_entropy).mean()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        self.temp_optimizer.zero_grad()
        temp_loss.backward()
        self.temp_optimizer.step()

        for p in self.qf1.parameters():
            p.requires_grad = True
        for p in self.qf2.parameters():
            p.requires_grad = True

        """
        Update Target Networks
        """
        self._update_target_networks()

        self.q_losses.append(qf_loss.detach())
        self.temp_losses.append(temp_loss.detach())
        self.policy_losses.append(policy_loss.detach())
        self.temp_vals.append(self.alpha.detach())

    def evaluate_agent(self, n_starts=10):
        reward_sum = 0
        for _ in range(n_starts):
            path = rollout(self.env, self.agent, max_path_length=self.max_path_length, accum_context=False)
            reward_sum += sum(path['rewards'])[0]
        return reward_sum / n_starts

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def fine_tune(self, variant, seed):
        random.seed(seed)
        torch.manual_seed(seed)
        np.random.seed(seed)
        self.env.seed(seed)

        cumulative_timestep = 0
        i_episode = 0

        df = pd.DataFrame(
            columns=[
                'step',
                'real_step',
                'train_reward',
                'eval_reward',
                'loss/q-f1',
                'loss/alpha',
                'loss/policy',
                'val/alpha'
            ]
        )

        # For this experiment, we are evaluating in just one sampled task from the meta-test set
        tasks = self.env.get_all_task_idx()
        eval_tasks = list(tasks[-variant['n_eval_tasks']:])
        idx = eval_tasks[0]
        self.env.reset_task(idx)

        self.agent.clear_z()
        while cumulative_timestep < self.total_steps:
            i_episode += 1
            should_infer_posterior = self.num_exp_traj_eval <= i_episode < self.start_fine_tuning
            should_fine_tune = self.start_fine_tuning <= i_episode
            should_accum_context = not should_fine_tune
            if should_fine_tune and self.should_freeze_z and (not self.agent.freeze_z):
                self.agent.freeze_z = True
            train_reward, episode_steps = self.collect_samples(should_accum_context)
            cumulative_timestep += episode_steps
            if should_infer_posterior:
                self.agent.infer_posterior(self.agent.context)
            eval_reward = self.evaluate_agent()
            message = 'Episode {} \t\t Samples {} \t\t Real samples {} \t\t Train reward: {} \t\t Eval reward: {}'
            print(message.format(i_episode, i_episode * self.max_path_length, cumulative_timestep, train_reward,
                                 eval_reward))
            new_df_row = {
                'step': int(i_episode * self.max_path_length),
                'real_step': int(cumulative_timestep),
                'train_reward': train_reward,
                'eval_reward': eval_reward,
                'loss/q-f1': self.get_mean(self.q_losses),
                'loss/alpha': self.get_mean(self.temp_losses),
                'loss/policy': self.get_mean(self.policy_losses),
                'val/alpha': self.get_mean(self.temp_vals)
            }
            self.q_losses = []
            self.temp_losses = []
            self.policy_losses = []
            self.temp_vals = []
            df = df.append(new_df_row, ignore_index=True)
            results_path = "results_ft/{}/ft{}".format(variant['env_name'], "_{}".format(seed - 1))
            if not os.path.isdir(results_path):
                os.makedirs(results_path)
            df.to_csv("{}/progress.csv".format(results_path))

    def print_experiment_description(self):
        print("\n\n", " -" * 15, "\n")
        print("Total steps:  \t\t\t", self.total_steps)
        print("Max path length:  \t\t\t", self.max_path_length)
        print("Trajectory length with prior:  \t\t\t", self.num_exp_traj_eval)
        print("Start fine tuning after:  \t\t\t", self.start_fine_tuning)
        print("Number of fine-tuning steps:  \t\t\t", self.fine_tuning_steps)
        print("Should freeze Z during fine-tuning?  \t\t\t", self.should_freeze_z)
        print("Batch size:  \t\t\t", self.batch_size)
        print("Gamma:  \t\t\t", self.discount)
        print("Tau:  \t\t\t", self.soft_target_tau)
def experiment(exp_specs):
    # Load the data -----------------------------------------------------------
    extra_data_path = exp_specs['extra_data_path']
    train_replay_buffer = joblib.load(extra_data_path)['replay_buffer']
    train_replay_buffer.change_max_size_to_cur_size()
    train_replay_buffer._next_obs = train_replay_buffer._next_obs[:, exp_specs[
        'extra_obs_dim']:]

    print('\nRewards: {} +/- {}'.format(np.mean(train_replay_buffer._rewards),
                                        np.std(train_replay_buffer._rewards)))

    next_obs_mean = np.mean(train_replay_buffer._next_obs, 0)
    next_obs_std = np.std(train_replay_buffer._next_obs, 0)
    print('\nNext Obs:\n{}\n+/-\n{}'.format(next_obs_mean, next_obs_std))

    print('\nAvg Next Obs Square Norm: {}'.format(
        np.mean(np.linalg.norm(train_replay_buffer._next_obs, axis=1)**2)))

    sample_batch = train_replay_buffer.random_batch(
        exp_specs['train_batch_size'])
    obs_dim = sample_batch['observations'].shape[-1]
    act_dim = sample_batch['actions'].shape[-1]

    val_replay_buffer = SimpleReplayBuffer(exp_specs['val_set_size'], obs_dim,
                                           act_dim)
    val_replay_buffer.set_buffer_from_dict(
        train_replay_buffer.sample_and_remove(exp_specs['val_set_size']))
    train_replay_buffer.set_buffer_from_dict(
        train_replay_buffer.sample_and_remove(exp_specs['train_set_size']))

    # Model Definitions -------------------------------------------------------
    model = GenericMap([obs_dim + act_dim],
                       [obs_dim - exp_specs['extra_obs_dim'] + 1],
                       siamese_input=False,
                       siamese_output=False,
                       num_hidden_layers=exp_specs['num_hidden_layers'],
                       hidden_dim=exp_specs['hidden_dim'],
                       act='relu',
                       use_bn=True,
                       deterministic=True)

    gap_model = GenericMap([obs_dim + act_dim], [
        obs_dim - exp_specs['extra_obs_dim'],
        obs_dim - exp_specs['extra_obs_dim']
    ],
                           siamese_input=False,
                           siamese_output=True,
                           num_hidden_layers=exp_specs['num_hidden_layers'],
                           hidden_dim=exp_specs['hidden_dim'],
                           act='relu',
                           use_bn=True,
                           deterministic=True)

    model_optim = Adam(model.parameters(), lr=float(exp_specs['lr']))
    gap_model_optim = Adam(gap_model.parameters(),
                           lr=float(exp_specs['gap_lr']))

    # Train -------------------------------------------------------------------
    model.train()
    for iter_num in range(exp_specs['max_iters']):
        model_optim.zero_grad()
        gap_model_optim.zero_grad()

        batch = train_replay_buffer.random_batch(exp_specs['train_batch_size'])
        batch = convert_numpy_dict_to_pytorch(batch)
        inputs = Variable(
            torch.cat([batch['observations'], batch['actions']], -1))
        outputs = Variable(
            torch.cat([batch['next_observations'], batch['rewards']], -1))
        true_next_obs = Variable(batch['next_observations'])

        preds = model([inputs])[0]
        gap_preds = gap_model([inputs])
        lower, upper = gap_preds[0], gap_preds[1]
        # residual for observations
        # preds = preds + Variable(torch.cat([batch['observations'], torch.zeros(exp_specs['train_batch_size'], 1)], 1))

        loss = torch.mean(torch.sum((outputs - preds)**2, -1))

        lower_loss = torch.mean(torch.sum(F.relu(lower - true_next_obs), -1))
        upper_loss = torch.mean(torch.sum(F.relu(true_next_obs - upper), -1))
        upper_lower_gap_loss = torch.mean(
            torch.sum(torch.abs(upper - lower), -1))

        total_loss = loss + upper_loss + lower_loss + float(
            exp_specs['upper_lower_gap_loss_weight']) * upper_lower_gap_loss

        total_loss.backward()
        model_optim.step()
        gap_model_optim.step()

        if iter_num % exp_specs['freq_val'] == 0:
            model.eval()

            val_batch = val_replay_buffer.random_batch(
                exp_specs['val_batch_size'])
            val_batch = convert_numpy_dict_to_pytorch(val_batch)
            inputs = Variable(
                torch.cat([val_batch['observations'], val_batch['actions']],
                          -1))
            outputs = Variable(
                torch.cat(
                    [val_batch['next_observations'], val_batch['rewards']],
                    -1))
            true_next_obs = Variable(val_batch['next_observations'])

            preds = model([inputs])[0]
            gap_preds = gap_model([inputs])
            lower, upper = gap_preds[0], gap_preds[1]
            # residual for observations
            # pred = preds + Variable(torch.cat([val_batch['observations'], torch.zeros(exp_specs['train_batch_size'], 1)], 1))

            loss = torch.mean(torch.sum((outputs - preds)**2, -1))
            next_obs_loss = torch.mean(
                torch.sum((outputs[:, :-1] - preds[:, :-1])**2, -1))
            rew_loss = torch.mean(
                torch.sum((outputs[:, -1:] - preds[:, -1:])**2, -1))

            lower_loss = torch.mean(
                torch.sum(F.relu(lower - true_next_obs), -1))
            upper_loss = torch.mean(
                torch.sum(F.relu(true_next_obs - upper), -1))
            upper_lower_gap_loss = torch.mean(
                torch.sum(torch.abs(upper - lower), -1))

            pred_over_upper = torch.mean(
                torch.sum(F.relu(preds[:, :-1] - upper), -1))
            pred_under_lower = torch.mean(
                torch.sum(F.relu(lower - preds[:, :-1]), -1))

            adj_next_obs_pred = torch.max(torch.min(preds[:, :-1], upper),
                                          lower)
            adj_next_obs_loss = torch.mean(
                torch.sum((outputs[:, :-1] - adj_next_obs_pred)**2, -1))

            ul_mean = (upper + lower) / 2.0
            ul_mean_as_obs_loss = torch.mean(
                torch.sum((outputs[:, :-1] - ul_mean)**2, -1))

            print('\n')
            print('-' * 20)
            print('Iter %d' % iter_num)
            print('Loss: %.4f' % loss)
            print('Obs Loss: %.4f' % next_obs_loss)
            print('Rew Loss: %.4f' % rew_loss)
            print('\nUpper Loss: %.4f' % upper_loss)
            print('Lower Loss: %.4f' % lower_loss)
            print('UL-Gap Loss: %.4f' % upper_lower_gap_loss)
            print('\nPred Over Upper: %.4f' % pred_over_upper)
            print('Pred Under Lower: %.4f' % pred_under_lower)
            print('\nAdj Obs Loss: %.4f' % adj_next_obs_loss)
            print('\nUL Mean as Obs Loss: %.4f' % ul_mean_as_obs_loss)

            model.train()
def experiment(exp_specs):
    # Set up logging ----------------------------------------------------------
    exp_id = exp_specs['exp_id']
    exp_prefix = exp_specs['exp_name']
    seed = exp_specs['seed']
    set_seed(seed)
    setup_logger(exp_prefix=exp_prefix, exp_id=exp_id, variant=exp_specs)

    # Load the data -----------------------------------------------------------
    extra_data_path = exp_specs['extra_data_path']
    train_replay_buffer = joblib.load(extra_data_path)['replay_buffer']
    train_replay_buffer.change_max_size_to_cur_size()
    train_replay_buffer._next_obs = train_replay_buffer._next_obs[:,exp_specs['extra_obs_dim']:]
    if exp_specs['remove_env_info']:
        train_replay_buffer._observations = train_replay_buffer._observations[:,exp_specs['extra_obs_dim']:]
    else:
        if exp_specs['normalize_env_info']:
            low, high = exp_specs['env_info_range'][0], exp_specs['env_info_range'][1]
            train_replay_buffer._observations[:,:exp_specs['extra_obs_dim']] -= (low + high)/2.0
            train_replay_buffer._observations[:,:exp_specs['extra_obs_dim']] /= (high - low)/2.0

    print('\nRewards: {} +/- {}'.format(
        np.mean(train_replay_buffer._rewards),
        np.std(train_replay_buffer._rewards)
    ))

    next_obs_mean = np.mean(train_replay_buffer._next_obs, 0)
    next_obs_std = np.std(train_replay_buffer._next_obs, 0)
    print('\nNext Obs:\n{}\n+/-\n{}'.format(
        next_obs_mean,
        next_obs_std
    ))

    print('\nAvg Next Obs Square Norm: {}'.format(
        np.mean(np.linalg.norm(train_replay_buffer._next_obs, axis=1)**2)
    ))

    sample_batch = train_replay_buffer.random_batch(exp_specs['train_batch_size'])
    obs_dim = sample_batch['observations'].shape[-1]
    act_dim = sample_batch['actions'].shape[-1]

    val_replay_buffer = SimpleReplayBuffer(exp_specs['val_set_size'], obs_dim, act_dim)
    val_replay_buffer.set_buffer_from_dict(
        train_replay_buffer.sample_and_remove(exp_specs['val_set_size'])
    )
    if exp_specs['train_from_beginning_transitions']:
        trans_dict = dict(
            observations=train_replay_buffer._observations[:exp_specs['train_set_size']],
            actions=train_replay_buffer._actions[:exp_specs['train_set_size']],
            rewards=train_replay_buffer._rewards[:exp_specs['train_set_size']],
            terminals=train_replay_buffer._terminals[:exp_specs['train_set_size']],
            next_observations=train_replay_buffer._next_obs[:exp_specs['train_set_size']],
        )
        train_replay_buffer.set_buffer_from_dict(trans_dict)
    else:
        train_replay_buffer.set_buffer_from_dict(
            train_replay_buffer.sample_and_remove(exp_specs['train_set_size'])
        )

    # Model Definitions -------------------------------------------------------
    if exp_specs['remove_env_info']:
        output_dim = [obs_dim + 1]
    else:
        output_dim = [obs_dim - exp_specs['extra_obs_dim'] + 1]
    model = GenericMap(
        [obs_dim + act_dim],
        output_dim,
        siamese_input=False,
        siamese_output=False,
        num_hidden_layers=exp_specs['num_hidden_layers'],
        hidden_dim=exp_specs['hidden_dim'],
        act='relu',
        use_bn=True,
        deterministic=True
    )

    model_optim = Adam(model.parameters(), lr=float(exp_specs['lr']))

    # Train -------------------------------------------------------------------
    model.train()
    for iter_num in range(exp_specs['max_iters']):
        model_optim.zero_grad()

        batch = train_replay_buffer.random_batch(exp_specs['train_batch_size'])
        batch = convert_numpy_dict_to_pytorch(batch)
        inputs = Variable(torch.cat([batch['observations'], batch['actions']], -1))
        outputs = Variable(torch.cat([batch['next_observations'], batch['rewards']], -1))

        preds = model([inputs])[0]
        if exp_specs['residual']:
            # residual for observations
            preds = preds + Variable(
                        torch.cat(
                            [
                                batch['observations'][:,exp_specs['extra_obs_dim']:],
                                torch.zeros(exp_specs['train_batch_size'], 1)
                            ],
                        1)
                    )
        
        loss = torch.mean(torch.sum((outputs - preds)**2, -1))

        loss.backward()
        model_optim.step()

        if iter_num % exp_specs['freq_val'] == 0:
            model.eval()

            val_batch = val_replay_buffer.random_batch(exp_specs['val_batch_size'])
            val_batch = convert_numpy_dict_to_pytorch(val_batch)
            inputs = Variable(torch.cat([val_batch['observations'], val_batch['actions']], -1))
            outputs = Variable(torch.cat([val_batch['next_observations'], val_batch['rewards']], -1))

            # print(exp_specs['remove_env_info'])
            # print(inputs)
            # print(outputs)
            # sleep(5)
            
            preds = model([inputs])[0]
            if exp_specs['residual']:
                # residual for observations
                preds = preds + Variable(
                            torch.cat(
                                [
                                    val_batch['observations'][:,exp_specs['extra_obs_dim']:],
                                    torch.zeros(exp_specs['train_batch_size'], 1)
                                ],
                            1)
                        )

            loss = torch.mean(torch.sum((outputs - preds)**2, -1))
            next_obs_loss = torch.mean(torch.sum((outputs[:,:-1] - preds[:,:-1])**2, -1))
            rew_loss = torch.mean(torch.sum((outputs[:,-1:] - preds[:,-1:])**2, -1))

            print('\n')
            print('-'*20)
            logger.record_tabular('Iter', iter_num)
            logger.record_tabular('Loss', loss.data[0])
            logger.record_tabular('Obs Loss', next_obs_loss.data[0])
            logger.record_tabular('Rew Loss', rew_loss.data[0])
            logger.dump_tabular(with_prefix=False, with_timestamp=False)

            model.train()