示例#1
0
    def __init__(self, embedding_dir, model_name="bert-base-multilingual-cased", layer=-2):
        super(BertEncoder, self).__init__(embedding_dir)

        # Load pre-trained model (weights) and set to evaluation mode (no more training)
        self.model = BertModel.from_pretrained(model_name)
        self.model.eval()

        # Load word piece tokenizer
        self.tokenizer = BertTokenizer.from_pretrained(model_name)

        # Layer from which to get the embeddings
        self.layer = layer
示例#2
0
 def __init__(self, temp_dir, load_pretrained_bert, bert_config):
     super(Bert, self).__init__()
     if(load_pretrained_bert):
         self.model = BertModel.from_pretrained('bert-base-chinese', cache_dir=temp_dir)
     else:
         self.model = BertModel(bert_config)
示例#3
0
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
import _pickle as pickle
from dataset import Dictionary
import numpy as np


def create_bert_embedding(idx2word, model, tokenizer):
    weights = np.zeros((len(idx2word), 768), dtype=np.float32)
    for idx, word in enumerate(idx2word):
        tokenize_text = tokenizer.tokenize(word)
        index_token = tokenizer.convert_tokens_to_ids(tokenize_text)
        tokens_tensor = torch.tensor([index_token])
        weights[idx] = model(tokens_tensor)[1].detach().numpy()
    return weights


if __name__ == '__main__':
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertModel.from_pretrained('bert-base-uncased')
    model.eval()

    d = Dictionary.load_from_file('../data/dictionary.pkl')
    # weights = create_bert_embedding(d.idx2word, model, tokenizer)
    # np.save('../data/bert_embedding.npy',weights)
示例#4
0
def imm(path):
    dirname = os.path.dirname(path)
    name = os.path.basename(path)
    rawname = os.path.splitext(name)[0] # without extension

    if 'lit' in name or 'literal' in name or 'LOCATION' in name:
        label = 0
    else:
        if 'met' in name or 'metonymic' in name or 'mixed' in name:
             label = 1 # 1 is for METONYMY/NON-LITERAL, 0 is for LITERAL
        elif 'INSTITUTE' in name:
            label = 1
        elif 'TEAM' in name:
            label = 2
        elif 'ARTIFACT' in name:
            label = 3
        elif 'EVENT' in name:
            label = 4

    bert_version = 'bert-base-uncased'
    model = BertModel.from_pretrained(bert_version)
    model.eval()
    spacy_tokenizer = English(parser=False)
    bert_tokenizer = BertTokenizer.from_pretrained(bert_version)
    en_nlp = spacy.load('en')
    inp = codecs.open(path, mode="r", encoding="utf-8")
    # PLEASE FORMAT THE INPUT FILE AS ONE SENTENCE PER LINE. SEE BELOW:
    # ENTITY<SEP>sentence<ENT>ENTITY<ENT>rest of sentence.
    # Germany<SEP>Their privileges as permanent Security Council members, especially the right of veto, 
    # had been increasingly questioned by <ENT>Germany<ENT> and Japan which, as major economic powers.
    out = []
    seq_length = 10  # There are THREE baselines in the paper (5, 10, 50) so use this integer to set it.

    for line in inp:
        line = line.split(u"<SEP>")
        sentence = line[1].split(u"<ENT>")
        entity = [t.text for t in spacy_tokenizer(sentence[1])]
        en_doc = en_nlp(u"".join(sentence).strip())
        words = []
        index = locate_entity(en_doc, entity, spacy_tokenizer(sentence[0].strip()), spacy_tokenizer(sentence[2].strip()))
        start = en_doc[index]

        # --------------------------------------------------------------------
        # Token map will be an int -> int mapping
        #    between the `spacy_tokens` index and the `bert_tokens` index.
        spacy_to_bert_map = []
        bert_tokens = []
        spacy_tokens = [token.text for token in en_doc]

        '''
            According to https://mccormickml.com/2019/05/14/BERT-word-embeddings-tutorial/
                [CLS] amd [SEP] tokens are important.
            Also, use the segment_ids to inform BERT
                that the input is just one sentence.
        '''
        spacy_tokens = ["[CLS]"] + spacy_tokens + ["[SEP]"]

        for orig_token in spacy_tokens:
            spacy_to_bert_map.append(len(bert_tokens))
            bert_tokens.extend(bert_tokenizer.tokenize(orig_token))

        segments_ids = [1] * len(bert_tokens)

        try:
            token_ids = bert_tokenizer.convert_tokens_to_ids(bert_tokens)
            tokens_tensor = torch.tensor([token_ids])
            segments_tensors = torch.tensor([segments_ids])
            with torch.no_grad():
                encoded_layers, _ = model(tokens_tensor, segments_tensors, output_all_encoded_layers=True)

            '''
                According to http://jalammar.github.io/illustrated-bert/
                    concatenating the last four hidden four layers
                    is a good choice as a contextualised ELMo-like word embeddings.

                Concatenation leads to very long tensors.
                So I decided to take sum of the last four hiddden layers.
                This is the second best approach according to the blog.
            '''
            bert_emb = torch.add(encoded_layers[-1],
                                 encoded_layers[-2]).add(encoded_layers[-3]).add(encoded_layers[-4]).squeeze()
            bert_emb_length = bert_emb.shape[-1]

            '''
                Perform summation of subword embeddings to compute word embeddings
                Another choice is to compute the average of the subword embeddings.
                Concatenation is obviously not a good choice here.
                Source: https://mccormickml.com/2019/05/14/BERT-word-embeddings-tutorial/

                Here, we perform summation of subword embeddings.
            '''
            cond_bert_emb = torch.zeros(len(spacy_tokens), bert_emb_length)
            for spacy_index in range(len(spacy_tokens)):
                start_bert_index = spacy_to_bert_map[spacy_index]
                try:
                    end_bert_index = spacy_to_bert_map[spacy_index + 1]
                except IndexError:
                    end_bert_index = len(bert_tokens)
                for foo in range(start_bert_index, end_bert_index):
                    cond_bert_emb[spacy_index] = cond_bert_emb[spacy_index].add(bert_emb[foo])
        except ValueError:
            cond_bert_emb = torch.zeros(len(spacy_tokens), 768)
            print('ValueError Exception caught!')

        '''
            Since the two special tokens are added,
                strip bert embeddings appropriately.
            Now bert embeddings are in sync in spacy parse.
        '''
        cond_bert_emb = cond_bert_emb[1:-1]
        assert (len(cond_bert_emb) == len(en_doc))
        # --------------------------------------------------------------------

        right = pad([t.text for t in en_doc[start.i + 1:][:seq_length]], False, seq_length)
        left = pad([t.text for t in en_doc[:index - len(entity) + 1][-seq_length:]], True, seq_length)

        dep_right = pad([t.dep_ for t in en_doc[start.i + 1:]][:seq_length], False, seq_length)
        dep_left = pad([t.dep_ for t in en_doc[:index - len(entity) + 1]][-seq_length:], True, seq_length)

        bert_right = bert_pad(cond_bert_emb[start.i + 1:][:seq_length], False, seq_length)
        bert_left = bert_pad(cond_bert_emb[:index - len(entity) + 1][-seq_length:], True, seq_length)

        assert(bert_left.shape == bert_right.shape)
        assert(len(left) == len(dep_left) == len(bert_left))
        assert(len(right) == len(dep_right) == len(bert_right))
        out.append((left, dep_left, bert_left, right, dep_right, bert_right, label))
        #print(left, right)
        #print(dep_left, dep_right)
        #print(bert_left, bert_right)
        #print(label)
        #print(line[1])
    print("Processed:{} lines/sentences.".format(len(out)))
    dump_to_hdf5("{}/bert_pickles/{}_imm.hdf5".format(dirname, rawname), out)
def score(cands,
          refs,
          bert="bert-base-multilingual-cased",
          num_layers=8,
          verbose=False,
          no_idf=False,
          batch_size=64):
    """
    BERTScore metric.

    Args:
        - :param: `cands` (list of str): candidate sentences
        - :param: `refs` (list of str): reference sentences
        - :param: `bert` (str): bert specification
        - :param: `num_layers` (int): the layer of representation to use
        - :param: `verbose` (bool): turn on intermediate status update
        - :param: `no_idf` (bool): do not use idf weighting
        - :param: `batch_size` (int): bert score processing batch size
    """
    assert len(cands) == len(refs)
    assert bert in bert_types

    tokenizer = BertTokenizer.from_pretrained(bert)
    model = BertModel.from_pretrained(bert)
    model.eval()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)

    # drop unused layers
    model.encoder.layer = torch.nn.ModuleList(
        [layer for layer in model.encoder.layer[:num_layers]])

    if no_idf:
        idf_dict = defaultdict(lambda: 1.)
        idf_dict[101] = 0
        idf_dict[102] = 0
    else:
        if verbose:
            print('preparing IDF dict...')
        start = time.perf_counter()
        idf_dict = get_idf_dict(refs, tokenizer)
        if verbose:
            print('done in {:.2f} seconds'.format(time.perf_counter() - start))

    if verbose:
        print('calculating scores...')
    start = time.perf_counter()
    all_preds = bert_cos_score_idf(model,
                                   refs,
                                   cands,
                                   tokenizer,
                                   idf_dict,
                                   verbose=verbose,
                                   device=device,
                                   batch_size=batch_size)

    P = all_preds[:, 0].cpu()
    R = all_preds[:, 1].cpu()
    F1 = all_preds[:, 2].cpu()
    if verbose:
        print('done in {:.2f} seconds'.format(time.perf_counter() - start))

    return P, R, F1
示例#6
0
 def __init__(self, embed_dim=512):
     super(PoemEmbed, self).__init__()
     self.bert = BertModel.from_pretrained('bert-base-uncased')
     self.linear = nn.Linear(768, embed_dim)
示例#7
0
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM, WordpieceTokenizer
from argparse import ArgumentParser
import h5py
import numpy as np

argp = ArgumentParser()
argp.add_argument('input_path')
argp.add_argument('output_path')
argp.add_argument('bert_model', help='base or large')
args = argp.parse_args()

# Load pre-trained model tokenizer (vocabulary)
# Crucially, do not do basic tokenization; PTB is tokenized. Just do wordpiece tokenization.
if args.bert_model == 'base':
    tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
    model = BertModel.from_pretrained('bert-base-cased')
    LAYER_COUNT = 12
    FEATURE_COUNT = 768
elif args.bert_model == 'large':
    tokenizer = BertTokenizer.from_pretrained('bert-large-cased')
    model = BertModel.from_pretrained('bert-large-cased')
    LAYER_COUNT = 24
    FEATURE_COUNT = 1024
else:
    raise ValueError("BERT model must be base or large")

model.eval()

with h5py.File(args.output_path, 'w') as fout:
    for index, line in enumerate(open(args.input_path)):
        line = line.strip()  # Remove trailing characters
def get_bert_out(output_path, local_rank, no_cuda, batch_size):
    startt = timeit.default_timer()

    if local_rank == -1 or no_cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cuda", local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {} distributed training: {}".format(
        device, n_gpu, bool(local_rank != -1)))

    model = BertModel.from_pretrained(args.bert_dir)
    model.to(device)
    # model.to(0)

    if local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    model.eval()
    sent_bert = numpy.load(output_path + "sen_bert.npy")
    sent_mask_bert = numpy.load(output_path + "sen_mask_bert.npy")

    f = open(output_path + "sent_output_bert.npy", 'ab')
    num = 0
    all_input_ids = torch.tensor(sent_bert, dtype=torch.int64).to(device)
    all_input_mask = torch.tensor(sent_mask_bert, dtype=torch.int64).to(device)
    all_example_index = torch.tensor(list(range(len(sent_bert))),
                                     dtype=torch.int64).to(device)
    eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index)
    if local_rank == -1:
        eval_sampler = SequentialSampler(eval_data)
    else:
        eval_sampler = DistributedSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=batch_size)
    for input_ids, input_mask, example_indices in eval_dataloader:
        all_encoder_layers, _ = model(input_ids,
                                      token_type_ids=None,
                                      attention_mask=input_mask)
        all_encoder_layers = all_encoder_layers
        num += len(sent_bert)
        outs = []
        for b, example_index in enumerate(example_indices):
            layer_output = all_encoder_layers[-1].detach().cpu().numpy(
            )  # last layer
            layer_output = layer_output[b][:, :512]  # sent b
            # out = [round(x.item(), 6) for x in layer_output[0]]  # [CLS]
            # outs.append(out)
            outs.append(layer_output)  # all tokens-----------------
        outs = numpy.array(outs)
        numpy.save(f, outs)

    endt = timeit.default_timer()
    print(file=sys.stderr)
    print("Total use %.3f seconds for BERT Data Generating" % (endt - startt),
          file=sys.stderr)
示例#9
0
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM, BertConfig
import logging
import torch
import spacy
import itertools
import pickle
import pandas as pd

logging.basicConfig(level=logging.INFO)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nlp = spacy.load('en_core_web_sm')
# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained(
    'models/BERT/bert-base-uncased-vocab.txt')
# Load pre-trained model (weights)
model = BertModel.from_pretrained(
    'models/BERT/bert-base-uncased.tar.gz').cuda()


def lemmatization(sent, allowed_postags=['NOUN', 'ADJ', 'VERB', 'ADV']):
    """https://spacy.io/api/annotation"""
    # texts_out = []
    # for sent in texts:
    doc = nlp(sent.lower())
    return " ".join([
        token.lemma_ if token.lemma_ not in ['-PRON-'] else '' for token in doc
        if token.pos_ in allowed_postags
    ])
    # return texts_out


def add_special_tok(sentence):
示例#10
0
 def __init__(self):
     super(Bert, self).__init__()
     self.model = BertModel.from_pretrained('./model/Japanese/')
示例#11
0
from pytorch_transformers.convert_pytorch_checkpoint_to_tf import convert_pytorch_checkpoint_to_tf
from pytorch_pretrained_bert import BertModel

model = BertModel.from_pretrained("./finetuned_full_lm")
convert_pytorch_checkpoint_to_tf(model, "./finetuned_full_lm_tf", "fine_tuned_tf")
示例#12
0
                                            test_size=0.1,
                                            random_state=1)
trainloader1 = torch.utils.data.DataLoader(dataset=MyDataset(
    train1_data, subject_data, alias_data, opt.n),
                                           batch_size=BS,
                                           shuffle=True,
                                           collate_fn=collate_fn_link)

k = opt.k
for num_words in [opt.num_words]:
    for max_len in [opt.max_len]:
        for embedding_name in [opt.pretrain]:  # ['bert','wwm','ernie']
            bert_path = './pretrain/' + embedding_name + '/'
            dataset.tokenizer = BertTokenizer.from_pretrained(bert_path +
                                                              'vocab.txt')
            dataset.BERT = BertModel.from_pretrained(bert_path).to(device)
            dataset.BERT.eval()
            dataset.max_len = max_len
            for loss_weight in [opt.loss_weight]:
                accu_ = 0
                while accu_ < k:
                    # vocab_size还有pad和unknow,要+2
                    model = Net(vocab_size=len(word_index) + 2,
                                embedding_dim=EMBEDDING_DIM,
                                num_layers=num_layers,
                                hidden_dim=hidden_dim,
                                embedding=embedding,
                                device=device).to(device)

                    optimizer = optim.Adam(model.parameters(), lr=opt.lr)
