Example #1
0
def load_data(folder_name, batch_size, shuffle=False, drop_last=True):
    '''
    load dataset
    :param folder_name: dirname of datasets
    :param batch_size:
    :param shuffle:
    :param drop_last:
    '''
    if folder_name != 'test':
        dataset = ImageFolder(root="./data", transform=train_tfs)
    else:
        dataset = ImageFolder(root="./data", transform=test_tfs)
    select_ind = dataset.class_to_idx[folder_name]
    idx = 0
    # refer to https://github.com/znxlwm/pytorch-CartoonGAN/blob/master
    for i in range(dataset.__len__()):
        if dataset.imgs[idx][1] != select_ind:
            del dataset.imgs[idx]
            idx -= 1
        idx += 1
    return DataLoader(dataset=dataset,
                      batch_size=batch_size,
                      shuffle=shuffle,
                      drop_last=drop_last,
                      num_workers=4)
Example #2
0
def load_dataset_chestray(root, batch_size, shuffle, resize, height, width, crop, crop_size, grayscale):

    # Se define la transformacion compuesta:
    # CROP: Para igualar tamañano de entrada
    # TOTENSOR: Para que el tipo sea el adecuado

    transforms = []

    if grayscale:
        transforms.append(Grayscale(num_output_channels=1))

    if resize:
        transforms.append(Resize(size=[height, width], interpolation=Image.NEAREST))

    if crop:
        transforms.append(CenterCrop(crop_size))

    transforms.append(ToTensor())

    dataset = ImageFolder(root=root, transform=Compose(transforms))

    # plot_image_data(dataset[0][0],'Prueba' , 'gray')

    # <dataset> es una subclase de torch.utils.data.Dataset
    # Implementa los metodos __getitem__(), __len__()

    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=True)

    print("\n*******************************")
    print("Dataset Covid Chest Ray cargado")
    print("Tamaño: " + str(dataset.__len__()))
    print("Batchsize: " + str(batch_size))
    print("Batches: " + str(len(data_loader)))
    print("Clases: " + str(dataset.classes))
    print("Sample shape: " + str(dataset[0][0].shape))
    print("*******************************")


    return data_loader
tgc = TypegroupsClassifier.load(args.classifier)
target_transform = tgc.classMap.get_target_transform(test.class_to_idx)

# confusion matrix
cm = [[0 for x in range(tgc.network.fc.out_features)]
      for y in range(tgc.network.fc.out_features)]

nb_good = 0
nb_bad = 0
tgc.network.eval()
imgn = 0

feature_file = open(args.output, 'wb')

with torch.no_grad():
    idx = [i for i in range(0, test.__len__())]
    shuffle(idx)
    for sample_num in idx:
        sample, target = test.__getitem__(sample_num)
        print(target)

        score = torch.zeros(1, tgc.network.fc.out_features).to(dev)
        features = torch.zeros(1, tgc.network.fc.in_features).to(dev)
        print(features.size())
        for n in range(0, args.count):
            sample, target = test.__getitem__(sample_num)
            out, _, ap = tgc.network(sample.unsqueeze_(0).to(dev))
            ap = ap.view(ap.size(0), -1)
            score += out
            features += ap
            feature_file.write(target.to_bytes(1, byteorder='big',
Example #4
0
        os.path.join('ocrd_typegroups_classifier', 'models',
                     'classifier.tgc')):
    tgc = TypegroupsClassifier.load(
        os.path.join('ocrd_typegroups_classifier', 'models', 'classifier.tgc'))
else:
    print('Could not load a model to evaluate')
    quit(1)

validation = ImageFolder('lines/validation', transform=None)
validation.target_transform = tgc.classMap.get_target_transform(
    validation.class_to_idx)
good = 0
bad = 0
with torch.no_grad():
    tgc.network.eval()
    for idx in tqdm(range(validation.__len__()), desc='Evaluation'):
        sample, target = validation.__getitem__(idx)
        path, _ = validation.samples[idx]
        if target == -1:
            continue
        result = tgc.classify(sample, 224, 64, True)
        highscore = max(result)
        label = tgc.classMap.cl2id[result[highscore]]
        if target == label:
            good += 1
        else:
            bad += 1

accuracy = 100 * good / float(good + bad)

print('    Good:', good)
Example #5
0
nb_classes = 1 + max(tgc.classMap.cl2id.values())
nb_outputs = tgc.network.fc.out_features
print(nb_classes, 'classes to consider')
print(nb_outputs, 'outputs to process')

sm = torch.nn.Softmax()

nb_good = 0
nb_bad = 0
tgc.network.eval()
imgn = 0
with torch.no_grad():
    f = open("foo.html", "w")
    f.write('<html><head></head><body><table>')
    for sample_num in range(0, test.__len__()):
        score = torch.zeros(1, nb_outputs).to(dev)

        patchdict = dict()

        f.write('<tr>')
        for n in range(0, args.count):
            sample, target = test.__getitem__(sample_num)
            out, _, _ = tgc.network(sample.unsqueeze_(0).to(dev))

            mx, p = torch.max(out, 1)

            norm = sm(out)
            conf = torch.max(norm).item()

            if mx not in patchdict: