def train(flags): if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") plogger = file_writer.FileWriter( xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir ) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar")) ) if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") flags.learner_device = torch.device("cuda:0") flags.actor_device = torch.device("cuda:1") else: logging.info("Not using CUDA.") flags.learner_device = torch.device("cpu") flags.actor_device = torch.device("cpu") if flags.max_learner_queue_size is None: flags.max_learner_queue_size = flags.batch_size # The queue the learner threads will get their data from. # Setting `minimum_batch_size == maximum_batch_size` # makes the batch size static. learner_queue = actorpool.BatchingQueue( batch_dim=1, minimum_batch_size=flags.batch_size, maximum_batch_size=flags.batch_size, check_inputs=True, maximum_queue_size=flags.max_learner_queue_size, ) # The "batcher", a queue for the inference call. Will yield # "batch" objects with `get_inputs` and `set_outputs` methods. # The batch size of the tensors will be dynamic. inference_batcher = actorpool.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=512, timeout_ms=100, check_outputs=True, ) addresses = [] connections_per_server = 1 pipe_id = 0 while len(addresses) < flags.num_actors: for _ in range(connections_per_server): addresses.append(f"{flags.pipes_basename}.{pipe_id}") if len(addresses) == flags.num_actors: break pipe_id += 1 model = Net(num_actions=flags.num_actions, use_lstm=flags.use_lstm) model = model.to(device=flags.learner_device) actor_model = Net(num_actions=flags.num_actions, use_lstm=flags.use_lstm) actor_model.to(device=flags.actor_device) # The ActorPool that will run `flags.num_actors` many loops. actors = actorpool.ActorPool( unroll_length=flags.unroll_length, learner_queue=learner_queue, inference_batcher=inference_batcher, env_server_addresses=addresses, initial_agent_state=actor_model.initial_state(), ) def run(): try: actors.run() except Exception as e: logging.error("Exception in actorpool thread!") traceback.print_exc() print() raise e actorpool_thread = threading.Thread(target=run, name="actorpool-thread") optimizer = torch.optim.RMSprop( model.parameters(), lr=flags.learning_rate, momentum=flags.momentum, eps=flags.epsilon, alpha=flags.alpha, ) def lr_lambda(epoch): return ( 1 - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps) / flags.total_steps ) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) stats = {} # Load state from a checkpoint, if possible. if os.path.exists(checkpointpath): checkpoint_states = torch.load( checkpointpath, map_location=flags.learner_device ) model.load_state_dict(checkpoint_states["model_state_dict"]) optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"]) scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"]) stats = checkpoint_states["stats"] logging.info(f"Resuming preempted job, current stats:\n{stats}") # Initialize actor model like learner model. actor_model.load_state_dict(model.state_dict()) learner_threads = [ threading.Thread( target=learn, name="learner-thread-%i" % i, args=( flags, learner_queue, model, actor_model, optimizer, scheduler, stats, plogger, ), ) for i in range(flags.num_learner_threads) ] inference_threads = [ threading.Thread( target=inference, name="inference-thread-%i" % i, args=(flags, inference_batcher, actor_model), ) for i in range(flags.num_inference_threads) ] actorpool_thread.start() for t in learner_threads + inference_threads: t.start() def checkpoint(): if flags.disable_checkpoint: return logging.info("Saving checkpoint to %s", checkpointpath) torch.save( { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "stats": stats, "flags": vars(flags), }, checkpointpath, ) def format_value(x): return f"{x:1.5}" if isinstance(x, float) else str(x) try: last_checkpoint_time = timeit.default_timer() while True: start_time = timeit.default_timer() start_step = stats.get("step", 0) if start_step >= flags.total_steps: break time.sleep(5) end_step = stats.get("step", 0) if timeit.default_timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timeit.default_timer() logging.info( "Step %i @ %.1f SPS. Inference batcher size: %i." " Learner queue size: %i." " Other stats: (%s)", end_step, (end_step - start_step) / (timeit.default_timer() - start_time), inference_batcher.size(), learner_queue.size(), ", ".join( f"{key} = {format_value(value)}" for key, value in stats.items() ), ) except KeyboardInterrupt: pass # Close properly. else: logging.info("Learning finished after %i steps.", stats["step"]) checkpoint() # Done with learning. Stop all the ongoing work. inference_batcher.close() learner_queue.close() actorpool_thread.join() for t in learner_threads + inference_threads: t.join()
def train(flags): # pylint: disable=too-many-branches, too-many-statements # prepare for logging and saving models if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") plogger = file_writer.FileWriter(xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))) if flags.save_model_every_nsteps > 0: os.makedirs(checkpointpath.replace("model.tar", "intermediate"), exist_ok=True) # get a list and determine the number of environments environments = flags.env.split(",") flags.num_tasks = len(environments) # set the number of buffers if flags.num_buffers is None: flags.num_buffers = max(2 * flags.num_actors * flags.num_tasks, flags.batch_size) if flags.num_actors * flags.num_tasks >= flags.num_buffers: raise ValueError("num_buffers should be larger than num_actors") if flags.num_buffers < flags.batch_size: raise ValueError("num_buffers should be larger than batch_size") T = flags.unroll_length B = flags.batch_size # set the device to do the training on flags.device = None if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") flags.device = torch.device("cuda") else: logging.info("Not using CUDA.") flags.device = torch.device("cpu") # set the environments if flags.env == "six": flags.env = "AirRaidNoFrameskip-v4,CarnivalNoFrameskip-v4,DemonAttackNoFrameskip-v4," \ "NameThisGameNoFrameskip-v4,PongNoFrameskip-v4,SpaceInvadersNoFrameskip-v4" elif flags.env == "three": flags.env = "AirRaidNoFrameskip-v4,CarnivalNoFrameskip-v4,DemonAttackNoFrameskip-v4" # set the right agent class if flags.agent_type.lower() in [ "aaa", "attention_augmented", "attention_augmented_agent" ]: Net = AttentionAugmentedAgent logging.info("Using the Attention-Augmented Agent architecture.") elif flags.agent_type.lower() in ["rn", "res", "resnet", "res_net"]: Net = ResNet logging.info("Using the ResNet architecture (monobeast version).") else: Net = AtariNet logging.warning( "No valid agent type specified. Using the default agent.") # create a dummy environment, mostly to get the observation and action spaces from gym_env = create_env(environments[0], frame_height=flags.frame_height, frame_width=flags.frame_width, gray_scale=(flags.aaa_input_format == "gray_stack")) observation_space_shape = gym_env.observation_space.shape action_space_n = gym_env.action_space.n full_action_space = False for environment in environments[1:]: gym_env = create_env(environment) if gym_env.action_space.n != action_space_n: logging.warning( "Action spaces don't match, using full action space.") full_action_space = True action_space_n = 18 break # create the model and the buffers to pass around data between actors and learner model = Net(observation_space_shape, action_space_n, use_lstm=flags.use_lstm, num_tasks=flags.num_tasks, use_popart=flags.use_popart, reward_clipping=flags.reward_clipping, rgb_last=(flags.aaa_input_format == "rgb_last")) buffers = create_buffers(flags, observation_space_shape, model.num_actions) # I'm guessing that this is required (similarly to the buffers) so that the # different threads/processes can all have access to the parameters etc. (?) model.share_memory() # Add initial RNN state. initial_agent_state_buffers = [] for _ in range(flags.num_buffers): state = model.initial_state(batch_size=1) for t in state: t.share_memory_() initial_agent_state_buffers.append(state) # create stuff to keep track of the actor processes actor_processes = [] ctx = mp.get_context("fork") free_queue = ctx.SimpleQueue() full_queue = ctx.SimpleQueue() # create and start actor threads (the same number for each environment) for i, environment in enumerate(environments): for j in range(flags.num_actors): actor = ctx.Process( target=act, args=( flags, environment, i, full_action_space, i * flags.num_actors + j, free_queue, full_queue, model, buffers, initial_agent_state_buffers, ), ) actor.start() actor_processes.append(actor) learner_model = Net(observation_space_shape, action_space_n, use_lstm=flags.use_lstm, num_tasks=flags.num_tasks, use_popart=flags.use_popart, reward_clipping=flags.reward_clipping, rgb_last=(flags.aaa_input_format == "rgb_last")).to( device=flags.device) # the hyperparameters in the paper are found/adjusted using population-based training, # which might be a bit too difficult for us to do; while the IMPALA paper also does # some experiments with this, it doesn't seem to be implemented here optimizer = torch.optim.RMSprop( learner_model.parameters(), lr=flags.learning_rate, momentum=flags.momentum, eps=flags.epsilon, alpha=flags.alpha, ) def lr_lambda(epoch): return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # Load state from a checkpoint, if possible. if os.path.exists(checkpointpath): checkpoint_states = torch.load(checkpointpath, map_location=flags.device) learner_model.load_state_dict(checkpoint_states["model_state_dict"]) optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"]) scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"]) # stats = checkpoint_states["stats"] # logging.info(f"Resuming preempted job, current stats:\n{stats}") # Initialize actor model like learner model. model.load_state_dict(learner_model.state_dict()) logger = logging.getLogger("logfile") stat_keys = [ "total_loss", "mean_episode_return", "pg_loss", "baseline_loss", "entropy_loss", "mu", "sigma", ] + ["{}_step".format(e) for e in environments] logger.info("# Step\t%s", "\t".join(stat_keys)) step, stats = 0, {} def batch_and_learn(i, lock=threading.Lock()): """Thread target for the learning process.""" # step in particular needs to be from the outside scope, since all learner threads can update # it and all learners should stop once the total number of steps/frames has been processed nonlocal step, stats timings = prof.Timings() while step < flags.total_steps: timings.reset() batch, agent_state = get_batch( flags, free_queue, full_queue, buffers, initial_agent_state_buffers, timings, ) learn(flags, model, learner_model, batch, agent_state, optimizer, scheduler, stats, envs=environments) timings.time("learn") with lock: to_log = dict(step=step) to_log.update( {k: stats[k] for k in stat_keys if "_step" not in k}) for e in stats["env_step"]: to_log["{}_step".format(e)] = stats["env_step"][e] plogger.log(to_log) step += T * B # so this counts the number of frames, not e.g. trajectories/rollouts if i == 0: logging.info("Batch and learn: %s", timings.summary()) # populate the free queue with the indices of all the buffers at the start for m in range(flags.num_buffers): free_queue.put(m) # start as many learner threads as specified => could in principle do PBT threads = [] for i in range(flags.num_learner_threads): thread = threading.Thread(target=batch_and_learn, name="batch-and-learn-%d" % i, args=(i, )) thread.start() threads.append(thread) def save_latest_model(): if flags.disable_checkpoint: return logging.info("Saving checkpoint to %s", checkpointpath) torch.save( { "model_state_dict": learner_model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "flags": vars(flags), }, checkpointpath, ) def save_intermediate_model(): save_model_path = checkpointpath.replace( "model.tar", "intermediate/model." + str(stats.get("step", 0)).zfill(9) + ".tar") logging.info("Saving model to %s", save_model_path) torch.save( { "model_state_dict": learner_model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "stats": stats, "flags": vars(flags), }, save_model_path, ) timer = timeit.default_timer try: last_checkpoint_time = timer() last_savemodel_nsteps = 0 while step < flags.total_steps: start_step = stats.get("step", 0) start_time = timer() time.sleep(5) end_step = stats.get("step", 0) if timer() - last_checkpoint_time > 10 * 60: # save every 10 min. save_latest_model() last_checkpoint_time = timer() if flags.save_model_every_nsteps > 0 and end_step >= last_savemodel_nsteps + flags.save_model_every_nsteps: # save model every save_model_every_nsteps steps save_intermediate_model() last_savemodel_nsteps = end_step sps = (end_step - start_step) / (timer() - start_time) if stats.get("episode_returns", None): mean_return = ("Return per episode: %.1f. " % stats["mean_episode_return"]) else: mean_return = "" total_loss = stats.get("total_loss", float("inf")) logging.info( "Steps %i @ %.1f SPS. Loss %f. %sStats:\n%s", end_step, sps, total_loss, mean_return, pprint.pformat(stats), ) except KeyboardInterrupt: gradient_tracker.print_total() return # Try joining actors then quit. else: for thread in threads: thread.join() logging.info("Learning finished after %d steps.", step) finally: for _ in range(flags.num_actors): free_queue.put(None) # send quit signal to actors for actor in actor_processes: actor.join(timeout=10) gradient_tracker.print_total() save_latest_model() plogger.close()
def train(flags, game_params): # pylint: disable=too-many-branches, too-many-statements """ 1. Init actor model and create_buffers() 2. Starts 'num_actors' act() functions 3. Init learner model and optimizer, loads the former on the GPU 4. Launches 'num_learner_threads' threads executing batch_and_learn() 5. train finishes when all batch_and_learn threads finish, i.e. when steps >= flags.total_steps """ if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") plogger = file_writer.FileWriter(xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))) print("checkpointpath: ", checkpointpath) if flags.num_buffers is None: # Set sensible default for num_buffers. IMPORTANT!! flags.num_buffers = max(2 * flags.num_actors, flags.batch_size) if flags.num_actors >= flags.num_buffers: raise ValueError("num_buffers should be larger than num_actors") if flags.num_buffers < flags.batch_size: raise ValueError("num_buffers should be larger than batch_size") T = flags.unroll_length B = flags.batch_size flags.device = None if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") flags.device = torch.device("cuda") else: logging.info("Not using CUDA.") flags.device = torch.device("cpu") env = init_game(game_params['env'], flags.map_name) model = IMPALA_AC(env=env, device='cpu', **game_params['HPs']) observation_shape = (game_params['HPs']['spatial_dict']['in_channels'], *model.screen_res) player_shape = game_params['HPs']['spatial_dict']['in_player'] num_actions = len(model.action_names) buffers = create_buffers(flags, observation_shape, player_shape, num_actions, model.max_num_spatial_args, model.max_num_categorical_args) model.share_memory() # see if this works out of the box for my A2C actor_processes = [] ctx = mp.get_context("fork") free_queue = ctx.SimpleQueue() full_queue = ctx.SimpleQueue() for i in range(flags.num_actors): actor = ctx.Process( target=act, args=( flags, game_params, i, free_queue, full_queue, model, # with share memory buffers), ) actor.start() actor_processes.append(actor) # only model loaded into the GPU ? if not flags.disable_cuda and torch.cuda.is_available(): learner_model = IMPALA_AC(env=env, device='cuda', **game_params['HPs']).to(device=flags.device) else: learner_model = IMPALA_AC(env=env, device='cpu', **game_params['HPs']).to(device=flags.device) if flags.optim == "Adam": optimizer = torch.optim.Adam(learner_model.parameters(), lr=flags.learning_rate) else: optimizer = torch.optim.RMSprop( learner_model.parameters(), lr=flags.learning_rate, momentum=flags.momentum, eps=flags.epsilon, alpha=flags.alpha, ) def lr_lambda(epoch): """ Linear schedule from 1 to 0 used only for RMSprop. To be adjusted multiplying or not by batch size B depending on how the steps are counted. epoch = number of optimizer steps total_steps = optimizer steps * time steps * batch size or optimizer steps * time steps """ return 1 - min(epoch * T, flags.total_steps ) / flags.total_steps #epoch * T * B if using B steps scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) logger = logging.getLogger("logfile") stat_keys = [ "total_loss", "mean_episode_return", "pg_loss", "baseline_loss", "entropy_loss", ] logger.info("# Step\t%s", "\t".join(stat_keys)) step, stats = 0, {} def batch_and_learn(i, lock=threading.Lock()): """Thread target for the learning process.""" nonlocal step, stats timings = prof.Timings() while step < flags.total_steps: timings.reset() batch = get_batch( flags, free_queue, full_queue, buffers, timings, ) stats = learn(flags, model, learner_model, batch, optimizer, scheduler) timings.time("learn") with lock: to_log = dict(step=step) to_log.update({k: stats[k] for k in stat_keys}) plogger.log(to_log) step += T #* B # just count the parallel steps # end batch_and_learn if i == 0: logging.info("Batch and learn: %s", timings.summary()) for m in range(flags.num_buffers): free_queue.put(m) threads = [] for i in range(flags.num_learner_threads): thread = threading.Thread(target=batch_and_learn, name="batch-and-learn-%d" % i, args=(i, )) thread.start() threads.append(thread) def checkpoint(): if flags.disable_checkpoint: return logging.info("Saving checkpoint to %s", checkpointpath) if flags.optim == "Adam": torch.save( { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "flags": vars(flags), }, checkpointpath, # only one checkpoint at the time is saved ) else: torch.save( { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "flags": vars(flags), }, checkpointpath, # only one checkpoint at the time is saved ) # end checkpoint timer = timeit.default_timer try: last_checkpoint_time = timer() while step < flags.total_steps: start_step = step start_time = timer() time.sleep(5) if timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timer() sps = (step - start_step) / (timer() - start_time ) # steps per second if stats.get("episode_returns", None): mean_return = ("Return per episode: %.1f. " % stats["mean_episode_return"]) else: mean_return = "" total_loss = stats.get("total_loss", float("inf")) logging.info( "Steps %i @ %.1f SPS. Loss %f. %sStats:\n%s", step, sps, total_loss, mean_return, pprint.pformat(stats), ) except KeyboardInterrupt: return # Try joining actors then quit. else: for thread in threads: thread.join() logging.info("Learning finished after %d steps.", step) finally: for _ in range(flags.num_actors): free_queue.put(None) for actor in actor_processes: actor.join(timeout=1) checkpoint() plogger.close()
def train(flags): if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") plogger = file_writer.FileWriter(xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))) if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") flags.learner_device = torch.device("cuda") flags.actor_device = torch.device("cuda") else: logging.info("Not using CUDA.") flags.learner_device = torch.device("cpu") flags.actor_device = torch.device("cpu") if flags.max_learner_queue_size is None: flags.max_learner_queue_size = flags.batch_size # The queue the learner threads will get their data from. # Setting `minimum_batch_size == maximum_batch_size` # makes the batch size static. learner_queue = actorpool.BatchingQueue( batch_dim=1, minimum_batch_size=flags.batch_size, maximum_batch_size=flags.batch_size, check_inputs=True, maximum_queue_size=flags.max_learner_queue_size, ) d_queue = Queue(maxsize=flags.max_learner_queue_size // flags.batch_size) image_queue = Queue(maxsize=flags.max_learner_queue_size) # The "batcher", a queue for the inference call. Will yield # "batch" objects with `get_inputs` and `set_outputs` methods. # The batch size of the tensors will be dynamic. inference_batcher = actorpool.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=512, timeout_ms=100, check_outputs=True, ) addresses = [] connections_per_server = 1 pipe_id = 0 while len(addresses) < flags.num_actors: for _ in range(connections_per_server): addresses.append(f"{flags.pipes_basename}.{pipe_id}") if len(addresses) == flags.num_actors: break pipe_id += 1 config = dict( episode_length=flags.episode_length, canvas_width=flags.canvas_width, grid_width=grid_width, brush_sizes=flags.brush_sizes, ) if flags.dataset == "celeba" or flags.dataset == "celeba-hq": use_color = True else: use_color = False if flags.env_type == "fluid": env_name = "Fluid" config["shaders_basedir"] = SHADERS_BASEDIR elif flags.env_type == "libmypaint": env_name = "Libmypaint" config.update( dict( brush_type=flags.brush_type, use_color=use_color, use_pressure=flags.use_pressure, use_alpha=False, background="white", brushes_basedir=BRUSHES_BASEDIR, )) if flags.use_compound: env_name += "-v1" else: env_name += "-v0" env = env_wrapper.make_raw(env_name, config) if frame_width != flags.canvas_width: env = env_wrapper.WarpFrame(env, height=frame_width, width=frame_width) env = env_wrapper.wrap_pytorch(env) obs_shape = env.observation_space.shape if flags.condition: c, h, w = obs_shape c *= 2 obs_shape = (c, h, w) action_shape = env.action_space.nvec.tolist() order = env.order env.close() model = models.Net( obs_shape=obs_shape, action_shape=action_shape, grid_shape=(grid_width, grid_width), order=order, ) if flags.condition: model = models.Condition(model) model = model.to(device=flags.learner_device) actor_model = models.Net( obs_shape=obs_shape, action_shape=action_shape, grid_shape=(grid_width, grid_width), order=order, ) if flags.condition: actor_model = models.Condition(actor_model) actor_model.to(device=flags.actor_device) D = models.Discriminator(obs_shape, flags.power_iters) if flags.condition: D = models.Conditional(D) D.to(device=flags.learner_device) D_eval = models.Discriminator(obs_shape, flags.power_iters) if flags.condition: D_eval = models.Conditional(D_eval) D_eval = D_eval.to(device=flags.learner_device) optimizer = optim.Adam(model.parameters(), lr=flags.policy_learning_rate) D_optimizer = optim.Adam(D.parameters(), lr=flags.discriminator_learning_rate, betas=(0.5, 0.999)) def lr_lambda(epoch): return (1 - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps) / flags.total_steps) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) D_scheduler = torch.optim.lr_scheduler.LambdaLR(D_optimizer, lr_lambda) C, H, W = obs_shape if flags.condition: C //= 2 # The ActorPool that will run `flags.num_actors` many loops. actors = actorpool.ActorPool( unroll_length=flags.unroll_length, learner_queue=learner_queue, inference_batcher=inference_batcher, env_server_addresses=addresses, initial_action=actor_model.initial_action(), initial_agent_state=actor_model.initial_state(), image=torch.zeros(1, 1, C, H, W), ) def run(): try: actors.run() print("actors are running") except Exception as e: logging.error("Exception in actorpool thread!") traceback.print_exc() print() raise e actorpool_thread = threading.Thread(target=run, name="actorpool-thread") c, h, w = obs_shape tsfm = transforms.Compose( [transforms.Resize((h, w)), transforms.ToTensor()]) dataset = flags.dataset if dataset == "mnist": dataset = MNIST(root="./", train=True, transform=tsfm, download=True) elif dataset == "omniglot": dataset = Omniglot(root="./", background=True, transform=tsfm, download=True) elif dataset == "celeba": dataset = CelebA(root="./", split="train", target_type=None, transform=tsfm, download=True) elif dataset == "celeba-hq": dataset = datasets.CelebAHQ(root="./", split="train", transform=tsfm, download=True) else: raise NotImplementedError dataloader = DataLoader(dataset, batch_size=1, shuffle=True, drop_last=True, pin_memory=True) stats = {} # Load state from a checkpoint, if possible. if os.path.exists(checkpointpath): checkpoint_states = torch.load(checkpointpath, map_location=flags.learner_device) model.load_state_dict(checkpoint_states["model_state_dict"]) D.load_state_dict(checkpoint_states["D_state_dict"]) optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"]) D_optimizer.load_state_dict( checkpoint_states["D_optimizer_state_dict"]) scheduler.load_state_dict(checkpoint_states["D_scheduler_state_dict"]) D_scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"]) stats = checkpoint_states["stats"] logging.info(f"Resuming preempted job, current stats:\n{stats}") # Initialize actor model like learner model. actor_model.load_state_dict(model.state_dict()) D_eval.load_state_dict(D.state_dict()) learner_threads = [ threading.Thread( target=learn, name="learner-thread-%i" % i, args=( flags, learner_queue, d_queue, model, actor_model, D_eval, optimizer, scheduler, stats, plogger, ), ) for i in range(flags.num_learner_threads) ] inference_threads = [ threading.Thread( target=inference, name="inference-thread-%i" % i, args=( flags, inference_batcher, actor_model, image_queue, ), ) for i in range(flags.num_inference_threads) ] d_learner = [ threading.Thread( target=learn_D, name="d_learner-thread-%i" % i, args=( flags, d_queue, D, D_eval, D_optimizer, D_scheduler, stats, plogger, ), ) for i in range(flags.num_learner_threads) ] for thread in d_learner: thread.daemon = True dataloader_thread = threading.Thread(target=data_loader, args=( flags, dataloader, image_queue, )) dataloader_thread.daemon = True actorpool_thread.start() threads = learner_threads + inference_threads daemons = d_learner + [dataloader_thread] for t in threads + daemons: t.start() def checkpoint(): if flags.disable_checkpoint: return logging.info("Saving checkpoint to %s", checkpointpath) torch.save( { "model_state_dict": model.state_dict(), "D_state_dict": D.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "D_optimizer_state_dict": D_optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "D_scheduler_state_dict": D_scheduler.state_dict(), "stats": stats, "flags": vars(flags), }, checkpointpath, ) def format_value(x): return f"{x:1.5}" if isinstance(x, float) else str(x) try: last_checkpoint_time = timeit.default_timer() while True: start_time = timeit.default_timer() start_step = stats.get("step", 0) if start_step >= flags.total_steps: break time.sleep(5) end_step = stats.get("step", 0) if timeit.default_timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timeit.default_timer() logging.info( "Step %i @ %.1f SPS. Inference batcher size: %i." " Learner queue size: %i." " Other stats: (%s)", end_step, (end_step - start_step) / (timeit.default_timer() - start_time), inference_batcher.size(), learner_queue.size(), ", ".join(f"{key} = {format_value(value)}" for key, value in stats.items()), ) except KeyboardInterrupt: pass # Close properly. else: logging.info("Learning finished after %i steps.", stats["step"]) checkpoint() # Done with learning. Stop all the ongoing work. inference_batcher.close() learner_queue.close() actorpool_thread.join() for t in threads: t.join()
def train(flags): # pylint: disable=too-many-branches, too-many-statements if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") xp_args = flags.__dict__.copy() if 'device' in xp_args and isinstance(xp_args['device'], torch.device): xp_args['device'] = str(xp_args['device']) plogger = file_writer.FileWriter( xpid=flags.xpid, xp_args=xp_args, rootdir=flags.savedir ) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar")) ) if flags.num_buffers is None: # Set sensible default for num_buffers. flags.num_buffers = max(2 * flags.num_actors, flags.batch_size) if flags.num_actors >= flags.num_buffers: raise ValueError("num_buffers should be larger than num_actors") if flags.num_buffers < flags.batch_size: raise ValueError("num_buffers should be larger than batch_size") T = flags.unroll_length B = flags.batch_size flags.device = None if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") flags.device = torch.device("cuda") else: logging.info("Not using CUDA.") flags.device = torch.device("cpu") env = create_env(flags) model = Net(env.observation_space.shape, env.action_space.n, flags.use_lstm) buffers = create_buffers(flags, env.observation_space.shape, model.num_actions) model.share_memory() # Add initial RNN state. initial_agent_state_buffers = [] for _ in range(flags.num_buffers): state = model.initial_state(batch_size=1) for t in state: t.share_memory_() initial_agent_state_buffers.append(state) actor_processes = [] ctx = mp.get_context("fork") free_queue = ctx.SimpleQueue() full_queue = ctx.SimpleQueue() for i in range(flags.num_actors): actor = ctx.Process( target=act, args=( flags, i, free_queue, full_queue, model, buffers, initial_agent_state_buffers, ), ) actor.start() actor_processes.append(actor) learner_model = Net( env.observation_space.shape, env.action_space.n, flags.use_lstm ).to(device=flags.device) optimizer = torch.optim.RMSprop( learner_model.parameters(), lr=flags.learning_rate, momentum=flags.momentum, eps=flags.epsilon, alpha=flags.alpha, ) def lr_lambda(epoch): return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) logger = logging.getLogger("logfile") stat_keys = [ "total_loss", "mean_episode_return", "pg_loss", "baseline_loss", "entropy_loss", ] logger.info("# Step\t%s", "\t".join(stat_keys)) # Load state from a checkpoint, if possible. resume_checkpoint_path = checkpointpath if 'resume_checkpoint_path' in flags: resume_checkpoint_path = flags.resume_checkpoint_path if os.path.exists(resume_checkpoint_path): checkpoint_states = torch.load( resume_checkpoint_path, map_location=flags.device ) learner_model.load_state_dict(checkpoint_states["model_state_dict"]) optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"]) scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"]) stats = checkpoint_states["stats"] logging.info(f"Resuming preempted job, current stats:\n{stats}") # Initialize actor model like learner model. model.load_state_dict(learner_model.state_dict()) step, stats = 0, {} def batch_and_learn(i, lock=threading.Lock()): """Thread target for the learning process.""" nonlocal step, stats timings = prof.Timings() steps = flags.total_steps if 'train_steps' in flags: steps = flags.train_steps while step < steps: timings.reset() batch, agent_state = get_batch( flags, free_queue, full_queue, buffers, initial_agent_state_buffers, timings, ) stats = learn( flags, model, learner_model, batch, agent_state, optimizer, scheduler ) timings.time("learn") with lock: to_log = dict(step=step) to_log.update({k: stats[k] for k in stat_keys}) plogger.log(to_log) step += T * B if i == 0: logging.info("Batch and learn: %s", timings.summary()) for m in range(flags.num_buffers): free_queue.put(m) threads = [] for i in range(flags.num_learner_threads): thread = threading.Thread( target=batch_and_learn, name="batch-and-learn-%d" % i, args=(i,) ) thread.start() threads.append(thread) def checkpoint(path=None): if flags.disable_checkpoint: return if path is None: path = checkpointpath logging.info("Saving checkpoint to %s", checkpointpath) torch.save( { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "stats": stats, "flags": vars(flags), }, path, ) timer = timeit.default_timer try: last_checkpoint_time = timer() steps = flags.total_steps if 'train_steps' in flags: steps = flags.train_steps while step < steps: start_step = step start_time = timer() time.sleep(5) if timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timer() sps = (step - start_step) / (timer() - start_time) if stats.get("episode_returns", None): mean_return = ( "Return per episode: %.1f. " % stats["mean_episode_return"] ) else: mean_return = "" total_loss = stats.get("total_loss", float("inf")) print_step = step if 'current_step' in flags: print_step += flags.current_step logging.info( "Steps %i @ %.1f SPS. Loss %f. %sStats:\n%s", print_step, sps, total_loss, mean_return, pprint.pformat(stats), ) except KeyboardInterrupt: return # Try joining actors then quit. else: for thread in threads: thread.join() logging.info("Learning finished after %d steps.", step) finally: for _ in range(flags.num_actors): free_queue.put(None) for actor in actor_processes: actor.join(timeout=1) checkpoint() if wandb.run is not None and wandb.run.dir is not None: save_path = os.path.join(wandb.run.dir, f'model-{step}.tar') checkpoint(save_path) flags.resume_checkpoint_path = save_path plogger.close() return step
def train(flags): # pylint: disable=too-many-branches, too-many-statements if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") plogger = file_writer.FileWriter(xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))) if flags.num_buffers is None: # Set sensible default for num_buffers. flags.num_buffers = max(2 * flags.num_actors, flags.batch_size) if flags.num_actors >= flags.num_buffers: raise ValueError("num_buffers should be larger than num_actors") if flags.num_buffers < flags.batch_size: raise ValueError("num_buffers should be larger than batch_size") T = flags.unroll_length B = flags.batch_size flags.device = None if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") flags.device = torch.device("cuda") else: logging.info("Not using CUDA.") flags.device = torch.device("cpu") load_checkpoint = None if flags.loaddir is not None: loadpath = os.path.join(os.path.expanduser(flags.loaddir), 'model.tar') logging.info("Continue training from {}".format(loadpath)) load_checkpoint = torch.load(loadpath, map_location=flags.device) if flags.env.startswith('Coin'): obs_shape = (3, 64, 64) # will deadlock if we try to create env here act_shape = 7 # so specify space shapes manually else: env = create_env(flags.env, flags) obs_shape = copy(env.observation_space.shape) act_shape = env.action_space.n del env model = Net(obs_shape, act_shape, flags.use_lstm) if load_checkpoint is not None: model.load_state_dict(load_checkpoint["model_state_dict"]) buffers = create_buffers(flags, obs_shape, model.num_actions) model.share_memory() # Add initial RNN state. initial_agent_state_buffers = [] for _ in range(flags.num_buffers): state = model.initial_state(batch_size=1) for t in state: t.share_memory_() initial_agent_state_buffers.append(state) actor_processes = [] ctx = mp.get_context("fork") free_queue = ctx.SimpleQueue() full_queue = ctx.SimpleQueue() for i in range(flags.num_actors): actor = ctx.Process( target=act, args=( flags, i, free_queue, full_queue, model, buffers, initial_agent_state_buffers, ), ) actor.start() actor_processes.append(actor) learner_model = Net(obs_shape, act_shape, flags.use_lstm).to(device=flags.device) if load_checkpoint is not None: learner_model.load_state_dict(load_checkpoint["model_state_dict"]) optimizer = torch.optim.RMSprop( learner_model.parameters(), lr=flags.learning_rate, momentum=flags.momentum, eps=flags.epsilon, alpha=flags.alpha, ) if load_checkpoint is not None: optimizer.load_state_dict(load_checkpoint["optimizer_state_dict"]) def lr_lambda(epoch): return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) if load_checkpoint is not None: scheduler.load_state_dict(load_checkpoint["scheduler_state_dict"]) logger = logging.getLogger("logfile") stat_keys = [ "total_loss", "mean_episode_return", "pg_loss", "baseline_loss", "entropy_loss", ] logger.info("# Step\t%s", "\t".join(stat_keys)) step, stats = 0, {} def batch_and_learn(i, lock=threading.Lock()): """Thread target for the learning process.""" nonlocal step, stats timings = prof.Timings() while step < flags.total_steps: timings.reset() batch, agent_state = get_batch( flags, free_queue, full_queue, buffers, initial_agent_state_buffers, timings, ) stats = learn(flags, model, learner_model, batch, agent_state, optimizer, scheduler) timings.time("learn") with lock: to_log = dict(step=step) to_log.update({k: stats[k] for k in stat_keys}) plogger.log(to_log) step += T * B if i == 0: logging.info("Batch and learn: %s", timings.summary()) for m in range(flags.num_buffers): free_queue.put(m) threads = [] for i in range(flags.num_learner_threads): thread = threading.Thread(target=batch_and_learn, name="batch-and-learn-%d" % i, args=(i, )) thread.start() threads.append(thread) def checkpoint(): if flags.disable_checkpoint: return logging.info("Saving checkpoint to %s", checkpointpath) torch.save( { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "flags": vars(flags), }, checkpointpath, ) timer = timeit.default_timer try: last_checkpoint_time = timer() last_print_time = timer() episode_returns = [] while step < flags.total_steps: start_step = step start_time = timer() time.sleep(0.5) if stats.get("episode_returns", None): episode_returns.extend(stats["episode_returns"]) if timer() - last_print_time < 10.0: continue # wait 10s to print if timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timer() sps = (step - start_step) / (timer() - start_time) if len(episode_returns) > 0: mean_return_val = sum(episode_returns) / len(episode_returns) mean_return = ("Return per episode: %.1f" % mean_return_val) else: mean_return = "" total_loss = stats.get("total_loss", float("inf")) logging.info( "Steps %i @ %.1f SPS Loss %f %s\nEpsd returns:%s\nStats:\n%s", step, sps, total_loss, mean_return, pprint.pformat(episode_returns), pprint.pformat(stats), ) last_print_time = timer() episode_returns = [] except KeyboardInterrupt: return # Try joining actors then quit. else: for thread in threads: thread.join() logging.info("Learning finished after %d steps.", step) finally: for _ in range(flags.num_actors): free_queue.put(None) for actor in actor_processes: actor.join(timeout=1) checkpoint() plogger.close()
def train(flags): if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") plogger = file_writer.FileWriter(xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))) if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") flags.learner_device = torch.device("cuda") flags.actor_device = torch.device("cuda") else: logging.info("Not using CUDA.") flags.learner_device = torch.device("cpu") flags.actor_device = torch.device("cpu") if flags.max_learner_queue_size is None: flags.max_learner_queue_size = flags.batch_size # The queue the learner threads will get their data from. # Setting `minimum_batch_size == maximum_batch_size` # makes the batch size static. learner_queue = libtorchbeast.BatchingQueue( batch_dim=1, minimum_batch_size=flags.batch_size, maximum_batch_size=flags.batch_size, check_inputs=True, maximum_queue_size=flags.max_learner_queue_size, ) # The queue the actorpool stores final render image pairs. # A seperate thread will load them to the ReplayBuffer. # The batch size of the pairs will be dynamic. replay_queue = libtorchbeast.BatchingQueue( batch_dim=1, minimum_batch_size=flags.batch_size, maximum_batch_size=flags.batch_size, check_inputs=True, maximum_queue_size=flags.batch_size, ) # The "batcher", a queue for the inference call. Will yield # "batch" objects with `get_inputs` and `set_outputs` methods. # The batch size of the tensors will be dynamic. inference_batcher = libtorchbeast.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=512, timeout_ms=100, check_outputs=True, ) addresses = [] connections_per_server = 1 pipe_id = 0 while len(addresses) < flags.num_actors: for _ in range(connections_per_server): addresses.append(f"{flags.pipes_basename}.{pipe_id}") if len(addresses) == flags.num_actors: break pipe_id += 1 dataset_is_gray = flags.dataset in ["mnist", "omniglot"] grayscale = not dataset_is_gray and not flags.use_color dataset_is_gray |= grayscale dataset = utils.create_dataset(flags.dataset, grayscale) env_name, config = utils.parse_flags(flags) env = utils.create_env(env_name, config, dataset_is_gray, dataset=None) if flags.condition: new_space = env.observation_space.spaces c, h, w = new_space["canvas"].shape new_space["canvas"] = spaces.Box(low=0, high=255, shape=(c * 2, h, w), dtype=np.uint8) env.observation_space = spaces.Dict(new_space) obs_shape = env.observation_space["canvas"].shape action_shape = env.action_space.nvec order = env.order env.close() model = models.Net( obs_shape=obs_shape, order=order, action_shape=action_shape, grid_shape=(grid_width, grid_width), ) model = model.to(device=flags.learner_device) actor_model = models.Net( obs_shape=obs_shape, order=order, action_shape=action_shape, grid_shape=(grid_width, grid_width), ).eval() actor_model.to(device=flags.actor_device) if flags.condition: D = models.ComplementDiscriminator(obs_shape) else: D = models.Discriminator(obs_shape) D.to(device=flags.learner_device) if flags.condition: D_eval = models.ComplementDiscriminator(obs_shape) else: D_eval = models.Discriminator(obs_shape) D_eval = D_eval.to(device=flags.learner_device).eval() optimizer = optim.Adam(model.parameters(), lr=flags.policy_learning_rate) D_optimizer = optim.Adam(D.parameters(), lr=flags.discriminator_learning_rate, betas=(0.5, 0.999)) def lr_lambda(epoch): return (1 - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps) / flags.total_steps) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # The ActorPool that will run `flags.num_actors` many loops. actors = libtorchbeast.ActorPool( unroll_length=flags.unroll_length, learner_queue=learner_queue, replay_queue=replay_queue, inference_batcher=inference_batcher, env_server_addresses=addresses, initial_agent_state=actor_model.initial_state(), ) def run(): try: actors.run() print("actors are running") except Exception as e: logging.error("Exception in actorpool thread!") traceback.print_exc() print() raise e actorpool_thread = threading.Thread(target=run, name="actorpool-thread") dataloader = DataLoader( dataset, batch_size=flags.batch_size, shuffle=True, drop_last=True, pin_memory=True, ) stats = {} # Load state from a checkpoint, if possible. if os.path.exists(checkpointpath): checkpoint_states = torch.load(checkpointpath, map_location=flags.learner_device) model.load_state_dict(checkpoint_states["model_state_dict"]) D.load_state_dict(checkpoint_states["D_state_dict"]) optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"]) D_optimizer.load_state_dict( checkpoint_states["D_optimizer_state_dict"]) scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"]) stats = checkpoint_states["stats"] logging.info(f"Resuming preempted job, current stats:\n{stats}") # Initialize actor model like learner model. actor_model.load_state_dict(model.state_dict()) D_eval.load_state_dict(D.state_dict()) learner_threads = [ threading.Thread( target=learn, name="learner-thread-%i" % i, args=( flags, learner_queue, model, actor_model, D_eval, optimizer, scheduler, stats, plogger, ), ) for i in range(flags.num_learner_threads) ] inference_threads = [ threading.Thread( target=inference, name="inference-thread-%i" % i, args=( flags, inference_batcher, actor_model, ), ) for i in range(flags.num_inference_threads) ] d_learner = threading.Thread( target=learn_D, name="d_learner-thread", args=( flags, dataloader, replay_queue, D, D_eval, D_optimizer, stats, plogger, ), ) actorpool_thread.start() threads = learner_threads + inference_threads for t in inference_threads: t.start() def checkpoint(): if flags.disable_checkpoint: return logging.info("Saving checkpoint to %s", checkpointpath) torch.save( { "model_state_dict": model.state_dict(), "D_state_dict": D.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "D_optimizer_state_dict": D_optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "stats": stats, "flags": vars(flags), }, checkpointpath, ) def format_value(x): return f"{x:1.5}" if isinstance(x, float) else str(x) try: while replay_queue.size() < flags.batch_size: if learner_queue.size() >= flags.batch_size: next(learner_queue) time.sleep(0.01) d_learner.start() for t in learner_threads: t.start() last_checkpoint_time = timeit.default_timer() while True: start_time = timeit.default_timer() start_step = stats.get("step", 0) if start_step >= flags.total_steps: break time.sleep(5) end_step = stats.get("step", 0) if timeit.default_timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timeit.default_timer() logging.info( "Step %i @ %.1f SPS. Inference batcher size: %i." " Learner queue size: %i." " Other stats: (%s)", end_step, (end_step - start_step) / (timeit.default_timer() - start_time), inference_batcher.size(), learner_queue.size(), ", ".join(f"{key} = {format_value(value)}" for key, value in stats.items()), ) except KeyboardInterrupt: pass # Close properly. else: logging.info("Learning finished after %i steps.", stats["step"]) checkpoint() # Done with learning. Stop all the ongoing work. inference_batcher.close() learner_queue.close() replay_queue.close() actorpool_thread.join() d_learner.join() for t in threads: t.join()
def train(flags): # pylint: disable=too-many-branches, too-many-statements terms = flags.xpid.split("-") if len(terms) == 3: group = terms[0] + "-" + terms[1] if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") plogger = file_writer.FileWriter(xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))) if flags.num_buffers is None: # Set sensible default for num_buffers. flags.num_buffers = max(2 * flags.num_actors, flags.batch_size) if flags.num_actors >= flags.num_buffers: raise ValueError("num_buffers should be larger than num_actors") if flags.num_buffers < flags.batch_size: raise ValueError("num_buffers should be larger than batch_size") T = flags.unroll_length B = flags.batch_size flags.device = None if (not flags.disable_cuda) and torch.cuda.is_available(): logging.info("Using CUDA.") flags.device = torch.device("cuda:" + str(torch.cuda.current_device())) else: logging.info("Not using CUDA.") flags.device = torch.device("cpu") step, stats = 0, {} env = create_gymenv(flags) actor_flags = flags try: checkpoint = torch.load(checkpointpath, map_location=flags.device) step = checkpoint["step"] except Exception as e: print(e) model = create_model(flags, env).to(device=flags.device) try: model.load_state_dict(checkpoint["model_state_dict"]) except Exception as e: print(e) if flags.agent in ["CNN"]: buffers = create_buffers( flags, env.observation_space.spaces["image"].shape, env.action_space.n, flags.unroll_length, flags.num_buffers, img_shape=env.observation_space.spaces["image"].shape) actor_buffers = create_buffers( flags, env.observation_space.spaces["image"].shape, env.action_space.n, 0, flags.num_actors, img_shape=env.observation_space.spaces["image"].shape) elif flags.agent in ["NLM", "KBMLP", "GCN"]: buffers = create_buffers( flags, env.obs_shape, model.num_actions, flags.unroll_length, flags.num_buffers, img_shape=env.observation_space.spaces["image"].shape) actor_buffers = create_buffers( flags, env.obs_shape, model.num_actions, 0, flags.num_actors, img_shape=env.observation_space.spaces["image"].shape) else: raise ValueError() actor_processes = [] ctx = mp.get_context("fork") free_queue = ctx.SimpleQueue() full_queue = ctx.SimpleQueue() actor_model_queues = [ctx.SimpleQueue() for _ in range(flags.num_actors)] actor_env_queues = [ctx.SimpleQueue() for _ in range(flags.num_actors)] for i in range(flags.num_actors): actor = ctx.Process( target=act, args=(actor_flags, create_gymenv(flags), i, free_queue, full_queue, buffers, actor_buffers, actor_model_queues, actor_env_queues), ) actor.start() actor_processes.append(actor) optimizer = torch.optim.RMSprop( model.parameters(), lr=flags.learning_rate, momentum=flags.momentum, eps=flags.epsilon, alpha=flags.alpha, ) def lr_lambda(epoch): return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) logger = logging.getLogger("logfile") if flags.mode == "imitate": stat_keys = [ "total_loss", "accuracy", "mean_episode_return", ] else: stat_keys = [ "total_loss", "mean_episode_return", "pg_loss", "baseline_loss", "entropy_loss", ] logger.info("# Step\t%s", "\t".join(stat_keys)) finish = False def batch_and_inference(): nonlocal finish while not all(finish): indexes = [] for i in range(flags.num_actors): indexes.append(actor_model_queues[i].get()) batch = get_inference_batch(flags, actor_buffers) with torch.no_grad(): agent_output = model(batch) for index in indexes: for key in agent_output: actor_buffers[key][index][0] = agent_output[key][0, index] for i in range(flags.num_actors): actor_env_queues[i].put(None) finish = [False for _ in range(flags.num_learner_threads)] def batch_and_learn(i, lock=threading.Lock()): """Thread target for the learning process.""" nonlocal step, stats, finish timings = prof.Timings() while step < flags.total_steps: timings.reset() batch = get_batch( flags, free_queue, full_queue, buffers, timings, ) stats = learn(flags, model, batch, optimizer, scheduler) timings.time("learn") with lock: to_log = dict(step=step) to_log.update({k: stats[k] for k in stat_keys}) plogger.log(to_log) step += T * B if i == 0: logging.info("Batch and learn: %s", timings.summary()) finish[i] = True for m in range(flags.num_buffers): free_queue.put(m) threads = [] thread = threading.Thread(target=batch_and_inference, name="batch-and-inference") thread.start() threads.append(thread) for i in range(flags.num_learner_threads): thread = threading.Thread(target=batch_and_learn, name="batch-and-learn-%d" % i, args=(i, )) thread.start() threads.append(thread) def checkpoint(): if flags.disable_checkpoint: return logging.info("Saving checkpoint to %s", checkpointpath) torch.save( { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "flags": vars(flags), "step": step }, checkpointpath, ) timer = timeit.default_timer try: last_checkpoint_time = timer() while step < flags.total_steps: start_step = step start_time = timer() time.sleep(5) if timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timer() sps = (step - start_step) / (timer() - start_time) total_loss = stats.get("total_loss", float("inf")) if stats.get("episode_returns", None): mean_return = ("Return per episode: %.1f. " % stats["mean_episode_return"]) else: mean_return = "" logging.info( "Steps %i @ %.1f SPS. Loss %f. %sStats:\n%s", step, sps, total_loss, mean_return, pprint.pformat(stats), ) except KeyboardInterrupt: return # Try joining actors then quit. else: for thread in threads: thread.join() logging.info("Learning finished after %d steps.", step) finally: for i in range(flags.num_actors): free_queue.put(None) actor_env_queues[i].put("exit") for actor in actor_processes: actor.join(timeout=1) checkpoint() plogger.close()