Beispiel #1
0
    def train(self, dataset, test_dataset=None, max_epochs=10000, save_step=1000, print_step=1,  plot_step=1,
              record_stats=False):

        for epoch in range(1, max_epochs + 1):
            stats = self.train_epoch(dataset, epoch)
            test = self.train_epoch(test_dataset, epoch, train=False)
            for k, v in test.items():
                stats['V ' + k] = v
            stats['Test RL'] = self.test_pd(test_dataset)

            if epoch % print_step == 0:
                with logger.prefix('itr #%d | ' % epoch):
                    self.print_diagnostics(stats)

            if epoch % plot_step == 0:
                self.plot_compare(dataset, epoch)
                self.plot_interp(dataset, epoch)
                self.plot_compare(test_dataset, epoch, save_dir='test')
                self.plot_random(dataset, epoch)

            if epoch % save_step == 0 and logger.get_snapshot_dir() is not None:
                self.save(logger.get_snapshot_dir() + '/snapshots/', epoch)

            if record_stats:
                with logger.prefix('itr #%d | ' % epoch):
                    self.log_diagnostics(stats)
                    logger.dump_tabular()

        return stats
Beispiel #2
0
    def train(
            self,
            dataset,  #main training dataset
            test_dataset,  #main validation dataset
            dummy_dataset,  #dataset containing only data from the current iteration
            joint_training,  #whether training should happen on one dataset at a time or jointly
            max_itr=1000,
            save_step=10,
            train_vae_after_add=10,  #how many times to train the vae after exploring
            #unused
        plot_step=0,
            record_stats=True,
            print_step=True,
            start_itr=0,
            add_size=0,
            add_interval=0):

        for itr in range(1, max_itr + 1):
            if itr % save_step == 0 and logger.get_snapshot_dir() is not None:
                self.save(logger.get_snapshot_dir() + '/snapshots', itr)
                np.save(logger.get_snapshot_dir() + '/snapshots/traindata',
                        self.dataset.train_data)

            # run mpc + explorer and collect data + stats
            stats = self.train_explorer(dataset, test_dataset, dummy_dataset,
                                        itr)
            with logger.prefix('itr #%d | ' % (itr)):
                self.vae.print_diagnostics(stats)
            record_tabular(stats, 'ex_stats.csv')

            # fit the VAE on newly collected data and replay buffer
            for vae_itr in range(train_vae_after_add):
                if joint_training:
                    vae_stats = self.train_vae_joint(dataset, dummy_dataset,
                                                     test_dataset, itr,
                                                     vae_itr)
                else:
                    # vae_stats = self.train_vae(dummy_dataset, None, itr, vae_itr)
                    # with logger.prefix('itr #%d vae newdata itr #%d | ' % (itr, vae_itr)):
                    #     self.vae.print_diagnostics(vae_stats)
                    # record_tabular(vae_stats, 'new_vae_stats.csv')

                    vae_stats = self.train_vae(dataset, test_dataset, itr,
                                               vae_itr)
                with logger.prefix('itr #%d vae itr #%d | ' % (itr, vae_itr)):
                    self.vae.print_diagnostics(vae_stats)
                record_tabular(vae_stats, 'vae_stats.csv')
Beispiel #3
0
 def train(self):
     start_time = time.time()
     for itr in range(self.start_itr, self.n_itr):
         itr_start_time = time.time()
         with logger.prefix('itr #%d | ' % itr):
             logger.log("Obtaining samples...")
             sd = self.obtain_samples(itr)
             if self.alter_sd_fn is not None:
                 self.alter_sd_fn(sd, *self.alter_sd_args)
             logger.log("Processing samples...")
             self.process_samples(itr, sd)
             logger.log("Logging diagnostics...")
             self.log_diagnostics(sd['stats'])
             logger.log("Optimizing policy...")
             self.optimize_policy(itr, sd)
             logger.record_tabular('Time', time.time() - start_time)
             logger.record_tabular('ItrTime', time.time() - itr_start_time)
             logger.dump_tabular(with_prefix=False)
         if itr % self.plot_every == 0 and self.plot and itr > self.plot_itr_threshold:
             rollout(self.policy, self.env_obj, self.max_path_length, plot=True)
         if itr % self.save_step == 0 and logger.get_snapshot_dir() is not None:
             self.save(logger.get_snapshot_dir() + '/snapshots', itr)
