Beispiel #1
0
def main(**kwargs):
    exp_dir = os.getcwd() + '/data/' + EXP_NAME + '/' + str(kwargs['seed'])
    logger.configure(dir=exp_dir,
                     format_strs=['stdout', 'log', 'csv'],
                     snapshot_mode='last')
    json.dump(kwargs,
              open(exp_dir + '/params.json', 'w'),
              indent=2,
              sort_keys=True,
              cls=ClassEncoder)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = kwargs.get(
        'gpu_frac', 0.95)
    sess = tf.Session(config=config)

    with sess.as_default() as sess:
        folder = './data/policy/' + kwargs['env']
        paths = pickle.load(open(folder + '/paths.pickle', 'rb'))
        niters = paths.get_current_episode_size() // 100
        train_data, test_data = split_data(paths, niters)

        dimo = train_data[0]['o'].shape[-1]

        dims = [dimo]
        env = gym.make(kwargs['env'],
                       obs_type=kwargs['obs_type'],
                       fixed_num_of_contact=kwargs['fixed_num_of_contact'])

        feature_net = FeatureNet(
            dims,
            fixed_num_of_contact=kwargs['fixed_num_of_contact'],
            contact_dim=env.contact_dim,
            sess=sess,
            output=kwargs['prediction'],
            process_type=kwargs['process_type'],
            feature_dim=kwargs['feature_dim'],
            feature_layer=kwargs['feature_layer'])

        sess.run(tf.global_variables_initializer())
        for i in range(niters):
            start = timer.time()
            feature_net.train(train_data[i])
            feature_net.test(test_data[i])
            logger.logkv("iter", i)
            logger.logkv("iter_time", timer.time() - start)
            logger.dumpkvs()
            if i == 0:
                sess.graph.finalize()
Beispiel #2
0
def test_mpi_weighted_mean():
    comm = MPI.COMM_WORLD
    with logger.scoped_configure(comm=comm):
        if comm.rank == 0:
            name2valcount = {'a' : (10, 2), 'b' : (20,3)}
        elif comm.rank == 1:
            name2valcount = {'a' : (19, 1), 'c' : (42,3)}
        else:
            raise NotImplementedError
        d = mpi_util.mpi_weighted_mean(comm, name2valcount)
        correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
        if comm.rank == 0:
            assert d == correctval, '{} != {}'.format(d, correctval)

        for name, (val, count) in name2valcount.items():
            for _ in range(count):
                logger.logkv_mean(name, val)
        d2 = logger.dumpkvs()
        if comm.rank == 0:
            assert d2 == correctval
