コード例 #1
0
ファイル: run_mrc.py プロジェクト: lyj555/LICS2021_MRC
 def train_and_eval(self, args):
     self._initialize_run_env(args.device, args.seed)
     self._initialize_model(args.model_type, args.pretrained_model_path)
     self.data_helper = DataHelper(self.tokenizer, args.batch_size,
                                   args.doc_stride, args.max_seq_length)
     # start training
     if args.do_train:
         logging.info("start training...")
         self._start_train(args)
         logging.info("train success.")
     # start evaluation
     if args.do_eval:
         logging.info("start evaluating...")
         assert len(args.eval_files) == 1, "if do_eval, then eval_files must have one!!!"
         eval_file_path = args.eval_files[0]
         self.predict([eval_file_path], args.output_dir, args.max_answer_length,
                      args.cls_threshold, args.n_best_size)
         file_name = os.path.basename(eval_file_path).replace(".json", "")
         pred_file_path = os.path.join(args.output_dir, file_name + '_predictions.json')
         self._evaluate(eval_file_path, pred_file_path, args.tag)
         # confirm threshold
         confirm_threshold(eval_file_path, args.output_dir, file_name)
         logging.info("evaluate success.")
     # start predicting
     if args.do_predict:
         logging.info("start predicting...")
         self.predict(args.predict_files, args.output_dir, args.max_answer_length,
                      args.cls_threshold, args.n_best_size)
         logging.info("predict success.")
