Ejemplo n.º 1
0
 def __init__(self, config, num_classes, output_attentions=False):
     super(BERTImage, self).__init__()
     self.output_attentions = output_attentions
     bert_config = BertConfig.from_dict(config)
     num_channels_in = config['num_channels_in']
     self.hidden_size = config['hidden_size']
     self.features_upscale = nn.Linear(num_channels_in, self.hidden_size)
     # use the BERT encoder
     self.encoder = BertEncoder(bert_config,
                                output_attentions=output_attentions)
     self.register_buffer('attention_mask', torch.tensor(1.0))
     self.classifier = nn.Linear(self.hidden_size, num_classes)
Ejemplo n.º 2
0
    def __init__(self, opt):
        super(BertMapping, self).__init__()
        bert_config = BertConfig.from_json_file(opt.bert_config_file)
        self.bert = BertModel(bert_config)
        self.bert.load_state_dict(
            torch.load(opt.init_checkpoint, map_location='cpu'))
        freeze_layers(self.bert)
        self.txt_stru = opt.txt_stru

        if opt.txt_stru == 'pooling':
            self.dropout = nn.Dropout(bert_config.hidden_dropout_prob)
            self.mapping = nn.Linear(bert_config.hidden_size, opt.final_dims)
        elif opt.txt_stru == 'cnn':
            Ks = [1, 2, 3]
            in_channel = 1
            out_channel = 512
            embedding_dim = bert_config.hidden_size
            self.convs1 = nn.ModuleList([
                nn.Conv2d(in_channel, out_channel, (K, embedding_dim))
                for K in Ks
            ])
            self.dropout = nn.Dropout(bert_config.hidden_dropout_prob)
            self.mapping = nn.Linear(len(Ks) * out_channel, opt.final_dims)
        elif opt.txt_stru == 'rnn':
            embedding_dim = bert_config.hidden_size
            self.bi_gru = opt.bi_gru
            self.rnn = nn.GRU(embedding_dim,
                              opt.embed_size,
                              opt.num_layers,
                              batch_first=True,
                              bidirectional=opt.bi_gru)
            self.dropout = nn.Dropout(bert_config.hidden_dropout_prob)
            self.mapping = nn.Linear(opt.embed_size, opt.final_dims)
        elif opt.txt_stru == 'trans':
            bert_config = BertConfig.from_json_file(opt.img_trans_cfg)
            self.layer = bert.BERTLayer(bert_config)
            self.dropout = nn.Dropout(bert_config.hidden_dropout_prob)
            self.mapping = nn.Linear(768, opt.final_dims)
Ejemplo n.º 3
0
def main(args):
    """main function"""
    bert_config = BertConfig(args.bert_config_path)
    bert_config.print_config()

    task_name = args.task_name.lower()
    paradigm_inst = define_paradigm.Paradigm(task_name)
    pred_inst = define_predict_pack.DefinePredict()
    pred_func = getattr(pred_inst, pred_inst.task_map[task_name])

    processors = {
        'udc': reader.UDCProcessor,
        'swda': reader.SWDAProcessor,
        'mrda': reader.MRDAProcessor,
        'atis_slot': reader.ATISSlotProcessor,
        'atis_intent': reader.ATISIntentProcessor,
        'dstc2': reader.DSTC2Processor,
        'dstc2_asr': reader.DSTC2Processor,
    }

    in_tokens = {
        'udc': True,
        'swda': True,
        'mrda': True,
        'atis_slot': False,
        'atis_intent': True,
        'dstc2': True,
        'dstc2_asr': True
    }

    processor = processors[task_name](data_dir=args.data_dir,
                                      vocab_path=args.vocab_path,
                                      max_seq_len=args.max_seq_len,
                                      do_lower_case=args.do_lower_case,
                                      in_tokens=in_tokens[task_name],
                                      task_name=task_name,
                                      random_seed=args.random_seed)
    num_labels = len(processor.get_labels())

    predict_prog = fluid.Program()
    predict_startup = fluid.Program()
    with fluid.program_guard(predict_prog, predict_startup):
        with fluid.unique_name.guard():
            pred_results = create_model(args,
                                        pyreader_name='predict_reader',
                                        bert_config=bert_config,
                                        num_labels=num_labels,
                                        paradigm_inst=paradigm_inst,
                                        is_prediction=True)
            predict_pyreader = pred_results.get('pyreader', None)
            probs = pred_results.get('probs', None)
            feed_target_names = pred_results.get('feed_targets_name', None)

    predict_prog = predict_prog.clone(for_test=True)

    if args.use_cuda:
        place = fluid.CUDAPlace(0)
        dev_count = fluid.core.get_cuda_device_count()
    else:
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))

    place = fluid.CUDAPlace(0) if args.use_cuda == True else fluid.CPUPlace()
    exe = fluid.Executor(place)
    exe.run(predict_startup)

    if args.init_checkpoint:
        init_pretraining_params(exe, args.init_checkpoint, predict_prog)
    else:
        raise ValueError(
            "args 'init_checkpoint' should be set for prediction!")

    predict_exe = fluid.ParallelExecutor(use_cuda=args.use_cuda,
                                         main_program=predict_prog)

    test_data_generator = processor.data_generator(batch_size=args.batch_size,
                                                   phase='test',
                                                   epoch=1,
                                                   shuffle=False)
    predict_pyreader.decorate_tensor_provider(test_data_generator)

    predict_pyreader.start()
    all_results = []
    time_begin = time.time()
    while True:
        try:
            results = predict_exe.run(fetch_list=[probs.name])
            all_results.extend(results[0])
        except fluid.core.EOFException:
            predict_pyreader.reset()
            break
    time_end = time.time()

    np.set_printoptions(precision=4, suppress=True)
    print("-------------- prediction results --------------")
    print("example_id\t" + '  '.join(processor.get_labels()))
    if in_tokens[task_name]:
        for index, result in enumerate(all_results):
            tags = pred_func(result)
            print("%s\t%s" % (index, tags))
    else:
        tags = pred_func(all_results, args.max_seq_len)
        for index, tag in enumerate(tags):
            print("%s\t%s" % (index, tag))

    if args.save_inference_model_path:
        _, ckpt_dir = os.path.split(args.init_checkpoint)
        dir_name = ckpt_dir + '_inference_model'
        model_path = os.path.join(args.save_inference_model_path, dir_name)
        fluid.io.save_inference_model(model_path,
                                      feed_target_names, [probs],
                                      exe,
                                      main_program=predict_prog)
