def train(args):
    ernie_config = ErnieConfig(args.ernie_config)
    ernie_config.print_config()

    if not (args.do_train or args.do_predict):
        raise ValueError("For args `do_train` and `do_predict`, at "
                         "least one of them must be True.")

    if args.use_cuda:
        place = fluid.CUDAPlace(0)
        dev_count = fluid.core.get_cuda_device_count()
    else:
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    exe = fluid.Executor(place)

    processor = DataProcessor(vocab_path=args.vocab_path,
                              do_lower_case=args.do_lower_case,
                              max_seq_length=args.max_seq_len,
                              in_tokens=args.in_tokens,
                              doc_stride=args.doc_stride,
                              max_query_length=args.max_query_length)

    startup_prog = fluid.Program()
    if args.random_seed is not None:
        startup_prog.random_seed = args.random_seed

    if args.do_train:
        train_data_generator = processor.data_generator(
            data_path=args.train_file,
            batch_size=args.batch_size,
            phase='train',
            shuffle=True,
            dev_count=dev_count,
            version_2_with_negative=args.version_2_with_negative,
            epoch=args.epoch)

        num_train_examples = processor.get_num_examples(phase='train')
        if args.in_tokens:
            max_train_steps = args.epoch * num_train_examples // (
                args.batch_size // args.max_seq_len) // dev_count
        else:
            max_train_steps = args.epoch * num_train_examples // (
                args.batch_size) // dev_count
        warmup_steps = int(max_train_steps * args.warmup_proportion)
        print("Device count: %d" % dev_count)
        print("Num train examples: %d" % num_train_examples)
        print("Max train steps: %d" % max_train_steps)
        print("Num warmup steps: %d" % warmup_steps)

        train_program = fluid.Program()
        with fluid.program_guard(train_program, startup_prog):
            with fluid.unique_name.guard():
                train_data_loader, loss, num_seqs = create_model(
                    ernie_config=ernie_config, is_training=True)

                scheduled_lr, loss_scaling = optimization(
                    loss=loss,
                    warmup_steps=warmup_steps,
                    num_train_steps=max_train_steps,
                    learning_rate=args.learning_rate,
                    train_program=train_program,
                    startup_prog=startup_prog,
                    weight_decay=args.weight_decay,
                    scheduler=args.lr_scheduler,
                    use_fp16=args.use_fp16,
                    use_dynamic_loss_scaling=args.use_dynamic_loss_scaling,
                    init_loss_scaling=args.init_loss_scaling,
                    incr_every_n_steps=args.incr_every_n_steps,
                    decr_every_n_nan_or_inf=args.decr_every_n_nan_or_inf,
                    incr_ratio=args.incr_ratio,
                    decr_ratio=args.decr_ratio)

    if args.do_predict:
        test_prog = fluid.Program()
        with fluid.program_guard(test_prog, startup_prog):
            with fluid.unique_name.guard():
                test_data_loader, unique_ids, start_logits, end_logits, num_seqs = create_model(
                    ernie_config=ernie_config, is_training=False)

        test_prog = test_prog.clone(for_test=True)

    exe.run(startup_prog)

    if args.do_train:
        if args.init_checkpoint and args.init_pretraining_params:
            print(
                "WARNING: args 'init_checkpoint' and 'init_pretraining_params' "
                "both are set! Only arg 'init_checkpoint' is made valid.")
        if args.init_checkpoint:
            init_checkpoint(exe,
                            args.init_checkpoint,
                            main_program=startup_prog,
                            use_fp16=args.use_fp16)
        elif args.init_pretraining_params:
            init_pretraining_params(exe,
                                    args.init_pretraining_params,
                                    main_program=startup_prog,
                                    use_fp16=args.use_fp16)
    elif args.do_predict:
        if not args.init_checkpoint:
            raise ValueError("args 'init_checkpoint' should be set if"
                             "only doing prediction!")
        init_checkpoint(exe,
                        args.init_checkpoint,
                        main_program=startup_prog,
                        use_fp16=args.use_fp16)

    if args.do_train:
        exec_strategy = fluid.ExecutionStrategy()
        exec_strategy.use_experimental_executor = args.use_fast_executor
        exec_strategy.num_threads = dev_count
        exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope

        train_compiled_program = fluid.CompiledProgram(
            train_program).with_data_parallel(loss_name=loss.name,
                                              exec_strategy=exec_strategy)

        train_data_loader.set_batch_generator(train_data_generator, place)

        train_data_loader.start()
        steps = 0
        total_cost, total_num_seqs = [], []
        time_begin = time.time()
        while True:
            try:
                steps += 1
                if steps % args.skip_steps == 0:
                    if args.use_fp16:
                        fetch_list = [
                            loss.name, scheduled_lr.name, num_seqs.name,
                            loss_scaling.name
                        ]
                    else:
                        fetch_list = [
                            loss.name, scheduled_lr.name, num_seqs.name
                        ]
                else:
                    fetch_list = []

                outputs = exe.run(train_compiled_program,
                                  fetch_list=fetch_list)

                if steps % args.skip_steps == 0:
                    if args.use_fp16:
                        np_loss, np_lr, np_num_seqs, np_scaling = outputs
                    else:
                        np_loss, np_lr, np_num_seqs = outputs
                    total_cost.extend(np_loss * np_num_seqs)
                    total_num_seqs.extend(np_num_seqs)

                    if args.verbose:
                        verbose = "train data_loader queue size: %d, " % train_data_loader.queue.size(
                        )
                        verbose += "learning rate: %f " % np_lr[0]
                        if args.use_fp16:
                            verbose += ", loss scaling: %f" % np_scaling[0]
                        print(verbose)

                    time_end = time.time()
                    used_time = time_end - time_begin
                    current_example, epoch = processor.get_train_progress()

                    print("epoch: %d, progress: %d/%d, step: %d, loss: %f, "
                          "speed: %f steps/s" %
                          (epoch, current_example, num_train_examples, steps,
                           np.sum(total_cost) / np.sum(total_num_seqs),
                           args.skip_steps / used_time))
                    total_cost, total_num_seqs = [], []
                    time_begin = time.time()

                if steps % args.save_steps == 0 or steps == max_train_steps:
                    save_path = os.path.join(args.checkpoints,
                                             "step_" + str(steps))
                    fluid.io.save_persistables(exe, save_path, train_program)
            except fluid.core.EOFException:
                save_path = os.path.join(args.checkpoints,
                                         "step_" + str(steps) + "_final")
                fluid.io.save_persistables(exe, save_path, train_program)
                train_data_loader.reset()
                break

    if args.do_predict:
        input_files = []
        for input_pattern in args.predict_file:
            input_files.extend(glob.glob(input_pattern))
        assert len(input_files) > 0, 'Can not find predict_file {}'.format(
            args.predict_file)
        for input_file in input_files:
            print('Run prediction on {}'.format(input_file))
            prefix = os.path.basename(input_file)
            prefix = re.sub('.json', '', prefix)

            test_data_loader.set_batch_generator(
                processor.data_generator(data_path=input_file,
                                         batch_size=args.batch_size,
                                         phase='predict',
                                         shuffle=False,
                                         dev_count=1,
                                         epoch=1), place)

            predict(exe,
                    test_prog,
                    test_data_loader, [
                        unique_ids.name, start_logits.name, end_logits.name,
                        num_seqs.name
                    ],
                    processor,
                    prefix=prefix)
Пример #2
0
def train(args):
    if not (args.do_train or args.do_predict):
        raise ValueError("For args `do_train` and `do_predict`, at "
                         "least one of them must be True.")

    xlnet_config = XLNetConfig(args.model_config_path)
    xlnet_config.print_config()

    if args.use_cuda:
        place = fluid.CUDAPlace(0)
        dev_count = fluid.core.get_cuda_device_count()
    else:
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    exe = fluid.Executor(place)

    processor = DataProcessor(spiece_model_file=args.spiece_model_file,
                              uncased=args.uncased,
                              max_seq_length=args.max_seq_length,
                              doc_stride=args.doc_stride,
                              max_query_length=args.max_query_length)

    startup_prog = fluid.Program()
    if args.random_seed is not None:
        startup_prog.random_seed = args.random_seed

    if args.do_train:
        train_data_generator = processor.data_generator(
            data_path=args.train_file,
            batch_size=args.train_batch_size,
            phase='train',
            shuffle=True,
            dev_count=dev_count,
            epoch=args.epoch)

        num_train_examples = processor.get_num_examples(phase='train')
        print("Device count: %d" % dev_count)
        print("Max num of epoches: %d" % args.epoch)
        print("Num of train examples: %d" % num_train_examples)
        print("Num of train steps: %d" % args.train_steps)
        print("Num of warmup steps: %d" % args.warmup_steps)

        train_program = fluid.Program()
        with fluid.program_guard(train_program, startup_prog):
            with fluid.unique_name.guard():
                train_data_loader, loss = create_model(
                    xlnet_config=xlnet_config, is_training=True)

                scheduled_lr = optimization(
                    loss=loss,
                    warmup_steps=args.warmup_steps,
                    num_train_steps=args.train_steps,
                    learning_rate=args.learning_rate,
                    train_program=train_program,
                    startup_prog=startup_prog,
                    weight_decay=args.weight_decay,
                    lr_layer_decay_rate=args.lr_layer_decay_rate,
                    scheduler=args.lr_scheduler)

    if args.do_predict:
        test_prog = fluid.Program()
        with fluid.program_guard(test_prog, startup_prog):
            with fluid.unique_name.guard():
                test_data_loader, predictions = create_model(
                    xlnet_config=xlnet_config, is_training=False)

        test_prog = test_prog.clone(for_test=True)

    exe.run(startup_prog)

    if args.do_train:
        if args.init_checkpoint and args.init_pretraining_params:
            print(
                "WARNING: args 'init_checkpoint' and 'init_pretraining_params' "
                "both are set! Only arg 'init_checkpoint' is made valid.")
        if args.init_checkpoint:
            init_checkpoint(exe,
                            args.init_checkpoint,
                            main_program=startup_prog)
        elif args.init_pretraining_params:
            init_pretraining_params(exe,
                                    args.init_pretraining_params,
                                    main_program=startup_prog)

    elif args.do_predict:
        if not args.init_checkpoint:
            raise ValueError("args 'init_checkpoint' should be set if"
                             "only doing prediction!")
        init_checkpoint(exe, args.init_checkpoint, main_program=startup_prog)

    if args.do_train:
        exec_strategy = fluid.ExecutionStrategy()
        exec_strategy.use_experimental_executor = args.use_fast_executor
        exec_strategy.num_threads = dev_count
        exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope

        build_strategy = fluid.BuildStrategy()
        # These two flags must be set in this model for correctness
        build_strategy.fuse_all_optimizer_ops = True
        build_strategy.enable_inplace = False
        train_exe = fluid.ParallelExecutor(use_cuda=args.use_cuda,
                                           loss_name=loss.name,
                                           exec_strategy=exec_strategy,
                                           build_strategy=build_strategy,
                                           main_program=train_program)

        train_data_loader.set_batch_generator(train_data_generator, place)

        train_data_loader.start()
        steps = 0
        total_cost = []
        time_begin = time.time()
        print("Begin to train model  ...")
        print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
        while steps < args.train_steps:
            try:
                steps += 1
                if steps % args.skip_steps == 0:
                    fetch_list = [loss.name, scheduled_lr.name]
                else:
                    fetch_list = []

                outputs = train_exe.run(fetch_list=fetch_list)

                if steps % args.skip_steps == 0:
                    np_loss, np_lr = outputs
                    total_cost.extend(np_loss)

                    if args.verbose:
                        verbose = "train data_loader queue size: %d, " % train_data_loader.queue.size(
                        )
                        verbose += "learning rate: %f " % np_lr[0]
                        print(verbose)

                    time_end = time.time()
                    used_time = time_end - time_begin
                    current_example, epoch = processor.get_train_progress()

                    print("epoch: %d, progress: %d/%d, step: %d, loss: %f, "
                          "speed: %f steps/s" %
                          (epoch, current_example, num_train_examples, steps,
                           np.mean(total_cost), args.skip_steps / used_time))
                    total_cost = []
                    time_begin = time.time()

                if steps % args.save_steps == 0 or steps == args.train_steps:
                    save_path = os.path.join(args.checkpoints,
                                             "step_" + str(steps))
                    fluid.io.save_persistables(exe, save_path, train_program)
            except fluid.core.EOFException:
                save_path = os.path.join(args.checkpoints,
                                         "step_" + str(steps) + "_final")
                fluid.io.save_persistables(exe, save_path, train_program)
                train_data_loader.reset()
                break
        print("Finish model training ...")
        print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))

    if args.do_predict:
        print("Begin to do prediction  ...")
        print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
        test_data_loader.set_batch_generator(
            processor.data_generator(data_path=args.predict_file,
                                     batch_size=args.predict_batch_size,
                                     phase='predict',
                                     shuffle=False,
                                     dev_count=1,
                                     epoch=1), place)

        predict(exe,
                test_prog,
                test_data_loader, [
                    predictions['unique_ids'].name,
                    predictions['start_top_log_probs'].name,
                    predictions['start_top_index'].name,
                    predictions['end_top_log_probs'].name,
                    predictions['end_top_index'].name,
                    predictions['cls_logits'].name
                ],
                processor,
                name='')

        print("Finish prediction ...")
        print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
Пример #3
0
def train(args):
    bert_config = BertConfig(args.bert_config_path)
    bert_config.print_config()

    if not (args.do_train or args.do_predict):
        raise ValueError("For args `do_train` and `do_predict`, at "
                         "least one of them must be True.")

    if args.use_cuda:
        place = fluid.CUDAPlace(0)
        dev_count = fluid.core.get_cuda_device_count()
    else:
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    exe = fluid.Executor(place)

    processor = DataProcessor(vocab_path=args.vocab_path,
                              do_lower_case=args.do_lower_case,
                              max_seq_length=args.max_seq_len,
                              in_tokens=args.in_tokens,
                              doc_stride=args.doc_stride,
                              max_query_length=args.max_query_length)

    startup_prog = fluid.Program()
    if args.random_seed is not None:
        startup_prog.random_seed = args.random_seed

    if args.do_train:
        train_data_generator = processor.data_generator(
            data_path=args.train_file,
            batch_size=args.batch_size,
            phase='train',
            shuffle=False,
            dev_count=dev_count,
            version_2_with_negative=args.version_2_with_negative,
            epoch=args.epoch)

        num_train_examples = processor.get_num_examples(phase='train')
        if args.in_tokens:
            max_train_steps = args.epoch * num_train_examples // (
                args.batch_size // args.max_seq_len) // dev_count
        else:
            max_train_steps = args.epoch * num_train_examples // (
                args.batch_size) // dev_count
        warmup_steps = int(max_train_steps * args.warmup_proportion)
        print("Device count: %d" % dev_count)
        print("Num train examples: %d" % num_train_examples)
        print("Max train steps: %d" % max_train_steps)
        print("Num warmup steps: %d" % warmup_steps)

        train_program = fluid.Program()
        with fluid.program_guard(train_program, startup_prog):
            with fluid.unique_name.guard():
                train_pyreader, loss, num_seqs = create_model(
                    pyreader_name='train_reader',
                    bert_config=bert_config,
                    is_training=True)

                scheduled_lr = optimization(loss=loss,
                                            warmup_steps=warmup_steps,
                                            num_train_steps=max_train_steps,
                                            learning_rate=args.learning_rate,
                                            train_program=train_program,
                                            startup_prog=startup_prog,
                                            weight_decay=args.weight_decay,
                                            scheduler=args.lr_scheduler,
                                            use_fp16=args.use_fp16,
                                            loss_scaling=args.loss_scaling)

                fluid.memory_optimize(train_program,
                                      skip_opt_set=[loss.name, num_seqs.name])

        if args.verbose:
            if args.in_tokens:
                lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
                    program=train_program,
                    batch_size=args.batch_size // args.max_seq_len)
            else:
                lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
                    program=train_program, batch_size=args.batch_size)
            print("Theoretical memory usage in training:  %.3f - %.3f %s" %
                  (lower_mem, upper_mem, unit))

    if args.do_predict:
        test_prog = fluid.Program()
        with fluid.program_guard(test_prog, startup_prog):
            with fluid.unique_name.guard():
                test_pyreader, unique_ids, start_logits, end_logits, num_seqs = create_model(
                    pyreader_name='test_reader',
                    bert_config=bert_config,
                    is_training=False)

                fluid.memory_optimize(test_prog,
                                      skip_opt_set=[
                                          unique_ids.name, start_logits.name,
                                          end_logits.name, num_seqs.name
                                      ])

        test_prog = test_prog.clone(for_test=True)

    exe.run(startup_prog)

    if args.do_train:
        if args.init_checkpoint and args.init_pretraining_params:
            print(
                "WARNING: args 'init_checkpoint' and 'init_pretraining_params' "
                "both are set! Only arg 'init_checkpoint' is made valid.")
        if args.init_checkpoint:
            init_checkpoint(exe,
                            args.init_checkpoint,
                            main_program=startup_prog,
                            use_fp16=args.use_fp16)
        elif args.init_pretraining_params:
            init_pretraining_params(exe,
                                    args.init_pretraining_params,
                                    main_program=startup_prog,
                                    use_fp16=args.use_fp16)
    elif args.do_predict:
        if not args.init_checkpoint:
            raise ValueError("args 'init_checkpoint' should be set if"
                             "only doing prediction!")
        init_checkpoint(exe,
                        args.init_checkpoint,
                        main_program=startup_prog,
                        use_fp16=args.use_fp16)

    if args.do_train:
        exec_strategy = fluid.ExecutionStrategy()
        exec_strategy.use_experimental_executor = args.use_fast_executor
        exec_strategy.num_threads = dev_count
        exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope

        train_exe = fluid.ParallelExecutor(use_cuda=args.use_cuda,
                                           loss_name=loss.name,
                                           exec_strategy=exec_strategy,
                                           main_program=train_program)

        train_pyreader.decorate_tensor_provider(train_data_generator)

        train_pyreader.start()
        steps = 0
        total_cost, total_num_seqs = [], []
        time_begin = time.time()
        while steps < max_train_steps:
            try:
                steps += 1
                if steps % args.skip_steps == 0:
                    if warmup_steps <= 0:
                        fetch_list = [loss.name, num_seqs.name]
                    else:
                        fetch_list = [
                            loss.name, scheduled_lr.name, num_seqs.name
                        ]
                else:
                    fetch_list = []

                outputs = train_exe.run(fetch_list=fetch_list)

                if steps % args.skip_steps == 0:
                    if warmup_steps <= 0:
                        np_loss, np_num_seqs = outputs
                    else:
                        np_loss, np_lr, np_num_seqs = outputs
                    total_cost.extend(np_loss * np_num_seqs)
                    total_num_seqs.extend(np_num_seqs)

                    if args.verbose:
                        verbose = "train pyreader queue size: %d, " % train_pyreader.queue.size(
                        )
                        verbose += "learning rate: %f" % (np_lr[0] if
                                                          warmup_steps > 0 else
                                                          args.learning_rate)
                        print(verbose)

                    time_end = time.time()
                    used_time = time_end - time_begin
                    current_example, epoch = processor.get_train_progress()

                    print("epoch: %d, progress: %d/%d, step: %d, loss: %f, "
                          "speed: %f steps/s" %
                          (epoch, current_example, num_train_examples, steps,
                           np.sum(total_cost) / np.sum(total_num_seqs),
                           args.skip_steps / used_time))
                    total_cost, total_num_seqs = [], []
                    time_begin = time.time()

                if steps % args.save_steps == 0:
                    save_path = os.path.join(args.checkpoints,
                                             "step_" + str(steps))
                    fluid.io.save_persistables(exe, save_path, train_program)
            except fluid.core.EOFException:
                save_path = os.path.join(args.checkpoints,
                                         "step_" + str(steps) + "_final")
                fluid.io.save_persistables(exe, save_path, train_program)
                train_pyreader.reset()
                break

    if args.do_predict:
        test_pyreader.decorate_tensor_provider(
            processor.data_generator(data_path=args.predict_file,
                                     batch_size=args.batch_size,
                                     phase='predict',
                                     shuffle=False,
                                     dev_count=1,
                                     epoch=1))

        predict(exe, test_prog, test_pyreader, [
            unique_ids.name, start_logits.name, end_logits.name, num_seqs.name
        ], processor)
