Ejemplo n.º 1
0
    def test_bad_construct(self):
        with self.assertRaisesRegex(ValueError, "Min batch size must be >= 1"):
            libtorchbeast.BatchingQueue(batch_dim=3,
                                        minimum_batch_size=0,
                                        maximum_batch_size=1)

        with self.assertRaisesRegex(
                ValueError, "Max batch size must be >= min batch size"):
            libtorchbeast.BatchingQueue(batch_dim=3,
                                        minimum_batch_size=1,
                                        maximum_batch_size=0)
Ejemplo n.º 2
0
    def setUp(self):
        self.server_proc = subprocess.Popen(
            ["python", "tests/contiguous_arrays_env.py"])

        server_address = ["unix:/tmp/contiguous_arrays_test"]
        self.learner_queue = libtorchbeast.BatchingQueue(batch_dim=1,
                                                         minimum_batch_size=1,
                                                         maximum_batch_size=10,
                                                         check_inputs=True)
        self.inference_batcher = libtorchbeast.DynamicBatcher(
            batch_dim=1,
            minimum_batch_size=1,
            maximum_batch_size=10,
            timeout_ms=100,
            check_outputs=True,
        )
        actor = libtorchbeast.ActorPool(
            unroll_length=1,
            learner_queue=self.learner_queue,
            inference_batcher=self.inference_batcher,
            env_server_addresses=server_address,
            initial_agent_state=(),
        )

        def run():
            actor.run()

        self.actor_thread = threading.Thread(target=run)
        self.actor_thread.start()

        self.target = np.arange(3 * 4 * 5)
        self.target = self.target.reshape(3, 4, 5)
        self.target = self.target.transpose(2, 1, 0)
Ejemplo n.º 3
0
    def test_batched_run(self, batch_size=2):
        queue = libtorchbeast.BatchingQueue(batch_dim=0,
                                            minimum_batch_size=batch_size,
                                            maximum_batch_size=batch_size)

        inputs = [torch.full((1, 2, 3), i) for i in range(batch_size)]

        def enqueue_target(i):
            while queue.size() < i:
                # Make sure thread i calls enqueue before thread i + 1.
                time.sleep(0.05)
            queue.enqueue(inputs[i])

        enqueue_threads = []
        for i in range(batch_size):
            enqueue_threads.append(
                threading.Thread(target=enqueue_target,
                                 name=f"enqueue-thread-{i}",
                                 args=(i, )))

        for t in enqueue_threads:
            t.start()

        batch = next(queue)
        np.testing.assert_array_equal(batch, torch.cat(inputs))

        for t in enqueue_threads:
            t.join()
Ejemplo n.º 4
0
    def test_simple_run(self):
        queue = libtorchbeast.BatchingQueue(batch_dim=0,
                                            minimum_batch_size=1,
                                            maximum_batch_size=1)

        inputs = torch.zeros(1, 2, 3)
        queue.enqueue(inputs)
        batch = next(queue)
        np.testing.assert_array_equal(batch, inputs)
Ejemplo n.º 5
0
 def test_check_inputs(self):
     queue = libtorchbeast.BatchingQueue(batch_dim=2)
     with self.assertRaisesRegex(
             ValueError,
             "Enqueued tensors must have more than batch_dim =="):
         queue.enqueue(torch.ones(5))
     with self.assertRaisesRegex(ValueError,
                                 "Cannot enqueue empty vector of tensors"):
         queue.enqueue([])
     with self.assertRaisesRegex(libtorchbeast.ClosedBatchingQueue,
                                 "Enqueue to closed queue"):
         queue.close()
         queue.enqueue(torch.ones(1, 1, 1))