コード例 #2
0
def train_lstm():
    batch_size = 100
    num_layers = 3
    num_directions = 2
    embedding_size = 100
    hidden_size = 64
    learning_rate = 0.0001
    num_epochs = 5

    data_helper = DataHelper()
    train_text, train_labels, ver_text, ver_labels, test_text, test_labels = data_helper.get_data_and_labels()
    word_set = data_helper.get_word_set()
    vocab = data_helper.get_word_dict()
    words_length = len(word_set) + 2

    lstm = LSTM(words_length, embedding_size, hidden_size, num_layers, num_directions, batch_size)
    X = [[vocab[word] for word in sentence.split(' ')] for sentence in train_text]
    X_lengths = [len(sentence) for sentence in X]
    pad_token = vocab['<PAD>']
    longest_sent = max(X_lengths)
    b_size = len(X)
    padded_X = np.ones((b_size, longest_sent)) * pad_token
    for i, x_len in enumerate(X_lengths):
        sequence = X[i]
        padded_X[i, 0:x_len] = sequence[:x_len]

    x = Variable(torch.tensor(padded_X)).long()
    y = Variable(torch.tensor(list(int(i) for i in train_labels)))
    dataset = Data.TensorDataset(x, y)
    loader = Data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2
    )

    loss_func = nn.CrossEntropyLoss()
    optimizer = optim.Adam(lstm.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        for step, (batch_x, batch_y) in enumerate(loader):
            output = lstm(batch_x)
            temp = torch.argmax(output, dim=1)
            correct = 0
            for i in range(batch_size):
                if batch_y[i] == temp[i]:
                    correct += 1

            loss = loss_func(output, batch_y)
            print('epoch: {0}, step: {1}, loss: {2}, train acc: {3}'.format(epoch, step, loss, correct / batch_size))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        ver_lstm(lstm, ver_text, ver_labels, vocab, batch_size)
    test_lstm(lstm, test_text, test_labels, vocab, batch_size)
コード例 #3
0
def train_svm():
    data_helper = DataHelper()
    train_text, train_labels, ver_text, ver_labels, test_text, test_labels = data_helper.get_data_and_labels()
    stopwords = data_helper.get_stopwords()

    svm = SVM(train_text, train_labels, ver_text, ver_labels, test_text, test_labels, stopwords)

    svm.train()
    svm.verification()
    print('ver_acc: {:.3}'.format(svm.ver_acc))
    svm.test()
    print('test_acc: {:.3}'.format(svm.test_acc))
コード例 #4
0
def run_generating(args):

    # ----------------------------------------------------- #
    log_path = os.path.join('./saved_models/pretrain_generator', 'run.log')

    logger.setLevel(logging.DEBUG)
    handler = logging.FileHandler(log_path, 'w')
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s: %(message)s', datefmt='%Y/%m/%d %H:%M:%S')
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    logger.info('args: {}'.format(args))

    # ----------------------------------------------------- #
    # load data & init model and optimizer

    logger.info('Loading data & model')

    config = GPT2Config.from_pretrained(args.generator_type, cache_dir='../cache/')
    datahelper = DataHelper(args)

    path_embedding_file = os.path.join('./path_embeddings/', args.data_dir, 'path_embedding.pickle')

    # self define lm head gpt2
    gpt = GPT2Model.from_pretrained(args.generator_type, cache_dir='../cache/')
    config.vocab_size = len(datahelper.gpt_tokenizer)
    gpt.resize_token_embeddings(len(datahelper.gpt_tokenizer))
    pretrain_generator_ckpt = os.path.join('./saved_models/pretrain_generator', 'model.ckpt')
    generator = Generator(gpt, config, max_len=args.output_len).to(args.device)
    generator.load_state_dict(torch.load(pretrain_generator_ckpt, map_location=args.device))

    save_path_embedding(datahelper, generator, path_embedding_file, args)
    print('Finish.')
コード例 #5
0
from models.conv_model import ConvModel

if __name__ == "__main__":
    config = ConvConfig()
    NEG_DATA_PATH = 'data/sentiment/neg.txt'
    POS_DATA_PATH = 'data/sentiment/pos.txt'
    WORD2VEC_PATH = '/data/pretrained_model/word_embedding/glove.6B/glove.6B.%sd.txt' % config.EMBEDDING_DIM
    SAVE_PATH = 'models/models/model_w{}_e{}_c{}.h5'.format(
        config.EMBEDDING_DIM, config.LATENT_DIM, config.CLASS_NUM)

    print(config)
    print("Data Path: ", NEG_DATA_PATH, POS_DATA_PATH)
    print("Word2Vec Path: ", WORD2VEC_PATH)
    print("Save Path: ", SAVE_PATH)

    data_helper = DataHelper(config)

    #### load the data ####
    input_texts, target_texts, target_texts_inputs, classes = data_helper.read_txt_sentiment(
        NEG_DATA_PATH, POS_DATA_PATH)

    #### tokenize the inputs, outputs ####
    input_sequences, word2idx_inputs, max_len_input = \
                         data_helper.create_vocab(input_texts, target_texts, target_texts_inputs)

    #### load word2vec pretrained model ####
    word2vec = data_helper.load_word2vec(WORD2VEC_PATH)

    #### create embedding matrix ####
    embedding_matrix = data_helper.create_embedding_matrix(
        word2vec, word2idx_inputs, WORD2VEC_PATH)
コード例 #6
0
ファイル: run_mrc.py プロジェクト: lyj555/LICS2021_MRC
class ModelOperation(object):
    """ModelTrain"""

    def __init__(self):
        self.cur_process_num = paddle.distributed.get_world_size()  # PADDLE_TRAINERS_NUM 的值,默认值为1
        self.cur_process_rank = paddle.distributed.get_rank()  # PADDLE_TRAINER_ID 的值,默认值为0

        self.model_class = {
            "ernie": (ErnieForQuestionAnswering, ErnieTokenizer),
            "bert": (BertForQuestionAnswering, BertTokenizer),
            "roberta": (RobertaForQuestionAnswering, RobertaTokenizer)
        }
        self.data_helper = None

    def _initialize_run_env(self, device, seed):
        assert device in ("cpu", "gpu", "xpu"), \
            f"param device({device}) must be in ('cpu', 'gpu', 'xpu')!!!"
        paddle.set_device(device)
        if self.cur_process_num > 1:
            paddle.distributed.init_parallel_env()
        if seed:
            self.set_seed(seed)

    def _initialize_model(self, model_type, pretrained_model_path):
        assert os.path.exists(pretrained_model_path), \
            f"model path {pretrained_model_path} must exists!!!"
        logging.info(f"initialize model from {pretrained_model_path}")

        model_class, tokenizer_class = self.model_class[model_type]
        self.tokenizer = tokenizer_class.from_pretrained(pretrained_model_path)
        self.model = model_class.from_pretrained(pretrained_model_path)

        if self.cur_process_num > 1:
            self.model = paddle.DataParallel(self.model)

    def _initialize_optimizer(self, args, num_training_steps):
        self.lr_scheduler = LinearDecayWithWarmup(
            args.learning_rate, num_training_steps, args.warmup_proportion)

        self.optimizer = paddle.optimizer.AdamW(
            learning_rate=self.lr_scheduler,
            epsilon=args.adam_epsilon,
            parameters=self.model.parameters(),
            weight_decay=args.weight_decay,
            apply_decay_param_fun=lambda x: x in [
                p.name for n, p in self.model.named_parameters()
                if not any(nd in n for nd in ["bias", "norm"])
            ])

    def _start_train(self, args):
        # get train data loader
        train_data_loader = self.data_helper.get_iterator(args.train_data_path, shuffle=True)
        num_training_steps = args.max_train_steps if args.max_train_steps > 0 else \
            len(train_data_loader) * args.train_epochs
        logging.info("Num train examples: %d" % len(train_data_loader.dataset.data))
        logging.info("Max train steps: %d" % num_training_steps)
        # initialize optimizer
        self._initialize_optimizer(args, num_training_steps)
        # define loss function
        criterion = CrossEntropyLossForQA()

        global_step = 0
        tic_train = time.time()
        for epoch in range(args.train_epochs):
            for step, batch in enumerate(train_data_loader):
                global_step += 1
                input_ids, segment_ids, start_positions, end_positions, answerable_label = batch

                logits = self.model(input_ids=input_ids, token_type_ids=segment_ids)
                loss = criterion(logits, (start_positions, end_positions, answerable_label))

                if global_step % args.logging_steps == 0:
                    logging.info(
                        "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
                        % (global_step, epoch, step, loss,
                           args.logging_steps / (time.time() - tic_train)))
                    tic_train = time.time()
                loss.backward()
                self.optimizer.step()
                self.lr_scheduler.step()
                self.optimizer.clear_gradients()

                if global_step % args.save_steps == 0 or global_step == num_training_steps:
                    if self.cur_process_rank == 0:
                        output_dir = \
                            os.path.join(args.output_dir, "model_{}".format(global_step))
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        # need better way to get inner model of DataParallel
                        model_to_save = \
                            self.model._layers if isinstance(self.model, paddle.DataParallel) else self.model
                        model_to_save.save_pretrained(output_dir)
                        self.tokenizer.save_pretrained(output_dir)
                        logging.info('Saving checkpoint to:', output_dir)

    @staticmethod
    def _evaluate(raw_data_path, pred_data_path, tag=None):
        ref_ans = read_mrc_dataset(raw_data_path, tag=tag)
        assert len(ref_ans) > 0, 'Find no sample with tag - {}'.format(tag)
        pred_ans = read_model_prediction(pred_data_path)
        F1, EM, ans_score, TOTAL, SKIP = evaluate(ref_ans, pred_ans, verbose=False)
        print_metrics(F1, EM, ans_score, TOTAL, SKIP, tag)

    def train_and_eval(self, args):
        self._initialize_run_env(args.device, args.seed)
        self._initialize_model(args.model_type, args.pretrained_model_path)
        self.data_helper = DataHelper(self.tokenizer, args.batch_size,
                                      args.doc_stride, args.max_seq_length)
        # start training
        if args.do_train:
            logging.info("start training...")
            self._start_train(args)
            logging.info("train success.")
        # start evaluation
        if args.do_eval:
            logging.info("start evaluating...")
            assert len(args.eval_files) == 1, "if do_eval, then eval_files must have one!!!"
            eval_file_path = args.eval_files[0]
            self.predict([eval_file_path], args.output_dir, args.max_answer_length,
                         args.cls_threshold, args.n_best_size)
            file_name = os.path.basename(eval_file_path).replace(".json", "")
            pred_file_path = os.path.join(args.output_dir, file_name + '_predictions.json')
            self._evaluate(eval_file_path, pred_file_path, args.tag)
            # confirm threshold
            confirm_threshold(eval_file_path, args.output_dir, file_name)
            logging.info("evaluate success.")
        # start predicting
        if args.do_predict:
            logging.info("start predicting...")
            self.predict(args.predict_files, args.output_dir, args.max_answer_length,
                         args.cls_threshold, args.n_best_size)
            logging.info("predict success.")

    @paddle.no_grad()
    def _predict(self, data_loader, output_dir, max_answer_length, cls_threshold,
                 n_best_size=10, prefix=""):
        self.model.eval()

        all_start_logits, all_end_logits = [], []
        all_cls_logits = []
        tic_eval = time.time()

        for batch in data_loader:
            input_ids, segment_ids = batch
            start_logits_tensor, end_logits_tensor, cls_logits_tensor = \
                self.model(input_ids, segment_ids)

            for idx in range(start_logits_tensor.shape[0]):
                if len(all_start_logits) % 1000 == 0 and len(all_start_logits):
                    logging.info("Processing example: %d" % len(all_start_logits))
                    logging.info('time per 1000:', time.time() - tic_eval)
                    tic_eval = time.time()

                all_start_logits.append(start_logits_tensor.numpy()[idx])
                all_end_logits.append(end_logits_tensor.numpy()[idx])
                all_cls_logits.append(cls_logits_tensor.numpy()[idx])

        all_predictions, all_nbest_json, all_cls_predictions = \
            compute_prediction_span(
                examples=data_loader.dataset.data,
                features=data_loader.dataset.new_data,
                predictions=(all_start_logits, all_end_logits, all_cls_logits),
                version_2_with_negative=True,
                n_best_size=n_best_size,
                max_answer_length=max_answer_length,
                cls_threshold=cls_threshold)

        # start save inference result
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        with open(os.path.join(output_dir, prefix + '_predictions.json'), "w", encoding='utf-8') as f:
            f.write(json.dumps(all_predictions, ensure_ascii=False, indent=4) + "\n")

        with open(os.path.join(output_dir, prefix + '_nbest_predictions.json'), "w",
                  encoding="utf8") as f:
            f.write(json.dumps(all_nbest_json, indent=4, ensure_ascii=False) + u"\n")

        if all_cls_predictions:
            with open(os.path.join(output_dir, prefix + "_cls_preditions.json"), "w") as f:
                for cls_predictions in all_cls_predictions:
                    qas_id, pred_cls_label, no_answer_prob, answerable_prob = cls_predictions
                    f.write('{}\t{}\t{}\t{}\n'.format(qas_id, pred_cls_label, no_answer_prob, answerable_prob))
        self.model.train()

    def predict(self, predict_files, output_dir, max_answer_length, cls_threshold, n_best_size):
        assert predict_files is not None, "param predict_files should be set when predicting!"
        input_files = []
        for input_pattern in predict_files:
            input_files.extend(glob.glob(input_pattern))
        assert len(input_files) > 0, 'Can not find predict file in {}'.format(predict_files)
        for input_file in input_files:
            file_name = os.path.basename(input_file).replace(".json", "")
            data_loader = \
                self.data_helper.get_iterator(input_file, part_feature=True)  # no need extract position info
            self._predict(data_loader, output_dir, max_answer_length,
                          cls_threshold, n_best_size, prefix=file_name)

    @staticmethod
    def set_seed(random_seed):
        random.seed(random_seed)
        np.random.seed(random_seed)
        paddle.seed(random_seed)
コード例 #7
0
    config = configparser.ConfigParser()
    config.read(CONFIG_PATH)
    print("Config file read")

    #load embedding, read data and tokenize
    if args.emb_type == 'keras':
        raise NotImplementedError(
            "Keras Embedding is not applicable to linear model.",
            self.emb_type)

    io = IOHelper(size=args.size,
                  emb_type=args.emb_type,
                  preprocess=args.preprocess)

    dh = DataHelper(io,
                    ln=(int)(config['MODEL']['LENGTH']),
                    nr_words=(int)(config['MODEL']['NUM_WORDS']),
                    dim=(int)(config['MODEL']['DIM']))
    #read data
    X_raw, Y_raw, X_test_raw = io.readData()

    #prepare data for training
    emb, emb_dim = dh.getEmbedding()

    #compute feature vectors
    X_feat = []

    X_feat = [embed2featureWords(row, emb, emb_dim) for row in X_raw]

    X_feat = pd.DataFrame(X_feat)
    y = pd.DataFrame(Y_raw)
コード例 #8
0
tf.flags.DEFINE_boolean("allow_soft_placement", True,
                        "allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False,
                        "log placement of ops on devices")
tf.flags.DEFINE_boolean("gpu_options_allow_growth", True,
                        "allow gpu options growth")

FLAGS = tf.flags.FLAGS
print("parameters info:")
for attr, value in tf.flags.FLAGS.__flags.items():
    print("{0}: {1}".format(attr, value.value))

# 2. load data
## get train_x, train_y, dev_x, dev_y
data_helper = DataHelper(FLAGS.data_path,
                         FLAGS.vocab_path,
                         fields=['y', 'x'],
                         startline=1)
data_list = data_helper.get_data(id_fields=['x'])
x, y = data_list['x'], data_list['y']
padding_x, max_document_length = padding(x, maxlen=FLAGS.pad_seq_len)
int_y = [int(_y) for _y in y]
encoded_y = one_hot_encode(int_y)
train_x, test_x, train_y, test_y = train_test_data_split(padding_x, encoded_y)

# 3. define session
with tf.Graph().as_default():
    # session_config=tf.ConfigProto(allow_soft_placement=True,log_device_placement=False)
    # sess=tf.Session(config=session_config)
    session_config = tf.compat.v1.ConfigProto(allow_soft_placement=True,
                                              log_device_placement=False)
    sess = tf.compat.v1.Session(config=session_config)
コード例 #9
0
from models.config import AttnConfig
from models.seq2seq_attn_model import Seq2SeqAttnModel

if __name__ == "__main__":

    config = AttnConfig()
    DATA_PATH = 'toy_data/translation/kor.txt'
    WORD2VEC_PATH = '/data/pretrained_model/word_embedding/glove.6B/glove.6B.%sd.txt' % config.EMBEDDING_DIM
    LOAD_PATH = 'bin/checkpoints/seq2seq_model.h5'

    print(config)
    print("Data Path: ", DATA_PATH)
    print("Word2Vec Path: ", WORD2VEC_PATH)
    print("Save Path: ", LOAD_PATH)

    data_helper = DataHelper(config)

    #### load the data ####
    input_texts, target_texts, target_texts_inputs = data_helper.read_txt_translation(
        DATA_PATH)

    #### tokenize the inputs, outputs ####
    encoder_inputs, decoder_inputs, decoder_targets, \
                word2idx_inputs, word2idx_outputs, \
                max_len_input, max_len_target, num_words_output = \
                            data_helper.create_vocab(input_texts, target_texts, target_texts_inputs)

    #### set data of model ####
    model = Seq2SeqAttnModel(config)
    model.set_data(encoder_inputs, decoder_inputs, decoder_targets,
                   max_len_input, max_len_target, num_words_output,
コード例 #10
0
    config = AdvStyleConfig()
    NEG_DATA_PATH = 'toy_data/sentiment/neg.txt'
    POS_DATA_PATH = 'toy_data/sentiment/pos.txt'
    WORD2VEC_PATH = '/data/pretrained_model/word_embedding/glove.6B/glove.6B.%sd.txt' % config.EMBEDDING_DIM
    LOAD_PATH = 'bin/checkpoints/seq2seq_adv_style_model.h5'
    SAVE_RESULT_PATH = 'results/sample_result.csv'

    config.MODE = 'inference'

    print(config)

    print("Data Path: ", NEG_DATA_PATH, POS_DATA_PATH)
    print("Word2Vec Path: ", WORD2VEC_PATH)
    print("Save Path: ", LOAD_PATH)

    data_helper = DataHelper(config)

    #### load the data ####
    input_texts, target_texts, target_texts_inputs, styles = data_helper.read_txt_sentiment(
        NEG_DATA_PATH, POS_DATA_PATH)

    #### tokenize the inputs, outputs ####
    encoder_inputs, decoder_inputs, decoder_targets, \
                word2idx_inputs, word2idx_outputs, \
                max_len_input, max_len_target, num_words_output = \
                            data_helper.create_vocab(input_texts, target_texts, target_texts_inputs)

    #### set data of model ####
    model = Seq2SeqAdvStyleModel(config)
    model.set_data(encoder_inputs, decoder_inputs, decoder_targets, styles,
                   max_len_input, max_len_target, num_words_output,
コード例 #11
0
import os, sys
import numpy as np

from utils.data_helper import DataHelper
from models.config import PtrNetworkConfig
from models.ptr_network_model import PtrNetworkModel

if __name__ == "__main__":
    config = PtrNetworkConfig()
    X_DATA_PATH = 'toy_data/number_ordering/x.txt'
    Y_DATA_PATH = 'toy_data/number_ordering/y.txt'

    SAVE_PATH = 'bin/checkpoints/seq2seq_adv_style_model.h5'

    print(config)
    print("Data Path: ", X_DATA_PATH, Y_DATA_PATH)
    print("Save Path: ", SAVE_PATH)

    data_helper = DataHelper(config)

    #### create number ordering data ####
    data_helper.create_txt_number_ordering(X_DATA_PATH, Y_DATA_PATH)

    #### load the data ####
コード例 #12
0
    parser.add_argument('--1', dest='cross_valid', action='store_true')

    args = parser.parse_args()

    #parse config file
    config = configparser.ConfigParser()
    config.read(CONFIG_PATH)
    print("Config file read")

    #initialize IO- and Data-Helper
    io = IOHelper(size=args.size,
                  emb_type=args.emb_type,
                  preprocess=args.preprocess)

    dh = DataHelper(io,
                    ln=(int)(config['MODEL']['LENGTH']),
                    nr_words=(int)(config['MODEL']['NUM_WORDS']),
                    dim=(int)(config['MODEL']['DIM']))

    #initialize model class for training
    model = Model(io,
                  dh,
                  arch_name=args.arch_name,
                  epochs=(int)(config['MODEL']['EPOCHS']),
                  ln=(int)(config['MODEL']['LENGTH']),
                  batch_size=(int)(config['MODEL']['BATCH_SIZE']))

    if args.cross_valid:
        #print("Cross Validation")
        model.crossValidate(kfold=(int)(config['DEFAULT']['KFOLD']))

    print("Train on full dataset")
コード例 #13
0
from models.config import AttnConfig
from models.seq2seq_attn_model import Seq2SeqAttnModel

if __name__ == "__main__":

    config = AttnConfig()
    DATA_PATH = 'toy_data/translation/kor.txt'
    WORD2VEC_PATH = '/data/pretrained_model/word_embedding/glove.6B/glove.6B.%sd.txt' % config.EMBEDDING_DIM
    SAVE_PATH = 'bin/checkpoints/seq2seq_model.h5'

    print(config)
    print("Data Path: ", DATA_PATH)
    print("Word2Vec Path: ", WORD2VEC_PATH)
    print("Save Path: ", SAVE_PATH)

    data_helper = DataHelper(config)

    #### load the data ####
    input_texts, target_texts, target_texts_inputs = data_helper.read_txt_translation(
        DATA_PATH)

    #### tokenize the inputs, outputs ####
    encoder_inputs, decoder_inputs, decoder_targets, \
     word2idx_inputs, word2idx_outputs, \
     max_len_input, max_len_target, num_words_output = \
                         data_helper.create_vocab(input_texts, target_texts, target_texts_inputs)

    #### load word2vec pretrained model ####
    word2vec = data_helper.load_word2vec(WORD2VEC_PATH)

    #### create embedding matrix ####
コード例 #14
0
def test():
    data_helper = DataHelper()
    train_text, train_labels, ver_text, ver_labels, test_text, test_labels = data_helper.get_data_and_labels()
    labels = list(int(i) for i in train_labels)
    wts = np.bincount(labels)
    print(wts)
コード例 #15
0
ファイル: main.py プロジェクト: min942773/path_generator
def run_training(args):

    # ----------------------------------------------------- #
    # checkpoint directory
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    model_ckpt = os.path.join(args.save_dir, 'model.ckpt')

    # log file
    if args.num_epoch == 0:
        log_path = os.path.join(args.save_dir, 'test.log')
    else:
        log_path = os.path.join(args.save_dir, 'train.log')

    logger.setLevel(logging.DEBUG)
    handler = logging.FileHandler(log_path, 'w')
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s: %(message)s',
                                  datefmt='%Y/%m/%d %H:%M:%S')
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    logger.info('args: {}'.format(args))

    writer = SummaryWriter(log_dir=args.save_dir)
    # ----------------------------------------------------- #
    # load data & init model and optimizer

    logger.info('Loading data & model')

    config = GPT2Config.from_pretrained(args.model, cache_dir='../cache/')
    tokenizer = GPT2Tokenizer.from_pretrained(args.model,
                                              cache_dir='../cache/')
    gpt = GPT2Model.from_pretrained(args.model, cache_dir='../cache/')
    logger.info('Old vocab size: {}'.format(config.vocab_size))

    datahelper = DataHelper(os.path.join('./data', args.data_dir),
                            tokenizer=tokenizer)
    config.vocab_size = len(tokenizer)
    logger.info('New vocab size: {}'.format(config.vocab_size))
    gpt.resize_token_embeddings(len(tokenizer))
    model = GPT2LM(gpt, config)
    model.to(args.device)

    train_sampler = RandomSampler(datahelper.trainset)
    train_dataloader = DataLoader(datahelper.trainset,
                                  sampler=train_sampler,
                                  batch_size=args.batch_size)
    logger.info('Num of samples: {}, steps: {}'.format(
        len(datahelper.trainset),
        len(datahelper.trainset) // args.batch_size))

    t_total = len(train_dataloader) * args.num_epoch
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # ----------------------------------------------------- #

    # training
    best_dev_loss = 1e19
    train_iterator = trange(int(args.num_epoch), desc="Epoch")
    step_nogress = 0
    global_step = 0
    save_id = 0
    tr_loss, logging_loss = 0.0, 0.0
    for epoch in train_iterator:
        train_loss = 0.0
        num_steps = 0
        model.train()
        epoch_iterator = tqdm(train_dataloader,
                              desc="Train Iteration at Epoch {}".format(epoch))
        for step, batch in enumerate(epoch_iterator):

            inputs = batch.to(args.device)
            labels = batch.clone()[:, 16:].to(args.device)

            optimizer.zero_grad()
            outputs = model(inputs)

            outputs = outputs[:, 15:-1]
            outputs = outputs.contiguous()
            labels = labels.contiguous()
            loss = F.nll_loss(outputs.view(-1, config.vocab_size),
                              labels.view(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.max_grad_norm)
            optimizer.step()
            scheduler.step()  # Update learning rate schedule

            train_loss += loss.item()
            tr_loss += loss.item()
            num_steps += 1  # len(batch)
            log = 'Epoch: {:03d}, Iter: {:03d}, step loss: {:.4f}'
            logger.info(log.format(epoch, step, loss.item()))
            writer.add_scalar('Train/nll', loss.item(), global_step)
            # writer.add_scalar('Train/nll_no_pad', loss_no_pad.item(), global_step)

            global_step += 1

        train_loss /= num_steps
        log = 'Epoch: {:03d} Train loss: {:.4f}'
        logger.info(log.format(epoch, train_loss))

        result_dev = evaluation(datahelper, model, config, args, test=False)
        log = 'Epoch: {:03d}, Dev ppl: {:.4f} loss: {:.4f}'
        if result_dev['loss_no_pad'] <= best_dev_loss:
            best_dev_loss = result_dev['loss_no_pad']
            torch.save(model.state_dict(), '{}'.format(model_ckpt))
            step_nogress = 0

        logger.info(log.format(epoch, result_dev['ppl'], result_dev['loss']))
        writer.add_scalar('Dev/nll', result_dev['loss'], epoch)
        writer.add_scalar('Dev/nll_no_pad', result_dev['loss_no_pad'], epoch)
        writer.add_scalar('Dev/ppl', result_dev['ppl'], epoch)
        step_nogress += 1
        if step_nogress > 2:
            break

    # testing
    model.load_state_dict(torch.load('{}'.format(model_ckpt)))
    result_test = evaluation(datahelper, model, config, args, test=True)
    log = 'Epoch: {:03d}, Test ppl: {:.4f}  loss: {:.4f}'
    logger.info(log.format(-1, result_test['ppl'], result_test['loss']))
    writer.add_scalar('Test/nll', result_test['loss'], 0)
    writer.add_scalar('Test/nll_no_pad', result_test['loss_no_pad'], 0)
    writer.add_scalar('Test/ppl', result_test['ppl'], 0)