Exemple #1
0
class DistribHost(Container):
    """
    DistribHost saves models and writes summaries. This is the only difference
    from the worker.
    """

    def __init__(
        self,
        args,
        logger,
        log_id_dir,
        initial_step_count,
        local_rank,
        global_rank,
        world_size,
    ):
        seed = (
            args.seed
            if global_rank == 0
            else args.seed + args.nb_env * global_rank
        )
        logger.info("Using {} for rank {} seed.".format(seed, global_rank))

        # ENV
        engine = REGISTRY.lookup_engine(args.env)
        env_cls = REGISTRY.lookup_env(args.env)
        mgr_cls = REGISTRY.lookup_manager(args.manager)
        env_mgr = mgr_cls.from_args(args, engine, env_cls, seed=seed)

        # NETWORK
        torch.manual_seed(args.seed)
        device = torch.device("cuda:{}".format(local_rank))
        output_space = REGISTRY.lookup_output_space(
            args.agent, env_mgr.action_space
        )
        if args.custom_network:
            net_cls = REGISTRY.lookup_network(args.custom_network)
        else:
            net_cls = ModularNetwork
        net = net_cls.from_args(
            args,
            env_mgr.observation_space,
            output_space,
            env_mgr.gpu_preprocessor,
            REGISTRY,
        )
        logger.info("Network parameters: " + str(self.count_parameters(net)))

        def optim_fn(x):
            return torch.optim.RMSprop(x, lr=args.lr, eps=1e-5, alpha=0.99)

        # AGENT
        rwd_norm = REGISTRY.lookup_reward_normalizer(args.rwd_norm).from_args(
            args
        )
        agent_cls = REGISTRY.lookup_agent(args.agent)
        builder = agent_cls.exp_spec_builder(
            env_mgr.observation_space,
            env_mgr.action_space,
            net.internal_space(),
            env_mgr.nb_env,
        )
        agent = agent_cls.from_args(
            args, rwd_norm, env_mgr.action_space, builder
        )

        self.agent = agent
        self.nb_step = args.nb_step
        self.env_mgr = env_mgr
        self.nb_env = args.nb_env
        self.network = net.to(device)
        self.optimizer = optim_fn(self.network.parameters())
        self.device = device
        self.initial_step_count = initial_step_count
        self.log_id_dir = log_id_dir
        self.epoch_len = args.epoch_len
        self.summary_freq = args.summary_freq
        self.logger = logger
        self.summary_writer = SummaryWriter(
            os.path.join(log_id_dir, "rank{}".format(global_rank))
        )
        self.saver = SimpleModelSaver(log_id_dir)
        self.local_rank = local_rank
        self.global_rank = global_rank
        self.world_size = world_size
        self.updater = DistribUpdater(
            self.optimizer,
            self.network,
            args.grad_norm_clip,
            world_size,
            not args.no_divide,
        )

        if args.load_network:
            self.network = self.load_network(self.network, args.load_network)
            logger.info("Reloaded network from {}".format(args.load_network))
        if args.load_optim:
            self.optimizer = self.load_optim(self.optimizer, args.load_optim)
            logger.info("Reloaded optimizer from {}".format(args.load_optim))

        self.network.train()

    def run(self):
        local_step_count = global_step_count = self.initial_step_count
        next_save = self.init_next_save(self.initial_step_count, self.epoch_len)
        prev_step_t = time()
        ep_rewards = torch.zeros(self.nb_env)

        obs = dtensor_to_dev(self.env_mgr.reset(), self.device)
        internals = listd_to_dlist(
            [
                self.network.new_internals(self.device)
                for _ in range(self.nb_env)
            ]
        )
        start_time = time()
        while global_step_count < self.nb_step:
            actions, internals = self.agent.act(self.network, obs, internals)
            next_obs, rewards, terminals, infos = self.env_mgr.step(actions)
            next_obs = dtensor_to_dev(next_obs, self.device)

            self.agent.observe(
                obs,
                rewards.to(self.device).float(),
                terminals.to(self.device).float(),
                infos,
            )
            for i, terminal in enumerate(terminals):
                if terminal:
                    for k, v in self.network.new_internals(self.device).items():
                        internals[k][i] = v

            # Perform state updates
            local_step_count += self.nb_env
            global_step_count += self.nb_env * self.world_size
            ep_rewards += rewards.float()
            obs = next_obs

            term_rewards = []
            for i, terminal in enumerate(terminals):
                if terminal:
                    for k, v in self.network.new_internals(self.device).items():
                        internals[k][i] = v
                    term_rewards.append(ep_rewards[i].item())
                    ep_rewards[i].zero_()

            if term_rewards:
                term_reward = np.mean(term_rewards)
                delta_t = time() - start_time
                self.logger.info(
                    "RANK: {} "
                    "GLOBAL STEP: {} "
                    "REWARD: {} "
                    "GLOBAL STEP/S: {} "
                    "LOCAL STEP/S: {}".format(
                        self.global_rank,
                        global_step_count,
                        term_reward,
                        (global_step_count - self.initial_step_count) / delta_t,
                        (local_step_count - self.initial_step_count) / delta_t,
                    )
                )
                self.summary_writer.add_scalar(
                    "reward", term_reward, global_step_count
                )

            if global_step_count >= next_save:
                self.saver.save_state_dicts(
                    self.network, global_step_count, self.optimizer
                )
                next_save += self.epoch_len

            # Learn
            if self.agent.is_ready():
                loss_dict, metric_dict = self.agent.learn_step(
                    self.updater, self.network, next_obs, internals
                )
                total_loss = torch.sum(
                    torch.stack(tuple(loss for loss in loss_dict.values()))
                )

                self.agent.clear()
                for k, vs in internals.items():
                    internals[k] = [v.detach() for v in vs]

                # write summaries
                cur_step_t = time()
                if cur_step_t - prev_step_t > self.summary_freq:
                    self.write_summaries(
                        self.summary_writer,
                        global_step_count,
                        total_loss,
                        loss_dict,
                        metric_dict,
                        self.network.named_parameters(),
                    )
                    prev_step_t = cur_step_t

    def close(self):
        return self.env_mgr.close()
