Exemple #1
0
 def test_early_stop(self):
     data_set, model = prepare_env()
     trainer = Trainer(data_set, model, optimizer=SGD(lr=0.01), loss=BCELoss(pred="predict", target="y"),
                       batch_size=32, n_epochs=20, print_every=50, dev_data=data_set,
                       metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False,
                       callbacks=[EarlyStopCallback(5)], check_code_level=2)
     trainer.train()
def create_cb():
    lrschedule_callback = LRScheduler(
        lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.05 * ep)))
    clip_callback = GradientClipCallback(clip_type='value', clip_value=2)
    save_dir = os.path.join(root_path, f'model/{args.data_type}',
                            f'fold{args.fold}')
    save_callback = SaveModelCallback(top=1, save_dir=save_dir)
    if args.cv:
        callbacks = [
            lrschedule_callback,
            clip_callback,
            save_callback,
        ]
    else:
        callbacks = [
            lrschedule_callback,
            clip_callback,
            save_callback,
        ]
    # callbacks.append(Unfreeze_Callback(embedding_param ,args.fix_embed_epoch))

    if args.use_bert:
        if args.fix_bert_epoch != 0:
            callbacks.append(
                Unfreeze_Callback(model.lattice_embed, args.fix_bert_epoch))
        else:
            bert_embedding.requires_grad = True

    callbacks.append(EarlyStopCallback(args.early_stop))

    if args.warmup > 0 and args.model == 'transformer':
        callbacks.append(WarmupCallback(warmup=args.warmup, ))
    return callbacks
Exemple #3
0
 def test_early_stop_callback(self):
     """
     需要观察是否真的 EarlyStop
     """
     data_set, model = prepare_env()
     trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
                       batch_size=2, n_epochs=10, print_every=5, dev_data=data_set,
                       metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
                       callbacks=EarlyStopCallback(1), check_code_level=2)
     trainer.train()
Exemple #4
0
def train():
    train_data = pickle.load(open(opt.train_data_path, 'rb'))
    validate_data = pickle.load(open(opt.validate_data_path, 'rb'))

    vocab = pickle.load(open(opt.vocab, 'rb'))
    word2idx = vocab.word2idx
    idx2word = vocab.idx2word
    vocab_size = len(word2idx)
    print("vocab_size" + str(vocab_size))
    embedding_dim = opt.embedding_dim
    hidden_dim = opt.hidden_dim
    model = utils.find_class_by_name(opt.model_name,
                                     [models])(vocab_size, embedding_dim,
                                               hidden_dim)

    if not os.path.exists(opt.save_model_path):
        os.mkdir(opt.save_model_path)

    # define dataloader
    train_data.set_input('input_data', flag=True)
    train_data.set_target('target', flag=True)
    validate_data.set_input('input_data', flag=True)
    validate_data.set_target('target', flag=True)

    if opt.optimizer == 'Adagrad':
        _optimizer = Adagrad(lr=opt.learning_rate, weight_decay=0)
    elif opt.optimizer == 'SGD':
        _optimizer = SGD(lr=opt.learning_rate, momentum=0)
    elif opt.optimizer == 'SGD_momentum':
        _optimizer = SGD(lr=opt.learning_rate, momentum=0.9)
    elif opt.optimizer == 'Adam':
        _optimizer = Adam(lr=opt.learning_rate, weight_decay=0)

    overfit_trainer = Trainer(model=model,
                              train_data=train_data,
                              loss=MyCrossEntropyLoss(pred="output",
                                                      target="target"),
                              n_epochs=opt.epoch,
                              batch_size=opt.batch_size,
                              device='cuda:0',
                              dev_data=validate_data,
                              metrics=MyPPMetric(pred="output",
                                                 target="target"),
                              metric_key="-pp",
                              validate_every=opt.validate_every,
                              optimizer=_optimizer,
                              callbacks=[EarlyStopCallback(opt.patience)],
                              save_path=opt.save_model_path)

    overfit_trainer.train()
def train():
    config = Config()

    train_data, dev_data, vocabulary = get_dataset(config.data_path)

    poetry_model = PoetryModel(vocabulary_size=len(vocabulary),
                               embedding_size=config.embedding_size,
                               hidden_size=config.hidden_size)
    loss = Loss(pred='output', target='target')
    perplexity = Perplexity(pred='output', target='target')

    print("optimizer:", config.optimizer)
    print("momentum:", config.momentum)
    if config.optimizer == 'adam':
        optimizer = Adam(lr=config.lr, weight_decay=config.weight_decay)
    elif config.optimizer == 'sgd':
        optimizer = SGD(lr=config.lr, momentum=config.momentum)
    elif config.optimizer == 'adagrad':
        optimizer = Adagrad(lr=config.lr, weight_decay=config.weight_decay)
    elif config.optimizer == 'adadelta':
        optimizer = Adadelta(lr=config.lr,
                             rho=config.rho,
                             eps=config.eps,
                             weight_decay=config.weight_decay)

    timing = TimingCallback()
    early_stop = EarlyStopCallback(config.patience)

    trainer = Trainer(train_data=train_data,
                      model=poetry_model,
                      loss=loss,
                      metrics=perplexity,
                      n_epochs=config.epoch,
                      batch_size=config.batch_size,
                      print_every=config.print_every,
                      validate_every=config.validate_every,
                      dev_data=dev_data,
                      save_path=config.save_path,
                      optimizer=optimizer,
                      check_code_level=config.check_code_level,
                      metric_key="-PPL",
                      sampler=RandomSampler(),
                      prefetch=False,
                      use_tqdm=True,
                      device=config.device,
                      callbacks=[timing, early_stop])
    trainer.train()
Exemple #6
0
def main(net, optimizer, train_set, test_set, loss, metric):
    train_data, dev_data = train_set.split(ratio)
    trainer = Trainer(model=net,
                      loss=loss,
                      optimizer=optimizer,
                      n_epochs=epochs,
                      train_data=train_data,
                      dev_data=dev_data,
                      metrics=metric,
                      device=device,
                      save_path=model_save_path,
                      print_every=200,
                      use_tqdm=False,
                      callbacks=[EarlyStopCallback(early_epoch)])
    trainer.train()

    tester = Tester(model=net, data=test_set, metrics=metric, device=device)
    tester.test()
Exemple #7
0
def train_fastnlp(conf, args=None):
    pdata = PoemData()
    pdata.read_data(conf)
    pdata.get_vocab()
    if conf.use_gpu:
        device = torch.device('cuda')
    else:
        device = None
    conf.device = device
    model = FastNLPPoetryModel(pdata.vocab_size, conf.embedding_dim,
                               conf.hidden_dim, device)
    train_data = pdata.train_data
    test_data = pdata.test_data
    train_data.apply(lambda x: x['pad_words'][:-1], new_field_name="input")
    train_data.apply(lambda x: x['pad_words'][1:], new_field_name="target")
    test_data.apply(lambda x: x['pad_words'][:-1], new_field_name="input")
    test_data.apply(lambda x: x['pad_words'][1:], new_field_name="target")
    train_data.set_input("input")
    train_data.set_target("target")
    test_data.set_input("input")
    test_data.set_target("target")

    loss = MyCrossEntropyLoss(pred='output', target='target')
    metric = MyPerplexityMetric(pred='output', target='target')
    optimizer = fastnlp_optim.Adam(lr=conf.learning_rate, weight_decay=0)
    overfit_model = deepcopy(model)
    overfit_trainer = Trainer(model=overfit_model,
                              device=conf.device,
                              batch_size=conf.batch_size,
                              n_epochs=conf.n_epochs,
                              train_data=train_data,
                              dev_data=test_data,
                              loss=loss,
                              metrics=metric,
                              optimizer=optimizer,
                              save_path=conf.best_model_path,
                              validate_every=conf.save_every,
                              metric_key="-PPL",
                              callbacks=[EarlyStopCallback(conf.patience)])
    print(overfit_trainer.train())
