Exemplo n.º 1
0
 def init_modules(self):
     # Init modules that depend on the vocabulary
     with fork_rng(self.seed):
         with fork_rng(True):
             self.encoder = get_instance({
                 **self.encoder, "_preprocessor":
                 self.preprocessor
             }) if not isinstance(self.encoder,
                                  torch.nn.Module) else self.encoder
         with fork_rng(True):
             self.decoder = get_instance(
                 {
                     **self.decoder, "_preprocessor": self.preprocessor,
                     "_encoder": self.encoder
                 }) if not isinstance(self.decoder,
                                      torch.nn.Module) else self.decoder
Exemplo n.º 2
0
 def __init__(self, dim, max_len=512, temperature=10000.0, mode="sin", seed=None):
     super().__init__()
     self.dim = dim
     if mode.endswith("-proj"):
         self.proj = torch.nn.Linear(dim, dim)
         mode = mode[:-5]
     elif mode.endswith("-scale1d-init0"):
         self.proj = Scaler(dim, 0)
         mode = mode[:-14]
     elif mode.endswith("-scale1d-init0-affine"):
         self.proj = Scaler(dim, 0, affine=True)
         mode = mode[:-21]
     elif mode.endswith("-scale1d-init1"):
         self.proj = Scaler(dim, 1)
         mode = mode[:-14]
     elif mode.endswith("-scale1d-init1-affine"):
         self.proj = Scaler(dim, 1, affine=True)
         mode = mode[:-21]
     elif mode.endswith("-scale0d-init0"):
         self.proj = Scaler(1, 0)
         mode = mode[:-14]
     elif mode.endswith("-scale0d-init0-affine"):
         self.proj = Scaler(1, 0, affine=True)
         mode = mode[:-21]
     elif mode.endswith("-scale0d-init1"):
         self.proj = Scaler(1, 1)
         mode = mode[:-14]
     elif mode.endswith("-scale0d-init1-affine"):
         self.proj = Scaler(1, 1, affine=True)
         mode = mode[:-21]
     self.mode = mode
     if mode == "sin" or mode == "sym-sin" or mode == "inv-sin" or mode == "shift-sin":
         pe = torch.zeros(max_len, dim)
         position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
         div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(temperature) / dim))
         pe[:, 0::2] = torch.sin(position * div_term)
         pe[:, 1::2] = torch.cos(position * div_term)
         self.register_buffer('pe', pe)
     elif mode == "learned":
         if seed is not None:
             with fork_rng(seed):
                 self.pe = torch.nn.Embedding(max_len, dim).weight
         else:
             self.pe = torch.nn.Embedding(max_len, dim).weight
     elif mode == "random":
         self.pe = None
     elif mode == "zeros":
         self.pe = None
     else:
         raise Exception()
Exemplo n.º 3
0
 def fn():
     if getattr(self, 'test_data', None) is None or len(
             self.test_data) == 0:
         return None
     with fork_rng(self.data_seed):
         prep = self.preprocess(self.test_data, split="test")
         batch_size = self.batch_size if self.batch_size != "doc" else 1
         if hasattr(prep, '__getitem__'):
             return torch.utils.data.DataLoader(prep,
                                                shuffle=False,
                                                batch_size=batch_size,
                                                collate_fn=identity)
         else:
             return torch.utils.data.DataLoader(DummyIterableDataset(
                 prep, None),
                                                shuffle=False,
                                                batch_size=batch_size,
                                                collate_fn=identity)
