示例#1
0
def dl_init():
    opt = Config()
    try:
        if opt.MODEL == 'MiracleWeightWideNet':
            net = miracle_weight_wide_net.MiracleWeightWideNet(opt)
        elif opt.MODEL == 'MiracleWideNet':
            net = miracle_wide_net.MiracleWideNet(opt)
        elif opt.MODEL == 'MiracleNet':
            net = miracle_net.MiracleNet(opt)
        elif opt.MODEL == 'MiracleLineConvNet':
            net = miracle_lineconv_net.MiracleLineConvNet(opt)
    except KeyError('Your model is not found.'):
        exit(0)
    else:
        print("==> Model initialized successfully.")
    net_save_prefix = opt.NET_SAVE_PATH + opt.MODEL + '_' + opt.PROCESS_ID + '/'
    temp_model_name = net_save_prefix + "best_model.dat"
    if os.path.exists(temp_model_name):
        net, *_ = net.load(temp_model_name)
        print("Load existing model: %s" % temp_model_name)
        if opt.USE_CUDA:
            net.cuda()
            print("==> Using CUDA.")
    else:
        raise FileNotFoundError()
    return opt, net
示例#2
0
def dl_init():
    opt = Config()
    if opt.MODEL == 'MiracleWeightWideNet':
        net = miracle_weight_wide_net.MiracleWeightWideNet(opt)
    elif opt.MODEL == 'MiracleWideNet':
        net = miracle_wide_net.MiracleWideNet(opt)
    elif opt.MODEL == 'MiracleNet':
        net = miracle_net.MiracleNet(opt)
    elif opt.MODEL == 'MiracleLineConvNet':
        net = miracle_lineconv_net.MiracleLineConvNet(opt)

    NET_SAVE_PREFIX = opt.NET_SAVE_PATH + opt.MODEL + '_' + opt.PROCESS_ID + '/'
    temp_model_name = NET_SAVE_PREFIX + "best_model.dat"
    if os.path.exists(temp_model_name):
        net, *_ = net.load(temp_model_name)
        print("Load existing model: %s" % temp_model_name)
    else:
        FileNotFoundError()
    return opt, net
示例#3
0
def main():
    # Initializing configs
    allDataset = None
    all_loader = None
    # opt = Config()
    folder_init(opt)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Load data
    train_pairs, test_pairs = load_data(opt, './TempData/')

    trainDataset = POISSON(train_pairs, opt)
    train_loader = DataLoader(dataset=trainDataset,
                              batch_size=opt.BATCH_SIZE,
                              shuffle=True,
                              num_workers=opt.NUM_WORKERS,
                              drop_last=False)

    testDataset = POISSON(test_pairs, opt)
    test_loader = DataLoader(dataset=testDataset,
                             batch_size=opt.TEST_BATCH_SIZE,
                             shuffle=False,
                             num_workers=opt.NUM_WORKERS,
                             drop_last=False)

    if opt.TRAIN_ALL or opt.TEST_ALL:
        train_pairs.extend(test_pairs)
        all_pairs = train_pairs
        allDataset = POISSON(all_pairs, opt)
        all_loader = DataLoader(dataset=allDataset,
                                batch_size=opt.TEST_BATCH_SIZE,
                                shuffle=True,
                                num_workers=opt.NUM_WORKERS,
                                drop_last=False)

    opt.NUM_TEST = len(testDataset)
    print("==> All datasets are generated successfully.")

    # Initialize model chosen
    try:
        if opt.MODEL == 'MiracleWeightWideNet':
            net = miracle_weight_wide_net.MiracleWeightWideNet(opt)
        elif opt.MODEL == 'MiracleWideNet':
            net = miracle_wide_net.MiracleWideNet(opt)
        elif opt.MODEL == 'MiracleNet':
            net = miracle_net.MiracleNet(opt)
        elif opt.MODEL == 'MiracleLineConvNet':
            net = miracle_lineconv_net.MiracleLineConvNet(opt)
    except KeyError('==> Your model is not found.'):
        exit(0)
    else:
        print("==> Model initialized successfully.")

    # Instantiation of tensorboard and add net graph to it
    writer = SummaryWriter(opt.SUMMARY_PATH)
    dummy_input = Variable(torch.rand(opt.BATCH_SIZE, 2, 9, 41))
    writer.add_graph(net, dummy_input)

    # Start training or testing
    if opt.TEST_ALL:
        results = []
        net, *_ = load_model(net, device, "best_model.dat")
        results = test_all(opt, all_loader, net, results, device)
        out_file = './source/val_results/' + opt.MODEL + '_' + opt.PROCESS_ID + '_results.pkl'
        pickle.dump(results, open(out_file, 'wb+'))
    else:
        pre_epoch = 0
        best_loss = 100
        if opt.LOAD_SAVED_MOD:
            try:
                net, pre_epoch, best_loss = load_model(net, device,
                                                       "temp_model.dat")
            except FileNotFoundError:
                net = model_to_device(net, device)
        else:
            net = model_to_device(net, device)
        if opt.TRAIN_ALL:
            opt.NUM_TRAIN = len(allDataset)
            _ = training(opt, writer, all_loader, test_loader, net, pre_epoch,
                         device, best_loss)
        else:
            opt.NUM_TRAIN = len(trainDataset)
            _ = training(opt, writer, train_loader, test_loader, net,
                         pre_epoch, device, best_loss)