コード例 #1
0
def do_predict(settings, args):
    #
    if args.ckpt_loading == "latest":
        dir_ckpt = settings.model_dir
    else:
        dir_ckpt = settings.model_dir_best
    #
    pb_file = os.path.join(dir_ckpt, "model_frozen.pb")
    #
    # model
    model = settings.ModelClass(settings)
    model.prepare_for_prediction_with_pb(pb_file)
    #
    # data
    if args.data == "test":
        file_raw = os.path.join(args.dir_examples, "data_examples_test.txt")
    elif args.data == "train":
        file_raw = os.path.join(args.dir_examples, "data_examples_train.txt")
    elif args.data == "valid":
        file_raw = os.path.join(args.dir_examples, "data_examples_valid.txt")
    #
    data_raw = data_utils.load_from_file_raw(file_raw)
    #
    batch_stder = lambda x: data_utils.get_batch_std(x, settings)
    data_batcher = DataBatcher(data_raw,
                               batch_stder,
                               settings.batch_size_eval,
                               single_pass=True)
    #
    # predict
    count = 0
    while True:
        batch = data_batcher.get_next_batch()
        #
        if batch is None: break
        if count == settings.max_batches_eval: continue  #
        #
        count += 1
        print(count)
        #
        print("batch data:")
        print(batch["input_y"])
        print("batch data end")
        #
        results = model.predict_with_pb_from_batch(batch)["logits"]
        #
        print("results:")
        print(np.argmax(results[0], -1))
        print("results end")
        print()
        #
    #
    settings.logger.info('prediction finished, with total num_batches: %d' %
                         count)
コード例 #2
0
def do_eval(settings, args):
    #
    if args.ckpt_loading == "latest":
        dir_ckpt = settings.model_dir
    else:
        dir_ckpt = settings.model_dir_best
    #
    # model
    model = settings.ModelClass(settings)
    model.prepare_for_train_and_valid(dir_ckpt)
    model.assign_dropout_keep_prob(1.0)
    #
    # data
    file_raw = os.path.join(args.dir_examples, "data_examples_test.txt")
    data_raw = data_utils.load_from_file_raw(file_raw)
    #
    batch_stder = lambda x: data_utils.get_batch_std(x, settings)
    data_batcher = DataBatcher(data_raw,
                               batch_stder,
                               settings.batch_size_eval,
                               single_pass=True)
    #
    # eval
    eval_score, loss_aver, metric_aver = eval_process(
        model, data_batcher, settings.max_batches_eval, mode_eval=True)
    #
    print('loss_aver, metric_aver: %g, %g' % (loss_aver, metric_aver))
    settings.logger.info('loss_aver, metric_aver: %g, %g' %
                         (loss_aver, metric_aver))
    settings.logger.info('{}'.format(eval_score))
コード例 #3
0
 #
 settings.result_dir = os.path.join(settings.base_dir, "result")
 if not os.path.exists(settings.base_dir):
     os.mkdir(settings.base_dir)
 if not os.path.exists(settings.result_dir):
     os.mkdir(settings.result_dir)
 #
 # model & vocab
 settings.model_tag = "multi_doc_qa"
 settings.vocab = vocab
 #
 # mode
 if args.mode == "train":
     example_gen = lambda single_pass: example_generator(train_files, True, settings.max_p_len, single_pass)
     batch_stder = lambda items: do_batch_std(items, vocab, settings)
     batcher = DataBatcher(example_gen, batch_stder, settings.batch_size, single_pass=False)
     #
     settings.is_train = True
     settings.check_settings()
     settings.logger.info("{}".format(args))
     #
     model = ModelDocQA(settings)
     model.prepare_for_train_and_valid(settings.model_dir)
     #
     model_utils.do_train(model, batcher, settings)
     model.close_logger()
     #
 elif args.mode == "eval":
     example_gen = lambda single_pass: example_generator(dev_files, True, settings.max_p_len, single_pass)
     batch_stder = lambda items: do_batch_std(items, vocab, settings)
     batcher = DataBatcher(example_gen, batch_stder, settings.batch_size, single_pass=True)  # batch_size
