def test_batched_run(self, batch_size=10): batcher = actorpool.DynamicBatcher( batch_dim=0, minimum_batch_size=batch_size, maximum_batch_size=batch_size ) inputs = [torch.full((1, 2, 3), i) for i in range(batch_size)] outputs = torch.ones(batch_size, 42, 3) def target(i): while batcher.size() < i: # Make sure thread i calls compute before thread i + 1. time.sleep(0.05) np.testing.assert_array_equal( batcher.compute(inputs[i]), outputs[i : i + 1] ) threads = [] for i in range(batch_size): threads.append( threading.Thread(target=target, name=f"compute-thread-{i}", args=(i,)) ) for t in threads: t.start() batch = next(batcher) batched_inputs = batch.get_inputs() np.testing.assert_array_equal(batched_inputs, torch.cat(inputs)) batch.set_outputs(outputs) for t in threads: t.join()
def setUp(self): self.server_proc = subprocess.Popen( ["python", "tests/contiguous_arrays_env.py"]) server_address = ["unix:/tmp/contiguous_arrays_test"] self.learner_queue = actorpool.BatchingQueue(batch_dim=1, minimum_batch_size=1, maximum_batch_size=10, check_inputs=True) self.inference_batcher = actorpool.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=10, timeout_ms=100, check_outputs=True, ) actor = actorpool.ActorPool( unroll_length=1, learner_queue=self.learner_queue, inference_batcher=self.inference_batcher, env_server_addresses=server_address, initial_agent_state=(), ) def run(): actor.run() self.actor_thread = threading.Thread(target=run) self.actor_thread.start() self.target = np.arange(3 * 4 * 5) self.target = self.target.reshape(3, 4, 5) self.target = self.target.transpose(2, 1, 0)
def setUp(self): self.server_proc = subprocess.Popen(["python", "tests/core_agent_state_env.py"]) self.B = 2 self.T = 3 self.model = Net() server_address = ["unix:/tmp/core_agent_state_test"] self.learner_queue = actorpool.BatchingQueue( batch_dim=1, minimum_batch_size=self.B, maximum_batch_size=self.B, check_inputs=True, ) self.inference_batcher = actorpool.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=1, timeout_ms=100, check_outputs=True, ) self.actor = actorpool.ActorPool( unroll_length=self.T, learner_queue=self.learner_queue, inference_batcher=self.inference_batcher, env_server_addresses=server_address, initial_agent_state=self.model.initial_state(), )
def test_timeout(self): timeout_ms = 300 batcher = actorpool.DynamicBatcher( batch_dim=0, minimum_batch_size=5, maximum_batch_size=5, timeout_ms=timeout_ms, ) inputs = torch.zeros(1, 2, 3) outputs = torch.ones(1, 42, 3) def compute_target(): batcher.compute(inputs) compute_thread = threading.Thread(target=compute_target, name="compute-thread") compute_thread.start() start_waiting_time = time.time() # Wait until approximately timeout_ms. batch = next(batcher) waiting_time_ms = (time.time() - start_waiting_time) * 1000 # Timeout has expired and the batch of size 1 (< minimum_batch_size) # has been consumed. batch.set_outputs(outputs) compute_thread.join() self.assertTrue(timeout_ms <= waiting_time_ms <= timeout_ms + timeout_ms / 10)
def test_check_outputs2(self): batcher = actorpool.DynamicBatcher( batch_dim=2, minimum_batch_size=1, maximum_batch_size=1 ) inputs = torch.zeros(1, 2, 3) def target(): batcher.compute(inputs) t = threading.Thread(target=target, name="compute-thread") t.start() batch = next(batcher) with self.assertRaisesRegex( ValueError, "Output shape must have the same batch dimension as the input batch size.", ): # Dimenstion two of the outputs is != from the size of the batch (3 != 1). batch.set_outputs(torch.ones(1, 42, 3)) # Set correct outputs so the thread can join. batch.set_outputs(torch.ones(1, 1, 1)) t.join()
def test_many_consumers( self, minimum_batch_size=1, compute_thread_number=64, repeats=100, consume_thread_number=16, ): batcher = actorpool.DynamicBatcher( batch_dim=0, minimum_batch_size=minimum_batch_size ) lock = threading.Lock() total_batches_consumed = 0 def compute_thread_target(i): for _ in range(repeats): inputs = torch.full((1, 2, 3), i) batcher.compute(inputs) def consume_thread_target(): nonlocal total_batches_consumed for batch in batcher: inputs = batch.get_inputs() batch_size, *_ = inputs.shape batch.set_outputs(torch.ones_like(inputs)) with lock: total_batches_consumed += batch_size compute_threads = [] for i in range(compute_thread_number): compute_threads.append( threading.Thread( target=compute_thread_target, name=f"compute-thread-{i}", args=(i,) ) ) consume_threads = [] for i in range(consume_thread_number): consume_threads.append( threading.Thread( target=consume_thread_target, name=f"consume-thread-{i}" ) ) for t in compute_threads + consume_threads: t.start() for t in compute_threads: t.join() # Stop iteration in all consume_threads. batcher.close() for t in consume_threads: t.join() self.assertEqual(total_batches_consumed, compute_thread_number * repeats)
def test_dropped_batch(self): batcher = actorpool.DynamicBatcher( batch_dim=0, minimum_batch_size=1, maximum_batch_size=1 ) inputs = torch.zeros(1, 2, 3) def target(): with self.assertRaisesRegex(actorpool.AsyncError, _BROKEN_PROMISE_MESSAGE): batcher.compute(inputs) t = threading.Thread(target=target, name="compute-thread") t.start() next(batcher) # Retrieves but doesn't keep the batch object. t.join()
def test_simple_run(self): batcher = actorpool.DynamicBatcher( batch_dim=0, minimum_batch_size=1, maximum_batch_size=1 ) inputs = torch.zeros(1, 2, 3) outputs = torch.ones(1, 42, 3) def target(): np.testing.assert_array_equal(batcher.compute(inputs), outputs) t = threading.Thread(target=target, name="compute-thread") t.start() batch = next(batcher) np.testing.assert_array_equal(batch.get_inputs(), inputs) batch.set_outputs(outputs) t.join()
def test_multiple_set_outputs_calls(self): batcher = actorpool.DynamicBatcher( batch_dim=0, minimum_batch_size=1, maximum_batch_size=1 ) inputs = torch.zeros(1, 2, 3) outputs = torch.ones(1, 42, 3) def target(): batcher.compute(inputs) t = threading.Thread(target=target, name="compute-thread") t.start() batch = next(batcher) batch.set_outputs(outputs) with self.assertRaisesRegex(RuntimeError, "set_outputs called twice"): batch.set_outputs(outputs) t.join()
def test_check_outputs1(self): batcher = actorpool.DynamicBatcher( batch_dim=2, minimum_batch_size=1, maximum_batch_size=1 ) inputs = torch.zeros(1, 2, 3) def target(): batcher.compute(inputs) t = threading.Thread(target=target, name="compute-thread") t.start() batch = next(batcher) with self.assertRaisesRegex(ValueError, "output shape must have at least"): outputs = torch.ones(1) batch.set_outputs(outputs) # Set correct outputs so the thread can join. batch.set_outputs(torch.ones(1, 1, 1)) t.join()
def train(flags): if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") plogger = file_writer.FileWriter( xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir ) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar")) ) if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") flags.learner_device = torch.device("cuda:0") flags.actor_device = torch.device("cuda:1") else: logging.info("Not using CUDA.") flags.learner_device = torch.device("cpu") flags.actor_device = torch.device("cpu") if flags.max_learner_queue_size is None: flags.max_learner_queue_size = flags.batch_size # The queue the learner threads will get their data from. # Setting `minimum_batch_size == maximum_batch_size` # makes the batch size static. learner_queue = actorpool.BatchingQueue( batch_dim=1, minimum_batch_size=flags.batch_size, maximum_batch_size=flags.batch_size, check_inputs=True, maximum_queue_size=flags.max_learner_queue_size, ) # The "batcher", a queue for the inference call. Will yield # "batch" objects with `get_inputs` and `set_outputs` methods. # The batch size of the tensors will be dynamic. inference_batcher = actorpool.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=512, timeout_ms=100, check_outputs=True, ) addresses = [] connections_per_server = 1 pipe_id = 0 while len(addresses) < flags.num_actors: for _ in range(connections_per_server): addresses.append(f"{flags.pipes_basename}.{pipe_id}") if len(addresses) == flags.num_actors: break pipe_id += 1 model = Net(num_actions=flags.num_actions, use_lstm=flags.use_lstm) model = model.to(device=flags.learner_device) actor_model = Net(num_actions=flags.num_actions, use_lstm=flags.use_lstm) actor_model.to(device=flags.actor_device) # The ActorPool that will run `flags.num_actors` many loops. actors = actorpool.ActorPool( unroll_length=flags.unroll_length, learner_queue=learner_queue, inference_batcher=inference_batcher, env_server_addresses=addresses, initial_agent_state=actor_model.initial_state(), ) def run(): try: actors.run() except Exception as e: logging.error("Exception in actorpool thread!") traceback.print_exc() print() raise e actorpool_thread = threading.Thread(target=run, name="actorpool-thread") optimizer = torch.optim.RMSprop( model.parameters(), lr=flags.learning_rate, momentum=flags.momentum, eps=flags.epsilon, alpha=flags.alpha, ) def lr_lambda(epoch): return ( 1 - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps) / flags.total_steps ) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) stats = {} # Load state from a checkpoint, if possible. if os.path.exists(checkpointpath): checkpoint_states = torch.load( checkpointpath, map_location=flags.learner_device ) model.load_state_dict(checkpoint_states["model_state_dict"]) optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"]) scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"]) stats = checkpoint_states["stats"] logging.info(f"Resuming preempted job, current stats:\n{stats}") # Initialize actor model like learner model. actor_model.load_state_dict(model.state_dict()) learner_threads = [ threading.Thread( target=learn, name="learner-thread-%i" % i, args=( flags, learner_queue, model, actor_model, optimizer, scheduler, stats, plogger, ), ) for i in range(flags.num_learner_threads) ] inference_threads = [ threading.Thread( target=inference, name="inference-thread-%i" % i, args=(flags, inference_batcher, actor_model), ) for i in range(flags.num_inference_threads) ] actorpool_thread.start() for t in learner_threads + inference_threads: t.start() def checkpoint(): if flags.disable_checkpoint: return logging.info("Saving checkpoint to %s", checkpointpath) torch.save( { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "stats": stats, "flags": vars(flags), }, checkpointpath, ) def format_value(x): return f"{x:1.5}" if isinstance(x, float) else str(x) try: last_checkpoint_time = timeit.default_timer() while True: start_time = timeit.default_timer() start_step = stats.get("step", 0) if start_step >= flags.total_steps: break time.sleep(5) end_step = stats.get("step", 0) if timeit.default_timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timeit.default_timer() logging.info( "Step %i @ %.1f SPS. Inference batcher size: %i." " Learner queue size: %i." " Other stats: (%s)", end_step, (end_step - start_step) / (timeit.default_timer() - start_time), inference_batcher.size(), learner_queue.size(), ", ".join( f"{key} = {format_value(value)}" for key, value in stats.items() ), ) except KeyboardInterrupt: pass # Close properly. else: logging.info("Learning finished after %i steps.", stats["step"]) checkpoint() # Done with learning. Stop all the ongoing work. inference_batcher.close() learner_queue.close() actorpool_thread.join() for t in learner_threads + inference_threads: t.join()
def train(flags): if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") plogger = file_writer.FileWriter(xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))) if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") flags.learner_device = torch.device("cuda") flags.actor_device = torch.device("cuda") else: logging.info("Not using CUDA.") flags.learner_device = torch.device("cpu") flags.actor_device = torch.device("cpu") if flags.max_learner_queue_size is None: flags.max_learner_queue_size = flags.batch_size # The queue the learner threads will get their data from. # Setting `minimum_batch_size == maximum_batch_size` # makes the batch size static. learner_queue = actorpool.BatchingQueue( batch_dim=1, minimum_batch_size=flags.batch_size, maximum_batch_size=flags.batch_size, check_inputs=True, maximum_queue_size=flags.max_learner_queue_size, ) d_queue = Queue(maxsize=flags.max_learner_queue_size // flags.batch_size) image_queue = Queue(maxsize=flags.max_learner_queue_size) # The "batcher", a queue for the inference call. Will yield # "batch" objects with `get_inputs` and `set_outputs` methods. # The batch size of the tensors will be dynamic. inference_batcher = actorpool.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=512, timeout_ms=100, check_outputs=True, ) addresses = [] connections_per_server = 1 pipe_id = 0 while len(addresses) < flags.num_actors: for _ in range(connections_per_server): addresses.append(f"{flags.pipes_basename}.{pipe_id}") if len(addresses) == flags.num_actors: break pipe_id += 1 config = dict( episode_length=flags.episode_length, canvas_width=flags.canvas_width, grid_width=grid_width, brush_sizes=flags.brush_sizes, ) if flags.dataset == "celeba" or flags.dataset == "celeba-hq": use_color = True else: use_color = False if flags.env_type == "fluid": env_name = "Fluid" config["shaders_basedir"] = SHADERS_BASEDIR elif flags.env_type == "libmypaint": env_name = "Libmypaint" config.update( dict( brush_type=flags.brush_type, use_color=use_color, use_pressure=flags.use_pressure, use_alpha=False, background="white", brushes_basedir=BRUSHES_BASEDIR, )) if flags.use_compound: env_name += "-v1" else: env_name += "-v0" env = env_wrapper.make_raw(env_name, config) if frame_width != flags.canvas_width: env = env_wrapper.WarpFrame(env, height=frame_width, width=frame_width) env = env_wrapper.wrap_pytorch(env) obs_shape = env.observation_space.shape if flags.condition: c, h, w = obs_shape c *= 2 obs_shape = (c, h, w) action_shape = env.action_space.nvec.tolist() order = env.order env.close() model = models.Net( obs_shape=obs_shape, action_shape=action_shape, grid_shape=(grid_width, grid_width), order=order, ) if flags.condition: model = models.Condition(model) model = model.to(device=flags.learner_device) actor_model = models.Net( obs_shape=obs_shape, action_shape=action_shape, grid_shape=(grid_width, grid_width), order=order, ) if flags.condition: actor_model = models.Condition(actor_model) actor_model.to(device=flags.actor_device) D = models.Discriminator(obs_shape, flags.power_iters) if flags.condition: D = models.Conditional(D) D.to(device=flags.learner_device) D_eval = models.Discriminator(obs_shape, flags.power_iters) if flags.condition: D_eval = models.Conditional(D_eval) D_eval = D_eval.to(device=flags.learner_device) optimizer = optim.Adam(model.parameters(), lr=flags.policy_learning_rate) D_optimizer = optim.Adam(D.parameters(), lr=flags.discriminator_learning_rate, betas=(0.5, 0.999)) def lr_lambda(epoch): return (1 - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps) / flags.total_steps) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) D_scheduler = torch.optim.lr_scheduler.LambdaLR(D_optimizer, lr_lambda) C, H, W = obs_shape if flags.condition: C //= 2 # The ActorPool that will run `flags.num_actors` many loops. actors = actorpool.ActorPool( unroll_length=flags.unroll_length, learner_queue=learner_queue, inference_batcher=inference_batcher, env_server_addresses=addresses, initial_action=actor_model.initial_action(), initial_agent_state=actor_model.initial_state(), image=torch.zeros(1, 1, C, H, W), ) def run(): try: actors.run() print("actors are running") except Exception as e: logging.error("Exception in actorpool thread!") traceback.print_exc() print() raise e actorpool_thread = threading.Thread(target=run, name="actorpool-thread") c, h, w = obs_shape tsfm = transforms.Compose( [transforms.Resize((h, w)), transforms.ToTensor()]) dataset = flags.dataset if dataset == "mnist": dataset = MNIST(root="./", train=True, transform=tsfm, download=True) elif dataset == "omniglot": dataset = Omniglot(root="./", background=True, transform=tsfm, download=True) elif dataset == "celeba": dataset = CelebA(root="./", split="train", target_type=None, transform=tsfm, download=True) elif dataset == "celeba-hq": dataset = datasets.CelebAHQ(root="./", split="train", transform=tsfm, download=True) else: raise NotImplementedError dataloader = DataLoader(dataset, batch_size=1, shuffle=True, drop_last=True, pin_memory=True) stats = {} # Load state from a checkpoint, if possible. if os.path.exists(checkpointpath): checkpoint_states = torch.load(checkpointpath, map_location=flags.learner_device) model.load_state_dict(checkpoint_states["model_state_dict"]) D.load_state_dict(checkpoint_states["D_state_dict"]) optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"]) D_optimizer.load_state_dict( checkpoint_states["D_optimizer_state_dict"]) scheduler.load_state_dict(checkpoint_states["D_scheduler_state_dict"]) D_scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"]) stats = checkpoint_states["stats"] logging.info(f"Resuming preempted job, current stats:\n{stats}") # Initialize actor model like learner model. actor_model.load_state_dict(model.state_dict()) D_eval.load_state_dict(D.state_dict()) learner_threads = [ threading.Thread( target=learn, name="learner-thread-%i" % i, args=( flags, learner_queue, d_queue, model, actor_model, D_eval, optimizer, scheduler, stats, plogger, ), ) for i in range(flags.num_learner_threads) ] inference_threads = [ threading.Thread( target=inference, name="inference-thread-%i" % i, args=( flags, inference_batcher, actor_model, image_queue, ), ) for i in range(flags.num_inference_threads) ] d_learner = [ threading.Thread( target=learn_D, name="d_learner-thread-%i" % i, args=( flags, d_queue, D, D_eval, D_optimizer, D_scheduler, stats, plogger, ), ) for i in range(flags.num_learner_threads) ] for thread in d_learner: thread.daemon = True dataloader_thread = threading.Thread(target=data_loader, args=( flags, dataloader, image_queue, )) dataloader_thread.daemon = True actorpool_thread.start() threads = learner_threads + inference_threads daemons = d_learner + [dataloader_thread] for t in threads + daemons: t.start() def checkpoint(): if flags.disable_checkpoint: return logging.info("Saving checkpoint to %s", checkpointpath) torch.save( { "model_state_dict": model.state_dict(), "D_state_dict": D.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "D_optimizer_state_dict": D_optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "D_scheduler_state_dict": D_scheduler.state_dict(), "stats": stats, "flags": vars(flags), }, checkpointpath, ) def format_value(x): return f"{x:1.5}" if isinstance(x, float) else str(x) try: last_checkpoint_time = timeit.default_timer() while True: start_time = timeit.default_timer() start_step = stats.get("step", 0) if start_step >= flags.total_steps: break time.sleep(5) end_step = stats.get("step", 0) if timeit.default_timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timeit.default_timer() logging.info( "Step %i @ %.1f SPS. Inference batcher size: %i." " Learner queue size: %i." " Other stats: (%s)", end_step, (end_step - start_step) / (timeit.default_timer() - start_time), inference_batcher.size(), learner_queue.size(), ", ".join(f"{key} = {format_value(value)}" for key, value in stats.items()), ) except KeyboardInterrupt: pass # Close properly. else: logging.info("Learning finished after %i steps.", stats["step"]) checkpoint() # Done with learning. Stop all the ongoing work. inference_batcher.close() learner_queue.close() actorpool_thread.join() for t in threads: t.join()
def test(flags, **kwargs): if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA for testing.") flags.actor_device = torch.device("cuda:0") else: logging.info("Not using CUDA for testing.") flags.actor_device = torch.device("cpu") model = Net( observation_shape=flags.observation_shape, hidden_size=flags.hidden_size, num_actions=flags.num_actions, use_lstm=flags.use_lstm, ) model.eval() logging.info("Initializing weights from {} for testing.".format( flags.checkpoint)) checkpoint = torch.load(flags.checkpoint, map_location=flags.actor_device) model.load_state_dict(checkpoint["model_state_dict"]) model = model.to(flags.actor_device) inference_batcher = actorpool.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=512, timeout_ms=100, check_outputs=True, ) inference_threads = [ threading.Thread( target=inference, name="inference-thread-%i" % i, args=(inference_batcher, model, flags), ) for i in range(flags.num_inference_threads) ] # Initialize ActorPool in test mode (without learner queue) for # RPC communication with the env and enqueueing steps in inference batcher. actors = actorpool.ActorPool( unroll_length=0, # Unused in test mode learner_queue=None, # Indicates test mode inference_batcher=inference_batcher, server_address=flags.address, initial_agent_state=model.initial_state(), ) def run(): try: actors.run() except Exception as e: logging.error("Exception in actorpool thread!") traceback.print_exc() print() raise e actorpool_thread = threading.Thread(target=run, name="actorpool-thread") actorpool_thread.start() for t in inference_threads: t.start() # Wait until interrupted try: while True: time.sleep(10) except KeyboardInterrupt: pass # Close properly. logging.info("Testing finished.") inference_batcher.close() actors.stop() actorpool_thread.join() for t in inference_threads: t.join()
def train(flags, rank=0, barrier=None, device="cuda:0", gossip_buffer=None): if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") plogger = file_writer.FileWriter(xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir) if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") flags.learner_device = torch.device(device) flags.actor_device = torch.device(device) else: logging.info("Not using CUDA.") flags.learner_device = torch.device("cpu") flags.actor_device = torch.device("cpu") # The queue the learner threads will get their data from. # Setting `minimum_batch_size == maximum_batch_size` # makes the batch size static. We could make it dynamic, but that # requires a loss (and learning rate schedule) that's batch size # independent. learner_queue = actorpool.BatchingQueue( batch_dim=1, minimum_batch_size=flags.batch_size, maximum_batch_size=flags.batch_size, check_inputs=True, ) # The "batcher", a queue for the inference call. Will yield # "batch" objects with `get_inputs` and `set_outputs` methods. # The batch size of the tensors will be dynamic. inference_batcher = actorpool.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=512, timeout_ms=100, check_outputs=True, ) model = Net( observation_shape=flags.observation_shape, hidden_size=flags.hidden_size, num_actions=flags.num_actions, use_lstm=flags.use_lstm, ) model = model.to(device=flags.learner_device) actor_model = Net( observation_shape=flags.observation_shape, hidden_size=flags.hidden_size, num_actions=flags.num_actions, use_lstm=flags.use_lstm, ) actor_model.to(device=flags.actor_device) # The ActorPool that will accept connections from actor clients. actors = actorpool.ActorPool( unroll_length=flags.unroll_length, learner_queue=learner_queue, inference_batcher=inference_batcher, server_address=flags.address, initial_agent_state=actor_model.initial_state(), ) def run(): try: actors.run() except Exception as e: logging.error("Exception in actorpool thread!") traceback.print_exc() print() raise e actorpool_thread = threading.Thread(target=run, name="actorpool-thread") optimizer = torch.optim.RMSprop( model.parameters(), lr=flags.learning_rate, momentum=flags.momentum, eps=flags.epsilon, alpha=flags.alpha, ) def lr_lambda(epoch): return (1 - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps) / flags.total_steps) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) stats = {} learner_threads = [ threading.Thread( target=learn, name="learner-thread-%i" % i, args=( learner_queue, model, actor_model, optimizer, scheduler, stats, flags, plogger, rank, gossip_buffer, ), ) for i in range(flags.num_learner_threads) ] inference_threads = [ threading.Thread( target=inference, name="inference-thread-%i" % i, args=(inference_batcher, actor_model, flags), ) for i in range(flags.num_inference_threads) ] # Synchronize GALA agents before starting training if barrier is not None: barrier.wait() logging.info("%s: barrier passed" % rank) actorpool_thread.start() for t in learner_threads + inference_threads: t.start() def checkpoint(): if flags.checkpoint: logging.info("Saving checkpoint to %s", flags.checkpoint) torch.save( { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "flags": vars(flags), }, flags.checkpoint, ) def format_value(x): return f"{x:1.5}" if isinstance(x, float) else str(x) try: last_checkpoint_time = timeit.default_timer() while True: start_time = timeit.default_timer() start_step = stats.get("step", 0) if start_step >= flags.total_steps: break time.sleep(5) end_step = stats.get("step", 0) if timeit.default_timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timeit.default_timer() logging.info( "Step %i @ %.1f SPS. Inference batcher size: %i." " Learner queue size: %i." " Other stats: (%s)", end_step, (end_step - start_step) / (timeit.default_timer() - start_time), inference_batcher.size(), learner_queue.size(), ", ".join(f"{key} = {format_value(value)}" for key, value in stats.items()), ) except KeyboardInterrupt: pass # Close properly. else: logging.info("Learning finished after %i steps.", stats["step"]) checkpoint() # Done with learning. Let's stop all the ongoing work. inference_batcher.close() learner_queue.close() actors.stop() actorpool_thread.join() for t in learner_threads + inference_threads: t.join() # Trace and save the final model. trace_model(flags, model)