def main(): # parse the options opts = parse_args() # create the dataloaders dataloader = {'train': create_dataloader('train_valid' if opts.no_validation else 'train', opts), 'valid': create_dataloader('valid', opts)} # create the model model = Prover(opts) model.to(opts.device) # crete the optimizer optimizer = torch.optim.RMSprop(model.parameters(), lr=opts.learning_rate, momentum=opts.momentum, weight_decay=opts.l2) if opts.no_validation: scheduler = StepLR(optimizer, step_size=opts.lr_reduce_steps, gamma=0.1) else: scheduler = ReduceLROnPlateau(optimizer, patience=opts.lr_reduce_patience, verbose=True) # load the checkpoint start_epoch = 0 if opts.resume != None: log('loading model checkpoint from %s..' % opts.resume) if opts.device.type == 'cpu': checkpoint = torch.load(opts.resume, map_location='cpu') else: checkpoint = torch.load(opts.resume) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['n_epoch'] + 1 model.to(opts.device) agent = Agent(model, optimizer, dataloader, opts) best_acc = -1. for n_epoch in range(start_epoch, start_epoch + opts.num_epochs): log('EPOCH #%d' % n_epoch) # training loss_train = agent.train(n_epoch) # save the model checkpoint if n_epoch % opts.save_model_epochs == 0: agent.save(n_epoch, opts.checkpoint_dir) # validation if not opts.no_validation: loss_valid = agent.valid(n_epoch) # reduce the learning rate if opts.no_validation: scheduler.step() else: scheduler.step(loss_valid)
log("using CPU", "WARNING") torch.manual_seed(opts.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False np.random.seed(opts.seed) random.seed(opts.seed) if "ours" in opts.method: model = Prover(opts) log("loading model checkpoint from %s.." % opts.path) if opts.device.type == "cpu": checkpoint = torch.load(opts.path, map_location="cpu") else: checkpoint = torch.load(opts.path) model.load_state_dict(checkpoint["state_dict"]) model.to(opts.device) else: model = None agent = Agent(model, None, None, opts) if opts.file: files = [opts.file] else: files = [] projs = json.load(open(opts.projs_split))["projs_" + opts.split] for proj in projs: files.extend( glob(os.path.join(opts.datapath, "%s/**/*.json" % proj), recursive=True))
"weak-up-to", "buchberger", "jordan-curve-theorem", "dblib", "disel", "zchinese", "zfc", "dep-map", "chinese", "UnifySL", "hoare-tut", "huffman", "PolTac", "angles", "coq-procrastination", "coq-library-undecidability", "tree-automata", "coquelicot", "fermat4", "demos", "coqoban", "goedel", "verdi-raft", "verdi", "zorns-lemma", "coqrel", "fundamental-arithmetics" ] if 'ours' in opts.method: model = Prover(opts) log('loading model checkpoint from %s..' % opts.path) if opts.device.type == 'cpu': checkpoint = torch.load(opts.path, map_location='cpu') else: checkpoint = torch.load(opts.path) model.load_state_dict(checkpoint['state_dict']) model.to(opts.device) else: model = None agent = Agent(model, None, None, opts) if opts.file: files = [opts.file] elif opts.proj_idx is not None: files = glob(os.path.join(opts.datapath, '%s/**/*.json' % projs_test[opts.proj_idx]), recursive=True) else: files = [] projs = json.load(open(opts.projs_split))['projs_' + opts.split]