def tokenize(
    texts: List[str],
    tokenizer: BertWordPieceTokenizer,
    flat_map: bool = False,
) -> Union[List[str], List[List[str]]]:
    """Tokenize texts using BERT WordPiece tokenizer implemented in Rust.

    Arguments:
        texts {List[str]} -- Text data to tokenize
        tokenizer {BertWordPieceTokenizer}
            -- A BertWordPieceTokenizer from the `tokenizers` library
        flat_map {bool} -- If True, flat maps results into a List[str],
                           instead of List[List[str]].

    Returns:
        A tokenized string or a list of tokenized string.
    """
    # Instantiate the tokenizer
    if not hasattr(tokenizer, 'encode_batch'):
        raise AttributeError(f'Provided `tokenizer` is not from `tokenizers` '
                             'library.')

    if flat_map:
        tokenized = [
            t for enc in tokenizer.encode_batch(texts) for t in enc.tokens
        ]
    else:
        tokenized = [enc.tokens for enc in tokenizer.encode_batch(texts)]
    return tokenized
Exemple #2
0
class BertTokenizer():
    def __init__(self, newa_vocab_path, eng_vocab_path):
        self.src_tokenizer = BertWordPieceTokenizer(newa_vocab_path,
                                                    lowercase=True)
        self.tgt_tokenizer = BertWordPieceTokenizer(eng_vocab_path,
                                                    lowercase=True)

    def encode(self, src_sents, tgt_sents=None, return_tensor=False):
        src_tokens = self.src_tokenizer.encode_batch(
            src_sents, return_tensor=return_tensor)
        if tgt_sents is not None:
            tgt_tokens = self.tgt_tokenizer.encode_batch(
                tgt_sents, return_tensor=return_tensor)

    def decode(self, src_ids, tgt_ids, return_tensor=False):
        pass

    @staticmethod
    def create_vocab(file_path, output_path, least_freq=2):
        tokenizer = BertWordPieceTokenizer(clean_text=False,
                                           strip_accents=False,
                                           lowercase=True)
        files = [file_path]
        tokenizer.train(files,
                        vocab_size=1000,
                        min_frequency=least_freq,
                        show_progress=True,
                        special_tokens=['[PAD]', '[UNK]', '[SOS]', '[EOS]'],
                        limit_alphabet=1000,
                        wordpieces_prefix="##")
        tokenizer.save(output_path)
        print(f"Vacabulary created at location {output_path}")
    def __init__(self, t: PreTrainedTokenizer, args, file_path: str, block_size=512):
        assert os.path.isfile(file_path)
        logger.info("Creating features from dataset file at %s", file_path)
        
        # -------------------------- CHANGES START
        bert_tokenizer = os.path.join(args.tokenizer_name, "vocab.txt")
        if os.path.exists(bert_tokenizer):
            logger.info("Loading BERT tokenizer")
            from tokenizers import BertWordPieceTokenizer
            tokenizer = BertWordPieceTokenizer(os.path.join(args.tokenizer_name, "vocab.txt"), handle_chinese_chars=False, lowercase=False)
            tokenizer.enable_truncation(512)
        else:
            from tokenizers import ByteLevelBPETokenizer
            from tokenizers.processors import BertProcessing
            logger.info("Loading RoBERTa tokenizer")
            
            tokenizer = ByteLevelBPETokenizer(
                os.path.join(args.tokenizer_name, "vocab.json"),
                os.path.join(args.tokenizer_name, "merges.txt")
            )
            tokenizer._tokenizer.post_processor = BertProcessing(
                ("</s>", tokenizer.token_to_id("</s>")),
                ("<s>", tokenizer.token_to_id("<s>")),
            )
            tokenizer.enable_truncation(max_length=512)

        logger.info("Reading file %s", file_path)
        with open(file_path, encoding="utf-8") as f:
            lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]

        logger.info("Running tokenization")
        self.examples = tokenizer.encode_batch(lines)
Exemple #4
0
def main():
    if params.model == 'bert':
        CLS_TOKEN = '[CLS]'
        SEP_TOKEN = "[SEP]"
        UNK_TOKEN = "[UNK]"
        PAD_TOKEN = "[PAD]"
        MASK_TOKEN = "[MASK]"
    elif params.model == 'roberta':
        CLS_TOKEN = '<s>'
        SEP_TOKEN = '</s>'
        UNK_TOKEN = '<unk>'
        PAD_TOKEN = '<pad>'
        MASK_TOKEN = '<mask>'

    tokenizer = BertWordPieceTokenizer(
        params.vocab,
        unk_token=UNK_TOKEN,
        sep_token=SEP_TOKEN,
        cls_token=CLS_TOKEN,
        pad_token=PAD_TOKEN,
        mask_token=MASK_TOKEN,
        lowercase=False,
        strip_accents=False)

    with open(params.input, 'r') as reader, open(params.output, 'w') as writer:
        sentences = reader.readlines()
        for batch in chunks(sentences, 1000):
            encoded = tokenizer.encode_batch(batch)
            sentences = map(lambda x: ' '.join(x.tokens[1:-1]), encoded)
            writer.write('\n'.join(sentences) + '\n')
Exemple #5
0
    def convert_to_ratt(self,
                        ratt_dir,
                        do_lower=True,
                        max_sequence_length=128,
                        data_type="train"):
        if not os.path.exists(ratt_dir):
            os.mkdir(ratt_dir)
        # Build dictionary
        text_list, label_list = self._read_csv(self.raw_data_file)

        # Token vocab
        token_vocab_name = "ratt"
        vocab_file = os.path.join(ratt_dir, token_vocab_name + "-vocab.txt")
        if not os.path.isfile(vocab_file):
            tokenizer = BertWordPieceTokenizer(lowercase=do_lower)
            tokenizer.train(files=[self.raw_data_file], vocab_size=8192)
            tokenizer.save_model(ratt_dir, token_vocab_name)
        else:
            tokenizer = BertWordPieceTokenizer(vocab_file=vocab_file,
                                               lowercase=do_lower)

        # Label vocab
        label_vocab_file = os.path.join(ratt_dir, "label_dict.txt")
        if not os.path.isfile(label_vocab_file):
            labels = set(label_list)
            label_map = {str(l): i for i, l in enumerate(labels)}
            with open(label_vocab_file, "w", encoding="utf-8") as fout:
                for l in labels:
                    fout.write("%s\n" % l)
        else:
            label_map = {}
            with open(label_vocab_file, encoding="utf-8") as fin:
                for i, line in enumerate(fin):
                    label_map[line.rstrip()] = i

        if data_type not in ["train", "dev", "test"]:
            data_types = ["train", "dev", "test"]
        else:
            data_types = [data_type]

        for data_type in data_types:
            logging.info("Converting %s.." %
                         eval("self.raw_%s_file" % data_type))
            text_list, label_list = self._read_csv(
                eval("self.raw_%s_file" % data_type))

            outputs = tokenizer.encode_batch(text_list,
                                             add_special_tokens=True)
            input_ids = [output.ids for output in outputs]
            padded_inputs = tf.keras.preprocessing.sequence.pad_sequences(
                input_ids,
                padding="post",
                maxlen=max_sequence_length,
                truncating="post")

            label_ids = [label_map[str(label)] for label in label_list]
            save_file = os.path.join(ratt_dir, data_type + ".npz")
            np.savez(save_file, inputs=padded_inputs, targets=label_ids)
Exemple #6
0
def process_chunk(chunk_no, block_offset, no_lines, config):
    # Load lines
    doc_ids = []
    full_texts = []
    docs_path = os.path.join(config["data_home"], "docs/msmarco-docs.tsv")
    with open(docs_path, encoding="utf-8") as f:
        f.seek(block_offset[chunk_no])
        for i in tqdm(range(no_lines),
                      desc="Loading block for {}".format(chunk_no)):
            line = f.readline()
            try:
                doc_id, url, title, text = line[:-1].split("\t")
            except (IndexError, ValueError):
                continue
            doc_ids.append(doc_id)
            full_texts.append(" ".join([url, title, text]))
    # tokenizer = DistilBertTokenizer.from_pretrained(config["bert_class"])
    tokenizer = BertWordPieceTokenizer(config["tokenizer_vocab_path"],
                                       lowercase=True)
    output_line_format = "{}\t{}\n"
    trec_format = "<DOC>\n<DOCNO>{}</DOCNO>\n<TEXT>{}</TEXT></DOC>\n"
    partial_doc_path = os.path.join(config["data_home"], "tmp",
                                    "docs-{}".format(chunk_no))
    partial_doc_path_bert = os.path.join(config["data_home"], "tmp",
                                         "docs-{}.bert".format(chunk_no))
    partial_trec_path = os.path.join(config["data_home"], "tmp",
                                     "trec_docs-{}".format(chunk_no))

    with open(partial_doc_path, 'w', encoding="utf-8") as outf, open(
            partial_trec_path, 'w', encoding="utf-8") as outf_trec, open(
                partial_doc_path_bert, 'w',
                encoding='utf-8') as outf_bert:  # noqa E501
        start = time.time()
        tokenized = tokenizer.encode_batch(full_texts)
        end = time.time()
        print("tokenizer {} finished in {}s".format(chunk_no, end - start), )
        for doc_id, sample in tqdm(zip(doc_ids, tokenized),
                                   desc="dumping tokenized docs to tmp file",
                                   total=len(tokenized)):  # noqa E501
            start = time.time()
            bert_text = sample.tokens[1:-1]
            tokenized_text = ' '.join(bert_text).replace("##", "")
            outf.write(output_line_format.format(doc_id, tokenized_text))
            outf_trec.write(trec_format.format(doc_id, tokenized_text))
            outf_bert.write("{}\t{}\n".format(doc_id, bert_text))
        outf.flush()
        outf_trec.flush()
        outf_bert.flush()
