Beispiel #1
0
    def step(self, batch_idx: torch.Tensor, emb_t: torch.Tensor,
             is_shift: torch.Tensor, is_reduce: torch.Tensor,
             stack: torch.Tensor, stack_ptr: torch.Tensor):
        # stack_ptr_ = stack_ptr
        # stack_ptr = stack_ptr_.clone()

        # 2. Batched shift and reduce operations
        # shift
        if is_shift.any():
            shift_stack = stack[is_shift]
            shift_stack_ptr = stack_ptr[is_shift]
            idx = torch.arange(shift_stack.size(0),
                               dtype=shift_stack_ptr.dtype,
                               device=shift_stack_ptr.device)
            shift_stack[idx, shift_stack_ptr] = emb_t[is_shift]
            stack[is_shift] = shift_stack
            stack_ptr[is_shift] = shift_stack_ptr + 1

        # reduce
        if is_reduce.any():
            reduce_stack = stack[is_reduce]
            reduce_stack_ptr = stack_ptr[is_reduce]
            idx = torch.arange(reduce_stack.size(0),
                               dtype=reduce_stack_ptr.dtype,
                               device=reduce_stack_ptr.device)
            r_child = reduce_stack[idx, reduce_stack_ptr - 1]
            l_child = reduce_stack[idx, reduce_stack_ptr - 2]
            parent = self.op(l_child, r_child)
            reduce_stack[idx, reduce_stack_ptr - 2] = parent
            stack[is_reduce] = reduce_stack
            stack_ptr[is_reduce] = reduce_stack_ptr - 1

        return stack, stack_ptr
Beispiel #2
0
    def _recall_at(self, predictions: Batch, targets: Batch,
                   matches: torch.Tensor) -> Dict[int, float]:
        # matches.shape = [num_relations_pred, num_relations_gt]

        # matches.argmax(dim=0) will return the last index if no True value is found.
        # We can use matches.any(dim=0) to ignore those cases.
        # Also, we must account for the row offset in the matches matrix.
        gt_retrieved = matches.any(dim=0)

        offset = (predictions.n_edges.cumsum(dim=0).repeat_interleave(
            targets.n_edges) - predictions.n_edges[0])
        gt_retrieved_rank = matches.int().argmax(dim=0) - offset

        # [K, E_t]
        gt_retrieved_at = (gt_retrieved_rank[None, :] <
                           self.steps[:, None]) & gt_retrieved[None, :]

        # [K, num_graphs]
        gt_relation_to_graph_assignment = targets.batch[
            targets.relation_indexes[0]]
        recall_at_per_graph = scatter_(
            "mean",
            gt_retrieved_at.float(),
            index=gt_relation_to_graph_assignment,
            dim=1,
            dim_size=targets.num_graphs,
        )

        # [K]
        recall_at = recall_at_per_graph.mean(dim=1)

        return {k: v for k, v in zip(self.steps.numpy(), recall_at.numpy())}
    def forward(self, q: Tensor, sequence_mask: Tensor) -> Tensor:
        """ Produce a single classification output for a sequence of vectors.

        Parameters
        ----------
        q : [T, B, D]
            Hidden activations after central encoder.
        sequence_mask : [T, B, 1]
            Positive mask for zeroing out padded vectors between operations.

        Returns
        -------
        classification: [B]
            Probability of this particle existing in the data.
        """

        # ------------------------------------------------------------
        # Collapse the sequence vectors into a single vector as a sum.
        # hidden_dim : [1, B, D]
        # sequence_mask : [1, B, 1]
        # ------------------------------------------------------------
        hidden = (q * sequence_mask).sum(0, keepdim=True) / sequence_mask.sum(
            0, keepdim=True)
        sequence_mask = sequence_mask.any(dim=0, keepdim=True)

        # ------------------------------------------------------------
        # Run through the linear layer stack and output the result
        # classification : [B]
        # ------------------------------------------------------------
        hidden = self.hidden_layers(hidden, sequence_mask).squeeze()
        classification = self.output_layer(hidden).squeeze()

        return classification
