Exemple #1
0
    def test_bad_construct(self):
        with self.assertRaisesRegex(ValueError, "Min batch size must be >= 1"):
            actorpool.BatchingQueue(batch_dim=3,
                                    minimum_batch_size=0,
                                    maximum_batch_size=1)

        with self.assertRaisesRegex(
                ValueError, "Max batch size must be >= min batch size"):
            actorpool.BatchingQueue(batch_dim=3,
                                    minimum_batch_size=1,
                                    maximum_batch_size=0)
    def setUp(self):
        self.server_proc = subprocess.Popen(
            ["python", "tests/contiguous_arrays_env.py"])

        server_address = ["unix:/tmp/contiguous_arrays_test"]
        self.learner_queue = actorpool.BatchingQueue(batch_dim=1,
                                                     minimum_batch_size=1,
                                                     maximum_batch_size=10,
                                                     check_inputs=True)
        self.inference_batcher = actorpool.DynamicBatcher(
            batch_dim=1,
            minimum_batch_size=1,
            maximum_batch_size=10,
            timeout_ms=100,
            check_outputs=True,
        )
        actor = actorpool.ActorPool(
            unroll_length=1,
            learner_queue=self.learner_queue,
            inference_batcher=self.inference_batcher,
            env_server_addresses=server_address,
            initial_agent_state=(),
        )

        def run():
            actor.run()

        self.actor_thread = threading.Thread(target=run)
        self.actor_thread.start()

        self.target = np.arange(3 * 4 * 5)
        self.target = self.target.reshape(3, 4, 5)
        self.target = self.target.transpose(2, 1, 0)
Exemple #3
0
    def test_batched_run(self, batch_size=2):
        queue = actorpool.BatchingQueue(batch_dim=0,
                                        minimum_batch_size=batch_size,
                                        maximum_batch_size=batch_size)

        inputs = [torch.full((1, 2, 3), i) for i in range(batch_size)]

        def enqueue_target(i):
            while queue.size() < i:
                # Make sure thread i calls enqueue before thread i + 1.
                time.sleep(0.05)
            queue.enqueue(inputs[i])

        enqueue_threads = []
        for i in range(batch_size):
            enqueue_threads.append(
                threading.Thread(target=enqueue_target,
                                 name=f"enqueue-thread-{i}",
                                 args=(i, )))

        for t in enqueue_threads:
            t.start()

        batch = next(queue)
        np.testing.assert_array_equal(batch, torch.cat(inputs))

        for t in enqueue_threads:
            t.join()
Exemple #4
0
    def setUp(self):
        self.server_proc = subprocess.Popen(["python", "tests/core_agent_state_env.py"])

        self.B = 2
        self.T = 3
        self.model = Net()
        server_address = ["unix:/tmp/core_agent_state_test"]
        self.learner_queue = actorpool.BatchingQueue(
            batch_dim=1,
            minimum_batch_size=self.B,
            maximum_batch_size=self.B,
            check_inputs=True,
        )
        self.inference_batcher = actorpool.DynamicBatcher(
            batch_dim=1,
            minimum_batch_size=1,
            maximum_batch_size=1,
            timeout_ms=100,
            check_outputs=True,
        )
        self.actor = actorpool.ActorPool(
            unroll_length=self.T,
            learner_queue=self.learner_queue,
            inference_batcher=self.inference_batcher,
            env_server_addresses=server_address,
            initial_agent_state=self.model.initial_state(),
        )
Exemple #5
0
    def test_simple_run(self):
        queue = actorpool.BatchingQueue(batch_dim=0,
                                        minimum_batch_size=1,
                                        maximum_batch_size=1)

        inputs = torch.zeros(1, 2, 3)
        queue.enqueue(inputs)
        batch = next(queue)
        np.testing.assert_array_equal(batch, inputs)
Exemple #6
0
 def test_check_inputs(self):
     queue = actorpool.BatchingQueue(batch_dim=2)
     with self.assertRaisesRegex(
             ValueError,
             "Enqueued tensors must have more than batch_dim =="):
         queue.enqueue(torch.ones(5))
     with self.assertRaisesRegex(ValueError,
                                 "Cannot enqueue empty vector of tensors"):
         queue.enqueue([])
     with self.assertRaisesRegex(actorpool.ClosedBatchingQueue,
                                 "Enqueue to closed queue"):
         queue.close()
         queue.enqueue(torch.ones(1, 1, 1))
