Exemple #1
0
def test(model, epoch, batch_size, data_path, fold, gpu, dicts, freq_params,
         model_dir, testing, debug):
    """
        Testing loop.
        Returns metrics
    """
    filename = data_path.replace('train', fold)
    print('file for evaluation: %s' % filename)
    num_labels = tools.get_num_labels()

    raw_text = []
    y, yhat, yhat_raw, losses = [], [], [], []
    ind2w, w2ind, ind2l, l2ind = dicts[0], dicts[1], dicts[2], dicts[3]

    model.eval()
    gen = datasets.data_generator(filename, dicts, batch_size, num_labels)
    for batch_idx, tup in tqdm(enumerate(gen)):
        if debug and batch_idx > 50:
            break
        data, label, raw_t = tup
        for item in raw_t:
            raw_text.append(item)
        target = label
        data, target = Variable(torch.LongTensor(data),
                                volatile=True), Variable(
                                    torch.FloatTensor(target))
        if gpu:
            data = data.cuda()
            target = target.cuda()
        model.zero_grad()

        output, loss, _ = model(data, target)  # for conv_attn model

        output = output.data.cpu().numpy()
        losses.append(loss.data[0])
        target_data = target.data.cpu().numpy()

        yhat_raw.append(output)
        output = np.round(output)
        y.append(target_data)
        yhat.append(output)

    y = np.concatenate(y, axis=0)
    yhat = np.concatenate(yhat, axis=0)
    yhat_raw = np.concatenate(yhat_raw, axis=0)

    print("y shape: " + str(y.shape))
    print("yhat shape: " + str(yhat.shape))

    # get metrics
    k = 5
    metrics, ht_at_k_val, ht_at_1_val = evaluate.all_metrics(yhat,
                                                             y,
                                                             k=k,
                                                             yhat_raw=yhat_raw)
    evaluate.print_metrics(metrics)
    metrics['loss_%s' % fold] = np.mean(losses)

    return metrics
Exemple #2
0
def train(model, optimizer, Y, epoch, batch_size, data_path, gpu, version,
          dicts, quiet):
    """
        Training loop.
        output: losses for each example for this iteration
    """
    print("EPOCH %d" % epoch)
    num_labels = len(dicts['ind2c'])

    losses = []
    # how often to print some info to stdout
    print_every = 25

    ind2w, w2ind, ind2c, c2ind = dicts['ind2w'], dicts['w2ind'], dicts[
        'ind2c'], dicts['c2ind']
    unseen_code_inds = set(ind2c.keys())
    desc_embed = model.lmbda > 0

    model.train()
    gen = datasets.data_generator(data_path,
                                  dicts,
                                  batch_size,
                                  num_labels,
                                  version=version,
                                  desc_embed=desc_embed)
    for batch_idx, tup in tqdm(enumerate(gen)):
        data, target, _, code_set, descs = tup
        # print("data size " + str(data.shape))
        # print("target size " + str(target.shape))
        data, target = Variable(torch.LongTensor(data)), Variable(
            torch.FloatTensor(target))
        unseen_code_inds = unseen_code_inds.difference(code_set)
        if gpu:
            data = data.cuda()
            target = target.cuda()
        optimizer.zero_grad()

        if desc_embed:
            desc_data = descs
        else:
            desc_data = None

        output, loss, _ = model(data, target)
        # print(output.shape)

        loss.backward()
        optimizer.step()

        losses.append(loss.data.item())

        if not quiet and batch_idx % print_every == 0:
            # print the average loss of the last 10 batches
            print(
                "Train epoch: {} [batch #{}, batch_size {}, seq length {}]\tLoss: {:.6f}"
                .format(epoch, batch_idx,
                        data.size()[0],
                        data.size()[1], np.mean(losses[-10:])))
    return losses, unseen_code_inds
