def run_context_predictor(args, res_encoder_model, context_predictor_model, models_store_path): print("RUNNING CONTEXT PREDICTOR TRAINING") stats_csv_path = os.path.join(models_store_path, "pred_stats.csv") dataset_train, dataset_test = get_imagenet_datasets( args.image_folder, num_classes=args.num_classes) def get_random_patch_loader(): return DataLoader(dataset_train, args.num_random_patches, shuffle=True) random_patch_loader = get_random_patch_loader() data_loader_train = DataLoader(dataset_train, args.sub_batch_size, shuffle=True) params = list(res_encoder_model.parameters()) + list( context_predictor_model.parameters()) optimizer = torch.optim.Adam(params=params, lr=0.00001) sub_batches_processed = 0 batch_loss = 0 sum_batch_loss = 0 best_batch_loss = 1e10 z_vect_similarity = dict() for batch in data_loader_train: # plt.imshow(img_arr.permute(1,2,0)) # fig, axes = plt.subplots(7,7) img_batch = batch['image'].to(args.device) patch_batch = get_patch_tensor_from_image_batch(img_batch) batch_size = len(img_batch) patches_encoded = res_encoder_model.forward(patch_batch) patches_encoded = patches_encoded.view(batch_size, 7, 7, -1) patches_encoded = patches_encoded.permute(0, 3, 1, 2) for i in range(2): patches_return = get_random_patches(random_patch_loader, args.num_random_patches) if patches_return['is_data_loader_finished']: random_patch_loader = get_random_patch_loader() else: random_patches = patches_return['patches_tensor'].to( args.device) # enc_random_patches = resnet_encoder.forward(random_patches).detach() enc_random_patches = res_encoder_model.forward(random_patches) # TODO: vectorize the context_predictor_model - stack all 3x3 contexts together predictions, locations = context_predictor_model.forward( patches_encoded) losses = [] for b in range(len(predictions) // batch_size): b_idx_start = b * batch_size b_idx_end = (b + 1) * batch_size p_y = locations[b_idx_start][0] p_x = locations[b_idx_start][1] target = patches_encoded[:, :, p_y, p_x] pred = predictions[b_idx_start:b_idx_end] dot_norm_val = dot_norm_exp(pred.detach().to('cpu'), target.detach().to('cpu')) euc_loss_val = norm_euclidian(pred.detach().to('cpu'), target.detach().to('cpu')) good_term_dot = dot(pred, target) dot_terms = [torch.unsqueeze(good_term_dot, dim=0)] for random_patch_idx in range(args.num_random_patches): bad_term_dot = dot( pred, enc_random_patches[random_patch_idx:random_patch_idx + 1]) dot_terms.append(torch.unsqueeze(bad_term_dot, dim=0)) log_softmax = torch.log_softmax(torch.cat(dot_terms, dim=0), dim=0) losses.append(-log_softmax[0, ]) # losses.append(-torch.log(good_term/divisor)) loss = torch.mean(torch.cat(losses)) loss.backward() # loss = torch.sum(torch.cat(losses)) # loss.backward() sub_batches_processed += img_batch.shape[0] batch_loss += loss.detach().to('cpu') sum_batch_loss += torch.sum(torch.cat(losses).detach().to('cpu')) if sub_batches_processed >= args.batch_size: optimizer.step() optimizer.zero_grad() print(f"{datetime.datetime.now()} Loss: {batch_loss}") print(f"{datetime.datetime.now()} SUM Loss: {sum_batch_loss}") torch.save( res_encoder_model.state_dict(), os.path.join(models_store_path, "last_res_ecoder_weights.pt")) torch.save( context_predictor_model.state_dict(), os.path.join(models_store_path, "last_context_predictor_weights.pt")) if best_batch_loss > batch_loss: best_batch_loss = batch_loss torch.save( res_encoder_model.state_dict(), os.path.join(models_store_path, "best_res_ecoder_weights.pt")) torch.save( context_predictor_model.state_dict(), os.path.join(models_store_path, "best_context_predictor_weights.pt")) for key, cos_similarity_tensor in z_vect_similarity.items(): print( f"Mean cos_sim for class {key} is {cos_similarity_tensor.mean()} . Number: {cos_similarity_tensor.size()}" ) z_vect_similarity = dict() stats = dict(batch_loss=batch_loss, sum_batch_loss=sum_batch_loss) write_csv_stats(stats_csv_path, stats) sub_batches_processed = 0 batch_loss = 0 sum_batch_loss = 0
predicted_class = dataset_test.get_class_name(predictions[img_idx]) actual_class = dataset_test.get_class_name(truth[img_idx]) axes[row, col].set_title( f"Predicted class {predicted_class} \n but actually {actual_class}" ) plt.savefig(f"{image_name}.jpg") plt.close() NUM_CLASSES = None NUM_CLASSES = 30 data_path = "/home/martin/ai/ImageNet-datasets-downloader/images_4/imagenet_images" dataset_train, dataset_test = get_imagenet_datasets(data_path, num_classes=NUM_CLASSES, random_seed=422) print(dataset_train.get_class_names()) if NUM_CLASSES == None: NUM_CLASSES = dataset_train.get_number_of_classes() NUM_TRAIN_SAMPLES = dataset_train.get_number_of_samples() NUM_TEST_SAMPLES = dataset_test.get_number_of_samples() print(f"train_samples {NUM_TRAIN_SAMPLES} test_samples {NUM_TEST_SAMPLES}") print(f'num_classes {NUM_CLASSES}') NUM_EPOCHS = 50 BATCH_SIZE = 2 LEARNING_RATE = 1e-5
img = np.transpose(img,(1,2,0)) axes[row, col].imshow(img) predicted_class = dataset_test.get_class_name(predictions[img_idx]) actual_class = dataset_test.get_class_name(truth[img_idx]) axes[row, col].set_title(f"Predicted class {predicted_class} \n but actually {actual_class}") plt.savefig(f"{image_name}.jpg") plt.close() NUM_CLASSES = None NUM_CLASSES = 1000 data_path = "/Users/martinsf/ai/deep_learning_projects/data/imagenet_images" dataset_train, dataset_test = get_imagenet_datasets(data_path, num_classes = NUM_CLASSES) if NUM_CLASSES == None: NUM_CLASSES = dataset_train.get_number_of_classes() NUM_TRAIN_SAMPLES = dataset_train.get_number_of_samples() NUM_TEST_SAMPLES= dataset_test.get_number_of_samples() print(f"train_samples {NUM_TRAIN_SAMPLES} test_samples {NUM_TEST_SAMPLES}") print(f'num_classes {NUM_CLASSES}') NUM_EPOCHS = 10000 BATCH_SIZE = 8 LEARNING_RATE = 1e-4 #DEVICE = 'cuda' DEVICE = 'cpu'
def run_classificator(args, res_classificator_model, res_encoder_model, models_store_path): print("RUNNING CLASSIFICATOR TRAINING") dataset_train, dataset_test = get_imagenet_datasets(args.image_folder, num_classes = args.num_classes, train_split = 0.2, random_seed = 42) stats_csv_path = os.path.join(models_store_path, "classification_stats.csv") EPOCHS = 500 LABELS_PER_CLASS = 25 # not used yet data_loader_train = DataLoader(dataset_train, args.sub_batch_size, shuffle = True) data_loader_test = DataLoader(dataset_test, args.sub_batch_size, shuffle = True) NUM_TRAIN_SAMPLES = dataset_train.get_number_of_samples() NUM_TEST_SAMPLES = dataset_test.get_number_of_samples() # params = list(res_classificator_model.parameters()) + list(res_encoder_model.parameters()) optimizer_enc = torch.optim.Adam(params = res_encoder_model.parameters(), lr = 0.00001) # Train encoder slower than the classifier layers optimizer_cls = torch.optim.Adam(params = res_classificator_model.parameters(), lr = 0.001) best_epoch_test_loss = 1e10 for epoch in range(EPOCHS): sub_batches_processed = 0 epoch_train_true_positives = 0.0 epoch_train_loss = 0.0 epoch_train_losses = [] batch_train_loss = 0.0 batch_train_true_positives = 0.0 epoch_test_true_positives = 0.0 epoch_test_loss = 0.0 epoch_test_losses = [] for batch in data_loader_train: img_batch = batch['image'].to(args.device) patch_batch = get_patch_tensor_from_image_batch(img_batch) patches_encoded = res_encoder_model.forward(patch_batch) patches_encoded = patches_encoded.view(img_batch.shape[0], 7,7,-1) patches_encoded = patches_encoded.permute(0,3,1,2) classes = batch['cls'].to(args.device) y_one_hot = torch.zeros(img_batch.shape[0], args.num_classes).to(args.device) y_one_hot = y_one_hot.scatter_(1, classes.unsqueeze(dim=1), 1) labels = batch['class_name'] pred = res_classificator_model.forward(patches_encoded) loss = torch.sum(-y_one_hot * torch.log(pred)) epoch_train_losses.append(loss.detach().to('cpu').numpy()) epoch_train_loss += loss.detach().to('cpu').numpy() batch_train_loss += loss.detach().to('cpu').numpy() epoch_train_true_positives += torch.sum(pred.argmax(dim=1) == classes) batch_train_true_positives += torch.sum(pred.argmax(dim=1) == classes) loss.backward() sub_batches_processed += img_batch.shape[0] if sub_batches_processed >= args.batch_size: optimizer_enc.step() optimizer_cls.step() optimizer_enc.zero_grad() optimizer_cls.zero_grad() sub_batches_processed = 0 batch_train_accuracy = float(batch_train_true_positives) / float(args.batch_size) print(f"Training loss of batch is {batch_train_loss}") print(f"Accuracy of batch is {batch_train_accuracy}") batch_train_loss = 0.0 batch_train_true_positives = 0.0 epoch_train_accuracy = float(epoch_train_true_positives) / float(NUM_TRAIN_SAMPLES) print(f"Training loss of epoch {epoch} is {epoch_train_loss}") print(f"Accuracy of epoch {epoch} is {epoch_train_accuracy}") with torch.no_grad(): res_classificator_model.eval() res_encoder_model.eval() for batch in data_loader_test: img_batch = batch['image'].to(args.device) patch_batch = get_patch_tensor_from_image_batch(img_batch) patches_encoded = res_encoder_model.forward(patch_batch) patches_encoded = patches_encoded.view(img_batch.shape[0], 7,7,-1) patches_encoded = patches_encoded.permute(0,3,1,2) classes = batch['cls'].to(args.device) y_one_hot = torch.zeros(img_batch.shape[0], args.num_classes).to(args.device) y_one_hot = y_one_hot.scatter_(1, classes.unsqueeze(dim=1), 1) labels = batch['class_name'] pred = res_classificator_model.forward(patches_encoded) loss = torch.sum(-y_one_hot * torch.log(pred)) epoch_test_losses.append(loss.detach().to('cpu').numpy()) epoch_test_loss += loss.detach().to('cpu').numpy() epoch_test_true_positives += torch.sum(pred.argmax(dim=1) == classes) epoch_test_accuracy = float(epoch_test_true_positives) / float(NUM_TEST_SAMPLES) print(f"Test loss of epoch {epoch} is {epoch_test_loss}") print(f"Test accuracy of epoch {epoch} is {epoch_test_accuracy}") res_classificator_model.train() res_encoder_model.train() torch.save(res_encoder_model.state_dict(), os.path.join(models_store_path, "last_res_ecoder_weights.pt")) torch.save(res_classificator_model.state_dict(), os.path.join(models_store_path, "last_res_classificator_weights.pt")) if best_epoch_test_loss > epoch_test_loss: best_epoch_test_loss = epoch_test_loss torch.save(res_encoder_model.state_dict(), os.path.join(models_store_path, "best_res_ecoder_weights.pt")) torch.save(res_classificator_model.state_dict(), os.path.join(models_store_path, "best_res_classificator_weights.pt")) stats = dict( epoch = epoch, train_acc = epoch_train_accuracy, train_loss = epoch_train_loss, test_acc = epoch_test_accuracy, test_loss = epoch_test_loss ) print(f"Writing dict {stats} into file {stats_csv_path}") write_csv_stats(stats_csv_path, stats)