Ejemplo n.º 6
0
    def setUp(self):
        self.server_proc = subprocess.Popen(
            ["python", "tests/core_agent_state_env.py"])

        self.B = 2
        self.T = 3
        self.model = Net()
        server_address = ["unix:/tmp/core_agent_state_test"]
        self.learner_queue = libtorchbeast.BatchingQueue(
            batch_dim=1,
            minimum_batch_size=self.B,
            maximum_batch_size=self.B,
            check_inputs=True,
        )
        self.replay_queue = libtorchbeast.BatchingQueue(
            batch_dim=1,
            minimum_batch_size=1,
            maximum_batch_size=1,
            timeout_ms=100,
            check_inputs=True,
            maximum_queue_size=1,
        )
        self.inference_batcher = libtorchbeast.DynamicBatcher(
            batch_dim=1,
            minimum_batch_size=1,
            maximum_batch_size=1,
            timeout_ms=100,
            check_outputs=True,
        )
        self.actor = libtorchbeast.ActorPool(
            unroll_length=self.T,
            learner_queue=self.learner_queue,
            replay_queue=self.replay_queue,
            inference_batcher=self.inference_batcher,
            env_server_addresses=server_address,
            initial_agent_state=self.model.initial_state(),
        )
Ejemplo n.º 7
0
    def test_many_consumers(self,
                            enqueue_threads_number=16,
                            repeats=100,
                            dequeue_threads_number=64):
        queue = libtorchbeast.BatchingQueue(batch_dim=0)

        lock = threading.Lock()
        total_batches_consumed = 0

        def enqueue_target(i):
            for _ in range(repeats):
                queue.enqueue(torch.full((1, 2, 3), i))

        def dequeue_target():
            nonlocal total_batches_consumed
            for batch in queue:
                batch_size, *_ = batch.shape
                with lock:
                    total_batches_consumed += batch_size

        enqueue_threads = []
        for i in range(enqueue_threads_number):
            enqueue_threads.append(
                threading.Thread(target=enqueue_target,
                                 name=f"enqueue-thread-{i}",
                                 args=(i, )))

        dequeue_threads = []
        for i in range(dequeue_threads_number):
            dequeue_threads.append(
                threading.Thread(target=dequeue_target,
                                 name=f"dequeue-thread-{i}"))

        for t in enqueue_threads + dequeue_threads:
            t.start()

        for t in enqueue_threads:
            t.join()

        queue.close()

        for t in dequeue_threads:
            t.join()

        self.assertEqual(total_batches_consumed,
                         repeats * enqueue_threads_number)
