Пример #1
0
def main(args):
    if args.debug:
        debug()

    if args.cuda:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
    else:
        config = tf.ConfigProto(device_count={'GPU': 0})
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    data_class = SingleTurnDialog.load_class(args.dataset)
    wordvec_class = WordVector.load_class(args.wvclass)
    if wordvec_class == None:
        wordvec_class = Glove
    if args.cache:
        data = try_cache(data_class, (args.datapath, ), args.cache_dir)
        vocab = data.vocab_list
        embed = try_cache(lambda wv, ez, vl: wordvec_class(wv).load(ez, vl),
                          (args.wvpath, args.embedding_size, vocab),
                          args.cache_dir, wordvec_class.__name__)
    else:
        data = data_class(args.datapath)
        wv = wordvec_class(args.wvpath)
        vocab = data.vocab_list
        embed = wv.load(args.embedding_size, vocab)

    embed = np.array(embed, dtype=np.float32)

    with tf.Session(config=config) as sess:
        model = create_model(sess, data, args, embed)
        if args.mode == "train":
            model.train_process(sess, data, args)
        else:
            model.test_process(sess, data, args)
Пример #2
0
def main(args):
    if args.debug:
        debug()
    if args.cuda:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
    else:
        config = tf.ConfigProto(device_count={'GPU': 0})
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    data_class = LanguageGeneration.load_class(args.dataset)
    wordvec_class = WordVector.load_class(args.wvclass)
    if wordvec_class == None:
        wordvec_class = Glove
    if args.cache:
        data = try_cache(data_class, (args.datapath,), args.cache_dir)
        vocab = data.vocab_list
        embed = try_cache(lambda wv, ez, vl: wordvec_class(wv).load_matrix(ez, vl),
                          (args.wvpath, args.embedding_size, vocab),
                          args.cache_dir, wordvec_class.__name__)
    else:
        data = data_class(args.datapath)
        wv = wordvec_class(args.wvpath)
        vocab = data.vocab_list
        embed = wv.load_matrix(args.embedding_size, vocab)

    embed = np.array(embed, dtype = np.float32)
    with tf.Session(config=config) as sess:
        generator, discriminator, rollout_gen = create_model(sess, data, args, embed)
        if args.mode == "train":
            if args.pre_train:
                #Start pretraining
                print('Start pre-training generator...')
                generator.pre_train_process(sess, data)

                print('Start pre-training discriminator...')
                discriminator.train_process(generator, data, sess, args.dis_pre_epoch_num)

                #Start testing
                generator.test_process(sess, data)

            #Start adversarial training
            for batch in range(args.total_adv_batch):
                print("Adversarial  training %d"%batch)
                print('Start adversarial training generator...')
                generator.adv_train_process(sess, data, rollout_gen, discriminator)
                testout = generator.pre_evaluate(sess, data, args.batch_size, "test")
                if (batch % args.test_per_epoch == 0 or batch == args.total_adv_batch - 1) and batch != 0:
                    print('total_batch: ', batch, 'test_loss: ', testout[0])
                    generator.test_process(sess, data) 

                print('Start adversarial training discriminator...')
                discriminator.train_process(generator, data, sess, args.dis_adv_epoch_num)
        else:
            print("Start testing...")
            test_res = generator.test_process(sess, data)
            for key, val in test_res.items():
                if isinstance(val, bytes):
                    test_res[key] = str(val)
            json.dump(test_res, open("./result.json", "w"))
