Ejemplo n.º 1
0
def experiment(args):

    device = torch.device(
        "cuda:{}".format(args.device) if args.cuda else "cpu")

    env = get_vec_env(params["env_name"], params["env"], args.vec_env_nums)

    env.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    buffer_param = params['replay_buffer']

    experiment_name = os.path.split(
        os.path.splitext(args.config)[0])[-1] if args.id is None \
        else args.id
    logger = Logger(experiment_name, params['env_name'], args.seed, params,
                    args.log_dir, args.overwrite)

    params['general_setting']['env'] = env

    replay_buffer = OnPolicyReplayBuffer(
        env_nums=args.vec_env_nums,
        max_replay_buffer_size=int(buffer_param['size']),
        time_limit_filter=buffer_param['time_limit_filter'])
    params['general_setting']['replay_buffer'] = replay_buffer

    params['general_setting']['logger'] = logger
    params['general_setting']['device'] = device

    params['net']['base_type'] = networks.MLPBase
    params['net']['activation_func'] = torch.nn.Tanh
    pf = policies.GuassianContPolicyBasicBias(
        input_shape=env.observation_space.shape[0],
        output_shape=env.action_space.shape[0],
        **params['net'],
        **params['policy'])
    vf = networks.Net(input_shape=env.observation_space.shape,
                      output_shape=1,
                      **params['net'])
    print(pf)
    print(vf)
    params['general_setting']['collector'] = VecOnPolicyCollector(
        vf,
        env=env,
        pf=pf,
        replay_buffer=replay_buffer,
        device=device,
        train_render=False,
        **params["collector"])
    params['general_setting']['save_dir'] = osp.join(logger.work_dir, "model")
    agent = A2C(pf=pf, vf=vf, **params["a2c"], **params["general_setting"])
    agent.train()
Ejemplo n.º 2
0
def experiment(args):
    # import torch.multiprocessing as mp
    # mp.set_start_method('spawn')

    device = torch.device(
        "cuda:{}".format(args.device) if args.cuda else "cpu")

    env = get_env(params['env_name'], params['env'])

    env.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    buffer_param = params['replay_buffer']

    experiment_name = os.path.split(os.path.splitext(args.config)[0])[-1] if args.id is None \
        else args.id
    logger = Logger(experiment_name, params['env_name'], args.seed, params,
                    args.log_dir)

    params['general_setting']['env'] = env

    replay_buffer = OnPolicyReplayBuffer(
        int(buffer_param['size']),
        time_limit_filter=buffer_param['time_limit_filter'])

    params['general_setting']['replay_buffer'] = replay_buffer

    params['general_setting']['logger'] = logger
    params['general_setting']['device'] = device

    params['net']['base_type'] = networks.MLPBase
    params['net']['activation_func'] = nn.Tanh
    pf = policies.GuassianContPolicyBasicBias(
        input_shape=env.observation_space.shape[0],
        output_shape=env.action_space.shape[0],
        init_func=lambda x: init.orthogonal_init(
            x, scale=np.sqrt(2), constant=0),
        net_last_init_func=lambda x: init.orthogonal_init(
            x, scale=0.01, constant=0),
        **params['net'],
        **params['policy'])
    vf = networks.Net(input_shape=env.observation_space.shape,
                      output_shape=1,
                      init_func=lambda x: init.orthogonal_init(
                          x, scale=np.sqrt(2), constant=0),
                      net_last_init_func=lambda x: init.orthogonal_init(
                          x, scale=1, constant=0),
                      **params['net'])
    params['general_setting']['collector'] = OnPlicyCollectorBase(
        vf,
        env=env,
        pf=pf,
        replay_buffer=replay_buffer,
        device=device,
        train_render=False,
        **params["collector"])

    params['general_setting']['save_dir'] = osp.join(logger.work_dir, "model")
    agent = TRPO(pf=pf, vf=vf, **params["trpo"], **params["general_setting"])
    print(params["general_setting"])
    print(agent.epoch_frames)
    agent.train()