コード例 #1
0
def play(policy_file, seed, n_test_rollouts, render):
    set_global_seeds(seed)

    # Load policy.
    with open(policy_file, 'rb') as f:
        policy = pickle.load(f)
    env_name = policy.info['env_name']

    # Load params
    with open(PARAMS_FILE) as json_file:
        params = json.load(json_file)

    params['env_name'] = env_name

    params = config.prepare_params(params)
    config.log_params(params, logger=logger)

    structure = params['structure']
    task_selection = params['task_selection']
    goal_selection = params['goal_selection']

    dims = config.configure_dims(params)

    eval_params = {
        'exploit': True,
        'use_target_net': params['test_with_polyak'],
        'use_demo_states': False,
        'compute_Q': True,
        'T': params['T'],
        'structure': structure,
        'task_selection': task_selection,
        'goal_selection': goal_selection,
        'queue_length': params['queue_length'],
        'eval': True,
        'render': render
    }

    for name in ['T', 'gamma', 'noise_eps', 'random_eps']:
        eval_params[name] = params[name]

    evaluator = RolloutWorker(params['make_env'], policy, dims, logger,
                              **eval_params)
    evaluator.seed(seed)

    # Run evaluation.
    evaluator.clear_history()
    for _ in range(n_test_rollouts):
        evaluator.generate_rollouts()

    # record logs
    for key, val in evaluator.logs('test'):
        logger.record_tabular(key, np.mean(val))
    logger.dump_tabular()
コード例 #2
0
def main(policy_file, seed, n_test_rollouts, render):
    """
    run HER from a saved policy

    :param policy_file: (str) pickle path to a saved policy
    :param seed: (int) initial seed
    :param n_test_rollouts: (int) the number of test rollouts
    :param render: (bool) if rendering should be done
    """
    set_global_seeds(seed)

    # Load policy.
    with open(policy_file, 'rb') as file_handler:
        policy = pickle.load(file_handler)
    env_name = policy.info['env_name']

    # Prepare params.
    params = config.DEFAULT_PARAMS
    if env_name in config.DEFAULT_ENV_PARAMS:
        params.update(config.DEFAULT_ENV_PARAMS[env_name]
                      )  # merge env-specific parameters in
    params['env_name'] = env_name
    params = config.prepare_params(params)
    config.log_params(params, logger_input=logger)

    dims = config.configure_dims(params)

    eval_params = {
        'exploit': True,
        'use_target_net': params['test_with_polyak'],
        'compute_q': True,
        'rollout_batch_size': 1,
        'render': bool(render),
    }

    for name in ['time_horizon', 'gamma', 'noise_eps', 'random_eps']:
        eval_params[name] = params[name]

    evaluator = RolloutWorker(params['make_env'], policy, dims, logger,
                              **eval_params)
    evaluator.seed(seed)

    # Run evaluation.
    evaluator.clear_history()
    for _ in range(n_test_rollouts):
        evaluator.generate_rollouts()

    # record logs
    for key, val in evaluator.logs('test'):
        logger.record_tabular(key, np.mean(val))
    logger.dump_tabular()
コード例 #3
0
ファイル: play.py プロジェクト: qdrn/baselines
def main(policy_file, seed, n_test_rollouts, render, record):
    set_global_seeds(seed)

    # Load policy.
    with open(policy_file, 'rb') as f:
        policy = pickle.load(f)
    env_name = policy.info['env_name']

    # Prepare params.
    params = config.DEFAULT_PARAMS
    if env_name in config.DEFAULT_ENV_PARAMS:
        params.update(config.DEFAULT_ENV_PARAMS[env_name])  # merge env-specific parameters in
    params['env_name'] = env_name
    params = config.prepare_params(params)
    config.log_params(params, logger=logger)

    dims = config.configure_dims(params)

    eval_params = {
        'exploit': True,
        'use_target_net': params['test_with_polyak'],
        'compute_Q': True,
        'rollout_batch_size': 1,
        'render': bool(render),
    }

    for name in ['T', 'gamma', 'noise_eps', 'random_eps']:
        eval_params[name] = params[name]

    if record:
        make_env = params['make_env']
        def video_callable(episode_id):
            return True
        def make_record_env():
            env = make_env()
            return gym.wrappers.Monitor(env, '../../../results/video/' + env_name, force=True, video_callable=video_callable)
        params['make_env'] = make_record_env
    evaluator = RolloutWorker(params['make_env'], policy, dims, logger, **eval_params)
    evaluator.seed(seed)

    # Run evaluation.
    evaluator.clear_history()
    for _ in range(n_test_rollouts):
        evaluator.generate_rollouts()

    # record logs
    for key, val in evaluator.logs('test'):
        logger.record_tabular(key, np.mean(val))
    logger.dump_tabular()