Exemple #3
0
def train(model, optimizer, Y, epoch, batch_size, data_path, gpu, version,
          freq_params, dicts, debug, quiet):
    """
        Training loop.
        output: losses for each example for this iteration
    """
    num_labels = tools.get_num_labels(Y, version)

    losses = []
    #how often to print some info to stdout
    print_every = 25

    ind2w, w2ind, ind2c, c2ind = dicts[0], dicts[1], dicts[2], dicts[3]
    unseen_code_inds = set(ind2c.keys())
    desc_embed = model.lmbda > 0

    model.train()  # PUTS MODEL IN TRAIN MODE
    gen = datasets.data_generator(data_path,
                                  dicts,
                                  batch_size,
                                  num_labels,
                                  version=version,
                                  desc_embed=desc_embed)
    for batch_idx, tup in tqdm(enumerate(gen)):
        if debug and batch_idx > 50:  # LIKELY NOT NEEDED
            break
        data, target, _, code_set, descs = tup
        data, target = Variable(torch.LongTensor(data)), Variable(
            torch.FloatTensor(target))
        unseen_code_inds = unseen_code_inds.difference(code_set)
        if gpu:
            data = data.cuda()
            target = target.cuda()
        optimizer.zero_grad()

        if desc_embed:
            desc_data = descs
        else:
            desc_data = None

        output, loss, _ = model(data, target,
                                desc_data=desc_data)  # FORWARD PASS

        loss.backward()
        optimizer.step()

        losses.append(loss.data[0])

        if not quiet and batch_idx % print_every == 0:
            #print the average loss of the last 100 batches
            print(
                "Train epoch: {} [batch #{}, batch_size {}, seq length {}]\tLoss: {:.6f}"
                .format(epoch + 1, batch_idx,
                        data.size()[0],
                        data.size()[1], np.mean(losses[-100:])))
    return losses, unseen_code_inds
Exemple #4
0
def slim_data_generator(data_path):
    while 1:
        for batch_idx, tup in enumerate(
                datasets.data_generator(data_path,
                                        dicts,
                                        batch_size=batch_size,
                                        num_labels=nb_labels)):
            X, y, _, code_set, descs = tup
            X = sequence.pad_sequences(X, maxlen=maxlen)
            yield X, y
Exemple #5
0
def one_epoch(model, optimizer, Y, epoch, n_epochs, batch_size, data_path,
              version, freq_params, testing, dicts, model_dir,
              unseen_code_inds, samples, gpu, debug, quiet):
    """
        Basically a wrapper to do a training epoch and test on dev
    """
    if not testing:
        losses, unseen_code_inds = train(model, optimizer, Y, epoch,
                                         batch_size, data_path, gpu, version,
                                         freq_params, dicts, debug, quiet)
        print('num unseen codes', len(unseen_code_inds))
        loss = np.mean(losses)
        print("epoch loss: " + str(loss))
    else:
        loss = np.nan
        if model.lmbda > 0:
            #still need to get unseen code inds
            unseen_code_inds = set(dicts[2].keys())
            num_labels = tools.get_num_labels(Y, version)
            gen = datasets.data_generator(data_path,
                                          dicts,
                                          batch_size,
                                          num_labels,
                                          version=version)
            for _, _, _, code_set, _ in gen:
                unseen_code_inds = unseen_code_inds.difference(code_set)
            print("num codes not in train set: %d" % len(unseen_code_inds))
        else:
            unseen_code_inds = set()

    fold = 'test' if version == 'mimic2' else 'dev'
    if epoch == n_epochs - 1:
        print("last epoch: testing on test and train sets")
        testing = True
        quiet = False

    #test on dev
    metrics = test(model, Y, epoch, batch_size, data_path, fold, gpu, version,
                   unseen_code_inds, dicts, samples, freq_params, model_dir,
                   testing, debug)
    if testing or epoch == n_epochs - 1:
        print("evaluating on test")
        metrics_te = test(model, Y, epoch, batch_size, data_path, "test", gpu,
                          version, unseen_code_inds, dicts, samples,
                          freq_params, model_dir, True, debug)
    else:
        metrics_te = defaultdict(float)
        fpr_te = defaultdict(lambda: [])
        tpr_te = defaultdict(lambda: [])
    metrics_tr = {'loss': loss}
    metrics_all = (metrics, metrics_te, metrics_tr)
    return metrics_all, unseen_code_inds