Exemple #2
0
    def __init__(
        self,
        args,
        logger,
        log_id_dir,
        initial_step_count,
        local_rank,
        global_rank,
        world_size,
    ):
        seed = (
            args.seed
            if global_rank == 0
            else args.seed + args.nb_env * global_rank
        )
        logger.info("Using {} for rank {} seed.".format(seed, global_rank))

        # ENV
        engine = REGISTRY.lookup_engine(args.env)
        env_cls = REGISTRY.lookup_env(args.env)
        mgr_cls = REGISTRY.lookup_manager(args.manager)
        env_mgr = mgr_cls.from_args(args, engine, env_cls, seed=seed)

        # NETWORK
        torch.manual_seed(args.seed)
        device = torch.device("cuda:{}".format(local_rank))
        output_space = REGISTRY.lookup_output_space(
            args.agent, env_mgr.action_space
        )
        if args.custom_network:
            net_cls = REGISTRY.lookup_network(args.custom_network)
        else:
            net_cls = ModularNetwork
        net = net_cls.from_args(
            args,
            env_mgr.observation_space,
            output_space,
            env_mgr.gpu_preprocessor,
            REGISTRY,
        )
        logger.info("Network parameters: " + str(self.count_parameters(net)))

        def optim_fn(x):
            return torch.optim.RMSprop(x, lr=args.lr, eps=1e-5, alpha=0.99)

        # AGENT
        rwd_norm = REGISTRY.lookup_reward_normalizer(args.rwd_norm).from_args(
            args
        )
        agent_cls = REGISTRY.lookup_agent(args.agent)
        builder = agent_cls.exp_spec_builder(
            env_mgr.observation_space,
            env_mgr.action_space,
            net.internal_space(),
            env_mgr.nb_env,
        )
        agent = agent_cls.from_args(
            args, rwd_norm, env_mgr.action_space, builder
        )

        self.agent = agent
        self.nb_step = args.nb_step
        self.env_mgr = env_mgr
        self.nb_env = args.nb_env
        self.network = net.to(device)
        self.optimizer = optim_fn(self.network.parameters())
        self.device = device
        self.initial_step_count = initial_step_count
        self.log_id_dir = log_id_dir
        self.epoch_len = args.epoch_len
        self.summary_freq = args.summary_freq
        self.logger = logger
        self.summary_writer = SummaryWriter(
            os.path.join(log_id_dir, "rank{}".format(global_rank))
        )
        self.saver = SimpleModelSaver(log_id_dir)
        self.local_rank = local_rank
        self.global_rank = global_rank
        self.world_size = world_size
        self.updater = DistribUpdater(
            self.optimizer,
            self.network,
            args.grad_norm_clip,
            world_size,
            not args.no_divide,
        )

        if args.load_network:
            self.network = self.load_network(self.network, args.load_network)
            logger.info("Reloaded network from {}".format(args.load_network))
        if args.load_optim:
            self.optimizer = self.load_optim(self.optimizer, args.load_optim)
            logger.info("Reloaded optimizer from {}".format(args.load_optim))

        self.network.train()