Exemple #8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--method",
        default='cnn',
        help="train model and test it",
        choices=['cnn', 'cnn_glove', 'rnn', 'rnn_maxpool', 'rnn_avgpool'])
    parser.add_argument("--dataset",
                        default='1',
                        help="1: small dataset; 2: big dataset",
                        choices=['1', '2'])
    args = parser.parse_args()

    # 超参数
    embedding_dim = 256
    batch_size = 32
    # RNN
    hidden_dim = 256
    # CNN
    kernel_sizes = (3, 4, 5)
    num_channels = (120, 160, 200)
    acti_function = 'relu'

    learning_rate = 1e-3
    train_patience = 8
    cate_num = 4

    # GloVe
    embedding_file_path = "glove.6B.100d.txt"

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    vocab = read_vocab("vocab.txt")
    print("vocabulary length:", len(vocab))
    train_data = DataSet().load("train_set")
    dev_data = DataSet().load("dev_set")
    test_data = DataSet().load("test_set")

    if (args.dataset == '1'):
        cate_num = 4
        num_channels = (48, 48, 48)
        embedding_dim = 128
        hidden_dim = 128
    elif (args.dataset == '2'):
        cate_num = 20

    if (args.method == 'cnn'):
        model = TextCNN(vocab_size=len(vocab),
                        embedding_dim=embedding_dim,
                        kernel_sizes=kernel_sizes,
                        num_channels=num_channels,
                        num_classes=cate_num,
                        activation=acti_function)
    elif (args.method == 'cnn_glove'):
        glove_embedding = EmbedLoader.load_with_vocab(embedding_file_path,
                                                      vocab)
        embedding_dim = glove_embedding.shape[1]
        print("GloVe embedding_dim:", embedding_dim)

        model = TextCNN_glove(vocab_size=len(vocab),
                              embedding_dim=embedding_dim,
                              kernel_sizes=kernel_sizes,
                              num_channels=num_channels,
                              num_classes=cate_num,
                              activation=acti_function)
        model.embedding.load_state_dict(
            {"weight": torch.from_numpy(glove_embedding)})
        model.constant_embedding.load_state_dict(
            {"weight": torch.from_numpy(glove_embedding)})
        model.constant_embedding.weight.requires_grad = False
        model.embedding.weight.requires_grad = True

    elif (args.method == 'rnn'):
        embedding_dim = 128
        hidden_dim = 128
        model = BiRNNText(vocab_size=len(vocab),
                          embedding_dim=embedding_dim,
                          output_dim=cate_num,
                          hidden_dim=hidden_dim)
    elif (args.method == 'rnn_maxpool'):
        model = BiRNNText_pool(vocab_size=len(vocab),
                               embedding_dim=embedding_dim,
                               output_dim=cate_num,
                               hidden_dim=hidden_dim,
                               pool_name="max")
    elif (args.method == 'rnn_avgpool'):
        model = BiRNNText_pool(vocab_size=len(vocab),
                               embedding_dim=embedding_dim,
                               output_dim=cate_num,
                               hidden_dim=hidden_dim,
                               pool_name="avg")

    tester = Tester(test_data, model, metrics=AccuracyMetric())

    trainer = Trainer(
        train_data=train_data,
        model=model,
        loss=CrossEntropyLoss(pred=Const.OUTPUT, target=Const.TARGET),
        metrics=AccuracyMetric(),
        n_epochs=80,
        batch_size=batch_size,
        print_every=10,
        validate_every=-1,
        dev_data=dev_data,
        optimizer=torch.optim.Adam(model.parameters(), lr=learning_rate),
        check_code_level=2,
        metric_key='acc',
        use_tqdm=True,
        callbacks=[EarlyStopCallback(train_patience)],
        device=device,
    )

    trainer.train()
    tester.test()
