def main(): # Extract arguments ap = argparse.ArgumentParser() ap.add_argument("data", help="Data file containing bugs") ap.add_argument("vocabulary", help="Vocabulary file") ap.add_argument("-s", "--suffix", help="Model and log-file suffix") args = ap.parse_args() data = DataReader(config["data"], data_file=args.data, vocab_path=args.vocabulary) model = TransformerPatchingModel(config["transformer"], data.vocabulary.vocab_dim, is_pointer=config["data"]["edits"]) # Restore model after a simple init tracker = Tracker(model, suffix=args.suffix) model(tf.zeros((1, 2), 'int32'), tf.zeros((1, 2), 'int32'), tf.zeros((1, 2), 'int32'), tf.zeros((0, 0), 'int32'), True) tracker.restore(best_only=True) with open( "results" + ("" if args.suffix is None else "-" + args.suffix) + ".txt", "w") as f_out: for batch in data.batcher(mode="test", optimize_packing=False): pre, pre_locs = batch[:2] preds = model.predict(data.vocabulary, pre, pre_locs, config["data"]["beam_size"], config["data"]["max_bug_length"]) write_completions(f_out, data.vocabulary, pre.numpy(), pre_locs.numpy(), preds)