Beispiel #4
0
 def rescale(
     self,
     tensor: torch.Tensor,
     mask: torch.Tensor,
     image_name: str,
 ) -> torch.Tensor:
     # The tensor is cloned as in-place operations will be used
     array = tensor.clone().float().numpy()
     mask = mask.numpy()
     if not mask.any():
         message = (f'Rescaling image "{image_name}" not possible'
                    ' because the mask to compute the statistics is empty')
         warnings.warn(message, RuntimeWarning)
         return tensor
     values = array[mask]
     cutoff = np.percentile(values, self.percentiles)
     np.clip(array, *cutoff, out=array)
     if self.in_min_max is None:
         in_min, in_max = array.min(), array.max()
     else:
         in_min, in_max = self.in_min_max
     in_range = in_max - in_min
     if in_range == 0:  # should this be compared using a tolerance?
         message = (f'Rescaling image "{image_name}" not possible'
                    ' because all the intensity values are the same')
         warnings.warn(message, RuntimeWarning)
         return tensor
     array -= in_min
     array /= in_range
     out_range = self.out_max - self.out_min
     array *= out_range
     array += self.out_min
     return torch.as_tensor(array)
Beispiel #5
0
    def forward(self,
                prev_output_tokens: Tensor,
                encoder_out: Tensor,
                encoder_padding_mask: Tensor,
                incremental_state: Optional[Dict[str, Dict[str,
                                                           Tensor]]] = None):
        # embed positions
        positions = self.embed_positions(
            prev_output_tokens,
            incremental_state=incremental_state,
        ) if self.embed_positions is not None else None

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            if positions is not None:
                positions = positions[:, -1:]

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)
        if positions is not None:
            x += positions
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        # The tensor needs to copy transposed because
        # fused dropout is not capable of handing strided data
        if self.fuse_dropout_add:
            x = x.transpose(0, 1).contiguous()
        else:
            x = x.transpose(0, 1)
        attn = None

        # decoder layers
        for layer in self.layers:
            x, attn = layer(
                x,
                encoder_out,
                encoder_padding_mask if encoder_padding_mask.any() else None,
                incremental_state,
            )

        if self.normalize:
            x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.adaptive_softmax is None:
            # project back to size of vocabulary
            x = F.linear(x, self.embed_out)

        return x, attn
Beispiel #6
0
    def forward(self,  # pylint: disable=arguments-differ
                source: Dict[str, torch.Tensor],
                target: Dict[str, torch.Tensor],
                reset: torch.Tensor,
                metadata: List[Dict[str, Any]],
                mention_type: torch.Tensor = None,
                raw_entity_ids: Dict[str, torch.Tensor] = None,
                entity_ids: Dict[str, torch.Tensor] = None,
                parent_ids: Dict[str, torch.Tensor] = None,
                relations: Dict[str, torch.Tensor] = None,
                shortlist: Dict[str, torch.Tensor] = None,
                shortlist_inds: torch.Tensor = None,
                alias_copy_inds: torch.Tensor = None) -> Dict[str, torch.Tensor]:

        # Tensorize the alias_database - this will only perform the operation once.
        alias_database = metadata[0]['alias_database']
        alias_database.tensorize(vocab=self.vocab)

        # Reset the model if needed
        if reset.any() and (self._state is not None):
            for layer in range(self._num_layers):
                h, c = self._state['layer_%i' % layer]
                h[:, reset, :] = torch.zeros_like(h[:, reset, :])
                c[:, reset, :] = torch.zeros_like(c[:, reset, :])
                self._state['layer_%i' % layer] = (h, c)
        self._recent_entities.reset(reset)

        if entity_ids is not None:
            output_dict = self._forward_loop(
                source=source,
                target=target,
                alias_database=alias_database,
                mention_type=mention_type,
                raw_entity_ids=raw_entity_ids,
                entity_ids=entity_ids,
                parent_ids=parent_ids,
                relations=relations,
                shortlist=shortlist,
                shortlist_inds=shortlist_inds,
                alias_copy_inds=alias_copy_inds)
        else:
            # TODO: Figure out what we want here - probably to do some king of inference on
            # entities / mention types.
            output_dict = {}

        return output_dict