def train(args):
    bert_config = BertConfig(args.bert_config_path)
    bert_config.print_config()

    if not (args.do_train or args.do_predict or args.do_val):
        raise ValueError("For args `do_train` and `do_predict`, at "
                         "least one of them must be True.")

    if args.use_cuda:
        place = fluid.CUDAPlace(0)
        dev_count = fluid.core.get_cuda_device_count()
    else:
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    exe = fluid.Executor(place)

    id2concept, concept2id, concept_embedding_mat = read_concept_embedding(
        args.concept_embedding_path)

    processor = DataProcessor(vocab_path=args.vocab_path,
                              do_lower_case=args.do_lower_case,
                              max_seq_length=args.max_seq_len,
                              in_tokens=args.in_tokens,
                              doc_stride=args.doc_stride,
                              max_query_length=args.max_query_length)

    startup_prog = fluid.Program()
    if args.random_seed is not None:
        startup_prog.random_seed = args.random_seed
        random.seed(args.random_seed)
        np.random.seed(args.random_seed)

    if args.do_train:
        train_concept_settings = {
            'tokenization_path':
            '../retrieve_concepts/tokenization_squad/tokens/train.tokenization.{}.data'
            .format('uncased' if args.do_lower_case else 'cased'),
            'concept2id':
            concept2id,
            'use_wordnet':
            args.use_wordnet,
            'retrieved_synset_path':
            args.retrieved_synset_path,
            'use_nell':
            args.use_nell,
            'retrieved_nell_concept_path':
            args.train_retrieved_nell_concept_path,
        }
        train_data_generator = processor.data_generator(
            data_path=args.train_file,
            batch_size=args.batch_size,
            phase='train',
            shuffle=True,
            dev_count=dev_count,
            version_2_with_negative=args.version_2_with_negative,
            epoch=args.epoch,
            **train_concept_settings)

        num_train_examples = processor.get_num_examples(phase='train')
        if args.in_tokens:
            max_train_steps = args.epoch * num_train_examples // (
                args.batch_size // args.max_seq_len) // dev_count
        else:
            max_train_steps = args.epoch * num_train_examples // (
                args.batch_size) // dev_count
        warmup_steps = int(max_train_steps * args.warmup_proportion)
        logger.info("Device count: %d" % dev_count)
        logger.info("Num train examples: %d" % num_train_examples)
        logger.info("Max train steps: %d" % max_train_steps)
        logger.info("Num warmup steps: %d" % warmup_steps)

        train_program = fluid.Program()
        # if args.random_seed is not None:
        #     train_program.random_seed = args.random_seed
        with fluid.program_guard(train_program, startup_prog):
            with fluid.unique_name.guard():
                train_pyreader, loss, num_seqs = create_model(
                    pyreader_name='train_reader',
                    bert_config=bert_config,
                    max_concept_length=processor.train_max_concept_length,
                    concept_embedding_mat=concept_embedding_mat,
                    is_training=True,
                    freeze=args.freeze)

                scheduled_lr = optimization(loss=loss,
                                            warmup_steps=warmup_steps,
                                            num_train_steps=max_train_steps,
                                            learning_rate=args.learning_rate,
                                            train_program=train_program,
                                            startup_prog=startup_prog,
                                            weight_decay=args.weight_decay,
                                            scheduler=args.lr_scheduler,
                                            use_fp16=args.use_fp16,
                                            loss_scaling=args.loss_scaling)

                if args.use_ema:
                    ema = fluid.optimizer.ExponentialMovingAverage(
                        args.ema_decay)
                    ema.update()

                fluid.memory_optimize(train_program,
                                      skip_opt_set=[loss.name, num_seqs.name])

        if args.verbose:
            if args.in_tokens:
                lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
                    program=train_program,
                    batch_size=args.batch_size // args.max_seq_len)
            else:
                lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
                    program=train_program, batch_size=args.batch_size)
            logger.info(
                "Theoretical memory usage in training:  %.3f - %.3f %s" %
                (lower_mem, upper_mem, unit))

    if args.do_predict or args.do_val:
        eval_concept_settings = {
            'tokenization_path':
            '../retrieve_concepts/tokenization_squad/tokens/dev.tokenization.{}.data'
            .format('uncased' if args.do_lower_case else 'cased'),
            'concept2id':
            concept2id,
            'use_wordnet':
            args.use_wordnet,
            'retrieved_synset_path':
            args.retrieved_synset_path,
            'use_nell':
            args.use_nell,
            'retrieved_nell_concept_path':
            args.dev_retrieved_nell_concept_path,
        }
        eval_data_generator = processor.data_generator(
            data_path=args.predict_file,
            batch_size=args.batch_size,
            phase='predict',
            shuffle=False,
            dev_count=1,
            epoch=1,
            **eval_concept_settings)

        test_prog = fluid.Program()
        # if args.random_seed is not None:
        #     test_prog.random_seed = args.random_seed
        with fluid.program_guard(test_prog, startup_prog):
            with fluid.unique_name.guard():
                test_pyreader, unique_ids, start_logits, end_logits, num_seqs = create_model(
                    pyreader_name='test_reader',
                    bert_config=bert_config,
                    max_concept_length=processor.predict_max_concept_length,
                    concept_embedding_mat=concept_embedding_mat,
                    is_training=False)

                if args.use_ema and 'ema' not in dir():
                    ema = fluid.optimizer.ExponentialMovingAverage(
                        args.ema_decay)

                fluid.memory_optimize(test_prog,
                                      skip_opt_set=[
                                          unique_ids.name, start_logits.name,
                                          end_logits.name, num_seqs.name
                                      ])

        test_prog = test_prog.clone(for_test=True)
        # if args.random_seed is not None:
        #     test_prog.random_seed = args.random_seed

    exe.run(startup_prog)

    if args.do_train:
        logger.info('load pretrained concept embedding')
        fluid.global_scope().find_var('concept_emb_mat').get_tensor().set(
            concept_embedding_mat, place)

        if args.init_checkpoint and args.init_pretraining_params:
            logger.info(
                "WARNING: args 'init_checkpoint' and 'init_pretraining_params' "
                "both are set! Only arg 'init_checkpoint' is made valid.")
        if args.init_checkpoint:
            init_checkpoint(exe,
                            args.init_checkpoint,
                            main_program=startup_prog,
                            use_fp16=args.use_fp16)
        elif args.init_pretraining_params:
            init_pretraining_params(exe,
                                    args.init_pretraining_params,
                                    main_program=startup_prog,
                                    use_fp16=args.use_fp16)
    elif args.do_predict or args.do_val:
        if not args.init_checkpoint:
            raise ValueError("args 'init_checkpoint' should be set if"
                             "only doing prediction!")
        init_checkpoint(exe,
                        args.init_checkpoint,
                        main_program=startup_prog,
                        use_fp16=args.use_fp16)

    if args.do_train:
        exec_strategy = fluid.ExecutionStrategy()
        exec_strategy.use_experimental_executor = args.use_fast_executor
        exec_strategy.num_threads = dev_count
        exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope

        train_exe = fluid.ParallelExecutor(use_cuda=args.use_cuda,
                                           loss_name=loss.name,
                                           exec_strategy=exec_strategy,
                                           main_program=train_program)

        train_pyreader.decorate_tensor_provider(train_data_generator)

        train_pyreader.start()
        steps = 0
        total_cost, total_num_seqs = [], []
        time_begin = time.time()
        while steps < max_train_steps:
            try:
                steps += 1
                if steps % args.skip_steps == 0:
                    if warmup_steps <= 0:
                        fetch_list = [loss.name, num_seqs.name]
                    else:
                        fetch_list = [
                            loss.name, scheduled_lr.name, num_seqs.name
                        ]
                else:
                    fetch_list = []

                outputs = train_exe.run(fetch_list=fetch_list)

                if steps % args.skip_steps == 0:
                    if warmup_steps <= 0:
                        np_loss, np_num_seqs = outputs
                    else:
                        np_loss, np_lr, np_num_seqs = outputs
                    total_cost.extend(np_loss * np_num_seqs)
                    total_num_seqs.extend(np_num_seqs)

                    if args.verbose:
                        verbose = "train pyreader queue size: %d, " % train_pyreader.queue.size(
                        )
                        verbose += "learning rate: %f" % (np_lr[0] if
                                                          warmup_steps > 0 else
                                                          args.learning_rate)
                        logger.info(verbose)

                    time_end = time.time()
                    used_time = time_end - time_begin
                    current_example, epoch = processor.get_train_progress()

                    logger.info(
                        "epoch: %d, progress: %d/%d, step: %d, loss: %f, "
                        "speed: %f steps/s" %
                        (epoch, current_example, num_train_examples, steps,
                         np.sum(total_cost) / np.sum(total_num_seqs),
                         args.skip_steps / used_time))
                    total_cost, total_num_seqs = [], []
                    time_begin = time.time()

                if steps % args.save_steps == 0 or steps == max_train_steps:
                    save_path = os.path.join(args.checkpoints,
                                             "step_" + str(steps))
                    fluid.io.save_persistables(exe, save_path, train_program)

                if steps % args.validation_steps == 0 or steps == max_train_steps:
                    if args.do_val:
                        test_pyreader.decorate_tensor_provider(
                            processor.data_generator(
                                data_path=args.predict_file,
                                batch_size=args.batch_size,
                                phase='predict',
                                shuffle=False,
                                dev_count=1,
                                epoch=1,
                                **eval_concept_settings))
                        val_performance = predict(
                            exe, test_prog, test_pyreader, [
                                unique_ids.name, start_logits.name,
                                end_logits.name, num_seqs.name
                            ], processor, eval_concept_settings,
                            'validate_result_step_{}.json'.format(steps))
                        logger.info(
                            "Validation performance after step {}:\n* Exact_match: {}\n* F1: {}"
                            .format(steps, val_performance['exact_match'],
                                    val_performance['f1']))

            except fluid.core.EOFException:
                save_path = os.path.join(args.checkpoints,
                                         "step_" + str(steps) + "_final")
                fluid.io.save_persistables(exe, save_path, train_program)
                train_pyreader.reset()
                break

    if args.do_predict:
        test_pyreader.decorate_tensor_provider(eval_data_generator)

        if args.use_ema:
            with ema.apply(exe):
                eval_performance = predict(exe, test_prog, test_pyreader, [
                    unique_ids.name, start_logits.name, end_logits.name,
                    num_seqs.name
                ], processor, eval_concept_settings)
        else:
            eval_performance = predict(exe, test_prog, test_pyreader, [
                unique_ids.name, start_logits.name, end_logits.name,
                num_seqs.name
            ], processor, eval_concept_settings)

        logger.info("Eval performance:\n* Exact_match: {}\n* F1: {}".format(
            eval_performance['exact_match'], eval_performance['f1']))