예제 #1
0
    def forward(self, memory_keys, memory_values, queries, masks=None):
        """Computes attention weights and the context vector.

    Args:
      memory_keys (torch.FloatTensor): batch of keys of shape
        (batch_size, num_cells, key_dim)
      memory_values (torch.FloatTensor): batch of values of shape
        (batch_size, num_cells, value_dim)
      queries (torch.FloatTensor): batch of queries of shape
        (batch_size, key_dim)
      masks (torch.ByteTensor | None): batch of masks of shape
        (batch_size, num_cells). Masks out cells where the value is 0. Defaults
        to no masking.

    Returns:
      attention_weights (torch.FloatTensor): shape (batch_size, num_cells)
      context (torch.FloatTensor): shape (batch_size, value_dim)
    """
        if masks is None:
            masks = torch.ones(memory_keys.shape[0], memory_keys.shape[1])
        masks = masks.unsqueeze(1)

        # (batch_size, 1, num_cells)
        attention_weights = F.softmax(self._score(queries, memory_keys), -1)
        masked_attention_weights = utils.mask_renormalize(
            attention_weights, masks)

        # (batch_size, 1, value_dim)
        context = torch.bmm(masked_attention_weights, memory_values)
        return masked_attention_weights.squeeze(1), context.squeeze(1)
예제 #2
0
    def forward(self, cache_accesses, prev_hidden_state=None, inference=False):
        """Computes cache line to evict.

      Each cache line in the cache access is scored (higher score ==> more
      likely to evict).

    Args:
      cache_accesses (list[CacheAccess]): batch of cache accesses to
        process and whose cache lines to choose from.
      prev_hidden_state (Object | None): the result from the
        previous call to this model on the previous cache access. Use None
        only on the first cache access from a trace.
      inference (bool): set to be True at inference time, when the outputs are
        not being trained on. If True, detaches the hidden state from the graph
        to save memory.

    Returns:
      scores (torch.FloatTensor): tensor of shape
        (batch_size, len(cache_access.cache_lines)). Each entry is the
        eviction score of the corresponding cache line. The candidate
        with the highest score should be chosen for eviction.
      predicted_reuse_distances (torch.FloatTensor): tensor of shape
        (batch_size, len(cache_access.cache_lines)). Each entry is the predicted
        reuse distance of the corresponding cache line.
      hidden_state (Object): hidden state to pass to the next call of
        this function. Must be called on consecutive EvictionEntries in a trace.
      access_attention (iterable[iterable[(torch.FloatTensor, CacheAccess)]]):
        batch (outer list) of attention weights. Each inner list element is the
        attention weights of each cache line in same order as
        cache_access.cache_lines (torch.FloatTensor of shape num_cache_lines)
        on a past CacheAccess arranged from earliest to most recent.
    """
        batch_size = len(cache_accesses)
        if prev_hidden_state is None:
            hidden_state, hidden_state_history, access_history = (
                self._initial_hidden_state(batch_size))
        else:
            hidden_state, hidden_state_history, access_history = prev_hidden_state

        pc_embedding = self._pc_embedder(
            [cache_access.pc for cache_access in cache_accesses])
        address_embedding = self._address_embedder(
            [cache_access.address for cache_access in cache_accesses])

        # Each (batch_size, hidden_size)
        next_c, next_h = self._lstm_cell(
            torch.cat((pc_embedding, address_embedding), -1), hidden_state)

        if inference:
            next_c = next_c.detach()
            next_h = next_h.detach()

        # Don't modify history in place
        hidden_state_history = hidden_state_history.copy()
        hidden_state_history.append(next_h)
        access_history = access_history.copy()
        access_history.append(cache_accesses)

        # Cache lines must be padded to at least length 1 for embedding layers.
        cache_lines, mask = utils.pad(
            [cache_access.cache_lines for cache_access in cache_accesses],
            min_len=1,
            pad_token=(0, 0))
        cache_lines = np.array(cache_lines)
        num_cache_lines = cache_lines.shape[1]

        # Flatten into single list
        cache_pcs = itertools.chain.from_iterable(cache_lines[:, :, 1])
        cache_addresses = itertools.chain.from_iterable(cache_lines[:, :, 0])

        # (batch_size, num_cache_lines, embed_dim)
        cache_line_embeddings = self._cache_line_embedder(
            cache_addresses).view(batch_size, num_cache_lines, -1)
        if self._cache_pc_embedder is not None:
            cache_pc_embeddings = self._cache_pc_embedder(cache_pcs).view(
                batch_size, num_cache_lines, -1)
            cache_line_embeddings = torch.cat(
                (cache_line_embeddings, cache_pc_embeddings), -1)

        # (batch_size, history_len, hidden_size)
        history_tensor = torch.stack(list(hidden_state_history), dim=1)

        # (batch_size, history_len, positional_embed_size)
        positional_embeds = self._positional_embedder(
            list(range(len(hidden_state_history)))).expand(batch_size, -1, -1)

        # attention_weights: (batch_size, num_cache_lines, history_len)
        # context: (batch_size, num_cache_lines, hidden_size + pos_embed_size)
        attention_weights, context = self._history_attention(
            history_tensor, torch.cat((history_tensor, positional_embeds), -1),
            cache_line_embeddings)

        # (batch_size, num_cache_lines)
        scores = F.softmax(self._cache_line_scorer(context).squeeze(-1), -1)
        probs = utils.mask_renormalize(scores, mask)

        pred_reuse_distances = self._reuse_distance_estimator(context).squeeze(
            -1)
        # Return reuse distances as scores if probs aren't being trained.
        if len(self._loss_fns) == 1 and "reuse_dist" in self._loss_fns:
            probs = torch.max(
                pred_reuse_distances,
                torch.ones_like(pred_reuse_distances) * 1e-5) * mask.float()

        # Transpose access_history to be (batch_size, history_len)
        unbatched_histories = zip(*access_history)
        # Nested zip of attention and access_history
        access_attention = (zip(weights.transpose(0, 1), history)
                            for weights, history in zip(
                                attention_weights, unbatched_histories))

        next_hidden_state = ((next_c, next_h), hidden_state_history,
                             access_history)
        return probs, pred_reuse_distances, next_hidden_state, access_attention