Beispiel #7
0
    def forward(self,
                prev_output_tokens: Tensor,
                encoder_out: Tensor,
                encoder_padding_mask: Tensor,
                incremental_state: Optional[Dict[str, Dict[str,
                                                           Tensor]]] = None):
        positions = self.embed_positions(
            prev_output_tokens,
            incremental_state=incremental_state,
        ) if self.embed_positions is not None else None

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            if positions is not None:
                positions = positions[:, -1:]

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)
        if positions is not None:
            x += positions
        if self.training:
            if self.platform == "npu":
                x, _, _ = torch.dropoutV2(x, self.seed, p=self.prob)
            elif self.platform == "gpu":
                x = self.dropout(x)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
        attn = None

        # decoder layers
        for layer in self.layers:
            x, attn = layer(
                x,
                encoder_out,
                encoder_padding_mask if encoder_padding_mask.any() else None,
                incremental_state,
            )

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)
        x = F.linear(x, self.embed_out)

        return x, attn
Beispiel #8
0
def offset_loss(
    preds: torch.Tensor,
    labels: torch.Tensor,
    is_center: torch.Tensor,
    n_objects: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    :param preds: N2HW (float32)
    :param labels: N2HW (float32)
    :param is_center: NKHW (torch) for K-way classification
    """
    if n_objects is None:
        n_objects = is_center.sum()
    if not n_objects:
        return 0

    is_center = is_center.any(1).unsqueeze(1)
    # we could omit `* is_center` for labels if we assume they're 0 except for centers
    l1_loss = torch.nn.functional.l1_loss(preds * is_center,
                                          labels * is_center,
                                          reduction="sum")
    return l1_loss / n_objects
Beispiel #9
0
    def forward(self,  # pylint: disable=arguments-differ
                source: Dict[str, torch.Tensor],
                target: Dict[str, torch.Tensor],
                reset: torch.Tensor,
                entity_ids: torch.Tensor = None,
                shortlist: Dict[str, torch.Tensor] = None,
                shortlist_inds: torch.Tensor = None,
                alias_copy_inds: torch.Tensor = None,
                alias_tokens: torch.Tensor = None,
                alias_inds: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        alias_tokens = alias_tokens['tokens']
        # Inds have fixed size and don't get truncated on split so truncate
        # now.
        alias_inds = alias_inds[:, :, :alias_tokens.shape[2]]

        # Reset the model if needed
        if reset.any() and (self._state is not None):
            for layer in range(self._num_layers):
                h, c = self._state['layer_%i' % layer]
                h[:, reset, :] = torch.zeros_like(h[:, reset, :])
                c[:, reset, :] = torch.zeros_like(c[:, reset, :])
                self._state['layer_%i' % layer] = (h, c)

        if entity_ids is not None:
            output_dict = self._forward_loop(
                source=source,
                target=target,
                alias_copy_inds=alias_copy_inds,
                alias_tokens=alias_tokens,
                alias_inds=alias_inds)
        else:
            # TODO: Figure out what we want here - probably to do some king of inference on
            # entities / mention types.
            output_dict = {}

        return output_dict
Beispiel #10
0
    def sample(self,
               source: Dict[str, torch.Tensor],
               target: Dict[str, torch.Tensor],
               reset: torch.Tensor,
               metadata: Dict[str, Any],
               alias_copy_inds: torch.Tensor,
               shortlist: Dict[str, torch.Tensor] = None,
               **kwargs) -> Dict[str, Any]:  # **kwargs intended to eat the other fields if they are provided.
        """
        Sampling annotations for the generative model. Note that unlike forward, this function
        expects inputs from a **generative** dataset reader, not a **discriminative** one.
        """

        # Tensorize the alias_database - this will only perform the operation once.
        alias_database = metadata[0]['alias_database']
        alias_database.tensorize(vocab=self.vocab)

        # Reset the model if needed
        if reset.any() and (self._state is not None):
            for layer in range(self._num_layers):
                h, c = self._state['layer_%i' % layer]
                h[:, reset, :] = torch.zeros_like(h[:, reset, :])
                c[:, reset, :] = torch.zeros_like(c[:, reset, :])
                self._state['layer_%i' % layer] = (h, c)
        self._recent_entities.reset(reset)

        logp = 0.0

        mask = get_text_field_mask(target).byte()
        # We encode the target tokens (**not** source) since the discriminative model makes
        # predictions on the current token, but the generative model expects labels for the
        # **next** (e.g. target) token!
        encoded, *_ = self._encode_source(target['tokens'])
        splits = [self.token_embedding_dim] + [self.entity_embedding_dim] * 2
        encoded_token, encoded_head, encoded_relation = encoded.split(splits, dim=-1)

        # Compute new mention logits
        mention_logits = self._fc_mention_type(encoded_token)
        mention_probs = F.softmax(mention_logits, dim=-1)
        mention_type = parallel_sample(mention_probs)
        mention_logp = mention_probs.gather(-1, mention_type.unsqueeze(-1)).log()
        mention_logp[~mask] = 0
        mention_logp = mention_logp.sum()

        # Compute entity logits
        new_entity_mask = mention_type.eq(1)
        new_entity_logits = self._new_entity_logits(encoded_head + encoded_relation, shortlist)
        if self._use_shortlist:
            # If using shortlist, then samples are indexed w.r.t the shortlist and entity_ids must be looked up
            shortlist_mask = get_text_field_mask(shortlist)
            new_entity_probs = masked_softmax(new_entity_logits, shortlist_mask)
            shortlist_inds = torch.zeros_like(mention_type)
            # Some sequences may be full of padding in which case the shortlist
            # is empty
            not_just_padding = shortlist_mask.byte().any(-1)
            shortlist_inds[not_just_padding] = parallel_sample(new_entity_probs[not_just_padding])
            shortlist_inds[~new_entity_mask] = 0
            _new_entity_logp = new_entity_probs.gather(-1, shortlist_inds.unsqueeze(-1)).log()
            new_entity_samples = shortlist['entity_ids'].gather(1, shortlist_inds)
        else:
            new_entity_logits = new_entity_logits
            # If not using shortlist, then samples are indexed w.r.t to the global vocab
            new_entity_probs = F.softmax(new_entity_logits, dim=-1)
            new_entity_samples = parallel_sample(new_entity_probs)
            _new_entity_logp = new_entity_probs.gather(-1, new_entity_samples.unsqueeze(-1)).log()
            shortlist_inds = None
        # Zero out masked tokens and non-new entity predictions
        _new_entity_logp[~mask] = 0
        _new_entity_logp[~new_entity_mask] = 0
        new_entity_logp = _new_entity_logp.sum()

        # Start filling in the entity ids
        entity_ids = torch.zeros_like(target['tokens'])
        entity_ids[new_entity_mask] = new_entity_samples[new_entity_mask]

        # ...UGH we also need the raw ids - remapping time
        raw_entity_ids = torch.zeros_like(target['tokens'])
        for *index, entity_id in nested_enumerate(entity_ids.tolist()):
            token = self.vocab.get_token_from_index(entity_id, 'entity_ids')
            raw_entity_id = self.vocab.get_token_index(token, 'raw_entity_ids')
            raw_entity_ids[tuple(index)] = raw_entity_id

        # Derived mentions need to be computed sequentially.
        parent_ids = torch.zeros_like(target['tokens']).unsqueeze(-1)
        derived_entity_mask = mention_type.eq(2)
        derived_entity_logp = 0.0

        sequence_length = target['tokens'].shape[1]

        for i in range(sequence_length):

            current_mask = derived_entity_mask[:, i] & mask[:, i]

            # -------------------  SAMPLE PARENTS ---------------------

            # Update recent entities with **current** entity only
            current_entity_id = entity_ids[:, i].unsqueeze(1)
            candidate_ids, candidate_mask = self._recent_entities(current_entity_id)

            # If no mentions are derived, there is no point continuing after entities have been updated.
            if not current_mask.any():
                continue

            # Otherwise we proceed
            candidate_embeddings = self._entity_embedder(candidate_ids)

            # Compute logits w.r.t **current** hidden state only
            current_head_encoding = encoded_head[:, i].unsqueeze(1)
            selection_logits = torch.bmm(current_head_encoding, candidate_embeddings.transpose(1, 2))
            selection_probs = masked_softmax(selection_logits, candidate_mask)

            # Only sample if there is at least one viable candidate (e.g. if a sampling distribution
            # has no probability mass we cannot sample from it). Return zero as the parent for
            # non-viable distributions.
            viable_candidate_mask = candidate_mask.any(-1).squeeze()
            _parent_ids = torch.zeros_like(current_entity_id)
            parent_logp = torch.zeros_like(current_entity_id, dtype=torch.float32)
            if viable_candidate_mask.any():
                viable_candidate_ids = candidate_ids[viable_candidate_mask]
                viable_candidate_probs = selection_probs[viable_candidate_mask]
                viable_parent_samples = parallel_sample(viable_candidate_probs)
                viable_logp = viable_candidate_probs.gather(-1, viable_parent_samples.unsqueeze(-1)).log()
                viable_parent_ids = viable_candidate_ids.gather(-1, viable_parent_samples)
                _parent_ids[viable_candidate_mask] = viable_parent_ids
                parent_logp[viable_candidate_mask] = viable_logp.squeeze(-1)

            parent_ids[current_mask, i] = _parent_ids[current_mask]  # TODO: Double-check
            derived_entity_logp += parent_logp[current_mask].sum()

            # ---------------------- SAMPLE RELATION -----------------------------

            # Lookup sampled parent ids in the knowledge graph
            indices, parent_ids_list, relations_list, tail_ids_list = self._knowledge_graph_lookup(_parent_ids)
            relation_embeddings = [self._relation_embedder(r) for r in relations_list]

            # Sample tail ids
            current_relation_encoding = encoded_relation[:, i].unsqueeze(1)
            _raw_tail_ids = torch.zeros_like(_parent_ids).squeeze(-1)
            _tail_ids = torch.zeros_like(_parent_ids).squeeze(-1)
            for index, relation_embedding, tail_id_lookup in zip(indices, relation_embeddings, tail_ids_list):
                # Compute the score for each relation w.r.t the current encoding. NOTE: In the loss
                # code index has a slice. We don't need that here since there is always a
                # **single** parent.
                logits = torch.mv(relation_embedding, current_relation_encoding[index])
                # Convert to probability
                tail_probs = F.softmax(logits, dim=-1)
                # Sample
                tail_sample = torch.multinomial(tail_probs, 1)
                # Get logp. Ignoring the current_mask here is **super** dodgy, but since we forced
                # null parents to zero we shouldn't be accumulating probabilities for unused predictions.
                tail_logp = tail_probs.gather(-1, tail_sample).log()
                derived_entity_logp += tail_logp.sum()  # Sum is redundant, just need it to make logp a scalar

                # Map back to raw id
                raw_tail_id = tail_id_lookup[tail_sample]
                # Convert raw id to id
                tail_id_string = self.vocab.get_token_from_index(raw_tail_id.item(), 'raw_entity_ids')
                tail_id = self.vocab.get_token_index(tail_id_string, 'entity_ids')

                _raw_tail_ids[index[:-1]] = raw_tail_id
                _tail_ids[index[:-1]] = tail_id

            raw_entity_ids[current_mask, i] = _raw_tail_ids[current_mask]  # TODO: Double-check
            entity_ids[current_mask, i] = _tail_ids[current_mask]  # TODO: Double-check

            self._recent_entities.insert(_tail_ids, current_mask)

            # --------------------- CONTINUE MENTIONS ---------------------------------------
            continue_mask = mention_type[:, i].eq(3) & mask[:, i]
            if not current_mask.any() or i == 0:
                continue
            raw_entity_ids[continue_mask, i] = raw_entity_ids[continue_mask, i-1]
            entity_ids[continue_mask, i] = entity_ids[continue_mask, i-1]
            entity_ids[continue_mask, i] = entity_ids[continue_mask, i-1]
            parent_ids[continue_mask, i] = parent_ids[continue_mask, i-1]
            if self._use_shortlist:
                shortlist_inds[continue_mask, i] = shortlist_inds[continue_mask, i-1]
            alias_copy_inds[continue_mask, i] = alias_copy_inds[continue_mask, i-1]

        # Lastly, because entities won't always match the true entity ids,
        # we need to zero out any alias copy ids that won't be valid.
        if 'raw_entity_ids' in kwargs:
            true_raw_entity_ids = kwargs['raw_entity_ids']['raw_entity_ids']
            invalid_id_mask = ~true_raw_entity_ids.eq(raw_entity_ids)
            alias_copy_inds[invalid_id_mask] = 0

        # Pass denotes fields that are passed directly from input to output.
        sample = {
            'source': source,  # Pass
            'target': target,  # Pass
            'reset': reset,  # Pass
            'metadata': metadata,  # Pass
            'mention_type': mention_type,
            'raw_entity_ids': {'raw_entity_ids': raw_entity_ids},
            'entity_ids': {'entity_ids': entity_ids},
            'parent_ids': {'entity_ids': parent_ids},
            'relations': {'relations': None},  # We aren't using them - eventually should remove entirely
            'shortlist': shortlist,  # Pass
            'shortlist_inds': shortlist_inds,
            'alias_copy_inds': alias_copy_inds
        }
        logp = mention_logp + new_entity_logp + derived_entity_logp
        return {'sample': sample, 'logp': logp}
Beispiel #11
0
def mask_to_instances(mask: Tensor) -> Tensor:
    r"""Given a binary mask, extract a new mask indicating contiguous
    instances. The resultant mask will use ``0`` to indicate background
    classes, and will begin numbering instance regions with ``1``.

    This method identifies connected components via nearest-neighbor message passing.
    The runtime is a function of the diameter of the largest connected component.

    Args:
        mask (:class:`torch.Tensor`):
            Binary mask to extract instances

    Shapes:
        * ``mask`` - :math:`(H, W)`
        * Output - :math:`(H, W)`
    """
    H, W = mask.shape[-2:]
    mask = mask > 0

    if not mask.any():
        return mask

    # assign each positive location a unique instance label
    _ = torch.arange(0, H * W, device=mask.device)
    with torch.random.fork_rng(devices=(mask.device, )):
        torch.random.manual_seed(42)
        instances = (torch.multinomial(torch.ones(H * W,
                                                  dtype=torch.float,
                                                  device=mask.device),
                                       H * W,
                                       replacement=False).view(H, W).long())
        instances[~mask] = 0

    # get locations of each positive label
    nodes = mask.nonzero()
    N = nodes.shape[0]
    node_idx = torch.split(nodes, [1, 1], dim=-1)

    # build adjacency tensor
    delta = torch.tensor(
        [
            [-1, -1],
            [-1, 0],
            [-1, 1],
            [0, -1],
            [0, 0],
            [0, 1],
            [1, -1],
            [1, 0],
            [1, 1],
        ],
        device=mask.device,
    )
    A = delta.shape[-2]
    adjacency = (nodes.view(-1, 1, 2) + delta.view(1, -1, 2)).reshape(-1, 2)
    adjacency[..., 0].clamp_(min=0, max=H - 1)
    adjacency[..., 1].clamp_(min=0, max=W - 1)
    adjacency_idx = torch.split(adjacency, [1, 1], dim=-1)
    passed_messages = instances[adjacency_idx].view(N, A)

    # iteratively pass instance label to neighbors
    # adopt instance label of max(self, neighbors)
    # NOTE: try to buffer things and operate in place where possible for speed
    # NOTE: something below fails to script
    old_instances = instances
    new_instances = instances.clone()
    adjacency_buffer = adjacency.new_empty(N)
    adjacency.new_empty(N)
    while True:
        passed_messages = old_instances[adjacency_idx].view(N, A)
        torch.amax(passed_messages, dim=-1, out=adjacency_buffer)
        new_instances[node_idx] = adjacency_buffer.view(-1, 1)

        # if nothing was updated, we're done
        diff = new_instances[mask] != old_instances[mask]
        if not diff.any():
            break

        _ = new_instances
        new_instances = old_instances
        old_instances = _

    # convert unique instance labels into a 1-indexed set of consecutive labels
    unique_instances = torch.unique(instances)
    for new_ins, old_ins in enumerate(unique_instances):
        if old_ins == 0:
            continue
        instances[instances == old_ins] = new_ins

    return instances.long()
Beispiel #12
0
    def forward(self,  # pylint: disable=arguments-differ
                source: Dict[str, torch.Tensor],
                target: Dict[str, torch.Tensor],
                reset: torch.Tensor = None) -> Dict[str, torch.Tensor]:

        # THE BELOW ONLY NEEDS TO BE SATISFIED FOR THE FANCY ITERATOR, MERITY
        # ET AL JUST PROPOGATE THE HIDDEN STATE NO MATTER WHAT
        # To make life easier when evaluating the model we use a BasicIterator
        # so that we do not need to worry about the sequence truncation
        # performed by our splitting iterators. To accomodate this, we assume
        # that if reset is not given, then everything gets reset.
        if reset is None:
            self._state = None
        elif reset.all() and (self._state is not None):
            logger.debug('RESET')
            self._state = None
        elif reset.any() and (self._state is not None):
            for layer in range(self.num_layers):
                h, c = self._state['layer_%i' % layer]
                h[:, reset, :] = torch.zeros_like(h[:, reset, :])
                c[:, reset, :] = torch.zeros_like(c[:, reset, :])
                self._state['layer_%i' % layer] = (h, c)

        target_mask = get_text_field_mask(target)
        source = source['tokens']
        target = target['tokens']

        embeddings = embedded_dropout(self.embedder, source,
                                      dropout=self.dropoute if self.training else 0)
        embeddings = self.locked_dropout(embeddings, self.dropouti)

        # Iterate through RNN layers
        current_input = embeddings
        current_hidden = []
        outputs = []
        dropped_outputs = []
        for layer, rnn in enumerate(self.rnns):

            # Bookkeeping
            if self._state is not None:
                prev_hidden = self._state['layer_%i' % layer]
            else:
                prev_hidden = None

            # Forward-pass
            output, hidden = rnn(current_input, prev_hidden)

            # More bookkeeping
            output = output.contiguous()
            outputs.append(output)
            hidden = tuple(h.detach() for h in hidden)
            current_hidden.append(hidden)

            # Apply dropout
            if layer == self.num_layers - 1:
                current_input = self.locked_dropout(output, self.dropout)
                dropped_outputs.append(output)
            else:
                current_input = self.locked_dropout(output, self.dropouth)
                dropped_outputs.append(current_input)

        # Compute logits and loss
        logits = self.decoder(current_input)
        loss = sequence_cross_entropy_with_logits(logits, target.contiguous(),
                                                  target_mask,
                                                  average="token")
        num_tokens = target_mask.float().sum() + 1e-13

        # Activation regularization
        if self.alpha:
            loss = loss + self.alpha * current_input.pow(2).mean()
        # Temporal activation regularization (slowness)
        if self.beta:
            loss = loss + self.beta * (output[:, 1:] - output[:, :-1]).pow(2).mean()

        # Update metrics and state
        unks = target.eq(self._unk_index)
        unk_penalty = self._unk_penalty * unks.float().sum()

        self.ppl(loss * num_tokens, num_tokens)
        self.upp(loss * num_tokens + unk_penalty, num_tokens)
        self._state = {'layer_%i' % l: h for l, h in enumerate(current_hidden)}

        return {'loss': loss}
Beispiel #13
0
 def _has_any(mask: Tensor) -> bool:
     r""" Checks if the mask has any set to \p True """
     assert mask.dtype == torch.bool, "Mask should be a Boolean Tensor"
     return bool(mask.any().item())
Beispiel #14
0
def pgd(net: nn.Module,
        x: torch.Tensor,
        nb_iter: int = 10,
        eps: float = 0.3,
        eps_iter: float = 0.05,
        rand_minmax: float = 0.3,
        clip_min=None,
        clip_max=None,
        y=None,
        ordr=np.inf,
        rand_init=None,
        targeted=False) -> torch.Tensor:
    """
  This class implements either the Basic Iterative Method
  (Kuarkin et al. 2016) when rand_init is set to 0. or the
  Madry et al. (2017) method when rand_minmax is larger than 0.
  Paper link (Kuarkin et al. 2016): https://arxiv.org/pdf/1607.02533.pdf
  Paper link (Madry et al. 2017): https://arxiv.org/pdf/1706.06083.pdf
 
  # TODO FIX the below arguments style
  Arguments
  ---------
  model: Model
  dtype: dtype of the data
  default_rand_init: whether to use random initialization by default
  kwargs: passed through to super constructor

  Returns
  -------
  adv_x : torch.Tensor
      The adversarial Example.
  """

    # TODO Check params
    # If a data range was specified, check that the input was in that range
    if clip_min is not None:
        asserts.append(x.any() >= clip_min)

    if clip_max is not None:
        asserts.append(x.any() <= clip_max)

    # Initialize loop variables
    if rand_init:
        eta = torch.FloatTensor(*x.shape).uniform_(-minmax, minmax)
    else:
        eta = torch.zeros_like(x)

    # Clip eta
    eta = clip_eta(eta, ordr, eps)
    adv_x = x + eta

    if clip_min is not None or clip_max is not None:
        adv_x = torch.clamp(adv_x, clip_max, clip_min)

    if y is None:
        # Use ground truth labels to avoid label leaking
        _, y = torch.max(net(x), dim=1)
    else:
        targeted = True

    if ord == 1:
        raise NotImplementedError(
            "It's not clear that FGM is a good inner loop"
            " step for PGD when ord=1, because ord=1 FGM "
            " changes only one pixel at a time. We need "
            " to rigoursly test a strong ord=1 PGD "
            " before enabling this feature.")
    i = 0
    while i < nb_iter:
        """
    Do a projected gradient step.
    """
        adv_x = fgm(net,
                    adv_x,
                    eps=eps_iter,
                    ordr=ordr,
                    clip_min=clip_min,
                    clip_max=clip_max,
                    y=y,
                    targeted=targeted)

        # Clipping perturbation eta to ord norm ball
        eta = adv_x - x
        eta = clip_eta(eta, ordr, eps)
        adv_x = x + eta

        # Redo the clipping.
        # FGM alread already did it, but subtracting and re-adding eta can add some
        # small numerical error
        if clip_min is not None or clip_max is not None:
            adv_x = torch.clamp(adv_x, clip_min, clip_max)
        i += 1

    # Asserts run only on CPU
    # When multi-GPU eval code tries to force all PGD ops onto GPU, this
    # can cause an error.
    # The 1e-6 is needed to compensate for numerical error.
    # Without the 1e-6 this fails when e.g. eps=.2 clip_max=.5 clip_min=.7
    if ordr == np.inf and clip_min is not None:
        assert (eps <= (1e6 + clip_max - clip_min))

    return adv_x