def init_modules(self): # Init modules that depend on the vocabulary with fork_rng(self.seed): with fork_rng(True): self.encoder = get_instance({ **self.encoder, "_preprocessor": self.preprocessor }) if not isinstance(self.encoder, torch.nn.Module) else self.encoder with fork_rng(True): self.decoder = get_instance( { **self.decoder, "_preprocessor": self.preprocessor, "_encoder": self.encoder }) if not isinstance(self.decoder, torch.nn.Module) else self.decoder
def __init__(self, dim, max_len=512, temperature=10000.0, mode="sin", seed=None): super().__init__() self.dim = dim if mode.endswith("-proj"): self.proj = torch.nn.Linear(dim, dim) mode = mode[:-5] elif mode.endswith("-scale1d-init0"): self.proj = Scaler(dim, 0) mode = mode[:-14] elif mode.endswith("-scale1d-init0-affine"): self.proj = Scaler(dim, 0, affine=True) mode = mode[:-21] elif mode.endswith("-scale1d-init1"): self.proj = Scaler(dim, 1) mode = mode[:-14] elif mode.endswith("-scale1d-init1-affine"): self.proj = Scaler(dim, 1, affine=True) mode = mode[:-21] elif mode.endswith("-scale0d-init0"): self.proj = Scaler(1, 0) mode = mode[:-14] elif mode.endswith("-scale0d-init0-affine"): self.proj = Scaler(1, 0, affine=True) mode = mode[:-21] elif mode.endswith("-scale0d-init1"): self.proj = Scaler(1, 1) mode = mode[:-14] elif mode.endswith("-scale0d-init1-affine"): self.proj = Scaler(1, 1, affine=True) mode = mode[:-21] self.mode = mode if mode == "sin" or mode == "sym-sin" or mode == "inv-sin" or mode == "shift-sin": pe = torch.zeros(max_len, dim) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(temperature) / dim)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) elif mode == "learned": if seed is not None: with fork_rng(seed): self.pe = torch.nn.Embedding(max_len, dim).weight else: self.pe = torch.nn.Embedding(max_len, dim).weight elif mode == "random": self.pe = None elif mode == "zeros": self.pe = None else: raise Exception()
def fn(): if getattr(self, 'test_data', None) is None or len( self.test_data) == 0: return None with fork_rng(self.data_seed): prep = self.preprocess(self.test_data, split="test") batch_size = self.batch_size if self.batch_size != "doc" else 1 if hasattr(prep, '__getitem__'): return torch.utils.data.DataLoader(prep, shuffle=False, batch_size=batch_size, collate_fn=identity) else: return torch.utils.data.DataLoader(DummyIterableDataset( prep, None), shuffle=False, batch_size=batch_size, collate_fn=identity)
def fn(): if getattr(self, 'train_data', None) is None or self._is_resuming_finished_model: return [None] with fork_rng(self.data_seed): batch_size = self.batch_size if self.batch_size != "doc" else 1 non_default_epoch_length = ( self.trainer.val_check_interval * batch_size if (getattr(self, 'trainer', None) is not None and self.trainer.val_check_interval is not None and self.trainer.max_steps is not None) else None) if hasattr(self.train_data, '__getitem__') and non_default_epoch_length is None: prep = self.preprocess(self.train_data, split="train") return torch.utils.data.DataLoader(prep, shuffle=True, batch_size=batch_size, collate_fn=identity) elif non_default_epoch_length is not None and hasattr( self.train_data, '__len__'): if self.dynamic_preprocessing is True: prep = self.preprocess(loop(self.train_data, shuffle=True), split="train") else: prep = loop(self.preprocess(self.train_data, split="train"), shuffle=True) return torch.utils.data.DataLoader(DummyIterableDataset( prep, epoch_length=non_default_epoch_length), shuffle=False, batch_size=batch_size, collate_fn=identity) else: prep = self.preprocess(self.train_data, split="train") return torch.utils.data.DataLoader(DummyIterableDataset( prep, epoch_length=non_default_epoch_length), shuffle=False, batch_size=batch_size, collate_fn=identity)
def __init__(self, input_size, hidden_size, n_labels, do_biaffine=True, do_tagging=True, do_length=True, multi_label=True, threshold=0.5, max_length=100, max_fragments_count=100, detach_span_tag_logits=True, tag_loss_weight=0.2, biaffine_loss_weight=0.2, dropout_p=0.2, mode="seq", allow_overlap=True, learnable_transitions=False, marginal_tagger_loss=False, eps=1e-8): super().__init__() assert do_biaffine or do_tagging self.input_size = input_size self.hidden_size = hidden_size self.n_labels = n_labels if not multi_label: effective_n_labels = n_labels + 1 else: effective_n_labels = n_labels self.do_biaffine = do_biaffine self.do_tagging = do_tagging self.do_length = do_length self.multi_label = multi_label self.tag_loss_weight = tag_loss_weight self.biaffine_loss_weight = biaffine_loss_weight self.detach_span_tag_logits = detach_span_tag_logits self.threshold = threshold self.max_length = max_length self.max_fragments_count = max_fragments_count self.dropout = torch.nn.Dropout(dropout_p) with fork_rng(): if do_tagging: if self.do_tagging is True: self.do_tagging = "full" if self.do_tagging.startswith("full"): n_tags = 5 self.register_buffer( 'tag_combinator', torch.tensor([ [1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1], ]).float()) elif self.do_tagging.startswith("positive"): n_tags = 4 self.register_buffer( 'tag_combinator', torch.tensor([ [-1, 1, 0, 0, 0], [-1, 0, 1, 0, 0], [-1, 0, 0, 1, 0], [-1, 0, 0, 0, 1], ]).float()) elif self.do_tagging.startswith("shared_label_unary"): n_tags = 3 self.register_buffer( 'tag_combinator', torch.tensor([ [-1, 1, 1, 1, 1], [-1, -1, 1, 0, 1], [-1, -1, 0, 1, 1], ]).float()) elif self.do_tagging.startswith("shared_label"): n_tags = 4 self.register_buffer( 'tag_combinator', torch.tensor([ [-1, 1, 1, 1, 1], [-1, 0, 1, 0, 0], [-1, 0, 0, 1, 0], [-1, 0, 0, 0, 1], ]).float()) if self.do_tagging.endswith(":ffn"): self.tag_proj = TagFFN(input_size, n_labels, n_tags=n_tags, dropout_p=dropout_p) else: self.tag_proj = torch.nn.Linear(input_size, n_labels * n_tags) self.crf = BIOULDecoder( num_labels=1, with_start_end_transitions=True, allow_overlap=allow_overlap, learnable_transitions=learnable_transitions, ) self.tagger_loss = TaggerLoss(marginal=marginal_tagger_loss, _crf=self.crf) self.eps = eps with fork_rng(): if do_biaffine: self.length_proj = torch.nn.Linear( hidden_size, max_length) if self.do_length else None self.begin_proj = torch.nn.Linear(input_size, hidden_size * effective_n_labels, bias=True) self.end_proj = torch.nn.Linear(input_size, hidden_size * effective_n_labels, bias=True) self.biaffine_bias = torch.nn.Parameter(torch.zeros(())) self.biaffine_loss = BiaffineLoss(multi_label=self.multi_label) self.mode = mode
def __init__(self, _bert=None, bert_config=None, path=None, n_layers=4, combine_mode="softmax", bert_dropout_p=None, output_lm_embeds=False, token_dropout_p=0., dropout_p=0., word_pooler={"module": "pooler", "mode": "mean"}, proj_size=None, freeze_n_layers=-1, do_norm=True, do_cache=False, _preprocessor=None, ): super().__init__() assert not ("scaled" in combine_mode and do_norm) if do_cache: assert freeze_n_layers == -1, "Must freeze bert to enable caching: set freeze_n_layers=-1" with fork_rng(True): if output_lm_embeds: self.bert = _bert if _bert is not None else transformers.AutoModelForMaskedLM.from_pretrained(path, config=bert_config) if hasattr(self.bert, 'lm_head'): self.bert.lm_head.__class__ = LM_HEAD_CLS_MAPPING[self.bert.lm_head.__class__] else: self.bert.cls.predictions.__class__ = LM_HEAD_CLS_MAPPING[self.bert.cls.predictions.__class__] else: self.bert = _bert if _bert is not None else transformers.AutoModel.from_pretrained(path, config=bert_config) self.output_lm_embeds = output_lm_embeds self.n_layers = n_layers if n_layers > 1: with fork_rng(True): self.weight = torch.nn.Parameter(torch.zeros(n_layers)) if "softmax" in combine_mode else torch.nn.Parameter(torch.ones(n_layers) / n_layers) if combine_mode == "linear" else None with fork_rng(True): self.word_pooler = Pooler(**word_pooler) if word_pooler is not None else None if "scaled" in combine_mode: self.bert_scaling = torch.nn.Parameter(torch.ones(())) self.combine_mode = combine_mode bert_model = self.bert.bert if hasattr(self.bert, 'bert') else self.bert.roberta if hasattr(self.bert, 'roberta') else self.bert bert_output_size = bert_model.embeddings.word_embeddings.weight.shape[1] * (1 if combine_mode != "concat" else n_layers) self.bert_output_size = bert_output_size if proj_size is not None: self.proj = torch.nn.Linear(bert_output_size, proj_size) self._output_size = proj_size else: self.proj = None self._output_size = bert_output_size self.norm = torch.nn.LayerNorm(self._output_size) if do_norm else Identity() if freeze_n_layers < 0: freeze_n_layers = len(bert_model.encoder.layer) + 2 + freeze_n_layers for module in (bert_model.embeddings, *bert_model.encoder.layer)[:freeze_n_layers]: for param in module.parameters(): param.requires_grad = False if bert_dropout_p is not None: for module in bert_model.modules(): if isinstance(module, torch.nn.Dropout): module.p = bert_dropout_p self.dropout = torch.nn.Dropout(dropout_p) self.token_dropout_p = token_dropout_p self.cache = {} self.do_cache = do_cache