Exemplo n.º 4
0
        def fn():
            if getattr(self, 'train_data',
                       None) is None or self._is_resuming_finished_model:
                return [None]
            with fork_rng(self.data_seed):
                batch_size = self.batch_size if self.batch_size != "doc" else 1
                non_default_epoch_length = (
                    self.trainer.val_check_interval * batch_size if
                    (getattr(self, 'trainer', None) is not None
                     and self.trainer.val_check_interval is not None
                     and self.trainer.max_steps is not None) else None)
                if hasattr(self.train_data,
                           '__getitem__') and non_default_epoch_length is None:
                    prep = self.preprocess(self.train_data, split="train")
                    return torch.utils.data.DataLoader(prep,
                                                       shuffle=True,
                                                       batch_size=batch_size,
                                                       collate_fn=identity)
                elif non_default_epoch_length is not None and hasattr(
                        self.train_data, '__len__'):
                    if self.dynamic_preprocessing is True:
                        prep = self.preprocess(loop(self.train_data,
                                                    shuffle=True),
                                               split="train")
                    else:
                        prep = loop(self.preprocess(self.train_data,
                                                    split="train"),
                                    shuffle=True)

                    return torch.utils.data.DataLoader(DummyIterableDataset(
                        prep, epoch_length=non_default_epoch_length),
                                                       shuffle=False,
                                                       batch_size=batch_size,
                                                       collate_fn=identity)
                else:
                    prep = self.preprocess(self.train_data, split="train")
                    return torch.utils.data.DataLoader(DummyIterableDataset(
                        prep, epoch_length=non_default_epoch_length),
                                                       shuffle=False,
                                                       batch_size=batch_size,
                                                       collate_fn=identity)
Exemplo n.º 5
0
    def __init__(self,
                 input_size,
                 hidden_size,
                 n_labels,
                 do_biaffine=True,
                 do_tagging=True,
                 do_length=True,
                 multi_label=True,
                 threshold=0.5,
                 max_length=100,
                 max_fragments_count=100,
                 detach_span_tag_logits=True,
                 tag_loss_weight=0.2,
                 biaffine_loss_weight=0.2,
                 dropout_p=0.2,
                 mode="seq",
                 allow_overlap=True,
                 learnable_transitions=False,
                 marginal_tagger_loss=False,
                 eps=1e-8):
        super().__init__()

        assert do_biaffine or do_tagging

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_labels = n_labels

        if not multi_label:
            effective_n_labels = n_labels + 1
        else:
            effective_n_labels = n_labels

        self.do_biaffine = do_biaffine
        self.do_tagging = do_tagging
        self.do_length = do_length
        self.multi_label = multi_label

        self.tag_loss_weight = tag_loss_weight
        self.biaffine_loss_weight = biaffine_loss_weight
        self.detach_span_tag_logits = detach_span_tag_logits

        self.threshold = threshold
        self.max_length = max_length
        self.max_fragments_count = max_fragments_count
        self.dropout = torch.nn.Dropout(dropout_p)

        with fork_rng():
            if do_tagging:
                if self.do_tagging is True:
                    self.do_tagging = "full"
                if self.do_tagging.startswith("full"):
                    n_tags = 5
                    self.register_buffer(
                        'tag_combinator',
                        torch.tensor([
                            [1, 0, 0, 0, 0],
                            [0, 1, 0, 0, 0],
                            [0, 0, 1, 0, 0],
                            [0, 0, 0, 1, 0],
                            [0, 0, 0, 0, 1],
                        ]).float())
                elif self.do_tagging.startswith("positive"):
                    n_tags = 4
                    self.register_buffer(
                        'tag_combinator',
                        torch.tensor([
                            [-1, 1, 0, 0, 0],
                            [-1, 0, 1, 0, 0],
                            [-1, 0, 0, 1, 0],
                            [-1, 0, 0, 0, 1],
                        ]).float())
                elif self.do_tagging.startswith("shared_label_unary"):
                    n_tags = 3
                    self.register_buffer(
                        'tag_combinator',
                        torch.tensor([
                            [-1, 1, 1, 1, 1],
                            [-1, -1, 1, 0, 1],
                            [-1, -1, 0, 1, 1],
                        ]).float())
                elif self.do_tagging.startswith("shared_label"):
                    n_tags = 4
                    self.register_buffer(
                        'tag_combinator',
                        torch.tensor([
                            [-1, 1, 1, 1, 1],
                            [-1, 0, 1, 0, 0],
                            [-1, 0, 0, 1, 0],
                            [-1, 0, 0, 0, 1],
                        ]).float())

                if self.do_tagging.endswith(":ffn"):
                    self.tag_proj = TagFFN(input_size,
                                           n_labels,
                                           n_tags=n_tags,
                                           dropout_p=dropout_p)
                else:
                    self.tag_proj = torch.nn.Linear(input_size,
                                                    n_labels * n_tags)

                self.crf = BIOULDecoder(
                    num_labels=1,
                    with_start_end_transitions=True,
                    allow_overlap=allow_overlap,
                    learnable_transitions=learnable_transitions,
                )
                self.tagger_loss = TaggerLoss(marginal=marginal_tagger_loss,
                                              _crf=self.crf)

        self.eps = eps

        with fork_rng():
            if do_biaffine:
                self.length_proj = torch.nn.Linear(
                    hidden_size, max_length) if self.do_length else None
                self.begin_proj = torch.nn.Linear(input_size,
                                                  hidden_size *
                                                  effective_n_labels,
                                                  bias=True)
                self.end_proj = torch.nn.Linear(input_size,
                                                hidden_size *
                                                effective_n_labels,
                                                bias=True)
                self.biaffine_bias = torch.nn.Parameter(torch.zeros(()))
                self.biaffine_loss = BiaffineLoss(multi_label=self.multi_label)

        self.mode = mode
