Пример #1
0
    def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None:
        self.tgt_dict = tgt_dict
        self.vocab_size = len(tgt_dict)
        self.nbest = cfg.nbest
        self.unitlm = cfg.unitlm

        if cfg.criterion == "ctc":
            self.criterion_type = CriterionType.CTC
            self.blank = (tgt_dict.index("<ctc_blank>") if "<ctc_blank>"
                          in tgt_dict.indices else tgt_dict.bos())
            if "<sep>" in tgt_dict.indices:
                self.silence = tgt_dict.index("<sep>")
            elif "|" in tgt_dict.indices:
                self.silence = tgt_dict.index("|")
            else:
                self.silence = tgt_dict.eos()
            self.asgtransitions = None
        elif cfg.criterion == "asg_loss":
            self.criterion_type = CriterionType.ASG
            self.blank = -1
            self.silence = -1
            self.asgtransitions = cfg.asgtransitions
            self.maxreplabel = cfg.maxreplabel
            assert len(self.asgtransitions) == self.vocab_size**2
        else:
            raise RuntimeError(f"unknown criterion: {cfg.criterion}")
Пример #2
0
def format_ascii(iteration: Iteration,
                 iteration_number: int,
                 token_dictionary: Dictionary,
                 expansion_dictionary: Dictionary,
                 expansion: ExpansionStrategy,
                 no_token: str = '-') -> str:
    tokens = token_dictionary.string(iteration.nlt).split(' ')
    tokens = [
        t if is_new_token else no_token
        for t, is_new_token in zip(tokens, iteration.new_token_mask)
    ]

    expansions = [
        expansion.pretty_format(e)
        for e in expansion_dictionary.string(iteration.nle).split(' ')
    ]
    expansions = [
        e if is_new_token else no_token
        for e, is_new_token in zip(expansions, iteration.new_token_mask)
    ]

    s = f'iteration {iteration_number}\n'
    s += f'  PLT: ' + token_dictionary.string(iteration.plt) + '\n'
    s += f'  NLT: ' + ' '.join(tokens) + '\n'
    s += f'  NLE: ' + ' '.join(expansions)
    return s
Пример #3
0
def record2transition_dict(record: Dict[str, np.ndarray],
                           token_dictionary: Dictionary,
                           expansion_dictionary: Dictionary,
                           null_expansion: str) -> Transition:

    prev_tokens = token_dictionary.string(
        record[KEY_PREV_LEVEL_TOKENS].tolist()).split(' ')

    next_tokens = token_dictionary.string(
        record[KEY_NEXT_LEVEL_TOKENS].tolist()).split(' ')

    loss_mask = [
        0 if t == token_dictionary.pad_word else 1 for t in next_tokens
    ]

    next_tokens = unmask_tokens(next_tokens, prev_tokens, loss_mask)

    next_expans = expansion_dictionary.string(
        record[KEY_NEXT_LEVEL_EXPANS].tolist()).split(' ')
    next_expans = unmask_expansions(next_expans, loss_mask, null_expansion)

    head_positions = record[KEY_HEAD_POSITIONS].tolist()

    return Transition(
        previous_level_tokens=prev_tokens,
        loss_mask=loss_mask,
        next_level_tokens=next_tokens,
        next_level_expansions=next_expans,
        heads=head_positions,
    )
Пример #4
0
    def __init__(self,
                 args,
                 token_dictionary: Dictionary,
                 expansion_dictionary: Dictionary,
                 expansion_strategy: ExpansionStrategy,
                 device: Optional[Union[torch.device, str]] = None,
                 regenerate_tokens: bool = False,
                 temperature: float = 1.0):
        self.temperature = temperature

        self.inference = IterativeInference(token_dictionary,
                                            expansion_dictionary,
                                            expansion_strategy,
                                            device,
                                            mask_unk=True)
        self.dependency_placeholder_ids = {
            token_dictionary.index(t)
            for t in expansion_strategy.get_dependency_placeholders()
        }

        def expand(e):
            left_deps, right_deps = expansion_strategy.expand_deps(e)
            left_dep_idxs = [token_dictionary.index(t) for t in left_deps]
            right_dep_idxs = [token_dictionary.index(t) for t in right_deps]
            return left_dep_idxs, right_dep_idxs

        self.expansions = {
            expansion_dictionary.index(e): expand(e)
            for e in expansion_dictionary.symbols
        }
        self.regenerate_tokens = regenerate_tokens