def tokenize(texts: pd.Series,
             tokenizer: BertWordPieceTokenizer,
             chunk_size: int = 240,
             maxlen: int = 512) -> np.array:
    '''Tokenize input text, return in a form of array'''
    tokenizer.enable_truncation(max_length=maxlen)
    try:
        tokenizer.enable_padding(max_length=maxlen)
    except TypeError:
        tokenizer.enable_padding(length=maxlen)
    all_ids = []

    for i in range(0, len(texts), chunk_size):
        text_chunk = texts[i:i + chunk_size].tolist()
        encs = tokenizer.encode_batch(text_chunk)
        all_ids.extend([enc.ids for enc in encs])

    return np.array(all_ids)
def get_preds(list_of_texts):
    transformer_layer = (transformers.TFDistilBertModel.from_pretrained(
        'distilbert-base-multilingual-cased'))

    model = build_model(transformer_layer, max_len=MAX_LEN)
    model.load_weights('model/weights')

    #model = tf.keras.models.load_model('model')

    print('weights loaded')

    tokenizer = transformers.DistilBertTokenizer.from_pretrained(
        'distilbert-base-multilingual-cased')
    tokenizer.save_pretrained('.')
    # Reload it with the huggingface tokenizers library
    fast_tokenizer = BertWordPieceTokenizer('vocab.txt', lowercase=False)

    fast_tokenizer.enable_truncation(max_length=MAX_LEN)
    fast_tokenizer.enable_padding(length=MAX_LEN)

    all_ids = []
    encs = fast_tokenizer.encode_batch(list_of_texts)
    all_ids.extend([enc.ids for enc in encs])

    all_ids = np.array(all_ids).astype(np.float32)

    to_predict = create_test(all_ids)

    predictions = model.predict(to_predict)
    #print(predictions*10)

    for prediction in predictions:
        print(prediction)

    dic = {'predictions': predictions}

    parsed = []
    #response = pd.DataFrame(dic)
    #parsed = response.to_json(orient = 'columns') #not sure if works
    #json.dumps(parsed)           #to be reviewed
    return parsed, predictions
Exemple #9
0
# CRIA O TOKENIZER A PARTIR DE UM VOCABULÁRIO
# LOWERCASE = FALSE (NÃO IRÁ CONVERTER AS ENTRADAS PARA LOWERCASE. MANTEM O ORGINIAL)
# STRIP ACCENTS = FALSE (MANTEM OS ACENTOS)
tokenizer = BertWordPieceTokenizer("vocab.txt", lowercase=False, strip_accents=False)

# MOSTRA AS INFORMAÇÕES DO TONENIZER
print(tokenizer)

# PERMITE O TRUNCATION E O PADDING
tokenizer.enable_truncation(max_length=60)
tokenizer.enable_padding()


# TOKENINZA EM BATCH TODAS AS SENTENÇAS
# TEM QUE USAR .TOLIST PARA CONVERTER POR LISTA. SENTENCAS É UM ARRAY NUMPY
output = tokenizer.encode_batch(sentencas.tolist())

