def get_agent(params): env = params['general_setting']['env'] # params['general_setting']['collector'] = BaseCollector( # env # ) if len(env.observation_space.shape) == 3: params['net']['base_type'] = networks.CNNBase if params['env']['frame_stack']: buffer_param = params['replay_buffer'] efficient_buffer = replay_buffers.MemoryEfficientReplayBuffer( int(buffer_param['size'])) params['general_setting']['replay_buffer'] = efficient_buffer else: params['net']['base_type'] = networks.MLPBase if params['agent'] == 'sac': pf = policies.GuassianContPolicy( input_shape=env.observation_space.shape[0], output_shape=2 * env.action_space.shape[0], **params['net']) vf = networks.Net(input_shape=env.observation_space.shape[0], output_shape=1, **params['net']) qf = networks.FlattenNet(input_shape=env.observation_space.shape[0] + env.action_space.shape[0], output_shape=1, **params['net']) pretrain_pf = policies.UniformPolicyContinuous( env.action_space.shape[0]) return SAC(pf=pf, vf=vf, qf=qf, pretrain_pf=pretrain_pf, **params['sac'], **params['general_setting']) if params['agent'] == 'twin_sac': pf = policies.GuassianContPolicy( input_shape=env.observation_space.shape[0], output_shape=2 * env.action_space.shape[0], **params['net']) vf = networks.Net(input_shape=env.observation_space.shape[0], output_shape=1, **params['net']) qf1 = networks.FlattenNet(input_shape=env.observation_space.shape[0] + env.action_space.shape[0], output_shape=1, **params['net']) qf2 = networks.FlattenNet(input_shape=env.observation_space.shape[0] + env.action_space.shape[0], output_shape=1, **params['net']) pretrain_pf = policies.UniformPolicyContinuous( env.action_space.shape[0]) return TwinSAC(pf=pf, vf=vf, qf1=qf1, qf2=qf2, pretrain_pf=pretrain_pf, **params['twin_sac'], **params['general_setting']) if params['agent'] == 'td3': pf = policies.DetContPolicy(input_shape=env.observation_space.shape[0], output_shape=env.action_space.shape[0], **params['net']) qf1 = networks.FlattenNet(input_shape=env.observation_space.shape[0] + env.action_space.shape[0], output_shape=1, **params['net']) qf2 = networks.FlattenNet(input_shape=env.observation_space.shape[0] + env.action_space.shape[0], output_shape=1, **params['net']) pretrain_pf = policies.UniformPolicyContinuous( env.action_space.shape[0]) return TD3(pf=pf, qf1=qf1, qf2=qf2, pretrain_pf=pretrain_pf, **params['td3'], **params['general_setting']) if params['agent'] == 'ddpg': pf = policies.DetContPolicy(input_shape=env.observation_space.shape[0], output_shape=env.action_space.shape[0], **params['net']) qf = networks.FlattenNet(input_shape=env.observation_space.shape[0] + env.action_space.shape[0], output_shape=1, **params['net']) pretrain_pf = policies.UniformPolicyContinuous( env.action_space.shape[0]) return DDPG(pf=pf, qf=qf, pretrain_pf=pretrain_pf, **params['ddpg'], **params['general_setting']) if params['agent'] == 'dqn': qf = networks.Net(input_shape=env.observation_space.shape, output_shape=env.action_space.n, **params['net']) pf = policies.EpsilonGreedyDQNDiscretePolicy( qf=qf, action_shape=env.action_space.n, **params['policy']) pretrain_pf = policies.UniformPolicyDiscrete( action_num=env.action_space.n) params["general_setting"]["optimizer_class"] = optim.RMSprop return DQN(pf=pf, qf=qf, pretrain_pf=pretrain_pf, **params["dqn"], **params["general_setting"]) if params['agent'] == 'bootstrapped dqn': qf = networks.BootstrappedNet( input_shape=env.observation_space.shape, output_shape=env.action_space.n, head_num=params['bootstrapped dqn']['head_num'], **params['net']) pf = policies.BootstrappedDQNDiscretePolicy( qf=qf, head_num=params['bootstrapped dqn']['head_num'], action_shape=env.action_space.n, **params['policy']) pretrain_pf = policies.UniformPolicyDiscrete( action_num=env.action_space.n) params["general_setting"]["optimizer_class"] = optim.RMSprop return BootstrappedDQN(pf=pf, qf=qf, pretrain_pf=pretrain_pf, **params["bootstrapped dqn"], **params["general_setting"]) if params['agent'] == 'qrdqn': qf = networks.Net(input_shape=env.observation_space.shape, output_shape=env.action_space.n * params["qrdqn"]["quantile_num"], **params['net']) pf = policies.EpsilonGreedyQRDQNDiscretePolicy( qf=qf, action_shape=env.action_space.n, **params['policy']) pretrain_pf = policies.UniformPolicyDiscrete( action_num=env.action_space.n) return QRDQN(pf=pf, qf=qf, pretrain_pf=pretrain_pf, **params["qrdqn"], **params["general_setting"]) # On Policy Methods act_space = env.action_space params[params['agent']]['continuous'] = isinstance(act_space, gym.spaces.Box) buffer_param = params['replay_buffer'] buffer = replay_buffers.OnPolicyReplayBuffer(int(buffer_param['size'])) params['general_setting']['replay_buffer'] = buffer if params[params['agent']]['continuous']: pf = policies.GuassianContPolicy( input_shape=env.observation_space.shape, output_shape=2 * env.action_space.shape[0], **params['net']) else: print(params['policy']) print(params['net']) # print(**params['policy']) pf = policies.CategoricalDisPolicy( input_shape=env.observation_space.shape, output_shape=env.action_space.n, **params['net'], **params['policy']) if params['agent'] == 'reinforce': return Reinforce(pf=pf, **params["reinforce"], **params["general_setting"]) # Actor-Critic Frameworks vf = networks.Net(input_shape=env.observation_space.shape, output_shape=1, **params['net']) if params['agent'] == 'a2c': return A2C(pf=pf, vf=vf, **params["a2c"], **params["general_setting"]) if params['agent'] == 'ppo': return PPO(pf=pf, vf=vf, **params["ppo"], **params["general_setting"]) raise Exception("specified algorithm is not implemented")
def experiment(args): import torch.multiprocessing as mp mp.set_start_method('spawn') device = torch.device("cuda:{}".format(args.device) if args.cuda else "cpu") env = get_env( params['env_name'], params['env']) env.seed(args.seed) torch.manual_seed(args.seed) np.random.seed(args.seed) if args.cuda: torch.backends.cudnn.deterministic=True buffer_param = params['replay_buffer'] experiment_name = os.path.split( os.path.splitext( args.config )[0] )[-1] if args.id is None \ else args.id logger = Logger( experiment_name , params['env_name'], args.seed, params, args.log_dir ) params['general_setting']['env'] = env # replay_buffer = OnPolicyReplayBuffer(int(buffer_param['size'])) # example_ob = env.reset() # example_dict = { # "obs": example_ob, # "next_obs": example_ob, # "acts": env.action_space.sample(), # "values": [0], # "rewards": [0], # "terminals": [False] # } replay_buffer = OnPolicyReplayBuffer( int(buffer_param['size'])) # replay_buffer.build_by_example(example_dict) params['general_setting']['replay_buffer'] = replay_buffer params['general_setting']['logger'] = logger params['general_setting']['device'] = device params['net']['base_type']=networks.MLPBase pf = policies.CategoricalDisPolicy( input_shape = env.observation_space.shape[0], output_shape = env.action_space.n, **params['net'], **params['policy'] ) vf = networks.Net( input_shape = env.observation_space.shape, output_shape = 1, **params['net'] ) params['general_setting']['collector'] = OnPlicyCollectorBase( vf, env = env, pf = pf, replay_buffer = replay_buffer, device = "cuda", train_render=False ) # params['general_setting']['collector'] = ParallelOnPlicyCollector( # vf, env = env, pf = pf, replay_buffer = replay_buffer, device=device, worker_nums=2 # ) params['general_setting']['save_dir'] = osp.join(logger.work_dir,"model") agent = PPO( pf = pf, vf = vf, **params["ppo"], **params["general_setting"] ) agent.train()
def experiment(args): device = torch.device( "cuda:{}".format(args.device) if args.cuda else "cpu") env = get_vec_env( params["env_name"], params["env"], args.vec_env_nums ) env.seed(args.seed) torch.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) if args.cuda: torch.cuda.manual_seed_all(args.seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True buffer_param = params['replay_buffer'] experiment_name = os.path.split( os.path.splitext(args.config)[0])[-1] if args.id is None \ else args.id logger = Logger( experiment_name, params['env_name'], args.seed, params, args.log_dir, args.overwrite) params['general_setting']['env'] = env replay_buffer = OnPolicyReplayBuffer( env_nums=args.vec_env_nums, max_replay_buffer_size=int(buffer_param['size']), time_limit_filter=buffer_param['time_limit_filter'] ) params['general_setting']['replay_buffer'] = replay_buffer params['general_setting']['logger'] = logger params['general_setting']['device'] = device params['net']['base_type'] = networks.CNNBase params['net']['activation_func'] = torch.nn.Tanh print(env.observation_space.shape) print(env.action_space.n) pf = policies.CategoricalDisPolicy( input_shape=env.observation_space.shape, output_shape=env.action_space.n, **params['net'], **params['policy'] ) vf = networks.Net( input_shape=env.observation_space.shape, output_shape=1, **params['net'] ) print(pf) print(vf) params['general_setting']['collector'] = VecOnPolicyCollector( vf, env=env, pf=pf, replay_buffer=replay_buffer, device=device, train_render=False, **params["collector"] ) params['general_setting']['save_dir'] = osp.join( logger.work_dir, "model") agent = PPO( pf=pf, vf=vf, **params["ppo"], **params["general_setting"] ) agent.train()