示例#1
0
def inference(flags, inference_batcher, model, lock=threading.Lock()):
    with torch.no_grad():
        for batch in inference_batcher:
            batched_env_outputs, agent_state = batch.get_inputs()

            obs, _, done, *_ = batched_env_outputs

            obs, done, agent_state = nest.map(
                lambda t: t.to(flags.actor_device, non_blocking=True),
                [obs, done, agent_state],
            )

            with lock:
                outputs = model(obs, done, agent_state)

            outputs = nest.map(lambda t: t.cpu(), outputs)
            batch.set_outputs(outputs)
    def inference(inference_batcher, lock=threading.Lock()):
        nonlocal step

        for batch in inference_batcher:
            batched_env_outputs, agent_state = batch.get_inputs()

            obs, _, done, *_ = batched_env_outputs
            B = done.shape[1]

            with lock:
                step += B

            actions = nest.map(lambda i: action_space.sample(), [i for i in range(B)])
            action = torch.from_numpy(np.concatenate(actions)).view(1, B, -1)

            outputs = ((action,), ())
            outputs = nest.map(lambda t: t, outputs)
            batch.set_outputs(outputs)
示例#3
0
    def forward(self, inputs, core_state):
        x = inputs["frame"]
        T, B, *_ = x.shape
        x = torch.flatten(x, 0, 1)  # Merge time and batch.
        x = x.float() / 255.0

        res_input = None
        for i, fconv in enumerate(self.feat_convs):
            x = fconv(x)
            res_input = x
            x = self.resnet1[i](x)
            x += res_input
            res_input = x
            x = self.resnet2[i](x)
            x += res_input

        x = F.relu(x)
        x = x.view(T * B, -1)
        x = F.relu(self.fc(x))

        clipped_reward = torch.clamp(inputs["reward"], -1, 1).view(T * B, 1)
        core_input = torch.cat([x, clipped_reward], dim=-1)

        if self.use_lstm:
            core_input = core_input.view(T, B, -1)
            core_output_list = []
            notdone = (~inputs["done"]).float()
            for input, nd in zip(core_input.unbind(), notdone.unbind()):
                # Reset core state to zero whenever an episode ended.
                # Make `done` broadcastable with (num_layers, B, hidden_size)
                # states:
                nd = nd.view(1, -1, 1)
                core_state = nest.map(nd.mul, core_state)
                output, core_state = self.core(input.unsqueeze(0), core_state)
                core_output_list.append(output)
            core_output = torch.flatten(torch.cat(core_output_list), 0, 1)
        else:
            core_output = core_input

        policy_logits = self.policy(core_output)
        baseline = self.baseline(core_output)

        if self.training:
            action = torch.multinomial(F.softmax(policy_logits, dim=1),
                                       num_samples=1)
        else:
            # Don't sample when testing.
            action = torch.argmax(policy_logits, dim=1)

        policy_logits = policy_logits.view(T, B, self.num_actions)
        baseline = baseline.view(T, B)
        action = action.view(T, B)

        return (action, policy_logits, baseline), core_state
示例#4
0
    def test_nest_map(self):
        t1 = torch.tensor(0)
        t2 = torch.tensor(1)
        d = {"hey": t2}

        n = nest.map(lambda t: t + 42, (t1, t2))

        self.assertSequenceEqual(n, [t1 + 42, t2 + 42])
        self.assertSequenceEqual(n, nest.flatten(n))

        n1 = (d, n, t1)
        n2 = nest.map(lambda t: t * 2, n1)

        self.assertEqual(n2[0], {"hey": torch.tensor(2)})
        self.assertEqual(n2[1], (torch.tensor(84), torch.tensor(86)))
        self.assertEqual(n2[2], torch.tensor(0))

        t = torch.tensor(42)

        # Doesn't work with pybind11/functional.h, but does with py::function.
        self.assertEqual(nest.map(t.add, t2), torch.tensor(43))
示例#5
0
def tca_reward_function(flags, obs, new_frame, D):
    frame = obs["canvas"][:-1]
    frame, new_frame = nest.map(lambda t: torch.flatten(t, 0, 1),
                                (frame, new_frame))

    with torch.no_grad():
        reward = torch.zeros(flags.unroll_length + 1,
                             flags.batch_size,
                             device=flags.learner_device)

        reward[1:] += (D(new_frame) - D(frame)).view(-1, flags.batch_size)

    return reward
