def experiment(variant):

    # create multi-task environment and sample tasks
    env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']))
    tasks = env.get_all_task_idx()
    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    reward_dim = 1

    # instantiate networks
    latent_dim = variant['latent_size']
    context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim if variant[
        'algo_params'][
            'use_next_obs_in_context'] else obs_dim + action_dim + reward_dim
    context_encoder_output_dim = latent_dim * 2 if variant['algo_params'][
        'use_information_bottleneck'] else latent_dim
    net_size = variant['net_size']
    recurrent = variant['algo_params']['recurrent']
    encoder_model = RecurrentEncoder if recurrent else MlpEncoder

    context_encoder = encoder_model(
        hidden_sizes=[200, 200, 200],
        input_size=context_encoder_input_dim,
        output_size=context_encoder_output_dim,
    )
    qf1 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    qf2 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    vf = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + latent_dim,
        output_size=1,
    )
    policy = TanhGaussianPolicy(
        hidden_sizes=[net_size, net_size, net_size],
        obs_dim=obs_dim + latent_dim,
        latent_dim=latent_dim,
        action_dim=action_dim,
    )
    agent = PEARLAgent(latent_dim, context_encoder, policy,
                       **variant['algo_params'])
    algorithm = PEARLSoftActorCritic(
        env=env,
        train_tasks=list(tasks[:variant['n_train_tasks']]),
        eval_tasks=list(tasks[-variant['n_eval_tasks']:]),
        nets=[agent, qf1, qf2, vf],
        latent_dim=latent_dim,
        **variant['algo_params'])

    # optionally load pre-trained weights
    if variant['path_to_weights'] is not None:
        path = variant['path_to_weights']
        context_encoder.load_state_dict(
            torch.load(os.path.join(path, 'context_encoder.pth')))
        qf1.load_state_dict(torch.load(os.path.join(path, 'qf1.pth')))
        qf2.load_state_dict(torch.load(os.path.join(path, 'qf2.pth')))
        vf.load_state_dict(torch.load(os.path.join(path, 'vf.pth')))
        # TODO hacky, revisit after model refactor
        algorithm.networks[-2].load_state_dict(
            torch.load(os.path.join(path, 'target_vf.pth')))
        policy.load_state_dict(torch.load(os.path.join(path, 'policy.pth')))

    # optional GPU mode
    ptu.set_gpu_mode(variant['util_params']['use_gpu'],
                     variant['util_params']['gpu_id'])
    if ptu.gpu_enabled():
        algorithm.to()

    # debugging triggers a lot of printing and logs to a debug directory
    DEBUG = variant['util_params']['debug']
    os.environ['DEBUG'] = str(int(DEBUG))

    # create logging directory
    # TODO support Docker
    exp_id = 'debug' if DEBUG else None
    experiment_log_dir = setup_logger(
        variant['env_name'],
        variant=variant,
        exp_id=exp_id,
        base_log_dir=variant['util_params']['base_log_dir'])

    # optionally save eval trajectories as pkl files
    if variant['algo_params']['dump_eval_paths']:
        pickle_dir = experiment_log_dir + '/eval_trajectories'
        pathlib.Path(pickle_dir).mkdir(parents=True, exist_ok=True)

    # run the algorithm
    algorithm.train()