def train():
    train_data = pickle.load(open(opt.train_data_path, 'rb'))
    validate_data = pickle.load(open(opt.validate_data_path, 'rb'))

    vocab = pickle.load(open(opt.vocab, 'rb'))
    word2idx = vocab.word2idx
    idx2word = vocab.idx2word
    input_size = len(word2idx)

    vocab_size = opt.class_num
    class_num = opt.class_num

    embedding_dim = opt.embedding_dim

    if opt.model_name == "LSTMModel":
        model = utils.find_class_by_name(opt.model_name,
                                         [models])(input_size, vocab_size,
                                                   embedding_dim,
                                                   opt.use_word2vec,
                                                   opt.embedding_weight_path)
    elif opt.model_name == "B_LSTMModel":
        model = utils.find_class_by_name(opt.model_name,
                                         [models])(input_size, vocab_size,
                                                   embedding_dim,
                                                   opt.use_word2vec,
                                                   opt.embedding_weight_path)
    elif opt.model_name == "CNNModel":
        model = utils.find_class_by_name(opt.model_name,
                                         [models])(input_size, vocab_size,
                                                   embedding_dim,
                                                   opt.use_word2vec,
                                                   opt.embedding_weight_path)
    elif opt.model_name == "MyBertModel":
        #bert_dir = "./BertPretrain"
        #bert_dir = None
        #model = utils.find_class_by_name(opt.model_name, [models])(10, 0.1, 4, bert_dir)
        train_data.apply(lambda x: x['input_data'][:2500],
                         new_field_name='input_data')
        validate_data.apply(lambda x: x['input_data'][:2500],
                            new_field_name='input_data')

        model = utils.find_class_by_name(opt.model_name, [models])(
            input_size=input_size,
            hidden_size=512,
            hidden_dropout_prob=0.1,
            num_labels=class_num,
            use_word2vec=opt.use_word2vec,
            embedding_weight_path=opt.embedding_weight_path,
        )

    if not os.path.exists(opt.save_model_path):
        os.mkdir(opt.save_model_path)

    # define dataloader
    train_data.set_input('input_data', flag=True)
    train_data.set_target('target', flag=True)
    validate_data.set_input('input_data', flag=True)
    validate_data.set_target('target', flag=True)

    if opt.optimizer == 'SGD':
        _optimizer = SGD(lr=opt.learning_rate, momentum=0)
    elif opt.optimizer == 'SGD_momentum':
        _optimizer = SGD(lr=opt.learning_rate, momentum=0.9)
    elif opt.optimizer == 'Adam':
        _optimizer = Adam(lr=opt.learning_rate, weight_decay=0)

    overfit_trainer = Trainer(
        model=model,
        train_data=train_data,
        loss=CrossEntropyLoss(pred="output", target="target"),
        n_epochs=opt.epoch,
        batch_size=opt.batch_size,
        device=[0, 1, 2, 3],
        #device=None,
        dev_data=validate_data,
        metrics=AccuracyMetric(pred="output", target="target"),
        metric_key="+acc",
        validate_every=opt.validate_every,
        optimizer=_optimizer,
        callbacks=[EarlyStopCallback(opt.patience)],
        save_path=opt.save_model_path)

    overfit_trainer.train()
Exemple #10
0
        self.fix_epoch_num = fix_epoch_num
        assert self.bert_embedding.requires_grad == False

    def on_epoch_begin(self):
        if self.epoch == self.fix_epoch_num + 1:
            self.bert_embedding.requires_grad = True


callbacks = [evaluate_callback, lrschedule_callback, clip_callback]
if args.use_bert:
    if args.fix_bert_epoch != 0:
        callbacks.append(Unfreeze_Callback(bert_embedding,
                                           args.fix_bert_epoch))
    else:
        bert_embedding.requires_grad = True
callbacks.append(EarlyStopCallback(args.early_stop))
if args.warmup > 0 and args.model == 'transformer':
    callbacks.append(WarmupCallback(warmup=args.warmup))


class record_best_test_callback(Callback):
    def __init__(self, trainer, result_dict):
        super().__init__()
        self.trainer222 = trainer
        self.result_dict = result_dict

    def on_valid_end(self, eval_result, metric_key, optimizer, better_result):
        print(eval_result['data_test']['SpanFPreRecMetric']['f'])