Exemple #3
0
class ActorLearnerHost(Container):
    @classmethod
    def as_remote(cls,
                  num_cpus=None,
                  num_gpus=None,
                  memory=None,
                  object_store_memory=None,
                  resources=None):
        return ray.remote(
            num_cpus=num_cpus,
            num_gpus=num_gpus,
            memory=memory,
            object_store_memory=object_store_memory,
            resources=resources)(cls)

    def __init__(
            self,
            args,
            log_id_dir,
            initial_step_count,
            rank=0,
    ):
        # ARGS TO STATE VARS
        self._args = args
        self.nb_learners = args.nb_learners
        self.nb_workers = args.nb_workers
        self.rank = rank
        self.nb_step = args.nb_step
        self.nb_env = args.nb_env
        self.initial_step_count = initial_step_count
        self.epoch_len = args.epoch_len
        self.summary_freq = args.summary_freq
        self.nb_learn_batch = args.nb_learn_batch
        self.rollout_queue_size = args.rollout_queue_size
        # can be none if rank != 0
        self.log_id_dir = log_id_dir

        # load saved registry classes
        REGISTRY.load_extern_classes(log_id_dir)

        # ENV (temporary)
        env_cls = REGISTRY.lookup_env(args.env)
        env = env_cls.from_args(args, 0)
        env_action_space, env_observation_space, env_gpu_preprocessor = \
            env.action_space, env.observation_space, env.gpu_preprocessor
        env.close()

        # NETWORK
        torch.manual_seed(args.seed)
        device = torch.device("cuda")  # ray handles gpus
        torch.backends.cudnn.benchmark = True
        output_space = REGISTRY.lookup_output_space(
            args.actor_worker, env_action_space)
        if args.custom_network:
            net_cls = REGISTRY.lookup_network(args.custom_network)
        else:
            net_cls = ModularNetwork
        net = net_cls.from_args(
            args,
            env_observation_space,
            output_space,
            env_gpu_preprocessor,
            REGISTRY
        )
        self.network = net.to(device)
        # TODO: this is a hack, remove once queuer puts rollouts on the correct device
        self.network.device = device
        self.device = device
        self.network.train()

        # OPTIMIZER
        def optim_fn(x):
            return torch.optim.RMSprop(x, lr=args.lr, eps=1e-5, alpha=0.99)
        if args.nb_learners > 1:
            self.optimizer = NCCLOptimizer(optim_fn, self.network, self.nb_learners)
        else:
            self.optimizer = optim_fn(self.network.parameters())

        # LEARNER / EXP
        rwd_norm = REGISTRY.lookup_reward_normalizer(
            args.rwd_norm).from_args(args)
        actor_cls = REGISTRY.lookup_actor(args.actor_host)
        builder = actor_cls.exp_spec_builder(
            env.observation_space,
            env.action_space,
            net.internal_space(),
            args.nb_env * args.nb_learn_batch
        )
        w_builder = REGISTRY.lookup_actor(args.actor_worker).exp_spec_builder(
            env.observation_space,
            env.action_space,
            net.internal_space(),
            args.nb_env
        )
        actor = actor_cls.from_args(args, env.action_space)
        learner = REGISTRY.lookup_learner(args.learner).from_args(args, rwd_norm)

        exp_cls = REGISTRY.lookup_exp(args.exp).from_args(args, builder)

        self.actor = actor
        self.learner = learner
        self.exp = exp_cls.from_args(args, builder).to(device)

        # Rank 0 setup, load network/optimizer and create SummaryWriter/Saver
        if rank == 0:
            if args.load_network:
                self.network = self.load_network(self.network, args.load_network)
                print('Reloaded network from {}'.format(args.load_network))
            if args.load_optim:
                self.optimizer = self.load_optim(self.optimizer, args.load_optim)
                print('Reloaded optimizer from {}'.format(args.load_optim))

            print('Network parameters: ' + str(self.count_parameters(net)))
            self.summary_writer = SummaryWriter(log_id_dir)
            self.saver = SimpleModelSaver(log_id_dir)

    def run(self, workers, profile=False):
        if profile:
            try:
                from pyinstrument import Profiler
            except:
                raise ImportError('You must install pyinstrument to use profiling.')
            profiler = Profiler()
            profiler.start()

        # setup queuer
        rollout_queuer = RolloutQueuerAsync(workers, self.nb_learn_batch, self.rollout_queue_size)
        rollout_queuer.start()

        # initial setup
        global_step_count = self.initial_step_count
        next_save = self.init_next_save(self.initial_step_count, self.epoch_len)
        prev_step_t = time()
        ep_rewards = torch.zeros(self.nb_env)
        start_time = time()

        # loop until total number steps
        print('{} starting training'.format(self.rank))
        while not self.done(global_step_count):
            self.exp.clear()
            # Get batch from queue
            rollouts, terminal_rewards, terminal_infos = rollout_queuer.get()

            # Iterate forward on batch
            self.exp.write_exps(rollouts)
            # keep a copy of terminals on the cpu it's faster
            rollout_terminals = torch.stack(self.exp['terminals']).numpy()
            self.exp.to(self.device)
            r = self.exp.read()
            internals = {k: ts[0].unbind(0) for k, ts in r.internals.items()}
            for obs, rewards, terminals in zip(
                    r.observations,
                    r.rewards,
                    rollout_terminals
            ):
                _, h_exp, internals = self.actor.act(self.network, obs,
                                                     internals)
                self.exp.write_actor(h_exp, no_env=True)

                # where returns a single element tuple with the indexes
                terminal_inds = np.where(terminals)[0]
                for i in terminal_inds:
                    for k, v in self.network.new_internals(self.device).items():
                        internals[k][i] = v

            # compute loss
            loss_dict, metric_dict = self.learner.compute_loss(
                self.network, self.exp.read(), r.next_observation, internals
            )
            total_loss = torch.sum(
                torch.stack(tuple(loss for loss in loss_dict.values()))
            )

            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()

            # Perform state updates
            global_step_count += self.nb_env * self.nb_learn_batch * len(r.terminals) * self.nb_learners

            # if rank 0 write summaries and save
            # and send parameters to workers async
            if self.rank == 0:
                # TODO: this could be parallelized, chunk by nb learners
                self.synchronize_worker_parameters(workers, global_step_count)

                # possible save
                if global_step_count >= next_save:
                    self.saver.save_state_dicts(
                        self.network, global_step_count, self.optimizer
                    )
                    next_save += self.epoch_len

                # write reward summaries
                if any(terminal_rewards):
                    terminal_rewards = list(filter(lambda x: x is not None, terminal_rewards))
                    self.summary_writer.add_scalar(
                        'reward', np.mean(terminal_rewards), global_step_count
                    )

                # write infos
                if any(terminal_infos):
                    terminal_infos = list(filter(lambda x: x is not None, terminal_infos))
                    float_keys = [
                        k for k, v in terminal_infos[0].items() if type(v) == float
                    ]
                    terminal_infos_dlist = listd_to_dlist(terminal_infos)
                    for k in float_keys:
                        self.summary_writer.add_scalar(
                            f'info/{k}',
                            np.mean(terminal_infos_dlist[k]),
                            global_step_count
                        )

            # write summaries
            cur_step_t = time()
            if cur_step_t - prev_step_t > self.summary_freq:
                print('Rank {} Metrics:'.format(self.rank), rollout_queuer.metrics())
                if self.rank == 0:
                    self.write_summaries(
                        self.summary_writer, global_step_count, total_loss,
                        loss_dict, metric_dict, self.network.named_parameters()
                    )
                prev_step_t = cur_step_t

        rollout_queuer.close()
        print('{} stopped training'.format(self.rank))
        if profile:
            profiler.stop()
            print(profiler.output_text(unicode=True, color=True))

    def done(self, global_step_count):
        return global_step_count >= self.nb_step

    def close(self):
        pass

    def get_parameters(self):
        params = [p.cpu() for p in self.network.parameters()]
        return params

    def synchronize_worker_parameters(self, workers, global_step_count=0, blocking=False):
        parameters = self.get_parameters()
        futures = [w.set_weights.remote(parameters) for w in workers]

        if global_step_count != 0:
            futures.extend([w.set_global_step.remote(global_step_count) for w in workers])

        if blocking:
            ray.get(futures)

    def _rank0_nccl_port_init(self):
        ip = ray.services.get_node_ip_address()
        port = find_free_port()
        nccl_addr = "tcp://{ip}:{port}".format(ip=ip, port=port)
        return nccl_addr, ip, port

    def _nccl_init(self, nccl_addr, nccl_ip, nccl_port):
        self.nccl_ip, self.nccl_addr, self.nccl_port = nccl_ip, nccl_addr, nccl_port
        print('Rank {} calling init_process_group. Addr: {}'.format(self.rank, nccl_addr))
        # from https://github.com/pytorch/pytorch/blob/master/test/simulate_nccl_errors.py
        store = dist.TCPStore(self.nccl_ip, self.nccl_port, self.nb_learners, self.rank == 0)
        process_group = dist.ProcessGroupNCCL(store, self.rank, self.nb_learners)
        print('Rank {} initialized process group.'.format(self.rank))
        process_group.barrier()
        print('Rank {} process group barrier finished.'.format(self.rank))
        self.process_group = process_group
        # set optimizer process_group
        self.optimizer.set_process_group(self.process_group)

    def _sync_peer_parameters(self):
        print('Rank {} syncing parameters.'.format(self.rank))
        self.process_group.barrier()
        for p in self.network.parameters():
            self.process_group.allreduce(p.data)
            p.data = p.data / self.nb_learners
        print('Rank {} parameters synced.'.format(self.rank))
