def setup_training(model, train_loader, valid_loader, hps):
    """Does setup before starting training (run_training)"""

    train_dir = os.path.join(hps.save_root, "train")
    if not os.path.exists(train_dir): os.makedirs(train_dir)

    if hps.restore_model != 'None':
        logger.info("[INFO] Restoring %s for training...", hps.restore_model)
        bestmodel_file = os.path.join(train_dir, hps.restore_model)
        loader = ModelLoader()
        loader.load_pytorch(model, bestmodel_file)
    else:
        logger.info("[INFO] Create new model for training...")

    try:
        run_training(model, train_loader, valid_loader,
                     hps)  # this is an infinite loop until interrupted
    except KeyboardInterrupt:
        logger.error(
            "[Error] Caught keyboard interrupt on worker. Stopping supervisor..."
        )
        save_file = os.path.join(train_dir, "earlystop.pkl")
        saver = ModelSaver(save_file)
        saver.save_pytorch(model)
        logger.info('[INFO] Saving early stop model to %s', save_file)
Esempio n. 2
0
def run_training(model, train_loader, valid_loader, hps):
    """Repeatedly runs training iterations, logging loss to screen and writing summaries"""
    logger.info("[INFO] Starting run_training")

    train_dir = os.path.join(hps.save_root, "train")
    if not os.path.exists(train_dir): os.makedirs(train_dir)
    eval_dir = os.path.join(hps.save_root, "eval")  # make a subdir of the root dir for eval data
    if not os.path.exists(eval_dir): os.makedirs(eval_dir)

    lr = hps.lr
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    criterion = MyCrossEntropyLoss(pred = "p_sent", target=Const.TARGET, mask=Const.INPUT_LEN, reduce='none')
    # criterion = torch.nn.CrossEntropyLoss(reduce="none")

    trainer = Trainer(model=model, train_data=train_loader, optimizer=optimizer, loss=criterion,
                      n_epochs=hps.n_epochs, print_every=100, dev_data=valid_loader, metrics=[LabelFMetric(pred="prediction"), FastRougeMetric(hps, pred="prediction")],
                      metric_key="f", validate_every=-1, save_path=eval_dir,
                      callbacks=[TrainCallback(hps, patience=5)], use_tqdm=False)

    train_info = trainer.train(load_best_model=True)
    logger.info('   | end of Train | time: {:5.2f}s | '.format(train_info["seconds"]))
    logger.info('[INFO] best eval model in epoch %d and iter %d', train_info["best_epoch"], train_info["best_step"])
    logger.info(train_info["best_eval"])

    bestmodel_save_path = os.path.join(eval_dir, 'bestmodel.pkl')  # this is where checkpoints of best models are saved
    saver = ModelSaver(bestmodel_save_path)
    saver.save_pytorch(model)
    logger.info('[INFO] Saving eval best model to %s', bestmodel_save_path)
Esempio n. 3
0
def train(path):
    # Trainer
    trainer = Trainer(**train_args.data)

    def _define_optim(obj):
        lr = optim_args.data['lr']
        embed_params = set(obj._model.word_embedding.parameters())
        decay_params = set(obj._model.arc_predictor.parameters()) | set(
            obj._model.label_predictor.parameters())
        params = [
            p for p in obj._model.parameters()
            if p not in decay_params and p not in embed_params
        ]
        obj._optimizer = torch.optim.Adam([{
            'params': list(embed_params),
            'lr': lr * 0.1
        }, {
            'params': list(decay_params),
            **optim_args.data
        }, {
            'params': params
        }],
                                          lr=lr,
                                          betas=(0.9, 0.9))
        obj._scheduler = torch.optim.lr_scheduler.LambdaLR(
            obj._optimizer, lambda ep: max(.75**(ep / 5e4), 0.05))

    def _update(obj):
        # torch.nn.utils.clip_grad_norm_(obj._model.parameters(), 5.0)
        obj._scheduler.step()
        obj._optimizer.step()

    trainer.define_optimizer = lambda: _define_optim(trainer)
    trainer.update = lambda: _update(trainer)
    trainer.set_validator(
        Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label)))

    model.word_embedding = torch.nn.Embedding.from_pretrained(embed,
                                                              freeze=False)
    model.word_embedding.padding_idx = word_v.padding_idx
    model.word_embedding.weight.data[word_v.padding_idx].fill_(0)
    model.pos_embedding.padding_idx = pos_v.padding_idx
    model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0)

    # try:
    #     ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
    #     print('model parameter loaded!')
    # except Exception as _:
    #     print("No saved model. Continue.")
    #     pass

    # Start training
    trainer.train(model, train_data, dev_data)
    print("Training finished!")

    # Saver
    saver = ModelSaver("./save/saved_model.pkl")
    saver.save_pytorch(model)
    print("Model saved!")
def trainer(data_folder, write2model, write2vocab):
    data_bundle = PeopleDailyNERLoader().load(
        data_folder)  # 这一行代码将从{data_dir}处读取数据至DataBundle
    data_bundle = PeopleDailyPipe().process(data_bundle)
    data_bundle.rename_field('chars', 'words')
    # 存储vocab
    targetVocab = dict(data_bundle.vocabs["target"])
    wordsVocab = dict(data_bundle.vocabs["words"])
    targetWc = dict(data_bundle.vocabs['target'].word_count)
    wordsWc = dict(data_bundle.vocabs['words'].word_count)
    with open(write2vocab, "w", encoding="utf-8") as VocabOut:
        VocabOut.write(
            json.dumps(
                {
                    "targetVocab": targetVocab,
                    "wordsVocab": wordsVocab,
                    "targetWc": targetWc,
                    "wordsWc": wordsWc
                },
                ensure_ascii=False))

    embed = BertEmbedding(vocab=data_bundle.get_vocab('words'),
                          model_dir_or_name='cn',
                          requires_grad=False,
                          auto_truncate=True)
    model = BiLSTMCRF(embed=embed,
                      num_classes=len(data_bundle.get_vocab('target')),
                      num_layers=1,
                      hidden_size=100,
                      dropout=0.5,
                      target_vocab=data_bundle.get_vocab('target'))

    metric = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))
    optimizer = Adam(model.parameters(), lr=2e-5)
    loss = LossInForward()
    device = 0 if torch.cuda.is_available() else 'cpu'
    # device = "cpu"
    trainer = Trainer(data_bundle.get_dataset('train'),
                      model,
                      loss=loss,
                      optimizer=optimizer,
                      batch_size=8,
                      dev_data=data_bundle.get_dataset('dev'),
                      metrics=metric,
                      device=device,
                      n_epochs=1)
    trainer.train()
    tester = Tester(data_bundle.get_dataset('test'), model, metrics=metric)
    tester.test()
    saver = ModelSaver(write2model)
    saver.save_pytorch(model, param_only=False)
Esempio n. 5
0
def train():
    # Config Loader
    train_args = ConfigSection()
    test_args = ConfigSection()
    ConfigLoader().load_config(cfgfile, {
        "train": train_args,
        "test": test_args
    })

    print("loading data set...")
    data = SeqLabelDataSet(load_func=TokenizeDataSetLoader.load)
    data.load(cws_data_path)
    data_train, data_dev = data.split(ratio=0.3)
    train_args["vocab_size"] = len(data.word_vocab)
    train_args["num_classes"] = len(data.label_vocab)
    print("vocab size={}, num_classes={}".format(len(data.word_vocab),
                                                 len(data.label_vocab)))

    change_field_is_target(data_dev, "truth", True)
    save_pickle(data_dev, "./save/", "data_dev.pkl")
    save_pickle(data.word_vocab, "./save/", "word2id.pkl")
    save_pickle(data.label_vocab, "./save/", "label2id.pkl")

    # Trainer
    trainer = SeqLabelTrainer(epochs=train_args["epochs"],
                              batch_size=train_args["batch_size"],
                              validate=train_args["validate"],
                              use_cuda=train_args["use_cuda"],
                              pickle_path=train_args["pickle_path"],
                              save_best_dev=True,
                              print_every_step=10,
                              model_name="trained_model.pkl",
                              evaluator=SeqLabelEvaluator())

    # Model
    model = AdvSeqLabel(train_args)
    try:
        ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
        print('model parameter loaded!')
    except Exception as e:
        print("No saved model. Continue.")
        pass

    # Start training
    trainer.train(model, data_train, data_dev)
    print("Training finished!")

    # Saver
    saver = ModelSaver("./save/trained_model.pkl")
    saver.save_pytorch(model)
    print("Model saved!")
Esempio n. 6
0
 def on_epoch_begin(self, cur_epoch, total_epoch):
     if cur_epoch == 1:
         self.opt = self.trainer.optimizer  # pytorch optimizer
         self.opt.param_groups[0]["lr"] = self.start_lr
         # save model
         ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True)
         self.find = True
Esempio n. 7
0
    def on_exception(self, exception):
        if isinstance(exception, KeyboardInterrupt):
            logger.error(
                "[Error] Caught keyboard interrupt on worker. Stopping supervisor..."
            )
            train_dir = os.path.join(self._hps.save_root, "train")
            save_file = os.path.join(train_dir, "earlystop.pkl")
            saver = ModelSaver(save_file)
            saver.save_pytorch(self.model)
            logger.info('[INFO] Saving early stop model to %s', save_file)

            if self.quit_all is True:
                sys.exit(0)  # 直接退出程序
            else:
                pass
        else:
            raise exception  # 抛出陌生Error
Esempio n. 8
0
    def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
        logger.info('   | end of valid {:3d} | time: {:5.2f}s | '.format(
            self.epoch, (time.time() - self.valid_start_time)))

        # early stop
        if not is_better_eval:
            if self.wait == self.patience:
                train_dir = os.path.join(self._hps.save_root, "train")
                save_file = os.path.join(train_dir, "earlystop.pkl")
                saver = ModelSaver(save_file)
                saver.save_pytorch(self.model)
                logger.info('[INFO] Saving early stop model to %s', save_file)
                raise EarlyStopError("Early stopping raised.")
            else:
                self.wait += 1
        else:
            self.wait = 0

        # lr descent
        if self._hps.lr_descent:
            new_lr = max(5e-6, self._hps.lr / (self.epoch + 1))
            for param_group in list(optimizer.param_groups):
                param_group['lr'] = new_lr
            logger.info("[INFO] The learning rate now is %f", new_lr)
def run_training(model, train_loader, valid_loader, hps):
    """Repeatedly runs training iterations, logging loss to screen and writing summaries"""
    logger.info("[INFO] Starting run_training")

    train_dir = os.path.join(hps.save_root, "train")
    if not os.path.exists(train_dir): os.makedirs(train_dir)

    lr = hps.lr
    # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, betas=(0.9, 0.98),
    # eps=1e-09)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=lr)
    criterion = torch.nn.CrossEntropyLoss(reduction='none')

    best_train_loss = None
    best_train_F = None
    best_loss = None
    best_F = None
    step_num = 0
    non_descent_cnt = 0
    for epoch in range(1, hps.n_epochs + 1):
        epoch_loss = 0.0
        train_loss = 0.0
        total_example_num = 0
        match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0
        epoch_start_time = time.time()
        for i, (batch_x, batch_y) in enumerate(train_loader):
            # if i > 10:
            #     break
            model.train()

            iter_start_time = time.time()

            input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN]
            label = batch_y[Const.TARGET]

            # logger.info(batch_x["text"][0])
            # logger.info(input[0,:,:])
            # logger.info(input_len[0:5,:])
            # logger.info(batch_y["summary"][0:5])
            # logger.info(label[0:5,:])

            # logger.info((len(batch_x["text"][0]), sum(input[0].sum(-1) != 0)))

            batch_size, N, seq_len = input.size()

            if hps.cuda:
                input = input.cuda()  # [batch, N, seq_len]
                label = label.cuda()
                input_len = input_len.cuda()

            input = Variable(input)
            label = Variable(label)
            input_len = Variable(input_len)

            model_outputs = model.forward(input, input_len)  # [batch, N, 2]

            outputs = model_outputs["p_sent"].view(-1, 2)

            label = label.view(-1)

            loss = criterion(outputs, label)  # [batch_size, doc_max_timesteps]
            # input_len = input_len.float().view(-1)
            loss = loss.view(batch_size, -1)
            loss = loss.masked_fill(input_len.eq(0), 0)
            loss = loss.sum(1).mean()
            logger.debug("loss %f", loss)

            if not (np.isfinite(loss.data)).numpy():
                logger.error("train Loss is not finite. Stopping.")
                logger.info(loss)
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        logger.info(name)
                        logger.info(param.grad.data.sum())
                raise Exception("train Loss is not finite. Stopping.")

            optimizer.zero_grad()
            loss.backward()
            if hps.grad_clip:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               hps.max_grad_norm)

            optimizer.step()
            step_num += 1

            train_loss += float(loss.data)
            epoch_loss += float(loss.data)

            if i % 100 == 0:
                # start debugger
                # import pdb; pdb.set_trace()
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        logger.debug(name)
                        logger.debug(param.grad.data.sum())
                logger.info(
                    '       | end of iter {:3d} | time: {:5.2f}s | train loss {:5.4f} | '
                    .format(i, (time.time() - iter_start_time),
                            float(train_loss / 100)))
                train_loss = 0.0

            # calculate the precision, recall and F
            prediction = outputs.max(1)[1]
            prediction = prediction.data
            label = label.data
            pred += prediction.sum()
            true += label.sum()
            match_true += ((prediction == label) & (prediction == 1)).sum()
            match += (prediction == label).sum()
            total_example_num += int(batch_size * N)

        if hps.lr_descent:
            # new_lr = pow(hps.hidden_size, -0.5) * min(pow(step_num, -0.5),
            #                                           step_num * pow(hps.warmup_steps, -1.5))
            new_lr = max(5e-6, lr / (epoch + 1))
            for param_group in list(optimizer.param_groups):
                param_group['lr'] = new_lr
            logger.info("[INFO] The learning rate now is %f", new_lr)

        epoch_avg_loss = epoch_loss / len(train_loader)
        logger.info(
            '   | end of epoch {:3d} | time: {:5.2f}s | epoch train loss {:5.4f} | '
            .format(epoch, (time.time() - epoch_start_time),
                    float(epoch_avg_loss)))

        logger.info(
            "[INFO] Trainset match_true %d, pred %d, true %d, total %d, match %d",
            match_true, pred, true, total_example_num, match)
        accu, precision, recall, F = utils.eval_label(match_true, pred, true,
                                                      total_example_num, match)
        logger.info(
            "[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f",
            total_example_num / hps.doc_max_timesteps, accu, precision, recall,
            F)

        if not best_train_loss or epoch_avg_loss < best_train_loss:
            save_file = os.path.join(train_dir, "bestmodel.pkl")
            logger.info(
                '[INFO] Found new best model with %.3f running_train_loss. Saving to %s',
                float(epoch_avg_loss), save_file)
            saver = ModelSaver(save_file)
            saver.save_pytorch(model)
            best_train_loss = epoch_avg_loss
        elif epoch_avg_loss > best_train_loss:
            logger.error(
                "[Error] training loss does not descent. Stopping supervisor..."
            )
            save_file = os.path.join(train_dir, "earlystop.pkl")
            saver = ModelSaver(save_file)
            saver.save_pytorch(model)
            logger.info('[INFO] Saving early stop model to %s', save_file)
            return

        if not best_train_F or F > best_train_F:
            save_file = os.path.join(train_dir, "bestFmodel.pkl")
            logger.info(
                '[INFO] Found new best model with %.3f F score. Saving to %s',
                float(F), save_file)
            saver = ModelSaver(save_file)
            saver.save_pytorch(model)
            best_train_F = F

        best_loss, best_F, non_descent_cnt = run_eval(model, valid_loader, hps,
                                                      best_loss, best_F,
                                                      non_descent_cnt)

        if non_descent_cnt >= 3:
            logger.error(
                "[Error] val loss does not descent for three times. Stopping supervisor..."
            )
            save_file = os.path.join(train_dir, "earlystop")
            saver = ModelSaver(save_file)
            saver.save_pytorch(model)
            logger.info('[INFO] Saving early stop model to %s', save_file)
            return