Exemplo n.º 6
0
    def __init__(self,
                 _bert=None,
                 bert_config=None,
                 path=None,
                 n_layers=4,
                 combine_mode="softmax",
                 bert_dropout_p=None,
                 output_lm_embeds=False,
                 token_dropout_p=0.,
                 dropout_p=0.,
                 word_pooler={"module": "pooler", "mode": "mean"},
                 proj_size=None,
                 freeze_n_layers=-1,
                 do_norm=True,
                 do_cache=False,
                 _preprocessor=None, ):
        super().__init__()
        assert not ("scaled" in combine_mode and do_norm)
        if do_cache:
            assert freeze_n_layers == -1, "Must freeze bert to enable caching: set freeze_n_layers=-1"

        with fork_rng(True):
            if output_lm_embeds:
                self.bert = _bert if _bert is not None else transformers.AutoModelForMaskedLM.from_pretrained(path, config=bert_config)
                if hasattr(self.bert, 'lm_head'):
                    self.bert.lm_head.__class__ = LM_HEAD_CLS_MAPPING[self.bert.lm_head.__class__]
                else:
                    self.bert.cls.predictions.__class__ = LM_HEAD_CLS_MAPPING[self.bert.cls.predictions.__class__]
            else:
                self.bert = _bert if _bert is not None else transformers.AutoModel.from_pretrained(path, config=bert_config)
        self.output_lm_embeds = output_lm_embeds
        self.n_layers = n_layers
        if n_layers > 1:
            with fork_rng(True):
                self.weight = torch.nn.Parameter(torch.zeros(n_layers)) if "softmax" in combine_mode else torch.nn.Parameter(torch.ones(n_layers) / n_layers) if combine_mode == "linear" else None
        with fork_rng(True):
            self.word_pooler = Pooler(**word_pooler) if word_pooler is not None else None
        if "scaled" in combine_mode:
            self.bert_scaling = torch.nn.Parameter(torch.ones(()))
        self.combine_mode = combine_mode

        bert_model = self.bert.bert if hasattr(self.bert, 'bert') else self.bert.roberta if hasattr(self.bert, 'roberta') else self.bert
        bert_output_size = bert_model.embeddings.word_embeddings.weight.shape[1] * (1 if combine_mode != "concat" else n_layers)
        self.bert_output_size = bert_output_size
        if proj_size is not None:
            self.proj = torch.nn.Linear(bert_output_size, proj_size)
            self._output_size = proj_size
        else:
            self.proj = None
            self._output_size = bert_output_size
        self.norm = torch.nn.LayerNorm(self._output_size) if do_norm else Identity()

        if freeze_n_layers < 0:
            freeze_n_layers = len(bert_model.encoder.layer) + 2 + freeze_n_layers
        for module in (bert_model.embeddings, *bert_model.encoder.layer)[:freeze_n_layers]:
            for param in module.parameters():
                param.requires_grad = False
        if bert_dropout_p is not None:
            for module in bert_model.modules():
                if isinstance(module, torch.nn.Dropout):
                    module.p = bert_dropout_p
        self.dropout = torch.nn.Dropout(dropout_p)
        self.token_dropout_p = token_dropout_p
        self.cache = {}
        self.do_cache = do_cache