示例#1
0
def main():
    args = get_train_args()
    model = init_train_env(args, tbert_type='siamese2')
    valid_examples = load_examples(args.data_dir, data_type="valid", model=model, num_limit=args.valid_num,
                                   overwrite=args.overwrite)
    train_examples = load_examples(args.data_dir, data_type="train", model=model, num_limit=args.train_num,
                                   overwrite=args.overwrite)
    train(args, train_examples, valid_examples, model, train_iter_method=train_with_neg_sampling)
    logger.info("Training finished")
示例#2
0
def main():
    args = get_train_args()
    model = init_train_env(args, tbert_type='single')
    valid_examples = load_examples(args.data_dir,
                                   data_type="valid",
                                   model=model,
                                   num_limit=args.valid_num,
                                   overwrite=args.overwrite)
    train_examples = load_examples(args.data_dir,
                                   data_type="train",
                                   model=model,
                                   num_limit=args.train_num,
                                   overwrite=args.overwrite)
    train(args, train_examples, valid_examples, model, train_single_iteration)
    logger.info("Training finished")
示例#3
0
    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    res_file = os.path.join(args.output_dir, "./raw_res.csv")

    cache_dir = os.path.join(args.data_dir, "cache")
    cached_file = os.path.join(cache_dir, "test_examples_cache.dat".format())

    logging.basicConfig(level='INFO')
    logger = logging.getLogger(__name__)

    if not os.path.isdir(args.output_dir):
        os.makedirs(args.output_dir)

    model = TBertI2(BertConfig(), args.code_bert)
    if args.model_path and os.path.exists(args.model_path):
        model_path = os.path.join(args.model_path, MODEL_FNAME)
        model.load_state_dict(torch.load(model_path))

    logger.info("model loaded")
    start_time = time.time()
    test_examples = load_examples(args.data_dir,
                                  data_type="test",
                                  model=model,
                                  overwrite=args.overwrite,
                                  num_limit=args.test_num)
    test_examples.update_embd(model)
    m = test(args, model, test_examples, "cached_siamese2_test")
    exe_time = time.time() - start_time
    m.write_summary(exe_time)