コード例 #4
0
def do_train_and_valid(settings, args):
    #
    if args.ckpt_loading == "latest":
        dir_ckpt = settings.model_dir
    else:
        dir_ckpt = settings.model_dir_best
    #
    # model
    model = settings.ModelClass(settings)
    model.prepare_for_train_and_valid(dir_ckpt)
    #
    # data
    file_raw = os.path.join(args.dir_examples, "data_examples_train.txt")
    data_raw = data_utils.load_from_file_raw(file_raw)
    #
    batch_stder = lambda x: data_utils.get_batch_std(x, settings)
    data_batcher = DataBatcher(data_raw,
                               batch_stder,
                               settings.batch_size,
                               single_pass=False)
    #
    eval_period = settings.valid_period_batch
    file_raw_eval = os.path.join(args.dir_examples, "data_examples_valid.txt")
    data_raw_eval = data_utils.load_from_file_raw(file_raw_eval)
    #
    # train
    loss = 10000.0
    best_metric_val = 0
    # last_improved = 0
    lr = 0.0
    #
    count = 0
    model.settings.logger.info("")
    while True:
        #
        # eval
        if count % eval_period == 0:
            settings.logger.info("training curr batch, loss, lr: %d, %g, %g" %
                                 (count, loss, lr))
            #
            model.save_ckpt(settings.model_dir, settings.model_name, count)
            model.assign_dropout_keep_prob(1.0)
            #
            settings.logger.info('evaluating after num_batches: %d' % count)
            eval_batcher = DataBatcher(data_raw_eval,
                                       batch_stder,
                                       settings.batch_size,
                                       single_pass=True)
            #
            eval_score, loss_aver, metric_val = eval_process(
                model,
                eval_batcher,
                settings.max_batches_eval,
                mode_eval=False)
            settings.logger.info(
                "eval loss_aver, metric, metric_best: %g, %g, %g" %
                (loss_aver, metric_val, best_metric_val))
            #
            # save best
            if metric_val >= best_metric_val:  # >=
                best_metric_val = metric_val
                # last_improved = count
                # ckpt
                settings.logger.info('a new best model, saving ...')
                model.save_ckpt_best(settings.model_dir_best,
                                     settings.model_name, count)
                #
            #
            if lr < settings.learning_rate_minimum and count > settings.warmup_steps:
                settings.logger.info(
                    'current learning_rate < learning_rate_minimum, stop training'
                )
                break
            #
            model.assign_dropout_keep_prob(settings.keep_prob)
            settings.logger.info("")
            #
        #
        # train
        batch = data_batcher.get_next_batch()
        # if batch is None: break
        count += 1
        # print(count)
        #
        result_dict = model.run_train_one_batch(batch)  # just for train
        loss = result_dict["loss_optim"]
        lr = result_dict["lr"]
        #
        # print(loss)
        # model.logger.info("training curr batch, loss, lr: %d, %g, %g" % (count, loss, lr)
        #
    #
    settings.logger.info("training finshed with total num_batches: %d" % count)
コード例 #5
0
def do_debug(settings, args):
    #
    if args.ckpt_loading == "latest":
        dir_ckpt = settings.model_dir
    else:
        dir_ckpt = settings.model_dir_best
    #
    # model
    model = ModelWrapper(settings, settings.model_graph)
    model.prepare_for_train_and_valid(dir_ckpt)
    model.assign_dropout_keep_prob(1.0)
    #
    # data
    file_raw = os.path.join(args.dir_examples, "data_examples_test.txt")
    data_raw = data_utils.load_from_file_raw(file_raw)
    #
    batch_stder = lambda x: data_utils.get_batch_std(x, settings)
    data_batcher = DataBatcher(data_raw,
                               batch_stder,
                               settings.batch_size_eval,
                               single_pass=True)
    #
    # eval
    list_batches_result = []
    #
    loss_aver, metric_aver = 0.0, 0.0
    count = 0
    while True:
        batch = data_batcher.get_next_batch()
        #
        if batch is None: break
        if count == 1000000: continue  #
        #
        count += 1
        # print(count)
        #
        results, loss, metric = model.run_eval_one_batch(batch)
        loss_aver += loss
        metric_aver += metric
        # print(loss)
        # print(metric)
        #
        print(count)
        print("batch data:")
        print(batch[-1])
        #
        print("results:")
        print(np.argmax(results[0], -1))
        print()
        #
        item = batch[0], batch[1], np.argmax(results[0], -1)
        list_batches_result.append(item)
        #
    #
    dir_result = "data_check_result"
    if not os.path.exists(dir_result): os.mkdir(dir_result)
    #
    file_path = os.path.join(
        dir_result, "list_batches_result_%d.pkl" % settings.batch_size_eval)
    #
    with open(file_path, 'wb') as fp:
        pickle.dump(list_batches_result, fp)
    #
    loss_aver /= count
    metric_aver /= count
    #
    print('loss_aver, metric_aver: %g, %g' % (loss_aver, metric_aver))
    model.logger.info('loss_aver, metric_aver: %g, %g' %
                      (loss_aver, metric_aver))