Ejemplo n.º 4
0
def main(args): 
    """main function"""
    bert_config = BertConfig(args.bert_config_path)
    bert_config.print_config()

    if args.use_cuda: 
        place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
        dev_count = fluid.core.get_cuda_device_count()
    else: 
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    exe = fluid.Executor(place)

    task_name = args.task_name.lower()
    paradigm_inst = define_paradigm.Paradigm(task_name)

    processors = {
        'udc': reader.UDCProcessor,
        'swda': reader.SWDAProcessor,
        'mrda': reader.MRDAProcessor,
        'atis_slot': reader.ATISSlotProcessor,
        'atis_intent': reader.ATISIntentProcessor,
        'dstc2': reader.DSTC2Processor,
    }
    in_tokens = {
        'udc': True,
        'swda': True,
        'mrda': True,
        'atis_slot': False,
        'atis_intent': True,
        'dstc2': True,
    }

    processor = processors[task_name](data_dir=args.data_dir,
                                      vocab_path=args.vocab_path,
                                      max_seq_len=args.max_seq_len,
                                      do_lower_case=args.do_lower_case, 
                                      in_tokens=in_tokens[task_name],
                                      task_name=task_name, 
                                      random_seed=args.random_seed)

    num_labels = len(processor.get_labels())

    if not (args.do_train or args.do_val or args.do_test): 
        raise ValueError("For args `do_train`, `do_val` and `do_test`, at "
                         "least one of them must be True.")

    startup_prog = fluid.Program()
    if args.random_seed is not None:
        startup_prog.random_seed = args.random_seed

    if args.do_train: 
        train_data_generator = processor.data_generator(
            batch_size=args.batch_size,
            phase='train',
            epoch=args.epoch,
            shuffle=True)
        num_train_examples = processor.get_num_examples(phase='train')

        if in_tokens[task_name]: 
            max_train_steps = args.epoch * num_train_examples // (
                args.batch_size // args.max_seq_len) // dev_count
        else: 
            max_train_steps = args.epoch * num_train_examples // args.batch_size // dev_count

        warmup_steps = int(max_train_steps * args.warmup_proportion)
        print("Device count: %d" % dev_count)
        print("Num train examples: %d" % num_train_examples)
        print("Max train steps: %d" % max_train_steps)
        print("Num warmup steps: %d" % warmup_steps)

        train_program = fluid.Program()
        if args.random_seed is not None:
            train_program.random_seed = args.random_seed
        with fluid.program_guard(train_program, startup_prog):
            with fluid.unique_name.guard():
                results = create_model(
                    args,
                    pyreader_name='train_reader',
                    bert_config=bert_config,
                    num_labels=num_labels,
                    paradigm_inst=paradigm_inst)
                train_pyreader = results.get("pyreader", None)
                loss = results.get("loss", None)
                probs = results.get("probs", None)
                accuracy = results.get("accuracy", None)
                num_seqs = results.get("num_seqs", None)
                scheduled_lr = optimization(
                    loss=loss,
                    warmup_steps=warmup_steps,
                    num_train_steps=max_train_steps,
                    learning_rate=args.learning_rate,
                    train_program=train_program,
                    startup_prog=startup_prog,
                    weight_decay=args.weight_decay,
                    scheduler=args.lr_scheduler,
                    use_fp16=args.use_fp16,
                    loss_scaling=args.loss_scaling)

                if accuracy is not None: 
                    skip_opt_set = [loss.name, probs.name, accuracy.name, num_seqs.name]
                else: 
                    skip_opt_set = [loss.name, probs.name, num_seqs.name]
                fluid.memory_optimize(
                    input_program=train_program, 
                    skip_opt_set=skip_opt_set)

        if args.verbose: 
            if in_tokens[task_name]: 
                lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
                    program=train_program,
                    batch_size=args.batch_size // args.max_seq_len)
            else: 
                lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
                program=train_program, batch_size=args.batch_size)
            print("Theoretical memory usage in training: %.3f - %.3f %s" %
                (lower_mem, upper_mem, unit))

    if args.do_val or args.do_test:
        test_prog = fluid.Program()
        with fluid.program_guard(test_prog, startup_prog):
            with fluid.unique_name.guard():
                test_results = create_model(
                    args,
                    pyreader_name='test_reader',
                    bert_config=bert_config,
                    num_labels=num_labels,
                    paradigm_inst=paradigm_inst)
                test_pyreader = test_results.get("pyreader", None)
                loss = test_results.get("loss", None)
                probs = test_results.get("probs", None)
                accuracy = test_results.get("accuracy", None)
                num_seqs = test_results.get("num_seqs", None)
        test_prog = test_prog.clone(for_test=True)
    
    exe.run(startup_prog)

    if args.do_train:
        if args.init_checkpoint and args.init_pretraining_params:
            print(
                  "WARNING: args 'init_checkpoint' and 'init_pretraining_params' "
                  "both are set! Only arg 'init_checkpoint' is made valid.")
        if args.init_checkpoint:
            init_checkpoint(
                exe, 
                args.init_checkpoint, 
                main_program=startup_prog,
                use_fp16=args.use_fp16)
        elif args.init_pretraining_params:
            init_pretraining_params(
                exe, 
                args.init_pretraining_params, 
                main_program=startup_prog,
                use_fp16=args.use_fp16)
    elif args.do_val or args.do_test: 
        if not args.init_checkpoint: 
            raise ValueError("args 'init_checkpoint' should be set if"
                    "only doing validation or testing!")
        init_checkpoint(
            exe,
            args.init_checkpoint,
            main_program=startup_prog,
            use_fp16=args.use_fp16)

    if args.do_train: 
        exec_strategy = fluid.ExecutionStrategy()
        exec_strategy.use_experimental_executor = args.use_fast_executor
        exec_strategy.num_threads = dev_count
        exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope

        train_exe = fluid.ParallelExecutor(
            use_cuda=args.use_cuda,
            loss_name=loss.name,
            exec_strategy=exec_strategy,
            main_program=train_program)
        train_pyreader.decorate_tensor_provider(train_data_generator)
    else: 
        train_exe = None

    if args.do_val or args.do_test:  
        test_exe = fluid.ParallelExecutor(
            use_cuda=args.use_cuda,
            main_program=test_prog,
            share_vars_from=train_exe)
   
    if args.do_train: 
        train_pyreader.start()
        steps = 0
        total_cost, total_acc, total_num_seqs = [], [], []
        time_begin = time.time()
        ce_info = []
        while True:
            try: 
                steps += 1
                if steps % args.skip_steps == 0: 
                    if warmup_steps <= 0: 
                        if accuracy is not None: 
                            fetch_list = [loss.name, accuracy.name, num_seqs.name]
                        else: 
                            fetch_list = [loss.name, num_seqs.name]
                    else: 
                        if accuracy is not None:
                            fetch_list = [
                                loss.name, accuracy.name, scheduled_lr.name,
                                num_seqs.name
                            ]
                        else: 
                            fetch_list = [loss.name, scheduled_lr.name, num_seqs.name]
                else: 
                    fetch_list = []
                if accuracy is not None: 
                    fetch_test_list = [loss.name, accuracy.name, num_seqs.name]
                else: 
                    fetch_test_list = [loss.name, num_seqs.name]

                outputs = train_exe.run(fetch_list=fetch_list)

                if steps % args.skip_steps == 0: 
                    if warmup_steps <= 0: 
                        if accuracy is not None: 
                            np_loss, np_acc, np_num_seqs = outputs
                        else: 
                            np_loss, np_num_seqs = outputs
                    else: 
                        if accuracy is not None:
                            np_loss, np_acc, np_lr, np_num_seqs = outputs
                        else: 
                            np_loss, np_lr, np_num_seqs = outputs

                    total_cost.extend(np_loss * np_num_seqs) 
                    total_num_seqs.extend(np_num_seqs)
                    if accuracy is not None: 
                        total_acc.extend(np_acc * np_num_seqs)
                    
                    if args.verbose: 
                        verbose = "train pyreader queue size: %d, " % train_pyreader.queue.size()
                        verbose += "learning rate: %f" % (
                            np_lr[0]
                            if warmup_steps > 0 else args.learning_rate)
                        print(verbose) 

                    current_example, current_epoch = processor.get_train_progress()
                    time_end = time.time()
                    used_time = time_end - time_begin
                    current_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
                    if accuracy is not None: 
                        print("%s epoch: %d, progress: %d/%d, step: %d, ave loss: %f, "
                              "ave acc: %f, speed: %f steps/s" %
                              (current_time, current_epoch, current_example, num_train_examples,
                               steps, np.sum(total_cost) / np.sum(total_num_seqs),
                               np.sum(total_acc) / np.sum(total_num_seqs),
                               args.skip_steps / used_time))
                        ce_info.append([np.sum(total_cost) / np.sum(total_num_seqs), np.sum(total_acc) / np.sum(total_num_seqs), args.skip_steps / used_time])
                    else: 
                        print("%s epoch: %d, progress: %d/%d, step: %d, ave loss: %f, "
                            "speed: %f steps/s" %
                            (current_time, current_epoch, current_example, num_train_examples,
                            steps, np.sum(total_cost) / np.sum(total_num_seqs),
                            args.skip_steps / used_time))
                        ce_info.append([np.sum(total_cost) / np.sum(total_num_seqs), args.skip_steps / used_time])
                    total_cost, total_acc, total_num_seqs = [], [], []
                    time_begin = time.time()

                if steps % args.save_steps == 0:
                    save_path = os.path.join(args.checkpoints, "step_" + str(steps))
                    fluid.io.save_persistables(exe, save_path, train_program)
                if steps % args.validation_steps == 0: 
                    #evaluate dev set
                    if args.do_val:
                        test_pyreader.decorate_tensor_provider(
                            processor.data_generator(  
                                batch_size=args.batch_size,
                                phase='dev',
                                epoch=1,
                                shuffle=False))
                        evaluate(test_exe, test_prog, test_pyreader, fetch_test_list, "dev")
                    #evaluate test set
                    if args.do_test: 
                        test_pyreader.decorate_tensor_provider(
                            processor.data_generator(
                                batch_size=args.batch_size,
                                phase='test',
                                epoch=1,
                                shuffle=False))
                        evaluate(test_exe, test_prog, test_pyreader, fetch_test_list, "test")
            except fluid.core.EOFException:
                save_path = os.path.join(args.checkpoints, "step_" + str(steps))
                fluid.io.save_persistables(exe, save_path, train_program)
                train_pyreader.reset()
                break
    if args.do_train and args.enable_ce:
        card_num = get_cards()
        print("zytest_card_num", card_num)
        ce_loss = 0
        ce_acc = 0
        ce_time = 0
        try:
            ce_loss = ce_info[-2][0]
            ce_acc = ce_info[-2][1]
            ce_time = ce_info[-2][2]
        except:
            print("ce info error")
        print("kpis\teach_step_duration_%s_card%s\t%s" %
                (task_name, card_num, ce_time))
        print("kpis\ttrain_loss_%s_card%s\t%f" %
            (task_name, card_num, ce_loss))
        print("kpis\ttrain_acc_%s_card%s\t%f" %
            (task_name, card_num, ce_acc))

    #final eval on dev set
    if args.do_val: 
        test_pyreader.decorate_tensor_provider( 
            processor.data_generator( 
                batch_size=args.batch_size, phase='dev', epoch=1,
                shuffle=False))
        print("Final validation result:")
        evaluate(test_exe, test_prog, test_pyreader, fetch_test_list, "dev")

    #final eval on test set
    if args.do_test: 
        test_pyreader.decorate_tensor_provider( 
            processor.data_generator(
                batch_size=args.batch_size,
                phase='test',
                epoch=1,
                shuffle=False)) 
        print("Final test result:") 
        evaluate(test_exe, test_prog, test_pyreader, fetch_test_list, "test")
Ejemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument('--pause', type=int, default=0)
    parser.add_argument('--iteration', type=str, default='1')
    parser.add_argument('--fs', type=str, default='local',
                        help='must be `local`. Do not change.')

    # Data paths
    parser.add_argument('--data_dir', default='data/', type=str)
    parser.add_argument("--train_file", default='train-v1.1.json', type=str,
                        help="SQuAD json for training. E.g., train-v1.1.json")
    parser.add_argument("--predict_file", default='dev-v1.1.json', type=str,
                        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
    parser.add_argument('--gt_file', default='dev-v1.1.json', type=str, help='ground truth file needed for evaluation.')

    # Metadata paths
    parser.add_argument('--metadata_dir', default='metadata/', type=str)
    parser.add_argument("--vocab_file", default='vocab.txt', type=str,
                        help="The vocabulary file that the BERT model was trained on.")
    parser.add_argument("--bert_model_option", default='large_uncased', type=str,
                        help="model architecture option. [large_uncased] or [base_uncased]")
    parser.add_argument("--bert_config_file", default='bert_config.json', type=str,
                        help="The config json file corresponding to the pre-trained BERT model. "
                             "This specifies the model architecture.")
    parser.add_argument("--init_checkpoint", default='pytorch_model.bin', type=str,
                        help="Initial checkpoint (usually from a pre-trained BERT model).")

    # Output and load paths
    parser.add_argument("--output_dir", default='out/', type=str,
                        help="The output directory where the model checkpoints will be written.")
    parser.add_argument("--index_file", default='index.hdf5', type=str, help="index output file.")
    parser.add_argument("--question_emb_file", default='question.hdf5', type=str, help="question output file.")

    parser.add_argument('--load_dir', default='out/', type=str)

    # Local paths (if we want to run cmd)
    parser.add_argument('--eval_script', default='evaluate-v1.1.py', type=str)

    # Do's
    parser.add_argument("--do_load", default=False, action='store_true', help='Do load. If eval, do load automatically')
    parser.add_argument("--do_train", default=False, action='store_true', help="Whether to run training.")
    parser.add_argument("--do_train_filter", default=False, action='store_true', help='Train filter or not.')
    parser.add_argument("--do_train_sparse", default=False, action='store_true', help='Train sparse or not.')
    parser.add_argument("--do_predict", default=False, action='store_true', help="Whether to run eval on the dev set.")
    parser.add_argument('--do_eval', default=False, action='store_true')
    parser.add_argument('--do_embed_question', default=False, action='store_true')
    parser.add_argument('--do_index', default=False, action='store_true')
    parser.add_argument('--do_serve', default=False, action='store_true')

    # Model options: if you change these, you need to train again
    parser.add_argument("--do_case", default=False, action='store_true',
                        help="Whether to lower case the input text. Should be True for uncased "
                             "models and False for cased models.")
    parser.add_argument('--phrase_size', default=961, type=int)
    parser.add_argument('--metric', default='ip', type=str, help='ip | l2')
    parser.add_argument("--use_sparse", default=False, action='store_true')

    # GPU and memory related options
    parser.add_argument("--max_seq_length", default=384, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
                             "longer than this will be truncated, and sequences shorter than this will be padded.")
    parser.add_argument("--doc_stride", default=128, type=int,
                        help="When splitting up a long document into chunks, how much stride to take between chunks.")
    parser.add_argument("--max_query_length", default=64, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--train_batch_size", default=12, type=int, help="Total batch size for training.")
    parser.add_argument("--predict_batch_size", default=16, type=int, help="Total batch size for predictions.")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--optimize_on_cpu',
                        default=False,
                        action='store_true',
                        help="Whether to perform optimization and keep the optimizer averages on CPU")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--fp16',
                        default=False,
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")

    # Training options: only effective during training
    parser.add_argument("--learning_rate", default=3e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs", default=3.0, type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--num_train_filter_epochs", default=1.0, type=float,
                        help="Total number of training epochs for filter to perform.")
    parser.add_argument("--num_train_sparse_epochs", default=3.0, type=float,
                        help="Total number of training epochs for sparse to perform.")
    parser.add_argument("--warmup_proportion", default=0.1, type=float,
                        help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
                             "of training.")
    parser.add_argument("--save_checkpoints_steps", default=1000, type=int,
                        help="How often to save the model checkpoint.")
    parser.add_argument("--iterations_per_loop", default=1000, type=int,
                        help="How many steps to make in each estimator call.")

    # Prediction options: only effective during prediction
    parser.add_argument("--n_best_size", default=20, type=int,
                        help="The total number of n-best predictions to generate in the nbest_predictions.json "
                             "output file.")
    parser.add_argument("--max_answer_length", default=30, type=int,
                        help="The maximum length of an answer that can be generated. This is needed because the start "
                             "and end predictions are not conditioned on one another.")

    # Index Options
    parser.add_argument('--dtype', default='float32', type=str)
    parser.add_argument('--filter_threshold', default=-1e9, type=float)
    parser.add_argument('--compression_offset', default=-2, type=float)
    parser.add_argument('--compression_scale', default=20, type=float)
    parser.add_argument('--split_by_para', default=False, action='store_true')

    # Serve Options
    parser.add_argument('--port', default=9009, type=int)

    # Others
    parser.add_argument('--parallel', default=False, action='store_true')
    parser.add_argument("--verbose_logging", default=False, action='store_true',
                        help="If true, all of the warnings related to data processing will be printed. "
                             "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--draft', default=False, action='store_true')
    parser.add_argument('--draft_num_examples', type=int, default=12)

    args = parser.parse_args()

    # Filesystem routines
    if args.fs == 'local':
        class Processor(object):
            def __init__(self, path):
                self._save = None
                self._load = None
                self._path = path

            def bind(self, save, load):
                self._save = save
                self._load = load

            def save(self, checkpoint=None, save_fn=None, **kwargs):
                path = os.path.join(self._path, str(checkpoint))
                if save_fn is None:
                    self._save(path, **kwargs)
                else:
                    save_fn(path, **kwargs)

            def load(self, checkpoint, load_fn=None, session=None, **kwargs):
                assert self._path == session
                path = os.path.join(self._path, str(checkpoint), 'model.pt')
                if load_fn is None:
                    self._load(path, **kwargs)
                else:
                    load_fn(path, **kwargs)

        processor = Processor(args.load_dir)
    else:
        raise ValueError(args.fs)

    if not args.do_train:
        args.do_load = True

    # Configure paths
    args.train_file = os.path.join(args.data_dir, args.train_file)
    args.predict_file = os.path.join(args.data_dir, args.predict_file)
    args.gt_file = os.path.join(args.data_dir, args.gt_file)

    args.bert_config_file = os.path.join(args.metadata_dir, args.bert_config_file.replace(".json", "") +
                                         "_" + args.bert_model_option + ".json")
    args.init_checkpoint = os.path.join(args.metadata_dir, args.init_checkpoint.replace(".bin", "") +
                                        "_" + args.bert_model_option + ".bin")
    args.vocab_file = os.path.join(args.metadata_dir, args.vocab_file)
    args.index_file = os.path.join(args.output_dir, args.index_file)

    # Multi-GPU stuff
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
            args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)

    # Seed for reproducibility
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    bert_config = BertConfig.from_json_file(args.bert_config_file)

    if args.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (args.max_seq_length, bert_config.max_position_embeddings))

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        # raise ValueError("Output directory () already exists and is not empty.")
        pass
    else:
        os.makedirs(args.output_dir, exist_ok=True)

    tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file, do_lower_case=not args.do_case)

    model = BertPhraseModel(
        bert_config,
        phrase_size=args.phrase_size,
        metric=args.metric,
        use_sparse=args.use_sparse
    )

    print('Number of model parameters:', sum(p.numel() for p in model.parameters()))

    if not args.do_load and args.init_checkpoint is not None:
        state_dict = torch.load(args.init_checkpoint, map_location='cpu')
        # If below: for Korean BERT compatibility
        if next(iter(state_dict)).startswith('bert.'):
            state_dict = {key[len('bert.'):]: val for key, val in state_dict.items()}
            state_dict = {key: val for key, val in state_dict.items() if key in model.encoder.bert_model.state_dict()}
        model.encoder.bert.load_state_dict(state_dict)

    if args.fp16:
        model.half()

    if not args.optimize_on_cpu:
        model.to(device)

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank)
    elif args.parallel or n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if args.do_load:
        bind_model(processor, model)
        processor.load(args.iteration, session=args.load_dir)

    if args.do_train:
        train_examples = read_squad_examples(
            input_file=args.train_file, is_training=True, draft=args.draft, draft_num_examples=args.draft_num_examples)
        num_train_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)

        no_decay = ['bias', 'gamma', 'beta']
        optimizer_parameters = [
            {'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01},
            {'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0}
        ]
        optimizer = BERTAdam(optimizer_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_steps)

        bind_model(processor, model, optimizer)

        global_step = 0
        train_features, train_features_ = convert_examples_to_features(
            examples=train_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=True)

        train_features = inject_noise_to_features_list(train_features,
                                                       clamp=True,
                                                       replace=True,
                                                       shuffle=True)

        logger.info("***** Running training *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)

        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
        all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)

        all_input_ids_ = torch.tensor([f.input_ids for f in train_features_], dtype=torch.long)
        all_input_mask_ = torch.tensor([f.input_mask for f in train_features_], dtype=torch.long)

        if args.fp16:
            (all_input_ids, all_input_mask,
             all_start_positions,
             all_end_positions) = tuple(t.half() for t in (all_input_ids, all_input_mask,
                                                           all_start_positions, all_end_positions))
            all_input_ids_, all_input_mask_ = tuple(t.half() for t in (all_input_ids_, all_input_mask_))

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_input_ids_, all_input_mask_,
                                   all_start_positions, all_end_positions)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        model.train()
        for epoch in range(int(args.num_train_epochs)):
            for step, batch in enumerate(tqdm(train_dataloader, desc="Epoch %d" % (epoch + 1))):
                batch = tuple(t.to(device) for t in batch)
                (input_ids, input_mask,
                 input_ids_, input_mask_,
                 start_positions, end_positions) = batch
                loss, _ = model(input_ids, input_mask,
                                input_ids_, input_mask_,
                                start_positions, end_positions)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.optimize_on_cpu:
                        model.to('cpu')
                    optimizer.step()  # We have accumulated enought gradients
                    model.zero_grad()
                    if args.optimize_on_cpu:
                        model.to(device)
                    global_step += 1

            processor.save(epoch + 1)

    if args.do_train_filter:
        train_examples = read_squad_examples(
            input_file=args.train_file, is_training=True, draft=args.draft, draft_num_examples=args.draft_num_examples)
        num_train_steps = int(
            len(
                train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_filter_epochs)

        if args.parallel or n_gpu > 1:
            optimizer = Adam(model.module.filter.parameters())
        else:
            optimizer = Adam(model.filter.parameters())

        bind_model(processor, model, optimizer)

        global_step = 0
        train_features, train_features_ = convert_examples_to_features(
            examples=train_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=True)
        logger.info("***** Running filter training *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)

        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
        all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)

        all_input_ids_ = torch.tensor([f.input_ids for f in train_features_], dtype=torch.long)
        all_input_mask_ = torch.tensor([f.input_mask for f in train_features_], dtype=torch.long)

        if args.fp16:
            (all_input_ids, all_input_mask,
             all_start_positions,
             all_end_positions) = tuple(t.half() for t in (all_input_ids, all_input_mask,
                                                           all_start_positions, all_end_positions))
            all_input_ids_, all_input_mask_ = tuple(t.half() for t in (all_input_ids_, all_input_mask_))

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_input_ids_, all_input_mask_,
                                   all_start_positions, all_end_positions)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        model.train()
        for epoch in range(int(args.num_train_filter_epochs)):
            for step, batch in enumerate(tqdm(train_dataloader, desc="Epoch %d" % (epoch + 1))):
                batch = tuple(t.to(device) for t in batch)
                (input_ids, input_mask,
                 input_ids_, input_mask_,
                 start_positions, end_positions) = batch
                _, loss = model(input_ids, input_mask,
                                input_ids_, input_mask_,
                                start_positions, end_positions)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.optimize_on_cpu:
                        model.to('cpu')
                    optimizer.step()  # We have accumulated enought gradients
                    model.zero_grad()
                    if args.optimize_on_cpu:
                        model.to(device)
                    global_step += 1

            processor.save(epoch + 1)

    if args.do_train_sparse:
        train_examples = read_squad_examples(
            input_file=args.train_file, is_training=True, draft=args.draft, draft_num_examples=args.draft_num_examples)
        num_train_steps = int(
            len(
                train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_sparse_epochs)

        '''
        if args.parallel or n_gpu > 1:
            optimizer = Adam(model.module.sparse_layer.parameters())
        else:
            optimizer = Adam(model.sparse_layer.parameters())
        '''

        no_decay = ['bias', 'gamma', 'beta']
        optimizer_parameters = [
            {'params': [p for n, p in model.named_parameters() if (n not in no_decay) and ('filter' not in n)],
             'weight_decay_rate': 0.01},
            {'params': [p for n, p in model.named_parameters() if (n in no_decay) and ('filter' not in n)],
             'weight_decay_rate': 0.0}
        ]
        optimizer = BERTAdam(optimizer_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_steps)

        bind_model(processor, model, optimizer)

        global_step = 0
        train_features, train_features_ = convert_examples_to_features(
            examples=train_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=True)
        logger.info("***** Running sparse training *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)

        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
        all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)

        all_input_ids_ = torch.tensor([f.input_ids for f in train_features_], dtype=torch.long)
        all_input_mask_ = torch.tensor([f.input_mask for f in train_features_], dtype=torch.long)

        if args.fp16:
            (all_input_ids, all_input_mask,
             all_start_positions,
             all_end_positions) = tuple(t.half() for t in (all_input_ids, all_input_mask,
                                                           all_start_positions, all_end_positions))
            all_input_ids_, all_input_mask_ = tuple(t.half() for t in (all_input_ids_, all_input_mask_))

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_input_ids_, all_input_mask_,
                                   all_start_positions, all_end_positions)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        model.train()
        for epoch in range(int(args.num_train_sparse_epochs)):
            for step, batch in enumerate(tqdm(train_dataloader, desc="Epoch %d" % (epoch + 1))):
                batch = tuple(t.to(device) for t in batch)
                (input_ids, input_mask,
                 input_ids_, input_mask_,
                 start_positions, end_positions) = batch
                loss, _ = model(input_ids, input_mask,
                                input_ids_, input_mask_,
                                start_positions, end_positions)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.optimize_on_cpu:
                        model.to('cpu')
                    optimizer.step()  # We have accumulated enought gradients
                    model.zero_grad()
                    if args.optimize_on_cpu:
                        model.to(device)
                    global_step += 1

            processor.save(epoch + 1)

    if args.do_predict:
        eval_examples = read_squad_examples(
            input_file=args.predict_file, is_training=False, draft=args.draft,
            draft_num_examples=args.draft_num_examples)
        eval_features, query_eval_features = convert_examples_to_features(
            examples=eval_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=False)

        logger.info("***** Running predictions *****")
        logger.info("  Num orig examples = %d", len(eval_examples))
        logger.info("  Num split examples = %d", len(eval_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_input_ids_ = torch.tensor([f.input_ids for f in query_eval_features], dtype=torch.long)
        all_input_mask_ = torch.tensor([f.input_mask for f in query_eval_features], dtype=torch.long)
        all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
        if args.fp16:
            (all_input_ids, all_input_mask, all_example_index) = tuple(t.half() for t in (all_input_ids, all_input_mask,
                                                                                          all_example_index))
            all_input_ids_, all_input_mask_ = tuple(t.half() for t in (all_input_ids_, all_input_mask_))

        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_input_ids_, all_input_mask_,
                                  all_example_index)
        if args.local_rank == -1:
            eval_sampler = SequentialSampler(eval_data)
        else:
            eval_sampler = DistributedSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)

        model.eval()
        logger.info("Start evaluating")

        def get_results():
            for (input_ids, input_mask, input_ids_, input_mask_, example_indices) in eval_dataloader:
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                input_ids_ = input_ids_.to(device)
                input_mask_ = input_mask_.to(device)
                with torch.no_grad():
                    batch_all_logits, bs, be = model(input_ids, input_mask, input_ids_, input_mask_)
                for i, example_index in enumerate(example_indices):
                    all_logits = batch_all_logits[i].detach().cpu().numpy()
                    filter_start_logits = bs[i].detach().cpu().numpy()
                    filter_end_logits = be[i].detach().cpu().numpy()
                    eval_feature = eval_features[example_index.item()]
                    unique_id = int(eval_feature.unique_id)
                    yield RawResult(unique_id=unique_id,
                                    all_logits=all_logits,
                                    filter_start_logits=filter_start_logits,
                                    filter_end_logits=filter_end_logits)

        output_prediction_file = os.path.join(args.output_dir, "predictions.json")
        write_predictions(eval_examples, eval_features, get_results(),
                          args.max_answer_length,
                          not args.do_case, output_prediction_file, args.verbose_logging,
                          args.filter_threshold)

        if args.do_eval:
            command = "python %s %s %s" % (args.eval_script, args.gt_file, output_prediction_file)
            import subprocess
            process = subprocess.Popen(command.split(), stdout=subprocess.PIPE)
            output, error = process.communicate()

    if args.do_embed_question:
        question_examples = read_squad_examples(
            question_only=True,
            input_file=args.predict_file, is_training=False, draft=args.draft,
            draft_num_examples=args.draft_num_examples)
        query_eval_features = convert_questions_to_features(
            examples=question_examples,
            tokenizer=tokenizer,
            max_query_length=args.max_query_length)
        question_dataloader = convert_question_features_to_dataloader(query_eval_features, args.fp16, args.local_rank,
                                                                      args.predict_batch_size)

        model.eval()
        logger.info("Start embedding")
        question_results = get_question_results_(question_examples, query_eval_features, question_dataloader, device,
                                                 model)
        path = os.path.join(args.output_dir, args.question_emb_file)
        print('Writing %s' % path)
        write_question_results(question_results, query_eval_features, path)

    if args.do_index:
        if ':' not in args.predict_file:
            predict_files = [args.predict_file]
            offsets = [0]
        else:
            dirname = os.path.dirname(args.predict_file)
            basename = os.path.basename(args.predict_file)
            start, end = list(map(int, basename.split(':')))

            # skip files if possible
            if os.path.exists(args.index_file):
                with h5py.File(args.index_file, 'r') as f:
                    dids = list(map(int, f.keys()))
                start = int(max(dids) / 1000)
                print('%s exists; starting from %d' % (args.index_file, start))

            names = [str(i).zfill(4) for i in range(start, end)]
            predict_files = [os.path.join(dirname, name) for name in names]
            offsets = [int(each) * 1000 for each in names]

        for offset, predict_file in zip(offsets, predict_files):
            try:
                context_examples = read_squad_examples(
                    context_only=True,
                    input_file=predict_file, is_training=False, draft=args.draft,
                    draft_num_examples=args.draft_num_examples)

                for example in context_examples:
                    example.doc_idx += offset

                context_features = convert_documents_to_features(
                    examples=context_examples,
                    tokenizer=tokenizer,
                    max_seq_length=args.max_seq_length,
                    doc_stride=args.doc_stride)

                logger.info("***** Running indexing on %s *****" % predict_file)
                logger.info("  Num orig examples = %d", len(context_examples))
                logger.info("  Num split examples = %d", len(context_features))
                logger.info("  Batch size = %d", args.predict_batch_size)

                all_input_ids = torch.tensor([f.input_ids for f in context_features], dtype=torch.long)
                all_input_mask = torch.tensor([f.input_mask for f in context_features], dtype=torch.long)
                all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
                if args.fp16:
                    all_input_ids, all_input_mask, all_example_index = tuple(
                        t.half() for t in (all_input_ids, all_input_mask, all_example_index))

                context_data = TensorDataset(all_input_ids, all_input_mask, all_example_index)

                if args.local_rank == -1:
                    context_sampler = SequentialSampler(context_data)
                else:
                    context_sampler = DistributedSampler(context_data)
                context_dataloader = DataLoader(context_data, sampler=context_sampler,
                                                batch_size=args.predict_batch_size)

                model.eval()
                logger.info("Start indexing")

                def get_context_results():
                    for (input_ids, input_mask, example_indices) in context_dataloader:
                        input_ids = input_ids.to(device)
                        input_mask = input_mask.to(device)
                        with torch.no_grad():
                            batch_start, batch_end, batch_span_logits, bs, be, batch_sparse = model(input_ids,
                                                                                                    input_mask)
                        for i, example_index in enumerate(example_indices):
                            start = batch_start[i].detach().cpu().numpy().astype(args.dtype)
                            end = batch_end[i].detach().cpu().numpy().astype(args.dtype)
                            sparse = None
                            if batch_sparse is not None:
                                sparse = batch_sparse[i].detach().cpu().numpy().astype(args.dtype)
                            span_logits = batch_span_logits[i].detach().cpu().numpy().astype(args.dtype)
                            filter_start_logits = bs[i].detach().cpu().numpy().astype(args.dtype)
                            filter_end_logits = be[i].detach().cpu().numpy().astype(args.dtype)
                            context_feature = context_features[example_index.item()]
                            unique_id = int(context_feature.unique_id)
                            yield ContextResult(unique_id=unique_id,
                                                start=start,
                                                end=end,
                                                span_logits=span_logits,
                                                filter_start_logits=filter_start_logits,
                                                filter_end_logits=filter_end_logits,
                                                sparse=sparse)

                t0 = time()
                write_hdf5(context_examples, context_features, get_context_results(),
                           args.max_answer_length, not args.do_case, args.index_file, args.filter_threshold,
                           args.verbose_logging,
                           offset=args.compression_offset, scale=args.compression_scale,
                           split_by_para=args.split_by_para,
                           use_sparse=args.use_sparse)
                print('%s: %.1f mins' % (predict_file, (time() - t0) / 60))
            except Exception as e:
                with open(os.path.join(args.output_dir, 'error_files.txt'), 'a') as fp:
                    fp.write('error file: %s\n' % predict_file)
                    fp.write('error message: %s\n' % str(e))

    if args.do_serve:
        def get(text):
            question_examples = [SquadExample(qas_id='serve', question_text=text)]
            query_eval_features = convert_questions_to_features(
                examples=question_examples,
                tokenizer=tokenizer,
                max_query_length=16)
            question_dataloader = convert_question_features_to_dataloader(query_eval_features, args.fp16,
                                                                          args.local_rank,
                                                                          args.predict_batch_size)

            model.eval()

            question_results = get_question_results_(question_examples, query_eval_features, question_dataloader,
                                                     device, model)
            question_result = next(iter(question_results))
            out = question_result.start.tolist(), question_result.end.tolist(), question_result.span_logit.tolist()
            return out

        serve(get, args.port)
Ejemplo n.º 6
0
def run():
    bert_config = BertConfig.from_json_file(HP.bert_config)
    inputs = tf.placeholder(dtype=tf.int32, shape=[None, None])
    segments = tf.placeholder(dtype=tf.int32, shape=[None, None])
    inputs_length = tf.placeholder(dtype=tf.int32, shape=[None])
    answers = tf.placeholder(dtype=tf.int32, shape=[None, None, 2])
    answers_length = tf.placeholder(dtype=tf.int32, shape=[None])

    mod = model.Model(bert_config,
                      HP.is_training,
                      HP.num_units,
                      inputs,
                      segments,
                      inputs_length,
                      answers,
                      answers_length,
                      layers=HP.layers)

    train_data = DataGenerator(HP.train_file, HP.max_seq_length, HP.batch_size)
    dev_data = DataGenerator(HP.dev_file, HP.max_seq_length, HP.batch_size)

    long_train_loss = {i: Average() for i in HP.train_layers}
    long_train_accuracy = {i: Average() for i in HP.train_layers}

    train_loss = {i: Average() for i in HP.train_layers}
    train_accuracy = {i: Average() for i in HP.train_layers}

    dev_loss = {i: Average() for i in HP.train_layers}
    dev_accuracy = {i: Average() for i in HP.train_layers}

    epoch_dev_loss = {i: Average() for i in HP.train_layers}
    epoch_dev_accuracy = {i: Average() for i in HP.train_layers}

    sess = tf.Session()
    var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='bert')
    var = {v.name: v for v in var}
    start_saver = tf.train.Saver(var)
    start_saver.restore(sess, HP.start1_checkpoint)

    init = tf.initialize_all_variables()
    sess.run(init)

    logger = Logger(HP.log_files)

    loss = mod.losses(HP.train_layers)
    accuracy = mod.accuracy(HP.train_layers)
    train = mod.train(HP.learning_rate, HP.bert_train, HP.train_layers)

    var = mod.weights
    var = {v.name: v for v in var}
    end_saver = tf.train.Saver(var)
    saver = tf.train.Saver()

    step = 0
    for epoch in range(1, HP.epochs + 1):
        while not train_data.has_ended():
            step = step + 1
            train_inputs, train_segments, train_inputs_length, train_answers, train_answers_length = \
                train_data.get_next()
            feedict = {
                inputs: train_inputs,
                segments: train_segments,
                inputs_length: train_inputs_length,
                answers: train_answers,
                answers_length: train_answers_length
            }
            _loss, _accuracy, _ = sess.run([loss, accuracy, train],
                                           feed_dict=feedict)
            _batch_size = len(train_inputs_length)
            for i in HP.train_layers:
                train_loss[i].add(_loss[i], _batch_size)
                train_accuracy[i].add(_accuracy[i], _batch_size)
                long_train_loss[i].add(_loss[i], _batch_size)
                long_train_accuracy[i].add(_accuracy[i], _batch_size)
            if dev_data.has_ended():
                dev_data.reset(False)
            dev_inputs, dev_segments, dev_inputs_length, dev_answers, dev_answers_length = dev_data.get_next(
            )
            feedict = {
                inputs: dev_inputs,
                segments: dev_segments,
                inputs_length: dev_inputs_length,
                answers: dev_answers,
                answers_length: dev_answers_length
            }
            _loss, _accuracy = sess.run([loss, accuracy], feed_dict=feedict)
            _batch_size = len(dev_inputs_length)
            for i in HP.train_layers:
                dev_loss[i].add(_loss[i], _batch_size)
                dev_accuracy[i].add(_accuracy[i], _batch_size)
            if step % 500 == 0:
                for i in HP.train_layers:
                    logger.log("train " + str(i), train_loss[i],
                               train_accuracy[i], step)
                    logger.log("dev " + str(i), dev_loss[i], dev_accuracy[i],
                               step)
                    train_loss[i].reset()
                    train_accuracy[i].reset()
                    dev_loss[i].reset()
                    dev_accuracy[i].reset()
            if step % 1000 == 0:
                saver.save(sess, HP.save1_checkpoint)
                logger.log_text("saving checkpoint")
        for i in HP.train_layers:
            logger.log("epoch train " + str(i), long_train_loss[i],
                       long_train_accuracy[i], step)
            long_train_loss[i].reset()
            long_train_accuracy[i].reset()
        train_data.reset(True)
        dev_data.reset(True)
        while not dev_data.has_ended():
            dev_inputs, dev_segments, dev_inputs_length, dev_answers, dev_answers_length = dev_data.get_next(
            )
            feedict = {
                inputs: dev_inputs,
                segments: dev_segments,
                inputs_length: dev_inputs_length,
                answers: dev_answers,
                answers_length: dev_answers_length
            }
            _loss, _accuracy = sess.run([loss, accuracy], feed_dict=feedict)
            _batch_size = len(dev_inputs_length)
            for i in HP.train_layers:
                epoch_dev_loss[i].add(_loss[i], _batch_size)
                epoch_dev_accuracy[i].add(_accuracy[i], _batch_size)
        for i in HP.train_layers:
            logger.log("epoch validation " + str(i), epoch_dev_loss[i],
                       epoch_dev_accuracy[i], epoch)
        dev_data.reset(True)

    end_saver.save(sess, HP.end1_checkpoint)