print(torch.rand(size=[3, 3], device=device))
            return torch.optim.Adagrad(
                self._get_require_grads_param(self.model_params),
                **self.settings)


if __name__ == "__main__":

    vocab = pickle.load(open(config.vocab_path, 'rb'))
    train_data = pickle.load(open(config.train_data_path, 'rb'))
    dev_data = pickle.load(open(config.dev_data_path, 'rb'))

    model = PoetryModel(len(vocab), config.intput_size, config.hidden_size)
    optimizer = Adam(lr=config.learning_rate, weight_decay=0)
    # optimizer = Adagrad(lr=config.learning_rate, weight_decay=0)
    # optimizer=SGD(lr=config.learning_rate, momentum=0.9)
    loss = MyCrossEntropyLoss(pred="output", target="target")
    metric = PerplexityMetric(pred="output", target="target")
    trainer = Trainer(model=model,
                      n_epochs=config.epoch,
                      validate_every=config.validate_every,
                      optimizer=optimizer,
                      train_data=train_data,
                      dev_data=dev_data,
                      metrics=metric,
                      loss=loss,
                      batch_size=config.batch_size,
                      device='cuda:0',
                      save_path=config.save_path,
                      metric_key="-PPL",
                      callbacks=[EarlyStopCallback(config.patience)])
    trainer.train()