Пример #5
0
 def __init__(self, bpe, dictionary: Dictionary):
     self.bpe = bpe
     self.vocab = Vocabulary(
         dictionary.symbols,
         pad_token=str(dictionary[dictionary.pad()]),
         bos_token=str(dictionary[dictionary.bos()]),
         eos_token=str(dictionary[dictionary.eos()]),
     )
     self.bos = self.vocab.bos_token
     self.eos = self.vocab.eos_token
Пример #6
0
    def build(self,
              filepath=None,
              vocab_path=None,
              threshold=-1,
              max_vocab=-1):
        if vocab_path and os.path.exists(vocab_path):
            print("loading vocab from {}".format(vocab_path))
            d = Dictionary.load(vocab_path)
            print('vocab size {}'.format(len(d)))
        else:
            print("building vocab...")
            d = Dictionary()
            for step, line in enumerate(sentence_iterator(filepath)):
                if not step % 1000:
                    print("working on {}kth line".format(step // 1000),
                          end='\r')
                tokens = [self.get_lemma(w) for w in line]
                for tok in tokens:
                    d.add_symbol(tok)
            d.finalize(threshold=threshold, nwords=max_vocab)
            print('build done. vocab size {}'.format(len(d)))
            d.save('{}/dict.txt'.format(self.data_dir))

        self.vocab = d
        self.unk = self.vocab.unk()
Пример #7
0
 def setUp(self):
     self.vocab = Dictionary()
     self.vocab.add_token_to_namespace("sentence", namespace='words')
     self.vocab.add_token_to_namespace("A", namespace='words')
     self.vocab.add_token_to_namespace("A", namespace='characters')
     self.vocab.add_token_to_namespace("s", namespace='characters')
     self.vocab.add_token_to_namespace("e", namespace='characters')
     self.vocab.add_token_to_namespace("n", namespace='characters')
     self.vocab.add_token_to_namespace("t", namespace='characters')
     self.vocab.add_token_to_namespace("c", namespace='characters')
     super(TestTextField, self).setUp()
Пример #8
0
    def __init__(self, tgt_dict: Dictionary) -> None:
        self.tgt_dict = tgt_dict
        self.vocab_size = len(tgt_dict)

        self.blank = (tgt_dict.index("<ctc_blank>")
                      if "<ctc_blank>" in tgt_dict.indices else tgt_dict.bos())
        if "<sep>" in tgt_dict.indices:
            self.silence = tgt_dict.index("<sep>")
        elif "|" in tgt_dict.indices:
            self.silence = tgt_dict.index("|")
        else:
            self.silence = tgt_dict.eos()
Пример #9
0
    def rebuild_vocab(self):
        self._vocab = Dictionary()
        self._vocab.add_symbol(self.mask_builder.mask_token)
        desc = 'build-vocab: {}'.format(self.save_path)
        pbar = tqdm(range(len(self.dataset)), desc=desc, leave=True)

        for i in pbar:
            contents = self.dataset[i]
            tokens = self.tokenizer(contents)
            for token in tokens:
                self._vocab.add_symbol(token)

        if self.save_path is not None:
            self._vocab.save(self.vocab_path)
Пример #10
0
def build_fairseq_vocab(
    vocab_file: str,
    dictionary_class: Dictionary = Dictionary,
    special_token_replacements: Dict[str, SpecialToken] = None,
    max_vocab: int = -1,
    min_count: int = -1,
    tokens_to_add: Optional[List[str]] = None,
):
    """
    Function builds a PyText vocabulary for models pre-trained using Fairseq
    modules. The dictionary class can take any Fairseq Dictionary class
    and is used to load the vocab file.
    """
    if not special_token_replacements:
        special_token_replacements = {
            "<pad>": SpecialTokens.PAD,
            "<s>": SpecialTokens.BOS,
            "</s>": SpecialTokens.EOS,
            "<unk>": SpecialTokens.UNK,
            "<mask>": SpecialTokens.MASK,
        }
    with PathManager.open(vocab_file) as f:
        dictionary = dictionary_class.load(f)
        # finalize will sort the dict based on frequency so only do this if
        # a min_count or max_vocab size is specified
        if min_count > 0 or max_vocab > 0:
            dictionary.finalize(threshold=min_count, nwords=max_vocab, padding_factor=1)
        if tokens_to_add:
            for token in tokens_to_add:
                dictionary.add_symbol(token)
        return Vocabulary(
            dictionary.symbols,
            dictionary.count,
            replacements=special_token_replacements,
        )
Пример #11
0
def tasks_and_vocab_from_params(params: Params, serialization_dir: str) -> Tuple[List[Task], Dictionary]:
  """
  """
  task_list = []
  instances_for_vocab_creation = itertools.chain()
  datasets_for_vocab_creation = {}
  task_keys = [key for key in params.keys() if re.search("^task_", key)]

  for key in task_keys:
    LOGGER.info("Creating task '{}'".format(key))
    task_params = params.pop(key)
    task_description = task_params.pop("task_description")
    task_data_params = task_params.pop("data_params")

    task = Task.from_params(params=task_description)
    task_list.append(task)

    task_instances_for_vocab, task_datasets_for_vocab = task.setup_data(params=task_data_params)
    instances_for_vocab_creation = itertools.chain(instances_for_vocab_creation, task_instances_for_vocab)
    datasets_for_vocab_creation[task.name] = task_datasets_for_vocab

  # Create and save the dictionary
  for task_name, task_dataset_list in datasets_for_vocab_creation.items():
    LOGGER.info("creating dictionary for '{} from '{}'".format(task_name, ', '.join(task_dataset_list)))

  LOGGER.info('fitting dictionary from dataset')
  vocab = Dictionary.from_params(params.pop("dictionary", {}), instances_for_vocab_creation)

  # vocab save_to_files

  return task_list, vocab
Пример #12
0
def build_fairseq_vocab(
    vocab_file: str,
    dictionary_class: Dictionary = Dictionary,
    special_token_replacements: Dict[str, Token] = None,
    max_vocab: int = -1,
    min_count: int = -1,
    tokens_to_add: Optional[List[str]] = None,
) -> Vocabulary:
    """
    Function builds a PyText vocabulary for models pre-trained using Fairseq
    modules. The dictionary class can take any Fairseq Dictionary class
    and is used to load the vocab file.
    """
    dictionary = dictionary_class.load(vocab_file)
    # finalize will sort the dict based on frequency so only do this if
    # a min_count or max_vocab size is specified
    if min_count > 0 or max_vocab > 0:
        dictionary.finalize(threshold=min_count,
                            nwords=max_vocab,
                            padding_factor=1)
    if tokens_to_add:
        for token in tokens_to_add:
            dictionary.add_symbol(token)
    return Vocabulary(dictionary.symbols,
                      dictionary.count,
                      replacements=special_token_replacements)
Пример #13
0
    def test_index_converts_field_correctly(self):
        vocab = Dictionary()
        b_index = vocab.add_token_to_namespace("B", namespace='*labels')
        i_index = vocab.add_token_to_namespace("I", namespace='*labels')
        o_index = vocab.add_token_to_namespace("O", namespace='*labels')

        tags = ["B", "I", "O", "O", "O"]
        sequence_label_field = SequenceLabelField(tags,
                                                  self.text,
                                                  label_namespace="*labels")
        sequence_label_field.index(vocab)

        # pylint: disable=protected-access
        assert sequence_label_field.indexed_labels == [
            b_index, i_index, o_index, o_index, o_index
        ]
Пример #14
0
    def single_dictionary(self, src_lang, tgt_lang):
        from fairseq.data.dictionary import Dictionary
        dictionary = Dictionary()
        vocab = set()

        # Control tokens
        tgt_lang_token = language_token(tgt_lang)
        vocab.add(tgt_lang_token)

        tokenizer_vocab = self.tokenizer[src_lang].vocab
        vocab = vocab.union(tokenizer_vocab)
        vocab = sorted(list(vocab))

        for word in vocab:
            dictionary.add_symbol(word)

        return dictionary
Пример #15
0
    def from_config(cls, config: Config):
        dictionary = Dictionary.load(config.token_dictionary_path)
        bpe = create_gpt2_bpe(config.bpe_encoder_path, config.bpe_vocab_path)
        # This hacks the bpe instance to be picklable
        bpe = copy.copy(bpe)
        bpe.__class__ = PickleableGPT2BPEEncoder

        return cls(bpe, dictionary)
Пример #16
0
    def test_label_field_can_index_with_vocab(self):
        vocab = Dictionary()
        vocab.add_token_to_namespace("entailment", namespace="labels")
        vocab.add_token_to_namespace("contradiction", namespace="labels")
        vocab.add_token_to_namespace("neutral", namespace="labels")

        label = LabelField("entailment")
        label.index(vocab)
        tensor = label.as_tensor(label.get_padding_lengths())
        assert tensor.item() == 0
Пример #17
0
    def test_index_converts_field_correctly(self):
        vocab = Dictionary()
        sentence_index = vocab.add_token_to_namespace("sentence",
                                                      namespace='words')
        capital_a_index = vocab.add_token_to_namespace("A", namespace='words')
        capital_a_char_index = vocab.add_token_to_namespace(
            "A", namespace='characters')
        s_index = vocab.add_token_to_namespace("s", namespace='characters')
        e_index = vocab.add_token_to_namespace("e", namespace='characters')
        n_index = vocab.add_token_to_namespace("n", namespace='characters')
        t_index = vocab.add_token_to_namespace("t", namespace='characters')
        c_index = vocab.add_token_to_namespace("c", namespace='characters')

        field = TextField([Token(t) for t in ["A", "sentence"]],
                          {"words": SingleIdTokenIndexer(namespace="words")})
        field.index(vocab)
        # pylint: disable=protected-access
        assert field._indexed_tokens["words"] == [
            capital_a_index, sentence_index
        ]

        field1 = TextField(
            [Token(t) for t in ["A", "sentence"]], {
                "characters":
                TokenCharacterIndexer(namespace="characters",
                                      min_padding_length=1)
            })
        field1.index(vocab)
        assert field1._indexed_tokens["characters"] == [[capital_a_char_index],
                                                        [
                                                            s_index, e_index,
                                                            n_index, t_index,
                                                            e_index, n_index,
                                                            c_index, e_index
                                                        ]]
        field2 = TextField(
            [Token(t) for t in ["A", "sentence"]],
            token_indexers={
                "words":
                SingleIdTokenIndexer(namespace="words"),
                "characters":
                TokenCharacterIndexer(namespace="characters",
                                      min_padding_length=1)
            })
        field2.index(vocab)
        assert field2._indexed_tokens["words"] == [
            capital_a_index, sentence_index
        ]
        assert field2._indexed_tokens["characters"] == [[capital_a_char_index],
                                                        [
                                                            s_index, e_index,
                                                            n_index, t_index,
                                                            e_index, n_index,
                                                            c_index, e_index
                                                        ]]
def output_trained_embeddings_to_file(emb, dict_path, tgt_path):
    emb_dict = Dictionary.load(dict_path)
    emb = emb.data
    with open(tgt_path, 'w') as f:
        sys.stdout = f
        print(emb.shape[0], emb.shape[1])
        for i in range(emb.shape[0]):
            print(emb_dict.symbols[i], ' '.join(['%f' % x for x in emb[i]]))
        sys.stdout = sys.__stdout__
Пример #19
0
 def load_model(cls, vocab_path, model_path, embedding_size=300, cpu=False):
     d = Dictionary.load(vocab_path)
     vocab_size = len(d)
     model = Word2Vec(vocab_size=vocab_size, embedding_size=embedding_size)
     sgns = SGNS(embedding=model, vocab_size=vocab_size, n_negs=1, weights=None)
     sgns.load_state_dict(torch.load(model_path))
     sgns.eval()
     use_cuda = torch.cuda.is_available() and not cpu
     return cls(sgns, d, use_cuda)
Пример #20
0
 def tokens_to_indices(
     self, tokens: List[Token], vocabulary: Dictionary, index_name: str
 ) -> Dict[str, List[int]]:  # pylint: disable=unused-argument
     return {
             "token_ids": [10, 15] + \
                       [vocabulary.get_token_index(token.text, 'words') for token in tokens] + \
                       [25],
             "additional_key": [22, 29]
     }
Пример #21
0
    def __init__(
        self,
        dictionary: Dictionary,
        embed_dim: int = 512,
        hidden_size: int = 512,
        out_embed_dim: int = 512,
        num_layers: int = 1,
        dropout_in: float = 0.1,
        dropout_out: float = 0.1,
        attention: bool = True,
        encoder_embed_dim: int = 512,
        encoder_output_units: int = 512,
        pretrained_embed: Optional[nn.Embedding] = None,
        share_input_output_embed: bool = False,
        adaptive_softmax_cutoff: Optional[int] = None,
    ):
        super().__init__(dictionary)
        self.dropout_in = dropout_in
        self.dropout_out = dropout_out
        self.hidden_size = hidden_size
        self.share_input_output_embed = share_input_output_embed
        self.need_attn = True

        self.adaptive_softmax = None
        num_embeddings = len(dictionary)
        padding_idx = dictionary.pad()
        if pretrained_embed is None:
            self.embed_tokens = Embedding(num_embeddings, embed_dim,
                                          padding_idx)
        else:
            self.embed_tokens = pretrained_embed

        self.encoder_output_units = encoder_output_units

        self.layers = nn.ModuleList([
            LSTMCell(
                input_size=hidden_size +
                embed_dim if layer == 0 else hidden_size,
                hidden_size=hidden_size,
            ) for layer in range(num_layers)
        ])
        self.attention = AttentionLayer(hidden_size, encoder_output_units,
                                        hidden_size) if attention else None
        if hidden_size != out_embed_dim:
            self.additional_fc = Linear(hidden_size, out_embed_dim)
        if adaptive_softmax_cutoff is not None:
            # setting adaptive_softmax dropout to dropout_out for now but can be redefined
            self.adaptive_softmax = AdaptiveSoftmax(num_embeddings,
                                                    embed_dim,
                                                    adaptive_softmax_cutoff,
                                                    dropout=dropout_out)
        elif not self.share_input_output_embed:
            self.fc_out = Linear(out_embed_dim,
                                 num_embeddings,
                                 dropout=dropout_out)
Пример #22
0
def build_fairseq_vocab(
        vocab_file: str,
        dictionary_class: Dictionary = Dictionary,
        special_token_replacements: Dict[str, str] = None,
        unk_token: str = "<unk>",
        max_vocab: int = -1,
        min_count: int = -1,
        tokens_to_add: Optional[List[str]] = None,
):
    """Function builds a torchtext Vocab for models pre-trained using Fairseq
    modules.
    The dictionary class can take any Fairseq Dictionary class and is
    used to load the vocab file.
    """
    if not special_token_replacements:
        special_token_replacements = {
            "<pad>": "__PAD__",
            "<s>": "__BEGIN_OF_SENTENCE__",
            "</s>": "__END_OF_SENTENCE__",
            "<unk>": "__UNKNOWN__",
            "<mask>": "__MASK__",
        }
        unk_replacement = special_token_replacements[
            unk_token] if unk_token in special_token_replacements else unk_token
        special_tokens_to_remove = [special_pair[0] for special_pair in special_token_replacements]
        special_tokens_to_add = tuple(
            special_pair[1] for special_pair in special_token_replacements if special_pair[0] != unk_token)

    with open(vocab_file) as f:
        dictionary = dictionary_class.load(f)
        # finalize will sort the dict based on frequency so only do this if
        # a min_count or max_vocab size is specified
        if min_count > 0 or max_vocab > 0:
            dictionary.finalize(threshold=min_count, nwords=max_vocab, padding_factor=1)
        if tokens_to_add:
            for token in tokens_to_add:
                dictionary.add_symbol(token)

        dictionary_items = list(zip(dictionary.symbols, dictionary.count))

        ordered_dict = OrderedDict()
        # add special tokens to beginning of ordered_dict
        for s in special_tokens_to_add:
            ordered_dict[s] = 1

        # add all other tokens from dictionary_items
        for token, freq in dictionary_items:
            ordered_dict[token] = freq

        # remove special_tokens_to_remove from dict
        for s in special_tokens_to_remove:
            if s in ordered_dict:
                del ordered_dict[s]

        return vocab(ordered_dict, unk_token=unk_replacement)
Пример #23
0
def initalize_kaldi(cfg: KaldiInitializerConfig) -> Path:
    if cfg.fst_dir is None:
        cfg.fst_dir = osp.join(cfg.data_dir, "kaldi")
    if cfg.out_labels is None:
        cfg.out_labels = cfg.in_labels

    kaldi_root = Path(cfg.kaldi_root)
    data_dir = Path(cfg.data_dir)
    fst_dir = Path(cfg.fst_dir)
    fst_dir.mkdir(parents=True, exist_ok=True)

    arpa_base = osp.splitext(osp.basename(cfg.lm_arpa))[0]
    unique_label = f"{cfg.in_labels}.{arpa_base}"

    with open(data_dir / f"dict.{cfg.in_labels}.txt", "r") as f:
        vocab = Dictionary.load(f)

    in_units_file = create_units(fst_dir, cfg.in_labels, vocab)

    grammar_graph, out_words_file = create_G(kaldi_root, fst_dir,
                                             Path(cfg.lm_arpa), arpa_base)

    disambig_lexicon_file, disambig_L_in_units_file = create_lexicon(
        cfg, fst_dir, unique_label, in_units_file, out_words_file)

    h_graph, h_out_units_file, disambig_in_units_file_int = create_H(
        kaldi_root,
        fst_dir,
        disambig_L_in_units_file,
        cfg.in_labels,
        vocab,
        cfg.blank_symbol,
        cfg.silence_symbol,
    )
    lexicon_graph = create_L(
        kaldi_root,
        fst_dir,
        unique_label,
        disambig_lexicon_file,
        disambig_L_in_units_file,
        out_words_file,
    )
    lg_graph = create_LG(kaldi_root, fst_dir, unique_label, lexicon_graph,
                         grammar_graph)
    hlga_graph = create_HLGa(kaldi_root, fst_dir, unique_label, h_graph,
                             lg_graph, disambig_in_units_file_int)
    hlg_graph = create_HLG(kaldi_root, fst_dir, unique_label, hlga_graph)

    # for debugging
    # hla_graph = create_HLa(kaldi_root, fst_dir, unique_label, h_graph, lexicon_graph, disambig_in_units_file_int)
    # hl_graph = create_HLG(kaldi_root, fst_dir, unique_label, hla_graph, prefix="HL_looped")
    # create_HLG(kaldi_root, fst_dir, "phnc", h_graph, prefix="H_looped")

    return hlg_graph
Пример #24
0
    def dictionary(self):
        from fairseq.data.dictionary import Dictionary
        dictionary = Dictionary()

        vocab = set()

        # Add language_tokens
        langs, _ = list(zip(*self.tokenizer.keys()))
        langs = list(map(language_token, langs))
        vocab = vocab.union(set(langs))

        for key in self.tokenizer:
            tokenizer_vocab = self.tokenizer[key].vocab
            vocab = vocab.union(tokenizer_vocab)

        vocab = sorted(list(vocab))
        for word in vocab:
            dictionary.add_symbol(word)


        return dictionary
Пример #25
0
class VocabBuilder:
    def __init__(self, dataset, tokenizer, save_path):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.vocab_path = os.path.join(save_path, 'vocab.pt')
        self._vocab = None

    def vocab(self):
        if self._vocab is None:
            self.build_vocab()
        return self._vocab

    def build_vocab(self):
        if os.path.exists(self.vocab_path):
            self._vocab = Dictionary.load(self.vocab_path)
        else:
            self.rebuild_vocab()

    def rebuild_vocab(self):
        self._vocab = Dictionary()
        self._vocab.add_symbol(self.mask_builder.mask_token)
        desc = 'build-vocab: {}'.format(self.save_path)
        pbar = tqdm(range(len(self.dataset)), desc=desc, leave=True)

        for i in pbar:
            contents = self.dataset[i]
            tokens = self.tokenizer(contents)
            for token in tokens:
                self._vocab.add_symbol(token)

        if self.save_path is not None:
            self._vocab.save(self.vocab_path)
Пример #26
0
    def __init__(self,
                 token_dictionary: Dictionary,
                 expansion_dictionary: Dictionary,
                 expansion_strategy: ExpansionStrategy,
                 device=None,
                 mask_unk: bool = True):

        assert token_dictionary.pad_index == expansion_dictionary.pad_index
        self.device = device or torch.device('cpu')
        self.pad_idx = token_dictionary.pad_index
        self.token_dictionary = token_dictionary
        self.expansion_dictionary = expansion_dictionary
        self.expansion_strategy = expansion_strategy
        self.root_token_id = token_dictionary.index(expansion_strategy.root_node_token())

        minus_inf = float('-inf')
        # create mask to use in the token softmax later
        token_prob_mask = np.zeros(shape=(len(token_dictionary)), dtype=np.float32)
        for dep_placeholder in expansion_strategy.get_dependency_placeholders():
            index_to_mask = token_dictionary.index(dep_placeholder)
            if index_to_mask == token_dictionary.unk_index:
                continue   # symbol not found (e.g. [subword] if no subwords) ==> skip
            token_prob_mask[index_to_mask] = minus_inf
        special_token_idxs = [token_dictionary.pad_index, token_dictionary.eos_index]
        if mask_unk:
            special_token_idxs += [token_dictionary.unk_index]
        for special_token_idx in special_token_idxs:
            token_prob_mask[special_token_idx] = minus_inf
        self.token_prob_mask = torch.from_numpy(token_prob_mask).to(device)

        # create mask to use in the expansion softmax later
        expansion_prob_mask = np.zeros(shape=(len(expansion_dictionary)), dtype=np.float32)
        special_idxs = [expansion_dictionary.pad_index, expansion_dictionary.eos_index]
        if mask_unk:
            special_idxs = [expansion_dictionary.unk_index]
        for special_idx in special_idxs:
            expansion_prob_mask[special_idx] = minus_inf
        self.expansion_prob_mask = torch.from_numpy(expansion_prob_mask).to(device)
    def test_token2indices_correct_characters(self):
        vocab = Dictionary()
        vocab.add_token_to_namespace("A", namespace='characters')  # 2
        vocab.add_token_to_namespace("s", namespace='characters')  # 3
        vocab.add_token_to_namespace("e", namespace='characters')  # 4
        vocab.add_token_to_namespace("n", namespace='characters')  # 5
        vocab.add_token_to_namespace("t", namespace='characters')  # 6
        vocab.add_token_to_namespace("c", namespace='characters')  # 7

        indexer = TokenCharacterIndexer("characters", min_padding_length=1)
        indices = indexer.tokens_to_indices([Token("sentential")], vocab,
                                            "char")

        expected_ = {"char": [[3, 4, 5, 6, 4, 5, 6, 1, 1, 1]]}

        assert indices == expected_
Пример #28
0
def train(args):
    d = Dictionary.load(args.vocab)
    wf = np.array(d.count)
    wf[wf == 0] = 1
    wf = wf / wf.sum()
    ws = 1 - np.sqrt(args.ss_t / wf)
    ws = np.clip(ws, 0, 1)
    vocab_size = len(d)
    weights = wf if args.weights else None
    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)
    model = Word2Vec(vocab_size=vocab_size, embedding_size=args.e_dim)
    modelpath = os.path.join(args.save_dir, '{}.pt'.format(args.name))
    sgns = SGNS(embedding=model,
                vocab_size=vocab_size,
                n_negs=args.n_negs,
                weights=weights,
                pad=d.unk())
    if os.path.isfile(modelpath) and args.conti:
        sgns.load_state_dict(t.load(modelpath))
    if args.cuda:
        sgns = sgns.cuda()
    optim = Adam(sgns.parameters())
    optimpath = os.path.join(args.save_dir, '{}.optim.pt'.format(args.name))
    if os.path.isfile(optimpath) and args.conti:
        optim.load_state_dict(t.load(optimpath))
    dataset = PermutedSubsampledCorpus(args.data, ws=ws)
    dataloader = DataLoader(dataset,
                            batch_size=args.mb,
                            shuffle=True,
                            num_workers=0)
    for epoch in range(1, args.epoch + 1):
        total_batches = int(np.ceil(len(dataset) / args.mb))
        pbar = tqdm(dataloader)
        pbar.set_description("[Epoch {}]".format(epoch))
        for iword, owords in pbar:
            loss = sgns(iword, owords)
            optim.zero_grad()
            loss.backward()
            optim.step()
            pbar.set_postfix(loss=loss.item())

        t.save(
            sgns.state_dict(),
            os.path.join(args.save_dir, '{}-e{}.pt'.format(args.name, epoch)))
        t.save(
            optim.state_dict(),
            os.path.join(args.save_dir,
                         '{}-e{}.optim.pt'.format(args.name, epoch)))
