def test_batched_run(self, batch_size=10): batcher = libtorchbeast.DynamicBatcher(batch_dim=0, minimum_batch_size=batch_size, maximum_batch_size=batch_size) inputs = [torch.full((1, 2, 3), i) for i in range(batch_size)] outputs = torch.ones(batch_size, 42, 3) def target(i): while batcher.size() < i: # Make sure thread i calls compute before thread i + 1. time.sleep(0.05) np.testing.assert_array_equal(batcher.compute(inputs[i]), outputs[i:i + 1]) threads = [] for i in range(batch_size): threads.append( threading.Thread(target=target, name=f"compute-thread-{i}", args=(i, ))) for t in threads: t.start() batch = next(batcher) batched_inputs = batch.get_inputs() np.testing.assert_array_equal(batched_inputs, torch.cat(inputs)) batch.set_outputs(outputs) for t in threads: t.join()
def setUp(self): self.server_proc = subprocess.Popen( ["python", "tests/contiguous_arrays_env.py"]) server_address = ["unix:/tmp/contiguous_arrays_test"] self.learner_queue = libtorchbeast.BatchingQueue(batch_dim=1, minimum_batch_size=1, maximum_batch_size=10, check_inputs=True) self.inference_batcher = libtorchbeast.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=10, timeout_ms=100, check_outputs=True, ) actor = libtorchbeast.ActorPool( unroll_length=1, learner_queue=self.learner_queue, inference_batcher=self.inference_batcher, env_server_addresses=server_address, initial_agent_state=(), ) def run(): actor.run() self.actor_thread = threading.Thread(target=run) self.actor_thread.start() self.target = np.arange(3 * 4 * 5) self.target = self.target.reshape(3, 4, 5) self.target = self.target.transpose(2, 1, 0)
def test_check_outputs2(self): batcher = libtorchbeast.DynamicBatcher(batch_dim=2, minimum_batch_size=1, maximum_batch_size=1) inputs = torch.zeros(1, 2, 3) def target(): batcher.compute(inputs) t = threading.Thread(target=target, name="compute-thread") t.start() batch = next(batcher) with self.assertRaisesRegex( ValueError, "Output shape must have the same batch dimension as the input batch size.", ): # Dimenstion two of the outputs is != from the size of the batch (3 != 1). batch.set_outputs(torch.ones(1, 42, 3)) # Set correct outputs so the thread can join. batch.set_outputs(torch.ones(1, 1, 1)) t.join()
def test_timeout(self): timeout_ms = 300 batcher = libtorchbeast.DynamicBatcher( batch_dim=0, minimum_batch_size=5, maximum_batch_size=5, timeout_ms=timeout_ms, ) inputs = torch.zeros(1, 2, 3) outputs = torch.ones(1, 42, 3) def compute_target(): batcher.compute(inputs) compute_thread = threading.Thread(target=compute_target, name="compute-thread") compute_thread.start() start_waiting_time = time.time() # Wait until approximately timeout_ms. batch = next(batcher) waiting_time_ms = (time.time() - start_waiting_time) * 1000 # Timeout has expired and the batch of size 1 (< minimum_batch_size) # has been consumed. batch.set_outputs(outputs) compute_thread.join() self.assertTrue( timeout_ms <= waiting_time_ms <= timeout_ms + timeout_ms / 10)
def test_many_consumers( self, minimum_batch_size=1, compute_thread_number=64, repeats=100, consume_thread_number=16, ): batcher = libtorchbeast.DynamicBatcher( batch_dim=0, minimum_batch_size=minimum_batch_size) lock = threading.Lock() total_batches_consumed = 0 def compute_thread_target(i): for _ in range(repeats): inputs = torch.full((1, 2, 3), i) batcher.compute(inputs) def consume_thread_target(): nonlocal total_batches_consumed for batch in batcher: inputs = batch.get_inputs() batch_size, *_ = inputs.shape batch.set_outputs(torch.ones_like(inputs)) with lock: total_batches_consumed += batch_size compute_threads = [] for i in range(compute_thread_number): compute_threads.append( threading.Thread(target=compute_thread_target, name=f"compute-thread-{i}", args=(i, ))) consume_threads = [] for i in range(consume_thread_number): consume_threads.append( threading.Thread(target=consume_thread_target, name=f"consume-thread-{i}")) for t in compute_threads + consume_threads: t.start() for t in compute_threads: t.join() # Stop iteration in all consume_threads. batcher.close() for t in consume_threads: t.join() self.assertEqual(total_batches_consumed, compute_thread_number * repeats)
def test_dropped_batch(self): batcher = libtorchbeast.DynamicBatcher(batch_dim=0, minimum_batch_size=1, maximum_batch_size=1) inputs = torch.zeros(1, 2, 3) def target(): with self.assertRaisesRegex(libtorchbeast.AsyncError, _BROKEN_PROMISE_MESSAGE): batcher.compute(inputs) t = threading.Thread(target=target, name="compute-thread") t.start() next(batcher) # Retrieves but doesn't keep the batch object. t.join()
def test_simple_run(self): batcher = libtorchbeast.DynamicBatcher(batch_dim=0, minimum_batch_size=1, maximum_batch_size=1) inputs = torch.zeros(1, 2, 3) outputs = torch.ones(1, 42, 3) def target(): np.testing.assert_array_equal(batcher.compute(inputs), outputs) t = threading.Thread(target=target, name="compute-thread") t.start() batch = next(batcher) np.testing.assert_array_equal(batch.get_inputs(), inputs) batch.set_outputs(outputs) t.join()
def test_multiple_set_outputs_calls(self): batcher = libtorchbeast.DynamicBatcher(batch_dim=0, minimum_batch_size=1, maximum_batch_size=1) inputs = torch.zeros(1, 2, 3) outputs = torch.ones(1, 42, 3) def target(): batcher.compute(inputs) t = threading.Thread(target=target, name="compute-thread") t.start() batch = next(batcher) batch.set_outputs(outputs) with self.assertRaisesRegex(RuntimeError, "set_outputs called twice"): batch.set_outputs(outputs) t.join()
def test_check_outputs1(self): batcher = libtorchbeast.DynamicBatcher(batch_dim=2, minimum_batch_size=1, maximum_batch_size=1) inputs = torch.zeros(1, 2, 3) def target(): batcher.compute(inputs) t = threading.Thread(target=target, name="compute-thread") t.start() batch = next(batcher) with self.assertRaisesRegex(ValueError, "output shape must have at least"): outputs = torch.ones(1) batch.set_outputs(outputs) # Set correct outputs so the thread can join. batch.set_outputs(torch.ones(1, 1, 1)) t.join()
def setUp(self): self.server_proc = subprocess.Popen( ["python", "tests/core_agent_state_env.py"]) self.B = 2 self.T = 3 self.model = Net() server_address = ["unix:/tmp/core_agent_state_test"] self.learner_queue = libtorchbeast.BatchingQueue( batch_dim=1, minimum_batch_size=self.B, maximum_batch_size=self.B, check_inputs=True, ) self.replay_queue = libtorchbeast.BatchingQueue( batch_dim=1, minimum_batch_size=1, maximum_batch_size=1, timeout_ms=100, check_inputs=True, maximum_queue_size=1, ) self.inference_batcher = libtorchbeast.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=1, timeout_ms=100, check_outputs=True, ) self.actor = libtorchbeast.ActorPool( unroll_length=self.T, learner_queue=self.learner_queue, replay_queue=self.replay_queue, inference_batcher=self.inference_batcher, env_server_addresses=server_address, initial_agent_state=self.model.initial_state(), )
def train(flags): logging.info("Logging results to %s", flags.savedir) if isinstance(flags, omegaconf.DictConfig): flag_dict = omegaconf.OmegaConf.to_container(flags) else: flag_dict = vars(flags) plogger = file_writer.FileWriter(xp_args=flag_dict, rootdir=flags.savedir) if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") learner_device = torch.device(flags.learner_device) actor_device = torch.device(flags.actor_device) else: logging.info("Not using CUDA.") learner_device = torch.device("cpu") actor_device = torch.device("cpu") if flags.max_learner_queue_size is None: flags.max_learner_queue_size = flags.batch_size # The queue the learner threads will get their data from. # Setting `minimum_batch_size == maximum_batch_size` # makes the batch size static. We could make it dynamic, but that # requires a loss (and learning rate schedule) that's batch size # independent. learner_queue = libtorchbeast.BatchingQueue( batch_dim=1, minimum_batch_size=flags.batch_size, maximum_batch_size=flags.batch_size, check_inputs=True, maximum_queue_size=flags.max_learner_queue_size, ) # The "batcher", a queue for the inference call. Will yield # "batch" objects with `get_inputs` and `set_outputs` methods. # The batch size of the tensors will be dynamic. inference_batcher = libtorchbeast.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=512, timeout_ms=100, check_outputs=True, ) addresses = [] connections_per_server = 1 pipe_id = 0 while len(addresses) < flags.num_actors: for _ in range(connections_per_server): addresses.append(f"{flags.pipes_basename}.{pipe_id}") if len(addresses) == flags.num_actors: break pipe_id += 1 logging.info("Using model %s", flags.model) model = create_model(flags, learner_device) plogger.metadata["model_numel"] = sum( p.numel() for p in model.parameters() if p.requires_grad ) logging.info("Number of model parameters: %i", plogger.metadata["model_numel"]) actor_model = create_model(flags, actor_device) # The ActorPool that will run `flags.num_actors` many loops. actors = libtorchbeast.ActorPool( unroll_length=flags.unroll_length, learner_queue=learner_queue, inference_batcher=inference_batcher, env_server_addresses=addresses, initial_agent_state=model.initial_state(), ) def run(): try: actors.run() except Exception as e: logging.error("Exception in actorpool thread!") traceback.print_exc() print() raise e actorpool_thread = threading.Thread(target=run, name="actorpool-thread") optimizer = torch.optim.RMSprop( model.parameters(), lr=flags.learning_rate, momentum=flags.momentum, eps=flags.epsilon, alpha=flags.alpha, ) def lr_lambda(epoch): return ( 1 - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps) / flags.total_steps ) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) stats = {} if flags.checkpoint and os.path.exists(flags.checkpoint): logging.info("Loading checkpoint: %s" % flags.checkpoint) checkpoint_states = torch.load( flags.checkpoint, map_location=flags.learner_device ) model.load_state_dict(checkpoint_states["model_state_dict"]) optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"]) scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"]) stats = checkpoint_states["stats"] logging.info(f"Resuming preempted job, current stats:\n{stats}") # Initialize actor model like learner model. actor_model.load_state_dict(model.state_dict()) learner_threads = [ threading.Thread( target=learn, name="learner-thread-%i" % i, args=( learner_queue, model, actor_model, optimizer, scheduler, stats, flags, plogger, learner_device, ), ) for i in range(flags.num_learner_threads) ] inference_threads = [ threading.Thread( target=inference, name="inference-thread-%i" % i, args=(inference_batcher, actor_model, flags, actor_device), ) for i in range(flags.num_inference_threads) ] actorpool_thread.start() for t in learner_threads + inference_threads: t.start() def checkpoint(checkpoint_path=None): if flags.checkpoint: if checkpoint_path is None: checkpoint_path = flags.checkpoint logging.info("Saving checkpoint to %s", checkpoint_path) torch.save( { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "stats": stats, "flags": vars(flags), }, checkpoint_path, ) # TODO: test this again then uncomment (from deleted polyhydra code) # def receive_slurm_signal(signal_num=None, frame=None): # logging.info("Received SIGTERM, checkpointing") # make_checkpoint() # signal.signal(signal.SIGTERM, receive_slurm_signal) def format_value(x): return f"{x:1.5}" if isinstance(x, float) else str(x) try: train_start_time = timeit.default_timer() train_time_offset = stats.get("train_seconds", 0) # used for resuming training last_checkpoint_time = timeit.default_timer() dev_checkpoint_intervals = [0, 0.25, 0.5, 0.75] loop_start_time = timeit.default_timer() loop_start_step = stats.get("step", 0) while True: if loop_start_step >= flags.total_steps: break time.sleep(5) loop_end_time = timeit.default_timer() loop_end_step = stats.get("step", 0) stats["train_seconds"] = round( loop_end_time - train_start_time + train_time_offset, 1 ) if loop_end_time - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = loop_end_time if len(dev_checkpoint_intervals) > 0: step_percentage = loop_end_step / flags.total_steps i = dev_checkpoint_intervals[0] if step_percentage > i: checkpoint(flags.checkpoint[:-4] + "_" + str(i) + ".tar") dev_checkpoint_intervals = dev_checkpoint_intervals[1:] logging.info( "Step %i @ %.1f SPS. Inference batcher size: %i." " Learner queue size: %i." " Other stats: (%s)", loop_end_step, (loop_end_step - loop_start_step) / (loop_end_time - loop_start_time), inference_batcher.size(), learner_queue.size(), ", ".join( f"{key} = {format_value(value)}" for key, value in stats.items() ), ) loop_start_time = loop_end_time loop_start_step = loop_end_step except KeyboardInterrupt: pass # Close properly. else: logging.info("Learning finished after %i steps.", stats["step"]) checkpoint() # Done with learning. Let's stop all the ongoing work. inference_batcher.close() learner_queue.close() actorpool_thread.join() for t in learner_threads + inference_threads: t.join()
def train(flags): if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") plogger = file_writer.FileWriter(xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))) if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") flags.learner_device = torch.device("cuda:0") flags.actor_device = torch.device("cuda:1") else: logging.info("Not using CUDA.") flags.learner_device = torch.device("cpu") flags.actor_device = torch.device("cpu") if flags.max_learner_queue_size is None: flags.max_learner_queue_size = flags.batch_size # The queue the learner threads will get their data from. # Setting `minimum_batch_size == maximum_batch_size` # makes the batch size static. learner_queue = libtorchbeast.BatchingQueue( batch_dim=1, minimum_batch_size=flags.batch_size, maximum_batch_size=flags.batch_size, check_inputs=True, maximum_queue_size=flags.max_learner_queue_size, ) # The "batcher", a queue for the inference call. Will yield # "batch" objects with `get_inputs` and `set_outputs` methods. # The batch size of the tensors will be dynamic. inference_batcher = libtorchbeast.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=512, timeout_ms=100, check_outputs=True, ) addresses = [] connections_per_server = 1 pipe_id = 0 while len(addresses) < flags.num_actors: for _ in range(connections_per_server): addresses.append(f"{flags.pipes_basename}.{pipe_id}") if len(addresses) == flags.num_actors: break pipe_id += 1 model = Net(num_actions=flags.num_actions, use_lstm=flags.use_lstm) model = model.to(device=flags.learner_device) actor_model = Net(num_actions=flags.num_actions, use_lstm=flags.use_lstm) actor_model.to(device=flags.actor_device) # The ActorPool that will run `flags.num_actors` many loops. actors = libtorchbeast.ActorPool( unroll_length=flags.unroll_length, learner_queue=learner_queue, inference_batcher=inference_batcher, env_server_addresses=addresses, initial_agent_state=actor_model.initial_state(), ) def run(): try: actors.run() except Exception as e: logging.error("Exception in actorpool thread!") traceback.print_exc() print() raise e actorpool_thread = threading.Thread(target=run, name="actorpool-thread") optimizer = torch.optim.RMSprop( model.parameters(), lr=flags.learning_rate, momentum=flags.momentum, eps=flags.epsilon, alpha=flags.alpha, ) def lr_lambda(epoch): return (1 - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps) / flags.total_steps) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) stats = {} # Load state from a checkpoint, if possible. if os.path.exists(checkpointpath): checkpoint_states = torch.load(checkpointpath, map_location=flags.learner_device) model.load_state_dict(checkpoint_states["model_state_dict"]) optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"]) scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"]) stats = checkpoint_states["stats"] logging.info(f"Resuming preempted job, current stats:\n{stats}") # Initialize actor model like learner model. actor_model.load_state_dict(model.state_dict()) learner_threads = [ threading.Thread( target=learn, name="learner-thread-%i" % i, args=( flags, learner_queue, model, actor_model, optimizer, scheduler, stats, plogger, ), ) for i in range(flags.num_learner_threads) ] inference_threads = [ threading.Thread( target=inference, name="inference-thread-%i" % i, args=(flags, inference_batcher, actor_model), ) for i in range(flags.num_inference_threads) ] actorpool_thread.start() for t in learner_threads + inference_threads: t.start() def checkpoint(): if flags.disable_checkpoint: return logging.info("Saving checkpoint to %s", checkpointpath) torch.save( { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "stats": stats, "flags": vars(flags), }, checkpointpath, ) def format_value(x): return f"{x:1.5}" if isinstance(x, float) else str(x) try: last_checkpoint_time = timeit.default_timer() while True: start_time = timeit.default_timer() start_step = stats.get("step", 0) if start_step >= flags.total_steps: break time.sleep(5) end_step = stats.get("step", 0) if timeit.default_timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timeit.default_timer() logging.info( "Step %i @ %.1f SPS. Inference batcher size: %i." " Learner queue size: %i." " Other stats: (%s)", end_step, (end_step - start_step) / (timeit.default_timer() - start_time), inference_batcher.size(), learner_queue.size(), ", ".join(f"{key} = {format_value(value)}" for key, value in stats.items()), ) except KeyboardInterrupt: pass # Close properly. else: logging.info("Learning finished after %i steps.", stats["step"]) checkpoint() # Done with learning. Stop all the ongoing work. inference_batcher.close() learner_queue.close() actorpool_thread.join() for t in learner_threads + inference_threads: t.join()
def run(): flags = argparse.Namespace() flags, argv = polybeast_learner.parser.parse_known_args(namespace=flags) flags, argv = polybeast_env.parser.parse_known_args(args=argv, namespace=flags) if argv: # Produce an error message. polybeast_env.parser.print_usage() print("Unkown args:", " ".join(argv)) return -1 env_processes = [] for actor_id in range(1): p = mp.Process(target=run_env, args=(flags, actor_id)) p.start() env_processes.append(p) if flags.max_learner_queue_size is None: flags.max_learner_queue_size = flags.batch_size learner_queue = libtorchbeast.BatchingQueue( batch_dim=1, minimum_batch_size=flags.batch_size, maximum_batch_size=flags.batch_size, check_inputs=True, maximum_queue_size=flags.max_learner_queue_size, ) replay_queue = libtorchbeast.BatchingQueue( batch_dim=1, minimum_batch_size=1, maximum_batch_size=flags.num_actors, timeout_ms=100, check_inputs=True, maximum_queue_size=flags.num_actors, ) inference_batcher = libtorchbeast.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=512, timeout_ms=100, check_outputs=True, ) addresses = [] connections_per_server = 1 pipe_id = 0 while len(addresses) < flags.num_actors: for _ in range(connections_per_server): addresses.append(f"{flags.pipes_basename}.{pipe_id}") if len(addresses) == flags.num_actors: break pipe_id += 1 actors = libtorchbeast.ActorPool( unroll_length=flags.unroll_length, learner_queue=learner_queue, replay_queue=replay_queue, inference_batcher=inference_batcher, env_server_addresses=addresses, initial_agent_state=(), ) def run(): try: actors.run() print("actors are running") except Exception as e: logging.error("Exception in actorpool thread!") traceback.print_exc() print() raise e actorpool_thread = threading.Thread(target=run, name="actorpool-thread") def dequeue(queue): for tensor in queue: del tensor dequeue_threads = [ threading.Thread(target=dequeue, name="dequeue-thread-%i" % i, args=(queue,)) for i, queue in enumerate([learner_queue, replay_queue]) ] # create an environment to sample random actions dataset_uses_color = flags.dataset not in ["mnist", "omniglot"] grayscale = dataset_uses_color and not flags.use_color is_color = flags.use_color or flags.env_type == "fluid" if is_color is False: grayscale = True else: grayscale = is_color and not dataset_uses_color env_name, config = utils.parse_flags(flags) env = utils.create_env(env_name, config, grayscale, dataset=None) if flags.condition: new_space = env.observation_space.spaces c, h, w = new_space["canvas"].shape new_space["canvas"] = spaces.Box( low=0, high=255, shape=(c * 2, h, w), dtype=np.uint8 ) env.observation_space = spaces.Dict(new_space) action_space = env.action_space env.close() def inference(inference_batcher, lock=threading.Lock()): nonlocal step for batch in inference_batcher: batched_env_outputs, agent_state = batch.get_inputs() obs, _, done, *_ = batched_env_outputs B = done.shape[1] with lock: step += B actions = nest.map(lambda i: action_space.sample(), [i for i in range(B)]) action = torch.from_numpy(np.concatenate(actions)).view(1, B, -1) outputs = ((action,), ()) outputs = nest.map(lambda t: t, outputs) batch.set_outputs(outputs) lock = threading.Lock() inference_threads = [ threading.Thread( target=inference, name="inference-thread-%i" % i, args=(inference_batcher, lock), ) for i in range(flags.num_inference_threads) ] actorpool_thread.start() threads = dequeue_threads + inference_threads for t in threads: t.start() step = 0 try: while step < 10000: start_time = timeit.default_timer() start_step = step time.sleep(3) end_step = step logging.info( "Step %i @ %.1f SPS.", end_step, (end_step - start_step) / (timeit.default_timer() - start_time), ) except KeyboardInterrupt: pass inference_batcher.close() learner_queue.close() replay_queue.close() actorpool_thread.join() for t in threads: t.join() for p in env_processes: p.terminate()
def train(flags): if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") plogger = file_writer.FileWriter(xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))) if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") flags.learner_device = torch.device("cuda") flags.actor_device = torch.device("cuda") else: logging.info("Not using CUDA.") flags.learner_device = torch.device("cpu") flags.actor_device = torch.device("cpu") if flags.max_learner_queue_size is None: flags.max_learner_queue_size = flags.batch_size # The queue the learner threads will get their data from. # Setting `minimum_batch_size == maximum_batch_size` # makes the batch size static. learner_queue = libtorchbeast.BatchingQueue( batch_dim=1, minimum_batch_size=flags.batch_size, maximum_batch_size=flags.batch_size, check_inputs=True, maximum_queue_size=flags.max_learner_queue_size, ) # The queue the actorpool stores final render image pairs. # A seperate thread will load them to the ReplayBuffer. # The batch size of the pairs will be dynamic. replay_queue = libtorchbeast.BatchingQueue( batch_dim=1, minimum_batch_size=flags.batch_size, maximum_batch_size=flags.batch_size, check_inputs=True, maximum_queue_size=flags.batch_size, ) # The "batcher", a queue for the inference call. Will yield # "batch" objects with `get_inputs` and `set_outputs` methods. # The batch size of the tensors will be dynamic. inference_batcher = libtorchbeast.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=512, timeout_ms=100, check_outputs=True, ) addresses = [] connections_per_server = 1 pipe_id = 0 while len(addresses) < flags.num_actors: for _ in range(connections_per_server): addresses.append(f"{flags.pipes_basename}.{pipe_id}") if len(addresses) == flags.num_actors: break pipe_id += 1 dataset_is_gray = flags.dataset in ["mnist", "omniglot"] grayscale = not dataset_is_gray and not flags.use_color dataset_is_gray |= grayscale dataset = utils.create_dataset(flags.dataset, grayscale) env_name, config = utils.parse_flags(flags) env = utils.create_env(env_name, config, dataset_is_gray, dataset=None) if flags.condition: new_space = env.observation_space.spaces c, h, w = new_space["canvas"].shape new_space["canvas"] = spaces.Box(low=0, high=255, shape=(c * 2, h, w), dtype=np.uint8) env.observation_space = spaces.Dict(new_space) obs_shape = env.observation_space["canvas"].shape action_shape = env.action_space.nvec order = env.order env.close() model = models.Net( obs_shape=obs_shape, order=order, action_shape=action_shape, grid_shape=(grid_width, grid_width), ) model = model.to(device=flags.learner_device) actor_model = models.Net( obs_shape=obs_shape, order=order, action_shape=action_shape, grid_shape=(grid_width, grid_width), ).eval() actor_model.to(device=flags.actor_device) if flags.condition: D = models.ComplementDiscriminator(obs_shape) else: D = models.Discriminator(obs_shape) D.to(device=flags.learner_device) if flags.condition: D_eval = models.ComplementDiscriminator(obs_shape) else: D_eval = models.Discriminator(obs_shape) D_eval = D_eval.to(device=flags.learner_device).eval() optimizer = optim.Adam(model.parameters(), lr=flags.policy_learning_rate) D_optimizer = optim.Adam(D.parameters(), lr=flags.discriminator_learning_rate, betas=(0.5, 0.999)) def lr_lambda(epoch): return (1 - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps) / flags.total_steps) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # The ActorPool that will run `flags.num_actors` many loops. actors = libtorchbeast.ActorPool( unroll_length=flags.unroll_length, learner_queue=learner_queue, replay_queue=replay_queue, inference_batcher=inference_batcher, env_server_addresses=addresses, initial_agent_state=actor_model.initial_state(), ) def run(): try: actors.run() print("actors are running") except Exception as e: logging.error("Exception in actorpool thread!") traceback.print_exc() print() raise e actorpool_thread = threading.Thread(target=run, name="actorpool-thread") dataloader = DataLoader( dataset, batch_size=flags.batch_size, shuffle=True, drop_last=True, pin_memory=True, ) stats = {} # Load state from a checkpoint, if possible. if os.path.exists(checkpointpath): checkpoint_states = torch.load(checkpointpath, map_location=flags.learner_device) model.load_state_dict(checkpoint_states["model_state_dict"]) D.load_state_dict(checkpoint_states["D_state_dict"]) optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"]) D_optimizer.load_state_dict( checkpoint_states["D_optimizer_state_dict"]) scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"]) stats = checkpoint_states["stats"] logging.info(f"Resuming preempted job, current stats:\n{stats}") # Initialize actor model like learner model. actor_model.load_state_dict(model.state_dict()) D_eval.load_state_dict(D.state_dict()) learner_threads = [ threading.Thread( target=learn, name="learner-thread-%i" % i, args=( flags, learner_queue, model, actor_model, D_eval, optimizer, scheduler, stats, plogger, ), ) for i in range(flags.num_learner_threads) ] inference_threads = [ threading.Thread( target=inference, name="inference-thread-%i" % i, args=( flags, inference_batcher, actor_model, ), ) for i in range(flags.num_inference_threads) ] d_learner = threading.Thread( target=learn_D, name="d_learner-thread", args=( flags, dataloader, replay_queue, D, D_eval, D_optimizer, stats, plogger, ), ) actorpool_thread.start() threads = learner_threads + inference_threads for t in inference_threads: t.start() def checkpoint(): if flags.disable_checkpoint: return logging.info("Saving checkpoint to %s", checkpointpath) torch.save( { "model_state_dict": model.state_dict(), "D_state_dict": D.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "D_optimizer_state_dict": D_optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "stats": stats, "flags": vars(flags), }, checkpointpath, ) def format_value(x): return f"{x:1.5}" if isinstance(x, float) else str(x) try: while replay_queue.size() < flags.batch_size: if learner_queue.size() >= flags.batch_size: next(learner_queue) time.sleep(0.01) d_learner.start() for t in learner_threads: t.start() last_checkpoint_time = timeit.default_timer() while True: start_time = timeit.default_timer() start_step = stats.get("step", 0) if start_step >= flags.total_steps: break time.sleep(5) end_step = stats.get("step", 0) if timeit.default_timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timeit.default_timer() logging.info( "Step %i @ %.1f SPS. Inference batcher size: %i." " Learner queue size: %i." " Other stats: (%s)", end_step, (end_step - start_step) / (timeit.default_timer() - start_time), inference_batcher.size(), learner_queue.size(), ", ".join(f"{key} = {format_value(value)}" for key, value in stats.items()), ) except KeyboardInterrupt: pass # Close properly. else: logging.info("Learning finished after %i steps.", stats["step"]) checkpoint() # Done with learning. Stop all the ongoing work. inference_batcher.close() learner_queue.close() replay_queue.close() actorpool_thread.join() d_learner.join() for t in threads: t.join()