Beispiel #3
0
def learn(*,
          network,
          env,
          total_timesteps,
          eval_env=None,
          seed=None,
          nsteps=2048,
          ent_coef=0.0,
          lr=3e-4,
          vf_coef=0.5,
          max_grad_norm=0.5,
          gamma=0.99,
          lam=0.95,
          log_interval=10,
          nminibatches=4,
          noptepochs=4,
          cliprange=0.2,
          save_interval=0,
          load_path=None,
          model_fn=None,
          update_fn=None,
          init_fn=None,
          mpi_rank_weight=1,
          comm=None,
          **network_kwargs):
    '''
    Learn policy using PPO algorithm (https://arxiv.org/abs/1707.06347)

    Parameters:
    ----------

    network:                          policy network architecture. Either string (mlp, lstm, lnlstm, cnn_lstm, cnn, cnn_small, conv_only - see baselines.common/models.py for full list)
                                      specifying the standard network architecture, or a function that takes tensorflow tensor as input and returns
                                      tuple (output_tensor, extra_feed) where output tensor is the last network layer output, extra_feed is None for feed-forward
                                      neural nets, and extra_feed is a dictionary describing how to feed state into the network for recurrent neural nets.
                                      See common/models.py/lstm for more details on using recurrent nets in policies

    env: baselines.common.vec_env.VecEnv     environment. Needs to be vectorized for parallel environment simulation.
                                      The environments produced by gym.make can be wrapped using baselines.common.vec_env.DummyVecEnv class.


    nsteps: int                       number of steps of the vectorized environment per update (i.e. batch size is nsteps * nenv where
                                      nenv is number of environment copies simulated in parallel)

    total_timesteps: int              number of timesteps (i.e. number of actions taken in the environment)

    ent_coef: float                   policy entropy coefficient in the optimization objective

    lr: float or function             learning rate, constant or a schedule function [0,1] -> R+ where 1 is beginning of the
                                      training and 0 is the end of the training.

    vf_coef: float                    value function loss coefficient in the optimization objective

    max_grad_norm: float or None      gradient norm clipping coefficient

    gamma: float                      discounting factor

    lam: float                        advantage estimation discounting factor (lambda in the paper)

    log_interval: int                 number of timesteps between logging events

    nminibatches: int                 number of training minibatches per update. For recurrent policies,
                                      should be smaller or equal than number of environments run in parallel.

    noptepochs: int                   number of training epochs per update

    cliprange: float or function      clipping range, constant or schedule function [0,1] -> R+ where 1 is beginning of the training
                                      and 0 is the end of the training

    save_interval: int                number of timesteps between saving events

    load_path: str                    path to load the model from

    **network_kwargs:                 keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
                                      For instance, 'mlp' network architecture has arguments num_hidden and num_layers.



    '''

    set_global_seeds(seed)

    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)
    total_timesteps = int(total_timesteps)

    policy = build_policy(env, network, **network_kwargs)

    # Get the nb of env
    nenvs = env.num_envs

    # Get state_space and action_space
    ob_space = env.observation_space
    ac_space = env.action_space

    # Calculate the batch_size
    nbatch = nenvs * nsteps
    nbatch_train = nbatch // nminibatches
    is_mpi_root = (MPI is None or MPI.COMM_WORLD.Get_rank() == 0)

    # Instantiate the model object (that creates act_model and train_model)
    if model_fn is None:
        from tactile_baselines.ppo2.model import Model
        model_fn = Model

    model = model_fn(policy=policy,
                     ob_space=ob_space,
                     ac_space=ac_space,
                     nbatch_act=nenvs,
                     nbatch_train=nbatch_train,
                     nsteps=nsteps,
                     ent_coef=ent_coef,
                     vf_coef=vf_coef,
                     max_grad_norm=max_grad_norm,
                     comm=comm,
                     mpi_rank_weight=mpi_rank_weight)

    if load_path is not None:
        model.load(load_path)
    # Instantiate the runner object
    runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam)
    if eval_env is not None:
        eval_runner = Runner(env=eval_env,
                             model=model,
                             nsteps=nsteps,
                             gamma=gamma,
                             lam=lam)

    epinfobuf = deque(maxlen=100)
    if eval_env is not None:
        eval_epinfobuf = deque(maxlen=100)

    if init_fn is not None:
        init_fn()

    # Start total timer
    tfirststart = time.perf_counter()

    nupdates = total_timesteps // nbatch
    for update in range(1, nupdates + 1):
        assert nbatch % nminibatches == 0
        # Start timer
        tstart = time.perf_counter()
        frac = 1.0 - (update - 1.0) / nupdates
        # Calculate the learning rate
        lrnow = lr(frac)
        # Calculate the cliprange
        cliprangenow = cliprange(frac)

        if update % log_interval == 0 and is_mpi_root:
            logger.info('Stepping environment...')

        # Get minibatch
        obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run(
        )  #pylint: disable=E0632
        if eval_env is not None:
            eval_obs, eval_returns, eval_masks, eval_actions, eval_values, eval_neglogpacs, eval_states, eval_epinfos = eval_runner.run(
            )  #pylint: disable=E0632

        if update % log_interval == 0 and is_mpi_root: logger.info('Done.')

        epinfobuf.extend(epinfos)
        if eval_env is not None:
            eval_epinfobuf.extend(eval_epinfos)

        # Here what we're going to do is for each minibatch calculate the loss and append it.
        mblossvals = []
        if states is None:  # nonrecurrent version
            # Index of each element of batch_size
            # Create the indices array
            inds = np.arange(nbatch)
            for _ in range(noptepochs):
                # Randomize the indexes
                np.random.shuffle(inds)
                # 0 to batch_size with batch_train_size step
                for start in range(0, nbatch, nbatch_train):
                    end = start + nbatch_train
                    mbinds = inds[start:end]
                    slices = (arr[mbinds]
                              for arr in (obs, returns, masks, actions, values,
                                          neglogpacs))
                    mblossvals.append(model.train(lrnow, cliprangenow,
                                                  *slices))
        else:  # recurrent version
            assert nenvs % nminibatches == 0
            envsperbatch = nenvs // nminibatches
            envinds = np.arange(nenvs)
            flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
            for _ in range(noptepochs):
                np.random.shuffle(envinds)
                for start in range(0, nenvs, envsperbatch):
                    end = start + envsperbatch
                    mbenvinds = envinds[start:end]
                    mbflatinds = flatinds[mbenvinds].ravel()
                    slices = (arr[mbflatinds]
                              for arr in (obs, returns, masks, actions, values,
                                          neglogpacs))
                    mbstates = states[mbenvinds]
                    mblossvals.append(
                        model.train(lrnow, cliprangenow, *slices, mbstates))

        # Feedforward --> get losses --> update
        lossvals = np.mean(mblossvals, axis=0)
        # End timer
        tnow = time.perf_counter()
        # Calculate the fps (frame per second)
        fps = int(nbatch / (tnow - tstart))

        if update_fn is not None:
            update_fn(update)

        if update % log_interval == 0 or update == 1:
            # Calculates if value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = explained_variance(values, returns)
            logger.logkv("misc/serial_timesteps", update * nsteps)
            logger.logkv("misc/nupdates", update)
            logger.logkv("misc/total_timesteps", update * nbatch)
            logger.logkv("fps", fps)
            logger.logkv("misc/explained_variance", float(ev))
            logger.logkv('eprewmean',
                         safemean([epinfo['r'] for epinfo in epinfobuf]))
            logger.logkv('eplenmean',
                         safemean([epinfo['l'] for epinfo in epinfobuf]))
            if eval_env is not None:
                logger.logkv(
                    'eval_eprewmean',
                    safemean([epinfo['r'] for epinfo in eval_epinfobuf]))
                logger.logkv(
                    'eval_eplenmean',
                    safemean([epinfo['l'] for epinfo in eval_epinfobuf]))
            logger.logkv('misc/time_elapsed', tnow - tfirststart)
            for (lossval, lossname) in zip(lossvals, model.loss_names):
                logger.logkv('loss/' + lossname, lossval)

            logger.dumpkvs()
        if save_interval and (update % save_interval == 0 or update
                              == 1) and logger.get_dir() and is_mpi_root:
            checkdir = osp.join(logger.get_dir(), 'checkpoints')
            os.makedirs(checkdir, exist_ok=True)
            savepath = osp.join(checkdir, '%.5i' % update)
            print('Saving to', savepath)
            model.save(savepath)

    return model
