def run_context_predictor(args, res_encoder_model, context_predictor_model,
                          models_store_path):

    print("RUNNING CONTEXT PREDICTOR TRAINING")

    stats_csv_path = os.path.join(models_store_path, "pred_stats.csv")

    dataset_train, dataset_test = get_imagenet_datasets(
        args.image_folder, num_classes=args.num_classes)

    def get_random_patch_loader():
        return DataLoader(dataset_train, args.num_random_patches, shuffle=True)

    random_patch_loader = get_random_patch_loader()
    data_loader_train = DataLoader(dataset_train,
                                   args.sub_batch_size,
                                   shuffle=True)

    params = list(res_encoder_model.parameters()) + list(
        context_predictor_model.parameters())
    optimizer = torch.optim.Adam(params=params, lr=0.00001)

    sub_batches_processed = 0
    batch_loss = 0
    sum_batch_loss = 0
    best_batch_loss = 1e10

    z_vect_similarity = dict()

    for batch in data_loader_train:

        # plt.imshow(img_arr.permute(1,2,0))
        # fig, axes = plt.subplots(7,7)

        img_batch = batch['image'].to(args.device)
        patch_batch = get_patch_tensor_from_image_batch(img_batch)
        batch_size = len(img_batch)

        patches_encoded = res_encoder_model.forward(patch_batch)
        patches_encoded = patches_encoded.view(batch_size, 7, 7, -1)
        patches_encoded = patches_encoded.permute(0, 3, 1, 2)

        for i in range(2):
            patches_return = get_random_patches(random_patch_loader,
                                                args.num_random_patches)
            if patches_return['is_data_loader_finished']:
                random_patch_loader = get_random_patch_loader()
            else:
                random_patches = patches_return['patches_tensor'].to(
                    args.device)

        # enc_random_patches = resnet_encoder.forward(random_patches).detach()
        enc_random_patches = res_encoder_model.forward(random_patches)

        # TODO: vectorize the context_predictor_model - stack all 3x3 contexts together
        predictions, locations = context_predictor_model.forward(
            patches_encoded)
        losses = []

        for b in range(len(predictions) // batch_size):

            b_idx_start = b * batch_size
            b_idx_end = (b + 1) * batch_size

            p_y = locations[b_idx_start][0]
            p_x = locations[b_idx_start][1]

            target = patches_encoded[:, :, p_y, p_x]
            pred = predictions[b_idx_start:b_idx_end]

            dot_norm_val = dot_norm_exp(pred.detach().to('cpu'),
                                        target.detach().to('cpu'))
            euc_loss_val = norm_euclidian(pred.detach().to('cpu'),
                                          target.detach().to('cpu'))

            good_term_dot = dot(pred, target)
            dot_terms = [torch.unsqueeze(good_term_dot, dim=0)]

            for random_patch_idx in range(args.num_random_patches):

                bad_term_dot = dot(
                    pred,
                    enc_random_patches[random_patch_idx:random_patch_idx + 1])
                dot_terms.append(torch.unsqueeze(bad_term_dot, dim=0))

            log_softmax = torch.log_softmax(torch.cat(dot_terms, dim=0), dim=0)
            losses.append(-log_softmax[0, ])
            # losses.append(-torch.log(good_term/divisor))

        loss = torch.mean(torch.cat(losses))
        loss.backward()

        # loss = torch.sum(torch.cat(losses))
        # loss.backward()

        sub_batches_processed += img_batch.shape[0]
        batch_loss += loss.detach().to('cpu')
        sum_batch_loss += torch.sum(torch.cat(losses).detach().to('cpu'))

        if sub_batches_processed >= args.batch_size:

            optimizer.step()
            optimizer.zero_grad()

            print(f"{datetime.datetime.now()} Loss: {batch_loss}")
            print(f"{datetime.datetime.now()} SUM Loss: {sum_batch_loss}")

            torch.save(
                res_encoder_model.state_dict(),
                os.path.join(models_store_path, "last_res_ecoder_weights.pt"))
            torch.save(
                context_predictor_model.state_dict(),
                os.path.join(models_store_path,
                             "last_context_predictor_weights.pt"))

            if best_batch_loss > batch_loss:
                best_batch_loss = batch_loss
                torch.save(
                    res_encoder_model.state_dict(),
                    os.path.join(models_store_path,
                                 "best_res_ecoder_weights.pt"))
                torch.save(
                    context_predictor_model.state_dict(),
                    os.path.join(models_store_path,
                                 "best_context_predictor_weights.pt"))

            for key, cos_similarity_tensor in z_vect_similarity.items():
                print(
                    f"Mean cos_sim for class {key} is {cos_similarity_tensor.mean()} . Number: {cos_similarity_tensor.size()}"
                )

            z_vect_similarity = dict()

            stats = dict(batch_loss=batch_loss, sum_batch_loss=sum_batch_loss)
            write_csv_stats(stats_csv_path, stats)

            sub_batches_processed = 0
            batch_loss = 0
            sum_batch_loss = 0
示例#2
0
        predicted_class = dataset_test.get_class_name(predictions[img_idx])
        actual_class = dataset_test.get_class_name(truth[img_idx])
        axes[row, col].set_title(
            f"Predicted class {predicted_class} \n but actually {actual_class}"
        )

    plt.savefig(f"{image_name}.jpg")
    plt.close()


NUM_CLASSES = None
NUM_CLASSES = 30

data_path = "/home/martin/ai/ImageNet-datasets-downloader/images_4/imagenet_images"
dataset_train, dataset_test = get_imagenet_datasets(data_path,
                                                    num_classes=NUM_CLASSES,
                                                    random_seed=422)
print(dataset_train.get_class_names())

if NUM_CLASSES == None:
    NUM_CLASSES = dataset_train.get_number_of_classes()

NUM_TRAIN_SAMPLES = dataset_train.get_number_of_samples()
NUM_TEST_SAMPLES = dataset_test.get_number_of_samples()

print(f"train_samples  {NUM_TRAIN_SAMPLES} test_samples {NUM_TEST_SAMPLES}")

print(f'num_classes {NUM_CLASSES}')
NUM_EPOCHS = 50
BATCH_SIZE = 2
LEARNING_RATE = 1e-5
        img = np.transpose(img,(1,2,0))

        axes[row, col].imshow(img)

        predicted_class = dataset_test.get_class_name(predictions[img_idx])
        actual_class = dataset_test.get_class_name(truth[img_idx])
        axes[row, col].set_title(f"Predicted class {predicted_class} \n but actually {actual_class}")

    plt.savefig(f"{image_name}.jpg")
    plt.close()

NUM_CLASSES = None
NUM_CLASSES = 1000

data_path = "/Users/martinsf/ai/deep_learning_projects/data/imagenet_images"
dataset_train, dataset_test = get_imagenet_datasets(data_path, num_classes = NUM_CLASSES)

if NUM_CLASSES == None:
    NUM_CLASSES = dataset_train.get_number_of_classes()

NUM_TRAIN_SAMPLES = dataset_train.get_number_of_samples()
NUM_TEST_SAMPLES= dataset_test.get_number_of_samples()

print(f"train_samples  {NUM_TRAIN_SAMPLES} test_samples {NUM_TEST_SAMPLES}")

print(f'num_classes {NUM_CLASSES}')
NUM_EPOCHS = 10000
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
#DEVICE = 'cuda'
DEVICE = 'cpu'
def run_classificator(args, res_classificator_model, res_encoder_model, models_store_path):

    print("RUNNING CLASSIFICATOR TRAINING")
    dataset_train, dataset_test = get_imagenet_datasets(args.image_folder, num_classes = args.num_classes, train_split = 0.2, random_seed = 42)

    stats_csv_path = os.path.join(models_store_path, "classification_stats.csv")

    EPOCHS = 500
    LABELS_PER_CLASS = 25 # not used yet

    data_loader_train = DataLoader(dataset_train, args.sub_batch_size, shuffle = True)
    data_loader_test = DataLoader(dataset_test, args.sub_batch_size, shuffle = True)

    NUM_TRAIN_SAMPLES = dataset_train.get_number_of_samples()
    NUM_TEST_SAMPLES = dataset_test.get_number_of_samples()

    # params = list(res_classificator_model.parameters()) + list(res_encoder_model.parameters())
    optimizer_enc = torch.optim.Adam(params = res_encoder_model.parameters(), lr = 0.00001) # Train encoder slower than the classifier layers
    optimizer_cls = torch.optim.Adam(params = res_classificator_model.parameters(), lr = 0.001)

    best_epoch_test_loss = 1e10

    for epoch in range(EPOCHS):

        sub_batches_processed = 0

        epoch_train_true_positives = 0.0
        epoch_train_loss = 0.0
        epoch_train_losses = []

        batch_train_loss = 0.0
        batch_train_true_positives = 0.0

        epoch_test_true_positives = 0.0
        epoch_test_loss = 0.0

        epoch_test_losses = []

        for batch in data_loader_train:

            img_batch = batch['image'].to(args.device)

            patch_batch = get_patch_tensor_from_image_batch(img_batch)
            patches_encoded = res_encoder_model.forward(patch_batch)

            patches_encoded = patches_encoded.view(img_batch.shape[0], 7,7,-1)
            patches_encoded = patches_encoded.permute(0,3,1,2)

            classes = batch['cls'].to(args.device)

            y_one_hot = torch.zeros(img_batch.shape[0], args.num_classes).to(args.device)
            y_one_hot = y_one_hot.scatter_(1, classes.unsqueeze(dim=1), 1)

            labels = batch['class_name']

            pred = res_classificator_model.forward(patches_encoded)
            loss = torch.sum(-y_one_hot * torch.log(pred))
            epoch_train_losses.append(loss.detach().to('cpu').numpy())
            epoch_train_loss += loss.detach().to('cpu').numpy()
            batch_train_loss += loss.detach().to('cpu').numpy()

            epoch_train_true_positives += torch.sum(pred.argmax(dim=1) == classes)
            batch_train_true_positives += torch.sum(pred.argmax(dim=1) == classes)

            loss.backward()
            sub_batches_processed += img_batch.shape[0]

            if sub_batches_processed >= args.batch_size:

                optimizer_enc.step()
                optimizer_cls.step()

                optimizer_enc.zero_grad()
                optimizer_cls.zero_grad()

                sub_batches_processed = 0

                batch_train_accuracy = float(batch_train_true_positives) / float(args.batch_size)

                print(f"Training loss of batch is {batch_train_loss}")
                print(f"Accuracy of batch is {batch_train_accuracy}")

                batch_train_loss = 0.0
                batch_train_true_positives = 0.0


        epoch_train_accuracy = float(epoch_train_true_positives) / float(NUM_TRAIN_SAMPLES)

        print(f"Training loss of epoch {epoch} is {epoch_train_loss}")
        print(f"Accuracy of epoch {epoch} is {epoch_train_accuracy}")

        with torch.no_grad():

            res_classificator_model.eval()
            res_encoder_model.eval()

            for batch in data_loader_test:

                img_batch = batch['image'].to(args.device)

                patch_batch = get_patch_tensor_from_image_batch(img_batch)
                patches_encoded = res_encoder_model.forward(patch_batch)

                patches_encoded = patches_encoded.view(img_batch.shape[0], 7,7,-1)
                patches_encoded = patches_encoded.permute(0,3,1,2)

                classes = batch['cls'].to(args.device)

                y_one_hot = torch.zeros(img_batch.shape[0], args.num_classes).to(args.device)
                y_one_hot = y_one_hot.scatter_(1, classes.unsqueeze(dim=1), 1)

                labels = batch['class_name']

                pred = res_classificator_model.forward(patches_encoded)
                loss = torch.sum(-y_one_hot * torch.log(pred))
                epoch_test_losses.append(loss.detach().to('cpu').numpy())
                epoch_test_loss += loss.detach().to('cpu').numpy()

                epoch_test_true_positives += torch.sum(pred.argmax(dim=1) == classes)

            epoch_test_accuracy = float(epoch_test_true_positives) / float(NUM_TEST_SAMPLES)

            print(f"Test loss of epoch {epoch} is {epoch_test_loss}")
            print(f"Test accuracy of epoch {epoch} is {epoch_test_accuracy}")

        res_classificator_model.train()
        res_encoder_model.train()


        torch.save(res_encoder_model.state_dict(), os.path.join(models_store_path, "last_res_ecoder_weights.pt"))
        torch.save(res_classificator_model.state_dict(), os.path.join(models_store_path, "last_res_classificator_weights.pt"))

        if best_epoch_test_loss > epoch_test_loss:

            best_epoch_test_loss = epoch_test_loss
            torch.save(res_encoder_model.state_dict(), os.path.join(models_store_path, "best_res_ecoder_weights.pt"))
            torch.save(res_classificator_model.state_dict(), os.path.join(models_store_path, "best_res_classificator_weights.pt"))


        stats = dict(
            epoch = epoch,
            train_acc = epoch_train_accuracy,
            train_loss = epoch_train_loss,
            test_acc = epoch_test_accuracy,
            test_loss = epoch_test_loss
        )

        print(f"Writing dict {stats} into file {stats_csv_path}")
        write_csv_stats(stats_csv_path, stats)