Exemple #6
0
def train(model, optimizer, epoch, batch_size, data_path, gpu, freq_params,
          dicts, debug, quiet):
    """
        Training loop.
        output: losses for each example for this iteration
    """
    num_labels = tools.get_num_labels()

    losses = []
    # how often to print some info to stdout
    print_every = 25
    ind2w, w2ind, ind2l, l2ind = dicts[0], dicts[1], dicts[2], dicts[3]

    model.train()
    gen = datasets.data_generator(data_path, dicts, batch_size, num_labels)
    for batch_idx, tup in tqdm(enumerate(gen)):
        if debug and batch_idx > 50:
            break
        data, label, _ = tup
        target = label
        data, target = Variable(torch.LongTensor(data)), Variable(
            torch.FloatTensor(target))
        if gpu:
            data = data.cuda()
            target = target.cuda()
        optimizer.zero_grad()

        output, loss, _ = model(data, target)

        loss.backward()
        optimizer.step()

        losses.append(loss.data[0])

        if not quiet and batch_idx % print_every == 0:
            # print the average loss of the last 100 batches
            print(
                "Train epoch: {} [batch #{}, batch_size {}, seq length {}]\tLoss: {:.6f}"
                .format(epoch + 1, batch_idx,
                        data.size()[0],
                        data.size()[1], np.mean(losses[-100:])))
    return losses
Exemple #7
0
def test(model, Y, epoch, data_path, fold, gpu, version, code_inds, dicts,
         samples, model_dir, testing):
    """
        Testing loop.
        Returns metrics
    """
    filename = data_path.replace('train', fold)
    print('file for evaluation: %s' % filename)
    num_labels = len(dicts['ind2c'])

    #initialize stuff for saving attention samples
    if samples:
        tp_file = open('%s/tp_%s_examples_%d.txt' % (model_dir, fold, epoch),
                       'w')
        fp_file = open('%s/fp_%s_examples_%d.txt' % (model_dir, fold, epoch),
                       'w')
        window_size = model.conv.weight.data.size()[2]

    y, yhat, yhat_raw, hids, losses = [], [], [], [], []
    ind2w, w2ind, ind2c, c2ind = dicts['ind2w'], dicts['w2ind'], dicts[
        'ind2c'], dicts['c2ind']

    desc_embed = model.lmbda > 0
    if desc_embed and len(code_inds) > 0:
        unseen_code_vecs(model, code_inds, dicts, gpu)

    model.eval()
    gen = datasets.data_generator(filename,
                                  dicts,
                                  1,
                                  num_labels,
                                  version=version,
                                  desc_embed=desc_embed)
    for batch_idx, tup in tqdm(enumerate(gen)):
        data, target, hadm_ids, _, descs = tup
        data, target = Variable(torch.LongTensor(data),
                                volatile=True), Variable(
                                    torch.FloatTensor(target))
        if gpu:
            data = data.cuda()
            target = target.cuda()
        model.zero_grad()

        if desc_embed:
            desc_data = descs
        else:
            desc_data = None

        #get an attention sample for 2% of batches
        get_attn = samples and (np.random.rand() < 0.02 or
                                (fold == 'test' and testing))
        output, loss, alpha = model(data,
                                    target,
                                    desc_data=desc_data,
                                    get_attention=get_attn)

        output = F.sigmoid(output)
        output = output.data.cpu().numpy()
        losses.append(loss.item())
        target_data = target.data.cpu().numpy()
        if get_attn and samples:
            interpret.save_samples(data,
                                   output,
                                   target_data,
                                   alpha,
                                   window_size,
                                   epoch,
                                   tp_file,
                                   fp_file,
                                   dicts=dicts)

        #save predictions, target, hadm ids
        yhat_raw.append(output)
        output = np.round(output)
        y.append(target_data)
        yhat.append(output)
        hids.extend(hadm_ids)

    #close files if needed
    if samples:
        tp_file.close()
        fp_file.close()

    y = np.concatenate(y, axis=0)
    yhat = np.concatenate(yhat, axis=0)
    yhat_raw = np.concatenate(yhat_raw, axis=0)

    #write the predictions
    preds_file = persistence.write_preds(yhat, model_dir, hids, fold, ind2c,
                                         yhat_raw)
    #get metrics
    k = 5 if num_labels == 50 else [8, 15]
    metrics = evaluation.all_metrics(yhat, y, k=k, yhat_raw=yhat_raw)
    evaluation.print_metrics(metrics)
    metrics['loss_%s' % fold] = np.mean(losses)
    return metrics