Пример #3
0
def main(args):
    if args.debug:
        debug()

    if args.cuda:
        if not "CUDA_VISIBLE_DEVICES" in os.environ:
            os.environ["CUDA_VISIBLE_DEVICES"] = '0'
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
    else:
        config = tf.ConfigProto(device_count={'GPU': 0})
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    np.random.seed(args.seed)
    tf.set_random_seed(args.seed)

    data_class = MyMemHRED
    wordvec_class = TencentChinese
    logger.info("模型侧加载数据")
    if args.cache:
        if not os.path.isdir(args.cache_dir):
            os.mkdir(args.cache_dir)
        data = try_cache(
            data_class, {
                "file_id": args.datapath,
                "max_sent_length": args.max_sent_length,
                "num_turns": args.num_turns,
                "max_know_legth": args.max_know_length
            }, args.cache_dir)
        vocab = data.vocab_list
        logger.info("加载词向量")
        embed = try_cache(
            lambda wv, ez, vl: wordvec_class(wv).load_matrix(ez, vl),
            (args.wv_path, args.embedding_size, vocab), args.cache_dir,
            wordvec_class.__name__)
    else:
        data = data_class(file_id=args.datapath,
                          max_sent_length=args.max_sent_length,
                          num_turns=args.num_turns,
                          max_know_length=args.max_know_length)
        logger.info("定义并加载词向量文件")
        wv = wordvec_class(args.wv_path)
        vocab = data.vocab_list
        embed = wv.load_matrix(args.embedding_size, vocab)

    embed = np.array(embed, dtype=np.float32)
    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)

    with tf.Session(config=config) as sess:
        model = create_model(sess, data, args, embed)
        if args.mode == "train":
            logger.info("开始训练...")
            model.train_process(sess, data, args)
        else:
            logger.info("开始测试...")
            model.test_process(sess, data, args)
Пример #4
0
def main(args):
    if args.debug:
        debug()

    if args.cuda:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
    else:
        config = tf.ConfigProto(device_count={'GPU': 0})
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    data_class = MultiTurnDialog.load_class(args.dataset)
    wordvec_class = WordVector.load_class(args.wvclass)
    if wordvec_class == None:
        wordvec_class = Glove
    if args.cache:
        data = try_cache(data_class, (args.datapath, ), args.cache_dir)
        vocab = data.frequent_vocab_list
        embed = try_cache(
            lambda wv, ez, vl: wordvec_class(wv).load_matrix(ez, vl),
            (args.wvpath, args.word_embedding_size, vocab), args.cache_dir,
            wordvec_class.__name__)
        word2vec = try_cache(
            lambda wv, ez, vl: wordvec_class(wv).load_dict(vl),
            (args.wvpath, args.word_embedding_size, vocab), args.cache_dir,
            wordvec_class.__name__)
    else:
        data = data_class(
            args.datapath,
            min_frequent_vocab_times=args.min_frequent_vocab_times,
            max_sent_length=args.max_sent_length,
            max_turn_length=args.max_turn_length)
        wv = wordvec_class(args.wvpath)
        vocab = data.frequent_vocab_list  #dim:9508
        embed = wv.load_matrix(args.word_embedding_size, vocab)
        word2vec = wv.load_dict(vocab)

    embed = np.array(embed, dtype=np.float32)
    with tf.Session(config=config) as sess:
        model = create_model(sess, data, args, embed)
        if args.mode == "train":
            model.train_process(sess, data, args)
        else:
            multi_ref_res = model.test_multi_ref(sess, data, word2vec, args)
            test_res = model.test_process(sess, data, args)
            test_res.update(multi_ref_res)

            for key, val in test_res.items():
                if isinstance(val, bytes):
                    test_res[key] = str(val)
            json.dump(test_res, open("./result.json", "w"))
Пример #5
0
def main(args):
    if args.debug:
        debug()

    if args.cuda:
        if not "CUDA_VISIBLE_DEVICES" in os.environ:
            os.environ["CUDA_VISIBLE_DEVICES"] = '0'
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
    else:
        config = tf.ConfigProto(device_count={'GPU': 0})
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    np.random.seed(233)
    tf.set_random_seed(233)

    data_class = MyLM
    wordvec_class = TencentChinese
    if wordvec_class == None:
        wordvec_class = TencentChinese
    if args.cache:
        if not os.path.isdir(args.cache_dir):
            os.mkdir(args.cache_dir)
        data = try_cache(data_class, (args.datapath, ), args.cache_dir)
        vocab = data.vocab_list
        embed = try_cache(
            lambda wv, ez, vl: wordvec_class(wv).load_matrix(ez, vl),
            (args.wvpath, args.embedding_size, vocab), args.cache_dir,
            wordvec_class.__name__)
    else:
        data = data_class(args.datapath)
        wv = wordvec_class(args.wvpath)
        vocab = data.vocab_list
        embed = wv.load_matrix(args.embedding_size, vocab)

    embed = np.array(embed, dtype=np.float32)
    if not os.path.isdir(args.out_dir):
        os.mkdir(args.out_dir)

    with tf.Session(config=config) as sess:
        model = create_model(sess, data, args, embed)
        if args.mode == "train":
            model.train_process(sess, data, args)
        else:
            model.test_process(sess, data, args)