Ejemplo n.º 7
0
def run():
    bert_config = BertConfig.from_json_file(HP.bert_config)
    inputs = tf.placeholder(dtype=tf.int32, shape=[1, None])
    segments = tf.placeholder(dtype=tf.int32, shape=[1, None])
    inputs_length = tf.placeholder(dtype=tf.int32, shape=[1])
    answers = tf.placeholder(dtype=tf.int32, shape=[1, None, 2])
    answers_length = tf.placeholder(dtype=tf.int32, shape=[1])


    mod = model.Model(bert_config, HP.is_training, HP.num_units, inputs, segments, inputs_length,
                      answers, answers_length, layers=HP.layers)
    loss = mod.losses(HP.train_layers)
    loss = [loss[i] for i in HP.layers]
    accuracy = mod.accuracy(HP.train_layers)
    accuracy = list(accuracy[i] for i in HP.layers)

    num_experts = len(HP.layers)
    eta = np.arange(HP.eta[0], HP.eta[1], HP.eta[2], dtype=np.float32)
    num_eta = len(eta)
    log_probs = tf.get_variable(name='log_prob', shape=[num_eta, num_experts], dtype=tf.float32, trainable=False)
    probs = tf.nn.softmax(log_probs, 1)

    init = tf.assign(log_probs, tf.zeros_like(log_probs))

    choice = tf.random.multinomial(log_probs, 1, output_dtype=tf.int32)
    choice = tf.squeeze(choice, -1)
    loss = tf.stack(loss)
    accuracy = tf.stack(accuracy)
    my_loss = tf.gather(loss, choice)
    my_accuracy = tf.gather(accuracy, choice)

    mean_loss = tf.reduce_sum((1 - accuracy) * probs, -1)
    regret = tf.expand_dims(mean_loss, 1) - (1-accuracy)

    log_update = tf.expand_dims(eta, 1) * regret
    log_update = log_update - log_update**2
    with tf.control_dependencies([my_loss, my_accuracy]):
        update = tf.assign(log_probs, log_probs + log_update)

    dev_data = DataGenerator(HP.dev_file, HP.max_seq_length, 1)

    weights = {v.name: v for v in mod.weights}

    starter = tf.train.Saver(weights)

    saver = tf.train.Saver({log_probs.name: log_probs})

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess = tf.Session(config=sess_config)
    starter.restore(sess, HP.start2_checkpoint)

    sess.run(init)

    keys = ["loss", "accuracy", "my_loss", "my_accuracy", "probs"]
    step = 0
    saved_items = [[] for _ in keys]
    with open(HP.log_files, 'w') as f:
        f.write("***START***\n\n\n")
    print("***START***\n\n\n")
    while not dev_data.has_ended():
        step = step + 1
        dev_inputs, dev_segments, dev_inputs_length, dev_answers, dev_answers_length = \
            dev_data.get_next()
        feed_dict = {inputs: dev_inputs, segments: dev_segments, inputs_length: dev_inputs_length,
                     answers: dev_answers, answers_length: dev_answers_length}
        items = sess.run([loss, accuracy, my_loss, my_accuracy, probs, update], feed_dict=feed_dict)
        items = items[:-1]
        for saved_item, item in zip(saved_items, items):
            saved_item.append(item.tolist())
        if step % 1000 == 0:
            text = "step: %d\nsaving checkpoint\n\n" % step
            print(text)
            with open(HP.log_files, 'a') as f:
                f.write(text)
            dictionary = dict(zip(keys, saved_items))
            with open(HP.weights_file, 'w') as f:
                json.dump(dictionary, f)
            saver.save(sess, HP.save2_checkpoint)

    dictionary = dict(zip(keys, saved_items))
    with open(HP.weights_file, 'w') as f:
        json.dump(dictionary, f)
    print("\n\n\nEND")
    with open(HP.log_files, 'w') as f:
        f.write("***END***\n\n\n")