Ejemplo n.º 8
0
def train(flags):
    logging.info("Logging results to %s", flags.savedir)
    if isinstance(flags, omegaconf.DictConfig):
        flag_dict = omegaconf.OmegaConf.to_container(flags)
    else:
        flag_dict = vars(flags)
    plogger = file_writer.FileWriter(xp_args=flag_dict, rootdir=flags.savedir)

    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        learner_device = torch.device(flags.learner_device)
        actor_device = torch.device(flags.actor_device)
    else:
        logging.info("Not using CUDA.")
        learner_device = torch.device("cpu")
        actor_device = torch.device("cpu")

    if flags.max_learner_queue_size is None:
        flags.max_learner_queue_size = flags.batch_size

    # The queue the learner threads will get their data from.
    # Setting `minimum_batch_size == maximum_batch_size`
    # makes the batch size static. We could make it dynamic, but that
    # requires a loss (and learning rate schedule) that's batch size
    # independent.
    learner_queue = libtorchbeast.BatchingQueue(
        batch_dim=1,
        minimum_batch_size=flags.batch_size,
        maximum_batch_size=flags.batch_size,
        check_inputs=True,
        maximum_queue_size=flags.max_learner_queue_size,
    )

    # The "batcher", a queue for the inference call. Will yield
    # "batch" objects with `get_inputs` and `set_outputs` methods.
    # The batch size of the tensors will be dynamic.
    inference_batcher = libtorchbeast.DynamicBatcher(
        batch_dim=1,
        minimum_batch_size=1,
        maximum_batch_size=512,
        timeout_ms=100,
        check_outputs=True,
    )

    addresses = []
    connections_per_server = 1
    pipe_id = 0
    while len(addresses) < flags.num_actors:
        for _ in range(connections_per_server):
            addresses.append(f"{flags.pipes_basename}.{pipe_id}")
            if len(addresses) == flags.num_actors:
                break
        pipe_id += 1

    logging.info("Using model %s", flags.model)

    model = create_model(flags, learner_device)

    plogger.metadata["model_numel"] = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )

    logging.info("Number of model parameters: %i", plogger.metadata["model_numel"])

    actor_model = create_model(flags, actor_device)

    # The ActorPool that will run `flags.num_actors` many loops.
    actors = libtorchbeast.ActorPool(
        unroll_length=flags.unroll_length,
        learner_queue=learner_queue,
        inference_batcher=inference_batcher,
        env_server_addresses=addresses,
        initial_agent_state=model.initial_state(),
    )

    def run():
        try:
            actors.run()
        except Exception as e:
            logging.error("Exception in actorpool thread!")
            traceback.print_exc()
            print()
            raise e

    actorpool_thread = threading.Thread(target=run, name="actorpool-thread")

    optimizer = torch.optim.RMSprop(
        model.parameters(),
        lr=flags.learning_rate,
        momentum=flags.momentum,
        eps=flags.epsilon,
        alpha=flags.alpha,
    )

    def lr_lambda(epoch):
        return (
            1
            - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps)
            / flags.total_steps
        )

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    stats = {}

    if flags.checkpoint and os.path.exists(flags.checkpoint):
        logging.info("Loading checkpoint: %s" % flags.checkpoint)
        checkpoint_states = torch.load(
            flags.checkpoint, map_location=flags.learner_device
        )
        model.load_state_dict(checkpoint_states["model_state_dict"])
        optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"])
        stats = checkpoint_states["stats"]
        logging.info(f"Resuming preempted job, current stats:\n{stats}")

    # Initialize actor model like learner model.
    actor_model.load_state_dict(model.state_dict())

    learner_threads = [
        threading.Thread(
            target=learn,
            name="learner-thread-%i" % i,
            args=(
                learner_queue,
                model,
                actor_model,
                optimizer,
                scheduler,
                stats,
                flags,
                plogger,
                learner_device,
            ),
        )
        for i in range(flags.num_learner_threads)
    ]
    inference_threads = [
        threading.Thread(
            target=inference,
            name="inference-thread-%i" % i,
            args=(inference_batcher, actor_model, flags, actor_device),
        )
        for i in range(flags.num_inference_threads)
    ]

    actorpool_thread.start()
    for t in learner_threads + inference_threads:
        t.start()

    def checkpoint(checkpoint_path=None):
        if flags.checkpoint:
            if checkpoint_path is None:
                checkpoint_path = flags.checkpoint
            logging.info("Saving checkpoint to %s", checkpoint_path)
            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "stats": stats,
                    "flags": vars(flags),
                },
                checkpoint_path,
            )

    # TODO: test this again then uncomment (from deleted polyhydra code)
    # def receive_slurm_signal(signal_num=None, frame=None):
    #     logging.info("Received SIGTERM, checkpointing")
    #     make_checkpoint()

    # signal.signal(signal.SIGTERM, receive_slurm_signal)

    def format_value(x):
        return f"{x:1.5}" if isinstance(x, float) else str(x)

    try:
        train_start_time = timeit.default_timer()
        train_time_offset = stats.get("train_seconds", 0)  # used for resuming training
        last_checkpoint_time = timeit.default_timer()

        dev_checkpoint_intervals = [0, 0.25, 0.5, 0.75]

        loop_start_time = timeit.default_timer()
        loop_start_step = stats.get("step", 0)
        while True:
            if loop_start_step >= flags.total_steps:
                break
            time.sleep(5)
            loop_end_time = timeit.default_timer()
            loop_end_step = stats.get("step", 0)

            stats["train_seconds"] = round(
                loop_end_time - train_start_time + train_time_offset, 1
            )

            if loop_end_time - last_checkpoint_time > 10 * 60:
                # Save every 10 min.
                checkpoint()
                last_checkpoint_time = loop_end_time

            if len(dev_checkpoint_intervals) > 0:
                step_percentage = loop_end_step / flags.total_steps
                i = dev_checkpoint_intervals[0]
                if step_percentage > i:
                    checkpoint(flags.checkpoint[:-4] + "_" + str(i) + ".tar")
                    dev_checkpoint_intervals = dev_checkpoint_intervals[1:]

            logging.info(
                "Step %i @ %.1f SPS. Inference batcher size: %i."
                " Learner queue size: %i."
                " Other stats: (%s)",
                loop_end_step,
                (loop_end_step - loop_start_step) / (loop_end_time - loop_start_time),
                inference_batcher.size(),
                learner_queue.size(),
                ", ".join(
                    f"{key} = {format_value(value)}" for key, value in stats.items()
                ),
            )
            loop_start_time = loop_end_time
            loop_start_step = loop_end_step
    except KeyboardInterrupt:
        pass  # Close properly.
    else:
        logging.info("Learning finished after %i steps.", stats["step"])

    checkpoint()

    # Done with learning. Let's stop all the ongoing work.
    inference_batcher.close()
    learner_queue.close()

    actorpool_thread.join()

    for t in learner_threads + inference_threads:
        t.join()
