예제 #1
0
def test(model_path, data_path, target_tag, pair_num, save_result=False):
    model = SiameseNetwork()
    model.load_state_dict(
        torch.load(model_path, map_location='cpu')['model_state_dict'])
    result = []

    for i in range(0, pair_num):
        im1, im2, is_same, first_image, second_image = random_data_generator(
            data_path, target_tag)
        predict = model(im1.unsqueeze(0), im2.unsqueeze(0)).item()
        result.append((first_image, second_image, predict, is_same))

    if save_result:
        plt.figure(figsize=(4, 12), dpi=300)
        plt.subplots_adjust(left=None,
                            bottom=None,
                            right=None,
                            top=None,
                            wspace=None,
                            hspace=1.5)

        for i, item in enumerate(result):
            first_image, second_image, predict, is_same = item[0], item[
                1], item[2], item[3]
            img1 = Image.open(first_image)
            img2 = Image.open(second_image)
            img1 = img1.convert('RGB')
            img2 = img2.convert('RGB')
            plt.subplot(pair_num, 2, 2 * i + 1)
            plt.axis('off')
            plt.title('Label=%s, Prediction=%s' % (str(is_same), str(predict)),
                      fontsize=4)
            plt.imshow(img1)
            plt.subplot(pair_num, 2, 2 * i + 2)
            plt.axis('off')
            plt.title('Real Label=%s' % second_image.split('/')[-2],
                      fontsize=4)
            plt.imshow(img2)

        plt.savefig('./result.png', dpi=300)

    return result
예제 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-img', type=str, default='test.jpg')
    # parser.add_argument('-ques', type=str, default='What vechile is in the picture?')
    args = parser.parse_args()
    img = args.img
    folder_dataset_test = dset.ImageFolder(root=Config.testing_dir)
    siamese_dataset = SiameseNetworkDataset(
        imageFolderDataset=folder_dataset_test,
        transform=transforms.Compose(
            [transforms.Resize((100, 100)),
             transforms.ToTensor()]),
        should_invert=False)

    img = Image.open(img)
    img = img.convert("L")
    transform = transforms.Compose(
        [transforms.Resize((100, 100)),
         transforms.ToTensor()])
    img = transform(img)
    # Add a dimension to image to make padding possible.
    img = img[None, :, :, :]

    test_dataloader = DataLoader(siamese_dataset, num_workers=3, batch_size=1)
    dataiter = iter(test_dataloader)
    net = SiameseNetwork()
    net.load_state_dict(torch.load("trained_weights.pt"))

    for i in range(4):
        _, x1, label2 = next(dataiter)
        concatenated = torch.cat((img, x1), 0)

        output1, output2 = net(Variable(img), Variable(x1))
        euclidean_distance = F.pairwise_distance(output1, output2)
        imshow(
            torchvision.utils.make_grid(concatenated), 'Dissimilarity: \
					{:.2f}'.format(euclidean_distance.data.numpy()[0][0]))
예제 #3
0
from siamese_dataset import SiameseNetworkDataset
from siamese_network import SiameseNetwork
from helpers import imshow
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn.functional as F


class Config():
    testing_dir = "/home/wingman2/datasets/personas/test/"
    # testing_dir = "/home/wingman2/code/Facial-Similarity-with-Siamese-Networks-in-Pytorch/data/faces/testing/"


model = SiameseNetwork().cuda()
model.load_state_dict(torch.load('/home/wingman2/models/siamese-faces-160.pt'))
model.eval()

data_transforms_test = transforms.Compose(
    [transforms.Resize((100, 100)),
     transforms.ToTensor()])

folder_dataset_test = ImageFolder(root=Config.testing_dir)
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset_test,
                                        transform=data_transforms_test,
                                        should_invert=False)
test_dataloader = DataLoader(siamese_dataset,
                             num_workers=8,
                             batch_size=1,
                             shuffle=True)
dataiter = iter(test_dataloader)