コード例 #4
0
def main(policy_file, seed, n_test_rollouts, render, with_forces):
    set_global_seeds(seed)

    # Load policy.
    with open(policy_file, 'rb') as f:
        policy = pickle.load(f)
    env_name = policy.info['env_name']

    # Prepare params.
    params = config.DEFAULT_PARAMS
    params['with_forces'] = with_forces
    params['plot_forces'] = False
    if env_name in config.DEFAULT_ENV_PARAMS:
        params.update(config.DEFAULT_ENV_PARAMS[env_name]
                      )  # merge env-specific parameters in
    params['env_name'] = env_name
    params = config.prepare_params(params)
    config.log_params(params, logger=logger)

    dims = config.configure_dims(params)

    eval_params = {
        'exploit': True,
        'use_target_net': params['test_with_polyak'],
        'compute_Q': True,
        'rollout_batch_size': 1,
        'with_forces': with_forces,
        'plot_forces': False,
        'render': bool(render),
    }

    for name in ['T', 'gamma', 'noise_eps', 'random_eps']:
        eval_params[name] = params[name]

    evaluator = RolloutWorker(params['make_env'], policy, dims, logger,
                              **eval_params)
    evaluator.seed(seed)

    # Run evaluation.
    evaluator.clear_history()
    for _ in range(n_test_rollouts):
        evaluator.generate_rollouts()

    # record logs
    for key, val in evaluator.logs('test'):
        logger.record_tabular(key, np.mean(val))
    logger.dump_tabular()
コード例 #5
0
ファイル: play.py プロジェクト: Divyankpandey/baselines
def main(policy_file, seed, n_test_rollouts, render):
    set_global_seeds(seed)

    # Load policy.
    with open(policy_file, 'rb') as f:
        policy = pickle.load(f)
    env_name = policy.info['env_name']

    # Prepare params.
    params = config.DEFAULT_PARAMS
    if env_name in config.DEFAULT_ENV_PARAMS:
        params.update(config.DEFAULT_ENV_PARAMS[env_name])  # merge env-specific parameters in
    params['env_name'] = env_name
    params = config.prepare_params(params)
    config.log_params(params, logger=logger)

    dims = config.configure_dims(params)

    eval_params = {
        'exploit': True,
        'use_target_net': params['test_with_polyak'],
        'compute_Q': True,
        'rollout_batch_size': 1,
        'render': bool(render),
    }

    for name in ['T', 'gamma', 'noise_eps', 'random_eps']:
        eval_params[name] = params[name]
    
    evaluator = RolloutWorker(params['make_env'], policy, dims, logger, **eval_params)
    evaluator.seed(seed)

    # Run evaluation.
    evaluator.clear_history()
    for _ in range(n_test_rollouts):
        evaluator.generate_rollouts()

    # record logs
    for key, val in evaluator.logs('test'):
        logger.record_tabular(key, np.mean(val))
    logger.dump_tabular()
コード例 #6
0
def main(policy_file, seed, n_test_rollouts, render, exploit, compute_q,
         collect_data, goal_generation, note):
    set_global_seeds(seed)

    # Load policy.
    with open(policy_file, 'rb') as f:
        policy = pickle.load(f)
    env_name = policy.info['env_name']

    # Prepare params.
    params = config.DEFAULT_PARAMS
    params['note'] = note or params['note']
    if note:
        with open('params/' + env_name + '/' + note + '.json', 'r') as file:
            override_params = json.loads(file.read())
            params.update(**override_params)

    if env_name in config.DEFAULT_ENV_PARAMS:
        params.update(config.DEFAULT_ENV_PARAMS[env_name]
                      )  # merge env-specific parameters in
    params['env_name'] = env_name
    goal_generation = params['goal_generation']
    params = config.prepare_params(params)
    config.log_params(params, logger=logger)

    dims = config.configure_dims(params)

    eval_params = {
        'exploit': exploit,  # eval: True, train: False
        'use_target_net': params['test_with_polyak'],  # eval/train: False
        'compute_Q': compute_q,  # eval: True, train: False
        'rollout_batch_size': 1,
        'render': render,
    }

    for name in ['T', 'gamma', 'noise_eps', 'random_eps']:
        eval_params[name] = params[name]

    evaluator = RolloutWorker(params['make_env'], policy, dims, logger,
                              **eval_params)
    evaluator.seed(seed)

    # Run evaluation.
    evaluator.clear_history()
    num_skills = params['num_skills']

    if goal_generation == 'Zero':
        generated_goal = np.zeros(evaluator.g.shape)
    else:
        generated_goal = False

    for z in range(num_skills):
        assert (evaluator.rollout_batch_size == 1)
        z_s_onehot = np.zeros([evaluator.rollout_batch_size, num_skills])
        z_s_onehot[0, z] = 1

        base = os.path.splitext(policy_file)[0]
        for i_test_rollouts in range(n_test_rollouts):
            if render == 'rgb_array' or render == 'human':

                imgs, episode = evaluator.generate_rollouts(
                    generated_goal=generated_goal, z_s_onehot=z_s_onehot)
                end = '_test_{:02d}_exploit_{}_compute_q_{}_skill_{}.avi'.format(
                    i_test_rollouts, exploit, compute_q, z)
                test_filename = base + end
                save_video(imgs[0], test_filename, lib='cv2')
            else:
                episode = evaluator.generate_rollouts(
                    generated_goal=generated_goal, z_s_onehot=z_s_onehot)

            if collect_data:
                end = '_test_{:02d}_exploit_{}_compute_q_{}_skill_{}.txt'.format(
                    i_test_rollouts, exploit, compute_q, z)
                test_filename = base + end
                with open(test_filename, 'w') as file:
                    file.write(json.dumps(episode['o'].tolist()))

    # record logs
    for key, val in evaluator.logs('test'):
        logger.record_tabular(key, np.mean(val))
    logger.dump_tabular()