示例#13
0
def cross_validation(kfold=10):
    with open("data_new/by_article_ids.pickle", "rb") as ids_file:
        ids = pickle.load(ids_file)
    with open("data_new/preprocessed_byarticle_data.pickle", "rb")as data_file:
        data = pickle.load(data_file)
    with open("data_new/by_article_labels.pickle", "rb") as labels_file:
        labels = pickle.load(labels_file)
    with open("/home/nayeon/fakenews/data_new/vocab_trim4.pickle", 'rb') as vocab_file:
        vocab = pickle.load(vocab_file)

    if not constant.use_bert:
        # for basic LSTM model
        article_model = models.LSTM(vocab=vocab, 
                        embedding_size=constant.emb_dim, 
                        hidden_size=constant.hidden_dim, 
                        num_layers=constant.n_layers,
                        pretrain_emb=constant.pretrain_emb
                        )
        title_model = models.LSTM(vocab=vocab,
                        embedding_size=constant.emb_dim,
                        hidden_size=constant.hidden_dim_tit,
                        num_layers=constant.n_layers,
                        pretrain_emb=constant.pretrain_emb
                        )
        article_model = load_model(article_model, model_name="article_model")
        title_model = load_model(title_model, model_name="title_model")
    else:
        from pytorch_pretrained_bert import BertTokenizer, BertModel
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        bert_model = BertModel.from_pretrained('bert-base-uncased')
        if not constant.bert_from_scratch:
            state = torch.load("bert_model/pytorch_model.bin")
            bert_model.load_state_dict(state)
        article_model = bert_model
        title_model = bert_model
        if constant.use_bert_plus_lstm:
            lstm_article = nn.LSTM(input_size=768, hidden_size=constant.hidden_dim, 
                                   num_layers=constant.n_layers, bidirectional=False, batch_first=True)
            lstm_title = nn.LSTM(input_size=768, hidden_size=constant.hidden_dim_tit,
                                 num_layers=constant.n_layers, bidirectional=False, batch_first=True)
            lstm_article.load_state_dict(torch.load("bert_model/lstm_article2.bin"))
            lstm_title.load_state_dict(torch.load("bert_model/lstm_title2.bin"))
    
    # set average test acc
    avg_test_acc = 0
    best_acc = 0
    k = 0
    kf = KFold(n_splits=kfold)
    for train_index, test_index in kf.split(ids):
        k += 1
        print("k:", k)
        # get 25 true 25 false for validation #
        ids_train, ids_val = [], []
        data_train, data_val = {}, {}
        labels_train, labels_val = {}, {}
        cnt_true, cnt_false = 0, 0
        for index in train_index:
            id_ = ids[index]
            if labels[id_] == "true":
                if cnt_true < 25:
                    cnt_true += 1
                    ids_val.append(id_)
                    data_val[id_] = data[id_]
                    labels_val[id_] = labels[id_]
                else:
                    ids_train.append(id_)
                    data_train[id_] = data[id_]
                    labels_train[id_] = labels[id_]
            else:
                if cnt_false < 25:
                    cnt_false += 1
                    ids_val.append(id_)
                    data_val[id_] = data[id_]
                    labels_val[id_] = labels[id_]
                else:
                    ids_train.append(id_)
                    data_train[id_] = data[id_]
                    labels_train[id_] = labels[id_]
        # get test set from test_index
        ids_test, data_test, labels_test = [], {}, {}
        for index in test_index:
            id_ = ids[index]
            ids_test.append(id_)
            data_test[id_] = data[id_]
            labels_test[id_] = labels[id_]
        train = (ids_train, data_train, labels_train)
        val = (ids_val, data_val, labels_val)
        test = (ids_test, data_test, labels_test)

        # prepare by article cross validation data
        if constant.aug_count != '':
            data_loader_train, data_loader_val, data_loader_test, ids_val_dict, ids_test_dict = prepare_byarticle_cross_validation(train, val, test, constant.batch_size, constant.aug_count)
        else:
            data_loader_train, data_loader_val, data_loader_test = prepare_byarticle_cross_validation(train, val, test, constant.batch_size, constant.aug_count)

        # need to init the final Classifier for each fold
        if constant.use_bert:
            if constant.use_bert_plus_lstm:
                Classifier = models.Classifier(hidden_dim1=constant.hidden_dim, hidden_dim2=constant.hidden_dim_tit)
                # Classifier.load_state_dict(torch.load("bert_model/classifier_bypublisher2.bin"))
            else:
                Classifier = models.Classifier(hidden_dim1=768, hidden_dim2=768)
        else:
            Classifier = models.Classifier(hidden_dim1=constant.hidden_dim, hidden_dim2=constant.hidden_dim_tit)

        if constant.USE_CUDA:
            if constant.use_bert_plus_lstm:
                lstm_article.cuda()
                lstm_title.cuda()
            article_model.cuda()
            title_model.cuda()
            Classifier.cuda()

        criterion = nn.BCELoss()

        if constant.optimizer=='adam':
            opt = torch.optim.Adam(Classifier.parameters(), lr=constant.lr_classi, weight_decay=constant.weight_decay)
        elif constant.optimizer=='adagrad':
            opt = torch.optim.Adagrad(Classifier.parameters(), lr=constant.lr_classi)
        elif constant.optimizer=='sgd':
            opt = torch.optim.SGD(Classifier.parameters(), lr=constant.lr_classi, momentum=0.9)
        
        # set lr scheduler
        # scheduler = StepLR(opt, step_size=1, gamma=0.8)
        
        # set tensorboard folder name
        if constant.use_bert:
            experiment_name = "BERT_FineTune_aug{0}_LRlr{1}_k{2}".format(constant.aug_count, constant.lr_classi, k)
        else:
            experiment_name = "LSTM_FineTune_aug{0}_LRlr{1}_k{2}".format(constant.aug_count, constant.lr_classi, k)
        
        logdir = "tensorboard/" + experiment_name + "/"
        writer = SummaryWriter(logdir)
        global_steps = 0
        best_val_acc = 0
        # training and testifng
        for e in range(constant.max_epochs):
            # scheduler.step()
            article_model.train()
            title_model.train()
            Classifier.train()
            if constant.use_bert_plus_lstm:
                lstm_article.train()
                lstm_title.train()
            loss_log = []
            f1_log = 0
            acc_log = 0
            # training
            pbar = tqdm(enumerate(data_loader_train),total=len(data_loader_train))
            for i, (X, x_len, tit, tit_len, y, ind) in pbar:
                opt.zero_grad()
                if constant.use_bert:
                    X = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(item)) for item in X]
                    tit = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(item)) for item in tit]
                    X, segments_ids_article, tit, segments_ids_tit = padding_for_bert(X, tit)
                    if constant.USE_CUDA:
                        X, segments_ids_article, tit, segments_ids_tit, y = X.cuda(), segments_ids_article.cuda(), tit.cuda(), segments_ids_tit.cuda(), y.cuda()
                    encoded_article_layers, _ = article_model(X, segments_ids_article)
                    encoded_tit_layers, _ = title_model(tit, segments_ids_tit)
                    if constant.use_bert_plus_lstm:
                        _, article_hidden = lstm_article(encoded_article_layers[-1])
                        _, title_hidden = lstm_title(encoded_tit_layers[-1])
                        article_feat = article_hidden[-1][-1]
                        title_feat = title_hidden[-1][-1]
                    else:
                        article_feat = torch.sum(encoded_article_layers[-1], dim=1)
                        title_feat = torch.sum(encoded_tit_layers[-1], dim=1) #[batch_size, hidden_size]
                else:
                    article_feat = article_model.feature(X, x_len)
                    title_feat = title_model.feature(tit, tit_len)
                feature = torch.cat((article_feat, title_feat), dim=1)
                pred_prob = Classifier(feature)
                
                loss = criterion(pred_prob, y)
                loss.backward()
                opt.step()

                loss_log.append(loss.item())
                accuracy, microPrecision, microRecall, microF1 = getMetrics(pred_prob.detach().cpu().numpy(), y.cpu().numpy())
                f1_log += microF1
                acc_log += accuracy
                pbar.set_description("(Epoch {}) TRAIN F1:{:.4f} TRAIN LOSS:{:.4f} ACCURACY:{:.4f}".format((e+1), f1_log/float(i+1), np.mean(loss_log), acc_log/float(i+1)))

                writer.add_scalars('train', {'loss': np.mean(loss_log),
                                            'acc': acc_log/float(i+1),
                                            'f1': f1_log/float(i+1)}, global_steps)
                global_steps+=1
            
            """
                validate and test
                1. Get the test accuracy result from the model that gets the best accuracy in validation
                2. Whenever we find better accuracy result in the validation set, we need to test the model in the test 
                set and get the updated test set accuracy result.
                3. No need to save model during cross validation (cross validation is to find the best model)
            """
            article_model.eval()
            title_model.eval()
            Classifier.eval()
            if constant.use_bert_plus_lstm:
                lstm_article.eval()
                lstm_title.eval()
            print("Evaluation on validation set")
            use_add_feature_flag = constant.use_emo2vec_feat or constant.use_url
            if constant.use_bert:
                if constant.aug_count != '':
                    accuracy, pred, id_ = eval_bert_with_chunked_data(article_model, title_model, Classifier, data_loader_val, tokenizer, ids_val_dict, None, writer, e, False)
                else:
                    if constant.use_bert_plus_lstm:
                        accuracy, pred, id_ = eval_bert(article_model, title_model, Classifier, data_loader_val, tokenizer, lstm_article, lstm_title, use_add_feature_flag, writer, e, False)
                    else:
                        accuracy, pred, id_ = eval_bert(article_model, title_model, Classifier, data_loader_val, tokenizer, None, None, use_add_feature_flag, writer, e, False)
            else:
                accuracy, pred, id_ = eval_tit_lstm(article_model, title_model, Classifier, data_loader_val, use_add_feature_flag, writer, e, False)
            
            # find better accuracy in the validation set, need to test the model in the testset
            if(accuracy > best_val_acc):
                print("Find better model, test it on test set")
                best_val_acc = accuracy
                if constant.use_bert:
                    if constant.aug_count != '':
                        accuracy, pred, id_ = eval_bert_with_chunked_data(article_model, title_model, Classifier, data_loader_test, tokenizer, ids_test_dict, None, writer, e, True)
                    else:
                        if constant.use_bert_plus_lstm:
                            accuracy, pred, id_ = eval_bert(article_model, title_model, Classifier, data_loader_test, tokenizer, lstm_article, lstm_title, use_add_feature_flag, writer, e, True)
                        else:
                            accuracy, pred, id_ = eval_bert(article_model, title_model, Classifier, data_loader_test, tokenizer, None, None, use_add_feature_flag, writer, e, True)
                else:
                    accuracy, pred, id_ = eval_tit_lstm(article_model, title_model, Classifier, data_loader_test, use_add_feature_flag, writer, e, True)
                test_acc = accuracy
                if best_val_acc + test_acc > 1.53:
                    torch.save(Classifier.state_dict(), "bert_model/classifier.bin")
                    print("Classifier has been saved in bert_model/classifier.bin")
        # finish one fold, need to accumulate the test_acc (will do average of accuracy after k folds)
        avg_test_acc += test_acc
    
    # after k folds cross validation, get the final average test accuracy
    avg_test_acc = avg_test_acc * 1.0 / kfold
    print("After {0} folds cross validation, the final accuracy of {1} is {2}".format(kfold, constant.manual_name, avg_test_acc))
示例#14
0
def train(aug_count=""):
    # prepare data_loader and vocab
    if constant.train_cleaner_dataset:
        data_loader_train, data_loader_test, vocab = prepare_filtered_data(batch_size=constant.batch_size)
    else:
        data_loader_train, data_loader_test, vocab = prepare_byarticle_data(aug_count=aug_count, batch_size=constant.batch_size)
    
    # load parameters, LR is for fine tune
    if constant.use_bert:
        from pytorch_pretrained_bert import BertTokenizer, BertModel
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        bert_model = BertModel.from_pretrained('bert-base-uncased')
        if not constant.bert_from_scratch:
            state = torch.load("bert_model/pytorch_model.bin")
            bert_model.load_state_dict(state)
        article_model = bert_model
        title_model = bert_model
        # print("finish bert model loading")
        if constant.train_cleaner_dataset:
            lstm_article = nn.LSTM(input_size=768, hidden_size=constant.hidden_dim, 
                                   num_layers=constant.n_layers, bidirectional=False, batch_first=True)
            lstm_title = nn.LSTM(input_size=768, hidden_size=constant.hidden_dim_tit,
                                 num_layers=constant.n_layers, bidirectional=False, batch_first=True)
            LR = models.Classifier(hidden_dim1=constant.hidden_dim, hidden_dim2=constant.hidden_dim_tit)
        else:
            LR = models.Classifier(hidden_dim1=768, hidden_dim2=768)
    else:
        # for basic LSTM model
        article_model = models.LSTM(vocab=vocab, 
                        embedding_size=constant.emb_dim, 
                        hidden_size=constant.hidden_dim, 
                        num_layers=constant.n_layers,
                        pretrain_emb=constant.pretrain_emb,
                        )
        title_model = models.LSTM(vocab=vocab,
                        embedding_size=constant.emb_dim,
                        hidden_size=constant.hidden_dim_tit,
                        num_layers=constant.n_layers,
                        pretrain_emb=constant.pretrain_emb,
                        )