Beispiel #4
0
def main(**kwargs):
    exp_dir = os.getcwd(
    ) + '/cpc_model/' + kwargs['process_type'][0] + '/n200-8'
    logger.configure(dir=exp_dir,
                     format_strs=['stdout', 'log', 'csv'],
                     snapshot_mode='last')
    json.dump(kwargs,
              open(exp_dir + '/params.json', 'w'),
              indent=2,
              sort_keys=True,
              cls=ClassEncoder)

    obs, acts, fixed_num_of_contact = pickle.load(
        open('../untrained/HandManipulateEgg-v0/5seeds-dict.pickle', 'rb'))

    include_action = kwargs['include_action'][0]

    env = gym.make(kwargs['env'][0],
                   obs_type=kwargs['obs_type'][0],
                   fixed_num_of_contact=[fixed_num_of_contact, True])

    ngeoms = env.sim.model.ngeom
    obs, object_info = expand_data(obs, ngeoms, fixed_num_of_contact)
    next_obs = obs[:, 1:]
    obs = obs[:, :-1]
    N, L, _, contact_point_dim = obs.shape
    N, L, action_dim = acts.shape

    obs_dim = (fixed_num_of_contact, contact_point_dim)

    z_dim = 8
    lr = 1e-3
    epochs = 100
    batch_size = 2
    n = 200
    k = 1

    encoder = Encoder(z_dim, obs_dim[1], fixed_num_of_contact).cuda()
    if include_action:
        trans = Transition(z_dim, action_dim).cuda()
    else:
        trans = Transition(z_dim, 0).cuda()
    decoder = Decoder(z_dim, 3).cuda()

    optim_cpc = optim.Adam(list(encoder.parameters()) +
                           list(trans.parameters()),
                           lr=lr)
    optim_dec = optim.Adam(decoder.parameters(), lr=lr)
    train_data, test_data = split_data([obs, acts, next_obs])

    for epoch in range(epochs):
        train_cpc(encoder, trans, optim_cpc, epoch, train_data, batch_size, n,
                  k, include_action)
        test_cpc(encoder, trans, epoch, test_data, batch_size, n, k,
                 include_action)

        logger.logkv("epoch", epoch)
        logger.dumpkvs()

    train_data, test_data = split_data([obs, acts, next_obs, object_info])
    for epoch in range(100):
        train_decoder(decoder,
                      encoder,
                      optim_dec,
                      epoch,
                      train_data,
                      batch_size,
                      include_action,
                      n,
                      k=1)
        test_decoder(decoder,
                     encoder,
                     epoch,
                     test_data,
                     batch_size,
                     include_action,
                     n,
                     k=1)
        logger.logkv("epoch", epoch)
        logger.dumpkvs()
