コード例 #1
0
def main():

    # parse the options
    opts = parse_args()

    # create the dataloaders
    dataloader = {'train': create_dataloader('train_valid' if opts.no_validation else 'train', opts),
                  'valid': create_dataloader('valid', opts)}
    
    # create the model 
    model = Prover(opts)
    model.to(opts.device)
  
    # crete the optimizer
    optimizer = torch.optim.RMSprop(model.parameters(), lr=opts.learning_rate,
                                    momentum=opts.momentum,
                                    weight_decay=opts.l2)
    if opts.no_validation:
        scheduler = StepLR(optimizer, step_size=opts.lr_reduce_steps, gamma=0.1) 
    else:
        scheduler = ReduceLROnPlateau(optimizer, patience=opts.lr_reduce_patience, verbose=True)

    # load the checkpoint
    start_epoch = 0
    if opts.resume != None:
        log('loading model checkpoint from %s..' % opts.resume)
        if opts.device.type == 'cpu':
            checkpoint = torch.load(opts.resume, map_location='cpu')
        else:
            checkpoint = torch.load(opts.resume)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['n_epoch'] + 1
        model.to(opts.device)

    agent = Agent(model, optimizer, dataloader, opts)

    best_acc = -1.
    for n_epoch in range(start_epoch, start_epoch + opts.num_epochs):
        log('EPOCH #%d' % n_epoch)
   
        # training
        loss_train = agent.train(n_epoch)

        # save the model checkpoint
        if n_epoch % opts.save_model_epochs == 0:
            agent.save(n_epoch, opts.checkpoint_dir)

        # validation
        if not opts.no_validation:
            loss_valid = agent.valid(n_epoch)

        # reduce the learning rate
        if opts.no_validation:
            scheduler.step()
        else:
            scheduler.step(loss_valid)
コード例 #2
0
        log("using CPU", "WARNING")

    torch.manual_seed(opts.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(opts.seed)
    random.seed(opts.seed)

    if "ours" in opts.method:
        model = Prover(opts)
        log("loading model checkpoint from %s.." % opts.path)
        if opts.device.type == "cpu":
            checkpoint = torch.load(opts.path, map_location="cpu")
        else:
            checkpoint = torch.load(opts.path)
        model.load_state_dict(checkpoint["state_dict"])
        model.to(opts.device)
    else:
        model = None

    agent = Agent(model, None, None, opts)

    if opts.file:
        files = [opts.file]
    else:
        files = []
        projs = json.load(open(opts.projs_split))["projs_" + opts.split]
        for proj in projs:
            files.extend(
                glob(os.path.join(opts.datapath, "%s/**/*.json" % proj),
                     recursive=True))
コード例 #3
0
ファイル: evaluate.py プロジェクト: brando90/TacTok
        "weak-up-to", "buchberger", "jordan-curve-theorem", "dblib", "disel",
        "zchinese", "zfc", "dep-map", "chinese", "UnifySL", "hoare-tut",
        "huffman", "PolTac", "angles", "coq-procrastination",
        "coq-library-undecidability", "tree-automata", "coquelicot", "fermat4",
        "demos", "coqoban", "goedel", "verdi-raft", "verdi", "zorns-lemma",
        "coqrel", "fundamental-arithmetics"
    ]

    if 'ours' in opts.method:
        model = Prover(opts)
        log('loading model checkpoint from %s..' % opts.path)
        if opts.device.type == 'cpu':
            checkpoint = torch.load(opts.path, map_location='cpu')
        else:
            checkpoint = torch.load(opts.path)
        model.load_state_dict(checkpoint['state_dict'])
        model.to(opts.device)
    else:
        model = None

    agent = Agent(model, None, None, opts)

    if opts.file:
        files = [opts.file]
    elif opts.proj_idx is not None:
        files = glob(os.path.join(opts.datapath,
                                  '%s/**/*.json' % projs_test[opts.proj_idx]),
                     recursive=True)
    else:
        files = []
        projs = json.load(open(opts.projs_split))['projs_' + opts.split]