class DistribHost(Container): """ DistribHost saves models and writes summaries. This is the only difference from the worker. """ def __init__( self, args, logger, log_id_dir, initial_step_count, local_rank, global_rank, world_size, ): seed = ( args.seed if global_rank == 0 else args.seed + args.nb_env * global_rank ) logger.info("Using {} for rank {} seed.".format(seed, global_rank)) # ENV engine = REGISTRY.lookup_engine(args.env) env_cls = REGISTRY.lookup_env(args.env) mgr_cls = REGISTRY.lookup_manager(args.manager) env_mgr = mgr_cls.from_args(args, engine, env_cls, seed=seed) # NETWORK torch.manual_seed(args.seed) device = torch.device("cuda:{}".format(local_rank)) output_space = REGISTRY.lookup_output_space( args.agent, env_mgr.action_space ) if args.custom_network: net_cls = REGISTRY.lookup_network(args.custom_network) else: net_cls = ModularNetwork net = net_cls.from_args( args, env_mgr.observation_space, output_space, env_mgr.gpu_preprocessor, REGISTRY, ) logger.info("Network parameters: " + str(self.count_parameters(net))) def optim_fn(x): return torch.optim.RMSprop(x, lr=args.lr, eps=1e-5, alpha=0.99) # AGENT rwd_norm = REGISTRY.lookup_reward_normalizer(args.rwd_norm).from_args( args ) agent_cls = REGISTRY.lookup_agent(args.agent) builder = agent_cls.exp_spec_builder( env_mgr.observation_space, env_mgr.action_space, net.internal_space(), env_mgr.nb_env, ) agent = agent_cls.from_args( args, rwd_norm, env_mgr.action_space, builder ) self.agent = agent self.nb_step = args.nb_step self.env_mgr = env_mgr self.nb_env = args.nb_env self.network = net.to(device) self.optimizer = optim_fn(self.network.parameters()) self.device = device self.initial_step_count = initial_step_count self.log_id_dir = log_id_dir self.epoch_len = args.epoch_len self.summary_freq = args.summary_freq self.logger = logger self.summary_writer = SummaryWriter( os.path.join(log_id_dir, "rank{}".format(global_rank)) ) self.saver = SimpleModelSaver(log_id_dir) self.local_rank = local_rank self.global_rank = global_rank self.world_size = world_size self.updater = DistribUpdater( self.optimizer, self.network, args.grad_norm_clip, world_size, not args.no_divide, ) if args.load_network: self.network = self.load_network(self.network, args.load_network) logger.info("Reloaded network from {}".format(args.load_network)) if args.load_optim: self.optimizer = self.load_optim(self.optimizer, args.load_optim) logger.info("Reloaded optimizer from {}".format(args.load_optim)) self.network.train() def run(self): local_step_count = global_step_count = self.initial_step_count next_save = self.init_next_save(self.initial_step_count, self.epoch_len) prev_step_t = time() ep_rewards = torch.zeros(self.nb_env) obs = dtensor_to_dev(self.env_mgr.reset(), self.device) internals = listd_to_dlist( [ self.network.new_internals(self.device) for _ in range(self.nb_env) ] ) start_time = time() while global_step_count < self.nb_step: actions, internals = self.agent.act(self.network, obs, internals) next_obs, rewards, terminals, infos = self.env_mgr.step(actions) next_obs = dtensor_to_dev(next_obs, self.device) self.agent.observe( obs, rewards.to(self.device).float(), terminals.to(self.device).float(), infos, ) for i, terminal in enumerate(terminals): if terminal: for k, v in self.network.new_internals(self.device).items(): internals[k][i] = v # Perform state updates local_step_count += self.nb_env global_step_count += self.nb_env * self.world_size ep_rewards += rewards.float() obs = next_obs term_rewards = [] for i, terminal in enumerate(terminals): if terminal: for k, v in self.network.new_internals(self.device).items(): internals[k][i] = v term_rewards.append(ep_rewards[i].item()) ep_rewards[i].zero_() if term_rewards: term_reward = np.mean(term_rewards) delta_t = time() - start_time self.logger.info( "RANK: {} " "GLOBAL STEP: {} " "REWARD: {} " "GLOBAL STEP/S: {} " "LOCAL STEP/S: {}".format( self.global_rank, global_step_count, term_reward, (global_step_count - self.initial_step_count) / delta_t, (local_step_count - self.initial_step_count) / delta_t, ) ) self.summary_writer.add_scalar( "reward", term_reward, global_step_count ) if global_step_count >= next_save: self.saver.save_state_dicts( self.network, global_step_count, self.optimizer ) next_save += self.epoch_len # Learn if self.agent.is_ready(): loss_dict, metric_dict = self.agent.learn_step( self.updater, self.network, next_obs, internals ) total_loss = torch.sum( torch.stack(tuple(loss for loss in loss_dict.values())) ) self.agent.clear() for k, vs in internals.items(): internals[k] = [v.detach() for v in vs] # write summaries cur_step_t = time() if cur_step_t - prev_step_t > self.summary_freq: self.write_summaries( self.summary_writer, global_step_count, total_loss, loss_dict, metric_dict, self.network.named_parameters(), ) prev_step_t = cur_step_t def close(self): return self.env_mgr.close()
def __init__( self, args, logger, log_id_dir, initial_step_count, local_rank, global_rank, world_size, ): seed = ( args.seed if global_rank == 0 else args.seed + args.nb_env * global_rank ) logger.info("Using {} for rank {} seed.".format(seed, global_rank)) # ENV engine = REGISTRY.lookup_engine(args.env) env_cls = REGISTRY.lookup_env(args.env) mgr_cls = REGISTRY.lookup_manager(args.manager) env_mgr = mgr_cls.from_args(args, engine, env_cls, seed=seed) # NETWORK torch.manual_seed(args.seed) device = torch.device("cuda:{}".format(local_rank)) output_space = REGISTRY.lookup_output_space( args.agent, env_mgr.action_space ) if args.custom_network: net_cls = REGISTRY.lookup_network(args.custom_network) else: net_cls = ModularNetwork net = net_cls.from_args( args, env_mgr.observation_space, output_space, env_mgr.gpu_preprocessor, REGISTRY, ) logger.info("Network parameters: " + str(self.count_parameters(net))) def optim_fn(x): return torch.optim.RMSprop(x, lr=args.lr, eps=1e-5, alpha=0.99) # AGENT rwd_norm = REGISTRY.lookup_reward_normalizer(args.rwd_norm).from_args( args ) agent_cls = REGISTRY.lookup_agent(args.agent) builder = agent_cls.exp_spec_builder( env_mgr.observation_space, env_mgr.action_space, net.internal_space(), env_mgr.nb_env, ) agent = agent_cls.from_args( args, rwd_norm, env_mgr.action_space, builder ) self.agent = agent self.nb_step = args.nb_step self.env_mgr = env_mgr self.nb_env = args.nb_env self.network = net.to(device) self.optimizer = optim_fn(self.network.parameters()) self.device = device self.initial_step_count = initial_step_count self.log_id_dir = log_id_dir self.epoch_len = args.epoch_len self.summary_freq = args.summary_freq self.logger = logger self.summary_writer = SummaryWriter( os.path.join(log_id_dir, "rank{}".format(global_rank)) ) self.saver = SimpleModelSaver(log_id_dir) self.local_rank = local_rank self.global_rank = global_rank self.world_size = world_size self.updater = DistribUpdater( self.optimizer, self.network, args.grad_norm_clip, world_size, not args.no_divide, ) if args.load_network: self.network = self.load_network(self.network, args.load_network) logger.info("Reloaded network from {}".format(args.load_network)) if args.load_optim: self.optimizer = self.load_optim(self.optimizer, args.load_optim) logger.info("Reloaded optimizer from {}".format(args.load_optim)) self.network.train()
class ActorLearnerHost(Container): @classmethod def as_remote(cls, num_cpus=None, num_gpus=None, memory=None, object_store_memory=None, resources=None): return ray.remote( num_cpus=num_cpus, num_gpus=num_gpus, memory=memory, object_store_memory=object_store_memory, resources=resources)(cls) def __init__( self, args, log_id_dir, initial_step_count, rank=0, ): # ARGS TO STATE VARS self._args = args self.nb_learners = args.nb_learners self.nb_workers = args.nb_workers self.rank = rank self.nb_step = args.nb_step self.nb_env = args.nb_env self.initial_step_count = initial_step_count self.epoch_len = args.epoch_len self.summary_freq = args.summary_freq self.nb_learn_batch = args.nb_learn_batch self.rollout_queue_size = args.rollout_queue_size # can be none if rank != 0 self.log_id_dir = log_id_dir # load saved registry classes REGISTRY.load_extern_classes(log_id_dir) # ENV (temporary) env_cls = REGISTRY.lookup_env(args.env) env = env_cls.from_args(args, 0) env_action_space, env_observation_space, env_gpu_preprocessor = \ env.action_space, env.observation_space, env.gpu_preprocessor env.close() # NETWORK torch.manual_seed(args.seed) device = torch.device("cuda") # ray handles gpus torch.backends.cudnn.benchmark = True output_space = REGISTRY.lookup_output_space( args.actor_worker, env_action_space) if args.custom_network: net_cls = REGISTRY.lookup_network(args.custom_network) else: net_cls = ModularNetwork net = net_cls.from_args( args, env_observation_space, output_space, env_gpu_preprocessor, REGISTRY ) self.network = net.to(device) # TODO: this is a hack, remove once queuer puts rollouts on the correct device self.network.device = device self.device = device self.network.train() # OPTIMIZER def optim_fn(x): return torch.optim.RMSprop(x, lr=args.lr, eps=1e-5, alpha=0.99) if args.nb_learners > 1: self.optimizer = NCCLOptimizer(optim_fn, self.network, self.nb_learners) else: self.optimizer = optim_fn(self.network.parameters()) # LEARNER / EXP rwd_norm = REGISTRY.lookup_reward_normalizer( args.rwd_norm).from_args(args) actor_cls = REGISTRY.lookup_actor(args.actor_host) builder = actor_cls.exp_spec_builder( env.observation_space, env.action_space, net.internal_space(), args.nb_env * args.nb_learn_batch ) w_builder = REGISTRY.lookup_actor(args.actor_worker).exp_spec_builder( env.observation_space, env.action_space, net.internal_space(), args.nb_env ) actor = actor_cls.from_args(args, env.action_space) learner = REGISTRY.lookup_learner(args.learner).from_args(args, rwd_norm) exp_cls = REGISTRY.lookup_exp(args.exp).from_args(args, builder) self.actor = actor self.learner = learner self.exp = exp_cls.from_args(args, builder).to(device) # Rank 0 setup, load network/optimizer and create SummaryWriter/Saver if rank == 0: if args.load_network: self.network = self.load_network(self.network, args.load_network) print('Reloaded network from {}'.format(args.load_network)) if args.load_optim: self.optimizer = self.load_optim(self.optimizer, args.load_optim) print('Reloaded optimizer from {}'.format(args.load_optim)) print('Network parameters: ' + str(self.count_parameters(net))) self.summary_writer = SummaryWriter(log_id_dir) self.saver = SimpleModelSaver(log_id_dir) def run(self, workers, profile=False): if profile: try: from pyinstrument import Profiler except: raise ImportError('You must install pyinstrument to use profiling.') profiler = Profiler() profiler.start() # setup queuer rollout_queuer = RolloutQueuerAsync(workers, self.nb_learn_batch, self.rollout_queue_size) rollout_queuer.start() # initial setup global_step_count = self.initial_step_count next_save = self.init_next_save(self.initial_step_count, self.epoch_len) prev_step_t = time() ep_rewards = torch.zeros(self.nb_env) start_time = time() # loop until total number steps print('{} starting training'.format(self.rank)) while not self.done(global_step_count): self.exp.clear() # Get batch from queue rollouts, terminal_rewards, terminal_infos = rollout_queuer.get() # Iterate forward on batch self.exp.write_exps(rollouts) # keep a copy of terminals on the cpu it's faster rollout_terminals = torch.stack(self.exp['terminals']).numpy() self.exp.to(self.device) r = self.exp.read() internals = {k: ts[0].unbind(0) for k, ts in r.internals.items()} for obs, rewards, terminals in zip( r.observations, r.rewards, rollout_terminals ): _, h_exp, internals = self.actor.act(self.network, obs, internals) self.exp.write_actor(h_exp, no_env=True) # where returns a single element tuple with the indexes terminal_inds = np.where(terminals)[0] for i in terminal_inds: for k, v in self.network.new_internals(self.device).items(): internals[k][i] = v # compute loss loss_dict, metric_dict = self.learner.compute_loss( self.network, self.exp.read(), r.next_observation, internals ) total_loss = torch.sum( torch.stack(tuple(loss for loss in loss_dict.values())) ) self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() # Perform state updates global_step_count += self.nb_env * self.nb_learn_batch * len(r.terminals) * self.nb_learners # if rank 0 write summaries and save # and send parameters to workers async if self.rank == 0: # TODO: this could be parallelized, chunk by nb learners self.synchronize_worker_parameters(workers, global_step_count) # possible save if global_step_count >= next_save: self.saver.save_state_dicts( self.network, global_step_count, self.optimizer ) next_save += self.epoch_len # write reward summaries if any(terminal_rewards): terminal_rewards = list(filter(lambda x: x is not None, terminal_rewards)) self.summary_writer.add_scalar( 'reward', np.mean(terminal_rewards), global_step_count ) # write infos if any(terminal_infos): terminal_infos = list(filter(lambda x: x is not None, terminal_infos)) float_keys = [ k for k, v in terminal_infos[0].items() if type(v) == float ] terminal_infos_dlist = listd_to_dlist(terminal_infos) for k in float_keys: self.summary_writer.add_scalar( f'info/{k}', np.mean(terminal_infos_dlist[k]), global_step_count ) # write summaries cur_step_t = time() if cur_step_t - prev_step_t > self.summary_freq: print('Rank {} Metrics:'.format(self.rank), rollout_queuer.metrics()) if self.rank == 0: self.write_summaries( self.summary_writer, global_step_count, total_loss, loss_dict, metric_dict, self.network.named_parameters() ) prev_step_t = cur_step_t rollout_queuer.close() print('{} stopped training'.format(self.rank)) if profile: profiler.stop() print(profiler.output_text(unicode=True, color=True)) def done(self, global_step_count): return global_step_count >= self.nb_step def close(self): pass def get_parameters(self): params = [p.cpu() for p in self.network.parameters()] return params def synchronize_worker_parameters(self, workers, global_step_count=0, blocking=False): parameters = self.get_parameters() futures = [w.set_weights.remote(parameters) for w in workers] if global_step_count != 0: futures.extend([w.set_global_step.remote(global_step_count) for w in workers]) if blocking: ray.get(futures) def _rank0_nccl_port_init(self): ip = ray.services.get_node_ip_address() port = find_free_port() nccl_addr = "tcp://{ip}:{port}".format(ip=ip, port=port) return nccl_addr, ip, port def _nccl_init(self, nccl_addr, nccl_ip, nccl_port): self.nccl_ip, self.nccl_addr, self.nccl_port = nccl_ip, nccl_addr, nccl_port print('Rank {} calling init_process_group. Addr: {}'.format(self.rank, nccl_addr)) # from https://github.com/pytorch/pytorch/blob/master/test/simulate_nccl_errors.py store = dist.TCPStore(self.nccl_ip, self.nccl_port, self.nb_learners, self.rank == 0) process_group = dist.ProcessGroupNCCL(store, self.rank, self.nb_learners) print('Rank {} initialized process group.'.format(self.rank)) process_group.barrier() print('Rank {} process group barrier finished.'.format(self.rank)) self.process_group = process_group # set optimizer process_group self.optimizer.set_process_group(self.process_group) def _sync_peer_parameters(self): print('Rank {} syncing parameters.'.format(self.rank)) self.process_group.barrier() for p in self.network.parameters(): self.process_group.allreduce(p.data) p.data = p.data / self.nb_learners print('Rank {} parameters synced.'.format(self.rank))
def __init__( self, args, log_id_dir, initial_step_count, rank=0, ): # ARGS TO STATE VARS self._args = args self.nb_learners = args.nb_learners self.nb_workers = args.nb_workers self.rank = rank self.nb_step = args.nb_step self.nb_env = args.nb_env self.initial_step_count = initial_step_count self.epoch_len = args.epoch_len self.summary_freq = args.summary_freq self.nb_learn_batch = args.nb_learn_batch self.rollout_queue_size = args.rollout_queue_size # can be none if rank != 0 self.log_id_dir = log_id_dir # load saved registry classes REGISTRY.load_extern_classes(log_id_dir) # ENV (temporary) env_cls = REGISTRY.lookup_env(args.env) env = env_cls.from_args(args, 0) env_action_space, env_observation_space, env_gpu_preprocessor = \ env.action_space, env.observation_space, env.gpu_preprocessor env.close() # NETWORK torch.manual_seed(args.seed) device = torch.device("cuda") # ray handles gpus torch.backends.cudnn.benchmark = True output_space = REGISTRY.lookup_output_space( args.actor_worker, env_action_space) if args.custom_network: net_cls = REGISTRY.lookup_network(args.custom_network) else: net_cls = ModularNetwork net = net_cls.from_args( args, env_observation_space, output_space, env_gpu_preprocessor, REGISTRY ) self.network = net.to(device) # TODO: this is a hack, remove once queuer puts rollouts on the correct device self.network.device = device self.device = device self.network.train() # OPTIMIZER def optim_fn(x): return torch.optim.RMSprop(x, lr=args.lr, eps=1e-5, alpha=0.99) if args.nb_learners > 1: self.optimizer = NCCLOptimizer(optim_fn, self.network, self.nb_learners) else: self.optimizer = optim_fn(self.network.parameters()) # LEARNER / EXP rwd_norm = REGISTRY.lookup_reward_normalizer( args.rwd_norm).from_args(args) actor_cls = REGISTRY.lookup_actor(args.actor_host) builder = actor_cls.exp_spec_builder( env.observation_space, env.action_space, net.internal_space(), args.nb_env * args.nb_learn_batch ) w_builder = REGISTRY.lookup_actor(args.actor_worker).exp_spec_builder( env.observation_space, env.action_space, net.internal_space(), args.nb_env ) actor = actor_cls.from_args(args, env.action_space) learner = REGISTRY.lookup_learner(args.learner).from_args(args, rwd_norm) exp_cls = REGISTRY.lookup_exp(args.exp).from_args(args, builder) self.actor = actor self.learner = learner self.exp = exp_cls.from_args(args, builder).to(device) # Rank 0 setup, load network/optimizer and create SummaryWriter/Saver if rank == 0: if args.load_network: self.network = self.load_network(self.network, args.load_network) print('Reloaded network from {}'.format(args.load_network)) if args.load_optim: self.optimizer = self.load_optim(self.optimizer, args.load_optim) print('Reloaded optimizer from {}'.format(args.load_optim)) print('Network parameters: ' + str(self.count_parameters(net))) self.summary_writer = SummaryWriter(log_id_dir) self.saver = SimpleModelSaver(log_id_dir)
def main(args): # host needs to broadcast timestamp so all procs create the same log dir if rank == 0: timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') log_id = make_log_id_from_timestamp( args.tag, args.mode_name, args.agent, args.vision_network + args.network_body, timestamp) log_id_dir = os.path.join(args.log_dir, args.env_id, log_id) os.makedirs(log_id_dir) saver = SimpleModelSaver(log_id_dir) print_ascii_logo() else: timestamp = None timestamp = comm.bcast(timestamp, root=0) if rank != 0: log_id = make_log_id_from_timestamp( args.tag, args.mode_name, args.agent, args.vision_network + args.network_body, timestamp) log_id_dir = os.path.join(args.log_dir, args.env_id, log_id) comm.Barrier() # construct env seed = args.seed if rank == 0 else args.seed + ( args.nb_env * (rank - 1)) # unique seed per process env = make_env(args, seed) # construct network torch.manual_seed(args.seed) network_head_shapes = get_head_shapes(env.action_space, env.engine, args.agent) network = make_network(env.observation_space, network_head_shapes, args) # sync network params if rank == 0: for v in network.parameters(): comm.Bcast(v.detach().cpu().numpy(), root=0) print('Root variables synced') else: # can just use the numpy buffers variables = [v.detach().cpu().numpy() for v in network.parameters()] for v in variables: comm.Bcast(v, root=0) for shared_v, model_v in zip(variables, network.parameters()): model_v.data.copy_(torch.from_numpy(shared_v), non_blocking=True) print('{} variables synced'.format(rank)) # construct agent # host is always the first gpu, workers are distributed evenly across the rest if len(args.gpu_id) > 1: # nargs is always a list if rank == 0: gpu_id = args.gpu_id[0] else: gpu_id = args.gpu_id[1:][(rank - 1) % len(args.gpu_id[1:])] else: gpu_id = args.gpu_id[-1] os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") cudnn = True # disable cudnn for dynamic batches if rank == 0 and args.max_dynamic_batch > 0: cudnn = False torch.backends.cudnn.benchmark = cudnn agent = make_agent(network, device, env.engine, env.gpu_preprocessor, args) # workers if rank != 0: logger = make_logger( 'ImpalaWorker{}'.format(rank), os.path.join(log_id_dir, 'train_log{}.txt'.format(rank))) summary_writer = SummaryWriter(os.path.join(log_id_dir, str(rank))) container = ImpalaWorker(agent, env, args.nb_env, logger, summary_writer, use_local_buffers=args.use_local_buffers) # Run the container if args.profile: try: from pyinstrument import Profiler except: raise ImportError( 'You must install pyinstrument to use profiling.') profiler = Profiler() profiler.start() container.run() profiler.stop() print(profiler.output_text(unicode=True, color=True)) else: container.run() env.close() # host else: logger = make_logger( 'ImpalaHost', os.path.join(log_id_dir, 'train_log{}.txt'.format(rank))) summary_writer = SummaryWriter(os.path.join(log_id_dir, str(rank))) log_args(logger, args) write_args_file(log_id_dir, args) logger.info('Network Parameter Count: {}'.format( count_parameters(network))) # no need for the env anymore env.close() # Construct the optimizer def make_optimizer(params): opt = torch.optim.RMSprop(params, lr=args.learning_rate, eps=1e-5, alpha=0.99) return opt container = ImpalaHost(agent, comm, make_optimizer, summary_writer, args.summary_frequency, saver, args.epoch_len, args.host_training_info_interval, use_local_buffers=args.use_local_buffers) # Run the container if args.profile: try: from pyinstrument import Profiler except: raise ImportError( 'You must install pyinstrument to use profiling.') profiler = Profiler() profiler.start() if args.max_dynamic_batch > 0: container.run(args.max_dynamic_batch, args.max_queue_length, args.max_train_steps, dynamic=True, min_dynamic_batch=args.min_dynamic_batch) else: container.run(args.num_rollouts_in_batch, args.max_queue_length, args.max_train_steps) profiler.stop() print(profiler.output_text(unicode=True, color=True)) else: if args.max_dynamic_batch > 0: container.run(args.max_dynamic_batch, args.max_queue_length, args.max_train_steps, dynamic=True, min_dynamic_batch=args.min_dynamic_batch) else: container.run(args.num_rollouts_in_batch, args.max_queue_length, args.max_train_steps)
def main(args): # construct logging objects print_ascii_logo() log_id = make_log_id(args.tag, args.mode_name, args.agent, args.vision_network + args.network_body) log_id_dir = os.path.join(args.log_dir, args.env_id, log_id) os.makedirs(log_id_dir) logger = make_logger('Local', os.path.join(log_id_dir, 'train_log.txt')) summary_writer = SummaryWriter(log_id_dir) saver = SimpleModelSaver(log_id_dir) log_args(logger, args) write_args_file(log_id_dir, args) # construct env env = make_env(args, args.seed) # construct network torch.manual_seed(args.seed) network_head_shapes = get_head_shapes(env.action_space, env.engine, args.agent) network = make_network(env.observation_space, network_head_shapes, args) logger.info('Network Parameter Count: {}'.format( count_parameters(network))) # construct agent os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") torch.backends.cudnn.benchmark = True agent = make_agent(network, device, env.engine, env.gpu_preprocessor, args) # Construct the Container def make_optimizer(params): opt = torch.optim.RMSprop(params, lr=args.learning_rate, eps=1e-5, alpha=0.99) return opt container = Local(agent, env, make_optimizer, args.epoch_len, args.nb_env, logger, summary_writer, args.summary_frequency, saver) # if running an eval thread create eval env, agent, & logger if args.nb_eval_env > 0: # replace args num envs & seed eval_args = deepcopy(args) eval_args.seed = args.seed + args.nb_env # env and agent eval_args.nb_env = args.nb_eval_env eval_env = make_env(eval_args, eval_args.seed) eval_net = make_network(eval_env.observation_space, network_head_shapes, eval_args) eval_agent = make_agent(eval_net, device, eval_env.engine, eval_env.gpu_preprocessor, eval_args) eval_net.load_state_dict(network.state_dict()) # logger eval_logger = make_logger('LocalEval', os.path.join(log_id_dir, 'eval_log.txt')) evaluation_container = EvaluationThread( network, eval_agent, eval_env, args.nb_eval_env, eval_logger, summary_writer, args.eval_step_rate, override_step_count_fn=lambda: container. local_step_count # wire local containers step count into eval ) evaluation_container.start() # Run the container if args.profile: try: from pyinstrument import Profiler except: raise ImportError( 'You must install pyinstrument to use profiling.') profiler = Profiler() profiler.start() container.run(10e3) profiler.stop() print(profiler.output_text(unicode=True, color=True)) else: container.run(args.max_train_steps) env.close() if args.nb_eval_env > 0: evaluation_container.stop() eval_env.close()
def main(args): # host needs to broadcast timestamp so all procs create the same log dir if rank == 0: timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') log_id = make_log_id_from_timestamp( args.tag, args.mode_name, args.agent, args.vision_network + args.network_body, timestamp) log_id_dir = os.path.join(args.log_dir, args.env_id, log_id) os.makedirs(log_id_dir) saver = SimpleModelSaver(log_id_dir) print_ascii_logo() else: timestamp = None timestamp = comm.bcast(timestamp, root=0) if rank != 0: log_id = make_log_id_from_timestamp( args.tag, args.mode_name, args.agent, args.vision_network + args.network_body, timestamp) log_id_dir = os.path.join(args.log_dir, args.env_id, log_id) comm.Barrier() # construct env seed = args.seed if rank == 0 else args.seed + ( args.nb_env * (rank - 1)) # unique seed per process # don't make a ton of envs if host if rank == 0: env_args = deepcopy(args) env_args.nb_env = 1 env = make_env(env_args, seed) else: env = make_env(args, seed) # construct network torch.manual_seed(args.seed) network_head_shapes = get_head_shapes(env.action_space, env.engine, args.agent) network = make_network(env.observation_space, network_head_shapes, args) # sync network params if rank == 0: for v in network.parameters(): comm.Bcast(v.detach().cpu().numpy(), root=0) print('Root variables synced') else: # can just use the numpy buffers variables = [v.detach().cpu().numpy() for v in network.parameters()] for v in variables: comm.Bcast(v, root=0) for shared_v, model_v in zip(variables, network.parameters()): model_v.data.copy_(torch.from_numpy(shared_v), non_blocking=True) print('{} variables synced'.format(rank)) # host is rank 0 if rank != 0: # construct logger logger = make_logger( 'ToweredWorker{}'.format(rank), os.path.join(log_id_dir, 'train_log_rank{}.txt'.format(rank))) summary_writer = SummaryWriter( os.path.join(log_id_dir, 'rank{}'.format(rank))) # construct agent # distribute evenly across gpus if isinstance(args.gpu_id, list): gpu_id = args.gpu_id[(rank - 1) % len(args.gpu_id)] else: gpu_id = args.gpu_id os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") torch.backends.cudnn.benchmark = True agent = make_agent(network, device, env.engine, env.gpu_preprocessor, args) # construct container container = ToweredWorker(agent, env, args.nb_env, logger, summary_writer, args.summary_frequency) # Run the container try: container.run() finally: env.close() # host else: logger = make_logger( 'ToweredHost', os.path.join(log_id_dir, 'train_log_rank{}.txt'.format(rank))) log_args(logger, args) write_args_file(log_id_dir, args) logger.info('Network Parameter Count: {}'.format( count_parameters(network))) # no need for the env anymore env.close() # Construct the optimizer def make_optimizer(params): opt = torch.optim.RMSprop(params, lr=args.learning_rate, eps=1e-5, alpha=0.99) return opt container = ToweredHost(comm, args.num_grads_to_drop, network, make_optimizer, saver, args.epoch_len, logger) # Run the container if args.profile: try: from pyinstrument import Profiler except: raise ImportError( 'You must install pyinstrument to use profiling.') profiler = Profiler() profiler.start() container.run(10e3) profiler.stop() print(profiler.output_text(unicode=True, color=True)) else: container.run(args.max_train_steps)
def __init__(self, args, logger, log_id_dir, initial_step_count): # ENV engine = REGISTRY.lookup_engine(args.env) env_cls = REGISTRY.lookup_env(args.env) mgr_cls = REGISTRY.lookup_manager(args.manager) env_mgr = mgr_cls.from_args(args, engine, env_cls) # NETWORK torch.manual_seed(args.seed) if torch.cuda.is_available() and args.gpu_id >= 0: device = torch.device("cuda:{}".format(args.gpu_id)) torch.backends.cudnn.benchmark = True else: device = torch.device("cpu") output_space = REGISTRY.lookup_output_space(args.agent, env_mgr.action_space) if args.custom_network: net_cls = REGISTRY.lookup_network(args.custom_network) else: net_cls = ModularNetwork net = net_cls.from_args( args, env_mgr.gpu_preprocessor.observation_space, output_space, env_mgr.gpu_preprocessor, REGISTRY, ) logger.info("Network parameters: " + str(self.count_parameters(net))) def optim_fn(x): if args.optim == "RMSprop": return torch.optim.RMSprop(x, lr=args.lr, eps=1e-5, alpha=0.99) elif args.optim == "Adam": return torch.optim.Adam(x, lr=args.lr, eps=1e-5) def warmup_schedule(back_step): return back_step / args.warmup if back_step < args.warmup else 1.0 # AGENT rwd_norm = REGISTRY.lookup_reward_normalizer( args.rwd_norm).from_args(args) agent_cls = REGISTRY.lookup_agent(args.agent) builder = agent_cls.exp_spec_builder( env_mgr.observation_space, env_mgr.action_space, net.internal_space(), env_mgr.nb_env, ) agent = agent_cls.from_args(args, rwd_norm, env_mgr.action_space, builder) self.agent = agent.to(device) self.nb_step = args.nb_step self.env_mgr = env_mgr self.nb_env = args.nb_env self.network = net.to(device) self.optimizer = optim_fn(self.network.parameters()) self.scheduler = LambdaLR(self.optimizer, warmup_schedule) self.device = device self.initial_step_count = initial_step_count self.log_id_dir = log_id_dir self.epoch_len = args.epoch_len self.summary_freq = args.summary_freq self.logger = logger self.summary_writer = SummaryWriter(log_id_dir) self.saver = SimpleModelSaver(log_id_dir) self.updater = LocalUpdater(self.optimizer, self.network, args.grad_norm_clip) if args.load_network: self.network = self.load_network(self.network, args.load_network) logger.info("Reloaded network from {}".format(args.load_network)) if args.load_optim: self.optimizer = self.load_optim(self.optimizer, args.load_optim) logger.info("Reloaded optimizer from {}".format(args.load_optim)) self.network.train()
class Local(Container): def __init__(self, args, logger, log_id_dir, initial_step_count): # ENV engine = REGISTRY.lookup_engine(args.env) env_cls = REGISTRY.lookup_env(args.env) mgr_cls = REGISTRY.lookup_manager(args.manager) env_mgr = mgr_cls.from_args(args, engine, env_cls) # NETWORK torch.manual_seed(args.seed) if torch.cuda.is_available() and args.gpu_id >= 0: device = torch.device("cuda:{}".format(args.gpu_id)) torch.backends.cudnn.benchmark = True else: device = torch.device("cpu") output_space = REGISTRY.lookup_output_space(args.agent, env_mgr.action_space) if args.custom_network: net_cls = REGISTRY.lookup_network(args.custom_network) else: net_cls = ModularNetwork net = net_cls.from_args( args, env_mgr.gpu_preprocessor.observation_space, output_space, env_mgr.gpu_preprocessor, REGISTRY, ) logger.info("Network parameters: " + str(self.count_parameters(net))) def optim_fn(x): if args.optim == "RMSprop": return torch.optim.RMSprop(x, lr=args.lr, eps=1e-5, alpha=0.99) elif args.optim == "Adam": return torch.optim.Adam(x, lr=args.lr, eps=1e-5) def warmup_schedule(back_step): return back_step / args.warmup if back_step < args.warmup else 1.0 # AGENT rwd_norm = REGISTRY.lookup_reward_normalizer( args.rwd_norm).from_args(args) agent_cls = REGISTRY.lookup_agent(args.agent) builder = agent_cls.exp_spec_builder( env_mgr.observation_space, env_mgr.action_space, net.internal_space(), env_mgr.nb_env, ) agent = agent_cls.from_args(args, rwd_norm, env_mgr.action_space, builder) self.agent = agent.to(device) self.nb_step = args.nb_step self.env_mgr = env_mgr self.nb_env = args.nb_env self.network = net.to(device) self.optimizer = optim_fn(self.network.parameters()) self.scheduler = LambdaLR(self.optimizer, warmup_schedule) self.device = device self.initial_step_count = initial_step_count self.log_id_dir = log_id_dir self.epoch_len = args.epoch_len self.summary_freq = args.summary_freq self.logger = logger self.summary_writer = SummaryWriter(log_id_dir) self.saver = SimpleModelSaver(log_id_dir) self.updater = LocalUpdater(self.optimizer, self.network, args.grad_norm_clip) if args.load_network: self.network = self.load_network(self.network, args.load_network) logger.info("Reloaded network from {}".format(args.load_network)) if args.load_optim: self.optimizer = self.load_optim(self.optimizer, args.load_optim) logger.info("Reloaded optimizer from {}".format(args.load_optim)) self.network.train() def run(self): step_count = self.initial_step_count next_save = self.init_next_save(self.initial_step_count, self.epoch_len) prev_step_t = time() ep_rewards = torch.zeros(self.nb_env) obs = dtensor_to_dev(self.env_mgr.reset(), self.device) internals = listd_to_dlist([ self.network.new_internals(self.device) for _ in range(self.nb_env) ]) start_time = time() while step_count < self.nb_step: actions, internals = self.agent.act(self.network, obs, internals) next_obs, rewards, terminals, infos = self.env_mgr.step(actions) next_obs = dtensor_to_dev(next_obs, self.device) self.agent.observe( obs, rewards.to(self.device).float(), terminals.to(self.device).float(), infos, ) # Perform state updates step_count += self.nb_env ep_rewards += rewards.float() obs = next_obs term_rewards, term_infos = [], [] for i, terminal in enumerate(terminals): if terminal: for k, v in self.network.new_internals( self.device).items(): internals[k][i] = v term_rewards.append(ep_rewards[i].item()) if infos[i]: term_infos.append(infos[i]) ep_rewards[i].zero_() if term_rewards: term_reward = np.mean(term_rewards) delta_t = time() - start_time self.logger.info("STEP: {} REWARD: {} STEP/S: {}".format( step_count, term_reward, (step_count - self.initial_step_count) / delta_t, )) self.summary_writer.add_scalar("reward", term_reward, step_count) if term_infos: float_keys = [ k for k, v in term_infos[0].items() if type(v) == float ] term_infos_dlist = listd_to_dlist(term_infos) for k in float_keys: self.summary_writer.add_scalar( f"info/{k}", np.mean(term_infos_dlist[k]), step_count, ) if step_count >= next_save: self.saver.save_state_dicts(self.network, step_count, self.optimizer) next_save += self.epoch_len # Learn if self.agent.is_ready(): loss_dict, metric_dict = self.agent.learn_step( self.updater, self.network, next_obs, internals, ) total_loss = sum(loss_dict.values()) epoch = step_count / self.nb_env self.scheduler.step(epoch) self.agent.clear() for k, vs in internals.items(): internals[k] = [v.detach() for v in vs] # write summaries cur_step_t = time() if cur_step_t - prev_step_t > self.summary_freq: self.write_summaries( self.summary_writer, step_count, total_loss, loss_dict, metric_dict, self.network.named_parameters(), ) prev_step_t = cur_step_t def close(self): return self.env_mgr.close()