Beispiel #1
0
def experiment(variant, bcq_buffers, prev_exp_state=None):
    # Create the multitask replay buffer based on the buffer list
    train_buffer = MultiTaskReplayBuffer(bcq_buffers_list=bcq_buffers, )
    # create multi-task environment and sample tasks
    env = env_producer(variant['domain'], variant['seed'])
    env.reset()

    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))

    # instantiate networks
    network_ensemble = []
    for _ in range(variant['num_network_ensemble']):
        P = FlattenMlp(
            hidden_sizes=variant['P_hidden_sizes'],
            input_size=obs_dim + action_dim,
            output_size=1,
        )
        network_ensemble.append(P)

    trainer = SuperQTrainer(
        env,
        network_ensemble=network_ensemble,
        train_goal=variant['train_goal'],
        std_threshold=variant['std_threshold'],
        domain=variant['domain'],
    )

    algorithm = BatchMetaRLAlgorithm(
        trainer,
        train_buffer,
        **variant['algo_params'],
    )

    algorithm.to(ptu.device)

    start_epoch = prev_exp_state['epoch'] + \
        1 if prev_exp_state is not None else 0

    algorithm.train(start_epoch)
Beispiel #2
0
def experiment(variant):

    domain = variant['domain']
    seed = variant['seed']
    exp_mode = variant['exp_mode']
    max_path_length = variant['algo_params']['max_path_length']
    bcq_interactions = variant['bcq_interactions']
    num_tasks = variant['num_tasks']

    filename = f'./goals/{domain}-{exp_mode}-goals.pkl'
    idx_list, train_goals, wd_goals, ood_goals = pickle.load(
        open(filename, 'rb'))
    idx_list = idx_list[:num_tasks]

    sub_buffer_dir = f"buffers/{domain}/{exp_mode}/max_path_length_{max_path_length}/interactions_{bcq_interactions}k/seed_{seed}"
    buffer_dir = os.path.join(variant['data_models_root'], sub_buffer_dir)

    print("Buffer directory: " + buffer_dir)

    # Load buffer
    bcq_buffers = []

    buffer_loader_id_list = []
    for i, idx in enumerate(idx_list):
        bname = f'goal_0{idx}.zip_pkl' if idx < 10 else f'goal_{idx}.zip_pkl'
        filename = os.path.join(buffer_dir, bname)
        rp_buffer = ReplayBuffer.remote(
            index=i,
            seed=seed,
            num_trans_context=variant['num_trans_context'],
            in_mdp_batch_size=variant['in_mdp_batch_size'],
        )

        buffer_loader_id_list.append(rp_buffer.load_from_gzip.remote(filename))
        bcq_buffers.append(rp_buffer)
    ray.get(buffer_loader_id_list)

    assert len(bcq_buffers) == len(idx_list)

    train_buffer = MultiTaskReplayBuffer(bcq_buffers_list=bcq_buffers, )

    set_seed(variant['seed'])

    # create multi-task environment and sample tasks
    env = env_producer(variant['domain'], seed=0)
    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    reward_dim = 1

    # instantiate networks
    latent_dim = variant['latent_size']
    context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim if variant[
        'algo_params'][
            'use_next_obs_in_context'] else obs_dim + action_dim + reward_dim
    context_encoder_output_dim = latent_dim * 2 if variant['algo_params'][
        'use_information_bottleneck'] else latent_dim
    net_size = variant['net_size']
    recurrent = variant['algo_params']['recurrent']
    encoder_model = RecurrentEncoder if recurrent else MlpEncoder

    context_encoder = encoder_model(
        hidden_sizes=[200, 200, 200],
        input_size=context_encoder_input_dim,
        output_size=context_encoder_output_dim,
    )
    qf1 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    qf2 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    vf = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + latent_dim,
        output_size=1,
    )
    policy = TanhGaussianPolicy(
        hidden_sizes=[net_size, net_size, net_size],
        obs_dim=obs_dim + latent_dim,
        latent_dim=latent_dim,
        action_dim=action_dim,
    )
    agent = PEARLAgent(latent_dim, context_encoder, policy,
                       **variant['algo_params'])
    algorithm = PEARLSoftActorCritic(env=env,
                                     train_goals=train_goals,
                                     wd_goals=wd_goals,
                                     ood_goals=ood_goals,
                                     replay_buffers=train_buffer,
                                     nets=[agent, qf1, qf2, vf],
                                     latent_dim=latent_dim,
                                     **variant['algo_params'])

    # optionally load pre-trained weights
    if variant['path_to_weights'] is not None:
        path = variant['path_to_weights']
        context_encoder.load_state_dict(
            torch.load(os.path.join(path, 'context_encoder.pth')))
        qf1.load_state_dict(torch.load(os.path.join(path, 'qf1.pth')))
        qf2.load_state_dict(torch.load(os.path.join(path, 'qf2.pth')))
        vf.load_state_dict(torch.load(os.path.join(path, 'vf.pth')))
        # TODO hacky, revisit after model refactor
        algorithm.networks[-2].load_state_dict(
            torch.load(os.path.join(path, 'target_vf.pth')))
        policy.load_state_dict(torch.load(os.path.join(path, 'policy.pth')))

    # optional GPU mode
    ptu.set_gpu_mode(variant['util_params']['use_gpu'],
                     variant['util_params']['gpu_id'])
    if ptu.gpu_enabled():
        algorithm.to()

    # debugging triggers a lot of printing and logs to a debug directory
    DEBUG = variant['util_params']['debug']
    os.environ['DEBUG'] = str(int(DEBUG))

    # create logging directory
    # TODO support Docker
    exp_id = 'debug' if DEBUG else None
    experiment_log_dir = setup_logger(
        variant['domain'],
        variant=variant,
        exp_id=exp_id,
        base_log_dir=variant['util_params']['base_log_dir'])

    # optionally save eval trajectories as pkl files
    if variant['algo_params']['dump_eval_paths']:
        pickle_dir = experiment_log_dir + '/eval_trajectories'
        pathlib.Path(pickle_dir).mkdir(parents=True, exist_ok=True)

    # run the algorithm
    algorithm.train()