Beispiel #2
0
def experiment(variant):

    domain = variant['domain']
    seed = variant['seed']
    exp_mode = variant['exp_mode']
    max_path_length = variant['algo_params']['max_path_length']
    bcq_interactions = variant['bcq_interactions']
    num_tasks = variant['num_tasks']

    filename = f'./goals/{domain}-{exp_mode}-goals.pkl'
    idx_list, train_goals, wd_goals, ood_goals = pickle.load(
        open(filename, 'rb'))
    idx_list = idx_list[:num_tasks]

    sub_buffer_dir = f"buffers/{domain}/{exp_mode}/max_path_length_{max_path_length}/interactions_{bcq_interactions}k/seed_{seed}"
    buffer_dir = os.path.join(variant['data_models_root'], sub_buffer_dir)

    print("Buffer directory: " + buffer_dir)

    # Load buffer
    bcq_buffers = []

    buffer_loader_id_list = []
    for i, idx in enumerate(idx_list):
        bname = f'goal_0{idx}.zip_pkl' if idx < 10 else f'goal_{idx}.zip_pkl'
        filename = os.path.join(buffer_dir, bname)
        rp_buffer = ReplayBuffer.remote(
            index=i,
            seed=seed,
            num_trans_context=variant['num_trans_context'],
            in_mdp_batch_size=variant['in_mdp_batch_size'],
        )

        buffer_loader_id_list.append(rp_buffer.load_from_gzip.remote(filename))
        bcq_buffers.append(rp_buffer)
    ray.get(buffer_loader_id_list)

    assert len(bcq_buffers) == len(idx_list)

    train_buffer = MultiTaskReplayBuffer(bcq_buffers_list=bcq_buffers, )

    set_seed(variant['seed'])

    # create multi-task environment and sample tasks
    env = env_producer(variant['domain'], seed=0)
    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    reward_dim = 1

    # instantiate networks
    latent_dim = variant['latent_size']
    context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim if variant[
        'algo_params'][
            'use_next_obs_in_context'] else obs_dim + action_dim + reward_dim
    context_encoder_output_dim = latent_dim * 2 if variant['algo_params'][
        'use_information_bottleneck'] else latent_dim
    net_size = variant['net_size']
    recurrent = variant['algo_params']['recurrent']
    encoder_model = RecurrentEncoder if recurrent else MlpEncoder

    context_encoder = encoder_model(
        hidden_sizes=[200, 200, 200],
        input_size=context_encoder_input_dim,
        output_size=context_encoder_output_dim,
    )
    qf1 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    qf2 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    vf = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + latent_dim,
        output_size=1,
    )
    policy = TanhGaussianPolicy(
        hidden_sizes=[net_size, net_size, net_size],
        obs_dim=obs_dim + latent_dim,
        latent_dim=latent_dim,
        action_dim=action_dim,
    )
    agent = PEARLAgent(latent_dim, context_encoder, policy,
                       **variant['algo_params'])
    algorithm = PEARLSoftActorCritic(env=env,
                                     train_goals=train_goals,
                                     wd_goals=wd_goals,
                                     ood_goals=ood_goals,
                                     replay_buffers=train_buffer,
                                     nets=[agent, qf1, qf2, vf],
                                     latent_dim=latent_dim,
                                     **variant['algo_params'])

    # optionally load pre-trained weights
    if variant['path_to_weights'] is not None:
        path = variant['path_to_weights']
        context_encoder.load_state_dict(
            torch.load(os.path.join(path, 'context_encoder.pth')))
        qf1.load_state_dict(torch.load(os.path.join(path, 'qf1.pth')))
        qf2.load_state_dict(torch.load(os.path.join(path, 'qf2.pth')))
        vf.load_state_dict(torch.load(os.path.join(path, 'vf.pth')))
        # TODO hacky, revisit after model refactor
        algorithm.networks[-2].load_state_dict(
            torch.load(os.path.join(path, 'target_vf.pth')))
        policy.load_state_dict(torch.load(os.path.join(path, 'policy.pth')))

    # optional GPU mode
    ptu.set_gpu_mode(variant['util_params']['use_gpu'],
                     variant['util_params']['gpu_id'])
    if ptu.gpu_enabled():
        algorithm.to()

    # debugging triggers a lot of printing and logs to a debug directory
    DEBUG = variant['util_params']['debug']
    os.environ['DEBUG'] = str(int(DEBUG))

    # create logging directory
    # TODO support Docker
    exp_id = 'debug' if DEBUG else None
    experiment_log_dir = setup_logger(
        variant['domain'],
        variant=variant,
        exp_id=exp_id,
        base_log_dir=variant['util_params']['base_log_dir'])

    # optionally save eval trajectories as pkl files
    if variant['algo_params']['dump_eval_paths']:
        pickle_dir = experiment_log_dir + '/eval_trajectories'
        pathlib.Path(pickle_dir).mkdir(parents=True, exist_ok=True)

    # run the algorithm
    algorithm.train()
