Esempio n. 1
0
def experiment(args):

    device = torch.device(
        "cuda:{}".format(args.device) if args.cuda else "cpu")

    env = get_env(params['env_name'], params['env'])

    env.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.cuda:
        torch.backends.cudnn.deterministic = True

    buffer_param = params['replay_buffer']
    replay_buffer = BaseReplayBuffer(int(buffer_param['size']))

    experiment_name = os.path.split( os.path.splitext( args.config )[0] )[-1] if args.id is None \
        else args.id
    logger = Logger(experiment_name, params['env_name'], args.seed, params,
                    args.log_dir)

    params['general_setting']['env'] = env
    params['general_setting']['replay_buffer'] = replay_buffer
    params['general_setting']['logger'] = logger
    params['general_setting']['device'] = device

    params['net']['base_type'] = networks.MLPBase
    # agent = get_agent( params )
    # print(env)
    # params['general_setting']['collector'] = BaseCollector(
    #     env, pf, replay_buffer
    # )

    pf = policies.GuassianContPolicy(
        input_shape=env.observation_space.shape[0],
        output_shape=2 * env.action_space.shape[0],
        **params['net'])
    vf = networks.Net(input_shape=env.observation_space.shape[0],
                      output_shape=1,
                      **params['net'])
    qf = networks.FlattenNet(input_shape=env.observation_space.shape[0] +
                             env.action_space.shape[0],
                             output_shape=1,
                             **params['net'])
    pretrain_pf = policies.UniformPolicyContinuous(env.action_space.shape[0])

    params['general_setting']['collector'] = BaseCollector(env,
                                                           pf,
                                                           replay_buffer,
                                                           device=device)
    params['general_setting']['save_dir'] = osp.join(logger.work_dir, "model")
    agent = SAC(pf=pf,
                vf=vf,
                qf=qf,
                pretrain_pf=pretrain_pf,
                **params['sac'],
                **params['general_setting'])
    agent.train()
Esempio n. 2
0
def experiment(args):

    device = torch.device("cuda:{}".format(args.device) if args.cuda else "cpu")

    env = get_env( params['env_name'], params['env'])

    env.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.cuda:
        torch.backends.cudnn.deterministic=True
    
    buffer_param = params['replay_buffer']

    experiment_name = os.path.split( os.path.splitext( args.config )[0] )[-1] if args.id is None \
        else args.id
    logger = Logger( experiment_name , params['env_name'], args.seed, params, args.log_dir )

    params['general_setting']['env'] = env
    params['general_setting']['logger'] = logger
    params['general_setting']['device'] = device

    params['net']['base_type']=networks.MLPBase

    import torch.multiprocessing as mp
    mp.set_start_method('spawn')

    pf = policies.GuassianContPolicy (
        input_shape = env.observation_space.shape[0], 
        output_shape = 2 * env.action_space.shape[0],
        **params['net'] )
    vf = networks.Net( 
        input_shape = env.observation_space.shape[0],
        output_shape = 1,
        **params['net'] )
    qf = networks.FlattenNet( 
        input_shape = env.observation_space.shape[0] + env.action_space.shape[0],
        output_shape = 1,
        **params['net'] )
    # pretrain_pf = policies.UniformPolicyContinuous(env.action_space.shape[0])
    
    example_ob = env.reset()
    example_dict = { 
        "obs": example_ob,
        "next_obs": example_ob,
        "acts": env.action_space.sample(),
        "rewards": [0],
        "terminals": [False]
    }
    replay_buffer = SharedBaseReplayBuffer( int(buffer_param['size']),
            1
    )
    replay_buffer.build_by_example(example_dict)

    params['general_setting']['replay_buffer'] = replay_buffer

    params['general_setting']['collector'] = ParallelCollector(
        env, pf, replay_buffer, device=device, worker_nums=1
    )

    params['general_setting']['save_dir'] = osp.join(logger.work_dir,"model")
    agent = SAC(
        pf = pf,
        vf = vf,
        qf = qf,
        **params['sac'],
        **params['general_setting']
    )
    agent.train()
