Example #1
0
class GaussianBidirectionalNetwork(GaussianNetwork):
    def __init__(self, input_dim, hidden_dim, num_layers, **kwargs):
        super().__init__(**kwargs)
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.num_layers = num_layers
        # self.output_dim = output_dim
        self.lstm = nn.LSTM(input_dim,
                            hidden_dim,
                            num_layers,
                            bidirectional=True)
        self.rnn = RNN(self.lstm, hidden_dim)
        self.modules.extend([self.rnn])
        # self.linear = nn.Linear(hidden_dim * 2, output_dim)
    def reset(self, x):
        self.rnn.init_hidden(x.size()[1])

    def forward(self, x):
        self.reset(x)
        lstm_out = self.rnn.forward(x)
        lstm_mean = torch.mean(lstm_out, dim=0)
        # output = self.linear(lstm_mean)
        mean, log_var = self.mean_network(lstm_mean), self.log_var_network(
            lstm_mean)
        dist = Normal(mean, log_var=log_var)
        return dist

    def recurrent(self):
        return True
Example #2
0
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 num_layers,
                 output_dim,
                 log_var_network,
                 init=xavier_init,
                 scale_final=False,
                 min_var=1e-4,
                 obs_filter=None):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.num_layers = num_layers
        self.output_dim = output_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
        self.rnn = RNN(self.lstm, hidden_dim)
        self.softmax = nn.Softmax()
        self.linear = nn.Linear(hidden_dim, output_dim)

        self.log_var_network = log_var_network
        self.modules = [self.rnn, self.linear, self.log_var_network]

        self.obs_filter = obs_filter
        self.min_log_var = np_to_var(
            np.log(np.array([min_var])).astype(np.float32))

        self.apply(init)
        # self.apply(weights_init_mlp)
        if scale_final:
            if hasattr(self.mean_network, 'network'):
                self.mean_network.network.finallayer.weight.data.mul_(0.01)
 def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
     super().__init__()
     self.hidden_dim = hidden_dim
     self.input_dim = input_dim
     self.num_layers = num_layers
     self.output_dim = output_dim
     self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
     self.rnn = RNN(self.lstm, hidden_dim)
     self.softmax = nn.Softmax()
     self.linear = nn.Linear(hidden_dim, output_dim)
     self.modules = [self.rnn, self.linear]
Example #4
0
 def __init__(self, input_dim, hidden_dim, num_layers, **kwargs):
     super().__init__(**kwargs)
     self.hidden_dim = hidden_dim
     self.input_dim = input_dim
     self.num_layers = num_layers
     # self.output_dim = output_dim
     self.lstm = nn.LSTM(input_dim,
                         hidden_dim,
                         num_layers,
                         bidirectional=True)
     self.rnn = RNN(self.lstm, hidden_dim)
     self.modules.extend([self.rnn])
class LSTMPolicy(ModuleContainer):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.num_layers = num_layers
        self.output_dim = output_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
        self.rnn = RNN(self.lstm, hidden_dim)
        self.softmax = nn.Softmax()
        self.linear = nn.Linear(hidden_dim, output_dim)
        self.modules = [self.rnn, self.linear]

    def reset(self, bs):
        self.rnn.init_hidden(bs)

    def forward(self, x):
        if len(x.size()) == 2:
            x = x.unsqueeze(0)
        lstm_out = self.rnn.forward(x)
        lstm_reshape = lstm_out.view((-1, self.hidden_dim))
        output = self.softmax(self.linear(lstm_reshape))
        dist = Categorical(output)
        return dist

    def set_state(self, state):
        self.rnn.set_state(state)

    def get_state(self):
        return self.rnn.get_state()

    def recurrent(self):
        return True