コード例 #7
0
def run_task(v):
    random.seed(v['seed'])
    np.random.seed(v['seed'])

    num_cpu = 1
    if num_cpu > 1:
        try:
            whoami = mpi_fork(num_cpu, ['--bind-to', 'core'])
            print("fancy call succeeded")
        except CalledProcessError:
            print("fancy version of mpi call failed, try simple version")
            whoami = mpi_fork(num_cpu)

        if whoami == 'parent':
            sys.exit(0)
        import baselines.common.tf_util as U
        U.single_threaded_session().__enter__()

    # Configure logging
    rank = MPI.COMM_WORLD.Get_rank()
    logdir = ''
    if rank == 0:
        if logdir or logger_b.get_dir() is None:
            logger_b.configure(dir=logdir)
    else:
        logger_b.configure()
    logdir = logger_b.get_dir()
    assert logdir is not None
    os.makedirs(logdir, exist_ok=True)

    # Seed everything.
    rank_seed = v['seed'] + 1000000 * rank
    set_global_seeds(rank_seed)

    def make_env():
        return PnPEnv()

    env = make_env()
    test_env = make_env()
    env.reset()

    # for _ in range(1000):
    #     env.render()
    #     import pdb; pdb.set_trace()
    #     env.step(env.action_space.sample())

    params = config.DEFAULT_PARAMS
    params['action_l2'] = v['action_l2']
    params['max_u'] = v['max_u']
    params['gamma'] = v['discount']
    params['env_name'] = 'FetchReach-v0'
    params['replay_strategy'] = v['replay_strategy']
    params['lr'] = v['lr']
    params['layers'] = v['layers']
    params['hidden'] = v['hidden']
    params['n_cycles'] = v['n_cycles']  # cycles per epoch
    params['n_batches'] = v['n_batches']  # training batches per cycle
    params['batch_size'] = v[
        'batch_size']  # per mpi thread, measured in transitions and reduced to even multiple of chunk_length.
    params['n_test_rollouts'] = v[
        'n_test_rollouts']  # changed from 10 to 3 # number of test rollouts per epoch, each consists of rollout_batch_size rollouts
    # exploration
    params['random_eps'] = 0.3  # percentage of time a random action is taken
    params['noise_eps'] = v['action_noise']
    params['goal_weight'] = v['goal_weight']
    params['scope'] = 'ddpg3'

    params['sample_expert'] = v['sample_expert']
    params['expert_batch_size'] = v['expert_batch_size']
    params['bc_loss'] = v['bc_loss']
    params['anneal_bc'] = v['anneal_bc']
    params['gail_weight'] = v['gail_weight']
    params['terminate_bootstrapping'] = v['terminate_bootstrapping']
    params['mask_q'] = int(v['mode'] == 'pure_bc')
    params['two_qs'] = v['two_qs']
    params['anneal_discriminator'] = v['anneal_discriminator']
    params['two_rs'] = v['two_qs'] or v['anneal_discriminator']
    params['with_termination'] = v['rollout_terminate']

    if 'clip_dis' in v and v['clip_dis']:
        params['dis_bound'] = v['clip_dis']

    with open(os.path.join(logger_b.get_dir(), 'params.json'), 'w') as f:
        json.dump(params, f)

    params['T'] = v['horizon']
    params['to_goal'] = v['to_goal']

    params = config.prepare_params(params)
    params['make_env'] = make_env
    config.log_params(params, logger=logger_b)

    dims = config.configure_dims(params)

    # prepare GAIL
    if v['use_s_p']:
        discriminator = GAIL(dims['o'] + dims['o'] +
                             dims['g'] if not v['only_s'] else dims['o'] +
                             dims['g'],
                             dims['o'],
                             dims['o'],
                             dims['g'],
                             0.,
                             gail_loss=v['gail_reward'],
                             use_s_p=True,
                             only_s=v['only_s'])
    else:
        discriminator = GAIL(dims['o'] + dims['u'] +
                             dims['g'] if not v['only_s'] else dims['o'] +
                             dims['g'],
                             dims['o'],
                             dims['u'],
                             dims['g'],
                             0.,
                             gail_loss=v['gail_reward'],
                             only_s=v['only_s'])
    params['discriminator'] = discriminator

    # configure replay buffer for expert buffer
    params_expert = {
        k: params[k]
        for k in [
            'make_env', 'replay_k', 'discriminator', 'gail_weight', 'two_rs',
            'with_termination'
        ]
    }
    params_expert[
        'replay_strategy'] = 'future' if v['relabel_expert'] else 'none'

    params_policy_buffer = {
        k: params[k]
        for k in [
            'make_env', 'replay_k', 'discriminator', 'gail_weight', 'two_rs',
            'with_termination'
        ]
    }
    params_policy_buffer['replay_strategy'] = 'future'

    params_empty = {
        k: params[k]
        for k in [
            'make_env', 'replay_k', 'discriminator', 'gail_weight',
            'replay_strategy'
        ]
    }

    policy = config.configure_ddpg(dims=dims,
                                   params=params,
                                   clip_return=v['clip_return'],
                                   reuse=tf.AUTO_REUSE,
                                   env=env,
                                   to_goal=v['to_goal'])

    rollout_params = {
        'exploit': False,
        'use_target_net': False,
        'use_demo_states': True,
        'compute_Q': True,
        'T': params['T'],
        'weight': v['goal_weight'],
        'rollout_terminate': v['rollout_terminate'],
        'to_goal': v['to_goal']
    }

    eval_params = {
        'exploit': True,
        'use_target_net': params['test_with_polyak'],
        'use_demo_states': False,
        'compute_Q': True,
        'T': params['T'],
        'weight': v['goal_weight'],
        'rollout_terminate': v['rollout_terminate'],
        'to_goal': v['to_goal']
    }

    for name in [
            'T', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps'
    ]:
        rollout_params[name] = params[name]
        eval_params[name] = params[name]

    rollout_worker = RolloutWorker([env], policy, dims, logger_b,
                                   **rollout_params)
    # rollout_worker.seed(rank_seed)

    evaluator = RolloutWorker([env], policy, dims, logger_b, **eval_params)
    # evaluator.seed(rank_seed)

    n_traj = v['n_evaluation_traj']

    logger.log("Initializing report and plot_policy_reward...")
    log_dir = logger.get_snapshot_dir()
    inner_log_dir = osp.join(log_dir, 'inner_iters')
    report = HTMLReport(osp.join(log_dir, 'report.html'), images_per_row=3)
    report.add_header("{}".format(EXPERIMENT_TYPE))
    report.add_text(format_dict(v))

    logger.log("Starting the outer iterations")

    logger.log("Generating heat map")

    def evaluate_pnp(env, policy, n_rollouts=100):
        goal_reached = []
        distance_to_goal = []
        for i in range(n_rollouts):
            traj = rollout(env,
                           policy,
                           max_path_length=v['horizon'],
                           using_gym=True)
            goal_reached.append(np.max(traj['env_infos']['goal_reached']))
            distance_to_goal.append(np.min(traj['env_infos']['distance']))

        return np.mean(goal_reached), np.mean(distance_to_goal)

    from sandbox.experiments.goals.pick_n_place.pnp_expert import PnPExpert

    expert_policy = PnPExpert(env)

    expert_params = {
        'exploit': not v['noisy_expert'],
        'use_target_net': False,
        'use_demo_states': False,
        'compute_Q': False,
        'T': params['T'],
        'weight': v['goal_weight'],
        'rollout_terminate': v['rollout_terminate'],
        'to_goal': v['to_goal']
    }

    for name in [
            'T', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps'
    ]:
        expert_params[name] = params[name]

    expert_params['noise_eps'] = v['expert_noise']
    expert_params['random_eps'] = v['expert_eps']

    expert_worker = RolloutWorker([env], expert_policy, dims, logger_b,
                                  **expert_params)

    input_shapes = dims_to_shapes(dims)
    expert_sample_transitions = config.configure_her(params_expert)
    buffer_shapes = {
        key:
        (v['horizon'] if key != 'o' else v['horizon'] + 1, *input_shapes[key])
        for key, val in input_shapes.items()
    }
    buffer_shapes['g'] = (buffer_shapes['g'][0],
                          3 if not v['full_space_as_goal'] else 6)
    buffer_shapes['ag'] = (v['horizon'] + 1,
                           3 if not v['full_space_as_goal'] else 6)
    buffer_shapes['successes'] = (v['horizon'], )
    expert_buffer = ReplayBuffer(buffer_shapes, int(1e6), v['horizon'],
                                 expert_sample_transitions)
    policy.expert_buffer = expert_buffer

    sample_transitions_relabel = config.configure_her(params_policy_buffer)

    for _ in range(v['num_demos']):
        # rollout is generated by expert policy
        episode = expert_worker.generate_rollouts(
            slice_goal=(3, 6) if v['full_space_as_goal'] else None)
        # and is stored into the current expert buffer
        expert_buffer.store_episode(episode)

        # TODO: what is subsampling_rate
    uninitialized_vars = []
    for var in tf.global_variables():
        try:
            tf.get_default_session().run(var)
        except tf.errors.FailedPreconditionError:
            uninitialized_vars.append(var)

    init_new_vars_op = tf.initialize_variables(uninitialized_vars)
    tf.get_default_session().run(init_new_vars_op)

    max_success, min_distance = evaluate_pnp(env, policy)
    outer_iter = 0
    logger.record_tabular("Outer_iter", outer_iter)
    logger.record_tabular("Outer_Success", max_success)
    logger.record_tabular("MinDisToGoal", min_distance)
    logger.dump_tabular()

    for outer_iter in range(1, v['outer_iters']):
        logger.log("Outer itr # %i" % outer_iter)

        with ExperimentLogger(inner_log_dir,
                              outer_iter,
                              snapshot_mode='last',
                              hold_outter_log=True):
            train(
                policy,
                discriminator,
                rollout_worker,
                v['inner_iters'],
                v['n_cycles'],
                v['n_batches'],
                v['n_batches_dis'],
                policy.buffer,
                expert_buffer,
                empty_buffer=empty_buffer if v['on_policy_dis'] else None,
                num_rollouts=v['num_rollouts'],
                feasible_states=feasible_states if v['query_expert'] else None,
                expert_policy=expert_policy if v['query_expert'] else None,
                agent_policy=policy if v['query_agent'] else None,
                train_dis_per_rollout=v['train_dis_per_rollout'],
                noise_expert=v['noise_dis_agent'],
                noise_agent=v['noise_dis_expert'],
                sample_transitions_relabel=sample_transitions_relabel
                if v['relabel_for_policy'] else None,
                outer_iter=outer_iter,
                annealing_coeff=v['annealing_coeff'],
                q_annealing=v['q_annealing'])

        print("evaluating policy performance")

        logger.log("Generating heat map")

        success, min_distance = evaluate_pnp(env, policy)

        logger.record_tabular("Outer_iter", outer_iter)
        logger.record_tabular("Outer_Success", max_success)
        logger.record_tabular("MinDisToGoal", min_distance)
        logger.dump_tabular()

        if success > max_success:
            print("% f >= %f, saving policy to params_best" %
                  (success, max_success))
            with open(osp.join(log_dir, 'params_best.pkl'), 'wb') as f:
                cloudpickle.dump({'env': env, 'policy': policy}, f)
            max_success = success

        report.save()
        report.new_row()