Beispiel #3
0
def experiment(variant):
    print (variant['env_name'])
    print (variant['env_params'])
    env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']))
    tasks = env.get_all_task_idx()

    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))

    cont_latent_dim, num_cat, latent_dim, num_dir, dir_latent_dim = read_dim(variant['global_latent'])
    r_cont_dim, r_n_cat, r_cat_dim, r_n_dir, r_dir_dim = read_dim(variant['vrnn_latent'])
    reward_dim = 1
    net_size = variant['net_size']
    recurrent = variant['algo_params']['recurrent']
    glob = variant['algo_params']['glob']
    rnn = variant['rnn']
    vrnn_latent = variant['vrnn_latent']
    encoder_model = MlpEncoder
    if recurrent:
        if variant['vrnn_constraint'] == 'logitnormal':
            output_size = r_cont_dim * 2 + r_n_cat * r_cat_dim + r_n_dir * r_dir_dim * 2
        else:
            output_size = r_cont_dim * 2 + r_n_cat * r_cat_dim + r_n_dir * r_dir_dim
        if variant['rnn_sample'] == 'batch_sampling':
            if variant['algo_params']['use_next_obs']:
                input_size = (2 * obs_dim + action_dim + reward_dim) * variant['temp_res']
            else:
                input_size = (obs_dim + action_dim + reward_dim) * variant['temp_res']
        else:
            if variant['algo_params']['use_next_obs']:
                input_size = (2 * obs_dim + action_dim + reward_dim)
            else:
                input_size = (obs_dim + action_dim + reward_dim)
        if rnn == 'rnn':
            recurrent_model = RecurrentEncoder
            recurrent_context_encoder = recurrent_model(
                hidden_sizes=[net_size, net_size, net_size],
                input_size=input_size,
                output_size = output_size
            )
        elif rnn == 'vrnn':
            recurrent_model = VRNNEncoder
            recurrent_context_encoder = recurrent_model(
                hidden_sizes=[net_size, net_size, net_size],
                input_size=input_size,
                output_size=output_size, 
                temperature=variant['temperature'],
                vrnn_latent=variant['vrnn_latent'],
                vrnn_constraint=variant['vrnn_constraint'],
                r_alpha=variant['vrnn_alpha'],
                r_var=variant['vrnn_var'],
            )

    else:
        recurrent_context_encoder = None

    ptu.set_gpu_mode(variant['util_params']['use_gpu'], variant['util_params']['gpu_id'])
    if glob:
        if dir_latent_dim > 0 and variant['constraint'] == 'logitnormal':
            output_size = cont_latent_dim * 2 + num_cat * latent_dim + num_dir * dir_latent_dim * 2
        else:
            output_size = cont_latent_dim * 2 + num_cat * latent_dim + num_dir * dir_latent_dim
        if variant['algo_params']['use_next_obs']:
            input_size = 2 * obs_dim + action_dim + reward_dim
        else:
            input_size = obs_dim + action_dim + reward_dim
        global_context_encoder = encoder_model(
            hidden_sizes=[net_size, net_size, net_size],
            input_size=input_size,
            output_size=output_size, 
        )
    else:
        global_context_encoder = None      
    qf1 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \
                        + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim, 
        output_size=1,
    )
    qf2 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \
                        + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim,  
        output_size=1,
    )
    target_qf1 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \
                        + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim,  
        output_size=1,
    )
    target_qf2 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \
                        + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim, 
        output_size=1,
    )
    policy = TanhGaussianPolicy(
        hidden_sizes=[net_size, net_size, net_size],
        obs_dim=obs_dim + latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \
                        + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim, 
        latent_dim=latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \
                        + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim,
        action_dim=action_dim,
    )
    agent = PEARLAgent(
        global_context_encoder,
        recurrent_context_encoder,
        variant['global_latent'],
        variant['vrnn_latent'],
        policy,
        variant['temperature'],
        variant['unitkl'],
        variant['alpha'],
        variant['constraint'],
        variant['vrnn_constraint'],
        variant['var'],
        variant['vrnn_alpha'],
        variant['vrnn_var'],
        rnn,
        variant['temp_res'],
        variant['rnn_sample'],
        variant['weighted_sample'],
        **variant['algo_params']
    )
    if variant['path_to_weights'] is not None:
        path = variant['path_to_weights']
        with open(os.path.join(path, 'extra_data.pkl'), 'rb') as f:
            extra_data = pickle.load(f)
            variant['algo_params']['start_epoch'] = extra_data['epoch'] + 1
            replay_buffer = extra_data['replay_buffer']
            enc_replay_buffer = extra_data['enc_replay_buffer']
            variant['algo_params']['_n_train_steps_total'] = extra_data['_n_train_steps_total']
            variant['algo_params']['_n_env_steps_total'] = extra_data['_n_env_steps_total']
            variant['algo_params']['_n_rollouts_total'] = extra_data['_n_rollouts_total']
    else:
        replay_buffer=None
        enc_replay_buffer=None

    algorithm = PEARLSoftActorCritic(
        env=env,
        train_tasks=list(tasks[:variant['n_train_tasks']]),
        eval_tasks=list(tasks[-variant['n_eval_tasks']:]),
        nets=[agent, qf1, qf2, target_qf1, target_qf2],
        latent_dim=latent_dim,
        replay_buffer=replay_buffer,
        enc_replay_buffer=enc_replay_buffer,
        temp_res=variant['temp_res'],
        rnn_sample=variant['rnn_sample'],
        **variant['algo_params']
    )

    if variant['path_to_weights'] is not None: 
        path = variant['path_to_weights']
        if recurrent_context_encoder != None:
            recurrent_context_encoder.load_state_dict(torch.load(os.path.join(path, 'recurrent_context_encoder.pth')))
        if global_context_encoder != None:
            global_context_encoder.load_state_dict(torch.load(os.path.join(path, 'global_context_encoder.pth')))
        qf1.load_state_dict(torch.load(os.path.join(path, 'qf1.pth')))
        qf2.load_state_dict(torch.load(os.path.join(path, 'qf2.pth')))
        target_qf1.load_state_dict(torch.load(os.path.join(path, 'target_qf1.pth')))
        target_qf2.load_state_dict(torch.load(os.path.join(path, 'target_qf2.pth')))
        policy.load_state_dict(torch.load(os.path.join(path, 'policy.pth')))

    if ptu.gpu_enabled():
        algorithm.to()

    DEBUG = variant['util_params']['debug']
    os.environ['DEBUG'] = str(int(DEBUG))
    exp_id = 'debug' if DEBUG else None
    if variant.get('log_name', "") == "":
        log_name = variant['env_name']
    else:
        log_name = variant['log_name']
    experiment_log_dir = setup_logger(log_name, \
                            variant=variant, \
                            exp_id=exp_id, \
                            base_log_dir=variant['util_params']['base_log_dir'], \
                            config_log_dir=variant['util_params']['config_log_dir'], \
                            log_dir=variant['util_params']['log_dir'])
    if variant['algo_params']['dump_eval_paths']:
        pickle_dir = experiment_log_dir + '/eval_trajectories'
        pathlib.Path(pickle_dir).mkdir(parents=True, exist_ok=True)

    env.save_all_tasks(experiment_log_dir)

    if variant['eval']:
        algorithm._try_to_eval(0, eval_all=True, eval_train_offline=False, animated=True)
    else:
        algorithm.train()