Example #6
0
def run_task(vv):
    set_gpu_mode(vv['gpu'])
    env_name = None

    goals = np.array(vv['goals'])
    env = lambda: SwimmerEnv(
        vv['frame_skip'], goals=goals, include_rstate=False)

    obs_dim = int(env().observation_space.shape[0])
    action_dim = int(env().action_space.shape[0])
    vv['block_config'] = [env().reset().tolist(), vv['goals']]
    print(vv['block_config'])

    path_len = vv['path_len']
    data_path = vv['initial_data_path']
    use_actions = vv['use_actions']

    dummy = np.zeros((1, path_len + 1, obs_dim + action_dim))
    train_data, test_data = dummy, dummy
    train_dataset = WheeledContDataset(data_path=data_path,
                                       raw_data=train_data,
                                       obs_dim=obs_dim,
                                       action_dim=action_dim,
                                       path_len=path_len,
                                       env_id='Playpen',
                                       normalize=False,
                                       use_actions=use_actions,
                                       batch_size=vv['batch_size'],
                                       buffer_size=vv['buffer_size'],
                                       pltidx=[-2, -1])

    test_dataset = WheeledContDataset(data_path=data_path,
                                      raw_data=train_data,
                                      obs_dim=obs_dim,
                                      action_dim=action_dim,
                                      path_len=path_len,
                                      env_id='Playpen',
                                      normalize=False,
                                      use_actions=use_actions,
                                      batch_size=vv['batch_size'] // 9,
                                      buffer_size=vv['buffer_size'] // 9,
                                      pltidx=[-2, -1])
    dummy_dataset = WheeledContDataset(data_path=data_path,
                                       raw_data=train_data,
                                       obs_dim=obs_dim,
                                       action_dim=action_dim,
                                       path_len=path_len,
                                       env_id='Playpen',
                                       normalize=False,
                                       use_actions=use_actions,
                                       batch_size=vv['batch_size'],
                                       buffer_size=vv['buffer_size'],
                                       pltidx=[-2, -1])

    train_dataset.clear()
    test_dataset.clear()
    dummy_dataset.clear()

    latent_dim = vv['latent_dim']
    policy_rnn_hidden_dim = vv['policy_rnn_hidden_dim']
    rnn_hidden_dim = vv['decoder_rnn_hidden_dim']

    step_dim = obs_dim
    rnn_hidden_dim = 256
    if vv['encoder_type'] == 'mlp':
        encoder = GaussianNetwork(mean_network=MLP(
            (path_len + 1) * step_dim,
            latent_dim,
            hidden_sizes=vv['encoder_hidden_sizes'],
            hidden_act=nn.ReLU),
                                  log_var_network=MLP(
                                      (path_len + 1) * step_dim, latent_dim))
    elif vv['encoder_type'] == 'lstm':
        encoder = GaussianBidirectionalNetwork(
            input_dim=step_dim,
            hidden_dim=rnn_hidden_dim,
            num_layers=2,
            mean_network=MLP(2 * rnn_hidden_dim, latent_dim),
            log_var_network=MLP(2 * rnn_hidden_dim, latent_dim))

    if vv['decoder_var_type'] == 'param':
        decoder_log_var_network = Parameter(latent_dim,
                                            step_dim,
                                            init=np.log(0.1))
    else:
        decoder_log_var_network = MLP(rnn_hidden_dim, step_dim)
    if vv['decoder_type'] == 'grnn':
        decoder = GaussianRecurrentNetwork(
            recurrent_network=RNN(
                nn.LSTM(step_dim + latent_dim, rnn_hidden_dim),
                rnn_hidden_dim),
            mean_network=MLP(rnn_hidden_dim,
                             step_dim,
                             hidden_sizes=vv['decoder_hidden_sizes'],
                             hidden_act=nn.ReLU),
            log_var_network=decoder_log_var_network,
            path_len=path_len,
            output_dim=step_dim,
            min_var=1e-4,
        )
    elif vv['decoder_type'] == 'gmlp':
        decoder = GaussianNetwork(
            mean_network=MLP(latent_dim,
                             path_len * step_dim,
                             hidden_sizes=vv['decoder_hidden_sizes'],
                             hidden_act=nn.ReLU),
            log_var_network=Parameter(latent_dim,
                                      path_len * step_dim,
                                      init=np.log(0.1)),
            min_var=1e-4)
    elif vv['decoder_type'] == 'mixedrnn':
        gauss_output_dim = 10
        cat_output_dim = 5
        decoder = MixedRecurrentNetwork(
            recurrent_network=RNN(
                nn.LSTM(step_dim + latent_dim, rnn_hidden_dim),
                rnn_hidden_dim),
            mean_network=MLP(rnn_hidden_dim,
                             gauss_output_dim,
                             hidden_sizes=vv['decoder_hidden_sizes'],
                             hidden_act=nn.ReLU),
            prob_network=MLP(rnn_hidden_dim,
                             cat_output_dim,
                             final_act=nn.Softmax),
            log_var_network=Parameter(latent_dim,
                                      gauss_output_dim,
                                      init=np.log(0.1)),
            path_len=path_len,
            output_dim=step_dim,
            min_var=1e-4,
            gaussian_output_dim=gauss_output_dim,
            cat_output_dim=cat_output_dim)

    if vv['policy_type'] == 'grnn':
        policy = GaussianRecurrentPolicy(
            recurrent_network=RNN(
                nn.LSTM(obs_dim + latent_dim, policy_rnn_hidden_dim),
                policy_rnn_hidden_dim),
            mean_network=MLP(policy_rnn_hidden_dim,
                             action_dim,
                             hidden_act=nn.ReLU),
            log_var_network=Parameter(obs_dim + latent_dim,
                                      action_dim,
                                      init=np.log(1)),
            path_len=path_len,
            output_dim=action_dim)

    elif vv['policy_type'] == 'gmlp':
        policy = GaussianNetwork(
            mean_network=MLP(obs_dim + latent_dim,
                             action_dim,
                             hidden_sizes=vv['policy_hidden_sizes'],
                             hidden_act=nn.ReLU),
            log_var_network=Parameter(obs_dim + latent_dim,
                                      action_dim,
                                      init=np.log(1)))
        policy_ex = GaussianNetwork(mean_network=MLP(
            obs_dim,
            action_dim,
            hidden_sizes=vv['policy_hidden_sizes'],
            hidden_act=nn.ReLU),
                                    log_var_network=Parameter(obs_dim,
                                                              action_dim,
                                                              init=np.log(1)))
    elif vv['policy_type'] == 'crnn':
        policy = RecurrentCategoricalPolicy(
            recurrent_network=RNN(
                nn.LSTM(obs_dim + latent_dim, policy_rnn_hidden_dim),
                policy_rnn_hidden_dim),
            prob_network=MLP(policy_rnn_hidden_dim,
                             action_dim,
                             hidden_sizes=vv['policy_hidden_sizes'],
                             final_act=nn.Softmax),
            path_len=path_len,
            output_dim=action_dim)
    elif vv['policy_type'] == 'cmlp':
        policy = CategoricalNetwork(prob_network=MLP(obs_dim + latent_dim,
                                                     action_dim,
                                                     hidden_sizes=(400, 300,
                                                                   200),
                                                     hidden_act=nn.ReLU,
                                                     final_act=nn.Softmax),
                                    output_dim=action_dim)
        policy_ex = CategoricalNetwork(prob_network=MLP(obs_dim,
                                                        action_dim,
                                                        hidden_sizes=(400, 300,
                                                                      200),
                                                        hidden_act=nn.ReLU,
                                                        final_act=nn.Softmax),
                                       output_dim=action_dim)
    elif vv['policy_type'] == 'lstm':
        policy = LSTMPolicy(input_dim=obs_dim + latent_dim,
                            hidden_dim=rnn_hidden_dim,
                            num_layers=2,
                            output_dim=action_dim)

    vae = TrajVAEBC(encoder=encoder,
                    decoder=decoder,
                    latent_dim=latent_dim,
                    step_dim=step_dim,
                    feature_dim=train_dataset.obs_dim,
                    env=env,
                    path_len=train_dataset.path_len,
                    init_kl_weight=vv['kl_weight'],
                    max_kl_weight=vv['kl_weight'],
                    kl_mul=1.03,
                    loss_type=vv['vae_loss_type'],
                    lr=vv['vae_lr'],
                    obs_dim=obs_dim,
                    act_dim=action_dim,
                    policy=policy,
                    bc_weight=vv['bc_weight'])

    baseline = ZeroBaseline()
    policy_algo = PPO(env,
                      env_name,
                      policy,
                      baseline=baseline,
                      obs_dim=obs_dim,
                      action_dim=action_dim,
                      max_path_length=path_len,
                      center_adv=True,
                      optimizer=optim.Adam(policy.get_params(),
                                           vv['policy_lr'],
                                           eps=1e-5),
                      use_gae=vv['use_gae'],
                      epoch=10,
                      ppo_batch_size=200)

    baseline_ex = ZeroBaseline()
    policy_ex_algo = PPO(env,
                         env_name,
                         policy_ex,
                         baseline=baseline_ex,
                         obs_dim=obs_dim,
                         action_dim=action_dim,
                         max_path_length=path_len,
                         center_adv=True,
                         optimizer=optim.Adam(policy_ex.get_params(),
                                              vv['policy_lr'],
                                              eps=1e-5),
                         use_gae=vv['use_gae'],
                         epoch=10,
                         ppo_batch_size=200,
                         entropy_bonus=vv['entropy_bonus'])

    if vv['load_models_dir'] is not None:
        dir = getcwd(
        ) + "/research/lang/traj2vecv3_jd/" + vv['load_models_dir']
        itr = vv['load_models_idx']
        encoder.load_state_dict(torch.load(dir + '/encoder_%d.pkl' % itr))
        decoder.load_state_dict(torch.load(dir + '/decoder_%d.pkl' % itr))
        policy.load_state_dict(torch.load(dir + '/policy_%d.pkl' % itr))
        policy_ex.load_state_dict(torch.load(dir + '/policy_ex_%d.pkl' % itr))
        vae.optimizer.load_state_dict(
            torch.load(dir + '/vae_optimizer_%d.pkl' % itr))
        policy_algo.optimizer.load_state_dict(
            torch.load(dir + '/policy_optimizer_%d.pkl' % itr))

    rf = lambda obs, rstate: reward_fn(obs, rstate, goals, 3)
    mpc_explore = 4000
    if vv['path_len'] <= 50:
        mpc_explore *= 2
    vaepd = VAEPDEntropy(env,
                         env_name,
                         policy,
                         policy_ex,
                         encoder,
                         decoder,
                         path_len,
                         obs_dim,
                         action_dim,
                         step_dim,
                         policy_algo,
                         policy_ex_algo,
                         train_dataset,
                         latent_dim,
                         vae,
                         batch_size=400,
                         block_config=vv['block_config'],
                         plan_horizon=vv['mpc_plan'],
                         max_horizon=vv['mpc_max'],
                         mpc_batch=vv['mpc_batch'],
                         rand_per_mpc_step=vv['mpc_explore_step'],
                         mpc_explore=mpc_explore,
                         mpc_explore_batch=1,
                         reset_ent=vv['reset_ent'],
                         vae_train_steps=vv['vae_train_steps'],
                         mpc_explore_len=vv['mpc_explore_len'],
                         true_reward_scale=vv['true_reward_scale'],
                         discount_factor=vv['discount_factor'],
                         reward_fn=(rf, init_rstate))

    vaepd.train(train_dataset,
                test_dataset=test_dataset,
                dummy_dataset=dummy_dataset,
                plot_step=10,
                max_itr=vv['max_itr'],
                record_stats=True,
                print_step=1000,
                save_step=2,
                start_itr=0,
                train_vae_after_add=vv['train_vae_after_add'],
                joint_training=vv['joint_training'])