Beispiel #4
0
    def optimize_policy(self,
                        itr,
                        samples_data,
                        add_input_fn=None,
                        add_input_input=None,
                        add_loss_fn=None,
                        print=True):

        advantages = from_numpy(samples_data['discount_adv'].astype(
            np.float32))
        n_traj = samples_data['obs'].shape[0]
        n_obs = n_traj * self.max_path_length
        #add_input_obs = from_numpy(samples_data['obs'][:, :, :self.obs_dim].astype(np.float32)).view(n_traj, -1)
        if add_input_fn is not None:
            obs = from_numpy(samples_data['obs']
                             [:, :self.max_path_length, :self.obs_dim].astype(
                                 np.float32)).view(n_obs, -1)
        else:
            obs = from_numpy(
                samples_data['obs'][:, :self.max_path_length, :].astype(
                    np.float32)).view(n_obs, -1)

        #obs = from_numpy(samples_data['obs'][:, :self.max_path_length, :].astype(np.float32)).view(n_obs, -1)

        actions = samples_data['actions'].view(n_obs, -1).data
        returns = from_numpy(samples_data['discount_returns'].copy()).view(
            -1, 1).float()
        old_action_log_probs = samples_data['log_prob'].view(n_obs, -1).data
        states = samples_data['states'].view(
            samples_data['states'].size()[0], n_obs,
            -1) if self.policy.recurrent() else None

        for epoch_itr in range(self.epoch):
            sampler = BatchSampler(SubsetRandomSampler(range(n_obs)),
                                   self.ppo_batch_size,
                                   drop_last=False)
            for indices in sampler:
                indices = LongTensor(indices)
                obs_batch = Variable(obs[indices])
                actions_batch = actions[indices]
                return_batch = returns[indices]
                old_action_log_probs_batch = old_action_log_probs[indices]
                if states is not None:
                    self.policy.set_state(Variable(states[:, indices]))

                if add_input_fn is not None:
                    add_input_dist = add_input_fn(Variable(add_input_input))
                    add_input = add_input_dist.sample()
                    add_input_rep = torch.unsqueeze(add_input, 1).repeat(
                        1, self.max_path_length, 1).view(n_obs, -1)
                    #add_input_batch = add_input[indices/add_input.size()[0]]
                    add_input_batch = add_input_rep[indices]
                    obs_batch = torch.cat([obs_batch, add_input_batch], -1)

                values = self.baseline.forward(obs_batch.detach())
                action_dist = self.policy.forward(obs_batch)
                action_log_probs = action_dist.log_likelihood(
                    Variable(actions_batch)).unsqueeze(-1)
                dist_entropy = action_dist.entropy().mean()

                ratio = torch.exp(action_log_probs -
                                  Variable(old_action_log_probs_batch))
                adv_targ = Variable(advantages.view(-1, 1)[indices])
                surr1 = ratio * adv_targ
                surr2 = torch.clamp(ratio, 1.0 - self.clip_param,
                                    1.0 + self.clip_param) * adv_targ
                action_loss = -torch.min(
                    surr1,
                    surr2).mean()  # PPO's pessimistic surrogate (L^CLIP)

                value_loss = (Variable(return_batch) - values).pow(2).mean()

                self.optimizer.zero_grad()

                total_loss = (value_loss + action_loss -
                              dist_entropy * self.entropy_bonus)
                if add_loss_fn is not None:
                    total_loss += add_loss_fn(add_input_dist, add_input,
                                              add_input_input)
                total_loss.backward()
                self.optimizer.step()
            if print:
                stats = {
                    'total loss': get_numpy(total_loss)[0],
                    'action loss': get_numpy(action_loss)[0],
                    'value loss': get_numpy(value_loss)[0],
                    'entropy': get_numpy(dist_entropy)[0]
                }
                with logger.prefix('Train PPO itr %d epoch itr %d | ' %
                                   (itr, epoch_itr)):
                    self.print_diagnostics(stats)

        return total_loss
