def test_initial_state(self): model = models.Net( obs_shape=self.obs_shape, order=self.order, action_shape=self.action_shape, grid_shape=self.grid_shape, ) core_state = model.initial_state(self.batch_size) self.assertEqual(len(core_state), 2) for core_state_element in core_state: self.assertSequenceEqual( core_state_element.shape, (1, self.batch_size, self.core_output_size), )
def test_forward_return_signature(self): model = models.Net( obs_shape=self.obs_shape, order=self.order, action_shape=self.action_shape, grid_shape=self.grid_shape, ) core_state = model.initial_state(self.batch_size) (action, policy_logits, baseline), core_state = model(*self.inputs, core_state) self.assertSequenceEqual( action.shape, (self.batch_size, self.unroll_length, len(self.action_shape)) ) for logits, num_actions in zip(policy_logits, self.action_shape): self.assertSequenceEqual( logits.shape, (self.batch_size, self.unroll_length, num_actions) ) self.assertSequenceEqual(baseline.shape, (self.batch_size, self.unroll_length)) for core_state_element in core_state: self.assertSequenceEqual( core_state_element.shape, (1, self.batch_size, self.core_output_size), )
def test(flags): if flags.xpid is None: checkpointpath = "./latest/model.tar" else: checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))) config = dict( episode_length=flags.episode_length, canvas_width=flags.canvas_width, grid_width=grid_width, brush_sizes=flags.brush_sizes, ) if flags.dataset == "celeba" or flags.dataset == "celeba-hq": use_color = True else: use_color = False if flags.env_type == "fluid": env_name = "Fluid" config["shaders_basedir"] = SHADERS_BASEDIR elif flags.env_type == "libmypaint": env_name = "Libmypaint" config.update( dict( brush_type=flags.brush_type, use_color=use_color, use_pressure=flags.use_pressure, use_alpha=False, background="white", brushes_basedir=BRUSHES_BASEDIR, )) if flags.use_compound: env_name += "-v1" config.update( dict( new_stroke_penalty=flags.new_stroke_penalty, stroke_length_penalty=flags.stroke_length_penalty, )) else: env_name += "-v0" env = env_wrapper.make_raw(env_name, config) if frame_width != flags.canvas_width: env = env_wrapper.WarpFrame(env, height=frame_width, width=frame_width) env = env_wrapper.wrap_pytorch(env) env = env_wrapper.AddDim(env) obs_shape = env.observation_space.shape if flags.condition: c, h, w = obs_shape c *= 2 obs_shape = (c, h, w) action_shape = env.action_space.nvec.tolist() order = env.order model = models.Net( obs_shape=obs_shape, action_shape=action_shape, grid_shape=(grid_width, grid_width), order=order, ) if flags.condition: model = models.Condition(model) model.eval() D = models.Discriminator(obs_shape, flags.power_iters) if flags.condition: D = models.Conditional(D) D.eval() checkpoint = torch.load(checkpointpath, map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) D.load_state_dict(checkpoint["D_state_dict"]) if flags.condition: from random import randrange c, h, w = obs_shape tsfm = transforms.Compose( [transforms.Resize((h, w)), transforms.ToTensor()]) dataset = flags.dataset if dataset == "mnist": dataset = MNIST(root="./", train=True, transform=tsfm, download=True) elif dataset == "omniglot": dataset = Omniglot(root="./", background=True, transform=tsfm, download=True) elif dataset == "celeba": dataset = CelebA( root="./", split="train", target_type=None, transform=tsfm, download=True, ) elif dataset == "celeba-hq": dataset = datasets.CelebAHQ(root="./", split="train", transform=tsfm, download=True) else: raise NotImplementedError condition = dataset[randrange(len(dataset))].view((1, 1) + obs_shape) else: condition = None frame = env.reset() action = model.initial_action() agent_state = model.initial_state() done = torch.tensor(False).view(1, 1) rewards = [] frames = [frame] for i in range(flags.episode_length - 1): if flags.mode == "test_render": env.render() noise = torch.randn(1, 1, 10) agent_outputs, agent_state = model( dict( obs=frame, condition=condition, action=action, noise=noise, done=done, ), agent_state, ) action, *_ = agent_outputs frame, reward, done, _ = env.step(action) rewards.append(reward) frames.append(frame) reward = torch.cat(rewards) frame = torch.cat(frames) if flags.use_tca: frame = torch.flatten(frame, 0, 1) if flags.condition: condition = torch.flatten(condition, 0, 1) else: frame = frame[-1] if flags.condition: condition = condition[-1] D = D.eval() with torch.no_grad(): if flags.condition: p = D(frame, condition).view(-1, 1) else: p = D(frame).view(-1, 1) if flags.use_tca: d_reward = p[1:] - p[:-1] reward = reward[1:] + d_reward else: reward[-1] = reward[-1] + p reward = reward[1:] # empty condition condition = None logging.info( "Episode ended after %d steps. Final reward: %.4f. Episode reward: %.4f,", flags.episode_length, reward[-1].item(), reward.sum(), ) env.close()
def train(flags): if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") plogger = file_writer.FileWriter(xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))) if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") flags.learner_device = torch.device("cuda") flags.actor_device = torch.device("cuda") else: logging.info("Not using CUDA.") flags.learner_device = torch.device("cpu") flags.actor_device = torch.device("cpu") if flags.max_learner_queue_size is None: flags.max_learner_queue_size = flags.batch_size # The queue the learner threads will get their data from. # Setting `minimum_batch_size == maximum_batch_size` # makes the batch size static. learner_queue = actorpool.BatchingQueue( batch_dim=1, minimum_batch_size=flags.batch_size, maximum_batch_size=flags.batch_size, check_inputs=True, maximum_queue_size=flags.max_learner_queue_size, ) d_queue = Queue(maxsize=flags.max_learner_queue_size // flags.batch_size) image_queue = Queue(maxsize=flags.max_learner_queue_size) # The "batcher", a queue for the inference call. Will yield # "batch" objects with `get_inputs` and `set_outputs` methods. # The batch size of the tensors will be dynamic. inference_batcher = actorpool.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=512, timeout_ms=100, check_outputs=True, ) addresses = [] connections_per_server = 1 pipe_id = 0 while len(addresses) < flags.num_actors: for _ in range(connections_per_server): addresses.append(f"{flags.pipes_basename}.{pipe_id}") if len(addresses) == flags.num_actors: break pipe_id += 1 config = dict( episode_length=flags.episode_length, canvas_width=flags.canvas_width, grid_width=grid_width, brush_sizes=flags.brush_sizes, ) if flags.dataset == "celeba" or flags.dataset == "celeba-hq": use_color = True else: use_color = False if flags.env_type == "fluid": env_name = "Fluid" config["shaders_basedir"] = SHADERS_BASEDIR elif flags.env_type == "libmypaint": env_name = "Libmypaint" config.update( dict( brush_type=flags.brush_type, use_color=use_color, use_pressure=flags.use_pressure, use_alpha=False, background="white", brushes_basedir=BRUSHES_BASEDIR, )) if flags.use_compound: env_name += "-v1" else: env_name += "-v0" env = env_wrapper.make_raw(env_name, config) if frame_width != flags.canvas_width: env = env_wrapper.WarpFrame(env, height=frame_width, width=frame_width) env = env_wrapper.wrap_pytorch(env) obs_shape = env.observation_space.shape if flags.condition: c, h, w = obs_shape c *= 2 obs_shape = (c, h, w) action_shape = env.action_space.nvec.tolist() order = env.order env.close() model = models.Net( obs_shape=obs_shape, action_shape=action_shape, grid_shape=(grid_width, grid_width), order=order, ) if flags.condition: model = models.Condition(model) model = model.to(device=flags.learner_device) actor_model = models.Net( obs_shape=obs_shape, action_shape=action_shape, grid_shape=(grid_width, grid_width), order=order, ) if flags.condition: actor_model = models.Condition(actor_model) actor_model.to(device=flags.actor_device) D = models.Discriminator(obs_shape, flags.power_iters) if flags.condition: D = models.Conditional(D) D.to(device=flags.learner_device) D_eval = models.Discriminator(obs_shape, flags.power_iters) if flags.condition: D_eval = models.Conditional(D_eval) D_eval = D_eval.to(device=flags.learner_device) optimizer = optim.Adam(model.parameters(), lr=flags.policy_learning_rate) D_optimizer = optim.Adam(D.parameters(), lr=flags.discriminator_learning_rate, betas=(0.5, 0.999)) def lr_lambda(epoch): return (1 - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps) / flags.total_steps) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) D_scheduler = torch.optim.lr_scheduler.LambdaLR(D_optimizer, lr_lambda) C, H, W = obs_shape if flags.condition: C //= 2 # The ActorPool that will run `flags.num_actors` many loops. actors = actorpool.ActorPool( unroll_length=flags.unroll_length, learner_queue=learner_queue, inference_batcher=inference_batcher, env_server_addresses=addresses, initial_action=actor_model.initial_action(), initial_agent_state=actor_model.initial_state(), image=torch.zeros(1, 1, C, H, W), ) def run(): try: actors.run() print("actors are running") except Exception as e: logging.error("Exception in actorpool thread!") traceback.print_exc() print() raise e actorpool_thread = threading.Thread(target=run, name="actorpool-thread") c, h, w = obs_shape tsfm = transforms.Compose( [transforms.Resize((h, w)), transforms.ToTensor()]) dataset = flags.dataset if dataset == "mnist": dataset = MNIST(root="./", train=True, transform=tsfm, download=True) elif dataset == "omniglot": dataset = Omniglot(root="./", background=True, transform=tsfm, download=True) elif dataset == "celeba": dataset = CelebA(root="./", split="train", target_type=None, transform=tsfm, download=True) elif dataset == "celeba-hq": dataset = datasets.CelebAHQ(root="./", split="train", transform=tsfm, download=True) else: raise NotImplementedError dataloader = DataLoader(dataset, batch_size=1, shuffle=True, drop_last=True, pin_memory=True) stats = {} # Load state from a checkpoint, if possible. if os.path.exists(checkpointpath): checkpoint_states = torch.load(checkpointpath, map_location=flags.learner_device) model.load_state_dict(checkpoint_states["model_state_dict"]) D.load_state_dict(checkpoint_states["D_state_dict"]) optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"]) D_optimizer.load_state_dict( checkpoint_states["D_optimizer_state_dict"]) scheduler.load_state_dict(checkpoint_states["D_scheduler_state_dict"]) D_scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"]) stats = checkpoint_states["stats"] logging.info(f"Resuming preempted job, current stats:\n{stats}") # Initialize actor model like learner model. actor_model.load_state_dict(model.state_dict()) D_eval.load_state_dict(D.state_dict()) learner_threads = [ threading.Thread( target=learn, name="learner-thread-%i" % i, args=( flags, learner_queue, d_queue, model, actor_model, D_eval, optimizer, scheduler, stats, plogger, ), ) for i in range(flags.num_learner_threads) ] inference_threads = [ threading.Thread( target=inference, name="inference-thread-%i" % i, args=( flags, inference_batcher, actor_model, image_queue, ), ) for i in range(flags.num_inference_threads) ] d_learner = [ threading.Thread( target=learn_D, name="d_learner-thread-%i" % i, args=( flags, d_queue, D, D_eval, D_optimizer, D_scheduler, stats, plogger, ), ) for i in range(flags.num_learner_threads) ] for thread in d_learner: thread.daemon = True dataloader_thread = threading.Thread(target=data_loader, args=( flags, dataloader, image_queue, )) dataloader_thread.daemon = True actorpool_thread.start() threads = learner_threads + inference_threads daemons = d_learner + [dataloader_thread] for t in threads + daemons: t.start() def checkpoint(): if flags.disable_checkpoint: return logging.info("Saving checkpoint to %s", checkpointpath) torch.save( { "model_state_dict": model.state_dict(), "D_state_dict": D.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "D_optimizer_state_dict": D_optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "D_scheduler_state_dict": D_scheduler.state_dict(), "stats": stats, "flags": vars(flags), }, checkpointpath, ) def format_value(x): return f"{x:1.5}" if isinstance(x, float) else str(x) try: last_checkpoint_time = timeit.default_timer() while True: start_time = timeit.default_timer() start_step = stats.get("step", 0) if start_step >= flags.total_steps: break time.sleep(5) end_step = stats.get("step", 0) if timeit.default_timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timeit.default_timer() logging.info( "Step %i @ %.1f SPS. Inference batcher size: %i." " Learner queue size: %i." " Other stats: (%s)", end_step, (end_step - start_step) / (timeit.default_timer() - start_time), inference_batcher.size(), learner_queue.size(), ", ".join(f"{key} = {format_value(value)}" for key, value in stats.items()), ) except KeyboardInterrupt: pass # Close properly. else: logging.info("Learning finished after %i steps.", stats["step"]) checkpoint() # Done with learning. Stop all the ongoing work. inference_batcher.close() learner_queue.close() actorpool_thread.join() for t in threads: t.join()
def setUp(self): unroll_length = 3 # Inference called for every step. batch_size = 4 # Arbitrary. frame_dimension = 64 # Has to match what expected by the model. order = ["control", "end", "flag", "size", "pressure"] action_shape = [1024, 1024, 2, 8, 10] num_channels = 2 # Has to match with the first conv layer of the net. grid_shape = [32, 32] # Specific to each environment. obs_shape = [num_channels, frame_dimension, frame_dimension] # The following hyperparamaters are arbitrary. self.lr = 0.1 total_steps = 100000 # Set the random seed manually to get reproducible results. torch.manual_seed(0) self.model = models.Net( obs_shape=obs_shape, order=order, action_shape=action_shape, grid_shape=grid_shape, ) self.actor_model = models.Net( obs_shape=obs_shape, order=order, action_shape=action_shape, grid_shape=grid_shape, ) self.initial_model_dict = copy.deepcopy(self.model.state_dict()) self.initial_actor_model_dict = copy.deepcopy( self.actor_model.state_dict()) optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr) self.D = models.ComplementDiscriminator(obs_shape, spectral_norm=False) self.D_eval = models.ComplementDiscriminator( obs_shape, spectral_norm=False).eval() self.initial_D_dict = copy.deepcopy(self.D.state_dict()) self.initial_D_eval_dict = copy.deepcopy(self.D_eval.state_dict()) D_optimizer = torch.optim.SGD(self.D.parameters(), lr=self.lr) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=total_steps // 10) D_scheduler = torch.optim.lr_scheduler.StepLR(D_optimizer, step_size=total_steps // 10) self.stats = {} # The call to plogger.log will not perform any action. plogger = mock.Mock() plogger.log = mock.Mock() # Mock flags. mock_flags = mock.Mock() mock_flags.learner_device = torch.device("cpu") mock_flags.discounting = 0.99 # Default value from cmd. mock_flags.baseline_cost = 0.5 # Default value from cmd. mock_flags.entropy_cost = 0.0006 # Default value from cmd. mock_flags.unroll_length = unroll_length - 1 mock_flags.batch_size = batch_size mock_flags.grad_norm_clipping = 40 mock_flags.use_tca = True mock_flags.condition = True # Prepare content for mock_learner_queue. obs = dict( canvas=torch.ones([unroll_length, batch_size] + obs_shape), prev_action=torch.ones(unroll_length, batch_size, len(action_shape)), action_mask=torch.ones(unroll_length, batch_size, len(action_shape)), noise_sample=torch.ones(unroll_length, batch_size, 10), ) rewards = torch.ones(unroll_length, batch_size) done = torch.zeros(unroll_length, batch_size, dtype=torch.bool) episode_step = torch.ones(unroll_length, batch_size) episode_return = torch.ones(unroll_length, batch_size) new_frame = dict( canvas=torch.ones([unroll_length - 1, batch_size] + obs_shape), prev_action=torch.ones(unroll_length - 1, batch_size, len(action_shape)), action_mask=torch.ones(unroll_length - 1, batch_size, len(action_shape)), noise_sample=torch.ones(unroll_length - 1, batch_size, 10), ) env_outputs = (obs, rewards, done, episode_step, episode_return) actor_outputs = ( # Actions taken. torch.cat( list( map( lambda num_actions: torch.randint( low=0, high=num_actions, size=(unroll_length, batch_size, 1)), action_shape, )), dim=-1, ), # Logits. list( map( lambda num_actions: torch.randn(unroll_length, batch_size, num_actions), action_shape, )), # Baseline. torch.rand(unroll_length, batch_size), ) initial_agent_state = self.model.initial_state(batch_size) tensors = ((env_outputs, actor_outputs), new_frame, initial_agent_state) # Mock learner_queue. mock_learner_queue = mock.MagicMock() mock_learner_queue.__iter__.return_value = iter([tensors]) self.learn_args = ( mock_flags, mock_learner_queue, self.model, self.actor_model, self.D_eval, optimizer, scheduler, self.stats, plogger, ) # Mock replay_queue. mock_replay_queue = mock.MagicMock() mock_replay_queue.is_closed.return_value = True # Mock dataloader. mock_dataloader = mock.MagicMock() mock_dataloader.__iter__.return_value = iter([( torch.ones(batch_size, 1, frame_dimension, frame_dimension), None, )]) # Mock replay_buffer. mock_replay_buffer = mock.MagicMock() mock_replay_buffer.sample.return_value = torch.ones([batch_size] + obs_shape) self.learn_D_args = ( mock_flags, mock_dataloader, mock_replay_queue, mock_replay_buffer, self.D, self.D_eval, D_optimizer, D_scheduler, self.stats, plogger, )
def train(flags): if flags.xpid is None: flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") plogger = file_writer.FileWriter(xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir) checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))) if not flags.disable_cuda and torch.cuda.is_available(): logging.info("Using CUDA.") flags.learner_device = torch.device("cuda") flags.actor_device = torch.device("cuda") else: logging.info("Not using CUDA.") flags.learner_device = torch.device("cpu") flags.actor_device = torch.device("cpu") if flags.max_learner_queue_size is None: flags.max_learner_queue_size = flags.batch_size # The queue the learner threads will get their data from. # Setting `minimum_batch_size == maximum_batch_size` # makes the batch size static. learner_queue = libtorchbeast.BatchingQueue( batch_dim=1, minimum_batch_size=flags.batch_size, maximum_batch_size=flags.batch_size, check_inputs=True, maximum_queue_size=flags.max_learner_queue_size, ) # The queue the actorpool stores final render image pairs. # A seperate thread will load them to the ReplayBuffer. # The batch size of the pairs will be dynamic. replay_queue = libtorchbeast.BatchingQueue( batch_dim=1, minimum_batch_size=flags.batch_size, maximum_batch_size=flags.batch_size, check_inputs=True, maximum_queue_size=flags.batch_size, ) # The "batcher", a queue for the inference call. Will yield # "batch" objects with `get_inputs` and `set_outputs` methods. # The batch size of the tensors will be dynamic. inference_batcher = libtorchbeast.DynamicBatcher( batch_dim=1, minimum_batch_size=1, maximum_batch_size=512, timeout_ms=100, check_outputs=True, ) addresses = [] connections_per_server = 1 pipe_id = 0 while len(addresses) < flags.num_actors: for _ in range(connections_per_server): addresses.append(f"{flags.pipes_basename}.{pipe_id}") if len(addresses) == flags.num_actors: break pipe_id += 1 dataset_is_gray = flags.dataset in ["mnist", "omniglot"] grayscale = not dataset_is_gray and not flags.use_color dataset_is_gray |= grayscale dataset = utils.create_dataset(flags.dataset, grayscale) env_name, config = utils.parse_flags(flags) env = utils.create_env(env_name, config, dataset_is_gray, dataset=None) if flags.condition: new_space = env.observation_space.spaces c, h, w = new_space["canvas"].shape new_space["canvas"] = spaces.Box(low=0, high=255, shape=(c * 2, h, w), dtype=np.uint8) env.observation_space = spaces.Dict(new_space) obs_shape = env.observation_space["canvas"].shape action_shape = env.action_space.nvec order = env.order env.close() model = models.Net( obs_shape=obs_shape, order=order, action_shape=action_shape, grid_shape=(grid_width, grid_width), ) model = model.to(device=flags.learner_device) actor_model = models.Net( obs_shape=obs_shape, order=order, action_shape=action_shape, grid_shape=(grid_width, grid_width), ).eval() actor_model.to(device=flags.actor_device) if flags.condition: D = models.ComplementDiscriminator(obs_shape) else: D = models.Discriminator(obs_shape) D.to(device=flags.learner_device) if flags.condition: D_eval = models.ComplementDiscriminator(obs_shape) else: D_eval = models.Discriminator(obs_shape) D_eval = D_eval.to(device=flags.learner_device).eval() optimizer = optim.Adam(model.parameters(), lr=flags.policy_learning_rate) D_optimizer = optim.Adam(D.parameters(), lr=flags.discriminator_learning_rate, betas=(0.5, 0.999)) def lr_lambda(epoch): return (1 - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps) / flags.total_steps) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # The ActorPool that will run `flags.num_actors` many loops. actors = libtorchbeast.ActorPool( unroll_length=flags.unroll_length, learner_queue=learner_queue, replay_queue=replay_queue, inference_batcher=inference_batcher, env_server_addresses=addresses, initial_agent_state=actor_model.initial_state(), ) def run(): try: actors.run() print("actors are running") except Exception as e: logging.error("Exception in actorpool thread!") traceback.print_exc() print() raise e actorpool_thread = threading.Thread(target=run, name="actorpool-thread") dataloader = DataLoader( dataset, batch_size=flags.batch_size, shuffle=True, drop_last=True, pin_memory=True, ) stats = {} # Load state from a checkpoint, if possible. if os.path.exists(checkpointpath): checkpoint_states = torch.load(checkpointpath, map_location=flags.learner_device) model.load_state_dict(checkpoint_states["model_state_dict"]) D.load_state_dict(checkpoint_states["D_state_dict"]) optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"]) D_optimizer.load_state_dict( checkpoint_states["D_optimizer_state_dict"]) scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"]) stats = checkpoint_states["stats"] logging.info(f"Resuming preempted job, current stats:\n{stats}") # Initialize actor model like learner model. actor_model.load_state_dict(model.state_dict()) D_eval.load_state_dict(D.state_dict()) learner_threads = [ threading.Thread( target=learn, name="learner-thread-%i" % i, args=( flags, learner_queue, model, actor_model, D_eval, optimizer, scheduler, stats, plogger, ), ) for i in range(flags.num_learner_threads) ] inference_threads = [ threading.Thread( target=inference, name="inference-thread-%i" % i, args=( flags, inference_batcher, actor_model, ), ) for i in range(flags.num_inference_threads) ] d_learner = threading.Thread( target=learn_D, name="d_learner-thread", args=( flags, dataloader, replay_queue, D, D_eval, D_optimizer, stats, plogger, ), ) actorpool_thread.start() threads = learner_threads + inference_threads for t in inference_threads: t.start() def checkpoint(): if flags.disable_checkpoint: return logging.info("Saving checkpoint to %s", checkpointpath) torch.save( { "model_state_dict": model.state_dict(), "D_state_dict": D.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "D_optimizer_state_dict": D_optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "stats": stats, "flags": vars(flags), }, checkpointpath, ) def format_value(x): return f"{x:1.5}" if isinstance(x, float) else str(x) try: while replay_queue.size() < flags.batch_size: if learner_queue.size() >= flags.batch_size: next(learner_queue) time.sleep(0.01) d_learner.start() for t in learner_threads: t.start() last_checkpoint_time = timeit.default_timer() while True: start_time = timeit.default_timer() start_step = stats.get("step", 0) if start_step >= flags.total_steps: break time.sleep(5) end_step = stats.get("step", 0) if timeit.default_timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. checkpoint() last_checkpoint_time = timeit.default_timer() logging.info( "Step %i @ %.1f SPS. Inference batcher size: %i." " Learner queue size: %i." " Other stats: (%s)", end_step, (end_step - start_step) / (timeit.default_timer() - start_time), inference_batcher.size(), learner_queue.size(), ", ".join(f"{key} = {format_value(value)}" for key, value in stats.items()), ) except KeyboardInterrupt: pass # Close properly. else: logging.info("Learning finished after %i steps.", stats["step"]) checkpoint() # Done with learning. Stop all the ongoing work. inference_batcher.close() learner_queue.close() replay_queue.close() actorpool_thread.join() d_learner.join() for t in threads: t.join()
def _test_inference(self, use_color, device): model = models.Net( obs_shape=self.obs_shape, order=self.order, action_shape=self.action_shape, grid_shape=self.grid_shape, ) model.to(device) agent_state = model.initial_state(self.batch_size) inputs = ( ( self.obs, self.rewards, self.done, self.episode_return, self.episode_return, ), agent_state, ) # Set the behaviour of the methods of the mock batch. self.mock_batch.get_inputs = mock.Mock(return_value=inputs) self.mock_batch.set_outputs = mock.Mock() # Preparing the mock flags. Could do with just a dict but using # a Mock object for consistency. mock_flags = mock.Mock() mock_flags.actor_device = device polybeast.inference(mock_flags, self.mock_inference_batcher, model) # Assert the batch is used only once. self.mock_batch.get_inputs.assert_called_once() self.mock_batch.set_outputs.assert_called_once() # Check that set_outputs has been called with paramaters with the expected shape. batch_args, batch_kwargs = self.mock_batch.set_outputs.call_args self.assertEqual(batch_kwargs, {}) model_outputs, *other_args = batch_args self.assertEqual(other_args, []) (action, policy_logits, baseline), core_state = model_outputs self.assertSequenceEqual( action.shape, (self.unroll_length, self.batch_size, len(self.action_shape))) for logits, num_actions in zip(policy_logits, self.action_shape): self.assertSequenceEqual( logits.shape, (self.unroll_length, self.batch_size, num_actions)) self.assertSequenceEqual(baseline.shape, (self.unroll_length, self.batch_size)) for tensor in (action, baseline) + core_state: self.assertEqual(tensor.device, torch.device("cpu")) for tensor in policy_logits: self.assertEqual(tensor.device, torch.device("cpu")) self.assertEqual(len(core_state), 2) for core_state_element in core_state: self.assertSequenceEqual( core_state_element.shape, (1, self.batch_size, self.core_output_size), )