Example #7
0
class GaussianLSTMPolicy(ModuleContainer):
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 num_layers,
                 output_dim,
                 log_var_network,
                 init=xavier_init,
                 scale_final=False,
                 min_var=1e-4,
                 obs_filter=None):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.num_layers = num_layers
        self.output_dim = output_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
        self.rnn = RNN(self.lstm, hidden_dim)
        self.softmax = nn.Softmax()
        self.linear = nn.Linear(hidden_dim, output_dim)

        self.log_var_network = log_var_network
        self.modules = [self.rnn, self.linear, self.log_var_network]

        self.obs_filter = obs_filter
        self.min_log_var = np_to_var(
            np.log(np.array([min_var])).astype(np.float32))

        self.apply(init)
        # self.apply(weights_init_mlp)
        if scale_final:
            if hasattr(self.mean_network, 'network'):
                self.mean_network.network.finallayer.weight.data.mul_(0.01)

    def forward(self, x):
        if self.obs_filter is not None:
            x.data = self.obs_filter(x.data)
        if len(x.size()) == 2:
            x = x.unsqueeze(0)
        lstm_out = self.rnn.forward(x)
        lstm_reshape = lstm_out.view((-1, self.hidden_dim))
        mean = self.softmax(self.linear(lstm_reshape))

        log_var = self.log_var_network(x.contiguous().view((-1, x.shape[-1])))
        log_var = torch.max(self.min_log_var, log_var)
        # TODO Limit log var
        dist = Normal(mean=mean, log_var=log_var)
        return dist

    def reset(self, bs):
        self.rnn.init_hidden(bs)

    def set_state(self, state):
        self.rnn.set_state(state)

    def get_state(self):
        return self.rnn.get_state()

    def recurrent(self):
        return True
