Beispiel #1
0
    def from_defaults(args):
        if args.agent:
            agent_cls = R.lookup_agent(args.agent)
            agent_args = agent_cls.args
        else:
            h = R.lookup_actor(args.actor_host)
            w = R.lookup_actor(args.actor_worker)
            l = R.lookup_learner(args.learner)
            e = R.lookup_exp(args.exp)
            agent_args = {**h.args, **w.args, **l.args, **e.args}

        env_cls = R.lookup_env(args.env)
        rwdnorm_cls = R.lookup_reward_normalizer(args.rwd_norm)

        env_args = env_cls.args
        rwdnorm_args = rwdnorm_cls.args
        if args.custom_network:
            net_args = R.lookup_network(args.custom_network).args
        else:
            net_args = R.lookup_modular_args(args)
        args = DotDict({
            **args,
            **agent_args,
            **env_args,
            **rwdnorm_args,
            **net_args
        })

        return args
Beispiel #2
0
    def from_prompt(args):
        if args.agent:
            agent_cls = R.lookup_agent(args.agent)
            agent_args = agent_cls.prompt(provided=args)
        else:
            h = R.lookup_actor(args.actor_host)
            w = R.lookup_actor(args.actor_worker)
            l = R.lookup_learner(args.learner)
            e = R.lookup_exp(args.exp)
            agent_args = {
                **h.prompt(args),
                **w.prompt(args),
                **l.prompt(args),
                **e.prompt(args),
            }

        env_cls = R.lookup_env(args.env)
        rwdnorm_cls = R.lookup_reward_normalizer(args.rwd_norm)

        env_args = env_cls.prompt(provided=args)
        rwdnorm_args = rwdnorm_cls.prompt(provided=args)
        if args.custom_network:
            net_args = R.lookup_network(args.custom_network).prompt()
        else:
            net_args = R.prompt_modular_args(args)
        args = DotDict({
            **args,
            **agent_args,
            **env_args,
            **rwdnorm_args,
            **net_args
        })
        return args
Beispiel #3
0
    def __init__(
        self,
        args,
        logger,
        log_id_dir,
        initial_step_count,
        local_rank,
        global_rank,
        world_size,
    ):
        seed = (
            args.seed
            if global_rank == 0
            else args.seed + args.nb_env * global_rank
        )
        logger.info("Using {} for rank {} seed.".format(seed, global_rank))

        # ENV
        engine = REGISTRY.lookup_engine(args.env)
        env_cls = REGISTRY.lookup_env(args.env)
        mgr_cls = REGISTRY.lookup_manager(args.manager)
        env_mgr = mgr_cls.from_args(args, engine, env_cls, seed=seed)

        # NETWORK
        torch.manual_seed(args.seed)
        device = torch.device("cuda:{}".format(local_rank))
        output_space = REGISTRY.lookup_output_space(
            args.agent, env_mgr.action_space
        )
        if args.custom_network:
            net_cls = REGISTRY.lookup_network(args.custom_network)
        else:
            net_cls = ModularNetwork
        net = net_cls.from_args(
            args,
            env_mgr.observation_space,
            output_space,
            env_mgr.gpu_preprocessor,
            REGISTRY,
        )
        logger.info("Network parameters: " + str(self.count_parameters(net)))

        def optim_fn(x):
            return torch.optim.RMSprop(x, lr=args.lr, eps=1e-5, alpha=0.99)

        # AGENT
        rwd_norm = REGISTRY.lookup_reward_normalizer(args.rwd_norm).from_args(
            args
        )
        agent_cls = REGISTRY.lookup_agent(args.agent)
        builder = agent_cls.exp_spec_builder(
            env_mgr.observation_space,
            env_mgr.action_space,
            net.internal_space(),
            env_mgr.nb_env,
        )
        agent = agent_cls.from_args(
            args, rwd_norm, env_mgr.action_space, builder
        )

        self.agent = agent
        self.nb_step = args.nb_step
        self.env_mgr = env_mgr
        self.nb_env = args.nb_env
        self.network = net.to(device)
        self.optimizer = optim_fn(self.network.parameters())
        self.device = device
        self.initial_step_count = initial_step_count
        self.log_id_dir = log_id_dir
        self.epoch_len = args.epoch_len
        self.summary_freq = args.summary_freq
        self.logger = logger
        self.summary_writer = SummaryWriter(
            os.path.join(log_id_dir, "rank{}".format(global_rank))
        )
        self.saver = SimpleModelSaver(log_id_dir)
        self.local_rank = local_rank
        self.global_rank = global_rank
        self.world_size = world_size
        self.updater = DistribUpdater(
            self.optimizer,
            self.network,
            args.grad_norm_clip,
            world_size,
            not args.no_divide,
        )

        if args.load_network:
            self.network = self.load_network(self.network, args.load_network)
            logger.info("Reloaded network from {}".format(args.load_network))
        if args.load_optim:
            self.optimizer = self.load_optim(self.optimizer, args.load_optim)
            logger.info("Reloaded optimizer from {}".format(args.load_optim))

        self.network.train()
