Example #1
0
    def _forward_train(self, encodings: torch.tensor,
                       context_masks: torch.tensor, entity_masks: torch.tensor,
                       entity_sizes: torch.tensor, relations: torch.tensor,
                       rel_masks: torch.tensor):
        # get contextualized token embeddings from last transformer layer
        context_masks = context_masks.float()
        h = self.bert(input_ids=encodings, attention_mask=context_masks)[0]

        entity_masks = entity_masks.float()
        batch_size = encodings.shape[0]

        # classify entities
        size_embeddings = self.size_embeddings(
            entity_sizes)  # embed entity candidate sizes
        entity_clf, entity_spans_pool = self._classify_entities(
            encodings, h, entity_masks, size_embeddings)

        # classify relations
        rel_masks = rel_masks.float().unsqueeze(-1)
        h_large = h.unsqueeze(1).repeat(
            1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1)
        rel_clf = torch.zeros(
            [batch_size, relations.shape[1],
             self._relation_types]).to(self.rel_classifier.weight.device)

        # obtain relation logits
        # chunk processing to reduce memory usage
        for i in range(0, relations.shape[1], self._max_pairs):
            # classify relation candidates
            chunk_rel_logits = self._classify_relations(
                entity_spans_pool, size_embeddings, relations, rel_masks,
                h_large, i)
            rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_logits

        return entity_clf, rel_clf
Example #2
0
    def _forward_train_common(self,
                              encodings: torch.tensor,
                              context_masks: torch.tensor,
                              mention_masks: torch.tensor,
                              mention_sizes: torch.tensor,
                              entities: torch.tensor,
                              entity_masks: torch.tensor,
                              coref_mention_pairs: torch.tensor,
                              coref_eds,
                              max_spans=None,
                              max_coref_pairs=None,
                              **kwargs):
        context_masks = context_masks.float()
        mention_masks = mention_masks.float()
        entity_masks = entity_masks.float()

        h = self.bert(input_ids=encodings,
                      attention_mask=context_masks)['last_hidden_state']
        mention_reprs = self.mention_representation(h,
                                                    mention_masks,
                                                    max_spans=max_spans)
        entity_reprs = self.entity_representation(mention_reprs, entities,
                                                  entity_masks)

        mention_clf = self.mention_localization(mention_reprs, mention_sizes)
        entity_clf = self.entity_classification(entity_reprs)
        coref_clf = self.coreference_resolution(mention_reprs,
                                                coref_mention_pairs,
                                                coref_eds,
                                                max_pairs=max_coref_pairs)

        return h, mention_reprs, entity_reprs, mention_clf, entity_clf, coref_clf
Example #3
0
    def forward(self, encodings: torch.tensor, context_masks: torch.tensor, mention_masks: torch.tensor,
                mention_sample_masks: torch.tensor, coref_mention_pairs: torch.tensor,
                coref_eds: torch.tensor, coref_sample_masks: torch.tensor,
                max_spans=None, max_coref_pairs=None, valid_mentions=None, inference=False, **kwargs):
        context_masks = context_masks.float()
        mention_masks = mention_masks.float()
        coref_sample_masks = coref_sample_masks.float()

        h = self.bert(input_ids=encodings, attention_mask=context_masks)['last_hidden_state']
        mention_reprs = self.mention_representation(h, mention_masks, max_spans=max_spans)

        coref_clf = self.coreference_resolution(mention_reprs, coref_mention_pairs, coref_eds, max_pairs=max_coref_pairs)

        if inference:
            if valid_mentions is None:
                valid_mentions = torch.ones(mention_sample_masks.shape, dtype=torch.float).to(context_masks.device)
                valid_mentions *= mention_sample_masks

            clusters, clusters_sample_masks = misc.create_clusters(coref_clf, coref_mention_pairs, coref_sample_masks,
                                                                   valid_mentions, self._coref_threshold)

            coref_clf = torch.sigmoid(coref_clf)
            coref_clf[coref_clf < self._coref_threshold] = 0
            coref_clf *= coref_sample_masks

            return dict(coref_clf=coref_clf, clusters=clusters, clusters_sample_masks=clusters_sample_masks)

        return dict(coref_clf=coref_clf)
Example #4
0
def focal_loss(
    y_pred: torch.tensor,
    y_true: torch.tensor,
    alpha: float = 1,
    gamma: float = 2,
    reduce: bool = True
) -> torch.tensor:
    """Focal loss for training a classifier.

    Args:
        y_pred (torch.tensor): tensor of predicted probabilities
        y_true (torch.tensor): tensor of binary integer targets
        alpha (float, optional): balancing term. Defaults to 1.
        gamma (float, optional): focusing parameter (larger --> downweights easier samples).
            Defaults to 2.
        reduce (bool, optional): whether to mean-reduce the loss. Defaults to True.

    Returns:
        torch.tensor: focal loss (scalar)
    """
    bce_loss_term = F.binary_cross_entropy(y_pred.float(), y_true.float(), reduction='none')
    p_t = torch.exp(-bce_loss_term)
    loss = (alpha * ((1 - p_t) ** gamma) * bce_loss_term)
    if reduce:
        loss = loss.mean()
    return loss
    def _forward_encode(self, encodings: torch.tensor,
                        context_mask: torch.tensor, entity_masks: torch.tensor,
                        entity_sizes: torch.tensor, relations: torch.tensor,
                        rel_masks: torch.tensor):
        # get contextualized token embeddings from last transformer layer
        context_mask = context_mask.float()
        h = self.bert(input_ids=encodings, attention_mask=context_mask)[0]

        entity_masks = entity_masks.float()
        batch_size = encodings.shape[0]
        device = self.entity_encoder.weight.device

        # encode and classify entities
        size_embeddings = self.size_embeddings(
            entity_sizes)  # embed entity candidate sizes
        entity_encoding, entity_spans_pool = self._encode_entities(
            encodings, h, entity_masks, size_embeddings)

        # prepare relation encoding
        rel_masks = rel_masks.float().unsqueeze(-1)
        h_large = h.unsqueeze(1).repeat(
            1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1)
        rel_encoding = torch.zeros(
            [batch_size, relations.shape[1], self.encoding_size]).to(device)

        # obtain relation encodings
        # chunk processing to reduce memory usage
        for i in range(0, relations.shape[1], self._max_pairs):
            # classify relation candidates
            rel_encoding_chunk = self._encode_relations(
                entity_spans_pool, size_embeddings, relations, rel_masks,
                h_large, i)
            rel_encoding[:, i:i + self._max_pairs, :] = rel_encoding_chunk

        return entity_encoding, rel_encoding
Example #6
0
    def forward(self, encodings: torch.tensor, context_masks: torch.tensor, mention_masks: torch.tensor,
                entities: torch.tensor, entity_masks: torch.tensor,
                rel_entity_pairs: torch.tensor, rel_sample_masks: torch.tensor,
                rel_entity_pair_mp: torch.tensor, rel_mention_pair_ep: torch.tensor,
                rel_mention_pairs: torch.tensor, rel_ctx_masks: torch.tensor, rel_pair_masks: torch.tensor,
                rel_token_distances: torch.tensor, rel_sentence_distances: torch.tensor, entity_types: torch.tensor,
                max_spans: bool = None, max_rel_pairs: bool = None, inference: bool = False, *args, **kwargs):
        context_masks = context_masks.float()
        mention_masks = mention_masks.float()
        entity_masks = entity_masks.float()

        h = self.bert(input_ids=encodings, attention_mask=context_masks)['last_hidden_state']
        mention_reprs = self.mention_representation(h, mention_masks, max_spans=max_spans)
        entity_reprs = self.entity_representation(mention_reprs, entities, entity_masks)
        entity_pair_reprs = self.entity_pair_representation(entity_reprs, rel_entity_pairs)

        rel_entity_types = util.batch_index(entity_types, rel_entity_pairs)
        rel_clf = self.relation_classification(entity_pair_reprs, h, mention_reprs,
                                               rel_entity_pair_mp, rel_mention_pair_ep,
                                               rel_mention_pairs, rel_ctx_masks, rel_pair_masks,
                                               rel_token_distances, rel_sentence_distances, rel_entity_types,
                                               max_pairs=max_rel_pairs)

        if inference:
            rel_clf = torch.sigmoid(rel_clf)
            rel_clf[rel_clf < self._rel_threshold] = 0
            rel_clf *= rel_sample_masks.unsqueeze(-1)

        return dict(rel_clf=rel_clf)
    def _forward_eval(self,
                      encodings: torch.tensor,
                      context_mask: torch.tensor,
                      entity_masks: torch.tensor,
                      entity_sizes: torch.tensor,
                      entity_spans: torch.tensor = None,
                      entity_sample_mask: torch.tensor = None):
        # get contextualized token embeddings from last transformer layer
        context_mask = context_mask.float()
        h = self.bert(input_ids=encodings, attention_mask=context_mask)[0]

        # enhance hidden features
        orig_shape = h.shape
        h = self.feature_enhancer.prepare_input(h, context_mask)
        h = self.feature_enhancer(h)
        h = self.feature_enhancer.prepare_output(h, orig_shape)

        entity_masks = entity_masks.float()
        batch_size = encodings.shape[0]
        ctx_size = context_mask.shape[-1]

        # classify entities
        size_embeddings = self.size_embeddings(
            entity_sizes)  # embed entity candidate sizes
        entity_clf, entity_spans_pool = self._classify_entities(
            encodings, h, entity_masks, size_embeddings)

        # apply softmax
        entity_clf = torch.softmax(entity_clf, dim=2)

        return entity_clf
Example #8
0
    def _forward_inference_common(self,
                                  encodings: torch.tensor,
                                  context_masks: torch.tensor,
                                  mention_masks: torch.tensor,
                                  mention_sizes: torch.tensor,
                                  mention_spans: torch.tensor,
                                  mention_sample_masks: torch.tensor,
                                  max_spans=None,
                                  max_coref_pairs=None):
        # get contextualized token embeddings from last transformer layer
        context_masks = context_masks.float()
        mention_masks = mention_masks.float()
        mention_sample_masks = mention_sample_masks.float()

        # embed documents
        h = self.bert(input_ids=encodings,
                      attention_mask=context_masks)['last_hidden_state']

        # get mention representations
        mention_reprs = self.mention_representation(h,
                                                    mention_masks,
                                                    max_spans=max_spans)

        # classify mentions
        mention_clf = self.mention_localization(mention_reprs, mention_sizes)
        valid_mentions = (
            (torch.sigmoid(mention_clf) >= self._mention_threshold).float() *
            mention_sample_masks)

        # create mention pairs
        coref_mention_pairs, coref_mention_eds, coref_sample_masks = misc.create_coref_mention_pairs(
            valid_mentions, mention_spans, encodings, self._tokenizer)
        coref_sample_masks = coref_sample_masks.float()

        # classify coreferences
        coref_clf = self.coreference_resolution(mention_reprs,
                                                coref_mention_pairs,
                                                coref_mention_eds,
                                                max_pairs=max_coref_pairs)

        # create clusters
        clusters, clusters_sample_masks = misc.create_clusters(
            coref_clf, coref_mention_pairs, coref_sample_masks, valid_mentions,
            self._coref_threshold)
        entity_sample_masks = clusters_sample_masks.any(-1).float()

        # create entity representations
        entity_reprs = self.entity_representation(
            mention_reprs, clusters, clusters_sample_masks.float())

        # classify entities
        entity_clf = self.entity_classification(entity_reprs)

        return (h, mention_reprs, entity_reprs, clusters, entity_sample_masks,
                coref_sample_masks, clusters_sample_masks, mention_clf,
                entity_clf, coref_clf)
    def _forward_eval(self,
                      encodings: torch.tensor,
                      context_mask: torch.tensor,
                      entity_masks: torch.tensor,
                      entity_sizes: torch.tensor,
                      entity_spans: torch.tensor = None,
                      entity_sample_mask: torch.tensor = None):
        # get contextualized token embeddings from last transformer layer
        context_mask = context_mask.float()
        h = self.bert(input_ids=encodings, attention_mask=context_mask)[0]

        # enhance hidden features
        orig_shape = h.shape
        h = self.feature_enhancer.prepare_input(h, context_mask)
        h = self.feature_enhancer(h)
        h = self.feature_enhancer.prepare_output(h, orig_shape)

        entity_masks = entity_masks.float()
        batch_size = encodings.shape[0]
        ctx_size = context_mask.shape[-1]

        # classify entities
        size_embeddings = self.size_embeddings(
            entity_sizes)  # embed entity candidate sizes
        entity_clf, entity_spans_pool = self._classify_entities(
            encodings, h, entity_masks, size_embeddings)

        # ignore entity candidates that do not constitute an actual entity for relations (based on classifier)
        relations, rel_masks, rel_sample_masks = self._filter_spans(
            entity_clf, entity_spans, entity_sample_mask, ctx_size)
        rel_masks = rel_masks.float()
        rel_sample_masks = rel_sample_masks.float()
        h_large = h.unsqueeze(1).repeat(
            1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1)
        rel_clf = torch.zeros(
            [batch_size, relations.shape[1],
             self._relation_types]).to(self.rel_classifier.weight.device)

        # obtain relation logits
        # chunk processing to reduce memory usage
        for i in range(0, relations.shape[1], self._max_pairs):
            # classify relation candidates
            chunk_rel_logits = self._classify_relations(
                entity_spans_pool, size_embeddings, relations, rel_masks,
                h_large, i)
            # apply sigmoid
            chunk_rel_clf = torch.sigmoid(chunk_rel_logits)
            rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf

        rel_clf = rel_clf * rel_sample_masks  # mask

        # apply softmax
        entity_clf = torch.softmax(entity_clf, dim=2)

        return entity_clf, rel_clf, relations
    def dice(self,
             input: torch.tensor,
             target: torch.tensor,
             weight: float,
             epsilon=1e-6) -> float:
        """
        Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given  a multi channel input and target.
        Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function.

        :param input: NxCxSpatial input tensor
        :param target:  NxCxSpatial target tensor
        :param weight: Cx1 tensor of weight per channel. Channels represent the class
        :param epsilon: prevents division by zero
        :return: dice loss, dice score

        """
        assert input.size() == target.size(
        ), "'input' and 'target' must have the same shape"

        input = self._flatten(input)
        target = self._flatten(target)
        target = target.float()

        # Compute per channel Dice Coefficient
        intersect = (input * target).sum(-1)
        if weight is not None:
            intersect = weight * intersect

        union = (input * input).sum(-1) + (target * target).sum(-1)
        return 2 * (intersect / union.clamp(min=epsilon))
Example #11
0
def masked_cross_entropy(logits:torch.tensor, target:torch.tensor, length:torch.tensor):
    """
    Args:
        logits: A tensor containing a FloatTensor of size
            (batch, max_len, num_classes) which contains the
            unnormalized probability for each class.
        target: A tensor containing a LongTensor of size
            (batch, max_len) which contains the index of the true
            class for each corresponding step.
        length: A tensor containing a LongTensor of size (batch,)
            which contains the length of each data in a batch.

    Returns:
        loss: An average loss value masked by the length.
    """

    # logits_flat: (batch * max_len, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # log_probs_flat: (batch * max_len, num_classes)
    log_probs_flat = functional.log_softmax(logits_flat, dim=-1)
    # target_flat: (batch * max_len, 1)
    target_flat = target.view(-1, 1)
    # losses_flat: (batch * max_len, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
    # losses: (batch, max_len)
    losses = losses_flat.view(*target.size())
    # mask: (batch, max_len)
    mask = sequence_mask(sequence_length=length, max_len=target.size(1))
    losses = losses * mask.float()
    loss = losses.sum() / length.float().sum()
    return loss
Example #12
0
  def __init__(
          self,
          model: torch.nn.Module,
          device: torch.device,
          optimizer: torch.optim.Adam,
          file_path: str,
          train_loader: DataLoader,
          valid_loader: DataLoader,
          test_loader: DataLoader,
          unique_cells: Sequence[int],
          num_epochs: int,
          cells_tensor: torch.tensor,
          label_to_cellid: Dict[int, int],
          is_distance_distribution: bool

          ):

    self.model = model
    self.device = device
    self.optimizer = optimizer
    self.file_path = file_path
    self.train_loader = train_loader
    self.valid_loader = valid_loader
    self.test_loader = test_loader
    self.unique_cells = unique_cells
    self.num_epochs = num_epochs
    self.cells_tensor = cells_tensor.float().to(self.device)
    self.label_to_cellid = label_to_cellid
    self.cos = nn.CosineSimilarity(dim=2)
    self.best_valid_loss = float("Inf")
    if not os.path.exists(self.file_path):
      os.mkdir(self.file_path)
    self.model_path = os.path.join(self.file_path, 'model.pt')
    self.metrics_path = os.path.join(self.file_path, 'metrics.tsv')
    self.is_distance_distribution = is_distance_distribution
Example #13
0
def print_tensor(x: torch.tensor, pt=True, val=True, shp=True):
    """Print the mean, min, median, std, and size of a tensor tensor

    Args:
        x:
        val: if print the values of the tensor
        shp: if print the shape of the tensor

    Returns: None

    """

    x = x.float()
    message = ''
    if shp:
        message = str(x.shape) + '\n'
    if val:
        x = x.flatten()
        if len(x) != 1:
            message += (
                'mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f'
                % (x.mean(), x.min(), x.max(), x.median(), x.std()))
        else:
            message += (f'one element {x[0]}')
    if pt:
        logging.debug(message)
    return message
Example #14
0
    def forward(self, encodings: torch.tensor, context_masks: torch.tensor, mention_masks: torch.tensor,
                mention_sizes: torch.tensor, mention_sample_masks: torch.tensor,
                max_spans=None, inference=False, **kwargs):
        context_masks = context_masks.float()
        mention_masks = mention_masks.float()

        h = self.bert(input_ids=encodings, attention_mask=context_masks)['last_hidden_state']
        mention_reprs = self.mention_representation(h, mention_masks, max_spans=max_spans)
        mention_clf = self.mention_localization(mention_reprs, mention_sizes)

        if inference:
            mention_clf = torch.sigmoid(mention_clf)
            mention_clf[mention_clf < self._mention_threshold] = 0
            mention_clf *= mention_sample_masks

        return dict(mention_clf=mention_clf)
Example #15
0
    def _forward_train(self, encodings: torch.tensor,
                       context_mask: torch.tensor, entity_masks: torch.tensor,
                       entity_sizes: torch.tensor, relations: torch.tensor,
                       rel_masks: torch.tensor):
        # get contextualized token embeddings from last transformer layer
        context_mask = context_mask.float()
        h_bert = self.bert(input_ids=encodings, attention_mask=context_mask)[0]

        # print("h_bert", h_bert.shape, h_bert[1, :, 1])
        # print("mask bert", context_mask.shape, context_mask[1, :])
        # lengths = context_mask.sum(dim=1).int().tolist()
        # print("lengths", lengths)

        # enhance hidden features
        orig_shape = h_bert.shape
        h = self.feature_enhancer.prepare_input(h_bert, context_mask)
        h = self.feature_enhancer(h)
        h = self.feature_enhancer.prepare_output(h, orig_shape)
        # print("h_fe_prepped", h[1, :, 1])

        entity_masks = entity_masks.float()
        batch_size = encodings.shape[0]

        # classify entities
        size_embeddings = self.size_embeddings(
            entity_sizes)  # embed entity candidate sizes
        entity_clf, entity_spans_pool = self._classify_entities(
            encodings, h, entity_masks, size_embeddings)

        # classify relations
        rel_masks = rel_masks.float().unsqueeze(-1)
        h_large = h.unsqueeze(1).repeat(
            1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1)
        rel_clf = torch.zeros(
            [batch_size, relations.shape[1],
             self._relation_types]).to(self.rel_classifier.weight.device)

        # obtain relation logits
        # chunk processing to reduce memory usage
        for i in range(0, relations.shape[1], self._max_pairs):
            # classify relation candidates
            chunk_rel_logits = self._classify_relations(
                entity_spans_pool, size_embeddings, relations, rel_masks,
                h_large, i)
            rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_logits

        return entity_clf, rel_clf
