Beispiel #1
0
 def flatten(
     self,
     sample_list: Dict[str, Tensor],
     to_be_flattened: List[str],
     to_be_flattened_dim: List[str],
 ) -> Dict[str, Tensor]:
     for key in to_be_flattened:
         # Make sure these keys are present or otherwise set these keys to None
         sample_list[key] = transform_to_batch_sequence(sample_list[key])
     for key in to_be_flattened_dim:
         sample_list[key] = transform_to_batch_sequence_dim(sample_list[key])
     return sample_list
Beispiel #2
0
    def flatten(self,
                sample_list,
                to_be_flattened=None,
                to_be_flattened_dim=None):
        if to_be_flattened is None:
            to_be_flattened = {}
        if to_be_flattened_dim is None:
            to_be_flattened_dim = {}
        for key in to_be_flattened:
            # Make sure these keys are present or otherwise set these keys to None
            sample_list[key] = getattr(sample_list, key, None)
            sample_list[key] = transform_to_batch_sequence(sample_list[key])
        for key in to_be_flattened_dim:
            sample_list[key] = getattr(sample_list, key, None)
            sample_list[key] = transform_to_batch_sequence_dim(
                sample_list[key])

        if sample_list.visual_embeddings_type is None:
            if sample_list.image_mask is not None:
                sample_list.visual_embeddings_type = torch.zeros_like(
                    sample_list.image_mask, dtype=torch.long)

        if sample_list.image_mask is not None:
            attention_mask = torch.cat(
                (sample_list.input_mask, sample_list.image_mask), dim=-1)
            if sample_list.masked_lm_labels is not None:
                assert sample_list.masked_lm_labels.size(
                    -1) == sample_list.input_mask.size(-1)
                new_lm_labels = torch.ones_like(attention_mask) * -1
                size_masked_lm_labels = sample_list.masked_lm_labels.size()
                assert len(size_masked_lm_labels) == 2
                new_lm_labels[:size_masked_lm_labels[0], :
                              size_masked_lm_labels[
                                  1]] = sample_list.masked_lm_labels
                sample_list.masked_lm_labels = new_lm_labels
        else:
            attention_mask = sample_list.input_mask

        sample_list.attention_mask = attention_mask

        return sample_list