Beispiel #1
0
def get_model(base_model,
              dataset_name=False,
              compile=True,
              weights=None,
              epsilon=1e-8,
              teacher_epsilon=1e-3,
              init_temp=2.5):
    """Take an uncompiled model and return model compiled for ENDD.

    Warning: This function works in place. Model is returned only for
    conveniance.
    """
    if isinstance(base_model, str):
        if not dataset_name:
            raise ValueError(
                'dataset_name must be provided if base_model is given by name.'
            )
        if base_model == 'cnn':
            base_model = cnn.get_model(dataset_name,
                                       compile=False,
                                       softmax=False)
        elif base_model == 'vgg':
            base_model = vgg.get_model(dataset_name,
                                       compile=False,
                                       softmax=False)
        else:
            raise ValueError(
                """Base model {} not recognized, make sure it has been added
                              to endd.py, or pass a Keras model object as base model instead."""
            )

    if weights:
        saveload.load_weights(base_model, weights)

    if compile:
        base_model.compile(optimizer='adam',
                           loss=losses.DirichletEnDDLoss(
                               init_temp=init_temp,
                               epsilon=epsilon,
                               ensemble_epsilon=teacher_epsilon))
    return base_model
Beispiel #2
0
def train_vgg_endd(train_images,
                   ensemble_model,
                   dataset_name,
                   batch_size=128,
                   n_epochs=90,
                   one_cycle_lr_policy=True,
                   init_lr=0.001,
                   cycle_length=60,
                   temp_annealing=True,
                   init_temp=10,
                   dropout_rate=0.3,
                   save_endd_dataset=False,
                   load_previous_endd_dataset=False,
                   repetition = None):
    """Return a trained VGG ENDD model.

    The save_endd_dataset and load_previous_endd_dataset arguments are useful to avoid having to
    re-create the ensemble predictions.

    Args:
        train_images (np.array): Normalized train images, potentially including AUX data.
        ensemble_model (models.ensemble.Ensemble): Ensemble to distill.
        dataset_name (str): Name of dataset (required for loading correct model settings).
        batch_size (int): Batch size to use while training. Default 128,
        n_epochs (int): Number of epochs to train. Default 90.
        one_cycle_lr_policy (bool): True if one cycle LR policy should be used. Default True.
        init_lr (float): Initial learning rate for one cycle LR. Default 0.001.
        cycle_length (int): Epoch length in number of cycles. Default 60.
        temp_annealing (bool): True if temperature annealing should be used. Default True.
        init_temp (float): Initial temperature. Default 10.
        dropout_rate (float): Probability to drop node. Default 0.3.
        save_endd_dataset (bool): True if ENDD dataset should be saved (useful for speeding up
                                  repeated training with the same ensemble. Default False.
        load_previous_endd_dataset (bool): True if ENDD dataset should be loaded. The dataset loaded
                                           is the one saved the last time the function was run with
                                           save_endd_dataset=True.

    Returns:
        (keras.Model): Trained VGG ENDD model.
    """

    nr_models = len(ensemble_model.models)
    if repetition is None:
      save_str = 'train_endd_dataset_{}.pkl'.format(nr_models)
    else:
      save_str = 'train_endd_dataset_rep={}_{}'.format(reptition, nr_models)

    if load_previous_endd_dataset:
        with open('train_endd_dataset_100.pkl', 'rb') as file:
            train_images, train_ensemble_preds = pickle.load(file)
            # Load the particular ammount only
            train_ensemble_preds = train_ensemble_preds[:, :nr_models, :]
            print("loaded")
    else:
        # Get ensemble preds
        print("Evaluating")
        train_ensemble_preds = datasets.get_ensemble_preds(ensemble_model, train_images)
        print("Evaluated")

    # Save / Load pickled data. Generating ensemble preds takes a long time, so saving and
    # loading can make testing much more efficient.
    if save_endd_dataset:
        with open('train_endd_dataset.pkl', 'wb') as file:
            pickle.dump((train_images, train_ensemble_preds), file)
        print("saved")

    # Image augmentation
    data_generator = preprocessing.make_augmented_generator(train_images, train_ensemble_preds,
                                                            batch_size)

    # Callbacks
    endd_callbacks = []
    if one_cycle_lr_policy:
        olp_callback = callbacks.OneCycleLRPolicy(init_lr=init_lr,
                                                  max_lr=init_lr * 10,
                                                  min_lr=init_lr / 1000,
                                                  cycle_length=cycle_length,
                                                  epochs=n_epochs)
        endd_callbacks.append(olp_callback)

    if temp_annealing:
        temp_callback = callbacks.TemperatureAnnealing(init_temp=init_temp,
                                                       cycle_length=cycle_length,
                                                       epochs=n_epochs)
        endd_callbacks.append(temp_callback)

    if not endd_callbacks:
        endd_callbacks = None

    # Build ENDD model
    base_model = vgg.get_model(dataset_name,
                               compile=False,
                               dropout_rate=dropout_rate,
                               softmax=False)
    endd_model = endd.get_model(base_model, init_temp=init_temp, teacher_epsilon=1e-4)

    # Train model
    endd_model.fit(data_generator, epochs=n_epochs, callbacks=endd_callbacks)

    return endd_model
Beispiel #3
0
    train_images = train_images - 1.0
    test_images = test_images / 127.5
    test_images = test_images - 1.0

train_labels = tf.one_hot(train_labels.reshape((-1,)), 10)
test_labels = tf.one_hot(test_labels.reshape((-1,)), 10)

# Image augmentation
data_generator = tf.keras.preprocessing.image.ImageDataGenerator(rotation_range=15,
                                                                 horizontal_flip=True,
                                                                 width_shift_range=4,
                                                                 height_shift_range=4,
                                                                 fill_mode='nearest')

# Get model
model = vgg.get_model(dataset_name='cifar10', compile=True)

model.summary()

# Set-up tensorboard
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

# Train

epochs = 45
init_lr = 0.001
olp_callback = OneCycleLRPolicy(init_lr=init_lr,
                                max_lr=10 * init_lr,
                                min_lr=init_lr / 1000,
                                cycle_length=30,