示例#6
0
文件: models.py 项目: ln-e/spiralpp
    def forward(self, input, core_state):
        T, B, *_ = input["obs"].shape
        grid = self.grid.repeat(T * B, 1, 1, 1)

        notdone = (~input["done"]).float()
        action = torch.flatten(input["action"] * notdone.unsqueeze(dim=2), 0,
                               1)
        obs = torch.flatten(input["obs"].float(), 0, 1)
        noise = torch.flatten(input["noise"], 0, 1)

        spatial = self.conv5x5(torch.cat([obs, grid], dim=1))
        noise_embedding = self.fc(noise).view(-1, 32, 1, 1)
        mlp = self.action_fc(action)

        embedding = self.relu(spatial + noise_embedding + mlp)

        h = self.conv(embedding)
        h = self.resblock(h)
        h = self.flatten_fc(h)
        h = self.relu(h)

        core_input = h.view(T, B, 256)
        core_output_list = []
        for input, nd in zip(core_input.unbind(), notdone.unbind()):
            nd = nd.view(1, -1, 1)
            core_state = nest.map(nd.mul, core_state)
            output, core_state = self.lstm(input.unsqueeze(0), core_state)
            core_output_list.append(output)
        core_output = torch.flatten(torch.cat(core_output_list), 0, 1)

        action, policy_logits = self.policy(core_output, action)
        baseline = self.baseline(core_output)

        action = action.view(T, B, self._num_actions)
        baseline = baseline.view(T, B)
        policy_logits = nest.map(lambda t: t.view(T, B, -1), policy_logits)

        return (action, policy_logits, baseline), core_state
示例#7
0
    def forward(self, obs, done, core_state):
        T, B, C, H, W = obs["canvas"].shape
        grid = self._grid(T * B, H, W)

        notdone = (~done).float()
        obs["prev_action"] = obs["prev_action"] * notdone.unsqueeze(dim=2)

        obs = nest.map(lambda t: torch.flatten(t, 0, 1), obs)

        canvas, action_mask, action, noise = (
            obs[k]
            for k in ["canvas", "action_mask", "prev_action", "noise_sample"])

        features = self.obs(torch.cat([canvas, grid], dim=1))

        condition = (self.noise(noise) +
                     self.action(self.mask_mlp(action, action_mask))).view(
                         -1, 32, 1, 1)

        embedding = self.base(self.relu(features + condition)).view(T, B, 256)

        core_output_list = []
        for core_input, nd in zip(embedding.unbind(), notdone.unbind()):
            nd = nd.view(1, -1, 1)
            core_state = nest.map(nd.mul, core_state)
            output, core_state = self.lstm(core_input.unsqueeze(0), core_state)
            core_output_list.append(output)
        seed = torch.flatten(torch.cat(core_output_list), 0, 1)

        action, logits = self.policy(seed, action)
        baseline = self.baseline(seed)

        action = action.view(T, B, self._num_actions)
        baseline = baseline.view(T, B)
        logits = nest.map(lambda t: t.view(T, B, -1), logits)

        return (action, logits, baseline), core_state
示例#8
0
    def __init__(self,
                 timestep,
                 unroll_length,
                 num_actors,
                 num_overlapping_steps=0):
        self._full_length = num_overlapping_steps + unroll_length + 1
        self._num_overlapping_steps = num_overlapping_steps

        N = num_actors
        L = self._full_length

        self._state = nest.map(
            lambda t: torch.zeros((N, L) + t.shape, dtype=t.dtype), timestep)

        self._index = torch.zeros([N], dtype=torch.int64)
示例#9
0
    def _complete_unrolls(self, actor_ids):
        """Obtain unrolls that have reached the desired length"""
        actor_indices = self._index[actor_ids]

        actor_ids = actor_ids[actor_indices == self._full_length]
        unrolls = nest.map(lambda s: s[actor_ids], self._state)

        # Reset state of completed actors to start from the end of the previous
        # ones (NB: since `unrolls` is a copy it is ok to do it in place).
        j = self._num_overlapping_steps + 1
        for s in nest.flatten(self._state):
            s[actor_ids, :j] = s[actor_ids, -j:]

        self._index.scatter_(0, actor_ids, 1 + self._num_overlapping_steps)

        return actor_ids, unrolls