Exemple #8
0
def train(args, model, optimizer, epoch, dicts):
    """
        Training loop.
        output: losses for each example for this iteration
    """
    print("EPOCH %d" % epoch)
    num_labels = len(dicts['ind2c'])

    losses = []
    #how often to print some info to stdout
    print_every = 25

    ind2w, w2ind, ind2c, c2ind = dicts['ind2w'], dicts['w2ind'], dicts[
        'ind2c'], dicts['c2ind']
    unseen_code_inds = set(ind2c.keys())
    desc_embed = model.lmbda > 0

    model.train()
    gen = datasets.data_generator(args.data_path,
                                  dicts,
                                  args.batch_size,
                                  num_labels,
                                  desc_embed=desc_embed,
                                  version=args.version)

    for batch_idx, tup in tqdm(enumerate(gen)):

        old_word_embeds = model.embed.weight.data.cpu().numpy()

        data, target, _, code_set, descs = tup
        data, target = Variable(torch.LongTensor(data)), Variable(
            torch.FloatTensor(target))
        unseen_code_inds = unseen_code_inds.difference(code_set)
        if args.gpu:
            data = data.cuda()
            target = target.cuda()
        optimizer.zero_grad()

        if desc_embed:
            desc_data = descs
        else:
            desc_data = None

        output, loss, _ = model(data, target, desc_data=desc_data)

        loss.backward()
        optimizer.step()

        assert not np.array_equal(model.embed.weight.data.cpu().numpy(),
                                  old_word_embeds)
        #if not np.array_equal(model.embed.weight.data.cpu().numpy(), old_word_embeds):
        #	print("Weights updated")
        #else:
        #	print("No update")

        losses.append(loss.item())

        if not args.quiet and batch_idx % print_every == 0:
            #print the average loss of the last 10 batches
            print(
                "Train epoch: {} [batch #{}, batch_size {}, seq length {}]\tLoss: {:.6f}"
                .format(epoch, batch_idx,
                        data.size()[0],
                        data.size()[1], np.mean(losses[-10:])))
    return losses, unseen_code_inds
def test(model, epoch, batch_size, data_path, fold, gpu, dicts, samples,
         model_dir, testing, debug):
    """
        Testing loop.
        Returns metrics
    """
    filename = data_path.replace('train', fold)
    print('file for evaluation: %s' % filename)

    #    num_labels = tools.get_num_labels(Y, version)

    #initialize stuff for saving attention samples
    if samples:
        tp_file = open('%s/tp_%s_examples_%d.txt' % (model_dir, fold, epoch),
                       'w')
        fp_file = open('%s/fp_%s_examples_%d.txt' % (model_dir, fold, epoch),
                       'w')
        window_size = model.conv.weight.data.size()[2]

    y, yhat, yhat_raw, hids, losses = [], [], [], [], []

    #    ind2w, w2ind, ind2c, c2ind = dicts[0], dicts[1], dicts[2], dicts[3]
    ind2w, w2ind = dicts[0], dicts[1]

    #    desc_embed = model.lmbda > 0
    #    if desc_embed and len(code_inds) > 0:
    #        unseen_code_vecs(model, code_inds, dicts)

    model.eval()
    gen = datasets.data_generator(filename, dicts, batch_size)
    for batch_idx, tup in tqdm(enumerate(gen)):
        if debug and batch_idx > 50:
            break


