def test_should_return_single_image(self): # given batch_size = 4 data_loader = dm.data_loader(dm.classes.FROG, batch_size, True, self.cifar_path) data_iter = iter(data_loader) image, isTrue = data_iter.next() self.assertEqual(image.shape, torch.Size([batch_size, 3, 32, 32]))
def test_iterator_should_return_2_labels(self): # given batch_size = 2 data_loader = dm.data_loader(dm.classes.FROG, batch_size, True, self.cifar_paths) label_size = [] # when for batch, labels in iter(data_loader): label_size = len(labels) break # then self.assertEqual(label_size, 2)
def test_iterator_should_return_2_images(self): # given batch_size = 2 data_loader = dm.data_loader(dm.classes.FROG, batch_size, True, self.cifar_path) returned_batch_size = [] # when for batch, labels in iter(data_loader): returned_batch_size = batch.size() break # then self.assertEqual(returned_batch_size, torch.Size([2, 3, 32, 32]))
def test_data_iterator_should_remove_data(self): # given batch_size = 15 incorrect_per_correct = 2 data_loader = dm.data_loader(dm.classes.FROG, batch_size, True, self.cifar_path) data_iterator = data_loader._data_iterator(data_loader._trainerloader, data_loader._batch_size, dm.classes.FROG, incorrect_per_correct) # when for imgs, labels in data_iterator: continue # then self.assertTrue(len(data_iterator.data) < batch_size)
def test_data_iterator_should_return_correct_labels(self): # given batch_size = 15 incorrect_per_correct = 2 expected_labels = [0, 1, 1] * 2 data_loader = dm.data_loader(dm.classes.FROG, batch_size, True, self.cifar_path) data_iterator = data_loader._data_iterator(data_loader._trainerloader, data_loader._batch_size, dm.classes.FROG, incorrect_per_correct) # when imgs, labels = data_iterator.next() # then self.assertEqual(labels[:6].tolist(), expected_labels)
def test_data_iterator_should_have_correct_data_ammount(self): # given batch_size = 15 images_in_cifar_count = 5000 incorrect_per_correct = 3 data_loader = dm.data_loader(dm.classes.FROG, batch_size, True, self.cifar_path) # when data_iterator = data_loader._data_iterator( data_loader._trainerloader, data_loader._batch_size, dm.classes.FROG, incorrect_per_correct, ) # then self.assertEqual(len(data_iterator.data), images_in_cifar_count * (incorrect_per_correct + 1))
num = image.numpy() num = np.flip(num.flatten(3).reshape(32, 32, 3), 2) num = np.rot90(num, axes=(1, 0)) num = num - np.min(num) num = num / np.max(num) cv.imshow("test_image", num) cv.waitKey(0) discriminator, training_data, generator, generator_data = load_session( 'saved_session') print(training_data) print(generator_data) batch_size = 1 data_loader = data_manager.data_loader(training_data['discriminated_class'], batch_size, False) counter = 0 with torch.no_grad(): for image, label in data_loader: output = discriminator(image) output_np = output.cpu().numpy() #print("Żaba: %f \n Coś innego: %f" % (output_np[0,0], output_np[0,1])) #show_image(image) i = int(np.argmax(output.detach())) label = int(label.to('cpu')) if (i == 0 and i == label): counter = counter + 1 print(counter)
conv_layer1_size = 20 conv_layer2_size = 28 image_size = [32] * 2 discriminator = gan_networks.Discriminator(conv_layer1_size, conv_layer2_size, lin_layer1_size, image_size) noise_length = 400 lin_hidden_layer_size = 600 deconv1_size = 120 deconv2_size = 200 output_size = [32] * 2 generator = gan_networks.Generator(noise_length, lin_hidden_layer_size, deconv1_size, deconv2_size, output_size) batch_size = 5 correct_class_enum = data_manager.classes.FROG data_loader = data_manager.data_loader(correct_class_enum, batch_size, True) generator_optimizer = torch.optim.Adam(generator.parameters()) discriminator_optimizer = torch.optim.Adam(discriminator.parameters()) criterion = torch.nn.CrossEntropyLoss() epoch_count = 4 incorrect_to_correct_ratio = 1 logger = discriminator_logger(batch_size, epoch_count, incorrect_to_correct_ratio, correct_class_enum, noise_length) # rekord: 943 250/32/64 959 250/20/28 iteration_counter = 0