Exemplo n.º 1
0
def train(params_dict: dict):
    ncpu = multiprocessing.cpu_count()
    # ncpu = 1
    env_id = params_dict['env_params']['env']

    total_timesteps = float(params_dict['training_params']['num_timesteps'])

    learn = get_learn_function(params_dict['model_params']['alg'])
    alg_kwargs = get_learn_function_defaults(
        params_dict['model_params']['alg'], 'atari')
    alg_kwargs['network'] = params_dict['model_params']['network']
    alg_kwargs['lr'] = 0.0001

    if 'frame_stack' in params_dict['env_params'] and params_dict[
            'env_params']['frame_stack']:
        wrapper_kwargs = {'frame_stack': True}
    else:
        wrapper_kwargs = {}

    env = make_vec_env(env_id,
                       'atari',
                       ncpu,
                       seed=None,
                       wrapper_kwargs=wrapper_kwargs)

    # env = VecFrameStack(env, 4)

    model = learn(env=env,
                  seed=None,
                  total_timesteps=total_timesteps,
                  **alg_kwargs)

    return model, env
Exemplo n.º 2
0
def train(args, extra_args):
    env_type, env_id = run.get_env_type(args.env)

    if args.alg == 'gail':
        env_type += '_gail'
        args.alg = 'bgail'
    elif args.alg not in ['bgail', 'gail']:
        raise NotImplementedError

    learn = run.get_learn_function(args.alg)
    alg_kwargs = run.get_learn_function_defaults(args.alg, env_type)
    alg_kwargs.update(extra_args)

    env = build_env(args)
    logger.configure(os.path.join("log", "GAIL", args.env, "subsample_{}".format(extra_args["data_subsample_freq"]),
                                  "traj_{}".format(extra_args["num_expert_trajs"]), "batch_size_{}".format(extra_args["timesteps_per_batch"]),
                                  "seed_{}".format(args.seed)))

    print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))

    model = learn(env=env,
                  seed=args.seed,
                  save_path=args.save_path,
                  load_path=args.load_path,
                  render=args.render,
                  **alg_kwargs)
Exemplo n.º 3
0
Arquivo: run.py Projeto: zwc662/BGAIL
def train(args, extra_args):
    env_type, env_id = run.get_env_type(args.env)

    if args.alg == 'gail':
        env_type += '_gail'
        args.alg = 'bgail'
    elif args.alg not in ['bgail', 'gail']:
        raise NotImplementedError

    learn = run.get_learn_function(args.alg)
    alg_kwargs = run.get_learn_function_defaults(args.alg, env_type)
    alg_kwargs.update(extra_args)

    env = build_env(args)

    print('Training {} on {}:{} with arguments \n{}'.format(
        args.alg, env_type, env_id, alg_kwargs))

    model = learn(env=env,
                  seed=args.seed,
                  save_path=args.save_path,
                  load_path=args.load_path,
                  render=args.render,
                  **alg_kwargs)
Exemplo n.º 4
0
    'noptepochs': 10,
    'save_interval': 20,
    'log_interval': 1,
    'save_path': save_path,
    'model_load_path': model_load_path,
    'seed': 0,
    'reward_scale': 1,
    'flatten_dict_observations': True,
    'transfer_weights': False
}
args = SimpleNamespace(**args_dict)

# Prepare the environment and learning algorithm
env_type, env_id = get_env_type(args.env)
learn = get_learn_function(args.alg)
alg_kwargs = get_learn_function_defaults(args.alg, env_type)
env = build_env(args)
alg_kwargs['network'] = args.network

# The path we will store the results of this experiment
full_path = args.save_path + '/' + args.env + '-' + args.alg

# Make folders that we will store the checkpoints, models and epoch results
if not os.path.exists(full_path):
    os.makedirs(full_path)
    os.makedirs(full_path + '/checkpoints')

print("About to start learning model")

model = learn(env=env,
              seed=args.seed,