#         LR = models.LR(hidden_dim1=constant.hidden_dim, hidden_dim2=constant.hidden_dim_tit)
        LR = models.Classifier(hidden_dim1=constant.hidden_dim, hidden_dim2=constant.hidden_dim_tit)

        article_model = load_model(article_model, model_name="article_model")
        title_model = load_model(title_model, model_name="title_model")
        
    if constant.USE_CUDA:
        article_model.cuda()
        title_model.cuda()
        LR.cuda()
        if constant.train_cleaner_dataset:
            lstm_article.cuda()
            lstm_title.cuda()

    criterion = nn.BCELoss()
    
    if constant.train_cleaner_dataset:
        model = [
                {"params": lstm_article.parameters(), "lr": constant.lr_lstm},
                {"params": lstm_title.parameters(), "lr": constant.lr_title}, 
                {"params": LR.parameters(), "lr": constant.lr_classi},
            ]
        if constant.optimizer=='adam':
            opt = torch.optim.Adam(model, lr=constant.lr_classi, weight_decay=constant.weight_decay)
        elif constant.optimizer=='adagrad':
            opt = torch.optim.Adagrad(model, lr=constant.lr_classi)
        elif constant.optimizer=='sgd':
            opt = torch.optim.SGD(model, lr=constant.lr_classi, momentum=0.9)
    else:
        if constant.optimizer=='adam':
            opt = torch.optim.Adam(LR.parameters(), lr=constant.lr_classi, weight_decay=constant.weight_decay)
        elif constant.optimizer=='adagrad':
            opt = torch.optim.Adagrad(LR.parameters(), lr=constant.lr_classi)
        elif constant.optimizer=='sgd':
            opt = torch.optim.SGD(LR.parameters(), lr=constant.lr_classi, momentum=0.9)

    # test the result without fine tune
    # print("testing without fine tune")
    # accuracy = eval_tit_lstm(article_model, title_model, LR, data_loader_test, False)

    # set tensorboard folder name
    if constant.use_bert:
        experiment_name = "BERT_FineTune_aug{0}_LRlr{1}".format(constant.aug_count, constant.lr_classi)
    else:
        experiment_name = "LSTM_FineTune_aug{0}_LRlr{1}".format(constant.aug_count, constant.lr_classi)
    
    logdir = "tensorboard/" + experiment_name + "/"
    writer = SummaryWriter(logdir)

    test_best = 0
    cnt = 0
    global_steps = 0
    for e in range(constant.max_epochs):
        article_model.train()
        title_model.train()
        LR.train()
        if constant.train_cleaner_dataset:
            lstm_article.train()
            lstm_title.train()
        loss_log = []
        f1_log = 0
        acc_log = 0

        # training
        pbar = tqdm(enumerate(data_loader_train),total=len(data_loader_train))
        for i, (X, x_len, tit, tit_len, y, ind) in pbar:
            opt.zero_grad()
            if constant.use_bert:
                X = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(item)) for item in X]
                tit = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(item)) for item in tit]
                # padding
                X, segments_ids_article, tit, segments_ids_tit = padding_for_bert(X, tit)
                if constant.USE_CUDA:
                    X, segments_ids_article, tit, segments_ids_tit, y = X.cuda(), segments_ids_article.cuda(), tit.cuda(), segments_ids_tit.cuda(), y.cuda()
                encoded_article_layers, _ = article_model(X, segments_ids_article)
                encoded_tit_layers, _ = title_model(tit, segments_ids_tit)
                if constant.train_cleaner_dataset:
                    _, article_hidden = lstm_article(encoded_article_layers[-1])
                    _, title_hidden = lstm_title(encoded_tit_layers[-1])
                    article_feat = article_hidden[-1][-1]
                    title_feat = title_hidden[-1][-1]
                else:
                    article_feat = torch.sum(encoded_article_layers[-1], dim=1)
                    title_feat = torch.sum(encoded_tit_layers[-1], dim=1) #[batch_size, hidden_size]
            else:
                article_feat = article_model.feature(X, x_len)
                title_feat = title_model.feature(tit, tit_len)
            feature = torch.cat((article_feat, title_feat), dim=1)
            pred_prob = LR(feature)
            
            loss = criterion(pred_prob, y)
            loss.backward()
            opt.step()

            loss_log.append(loss.item())
            accuracy, microPrecision, microRecall, microF1 = getMetrics(pred_prob.detach().cpu().numpy(), y.cpu().numpy())
            f1_log += microF1
            acc_log += accuracy
            pbar.set_description("(Epoch {}) TRAIN F1:{:.4f} TRAIN LOSS:{:.4f} ACCURACY:{:.4f}".format((e+1), f1_log/float(i+1), np.mean(loss_log), acc_log/float(i+1)))

            writer.add_scalars('train', {'loss': np.mean(loss_log),
                                         'acc': acc_log/float(i+1),
                                         'f1': f1_log/float(i+1)}, global_steps)
            global_steps+=1
        
        article_model.eval()
        title_model.eval()
        LR.eval()
        if constant.train_cleaner_dataset:
            lstm_article.eval()
            lstm_title.eval()
        # testing
        if(e % 1 == 0):
            print("Evaluation on Test")
            use_add_feature_flag = constant.use_emo2vec_feat or constant.use_url
            if constant.use_bert:
                if constant.train_cleaner_dataset:
                    accuracy, pred, id_ = eval_bert(article_model, title_model, LR, data_loader_test, tokenizer, lstm_article, lstm_title, use_add_feature_flag, writer, e, True)
                else:
                    accuracy, pred, id_ = eval_bert(article_model, title_model, LR, data_loader_test, tokenizer, None, None, use_add_feature_flag, writer, e, True)
            else:
                accuracy, pred, id_ = eval_tit_lstm(article_model, title_model, LR, data_loader_test, use_add_feature_flag, writer, e, True)
            
            if(accuracy > test_best):
                test_best = accuracy
                print("Find better model. Saving model ...")
                cnt = 0
                if constant.train_cleaner_dataset:
                    torch.save(lstm_article.state_dict(), "bert_model/by_publisher/lstm_article_"+str(constant.hidden_dim)+"_"+str(constant.hidden_dim_tit)+"_"+str(test_best)+".bin")
                    torch.save(lstm_title.state_dict(), "bert_model/by_publisher/lstm_title_"+str(constant.hidden_dim)+"_"+str(constant.hidden_dim_tit)+"_"+str(test_best)+".bin")
                    torch.save(LR.state_dict(), "bert_model/by_publisher/classifier_bypublisher_"+str(constant.hidden_dim)+"_"+str(constant.hidden_dim_tit)+"_"+str(test_best)+".bin")
                    print("The lstm_article lstm_title classifier_bypublisher have been saved!")
                else:
                    torch.save(LR.state_dict(), "bert_model/finetune_classi_for_tunebert_"+str(accuracy)+".bin")
                    print("The fine tune classifier has been saved!")
            else:
                cnt += 1
            if(cnt == 10): 
                # save prediction and gold
                with open('pred/{0}_pred.pickle'.format(experiment_name), 'wb') as handle:
                    pickle.dump({"preds":pred, "ids":id_}, handle, protocol=pickle.HIGHEST_PROTOCOL)
                break
            if(test_best == 1.0): 
                # save prediction and gold
                with open('pred/{0}_pred.pickle'.format(experiment_name), 'wb') as handle:
                    pickle.dump({"preds":pred, "ids":id_}, handle, protocol=pickle.HIGHEST_PROTOCOL)
                break
示例#15
0
 def build_model(self):
     num_classes = len(self.class_list)
     self.model = BertWrapper(
         BertModel.from_pretrained(self.pretrained_path), num_classes)
示例#16
0
import argparse

import torch
import torch.optim as optim
import torch.multiprocessing as mp
from sklearn.model_selection import ParameterGrid
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

from kgegrok import data
from kgegrok import estimate
from kgegrok import stats
from kgegrok import evaluation
from kgegrok import utils
from kgegrok.stats import create_drawer

bert = BertModel.from_pretrained('bert-base-uncased', cache_dir='/MIUN/.bert')
bert.cuda()
bert.eval()

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
                                          cache_dir='/MIUN/.bert')

# # Tokenized input
# text = "Who was Jim Henson ? Jim Henson was a puppeteer"
# tokenized_text = tokenizer.tokenize(text)

# # Mask a token that we will try to predict back with `BertForMaskedLM`
# masked_index = 6
# tokenized_text[masked_index] = '[MASK]'
# assert tokenized_text == ['who', 'was', 'jim', 'henson', '?', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer']
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese', do_lower_case)
text = "[CLS] 李小龙是谁? [SEP]"
tokenized_text = tokenizer.tokenize(text)
# Mask a token that we will try to predict back with `BertForMaskedLM`
masked_index = 3
tokenized_text[masked_index] = '[MASK]'
# assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer', '[SEP]']
# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
# segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
segments_ids = [0, 0, 0, 0, 0, 0, 0, 0]
# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
model = BertModel.from_pretrained('bert-base-chinese')
model.eval()

# If you have a GPU, put everything on cuda
device = torch.device('cuda:2')
tokens_tensor = tokens_tensor.to(device)
segments_tensors = segments_tensors.to(device)
model.to(device)
# Predict hidden states features for each layer
with torch.no_grad():
    encoded_layers, _ = model(tokens_tensor, segments_tensors)
# We have a hidden states for each of the 12 layers in model bert-base-uncased
assert len(encoded_layers) == 12

# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-base-chinese')
示例#18
0
 def __init__(self, bert_model, temp_dir):
     super(Bert, self).__init__()
     self.model = BertModel.from_pretrained(bert_model, cache_dir=temp_dir)
示例#19
0
    def __init__(self, 
                 vocab: Vocabulary, 
                 bert_pretrained_model: str, 
                 dropout_prob: float = 0.1, 
                 max_count: int = 10,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 answering_abilities: List[str] = None,
                 number_rep: str = 'first',
                 arithmetic: str = 'base',
                 special_numbers : List[int] = None) -> None:
        super().__init__(vocab, regularizer)

        self.number_rep = number_rep
        
        self.BERT = BertModel.from_pretrained(bert_pretrained_model)
        self.tokenizer = BertTokenizer.from_pretrained(bert_pretrained_model)
        bert_dim = self.BERT.pooler.dense.out_features
        
        self.dropout = dropout_prob

        self._passage_weights_predictor = torch.nn.Linear(bert_dim, 1)
        self._question_weights_predictor = torch.nn.Linear(bert_dim, 1)
        self._number_weights_predictor = torch.nn.Linear(bert_dim, 1)
        self._arithmetic_weights_predictor = torch.nn.Linear(bert_dim, 1)
            
        self._template_predictor = \
                self.ff(2 * bert_dim, bert_dim, len(self.answering_abilities))

        if "passage_span_extraction" in self.answering_abilities:
            self._passage_span_extraction_index = self.answering_abilities.index("passage_span_extraction")
            self._passage_span_start_predictor = torch.nn.Linear(bert_dim, 1)
            self._passage_span_end_predictor = torch.nn.Linear(bert_dim, 1)

        if "question_span_extraction" in self.answering_abilities:
            self._question_span_extraction_index = self.answering_abilities.index("question_span_extraction")
            self._question_span_start_predictor = \
                self.ff(2 * bert_dim, bert_dim, 1)
            self._question_span_end_predictor = \
                self.ff(2 * bert_dim, bert_dim, 1)

        if "arithmetic" in self.answering_abilities:
            self.arithmetic = arithmetic
            self._arithmetic_index = self.answering_abilities.index("arithmetic")
            if special_numbers != None:
                self.special_numbers = special_numbers
                self.num_special_numbers = len(self.special_numbers)
                self.special_embedding = torch.nn.Embedding(self.num_special_numbers, bert_dim)
            else:
                self.num_special_numbers = 0
            if self.arithmetic == "base":
                self._number_sign_predictor = \
                    self.ff(2 * bert_dim, bert_dim, 3)
            else:
                self.init_arithmetic(bert_dim, bert_dim, bert_dim, layers=2, dropout=dropout_prob)

        if "counting" in self.answering_abilities:
            self._counting_index = self.answering_abilities.index("counting")
            self._count_number_predictor = \
                self.ff(bert_dim, bert_dim, max_count + 1) 

        self._drop_metrics = DropEmAndF1()
        initializer(self)
示例#20
0
 def __init__(self, model_path):
     self.tokenizer = BertTokenizer.from_pretrained(model_path)
     self.model = BertModel.from_pretrained(model_path)
     self.model.cuda('cuda')
     self.model.eval()