Пример #6
0
def main(args, load_exclude_set, restoreCallback):
    logging.basicConfig(\
     filename=0,\
     level=logging.DEBUG,\
     format='%(asctime)s %(filename)s[line:%(lineno)d] %(message)s',\
     datefmt='%H:%M:%S')

    if args.debug:
        debug()
    logging.info(json.dumps(args, indent=2))

    cuda_init(0, args.cuda)

    volatile = Storage()
    volatile.load_exclude_set = load_exclude_set
    volatile.restoreCallback = restoreCallback

    data_class = LanguageGeneration
    data_arg = Storage()
    data_arg.file_id = args.dataid
    data_arg.tokenizer = args.tokenizer
    data_arg.max_sent_length = args.max_sent_length
    data_arg.convert_to_lower_letter = args.convert_to_lower_letter
    data_arg.min_frequent_vocab_times = args.min_frequent_vocab_times
    data_arg.min_rare_vocab_times = args.min_rare_vocab_times
    wordvec_class = GeneralWordVector

    def load_dataset(data_arg, wvpath, embedding_size):
        wv = wordvec_class(wvpath)
        dm = data_class(**data_arg)
        return dm, wv.load_matrix(embedding_size, dm.frequent_vocab_list)

    if args.cache:
        dm, volatile.wordvec = try_cache(
            load_dataset, (data_arg, args.wvpath, args.embedding_size),
            args.cache_dir, data_class.__name__ + "_" + wordvec_class.__name__)
    else:
        dm, volatile.wordvec = load_dataset(data_arg, args.wvpath,
                                            args.embedding_size)

    volatile.dm = dm

    param = Storage()
    param.args = args
    param.volatile = volatile

    model = TransformerLM(param)
    if args.mode == "train":
        model.train_process()
    elif args.mode == "test":
        test_res = model.test_process()

        json.dump(test_res, open("./result.json", "w"))
    elif args.mode == "load":
        return model
    else:
        raise ValueError("Unknown mode")
Пример #7
0
def main(args, load_exclude_set, restoreCallback):
    logging.basicConfig(\
     filename=0,\
     level=logging.DEBUG,\
     format='%(asctime)s %(filename)s[line:%(lineno)d] %(message)s',\
     datefmt='%H:%M:%S')

    if args.debug:
        debug()
    logging.info(json.dumps(args, indent=2))

    cuda_init(0, args.cuda)

    volatile = Storage()
    volatile.load_exclude_set = load_exclude_set
    volatile.restoreCallback = restoreCallback

    data_class = SingleTurnDialog.load_class(args.dataset)
    data_arg = Storage()
    data_arg.file_id = args.datapath
    wordvec_class = WordVector.load_class(args.wvclass)
    if wordvec_class is None:
        wordvec_class = Glove

    def load_dataset(data_arg, wvpath, embedding_size):
        wv = wordvec_class(wvpath)
        dm = data_class(**data_arg)
        return dm, wv.load(embedding_size, dm.vocab_list)

    if args.cache:
        dm, volatile.wordvec = try_cache(
            load_dataset, (data_arg, args.wvpath, args.embedding_size),
            args.cache_dir, data_class.__name__ + "_" + wordvec_class.__name__)
    else:
        dm, volatile.wordvec = load_dataset(data_arg, args.wvpath,
                                            args.embedding_size)

    volatile.dm = dm

    param = Storage()
    param.args = args
    param.volatile = volatile

    model = Seq2seq(param)
    if args.mode == "train":
        model.train_process()
    elif args.mode == "test":
        test_res = model.test_process()

        for key, val in test_res.items():
            if isinstance(val, bytes):
                test_res[key] = str(val)
        json.dump(test_res, open("./result.json", "w"))
    else:
        raise ValueError("Unknown mode")
Пример #8
0
def main(args):
    logging.basicConfig(\
     filename=0,\
     level=logging.DEBUG,\
     format='%(asctime)s %(filename)s[line:%(lineno)d] %(message)s',\
     datefmt='%H:%M:%S')

    if args.debug:
        debug()
    logging.info(json.dumps(args, indent=2))

    cuda_init(0, args.cuda)

    volatile = Storage()
    data_class = SkeletonGeneration
    wordvec_class = WordVector.load_class(args.wvclass)
    if wordvec_class is None:
        wordvec_class = Glove
    if args.cache:
        dm = try_cache(data_class, (args.datapath, ), args.cache_dir)
        volatile.wordvec = try_cache(\
         lambda wv, ez, vl: wordvec_class(wv).load(ez, vl), \
         (args.wvpath, args.embedding_size, dm.vocab_list),
         args.cache_dir, wordvec_class.__name__)
    else:
        dm = data_class(args.datapath)
        wv = wordvec_class(args.wvpath)
        volatile.wordvec = wv.load(args.embedding_size, dm.vocab_list)

    volatile.dm = dm

    param = Storage()
    param.args = args
    param.volatile = volatile

    model = LM(param)
    if args.mode == "train":
        model.train_process()
    elif args.mode == "test":
        model.test_process()
    else:
        raise ValueError("Unknown mode")
Пример #9
0
def main(args, load_exclude_set, restoreCallback):
    logging.basicConfig(\
     filename=0,\
     level=logging.DEBUG,\
     format='%(asctime)s %(filename)s[line:%(lineno)d] %(message)s',\
     datefmt='%H:%M:%S')

    if args.debug:
        debug()
    logging.info(json.dumps(args, indent=2))

    cuda_init(0, args.cuda)

    volatile = Storage()
    volatile.load_exclude_set = load_exclude_set
    volatile.restoreCallback = restoreCallback

    data_class = LanguageGeneration
    data_arg = Storage()
    data_arg.file_id = args.dataid
    data_arg.max_sent_length = args.max_sent_length
    data_arg.convert_to_lower_letter = args.convert_to_lower_letter
    data_arg.pretrained = args.pretrained
    data_arg.tokenizer = args.pretrained_model

    def load_dataset(data_arg):
        tokenizer = PretrainedTokenizer(
            GPT2Tokenizer.from_pretrained(data_arg.tokenizer))
        new_arg = Storage(data_arg.copy())
        new_arg.tokenizer = tokenizer
        dm = data_class(**new_arg)
        return dm

    if args.cache:
        dm = try_cache(load_dataset, (data_arg, ), args.cache_dir,
                       data_class.__name__)
    else:
        dm = load_dataset(data_arg)

    volatile.dm = dm

    param = Storage()
    param.args = args
    param.volatile = volatile

    model = GPT2LM(param)
    if args.mode == "train":
        model.train_process()
    elif args.mode == "test":
        test_res = model.test_process()

        json.dump(test_res, open("./result.json", "w"))
    else:
        raise ValueError("Unknown mode")