Exemple #7
0
    def test_many_consumers(self,
                            enqueue_threads_number=16,
                            repeats=100,
                            dequeue_threads_number=64):
        queue = actorpool.BatchingQueue(batch_dim=0)

        lock = threading.Lock()
        total_batches_consumed = 0

        def enqueue_target(i):
            for _ in range(repeats):
                queue.enqueue(torch.full((1, 2, 3), i))

        def dequeue_target():
            nonlocal total_batches_consumed
            for batch in queue:
                batch_size, *_ = batch.shape
                with lock:
                    total_batches_consumed += batch_size

        enqueue_threads = []
        for i in range(enqueue_threads_number):
            enqueue_threads.append(
                threading.Thread(target=enqueue_target,
                                 name=f"enqueue-thread-{i}",
                                 args=(i, )))

        dequeue_threads = []
        for i in range(dequeue_threads_number):
            dequeue_threads.append(
                threading.Thread(target=dequeue_target,
                                 name=f"dequeue-thread-{i}"))

        for t in enqueue_threads + dequeue_threads:
            t.start()

        for t in enqueue_threads:
            t.join()

        queue.close()

        for t in dequeue_threads:
            t.join()

        self.assertEqual(total_batches_consumed,
                         repeats * enqueue_threads_number)
Exemple #8
0
def train(flags):
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(
        xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir
    )
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))
    )

    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.learner_device = torch.device("cuda:0")
        flags.actor_device = torch.device("cuda:1")
    else:
        logging.info("Not using CUDA.")
        flags.learner_device = torch.device("cpu")
        flags.actor_device = torch.device("cpu")

    if flags.max_learner_queue_size is None:
        flags.max_learner_queue_size = flags.batch_size

    # The queue the learner threads will get their data from.
    # Setting `minimum_batch_size == maximum_batch_size`
    # makes the batch size static.
    learner_queue = actorpool.BatchingQueue(
        batch_dim=1,
        minimum_batch_size=flags.batch_size,
        maximum_batch_size=flags.batch_size,
        check_inputs=True,
        maximum_queue_size=flags.max_learner_queue_size,
    )

    # The "batcher", a queue for the inference call. Will yield
    # "batch" objects with `get_inputs` and `set_outputs` methods.
    # The batch size of the tensors will be dynamic.
    inference_batcher = actorpool.DynamicBatcher(
        batch_dim=1,
        minimum_batch_size=1,
        maximum_batch_size=512,
        timeout_ms=100,
        check_outputs=True,
    )

    addresses = []
    connections_per_server = 1
    pipe_id = 0
    while len(addresses) < flags.num_actors:
        for _ in range(connections_per_server):
            addresses.append(f"{flags.pipes_basename}.{pipe_id}")
            if len(addresses) == flags.num_actors:
                break
        pipe_id += 1

    model = Net(num_actions=flags.num_actions, use_lstm=flags.use_lstm)
    model = model.to(device=flags.learner_device)

    actor_model = Net(num_actions=flags.num_actions, use_lstm=flags.use_lstm)
    actor_model.to(device=flags.actor_device)

    # The ActorPool that will run `flags.num_actors` many loops.
    actors = actorpool.ActorPool(
        unroll_length=flags.unroll_length,
        learner_queue=learner_queue,
        inference_batcher=inference_batcher,
        env_server_addresses=addresses,
        initial_agent_state=actor_model.initial_state(),
    )

    def run():
        try:
            actors.run()
        except Exception as e:
            logging.error("Exception in actorpool thread!")
            traceback.print_exc()
            print()
            raise e

    actorpool_thread = threading.Thread(target=run, name="actorpool-thread")

    optimizer = torch.optim.RMSprop(
        model.parameters(),
        lr=flags.learning_rate,
        momentum=flags.momentum,
        eps=flags.epsilon,
        alpha=flags.alpha,
    )

    def lr_lambda(epoch):
        return (
            1
            - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps)
            / flags.total_steps
        )

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    stats = {}

    # Load state from a checkpoint, if possible.
    if os.path.exists(checkpointpath):
        checkpoint_states = torch.load(
            checkpointpath, map_location=flags.learner_device
        )
        model.load_state_dict(checkpoint_states["model_state_dict"])
        optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"])
        stats = checkpoint_states["stats"]
        logging.info(f"Resuming preempted job, current stats:\n{stats}")

    # Initialize actor model like learner model.
    actor_model.load_state_dict(model.state_dict())

    learner_threads = [
        threading.Thread(
            target=learn,
            name="learner-thread-%i" % i,
            args=(
                flags,
                learner_queue,
                model,
                actor_model,
                optimizer,
                scheduler,
                stats,
                plogger,
            ),
        )
        for i in range(flags.num_learner_threads)
    ]
    inference_threads = [
        threading.Thread(
            target=inference,
            name="inference-thread-%i" % i,
            args=(flags, inference_batcher, actor_model),
        )
        for i in range(flags.num_inference_threads)
    ]

    actorpool_thread.start()
    for t in learner_threads + inference_threads:
        t.start()

    def checkpoint():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "stats": stats,
                "flags": vars(flags),
            },
            checkpointpath,
        )

    def format_value(x):
        return f"{x:1.5}" if isinstance(x, float) else str(x)

    try:
        last_checkpoint_time = timeit.default_timer()
        while True:
            start_time = timeit.default_timer()
            start_step = stats.get("step", 0)
            if start_step >= flags.total_steps:
                break
            time.sleep(5)
            end_step = stats.get("step", 0)

            if timeit.default_timer() - last_checkpoint_time > 10 * 60:
                # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timeit.default_timer()

            logging.info(
                "Step %i @ %.1f SPS. Inference batcher size: %i."
                " Learner queue size: %i."
                " Other stats: (%s)",
                end_step,
                (end_step - start_step) / (timeit.default_timer() - start_time),
                inference_batcher.size(),
                learner_queue.size(),
                ", ".join(
                    f"{key} = {format_value(value)}" for key, value in stats.items()
                ),
            )
    except KeyboardInterrupt:
        pass  # Close properly.
    else:
        logging.info("Learning finished after %i steps.", stats["step"])
        checkpoint()

    # Done with learning. Stop all the ongoing work.
    inference_batcher.close()
    learner_queue.close()

    actorpool_thread.join()

    for t in learner_threads + inference_threads:
        t.join()