Ejemplo n.º 8
0
  def load_model(cls, model_path, bert_config, init_spec, *inputs, **kwargs):
    """
    Instantiate a NNModule from a pre-trained model file.
    Download and cache the pre-trained model file if needed.
    
    Params:
      pretrained_model_name: either:
        - a str with the name of a pre-trained model to load selected in the list of:
          . `bert-base-uncased`
          . `bert-large-uncased`
          . `bert-base-cased`
          . `bert-base-multilingual`
          . `bert-base-chinese`
        - a path or url to a pretrained model archive containing:
          . `bert_config.json` a configuration file for the model
          . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
      *inputs, **kwargs: additional input for the specific Bert class
        (ex: num_labels for BertForSequenceClassification)
    """
    # Load config

    config = BertConfig.from_json_file(bert_config)
    logger.info("Model config {}".format(config))
    # Instantiate model.
    model = cls(config, *inputs, **kwargs)
    if model_path is None:
      return model
    logger.info("loading prtrained local model file {}".format(model_path))
    state_dict = torch.load(model_path, map_location='cpu')

    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    for k in list(state_dict.keys()):
      if 'LayerNorm.gamma' in k:
        nk = k.replace('LayerNorm.gamma', 'LayerNorm.weight')
        state_dict[nk]=state_dict[k]
        del state_dict[k]
      if 'LayerNorm.beta' in k:
        nk = k.replace('LayerNorm.beta', 'LayerNorm.bias')
        state_dict[nk]=state_dict[k]
        del state_dict[k]

    ignore_init = []
    if init_spec:
      remap_dict = type(state_dict)()
      for var in init_spec:
        mapping = init_spec[var].mapping
        name = init_spec[var].name
        if name.startswith('module.'):
          name = name[len('module.'):]
        if (not init_spec[var].use_pretrain):
          ignore_init += [name]
        elif mapping:
          if mapping.startswith('module.'):
            mapping = mapping[len('module.'):]
          if mapping in state_dict:
            remap_dict[name] = state_dict[mapping]
      logger.info('Variables not using pretraining: {}'.format(ignore_init))
      for ig in ignore_init:
        if ig in state_dict:
          del state_dict[ig]
      for key in state_dict:
        if key not in remap_dict:
          remap_dict[key]=state_dict[key]
      state_dict = remap_dict

    if metadata is not None:
      state_dict._metadata = metadata

    def load(module, prefix=''):
      local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
      module._load_from_state_dict(
        state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
      for name, child in module._modules.items():
        if child is not None:
          load(child, prefix + name + '.')
    load(model)
    if len(missing_keys) > 0:
      logger.warn("Weights of {} not initialized from pretrained model: {}".format(
        model.__class__.__name__, '\n  '.join(missing_keys)))
    if len(unexpected_keys) > 0:
      logger.warn("Weights from pretrained model not used in {}: {}".format(
        model.__class__.__name__, '\n  '.join(unexpected_keys)))
    return model
Ejemplo n.º 9
0
 def __init__(self, opt):
     super(TransformerMapping, self).__init__()
     self.opt = opt
     bert_config = BertConfig.from_json_file(opt.trans_cfg)
     self.layer = bert.BERTLayer(bert_config)
     self.mapping = nn.Linear(opt.img_dim, opt.final_dims)