示例#21
0
文件: test.py 项目: Nested-NER/TCSF
import os
from Evaluate import Evaluate
from config import config
from model import TOI_BERT
from pytorch_pretrained_bert import BertModel

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

mode = "test"  # test dev
test_best = True
epoch_start = 1
epoch_end = 100

misc_config = pickle.load(open(config.get_pkl_path("config"), "rb"))
config.load_config(misc_config)
bert_model = BertModel.from_pretrained(
    f"{config.bert_path}{config.bert_config}")
bert_model.cuda()

bert_model.eval()

model_path = config.get_model_path() + "f1_0.771.pth"

with open(config.get_pkl_path(mode), "rb") as f:
    word_batches, char_batches, char_len_batches, pos_tag_batches, entity_batches, toi_batches, word_origin_batches = pickle.load(
        f)
print("load data from " + config.get_pkl_path(mode))

#print(model_path)
if not os.path.exists(model_path):
    print("loda model error")
print("load model from " + model_path)
示例#22
0
 def __init__(self, utt_size):
     super().__init__()
     self.bert = BertModel.from_pretrained('bert-base-uncased')
     self.linear = nn.Linear(768, utt_size)
示例#23
0
print(len(tgt_vocab))
trainloader = dataloader.get_loader(trainset, batch_size=config.batch_size, shuffle=True, num_workers=2)
# consider batch_size=1 for valid/test
validloader = dataloader.get_loader(validset, batch_size=1, shuffle=False, num_workers=2)
testloader = dataloader.get_loader(testset, batch_size=1, shuffle=False, num_workers=2)

if opt.pretrain:
    pretrain_embed={}
    pretrain_embed['slot'] = torch.load('emb_tgt_mw.pt')
    
else:
    pretrain_embed = None

# model
print('building model...\n')
bmodel = BertModel.from_pretrained(bert_type)
bmodel.eval()
if use_cuda:
    with torch.no_grad():
        bmodel.cuda()
model = getattr(models, opt.model)(config, src_vocab, tgt_vocab, use_cuda,bmodel,
                       pretrain=pretrain_embed, score_fn=opt.score) 


        
# load checkpoint - (continue training)
if opt.restore:
    model.load_state_dict(checkpoints['model'])
if use_cuda:
    model.cuda()
#This version does not support distributed/parallel training.
示例#24
0
#Experiment script to get mean vectors, attended vectors and attentions
import pickle
import torch
import tqdm
from pytorch_pretrained_bert import BertModel, BertTokenizer

import dataset
import withaugmented_w_attn as waug

if __name__ == "__main__":
    dev = 'cuda'
    tokenizer = BertTokenizer.from_pretrained('bert-large-cased',
                                              do_lower_case=False)
    bert = BertModel.from_pretrained('bert-large-cased').to(dev)
    attn = waug.MultiAttention(256, True).to('cuda')
    ds = dataset.DS_Augmented('./data/gap-coreference/gap-validation.tsv',
                              tokenizer, 0)

    attn.load_state_dict(
        torch.load('./experiments/large-19-dpr-attn/best_attn.pt'))

    generated = []
    mean = []
    attentions = []
    for elem in tqdm.tqdm(ds):
        sent, spans, y = elem
        sent = sent[None].cuda()
        spans = spans[None].cuda()

        with torch.no_grad():
            embd, _ = bert(sent)
 def __init__(self, config):
     super(Model, self).__init__()
     self.bert = BertModel.from_pretrained(config.bert_path)
     for param in self.bert.parameters():
         param.requires_grad = True
     self.fc = nn.Linear(config.hidden_size, config.num_classes)
示例#26
0
meanings = pickle.load(open(meanings_path, 'rb'))

naf_pos = set(arguments['--naf_pos'].split('-'))
use_pos_in_candidate_selection = arguments[
    '--use_pos_in_candidate_selection'] == 'yes'

# iterable
if arguments['--input_folder']:
    naf_iterable = glob('%s/*naf' % arguments["--input_folder"])
elif arguments['--input_path']:
    naf_iterable = [arguments['--input_path']]

# load Bert
tokenizer = BertTokenizer.from_pretrained(bert_model_variation,
                                          do_lower_case=True)
model = BertModel.from_pretrained(bert_model_variation)

model.cuda()
model.eval()