Beispiel #5
0
def main(**kwargs):
    z_dim = kwargs['z_dim']
    trans_mode = kwargs['trans_mode']
    epochs = kwargs['epochs']
    include_action = kwargs['include_action']
    label = kwargs['label']

    dataset = kwargs['data_path']
    feature_dims = kwargs['feature_dims']
    mode = kwargs['mode']
    n = kwargs['n']
    k = kwargs['k']
    encoder_lr = kwargs['encoder_lr']
    decoder_lr = kwargs['decoder_lr']
    decoder_feature_dims = kwargs['decoder_feature_dims']
    process_type = kwargs['process_type']

    if kwargs['data_path'] == '../dataset/sequence/HandManipulateEgg-v0/5seeds-dict.pickle':
        kwargs['dataset'] = 'trained_5seeds'
    elif kwargs['data_path'] == '../dataset/untrained/HandManipulateEgg-v0/5seeds-dict.pickle':
        kwargs['dataset'] = 'untrained_5seeds'
    elif kwargs['data_path'] == '../dataset/HandManipulateEgg-v09-dict.pickle':
        kwargs['dataset'] = 'trained_1seed'
    exp_dir = os.getcwd() + '/data/' + EXP_NAME + '/' + str(kwargs['seed'])
    if kwargs['debug']:
        save_dir = '../saved_cpc/' + str(label) + '/' +  str(kwargs['normalize_data']) + '/' + str(process_type)+ '/trained/debug'
        # save_dir = '../saved_cpc/' + str(label) + '/' + str(process_type)+ '/trained/debug'
    else:
        save_dir = '../saved_cpc/' + str(label) + '/' +  str(kwargs['normalize_data']) + '/' + str(process_type)+ '/trained'
        # save_dir = '../saved_cpc/' + str(label) + '/' + str(process_type)+ '/trained'
    logger.configure(dir=exp_dir, format_strs=['stdout', 'log', 'csv'], snapshot_mode='last')
    json.dump(kwargs, open(exp_dir + '/params.json', 'w'), indent=2, sort_keys=True, cls=ClassEncoder)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = kwargs.get('gpu_frac', 0.95)
    sess = tf.Session(config=config)

    obs, acts, fixed_num_of_contact = pickle.load(open(dataset, 'rb'))

    env = gym.make(kwargs['env'],
                   obs_type = kwargs['obs_type'],
                   fixed_num_of_contact = [fixed_num_of_contact, True])

    ngeoms = env.sim.model.ngeom
    obs, object_info = expand_data(obs, ngeoms, fixed_num_of_contact)
    if kwargs['normalize_data']:
        obs = normalize_obs(obs)
    next_obs = obs[:, 1:]
    obs = obs[:, :-1]
    N, L, _, contact_point_dim = obs.shape
    N, L, action_dim = acts.shape

    obs_dim = (fixed_num_of_contact, contact_point_dim)
    train_data, test_data = split_data([obs, acts, next_obs, object_info])

    batch_size = 2

    if mode in ['restore', 'store_weights']:
        saver = tf.train.import_meta_graph(save_dir + '-999.meta')
        pur_save_dir = save_dir[:-8]
        saver.restore(sess, tf.train.latest_checkpoint(pur_save_dir))
        graph = tf.get_default_graph()

    with sess.as_default() as sess:
        encoder = Encoder(z_dim,
                          fixed_num_of_contact,
                          contact_point_dim,
                          feature_dims)
        trans = Transition(z_dim, action_dim, mode = trans_mode)
        cpc = CPC(sess,
                  encoder,
                  trans,
                  encoder_lr,
                  fixed_num_of_contact,
                  contact_point_dim,
                  action_dim,
                  include_action = include_action,
                  type = 1*(label=='cpc1') + 2*(label=='cpc2'),
                  n_neg = n,
                  process_type = process_type,
                  mode = mode)

        cpc_epochs, decoder_epochs = epochs
        if mode == 'train':
            sess.run(tf.global_variables_initializer())
            logger.log("training started")
            for epoch in range(cpc_epochs):
                # train_cpc(cpc, epoch, train_data, batch_size, n, k)
                test_cpc(cpc, epoch, test_data, batch_size, n, k)

                logger.logkv("epoch", epoch)
                logger.dumpkvs()
            cpc.save_model(save_dir, 999)

            """decoder"""
            logger.log("Done with cpc training.")

            decoder = Decoder(cpc,
                              sess,
                              z_dim,
                              decoder_feature_dims,
                              fixed_num_of_contact,
                              contact_point_dim,
                              decoder_lr)
            uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))]
            sess.run(tf.variables_initializer(uninit_vars))
            for epoch in range(decoder_epochs):
                train_decoder(decoder, epoch, train_data, batch_size, n, k)
                test_decoder(decoder, epoch, test_data, batch_size, n, k)

                logger.logkv("epoch", (epoch + cpc_epochs))
                logger.dumpkvs()
            print("model saved in", save_dir)

        elif mode == 'restore':
            decoder = Decoder(cpc,
                              sess,
                              z_dim,
                              decoder_feature_dims,
                              fixed_num_of_contact,
                              contact_point_dim,
                              decoder_lr)
            uninit_vars = [var for var in tf.global_variables() if not sess.run(tf.is_variable_initialized(var))]
            sess.run(tf.variables_initializer(uninit_vars))
            print("initialized")
            for epoch in range(100):
                train_decoder(decoder, epoch, train_data, batch_size, n, k)
                test_decoder(decoder, epoch, test_data, batch_size, n, k)

                logger.logkv("epoch", epoch)
                logger.dumpkvs()
                print("logging to", exp_dir)

        elif mode == 'store_weights':
            old = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='')
            old = sess.run(old)
            save_dir = './saved_model/' +  str(label) + '/' + str(process_type)+ '/trained/'
            with open(save_dir + 'weights.pickle', 'wb') as pickle_file:
                pickle.dump(old, pickle_file)
            print("weights saved to", save_dir)

            save_dir = '/home/vioichigo/try/tactile-baselines/saved_model/cpc2/trained'
            with open(save_dir + 'params.pickle', 'wb') as pickle_file:
                pickle.dump([z_dim, fixed_num_of_contact, contact_point_dim, action_dim, encoder_lr, feature_dims, trans_mode, label, include_action], pickle_file)

        tf.reset_default_graph()
        print("graph reset successfully")