Ejemplo n.º 9
0
 def test_multiple_close_calls(self):
     queue = libtorchbeast.BatchingQueue()
     queue.close()
     with self.assertRaisesRegex(RuntimeError, "Queue was closed already"):
         queue.close()
Ejemplo n.º 10
0
def train(flags):
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(xpid=flags.xpid,
                                     xp_args=flags.__dict__,
                                     rootdir=flags.savedir)
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" %
                           (flags.savedir, flags.xpid, "model.tar")))

    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.learner_device = torch.device("cuda:0")
        flags.actor_device = torch.device("cuda:1")
    else:
        logging.info("Not using CUDA.")
        flags.learner_device = torch.device("cpu")
        flags.actor_device = torch.device("cpu")

    if flags.max_learner_queue_size is None:
        flags.max_learner_queue_size = flags.batch_size

    # The queue the learner threads will get their data from.
    # Setting `minimum_batch_size == maximum_batch_size`
    # makes the batch size static.
    learner_queue = libtorchbeast.BatchingQueue(
        batch_dim=1,
        minimum_batch_size=flags.batch_size,
        maximum_batch_size=flags.batch_size,
        check_inputs=True,
        maximum_queue_size=flags.max_learner_queue_size,
    )

    # The "batcher", a queue for the inference call. Will yield
    # "batch" objects with `get_inputs` and `set_outputs` methods.
    # The batch size of the tensors will be dynamic.
    inference_batcher = libtorchbeast.DynamicBatcher(
        batch_dim=1,
        minimum_batch_size=1,
        maximum_batch_size=512,
        timeout_ms=100,
        check_outputs=True,
    )

    addresses = []
    connections_per_server = 1
    pipe_id = 0
    while len(addresses) < flags.num_actors:
        for _ in range(connections_per_server):
            addresses.append(f"{flags.pipes_basename}.{pipe_id}")
            if len(addresses) == flags.num_actors:
                break
        pipe_id += 1

    model = Net(num_actions=flags.num_actions, use_lstm=flags.use_lstm)
    model = model.to(device=flags.learner_device)

    actor_model = Net(num_actions=flags.num_actions, use_lstm=flags.use_lstm)
    actor_model.to(device=flags.actor_device)

    # The ActorPool that will run `flags.num_actors` many loops.
    actors = libtorchbeast.ActorPool(
        unroll_length=flags.unroll_length,
        learner_queue=learner_queue,
        inference_batcher=inference_batcher,
        env_server_addresses=addresses,
        initial_agent_state=actor_model.initial_state(),
    )

    def run():
        try:
            actors.run()
        except Exception as e:
            logging.error("Exception in actorpool thread!")
            traceback.print_exc()
            print()
            raise e

    actorpool_thread = threading.Thread(target=run, name="actorpool-thread")

    optimizer = torch.optim.RMSprop(
        model.parameters(),
        lr=flags.learning_rate,
        momentum=flags.momentum,
        eps=flags.epsilon,
        alpha=flags.alpha,
    )

    def lr_lambda(epoch):
        return (1 - min(epoch * flags.unroll_length * flags.batch_size,
                        flags.total_steps) / flags.total_steps)

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    stats = {}

    # Load state from a checkpoint, if possible.
    if os.path.exists(checkpointpath):
        checkpoint_states = torch.load(checkpointpath,
                                       map_location=flags.learner_device)
        model.load_state_dict(checkpoint_states["model_state_dict"])
        optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"])
        stats = checkpoint_states["stats"]
        logging.info(f"Resuming preempted job, current stats:\n{stats}")

    # Initialize actor model like learner model.
    actor_model.load_state_dict(model.state_dict())

    learner_threads = [
        threading.Thread(
            target=learn,
            name="learner-thread-%i" % i,
            args=(
                flags,
                learner_queue,
                model,
                actor_model,
                optimizer,
                scheduler,
                stats,
                plogger,
            ),
        ) for i in range(flags.num_learner_threads)
    ]
    inference_threads = [
        threading.Thread(
            target=inference,
            name="inference-thread-%i" % i,
            args=(flags, inference_batcher, actor_model),
        ) for i in range(flags.num_inference_threads)
    ]

    actorpool_thread.start()
    for t in learner_threads + inference_threads:
        t.start()

    def checkpoint():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "stats": stats,
                "flags": vars(flags),
            },
            checkpointpath,
        )

    def format_value(x):
        return f"{x:1.5}" if isinstance(x, float) else str(x)

    try:
        last_checkpoint_time = timeit.default_timer()
        while True:
            start_time = timeit.default_timer()
            start_step = stats.get("step", 0)
            if start_step >= flags.total_steps:
                break
            time.sleep(5)
            end_step = stats.get("step", 0)

            if timeit.default_timer() - last_checkpoint_time > 10 * 60:
                # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timeit.default_timer()

            logging.info(
                "Step %i @ %.1f SPS. Inference batcher size: %i."
                " Learner queue size: %i."
                " Other stats: (%s)",
                end_step,
                (end_step - start_step) /
                (timeit.default_timer() - start_time),
                inference_batcher.size(),
                learner_queue.size(),
                ", ".join(f"{key} = {format_value(value)}"
                          for key, value in stats.items()),
            )
    except KeyboardInterrupt:
        pass  # Close properly.
    else:
        logging.info("Learning finished after %i steps.", stats["step"])
        checkpoint()

    # Done with learning. Stop all the ongoing work.
    inference_batcher.close()
    learner_queue.close()

    actorpool_thread.join()

    for t in learner_threads + inference_threads:
        t.join()