Beispiel #4
0
def setup_and_run(variant):

    ptu.set_gpu_mode(variant['util_params']['use_gpu'],
                     variant['seed'] % variant['util_params']['num_gpus'])
    #setup env
    env_name = variant['env_name']
    env_params = variant['env_params']
    env_params['n_tasks'] = variant["n_train_tasks"] + variant["n_eval_tasks"]
    env = NormalizedBoxEnv(ENVS[env_name](**env_params))

    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    latent_dim = variant['latent_size']
    reward_dim = 1

    #setup encoder
    context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim if variant[
        'algo_params'][
            'use_next_obs_in_context'] else obs_dim + action_dim + reward_dim
    context_encoder_output_dim = latent_dim * 2 if variant['algo_params'][
        'use_information_bottleneck'] else latent_dim
    net_size = variant['net_size']
    recurrent = variant['algo_params']['recurrent']
    encoder_model = RecurrentEncoder if recurrent else MlpEncoder

    context_encoder = encoder_model(
        hidden_sizes=[200, 200, 200],
        input_size=context_encoder_input_dim,
        output_size=context_encoder_output_dim,
    )

    #setup actor, critic
    qf1 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    qf2 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    target_qf1 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    target_qf2 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    policy = TanhGaussianPolicy(
        hidden_sizes=[net_size, net_size, net_size],
        obs_dim=obs_dim + latent_dim,
        latent_dim=latent_dim,
        action_dim=action_dim,
    )
    agent = PEARLAgent(latent_dim, context_encoder, policy,
                       **variant['algo_params'])

    algorithm = PEARLSoftActorCritic(
        env=env,
        train_tasks=list(np.arange(variant['n_train_tasks'])),
        eval_tasks=list(
            np.arange(variant['n_train_tasks'],
                      variant['n_train_tasks'] + variant['n_eval_tasks'])),
        nets=[agent, qf1, qf2, target_qf1, target_qf2],
        latent_dim=latent_dim,
        **variant['algo_params'])
    # optionally load pre-trained weights
    if variant['path_to_weights'] is not None:
        path = variant['path_to_weights']
        context_encoder.load_state_dict(
            torch.load(os.path.join(path, 'context_encoder.pth')))
        qf1.load_state_dict(torch.load(os.path.join(path, 'qf1.pth')))
        qf2.load_state_dict(torch.load(os.path.join(path, 'qf2.pth')))

        target_qf1.load_state_dict(
            torch.load(os.path.join(path, 'target_qf1.pth')))
        target_qf2.load_state_dict(
            torch.load(os.path.join(path, 'target_qf2.pth')))

        # TODO hacky, revisit after model refactor
        policy.load_state_dict(torch.load(os.path.join(path, 'policy.pth')))

    if ptu.gpu_enabled():
        algorithm.to()

    os.environ['DEBUG'] = str(int(variant['util_params']['debug']))

    #setup logger
    run_mode = variant['run_mode']
    exp_log_name = os.path.join(
        variant['env_name'], run_mode,
        variant['log_annotation'] + variant['variant_name'],
        'seed-' + str(variant['seed']))

    setup_logger(exp_log_name,
                 variant=variant,
                 exp_id=None,
                 base_log_dir=os.environ.get('PEARL_DATA_PATH'),
                 snapshot_mode='gap',
                 snapshot_gap=10)

    # run the algorithm
    if run_mode == 'TRAIN':
        algorithm.train()
    elif run_mode == 'EVAL':
        assert variant['algo_params']['dump_eval_paths'] == True
        algorithm._try_to_eval()
    else:
        algorithm.eval_with_loaded_latent()