def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    logger.info("[INFO] Starting eval for this model ...")
    eval_dir = os.path.join(
        hps.save_root, "eval")  # make a subdir of the root dir for eval data
    if not os.path.exists(eval_dir): os.makedirs(eval_dir)

    model.eval()

    running_loss = 0.0
    match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0
    pairs = {}
    pairs["hyps"] = []
    pairs["refer"] = []
    total_example_num = 0
    criterion = torch.nn.CrossEntropyLoss(reduction='none')
    iter_start_time = time.time()

    with torch.no_grad():
        for i, (batch_x, batch_y) in enumerate(loader):
            # if i > 10:
            #     break

            input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN]
            label = batch_y[Const.TARGET]

            if hps.cuda:
                input = input.cuda()  # [batch, N, seq_len]
                label = label.cuda()
                input_len = input_len.cuda()

            batch_size, N, _ = input.size()

            input = Variable(input, requires_grad=False)
            label = Variable(label)
            input_len = Variable(input_len, requires_grad=False)

            model_outputs = model.forward(input, input_len)  # [batch, N, 2]
            outputs = model_outputs["p_sent"]
            prediction = model_outputs["prediction"]

            outputs = outputs.view(-1, 2)  # [batch * N, 2]
            label = label.view(-1)  # [batch * N]
            loss = criterion(outputs, label)
            loss = loss.view(batch_size, -1)
            loss = loss.masked_fill(input_len.eq(0), 0)
            loss = loss.sum(1).mean()
            logger.debug("loss %f", loss)
            running_loss += float(loss.data)

            label = label.data.view(batch_size, -1)
            pred += prediction.sum()
            true += label.sum()
            match_true += ((prediction == label) & (prediction == 1)).sum()
            match += (prediction == label).sum()
            total_example_num += batch_size * N

            # rouge
            prediction = prediction.view(batch_size, -1)
            for j in range(batch_size):
                original_article_sents = batch_x["text"][j]
                sent_max_number = len(original_article_sents)
                refer = "\n".join(batch_x["summary"][j])
                hyps = "\n".join(
                    original_article_sents[id]
                    for id in range(len(prediction[j]))
                    if prediction[j][id] == 1 and id < sent_max_number)
                if sent_max_number < hps.m and len(hyps) <= 1:
                    logger.error("sent_max_number is too short %d, Skip!",
                                 sent_max_number)
                    continue

                if len(hyps) >= 1 and hyps != '.':
                    # logger.debug(prediction[j])
                    pairs["hyps"].append(hyps)
                    pairs["refer"].append(refer)
                elif refer == "." or refer == "":
                    logger.error("Refer is None!")
                    logger.debug("label:")
                    logger.debug(label[j])
                    logger.debug(refer)
                elif hyps == "." or hyps == "":
                    logger.error("hyps is None!")
                    logger.debug("sent_max_number:%d", sent_max_number)
                    logger.debug("prediction:")
                    logger.debug(prediction[j])
                    logger.debug(hyps)
                else:
                    logger.error("Do not select any sentences!")
                    logger.debug("sent_max_number:%d", sent_max_number)
                    logger.debug(original_article_sents)
                    logger.debug("label:")
                    logger.debug(label[j])
                    continue

    running_avg_loss = running_loss / len(loader)

    if hps.use_pyrouge:
        logger.info("The number of pairs is %d", len(pairs["hyps"]))
        logging.getLogger('global').setLevel(logging.WARNING)
        if not len(pairs["hyps"]):
            logger.error("During testing, no hyps is selected!")
            return
        if isinstance(pairs["refer"][0], list):
            logger.info("Multi Reference summaries!")
            scores_all = utils.pyrouge_score_all_multi(pairs["hyps"],
                                                       pairs["refer"])
        else:
            scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"])
    else:
        if len(pairs["hyps"]) == 0 or len(pairs["refer"]) == 0:
            logger.error("During testing, no hyps is selected!")
            return
        rouge = Rouge()
        scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True)
        # try:
        #     scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True)
        # except ValueError as e:
        #     logger.error(repr(e))
        #     scores_all = []
        #     for idx in range(len(pairs["hyps"])):
        #         try:
        #             scores = rouge.get_scores(pairs["hyps"][idx], pairs["refer"][idx])[0]
        #             scores_all.append(scores)
        #         except ValueError as e:
        #             logger.error(repr(e))
        #             logger.debug("HYPS:\t%s", pairs["hyps"][idx])
        #             logger.debug("REFER:\t%s", pairs["refer"][idx])
        # finally:
        #     logger.error("During testing, some errors happen!")
        #     logger.error(len(scores_all))
        #     exit(1)

    logger.info(
        '[INFO] End of valid | time: {:5.2f}s | valid loss {:5.4f} | '.format(
            (time.time() - iter_start_time), float(running_avg_loss)))

    logger.info(
        "[INFO] Validset match_true %d, pred %d, true %d, total %d, match %d",
        match_true, pred, true, total_example_num, match)
    accu, precision, recall, F = utils.eval_label(match_true, pred, true,
                                                  total_example_num, match)
    logger.info(
        "[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f",
        total_example_num / hps.doc_max_timesteps, accu, precision, recall, F)

    res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \
            + "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \
                + "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f'])
    logger.info(res)

    # If running_avg_loss is best so far, save this checkpoint (early stopping).
    # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
    if best_loss is None or running_avg_loss < best_loss:
        bestmodel_save_path = os.path.join(
            eval_dir, 'bestmodel.pkl'
        )  # this is where checkpoints of best models are saved
        if best_loss is not None:
            logger.info(
                '[INFO] Found new best model with %.6f running_avg_loss. The original loss is %.6f, Saving to %s',
                float(running_avg_loss), float(best_loss), bestmodel_save_path)
        else:
            logger.info(
                '[INFO] Found new best model with %.6f running_avg_loss. The original loss is None, Saving to %s',
                float(running_avg_loss), bestmodel_save_path)
        saver = ModelSaver(bestmodel_save_path)
        saver.save_pytorch(model)
        best_loss = running_avg_loss
        non_descent_cnt = 0
    else:
        non_descent_cnt += 1

    if best_F is None or best_F < F:
        bestmodel_save_path = os.path.join(
            eval_dir, 'bestFmodel.pkl'
        )  # this is where checkpoints of best models are saved
        if best_F is not None:
            logger.info(
                '[INFO] Found new best model with %.6f F. The original F is %.6f, Saving to %s',
                float(F), float(best_F), bestmodel_save_path)
        else:
            logger.info(
                '[INFO] Found new best model with %.6f F. The original loss is None, Saving to %s',
                float(F), bestmodel_save_path)
        saver = ModelSaver(bestmodel_save_path)
        saver.save_pytorch(model)
        best_F = F

    return best_loss, best_F, non_descent_cnt
 def save_model(self, save_file):
     saver = ModelSaver(save_file)
     saver.save_pytorch(self.model)
     logger.info('[INFO] Saving model to %s', save_file)