Esempio n. 3
0
def get_agent(params):

    env = params['general_setting']['env']
    # params['general_setting']['collector'] = BaseCollector(
    #     env
    # )

    if len(env.observation_space.shape) == 3:
        params['net']['base_type'] = networks.CNNBase
        if params['env']['frame_stack']:
            buffer_param = params['replay_buffer']
            efficient_buffer = replay_buffers.MemoryEfficientReplayBuffer(
                int(buffer_param['size']))
            params['general_setting']['replay_buffer'] = efficient_buffer
    else:
        params['net']['base_type'] = networks.MLPBase

    if params['agent'] == 'sac':
        pf = policies.GuassianContPolicy(
            input_shape=env.observation_space.shape[0],
            output_shape=2 * env.action_space.shape[0],
            **params['net'])
        vf = networks.Net(input_shape=env.observation_space.shape[0],
                          output_shape=1,
                          **params['net'])
        qf = networks.FlattenNet(input_shape=env.observation_space.shape[0] +
                                 env.action_space.shape[0],
                                 output_shape=1,
                                 **params['net'])
        pretrain_pf = policies.UniformPolicyContinuous(
            env.action_space.shape[0])

        return SAC(pf=pf,
                   vf=vf,
                   qf=qf,
                   pretrain_pf=pretrain_pf,
                   **params['sac'],
                   **params['general_setting'])

    if params['agent'] == 'twin_sac':
        pf = policies.GuassianContPolicy(
            input_shape=env.observation_space.shape[0],
            output_shape=2 * env.action_space.shape[0],
            **params['net'])
        vf = networks.Net(input_shape=env.observation_space.shape[0],
                          output_shape=1,
                          **params['net'])
        qf1 = networks.FlattenNet(input_shape=env.observation_space.shape[0] +
                                  env.action_space.shape[0],
                                  output_shape=1,
                                  **params['net'])
        qf2 = networks.FlattenNet(input_shape=env.observation_space.shape[0] +
                                  env.action_space.shape[0],
                                  output_shape=1,
                                  **params['net'])
        pretrain_pf = policies.UniformPolicyContinuous(
            env.action_space.shape[0])

        return TwinSAC(pf=pf,
                       vf=vf,
                       qf1=qf1,
                       qf2=qf2,
                       pretrain_pf=pretrain_pf,
                       **params['twin_sac'],
                       **params['general_setting'])

    if params['agent'] == 'td3':
        pf = policies.DetContPolicy(input_shape=env.observation_space.shape[0],
                                    output_shape=env.action_space.shape[0],
                                    **params['net'])
        qf1 = networks.FlattenNet(input_shape=env.observation_space.shape[0] +
                                  env.action_space.shape[0],
                                  output_shape=1,
                                  **params['net'])
        qf2 = networks.FlattenNet(input_shape=env.observation_space.shape[0] +
                                  env.action_space.shape[0],
                                  output_shape=1,
                                  **params['net'])
        pretrain_pf = policies.UniformPolicyContinuous(
            env.action_space.shape[0])

        return TD3(pf=pf,
                   qf1=qf1,
                   qf2=qf2,
                   pretrain_pf=pretrain_pf,
                   **params['td3'],
                   **params['general_setting'])

    if params['agent'] == 'ddpg':
        pf = policies.DetContPolicy(input_shape=env.observation_space.shape[0],
                                    output_shape=env.action_space.shape[0],
                                    **params['net'])
        qf = networks.FlattenNet(input_shape=env.observation_space.shape[0] +
                                 env.action_space.shape[0],
                                 output_shape=1,
                                 **params['net'])
        pretrain_pf = policies.UniformPolicyContinuous(
            env.action_space.shape[0])

        return DDPG(pf=pf,
                    qf=qf,
                    pretrain_pf=pretrain_pf,
                    **params['ddpg'],
                    **params['general_setting'])

    if params['agent'] == 'dqn':
        qf = networks.Net(input_shape=env.observation_space.shape,
                          output_shape=env.action_space.n,
                          **params['net'])
        pf = policies.EpsilonGreedyDQNDiscretePolicy(
            qf=qf, action_shape=env.action_space.n, **params['policy'])
        pretrain_pf = policies.UniformPolicyDiscrete(
            action_num=env.action_space.n)
        params["general_setting"]["optimizer_class"] = optim.RMSprop
        return DQN(pf=pf,
                   qf=qf,
                   pretrain_pf=pretrain_pf,
                   **params["dqn"],
                   **params["general_setting"])

    if params['agent'] == 'bootstrapped dqn':
        qf = networks.BootstrappedNet(
            input_shape=env.observation_space.shape,
            output_shape=env.action_space.n,
            head_num=params['bootstrapped dqn']['head_num'],
            **params['net'])
        pf = policies.BootstrappedDQNDiscretePolicy(
            qf=qf,
            head_num=params['bootstrapped dqn']['head_num'],
            action_shape=env.action_space.n,
            **params['policy'])
        pretrain_pf = policies.UniformPolicyDiscrete(
            action_num=env.action_space.n)
        params["general_setting"]["optimizer_class"] = optim.RMSprop
        return BootstrappedDQN(pf=pf,
                               qf=qf,
                               pretrain_pf=pretrain_pf,
                               **params["bootstrapped dqn"],
                               **params["general_setting"])

    if params['agent'] == 'qrdqn':
        qf = networks.Net(input_shape=env.observation_space.shape,
                          output_shape=env.action_space.n *
                          params["qrdqn"]["quantile_num"],
                          **params['net'])
        pf = policies.EpsilonGreedyQRDQNDiscretePolicy(
            qf=qf, action_shape=env.action_space.n, **params['policy'])
        pretrain_pf = policies.UniformPolicyDiscrete(
            action_num=env.action_space.n)
        return QRDQN(pf=pf,
                     qf=qf,
                     pretrain_pf=pretrain_pf,
                     **params["qrdqn"],
                     **params["general_setting"])

    # On Policy Methods
    act_space = env.action_space
    params[params['agent']]['continuous'] = isinstance(act_space,
                                                       gym.spaces.Box)

    buffer_param = params['replay_buffer']
    buffer = replay_buffers.OnPolicyReplayBuffer(int(buffer_param['size']))
    params['general_setting']['replay_buffer'] = buffer

    if params[params['agent']]['continuous']:
        pf = policies.GuassianContPolicy(
            input_shape=env.observation_space.shape,
            output_shape=2 * env.action_space.shape[0],
            **params['net'])
    else:
        print(params['policy'])
        print(params['net'])
        # print(**params['policy'])
        pf = policies.CategoricalDisPolicy(
            input_shape=env.observation_space.shape,
            output_shape=env.action_space.n,
            **params['net'],
            **params['policy'])

    if params['agent'] == 'reinforce':
        return Reinforce(pf=pf,
                         **params["reinforce"],
                         **params["general_setting"])

    # Actor-Critic Frameworks
    vf = networks.Net(input_shape=env.observation_space.shape,
                      output_shape=1,
                      **params['net'])

    if params['agent'] == 'a2c':
        return A2C(pf=pf, vf=vf, **params["a2c"], **params["general_setting"])

    if params['agent'] == 'ppo':
        return PPO(pf=pf, vf=vf, **params["ppo"], **params["general_setting"])

    raise Exception("specified algorithm is not implemented")