Beispiel #5
0
def experiment(variant):
    eval_env = gym.make(variant['env_name'])
    expl_env = eval_env
    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    M = variant['layer_size']
    # q and policy netwroks
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    ).to(ptu.device)
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    ).to(ptu.device)
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    ).to(ptu.device)
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    ).to(ptu.device)

    # initialize with bc or not
    if variant['bc_model'] is None:
        policy = TanhGaussianPolicy(
            obs_dim=obs_dim,
            action_dim=action_dim,
            hidden_sizes=[M, M],
        ).to(ptu.device)
    else:
        bc_model = Mlp(
            input_size=obs_dim,
            output_size=action_dim,
            hidden_sizes=[64, 64],
            output_activation=F.tanh,
        ).to(ptu.device)

        checkpoint = torch.load(variant['bc_model'], map_location=map_location)
        bc_model.load_state_dict(checkpoint['network_state_dict'])
        print('Loading bc model: {}'.format(variant['bc_model']))

        # policy initialized with bc
        policy = TanhGaussianPolicy_BC(
            obs_dim=obs_dim,
            action_dim=action_dim,
            mean_network=bc_model,
            hidden_sizes=[M, M],
        ).to(ptu.device)

    # if bonus: define bonus networks
    if not variant['offline']:
        bonus_layer_size = variant['bonus_layer_size']
        bonus_network = Mlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            hidden_sizes=[bonus_layer_size, bonus_layer_size],
            output_activation=F.sigmoid,
        ).to(ptu.device)

        checkpoint = torch.load(variant['bonus_path'],
                                map_location=map_location)
        bonus_network.load_state_dict(checkpoint['network_state_dict'])
        print('Loading bonus model: {}'.format(variant['bonus_path']))

        if variant['initialize_Q'] and bonus_layer_size == M:
            target_qf1.load_state_dict(checkpoint['network_state_dict'])
            target_qf2.load_state_dict(checkpoint['network_state_dict'])
            print('Initialize QF1 and QF2 with the bonus model: {}'.format(
                variant['bonus_path']))
        if variant['initialize_Q'] and bonus_layer_size != M:
            print(
                ' Size mismatch between Q and bonus- Turining off the initialization'
            )

    # eval_policy = MakeDeterministic(policy)
    eval_path_collector = CustomMDPPathCollector(eval_env, )
    expl_path_collector = MdpPathCollector(
        expl_env,
        policy,
    )
    buffer_filename = None
    if variant['buffer_filename'] is not None:
        buffer_filename = variant['buffer_filename']

    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )

    dataset = eval_env.unwrapped.get_dataset()

    load_hdf5(dataset, replay_buffer, max_size=variant['replay_buffer_size'])

    if variant['normalize']:
        obs_mu, obs_std = dataset['observations'].mean(
            axis=0), dataset['observations'].std(axis=0)
        bonus_norm_param = [obs_mu, obs_std]
    else:
        bonus_norm_param = [None] * 2

    # shift the reward
    if variant['reward_shift'] is not None:
        rewards_shift_param = min(dataset['rewards']) - variant['reward_shift']
        print('.... reward is shifted : {} '.format(rewards_shift_param))
    else:
        rewards_shift_param = None

    if variant['offline']:
        trainer = SACTrainer(env=eval_env,
                             policy=policy,
                             qf1=qf1,
                             qf2=qf2,
                             target_qf1=target_qf1,
                             target_qf2=target_qf2,
                             rewards_shift_param=rewards_shift_param,
                             **variant['trainer_kwargs'])
        print('Agent of type offline SAC created')

    elif variant['bonus'] == 'bonus_add':
        trainer = SAC_BonusTrainer(
            env=eval_env,
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            bonus_network=bonus_network,
            beta=variant['bonus_beta'],
            use_bonus_critic=variant['use_bonus_critic'],
            use_bonus_policy=variant['use_bonus_policy'],
            use_log=variant['use_log'],
            bonus_norm_param=bonus_norm_param,
            rewards_shift_param=rewards_shift_param,
            device=ptu.device,
            **variant['trainer_kwargs'])
        print('Agent of type SAC + additive bonus created')
    elif variant['bonus'] == 'bonus_mlt':
        trainer = SAC_BonusTrainer_Mlt(
            env=eval_env,
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            bonus_network=bonus_network,
            beta=variant['bonus_beta'],
            use_bonus_critic=variant['use_bonus_critic'],
            use_bonus_policy=variant['use_bonus_policy'],
            bonus_norm_param=bonus_norm_param,
            rewards_shift_param=rewards_shift_param,
            device=ptu.device,
            **variant['trainer_kwargs'])
        print('Agent of type SAC + multiplicative bonus created')

    else:
        raise ValueError('Not implemented error')

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        batch_rl=True,
        q_learning_alg=True,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Beispiel #6