Exemple #12
0
def train(args):
    text_data = TextData()
    with open(os.path.join(args.vocab_dir, args.vocab_data), 'rb') as fin:
        text_data = pickle.load(fin)
    vocab_size = text_data.vocab_size
    class_num = text_data.class_num
    # class_num = 1
    seq_len = text_data.max_seq_len
    print("(vocab_size,class_num,seq_len):({0},{1},{2})".format(
        vocab_size, class_num, seq_len))

    train_data = text_data.train_set
    val_data = text_data.val_set
    test_data = text_data.test_set
    train_data.set_input('words', 'seq_len')
    train_data.set_target('target')
    val_data.set_input('words', 'seq_len')
    val_data.set_target('target')

    test_data.set_input('words', 'seq_len')
    test_data.set_target('target')

    init_embeds = None
    if args.pretrain_model == "None":
        print("No pretrained model with be used.")
        print("vocabsize:{0}".format(vocab_size))
        init_embeds = (vocab_size, args.embed_size)
    elif args.pretrain_model == "word2vec":
        embeds_path = os.path.join(args.prepare_dir, 'w2v_embeds.pkl')
        print("Loading Word2Vec pretrained embedding from {0}.".format(
            embeds_path))
        with open(embeds_path, 'rb') as fin:
            init_embeds = pickle.load(fin)
    elif args.pretrain_model == 'glove':
        embeds_path = os.path.join(args.prepare_dir, 'glove_embeds.pkl')
        print(
            "Loading Glove pretrained embedding from {0}.".format(embeds_path))
        with open(embeds_path, 'rb') as fin:
            init_embeds = pickle.load(fin)
    elif args.pretrain_model == 'glove2wv':
        embeds_path = os.path.join(args.prepare_dir, 'glove2wv_embeds.pkl')
        print(
            "Loading Glove pretrained embedding from {0}.".format(embeds_path))
        with open(embeds_path, 'rb') as fin:
            init_embeds = pickle.load(fin)
    else:
        init_embeds = (vocab_size, args.embed_size)

    if args.model == "CNNText":
        print("Using CNN Model.")
        model = CNNText(init_embeds,
                        num_classes=class_num,
                        padding=2,
                        dropout=args.dropout)
    elif args.model == "StarTransformer":
        print("Using StarTransformer Model.")
        model = STSeqCls(init_embeds,
                         num_cls=class_num,
                         hidden_size=args.hidden_size)
    elif args.model == "MyCNNText":
        model = MyCNNText(init_embeds=init_embeds,
                          num_classes=class_num,
                          padding=2,
                          dropout=args.dropout)
        print("Using user defined CNNText")
    elif args.model == "LSTMText":
        print("Using LSTM Model.")
        model = LSTMText(init_embeds=init_embeds,
                         output_dim=class_num,
                         hidden_dim=args.hidden_size,
                         num_layers=args.num_layers,
                         dropout=args.dropout)
    elif args.model == "Bert":
        print("Using Bert Model.")
    else:
        print("Using default model: CNNText.")
        model = CNNText((vocab_size, args.embed_size),
                        num_classes=class_num,
                        padding=2,
                        dropout=0.1)
    print(model)
    if args.cuda:
        device = torch.device('cuda')
    else:
        device = None

    print("train_size:{0} ; val_size:{1} ; test_size:{2}".format(
        train_data.get_length(), val_data.get_length(),
        test_data.get_length()))

    if args.optim == "Adam":
        print("Using Adam as optimizer.")
        optimizer = fastnlp_optim.Adam(lr=0.001,
                                       weight_decay=args.weight_decay)
        if (args.model_suffix == "default"):
            args.model_suffix == args.optim
    else:
        print("No Optimizer will be used.")
        optimizer = None

    criterion = CrossEntropyLoss()
    metric = AccuracyMetric()
    model_save_path = os.path.join(args.model_dir, args.model,
                                   args.model_suffix)
    earlystop = EarlyStopCallback(args.patience)
    fitlog_back = FitlogCallback({"val": val_data, "train": train_data})
    trainer = Trainer(train_data=train_data,
                      model=model,
                      save_path=model_save_path,
                      device=device,
                      n_epochs=args.epochs,
                      optimizer=optimizer,
                      dev_data=val_data,
                      loss=criterion,
                      batch_size=args.batch_size,
                      metrics=metric,
                      callbacks=[fitlog_back, earlystop])
    trainer.train()
    print("Train Done.")

    tester = Tester(data=val_data,
                    model=model,
                    metrics=metric,
                    batch_size=args.batch_size,
                    device=device)
    tester.test()
    print("Test Done.")

    print("Predict the answer with best model...")
    acc = 0.0
    output = []
    data_iterator = Batch(test_data, batch_size=args.batch_size)
    for data_x, batch_y in data_iterator:
        i_data = Variable(data_x['words']).cuda()
        pred = model(i_data)[C.OUTPUT]
        pred = pred.sigmoid()
        # print(pred.shape)
        output.append(pred.cpu().data)
    output = torch.cat(output, 0).numpy()
    print(output.shape)
    print("Predict Done. {} records".format(len(output)))
    result_save_path = os.path.join(args.result_dir,
                                    args.model + "_" + args.model_suffix)
    with open(result_save_path + ".pkl", 'wb') as f:
        pickle.dump(output, f)
    output = output.squeeze()[:, 1].tolist()
    projectid = text_data.test_projectid.values
    answers = []
    count = 0
    for i in range(len(output)):
        if output[i] > 0.5:
            count += 1
    print("true sample count:{}".format(count))
    add_count = 0
    for i in range(len(projectid) - len(output)):
        output.append([0.13])
        add_count += 1
    print("Add {} default result in predict.".format(add_count))

    df = pd.DataFrame()
    df['projectid'] = projectid
    df['y'] = output
    df.to_csv(result_save_path + ".csv", index=False)
    print("Predict Done, results saved to {}".format(result_save_path))

    fitlog.finish()