示例#10
0
def compute_metrics(
    flags,
    learner_outputs,
    actor_outputs,
    env_outputs,
    last_actions,
    reward_stats=None,
    end_of_episode_bootstrap=False,
):
    """
    Compute various metrics (including in particular the loss being optimized).

    :param learner_outputs: A dictionary holding the following tensors, where
        `T` is the unroll length and `N` the batch size (number of actors):
        * "action": (T + 1, N)
        * "baseline": (T + 1, N)
        * "policy_logits": (T + 1, N, num_actions)
    :param actor_outputs: Similar to `learner_outputs`.
    :param env_outputs: A triplet of observation (T + 1, N, obs_dim), reward
        (T + 1, N) and done (T + 1, N) tensors.
    :param last_actions: not used
    :param reward_stats: Must be provided when reward normalization is enabled.
        This dictionary holds statistics on the observed rewards.
    """
    del last_actions  # Only used in model.
    # Estimated value of the last state in the rollout (N).
    bootstrap_value = learner_outputs["baseline"][-1]

    # Move from obs[t] -> action[t] to action[t] -> obs[t].
    # After this step all tensors have shape (T, N, ...)
    actor_outputs = nest.map(lambda t: t[:-1], actor_outputs)
    rewards, done = nest.map(lambda t: t[1:], env_outputs[1:])
    learner_outputs = nest.map(lambda t: t[:-1], learner_outputs)

    if flags.reward_clipping == "abs_one":
        rewards = torch.clamp(rewards, -1, 1)
    elif flags.reward_clipping == "soft_asymmetric":
        squeezed = torch.tanh(rewards / 5.0)
        # Negative rewards are given less weight than positive rewards.
        rewards = torch.where(rewards < 0, 0.3 * squeezed, squeezed) * 5.0
    elif flags.reward_clipping != "none":
        raise NotImplementedError(flags.reward_clipping)

    if flags.reward_normalization:
        train_job_id = env_outputs[0][1:][:, :, -1].long()
        normalize_rewards(flags, train_job_id, rewards, reward_stats)

    discounts = (~done).float() * flags.discounting

    actor_logits = actor_outputs["policy_logits"]  # (T, N, num_actions)
    learner_logits = learner_outputs["policy_logits"]  # (T, N, num_actions)
    actions = actor_outputs["action"]  # (T, N)
    baseline = learner_outputs["baseline"]  # (T, N)

    vtrace_returns = vtrace.from_logits(
        behavior_policy_logits=actor_logits,
        target_policy_logits=learner_logits,
        actions=actions,
        discounts=discounts,
        rewards=rewards,
        values=baseline,
        bootstrap_value=bootstrap_value,
        end_of_episode_bootstrap=end_of_episode_bootstrap,
        done=done,
    )

    pg_loss = compute_policy_gradient_loss(learner_logits, actions,
                                           vtrace_returns.pg_advantages)
    baseline_loss = compute_baseline_loss(
        # NB: if `end_of_episode_bootstrap` is True, then `vs` at end of episode
        # is equal to the baseline, contributing to a zero cost. This is what we
        # want, as we cannot learn from that step without knowing the next state.
        vtrace_returns.vs - baseline)
    entropy_loss = compute_entropy_loss(learner_logits)

    # Total entropy if the learner was outputting a uniform distribution over actions
    # (used for normalization in the reported metrics).
    batch_total_size = learner_logits.shape[0] * learner_logits.shape[1]
    uniform_entropy = batch_total_size * np.log(flags.num_actions)

    return {
        "loss/total":
        pg_loss + flags.baseline_cost * baseline_loss +
        flags.entropy_cost * entropy_loss,
        "loss/pg":
        pg_loss,
        "loss/baseline":
        baseline_loss,
        "loss/normalized_neg_entropy":
        entropy_loss / uniform_entropy,
        "critic/baseline/mean":
        torch.mean(baseline),
        "critic/baseline/min":
        torch.min(baseline),
        "critic/baseline/max":
        torch.max(baseline),
    }