for naf_path in naf_iterable:

    output_path = naf_path + '.wsd'

    start_time = time_in_correct_format(datetime.now())

    doc = etree.parse(naf_path)
    naf_obj = wsd_datasets_classes.NAF(
        doc,
        naf_pos_to_consider=naf_pos,
        use_pos_in_candidate_selection=use_pos_in_candidate_selection,
示例#27
0
class imdb_dataset(Dataset):
    df = pd.Series(list(range(0, 25000))).sample(frac=1, random_state=1432)
    i2file = dict(
        zip(df.tolist()[:12500], os.listdir('./aclImdb/aclImdb/test/pos')))
    i2file.update(
        dict(zip(df.tolist()[12500:],
                 os.listdir('./aclImdb/aclImdb/test/neg'))))
    i2class = dict(zip(df.tolist()[:12500], ['pos'] * 12500))
    i2class.update(dict(zip(df.tolist()[12500:], ['neg'] * 12500)))
    path = Path('./aclImdb/aclImdb/test')
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    bert_model = BertModel.from_pretrained('bert-base-uncased')
    bert_model.eval()
    sn_model = Model()

    def __init__(self):
        pass

    def __len__(self):
        return 25000

    def __getitem__(self, idx):
        path = imdb_dataset.path / imdb_dataset.i2class[
            idx] / imdb_dataset.i2file[idx]
        review = open(str(path.absolute()), 'r', encoding='utf-8').readlines()
        #sn_model = Model()
        sn_vect = torch.tensor(imdb_dataset.sn_model.transform(review))

        max_words = 100
        tokenized_text_list = tokenize.sent_tokenize(review[0])
        if len(tokenized_text_list) > 25:
            no_of_sents = 25
            tokenized_text_list = tokenized_text_list[:25]
        else:
            no_of_sents = len(tokenized_text_list)
        word_list = []
        level_list = []
        #print (tokenized_text_list)
        for sent_index, tokenized_text in enumerate(tokenized_text_list):
            # word_dict[mode][cls][file_index][sent_index] = {}

            ## IFF Sentence is too big to handle by bert model
            if (len(tokenized_text) > 100):
                tokenized_text = tokenized_text[:100]

            # Convert token to vocabulary indices
            tokenized_text = imdb_dataset.tokenizer.tokenize(tokenized_text)
            indexed_tokens = imdb_dataset.tokenizer.convert_tokens_to_ids(
                tokenized_text)
            segments_ids = [0 for i in range(len(indexed_tokens))]
            if len(tokenized_text) != 0:
                word_list.append(len(tokenized_text))
            else:
                word_list.append(1)

            # Convert inputs to PyTorch tensors
            tokens_tensor = torch.tensor([indexed_tokens])
            segments_tensors = torch.tensor([segments_ids])

            # Predict hidden states features for each layer
            encoded_layers, _ = imdb_dataset.bert_model(
                tokens_tensor,
                segments_tensors,
                output_all_encoded_layers=False)
            encoded_layers = encoded_layers.data.to(device)

            if sent_index == 0:
                bert_review_tensor = encoded_layers
                sent_shape = bert_review_tensor.size()
                a = sent_shape[0]
                b = max_words - sent_shape[1]
                c = sent_shape[2]
                bert_review_tensor = torch.cat(
                    (bert_review_tensor.to(device), torch.zeros(
                        a, b, c).to(device)), 1)
            else:
                bert_review_tensor = imdb_dataset.get_review_tensor(
                    bert_review_tensor, encoded_layers, max_words)
            #****************DO NOT DELETE, MIGHT BE USEFUL FOR WORD LEVEL SENTIMENT NEURON VECTORS*********************
            # We have a hidden states for each of the 12 layers in model bert-base-uncased
            ## creating dictionary for word features
            # pdb.set_trace()
            # split_encoded_feature = torch.split(encoded_layers, 1, dim=1)
            #
            # for index,(raw_text,encoded_feature) in enumerate(zip(tokenized_text,split_encoded_feature)):
            # 	key = str(raw_text)+"_"+str(index)
            # 	value = torch.squeeze(encoded_feature)
            # 	word_dict[mode][cls][file_index][sent_index][key] = value #torch tensoor size[786]
            # ****************DO NOT DELETE, MIGHT BE USEFUL FOR WORD LEVEL SENTIMENT NEURON VECTORS*********************
            del tokens_tensor
            del segments_tensors
            # del split_encoded_feature
            del encoded_layers
            #del value
            del _
            gc.collect()
        if imdb_dataset.i2class[idx] == 'pos':
            #y_label = torch.tensor([0,1])
            level_list.append(1)
        elif imdb_dataset.i2class[idx] == 'neg':
            #y_label = torch.tensor([1,0])
            level_list.append(0)
        max_sents = 25
        word_list = word_list + [1] * (max_sents - no_of_sents)
        bert_review_tensor = imdb_dataset.get_final_bert_tensor(
            bert_review_tensor, max_sents)
        y_label = torch.LongTensor(level_list[:]).squeeze()
        gc.collect()
        #return (sn_vect,bert_review_tensor,no_of_sents,word_list,y_label)#
        return (sn_vect, bert_review_tensor, no_of_sents, word_list, y_label
                )  #

    @staticmethod
    def get_final_bert_tensor(bert_review_tensor, max_sents):
        shape = bert_review_tensor.size()
        a = max_sents - shape[0]
        b = shape[1]
        c = shape[2]
        return torch.cat((bert_review_tensor, torch.zeros(a, b, c)), 0)

    @staticmethod
    def get_review_tensor(review_tensor, sent_tensor, max_words):
        sent_shape = sent_tensor.size()
        a = sent_shape[0]
        b = max_words - sent_shape[1]
        c = sent_shape[2]
        sent_tensor = torch.cat(
            (sent_tensor.to(device), torch.zeros(a, b, c).to(device)), 1)
        return torch.cat((review_tensor, sent_tensor), 0)
示例#28
0
 def __init__(self, requires_grad=False):
     super().__init__()
     self.bert = BertModel.from_pretrained(bert_path)
     for param in self.bert.parameters():
         param.requires_grad = requires_grad
示例#29
0
def train(tdl, vdl, args):
    dev = 'cuda' if torch.cuda.is_available() else 'cpu'

    bert = BertModel.from_pretrained('bert-large-cased').to(dev)
    attn = MultiAttention(256).to(dev)
    if args.pre_attn:
        attn.load_state_dict(torch.load(args.logdir + '/' + args.pre_attn))
    model = withaugmented.Model(dropout=.5).to(dev)
    if args.pre_model:
        model.load_state_dict(torch.load(args.logdir + '/' + args.pre_model))
    loss_fn = nn.CrossEntropyLoss()  #torch.tensor([0.15, 0.15, .7]).to(dev))
    # optim = torch.optim.SGD(model.parameters(), lr=0.001, weight_decay=.001)
    optim = torch.optim.Adam(model.parameters(), weight_decay=0.001)
    if args.pre_opt:
        optim.load_state_dict(torch.load(args.logdir + '/' + args.pre_opt))
    writer = SummaryWriter(args.logdir + '/Log')

    best_val_loss = 100
    last_updated = 0
    total_loss = 0
    for epoch in tqdm.trange(100):
        model.train()
        for elem in tqdm.tqdm(tdl, "Epoch %d" % epoch):
            sent, spans, y = [e.to(dev) for e in elem]
            with torch.no_grad():
                embd, _ = bert(sent)
            embd = embd[args.bert_layer]
            a_mask, b_mask = generate_masks(embd, spans, dev)
            p = embd[range(embd.size(0)), spans[:, 4]]
            a, b = attn(embd, p, a_mask, b_mask)

            optim.zero_grad()
            yp = model(a, b, p)
            loss = loss_fn(yp, y)
            total_loss += float(loss)
            loss.backward()
            optim.step()
        total_loss /= len(tdl)
        writer.add_scalar('train/loss', total_loss, epoch)

        model.eval()
        accu = 0.
        totl = 0
        with torch.no_grad():
            for elem in vdl:
                sent, spans, y = [e.to(dev) for e in elem]
                with torch.no_grad():
                    embd, _ = bert(sent)
                embd = embd[args.bert_layer]
                a_mask, b_mask = generate_masks(embd, spans, dev)
                p = embd[range(embd.size(0)), spans[:, 4]]
                a, b = attn(embd, p, a_mask, b_mask)

                yp = model(a, b, p)
                l = F.cross_entropy(yp, y, reduction='none')
                yp = torch.softmax(yp, -1).argmax(-1)
                loss += float(l.sum())
                accu += float((yp == y).sum())
                totl += len(y)
            accu = (1. * accu) / totl
            loss /= totl
            writer.add_scalar('valid/loss', float(loss), epoch)
            writer.add_scalar('valid/accuracy', float(accu), epoch)
            best = ' '
            if best_val_loss > loss:
                last_updated = 0
                best_val_loss = loss
                torch.save(model.state_dict(),
                           args.logdir + '/best_model-%.4f.pt' % best_val_loss)
                torch.save(optim.state_dict(),
                           args.logdir + '/best_optim-%.4f.pt' % best_val_loss)
                torch.save(attn.state_dict(),
                           args.logdir + '/best_attn-%.4f.pt' % best_val_loss)
                best = '*'
            else:
                last_updated += 1
                if last_updated > args.patience:
                    return model
            tqdm.tqdm.write("Epoch: %d, Tr.loss: %.4f  Vl.loss: %.4f %s" %
                            (epoch, total_loss, loss, best))
示例#30
0
 def __init__(self, config):
     BertPreTrainedModel.__init__(self, config)
     self.bert = BertModel(config)
     self.apply(self.init_bert_weights)
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
print(tokenizer.convert_tokens_to_ids('hello'))  # [1044, 1041, 1048, 1048, 1051]
print(tokenizer.convert_tokens_to_ids(['hello']))  # [7592]
print(tokenizer.convert_tokens_to_ids(['[hello]']))  # KeyError: '[hello]'; can not deal with OOV
print(indexed_tokens)  # [101, 2040, 2001, 3958, 27227, 1029, 102, 3958, 103, 2001, 1037, 13997, 11510, 102]
## Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]  # tokenized_text 分为两句, 前 7 个词一句, 后七个词一句

##################################################################
## BertModel
## Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens]); print(tokens_tensor.shape)  # torch.Size([1, 14])
segments_tensors = torch.tensor([segments_ids])

## Load pre-trained model (weights)
model = BertModel.from_pretrained(home + '/datasets/WordVec/pytorch_pretrained_bert/bert-large-uncased/')
model.eval()

## Predict hidden states features for each layer
print(tokens_tensor.shape)  # torch.Size([1, 14])
with torch.no_grad():
    encoded_layers, _ = model(tokens_tensor, segments_tensors)
## We have a hidden states for each of the 24 layers in model bert-large-uncased
print(len(encoded_layers))  # 24
print(encoded_layers[0].shape)  # torch.Size([1, 14, 1024])
x = torch.LongTensor([[1, 2], [3, 4]]); print(x.shape)  # torch.Size([2, 2])
print(modelfj)

##################################################################
## BertForMaskedLM
model = BertForMaskedLM.from_pretrained('/Users/coder352/datasets/WordVec/pytorch_pretrained_bert/bert-large-uncased/')
示例#32
0
def main():
    parser = argparse.ArgumentParser(
        description='Tuning with bi-directional Tree-LSTM-CRF')
    parser.add_argument('--model_mode',
                        choices=[
                            'elmo', 'elmo_crf', 'elmo_bicrf', 'elmo_lveg',
                            'bert', 'elmo_la'
                        ])
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        help='Number of batch')
    parser.add_argument('--epoch', type=int, default=50, help='run epoch')
    parser.add_argument(
        '--optim_method',
        choices=['SGD', 'Adadelta', 'Adagrad', 'Adam', 'RMSprop'],
        help='optimaize method')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.01,
                        help='Learning rate')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        help='momentum factor')
    parser.add_argument('--decay_rate',
                        type=float,
                        default=0.1,
                        help='Decay rate of learning rate')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.0,
                        help='weight for regularization')
    parser.add_argument('--schedule',
                        type=int,
                        default=5,
                        help='schedule for learning rate decay')

    parser.add_argument(
        '--embedding',
        choices=['glove', 'senna', 'sskip', 'polyglot', 'random'],
        help='Embedding for words',
        required=True)
    parser.add_argument('--embedding_path', help='path for embedding dict')
    parser.add_argument('--train', type=str, default='/path/to/SST/train.txt')
    parser.add_argument('--dev', type=str, default='/path/to/SST/dev.txt')
    parser.add_argument('--test', type=str, default='/path/to/SST/test.txt')
    parser.add_argument('--num_labels', type=int, default=5)
    parser.add_argument('--embedding_p',
                        type=float,
                        default=0.5,
                        help="Dropout prob for embedding")
    parser.add_argument(
        '--component_num',
        type=int,
        default=1,
        help='the component number of mixture gaussian in LVeG')
    parser.add_argument('--gaussian_dim',
                        type=int,
                        default=1,
                        help='the gaussian dim in LVeG')
    parser.add_argument('--tensorboard', action='store_true')
    parser.add_argument('--td_name',
                        type=str,
                        default='default',
                        help='the name of this test')
    parser.add_argument('--td_dir', type=str, required=True)
    parser.add_argument(
        '--elmo_weight',
        type=str,
        default='/path/to/elmo/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5')
    parser.add_argument(
        '--elmo_config',
        type=str,
        default='/path/to/elmo//elmo_2x4096_512_2048cnn_2xhighway_options.json'
    )
    parser.add_argument('--elmo_input', action='store_true')
    parser.add_argument('--elmo_output', action='store_true')
    parser.add_argument('--elmo_preencoder_dim', type=str, default='300')
    parser.add_argument('--elmo_preencoder_p', type=str, default='0.25')
    parser.add_argument('--elmo_encoder_dim', type=int, default=300)
    parser.add_argument('--elmo_integrtator_dim', type=int, default=300)
    parser.add_argument('--elmo_integrtator_p', type=float, default=0.1)
    parser.add_argument('--elmo_output_dim', type=str, default='1200,600')
    parser.add_argument('--elmo_output_p', type=str, default='0.2,0.3,0.0')
    parser.add_argument('--elmo_output_pool_size', type=int, default=4)
    parser.add_argument('--bert_pred_dropout', type=float, default=0.1)
    parser.add_argument('--bert_dir', type=str, default='path/to/bert/')
    parser.add_argument('--bert_model',
                        choices=[
                            'bert-base-uncased', 'bert-large-uncased',
                            'bert-base-cased', 'bert-large-cased'
                        ])
    parser.add_argument('--random_seed', type=int, default=48)
    parser.add_argument('--pcfg_init',
                        action='store_true',
                        help='init the crf or lveg weight according to the '
                        'distribution of trainning dataset')
    parser.add_argument('--save_model', action='store_true', help='save_model')
    parser.add_argument('--load_model', action='store_true', help='load_model')
    parser.add_argument('--model_path', default='./model/')
    parser.add_argument('--model_name', default=None)

    # load tree
    args = parser.parse_args()
    print(args)
    logger = get_logger("SSTLogger")

    # set random seed
    random_seed = args.random_seed
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)
    myrandom = Random(random_seed)

    batch_size = args.batch_size
    embedd_mode = args.embedding
    model_mode = args.model_mode
    num_labels = args.num_labels

    elmo = model_mode.find('elmo') != -1
    bert = model_mode.find('bert') != -1

    elmo_weight = args.elmo_weight
    elmo_config = args.elmo_config

    load_model = args.load_model
    save_model = args.save_model
    model_path = args.model_path
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    model_name = args.model_name
    if save_model:
        model_name = model_path + '/' + model_mode + datetime.datetime.now(
        ).strftime("%H%M%S")
    if load_model:
        model_name = model_path + '/' + model_name

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    all_cite_version = [
        'fine_phase', 'fine_sents', 'bin_phase', 'bin_sents', 'bin_phase_v2',
        'bin_sents_v2', 'full_bin_phase', 'full_bin_phase_v2'
    ]

    if args.tensorboard:
        summary_writer = SummaryWriter(log_dir=args.td_dir + '/' +
                                       args.td_name)
        summary_writer.add_text('parameters', str(args))
    else:
        summary_writer = None

    def add_scalar_summary(summary_writer, name, value, step):
        if summary_writer is None:
            return
        if torch.is_tensor(value):
            value = value.item()
        summary_writer.add_scalar(tag=name,
                                  scalar_value=value,
                                  global_step=step)

    # ELMO PART
    # allennlp prepare part
    # build Vocabulary
    if elmo:
        elmo_model = Elmo(elmo_config,
                          elmo_weight,
                          1,
                          requires_grad=False,
                          dropout=0.0).to(device)
        token_indexers = {
            'tokens': SingleIdTokenIndexer(),
            'elmo': ELMoTokenCharactersIndexer()
        }
        train_reader = StanfordSentimentTreeBankDatasetReader(
            token_indexers=token_indexers, use_subtrees=False)
        dev_reader = StanfordSentimentTreeBankDatasetReader(
            token_indexers=token_indexers, use_subtrees=False)

        allen_train_dataset = train_reader.read(args.train)
        allen_dev_dataset = dev_reader.read(args.dev)
        allen_test_dataset = dev_reader.read(args.test)
        allen_vocab = Vocabulary.from_instances(
            allen_train_dataset + allen_dev_dataset + allen_test_dataset,
            min_count={'tokens': 1})
        # Build Embddering Layer
        if embedd_mode != 'random':
            params = Params({
                'embedding_dim': 300,
                'pretrained_file': args.embedding_path,
                'trainable': False
            })

            embedding = Embedding.from_params(allen_vocab, params)
            embedder = BasicTextFieldEmbedder({'tokens': embedding})
        else:
            # alert not random init here!
            embedder = None
            pass
    else:
        elmo_model = None
        token_indexers = None
        embedder = None
        allen_vocab = None

    if bert:
        bert_path = args.bert_dir + '/' + args.bert_model
        bert_model = BertModel.from_pretrained(bert_path +
                                               '.tar.gz').to(device)
        if bert_path.find('large') != -1:
            bert_dim = 1024
        else:
            bert_dim = 768
        for parameter in bert_model.parameters():
            parameter.requires_grad = False
        bert_model.eval()
        bert_tokenizer = BertTokenizer.from_pretrained(
            bert_path + 'txt', do_lower_case=args.lower)
    else:
        bert_model = None
        bert_tokenizer = None
        bert_dim = 768

    logger.info("constructing network...")

    # alphabet
    word_alphabet = Alphabet('word', default_value=True)
    # Read data
    logger.info("Reading Data")

    train_dataset = read_sst_data(args.train,
                                  word_alphabet,
                                  random=myrandom,
                                  merge=True)
    dev_dataset = read_sst_data(args.dev,
                                word_alphabet,
                                random=myrandom,
                                merge=True)
    test_dataset = read_sst_data(args.test,
                                 word_alphabet,
                                 random=myrandom,
                                 merge=True)
    word_alphabet.close()

    if num_labels == 3:
        train_dataset.convert_to_3_class()
        dev_dataset.convert_to_3_class()
        test_dataset.convert_to_3_class()

    # PCFG init
    if args.pcfg_init and (str.lower(model_mode).find('crf') != -1
                           or str.lower(model_mode).find('lveg') != -1):
        if str.lower(model_mode).find('bicrf') != -1 or str.lower(
                model_mode).find('lveg') != -1:
            dim = 3
        else:
            dim = 2
        trans_matrix = train_dataset.collect_rule_count(dim,
                                                        num_labels,
                                                        smooth=True)
    else:
        trans_matrix = None

    pre_encode_dim = [int(dim) for dim in args.elmo_preencoder_dim.split(',')]
    pre_encode_layer_dropout_prob = [
        float(prob) for prob in args.elmo_preencoder_p.split(',')
    ]
    output_dim = [int(dim)
                  for dim in args.elmo_output_dim.split(',')] + [num_labels]
    output_dropout = [float(prob) for prob in args.elmo_output_p.split(',')]

    if model_mode == 'elmo':
        # fixme ugly word dim
        network = Biattentive(
            vocab=allen_vocab,
            embedder=embedder,
            embedding_dropout_prob=args.embedding_p,
            word_dim=300,
            use_input_elmo=args.elmo_input,
            pre_encode_dim=pre_encode_dim,
            pre_encode_layer_dropout_prob=pre_encode_layer_dropout_prob,
            encode_output_dim=args.elmo_encoder_dim,
            integrtator_output_dim=args.elmo_integrtator_dim,
            integrtator_dropout=args.elmo_integrtator_p,
            use_integrator_output_elmo=args.elmo_output,
            output_dim=output_dim,
            output_pool_size=args.elmo_output_pool_size,
            output_dropout=output_dropout,
            elmo=elmo_model,
            token_indexer=token_indexers,
            device=device).to(device)
    elif model_mode == 'elmo_crf':
        network = CRFBiattentive(
            vocab=allen_vocab,
            embedder=embedder,
            embedding_dropout_prob=args.embedding_p,
            word_dim=300,
            use_input_elmo=args.elmo_input,
            pre_encode_dim=pre_encode_dim,
            pre_encode_layer_dropout_prob=pre_encode_layer_dropout_prob,
            encode_output_dim=args.elmo_encoder_dim,
            integrtator_output_dim=args.elmo_integrtator_dim,
            integrtator_dropout=args.elmo_integrtator_p,
            use_integrator_output_elmo=args.elmo_output,
            output_dim=output_dim,
            output_pool_size=args.elmo_output_pool_size,
            output_dropout=output_dropout,
            elmo=elmo_model,
            token_indexer=token_indexers,
            device=device,
            trans_mat=trans_matrix).to(device)
    elif model_mode == 'elmo_bicrf':
        network = BiCRFBiattentive(
            vocab=allen_vocab,
            embedder=embedder,
            embedding_dropout_prob=args.embedding_p,
            word_dim=300,
            use_input_elmo=args.elmo_input,
            pre_encode_dim=pre_encode_dim,
            pre_encode_layer_dropout_prob=pre_encode_layer_dropout_prob,
            encode_output_dim=args.elmo_encoder_dim,
            integrtator_output_dim=args.elmo_integrtator_dim,
            integrtator_dropout=args.elmo_integrtator_p,
            use_integrator_output_elmo=args.elmo_output,
            output_dim=output_dim,
            output_pool_size=args.elmo_output_pool_size,
            output_dropout=output_dropout,
            elmo=elmo_model,
            token_indexer=token_indexers,
            device=device,
            trans_mat=trans_matrix).to(device)
    elif model_mode == 'elmo_lveg':
        network = LVeGBiattentive(
            vocab=allen_vocab,
            embedder=embedder,
            embedding_dropout_prob=args.embedding_p,
            word_dim=300,
            use_input_elmo=args.elmo_input,
            pre_encode_dim=pre_encode_dim,
            pre_encode_layer_dropout_prob=pre_encode_layer_dropout_prob,
            encode_output_dim=args.elmo_encoder_dim,
            integrtator_output_dim=args.elmo_integrtator_dim,
            integrtator_dropout=args.elmo_integrtator_p,
            use_integrator_output_elmo=args.elmo_output,
            output_dim=output_dim,
            output_pool_size=args.elmo_output_pool_size,
            output_dropout=output_dropout,
            elmo=elmo_model,
            token_indexer=token_indexers,
            device=device,
            gaussian_dim=args.gaussian_dim,
            component_num=args.component_num,
            trans_mat=trans_matrix).to(device)
    elif model_mode == 'elmo_la':
        network = LABiattentive(
            vocab=allen_vocab,
            embedder=embedder,
            embedding_dropout_prob=args.embedding_p,
            word_dim=300,
            use_input_elmo=args.elmo_input,
            pre_encode_dim=pre_encode_dim,
            pre_encode_layer_dropout_prob=pre_encode_layer_dropout_prob,
            encode_output_dim=args.elmo_encoder_dim,
            integrtator_output_dim=args.elmo_integrtator_dim,
            integrtator_dropout=args.elmo_integrtator_p,
            use_integrator_output_elmo=args.elmo_output,
            output_dim=output_dim,
            output_pool_size=args.elmo_output_pool_size,
            output_dropout=output_dropout,
            elmo=elmo_model,
            token_indexer=token_indexers,
            device=device,
            comp=args.component_num,
            trans_mat=trans_matrix).to(device)
    elif model_mode == 'bert':
        # alert should be 2 classification, should test original model first
        network = BertClassification(tokenizer=bert_tokenizer,
                                     pred_dim=bert_dim,
                                     pred_dropout=args.bert_pred_dropout,
                                     bert=bert_model,
                                     num_labels=num_labels,
                                     device=device)
    else:
        raise NotImplementedError

    if load_model:
        logger.info('Load model from:' + model_name)
        network.load_state_dict(torch.load(model_name))

    optim_method = args.optim_method
    learning_rate = args.learning_rate
    lr = learning_rate
    momentum = args.momentum
    decay_rate = args.decay_rate
    gamma = args.gamma
    schedule = args.schedule

    # optim init
    if optim_method == 'SGD':
        optimizer = optim.SGD(network.parameters(),
                              lr=lr,
                              momentum=momentum,
                              weight_decay=gamma,
                              nesterov=True)
    elif optim_method == 'Adam':
        # default lr is 0.001
        optimizer = optim.Adam(network.parameters(), lr=lr, weight_decay=gamma)
    elif optim_method == 'Adadelta':
        # default lr is 1.0
        optimizer = optim.Adadelta(network.parameters(),
                                   lr=lr,
                                   weight_decay=gamma)
    elif optim_method == 'Adagrad':
        # default lr is 0.01
        optimizer = optim.Adagrad(network.parameters(),
                                  lr=lr,
                                  weight_decay=gamma)
    elif optim_method == 'RMSprop':
        # default lr is 0.01
        optimizer = optim.RMSprop(network.parameters(),
                                  lr=lr,
                                  weight_decay=gamma,
                                  momentum=momentum)
    else:
        raise NotImplementedError("Not Implement optim Method: " +
                                  optim_method)
    logger.info("Optim mode: " + optim_method)

    # dev and test
    dev_correct = {
        'fine_phase': 0.0,
        'fine_sents': 0.0,
        'bin_phase': 0.0,
        'bin_sents': 0.0,
        'bin_phase_v2': 0.0,
        'bin_sents_v2': 0.0,
        'full_bin_phase': 0.0,
        'full_bin_phase_v2': 0.0
    }
    best_epoch = {
        'fine_phase': 0,
        'fine_sents': 0,
        'bin_phase': 0,
        'bin_sents': 0,
        'bin_phase_v2': 0,
        'bin_sents_v2': 0,
        'full_bin_phase': 0,
        'full_bin_phase_v2': 0
    }
    test_correct = {}
    for key in all_cite_version:
        test_correct[key] = {
            'fine_phase': 0.0,
            'fine_sents': 0.0,
            'bin_phase': 0.0,
            'bin_sents': 0.0,
            'bin_phase_v2': 0.0,
            'bin_sents_v2': 0.0,
            'full_bin_phase': 0.0,
            'full_bin_phase_v2': 0.0
        }
    test_total = {
        'fine_phase': 0.0,
        'fine_sents': 0.0,
        'bin_phase': 0.0,
        'bin_sents': 0.0,
        'full_bin_phase': 0.0
    }

    def log_print(name, fine_phase_acc, fine_sents_acc, bin_sents_acc,
                  bin_phase_v2_acc):
        print(
            name +
            ' phase acc: %.2f%%, sents acc: %.2f%%, binary sents acc: %.2f%%, binary phase acc: %.2f%%,'
            %
            (fine_phase_acc, fine_sents_acc, bin_sents_acc, bin_phase_v2_acc))

    for epoch in range(1, args.epoch + 1):
        train_dataset.shuffle()
        print(
            'Epoch %d (optim_method=%s, learning rate=%.4f, decay rate=%.4f (schedule=%d)): '
            % (epoch, optim_method, lr, decay_rate, schedule))
        time.sleep(1)
        start_time = time.time()
        train_err = 0.0
        train_p_total = 0.0

        network.train()
        optimizer.zero_grad()
        forest = []
        for i in tqdm(range(len(train_dataset))):
            tree = train_dataset[i]
            forest.append(tree)
            output_dict = network.loss(tree)
            loss = output_dict['loss']
            a_tree_p_cnt = 2 * tree.length - 1
            loss.backward()

            train_err += loss.item()
            train_p_total += a_tree_p_cnt
            if i % batch_size == 0 and i != 0:
                optimizer.step()
                optimizer.zero_grad()
                for learned_tree in forest:
                    learned_tree.clean()
                forest = []

        optimizer.step()
        optimizer.zero_grad()
        for learned_tree in forest:
            learned_tree.clean()
        train_time = time.time() - start_time

        time.sleep(0.5)

        logger.info(
            'train: %d/%d loss: %.4f, time used : %.2fs' %
            (epoch, args.epoch, train_err / len(train_dataset), train_time))

        add_scalar_summary(summary_writer, 'train/loss',
                           train_err / len(train_dataset), epoch)

        if save_model:
            logger.info('Save model to ' + model_name + '_' + str(epoch))
            torch.save(network.state_dict(), model_name + '_' + str(epoch))

        network.eval()
        dev_corr = {
            'fine_phase': 0.0,
            'fine_sents': 0.0,
            'bin_phase': 0.0,
            'bin_sents': 0.0,
            'bin_phase_v2': 0.0,
            'bin_sents_v2': 0.0,
            'full_bin_phase': 0.0,
            'full_bin_phase_v2': 0.0
        }
        dev_tot = {
            'fine_phase': 0.0,
            'fine_sents': float(len(dev_dataset)),
            'bin_phase': 0.0,
            'bin_sents': 0.0,
            'bin_phase_v2': 0.0,
            'bin_sents_v2': 0.0,
            'full_bin_phase': 0.0,
            'full_bin_phase_v2': 0.0
        }
        final_test_corr = {
            'fine_phase': 0.0,
            'fine_sents': 0.0,
            'bin_phase': 0.0,
            'bin_sents': 0.0,
            'bin_phase_v2': 0.0,
            'bin_sents_v2': 0.0,
            'full_bin_phase': 0.0,
            'full_bin_phase_v2': 0.0
        }
        for i in tqdm(range(len(dev_dataset))):
            tree = dev_dataset[i]
            output_dict = network.predict(tree)
            p_corr, preds, bin_corr, bin_preds, bin_mask = output_dict['corr'], output_dict['label'], \
                                                           output_dict['binary_corr'], output_dict['binary_pred'], \
                                                           output_dict['binary_mask']

            dev_tot['fine_phase'] += preds.size

            dev_corr['fine_phase'] += p_corr.sum()
            dev_corr['fine_sents'] += p_corr[-1]
            dev_corr['full_bin_phase'] += bin_corr[0].sum()

            if len(bin_corr) == 2:
                dev_corr['full_bin_phase_v2'] += bin_corr[1].sum()
            else:
                dev_corr['full_bin_phase_v2'] = dev_corr['full_bin_phase']
            dev_tot['full_bin_phase'] += bin_mask.sum()

            if tree.label != int(num_labels / 2):
                dev_corr['bin_phase'] += bin_corr[0].sum()
                dev_tot['bin_phase'] += bin_mask.sum()
                dev_corr['bin_sents'] += bin_corr[0][-1]
                if len(bin_corr) == 2:
                    dev_corr['bin_phase_v2'] += bin_corr[1].sum()
                    dev_corr['bin_sents_v2'] += bin_corr[1][-1]
                else:
                    dev_corr['bin_phase_v2'] = dev_corr['bin_phase']
                    dev_corr['bin_sents_v2'] = dev_corr['bin_sents']
                dev_tot['bin_sents'] += 1.0

            tree.clean()

        time.sleep(1)

        dev_tot['bin_phase_v2'] = dev_tot['bin_phase']
        dev_tot['bin_sents_v2'] = dev_tot['bin_sents']
        dev_tot['full_bin_phase_v2'] = dev_tot['full_bin_phase']

        for key in all_cite_version:
            add_scalar_summary(summary_writer, 'dev/' + key,
                               (dev_corr[key] * 100 / dev_tot[key]), epoch)

        log_print('dev', dev_corr['fine_phase'] * 100 / dev_tot['fine_phase'],
                  dev_corr['fine_sents'] * 100 / dev_tot['fine_sents'],
                  dev_corr['bin_sents'] * 100 / dev_tot['bin_sents'],
                  dev_corr['bin_phase_v2'] * 100 / dev_tot['bin_phase'])

        update = []
        for key in all_cite_version:
            if dev_corr[key] > dev_correct[key]:
                update.append(key)

        # if dev_s_corr > dev_s_correct:

        if len(update) > 0:
            for key in update:
                dev_correct[key] = dev_corr[key]
                # update corresponding test dict cache
                test_correct[key] = {
                    'fine_phase': 0.0,
                    'fine_sents': 0.0,
                    'bin_phase': 0.0,
                    'bin_sents': 0.0,
                    'bin_phase_v2': 0.0,
                    'bin_sents_v2': 0.0,
                    'full_bin_phase': 0.0,
                    'full_bin_phase_v2': 0.0
                }
                best_epoch[key] = epoch
            test_total = {
                'fine_phase': 0.0,
                'fine_sents': float(len(test_dataset)),
                'bin_phase': 0.0,
                'bin_sents': 0.0,
                'bin_phase_v2': 0.0,
                'bin_sents_v2': 0.0,
                'full_bin_phase': 0.0,
                'full_bin_phase_v2': 0.0
            }

            time.sleep(1)

            for i in tqdm(range(len(test_dataset))):
                tree = test_dataset[i]
                output_dict = network.predict(tree)
                p_corr, preds, bin_corr, bin_preds, bin_mask = output_dict['corr'], output_dict['label'], \
                                                               output_dict['binary_corr'], output_dict['binary_pred'], \
                                                               output_dict['binary_mask']
                # count total number
                test_total['fine_phase'] += preds.size
                test_total['full_bin_phase'] += bin_mask.sum()
                if tree.label != int(num_labels / 2):
                    test_total['bin_phase'] += bin_mask.sum()
                    test_total['bin_sents'] += 1.0

                for key in update:
                    test_correct[key]['fine_phase'] += p_corr.sum()
                    test_correct[key]['fine_sents'] += p_corr[-1]
                    test_correct[key]['full_bin_phase'] += bin_corr[0].sum()

                    if len(bin_corr) == 2:
                        test_correct[key]['full_bin_phase_v2'] += bin_corr[
                            1].sum()
                    else:
                        test_correct[key]['full_bin_phase_v2'] = test_correct[
                            key]['full_bin_phase']

                    if tree.label != int(num_labels / 2):
                        test_correct[key]['bin_phase'] += bin_corr[0].sum()
                        test_correct[key]['bin_sents'] += bin_corr[0][-1]

                        if len(bin_corr) == 2:
                            test_correct[key]['bin_phase_v2'] += bin_corr[
                                1].sum()
                            test_correct[key]['bin_sents_v2'] += bin_corr[1][
                                -1]
                        else:
                            test_correct[key]['bin_phase_v2'] = test_correct[
                                key]['bin_phase']
                            test_correct[key]['bin_sents_v2'] = test_correct[
                                key]['bin_sents']

                tree.clean()

            time.sleep(1)

            test_total['bin_phase_v2'] = test_total['bin_phase']
            test_total['bin_sents_v2'] = test_total['bin_sents']
            test_total['full_bin_phase_v2'] = test_total['full_bin_phase']

        for key1 in all_cite_version:
            best_score = 0.0
            for key2 in all_cite_version:
                if test_correct[key2][key1] > best_score:
                    best_score = test_correct[key2][key1]
            final_test_corr[key1] = best_score

        for key in all_cite_version:
            add_scalar_summary(summary_writer, 'test/' + key,
                               (final_test_corr[key] * 100 / test_total[key]),
                               epoch)

        log_print(
            'Best ' + str(epoch) + ' Final test_',
            final_test_corr['fine_phase'] * 100 / test_total['fine_phase'],
            final_test_corr['fine_sents'] * 100 / test_total['fine_sents'],
            final_test_corr['bin_sents'] * 100 / test_total['bin_sents'],
            final_test_corr['bin_phase_v2'] * 100 / test_total['bin_phase_v2'])

        if optim_method == "SGD" and epoch % schedule == 0:
            lr = learning_rate / (epoch * decay_rate)
            optimizer = optim.SGD(network.parameters(),
                                  lr=lr,
                                  momentum=momentum,
                                  weight_decay=gamma,
                                  nesterov=True)

    if args.tensorboard:
        summary_writer.close()
    else:
        pass