Example #1
0
    def _filter_spans(self, entity_clf, entity_spans, entity_sample_masks,
                      ctx_size):
        batch_size = entity_clf.shape[0]
        entity_logits_max = entity_clf.argmax(
            dim=-1) * entity_sample_masks.long(
            )  # get entity type (including none)
        batch_relations = []
        batch_rel_masks = []
        batch_rel_sample_masks = []

        for i in range(batch_size):
            rels = []
            rel_masks = []
            sample_masks = []

            # get spans classified as entities
            non_zero_indices = (entity_logits_max[i] != 0).nonzero().view(-1)
            non_zero_spans = entity_spans[i][non_zero_indices].tolist()
            non_zero_indices = non_zero_indices.tolist()

            # create relations and masks
            for i1, s1 in zip(non_zero_indices, non_zero_spans):
                for i2, s2 in zip(non_zero_indices, non_zero_spans):
                    if i1 != i2:
                        rels.append((i1, i2))
                        rel_masks.append(
                            sampling.create_rel_mask(s1, s2, ctx_size))
                        sample_masks.append(1)

            if not rels:
                # case: no more than two spans classified as entities
                batch_relations.append(torch.tensor([[0, 0]],
                                                    dtype=torch.long))
                batch_rel_masks.append(
                    torch.tensor([[0] * ctx_size], dtype=torch.bool))
                batch_rel_sample_masks.append(
                    torch.tensor([0], dtype=torch.bool))
            else:
                # case: more than two spans classified as entities
                batch_relations.append(torch.tensor(rels, dtype=torch.long))
                batch_rel_masks.append(torch.stack(rel_masks))
                batch_rel_sample_masks.append(
                    torch.tensor(sample_masks, dtype=torch.bool))

        # stack
        device = self.rel_classifier.weight.device
        batch_relations = util.padded_stack(batch_relations).to(device)
        batch_rel_masks = util.padded_stack(batch_rel_masks).to(
            device).unsqueeze(-1)
        batch_rel_sample_masks = util.padded_stack(batch_rel_sample_masks).to(
            device).unsqueeze(-1)

        return batch_relations, batch_rel_masks, batch_rel_sample_masks
Example #2
0
def collate_fn_padding(batch):
    padded_batch = dict()
    keys = batch[0].keys()

    for key in keys:
        samples = [s[key] for s in batch]

        if not batch[0][key].shape:
            padded_batch[key] = torch.stack(samples)
        else:
            padded_batch[key] = util.padded_stack([s[key] for s in batch])

    return padded_batch
Example #3
0
def collate_fn_padding(batch):
    padded_batch = dict()
    keys = batch[0].keys()

    for key in keys:
        samples = [s[key] for s in batch]
        if key != '_id':
            if torch.is_tensor(batch[0][key]) and not batch[0][key].shape:
                padded_batch[key] = torch.stack(samples)
            else:
                padded_batch[key] = util.padded_stack([s[key] for s in batch])

    padded_batch['_id'] = batch[0]['_id']
    return padded_batch
Example #4
0
def _create_eval_batch(samples):
    batch_encodings = []
    batch_ctx_masks = []

    batch_entity_masks = []
    batch_entity_sizes = []
    batch_entity_spans = []
    batch_entity_sample_masks = []

    for sample in samples:
        encoding = sample.encoding
        ctx_mask = sample.ctx_mask

        entity_masks = sample.entity_masks
        entity_sizes = sample.entity_sizes
        entity_spans = sample.entity_spans

        # tensors to mask entity samples of batch
        # since samples are stacked into batches, "padding" entities possibly must be created
        # these are later masked during evaluation
        entity_sample_masks = torch.tensor([1] * entity_masks.shape[0],
                                           dtype=torch.bool)

        # corner case handling (no entities)
        if entity_masks.shape[0] == 0:
            entity_masks = torch.tensor([[0] * entity_masks.shape[-1]],
                                        dtype=torch.bool)
            entity_sizes = torch.tensor([0], dtype=torch.long)
            entity_spans = torch.tensor([[0, 0]], dtype=torch.long)
            entity_sample_masks = torch.tensor([0], dtype=torch.bool)

        batch_encodings.append(encoding)
        batch_ctx_masks.append(ctx_mask)

        batch_entity_masks.append(entity_masks)
        batch_entity_sizes.append(entity_sizes)
        batch_entity_spans.append(entity_spans)
        batch_entity_sample_masks.append(entity_sample_masks)

    # stack samples
    encodings = util.padded_stack(batch_encodings)
    ctx_masks = util.padded_stack(batch_ctx_masks)

    batch_entity_masks = util.padded_stack(batch_entity_masks)
    batch_entity_sizes = util.padded_stack(batch_entity_sizes)
    batch_entity_spans = util.padded_stack(batch_entity_spans)
    batch_entity_sample_masks = util.padded_stack(batch_entity_sample_masks)

    batch = EvalTensorBatch(encodings=encodings,
                            ctx_masks=ctx_masks,
                            entity_masks=batch_entity_masks,
                            entity_sizes=batch_entity_sizes,
                            entity_spans=batch_entity_spans,
                            entity_sample_masks=batch_entity_sample_masks)

    return batch