#        data, target, hadm_ids, _, descs = tup
        data, target, hadm_ids = tup
        data, target = Variable(torch.LongTensor(data),
                                volatile=True), Variable(
                                    torch.FloatTensor(target))
        if gpu:
            data = data.cuda()
            target = target.cuda()

        model.zero_grad()

        #        if desc_embed:
        #            desc_data = descs
        #        else:
        #            desc_data = None

        #        get_attn = samples and (np.random.rand() < 0.02 or (fold == 'test' and testing))

        #        output, loss, alpha = model(data, target, desc_data=desc_data, get_attention=get_attn)
        output, loss, alpha = model(data, target)

        output = output.data.cpu().numpy()
        losses.append(loss.data[0])
        target_data = target.data.cpu().numpy()

        #        if get_attn and samples:
        #            interpret.save_samples(data, output, target_data, alpha, window_size, epoch, tp_file, fp_file, freq_params[0], dicts=dicts)

        #save predictions, target, hadm ids
        yhat_raw.append(output)  # NEED TO KNOW FORM OF OUTPUT
        output = np.round(output)
        y.append(target_data)
        yhat.append(output)
        hids.extend(hadm_ids)

    if samples:
        tp_file.close()
        fp_file.close()

    y = np.concatenate(y, axis=0)
    yhat = np.concatenate(yhat, axis=0)
    yhat_raw = np.concatenate(yhat_raw, axis=0)

    print("y shape: " + str(y.shape))
    print("yhat shape: " + str(yhat.shape))

    #write the predictions
    #   preds_file = persistence.write_preds(yhat, model_dir, hids, fold, ind2c, yhat_raw)
    preds_file = persistence.write_preds(yhat, model_dir, hids, fold, yhat_raw)

    #get metrics
    #    k = 5 if num_labels == 50 else 8

    #    metrics = evaluation.all_metrics(yhat, y, k=k, yhat_raw=yhat_raw)
    metrics = evaluation.all_metrics(yhat, y, yhat_raw=yhat_raw)
    evaluation.print_metrics(metrics)
    metrics['loss_%s' % fold] = np.mean(losses)
    return metrics
Exemple #10
0
#     def __iter__(self):
#         return self

#     def __next__(self):
#         with self.lock:
#             for batch_idx, tup in enumerate(datasets.data_generator(self.data_path, dicts, batch_size=batch_size, num_labels=nb_labels)):
#                 X, y, _, code_set, descs = tup
#                 X = sequence.pad_sequences(X, maxlen=maxlen)
#                 return X, y

gen_slim_train = slim_data_generator(train_data_path)
gen_slim_test = slim_data_generator(test_data_path)
test_y = []
for batch_idx, tup in enumerate(
        datasets.data_generator(test_data_path,
                                dicts,
                                batch_size=batch_size,
                                num_labels=nb_labels)):
    data, target, hadm_ids, _, descs = tup
    test_y.append(target)
test_y = np.concatenate(test_y, axis=0)


# Helper functions
def count_samples(file_path):
    num_lines = 0
    with open(file_path) as f:
        r = csv.reader(f)
        next(r)
        for row in r:
            num_lines += 1
        return num_lines
Exemple #11
0
import pandas as pd

from datasets import data_generator
import numpy as np
import spank
from spank import functions as F

# %%

sdf = data_generator(num_sample=5)

sdf.head()

# %%
sdf.select_expr('name as name1', '*', 'type as style_type', sdf.feature1 + 1,
                F.Value(np.sin(sdf.feature2)).alias("feature4"), F.Value(1))

# %%
sdf.filter_expr((sdf.sex == 'female') & (sdf.feature2 > 0.1))

# %%