Exemple #4
0
    def __init__(
            self,
            args,
            log_id_dir,
            initial_step_count,
            rank=0,
    ):
        # ARGS TO STATE VARS
        self._args = args
        self.nb_learners = args.nb_learners
        self.nb_workers = args.nb_workers
        self.rank = rank
        self.nb_step = args.nb_step
        self.nb_env = args.nb_env
        self.initial_step_count = initial_step_count
        self.epoch_len = args.epoch_len
        self.summary_freq = args.summary_freq
        self.nb_learn_batch = args.nb_learn_batch
        self.rollout_queue_size = args.rollout_queue_size
        # can be none if rank != 0
        self.log_id_dir = log_id_dir

        # load saved registry classes
        REGISTRY.load_extern_classes(log_id_dir)

        # ENV (temporary)
        env_cls = REGISTRY.lookup_env(args.env)
        env = env_cls.from_args(args, 0)
        env_action_space, env_observation_space, env_gpu_preprocessor = \
            env.action_space, env.observation_space, env.gpu_preprocessor
        env.close()

        # NETWORK
        torch.manual_seed(args.seed)
        device = torch.device("cuda")  # ray handles gpus
        torch.backends.cudnn.benchmark = True
        output_space = REGISTRY.lookup_output_space(
            args.actor_worker, env_action_space)
        if args.custom_network:
            net_cls = REGISTRY.lookup_network(args.custom_network)
        else:
            net_cls = ModularNetwork
        net = net_cls.from_args(
            args,
            env_observation_space,
            output_space,
            env_gpu_preprocessor,
            REGISTRY
        )
        self.network = net.to(device)
        # TODO: this is a hack, remove once queuer puts rollouts on the correct device
        self.network.device = device
        self.device = device
        self.network.train()

        # OPTIMIZER
        def optim_fn(x):
            return torch.optim.RMSprop(x, lr=args.lr, eps=1e-5, alpha=0.99)
        if args.nb_learners > 1:
            self.optimizer = NCCLOptimizer(optim_fn, self.network, self.nb_learners)
        else:
            self.optimizer = optim_fn(self.network.parameters())

        # LEARNER / EXP
        rwd_norm = REGISTRY.lookup_reward_normalizer(
            args.rwd_norm).from_args(args)
        actor_cls = REGISTRY.lookup_actor(args.actor_host)
        builder = actor_cls.exp_spec_builder(
            env.observation_space,
            env.action_space,
            net.internal_space(),
            args.nb_env * args.nb_learn_batch
        )
        w_builder = REGISTRY.lookup_actor(args.actor_worker).exp_spec_builder(
            env.observation_space,
            env.action_space,
            net.internal_space(),
            args.nb_env
        )
        actor = actor_cls.from_args(args, env.action_space)
        learner = REGISTRY.lookup_learner(args.learner).from_args(args, rwd_norm)

        exp_cls = REGISTRY.lookup_exp(args.exp).from_args(args, builder)

        self.actor = actor
        self.learner = learner
        self.exp = exp_cls.from_args(args, builder).to(device)

        # Rank 0 setup, load network/optimizer and create SummaryWriter/Saver
        if rank == 0:
            if args.load_network:
                self.network = self.load_network(self.network, args.load_network)
                print('Reloaded network from {}'.format(args.load_network))
            if args.load_optim:
                self.optimizer = self.load_optim(self.optimizer, args.load_optim)
                print('Reloaded optimizer from {}'.format(args.load_optim))

            print('Network parameters: ' + str(self.count_parameters(net)))
            self.summary_writer = SummaryWriter(log_id_dir)
            self.saver = SimpleModelSaver(log_id_dir)
