Exemplo n.º 1
0
    def test_initial_state(self):
        model = models.Net(
            obs_shape=self.obs_shape,
            order=self.order,
            action_shape=self.action_shape,
            grid_shape=self.grid_shape,
        )
        core_state = model.initial_state(self.batch_size)

        self.assertEqual(len(core_state), 2)
        for core_state_element in core_state:
            self.assertSequenceEqual(
                core_state_element.shape, (1, self.batch_size, self.core_output_size),
            )
Exemplo n.º 2
0
    def test_forward_return_signature(self):
        model = models.Net(
            obs_shape=self.obs_shape,
            order=self.order,
            action_shape=self.action_shape,
            grid_shape=self.grid_shape,
        )
        core_state = model.initial_state(self.batch_size)

        (action, policy_logits, baseline), core_state = model(*self.inputs, core_state)
        self.assertSequenceEqual(
            action.shape, (self.batch_size, self.unroll_length, len(self.action_shape))
        )
        for logits, num_actions in zip(policy_logits, self.action_shape):
            self.assertSequenceEqual(
                logits.shape, (self.batch_size, self.unroll_length, num_actions)
            )
        self.assertSequenceEqual(baseline.shape, (self.batch_size, self.unroll_length))
        for core_state_element in core_state:
            self.assertSequenceEqual(
                core_state_element.shape, (1, self.batch_size, self.core_output_size),
            )
Exemplo n.º 3
0
def test(flags):
    if flags.xpid is None:
        checkpointpath = "./latest/model.tar"
    else:
        checkpointpath = os.path.expandvars(
            os.path.expanduser("%s/%s/%s" %
                               (flags.savedir, flags.xpid, "model.tar")))

    config = dict(
        episode_length=flags.episode_length,
        canvas_width=flags.canvas_width,
        grid_width=grid_width,
        brush_sizes=flags.brush_sizes,
    )

    if flags.dataset == "celeba" or flags.dataset == "celeba-hq":
        use_color = True
    else:
        use_color = False

    if flags.env_type == "fluid":
        env_name = "Fluid"
        config["shaders_basedir"] = SHADERS_BASEDIR
    elif flags.env_type == "libmypaint":
        env_name = "Libmypaint"
        config.update(
            dict(
                brush_type=flags.brush_type,
                use_color=use_color,
                use_pressure=flags.use_pressure,
                use_alpha=False,
                background="white",
                brushes_basedir=BRUSHES_BASEDIR,
            ))

    if flags.use_compound:
        env_name += "-v1"
        config.update(
            dict(
                new_stroke_penalty=flags.new_stroke_penalty,
                stroke_length_penalty=flags.stroke_length_penalty,
            ))
    else:
        env_name += "-v0"

    env = env_wrapper.make_raw(env_name, config)
    if frame_width != flags.canvas_width:
        env = env_wrapper.WarpFrame(env, height=frame_width, width=frame_width)
    env = env_wrapper.wrap_pytorch(env)
    env = env_wrapper.AddDim(env)

    obs_shape = env.observation_space.shape
    if flags.condition:
        c, h, w = obs_shape
        c *= 2
        obs_shape = (c, h, w)

    action_shape = env.action_space.nvec.tolist()
    order = env.order

    model = models.Net(
        obs_shape=obs_shape,
        action_shape=action_shape,
        grid_shape=(grid_width, grid_width),
        order=order,
    )
    if flags.condition:
        model = models.Condition(model)
    model.eval()

    D = models.Discriminator(obs_shape, flags.power_iters)
    if flags.condition:
        D = models.Conditional(D)
    D.eval()

    checkpoint = torch.load(checkpointpath, map_location="cpu")
    model.load_state_dict(checkpoint["model_state_dict"])
    D.load_state_dict(checkpoint["D_state_dict"])

    if flags.condition:
        from random import randrange

        c, h, w = obs_shape
        tsfm = transforms.Compose(
            [transforms.Resize((h, w)),
             transforms.ToTensor()])
        dataset = flags.dataset

        if dataset == "mnist":
            dataset = MNIST(root="./",
                            train=True,
                            transform=tsfm,
                            download=True)
        elif dataset == "omniglot":
            dataset = Omniglot(root="./",
                               background=True,
                               transform=tsfm,
                               download=True)
        elif dataset == "celeba":
            dataset = CelebA(
                root="./",
                split="train",
                target_type=None,
                transform=tsfm,
                download=True,
            )
        elif dataset == "celeba-hq":
            dataset = datasets.CelebAHQ(root="./",
                                        split="train",
                                        transform=tsfm,
                                        download=True)
        else:
            raise NotImplementedError

        condition = dataset[randrange(len(dataset))].view((1, 1) + obs_shape)
    else:
        condition = None

    frame = env.reset()
    action = model.initial_action()
    agent_state = model.initial_state()
    done = torch.tensor(False).view(1, 1)
    rewards = []
    frames = [frame]

    for i in range(flags.episode_length - 1):
        if flags.mode == "test_render":
            env.render()
        noise = torch.randn(1, 1, 10)
        agent_outputs, agent_state = model(
            dict(
                obs=frame,
                condition=condition,
                action=action,
                noise=noise,
                done=done,
            ),
            agent_state,
        )
        action, *_ = agent_outputs
        frame, reward, done, _ = env.step(action)

        rewards.append(reward)
        frames.append(frame)

    reward = torch.cat(rewards)
    frame = torch.cat(frames)

    if flags.use_tca:
        frame = torch.flatten(frame, 0, 1)
        if flags.condition:
            condition = torch.flatten(condition, 0, 1)
    else:
        frame = frame[-1]
        if flags.condition:
            condition = condition[-1]

    D = D.eval()
    with torch.no_grad():
        if flags.condition:
            p = D(frame, condition).view(-1, 1)
        else:
            p = D(frame).view(-1, 1)

        if flags.use_tca:
            d_reward = p[1:] - p[:-1]
            reward = reward[1:] + d_reward
        else:
            reward[-1] = reward[-1] + p
            reward = reward[1:]

            # empty condition
            condition = None

    logging.info(
        "Episode ended after %d steps. Final reward: %.4f. Episode reward: %.4f,",
        flags.episode_length,
        reward[-1].item(),
        reward.sum(),
    )
    env.close()
Exemplo n.º 4
0
def train(flags):
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(xpid=flags.xpid,
                                     xp_args=flags.__dict__,
                                     rootdir=flags.savedir)
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" %
                           (flags.savedir, flags.xpid, "model.tar")))

    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.learner_device = torch.device("cuda")
        flags.actor_device = torch.device("cuda")
    else:
        logging.info("Not using CUDA.")
        flags.learner_device = torch.device("cpu")
        flags.actor_device = torch.device("cpu")

    if flags.max_learner_queue_size is None:
        flags.max_learner_queue_size = flags.batch_size

    # The queue the learner threads will get their data from.
    # Setting `minimum_batch_size == maximum_batch_size`
    # makes the batch size static.
    learner_queue = actorpool.BatchingQueue(
        batch_dim=1,
        minimum_batch_size=flags.batch_size,
        maximum_batch_size=flags.batch_size,
        check_inputs=True,
        maximum_queue_size=flags.max_learner_queue_size,
    )

    d_queue = Queue(maxsize=flags.max_learner_queue_size // flags.batch_size)
    image_queue = Queue(maxsize=flags.max_learner_queue_size)

    # The "batcher", a queue for the inference call. Will yield
    # "batch" objects with `get_inputs` and `set_outputs` methods.
    # The batch size of the tensors will be dynamic.
    inference_batcher = actorpool.DynamicBatcher(
        batch_dim=1,
        minimum_batch_size=1,
        maximum_batch_size=512,
        timeout_ms=100,
        check_outputs=True,
    )

    addresses = []
    connections_per_server = 1
    pipe_id = 0
    while len(addresses) < flags.num_actors:
        for _ in range(connections_per_server):
            addresses.append(f"{flags.pipes_basename}.{pipe_id}")
            if len(addresses) == flags.num_actors:
                break
        pipe_id += 1

    config = dict(
        episode_length=flags.episode_length,
        canvas_width=flags.canvas_width,
        grid_width=grid_width,
        brush_sizes=flags.brush_sizes,
    )

    if flags.dataset == "celeba" or flags.dataset == "celeba-hq":
        use_color = True
    else:
        use_color = False

    if flags.env_type == "fluid":
        env_name = "Fluid"
        config["shaders_basedir"] = SHADERS_BASEDIR
    elif flags.env_type == "libmypaint":
        env_name = "Libmypaint"
        config.update(
            dict(
                brush_type=flags.brush_type,
                use_color=use_color,
                use_pressure=flags.use_pressure,
                use_alpha=False,
                background="white",
                brushes_basedir=BRUSHES_BASEDIR,
            ))

    if flags.use_compound:
        env_name += "-v1"
    else:
        env_name += "-v0"

    env = env_wrapper.make_raw(env_name, config)
    if frame_width != flags.canvas_width:
        env = env_wrapper.WarpFrame(env, height=frame_width, width=frame_width)
    env = env_wrapper.wrap_pytorch(env)

    obs_shape = env.observation_space.shape
    if flags.condition:
        c, h, w = obs_shape
        c *= 2
        obs_shape = (c, h, w)

    action_shape = env.action_space.nvec.tolist()
    order = env.order
    env.close()

    model = models.Net(
        obs_shape=obs_shape,
        action_shape=action_shape,
        grid_shape=(grid_width, grid_width),
        order=order,
    )
    if flags.condition:
        model = models.Condition(model)
    model = model.to(device=flags.learner_device)

    actor_model = models.Net(
        obs_shape=obs_shape,
        action_shape=action_shape,
        grid_shape=(grid_width, grid_width),
        order=order,
    )
    if flags.condition:
        actor_model = models.Condition(actor_model)
    actor_model.to(device=flags.actor_device)

    D = models.Discriminator(obs_shape, flags.power_iters)
    if flags.condition:
        D = models.Conditional(D)
    D.to(device=flags.learner_device)

    D_eval = models.Discriminator(obs_shape, flags.power_iters)
    if flags.condition:
        D_eval = models.Conditional(D_eval)
    D_eval = D_eval.to(device=flags.learner_device)

    optimizer = optim.Adam(model.parameters(), lr=flags.policy_learning_rate)
    D_optimizer = optim.Adam(D.parameters(),
                             lr=flags.discriminator_learning_rate,
                             betas=(0.5, 0.999))

    def lr_lambda(epoch):
        return (1 - min(epoch * flags.unroll_length * flags.batch_size,
                        flags.total_steps) / flags.total_steps)

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    D_scheduler = torch.optim.lr_scheduler.LambdaLR(D_optimizer, lr_lambda)

    C, H, W = obs_shape
    if flags.condition:
        C //= 2
    # The ActorPool that will run `flags.num_actors` many loops.
    actors = actorpool.ActorPool(
        unroll_length=flags.unroll_length,
        learner_queue=learner_queue,
        inference_batcher=inference_batcher,
        env_server_addresses=addresses,
        initial_action=actor_model.initial_action(),
        initial_agent_state=actor_model.initial_state(),
        image=torch.zeros(1, 1, C, H, W),
    )

    def run():
        try:
            actors.run()
            print("actors are running")
        except Exception as e:
            logging.error("Exception in actorpool thread!")
            traceback.print_exc()
            print()
            raise e

    actorpool_thread = threading.Thread(target=run, name="actorpool-thread")

    c, h, w = obs_shape
    tsfm = transforms.Compose(
        [transforms.Resize((h, w)),
         transforms.ToTensor()])

    dataset = flags.dataset

    if dataset == "mnist":
        dataset = MNIST(root="./", train=True, transform=tsfm, download=True)
    elif dataset == "omniglot":
        dataset = Omniglot(root="./",
                           background=True,
                           transform=tsfm,
                           download=True)
    elif dataset == "celeba":
        dataset = CelebA(root="./",
                         split="train",
                         target_type=None,
                         transform=tsfm,
                         download=True)
    elif dataset == "celeba-hq":
        dataset = datasets.CelebAHQ(root="./",
                                    split="train",
                                    transform=tsfm,
                                    download=True)
    else:
        raise NotImplementedError

    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=True,
                            drop_last=True,
                            pin_memory=True)

    stats = {}

    # Load state from a checkpoint, if possible.
    if os.path.exists(checkpointpath):
        checkpoint_states = torch.load(checkpointpath,
                                       map_location=flags.learner_device)
        model.load_state_dict(checkpoint_states["model_state_dict"])
        D.load_state_dict(checkpoint_states["D_state_dict"])
        optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"])
        D_optimizer.load_state_dict(
            checkpoint_states["D_optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint_states["D_scheduler_state_dict"])
        D_scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"])
        stats = checkpoint_states["stats"]
        logging.info(f"Resuming preempted job, current stats:\n{stats}")

    # Initialize actor model like learner model.
    actor_model.load_state_dict(model.state_dict())
    D_eval.load_state_dict(D.state_dict())

    learner_threads = [
        threading.Thread(
            target=learn,
            name="learner-thread-%i" % i,
            args=(
                flags,
                learner_queue,
                d_queue,
                model,
                actor_model,
                D_eval,
                optimizer,
                scheduler,
                stats,
                plogger,
            ),
        ) for i in range(flags.num_learner_threads)
    ]
    inference_threads = [
        threading.Thread(
            target=inference,
            name="inference-thread-%i" % i,
            args=(
                flags,
                inference_batcher,
                actor_model,
                image_queue,
            ),
        ) for i in range(flags.num_inference_threads)
    ]

    d_learner = [
        threading.Thread(
            target=learn_D,
            name="d_learner-thread-%i" % i,
            args=(
                flags,
                d_queue,
                D,
                D_eval,
                D_optimizer,
                D_scheduler,
                stats,
                plogger,
            ),
        ) for i in range(flags.num_learner_threads)
    ]
    for thread in d_learner:
        thread.daemon = True

    dataloader_thread = threading.Thread(target=data_loader,
                                         args=(
                                             flags,
                                             dataloader,
                                             image_queue,
                                         ))
    dataloader_thread.daemon = True

    actorpool_thread.start()

    threads = learner_threads + inference_threads
    daemons = d_learner + [dataloader_thread]

    for t in threads + daemons:
        t.start()

    def checkpoint():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "D_state_dict": D.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "D_optimizer_state_dict": D_optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "D_scheduler_state_dict": D_scheduler.state_dict(),
                "stats": stats,
                "flags": vars(flags),
            },
            checkpointpath,
        )

    def format_value(x):
        return f"{x:1.5}" if isinstance(x, float) else str(x)

    try:
        last_checkpoint_time = timeit.default_timer()
        while True:
            start_time = timeit.default_timer()
            start_step = stats.get("step", 0)
            if start_step >= flags.total_steps:
                break
            time.sleep(5)
            end_step = stats.get("step", 0)

            if timeit.default_timer() - last_checkpoint_time > 10 * 60:
                # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timeit.default_timer()

            logging.info(
                "Step %i @ %.1f SPS. Inference batcher size: %i."
                " Learner queue size: %i."
                " Other stats: (%s)",
                end_step,
                (end_step - start_step) /
                (timeit.default_timer() - start_time),
                inference_batcher.size(),
                learner_queue.size(),
                ", ".join(f"{key} = {format_value(value)}"
                          for key, value in stats.items()),
            )
    except KeyboardInterrupt:
        pass  # Close properly.
    else:
        logging.info("Learning finished after %i steps.", stats["step"])
        checkpoint()

    # Done with learning. Stop all the ongoing work.
    inference_batcher.close()
    learner_queue.close()

    actorpool_thread.join()

    for t in threads:
        t.join()
    def setUp(self):
        unroll_length = 3  # Inference called for every step.
        batch_size = 4  # Arbitrary.
        frame_dimension = 64  # Has to match what expected by the model.
        order = ["control", "end", "flag", "size", "pressure"]
        action_shape = [1024, 1024, 2, 8, 10]
        num_channels = 2  # Has to match with the first conv layer of the net.
        grid_shape = [32, 32]  # Specific to each environment.

        obs_shape = [num_channels, frame_dimension, frame_dimension]

        # The following hyperparamaters are arbitrary.
        self.lr = 0.1
        total_steps = 100000

        # Set the random seed manually to get reproducible results.
        torch.manual_seed(0)

        self.model = models.Net(
            obs_shape=obs_shape,
            order=order,
            action_shape=action_shape,
            grid_shape=grid_shape,
        )
        self.actor_model = models.Net(
            obs_shape=obs_shape,
            order=order,
            action_shape=action_shape,
            grid_shape=grid_shape,
        )
        self.initial_model_dict = copy.deepcopy(self.model.state_dict())
        self.initial_actor_model_dict = copy.deepcopy(
            self.actor_model.state_dict())

        optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr)

        self.D = models.ComplementDiscriminator(obs_shape, spectral_norm=False)
        self.D_eval = models.ComplementDiscriminator(
            obs_shape, spectral_norm=False).eval()

        self.initial_D_dict = copy.deepcopy(self.D.state_dict())
        self.initial_D_eval_dict = copy.deepcopy(self.D_eval.state_dict())

        D_optimizer = torch.optim.SGD(self.D.parameters(), lr=self.lr)

        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=total_steps //
                                                    10)

        D_scheduler = torch.optim.lr_scheduler.StepLR(D_optimizer,
                                                      step_size=total_steps //
                                                      10)

        self.stats = {}

        # The call to plogger.log will not perform any action.
        plogger = mock.Mock()
        plogger.log = mock.Mock()

        # Mock flags.
        mock_flags = mock.Mock()
        mock_flags.learner_device = torch.device("cpu")
        mock_flags.discounting = 0.99  # Default value from cmd.
        mock_flags.baseline_cost = 0.5  # Default value from cmd.
        mock_flags.entropy_cost = 0.0006  # Default value from cmd.
        mock_flags.unroll_length = unroll_length - 1
        mock_flags.batch_size = batch_size
        mock_flags.grad_norm_clipping = 40
        mock_flags.use_tca = True
        mock_flags.condition = True

        # Prepare content for mock_learner_queue.
        obs = dict(
            canvas=torch.ones([unroll_length, batch_size] + obs_shape),
            prev_action=torch.ones(unroll_length, batch_size,
                                   len(action_shape)),
            action_mask=torch.ones(unroll_length, batch_size,
                                   len(action_shape)),
            noise_sample=torch.ones(unroll_length, batch_size, 10),
        )
        rewards = torch.ones(unroll_length, batch_size)
        done = torch.zeros(unroll_length, batch_size, dtype=torch.bool)
        episode_step = torch.ones(unroll_length, batch_size)
        episode_return = torch.ones(unroll_length, batch_size)

        new_frame = dict(
            canvas=torch.ones([unroll_length - 1, batch_size] + obs_shape),
            prev_action=torch.ones(unroll_length - 1, batch_size,
                                   len(action_shape)),
            action_mask=torch.ones(unroll_length - 1, batch_size,
                                   len(action_shape)),
            noise_sample=torch.ones(unroll_length - 1, batch_size, 10),
        )

        env_outputs = (obs, rewards, done, episode_step, episode_return)
        actor_outputs = (
            # Actions taken.
            torch.cat(
                list(
                    map(
                        lambda num_actions: torch.randint(
                            low=0,
                            high=num_actions,
                            size=(unroll_length, batch_size, 1)),
                        action_shape,
                    )),
                dim=-1,
            ),
            # Logits.
            list(
                map(
                    lambda num_actions: torch.randn(unroll_length, batch_size,
                                                    num_actions),
                    action_shape,
                )),
            # Baseline.
            torch.rand(unroll_length, batch_size),
        )
        initial_agent_state = self.model.initial_state(batch_size)
        tensors = ((env_outputs, actor_outputs), new_frame,
                   initial_agent_state)

        # Mock learner_queue.
        mock_learner_queue = mock.MagicMock()
        mock_learner_queue.__iter__.return_value = iter([tensors])

        self.learn_args = (
            mock_flags,
            mock_learner_queue,
            self.model,
            self.actor_model,
            self.D_eval,
            optimizer,
            scheduler,
            self.stats,
            plogger,
        )

        # Mock replay_queue.
        mock_replay_queue = mock.MagicMock()
        mock_replay_queue.is_closed.return_value = True

        # Mock dataloader.
        mock_dataloader = mock.MagicMock()
        mock_dataloader.__iter__.return_value = iter([(
            torch.ones(batch_size, 1, frame_dimension, frame_dimension),
            None,
        )])

        # Mock replay_buffer.
        mock_replay_buffer = mock.MagicMock()
        mock_replay_buffer.sample.return_value = torch.ones([batch_size] +
                                                            obs_shape)

        self.learn_D_args = (
            mock_flags,
            mock_dataloader,
            mock_replay_queue,
            mock_replay_buffer,
            self.D,
            self.D_eval,
            D_optimizer,
            D_scheduler,
            self.stats,
            plogger,
        )
