示例#1
0
 def __init__(self, env, args, *other_args):
     super(ACAgent, self).__init__(env, args, other_args)
     self.num_env = args.num_env
     optimizer_fn = get_optimizer_func(args.opt_fn)()
     # policy net
     self.policy = get_policy_net(env, args)
     self.target_policy = get_policy_net(env, args)
     self.policy_opt = optimizer_fn(params=filter(lambda p: p.requires_grad,
                                                  self.policy.parameters()),
                                    lr=self.policy.learning_rate)
     # value net
     self.value = get_value_net(env, args)
     self.target_value = get_value_net(env, args)
     self.value_opt = optimizer_fn(params=filter(lambda p: p.requires_grad,
                                                 self.value.parameters()),
                                   lr=self.args.lr)
     if self.num_env > 1:
         self.queue = mp.Queue()
     self.envs = []
     self.states = []
     self.memorys = []
     for i in range(self.num_env):
         env = make_env(args=args)
         state = env.reset()
         self.envs.append(env)
         self.states.append(state)
         self.memorys.append(ReplayBuffer(1e5))
示例#2
0
文件: ac_agent.py 项目: rlcyf/tbase-1
 def __init__(self, env, args, *other_args):
     super(ACAgent, self).__init__(env, args, other_args)
     self.num_env = args.num_env
     optimizer_fn = get_optimizer_func(args.opt_fn)()
     # policy net
     self.policy = get_policy_net(env, args)
     self.target_policy = get_policy_net(env, args)
     self.policy_opt = optimizer_fn(params=filter(lambda p: p.requires_grad,
                                                  self.policy.parameters()),
                                    lr=self.policy.learning_rate)
     # value net
     self.value = get_value_net(env, args)
     self.target_value = get_value_net(env, args)
     self.value_opt = optimizer_fn(params=filter(lambda p: p.requires_grad,
                                                 self.value.parameters()),
                                   lr=self.args.lr)
     if self.num_env > 1:
         self.queue = mp.Queue()
     self.envs = []
     self.states = []
     self.memorys = []
     for i in range(self.num_env):
         env = make_env(args=args)
         state = env.reset()
         self.envs.append(env)
         self.states.append(state)
         self.memorys.append(ReplayBuffer(1e5))
     with open(self.args.progress_bar_path, "w") as progress_file:
         if self.args.eval or self.args.infer:
             progress_file.write("%d,%d\n" % (0, 1))
         else:
             progress_file.write("%d,%d\n" % (0, self.args.max_iter_num))
示例#3
0
文件: tbase.py 项目: tradingAI/runner
    def run(self):
        logger.info("%s is running" % self.name)
        args = common_arg_parser()
        if args.debug:
            import logging
            logger.setLevel(logging.DEBUG)
        set_global_seeds(args.seed)
        logger.info("tbase.run set global_seeds: %s" % str(args.seed))
        if torch.cuda.is_available():
            if args.num_env > 1 and args.device != 'cpu':
                set_start_method('spawn')
        env = make_env(args=args)
        print("\n" + "*" * 80)
        logger.info("Initializing agent by parameters:")
        logger.info(str(args))
        agent = get_agent(env, args)
        if not args.eval and not args.infer:
            logger.info("Training agent")
            agent.learn()
            logger.info("Finished, tensorboard --logdir=%s" %
                        args.tensorboard_dir)
        # eval models
        if not args.infer:
            eval_env = make_eval_env(args=args)
            agent.eval(eval_env, args)

        # infer actions
        if args.infer:
            infer_env = make_infer_env(args=args)
            agent.infer(infer_env)
示例#4
0
    def __init__(self, env, args, *other_args):
        # change to random policy
        args.policy_net = "Random"
        super(Agent, self).__init__(env, args, other_args)
        self.policy = get_policy_net(env, args)

        self.num_env = args.num_env
        self.envs = []
        self.states = []
        self.memorys = []
        for i in range(self.num_env):
            env = make_env(args=args)
            state = env.reset()
            self.envs.append(env)
            self.states.append(state)
            self.memorys.append(ReplayBuffer(1e5))
        self.queue = mp.Queue()
示例#5
0
文件: run.py 项目: iminders/tbase
def main():
    args = common_arg_parser()
    if args.debug:
        import logging
        logger.setLevel(logging.DEBUG)
    set_global_seeds(args.seed)
    logger.info("tbase.run set global_seeds: %s" % str(args.seed))
    if torch.cuda.is_available() and args.num_env > 1 and args.device != 'cpu':
        set_start_method('spawn')
    env = make_env(args=args)
    print("\n" + "*" * 80)
    logger.info("Initializing agent by parameters:")
    logger.info(str(args))
    agent = get_agent(env, args)
    if not args.eval:
        logger.info("Training agent")
        agent.learn()
        logger.info("Finished, check details by run tensorboard --logdir=%s" %
                    args.tensorboard_dir)
    # eval models
    eval_env = make_eval_env(args=args)
    agent.eval(eval_env, args)