Ejemplo n.º 11
0
def run():
    flags = argparse.Namespace()
    flags, argv = polybeast_learner.parser.parse_known_args(namespace=flags)
    flags, argv = polybeast_env.parser.parse_known_args(args=argv, namespace=flags)
    if argv:
        # Produce an error message.
        polybeast_env.parser.print_usage()
        print("Unkown args:", " ".join(argv))
        return -1

    env_processes = []
    for actor_id in range(1):
        p = mp.Process(target=run_env, args=(flags, actor_id))
        p.start()
        env_processes.append(p)

    if flags.max_learner_queue_size is None:
        flags.max_learner_queue_size = flags.batch_size

    learner_queue = libtorchbeast.BatchingQueue(
        batch_dim=1,
        minimum_batch_size=flags.batch_size,
        maximum_batch_size=flags.batch_size,
        check_inputs=True,
        maximum_queue_size=flags.max_learner_queue_size,
    )
    replay_queue = libtorchbeast.BatchingQueue(
        batch_dim=1,
        minimum_batch_size=1,
        maximum_batch_size=flags.num_actors,
        timeout_ms=100,
        check_inputs=True,
        maximum_queue_size=flags.num_actors,
    )
    inference_batcher = libtorchbeast.DynamicBatcher(
        batch_dim=1,
        minimum_batch_size=1,
        maximum_batch_size=512,
        timeout_ms=100,
        check_outputs=True,
    )

    addresses = []
    connections_per_server = 1
    pipe_id = 0
    while len(addresses) < flags.num_actors:
        for _ in range(connections_per_server):
            addresses.append(f"{flags.pipes_basename}.{pipe_id}")
            if len(addresses) == flags.num_actors:
                break
        pipe_id += 1

    actors = libtorchbeast.ActorPool(
        unroll_length=flags.unroll_length,
        learner_queue=learner_queue,
        replay_queue=replay_queue,
        inference_batcher=inference_batcher,
        env_server_addresses=addresses,
        initial_agent_state=(),
    )

    def run():
        try:
            actors.run()
            print("actors are running")
        except Exception as e:
            logging.error("Exception in actorpool thread!")
            traceback.print_exc()
            print()
            raise e

    actorpool_thread = threading.Thread(target=run, name="actorpool-thread")

    def dequeue(queue):
        for tensor in queue:
            del tensor

    dequeue_threads = [
        threading.Thread(target=dequeue, name="dequeue-thread-%i" % i, args=(queue,))
        for i, queue in enumerate([learner_queue, replay_queue])
    ]

    # create an environment to sample random actions
    dataset_uses_color = flags.dataset not in ["mnist", "omniglot"]
    grayscale = dataset_uses_color and not flags.use_color

    is_color = flags.use_color or flags.env_type == "fluid"
    if is_color is False:
        grayscale = True
    else:
        grayscale = is_color and not dataset_uses_color

    env_name, config = utils.parse_flags(flags)
    env = utils.create_env(env_name, config, grayscale, dataset=None)

    if flags.condition:
        new_space = env.observation_space.spaces
        c, h, w = new_space["canvas"].shape
        new_space["canvas"] = spaces.Box(
            low=0, high=255, shape=(c * 2, h, w), dtype=np.uint8
        )
        env.observation_space = spaces.Dict(new_space)

    action_space = env.action_space
    env.close()

    def inference(inference_batcher, lock=threading.Lock()):
        nonlocal step

        for batch in inference_batcher:
            batched_env_outputs, agent_state = batch.get_inputs()

            obs, _, done, *_ = batched_env_outputs
            B = done.shape[1]

            with lock:
                step += B

            actions = nest.map(lambda i: action_space.sample(), [i for i in range(B)])
            action = torch.from_numpy(np.concatenate(actions)).view(1, B, -1)

            outputs = ((action,), ())
            outputs = nest.map(lambda t: t, outputs)
            batch.set_outputs(outputs)

    lock = threading.Lock()
    inference_threads = [
        threading.Thread(
            target=inference,
            name="inference-thread-%i" % i,
            args=(inference_batcher, lock),
        )
        for i in range(flags.num_inference_threads)
    ]

    actorpool_thread.start()

    threads = dequeue_threads + inference_threads

    for t in threads:
        t.start()

    step = 0

    try:
        while step < 10000:
            start_time = timeit.default_timer()
            start_step = step
            time.sleep(3)
            end_step = step

            logging.info(
                "Step %i @ %.1f SPS.",
                end_step,
                (end_step - start_step) / (timeit.default_timer() - start_time),
            )
    except KeyboardInterrupt:
        pass

    inference_batcher.close()
    learner_queue.close()

    replay_queue.close()

    actorpool_thread.join()

    for t in threads:
        t.join()

    for p in env_processes:
        p.terminate()