Exemple #5
0
def main(args):
    # host needs to broadcast timestamp so all procs create the same log dir
    if rank == 0:
        timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        log_id = make_log_id_from_timestamp(
            args.tag, args.mode_name, args.agent,
            args.vision_network + args.network_body, timestamp)
        log_id_dir = os.path.join(args.log_dir, args.env_id, log_id)
        os.makedirs(log_id_dir)
        saver = SimpleModelSaver(log_id_dir)
        print_ascii_logo()
    else:
        timestamp = None
    timestamp = comm.bcast(timestamp, root=0)

    if rank != 0:
        log_id = make_log_id_from_timestamp(
            args.tag, args.mode_name, args.agent,
            args.vision_network + args.network_body, timestamp)
        log_id_dir = os.path.join(args.log_dir, args.env_id, log_id)

    comm.Barrier()

    # construct env
    seed = args.seed if rank == 0 else args.seed + (
        args.nb_env * (rank - 1))  # unique seed per process
    env = make_env(args, seed)

    # construct network
    torch.manual_seed(args.seed)
    network_head_shapes = get_head_shapes(env.action_space, env.engine,
                                          args.agent)
    network = make_network(env.observation_space, network_head_shapes, args)

    # sync network params
    if rank == 0:
        for v in network.parameters():
            comm.Bcast(v.detach().cpu().numpy(), root=0)
        print('Root variables synced')
    else:
        # can just use the numpy buffers
        variables = [v.detach().cpu().numpy() for v in network.parameters()]
        for v in variables:
            comm.Bcast(v, root=0)
        for shared_v, model_v in zip(variables, network.parameters()):
            model_v.data.copy_(torch.from_numpy(shared_v), non_blocking=True)
        print('{} variables synced'.format(rank))

    # construct agent
    # host is always the first gpu, workers are distributed evenly across the rest
    if len(args.gpu_id) > 1:  # nargs is always a list
        if rank == 0:
            gpu_id = args.gpu_id[0]
        else:
            gpu_id = args.gpu_id[1:][(rank - 1) % len(args.gpu_id[1:])]
    else:
        gpu_id = args.gpu_id[-1]
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    cudnn = True
    # disable cudnn for dynamic batches
    if rank == 0 and args.max_dynamic_batch > 0:
        cudnn = False

    torch.backends.cudnn.benchmark = cudnn
    agent = make_agent(network, device, env.engine, env.gpu_preprocessor, args)

    # workers
    if rank != 0:
        logger = make_logger(
            'ImpalaWorker{}'.format(rank),
            os.path.join(log_id_dir, 'train_log{}.txt'.format(rank)))
        summary_writer = SummaryWriter(os.path.join(log_id_dir, str(rank)))
        container = ImpalaWorker(agent,
                                 env,
                                 args.nb_env,
                                 logger,
                                 summary_writer,
                                 use_local_buffers=args.use_local_buffers)

        # Run the container
        if args.profile:
            try:
                from pyinstrument import Profiler
            except:
                raise ImportError(
                    'You must install pyinstrument to use profiling.')
            profiler = Profiler()
            profiler.start()
            container.run()
            profiler.stop()
            print(profiler.output_text(unicode=True, color=True))
        else:
            container.run()
        env.close()
    # host
    else:
        logger = make_logger(
            'ImpalaHost',
            os.path.join(log_id_dir, 'train_log{}.txt'.format(rank)))
        summary_writer = SummaryWriter(os.path.join(log_id_dir, str(rank)))
        log_args(logger, args)
        write_args_file(log_id_dir, args)
        logger.info('Network Parameter Count: {}'.format(
            count_parameters(network)))

        # no need for the env anymore
        env.close()

        # Construct the optimizer
        def make_optimizer(params):
            opt = torch.optim.RMSprop(params,
                                      lr=args.learning_rate,
                                      eps=1e-5,
                                      alpha=0.99)
            return opt

        container = ImpalaHost(agent,
                               comm,
                               make_optimizer,
                               summary_writer,
                               args.summary_frequency,
                               saver,
                               args.epoch_len,
                               args.host_training_info_interval,
                               use_local_buffers=args.use_local_buffers)

        # Run the container
        if args.profile:
            try:
                from pyinstrument import Profiler
            except:
                raise ImportError(
                    'You must install pyinstrument to use profiling.')
            profiler = Profiler()
            profiler.start()
            if args.max_dynamic_batch > 0:
                container.run(args.max_dynamic_batch,
                              args.max_queue_length,
                              args.max_train_steps,
                              dynamic=True,
                              min_dynamic_batch=args.min_dynamic_batch)
            else:
                container.run(args.num_rollouts_in_batch,
                              args.max_queue_length, args.max_train_steps)
            profiler.stop()
            print(profiler.output_text(unicode=True, color=True))
        else:
            if args.max_dynamic_batch > 0:
                container.run(args.max_dynamic_batch,
                              args.max_queue_length,
                              args.max_train_steps,
                              dynamic=True,
                              min_dynamic_batch=args.min_dynamic_batch)
            else:
                container.run(args.num_rollouts_in_batch,
                              args.max_queue_length, args.max_train_steps)