Пример #29
0
def format_latex(iteration: Iteration,
                 iteration_number: int,
                 token_dictionary: Dictionary,
                 expansion_dictionary: Dictionary,
                 expansion: ExpansionStrategy,
                 no_token: str = '-') -> str:

    plt = token_dictionary.string(iteration.plt).split(' ')

    nlt = token_dictionary.string(iteration.nlt).split(' ')
    nlt = [
        t if is_new_token else no_token
        for t, is_new_token in zip(nlt, iteration.new_token_mask)
    ]

    nle = [
        expansion.pretty_format(e)
        for e in expansion_dictionary.string(iteration.nle).split(' ')
    ]
    nle = [
        e if is_new_token else no_token
        for e, is_new_token in zip(nle, iteration.new_token_mask)
    ]

    num_elems = len(nlt)

    s = '\\begin{tabularx}{\\linewidth}{p{6mm} ' + ' '.join(
        ['c'] * num_elems) + '}\n'
    s += '\\multicolumn{' + str(num_elems) + '}{l}{Iteration ' + str(
        iteration_number + 1) + '}\\\\\n'
    s += '\\hline\n'
    s += 'PLT: & ' + ' & '.join(maybe_tt(t) for t in plt) + '\\\\\n'
    s += 'NLT: & ' + ' & '.join(maybe_tt(t) for t in nlt) + '\\\\\n'
    s += 'NLE: & ' + ' & '.join(maybe_tt(e) for e in nle) + '\\\\\n'
    s += '\\end{tabularx}'
    return s
Пример #30
0
    def tokens_to_indices(self, tokens: List[Token], vocabulary: Dictionary,
                          index_name: str):
        indices: List[int] = []

        for token in itertools.chain(self.start_tokens, tokens,
                                     self.end_tokens):
            if getattr(token, 'text_id', None) is not None:
                # `text_id` being set on the token means that we aren't using the vocab, we just use
                # this id instead.
                indices.append(token.text_id)
            else:
                text = token.text
                if self.lowercase_tokens:
                    text = text.lower()
                indices.append(vocabulary.get_token_index(
                    text, self.namespace))

        return {index_name: indices}