# O TOKENIZER RETORAR UMA LISTA DE OBJETOS DO TIPO TOKENIZER
# PRECISAMOS PEGAR OS ATRIBUTOS IDS E MASKS E ADICIONAR PARA LISTAS
# OS OBJETOS TEM O ATRIBUTO IDS(IDS), TOKENS (TOKENS) E attention_mask
# PRECISAMOS FAXER O FOR PARA PEGAR CADA UM E DEPOIS CRIAR A LISTA
ids=[x.ids for x in output]
attention_mask = [x.attention_mask for x in output]

print(len(ids))
print(len(attention_mask))

# PRINTS EXEMPLO DE SAIDA DA PRIMEIRA LINHA
print(output[0])
print(output[0].tokens)
Exemple #10
0
    for count, blob in enumerate(sub_blobs):
        data = blob.download_as_string()
        data = data.decode("utf-8")
        data = data.split("\n\n")
        flat_data = []
        for line in data:
            if len(line) > 100000:
                line = [
                    line[i:i + 100000] for i in range(0, len(line), 100000)
                ]
                flat_data.extend(line)
            else:
                flat_data.append(line)
        data = flat_data
        print(f"start tokenizing file {blob.name}")
        encoded = tokenizer.encode_batch(data)
        print(f"finish tokenizing file {blob.name}")

        # Prepare something for worker to do
        def generator():
            index = 0
            for item in encoded:
                yield (item, index)
                index += 1

        # Actual Work
        def worker(item):
            size = 0
            ids, masked_ids = [], []
            index = item[1]
            item = item[0]
Exemple #11
0
def numerize(vocab_path, input_path, bin_path):
    tokenizer = BertWordPieceTokenizer(vocab_path,
                                       unk_token=UNK_TOKEN,
                                       sep_token=SEP_TOKEN,
                                       cls_token=CLS_TOKEN,
                                       pad_token=PAD_TOKEN,
                                       mask_token=MASK_TOKEN,
                                       lowercase=False,
                                       strip_accents=False)
    sentences = []
    with open(input_path, 'r') as f:
        batch_stream = []
        for i, line in enumerate(f):
            batch_stream.append(line)
            if i % 1000 == 0:
                res = tokenizer.encode_batch(batch_stream)
                batch_stream = []
                # flatten the list
                for s in res:
                    sentences.extend(s.ids[1:])
            if i % 100000 == 0:
                print(f'processed {i} lines')

    print('convert the data to numpy')

    # convert data to numpy format in uint16
    if tokenizer.get_vocab_size() < 1 << 16:
        sentences = np.uint16(sentences)
    else:
        assert tokenizer.get_vocab_size() < 1 << 31
        sentences = np.int32(sentences)

    # save special tokens for later processing
    sep_index = tokenizer.token_to_id(SEP_TOKEN)
    cls_index = tokenizer.token_to_id(CLS_TOKEN)
    unk_index = tokenizer.token_to_id(UNK_TOKEN)
    mask_index = tokenizer.token_to_id(MASK_TOKEN)
    pad_index = tokenizer.token_to_id(PAD_TOKEN)

    # sanity check
    assert sep_index == SEP_INDEX
    assert cls_index == CLS_INDEX
    assert unk_index == UNK_INDEX
    assert pad_index == PAD_INDEX
    assert mask_index == MASK_INDEX

    print('collect statistics')
    # collect some statistics of the dataset
    n_unks = (sentences == unk_index).sum()
    n_toks = len(sentences)
    p_unks = n_unks * 100. / n_toks
    n_seqs = (sentences == sep_index).sum()
    print(
        f'| {n_seqs} sentences - {n_toks} tokens - {p_unks:.2f}% unknown words'
    )

    # print some statistics
    data = {
        'sentences': sentences,
        'sep_index': sep_index,
        'cls_index': cls_index,
        'unk_index': unk_index,
        'pad_index': pad_index,
        'mask_index': mask_index
    }

    torch.save(data, bin_path, pickle_protocol=4)