Exemplo n.º 6
0
def train(flags):
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(xpid=flags.xpid,
                                     xp_args=flags.__dict__,
                                     rootdir=flags.savedir)
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" %
                           (flags.savedir, flags.xpid, "model.tar")))

    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.learner_device = torch.device("cuda")
        flags.actor_device = torch.device("cuda")
    else:
        logging.info("Not using CUDA.")
        flags.learner_device = torch.device("cpu")
        flags.actor_device = torch.device("cpu")

    if flags.max_learner_queue_size is None:
        flags.max_learner_queue_size = flags.batch_size

    # The queue the learner threads will get their data from.
    # Setting `minimum_batch_size == maximum_batch_size`
    # makes the batch size static.
    learner_queue = libtorchbeast.BatchingQueue(
        batch_dim=1,
        minimum_batch_size=flags.batch_size,
        maximum_batch_size=flags.batch_size,
        check_inputs=True,
        maximum_queue_size=flags.max_learner_queue_size,
    )

    # The queue the actorpool stores final render image pairs.
    # A seperate thread will load them to the ReplayBuffer.
    # The batch size of the pairs will be dynamic.
    replay_queue = libtorchbeast.BatchingQueue(
        batch_dim=1,
        minimum_batch_size=flags.batch_size,
        maximum_batch_size=flags.batch_size,
        check_inputs=True,
        maximum_queue_size=flags.batch_size,
    )

    # The "batcher", a queue for the inference call. Will yield
    # "batch" objects with `get_inputs` and `set_outputs` methods.
    # The batch size of the tensors will be dynamic.
    inference_batcher = libtorchbeast.DynamicBatcher(
        batch_dim=1,
        minimum_batch_size=1,
        maximum_batch_size=512,
        timeout_ms=100,
        check_outputs=True,
    )

    addresses = []
    connections_per_server = 1
    pipe_id = 0
    while len(addresses) < flags.num_actors:
        for _ in range(connections_per_server):
            addresses.append(f"{flags.pipes_basename}.{pipe_id}")
            if len(addresses) == flags.num_actors:
                break
        pipe_id += 1

    dataset_is_gray = flags.dataset in ["mnist", "omniglot"]
    grayscale = not dataset_is_gray and not flags.use_color
    dataset_is_gray |= grayscale

    dataset = utils.create_dataset(flags.dataset, grayscale)

    env_name, config = utils.parse_flags(flags)
    env = utils.create_env(env_name, config, dataset_is_gray, dataset=None)

    if flags.condition:
        new_space = env.observation_space.spaces
        c, h, w = new_space["canvas"].shape
        new_space["canvas"] = spaces.Box(low=0,
                                         high=255,
                                         shape=(c * 2, h, w),
                                         dtype=np.uint8)
        env.observation_space = spaces.Dict(new_space)

    obs_shape = env.observation_space["canvas"].shape
    action_shape = env.action_space.nvec
    order = env.order
    env.close()

    model = models.Net(
        obs_shape=obs_shape,
        order=order,
        action_shape=action_shape,
        grid_shape=(grid_width, grid_width),
    )
    model = model.to(device=flags.learner_device)

    actor_model = models.Net(
        obs_shape=obs_shape,
        order=order,
        action_shape=action_shape,
        grid_shape=(grid_width, grid_width),
    ).eval()
    actor_model.to(device=flags.actor_device)

    if flags.condition:
        D = models.ComplementDiscriminator(obs_shape)
    else:
        D = models.Discriminator(obs_shape)
    D.to(device=flags.learner_device)

    if flags.condition:
        D_eval = models.ComplementDiscriminator(obs_shape)
    else:
        D_eval = models.Discriminator(obs_shape)
    D_eval = D_eval.to(device=flags.learner_device).eval()

    optimizer = optim.Adam(model.parameters(), lr=flags.policy_learning_rate)
    D_optimizer = optim.Adam(D.parameters(),
                             lr=flags.discriminator_learning_rate,
                             betas=(0.5, 0.999))

    def lr_lambda(epoch):
        return (1 - min(epoch * flags.unroll_length * flags.batch_size,
                        flags.total_steps) / flags.total_steps)

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # The ActorPool that will run `flags.num_actors` many loops.
    actors = libtorchbeast.ActorPool(
        unroll_length=flags.unroll_length,
        learner_queue=learner_queue,
        replay_queue=replay_queue,
        inference_batcher=inference_batcher,
        env_server_addresses=addresses,
        initial_agent_state=actor_model.initial_state(),
    )

    def run():
        try:
            actors.run()
            print("actors are running")
        except Exception as e:
            logging.error("Exception in actorpool thread!")
            traceback.print_exc()
            print()
            raise e

    actorpool_thread = threading.Thread(target=run, name="actorpool-thread")

    dataloader = DataLoader(
        dataset,
        batch_size=flags.batch_size,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
    )

    stats = {}

    # Load state from a checkpoint, if possible.
    if os.path.exists(checkpointpath):
        checkpoint_states = torch.load(checkpointpath,
                                       map_location=flags.learner_device)
        model.load_state_dict(checkpoint_states["model_state_dict"])
        D.load_state_dict(checkpoint_states["D_state_dict"])
        optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"])
        D_optimizer.load_state_dict(
            checkpoint_states["D_optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"])
        stats = checkpoint_states["stats"]
        logging.info(f"Resuming preempted job, current stats:\n{stats}")

    # Initialize actor model like learner model.
    actor_model.load_state_dict(model.state_dict())
    D_eval.load_state_dict(D.state_dict())

    learner_threads = [
        threading.Thread(
            target=learn,
            name="learner-thread-%i" % i,
            args=(
                flags,
                learner_queue,
                model,
                actor_model,
                D_eval,
                optimizer,
                scheduler,
                stats,
                plogger,
            ),
        ) for i in range(flags.num_learner_threads)
    ]

    inference_threads = [
        threading.Thread(
            target=inference,
            name="inference-thread-%i" % i,
            args=(
                flags,
                inference_batcher,
                actor_model,
            ),
        ) for i in range(flags.num_inference_threads)
    ]

    d_learner = threading.Thread(
        target=learn_D,
        name="d_learner-thread",
        args=(
            flags,
            dataloader,
            replay_queue,
            D,
            D_eval,
            D_optimizer,
            stats,
            plogger,
        ),
    )

    actorpool_thread.start()

    threads = learner_threads + inference_threads

    for t in inference_threads:
        t.start()

    def checkpoint():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "D_state_dict": D.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "D_optimizer_state_dict": D_optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "stats": stats,
                "flags": vars(flags),
            },
            checkpointpath,
        )

    def format_value(x):
        return f"{x:1.5}" if isinstance(x, float) else str(x)

    try:
        while replay_queue.size() < flags.batch_size:
            if learner_queue.size() >= flags.batch_size:
                next(learner_queue)
            time.sleep(0.01)

        d_learner.start()

        for t in learner_threads:
            t.start()

        last_checkpoint_time = timeit.default_timer()
        while True:
            start_time = timeit.default_timer()
            start_step = stats.get("step", 0)
            if start_step >= flags.total_steps:
                break
            time.sleep(5)
            end_step = stats.get("step", 0)

            if timeit.default_timer() - last_checkpoint_time > 10 * 60:
                # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timeit.default_timer()

            logging.info(
                "Step %i @ %.1f SPS. Inference batcher size: %i."
                " Learner queue size: %i."
                " Other stats: (%s)",
                end_step,
                (end_step - start_step) /
                (timeit.default_timer() - start_time),
                inference_batcher.size(),
                learner_queue.size(),
                ", ".join(f"{key} = {format_value(value)}"
                          for key, value in stats.items()),
            )
    except KeyboardInterrupt:
        pass  # Close properly.
    else:
        logging.info("Learning finished after %i steps.", stats["step"])
        checkpoint()

    # Done with learning. Stop all the ongoing work.
    inference_batcher.close()
    learner_queue.close()

    replay_queue.close()

    actorpool_thread.join()

    d_learner.join()

    for t in threads:
        t.join()