Пример #10
0
def main(args, load_exclude_set, restoreCallback):
    logging.basicConfig(\
     filename=0,\
     level=logging.DEBUG,\
     format='%(asctime)s %(filename)s[line:%(lineno)d] %(message)s',\
     datefmt='%H:%M:%S')

    if args.debug:
        debug()
    logging.info(json.dumps(args, indent=2))

    cuda_init(0, args.cuda)

    volatile = Storage()
    volatile.load_exclude_set = load_exclude_set
    volatile.restoreCallback = restoreCallback

    data_class = SingleTurnDialog.load_class(args.dataset)
    data_arg = Storage()
    data_arg.file_id = args.datapath + "#OpenSubtitles"
    data_arg.tokenizer = PretrainedTokenizer(
        BertTokenizer.from_pretrained(args.bert_vocab))
    data_arg.pretrained = "bert"
    wordvec_class = WordVector.load_class(args.wvclass)
    if wordvec_class is None:
        wordvec_class = Glove

    def load_dataset(data_arg, wvpath, embedding_size):
        wv = wordvec_class(wvpath)
        dm = data_class(**data_arg)
        return dm, wv.load_matrix(embedding_size, dm.frequent_vocab_list)

    if args.cache:
        dm, volatile.wordvec = try_cache(
            load_dataset, (data_arg, args.wvpath, args.embedding_size),
            args.cache_dir, data_class.__name__ + "_" + wordvec_class.__name__)
    else:
        dm, volatile.wordvec = load_dataset(data_arg, args.wvpath,
                                            args.embedding_size)

    volatile.dm = dm

    param = Storage()
    param.args = args
    param.volatile = volatile

    model = Seq2seq(param)
    if args.mode == "train":
        model.train_process()
    elif args.mode == "test":
        model.test_process()
    else:
        raise ValueError("Unknown mode")
Пример #11
0
def main(args):
    if args.debug:
        debug()

    if args.cuda:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
    else:
        config = tf.ConfigProto(device_count={'GPU': 0})
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    data_class = LanguageGeneration.load_class(args.dataset)
    wordvec_class = WordVector.load_class(args.wvclass)
    if wordvec_class == None:
        wordvec_class = Glove
    if args.cache:
        data = try_cache(data_class, (args.datapath, ), args.cache_dir)
        vocab = data.vocab_list
        embed = try_cache(lambda wv, ez, vl: wordvec_class(wv).load(ez, vl),
                          (args.wvpath, args.embedding_size, vocab),
                          args.cache_dir, wordvec_class.__name__)
    else:
        data = data_class(args.datapath)
        wv = wordvec_class(args.wvpath)
        vocab = data.vocab_list
        embed = wv.load(args.embedding_size, vocab)

    embed = np.array(embed, dtype=np.float32)

    with tf.Session(config=config) as sess:
        model = create_model(sess, data, args, embed)
        if args.mode == "train":
            model.train_process(sess, data, args)
        else:
            test_res = model.test_process(sess, data, args)

            for key, val in test_res.items():
                if isinstance(val, bytes):
                    test_res[key] = str(val)
            json.dump(test_res, open("./result.json", "w"))
Пример #12
0
def main(args):
    logging.basicConfig(
        filename=0,
        level=logging.DEBUG,
        format='%(asctime)s %(filename)s[line:%(lineno)d] %(message)s',
        datefmt='%H:%M:%S')

    if args.debug:
        debug()
    logging.info(json.dumps(args, indent=2))

    cuda_init(args.cuda_num, args.cuda)

    volatile = Storage()
    volatile.load_exclude_set = args.load_exclude_set
    volatile.restoreCallback = args.restoreCallback

    if args.dataset == 'WizardOfWiki':
        data_class = WizardOfWiki
    elif args.dataset == 'HollE':
        data_class = HollE
    else:
        raise ValueError
    wordvec_class = WordVector.load_class(args.wvclass)
    if wordvec_class is None:
        wordvec_class = Glove

    if not os.path.exists(args.cache_dir):
        os.mkdir(args.cache_dir)
    args.cache_dir = os.path.join(args.cache_dir, args.dataset)

    if not os.path.exists(args.out_dir):
        os.mkdir(args.out_dir)
    args.out_dir = os.path.join(args.out_dir, args.dataset)

    if not os.path.exists(args.model_dir):
        os.mkdir(args.model_dir)
    if args.dataset not in args.model_dir:
        args.model_dir = os.path.join(args.model_dir, args.dataset)

    if args.cache:
        dm = try_cache(data_class, (args.datapath, ), args.cache_dir)
        volatile.wordvec = try_cache(
            lambda wv, ez, vl: wordvec_class(wv).load_matrix(ez, vl),
            (args.wvpath, args.embedding_size, dm.vocab_list), args.cache_dir,
            wordvec_class.__name__)
    else:
        dm = data_class(args.datapath)
        wv = wordvec_class(args.wvpath)
        volatile.wordvec = wv.load_matrix(args.embedding_size, dm.vocab_list)

    volatile.dm = dm

    param = Storage()
    param.args = args
    param.volatile = volatile

    model = Seq2seq(param)
    if args.mode == "train":
        model.train_process()
    elif args.mode == "test":
        model.test_process()
    elif args.mode == 'dev':
        model.test_dev()
    else:
        raise ValueError("Unknown mode")