Exemple #6
0
def main(args):
    # construct logging objects
    print_ascii_logo()
    log_id = make_log_id(args.tag, args.mode_name, args.agent,
                         args.vision_network + args.network_body)
    log_id_dir = os.path.join(args.log_dir, args.env_id, log_id)

    os.makedirs(log_id_dir)
    logger = make_logger('Local', os.path.join(log_id_dir, 'train_log.txt'))
    summary_writer = SummaryWriter(log_id_dir)
    saver = SimpleModelSaver(log_id_dir)

    log_args(logger, args)
    write_args_file(log_id_dir, args)

    # construct env
    env = make_env(args, args.seed)

    # construct network
    torch.manual_seed(args.seed)
    network_head_shapes = get_head_shapes(env.action_space, env.engine,
                                          args.agent)
    network = make_network(env.observation_space, network_head_shapes, args)
    logger.info('Network Parameter Count: {}'.format(
        count_parameters(network)))

    # construct agent
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = True
    agent = make_agent(network, device, env.engine, env.gpu_preprocessor, args)

    # Construct the Container
    def make_optimizer(params):
        opt = torch.optim.RMSprop(params,
                                  lr=args.learning_rate,
                                  eps=1e-5,
                                  alpha=0.99)
        return opt

    container = Local(agent, env, make_optimizer, args.epoch_len, args.nb_env,
                      logger, summary_writer, args.summary_frequency, saver)

    # if running an eval thread create eval env, agent, & logger
    if args.nb_eval_env > 0:
        # replace args num envs & seed
        eval_args = deepcopy(args)
        eval_args.seed = args.seed + args.nb_env

        # env and agent
        eval_args.nb_env = args.nb_eval_env
        eval_env = make_env(eval_args, eval_args.seed)
        eval_net = make_network(eval_env.observation_space,
                                network_head_shapes, eval_args)
        eval_agent = make_agent(eval_net, device, eval_env.engine,
                                eval_env.gpu_preprocessor, eval_args)
        eval_net.load_state_dict(network.state_dict())

        # logger
        eval_logger = make_logger('LocalEval',
                                  os.path.join(log_id_dir, 'eval_log.txt'))

        evaluation_container = EvaluationThread(
            network,
            eval_agent,
            eval_env,
            args.nb_eval_env,
            eval_logger,
            summary_writer,
            args.eval_step_rate,
            override_step_count_fn=lambda: container.
            local_step_count  # wire local containers step count into eval
        )
        evaluation_container.start()

    # Run the container
    if args.profile:
        try:
            from pyinstrument import Profiler
        except:
            raise ImportError(
                'You must install pyinstrument to use profiling.')
        profiler = Profiler()
        profiler.start()
        container.run(10e3)
        profiler.stop()
        print(profiler.output_text(unicode=True, color=True))
    else:
        container.run(args.max_train_steps)
    env.close()

    if args.nb_eval_env > 0:
        evaluation_container.stop()
        eval_env.close()
Exemple #7
0
def main(args):
    # host needs to broadcast timestamp so all procs create the same log dir
    if rank == 0:
        timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        log_id = make_log_id_from_timestamp(
            args.tag, args.mode_name, args.agent,
            args.vision_network + args.network_body, timestamp)
        log_id_dir = os.path.join(args.log_dir, args.env_id, log_id)
        os.makedirs(log_id_dir)
        saver = SimpleModelSaver(log_id_dir)
        print_ascii_logo()
    else:
        timestamp = None
    timestamp = comm.bcast(timestamp, root=0)

    if rank != 0:
        log_id = make_log_id_from_timestamp(
            args.tag, args.mode_name, args.agent,
            args.vision_network + args.network_body, timestamp)
        log_id_dir = os.path.join(args.log_dir, args.env_id, log_id)

    comm.Barrier()

    # construct env
    seed = args.seed if rank == 0 else args.seed + (
        args.nb_env * (rank - 1))  # unique seed per process
    # don't make a ton of envs if host
    if rank == 0:
        env_args = deepcopy(args)
        env_args.nb_env = 1
        env = make_env(env_args, seed)
    else:
        env = make_env(args, seed)

    # construct network
    torch.manual_seed(args.seed)
    network_head_shapes = get_head_shapes(env.action_space, env.engine,
                                          args.agent)
    network = make_network(env.observation_space, network_head_shapes, args)

    # sync network params
    if rank == 0:
        for v in network.parameters():
            comm.Bcast(v.detach().cpu().numpy(), root=0)
        print('Root variables synced')
    else:
        # can just use the numpy buffers
        variables = [v.detach().cpu().numpy() for v in network.parameters()]
        for v in variables:
            comm.Bcast(v, root=0)
        for shared_v, model_v in zip(variables, network.parameters()):
            model_v.data.copy_(torch.from_numpy(shared_v), non_blocking=True)
        print('{} variables synced'.format(rank))

    # host is rank 0
    if rank != 0:
        # construct logger
        logger = make_logger(
            'ToweredWorker{}'.format(rank),
            os.path.join(log_id_dir, 'train_log_rank{}.txt'.format(rank)))
        summary_writer = SummaryWriter(
            os.path.join(log_id_dir, 'rank{}'.format(rank)))

        # construct agent
        # distribute evenly across gpus
        if isinstance(args.gpu_id, list):
            gpu_id = args.gpu_id[(rank - 1) % len(args.gpu_id)]
        else:
            gpu_id = args.gpu_id
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        torch.backends.cudnn.benchmark = True
        agent = make_agent(network, device, env.engine, env.gpu_preprocessor,
                           args)

        # construct container
        container = ToweredWorker(agent, env, args.nb_env, logger,
                                  summary_writer, args.summary_frequency)

        # Run the container
        try:
            container.run()
        finally:
            env.close()
    # host
    else:
        logger = make_logger(
            'ToweredHost',
            os.path.join(log_id_dir, 'train_log_rank{}.txt'.format(rank)))
        log_args(logger, args)
        write_args_file(log_id_dir, args)
        logger.info('Network Parameter Count: {}'.format(
            count_parameters(network)))

        # no need for the env anymore
        env.close()

        # Construct the optimizer
        def make_optimizer(params):
            opt = torch.optim.RMSprop(params,
                                      lr=args.learning_rate,
                                      eps=1e-5,
                                      alpha=0.99)
            return opt

        container = ToweredHost(comm, args.num_grads_to_drop, network,
                                make_optimizer, saver, args.epoch_len, logger)

        # Run the container
        if args.profile:
            try:
                from pyinstrument import Profiler
            except:
                raise ImportError(
                    'You must install pyinstrument to use profiling.')
            profiler = Profiler()
            profiler.start()
            container.run(10e3)
            profiler.stop()
            print(profiler.output_text(unicode=True, color=True))
        else:
            container.run(args.max_train_steps)
