示例#1
0
文件: sac.py 项目: sumitsk/oyster
def experiment(variant):
    env = NormalizedBoxEnv(PointEnv(**variant['task_params']))
    ptu.set_gpu_mode(variant['use_gpu'], variant['gpu_id'])

    tasks = env.get_all_task_idx()

    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    latent_dim = 5
    task_enc_output_dim = latent_dim * 2 if variant['algo_params']['use_information_bottleneck'] else latent_dim
    reward_dim = 1

    net_size = variant['net_size']
    # start with linear task encoding
    recurrent = variant['algo_params']['recurrent']
    encoder_model = RecurrentEncoder if recurrent else MlpEncoder
    task_enc = encoder_model(
            hidden_sizes=[200, 200, 200], # deeper net + higher dim space generalize better
            input_size=obs_dim + action_dim + reward_dim,
            output_size=task_enc_output_dim,
    )
    qf1 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    qf2 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    vf = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + latent_dim,
        output_size=1,
    )
    policy = TanhGaussianPolicy(
        hidden_sizes=[net_size, net_size, net_size],
        obs_dim=obs_dim + latent_dim,
        latent_dim=latent_dim,
        action_dim=action_dim,
    )
    agent = ProtoAgent(
        latent_dim,
        [task_enc, policy, qf1, qf2, vf],
        **variant['algo_params']
    )

    algorithm = ProtoSoftActorCritic(
        env=env,
        train_tasks=list(tasks[:-20]),
        eval_tasks=list(tasks[-20:]),
        nets=[agent, task_enc, policy, qf1, qf2, vf],
        latent_dim=latent_dim,
        **variant['algo_params']
    )
    if ptu.gpu_enabled():
        algorithm.to()
    algorithm.train()
示例#2
0
def experiment(variant):

    # create multi-task environment and sample tasks
    env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']))
    tasks = env.get_all_task_idx()
    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    reward_dim = 1

    # instantiate networks
    latent_dim = variant['latent_size']
    context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim if variant[
        'algo_params'][
            'use_next_obs_in_context'] else obs_dim + action_dim + reward_dim
    context_encoder_output_dim = latent_dim * 2 if variant['algo_params'][
        'use_information_bottleneck'] else latent_dim
    net_size = variant['net_size']
    recurrent = variant['algo_params']['recurrent']
    encoder_model = RecurrentEncoder if recurrent else MlpEncoder

    context_encoder = encoder_model(
        hidden_sizes=[200, 200, 200],
        input_size=context_encoder_input_dim,
        output_size=context_encoder_output_dim,
    )
    qf1 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    qf2 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    vf = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + latent_dim,
        output_size=1,
    )
    policy = TanhGaussianPolicy(
        hidden_sizes=[net_size, net_size, net_size],
        obs_dim=obs_dim + latent_dim,
        latent_dim=latent_dim,
        action_dim=action_dim,
    )
    agent = PEARLAgent(latent_dim, context_encoder, policy,
                       **variant['algo_params'])
    algorithm = PEARLSoftActorCritic(
        env=env,
        train_tasks=list(tasks[:variant['n_train_tasks']]),
        eval_tasks=list(tasks[-variant['n_eval_tasks']:]),
        nets=[agent, qf1, qf2, vf],
        latent_dim=latent_dim,
        **variant['algo_params'])

    # optionally load pre-trained weights
    if variant['path_to_weights'] is not None:
        path = variant['path_to_weights']
        context_encoder.load_state_dict(
            torch.load(os.path.join(path, 'context_encoder.pth')))
        qf1.load_state_dict(torch.load(os.path.join(path, 'qf1.pth')))
        qf2.load_state_dict(torch.load(os.path.join(path, 'qf2.pth')))
        vf.load_state_dict(torch.load(os.path.join(path, 'vf.pth')))
        # TODO hacky, revisit after model refactor
        algorithm.networks[-2].load_state_dict(
            torch.load(os.path.join(path, 'target_vf.pth')))
        policy.load_state_dict(torch.load(os.path.join(path, 'policy.pth')))

    # optional GPU mode
    ptu.set_gpu_mode(variant['util_params']['use_gpu'],
                     variant['util_params']['gpu_id'])
    if ptu.gpu_enabled():
        algorithm.to()

    # debugging triggers a lot of printing and logs to a debug directory
    DEBUG = variant['util_params']['debug']
    os.environ['DEBUG'] = str(int(DEBUG))

    # create logging directory
    # TODO support Docker
    exp_id = 'debug' if DEBUG else None
    experiment_log_dir = setup_logger(
        variant['env_name'],
        variant=variant,
        exp_id=exp_id,
        base_log_dir=variant['util_params']['base_log_dir'])

    # optionally save eval trajectories as pkl files
    if variant['algo_params']['dump_eval_paths']:
        pickle_dir = experiment_log_dir + '/eval_trajectories'
        pathlib.Path(pickle_dir).mkdir(parents=True, exist_ok=True)

    # run the algorithm
    algorithm.train()