Ejemplo n.º 12
0
def train(flags):
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(xpid=flags.xpid,
                                     xp_args=flags.__dict__,
                                     rootdir=flags.savedir)
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" %
                           (flags.savedir, flags.xpid, "model.tar")))

    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.learner_device = torch.device("cuda")
        flags.actor_device = torch.device("cuda")
    else:
        logging.info("Not using CUDA.")
        flags.learner_device = torch.device("cpu")
        flags.actor_device = torch.device("cpu")

    if flags.max_learner_queue_size is None:
        flags.max_learner_queue_size = flags.batch_size

    # The queue the learner threads will get their data from.
    # Setting `minimum_batch_size == maximum_batch_size`
    # makes the batch size static.
    learner_queue = libtorchbeast.BatchingQueue(
        batch_dim=1,
        minimum_batch_size=flags.batch_size,
        maximum_batch_size=flags.batch_size,
        check_inputs=True,
        maximum_queue_size=flags.max_learner_queue_size,
    )

    # The queue the actorpool stores final render image pairs.
    # A seperate thread will load them to the ReplayBuffer.
    # The batch size of the pairs will be dynamic.
    replay_queue = libtorchbeast.BatchingQueue(
        batch_dim=1,
        minimum_batch_size=flags.batch_size,
        maximum_batch_size=flags.batch_size,
        check_inputs=True,
        maximum_queue_size=flags.batch_size,
    )

    # The "batcher", a queue for the inference call. Will yield
    # "batch" objects with `get_inputs` and `set_outputs` methods.
    # The batch size of the tensors will be dynamic.
    inference_batcher = libtorchbeast.DynamicBatcher(
        batch_dim=1,
        minimum_batch_size=1,
        maximum_batch_size=512,
        timeout_ms=100,
        check_outputs=True,
    )

    addresses = []
    connections_per_server = 1
    pipe_id = 0
    while len(addresses) < flags.num_actors:
        for _ in range(connections_per_server):
            addresses.append(f"{flags.pipes_basename}.{pipe_id}")
            if len(addresses) == flags.num_actors:
                break
        pipe_id += 1

    dataset_is_gray = flags.dataset in ["mnist", "omniglot"]
    grayscale = not dataset_is_gray and not flags.use_color
    dataset_is_gray |= grayscale

    dataset = utils.create_dataset(flags.dataset, grayscale)

    env_name, config = utils.parse_flags(flags)
    env = utils.create_env(env_name, config, dataset_is_gray, dataset=None)

    if flags.condition:
        new_space = env.observation_space.spaces
        c, h, w = new_space["canvas"].shape
        new_space["canvas"] = spaces.Box(low=0,
                                         high=255,
                                         shape=(c * 2, h, w),
                                         dtype=np.uint8)
        env.observation_space = spaces.Dict(new_space)

    obs_shape = env.observation_space["canvas"].shape
    action_shape = env.action_space.nvec
    order = env.order
    env.close()

    model = models.Net(
        obs_shape=obs_shape,
        order=order,
        action_shape=action_shape,
        grid_shape=(grid_width, grid_width),
    )
    model = model.to(device=flags.learner_device)

    actor_model = models.Net(
        obs_shape=obs_shape,
        order=order,
        action_shape=action_shape,
        grid_shape=(grid_width, grid_width),
    ).eval()
    actor_model.to(device=flags.actor_device)

    if flags.condition:
        D = models.ComplementDiscriminator(obs_shape)
    else:
        D = models.Discriminator(obs_shape)
    D.to(device=flags.learner_device)

    if flags.condition:
        D_eval = models.ComplementDiscriminator(obs_shape)
    else:
        D_eval = models.Discriminator(obs_shape)
    D_eval = D_eval.to(device=flags.learner_device).eval()

    optimizer = optim.Adam(model.parameters(), lr=flags.policy_learning_rate)
    D_optimizer = optim.Adam(D.parameters(),
                             lr=flags.discriminator_learning_rate,
                             betas=(0.5, 0.999))

    def lr_lambda(epoch):
        return (1 - min(epoch * flags.unroll_length * flags.batch_size,
                        flags.total_steps) / flags.total_steps)

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # The ActorPool that will run `flags.num_actors` many loops.
    actors = libtorchbeast.ActorPool(
        unroll_length=flags.unroll_length,
        learner_queue=learner_queue,
        replay_queue=replay_queue,
        inference_batcher=inference_batcher,
        env_server_addresses=addresses,
        initial_agent_state=actor_model.initial_state(),
    )

    def run():
        try:
            actors.run()
            print("actors are running")
        except Exception as e:
            logging.error("Exception in actorpool thread!")
            traceback.print_exc()
            print()
            raise e

    actorpool_thread = threading.Thread(target=run, name="actorpool-thread")

    dataloader = DataLoader(
        dataset,
        batch_size=flags.batch_size,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
    )

    stats = {}

    # Load state from a checkpoint, if possible.
    if os.path.exists(checkpointpath):
        checkpoint_states = torch.load(checkpointpath,
                                       map_location=flags.learner_device)
        model.load_state_dict(checkpoint_states["model_state_dict"])
        D.load_state_dict(checkpoint_states["D_state_dict"])
        optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"])
        D_optimizer.load_state_dict(
            checkpoint_states["D_optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"])
        stats = checkpoint_states["stats"]
        logging.info(f"Resuming preempted job, current stats:\n{stats}")

    # Initialize actor model like learner model.
    actor_model.load_state_dict(model.state_dict())
    D_eval.load_state_dict(D.state_dict())

    learner_threads = [
        threading.Thread(
            target=learn,
            name="learner-thread-%i" % i,
            args=(
                flags,
                learner_queue,
                model,
                actor_model,
                D_eval,
                optimizer,
                scheduler,
                stats,
                plogger,
            ),
        ) for i in range(flags.num_learner_threads)
    ]

    inference_threads = [
        threading.Thread(
            target=inference,
            name="inference-thread-%i" % i,
            args=(
                flags,
                inference_batcher,
                actor_model,
            ),
        ) for i in range(flags.num_inference_threads)
    ]

    d_learner = threading.Thread(
        target=learn_D,
        name="d_learner-thread",
        args=(
            flags,
            dataloader,
            replay_queue,
            D,
            D_eval,
            D_optimizer,
            stats,
            plogger,
        ),
    )

    actorpool_thread.start()

    threads = learner_threads + inference_threads

    for t in inference_threads:
        t.start()

    def checkpoint():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "D_state_dict": D.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "D_optimizer_state_dict": D_optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "stats": stats,
                "flags": vars(flags),
            },
            checkpointpath,
        )

    def format_value(x):
        return f"{x:1.5}" if isinstance(x, float) else str(x)

    try:
        while replay_queue.size() < flags.batch_size:
            if learner_queue.size() >= flags.batch_size:
                next(learner_queue)
            time.sleep(0.01)

        d_learner.start()

        for t in learner_threads:
            t.start()

        last_checkpoint_time = timeit.default_timer()
        while True:
            start_time = timeit.default_timer()
            start_step = stats.get("step", 0)
            if start_step >= flags.total_steps:
                break
            time.sleep(5)
            end_step = stats.get("step", 0)

            if timeit.default_timer() - last_checkpoint_time > 10 * 60:
                # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timeit.default_timer()

            logging.info(
                "Step %i @ %.1f SPS. Inference batcher size: %i."
                " Learner queue size: %i."
                " Other stats: (%s)",
                end_step,
                (end_step - start_step) /
                (timeit.default_timer() - start_time),
                inference_batcher.size(),
                learner_queue.size(),
                ", ".join(f"{key} = {format_value(value)}"
                          for key, value in stats.items()),
            )
    except KeyboardInterrupt:
        pass  # Close properly.
    else:
        logging.info("Learning finished after %i steps.", stats["step"])
        checkpoint()

    # Done with learning. Stop all the ongoing work.
    inference_batcher.close()
    learner_queue.close()

    replay_queue.close()

    actorpool_thread.join()

    d_learner.join()

    for t in threads:
        t.join()