Beispiel #6
0
def main(**kwargs):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = kwargs.get(
        'gpu_frac', 0.95)
    sess = tf.Session(config=config)
    exp_dir = os.getcwd() + '/data/feature_net/' + kwargs['input_label'][
        0] + kwargs['output_label'][0] + '/'
    mode = kwargs['mode'][0]

    if mode == 'restore':
        rotation_saver = tf.train.import_meta_graph(exp_dir + '-999.meta')
        rotation_saver.restore(sess, tf.train.latest_checkpoint(exp_dir))
        graph = tf.get_default_graph()

    with sess.as_default() as sess:

        input_label = kwargs['input_label'][0]
        output_label = kwargs['output_label'][0]
        buffer = {}
        name = '1'
        paths, fixed_num_of_contact = pickle.load(
            open(
                '../saved/trained/SoftHandManipulateEgg-v080-' + name +
                '-dict.pickle', 'rb'))
        for key in paths:
            buffer[key] = paths[key]

        for name in [str(i) for i in range(2, 17)]:
            paths, fixed_num_of_contact = pickle.load(
                open(
                    '../saved/trained/SoftHandManipulateEgg-v080-' + name +
                    '-dict.pickle', 'rb'))
            for key in paths:
                buffer[key] = np.concatenate([buffer[key], paths[key]], axis=0)

        env = gym.make(kwargs['env'][0],
                       obs_type=kwargs['obs_type'][0],
                       fixed_num_of_contact=fixed_num_of_contact)
        batch_size = 100
        paths = data_filter(buffer, fixed_num_of_contact, batch_size)
        niters = paths['positions'].shape[0] // batch_size
        print("total iteration: ", niters)
        print("total number of data: ", paths['positions'].shape[0])

        train_data, test_data, _, _ = split_data(paths, niters)
        train_data['object_position'] = train_data['object_position'][:, :, :3]
        test_data['object_position'] = test_data['object_position'][:, :, :3]

        labels_to_dims = {}
        labels_to_dims['positions'] = 3

        rotation_model = RotationModel(
            dims=[labels_to_dims[input_label]],
            sess=sess,
            fixed_num_of_contact=fixed_num_of_contact,
            feature_layers=kwargs['feature_layers'][0],
            output_layers=kwargs['output_layers'][0],
            learning_rate=kwargs['learning_rate'][0])

        if mode == 'train':
            sess.run(tf.global_variables_initializer())
            for i in range(niters):
                input, out = train_data[input_label][i], train_data[
                    output_label][i]
                pred = rotation_model.train(input, out)
                logger.logkv("iter", i)
                logger.dumpkvs()
            rotation_model.save_model(exp_dir, 999)

        if mode == 'restore':
            rotation_model.restore()
            for i in range(1):
                logger.logkv("iter", i)
                _, _ = rotation_model.restore_predict(
                    train_data[input_label][i], train_data[output_label][i])
                logger.dumpkvs()
