Exemplo n.º 1
0
    flags.DEFINE_string(
        "model_path",
        "C:/Users/lee/Desktop/leequant761/siamese-pytorch/models",
        "path to store model")
    flags.DEFINE_string("gpu_ids", "0,1,2,3", "gpu ids used to train")

    Flags(sys.argv)

    data_transforms = transforms.Compose(
        [transforms.RandomAffine(15),
         transforms.ToTensor()])

    # os.environ["CUDA_VISIBLE_DEVICES"] = Flags.gpu_ids
    # print("use gpu:", Flags.gpu_ids, "to train.")

    trainSet = OmniglotTrain(Flags.train_path, transform=data_transforms)
    testSet = OmniglotTest(Flags.test_path,
                           transform=transforms.ToTensor(),
                           times=Flags.times,
                           way=Flags.way)

    trainLoader = DataLoader(trainSet,
                             batch_size=Flags.batch_size,
                             shuffle=False,
                             num_workers=Flags.workers)
    testLoader = DataLoader(testSet,
                            batch_size=Flags.way,
                            shuffle=False,
                            num_workers=Flags.workers)

    loss_fn = torch.nn.BCEWithLogitsLoss(reduction='mean')
Exemplo n.º 2
0
cuda = torch.cuda.is_available()

data_transforms = transforms.Compose(
    [transforms.RandomAffine(15),
     transforms.ToTensor()])

# Assuming you have run make_dataset.py as specified.
train_path = 'background'
test_path = 'evaluation'
train_dataset = dset.ImageFolder(root=train_path)
test_dataset = dset.ImageFolder(root=test_path)

way = 20
times = 400

dataSet = OmniglotTrain(train_dataset, transform=data_transforms)
testSet = OmniglotTest(test_dataset,
                       transform=transforms.ToTensor(),
                       times=times,
                       way=way)
testLoader = DataLoader(testSet, batch_size=way, shuffle=False, num_workers=16)

dataLoader = DataLoader(dataSet, batch_size=cmd.trainBatch,\
                        shuffle=False, num_workers=16)

# Get the network architecture
net = Siamese()
# Loss criterion
criterion = torch.nn.BCEWithLogitsLoss(size_average=True)

# Optimizer