Beispiel #5
0
    def train_explorer(self, dataset, test_dataset, dummy_dataset, itr):
        bs = self.batch_size

        # load fixed initial state and goals from config
        init_state = self.block_config[0]
        goals = np.array(self.block_config[1])

        # functions for computing the reward and initializing the reward state (rstate)
        # rstate is used to keep track of things such as which goal you are currently on
        reward_fn, init_rstate = self.reward_fn

        # total actual reward collected by MPC agent so far
        total_mpc_rew = np.zeros(self.mpc_batch)

        # keep track of states visited by MPC to initialize the explorer from
        all_inits = []

        # current state of mpc batche
        cur_state = np.array([init_state] * self.mpc_batch)

        # initialize the reward state for the mpc batch
        rstate = init_rstate(self.mpc_batch)

        # for visualization purposes
        mpc_preds = []
        mpc_actual = []
        mpc_span = []
        rstates = []

        # Perform MPC over max_horizon
        for T in range(self.max_horizon):
            print(T)

            # for goal visulization
            rstates.append(rstate)

            # rollout imaginary trajectories using state decoder
            rollouts = self.mpc(cur_state,
                                min(self.plan_horizon,
                                    self.max_horizon - T), self.mpc_explore,
                                self.mpc_explore_batch, reward_fn, rstate)

            # get first latent of best trajectory for each batch
            np_latents = rollouts[2][:, 0]

            # rollout the first latent in simulator
            mpc_traj = self.sampler_mpc.obtain_samples(self.mpc_batch *
                                                       self.max_path_length,
                                                       self.max_path_length,
                                                       np_to_var(np_latents),
                                                       reset_args=cur_state)

            # update reward and reward state based on trajectory from simulator
            mpc_rew, rstate = self.eval_rewards(mpc_traj['obs'], reward_fn,
                                                rstate)

            # for logging and visualization purposes
            futures = rollouts[0] + total_mpc_rew
            total_mpc_rew += mpc_rew
            mpc_preds.append(rollouts[1][0])
            mpc_span.append(rollouts[3])
            mpc_stats = {
                'mean futures': np.mean(futures),
                'std futures': np.std(futures),
                'mean actual': np.mean(total_mpc_rew),
                'std actual': np.std(total_mpc_rew),
            }
            mpc_actual.append(mpc_traj['obs'][0])
            with logger.prefix('itr #%d mpc step #%d | ' % (itr, T)):
                self.vae.print_diagnostics(mpc_stats)
            record_tabular(mpc_stats, 'mpc_stats.csv')

            # add current state to list of states explorer can initialize from
            all_inits.append(cur_state)

            # update current state to current state of simulator
            cur_state = mpc_traj['obs'][:, -1]

        # for visualization
        for idx, (actual, pred, rs, span) in enumerate(
                zip(mpc_actual, mpc_preds, rstates, mpc_span)):
            dataset.plot_pd_compare(
                [actual, pred, span[:100], span[:100, :dataset.path_len]],
                ['actual', 'pred', 'imagined', 'singlestep'],
                itr,
                save_dir='mpc_match',
                name='Pred' + str(idx),
                goals=goals,
                goalidx=rs[0])

        # compute reward at final state, for some tasks that care about final state reward
        final_reward, _ = reward_fn(cur_state, rstate)
        print(total_mpc_rew)
        print(final_reward)

        # randomly select states for explorer to explore
        start_states = np.concatenate(all_inits, axis=0)
        start_states = start_states[np.random.choice(
            start_states.shape[0],
            self.rand_per_mpc_step,
            replace=self.rand_per_mpc_step > start_states.shape[0])]

        # run the explorer from those states
        explore_len = ((self.max_path_length + 1) * self.mpc_explore_len) - 1
        self.policy_ex_algo.max_path_length = explore_len
        ex_trajs = self.sampler_ex.obtain_samples(start_states.shape[0] *
                                                  explore_len,
                                                  explore_len,
                                                  None,
                                                  reset_args=start_states)

        # Now concat actions taken by explorer with observations for adding to the dataset
        trajs = ex_trajs['obs']
        obs = trajs[:, -1]
        if hasattr(self.action_space,
                   'shape') and len(self.action_space.shape) > 0:
            acts = get_numpy(ex_trajs['actions'])
        else:
            # convert discrete actions into onehot
            act_idx = get_numpy(ex_trajs['actions'])
            acts = np.zeros(
                (trajs.shape[0], trajs.shape[1] - 1, dataset.action_dim))
            acts_reshape = acts.reshape((-1, dataset.action_dim))
            acts_reshape[range(acts_reshape.shape[0]),
                         act_idx.reshape(-1)] = 1.0

        # concat actions with obs
        acts = np.concatenate((acts, acts[:, -1:, :]), 1)
        trajacts = np.concatenate((ex_trajs['obs'], acts), axis=-1)
        trajacts = trajacts.reshape(
            (-1, self.max_path_length + 1, trajacts.shape[-1]))

        # compute train/val split
        ntrain = min(int(0.9 * trajacts.shape[0]),
                     dataset.buffer_size // self.add_frac)
        if dataset.n < dataset.batch_size and ntrain < dataset.batch_size:
            ntrain = dataset.batch_size
        nvalid = min(trajacts.shape[0] - ntrain,
                     test_dataset.buffer_size // self.add_frac)
        if test_dataset.n < test_dataset.batch_size and nvalid < test_dataset.batch_size:
            nvalid = test_dataset.batch_size

        print("Adding ", ntrain, ", Valid: ", nvalid)

        dataset.add_samples(trajacts[:ntrain].reshape((ntrain, -1)))
        test_dataset.add_samples(trajacts[-nvalid:].reshape((nvalid, -1)))

        # dummy dataset stores only data from this iteration
        dummy_dataset.clear()
        dummy_dataset.add_samples(trajacts[:-nvalid].reshape(
            (trajacts.shape[0] - nvalid, -1)))

        # compute negative ELBO on trajectories of explorer
        neg_elbos = []
        cur_batch = from_numpy(trajacts).float()
        for i in range(0, trajacts.shape[0], self.batch_size):
            mse, neg_ll, kl, bcloss, z_dist = self.vae.forward_batch(
                cur_batch[i:i + self.batch_size])
            neg_elbo = (get_numpy(neg_ll) + get_numpy(kl))
            neg_elbos.append(neg_elbo)

        # reward the explorer
        rewards = np.zeros_like(ex_trajs['rewards'])
        neg_elbos = np.concatenate(neg_elbos, axis=0)
        neg_elbos = neg_elbos.reshape((rewards.shape[0], -1))
        # just not on the first iteration, since VAE hasnt fitted yet
        if itr != 1:
            rewidx = list(
                range(self.max_path_length, explore_len,
                      self.max_path_length + 1)) + [explore_len - 1]
            for i in range(rewards.shape[0]):
                rewards[i, rewidx] = neg_elbos[i]

            # add in true reward to explorer if desired
            if self.true_reward_scale != 0:
                rstate = init_rstate(rewards.shape[0])
                for oidx in range(rewards.shape[1]):
                    r, rstate = reward_fn(ex_trajs['obs'][:, oidx], rstate)
                    rewards[:, oidx] += r * self.true_reward_scale

        ex_trajs['rewards'] = rewards

        # train explorer using PPO with neg elbo
        self.policy_ex_algo.process_samples(
            0, ex_trajs)  #, augment_obs=get_numpy(z))
        if itr != 1:
            self.policy_ex_algo.optimize_policy(0, ex_trajs)
        ex_trajs['stats']['MPC Actual'] = np.mean(total_mpc_rew)
        ex_trajs['stats']['Final Reward'] = np.mean(final_reward)

        # reset explorer if necessary
        if ex_trajs['stats']['Entropy'] < self.reset_ent:
            if hasattr(self.policy_ex, "prob_network"):
                self.policy_ex.prob_network.apply(xavier_init)
            else:
                self.policy_ex.apply(xavier_init)
                self.policy_ex.log_var_network.params_var.data = self.policy_ex.log_var_network.param_init

        # for visualization purposes
        colors = ['purple', 'magenta', 'green', 'black', 'yellow', 'black']
        fig, ax = plt.subplots(3, 2, figsize=(10, 10))
        for i in range(6):
            if i * 2 + 1 < obs.shape[1]:
                axx = ax[i // 2][i % 2]
                if i == 5:
                    axx.scatter(obs[:, -3], obs[:, -2], color=colors[i], s=10)
                else:
                    axx.scatter(obs[:, i * 2],
                                obs[:, i * 2 + 1],
                                color=colors[i],
                                s=10)
                axx.set_xlim(-3, 3)
                axx.set_ylim(-3, 3)
        path = logger.get_snapshot_dir() + '/final_dist'
        if not os.path.exists(path):
            os.makedirs(path)
        plt.savefig('%s/%d.png' % (path, itr))
        np.save(path + "/" + str(itr), obs)

        return ex_trajs['stats']