예제 #1
0
    def test_should_return_single_image(self):
        # given
        batch_size = 4

        data_loader = dm.data_loader(dm.classes.FROG, batch_size, True,
                                     self.cifar_path)
        data_iter = iter(data_loader)
        image, isTrue = data_iter.next()

        self.assertEqual(image.shape, torch.Size([batch_size, 3, 32, 32]))
예제 #2
0
    def test_iterator_should_return_2_labels(self):
        # given
        batch_size = 2
        data_loader = dm.data_loader(dm.classes.FROG, batch_size, True,
                                     self.cifar_paths)

        label_size = []
        # when
        for batch, labels in iter(data_loader):
            label_size = len(labels)
            break

        # then
        self.assertEqual(label_size, 2)
예제 #3
0
    def test_iterator_should_return_2_images(self):
        # given
        batch_size = 2
        data_loader = dm.data_loader(dm.classes.FROG, batch_size, True,
                                     self.cifar_path)

        returned_batch_size = []
        # when
        for batch, labels in iter(data_loader):
            returned_batch_size = batch.size()
            break

        # then
        self.assertEqual(returned_batch_size, torch.Size([2, 3, 32, 32]))
예제 #4
0
    def test_data_iterator_should_remove_data(self):
        # given
        batch_size = 15
        incorrect_per_correct = 2
        data_loader = dm.data_loader(dm.classes.FROG, batch_size, True,
                                     self.cifar_path)
        data_iterator = data_loader._data_iterator(data_loader._trainerloader,
                                                   data_loader._batch_size,
                                                   dm.classes.FROG,
                                                   incorrect_per_correct)

        # when
        for imgs, labels in data_iterator:
            continue

        # then
        self.assertTrue(len(data_iterator.data) < batch_size)
예제 #5
0
    def test_data_iterator_should_return_correct_labels(self):
        # given
        batch_size = 15
        incorrect_per_correct = 2
        expected_labels = [0, 1, 1] * 2
        data_loader = dm.data_loader(dm.classes.FROG, batch_size, True,
                                     self.cifar_path)
        data_iterator = data_loader._data_iterator(data_loader._trainerloader,
                                                   data_loader._batch_size,
                                                   dm.classes.FROG,
                                                   incorrect_per_correct)

        # when
        imgs, labels = data_iterator.next()

        # then
        self.assertEqual(labels[:6].tolist(), expected_labels)
예제 #6
0
    def test_data_iterator_should_have_correct_data_ammount(self):
        # given
        batch_size = 15
        images_in_cifar_count = 5000
        incorrect_per_correct = 3
        data_loader = dm.data_loader(dm.classes.FROG, batch_size, True,
                                     self.cifar_path)

        # when
        data_iterator = data_loader._data_iterator(
            data_loader._trainerloader,
            data_loader._batch_size,
            dm.classes.FROG,
            incorrect_per_correct,
        )

        # then
        self.assertEqual(len(data_iterator.data),
                         images_in_cifar_count * (incorrect_per_correct + 1))
예제 #7
0
    num = image.numpy()
    num = np.flip(num.flatten(3).reshape(32, 32, 3), 2)
    num = np.rot90(num, axes=(1, 0))
    num = num - np.min(num)
    num = num / np.max(num)
    cv.imshow("test_image", num)
    cv.waitKey(0)


discriminator, training_data, generator, generator_data = load_session(
    'saved_session')
print(training_data)
print(generator_data)
batch_size = 1

data_loader = data_manager.data_loader(training_data['discriminated_class'],
                                       batch_size, False)
counter = 0
with torch.no_grad():
    for image, label in data_loader:
        output = discriminator(image)
        output_np = output.cpu().numpy()

        #print("Żaba: %f \n Coś innego: %f" % (output_np[0,0], output_np[0,1]))
        #show_image(image)
        i = int(np.argmax(output.detach()))
        label = int(label.to('cpu'))
        if (i == 0 and i == label):
            counter = counter + 1
print(counter)
예제 #8
0
conv_layer1_size = 20
conv_layer2_size = 28
image_size = [32] * 2

discriminator = gan_networks.Discriminator(conv_layer1_size, conv_layer2_size,
                                           lin_layer1_size, image_size)
noise_length = 400
lin_hidden_layer_size = 600
deconv1_size = 120
deconv2_size = 200
output_size = [32] * 2
generator = gan_networks.Generator(noise_length, lin_hidden_layer_size,
                                   deconv1_size, deconv2_size, output_size)
batch_size = 5
correct_class_enum = data_manager.classes.FROG
data_loader = data_manager.data_loader(correct_class_enum, batch_size, True)

generator_optimizer = torch.optim.Adam(generator.parameters())
discriminator_optimizer = torch.optim.Adam(discriminator.parameters())
criterion = torch.nn.CrossEntropyLoss()

epoch_count = 4
incorrect_to_correct_ratio = 1

logger = discriminator_logger(batch_size, epoch_count,
                              incorrect_to_correct_ratio, correct_class_enum,
                              noise_length)
# rekord: 943 250/32/64 959 250/20/28

iteration_counter = 0