def main():
    start_time = time.time()
    args = parse_args()
    make_directories(args.output_dir)

    # Start Tensorboard and log hyperparams.
    tb_writer = SummaryWriter(args.output_dir)
    tb_writer.add_hparams(vars(args), {})

    file_log_handler = logging.FileHandler(
        os.path.join(args.output_dir, 'log.txt'))
    logger.addHandler(file_log_handler)

    # Get list of text and list of label (integers) from disk.
    train_text, train_label_id_list, eval_text, eval_label_id_list = \
        get_examples_and_labels(args.dataset)

    # Augment training data.
    if (args.augmentation_recipe is not None) and len(
            args.augmentation_recipe):
        import pandas as pd

        if args.augmentation_recipe == 'textfooler':
            aug_csv = '/p/qdata/jm8wx/research/text_attacks/textattack/outputs/attack-1590551967800.csv'
        elif args.augmentation_recipe == 'tf-adjusted':
            aug_csv = '/p/qdata/jm8wx/research/text_attacks/textattack/outputs/attack-1590564015768.csv'
        else:
            raise ValueError(
                f'Unknown augmentation recipe {args.augmentation_recipe}')

        aug_df = pd.read_csv(aug_csv)

        # filter skipped outputs
        aug_df = aug_df[aug_df['original_text'] != aug_df['perturbed_text']]

        print(
            f'Augmentation recipe {args.augmentation_recipe} / augmentation num. examples {args.augmentation_num}/ len {len(aug_df)}'
        )

        original_text = aug_df['original_text']
        perturbed_text = aug_df['perturbed_text']

        # convert `train_text` and `train_label_id_list` to an np array so things are faster
        train_text = np.array(train_text)
        train_label_id_list = np.array(train_label_id_list)

        x_adv_list = []
        x_adv_id_list = []
        for (x, x_adv) in zip(original_text, perturbed_text):
            x = x.replace('[[', '').replace(']]', '')
            x_adv = x_adv.replace('[[', '').replace(']]', '')
            x_idx = (train_text == x).nonzero()[0][0]
            x_adv_label = train_label_id_list[x_idx]
            x_adv_id_list.append(x_adv_label)
            x_adv_list.append(x_adv)

        # truncate to `args.augmentation_num` examples
        if (args.augmentation_num >= 0):
            perm = list(range(len(x_adv_list)))
            random.shuffle(perm)
            perm = perm[:args.augmentation_num]
            x_adv_list = [x_adv_list[i] for i in perm]
            x_adv_id_list = [x_adv_id_list[i] for i in perm]

        train_text = train_text.tolist() + x_adv_list
        train_label_id_list = train_label_id_list.tolist() + x_adv_id_list

        print(
            f'Augmentation added {len(x_adv_list)} examples, for a total of {len(train_text)}'
        )

    label_id_len = len(train_label_id_list)
    num_labels = len(set(train_label_id_list))
    logger.info('num_labels: %s', num_labels)

    train_examples_len = len(train_text)

    if len(train_label_id_list) != train_examples_len:
        raise ValueError(
            f'Number of train examples ({train_examples_len}) does not match number of labels ({len(train_label_id_list)})'
        )
    if len(eval_label_id_list) != len(eval_text):
        raise ValueError(
            f'Number of teste xamples ({len(eval_text)}) does not match number of labels ({len(eval_label_id_list)})'
        )

    print_cuda_memory(args)
    # old INFO:__main__:Loaded data and tokenized in 189.66675066947937s

    # @TODO support other vocabularies, or at least, support case
    tokenizer = BertWordPieceTokenizer('bert-base-uncased-vocab.txt',
                                       lowercase=True)
    tokenizer.enable_padding(max_length=args.max_seq_len)
    tokenizer.enable_truncation(max_length=args.max_seq_len)

    logger.info(f'Tokenizing training data. (len: {train_examples_len})')
    train_text_ids = [
        encoding.ids for encoding in tokenizer.encode_batch(train_text)
    ]
    logger.info(f'Tokenizing test data (len: {len(eval_label_id_list)})')
    eval_text_ids = [
        encoding.ids for encoding in tokenizer.encode_batch(eval_text)
    ]
    load_time = time.time()
    logger.info(f'Loaded data and tokenized in {load_time-start_time}s')

    print_cuda_memory(args)

    # Load pre-trained model tokenizer (vocabulary)
    logger.info('Loading model: %s', args.model_dir)
    # Load pre-trained model (weights)
    logger.info(f'Model class: (vanilla) BertForSequenceClassification.')
    model = BertForSequenceClassification.from_pretrained(
        args.model_dir, num_labels=num_labels)

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    model.to(device)
    # print(model)

    # multi-gpu training
    if args.num_gpus > 1:
        model = torch.nn.DataParallel(model)
    logger.info(f'Training model across {args.num_gpus} GPUs')

    num_train_optimization_steps = int(
        train_examples_len / args.batch_size /
        args.grad_accum_steps) * args.num_train_epochs

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_proportion,
        num_training_steps=num_train_optimization_steps)

    global_step = 0

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", train_examples_len)
    logger.info("  Batch size = %d", args.batch_size)
    logger.info("  Max sequence length = %d", args.max_seq_len)
    logger.info("  Num steps = %d", num_train_optimization_steps)

    wandb.log({'train_examples_len': train_examples_len})

    train_input_ids = torch.tensor(train_text_ids, dtype=torch.long)
    train_label_ids = torch.tensor(train_label_id_list, dtype=torch.long)
    train_data = TensorDataset(train_input_ids, train_label_ids)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.batch_size)

    eval_input_ids = torch.tensor(eval_text_ids, dtype=torch.long)
    eval_label_ids = torch.tensor(eval_label_id_list, dtype=torch.long)
    eval_data = TensorDataset(eval_input_ids, eval_label_ids)
    eval_sampler = RandomSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.batch_size)

    def get_eval_acc():
        correct = 0
        total = 0
        for input_ids, label_ids in tqdm.tqdm(eval_dataloader,
                                              desc="Evaluating accuracy"):
            input_ids = input_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                logits = model(input_ids)[0]

            correct += (logits.argmax(dim=1) == label_ids).sum()
            total += len(label_ids)

        return float(correct) / total

    def save_model():
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model itself

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(args.output_dir, args.weights_name)
        output_config_file = os.path.join(args.output_dir, args.config_name)

        torch.save(model_to_save.state_dict(), output_model_file)
        model_to_save.config.to_json_file(output_config_file)

        logger.info(
            f'Best acc found. Saved tokenizer, model config, and model to {args.output_dir}.'
        )

    global_step = 0

    def save_model_checkpoint(checkpoint_name=None):
        # Save model checkpoint
        checkpoint_name = checkpoint_name or 'checkpoint-{}'.format(
            global_step)
        output_dir = os.path.join(args.output_dir, checkpoint_name)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        # Take care of distributed/parallel training
        model_to_save = model.module if hasattr(model, 'module') else model
        model_to_save.save_pretrained(output_dir)
        torch.save(args, os.path.join(output_dir, 'training_args.bin'))
        logger.info('Checkpoint saved to %s.', output_dir)

    print_cuda_memory(args)
    model.train()
    best_eval_acc = 0
    steps_since_best_eval_acc = 0

    def loss_backward(loss):
        if args.num_gpus > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training
        if args.grad_accum_steps > 1:
            loss = loss / args.grad_accum_steps
        loss.backward()

    for epoch in tqdm.trange(int(args.num_train_epochs), desc="Epoch"):
        prog_bar = tqdm.tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(prog_bar):
            print_cuda_memory(args)
            batch = tuple(t.to(device) for t in batch)
            input_ids, labels = batch
            logits = model(input_ids)[0]
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = torch.nn.CrossEntropyLoss()(logits.view(-1, num_labels),
                                               labels.view(-1))
            if global_step % args.tb_writer_step == 0:
                tb_writer.add_scalar('loss', loss, global_step)
                tb_writer.add_scalar('lr', loss, global_step)
            loss_backward(loss)
            prog_bar.set_description(f"Loss {loss.item()}")
            if (step + 1) % args.grad_accum_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1
            # Save model checkpoint to file.
            if global_step % args.checkpoint_steps == 0:
                save_model_checkpoint()

            model.zero_grad()

            # Inc step counter.
            global_step += 1

        # Check accuracy after each epoch.
        eval_acc = get_eval_acc()
        tb_writer.add_scalar('epoch_eval_acc', eval_acc, global_step)
        wandb.log({'epoch_eval_acc': eval_acc, 'epoch': epoch})

        if args.checkpoint_every_epoch:
            save_model_checkpoint(f'epoch-{epoch}')

        logger.info(f'Eval acc: {eval_acc*100}%')
        if eval_acc > best_eval_acc:
            best_eval_acc = eval_acc
            steps_since_best_eval_acc = 0
            save_model()
        else:
            steps_since_best_eval_acc += 1
            if (args.early_stopping_epochs > 0) and (
                    steps_since_best_eval_acc > args.early_stopping_epochs):
                logger.info(
                    f'Stopping early since it\'s been {args.early_stopping_epochs} steps since validation acc increased'
                )
                break