sdf
#%%
def test(args, model, Y, epoch, data_path, fold, gpu, version, code_inds,
         dicts, samples, model_dir, testing):
    """
        Testing loop.
        Returns metrics
    """
    filename = data_path.replace('train', fold)
    print('file for evaluation: %s' % filename)
    num_labels = len(dicts['ind2c'])

    #initialize stuff for saving attention samples
    if samples:
        tp_file = open('%s/tp_%s_examples_%d.txt' % (model_dir, fold, epoch),
                       'w')
        fp_file = open('%s/fp_%s_examples_%d.txt' % (model_dir, fold, epoch),
                       'w')
        window_size = model.conv.weight.data.size()[2]

    y, yhat, yhat_raw, hids, losses = [], [], [], [], []
    ind2w, w2ind, ind2c, c2ind = dicts['ind2w'], dicts['w2ind'], dicts[
        'ind2c'], dicts['c2ind']

    desc_embed = model.lmbda > 0
    if desc_embed and len(code_inds) > 0:
        unseen_code_vecs(model, code_inds, dicts, gpu)

    if args.model == 'bert':
        if args.redefined_tokenizer:
            bert_tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path,
                                                           do_lower_case=True)
        else:
            bert_tokenizer = BertTokenizer.from_pretrained(
                './pretrained_weights/bert-base-uncased-vocab.txt',
                do_lower_case=True)
    elif args.model == 'biobert':
        if args.redefined_tokenizer:
            bert_tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path,
                                                           do_lower_case=False)
        else:
            bert_tokenizer = BertTokenizer.from_pretrained(
                './pretrained_weights/biobert_pretrain_output_all_notes_150000/vocab.txt',
                do_lower_case=False)
    elif args.model == 'bert-tiny':
        if args.redefined_tokenizer:
            bert_tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path,
                                                           do_lower_case=True)
        else:
            bert_tokenizer = BertTokenizer.from_pretrained(
                './pretrained_weights/bert-tiny-uncased-vocab.txt',
                do_lower_case=True)
    else:
        bert_tokenizer = None

    model.eval()
    gen = datasets.data_generator(filename,
                                  dicts,
                                  1,
                                  num_labels,
                                  version=version,
                                  desc_embed=desc_embed,
                                  bert_tokenizer=bert_tokenizer,
                                  test=True,
                                  max_seq_length=args.max_sequence_length)
    for batch_idx, tup in tqdm(enumerate(gen)):
        data, target, hadm_ids, _, descs = tup
        data, target = torch.LongTensor(data), torch.FloatTensor(target)
        if gpu:
            data = data.cuda()
            target = target.cuda()

        if desc_embed:
            desc_data = descs
        else:
            desc_data = None

        if args.model in ['bert', 'biobert', 'bert-tiny']:
            token_type_ids = (data > 0).long() * 0
            attention_mask = (data > 0).long()
            position_ids = torch.arange(data.size(1)).expand(
                data.size(0), data.size(1))
            if gpu:
                position_ids = position_ids.cuda()
            position_ids = position_ids * (data > 0).long()
        else:
            attention_mask = (data > 0).long()
            token_type_ids = None
            position_ids = None

        if args.model in BERT_MODEL_LIST:
            with torch.no_grad():
                output, loss = model(input_ids=data, \
                                     token_type_ids=token_type_ids, \
                                     attention_mask=attention_mask, \
                                     position_ids=position_ids, \
                                     labels=target, \
                                     desc_data=desc_data, \
                                     pos_labels=None, \
                                    )
            output = torch.sigmoid(output)
            output = output.data.cpu().numpy()
        else:
            with torch.no_grad():
                output, loss, alpha = model(data,
                                            target,
                                            desc_data=desc_data,
                                            get_attention=get_attn)

            #get an attention sample for 2% of batches
            get_attn = samples and (np.random.rand() < 0.02 or
                                    (fold == 'test' and testing))

            output = torch.sigmoid(output)
            output = output.data.cpu().numpy()

            if get_attn and samples:
                interpret.save_samples(data,
                                       output,
                                       target_data,
                                       alpha,
                                       window_size,
                                       epoch,
                                       tp_file,
                                       fp_file,
                                       dicts=dicts)

        losses.append(loss.item())
        target_data = target.data.cpu().numpy()

        #save predictions, target, hadm ids
        yhat_raw.append(output)
        output = np.round(output)
        y.append(target_data)
        yhat.append(output)
        hids.extend(hadm_ids)

    # close files if needed
    if samples:
        tp_file.close()
        fp_file.close()

    y = np.concatenate(y, axis=0)
    yhat = np.concatenate(yhat, axis=0)
    yhat_raw = np.concatenate(yhat_raw, axis=0)

    #write the predictions
    preds_file = persistence.write_preds(yhat, model_dir, hids, fold, ind2c,
                                         yhat_raw)
    #get metrics
    k = 5 if num_labels == 50 else [8, 15]
    metrics = evaluation.all_metrics(yhat, y, k=k, yhat_raw=yhat_raw)
    evaluation.print_metrics(metrics)
    metrics['loss_%s' % fold] = np.mean(losses)
    return metrics
