示例#1
0
    def __init__(self,
                 vocab_file,
                 max_len=None,
                 do_basic_tokenize=True,
                 never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
        """Constructs a BertTokenizer.

        Args:
          vocab_file: Path to a one-wordpiece-per-line vocabulary file
          do_basic_tokenize: Whether to do basic tokenization before wordpiece.
          max_len: An artificial maximum length to truncate tokenized sequences to;
                         Effective maximum length is always the minimum of this
                         value (if specified) and the underlying BERT model's
                         sequence length.
          never_split: List of tokens which will never be split during tokenization.
                         Only has an effect when do_wordpiece_only=False
        """
        if not os.path.isfile(vocab_file):
            raise ValueError(
                "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
                "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
                .format(vocab_file))
        self.vocab = load_vocab(vocab_file)
        self.ids_to_tokens = collections.OrderedDict([
            (ids, tok) for tok, ids in self.vocab.items()
        ])
        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
        self.max_len = max_len if max_len is not None else int(1e12)
示例#2
0
 def __init__(self, min_char: int, vocab_file: str, lower: bool,
              add_sentence_boundary: bool, add_word_boundary: bool,
              use_cuda: bool):
     super(WordPieceBatch,
           self).__init__(min_char=min_char,
                          lower=lower,
                          add_sentence_boundary=add_sentence_boundary,
                          add_word_boundary=add_word_boundary,
                          use_cuda=use_cuda)
     self.vocab = load_vocab(vocab_file=vocab_file)
     self.tokenizer = WordpieceTokenizer(vocab=self.vocab)
示例#3
0
 def from_config(cls, config: Config):
     basic_tokenizer = create_component(
         ComponentType.TOKENIZER, config.basic_tokenizer
     )
     vocab = load_vocab(config.wordpiece_vocab_path)
     wordpiece_tokenizer = WordpieceTokenizer(vocab=vocab)
     return cls(vocab, basic_tokenizer, wordpiece_tokenizer)
示例#4
0
    def test_wordpiece_tokenizer(self):
        vocab_tokens = [
            "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un",
            "runn", "##ing"
        ]

        vocab = {}
        for (i, token) in enumerate(vocab_tokens):
            vocab[token] = i
        tokenizer = WordpieceTokenizer(vocab=vocab)

        self.assertListEqual(tokenizer.tokenize(""), [])

        self.assertListEqual(tokenizer.tokenize("unwanted running"),
                             ["un", "##want", "##ed", "runn", "##ing"])

        self.assertListEqual(tokenizer.tokenize("unwantedX running"),
                             ["[UNK]", "runn", "##ing"])
示例#5
0
文件: subword.py 项目: seongl/claf
    def _wordpiece(self, text, unit="text"):
        """
        ex) Hello World -> ['Hello', 'World'] -> ['He', '##llo', 'Wo', '##rld']
        """
        if self.subword_tokenizer is None:
            vocab_path = self.data_handler.read(self.config["vocab_path"], return_path=True)
            vocab = load_vocab(vocab_path)
            self.subword_tokenizer = WordpieceTokenizer(vocab)

        tokens = []

        if unit == "word":
            for sub_token in self.subword_tokenizer.tokenize(text):
                tokens.append(sub_token)
        else:
            for token in self.word_tokenizer.tokenize(text):
                for sub_token in self.subword_tokenizer.tokenize(token):
                    tokens.append(sub_token)

        return tokens
示例#6
0
文件: subword.py 项目: seongl/claf
class SubwordTokenizer(Tokenizer):
    """
    Subword Tokenizer

    text -> [word tokens] -> [[sub word tokens], ...]

    * Args:
        name: tokenizer name [wordpiece]
    """

    def __init__(self, name, word_tokenizer, config={}):
        super(SubwordTokenizer, self).__init__(name, f"subword-{name}+{word_tokenizer.cache_name}")
        self.data_handler = DataHandler(CachePath.VOCAB)
        self.config = config
        self.word_tokenizer = word_tokenizer
        self.subword_tokenizer = None

    """ Tokenizers """

    def _wordpiece(self, text, unit="text"):
        """
        ex) Hello World -> ['Hello', 'World'] -> ['He', '##llo', 'Wo', '##rld']
        """
        if self.subword_tokenizer is None:
            vocab_path = self.data_handler.read(self.config["vocab_path"], return_path=True)
            vocab = load_vocab(vocab_path)
            self.subword_tokenizer = WordpieceTokenizer(vocab)

        tokens = []

        if unit == "word":
            for sub_token in self.subword_tokenizer.tokenize(text):
                tokens.append(sub_token)
        else:
            for token in self.word_tokenizer.tokenize(text):
                for sub_token in self.subword_tokenizer.tokenize(token):
                    tokens.append(sub_token)

        return tokens
示例#7
0
class BertLabelTokenizer:
    """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
    def __init__(self,
                 vocab_file,
                 max_len=None,
                 do_basic_tokenize=True,
                 never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
        """Constructs a BertTokenizer.

        Args:
          vocab_file: Path to a one-wordpiece-per-line vocabulary file
          do_basic_tokenize: Whether to do basic tokenization before wordpiece.
          max_len: An artificial maximum length to truncate tokenized sequences to;
                         Effective maximum length is always the minimum of this
                         value (if specified) and the underlying BERT model's
                         sequence length.
          never_split: List of tokens which will never be split during tokenization.
                         Only has an effect when do_wordpiece_only=False
        """
        if not os.path.isfile(vocab_file):
            raise ValueError(
                "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
                "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
                .format(vocab_file))
        self.vocab = load_vocab(vocab_file)
        self.ids_to_tokens = collections.OrderedDict([
            (ids, tok) for tok, ids in self.vocab.items()
        ])
        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
        self.max_len = max_len if max_len is not None else int(1e12)

    def tokenize(self, text):
        split_tokens = []
        token_begin_mask = []
        for token in text:
            wordpieces = self.wordpiece_tokenizer.tokenize(token)
            if len(wordpieces) > 0:
                for sub_token in wordpieces:
                    split_tokens.append(sub_token)
                token_begin_mask += [1] + [0] * (len(wordpieces) - 1)
        return split_tokens, token_begin_mask

    def tokenize_labels(self, text, labels):
        split_tokens = []
        split_labels = []
        token_begin_mask = []
        for token, label in zip(text, labels):
            wordpieces = self.wordpiece_tokenizer.tokenize(token)
            if len(wordpieces) > 0:
                for sub_token in wordpieces:
                    split_tokens.append(sub_token)
                split_labels += [label] + ["X"] * (len(wordpieces) - 1)
                token_begin_mask += [1] + [0] * (len(wordpieces) - 1)
        return split_tokens, split_labels, token_begin_mask

    def convert_tokens_to_ids(self, tokens):
        """Converts a sequence of tokens into ids using the vocab."""
        ids = []
        for token in tokens:
            ids.append(self.vocab[token])
        if len(ids) > self.max_len:
            logger.warning(
                "Token indices sequence length is longer than the specified maximum "
                " sequence length for this BERT model ({} > {}). Running this"
                " sequence through BERT will result in indexing errors".format(
                    len(ids), self.max_len))
        return ids

    def convert_ids_to_tokens(self, ids):
        """Converts a sequence of ids in wordpiece tokens using the vocab."""
        tokens = []
        for i in ids:
            tokens.append(self.ids_to_tokens[i])
        return tokens

    @classmethod
    def from_pretrained(cls,
                        pretrained_model_name_or_path,
                        cache_dir=None,
                        *inputs,
                        **kwargs):
        """
        Instantiate a PreTrainedBertModel from a pre-trained model file.
        Download and cache the pre-trained model file if needed.
        """
        if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
            vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[
                pretrained_model_name_or_path]
        else:
            vocab_file = pretrained_model_name_or_path
        if os.path.isdir(vocab_file):
            vocab_file = os.path.join(vocab_file, VOCAB_NAME)
        # redirect to the cache, if necessary
        try:
            resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
        except EnvironmentError:
            logger.error(
                "Model name '{}' was not found in model name list ({}). "
                "We assumed '{}' was a path or url but couldn't find any file "
                "associated to this path or url.".format(
                    pretrained_model_name_or_path,
                    ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
                    vocab_file))
            return None
        if resolved_vocab_file == vocab_file:
            logger.info("loading vocabulary file {}".format(vocab_file))
        else:
            logger.info("loading vocabulary file {} from cache at {}".format(
                vocab_file, resolved_vocab_file))
        if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
            # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
            # than the number of positional embeddings
            max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[
                pretrained_model_name_or_path]
            kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
        # Instantiate tokenizer.
        tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
        return tokenizer
示例#8
0
def main(param2val):

    # params
    params = Params.from_param2val(param2val)
    print(params, flush=True)

    #  paths
    project_path = Path(param2val['project_path'])
    save_path = Path(param2val['save_path'])
    srl_eval_path = project_path / 'perl' / 'srl-eval.pl'
    data_path_mlm = project_path / 'data' / 'training' / f'{params.corpus_name}_mlm.txt'
    data_path_train_srl = project_path / 'data' / 'training' / f'{params.corpus_name}_no-dev_srl.txt'
    data_path_devel_srl = project_path / 'data' / 'training' / f'human-based-2018_srl.txt'
    data_path_test_srl = project_path / 'data' / 'training' / f'human-based-2008_srl.txt'
    childes_vocab_path = project_path / 'data' / f'{params.corpus_name}_vocab.txt'
    google_vocab_path = project_path / 'data' / 'bert-base-cased.txt'  # to get word pieces

    # word-piece tokenizer - defines input vocabulary
    vocab = load_vocab(childes_vocab_path, google_vocab_path,
                       params.vocab_size)
    # TODO testing google vocab with wordpieces

    assert vocab['[PAD]'] == 0  # AllenNLP expects this
    assert vocab['[UNK]'] == 1  # AllenNLP expects this
    assert vocab['[CLS]'] == 2
    assert vocab['[SEP]'] == 3
    assert vocab['[MASK]'] == 4
    wordpiece_tokenizer = WordpieceTokenizer(vocab)
    print(f'Number of types in vocab={len(vocab):,}')

    # load utterances for MLM task
    utterances = load_utterances_from_file(data_path_mlm)
    train_utterances, devel_utterances, test_utterances = split(utterances)

    # load propositions for SLR task
    propositions = load_propositions_from_file(data_path_train_srl)
    train_propositions, devel_propositions, test_propositions = split(
        propositions)
    if data_path_devel_srl.is_file(
    ):  # use human-annotated data as devel split
        print(f'Using {data_path_devel_srl.name} as SRL devel split')
        devel_propositions = load_propositions_from_file(data_path_devel_srl)
    if data_path_test_srl.is_file():  # use human-annotated data as test split
        print(f'Using {data_path_test_srl.name} as SRL test split')
        test_propositions = load_propositions_from_file(data_path_test_srl)

    # converters handle conversion from text to instances
    converter_mlm = ConverterMLM(params, wordpiece_tokenizer)
    converter_srl = ConverterSRL(params, wordpiece_tokenizer)

    # get output_vocab
    # note: Allen NLP vocab holds labels, wordpiece_tokenizer.vocab holds input tokens
    # what from_instances() does:
    # 1. it iterates over all instances, and all fields, and all token indexers
    # 2. the token indexer is used to update vocabulary count, skipping words whose text_id is already set
    # 4. a PADDING and MASK symbol are added to 'tokens' namespace resulting in vocab size of 2
    # input tokens are not indexed, as they are already indexed by bert tokenizer vocab.
    # this ensures that the model is built with inputs for all vocab words,
    # such that words that occur only in LM or SRL task can still be input

    # make instances once - this allows iterating multiple times (required when num_epochs > 1)
    train_instances_mlm = converter_mlm.make_instances(train_utterances)
    devel_instances_mlm = converter_mlm.make_instances(devel_utterances)
    test_instances_mlm = converter_mlm.make_instances(test_utterances)
    train_instances_srl = converter_srl.make_instances(train_propositions)
    devel_instances_srl = converter_srl.make_instances(devel_propositions)
    test_instances_srl = converter_srl.make_instances(test_propositions)
    all_instances_mlm = chain(train_instances_mlm, devel_instances_mlm,
                              test_instances_mlm)
    all_instances_srl = chain(train_instances_srl, devel_instances_srl,
                              test_instances_srl)

    # make vocab from all instances
    output_vocab_mlm = Vocabulary.from_instances(all_instances_mlm)
    output_vocab_srl = Vocabulary.from_instances(all_instances_srl)
    # print(f'mlm vocab size={output_vocab_mlm.get_vocab_size()}')  # contain just 2 tokens
    # print(f'srl vocab size={output_vocab_srl.get_vocab_size()}')  # contain just 2 tokens
    assert output_vocab_mlm.get_vocab_size(
        'tokens') == output_vocab_srl.get_vocab_size('tokens')

    # BERT
    print('Preparing Multi-task BERT...')
    input_vocab_size = len(converter_mlm.wordpiece_tokenizer.vocab)
    bert_config = BertConfig(
        vocab_size_or_config_json_file=input_vocab_size,  # was 32K
        hidden_size=params.hidden_size,  # was 768
        num_hidden_layers=params.num_layers,  # was 12
        num_attention_heads=params.num_attention_heads,  # was 12
        intermediate_size=params.intermediate_size)  # was 3072
    bert_model = BertModel(config=bert_config)
    # Multi-tasking BERT
    mt_bert = MTBert(vocab_mlm=output_vocab_mlm,
                     vocab_srl=output_vocab_srl,
                     bert_model=bert_model,
                     embedding_dropout=params.embedding_dropout)
    mt_bert.cuda()
    num_params = sum(p.numel() for p in mt_bert.parameters()
                     if p.requires_grad)
    print('Number of model parameters: {:,}'.format(num_params), flush=True)

    # optimizers
    optimizer_mlm = BertAdam(params=mt_bert.parameters(), lr=params.lr)
    optimizer_srl = BertAdam(params=mt_bert.parameters(), lr=params.lr)
    move_optimizer_to_cuda(optimizer_mlm)
    move_optimizer_to_cuda(optimizer_srl)

    # batching
    bucket_batcher_mlm = BucketIterator(batch_size=params.batch_size,
                                        sorting_keys=[('tokens', "num_tokens")
                                                      ])
    bucket_batcher_mlm.index_with(output_vocab_mlm)
    bucket_batcher_srl = BucketIterator(batch_size=params.batch_size,
                                        sorting_keys=[('tokens', "num_tokens")
                                                      ])
    bucket_batcher_srl.index_with(output_vocab_srl)

    # big batcher to speed evaluation - 1024 is too big
    bucket_batcher_mlm_large = BucketIterator(batch_size=512,
                                              sorting_keys=[('tokens',
                                                             "num_tokens")])
    bucket_batcher_srl_large = BucketIterator(batch_size=512,
                                              sorting_keys=[('tokens',
                                                             "num_tokens")])
    bucket_batcher_mlm_large.index_with(output_vocab_mlm)
    bucket_batcher_srl_large.index_with(output_vocab_srl)

    # init performance collection
    name2col = {
        'devel_pps': [],
        'devel_f1s': [],
    }

    # init
    eval_steps = []
    train_start = time.time()
    loss_mlm = None
    no_mlm_batches = False
    step = 0

    # generators
    train_generator_mlm = bucket_batcher_mlm(train_instances_mlm,
                                             num_epochs=params.num_mlm_epochs)
    train_generator_srl = bucket_batcher_srl(
        train_instances_srl, num_epochs=None)  # infinite generator
    num_train_mlm_batches = bucket_batcher_mlm.get_num_batches(
        train_instances_mlm)
    if params.srl_interleaved:
        max_step = num_train_mlm_batches
    else:
        max_step = num_train_mlm_batches * 2
    print(f'Will stop training at step={max_step:,}')

    while step < max_step:

        # TRAINING
        if step != 0:  # otherwise evaluation at step 0 is influenced by training on one batch
            mt_bert.train()

            # masked language modeling task
            try:
                batch_mlm = next(train_generator_mlm)
            except StopIteration:
                if params.srl_interleaved:
                    break
                else:
                    no_mlm_batches = True
            else:
                loss_mlm = mt_bert.train_on_batch('mlm', batch_mlm,
                                                  optimizer_mlm)

            # semantic role labeling task
            if params.srl_interleaved:
                if random.random() < params.srl_probability:
                    batch_srl = next(train_generator_srl)
                    mt_bert.train_on_batch('srl', batch_srl, optimizer_srl)
            elif no_mlm_batches:
                batch_srl = next(train_generator_srl)
                mt_bert.train_on_batch('srl', batch_srl, optimizer_srl)

        # EVALUATION
        if step % config.Eval.interval == 0:
            mt_bert.eval()
            eval_steps.append(step)

            # evaluate perplexity
            devel_generator_mlm = bucket_batcher_mlm_large(devel_instances_mlm,
                                                           num_epochs=1)
            devel_pp = evaluate_model_on_pp(mt_bert, devel_generator_mlm)
            name2col['devel_pps'].append(devel_pp)
            print(f'devel-pp={devel_pp}', flush=True)

            # test sentences
            if config.Eval.test_sentences:
                test_generator_mlm = bucket_batcher_mlm_large(
                    test_instances_mlm, num_epochs=1)
                out_path = save_path / f'test_split_mlm_results_{step}.txt'
                predict_masked_sentences(mt_bert, test_generator_mlm, out_path)

            # probing - test sentences for specific syntactic tasks
            for name in config.Eval.probing_names:
                # prepare data
                probing_data_path_mlm = project_path / 'data' / 'probing' / f'{name}.txt'
                if not probing_data_path_mlm.exists():
                    print(f'WARNING: {probing_data_path_mlm} does not exist')
                    continue
                probing_utterances_mlm = load_utterances_from_file(
                    probing_data_path_mlm)
                # check that probing words are in vocab
                for u in probing_utterances_mlm:
                    # print(u)
                    for w in u:
                        if w == '[MASK]':
                            continue  # not in output vocab
                        # print(w)
                        assert output_vocab_mlm.get_token_index(
                            w, namespace='labels'), w
                # probing + save results to text
                probing_instances_mlm = converter_mlm.make_probing_instances(
                    probing_utterances_mlm)
                probing_generator_mlm = bucket_batcher_mlm(
                    probing_instances_mlm, num_epochs=1)
                out_path = save_path / f'probing_{name}_results_{step}.txt'
                predict_masked_sentences(mt_bert,
                                         probing_generator_mlm,
                                         out_path,
                                         print_gold=False,
                                         verbose=True)

            # evaluate devel f1
            devel_generator_srl = bucket_batcher_srl_large(devel_instances_srl,
                                                           num_epochs=1)
            devel_f1 = evaluate_model_on_f1(mt_bert, srl_eval_path,
                                            devel_generator_srl)

            name2col['devel_f1s'].append(devel_f1)
            print(f'devel-f1={devel_f1}', flush=True)

            # console
            min_elapsed = (time.time() - train_start) // 60
            pp = torch.exp(loss_mlm) if loss_mlm is not None else np.nan
            print(
                f'step {step:<6,}: pp={pp :2.4f} total minutes elapsed={min_elapsed:<3}',
                flush=True)

        # only increment step once in each iteration of the loop, otherwise evaluation may never happen
        step += 1

    # evaluate train perplexity
    if config.Eval.train_split:
        generator_mlm = bucket_batcher_mlm_large(train_instances_mlm,
                                                 num_epochs=1)
        train_pp = evaluate_model_on_pp(mt_bert, generator_mlm)
    else:
        train_pp = np.nan
    print(f'train-pp={train_pp}', flush=True)

    # evaluate train f1
    if config.Eval.train_split:
        generator_srl = bucket_batcher_srl_large(train_instances_srl,
                                                 num_epochs=1)
        train_f1 = evaluate_model_on_f1(mt_bert,
                                        srl_eval_path,
                                        generator_srl,
                                        print_tag_metrics=True)
    else:
        train_f1 = np.nan
    print(f'train-f1={train_f1}', flush=True)

    # test sentences
    if config.Eval.test_sentences:
        test_generator_mlm = bucket_batcher_mlm(test_instances_mlm,
                                                num_epochs=1)
        out_path = save_path / f'test_split_mlm_results_{step}.txt'
        predict_masked_sentences(mt_bert, test_generator_mlm, out_path)

    # probing - test sentences for specific syntactic tasks
    for name in config.Eval.probing_names:
        # prepare data
        probing_data_path_mlm = project_path / 'data' / 'probing' / f'{name}.txt'
        if not probing_data_path_mlm.exists():
            print(f'WARNING: {probing_data_path_mlm} does not exist')
            continue
        probing_utterances_mlm = load_utterances_from_file(
            probing_data_path_mlm)
        probing_instances_mlm = converter_mlm.make_probing_instances(
            probing_utterances_mlm)
        # batch and do inference
        probing_generator_mlm = bucket_batcher_mlm(probing_instances_mlm,
                                                   num_epochs=1)
        out_path = save_path / f'probing_{name}_results_{step}.txt'
        predict_masked_sentences(mt_bert,
                                 probing_generator_mlm,
                                 out_path,
                                 print_gold=False,
                                 verbose=True)

    # put train-pp and train-f1 into pandas Series
    s1 = pd.Series([train_pp], index=[eval_steps[-1]])
    s1.name = 'train_pp'
    s2 = pd.Series([train_f1], index=[eval_steps[-1]])
    s2.name = 'train_f1'

    # return performance as pandas Series
    series_list = [s1, s2]
    for name, col in name2col.items():
        print(f'Making pandas series with name={name} and length={len(col)}')
        s = pd.Series(col, index=eval_steps)
        s.name = name
        series_list.append(s)

    return series_list
示例#9
0
def test_WordpieceTokenizer():
    model = WordpieceTokenizer(
        tokenization.load_vocab(
            os.path.join(model_dir, "bert-base-cased-vocab.txt")))
    print(model.tokenize("decomposition deoomposition"))
示例#10
0
class WordPieceBatch(CharacterBatch, SpecialTokens):
    def __init__(self, min_char: int, vocab_file: str, lower: bool,
                 add_sentence_boundary: bool, add_word_boundary: bool,
                 use_cuda: bool):
        super(WordPieceBatch,
              self).__init__(min_char=min_char,
                             lower=lower,
                             add_sentence_boundary=add_sentence_boundary,
                             add_word_boundary=add_word_boundary,
                             use_cuda=use_cuda)
        self.vocab = load_vocab(vocab_file=vocab_file)
        self.tokenizer = WordpieceTokenizer(vocab=self.vocab)

    def create_one_batch(self, raw_dataset: List[List[str]]):
        batch_size = len(raw_dataset)
        seq_len = max([len(input_) for input_ in raw_dataset])
        if self.add_sentence_boundary:
            seq_len += 2

        sub_tokens = []
        for raw_data in raw_dataset:
            item = []
            for token in raw_data:
                if self.lower:
                    token = token.lower()
                item.append(self.tokenizer.tokenize(token))
            sub_tokens.append(item)
        max_char_len = max(
            [len(token) for item in sub_tokens for token in item])
        max_char_len = max(max_char_len, self.min_char)
        if self.add_word_boundary:
            max_char_len += 2

        batch = torch.LongTensor(batch_size, seq_len,
                                 max_char_len).fill_(self.pad_id)
        lengths = torch.LongTensor(batch_size, seq_len).fill_(1)
        for i, item in enumerate(sub_tokens):
            if self.add_sentence_boundary:
                item = [self.bos] + item + [self.eos]
            for j, token in enumerate(item):
                if self.add_sentence_boundary and (token == self.bos
                                                   or token == self.eos):
                    if self.add_word_boundary:
                        lengths[i, j] = 3
                        batch[i, j, 0] = self.mapping.get(self.bow)
                        batch[i, j, 1] = self.mapping.get(token)
                        batch[i, j, 2] = self.mapping.get(self.eow)
                    else:
                        lengths[i, j] = 1
                        batch[i, j, 0] = self.mapping.get(token)
                else:
                    if self.add_word_boundary:
                        lengths[i, j] = len(token) + 2
                        batch[i, j, 0] = self.mapping.get(self.bow)
                        for k, sub_token in enumerate(token):
                            batch[i, j, k + 1] = self.mapping.get(
                                sub_token, self.oov_id)
                        batch[i, j,
                              len(token) + 1] = self.mapping.get(self.eow)
                    else:
                        lengths[i, j] = len(token)
                        for k, sub_token in enumerate(token):
                            batch[i, j, k] = self.mapping.get(
                                sub_token, self.oov_id)

        if self.use_cuda:
            batch = batch.cuda()
            lengths = lengths.cuda()
        return batch, lengths

    def create_dict_from_dataset(self, raw_dataset: List[List[str]]):
        n_entries = 0
        for raw_data in raw_dataset:
            for token in raw_data:
                if self.lower:
                    token = token.lower()
                for sub_token in self.tokenizer.tokenize(token):
                    if sub_token not in self.mapping:
                        self.mapping[sub_token] = len(self.mapping)
                        n_entries += 1
        logger.info('+ loaded {0} entries from input'.format(n_entries))
        logger.info('+ current number of entries in mapping is: {0}'.format(
            len(self.mapping)))
from collections import defaultdict
import sys
import argparse
from utilities import *

from pytorch_pretrained_bert.tokenization import BertTokenizer, WordpieceTokenizer

parser = argparse.ArgumentParser(description='Dataset Settings')
ps = parser.add_argument
ps('--dataset', dest='dataset')
args = parser.parse_args()
args.bert_model = 'bert-base-uncased'
args.do_lower_case = 'True'

tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
wordpiece_tokenizer = WordpieceTokenizer(vocab=tokenizer.vocab)


dataset = 'data/{}'.format(args.dataset)

#from utilities import *

def flatten(l):
    return [item for sublist in l for item in sublist]

def load_reviews(fp):
    with open('./{}/{}.json'.format(dataset, fp),'r') as f:
        data = json.load(f)
    return data

def load_set(fp):
from pytorch_pretrained_bert.tokenization import BertTokenizer, WordpieceTokenizer

parser = argparse.ArgumentParser(description='Dataset Settings')
ps = parser.add_argument
ps('--dataset', dest='dataset')  # category
args = parser.parse_args()
args.bert_model = 'bert-base-uncased'
args.do_lower_case = 'True'

#main = ''
in_fp = './data/{}/{}_filter_flat_positive.large.json'.format(
    args.dataset, args.dataset)

tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                          do_lower_case=args.do_lower_case)
wordpiece_tokenizer = WordpieceTokenizer(vocab=tokenizer.vocab)

limit = 99999999999
min_reviews = 5
max_words = 1500

user_count = Counter()
item_count = Counter()
users = defaultdict(list)
items = defaultdict(list)

pairs = set()
interactions = defaultdict(list)

print("Building count dictionary first to save memory...")
#with gzip.open(in_fp, 'r') as f: