Beispiel #1
0
def fit(model, train_data, dev_data):
    """Do actual training. 
    """
    def get_summary_setting(model):
        config = model.config
        sess = model.sess
        loss_summary = tf.summary.scalar('loss', model.loss)
        acc_summary = tf.summary.scalar('accuracy', model.accuracy)
        f1_summary = tf.summary.scalar('f1', model.f1)
        lr_summary = tf.summary.scalar('learning_rate', model.learning_rate)
        train_summary_op = tf.summary.merge(
            [loss_summary, acc_summary, f1_summary, lr_summary])
        train_summary_dir = os.path.join(config.summary_dir, 'summaries',
                                         'train')
        train_summary_writer = tf.summary.FileWriter(train_summary_dir,
                                                     sess.graph)
        dev_summary_dir = os.path.join(config.summary_dir, 'summaries', 'dev')
        dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)
        return train_summary_op, train_summary_writer, dev_summary_writer

    config = model.config
    sess = model.sess

    # restore previous model if provided
    saver = tf.train.Saver()
    if config.restore is not None:
        saver.restore(sess, config.restore)
        tf.logging.debug('model restored')

    # summary setting
    train_summary_op, train_summary_writer, dev_summary_writer = get_summary_setting(
        model)

    # train and evaluate
    early_stopping = EarlyStopping(patience=10, measure='f1', verbose=1)
    max_seqeval_f1 = 0
    for e in range(config.epoch):
        train_step(model, train_data, train_summary_op, train_summary_writer)
        seqeval_f1, avg_f1 = dev_step(model, dev_data, dev_summary_writer, e)
        # early stopping
        if early_stopping.validate(seqeval_f1, measure='f1'): break
        if seqeval_f1 > max_seqeval_f1:
            tf.logging.debug('new best f1 score! : %s' % seqeval_f1)
            max_seqeval_f1 = seqeval_f1
            # save best model
            save_path = saver.save(sess,
                                   config.checkpoint_dir + '/' + 'ner_model')
            tf.logging.debug('max model saved in file: %s' % save_path)
            tf.train.write_graph(sess.graph,
                                 '.',
                                 config.checkpoint_dir + '/' + 'graph.pb',
                                 as_text=False)
            tf.train.write_graph(sess.graph,
                                 '.',
                                 config.checkpoint_dir + '/' + 'graph.pb_txt',
                                 as_text=True)
            early_stopping.reset(max_seqeval_f1)
        early_stopping.status()
    sess.close()
Beispiel #2
0
def hp_search(trial: optuna.Trial):
    if torch.cuda.is_available():
        logger.info("%s", torch.cuda.get_device_name(0))

    global gopt
    opt = gopt
    # set config
    config = load_config(opt)
    config['opt'] = opt
    logger.info("%s", config)

    # set path
    set_path(config)

    # set search spaces
    lr = trial.suggest_loguniform('lr', 1e-6, 1e-3) # .suggest_float('lr', 1e-6, 1e-3, log=True)
    bsz = trial.suggest_categorical('batch_size', [32, 64, 128])
    seed = trial.suggest_int('seed', 17, 42)
    epochs = trial.suggest_int('epochs', 1, opt.epoch)

    # prepare train, valid dataset
    train_loader, valid_loader = prepare_datasets(config, hp_search_bsz=bsz)

    with temp_seed(seed):
        # prepare model
        model = prepare_model(config)
        # create optimizer, scheduler, summary writer, scaler
        optimizer, scheduler, writer, scaler = prepare_osws(config, model, train_loader, hp_search_lr=lr)
        config['optimizer'] = optimizer
        config['scheduler'] = scheduler
        config['writer'] = writer
        config['scaler'] = scaler

        early_stopping = EarlyStopping(logger, patience=opt.patience, measure=opt.measure, verbose=1)
        best_eval_measure = float('inf') if opt.measure == 'loss' else -float('inf')
        for epoch in range(epochs):
            eval_loss, eval_acc = train_epoch(model, config, train_loader, valid_loader, epoch)

            if opt.measure == 'loss': eval_measure = eval_loss 
            else: eval_measure = eval_acc
            # early stopping
            if early_stopping.validate(eval_measure, measure=opt.measure): break
            if opt.measure == 'loss': is_best = eval_measure < best_eval_measure
            else: is_best = eval_measure > best_eval_measure
            if is_best:
                best_eval_measure = eval_measure
                early_stopping.reset(best_eval_measure)
            early_stopping.status()

            trial.report(eval_acc, epoch)
            if trial.should_prune():
                raise optuna.TrialPruned()
        return eval_acc
Beispiel #3
0
def train(opt):
    if torch.cuda.is_available():
        logger.info("%s", torch.cuda.get_device_name(0))

    # set etc
    torch.autograd.set_detect_anomaly(True)

    # set config
    config = load_config(opt)
    config['opt'] = opt
    logger.info("%s", config)

    # set path
    set_path(config)

    # prepare train, valid dataset
    train_loader, valid_loader = prepare_datasets(config)

    with temp_seed(opt.seed):
        # prepare model
        model = prepare_model(config)

        # create optimizer, scheduler, summary writer, scaler
        optimizer, scheduler, writer, scaler = prepare_osws(
            config, model, train_loader)
        config['optimizer'] = optimizer
        config['scheduler'] = scheduler
        config['writer'] = writer
        config['scaler'] = scaler

        # training
        early_stopping = EarlyStopping(logger,
                                       patience=opt.patience,
                                       measure='f1',
                                       verbose=1)
        local_worse_epoch = 0
        best_eval_f1 = -float('inf')
        for epoch_i in range(opt.epoch):
            epoch_st_time = time.time()
            eval_loss, eval_f1, best_eval_f1 = train_epoch(
                model, config, train_loader, valid_loader, epoch_i,
                best_eval_f1)
            # early stopping
            if early_stopping.validate(eval_f1, measure='f1'): break
            if eval_f1 == best_eval_f1:
                early_stopping.reset(best_eval_f1)
            early_stopping.status()
Beispiel #4
0
def do_train(model, config, train_data, dev_data):
    early_stopping = EarlyStopping(patience=10, measure='f1', verbose=1)
    maximum = 0
    session_conf = tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=False)
    session_conf.gpu_options.allow_growth = True
    sess = tf.Session(config=session_conf)
    feed_dict = {model.wrd_embeddings_init: config.embvec.wrd_embeddings}
    sess.run(tf.global_variables_initializer(),
             feed_dict=feed_dict)  # feed large embedding data
    saver = tf.train.Saver()
    if config.restore is not None:
        saver.restore(sess, config.restore)
        print('model restored')

    # summary setting
    loss_summary = tf.summary.scalar('loss', model.loss)
    acc_summary = tf.summary.scalar('accuracy', model.accuracy)
    train_summary_op = tf.summary.merge([loss_summary, acc_summary])
    train_summary_dir = os.path.join(config.summary_dir, 'summaries', 'train')
    train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)
    dev_summary_dir = os.path.join(config.summary_dir, 'summaries', 'dev')
    dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)
    for e in range(config.epoch):
        train_step(sess, model, config, train_data, train_summary_op,
                   train_summary_writer)
        m = dev_step(sess, model, config, dev_data, dev_summary_writer, e)
        # early stopping
        if early_stopping.validate(m, measure='f1'): break
        if m > maximum:
            print('new best f1 score! : %s' % m)
            maximum = m
            # save best model
            save_path = saver.save(sess,
                                   config.checkpoint_dir + '/' + 'ner_model')
            print('max model saved in file: %s' % save_path)
            tf.train.write_graph(sess.graph,
                                 '.',
                                 config.checkpoint_dir + '/' + 'graph.pb',
                                 as_text=False)
            tf.train.write_graph(sess.graph,
                                 '.',
                                 config.checkpoint_dir + '/' + 'graph.pb_txt',
                                 as_text=True)
    sess.close()
Beispiel #5
0
def train(opt):
    if torch.cuda.is_available():
        logger.info("%s", torch.cuda.get_device_name(0))

    # set etc
    torch.autograd.set_detect_anomaly(True)

    # set config
    config = load_config(opt)
    config['opt'] = opt
    logger.info("%s", config)
 
    # set path
    set_path(config)
  
    # prepare train, valid dataset
    train_loader, valid_loader = prepare_datasets(config)

    with temp_seed(opt.seed):
        # prepare model
        model = prepare_model(config)

        # create optimizer, scheduler, summary writer, scaler
        optimizer, scheduler, writer, scaler = prepare_osws(config, model, train_loader)
        config['optimizer'] = optimizer
        config['scheduler'] = scheduler
        config['writer'] = writer
        config['scaler'] = scaler

        # training
        early_stopping = EarlyStopping(logger, patience=opt.patience, measure='f1', verbose=1)
        local_worse_steps = 0
        prev_eval_f1 = -float('inf')
        best_eval_f1 = -float('inf')
        for epoch_i in range(opt.epoch):
            epoch_st_time = time.time()
            eval_loss, eval_f1 = train_epoch(model, config, train_loader, valid_loader, epoch_i)
            # early stopping
            if early_stopping.validate(eval_f1, measure='f1'): break
            if eval_f1 > best_eval_f1:
                best_eval_f1 = eval_f1
                if opt.save_path:
                    logger.info("[Best model saved] : {:10.6f}".format(best_eval_f1))
                    save_model(config, model)
                    # save finetuned bert model/config/tokenizer
                    if config['emb_class'] in ['bert', 'distilbert', 'albert', 'roberta', 'bart', 'electra']:
                        if not os.path.exists(opt.bert_output_dir):
                            os.makedirs(opt.bert_output_dir)
                        model.bert_tokenizer.save_pretrained(opt.bert_output_dir)
                        model.bert_model.save_pretrained(opt.bert_output_dir)
                early_stopping.reset(best_eval_f1)
            early_stopping.status()
            # begin: scheduling, apply rate decay at the measure(ex, loss) getting worse for the number of deacy epoch steps.
            if prev_eval_f1 >= eval_f1:
                local_worse_steps += 1
            else:
                local_worse_steps = 0
            logger.info('Scheduler: local_worse_steps / opt.lr_decay_steps = %d / %d' % (local_worse_steps, opt.lr_decay_steps))
            if not opt.use_transformers_optimizer and \
               epoch_i > opt.warmup_epoch and \
               (local_worse_steps >= opt.lr_decay_steps or early_stopping.step() > opt.lr_decay_steps):
                scheduler.step()
                local_worse_steps = 0
            prev_eval_f1 = eval_f1
Beispiel #6
0
                x_hypothesis: x_dev_hypothesis,
                embedding_arr: pretrained_weights,
                y: y_dev,
                keep_prob: 1.0,
                keep_prob_lstm: 1.0,
                is_training: False
            })
            dev_pre_results.append(res)
        acc = np.mean(dev_pre_results)


        print('epoch: ', epoch,
              # 'dev_pre_results', dev_pre_results,
              'dev_acc', acc)

        if es.validate(acc):
            break


    # testデータでの評価
    test_pre_results = []
    for i in range(n_batches_dev):

        res = prediction_result.eval(session=sess, feed_dict={
            x_premise: x_test_premise,
            x_hypothesis: x_test_hypothesis,
            embedding_arr: pretrained_weights,
            y: y_dev,
            keep_prob: 1.0,
            keep_prob_lstm: 1.0,
            is_training: False
Beispiel #7
0
def train(opt):
    if torch.cuda.is_available():
        logger.info("%s", torch.cuda.get_device_name(0))

    # set etc
    torch.autograd.set_detect_anomaly(True)

    # set config
    config = load_config(opt)
    config['opt'] = opt
    logger.info("%s", config)

    # set path
    set_path(config)

    # prepare train, valid dataset
    train_loader, valid_loader = prepare_datasets(config)

    with temp_seed(opt.seed):
        # prepare model
        model = prepare_model(config)

        # create optimizer, scheduler, summary writer, scaler
        optimizer, scheduler, writer, scaler = prepare_osws(
            config, model, train_loader)
        config['optimizer'] = optimizer
        config['scheduler'] = scheduler
        config['writer'] = writer
        config['scaler'] = scaler

        # training
        early_stopping = EarlyStopping(logger,
                                       patience=opt.patience,
                                       measure=opt.measure,
                                       verbose=1)
        local_worse_epoch = 0
        best_eval_measure = float(
            'inf') if opt.measure == 'loss' else -float('inf')
        for epoch_i in range(opt.epoch):
            epoch_st_time = time.time()
            eval_loss, eval_acc, best_eval_measure = train_epoch(
                model, config, train_loader, valid_loader, epoch_i,
                best_eval_measure)
            # for nni
            if opt.hp_search_nni:
                nni.report_intermediate_result(eval_acc)
                logger.info('[eval_acc] : %g', eval_acc)
                logger.info('[Pipe send intermediate result done]')
            if opt.measure == 'loss': eval_measure = eval_loss
            else: eval_measure = eval_acc
            # early stopping
            if early_stopping.validate(eval_measure, measure=opt.measure):
                break
            if eval_measure == best_eval_measure:
                early_stopping.reset(best_eval_measure)
            early_stopping.status()
        # for nni
        if opt.hp_search_nni:
            nni.report_final_result(eval_acc)
            logger.info('[Final result] : %g', eval_acc)
            logger.info('[Send final result done]')