Beispiel #4
0
    def __init__(
            self,
            args,
            logger,
            log_id_dir,
            initial_step_count,
            local_rank,
            global_rank,
            world_size
    ):
        seed = args.seed \
            if global_rank == 0 \
            else args.seed + args.nb_env * global_rank
        logger.info('Using {} for rank {} seed.'.format(seed, global_rank))

        # ENV
        engine = REGISTRY.lookup_engine(args.env)
        env_cls = REGISTRY.lookup_env(args.env)
        env_mgr = SubProcEnvManager.from_args(args, engine, env_cls, seed=seed)

        # NETWORK
        torch.manual_seed(args.seed)
        device = torch.device("cuda:{}".format(local_rank))
        output_space = REGISTRY.lookup_output_space(
            args.agent, env_mgr.action_space)
        if args.custom_network:
            net_cls = REGISTRY.lookup_network(args.custom_network)
        else:
            net_cls = ModularNetwork
        net = net_cls.from_args(
            args,
            env_mgr.observation_space,
            output_space,
            env_mgr.gpu_preprocessor,
            REGISTRY
        )

        def optim_fn(x):
            return torch.optim.RMSprop(x, lr=args.lr, eps=1e-5, alpha=0.99)

        # AGENT
        rwd_norm = REGISTRY.lookup_reward_normalizer(
            args.rwd_norm).from_args(args)
        agent_cls = REGISTRY.lookup_agent(args.agent)
        builder = agent_cls.exp_spec_builder(
            env_mgr.observation_space,
            env_mgr.action_space,
            net.internal_space(),
            env_mgr.nb_env
        )
        agent = agent_cls.from_args(
            args,
            rwd_norm,
            env_mgr.action_space,
            builder
        )

        self.agent = agent
        self.nb_step = args.nb_step
        self.env_mgr = env_mgr
        self.nb_env = args.nb_env
        self.network = net.to(device)
        self.optimizer = optim_fn(self.network.parameters())
        self.device = device
        self.initial_step_count = initial_step_count
        self.log_id_dir = log_id_dir
        self.epoch_len = args.epoch_len
        self.summary_freq = args.summary_freq
        self.logger = logger
        self.local_rank = local_rank
        self.global_rank = global_rank
        self.world_size = world_size

        if args.load_network:
            self.network = self.load_network(self.network, args.load_network)
            logger.info('Reloaded network from {}'.format(args.load_network))
        if args.load_optim:
            self.optimizer = self.load_optim(self.optimizer, args.load_optim)
            logger.info('Reloaded optimizer from {}'.format(args.load_optim))

        self.network.train()
Beispiel #5
0
    def __init__(self, args, logger, log_id_dir, initial_step_count):
        # ENV
        engine = REGISTRY.lookup_engine(args.env)
        env_cls = REGISTRY.lookup_env(args.env)
        mgr_cls = REGISTRY.lookup_manager(args.manager)
        env_mgr = mgr_cls.from_args(args, engine, env_cls)

        # NETWORK
        torch.manual_seed(args.seed)
        if torch.cuda.is_available() and args.gpu_id >= 0:
            device = torch.device("cuda:{}".format(args.gpu_id))
            torch.backends.cudnn.benchmark = True
        else:
            device = torch.device("cpu")
        output_space = REGISTRY.lookup_output_space(args.agent,
                                                    env_mgr.action_space)
        if args.custom_network:
            net_cls = REGISTRY.lookup_network(args.custom_network)
        else:
            net_cls = ModularNetwork
        net = net_cls.from_args(
            args,
            env_mgr.gpu_preprocessor.observation_space,
            output_space,
            env_mgr.gpu_preprocessor,
            REGISTRY,
        )
        logger.info("Network parameters: " + str(self.count_parameters(net)))

        def optim_fn(x):
            if args.optim == "RMSprop":
                return torch.optim.RMSprop(x, lr=args.lr, eps=1e-5, alpha=0.99)
            elif args.optim == "Adam":
                return torch.optim.Adam(x, lr=args.lr, eps=1e-5)

        def warmup_schedule(back_step):
            return back_step / args.warmup if back_step < args.warmup else 1.0

        # AGENT
        rwd_norm = REGISTRY.lookup_reward_normalizer(
            args.rwd_norm).from_args(args)
        agent_cls = REGISTRY.lookup_agent(args.agent)
        builder = agent_cls.exp_spec_builder(
            env_mgr.observation_space,
            env_mgr.action_space,
            net.internal_space(),
            env_mgr.nb_env,
        )
        agent = agent_cls.from_args(args, rwd_norm, env_mgr.action_space,
                                    builder)

        self.agent = agent.to(device)
        self.nb_step = args.nb_step
        self.env_mgr = env_mgr
        self.nb_env = args.nb_env
        self.network = net.to(device)
        self.optimizer = optim_fn(self.network.parameters())
        self.scheduler = LambdaLR(self.optimizer, warmup_schedule)
        self.device = device
        self.initial_step_count = initial_step_count
        self.log_id_dir = log_id_dir
        self.epoch_len = args.epoch_len
        self.summary_freq = args.summary_freq
        self.logger = logger
        self.summary_writer = SummaryWriter(log_id_dir)
        self.saver = SimpleModelSaver(log_id_dir)
        self.updater = LocalUpdater(self.optimizer, self.network,
                                    args.grad_norm_clip)

        if args.load_network:
            self.network = self.load_network(self.network, args.load_network)
            logger.info("Reloaded network from {}".format(args.load_network))
        if args.load_optim:
            self.optimizer = self.load_optim(self.optimizer, args.load_optim)
            logger.info("Reloaded optimizer from {}".format(args.load_optim))

        self.network.train()