示例#3
0
def experiment(variant, seed=None):

    # create multi-task environment and sample tasks, normalize obs if provided with 'normalizer.npz'
    if 'normalizer.npz' in os.listdir(variant['algo_params']['data_dir']):
        obs_absmax = np.load(os.path.join(variant['algo_params']['data_dir'], 'normalizer.npz'))['abs_max']
        env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']), obs_absmax=obs_absmax)
    else:
        env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']))
    
    if seed is not None:
        global_seed(seed)
        env.seed(seed)

    tasks = env.get_all_task_idx()
    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    reward_dim = 1

    # instantiate networks
    latent_dim = variant['latent_size']
    context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim if variant['algo_params']['use_next_obs_in_context'] else obs_dim + action_dim + reward_dim
    context_encoder_output_dim = latent_dim * 2 if variant['algo_params']['use_information_bottleneck'] else latent_dim
    net_size = variant['net_size']
    recurrent = variant['algo_params']['recurrent']
    encoder_model = RecurrentEncoder if recurrent else MlpEncoder

    context_encoder = encoder_model(
        hidden_sizes=[200, 200, 200],
        input_size=context_encoder_input_dim,
        output_size=context_encoder_output_dim,
        output_activation=torch.tanh,
    )
    qf1 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    qf2 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim,
        output_size=1,
    )
    vf = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + latent_dim,
        output_size=1,
    )

    policy = TanhGaussianPolicy(
        hidden_sizes=[net_size, net_size, net_size],
        obs_dim=obs_dim + latent_dim,
        latent_dim=latent_dim,
        action_dim=action_dim,
    )

    agent = PEARLAgent(
        latent_dim,
        context_encoder,
        policy,
        **variant['algo_params']
    )
    if variant['algo_type'] == 'FOCAL':
        # critic network for divergence in dual form (see BRAC paper https://arxiv.org/abs/1911.11361)
        c = FlattenMlp(
            hidden_sizes=[net_size, net_size, net_size],
            input_size=obs_dim + action_dim + latent_dim,
            output_size=1
        )
        if 'randomize_tasks' in variant.keys() and variant['randomize_tasks']:
            rng = default_rng()
            train_tasks = rng.choice(len(tasks), size=variant['n_train_tasks'], replace=False)
            eval_tasks = set(range(len(tasks))).difference(train_tasks)
            if 'goal_radius' in variant['env_params']:
                algorithm = FOCALSoftActorCritic(
                    env=env,
                    train_tasks=train_tasks,
                    eval_tasks=eval_tasks,
                    nets=[agent, qf1, qf2, vf, c],
                    latent_dim=latent_dim,
                    goal_radius=variant['env_params']['goal_radius'],
                    **variant['algo_params']
                )
            else:
                algorithm = FOCALSoftActorCritic(
                    env=env,
                    train_tasks=list(tasks[:variant['n_train_tasks']]),
                    eval_tasks=list(tasks[-variant['n_eval_tasks']:]),
                    nets=[agent, qf1, qf2, vf, c],
                    latent_dim=latent_dim,
                    **variant['algo_params']
                )
        else:
            if 'goal_radius' in variant['env_params']:
                algorithm = FOCALSoftActorCritic(
                    env=env,
                    train_tasks=list(tasks[:variant['n_train_tasks']]),
                    eval_tasks=list(tasks[-variant['n_eval_tasks']:]),
                    nets=[agent, qf1, qf2, vf, c],
                    latent_dim=latent_dim,
                    goal_radius=variant['env_params']['goal_radius'],
                    **variant['algo_params']
                )
            else:
                algorithm = FOCALSoftActorCritic(
                    env=env,
                    train_tasks=list(tasks[:variant['n_train_tasks']]),
                    eval_tasks=list(tasks[-variant['n_eval_tasks']:]),
                    nets=[agent, qf1, qf2, vf, c],
                    latent_dim=latent_dim,
                    **variant['algo_params']
                )
    else:
        NotImplemented

    # optional GPU mode
    ptu.set_gpu_mode(variant['util_params']['use_gpu'], variant['util_params']['gpu_id'])
    if ptu.gpu_enabled():
        algorithm.to()

    # debugging triggers a lot of printing and logs to a debug directory
    DEBUG = variant['util_params']['debug']
    os.environ['DEBUG'] = str(int(DEBUG))

    # create logging directory
    # TODO support Docker
    exp_id = 'debug' if DEBUG else None
    experiment_log_dir = setup_logger(
        variant['env_name'],
        variant=variant,
        exp_id=exp_id,
        base_log_dir=variant['util_params']['base_log_dir'],
        seed=seed,
        snapshot_mode="all"
    )

    # optionally save eval trajectories as pkl files
    if variant['algo_params']['dump_eval_paths']:
        pickle_dir = experiment_log_dir + '/eval_trajectories'
        pathlib.Path(pickle_dir).mkdir(parents=True, exist_ok=True)

    # run the algorithm
    algorithm.train()