Example #8
0
def run_task(vv):
    set_gpu_mode(vv['gpu'])
    env_name = vv['env_name']
    env = make_env(env_name,
                   1,
                   0,
                   '/tmp/gym',
                   kwargs=dict(border=vv['block_config'][2]))
    obs_dim = int(env().observation_space.shape[0])
    action_dim = int(env().action_space.n)

    path_len = vv['path_len']
    data_path = None
    # True so that behavioral cloning has access to actions
    use_actions = True

    #create a dummy datset since we initialize with no data
    dummy = np.zeros((1, path_len + 1, obs_dim + action_dim))
    train_data, test_data = dummy, dummy

    train_dataset = PlayPenContDataset(data_path=data_path,
                                       raw_data=train_data,
                                       obs_dim=obs_dim,
                                       action_dim=action_dim,
                                       path_len=path_len,
                                       env_id='Playpen',
                                       normalize=False,
                                       use_actions=use_actions,
                                       batch_size=vv['batch_size'],
                                       buffer_size=vv['buffer_size'])
    #validation set for vae training
    test_dataset = PlayPenContDataset(data_path=data_path,
                                      raw_data=train_data,
                                      obs_dim=obs_dim,
                                      action_dim=action_dim,
                                      path_len=path_len,
                                      env_id='Playpen',
                                      normalize=False,
                                      use_actions=use_actions,
                                      batch_size=vv['batch_size'] // 9,
                                      buffer_size=vv['buffer_size'] // 9)

    #this holds the data from the latest iteration for joint training
    dummy_dataset = PlayPenContDataset(data_path=data_path,
                                       raw_data=train_data,
                                       obs_dim=obs_dim,
                                       action_dim=action_dim,
                                       path_len=path_len,
                                       env_id='Playpen',
                                       normalize=False,
                                       use_actions=use_actions,
                                       batch_size=vv['batch_size'],
                                       buffer_size=vv['buffer_size'])

    train_dataset.clear()
    test_dataset.clear()
    dummy_dataset.clear()

    latent_dim = vv['latent_dim']
    rnn_hidden_dim = vv['decoder_rnn_hidden_dim']

    step_dim = obs_dim

    # build encoder
    if vv['encoder_type'] == 'mlp':
        encoder = GaussianNetwork(mean_network=MLP(
            (path_len + 1) * step_dim,
            latent_dim,
            hidden_sizes=vv['encoder_hidden_sizes'],
            hidden_act=nn.ReLU),
                                  log_var_network=MLP(
                                      (path_len + 1) * step_dim, latent_dim))
    elif vv['encoder_type'] == 'lstm':
        encoder = GaussianBidirectionalNetwork(
            input_dim=step_dim,
            hidden_dim=rnn_hidden_dim,
            num_layers=2,
            mean_network=MLP(2 * rnn_hidden_dim, latent_dim),
            log_var_network=MLP(2 * rnn_hidden_dim, latent_dim))

    # build state decoder
    if vv['decoder_var_type'] == 'param':
        decoder_log_var_network = Parameter(latent_dim,
                                            step_dim,
                                            init=np.log(0.1))
    else:
        decoder_log_var_network = MLP(rnn_hidden_dim, step_dim)
    if vv['decoder_type'] == 'grnn':
        decoder = GaussianRecurrentNetwork(
            recurrent_network=RNN(
                nn.LSTM(step_dim + latent_dim, rnn_hidden_dim),
                rnn_hidden_dim),
            mean_network=MLP(rnn_hidden_dim,
                             step_dim,
                             hidden_sizes=vv['decoder_hidden_sizes'],
                             hidden_act=nn.ReLU),
            #log_var_network=Parameter(latent_dim, step_dim, init=np.log(0.1)),
            log_var_network=decoder_log_var_network,
            path_len=path_len,
            output_dim=step_dim,
            min_var=1e-4,
        )
    elif vv['decoder_type'] == 'gmlp':
        decoder = GaussianNetwork(
            mean_network=MLP(latent_dim,
                             path_len * step_dim,
                             hidden_sizes=vv['decoder_hidden_sizes'],
                             hidden_act=nn.ReLU),
            log_var_network=Parameter(latent_dim,
                                      path_len * step_dim,
                                      init=np.log(0.1)),
            min_var=1e-4)
    elif vv['decoder_type'] == 'mixedrnn':
        gauss_output_dim = 10
        cat_output_dim = 5
        decoder = MixedRecurrentNetwork(
            recurrent_network=RNN(
                nn.LSTM(step_dim + latent_dim, rnn_hidden_dim),
                rnn_hidden_dim),
            mean_network=MLP(rnn_hidden_dim,
                             gauss_output_dim,
                             hidden_sizes=vv['decoder_hidden_sizes'],
                             hidden_act=nn.ReLU),
            prob_network=MLP(rnn_hidden_dim,
                             cat_output_dim,
                             final_act=nn.Softmax),
            log_var_network=Parameter(latent_dim,
                                      gauss_output_dim,
                                      init=np.log(0.1)),
            path_len=path_len,
            output_dim=step_dim,
            min_var=1e-4,
            gaussian_output_dim=gauss_output_dim,
            cat_output_dim=cat_output_dim)

    # policy decoder
    policy = CategoricalNetwork(prob_network=MLP(obs_dim + latent_dim,
                                                 action_dim,
                                                 hidden_sizes=(400, 300, 200),
                                                 hidden_act=nn.ReLU,
                                                 final_act=nn.Softmax),
                                output_dim=action_dim)

    # explorer policy
    policy_ex = CategoricalNetwork(prob_network=MLP(obs_dim,
                                                    action_dim,
                                                    hidden_sizes=(400, 300,
                                                                  200),
                                                    hidden_act=nn.ReLU,
                                                    final_act=nn.Softmax),
                                   output_dim=action_dim)

    # vae with behavioral cloning
    vae = TrajVAEBC(encoder=encoder,
                    decoder=decoder,
                    latent_dim=latent_dim,
                    step_dim=step_dim,
                    feature_dim=train_dataset.obs_dim,
                    env=env,
                    path_len=train_dataset.path_len,
                    init_kl_weight=vv['kl_weight'],
                    max_kl_weight=vv['kl_weight'],
                    kl_mul=1.03,
                    loss_type=vv['vae_loss_type'],
                    lr=vv['vae_lr'],
                    obs_dim=obs_dim,
                    act_dim=action_dim,
                    policy=policy,
                    bc_weight=vv['bc_weight'])

    # 0 baseline due to constantly changing rewards
    baseline = ZeroBaseline()
    # policy opt for policy decoder
    policy_algo = PPO(
        env,
        env_name,
        policy,
        baseline=baseline,
        obs_dim=obs_dim,
        action_dim=action_dim,
        max_path_length=path_len,
        center_adv=True,
        optimizer=optim.Adam(policy.get_params(), vv['policy_lr'],
                             eps=1e-5),  #vv['global_lr']),
        use_gae=vv['use_gae'],
        epoch=10,
        ppo_batch_size=200)

    # baseline for the explorer
    baseline_ex = ZeroBaseline()
    # policy opt for the explorer
    policy_ex_algo = PPO(
        env,
        env_name,
        policy_ex,
        baseline=baseline_ex,
        obs_dim=obs_dim,
        action_dim=action_dim,
        max_path_length=path_len,
        center_adv=True,
        optimizer=optim.Adam(policy_ex.get_params(), vv['policy_lr'],
                             eps=1e-5),  #vv['global_lr']),
        use_gae=vv['use_gae'],
        epoch=10,
        ppo_batch_size=200,
        entropy_bonus=vv['entropy_bonus'])

    # for loading the model from a saved state
    if vv['load_models_dir'] is not None:
        dir = getcwd(
        ) + "/research/lang/traj2vecv3_jd/" + vv['load_models_dir']
        itr = vv['load_models_idx']
        encoder.load_state_dict(torch.load(dir + '/encoder_%d.pkl' % itr))
        decoder.load_state_dict(torch.load(dir + '/decoder_%d.pkl' % itr))
        policy.load_state_dict(torch.load(dir + '/policy_%d.pkl' % itr))
        policy_ex.load_state_dict(torch.load(dir + '/policy_ex_%d.pkl' % itr))
        vae.optimizer.load_state_dict(
            torch.load(dir + '/vae_optimizer_%d.pkl' % itr))
        policy_algo.optimizer.load_state_dict(
            torch.load(dir + '/policy_optimizer_%d.pkl' % itr))

    # block goals
    goals = 2 * np.array(vv['block_config'][1])
    # reward function for MPC
    rf = lambda obs, rstate: reward_fn(obs, rstate, goals)

    # main algorithm launcher, includes mpc controller and exploration
    vaepd = VAEPDEntropy(
        env,
        env_name,
        policy,
        policy_ex,
        encoder,
        decoder,
        path_len,
        obs_dim,
        action_dim,
        step_dim,
        policy_algo,
        policy_ex_algo,
        train_dataset,
        latent_dim,
        vae,
        batch_size=400,
        block_config=vv['block_config'],
        plan_horizon=vv['mpc_plan'],
        max_horizon=vv['mpc_max'],
        mpc_batch=vv['mpc_batch'],
        rand_per_mpc_step=vv['mpc_explore_step'],
        mpc_explore=2048,
        mpc_explore_batch=6,
        reset_ent=vv['reset_ent'],
        vae_train_steps=vv['vae_train_steps'],
        mpc_explore_len=vv['mpc_explore_len'],
        consis_finetuning=vv['consis_finetuning'],
        true_reward_scale=vv['true_reward_scale'],
        discount_factor=vv['discount_factor'],
        reward_fn=(rf, init_rstate),
    )

    vaepd.train(train_dataset,
                test_dataset=test_dataset,
                dummy_dataset=dummy_dataset,
                plot_step=10,
                max_itr=vv['max_itr'],
                record_stats=True,
                print_step=1000,
                save_step=20,
                start_itr=0,
                train_vae_after_add=vv['train_vae_after_add'],
                joint_training=vv['joint_training'])