Exemple #13
0
def train(args):
    text_data = TextData()
    with open(os.path.join(args.vocab_dir, args.vocab_data), 'rb') as fin:
        text_data = pickle.load(fin)
    vocab_size = text_data.vocab_size
    class_num = text_data.class_num
    seq_len = text_data.max_seq_len
    print("(vocab_size,class_num,seq_len):({0},{1},{2})".format(
        vocab_size, class_num, seq_len))

    train_data = text_data.train_set
    test_dev_data = text_data.test_set
    train_data.set_input('words', 'seq_len')
    train_data.set_target('target')
    test_dev_data.set_input('words', 'seq_len')
    test_dev_data.set_target('target')
    test_data, dev_data = test_dev_data.split(0.2)

    test_data = test_dev_data
    init_embeds = None
    if args.pretrain_model == "None":
        print("No pretrained model with be used.")
        print("vocabsize:{0}".format(vocab_size))
        init_embeds = (vocab_size, args.embed_size)
    elif args.pretrain_model == "word2vec":
        embeds_path = os.path.join(args.prepare_dir, 'w2v_embeds.pkl')
        print("Loading Word2Vec pretrained embedding from {0}.".format(
            embeds_path))
        with open(embeds_path, 'rb') as fin:
            init_embeds = pickle.load(fin)
    elif args.pretrain_model == 'glove':
        embeds_path = os.path.join(args.prepare_dir, 'glove_embeds.pkl')
        print(
            "Loading Glove pretrained embedding from {0}.".format(embeds_path))
        with open(embeds_path, 'rb') as fin:
            init_embeds = pickle.load(fin)
    elif args.pretrain_model == 'glove2wv':
        embeds_path = os.path.join(args.prepare_dir, 'glove2wv_embeds.pkl')
        print(
            "Loading Glove pretrained embedding from {0}.".format(embeds_path))
        with open(embeds_path, 'rb') as fin:
            init_embeds = pickle.load(fin)
    else:
        init_embeds = (vocab_size, args.embed_size)

    if args.model == "CNNText":
        print("Using CNN Model.")
        model = CNNText(init_embeds,
                        num_classes=class_num,
                        padding=2,
                        dropout=args.dropout)
    elif args.model == "StarTransformer":
        print("Using StarTransformer Model.")
        model = STSeqCls(init_embeds,
                         num_cls=class_num,
                         hidden_size=args.hidden_size)
    elif args.model == "MyCNNText":
        model = MyCNNText(init_embeds=init_embeds,
                          num_classes=class_num,
                          padding=2,
                          dropout=args.dropout)
        print("Using user defined CNNText")
    elif args.model == "LSTMText":
        print("Using LSTM Model.")
        model = LSTMText(init_embeds=init_embeds,
                         output_dim=class_num,
                         hidden_dim=args.hidden_size,
                         num_layers=args.num_layers,
                         dropout=args.dropout)
    elif args.model == "Bert":
        print("Using Bert Model.")
    else:
        print("Using default model: CNNText.")
        model = CNNText((vocab_size, args.embed_size),
                        num_classes=class_num,
                        padding=2,
                        dropout=0.1)
    print(model)
    if args.cuda:
        device = torch.device('cuda')
    else:
        device = None

    print("train_size:{0} ; dev_size:{1} ; test_size:{2}".format(
        train_data.get_length(), dev_data.get_length(),
        test_data.get_length()))

    if args.optim == "Adam":
        print("Using Adam as optimizer.")
        optimizer = fastnlp_optim.Adam(lr=0.001,
                                       weight_decay=args.weight_decay)
        if (args.model_suffix == "default"):
            args.model_suffix == args.optim
    else:
        print("No Optimizer will be used.")
        optimizer = None

    criterion = CrossEntropyLoss()
    metric = AccuracyMetric()
    model_save_path = os.path.join(args.model_dir, args.model,
                                   args.model_suffix)
    earlystop = EarlyStopCallback(args.patience)
    trainer = Trainer(train_data=train_data,
                      model=model,
                      save_path=model_save_path,
                      device=device,
                      n_epochs=args.epochs,
                      optimizer=optimizer,
                      dev_data=test_data,
                      loss=criterion,
                      batch_size=args.batch_size,
                      metrics=metric,
                      callbacks=[FitlogCallback(test_data), earlystop])
    trainer.train()
    print("Train Done.")

    tester = Tester(data=test_data,
                    model=model,
                    metrics=metric,
                    batch_size=args.batch_size,
                    device=device)
    tester.test()
    print("Test Done.")
    fitlog.finish()