Esempio n. 4
0
def experiment(args):

    device = torch.device("cuda:{}".format(args.device) if args.cuda else "cpu")

    env, cls_dicts, cls_args = get_meta_env( params['env_name'], params['env'], params['meta_env'])

    env.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if args.cuda:
        torch.backends.cudnn.deterministic=True
    
    buffer_param = params['replay_buffer']

    experiment_name = os.path.split( os.path.splitext( args.config )[0] )[-1] if args.id is None \
        else args.id
    logger = Logger( experiment_name , params['env_name'], args.seed, params, args.log_dir )

    params['general_setting']['env'] = env
    params['general_setting']['logger'] = logger
    params['general_setting']['device'] = device

    params['net']['base_type']=networks.MLPBase

    import torch.multiprocessing as mp
    mp.set_start_method('spawn', force=True)

    from torchrl.networks.init import normal_init

    pf = policies.GuassianContPolicy(
        input_shape = env.observation_space.shape[0], 
        output_shape = 2 * env.action_space.shape[0],
        **params['net'] )
    qf1 = networks.FlattenNet( 
        input_shape = env.observation_space.shape[0] + env.action_space.shape[0],
        output_shape = 1,
        **params['net'] )
    qf2 = networks.FlattenNet( 
        input_shape = env.observation_space.shape[0] + env.action_space.shape[0],
        output_shape = 1,
        **params['net'] )

    example_ob = env.reset()
    example_dict = { 
        "obs": example_ob,
        "next_obs": example_ob,
        "acts": env.action_space.sample(),
        "rewards": [0],
        "terminals": [False],
        "task_idxs": [0]
    }
    replay_buffer = AsyncSharedReplayBuffer( int(buffer_param['size']),
            args.worker_nums
    )
    replay_buffer.build_by_example(example_dict)

    params['general_setting']['replay_buffer'] = replay_buffer

    epochs = params['general_setting']['pretrain_epochs'] + \
        params['general_setting']['num_epochs']

    print(env.action_space)
    print(env.observation_space)
    params['general_setting']['collector'] = AsyncMultiTaskParallelCollectorUniform(
        env=env, pf=pf, replay_buffer=replay_buffer,
        env_cls = cls_dicts, env_args = [params["env"], cls_args, params["meta_env"]],
        device=device,
        reset_idx=True,
        epoch_frames=params['general_setting']['epoch_frames'],
        max_episode_frames=params['general_setting']['max_episode_frames'],
        eval_episodes = params['general_setting']['eval_episodes'],
        worker_nums=args.worker_nums, eval_worker_nums=args.eval_worker_nums,
        train_epochs = epochs, eval_epochs= params['general_setting']['num_epochs']
    )
    params['general_setting']['batch_size'] = int(params['general_setting']['batch_size'])
    params['general_setting']['save_dir'] = osp.join(logger.work_dir,"model")
    agent = MTSAC(
        pf = pf,
        qf1 = qf1,
        qf2 = qf2,
        task_nums=env.num_tasks,
        **params['sac'],
        **params['general_setting']
    )
    agent.train()