Beispiel #7
0
    def train(self):
        """
        Trains policy on env using algo
        Pseudocode:
            for itr in n_itr:
                for step in num_inner_grad_steps:
                    sampler.sample()
                    algo.compute_updated_dists()
                algo.optimize_policy()
                sampler.update_goals()
        """

        with self.sess.as_default() as sess:
            # initialize uninitialized vars  (only initialize vars that were not loaded)
            sess.run(tf.global_variables_initializer())
            start_time = time.time()

            if self.start_itr == 0:
                self.algo._update_target(tau=1.0)
                if self.n_initial_exploration_steps > 0:
                    while self.replay_buffer._size < self.n_initial_exploration_steps:
                        paths = self.sampler.obtain_samples(
                            log=True, log_prefix='train-', random=True)
                        samples_data = self.sample_processor.process_samples(
                            paths, log='all', log_prefix='train-')[0]
                        self.replay_buffer.add_samples(
                            samples_data['observations'],
                            samples_data['actions'],
                            samples_data['rewards'],
                            samples_data['dones'],
                            samples_data['next_observations'],
                        )

            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                logger.log(
                    "\n ---------------- Iteration %d ----------------" % itr)
                logger.log(
                    "Sampling set of tasks/goals for this meta-batch...")
                """ -------------------- Sampling --------------------------"""

                logger.log("Obtaining samples...")
                time_env_sampling_start = time.time()
                paths = self.sampler.obtain_samples(log=True,
                                                    log_prefix='train-')
                sampling_time = time.time() - time_env_sampling_start
                """ ----------------- Processing Samples ---------------------"""
                # check how the samples are processed
                logger.log("Processing samples...")
                time_proc_samples_start = time.time()
                samples_data = self.sample_processor.process_samples(
                    paths, log='all', log_prefix='train-')[0]
                self.replay_buffer.add_samples(
                    samples_data['observations'],
                    samples_data['actions'],
                    samples_data['rewards'],
                    samples_data['dones'],
                    samples_data['next_observations'],
                )
                proc_samples_time = time.time() - time_proc_samples_start

                paths = self.sampler.obtain_samples(log=True,
                                                    log_prefix='eval-',
                                                    deterministic=True)
                _ = self.sample_processor.process_samples(
                    paths, log='all', log_prefix='eval-')[0]

                # self.log_diagnostics(paths, prefix='train-')
                """ ------------------ Policy Update ---------------------"""

                logger.log("Optimizing policy...")

                # This needs to take all samples_data so that it can construct graph for meta-optimization.
                time_optimization_step_start = time.time()

                self.algo.optimize_policy(self.replay_buffer,
                                          itr * self.epoch_length,
                                          self.num_grad_steps)
                """ ------------------- Logging Stuff --------------------------"""
                logger.logkv('Itr', itr)
                logger.logkv('n_timesteps',
                             self.sampler.total_timesteps_sampled)

                logger.logkv('Time-Optimization',
                             time.time() - time_optimization_step_start)
                logger.logkv('Time-SampleProc', np.sum(proc_samples_time))
                logger.logkv('Time-Sampling', sampling_time)

                logger.logkv('Time', time.time() - start_time)
                logger.logkv('ItrTime', time.time() - itr_start_time)

                logger.dumpkvs()
                if itr == 0:
                    sess.graph.finalize()

        logger.log("Training finished")
        self.sess.close()