Пример #13
0
def main(args, load_exclude_set, restoreCallback):
    logging.basicConfig(
        filename=0,
        level=logging.DEBUG,
        format='%(asctime)s %(filename)s[line:%(lineno)d] %(message)s',
        datefmt='%H:%M:%S')

    if args.debug:
        debug()
    logging.info(json.dumps(args, indent=2))

    cuda_init(args.device, args.cuda)

    volatile = Storage()
    volatile.load_exclude_set = load_exclude_set
    volatile.restoreCallback = restoreCallback

    data_class = SingleTurnDialog.load_class(args.dataset)
    data_arg = Storage()
    data_arg.file_id = args.datapath

    # RAML parameters
    if args.model == "raml":
        data_arg.raml_file = "samples_iwslt14.txt"
        data_arg.num_samples = 10 or args.n_samples
        data_arg.tau = 0.4

    wordvec_class = WordVector.load_class(args.wvclass)

    def load_dataset(data_arg, wvpath, embedding_size):
        wv = wordvec_class(wvpath)
        dm = data_class(**data_arg)
        return dm, wv.load_matrix(embedding_size, dm.vocab_list)

    if args.cache:
        dm, volatile.wordvec = try_cache(
            load_dataset, (data_arg, args.wvpath, args.embedding_size),
            args.cache_dir, data_class.__name__ + "_" + wordvec_class.__name__)
    else:
        dm, volatile.wordvec = load_dataset(data_arg, args.wvpath,
                                            args.embedding_size)

    volatile.dm = dm

    param = Storage()
    param.args = args
    param.volatile = volatile

    if args.model == "basic":
        model = Seq2seq(param)
    elif args.model == "raml":
        model = Seq2seqRAML(param)
    elif args.model == "scheduled-sampling":
        model = Seq2seqSS(param)
    elif args.model == "policy-gradient":
        model = Seq2seqPG(param)

    if args.mode == "train":
        model.train_process()
    elif args.mode == "test":
        test_res = model.test_process()

        json.dump(test_res, open("./result.json", "w"))
    else:
        raise ValueError("Unknown mode")