Exemple #8
0
    def __init__(self, args, logger, log_id_dir, initial_step_count):
        # ENV
        engine = REGISTRY.lookup_engine(args.env)
        env_cls = REGISTRY.lookup_env(args.env)
        mgr_cls = REGISTRY.lookup_manager(args.manager)
        env_mgr = mgr_cls.from_args(args, engine, env_cls)

        # NETWORK
        torch.manual_seed(args.seed)
        if torch.cuda.is_available() and args.gpu_id >= 0:
            device = torch.device("cuda:{}".format(args.gpu_id))
            torch.backends.cudnn.benchmark = True
        else:
            device = torch.device("cpu")
        output_space = REGISTRY.lookup_output_space(args.agent,
                                                    env_mgr.action_space)
        if args.custom_network:
            net_cls = REGISTRY.lookup_network(args.custom_network)
        else:
            net_cls = ModularNetwork
        net = net_cls.from_args(
            args,
            env_mgr.gpu_preprocessor.observation_space,
            output_space,
            env_mgr.gpu_preprocessor,
            REGISTRY,
        )
        logger.info("Network parameters: " + str(self.count_parameters(net)))

        def optim_fn(x):
            if args.optim == "RMSprop":
                return torch.optim.RMSprop(x, lr=args.lr, eps=1e-5, alpha=0.99)
            elif args.optim == "Adam":
                return torch.optim.Adam(x, lr=args.lr, eps=1e-5)

        def warmup_schedule(back_step):
            return back_step / args.warmup if back_step < args.warmup else 1.0

        # AGENT
        rwd_norm = REGISTRY.lookup_reward_normalizer(
            args.rwd_norm).from_args(args)
        agent_cls = REGISTRY.lookup_agent(args.agent)
        builder = agent_cls.exp_spec_builder(
            env_mgr.observation_space,
            env_mgr.action_space,
            net.internal_space(),
            env_mgr.nb_env,
        )
        agent = agent_cls.from_args(args, rwd_norm, env_mgr.action_space,
                                    builder)

        self.agent = agent.to(device)
        self.nb_step = args.nb_step
        self.env_mgr = env_mgr
        self.nb_env = args.nb_env
        self.network = net.to(device)
        self.optimizer = optim_fn(self.network.parameters())
        self.scheduler = LambdaLR(self.optimizer, warmup_schedule)
        self.device = device
        self.initial_step_count = initial_step_count
        self.log_id_dir = log_id_dir
        self.epoch_len = args.epoch_len
        self.summary_freq = args.summary_freq
        self.logger = logger
        self.summary_writer = SummaryWriter(log_id_dir)
        self.saver = SimpleModelSaver(log_id_dir)
        self.updater = LocalUpdater(self.optimizer, self.network,
                                    args.grad_norm_clip)

        if args.load_network:
            self.network = self.load_network(self.network, args.load_network)
            logger.info("Reloaded network from {}".format(args.load_network))
        if args.load_optim:
            self.optimizer = self.load_optim(self.optimizer, args.load_optim)
            logger.info("Reloaded optimizer from {}".format(args.load_optim))

        self.network.train()