Example #5
0
def _create_train_batch(samples):
    batch_encodings = []
    batch_ctx_masks = []

    batch_entity_masks = []
    batch_entity_sizes = []
    batch_entity_sample_masks = []

    batch_rels = []
    batch_rel_sample_masks = []
    batch_rel_masks = []

    batch_entity_types = []
    batch_rel_types = []

    for sample in samples:
        encoding = sample.encoding
        ctx_mask = sample.ctx_mask

        # entities
        entity_masks = sample.entity_masks
        entity_sizes = sample.entity_sizes
        entity_types = sample.entity_types

        # relations
        rels = sample.rels
        rel_masks = sample.rel_masks
        rel_types = sample.rel_types

        # tensors to mask entity/relation samples of batch
        # since samples are stacked into batches, "padding" entities/relations possibly must be created
        # these are later masked during loss computation
        entity_sample_masks = torch.tensor([1] * entity_masks.shape[0],
                                           dtype=torch.bool)
        rel_sample_masks = torch.tensor([1] * rel_masks.shape[0],
                                        dtype=torch.bool)

        # corner case handling (no entities / relations)
        if entity_masks.shape[0] == 0:
            entity_types = torch.tensor([0], dtype=torch.long)
            entity_masks = torch.tensor([[0] * entity_masks.shape[-1]],
                                        dtype=torch.bool)
            entity_sizes = torch.tensor([0], dtype=torch.long)
            entity_sample_masks = torch.tensor([0], dtype=torch.bool)

        if rel_masks.shape[0] == 0:
            rels = torch.tensor([[0, 0]], dtype=torch.long)
            rel_types = torch.tensor([0], dtype=torch.long)
            rel_masks = torch.tensor([[0] * rel_masks.shape[-1]],
                                     dtype=torch.bool)
            rel_sample_masks = torch.tensor([0], dtype=torch.bool)

        batch_encodings.append(encoding)
        batch_ctx_masks.append(ctx_mask)

        batch_entity_masks.append(entity_masks)
        batch_entity_sizes.append(entity_sizes)
        batch_entity_sample_masks.append(entity_sample_masks)

        batch_rels.append(rels)
        batch_rel_masks.append(rel_masks)
        batch_rel_sample_masks.append(rel_sample_masks)

        batch_rel_types.append(rel_types)
        batch_entity_types.append(entity_types)

    # stack samples
    encodings = util.padded_stack(batch_encodings)
    ctx_masks = util.padded_stack(batch_ctx_masks)

    batch_entity_masks = util.padded_stack(batch_entity_masks)
    batch_entity_sizes = util.padded_stack(batch_entity_sizes)

    batch_rels = util.padded_stack(batch_rels)
    batch_rel_masks = util.padded_stack(batch_rel_masks)

    batch_rel_sample_masks = util.padded_stack(batch_rel_sample_masks)
    batch_entity_sample_masks = util.padded_stack(batch_entity_sample_masks)

    batch_rel_types = util.padded_stack(batch_rel_types)
    batch_entity_types = util.padded_stack(batch_entity_types)

    batch = TrainTensorBatch(encodings=encodings,
                             ctx_masks=ctx_masks,
                             entity_masks=batch_entity_masks,
                             entity_sizes=batch_entity_sizes,
                             entity_types=batch_entity_types,
                             entity_sample_masks=batch_entity_sample_masks,
                             rels=batch_rels,
                             rel_masks=batch_rel_masks,
                             rel_types=batch_rel_types,
                             rel_sample_masks=batch_rel_sample_masks)

    return batch