def train(args, model, optimizer, Y, epoch, batch_size, data_path, gpu,
          version, dicts, quiet, scheduler, labels_weight):
    """
        Training loop.
        output: losses for each example for this iteration
    """
    print("EPOCH %d" % epoch)
    num_labels = len(dicts['ind2c'])

    losses = []
    #how often to print some info to stdout
    print_every = 25

    ind2w, w2ind, ind2c, c2ind = dicts['ind2w'], dicts['w2ind'], dicts[
        'ind2c'], dicts['c2ind']
    unseen_code_inds = set(ind2c.keys())
    desc_embed = model.lmbda > 0

    if args.model == 'bert':
        if args.redefined_tokenizer:
            bert_tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path,
                                                           do_lower_case=True)
        else:
            bert_tokenizer = BertTokenizer.from_pretrained(
                './pretrained_weights/bert-base-uncased-vocab.txt',
                do_lower_case=True)
    elif args.model == 'biobert':
        if args.redefined_tokenizer:
            bert_tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path,
                                                           do_lower_case=False)
        else:
            bert_tokenizer = BertTokenizer.from_pretrained(
                './pretrained_weights/biobert_pretrain_output_all_notes_150000/vocab.txt',
                do_lower_case=False)
    elif args.model == 'bert-tiny':
        if args.redefined_tokenizer:
            bert_tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path,
                                                           do_lower_case=True)
        else:
            bert_tokenizer = BertTokenizer.from_pretrained(
                './pretrained_weights/bert-tiny-uncased-vocab.txt',
                do_lower_case=True)
    else:
        bert_tokenizer = None

    model.train()
    model.zero_grad()
    gen = datasets.data_generator(data_path,
                                  dicts,
                                  batch_size,
                                  num_labels,
                                  version=version,
                                  desc_embed=desc_embed,
                                  bert_tokenizer=bert_tokenizer,
                                  max_seq_length=args.max_sequence_length)
    if labels_weight is not None:
        labels_weight = torch.LongTensor(labels_weight)
    for batch_idx, tup in tqdm(enumerate(gen)):
        data, target, _, code_set, descs = tup
        data, target = torch.LongTensor(data), torch.FloatTensor(target)
        unseen_code_inds = unseen_code_inds.difference(code_set)
        if gpu:
            data = data.cuda()
            target = target.cuda()
            if labels_weight is not None:
                labels_weight = labels_weight.cuda()

        if desc_embed:
            desc_data = descs
        else:
            desc_data = None

        if args.model in ['bert', 'biobert', 'bert-tiny']:
            token_type_ids = (data > 0).long() * 0
            attention_mask = (data > 0).long()
            position_ids = torch.arange(data.size(1)).expand(
                data.size(0), data.size(1))
            if gpu:
                position_ids = position_ids.cuda()
            position_ids = position_ids * (data > 0).long()
        else:
            attention_mask = (data > 0).long()
            token_type_ids = None
            position_ids = None

        if args.model in BERT_MODEL_LIST:
            output, loss = model(input_ids=data, \
                                 token_type_ids=token_type_ids, \
                                 attention_mask=attention_mask, \
                                 position_ids=position_ids, \
                                 labels=target, \
                                 desc_data=desc_data, \
                                 pos_labels=labels_weight, \
                                )
        else:
            output, loss, _ = model(data, target, desc_data=desc_data)

        loss.backward()
        if args.model in BERT_MODEL_LIST:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if args.model in BERT_MODEL_LIST:
            scheduler.step()
        model.zero_grad()

        losses.append(loss.item())

        if not quiet and batch_idx % print_every == 0:
            #print the average loss of the last 10 batches
            print(
                "Train epoch: {} [batch #{}, batch_size {}, seq length {}]\tLoss: {:.6f}"
                .format(epoch, batch_idx,
                        data.size()[0],
                        data.size()[1], np.mean(losses[-10:])))
    return losses, unseen_code_inds