Exemple #9
0
class Local(Container):
    def __init__(self, args, logger, log_id_dir, initial_step_count):
        # ENV
        engine = REGISTRY.lookup_engine(args.env)
        env_cls = REGISTRY.lookup_env(args.env)
        mgr_cls = REGISTRY.lookup_manager(args.manager)
        env_mgr = mgr_cls.from_args(args, engine, env_cls)

        # NETWORK
        torch.manual_seed(args.seed)
        if torch.cuda.is_available() and args.gpu_id >= 0:
            device = torch.device("cuda:{}".format(args.gpu_id))
            torch.backends.cudnn.benchmark = True
        else:
            device = torch.device("cpu")
        output_space = REGISTRY.lookup_output_space(args.agent,
                                                    env_mgr.action_space)
        if args.custom_network:
            net_cls = REGISTRY.lookup_network(args.custom_network)
        else:
            net_cls = ModularNetwork
        net = net_cls.from_args(
            args,
            env_mgr.gpu_preprocessor.observation_space,
            output_space,
            env_mgr.gpu_preprocessor,
            REGISTRY,
        )
        logger.info("Network parameters: " + str(self.count_parameters(net)))

        def optim_fn(x):
            if args.optim == "RMSprop":
                return torch.optim.RMSprop(x, lr=args.lr, eps=1e-5, alpha=0.99)
            elif args.optim == "Adam":
                return torch.optim.Adam(x, lr=args.lr, eps=1e-5)

        def warmup_schedule(back_step):
            return back_step / args.warmup if back_step < args.warmup else 1.0

        # AGENT
        rwd_norm = REGISTRY.lookup_reward_normalizer(
            args.rwd_norm).from_args(args)
        agent_cls = REGISTRY.lookup_agent(args.agent)
        builder = agent_cls.exp_spec_builder(
            env_mgr.observation_space,
            env_mgr.action_space,
            net.internal_space(),
            env_mgr.nb_env,
        )
        agent = agent_cls.from_args(args, rwd_norm, env_mgr.action_space,
                                    builder)

        self.agent = agent.to(device)
        self.nb_step = args.nb_step
        self.env_mgr = env_mgr
        self.nb_env = args.nb_env
        self.network = net.to(device)
        self.optimizer = optim_fn(self.network.parameters())
        self.scheduler = LambdaLR(self.optimizer, warmup_schedule)
        self.device = device
        self.initial_step_count = initial_step_count
        self.log_id_dir = log_id_dir
        self.epoch_len = args.epoch_len
        self.summary_freq = args.summary_freq
        self.logger = logger
        self.summary_writer = SummaryWriter(log_id_dir)
        self.saver = SimpleModelSaver(log_id_dir)
        self.updater = LocalUpdater(self.optimizer, self.network,
                                    args.grad_norm_clip)

        if args.load_network:
            self.network = self.load_network(self.network, args.load_network)
            logger.info("Reloaded network from {}".format(args.load_network))
        if args.load_optim:
            self.optimizer = self.load_optim(self.optimizer, args.load_optim)
            logger.info("Reloaded optimizer from {}".format(args.load_optim))

        self.network.train()

    def run(self):
        step_count = self.initial_step_count
        next_save = self.init_next_save(self.initial_step_count,
                                        self.epoch_len)
        prev_step_t = time()
        ep_rewards = torch.zeros(self.nb_env)

        obs = dtensor_to_dev(self.env_mgr.reset(), self.device)
        internals = listd_to_dlist([
            self.network.new_internals(self.device) for _ in range(self.nb_env)
        ])
        start_time = time()
        while step_count < self.nb_step:
            actions, internals = self.agent.act(self.network, obs, internals)
            next_obs, rewards, terminals, infos = self.env_mgr.step(actions)
            next_obs = dtensor_to_dev(next_obs, self.device)

            self.agent.observe(
                obs,
                rewards.to(self.device).float(),
                terminals.to(self.device).float(),
                infos,
            )

            # Perform state updates
            step_count += self.nb_env
            ep_rewards += rewards.float()
            obs = next_obs

            term_rewards, term_infos = [], []
            for i, terminal in enumerate(terminals):
                if terminal:
                    for k, v in self.network.new_internals(
                            self.device).items():
                        internals[k][i] = v
                    term_rewards.append(ep_rewards[i].item())
                    if infos[i]:
                        term_infos.append(infos[i])
                    ep_rewards[i].zero_()

            if term_rewards:
                term_reward = np.mean(term_rewards)
                delta_t = time() - start_time
                self.logger.info("STEP: {} REWARD: {} STEP/S: {}".format(
                    step_count,
                    term_reward,
                    (step_count - self.initial_step_count) / delta_t,
                ))
                self.summary_writer.add_scalar("reward", term_reward,
                                               step_count)
                if term_infos:
                    float_keys = [
                        k for k, v in term_infos[0].items() if type(v) == float
                    ]
                    term_infos_dlist = listd_to_dlist(term_infos)
                    for k in float_keys:
                        self.summary_writer.add_scalar(
                            f"info/{k}",
                            np.mean(term_infos_dlist[k]),
                            step_count,
                        )

            if step_count >= next_save:
                self.saver.save_state_dicts(self.network, step_count,
                                            self.optimizer)
                next_save += self.epoch_len

            # Learn
            if self.agent.is_ready():
                loss_dict, metric_dict = self.agent.learn_step(
                    self.updater,
                    self.network,
                    next_obs,
                    internals,
                )
                total_loss = sum(loss_dict.values())

                epoch = step_count / self.nb_env
                self.scheduler.step(epoch)

                self.agent.clear()
                for k, vs in internals.items():
                    internals[k] = [v.detach() for v in vs]

                # write summaries
                cur_step_t = time()
                if cur_step_t - prev_step_t > self.summary_freq:
                    self.write_summaries(
                        self.summary_writer,
                        step_count,
                        total_loss,
                        loss_dict,
                        metric_dict,
                        self.network.named_parameters(),
                    )
                    prev_step_t = cur_step_t

    def close(self):
        return self.env_mgr.close()