示例#4
0
def experiment(variant):
    print (variant['env_name'])
    print (variant['env_params'])
    env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']))
    tasks = env.get_all_task_idx()

    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))

    cont_latent_dim, num_cat, latent_dim, num_dir, dir_latent_dim = read_dim(variant['global_latent'])
    r_cont_dim, r_n_cat, r_cat_dim, r_n_dir, r_dir_dim = read_dim(variant['vrnn_latent'])
    reward_dim = 1
    net_size = variant['net_size']
    recurrent = variant['algo_params']['recurrent']
    glob = variant['algo_params']['glob']
    rnn = variant['rnn']
    vrnn_latent = variant['vrnn_latent']
    encoder_model = MlpEncoder
    if recurrent:
        if variant['vrnn_constraint'] == 'logitnormal':
            output_size = r_cont_dim * 2 + r_n_cat * r_cat_dim + r_n_dir * r_dir_dim * 2
        else:
            output_size = r_cont_dim * 2 + r_n_cat * r_cat_dim + r_n_dir * r_dir_dim
        if variant['rnn_sample'] == 'batch_sampling':
            if variant['algo_params']['use_next_obs']:
                input_size = (2 * obs_dim + action_dim + reward_dim) * variant['temp_res']
            else:
                input_size = (obs_dim + action_dim + reward_dim) * variant['temp_res']
        else:
            if variant['algo_params']['use_next_obs']:
                input_size = (2 * obs_dim + action_dim + reward_dim)
            else:
                input_size = (obs_dim + action_dim + reward_dim)
        if rnn == 'rnn':
            recurrent_model = RecurrentEncoder
            recurrent_context_encoder = recurrent_model(
                hidden_sizes=[net_size, net_size, net_size],
                input_size=input_size,
                output_size = output_size
            )
        elif rnn == 'vrnn':
            recurrent_model = VRNNEncoder
            recurrent_context_encoder = recurrent_model(
                hidden_sizes=[net_size, net_size, net_size],
                input_size=input_size,
                output_size=output_size, 
                temperature=variant['temperature'],
                vrnn_latent=variant['vrnn_latent'],
                vrnn_constraint=variant['vrnn_constraint'],
                r_alpha=variant['vrnn_alpha'],
                r_var=variant['vrnn_var'],
            )

    else:
        recurrent_context_encoder = None

    ptu.set_gpu_mode(variant['util_params']['use_gpu'], variant['util_params']['gpu_id'])
    if glob:
        if dir_latent_dim > 0 and variant['constraint'] == 'logitnormal':
            output_size = cont_latent_dim * 2 + num_cat * latent_dim + num_dir * dir_latent_dim * 2
        else:
            output_size = cont_latent_dim * 2 + num_cat * latent_dim + num_dir * dir_latent_dim
        if variant['algo_params']['use_next_obs']:
            input_size = 2 * obs_dim + action_dim + reward_dim
        else:
            input_size = obs_dim + action_dim + reward_dim
        global_context_encoder = encoder_model(
            hidden_sizes=[net_size, net_size, net_size],
            input_size=input_size,
            output_size=output_size, 
        )
    else:
        global_context_encoder = None      
    qf1 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \
                        + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim, 
        output_size=1,
    )
    qf2 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \
                        + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim,  
        output_size=1,
    )
    target_qf1 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \
                        + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim,  
        output_size=1,
    )
    target_qf2 = FlattenMlp(
        hidden_sizes=[net_size, net_size, net_size],
        input_size=obs_dim + action_dim + latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \
                        + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim, 
        output_size=1,
    )
    policy = TanhGaussianPolicy(
        hidden_sizes=[net_size, net_size, net_size],
        obs_dim=obs_dim + latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \
                        + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim, 
        latent_dim=latent_dim*num_cat + cont_latent_dim + dir_latent_dim*num_dir \
                        + r_n_cat * r_cat_dim + r_cont_dim + r_n_dir * r_dir_dim,
        action_dim=action_dim,
    )
    agent = PEARLAgent(
        global_context_encoder,
        recurrent_context_encoder,
        variant['global_latent'],
        variant['vrnn_latent'],
        policy,
        variant['temperature'],
        variant['unitkl'],
        variant['alpha'],
        variant['constraint'],
        variant['vrnn_constraint'],
        variant['var'],
        variant['vrnn_alpha'],
        variant['vrnn_var'],
        rnn,
        variant['temp_res'],
        variant['rnn_sample'],
        variant['weighted_sample'],
        **variant['algo_params']
    )
    if variant['path_to_weights'] is not None:
        path = variant['path_to_weights']
        with open(os.path.join(path, 'extra_data.pkl'), 'rb') as f:
            extra_data = pickle.load(f)
            variant['algo_params']['start_epoch'] = extra_data['epoch'] + 1
            replay_buffer = extra_data['replay_buffer']
            enc_replay_buffer = extra_data['enc_replay_buffer']
            variant['algo_params']['_n_train_steps_total'] = extra_data['_n_train_steps_total']
            variant['algo_params']['_n_env_steps_total'] = extra_data['_n_env_steps_total']
            variant['algo_params']['_n_rollouts_total'] = extra_data['_n_rollouts_total']
    else:
        replay_buffer=None
        enc_replay_buffer=None

    algorithm = PEARLSoftActorCritic(
        env=env,
        train_tasks=list(tasks[:variant['n_train_tasks']]),
        eval_tasks=list(tasks[-variant['n_eval_tasks']:]),
        nets=[agent, qf1, qf2, target_qf1, target_qf2],
        latent_dim=latent_dim,
        replay_buffer=replay_buffer,
        enc_replay_buffer=enc_replay_buffer,
        temp_res=variant['temp_res'],
        rnn_sample=variant['rnn_sample'],
        **variant['algo_params']
    )

    if variant['path_to_weights'] is not None: 
        path = variant['path_to_weights']
        if recurrent_context_encoder != None:
            recurrent_context_encoder.load_state_dict(torch.load(os.path.join(path, 'recurrent_context_encoder.pth')))
        if global_context_encoder != None:
            global_context_encoder.load_state_dict(torch.load(os.path.join(path, 'global_context_encoder.pth')))
        qf1.load_state_dict(torch.load(os.path.join(path, 'qf1.pth')))
        qf2.load_state_dict(torch.load(os.path.join(path, 'qf2.pth')))
        target_qf1.load_state_dict(torch.load(os.path.join(path, 'target_qf1.pth')))
        target_qf2.load_state_dict(torch.load(os.path.join(path, 'target_qf2.pth')))
        policy.load_state_dict(torch.load(os.path.join(path, 'policy.pth')))

    if ptu.gpu_enabled():
        algorithm.to()

    DEBUG = variant['util_params']['debug']
    os.environ['DEBUG'] = str(int(DEBUG))
    exp_id = 'debug' if DEBUG else None
    if variant.get('log_name', "") == "":
        log_name = variant['env_name']
    else:
        log_name = variant['log_name']
    experiment_log_dir = setup_logger(log_name, \
                            variant=variant, \
                            exp_id=exp_id, \
                            base_log_dir=variant['util_params']['base_log_dir'], \
                            config_log_dir=variant['util_params']['config_log_dir'], \
                            log_dir=variant['util_params']['log_dir'])
    if variant['algo_params']['dump_eval_paths']:
        pickle_dir = experiment_log_dir + '/eval_trajectories'
        pathlib.Path(pickle_dir).mkdir(parents=True, exist_ok=True)

    env.save_all_tasks(experiment_log_dir)

    if variant['eval']:
        algorithm._try_to_eval(0, eval_all=True, eval_train_offline=False, animated=True)
    else:
        algorithm.train()
