save_path = save_f + '.' + str(epoch) + '.' + str(steps)
        log.info("testing {}, save to {}".format(test_f, save_path))
        qids, left_scores, right_scores = predict(
            exe,
            test_prog,
            test_pyreader,
            graph_vars,
            is_classify=args.is_classify,
            is_regression=args.is_regression,
        )

        save_dir = os.path.dirname(save_path)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        with open(save_path, 'w') as f:
            if len(qids) == 0:
                for s, p in zip(left_scores, right_scores):
                    f.write('{}\t{}\n'.format(s, p))
            else:
                for id, l, r in zip(qids, left_scores, right_scores):
                    f.write('{}\t{}\t{}\n'.format(id, l, r))


if __name__ == '__main__':
    prepare_logger(log)
    print_arguments(args)
    check_cuda(args.use_cuda)
    main(args)
예제 #2
0
def get_role_init_dict(args, suf):
    """main"""
    # log = logging.getLogger()
    # prepare_logger(log)
    log = logging.getLogger(__name__)
    check_cuda(args.use_cuda)
    labels_map = {}  # label

    for line in utils.read_by_lines(args.label_map_config):
        arr = line.split("\t")
        labels_map[arr[0]] = int(arr[1])
    args.num_labels = len(labels_map)

    print("=========ERNIE CONFIG============")
    ernie_config = ErnieConfig(args.ernie_config_path)
    # ernie_config.print_config()
    print("=========ERNIE CONFIG============")
    if args.use_cuda:
        dev_list = fluid.cuda_places()
        place = dev_list[0]
        print("==============place==================", place)
        # place = dev_list[1]
        dev_count = len(dev_list)
    else:
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    print("==============place, dev_count==================", place, dev_count)
    reader = task_reader.RoleSequenceLabelReader(
        vocab_path=args.vocab_path,
        labels_map=labels_map,
        max_seq_len=args.max_seq_len,
        do_lower_case=args.do_lower_case,
        in_tokens=args.in_tokens,
        random_seed=args.random_seed,
        task_id=args.task_id)

    if not (args.do_train or args.do_val or args.do_test):
        raise ValueError("For args `do_train`, `do_val` and `do_test`, at "
                         "least one of them must be True.")

    startup_prog = fluid.Program()
    if args.random_seed is not None:
        startup_prog.random_seed = args.random_seed

    if args.do_val or args.do_test:
        test_prog = fluid.Program()
        with fluid.program_guard(test_prog, startup_prog):
            with fluid.unique_name.guard():
                # TODO pyreader_name 再次调整为不同
                test_pyreader, graph_vars = create_model(
                    args,
                    pyreader_name='test_reader_role' + suf,
                    ernie_config=ernie_config)

        test_prog = test_prog.clone(for_test=True)

    nccl2_num_trainers = 1
    nccl2_trainer_id = 0

    exe = fluid.Executor(place)
    exe.run(startup_prog)

    if args.do_val or args.do_test:
        if not args.init_checkpoint:
            raise ValueError("args 'init_checkpoint' should be set if"
                             "only doing validation or testing!")
        init_checkpoint(exe,
                        args.init_checkpoint,
                        main_program=startup_prog,
                        use_fp16=args.use_fp16)
    trigger_dict = dict()
    trigger_dict['log'] = log
    trigger_dict['args'] = args
    trigger_dict['labels_map'] = labels_map
    trigger_dict['ernie_config'] = ernie_config
    trigger_dict['place'] = place
    trigger_dict['dev_count'] = dev_count
    trigger_dict['reader'] = reader
    trigger_dict['startup_prog'] = startup_prog
    trigger_dict['test_prog'] = test_prog
    trigger_dict['test_pyreader'] = test_pyreader
    trigger_dict['graph_vars'] = graph_vars
    trigger_dict['nccl2_num_trainers'] = nccl2_num_trainers
    trigger_dict['nccl2_trainer_id'] = nccl2_trainer_id
    trigger_dict['exe'] = exe
    return trigger_dict