Пример #1
0
def train_and_evaluate(model,
                       data_loader,
                       train_data,
                       val_data,
                       test_data,
                       optimizer,
                       metrics,
                       params,
                       model_dir,
                       data_encoder,
                       label_encoder,
                       restore_file=None,
                       best_model='val',
                       save_model=True,
                       eval=True):

    from src.ner.utils import SummaryWriter, Label, plot

    # plotting tools
    train_summary_writer = SummaryWriter([*metrics] + ['loss'], name='train')
    val_summary_writer = SummaryWriter([*metrics] + ['loss'], name='val')
    test_summary_writer = SummaryWriter([*metrics] + ['loss'], name='test')
    writers = [train_summary_writer, val_summary_writer, test_summary_writer]
    labeller = Label(anchor_metric='accuracy', anchor_writer='val')
    plots_dir = os.path.join(model_dir, 'plots')
    if not os.path.exists(plots_dir):
        os.makedirs(plots_dir)

    start_epoch = -1
    if restore_file is not None:
        logging.info("Restoring parameters from {}".format(restore_file))
        checkpoint = utils.load_checkpoint(restore_file, model, optimizer)
        start_epoch = checkpoint['epoch']

    # save the snapshot of parameters fro reproducibility
    utils.save_dict_to_json(params.dict,
                            os.path.join(model_dir, 'train_snapshot.json'))

    # variable initialization
    best_acc = 0.0
    patience = 0
    early_stopping_metric = 'accuracy'

    if not val_data and eval or not val_data and save_model == 'val':
        raise Exception('No validation data has been passed.')

    for epoch in range(start_epoch + 1, params.num_epochs):

        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        num_steps = (params.train_size + 1) // params.batch_size
        train_data_iterator = data_loader.batch_iterator(train_data,
                                                         params,
                                                         shuffle=True)

        train_metrics = train(model, optimizer, train_data_iterator, metrics,
                              params.save_summary_steps, num_steps,
                              label_encoder)
        train_summary_writer.update(train_metrics)

        train_acc = train_metrics[early_stopping_metric]
        if best_model == 'train':
            is_best = train_acc >= best_acc
        if eval:
            # Evaluate for one epoch on validation set
            num_steps = (params.val_size + 1) // params.batch_size
            val_data_iterator = data_loader.batch_iterator(val_data,
                                                           params,
                                                           shuffle=False)
            val_metrics = evaluate(model,
                                   val_data_iterator,
                                   metrics,
                                   num_steps,
                                   data_encoder,
                                   label_encoder,
                                   mode='Val')
            val_summary_writer.update(val_metrics)

            val_acc = val_metrics[early_stopping_metric]
            if best_model == 'val':
                is_best = val_acc >= best_acc

            ### TEST
            num_steps = (params.test_size + 1) // params.batch_size
            test_data_iterator = data_loader.batch_iterator(test_data,
                                                            params,
                                                            shuffle=False)
            test_metrics = evaluate(model,
                                    test_data_iterator,
                                    metrics,
                                    num_steps,
                                    data_encoder,
                                    label_encoder,
                                    mode='Test')
            test_summary_writer.update(test_metrics)

        labeller.update(writers=writers)

        plot(writers=writers, plot_dir=plots_dir, save=True)

        # Save weights
        if save_model:
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optim_dict': optimizer.state_dict()
                },
                is_best=is_best,
                checkpoint=model_dir)
            # save encoders only if they do not exist yet
            if not os.path.exists(os.path.join(model_dir, 'data_encoder.pkl')):
                utils.save_obj(data_encoder,
                               os.path.join(model_dir, 'data_encoder.pkl'))
            if not os.path.exists(os.path.join(model_dir,
                                               'label_encoder.pkl')):
                utils.save_obj(label_encoder,
                               os.path.join(model_dir, 'label_encoder.pkl'))

        # If best_eval, best_save_path
        if is_best:
            patience = 0
            logging.info("- Found new best accuracy")
            best_acc = train_acc if best_model == 'train' else val_acc
            # Save best metrics in a json file in the model directory
            if eval:
                utils.save_dict_to_json(
                    val_metrics,
                    os.path.join(model_dir, "metrics_val_best_weights.json"))
            utils.save_dict_to_json(
                train_metrics,
                os.path.join(model_dir, "metrics_train_best_weights.json"))
        else:
            if eval:
                patience += 1
                logging.info('current patience: {} ; max patience: {}'.format(
                    patience, params.patience))
            if patience == params.patience:
                logging.info(
                    'patience reached. Exiting at epoch: {}'.format(epoch + 1))
                # Save latest metrics in a json file in the model directory before exiting
                if eval:
                    utils.save_dict_to_json(
                        val_metrics,
                        os.path.join(model_dir, 'plots',
                                     "metrics_val_last_weights.json"))
                    utils.save_dict_to_json(
                        test_metrics,
                        os.path.join(model_dir, 'plots',
                                     "metrics_test_last_weights.json"))
                utils.save_dict_to_json(
                    train_metrics,
                    os.path.join(model_dir, 'plots',
                                 "metrics_train_last_weights.json"))
                epoch = epoch - patience
                break

        # Save latest metrics in a json file in the model directory at end of epoch
        if eval:
            utils.save_dict_to_json(
                val_metrics,
                os.path.join(model_dir, 'plots',
                             "metrics_val_last_weights.json"))
            utils.save_dict_to_json(
                test_metrics,
                os.path.join(model_dir, 'plots',
                             "metrics_test_last_weights.json"))
        utils.save_dict_to_json(
            train_metrics,
            os.path.join(model_dir, 'plots',
                         "metrics_train_last_weights.json"))
    return epoch
def train_and_evaluate(model,
                       data_loader,
                       train_data,
                       val_data,
                       test_data,
                       optimizer,
                       metrics,
                       params,
                       model_dir,
                       data_encoder,
                       label_encoder,
                       restore_file=None,
                       save_model=True,
                       eval=True):
    from src.ner.utils import SummaryWriter, Label, plot

    # plotting tools
    train_summary_writer = SummaryWriter([*metrics] + ['loss'], name='train')
    val_summary_writer = SummaryWriter([*metrics] + ['loss'], name='val')
    test_summary_writer = SummaryWriter([*metrics] + ['loss'], name='test')
    writers = [train_summary_writer, val_summary_writer, test_summary_writer]
    labeller = Label(anchor_metric='f1_score',
                     anchor_writer='val')
    plots_dir = os.path.join(model_dir, 'plots')
    if not os.path.exists(plots_dir):
        os.makedirs(plots_dir)

    start_epoch = -1
    if restore_file is not None:
        logging.info("Restoring parameters from {}".format(restore_file))
        checkpoint = utils.load_checkpoint(restore_file, model, optimizer)
        start_epoch = checkpoint['epoch']

    # save the snapshot of parameters fro reproducibility
    utils.save_dict_to_json(params.dict, os.path.join(model_dir, 'train_snapshot.json'))

    # variable initialization
    best_val_score = 0.0
    patience = 0
    early_stopping_metric = 'f1_score'

    # set the Learning rate Scheduler
    lambda_lr = lambda epoch: 1 / (1 + (params.lr_decay_rate * epoch))
    lr_scheduler = LambdaLR(optimizer, lr_lambda=lambda_lr, last_epoch=start_epoch)

    # train over epochs
    for epoch in range(start_epoch + 1, params.num_epochs):
        lr_scheduler.step()
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))
        logging.info("Learning Rate : {}".format(lr_scheduler.get_lr()))

        # compute number of batches in one epoch (one full pass over the training set)
        # num_steps = (params.train_size + 1) // params.batch_size
        num_steps = (train_data['size'] + 1) // params.batch_size
        train_data_iterator = data_loader.batch_iterator(train_data, batch_size=params.batch_size, shuffle=True)
        train_metrics = train(model,
                              optimizer,
                              train_data_iterator,
                              metrics,
                              params,
                              num_steps,
                              data_encoder,
                              label_encoder)
        val_score = train_metrics[early_stopping_metric]
        is_best = val_score >= best_val_score
        train_summary_writer.update(train_metrics)

        if eval:
            # Evaluate for one epoch on validation set
            # num_steps = (params.val_size + 1) // params.batch_size
            num_steps = (val_data['size'] + 1) // params.batch_size
            val_data_iterator = data_loader.batch_iterator(val_data, batch_size=params.batch_size, shuffle=False)
            val_metrics = evaluate(model,
                                   val_data_iterator,
                                   metrics,
                                   num_steps,
                                   label_encoder,
                                   mode='val')

            val_score = val_metrics[early_stopping_metric]
            is_best = val_score >= best_val_score
            val_summary_writer.update(val_metrics)

            ### TEST
            # num_steps = (params.test_size + 1) // params.batch_size
            num_steps = (test_data['size'] + 1) // params.batch_size
            test_data_iterator = data_loader.batch_iterator(test_data, batch_size=params.batch_size, shuffle=False)
            test_metrics = evaluate(model,
                                    test_data_iterator,
                                    metrics,
                                    num_steps,
                                    label_encoder,
                                    mode='test')
            test_summary_writer.update(test_metrics)

        labeller.update(writers=writers)

        plot(writers=writers,
             plot_dir=plots_dir,
             save=True)

        # Save weights
        if save_model:
            utils.save_checkpoint({'epoch': epoch,
                                   'state_dict': model.state_dict(),
                                   'optim_dict': optimizer.state_dict()},
                                  is_best=is_best,
                                  checkpoint=model_dir)

            # save encoders only if they do not exist yet
            if not os.path.exists(os.path.join(model_dir, 'data_encoder.pkl')):
                utils.save_obj(data_encoder, os.path.join(model_dir, 'data_encoder.pkl'))
            if not os.path.exists(os.path.join(model_dir, 'label_encoder.pkl')):
                utils.save_obj(label_encoder, os.path.join(model_dir, 'label_encoder.pkl'))

        # If best_eval, best_save_path
        if is_best:
            patience = 0
            logging.info("- Found new best F1 score")
            best_val_score = val_score
            # Save best metrics in a json file in the model directory
            if eval:
                utils.save_dict_to_json(val_metrics, os.path.join(model_dir, 'plots', "metrics_val_best_weights.json"))
                utils.save_dict_to_json(test_metrics, os.path.join(model_dir, 'plots', "metrics_test_best_weights.json"))
            utils.save_dict_to_json(train_metrics, os.path.join(model_dir, 'plots', "metrics_train_best_weights.json"))
        else:
            if eval:
                patience += 1
                logging.info('current patience: {} ; max patience: {}'.format(patience, params.patience))
            if patience == params.patience:
                logging.info('patience reached. Exiting at epoch: {}'.format(epoch + 1))
                # Save latest metrics in a json file in the model directory before exiting
                if eval:
                    utils.save_dict_to_json(val_metrics, os.path.join(model_dir, 'plots', "metrics_val_last_weights.json"))
                    utils.save_dict_to_json(test_metrics,
                                            os.path.join(model_dir, 'plots', "metrics_test_last_weights.json"))
                utils.save_dict_to_json(train_metrics, os.path.join(model_dir, 'plots', "metrics_train_last_weights.json"))
                epoch = epoch - patience
                break

        # Save latest metrics in a json file in the model directory at end of epoch
        if eval:
            utils.save_dict_to_json(val_metrics, os.path.join(model_dir, 'plots', "metrics_val_last_weights.json"))
            utils.save_dict_to_json(test_metrics, os.path.join(model_dir, 'plots', "metrics_test_last_weights.json"))
        utils.save_dict_to_json(train_metrics, os.path.join(model_dir, 'plots', "metrics_train_last_weights.json"))
    return epoch