Пример #1
0
    def image_generator_with_top_policies(self,
                                          images,
                                          labels,
                                          batch_size=None):
        """

        Args:
            images (numpy.array): array with shape (N,dim,dim,channek-size)
            labels (numpy.array): array with shape (N), where each eleemnt is an integer from 0 to num_classes-1
            batch_size (int): batch size of the generator on demand
        Returns:
            generator: generator for augmented images
        """
        if batch_size is None:
            batch_size = self.config["child_batch_size"]

        top_policies_list = self.top_policies[[
            'A_aug1_type', 'A_aug1_magnitude', 'A_aug2_type',
            'A_aug2_magnitude', 'B_aug1_type', 'B_aug1_magnitude',
            'B_aug2_type', 'B_aug2_magnitude', 'C_aug1_type',
            'C_aug1_magnitude', 'C_aug2_type', 'C_aug2_magnitude',
            'D_aug1_type', 'D_aug1_magnitude', 'D_aug2_type',
            'D_aug2_magnitude', 'E_aug1_type', 'E_aug1_magnitude',
            'E_aug2_type', 'E_aug2_magnitude'
        ]].to_dict(orient="records")

        return deepaugment_image_generator(images,
                                           labels,
                                           top_policies_list,
                                           batch_size=batch_size)
Пример #2
0
def run_model(dataset_name, num_classes, epochs, batch_size, policies_path):

    data, input_shape = DataOp.load(dataset_name)
    data = DataOp.preprocess_normal(data)

    wrn_28_10 = ChildCNN(
        model_name="wrn_28_10",
        input_shape=input_shape,
        batch_size=batch_size,
        num_classes=num_classes,
        pre_augmentation_weights_path="initial_model_weights.h5",
        logging=logging,
    )

    if policies_path == "dont_augment":
        policy_str = "non_augmented"
    else:
        policy_str = "augmented"
    csv_logger = CSVLogger(
        f"{EXPERIMENT_FOLDER_PATH}/wrn_28_10_training_on_{dataset_name}_{policy_str}.csv"
    )

    if policies_path == "dont_augment":
        history = wrn_28_10.fit_normal(data,
                                       epochs=epochs,
                                       csv_logger=csv_logger)
        print(f"Reached validation accuracy is {history['val_acc'][-1]}")
    else:

        datagen = deepaugment_image_generator(
            data["X_train"],
            data["y_train"],
            policies_path,
            batch_size=batch_size,
            augment_chance=0.8,
        )
        print("fitting the model")
        history = wrn_28_10.fit_with_generator(
            datagen,
            data["X_val"],
            data["y_val"],
            train_data_size=len(data["X_train"]),
            epochs=epochs,
            csv_logger=csv_logger,
        )
        print(f"Reached validation accuracy is {history['val_acc'][-1]}")
Пример #3
0
def run_full_model(images,
                   labels,
                   test_proportion=0.1,
                   model="wrn_28_10",
                   epochs=200,
                   batch_size=64,
                   policies_path="dont_augment"):

    data = {}
    data["X_train"], data["X_val"], data["y_train"], data[
        "y_val"] = train_test_split(images,
                                    labels,
                                    test_size=test_proportion,
                                    shuffle=True)

    data = DataOp.preprocess_normal(data)

    input_shape = data["X_train"][0].shape
    num_classes = data["y_train"].shape[1]

    cnn_config = {
        "model": model,
        "weights": "imagenet",
        "input_shape": input_shape,
        "child_batch_size": batch_size,
        "pre_augmentation_weights_path": "initial_model_weights.h5",
        "logging": logging
    }

    full_model = ChildCNN(input_shape=input_shape,
                          num_classes=num_classes,
                          config=cnn_config)

    if policies_path == "dont_augment":
        policy_str = "non_augmented"
    else:
        policy_str = "augmented"
    csv_logger = CSVLogger(
        f"{EXPERIMENT_FOLDER_PATH}/wrn_28_10_training_on_{policy_str}.csv")

    if policies_path == "dont_augment":
        history = full_model.fit_normal(data,
                                        epochs=epochs,
                                        csv_logger=csv_logger)
        print(f"Reached validation accuracy is {history['val_acc'][-1]}")
    else:

        datagen = deepaugment_image_generator(
            data["X_train"],
            data["y_train"],
            policies_path,
            batch_size=batch_size,
            augment_chance=0.8,
        )
        file = open('datagen.txt', 'w')
        file.write(datagen)
        file.close()
        print("fitting the model")

        history = full_model.fit_with_generator(
            datagen,
            data["X_val"],
            data["y_val"],
            train_data_size=len(data["X_train"]),
            epochs=epochs,
            csv_logger=csv_logger,
        )
        print(f"Reached validation accuracy is {history['val_acc'][-1]}")