Beispiel #3
0
def experiment(variant,
               bcq_policies,
               bcq_buffers,
               ensemble_params_list,
               prev_exp_state=None):
    # Create the multitask replay buffer based on the buffer list
    train_buffer = MultiTaskReplayBuffer(bcq_buffers_list=bcq_buffers, )
    # create multi-task environment and sample tasks
    env = env_producer(variant['domain'], variant['seed'])

    env_max_action = float(env.action_space.high[0])
    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    vae_latent_dim = 2 * action_dim
    mlp_enconder_input_size = 2 * obs_dim + action_dim + 1 if variant[
        'use_next_obs_in_context'] else obs_dim + action_dim + 1

    variant['env_max_action'] = env_max_action
    variant['obs_dim'] = obs_dim
    variant['action_dim'] = action_dim

    variant['mlp_enconder_input_size'] = mlp_enconder_input_size

    # instantiate networks

    mlp_enconder = MlpEncoder(hidden_sizes=[200, 200, 200],
                              input_size=mlp_enconder_input_size,
                              output_size=2 * variant['latent_dim'])
    context_encoder = ProbabilisticContextEncoder(mlp_enconder,
                                                  variant['latent_dim'])

    ensemble_predictor = EnsemblePredictor(ensemble_params_list)

    Qs = FlattenMlp(
        hidden_sizes=variant['Qs_hidden_sizes'],
        input_size=obs_dim + action_dim + variant['latent_dim'],
        output_size=1,
    )
    vae_decoder = VaeDecoder(
        max_action=env_max_action,
        hidden_sizes=variant['vae_hidden_sizes'],
        input_size=obs_dim + vae_latent_dim + variant['latent_dim'],
        output_size=action_dim,
    )
    perturbation_generator = PerturbationGenerator(
        max_action=env_max_action,
        hidden_sizes=variant['perturbation_hidden_sizes'],
        input_size=obs_dim + action_dim + variant['latent_dim'],
        output_size=action_dim,
    )
    trainer = SuperQTrainer(
        ensemble_predictor=ensemble_predictor,
        num_network_ensemble=variant['num_network_ensemble'],
        bcq_policies=bcq_policies,
        std_threshold=variant['std_threshold'],
        is_combine=variant['is_combine'],
        nets=[context_encoder, Qs, vae_decoder, perturbation_generator])

    path_collector = RemotePathCollector(variant)

    algorithm = BatchMetaRLAlgorithm(
        trainer,
        path_collector,
        train_buffer,
        **variant['algo_params'],
    )

    algorithm.to(ptu.device)

    start_epoch = prev_exp_state['epoch'] + \
        1 if prev_exp_state is not None else 0

    # Log the variant
    logger.log("Variant:")
    logger.log(json.dumps(dict_to_safe_json(variant), indent=2))

    algorithm.train(start_epoch)
Beispiel #4
0
        bname = f'goal_0{idx}.zip_pkl' if idx < 10 else f'goal_{idx}.zip_pkl'
        filename = os.path.join(buffer_dir, bname)
        rp_buffer = ReplayBuffer.remote(
            index=i,
            seed=seed,
            num_trans_context=variant['num_trans_context'],
            in_mdp_batch_size=variant['in_mdp_batch_size'],
        )

        buffer_loader_id_list.append(rp_buffer.load_from_gzip.remote(filename))
        bcq_buffers.append(rp_buffer)
    ray.get(buffer_loader_id_list)

    assert len(bcq_buffers) == len(idx_list)

    train_buffer = MultiTaskReplayBuffer(bcq_buffers_list=bcq_buffers, )
    path_collector = RemotePathCollector(variant)

    # Initialize policy
    policy = BCQ(state_dim, action_dim, max_action, **variant['policy_params'])

    algorithm = BatchMetaRLAlgorithm(policy, path_collector, train_buffer,
                                     **variant['algo_params'])

    algorithm.to(ptu.device)

    # Log the variant
    logger.log("Variant:")
    logger.log(json.dumps(dict_to_safe_json(variant), indent=2))

    algorithm.train()