Exemple #9
0
def train(flags):
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(xpid=flags.xpid,
                                     xp_args=flags.__dict__,
                                     rootdir=flags.savedir)
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" %
                           (flags.savedir, flags.xpid, "model.tar")))

    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.learner_device = torch.device("cuda")
        flags.actor_device = torch.device("cuda")
    else:
        logging.info("Not using CUDA.")
        flags.learner_device = torch.device("cpu")
        flags.actor_device = torch.device("cpu")

    if flags.max_learner_queue_size is None:
        flags.max_learner_queue_size = flags.batch_size

    # The queue the learner threads will get their data from.
    # Setting `minimum_batch_size == maximum_batch_size`
    # makes the batch size static.
    learner_queue = actorpool.BatchingQueue(
        batch_dim=1,
        minimum_batch_size=flags.batch_size,
        maximum_batch_size=flags.batch_size,
        check_inputs=True,
        maximum_queue_size=flags.max_learner_queue_size,
    )

    d_queue = Queue(maxsize=flags.max_learner_queue_size // flags.batch_size)
    image_queue = Queue(maxsize=flags.max_learner_queue_size)

    # The "batcher", a queue for the inference call. Will yield
    # "batch" objects with `get_inputs` and `set_outputs` methods.
    # The batch size of the tensors will be dynamic.
    inference_batcher = actorpool.DynamicBatcher(
        batch_dim=1,
        minimum_batch_size=1,
        maximum_batch_size=512,
        timeout_ms=100,
        check_outputs=True,
    )

    addresses = []
    connections_per_server = 1
    pipe_id = 0
    while len(addresses) < flags.num_actors:
        for _ in range(connections_per_server):
            addresses.append(f"{flags.pipes_basename}.{pipe_id}")
            if len(addresses) == flags.num_actors:
                break
        pipe_id += 1

    config = dict(
        episode_length=flags.episode_length,
        canvas_width=flags.canvas_width,
        grid_width=grid_width,
        brush_sizes=flags.brush_sizes,
    )

    if flags.dataset == "celeba" or flags.dataset == "celeba-hq":
        use_color = True
    else:
        use_color = False

    if flags.env_type == "fluid":
        env_name = "Fluid"
        config["shaders_basedir"] = SHADERS_BASEDIR
    elif flags.env_type == "libmypaint":
        env_name = "Libmypaint"
        config.update(
            dict(
                brush_type=flags.brush_type,
                use_color=use_color,
                use_pressure=flags.use_pressure,
                use_alpha=False,
                background="white",
                brushes_basedir=BRUSHES_BASEDIR,
            ))

    if flags.use_compound:
        env_name += "-v1"
    else:
        env_name += "-v0"

    env = env_wrapper.make_raw(env_name, config)
    if frame_width != flags.canvas_width:
        env = env_wrapper.WarpFrame(env, height=frame_width, width=frame_width)
    env = env_wrapper.wrap_pytorch(env)

    obs_shape = env.observation_space.shape
    if flags.condition:
        c, h, w = obs_shape
        c *= 2
        obs_shape = (c, h, w)

    action_shape = env.action_space.nvec.tolist()
    order = env.order
    env.close()

    model = models.Net(
        obs_shape=obs_shape,
        action_shape=action_shape,
        grid_shape=(grid_width, grid_width),
        order=order,
    )
    if flags.condition:
        model = models.Condition(model)
    model = model.to(device=flags.learner_device)

    actor_model = models.Net(
        obs_shape=obs_shape,
        action_shape=action_shape,
        grid_shape=(grid_width, grid_width),
        order=order,
    )
    if flags.condition:
        actor_model = models.Condition(actor_model)
    actor_model.to(device=flags.actor_device)

    D = models.Discriminator(obs_shape, flags.power_iters)
    if flags.condition:
        D = models.Conditional(D)
    D.to(device=flags.learner_device)

    D_eval = models.Discriminator(obs_shape, flags.power_iters)
    if flags.condition:
        D_eval = models.Conditional(D_eval)
    D_eval = D_eval.to(device=flags.learner_device)

    optimizer = optim.Adam(model.parameters(), lr=flags.policy_learning_rate)
    D_optimizer = optim.Adam(D.parameters(),
                             lr=flags.discriminator_learning_rate,
                             betas=(0.5, 0.999))

    def lr_lambda(epoch):
        return (1 - min(epoch * flags.unroll_length * flags.batch_size,
                        flags.total_steps) / flags.total_steps)

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    D_scheduler = torch.optim.lr_scheduler.LambdaLR(D_optimizer, lr_lambda)

    C, H, W = obs_shape
    if flags.condition:
        C //= 2
    # The ActorPool that will run `flags.num_actors` many loops.
    actors = actorpool.ActorPool(
        unroll_length=flags.unroll_length,
        learner_queue=learner_queue,
        inference_batcher=inference_batcher,
        env_server_addresses=addresses,
        initial_action=actor_model.initial_action(),
        initial_agent_state=actor_model.initial_state(),
        image=torch.zeros(1, 1, C, H, W),
    )

    def run():
        try:
            actors.run()
            print("actors are running")
        except Exception as e:
            logging.error("Exception in actorpool thread!")
            traceback.print_exc()
            print()
            raise e

    actorpool_thread = threading.Thread(target=run, name="actorpool-thread")

    c, h, w = obs_shape
    tsfm = transforms.Compose(
        [transforms.Resize((h, w)),
         transforms.ToTensor()])

    dataset = flags.dataset

    if dataset == "mnist":
        dataset = MNIST(root="./", train=True, transform=tsfm, download=True)
    elif dataset == "omniglot":
        dataset = Omniglot(root="./",
                           background=True,
                           transform=tsfm,
                           download=True)
    elif dataset == "celeba":
        dataset = CelebA(root="./",
                         split="train",
                         target_type=None,
                         transform=tsfm,
                         download=True)
    elif dataset == "celeba-hq":
        dataset = datasets.CelebAHQ(root="./",
                                    split="train",
                                    transform=tsfm,
                                    download=True)
    else:
        raise NotImplementedError

    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=True,
                            drop_last=True,
                            pin_memory=True)

    stats = {}

    # Load state from a checkpoint, if possible.
    if os.path.exists(checkpointpath):
        checkpoint_states = torch.load(checkpointpath,
                                       map_location=flags.learner_device)
        model.load_state_dict(checkpoint_states["model_state_dict"])
        D.load_state_dict(checkpoint_states["D_state_dict"])
        optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"])
        D_optimizer.load_state_dict(
            checkpoint_states["D_optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint_states["D_scheduler_state_dict"])
        D_scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"])
        stats = checkpoint_states["stats"]
        logging.info(f"Resuming preempted job, current stats:\n{stats}")

    # Initialize actor model like learner model.
    actor_model.load_state_dict(model.state_dict())
    D_eval.load_state_dict(D.state_dict())

    learner_threads = [
        threading.Thread(
            target=learn,
            name="learner-thread-%i" % i,
            args=(
                flags,
                learner_queue,
                d_queue,
                model,
                actor_model,
                D_eval,
                optimizer,
                scheduler,
                stats,
                plogger,
            ),
        ) for i in range(flags.num_learner_threads)
    ]
    inference_threads = [
        threading.Thread(
            target=inference,
            name="inference-thread-%i" % i,
            args=(
                flags,
                inference_batcher,
                actor_model,
                image_queue,
            ),
        ) for i in range(flags.num_inference_threads)
    ]

    d_learner = [
        threading.Thread(
            target=learn_D,
            name="d_learner-thread-%i" % i,
            args=(
                flags,
                d_queue,
                D,
                D_eval,
                D_optimizer,
                D_scheduler,
                stats,
                plogger,
            ),
        ) for i in range(flags.num_learner_threads)
    ]
    for thread in d_learner:
        thread.daemon = True

    dataloader_thread = threading.Thread(target=data_loader,
                                         args=(
                                             flags,
                                             dataloader,
                                             image_queue,
                                         ))
    dataloader_thread.daemon = True

    actorpool_thread.start()

    threads = learner_threads + inference_threads
    daemons = d_learner + [dataloader_thread]

    for t in threads + daemons:
        t.start()

    def checkpoint():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "D_state_dict": D.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "D_optimizer_state_dict": D_optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "D_scheduler_state_dict": D_scheduler.state_dict(),
                "stats": stats,
                "flags": vars(flags),
            },
            checkpointpath,
        )

    def format_value(x):
        return f"{x:1.5}" if isinstance(x, float) else str(x)

    try:
        last_checkpoint_time = timeit.default_timer()
        while True:
            start_time = timeit.default_timer()
            start_step = stats.get("step", 0)
            if start_step >= flags.total_steps:
                break
            time.sleep(5)
            end_step = stats.get("step", 0)

            if timeit.default_timer() - last_checkpoint_time > 10 * 60:
                # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timeit.default_timer()

            logging.info(
                "Step %i @ %.1f SPS. Inference batcher size: %i."
                " Learner queue size: %i."
                " Other stats: (%s)",
                end_step,
                (end_step - start_step) /
                (timeit.default_timer() - start_time),
                inference_batcher.size(),
                learner_queue.size(),
                ", ".join(f"{key} = {format_value(value)}"
                          for key, value in stats.items()),
            )
    except KeyboardInterrupt:
        pass  # Close properly.
    else:
        logging.info("Learning finished after %i steps.", stats["step"])
        checkpoint()

    # Done with learning. Stop all the ongoing work.
    inference_batcher.close()
    learner_queue.close()

    actorpool_thread.join()

    for t in threads:
        t.join()