示例#5
0
def experiment(variant):

    # create multi-task environment and sample tasks
    env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']))
    tasks = env.get_all_task_idx()
    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    reward_dim = 1

    # instantiate networks
    latent_dim = variant['latent_size']
    context_encoder_input_dim = 2 * obs_dim + action_dim + reward_dim if variant[
        'algo_params'][
            'use_next_obs_in_context'] else obs_dim + action_dim + reward_dim
    context_encoder_output_dim = latent_dim * 2 if variant['algo_params'][
        'use_information_bottleneck'] else latent_dim
    net_size = variant['net_size']
    recurrent = variant['algo_params']['recurrent']
    encoder_model = RecurrentEncoder if recurrent else MlpEncoder

    context_encoder = encoder_model(
        hidden_sizes=[200, 200, 200],
        input_size=context_encoder_input_dim,
        output_size=context_encoder_output_dim,
    )

    #low Qs first and then high Qs
    q_list = [[
        FlattenMlp(
            hidden_sizes=[net_size, net_size, net_size],
            input_size=2 * obs_dim + action_dim,
            output_size=1,
        ),
        FlattenMlp(
            hidden_sizes=[net_size, net_size, net_size],
            input_size=2 * obs_dim + action_dim,
            output_size=1,
        )
    ],
              [
                  FlattenMlp(
                      hidden_sizes=[net_size, net_size, net_size],
                      input_size=obs_dim + action_dim + latent_dim,
                      output_size=1,
                  ),
                  FlattenMlp(
                      hidden_sizes=[net_size, net_size, net_size],
                      input_size=obs_dim + action_dim + latent_dim,
                      output_size=1,
                  )
              ]]
    #low vf first and then high vf
    vf_list = [
        FlattenMlp(
            hidden_sizes=[net_size, net_size, net_size],
            input_size=2 * obs_dim,
            output_size=1,
        ),
        FlattenMlp(
            hidden_sizes=[net_size, net_size, net_size],
            input_size=obs_dim + latent_dim,
            output_size=1,
        )
    ]

    #NOTE: Reduced number of hidden layers in h_policy from 3 to 2 (idea being it's not doing as much as the whole policy in PEARL)
    h_policy = TanhGaussianPolicy(
        hidden_sizes=[net_size, net_size],
        obs_dim=obs_dim + latent_dim,
        latent_dim=latent_dim,
        action_dim=obs_dim,
    )
    #NOTE: Kept the 3 layers because f**k it it'll get tons of data
    l_policy = TanhGaussianPolicy(
        hidden_sizes=[net_size, net_size, net_size, net_size],
        obs_dim=2 * obs_dim,
        latent_dim=0,
        action_dim=action_dim,
    )
    #TODO Implement BernAgent
    agent = BURNAgent(latent_dim,
                      context_encoder,
                      h_policy,
                      l_policy,
                      c=2,
                      **variant['algo_params'])
    algorithm = BURNSoftActorCritic(
        env=env,
        train_tasks=list(tasks[:variant['n_train_tasks']]),
        eval_tasks=list(tasks[-variant['n_eval_tasks']:]),
        nets=[agent, q_list, vf_list],
        latent_dim=latent_dim,
        **variant['algo_params'])

    # optionally load pre-trained weights
    #TODO Make sure weights are properly saved
    if variant['path_to_weights'] is not None:
        path = variant['path_to_weights']
        context_encoder.load_state_dict(
            torch.load(os.path.join(path, 'context_encoder.pth')))
        q_list[0][0].load_state_dict(
            torch.load(os.path.join(path, 'l_qf1.pth')))
        q_list[0][1].load_state_dict(
            torch.load(os.path.join(path, 'l_qf2.pth')))
        q_list[1][0].load_state_dict(
            torch.load(os.path.join(path, 'h_qf1.pth')))
        q_list[1][1].load_state_dict(
            torch.load(os.path.join(path, 'h_qf2.pth')))
        vf_list[0].load_state_dict(torch.load(os.path.join(path, 'l_vf.pth')))
        vf_list[1].load_state_dict(torch.load(os.path.join(path, 'h_vf.pth')))
        # TODO hacky, revisit after model refactor
        algorithm.networks[-2].load_state_dict(
            torch.load(os.path.join(path, 'target_vf.pth')))
        h_policy.load_state_dict(torch.load(os.path.join(path,
                                                         'h_policy.pth')))
        l_policy.load_state_dict(torch.load(os.path.join(path,
                                                         'l_policy.pth')))

    # optional GPU mode
    ptu.set_gpu_mode(variant['util_params']['use_gpu'],
                     variant['util_params']['gpu_id'])
    if ptu.gpu_enabled():
        algorithm.to()

    # debugging triggers a lot of printing and logs to a debug directory
    DEBUG = variant['util_params']['debug']
    os.environ['DEBUG'] = str(int(DEBUG))

    # create logging directory
    # TODO support Docker
    exp_id = 'debug' if DEBUG else None
    experiment_log_dir = setup_logger(
        variant['env_name'],
        variant=variant,
        exp_id=exp_id,
        base_log_dir=variant['util_params']['base_log_dir'])

    # optionally save eval trajectories as pkl files
    if variant['algo_params']['dump_eval_paths']:
        pickle_dir = experiment_log_dir + '/eval_trajectories'
        pathlib.Path(pickle_dir).mkdir(parents=True, exist_ok=True)

    # run the algorithm
    algorithm.train()