コード例 #1
0
def run(_run, image_shape, data_dir, train_pairs, valid_pairs, classes,
        class_weight, architecture, weights, batch_size, base_layers, pooling,
        dense_layers, metrics, device, opt_params, dropout_p,
        resuming_from_ckpt_file, steps_per_epoch, epochs, validation_steps,
        workers, use_multiprocessing, initial_epoch, early_stop_patience,
        tensorboard_tag, first_trainable_layer):
    report_dir = _run.observers[0].dir

    g = ImageDataGenerator(
        horizontal_flip=True,
        vertical_flip=True,
        samplewise_center=True,
        samplewise_std_normalization=True,
        zoom_range=45,
        rotation_range=.2,
        height_shift_range=.2,
        width_shift_range=.2,
        fill_mode='reflect',
        preprocessing_function=get_preprocess_fn(architecture))

    if isinstance(classes, int):
        classes = sorted(os.listdir(os.path.join(data_dir, 'train')))[:classes]

    train_data = BalancedDirectoryPairsSequence(os.path.join(
        data_dir, 'train'),
                                                g,
                                                target_size=image_shape[:2],
                                                pairs=train_pairs,
                                                classes=classes,
                                                batch_size=batch_size)
    valid_data = BalancedDirectoryPairsSequence(os.path.join(
        data_dir, 'valid'),
                                                g,
                                                target_size=image_shape[:2],
                                                pairs=valid_pairs,
                                                classes=classes,
                                                batch_size=batch_size)

    if class_weight == 'balanced':
        class_weight = get_class_weights(train_data.classes)

    with tf.device(device):
        print('building...')
        model = build_siamese_gram_model(image_shape,
                                         architecture,
                                         dropout_p,
                                         weights,
                                         base_layers=base_layers,
                                         dense_layers=dense_layers,
                                         pooling=pooling,
                                         include_top=False,
                                         trainable_limbs=True,
                                         embedding_units=0,
                                         joints='l2',
                                         include_base_top=False)
        model.summary()

        layer_names = [l.name for l in model.layers]

        if first_trainable_layer:
            if first_trainable_layer not in layer_names:
                raise ValueError('%s is not a layer in the model: %s' %
                                 (first_trainable_layer, layer_names))

            for layer in model.layers:
                if layer.name == first_trainable_layer:
                    break
                layer.trainable = False

        model.compile(optimizer=optimizers.Adam(**opt_params),
                      metrics=metrics,
                      loss=contrastive_loss)

        if resuming_from_ckpt_file:
            print('re-loading weights...')
            model.load_weights(resuming_from_ckpt_file)

        print('training from epoch %i...' % initial_epoch)
        try:
            model.fit_generator(
                train_data,
                steps_per_epoch=steps_per_epoch,
                epochs=epochs,
                verbose=2,
                validation_data=valid_data,
                validation_steps=validation_steps,
                initial_epoch=initial_epoch,
                class_weight=class_weight,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                callbacks=[
                    # callbacks.LearningRateScheduler(lambda epoch: .5 ** (epoch // 10) * opt_params['lr']),
                    callbacks.TerminateOnNaN(),
                    callbacks.ReduceLROnPlateau(min_lr=1e-10,
                                                patience=int(
                                                    early_stop_patience // 3)),
                    callbacks.EarlyStopping(patience=early_stop_patience),
                    callbacks.TensorBoard(os.path.join(report_dir,
                                                       tensorboard_tag),
                                          batch_size=batch_size),
                    callbacks.ModelCheckpoint(os.path.join(
                        report_dir, 'weights.h5'),
                                              save_best_only=True,
                                              verbose=1),
                ])
        except KeyboardInterrupt:
            print('interrupted by user')
        else:
            print('done')
コード例 #2
0
def run(_run, image_shape, data_dir, train_pairs, valid_pairs, classes,
        num_classes, architecture, weights, batch_size, base_layers, pooling,
        device, predictions_activation, opt_params, dropout_rate,
        resuming_ckpt, ckpt, steps_per_epoch, epochs, validation_steps, joints,
        workers, use_multiprocessing, initial_epoch, early_stop_patience,
        dense_layers, embedding_units, limb_weights, trainable_limbs,
        tensorboard_tag):
    report_dir = _run.observers[0].dir

    if isinstance(classes, int):
        classes = sorted(os.listdir(os.path.join(data_dir, 'train')))[:classes]

    g = ImageDataGenerator(
        horizontal_flip=True,
        vertical_flip=True,
        zoom_range=.2,
        rotation_range=.2,
        height_shift_range=.2,
        width_shift_range=.2,
        fill_mode='reflect',
        preprocessing_function=utils.get_preprocess_fn(architecture))

    train_data = BalancedDirectoryPairsSequence(os.path.join(
        data_dir, 'train'),
                                                g,
                                                target_size=image_shape[:2],
                                                pairs=train_pairs,
                                                classes=classes,
                                                batch_size=batch_size)
    valid_data = BalancedDirectoryPairsSequence(os.path.join(
        data_dir, 'valid'),
                                                g,
                                                target_size=image_shape[:2],
                                                pairs=valid_pairs,
                                                classes=classes,
                                                batch_size=batch_size)
    if steps_per_epoch is None:
        steps_per_epoch = len(train_data)
    if validation_steps is None:
        validation_steps = len(valid_data)

    with tf.device(device):
        print('building...')

        model = build_siamese_gram_model(
            image_shape,
            architecture,
            dropout_rate,
            weights,
            num_classes,
            base_layers,
            dense_layers,
            pooling,
            predictions_activation=predictions_activation,
            limb_weights=limb_weights,
            trainable_limbs=trainable_limbs,
            embedding_units=embedding_units,
            joints=joints)
        print('siamese model summary:')
        model.summary()
        if resuming_ckpt:
            print('loading weights...')
            model.load_weights(resuming_ckpt)

        model.compile(loss='binary_crossentropy',
                      metrics=['accuracy'],
                      optimizer=optimizers.Adam(**opt_params))

        print('training from epoch %i...' % initial_epoch)
        try:
            model.fit_generator(
                train_data,
                steps_per_epoch=steps_per_epoch,
                epochs=epochs,
                validation_data=valid_data,
                validation_steps=validation_steps,
                initial_epoch=initial_epoch,
                use_multiprocessing=use_multiprocessing,
                workers=workers,
                verbose=2,
                callbacks=[
                    callbacks.TerminateOnNaN(),
                    callbacks.EarlyStopping(patience=early_stop_patience),
                    callbacks.ReduceLROnPlateau(min_lr=1e-10,
                                                patience=int(
                                                    early_stop_patience // 3)),
                    callbacks.TensorBoard(os.path.join(report_dir,
                                                       tensorboard_tag),
                                          batch_size=batch_size),
                    callbacks.ModelCheckpoint(os.path.join(report_dir, ckpt),
                                              save_best_only=True,
                                              verbose=1),
                ])
        except KeyboardInterrupt:
            print('interrupted by user')
        else:
            print('done')
コード例 #3
0
def run(dataset_seed, image_shape, batch_size, device, data_dir, output_dir,
        phases, architecture,
        o_meta, limb_weights, joint_weights, weights, pooling,
        dense_layers, use_gram_matrix, last_base_layer, override,
        embedded_files_max_size, selected_layers):
    os.makedirs(output_dir, exist_ok=True)

    with tf.device(device):
        print('building model...')
        model = build_siamese_model(image_shape, architecture, 0.0, weights,
                                    last_base_layer=last_base_layer,
                                    use_gram_matrix=use_gram_matrix,
                                    dense_layers=dense_layers, pooling=pooling,
                                    include_base_top=False, include_top=True,
                                    trainable_limbs=False,
                                    limb_weights=limb_weights,
                                    predictions_activation=[o['a'] for o in o_meta],
                                    predictions_name=[o['n'] for o in o_meta],
                                    classes=[o['u'] for o in o_meta],
                                    embedding_units=[o['e'] for o in o_meta],
                                    joints=[o['j'] for o in o_meta])
        # Restore best parameters.
        print('loading weights from:', joint_weights)
        model.load_weights(joint_weights)
        model = model.get_layer('model_2')

        available_layers = [l.name for l in model.layers]
        if set(selected_layers) - set(available_layers):
            print('available layers:', available_layers)
            raise ValueError('selection contains unknown layers: %s' % selected_layers)

        style_features = [model.get_layer(l).output for l in selected_layers]

        if use_gram_matrix:
            gram_layer = layers.Lambda(gram_matrix, arguments=dict(norm_by_channels=False))
            style_features = [gram_layer(f) for f in style_features]

        model = Model(inputs=model.inputs, outputs=style_features)

    g = ImageDataGenerator(preprocessing_function=get_preprocess_fn(architecture))

    for phase in phases:
        phase_data_dir = os.path.join(data_dir, phase)
        output_file_name = os.path.join(output_dir, phase + '.%i.pickle')
        already_embedded = os.path.exists(output_file_name % 0)
        phase_exists = os.path.exists(phase_data_dir)

        if already_embedded and not override or not phase_exists:
            print('%s transformation skipped' % phase)
            continue

        # Shuffle must always be off in order to keep names consistent.
        data = g.flow_from_directory(phase_data_dir,
                                     target_size=image_shape[:2],
                                     class_mode='sparse',
                                     batch_size=batch_size, shuffle=False,
                                     seed=dataset_seed)
        print('transforming %i %s samples from %s' % (data.n, phase, phase_data_dir))
        part_id = 0
        samples_seen = 0
        displayed_once = False

        while samples_seen < data.n:
            z, y = {n: [] for n in selected_layers}, []
            chunk_size = 0
            chunk_start = samples_seen

            while chunk_size < embedded_files_max_size and samples_seen < data.n:
                _x, _y = next(data)

                outputs = model.predict_on_batch(_x)
                chunk_size += sum(o.nbytes for o in outputs)

                for l, o in zip(selected_layers, outputs):
                    z[l].append(o)

                y.append(_y)
                samples_seen += _x.shape[0]
                chunk_p = int(100 * (samples_seen / data.n))

                if chunk_p % 10 == 0:
                    if not displayed_once:
                        print('\n%i%% (%.2f MB)'
                              % (chunk_p, chunk_size / 1024 ** 2),
                              flush=True, end='')
                        displayed_once = True
                else:
                    displayed_once = False
                    print('.', end='')

            for layer in selected_layers:
                z[layer] = np.concatenate(z[layer])

            with open(output_file_name % part_id, 'wb') as f:
                pickle.dump({'data': z,
                             'target': np.concatenate(y),
                             'names': np.asarray(data.filenames[chunk_start: samples_seen])},
                            f, pickle.HIGHEST_PROTOCOL)
            part_id += 1
    print('done.')
コード例 #4
0
def run(_run, image_shape, data_dir, train_shuffle, dataset_train_seed,
        valid_shuffle, dataset_valid_seed, classes, num_classes, train_info,
        architecture, weights, batch_size, last_base_layer, use_gram_matrix,
        pooling, dense_layers, device, opt_params, dropout_p, resuming_from,
        ckpt_file, steps_per_epoch, epochs, validation_steps, workers,
        use_multiprocessing, initial_epoch, early_stop_patience,
        tensorboard_tag, first_trainable_layer, first_reset_layer,
        class_weight):
    report_dir = _run.observers[0].dir

    y, fs, encoders = load_labels(train_info)
    label_map = dict(zip([os.path.splitext(f)[0] for f in fs], y))

    g = ImageDataGenerator(
        horizontal_flip=True,
        vertical_flip=True,
        zoom_range=.2,
        rotation_range=.2,
        height_shift_range=.2,
        width_shift_range=.2,
        fill_mode='reflect',
        preprocessing_function=get_preprocess_fn(architecture))

    train_data = g.flow_from_directory(os.path.join(data_dir, 'train'),
                                       target_size=image_shape[:2],
                                       classes=classes,
                                       class_mode='sparse',
                                       batch_size=batch_size,
                                       shuffle=train_shuffle,
                                       seed=dataset_train_seed)

    fs = [os.path.basename(f).split('-')[0] for f in train_data.filenames]
    train_data.classes = np.array([label_map[f] for f in fs])

    valid_data = g.flow_from_directory(os.path.join(data_dir, 'valid'),
                                       target_size=image_shape[:2],
                                       classes=classes,
                                       class_mode='sparse',
                                       batch_size=batch_size,
                                       shuffle=valid_shuffle,
                                       seed=dataset_valid_seed)

    fs = [os.path.basename(f).split('-')[0] for f in valid_data.filenames]
    valid_data.classes = np.array([label_map[f] for f in fs])

    del y, fs, encoders

    if class_weight == 'balanced':
        raise ValueError('class_weight is still a little confusing in '
                         'this multi-label problems')

    if steps_per_epoch is None:
        steps_per_epoch = ceil(train_data.n / batch_size)
    if validation_steps is None:
        validation_steps = ceil(valid_data.n / batch_size)

    with tf.device(device):
        print('building...')
        model = build_model(image_shape,
                            architecture=architecture,
                            weights=weights,
                            dropout_p=dropout_p,
                            classes=num_classes,
                            last_base_layer=last_base_layer,
                            use_gram_matrix=use_gram_matrix,
                            pooling=pooling,
                            dense_layers=dense_layers,
                            predictions_activation='sigmoid')

        layer_names = [l.name for l in model.layers]

        if first_trainable_layer:
            if first_trainable_layer not in layer_names:
                raise ValueError('%s is not a layer in the model: %s' %
                                 (first_trainable_layer, layer_names))

            _trainable = False
            for layer in model.layers:
                if layer.name == first_trainable_layer:
                    _trainable = True
                layer.trainable = _trainable
            del _trainable

        model.compile(optimizer=optimizers.Adam(**opt_params),
                      metrics=[
                          'binary_accuracy', 'categorical_accuracy',
                          'top_k_categorical_accuracy'
                      ],
                      loss='binary_crossentropy')

        if resuming_from:
            print('re-loading weights...')
            model.load_weights(resuming_from)

        if first_reset_layer:
            if first_reset_layer not in layer_names:
                raise ValueError('%s is not a layer in the model: %s' %
                                 (first_reset_layer, layer_names))
            print('first layer to have its weights reset:', first_reset_layer)
            random_model = build_model(image_shape,
                                       architecture=architecture,
                                       weights=None,
                                       dropout_p=dropout_p,
                                       classes=num_classes,
                                       last_base_layer=last_base_layer,
                                       use_gram_matrix=use_gram_matrix,
                                       dense_layers=dense_layers,
                                       predictions_activation='sigmoid')
            _reset = False
            for layer, random_layer in zip(model.layers, random_model.layers):
                if layer.name == first_reset_layer:
                    _reset = True
                if _reset:
                    layer.set_weights(random_layer.get_weights())
            del random_model

            model.compile(optimizer=optimizers.Adam(**opt_params),
                          metrics=['cateprocal_accuracy'],
                          loss='binary_crossentropy')

        print('training from epoch %i...' % initial_epoch)
        try:
            model.fit_generator(
                train_data,
                steps_per_epoch=steps_per_epoch,
                epochs=epochs,
                validation_data=valid_data,
                validation_steps=validation_steps,
                initial_epoch=initial_epoch,
                verbose=2,
                class_weight=None,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                callbacks=[
                    callbacks.ReduceLROnPlateau(min_lr=1e-10,
                                                patience=int(
                                                    early_stop_patience // 3)),
                    callbacks.EarlyStopping(patience=early_stop_patience),
                    callbacks.TensorBoard(tensorboard_tag,
                                          batch_size=batch_size),
                    callbacks.ModelCheckpoint(ckpt_file,
                                              save_best_only=True,
                                              verbose=1),
                    callbacks.TensorBoard(os.path.join(report_dir,
                                                       tensorboard_tag),
                                          batch_size=batch_size),
                    callbacks.ModelCheckpoint(os.path.join(
                        report_dir, ckpt_file),
                                              save_best_only=True,
                                              verbose=1),
                ])

        except KeyboardInterrupt:
            print('interrupted by user')
        else:
            print('done')
コード例 #5
0
def run(image_shape, data_dir, valid_pairs, classes, num_classes, architecture,
        weights, batch_size, last_base_layer, pooling, device,
        predictions_activation, dropout_rate, ckpt, validation_steps,
        use_multiprocessing, use_gram_matrix, dense_layers, embedding_units,
        limb_weights, trainable_limbs):
    if isinstance(classes, int):
        classes = sorted(os.listdir(os.path.join(data_dir, 'train')))[:classes]

    g = ImageDataGenerator(
        preprocessing_function=utils.get_preprocess_fn(architecture))
    valid_data = BalancedDirectoryPairsSequence(os.path.join(
        data_dir, 'valid'),
                                                g,
                                                target_size=image_shape[:2],
                                                pairs=valid_pairs,
                                                classes=classes,
                                                batch_size=batch_size)
    if validation_steps is None:
        validation_steps = len(valid_data)

    with tf.device(device):
        print('building...')
        model = build_siamese_model(
            image_shape,
            architecture,
            dropout_rate,
            weights,
            num_classes,
            last_base_layer,
            use_gram_matrix,
            dense_layers,
            pooling,
            include_base_top=False,
            include_top=True,
            predictions_activation=predictions_activation,
            limb_weights=limb_weights,
            trainable_limbs=trainable_limbs,
            embedding_units=embedding_units,
            joints='multiply')
        print('siamese model summary:')
        model.summary()
        if ckpt:
            print('loading weights...')
            model.load_weights(ckpt)

        enqueuer = None
        try:
            enqueuer = OrderedEnqueuer(valid_data,
                                       use_multiprocessing=use_multiprocessing)
            enqueuer.start()
            output_generator = enqueuer.get()

            y, p = [], []
            for step in range(validation_steps):
                x, _y = next(output_generator)
                _p = model.predict(x, batch_size=batch_size)
                y.append(_y)
                p.append(_p)

            y, p = (np.concatenate(e).flatten() for e in (y, p))

            print('actual:', y[:80])
            print('expected:', p[:80])
            print('accuracy:', metrics.accuracy_score(y, p >= 0.5))
            print(metrics.classification_report(y, p >= 0.5))
            print(metrics.confusion_matrix(y, p >= 0.5))

        finally:
            if enqueuer is not None:
                enqueuer.stop()
コード例 #6
0
def run(tag, data_dir, n_classes, phases, classes, data_seed,
        results_file_name, group_patches, batch_size, image_shape,
        architecture, weights, dropout_p, last_base_layer, use_gram_matrix,
        pooling, dense_layers, device, ckpt_file):
    import tensorflow as tf
    from keras import backend as K
    from keras.preprocessing.image import ImageDataGenerator
    from connoisseur.models import build_model
    from connoisseur.utils import get_preprocess_fn

    tf.logging.set_verbosity(tf.logging.ERROR)
    tf_config = tf.ConfigProto(allow_soft_placement=True)
    tf_config.gpu_options.allow_growth = True
    s = tf.Session(config=tf_config)
    K.set_session(s)

    preprocess_input = get_preprocess_fn(architecture)

    g = ImageDataGenerator(samplewise_center=True,
                           samplewise_std_normalization=True,
                           preprocessing_function=None)

    with tf.device(device):
        print('building...')
        model = build_model(image_shape,
                            architecture=architecture,
                            weights=weights,
                            dropout_p=dropout_p,
                            classes=n_classes,
                            last_base_layer=last_base_layer,
                            use_gram_matrix=use_gram_matrix,
                            pooling=pooling,
                            dense_layers=dense_layers)

        if ckpt_file:
            print('re-loading weights...')
            model.load_weights(ckpt_file)

        results = []
        for phase in phases:
            print('\n# %s evaluation' % phase)

            data = g.flow_from_directory(os.path.join(data_dir, phase),
                                         target_size=image_shape[:2],
                                         classes=classes,
                                         batch_size=batch_size,
                                         seed=data_seed,
                                         shuffle=False,
                                         class_mode='sparse')

            steps = ceil(data.n / batch_size)

            probabilities = model.predict_generator(data, steps=steps)
            layer_results = evaluate(probabilities=probabilities,
                                     y=data.classes,
                                     names=data.filenames,
                                     tag=tag,
                                     group_patches=group_patches,
                                     phase=phase)
            layer_results['phase'] = phase
            results.append(layer_results)

    with open(results_file_name, 'w') as file:
        json.dump(results, file)
コード例 #7
0
def run(image_shape, data_dir, submission_info_path, solution_path, data_seed,
        classes, architecture, weights, batch_size, last_base_layer,
        use_gram_matrix, pooling, dense_layers, device, n_classes, dropout_p,
        ckpt_file, binary_strategy, results_file_name):
    import json
    import os
    from math import ceil

    import pandas as pd
    import tensorflow as tf
    from PIL import ImageFile

    from keras import backend as K
    from keras.preprocessing.image import ImageDataGenerator
    from connoisseur.models import build_model
    from connoisseur.utils import get_preprocess_fn

    ImageFile.LOAD_TRUNCATED_IMAGES = True

    tf.logging.set_verbosity(tf.logging.ERROR)
    tf_config = tf.ConfigProto(allow_soft_placement=True)
    tf_config.gpu_options.allow_growth = True
    s = tf.Session(config=tf_config)
    K.set_session(s)

    pairs = pd.read_csv(submission_info_path, quotechar='"',
                        delimiter=',').values[:, 1:]
    y = pd.read_csv(solution_path, quotechar='"', delimiter=',').values[:, 1:]

    pairs = [[
        'unknown/' + os.path.splitext(a)[0],
        'unknown/' + os.path.splitext(b)[0]
    ] for a, b in pairs]

    preprocess_input = get_preprocess_fn(architecture)
    g = ImageDataGenerator(preprocessing_function=preprocess_input)

    with tf.device(device):
        print('building...')
        model = build_model(image_shape,
                            architecture=architecture,
                            weights=weights,
                            dropout_p=dropout_p,
                            classes=n_classes,
                            last_base_layer=last_base_layer,
                            use_gram_matrix=use_gram_matrix,
                            pooling=pooling,
                            dense_layers=dense_layers)

        if ckpt_file:
            print('re-loading weights...')
            model.load_weights(ckpt_file)

        results = []
        for phase in ['test']:
            print('\n# %s evaluation' % phase)

            data = g.flow_from_directory(os.path.join(data_dir, phase),
                                         target_size=image_shape[:2],
                                         classes=classes,
                                         batch_size=batch_size,
                                         seed=data_seed,
                                         shuffle=False)

            steps = ceil(data.n / batch_size)

            probabilities = model.predict_generator(data, steps=steps)
            del model
            K.clear_session()

            layer_results = evaluate(probabilities, y, data.filenames, pairs,
                                     binary_strategy)
            layer_results['phase'] = phase
            results.append(layer_results)

    with open(results_file_name, 'w') as file:
        json.dump(results, file)
コード例 #8
0
def run(_run, image_shape, data_dir, train_shuffle, dataset_train_seed,
        valid_shuffle, dataset_valid_seed, classes, class_mode, class_weight,
        architecture, weights, batch_size, last_base_layer, use_gram_matrix,
        pooling, dense_layers, device, opt_params, dropout_p,
        resuming_from_ckpt_file, steps_per_epoch, epochs, validation_steps,
        workers, use_multiprocessing, initial_epoch, early_stop_patience,
        tensorboard_tag, first_trainable_layer, first_reset_layer):
    report_dir = _run.observers[0].dir

    g = ImageDataGenerator(
        horizontal_flip=True,
        vertical_flip=True,
        samplewise_center=True,
        samplewise_std_normalization=True,
        zoom_range=45,
        rotation_range=.2,
        height_shift_range=.2,
        width_shift_range=.2,
        fill_mode='reflect',
        preprocessing_function=get_preprocess_fn(architecture))

    if isinstance(classes, int):
        classes = sorted(os.listdir(os.path.join(data_dir, 'train')))[:classes]

    train_data = g.flow_from_directory(os.path.join(data_dir, 'train'),
                                       target_size=image_shape[:2],
                                       classes=classes,
                                       class_mode=class_mode,
                                       batch_size=batch_size,
                                       shuffle=train_shuffle,
                                       seed=dataset_train_seed)

    valid_data = g.flow_from_directory(os.path.join(data_dir, 'valid'),
                                       target_size=image_shape[:2],
                                       classes=classes,
                                       class_mode=class_mode,
                                       batch_size=batch_size,
                                       shuffle=valid_shuffle,
                                       seed=dataset_valid_seed)

    if class_weight == 'balanced':
        class_weight = get_class_weights(train_data.classes)

    if steps_per_epoch is None:
        steps_per_epoch = ceil(train_data.n / batch_size)
    if validation_steps is None:
        validation_steps = ceil(valid_data.n / batch_size)

    with tf.device(device):
        print('building...')
        model = build_model(image_shape,
                            architecture=architecture,
                            weights=weights,
                            dropout_p=dropout_p,
                            classes=train_data.num_classes,
                            last_base_layer=last_base_layer,
                            use_gram_matrix=use_gram_matrix,
                            pooling=pooling,
                            dense_layers=dense_layers)

        layer_names = [l.name for l in model.layers]

        if first_trainable_layer:
            if first_trainable_layer not in layer_names:
                raise ValueError('%s is not a layer in the model: %s' %
                                 (first_trainable_layer, layer_names))

            for layer in model.layers:
                if layer.name == first_trainable_layer:
                    break
                layer.trainable = False

        model.compile(
            optimizer=optimizers.Adam(**opt_params),
            metrics=['categorical_accuracy', 'top_k_categorical_accuracy'],
            loss='categorical_crossentropy')

        if resuming_from_ckpt_file:
            print('re-loading weights...')
            model.load_weights(resuming_from_ckpt_file)

        if first_reset_layer:
            if first_reset_layer not in layer_names:
                raise ValueError('%s is not a layer in the model: %s' %
                                 (first_reset_layer, layer_names))
            print('first layer to have its weights reset:', first_reset_layer)
            random_model = build_model(image_shape,
                                       architecture=architecture,
                                       weights=None,
                                       dropout_p=dropout_p,
                                       classes=train_data.num_class,
                                       last_base_layer=last_base_layer,
                                       use_gram_matrix=use_gram_matrix,
                                       dense_layers=dense_layers)
            _reset = False
            for layer, random_layer in zip(model.layers, random_model.layers):
                if layer.name == first_reset_layer:
                    _reset = True
                if _reset:
                    layer.set_weights(random_layer.get_weights())
            del random_model

            model.compile(optimizer=optimizers.Adam(**opt_params),
                          metrics=['accuracy'],
                          loss='categorical_crossentropy')

        print('training from epoch %i...' % initial_epoch)
        try:
            model.fit_generator(
                train_data,
                steps_per_epoch=steps_per_epoch,
                epochs=epochs,
                verbose=2,
                validation_data=valid_data,
                validation_steps=validation_steps,
                initial_epoch=initial_epoch,
                class_weight=class_weight,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                callbacks=[
                    # callbacks.LearningRateScheduler(lambda epoch: .5 ** (epoch // 10) * opt_params['lr']),
                    callbacks.TerminateOnNaN(),
                    callbacks.ReduceLROnPlateau(min_lr=1e-10,
                                                patience=int(
                                                    early_stop_patience // 3)),
                    callbacks.EarlyStopping(patience=early_stop_patience),
                    callbacks.TensorBoard(os.path.join(report_dir,
                                                       tensorboard_tag),
                                          batch_size=batch_size),
                    callbacks.ModelCheckpoint(os.path.join(
                        report_dir, 'weights.h5'),
                                              save_best_only=True,
                                              verbose=1),
                ])
        except KeyboardInterrupt:
            print('interrupted by user')
        else:
            print('done')