def train(flags): # pylint: disable=too-many-branches, too-many-statements if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") plogger = file_writer.FileWriter(xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))) if flags.num_buffers is None: # Set sensible default for num_buffers. flags.num_buffers = max(2 * flags.num_actors, flags.batch_size) if flags.num_actors >= flags.num_buffers: raise ValueError("num_buffers should be larger than num_actors") if flags.num_buffers < flags.batch_size: raise ValueError("num_buffers should be larger than batch_size") T = flags.unroll_length B = flags.batch_size flags.device = None if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") flags.device = torch.device("cuda") else: logging.info("Not using CUDA.") flags.device = torch.device("cpu") gym_env = environment.create_gym_env(flags, seed=flags.seed) env = environment.Environment(flags, gym_env) env.initial() kg_model = load_kg_model_for_env(flags, gym_env) model = model_for_env(flags, gym_env, kg_model) buffers = create_buffers(flags, gym_env.observation_space, model.num_actions) cumulative_steps = torch.zeros(1, dtype=int).share_memory_() model.share_memory() ctx = mp.get_context("fork") tester_processes = [] if flags.test_interval > 0: splits = ['test', 'train'] for split in splits: tester = ctx.Process( target=test, args=(flags, env, cumulative_steps, model, split, plogger), ) tester_processes.append(tester) tester.start() # Add initial RNN state. initial_agent_state_buffers = [] for _ in range(flags.num_buffers): state = model.initial_state(batch_size=1) for t in state: t.share_memory_() initial_agent_state_buffers.append(state) actor_processes = [] free_queue = ctx.SimpleQueue() full_queue = ctx.SimpleQueue() for i in range(flags.num_actors): actor = ctx.Process( target=act, args=( flags, env, i, free_queue, full_queue, model, buffers, initial_agent_state_buffers, ), ) actor.start() actor_processes.append(actor) learner_model = model_for_env(flags, gym_env, kg_model).to(device=flags.device) optimizer = torch.optim.RMSprop( learner_model.parameters(), lr=flags.learning_rate, momentum=flags.momentum, eps=flags.epsilon, alpha=flags.alpha, ) def lr_lambda(epoch): return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) logger = logging.getLogger("logfile") stat_keys = [ "total_loss", "mean_episode_return", "pg_loss", "baseline_loss", "entropy_loss", ] logger.info("# Step\t%s", "\t".join(stat_keys)) step, stats = 0, {} def batch_and_learn(i, lock=threading.Lock()): """Thread target for the learning process.""" nonlocal step, stats, cumulative_steps timings = prof.Timings() while step < flags.total_steps: timings.reset() batch, agent_state = get_batch( flags, free_queue, full_queue, buffers, initial_agent_state_buffers, timings, ) stats = learn(flags, model, learner_model, batch, agent_state, optimizer, scheduler) timings.time("learn") with lock: to_log = dict(step=step) to_log.update({k: stats[k] for k in stat_keys}) plogger.log(to_log) step += T * B cumulative_steps[0] = step if i == 0: logging.info("Batch and learn: %s", timings.summary()) for m in range(flags.num_buffers): free_queue.put(m) threads = [] for i in range(flags.num_learner_threads): thread = threading.Thread(target=batch_and_learn, name="batch-and-learn-%d" % i, args=(i, )) thread.start() threads.append(thread) def checkpoint(): if flags.disable_checkpoint: return logging.info("Saving checkpoint to %s", checkpointpath) torch.save( { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "flags": vars(flags), }, checkpointpath, ) timer = timeit.default_timer try: last_checkpoint_time = timer() while step < flags.total_steps: start_step = step start_time = timer() time.sleep(5) if timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timer() sps = (step - start_step) / (timer() - start_time) if stats.get("episode_returns", None): mean_return = ("Return per episode: %.1f. " % stats["mean_episode_return"]) else: mean_return = "" total_loss = stats.get("total_loss", float("inf")) logging.info( "Steps %i @ %.1f SPS. Loss %f. %sStats:\n%s", step, sps, total_loss, mean_return, pprint.pformat(stats), ) except KeyboardInterrupt: return # Try joining actors then quit. else: for thread in threads: thread.join() logging.info("Learning finished after %d steps.", step) finally: for _ in range(flags.num_actors): free_queue.put(None) for actor in actor_processes: actor.join(timeout=1) for tester in tester_processes: tester.join(timeout=1) checkpoint() plogger.close()
def train(flags): # pylint: disable=too-many-branches, too-many-statements if flags.xpid is None: flags.xpid = 'torchbeast-%s' % time.strftime('%Y%m%d-%H%M%S') plogger = file_writer.FileWriter( xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir, symlink_latest=False, ) checkpointpath = os.path.expandvars( os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid, 'model.tar'))) T = flags.unroll_length B = flags.batch_size flags.device = None if not flags.disable_cuda and torch.cuda.is_available(): logging.info('Using CUDA.') flags.device = torch.device('cuda') else: logging.info('Not using CUDA.') flags.device = torch.device('cpu') env = Net.create_env(flags) model = Net.make(flags, env) buffers = create_buffers(env.observation_space, len(env.action_space), flags) model.share_memory() actor_processes = [] ctx = mp.get_context('fork') free_queue = ctx.SimpleQueue() full_queue = ctx.SimpleQueue() for i in range(flags.num_actors): actor = ctx.Process( target=act, args=(i, free_queue, full_queue, model, buffers, flags)) actor.start() actor_processes.append(actor) learner_model = Net.make(flags, env).to(device=flags.device) optimizer = torch.optim.RMSprop( learner_model.parameters(), lr=flags.learning_rate, momentum=flags.momentum, eps=flags.epsilon, alpha=flags.alpha) def lr_lambda(epoch): return 1 - min(epoch * T * B, flags.total_frames) / flags.total_frames scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) if flags.resume: save = torch.load(flags.resume, map_location='cpu') learner_model.load_state_dict(save['model_state_dict']) optimizer.load_state_dict(save['optimizer_state_dict']) if flags.resume_scheduler: scheduler.load_state_dict(save['scheduler_state_dict']) # tune only the embedding layer if flags.resume_strategy == 'emb': keep = [] for group in optimizer.param_groups: if group['params'][0].size() == (len(learner_model.vocab), flags.demb): keep.append(group) optimizer.param_groups = keep logger = logging.getLogger('logfile') stat_keys = [ 'total_loss', 'mean_episode_return', 'pg_loss', 'baseline_loss', 'entropy_loss', 'aux_loss', 'mean_win_rate', 'mean_episode_len', ] logger.info('# Step\t%s', '\t'.join(stat_keys)) frames, stats = 0, {} def batch_and_learn(i, lock=threading.Lock()): """Thread target for the learning process.""" nonlocal frames, stats timings = prof.Timings() while frames < flags.total_frames: timings.reset() batch = get_batch(free_queue, full_queue, buffers, flags, timings) stats = learn(model, learner_model, batch, optimizer, scheduler, flags) timings.time('learn') with lock: to_log = dict(frames=frames) to_log.update({k: stats[k] for k in stat_keys}) plogger.log(to_log) frames += T * B if i == 0: logging.info('Batch and learn: %s', timings.summary()) for m in range(flags.num_buffers): free_queue.put(m) threads = [] for i in range(flags.num_threads): thread = threading.Thread( target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,)) thread.start() threads.append(thread) def checkpoint(): if flags.disable_checkpoint: return logging.info('Saving checkpoint to %s', checkpointpath) torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'flags': vars(flags), }, checkpointpath) timer = timeit.default_timer try: last_checkpoint_time = timer() while frames < flags.total_frames: start_frames = frames start_time = timer() time.sleep(5) if timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timer() fps = (frames - start_frames) / (timer() - start_time) if stats.get('episode_returns', None): mean_return = 'Return per episode: %.1f. ' % stats[ 'mean_episode_return'] else: mean_return = '' total_loss = stats.get('total_loss', float('inf')) logging.info('After %i frames: loss %f @ %.1f fps. %sStats:\n%s', frames, total_loss, fps, mean_return, pprint.pformat(stats)) except KeyboardInterrupt: return # Try joining actors then quit. else: for thread in threads: thread.join() logging.info('Learning finished after %d frames.', frames) finally: for _ in range(flags.num_actors): free_queue.put(None) for actor in actor_processes: actor.join(timeout=1) checkpoint() plogger.close()
def train(flags, rank=0, barrier=None, device="cuda:0", gossip_buffer=None): if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") plogger = file_writer.FileWriter(xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir) if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") flags.learner_device = torch.device(device) flags.actor_device = torch.device(device) else: logging.info("Not using CUDA.") flags.learner_device = torch.device("cpu") flags.actor_device = torch.device("cpu") # The queue the learner threads will get their data from. # Setting `minimum_batch_size == maximum_batch_size` # makes the batch size static. We could make it dynamic, but that # requires a loss (and learning rate schedule) that's batch size # independent. learner_queue = actorpool.BatchingQueue( batch_dim=1, minimum_batch_size=flags.batch_size, maximum_batch_size=flags.batch_size, check_inputs=True, ) # The "batcher", a queue for the inference call. Will yield # "batch" objects with `get_inputs` and `set_outputs` methods. # The batch size of the tensors will be dynamic. inference_batcher = actorpool.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=512, timeout_ms=100, check_outputs=True, ) model = Net( observation_shape=flags.observation_shape, hidden_size=flags.hidden_size, num_actions=flags.num_actions, use_lstm=flags.use_lstm, ) model = model.to(device=flags.learner_device) actor_model = Net( observation_shape=flags.observation_shape, hidden_size=flags.hidden_size, num_actions=flags.num_actions, use_lstm=flags.use_lstm, ) actor_model.to(device=flags.actor_device) # The ActorPool that will accept connections from actor clients. actors = actorpool.ActorPool( unroll_length=flags.unroll_length, learner_queue=learner_queue, inference_batcher=inference_batcher, server_address=flags.address, initial_agent_state=actor_model.initial_state(), ) def run(): try: actors.run() except Exception as e: logging.error("Exception in actorpool thread!") traceback.print_exc() print() raise e actorpool_thread = threading.Thread(target=run, name="actorpool-thread") optimizer = torch.optim.RMSprop( model.parameters(), lr=flags.learning_rate, momentum=flags.momentum, eps=flags.epsilon, alpha=flags.alpha, ) def lr_lambda(epoch): return (1 - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps) / flags.total_steps) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) stats = {} learner_threads = [ threading.Thread( target=learn, name="learner-thread-%i" % i, args=( learner_queue, model, actor_model, optimizer, scheduler, stats, flags, plogger, rank, gossip_buffer, ), ) for i in range(flags.num_learner_threads) ] inference_threads = [ threading.Thread( target=inference, name="inference-thread-%i" % i, args=(inference_batcher, actor_model, flags), ) for i in range(flags.num_inference_threads) ] # Synchronize GALA agents before starting training if barrier is not None: barrier.wait() logging.info("%s: barrier passed" % rank) actorpool_thread.start() for t in learner_threads + inference_threads: t.start() def checkpoint(): if flags.checkpoint: logging.info("Saving checkpoint to %s", flags.checkpoint) torch.save( { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "flags": vars(flags), }, flags.checkpoint, ) def format_value(x): return f"{x:1.5}" if isinstance(x, float) else str(x) try: last_checkpoint_time = timeit.default_timer() while True: start_time = timeit.default_timer() start_step = stats.get("step", 0) if start_step >= flags.total_steps: break time.sleep(5) end_step = stats.get("step", 0) if timeit.default_timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timeit.default_timer() logging.info( "Step %i @ %.1f SPS. Inference batcher size: %i." " Learner queue size: %i." " Other stats: (%s)", end_step, (end_step - start_step) / (timeit.default_timer() - start_time), inference_batcher.size(), learner_queue.size(), ", ".join(f"{key} = {format_value(value)}" for key, value in stats.items()), ) except KeyboardInterrupt: pass # Close properly. else: logging.info("Learning finished after %i steps.", stats["step"]) checkpoint() # Done with learning. Let's stop all the ongoing work. inference_batcher.close() learner_queue.close() actors.stop() actorpool_thread.join() for t in learner_threads + inference_threads: t.join() # Trace and save the final model. trace_model(flags, model)