Exemplo n.º 7
0
    def _test_inference(self, use_color, device):
        model = models.Net(
            obs_shape=self.obs_shape,
            order=self.order,
            action_shape=self.action_shape,
            grid_shape=self.grid_shape,
        )
        model.to(device)
        agent_state = model.initial_state(self.batch_size)

        inputs = (
            (
                self.obs,
                self.rewards,
                self.done,
                self.episode_return,
                self.episode_return,
            ),
            agent_state,
        )

        # Set the behaviour of the methods of the mock batch.
        self.mock_batch.get_inputs = mock.Mock(return_value=inputs)
        self.mock_batch.set_outputs = mock.Mock()

        # Preparing the mock flags. Could do with just a dict but using
        # a Mock object for consistency.
        mock_flags = mock.Mock()
        mock_flags.actor_device = device

        polybeast.inference(mock_flags, self.mock_inference_batcher, model)

        # Assert the batch is used only once.
        self.mock_batch.get_inputs.assert_called_once()
        self.mock_batch.set_outputs.assert_called_once()
        # Check that set_outputs has been called with paramaters with the expected shape.
        batch_args, batch_kwargs = self.mock_batch.set_outputs.call_args
        self.assertEqual(batch_kwargs, {})
        model_outputs, *other_args = batch_args
        self.assertEqual(other_args, [])

        (action, policy_logits, baseline), core_state = model_outputs
        self.assertSequenceEqual(
            action.shape,
            (self.unroll_length, self.batch_size, len(self.action_shape)))
        for logits, num_actions in zip(policy_logits, self.action_shape):
            self.assertSequenceEqual(
                logits.shape,
                (self.unroll_length, self.batch_size, num_actions))
        self.assertSequenceEqual(baseline.shape,
                                 (self.unroll_length, self.batch_size))

        for tensor in (action, baseline) + core_state:
            self.assertEqual(tensor.device, torch.device("cpu"))
        for tensor in policy_logits:
            self.assertEqual(tensor.device, torch.device("cpu"))

        self.assertEqual(len(core_state), 2)
        for core_state_element in core_state:
            self.assertSequenceEqual(
                core_state_element.shape,
                (1, self.batch_size, self.core_output_size),
            )