示例#11
0
def learn(
        flags,
        learner_queue,
        model,
        actor_model,
        D,
        optimizer,
        scheduler,
        stats,
        plogger,
        lock=threading.Lock(),
):
    for tensors in learner_queue:
        new_obs = tensors[1]
        tensors = edit_tuple(tensors, 1, new_obs["canvas"])

        tensors = nest.map(
            lambda t: t.to(flags.learner_device, non_blocking=True), tensors)

        batch, new_frame, initial_agent_state = tensors

        env_outputs, actor_outputs = batch
        obs, reward, done, step, _ = env_outputs

        lock.acquire()  # Only one thread learning at a time.

        if flags.use_tca:
            discriminator_reward = tca_reward_function(flags, obs, new_frame,
                                                       D)

            reward = env_outputs[1]
            env_outputs = edit_tuple(env_outputs, 1,
                                     reward + discriminator_reward)
            batch = edit_tuple(batch, 0, env_outputs)
        else:
            if done.any().item():
                discriminator_reward = reward_function(flags, done, new_frame,
                                                       D)

                reward = env_outputs[1]
                env_outputs = edit_tuple(env_outputs, 1,
                                         reward + discriminator_reward)
                batch = edit_tuple(batch, 0, env_outputs)

        optimizer.zero_grad()

        actor_outputs = AgentOutput._make(actor_outputs)

        learner_outputs, agent_state = model(obs, done, initial_agent_state)

        # Take final value function slice for bootstrapping.
        learner_outputs = AgentOutput._make(learner_outputs)
        bootstrap_value = learner_outputs.baseline[-1]

        # Move from obs[t] -> action[t] to action[t] -> obs[t].
        batch = nest.map(lambda t: t[1:], batch)
        learner_outputs = nest.map(lambda t: t[:-1], learner_outputs)

        # Turn into namedtuples again.
        env_outputs, actor_outputs = batch

        env_outputs = EnvOutput._make(env_outputs)
        actor_outputs = AgentOutput._make(actor_outputs)
        learner_outputs = AgentOutput._make(learner_outputs)

        discounts = (~env_outputs.done).float() * flags.discounting

        vtrace_returns = vtrace.from_logits(
            behavior_policy_logits=actor_outputs.policy_logits,
            target_policy_logits=learner_outputs.policy_logits,
            actions=actor_outputs.action,
            discounts=discounts,
            rewards=env_outputs.reward,
            values=learner_outputs.baseline,
            bootstrap_value=bootstrap_value,
        )

        vtrace_returns = vtrace.VTraceFromLogitsReturns._make(vtrace_returns)

        pg_loss = compute_policy_gradient_loss(
            learner_outputs.policy_logits,
            actor_outputs.action,
            vtrace_returns.pg_advantages,
        )
        baseline_loss = flags.baseline_cost * compute_baseline_loss(
            vtrace_returns.vs - learner_outputs.baseline)
        entropy_loss = flags.entropy_cost * compute_entropy_loss(
            learner_outputs.policy_logits)

        total_loss = pg_loss + baseline_loss + entropy_loss

        total_loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping)

        optimizer.step()
        scheduler.step()

        actor_model.load_state_dict(model.state_dict())

        episode_returns = env_outputs.episode_return[env_outputs.done]

        stats["step"] = stats.get("step",
                                  0) + flags.unroll_length * flags.batch_size
        stats["episode_returns"] = tuple(episode_returns.cpu().numpy())
        stats["mean_environment_return"] = episode_returns.mean().item()
        stats["mean_discriminator_return"] = discriminator_reward.mean().item()
        stats["mean_episode_return"] = (stats["mean_environment_return"] +
                                        stats["mean_discriminator_return"])
        stats["total_loss"] = total_loss.item()
        stats["pg_loss"] = pg_loss.item()
        stats["baseline_loss"] = baseline_loss.item()
        stats["entropy_loss"] = entropy_loss.item()
        stats["learner_queue_size"] = learner_queue.size()

        if flags.condition and new_frame.size() != 0:
            stats["l2_loss"] = F.mse_loss(
                *new_frame.split(split_size=new_frame.shape[1] //
                                 2, dim=1)).item()

        plogger.log(stats)
        lock.release()