Exemple #10
0
 def test_multiple_close_calls(self):
     queue = actorpool.BatchingQueue()
     queue.close()
     with self.assertRaisesRegex(RuntimeError, "Queue was closed already"):
         queue.close()
Exemple #11
0
def train(flags, rank=0, barrier=None, device="cuda:0", gossip_buffer=None):
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(xpid=flags.xpid,
                                     xp_args=flags.__dict__,
                                     rootdir=flags.savedir)

    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.learner_device = torch.device(device)
        flags.actor_device = torch.device(device)
    else:
        logging.info("Not using CUDA.")
        flags.learner_device = torch.device("cpu")
        flags.actor_device = torch.device("cpu")

    # The queue the learner threads will get their data from.
    # Setting `minimum_batch_size == maximum_batch_size`
    # makes the batch size static. We could make it dynamic, but that
    # requires a loss (and learning rate schedule) that's batch size
    # independent.
    learner_queue = actorpool.BatchingQueue(
        batch_dim=1,
        minimum_batch_size=flags.batch_size,
        maximum_batch_size=flags.batch_size,
        check_inputs=True,
    )

    # The "batcher", a queue for the inference call. Will yield
    # "batch" objects with `get_inputs` and `set_outputs` methods.
    # The batch size of the tensors will be dynamic.
    inference_batcher = actorpool.DynamicBatcher(
        batch_dim=1,
        minimum_batch_size=1,
        maximum_batch_size=512,
        timeout_ms=100,
        check_outputs=True,
    )

    model = Net(
        observation_shape=flags.observation_shape,
        hidden_size=flags.hidden_size,
        num_actions=flags.num_actions,
        use_lstm=flags.use_lstm,
    )
    model = model.to(device=flags.learner_device)

    actor_model = Net(
        observation_shape=flags.observation_shape,
        hidden_size=flags.hidden_size,
        num_actions=flags.num_actions,
        use_lstm=flags.use_lstm,
    )
    actor_model.to(device=flags.actor_device)

    # The ActorPool that will accept connections from actor clients.
    actors = actorpool.ActorPool(
        unroll_length=flags.unroll_length,
        learner_queue=learner_queue,
        inference_batcher=inference_batcher,
        server_address=flags.address,
        initial_agent_state=actor_model.initial_state(),
    )

    def run():
        try:
            actors.run()
        except Exception as e:
            logging.error("Exception in actorpool thread!")
            traceback.print_exc()
            print()
            raise e

    actorpool_thread = threading.Thread(target=run, name="actorpool-thread")

    optimizer = torch.optim.RMSprop(
        model.parameters(),
        lr=flags.learning_rate,
        momentum=flags.momentum,
        eps=flags.epsilon,
        alpha=flags.alpha,
    )

    def lr_lambda(epoch):
        return (1 - min(epoch * flags.unroll_length * flags.batch_size,
                        flags.total_steps) / flags.total_steps)

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    stats = {}

    learner_threads = [
        threading.Thread(
            target=learn,
            name="learner-thread-%i" % i,
            args=(
                learner_queue,
                model,
                actor_model,
                optimizer,
                scheduler,
                stats,
                flags,
                plogger,
                rank,
                gossip_buffer,
            ),
        ) for i in range(flags.num_learner_threads)
    ]
    inference_threads = [
        threading.Thread(
            target=inference,
            name="inference-thread-%i" % i,
            args=(inference_batcher, actor_model, flags),
        ) for i in range(flags.num_inference_threads)
    ]

    # Synchronize GALA agents before starting training
    if barrier is not None:
        barrier.wait()
        logging.info("%s: barrier passed" % rank)

    actorpool_thread.start()
    for t in learner_threads + inference_threads:
        t.start()

    def checkpoint():
        if flags.checkpoint:
            logging.info("Saving checkpoint to %s", flags.checkpoint)
            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "flags": vars(flags),
                },
                flags.checkpoint,
            )

    def format_value(x):
        return f"{x:1.5}" if isinstance(x, float) else str(x)

    try:
        last_checkpoint_time = timeit.default_timer()
        while True:
            start_time = timeit.default_timer()
            start_step = stats.get("step", 0)
            if start_step >= flags.total_steps:
                break
            time.sleep(5)
            end_step = stats.get("step", 0)

            if timeit.default_timer() - last_checkpoint_time > 10 * 60:
                # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timeit.default_timer()

            logging.info(
                "Step %i @ %.1f SPS. Inference batcher size: %i."
                " Learner queue size: %i."
                " Other stats: (%s)",
                end_step,
                (end_step - start_step) /
                (timeit.default_timer() - start_time),
                inference_batcher.size(),
                learner_queue.size(),
                ", ".join(f"{key} = {format_value(value)}"
                          for key, value in stats.items()),
            )
    except KeyboardInterrupt:
        pass  # Close properly.
    else:
        logging.info("Learning finished after %i steps.", stats["step"])
        checkpoint()

    # Done with learning. Let's stop all the ongoing work.
    inference_batcher.close()
    learner_queue.close()

    actors.stop()
    actorpool_thread.join()

    for t in learner_threads + inference_threads:
        t.join()

    # Trace and save the final model.
    trace_model(flags, model)