0
def main(
        env_name,
        seed,
        deterministic,
        traj_prior,
        start_ft_after,
        ft_steps,
        avoid_freezing_z,
        lr,
        batch_size,
        avoid_loading_critics
):
    config = "configs/{}.json".format(env_name)
    variant = default_config
    if config:
        with open(osp.join(config)) as f:
            exp_params = json.load(f)
        variant = deep_update_dict(exp_params, variant)

    exp_name = variant['env_name']
    print("Experiment: {}".format(exp_name))

    env = NormalizedBoxEnv(ENVS[exp_name](**variant['env_params']))
    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))

    print("Observation space:")
    print(env.observation_space)
    print(obs_dim)
    print("Action space:")
    print(env.action_space)
    print(action_dim)
    print("-" * 10)

    # instantiate networks
    latent_dim = variant['latent_size']
    reward_dim = 1
    context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim if variant['algo_params']['use_next_obs_in_context'] else obs_dim + action_dim + reward_dim
    context_encoder_output_dim = latent_dim * 2 if variant['algo_params']['use_information_bottleneck'] else latent_dim
    net_size = variant['net_size']
    recurrent = variant['algo_params']['recurrent']
    encoder_model = RecurrentEncoder if recurrent else MlpEncoder

    context_encoder = encoder_model(
        hidden_sizes=[200, 200, 200],
        input_size=context_encoder_input_dim,
        output_size=context_encoder_output_dim,
    )
    qf1 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    qf2 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    target_qf1 = qf1.copy()
    target_qf2 = qf2.copy()
    policy = TanhGaussianPolicy(
        hidden_sizes=[net_size, net_size, net_size],
        obs_dim=obs_dim + latent_dim,
        latent_dim=latent_dim,
        action_dim=action_dim,
    )
    agent = PEARLAgent(
        latent_dim,
        context_encoder,
        policy,
        **variant['algo_params']
    )

    # deterministic eval
    if deterministic:
        agent = MakeDeterministic(agent)

    # load trained weights (otherwise simulate random policy)
    path_to_exp = "output/{}/pearl_{}".format(env_name, seed-1)
    print("Based on experiment: {}".format(path_to_exp))
    context_encoder.load_state_dict(torch.load(os.path.join(path_to_exp, 'context_encoder.pth')))
    policy.load_state_dict(torch.load(os.path.join(path_to_exp, 'policy.pth')))
    if not avoid_loading_critics:
        qf1.load_state_dict(torch.load(os.path.join(path_to_exp, 'qf1.pth')))
        qf2.load_state_dict(torch.load(os.path.join(path_to_exp, 'qf2.pth')))
        target_qf1.load_state_dict(torch.load(os.path.join(path_to_exp, 'target_qf1.pth')))
        target_qf2.load_state_dict(torch.load(os.path.join(path_to_exp, 'target_qf2.pth')))

    # optional GPU mode
    ptu.set_gpu_mode(variant['util_params']['use_gpu'], variant['util_params']['gpu_id'])
    if ptu.gpu_enabled():
        agent.to(device)
        policy.to(device)
        context_encoder.to(device)
        qf1.to(device)
        qf2.to(device)
        target_qf1.to(device)
        target_qf2.to(device)

    helper = PEARLFineTuningHelper(
        env=env,
        agent=agent,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,

        num_exp_traj_eval=traj_prior,
        start_fine_tuning=start_ft_after,
        fine_tuning_steps=ft_steps,
        should_freeze_z=(not avoid_freezing_z),

        replay_buffer_size=int(1e6),
        batch_size=batch_size,
        discount=0.99,
        policy_lr=lr,
        qf_lr=lr,
        temp_lr=lr,
        target_entropy=-action_dim,
    )

    helper.fine_tune(variant=variant, seed=seed)