Пример #14
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--bert_config_file",
        default="chinese_wwm_pytorch/bert_config.json",
        type=str,
        help="The config json file corresponding to the pre-trained BERT model. "
        "This specifies the model architecture.")
    parser.add_argument(
        "--vocab_file",
        default="chinese_wwm_pytorch/vocab.txt",
        type=str,
        help="The vocabulary file that the BERT model was trained on.")
    parser.add_argument(
        "--init_checkpoint",
        default="chinese_wwm_pytorch/pytorch_model.bin",
        type=str,
        help="Initial checkpoint (usually from a pre-trained BERT model).")

    ## Required parameters
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )
    parser.add_argument(
        "--model_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )
    parser.add_argument("--cache_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="Whether to run training.")

    ## Other parameters
    parser.add_argument('--name',
                        type=str,
                        default='BERTRetrieval',
                        help='name of model')

    parser.add_argument('--dataset',
                        type=str,
                        default='MyBERTRetrieval',
                        help='Dataloader class. Default: OpenSubtitles')
    parser.add_argument(
        '--datapath',
        type=str,
        default='resources://OpenSubtitles',
        help='Directory for data set. Default: resources://OpenSubtitles')
    parser.add_argument("--num_choices",
                        default=10,
                        type=int,
                        help="the number of retrieval options")

    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--cache",
                        action='store_true',
                        help="Whether to run training.")

    parser.add_argument("--train_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--predict_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for predictions.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
        "of training.")

    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        "--do_lower_case",
        default=True,
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )

    args = parser.parse_args()

    if not args.do_train and not args.do_predict:
        raise ValueError(
            "At least one of `do_train` or `do_predict` must be True.")

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)
    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir, exist_ok=True)

    data_class = MyBERTRetrieval

    def load_dataset(file_id, bert_vocab_name, do_lower_case, num_choices):
        dm = data_class(file_id=file_id,
                        bert_vocab_name=bert_vocab_name,
                        do_lower_case=do_lower_case,
                        num_choices=num_choices)
        return dm

    if args.cache:
        dataManager = try_cache(load_dataset,
                                (args.datapath, args.vocab_file,
                                 args.do_lower_case, args.num_choices),
                                args.cache_dir, data_class.__name__)
    else:
        dataManager = load_dataset(file_id=args.datapath,
                                   bert_vocab_name=args.vocab_file,
                                   do_lower_case=args.do_lower_case,
                                   num_choices=args.num_choices)

    if not os.path.exists(os.path.join(args.datapath,
                                       'test_distractors.json')):
        test_distractors = dataManager.data['test']['resp_distractors']
        with open(os.path.join(args.datapath, 'test_distractors.json'),
                  'w') as f:
            json.dump(test_distractors, f, ensure_ascii=False, indent=4)

    if args.do_train:

        if not args.no_cuda:
            if not "CUDA_VISIBLE_DEVICES" in os.environ:
                os.environ["CUDA_VISIBLE_DEVICES"] = '2,3'

        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
        logger.info("device: {} n_gpu: {}".format(device, n_gpu))

        if args.gradient_accumulation_steps < 1:
            raise ValueError(
                "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
                .format(args.gradient_accumulation_steps))

        args.train_batch_size = int(args.train_batch_size /
                                    args.gradient_accumulation_steps)

        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if n_gpu > 0:
            torch.cuda.manual_seed_all(args.seed)

        num_train_steps = None
        logger.info("train examples {}".format(
            len(dataManager.data['train']['resp'])))
        num_train_steps = int(
            len(dataManager.data['train']['resp']) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

        # Prepare model
        '''
        if os.path.exists(output_model_file):
            model_state_dict = torch.load(output_model_file)
            model = BERTRetrieval(num_choices=args.num_choices, bert_config_file=args.bert_config_file)
            model.load_state_dict(model_state_dict)
        '''
        model = BERTRetrieval(num_choices=args.num_choices,
                              bert_config_file=args.bert_config_file)
        if args.init_checkpoint is not None:
            logger.info('load bert weight')
            state_dict = torch.load(args.init_checkpoint, map_location='cpu')
            missing_keys = []
            unexpected_keys = []
            error_msgs = []
            # copy state_dict so _load_from_state_dict can modify it
            metadata = getattr(state_dict, '_metadata', None)
            state_dict = state_dict.copy()
            if metadata is not None:
                state_dict._metadata = metadata

            def load(module, prefix=''):
                local_metadata = {} if metadata is None else metadata.get(
                    prefix[:-1], {})

                module._load_from_state_dict(state_dict, prefix,
                                             local_metadata, True,
                                             missing_keys, unexpected_keys,
                                             error_msgs)
                for name, child in module._modules.items():
                    # logger.info("name {} chile {}".format(name,child))
                    if child is not None:
                        load(child, prefix + name + '.')

            load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
            logger.info("missing keys:{}".format(missing_keys))
            logger.info('unexpected keys:{}'.format(unexpected_keys))
            logger.info('error msgs:{}'.format(error_msgs))

        model.to(device)
        model = torch.nn.DataParallel(model)

        # Prepare optimizer
        param_optimizer = list(model.named_parameters())

        # hack to remove pooler, which is not used
        # thus it produce None grad that break apex
        param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        t_total = num_train_steps
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate)
        global_step = 0

        logger.info("***** Running training *****")
        logger.info("  Num post-response pairs = %d",
                    len(dataManager.data['train']['resp']))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)

        model.train()
        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            model.zero_grad()
            dataManager.restart(key='train', batch_size=args.train_batch_size)
            data = dataManager.get_next_batch(key='train')
            step = 0
            loss_value = 0
            while data is not None:
                if n_gpu == 1:
                    preprocess_batch(
                        data, device)  # multi-gpu does scattering it-self
                else:
                    preprocess_batch(data)
                loss = model(data, data['labels'])
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                loss_value += loss.cpu().item(
                ) * args.gradient_accumulation_steps
                loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    lr_this_step = args.learning_rate * warmup_linear(
                        global_step / t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if (step + 1) % (args.gradient_accumulation_steps * 100) == 0:
                    logger.info("step: %d, loss: %f" % (step + 1, loss_value))
                    loss_value = 0

                step += 1
                data = dataManager.get_next_batch(key='train')

            output_model_file = os.path.join(
                args.model_dir, "pytorch_model.%d.%d.bin" %
                (int(args.num_train_epochs), epoch + 1))

            # Save a trained model
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self
            torch.save(model_to_save.state_dict(), output_model_file)

    # Load a trained model that you have fine-tuned

    if args.do_predict:

        total_epoch = int(args.num_train_epochs)
        chosen_epoch = 5  # int(args.num_train_epochs)

        if not args.no_cuda:
            if not "CUDA_VISIBLE_DEVICES" in os.environ:
                os.environ["CUDA_VISIBLE_DEVICES"] = '3'

        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()

        random.seed(args.seed)

        output_model_file = os.path.join(
            args.model_dir,
            "pytorch_model.%d.%d.bin" % (total_epoch, chosen_epoch))

        model_state_dict = torch.load(output_model_file)
        model = BERTRetrieval(num_choices=args.num_choices,
                              bert_config_file=args.bert_config_file)
        model.load_state_dict(model_state_dict)
        model.to(device)

        if n_gpu > 1:
            model = torch.nn.DataParallel(model)

        metric = MyMetrics()

        logger.info("***** Running testing *****")
        logger.info("  Num post-response pairs = %d",
                    len(dataManager.data['test']['resp']))
        logger.info("  Batch size = %d", args.predict_batch_size)

        model.eval()
        logger.info("Start evaluating")
        dataManager.restart(key='test',
                            batch_size=args.predict_batch_size,
                            shuffle=False)
        data = dataManager.get_next_batch(key='test')

        gens = []
        gold = []
        choices = []

        hits = {1: [0, 0], 3: [0, 0]}
        while data is not None:
            preprocess_batch(data, device)
            truth_response, can_responses = data['resp'], data['can_resps']

            with torch.no_grad():
                prob, pred = model(data)

            assert len(pred) == len(truth_response)
            assert len(pred) == len(can_responses)
            assert len(can_responses[0]) == args.num_choices

            for truth, pd, cans, prb in zip(truth_response, pred,
                                            can_responses, prob):
                metric.forword(truth, cans[pd])

                gold.append(truth)
                gens.append(cans[pd])
                choices.append(cans)

                idx = cans.index(truth)
                p_sort = np.argsort(prb)
                for key, count in hits.items():
                    if idx in p_sort[-key:]:
                        count[0] += 1
                    count[1] += 1

            data = dataManager.get_next_batch(key='test')

        result = metric.close()
        result.update({
            'hits@%d' % key: value[0] / value[1]
            for key, value in hits.items()
        })

        output_prediction_file = args.output_dir + "/%s_%s.%d.%d.txt" % (
            args.name, "test", total_epoch, chosen_epoch)
        with open(output_prediction_file, "w") as f:
            print("Test Result:")
            res_print = list(result.items())
            res_print.sort(key=lambda x: x[0])
            for key, value in res_print:
                if isinstance(value, float):
                    print("\t%s:\t%f" % (key, value))
                    f.write("%s:\t%f\n" % (key, value))
            f.write('\n')

            for resp, gen, options in zip(gold, gens, choices):
                f.write("resp:\t%s\n" % resp)
                f.write("gen:\t%s\n\n" % gen)
                for i, option in enumerate(options):
                    f.write("candidate %d:\t%s\n" % (i, option))
                f.write("\n")