コード例 #8
0
def run_task(v):
    random.seed(v['seed'])
    np.random.seed(v['seed'])
    sampling_res = 2 if 'sampling_res' not in v.keys() else v['sampling_res']

    # Log performance of randomly initialized policy with FIXED goal [0.1, 0.1]
    logger.log("Initializing report and plot_policy_reward...")
    log_dir = logger.get_snapshot_dir()  # problem with logger module here!!
    report = HTMLReport(osp.join(log_dir, 'report.html'), images_per_row=4)

    report.add_header("{}".format(EXPERIMENT_TYPE))
    report.add_text(format_dict(v))

    if v['control_mode'] == 'linear':
        from sandbox.envs.maze.point_maze_env import PointMazeEnv
        inner_env = normalize(PointMazeEnv(maze_id=v['maze_id'], maze_size_scaling=v['maze_scaling'], control_mode=v['control_mode']))
        inner_env_test = normalize(PointMazeEnv(maze_id=v['maze_id'], maze_size_scaling=v['maze_scaling'], control_mode=v['control_mode']))
    elif v['control_mode'] == 'pos':
        from sandbox.envs.maze.point_maze_pos_env import PointMazeEnv
        inner_env = normalize(PointMazeEnv(maze_id=v['maze_id'], maze_size_scaling=v['maze_scaling'], control_mode=v['control_mode']))
        inner_env_test = normalize(PointMazeEnv(maze_id=v['maze_id'], maze_size_scaling=v['maze_scaling'], control_mode=v['control_mode']))

    uniform_goal_generator = UniformStateGenerator(state_size=v['goal_size'], bounds=v['goal_range'],
                                                   center=v['goal_center'])

    num_cpu = 1
    if num_cpu > 1:
        try:
            whoami = mpi_fork(num_cpu, ['--bind-to', 'core'])
            print("fancy call succeeded")
        except CalledProcessError:
            print("fancy version of mpi call failed, try simple version")
            whoami = mpi_fork(num_cpu)

        if whoami == 'parent':
            sys.exit(0)
        import baselines.common.tf_util as U
        U.single_threaded_session().__enter__()

    # Configure logging
    rank = MPI.COMM_WORLD.Get_rank()
    logdir = ''
    if rank == 0:
        if logdir or logger_b.get_dir() is None:
            logger_b.configure(dir=logdir)
    else:
        logger_b.configure()
    logdir = logger_b.get_dir()
    assert logdir is not None
    os.makedirs(logdir, exist_ok=True)

    # Seed everything.
    rank_seed = v['seed'] + 1000000 * rank
    set_global_seeds(rank_seed)

    feasible_states = sample_unif_feas(inner_env, 10)
    if v['unif_starts']:
        starts = np.random.permutation(np.array(feasible_states))[:300]
    else:
        starts = np.array([[0, 0]])

    uniform_start_generator = UniformListStateGenerator(
        starts.tolist(), persistence=v['persistence'], with_replacement=v['with_replacement'], )


    # Prepare params.
    def make_env(inner_env=inner_env, terminal_eps=v['terminal_eps'], terminate_env=v['terminate_env']):
        return GoalStartExplorationEnv(
            env=inner_env, goal_generator=uniform_goal_generator,
            obs2goal_transform=lambda x: x[:v['goal_size']],
            start_generator=uniform_start_generator,
            obs2start_transform=lambda x: x[:v['goal_size']],
            terminal_eps=terminal_eps,
            distance_metric=v['distance_metric'],
            extend_dist_rew=v['extend_dist_rew'],
            only_feasible=v['only_feasible'],
            terminate_env=terminate_env,
            goal_weight=v['goal_weight'],
            inner_weight=0,
            append_goal_to_observation=False
        )


    env = make_env()
    test_env = make_env(inner_env=inner_env_test, terminal_eps=1., terminate_env=True)

    params = config.DEFAULT_PARAMS
    params['action_l2'] = v['action_l2']
    params['max_u'] = v['max_u']
    params['gamma'] = v['discount']
    params['env_name'] = 'FetchReach-v0'
    params['replay_strategy'] = v['replay_strategy']
    params['lr'] = v['lr']
    params['layers'] = v['layers']
    params['hidden'] = v['hidden']
    params['n_cycles'] = v['n_cycles']  # cycles per epoch
    params['n_batches'] = v['n_batches']  # training batches per cycle
    params['batch_size'] = v['batch_size']  # per mpi thread, measured in transitions and reduced to even multiple of chunk_length.
    params['n_test_rollouts'] = v['n_test_rollouts']  # changed from 10 to 3 # number of test rollouts per epoch, each consists of rollout_batch_size rollouts
    # exploration
    params['random_eps'] = 0.3  # percentagcone of time a random action is taken
    params['noise_eps'] = v['action_noise']
    params['goal_weight'] = v['goal_weight']

    with open(os.path.join(logger_b.get_dir(), 'params.json'), 'w') as f:
        json.dump(params, f)

    params['T'] = v['horizon']
    params['to_goal'] = v['to_goal']
    params['nearby_action_penalty'] = v['nearby_action_penalty']
    params['nearby_penalty_weight'] = v['nearby_penalty_weight']
    params['nearby_p'] = v['nearby_p']
    params['perturb_scale'] = v['perturb_scale']
    params['cells_apart'] = v['cells_apart']
    params['perturb_to_feasible'] = v['perturb_to_feasible']

    params['sample_expert'] = v['sample_expert']
    params['expert_batch_size'] = v['expert_batch_size']
    params['bc_loss'] = v['bc_loss']
    params['anneal_bc'] = v['anneal_bc']
    params['gail_weight']  =v['gail_weight']
    params['terminate_bootstrapping'] = v['terminate_bootstrapping']
    params['mask_q'] = int(v['mode'] == 'pure_bc')
    params['two_qs'] = v['two_qs']
    params['anneal_discriminator'] = v['anneal_discriminator']
    params['two_rs'] = v['two_qs'] or v['anneal_discriminator']
    params['with_termination'] = v['rollout_terminate']


    if 'clip_dis' in v and v['clip_dis']:
        params['dis_bound'] = v['clip_dis']

    params = config.prepare_params(params)
    params['make_env'] = make_env
    config.log_params(params, logger=logger_b)

    dims = config.configure_dims(params)
    # prepare GAIL

    if v['use_s_p']:
        discriminator = GAIL(dims['o'] + dims['o'] + dims['g'] if not v['only_s'] else dims['o'] + dims['g'],
                             dims['o'], dims['o'], dims['g'], 0., gail_loss = v['gail_reward'], use_s_p = True, only_s=v['only_s'])
    else:
        discriminator = GAIL(dims['o'] + dims['u'] + dims['g'] if not v['only_s'] else dims['o'] + dims['g'],
                             dims['o'], dims['u'], dims['g'], 0., gail_loss = v['gail_reward'], only_s=v['only_s'])

    params['discriminator'] = discriminator
    # configure replay buffer for expert buffer
    params_expert = {k:params[k] for k in ['make_env', 'replay_k', 'discriminator', 'gail_weight', 'two_rs', 'with_termination']}
    params_expert['replay_strategy'] = 'future' if v['relabel_expert'] else 'none'
    params_expert['sample_g_first'] = v['relabel_expert'] and v['sample_g_first']
    params_expert['zero_action_p'] = v['zero_action_p']

    params_policy_buffer = {k: params[k] for k in ['make_env', 'replay_k', 'discriminator', 'gail_weight', 'two_rs', 'with_termination']}
    params_policy_buffer['replay_strategy'] = 'future'
    params_policy_buffer['sample_g_first'] = False

    policy = config.configure_ddpg(dims=dims, params=params, clip_return=v['clip_return'], reuse=tf.AUTO_REUSE, env=env)


    rollout_params = {
        'exploit': False,
        'use_target_net': False,
        'use_demo_states': True,
        'compute_Q': True,
        'T': params['T'],
        'weight': v['goal_weight'],
        'rollout_terminate': v['rollout_terminate']
    }

    expert_rollout_params = {
        'exploit': not v['noisy_expert'],
        'use_target_net': False,
        'use_demo_states': True,
        'compute_Q': False,
        'T': params['T'],
        'weight': v['goal_weight'],
        'rollout_terminate': v['rollout_terminate']
    }

    for name in ['T', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps']:
        rollout_params[name] = params[name]
        expert_rollout_params[name] = params[name]

    expert_rollout_params['noise_eps'] = v['expert_noise']
    expert_rollout_params['random_eps'] = v['expert_eps']

    rollout_worker = RolloutWorker([env], policy, dims, logger_b, **rollout_params)

    # prepare expert policy, rollout worker
    import joblib
    if v['expert_policy'] == 'planner':
        from sandbox.experiments.goals.maze.expert.maze_expert import MazeExpert
        expert_policy = MazeExpert(inner_env, step_size=0.2)
    else:
        expert_policy = joblib.load(os.path.join(os.path.dirname(os.path.abspath(__file__)), v['expert_policy']))['policy']

    expert_rollout_worker = RolloutWorker([env], expert_policy, dims, logger_b, **expert_rollout_params)
    input_shapes = dims_to_shapes(dims)
    expert_sample_transitions = config.configure_her(params_expert)
    buffer_shapes = {key: (v['horizon'] if key != 'o' else v['horizon'] + 1, *input_shapes[key])
                     for key, val in input_shapes.items()}
    buffer_shapes['g'] = (buffer_shapes['g'][0], 2)
    buffer_shapes['ag'] = (v['horizon'] + 1, 2)
    buffer_shapes['successes'] = (v['horizon'],)
    expert_buffer = ReplayBuffer(buffer_shapes, int(1e6), v['horizon'], expert_sample_transitions)
    policy.expert_buffer = expert_buffer

    sample_transitions_relabel = config.configure_her(params_policy_buffer)

    normal_sample_transitions = policy.sample_transitions
    empty_buffer = ReplayBuffer(buffer_shapes, int(1e6), v['horizon'], normal_sample_transitions)

    if not v['query_expert'] or not 'gail' in v['mode']:
        for i in range(v['num_demos']):
            # rollout is generated by expert policy
            episode = expert_rollout_worker.generate_rollouts(reset=not v['no_resets'])
            # and is stored into the expert buffer
            expert_buffer.store_episode(episode)
            if i <= 20:
                path_length = np.argmax(episode['info_goal_reached'][0])
                path_length = v['horizon'] - 1 if path_length == 0 else path_length

                plot_path(episode['o'][0][:path_length], report=report, obs=True, goal=episode['g'][0][0], limit=v['goal_range'], center=v['goal_center'])
    report.new_row()


    # TODO: what is subsampling_rate
    uninitialized_vars = []
    for var in tf.global_variables():
        try:
            tf.get_default_session().run(var)
        except tf.errors.FailedPreconditionError:
            uninitialized_vars.append(var)

    init_new_vars_op = tf.initialize_variables(uninitialized_vars)
    tf.get_default_session().run(init_new_vars_op)

    outer_iter = 0
    logger.log('Generating the Initial Heatmap...')



    def evaluate_performance(env):
        four_rooms = np.array([[-2, -2], [-13, -13]])
        if v['unif_starts']:
            mean_rewards, successes = [], []
            for pos in four_rooms:
                env.update_start_generator(FixedStateGenerator(np.array(pos)))
                mr, scs = test_and_plot_policy(policy, env, horizon=v['horizon'],  max_reward=v['max_reward'], sampling_res=sampling_res,
                                               n_traj=v['n_traj'],
                                               itr=outer_iter, report=report, limit=v['goal_range'],
                                               center=v['goal_center'], using_gym=True,
                                               noise=v['action_noise'], n_processes=8, log=False)
                mean_rewards.append(mr)
                successes.append(scs)
            with logger.tabular_prefix('Outer_'):
                logger.record_tabular('iter', outer_iter)
                logger.record_tabular('MeanRewards', np.mean(mean_rewards))
                logger.record_tabular('Success', np.mean(successes))
        else:
            env.update_start_generator(FixedStateGenerator(np.array([0, 0])))
            _, scs = test_and_plot_policy(policy, env, horizon=v['horizon'], max_reward=v['max_reward'], sampling_res=sampling_res,
                                          n_traj=v['n_traj'],
                                          itr=outer_iter, report=report, limit=v['goal_range'], center=v['goal_center'],
                                          using_gym=True,
                                          noise=v['action_noise'], n_processes=8)

        report.new_row()

        env.update_start_generator(uniform_start_generator)

        return scs

    logger.dump_tabular(with_prefix=False)

    import cloudpickle
    max_success = 0.

    if not v['query_expert'] and v['num_demos'] > 0:
        if not v['relabel_expert']:
            goals = goals_filtered = expert_buffer.buffers['g'][:v['num_demos'], 0, :]
        else: # collect all states visited by the expert
            goals = None
            for i in range(v['num_demos']):
                terminate_index = np.argmax(expert_buffer.buffers['successes'][i])
                if np.logical_not(np.any(expert_buffer.buffers['successes'][i])):
                    terminate_index = v['horizon']
                cur_goals = expert_buffer.buffers['o'][i, :terminate_index, :2]
                if goals is None:
                    goals = cur_goals
                else:
                    goals = np.concatenate([goals, cur_goals])
            goal_state_collection = StateCollection(distance_threshold=v['coll_eps'])
            goal_state_collection.append(goals)
            goals_filtered = goal_state_collection.states
    else:
        goals_filtered = goals = np.random.permutation(np.array(feasible_states))[:300]
    if v['agent_state_as_goal']:
        goals = goals
    else:
        feasible_states = sample_unif_feas(inner_env, 10)
        goals = np.random.permutation(np.array(feasible_states))[:300]

    evaluate_performance(test_env)
    logger.dump_tabular(with_prefix=False)


    for outer_iter in range(1, v['outer_iters']):

        logger.log("Outer itr # %i" % outer_iter)

        with ExperimentLogger(log_dir, 'last', snapshot_mode='last', hold_outter_log=True):
            logger.log("Updating the environment goal generator")
            if v['unif_goals']:
                env.update_goal_generator(
                    UniformListStateGenerator(
                        goals.tolist(), persistence=v['persistence'], with_replacement=v['with_replacement'],
                    )
                )
            else:
                env.update_goal_generator(FixedStateGenerator(v['final_goal']))

            logger.log("Training the algorithm")

            train(policy, discriminator, rollout_worker, v['inner_iters'], v['n_cycles'], v['n_batches'], v['n_batches_dis'], policy.buffer, expert_buffer,
                  empty_buffer=empty_buffer if v['on_policy_dis'] else None, num_rollouts=v['num_rollouts'], reset=not v['no_resets'],
                  feasible_states=feasible_states if v['query_expert'] else None, expert_policy=expert_policy if v['query_expert'] else None,
                  agent_policy=policy if v['query_agent'] else None, train_dis_per_rollout=v['train_dis_per_rollout'],
                  noise_expert=v['noise_dis_agent'], noise_agent=v['noise_dis_expert'], sample_transitions_relabel=sample_transitions_relabel if v['relabel_for_policy'] else None,
                  q_annealing=v['q_annealing'], outer_iter=outer_iter, annealing_coeff=v['annealing_coeff'])

        # logger.record_tabular('NonZeroRewProp', nonzeros)
        logger.log('Generating the Heatmap...')

        success = evaluate_performance(test_env)

        if success > max_success:
            print ("% f >= %f, saving policy to params_best" % (success, max_success))
            with open(osp.join(log_dir, 'params_best.pkl'), 'wb') as f:
                cloudpickle.dump({'env': env, 'policy': policy}, f)
                max_success = success

        report.new_row()

        logger.dump_tabular(with_prefix=False)