Example #16
0
    def forward(self, encodings: torch.tensor, context_masks: torch.tensor, mention_masks: torch.tensor,
                mention_sizes: torch.tensor, entity_sample_masks: torch.tensor,
                entities: torch.tensor, entity_masks: torch.tensor, max_spans=None, inference=False, **kwargs):
        context_masks = context_masks.float()
        mention_masks = mention_masks.float()

        h = self.bert(input_ids=encodings, attention_mask=context_masks)['last_hidden_state']
        mention_reprs = self.mention_representation(h, mention_masks, max_spans=max_spans)

        entity_reprs = self.entity_representation(mention_reprs, entities, entity_masks)
        entity_clf = self.entity_classification(entity_reprs)

        if inference:
            entity_clf = torch.softmax(entity_clf, dim=-1)
            entity_clf *= entity_sample_masks.float().unsqueeze(-1)

        return dict(entity_clf=entity_clf)
Example #17
0
 def forward(self,
             wav_mask: torch.tensor,
             mask_val: float = 0.0) -> torch.tensor:
     # make mask
     with torch.no_grad():
         mel_mask = self.conv(wav_mask.float().unsqueeze(1)).squeeze(1)
         mel_mask = (mel_mask != mask_val).float()
     return mel_mask
Example #18
0
 def forward(self, wav_mask: torch.tensor) -> torch.tensor:
     # make mask
     with torch.no_grad():
         wav_mask = F.pad(wav_mask, [0, self.win_length // 2], value=0.)
         wav_mask = F.pad(wav_mask, [self.win_length // 2, 0], value=1.)
         mel_mask = self.conv(wav_mask.float().unsqueeze(1)).squeeze(1)
         mel_mask = torch.ceil(mel_mask)
     return mel_mask
Example #19
0
    def __call__(self, prediction: torch.tensor, target: torch.tensor, **kwargs) -> torch.tensor:
        prediction = prediction.transpose(self.axis, -1).contiguous()
        target = target.transpose(self.axis, -1).contiguous()

        if self.to_float:
            target = target.float()

        prediction = prediction.view(-1, prediction.shape[-1]) if self.is_2d else prediction.view(-1)
        return self.func.__call__(prediction, target.view(-1), **kwargs)
Example #20
0
def bce_loss(
    y_pred: torch.tensor,
    y_true: torch.tensor,
    reduce: bool = True,
    label_weight_mapping: dict = None
) -> torch.tensor:
    weight = None
    if label_weight_mapping is not None:
        weight = torch.tensor([
            [label_weight_mapping[col][int(t[col].item())] for col in range(len(t))]
            for t in y_true
        ])
        if torch.cuda.is_available():
            weight = weight.cuda()
    loss = F.binary_cross_entropy(y_pred.float(), y_true.float(), weight=weight)
    if reduce:
        loss = loss.mean()
    return loss
Example #21
0
    def forward(self, embedding: torch.tensor, action: torch.tensor = None):
        if action is None:
            action = torch.zeros((embedding.shape[0], self.action_dim), dtype=torch.float32)
        else:
            action = action.float()

        x = torch.cat([embedding, action], dim=-1)
        quality = self.transform(x)
        
        return quality
def f2_loss(pred: torch.tensor, target: torch.tensor, epsilon:float=1e-7) -> float:
    beta = 2
    y_pred = nn.Sigmoid()(pred)
    y_true = target.float()

    TP = (y_pred * y_true).sum(1)
    prec = TP / (y_pred.sum(1) + epsilon)
    rec = TP / (y_true.sum(1) + epsilon)
    res = (1 + beta ** 2) * prec * rec / ((beta ** 2) + prec + rec + epsilon)

    f1 = res
    f1 = f1.clamp(min=0)
    return 1 - f1.mean()
def intersectionAndUnionGPU(
        preds: torch.tensor,
        target: torch.tensor,
        num_classes: int,
        ignore_index=255) -> Tuple[torch.tensor, torch.tensor, torch.tensor]:
    """
    inputs:
        preds : shape [H, W]
        target : shape [H, W]
        num_classes : Number of classes

    returns :
        area_intersection : shape [num_class]
        area_union : shape [num_class]
        area_target : shape [num_class]
    """
    assert (preds.dim() in [1, 2, 3])
    assert preds.shape == target.shape
    preds = preds.view(-1)
    target = target.view(-1)
    preds[target == ignore_index] = ignore_index
    intersection = preds[preds == target]

    # Addind .float() becausue histc not working with long() on CPU
    area_intersection = torch.histc(intersection.float(),
                                    bins=num_classes,
                                    min=0,
                                    max=num_classes - 1)
    area_output = torch.histc(preds.float(),
                              bins=num_classes,
                              min=0,
                              max=num_classes - 1)
    area_target = torch.histc(target.float(),
                              bins=num_classes,
                              min=0,
                              max=num_classes - 1)
    area_union = area_output + area_target - area_intersection
    # print(torch.unique(intersection))
    return area_intersection, area_union, area_target
Example #24
0
    def _forward_train(self, encodings: torch.tensor,
                       context_mask: torch.tensor, entity_masks: torch.tensor,
                       entity_sizes: torch.tensor):
        # get contextualized token embeddings from last transformer layer
        context_mask = context_mask.float()
        h_bert = self.bert(input_ids=encodings, attention_mask=context_mask)[0]

        # enhance hidden features
        orig_shape = h_bert.shape
        h = self.feature_enhancer.prepare_input(h_bert, context_mask)
        h = self.feature_enhancer(h)
        h = self.feature_enhancer.prepare_output(h, orig_shape)

        entity_masks = entity_masks.float()
        batch_size = encodings.shape[0]

        # classify entities
        size_embeddings = self.size_embeddings(
            entity_sizes)  # embed entity candidate sizes
        entity_clf, entity_spans_pool = self._classify_entities(
            encodings, h, entity_masks, size_embeddings)

        return entity_clf
Example #25
0
def lr_one_cuml(x: torch.tensor, y: torch.tensor, **kwargs):
    from cuml.linear_model import LogisticRegression

    initial_device = x.device
    x = x.detach()
    y = y.float()
    model = LogisticRegression(**kwargs)
    model = model.fit(x, y)
    weight = model.coef_
    bias = model.intercept_
    weight = torch.as_tensor(weight, device=initial_device).float().t()
    bias = torch.as_tensor(bias, device=initial_device).float()
    n_iter = model.solver_model.num_iters
    return weight, bias, n_iter
Example #26
0
    def _forward_train(self, encodings: torch.tensor,
                       context_masks: torch.tensor, entity_masks: torch.tensor,
                       entity_sizes: torch.tensor):  # noqa
        # get contextualized token embeddings from last transformer layer
        context_masks = context_masks.float()
        h = self.bert(input_ids=encodings, attention_mask=context_masks)[0]

        # classify entities
        size_embeddings = self.size_embeddings(
            entity_sizes)  # embed entity candidate sizes
        entity_clf = self._classify_entities(encodings, h, entity_masks,
                                             size_embeddings)

        return entity_clf
Example #27
0
def mask_fill(
    fill_value: float,
    tokens: torch.tensor,
    embeddings: torch.tensor,
    padding_index: int,
) -> torch.tensor:
    """
    Function that masks embeddings representing padded elements.
    :param fill_value: the value to fill the embeddings belonging to padded tokens.
    :param tokens: The input sequences [bsz x seq_len].
    :param embeddings: word embeddings [bsz x seq_len x hiddens].
    :param padding_index: Index of the padding token.
    """
    padding_mask = tokens.eq(padding_index).unsqueeze(-1)
    return embeddings.float().masked_fill_(padding_mask, fill_value).type_as(embeddings)
Example #28
0
    def forward(self, preds: torch.tensor, labels: torch.tensor):
        smooth = 1.

        preds = torch.argmax(preds, dim=-1).float()
        preds.requires_grad = True
        labels = labels.float()

        pred_flat = preds.view(-1)
        label_flat = labels.view(-1)
        intersection = (pred_flat * label_flat).sum()

        score = (2. * intersection + smooth) / (pred_flat.sum() +
                                                label_flat.sum() + smooth)
        score = 1. - score.sum() / labels.size(0)
        return score
Example #29
0
def _input_transform(img: torch.tensor):
    ## bgr2rgb
    new_img = torch.zeros(img.shape).to(device)
    new_img[:, :, 0] = new_img[:, :, 0] + img[:, :, 2]
    new_img[:, :, 1] = new_img[:, :, 1] + img[:, :, 1]
    new_img[:, :, 2] = new_img[:, :, 2] + img[:, :, 0]
    img = new_img

    # reisze
    img = resize(img)

    img = img.permute(2, 0, 1).contiguous()
    #img = img.float().div(255.0)
    img = img.float() / 255.0
    img = img.unsqueeze(0)
    return img
    def compute_FB_param(self, features_q: torch.tensor,
                         gt_q: torch.tensor) -> torch.tensor:
        """
        inputs:
            features_q : shape [n_tasks, shot, c, h, w]
            gt_q : shape [n_tasks, shot, h, w]

        updates :
             self.FB_param : shape [n_tasks, num_classes]
        """
        ds_gt_q = F.interpolate(gt_q.float(),
                                size=features_q.size()[-2:],
                                mode='nearest').long()
        valid_pixels = (ds_gt_q != 255).unsqueeze(
            2)  # [n_tasks, shot, num_classes, h, w]
        assert (valid_pixels.sum(
            dim=(1, 2, 3, 4)) == 0).sum() == 0, valid_pixels.sum(dim=(1, 2, 3,
                                                                      4))

        one_hot_gt_q = to_one_hot(
            ds_gt_q, self.num_classes)  # [n_tasks, shot, num_classes, h, w]

        oracle_FB_param = (valid_pixels * one_hot_gt_q).sum(
            dim=(1, 3, 4)) / valid_pixels.sum(dim=(1, 3, 4))

        if self.FB_param_type == 'oracle':
            self.FB_param = oracle_FB_param
            # Used to assess influence of delta perturbation
            if self.FB_param_noise != 0:
                perturbed_FB_param = oracle_FB_param
                perturbed_FB_param[:,
                                   1] += self.FB_param_noise * perturbed_FB_param[:,
                                                                                  1]
                perturbed_FB_param = torch.clamp(perturbed_FB_param, 0, 1)
                perturbed_FB_param[:, 0] = 1.0 - perturbed_FB_param[:, 1]
                self.FB_param = perturbed_FB_param

        else:
            logits_q = self.get_logits(features_q)
            probas = self.get_probas(logits_q).detach()
            self.FB_param = (valid_pixels * probas).sum(dim=(1, 3, 4))
            self.FB_param /= valid_pixels.sum(dim=(1, 3, 4))

        # Compute the relative error
        deltas = self.FB_param[:, 1] / oracle_FB_param[:, 1] - 1
        return deltas