コード例 #1
0
ファイル: finetune_orig.py プロジェクト: kcyu1993/keras
def model_finetune(base_model,
                   pred_model,
                   optimizer,
                   weights_path=RESNET_BASELINE_WEIGHTS_PATH,
                   loss='categorical_crossentropy',
                   metric=['accuracy']):
    """
    Create a new model for fine-tune

    Parameters
    ----------
    base_model
    pred_model
    weights_path
    optimizer
    loss
    metric

    Returns
    -------

    """
    # Freeze the layers
    toggle_trainable_layers(base_model, False)
    base_model.load_weights(weights_path, by_name=True)
    new_model = Model(input=base_model.input,
                      output=pred_model.output,
                      name=base_model.name + "_" + pred_model.name)

    new_model.compile(optimizer, loss, metric)
    new_model.summary()
    return new_model, base_model, pred_model
コード例 #2
0
ファイル: finetune.py プロジェクト: kcyu1993/keras
def minc2500_finetune(model,
                      nb_epoch_finetune=100,
                      nb_epoch_after=0,
                      batch_size=32,
                      image_gen=None,
                      title='MINC2500_finetune',
                      early_stop=False,
                      keyword='',
                      optimizer=None,
                      log=True,
                      lr_decay=True,
                      verbose=2,
                      weight_path='',
                      load=False,
                      lr=0.001):

    # weight_path = RESNET_BASELINE_WEIGHTS_PATH
    load = False

    # loader = Minc2500()
    train, test = load_minc2500(index=1,
                                target_size=TARGET_SIZE,
                                gen=image_gen,
                                batch_size=batch_size)
    model.compile(optimizer,
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    fit_model_v2(model, [train, test],
                 batch_size=batch_size,
                 title=title,
                 nb_epoch=nb_epoch_finetune,
                 optimizer=optimizer,
                 early_stop=early_stop,
                 verbose=verbose,
                 lr_decay=lr_decay,
                 log=log,
                 weight_path=weight_path,
                 load=load,
                 lr=lr)
    tmp_weights = get_tmp_weights_path(model.name)
    model.save_weights(tmp_weights)
    if nb_epoch_after > 0:
        # K.clear_session()
        toggle_trainable_layers(model, True, keyword)
        model.compile(optimizer,
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])
        # model.load_weights(tmp_weights)
        fit_model_v2(model, [train, test],
                     batch_size=batch_size,
                     title=title,
                     nb_epoch=nb_epoch_after,
                     optimizer=optimizer,
                     early_stop=early_stop,
                     verbose=verbose,
                     lr_decay=lr_decay,
                     lr=lr / 10)

    return
コード例 #3
0
ファイル: resnet.py プロジェクト: kcyu1993/keras
def ResNet50_cifar_o1(denses=[],
                      nb_classes=10,
                      input_shape=None,
                      load_weights=True,
                      freeze_conv=False,
                      last_conv_feature_maps=[],
                      batch_norm=True):
    """
    Create ResNet50 based on without_top.

    Parameters
    ----------
    denses : list[int]  dense layer parameters
    nb_classes : int    nb of classes
    input_shape : tuple input shape

    Returns
    -------
    Model
    """
    if last_conv_feature_maps == []:
        if load_weights:
            model = ResNet50CIFAR(include_top=False, input_shape=input_shape)
        else:
            model = ResNet50(include_top=False,
                             weights=None,
                             input_shape=input_shape)
    else:
        if load_weights:
            res_model = ResNet50(include_top=False,
                                 input_shape=input_shape,
                                 last_avg=False)
        else:
            res_model = ResNet50(include_top=False,
                                 weights=None,
                                 input_shape=input_shape,
                                 last_avg=False)
        x = res_model.output
        for ind, feature_dim in enumerate(last_conv_feature_maps):
            x = Convolution2D(feature_dim,
                              1,
                              1,
                              activation='relu',
                              name='1x1_conv_{}'.format(ind))(x)
        x = AveragePooling2D((7, 7), name='avg_pool')(x)
        model = Model(res_model.input, x, name='resnet50_with_1x1')

    # Create Dense layers
    x = model.output
    x = Flatten()(x)
    for ind, dense in enumerate(denses):
        x = Dense(dense, activation='relu', name='fc' + str(ind + 1))(x)
    # Prediction
    x = Dense(nb_classes, activation='softmax', name='prediction')(x)
    if freeze_conv:
        toggle_trainable_layers(model, trainable=False)
    new_model = Model(model.input, x, name='resnet50_o1')
    return new_model
コード例 #4
0
ファイル: finetune_orig.py プロジェクト: kcyu1993/keras
def mincorig_finetune(model,
                      nb_epoch_finetune=100,
                      nb_epoch_after=0,
                      batch_size=32,
                      image_gen=None,
                      title='MINCorig_finetune',
                      early_stop=False,
                      keyword='',
                      optimizer=None,
                      log=True,
                      verbose=2,
                      lr=0.001):

    loader = MincOriginal()
    train = loader.generator('train.txt',
                             batch_size=batch_size,
                             gen=image_gen,
                             target_size=TARGET_SIZE)
    test = loader.generator('validate.txt',
                            batch_size=batch_size,
                            gen=image_gen,
                            target_size=TARGET_SIZE)

    model.compile(optimizer,
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    fit_model_v2(model, [train, test],
                 batch_size=batch_size,
                 title=title,
                 nb_epoch=nb_epoch_finetune,
                 optimizer=optimizer,
                 early_stop=early_stop,
                 verbose=verbose,
                 log=log,
                 lr=lr)
    tmp_weights = get_tmp_weights_path(model.name)
    model.save_weights(tmp_weights)
    if nb_epoch_after > 0:
        # K.clear_session()
        toggle_trainable_layers(model, True, keyword)
        model.compile(optimizer,
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])
        # model.load_weights(tmp_weights)
        fit_model_v2(model, [train, test],
                     batch_size=batch_size,
                     title=title,
                     nb_epoch=nb_epoch_after,
                     optimizer=optimizer,
                     early_stop=early_stop,
                     verbose=verbose,
                     lr=lr / 10)

    return
コード例 #5
0
ファイル: finetune.py プロジェクト: kcyu1993/keras
def dtd_finetune(model,
                 nb_epoch_finetune=100,
                 nb_epoch_after=0,
                 batch_size=32,
                 image_gen=None,
                 title='dtd_finetune',
                 early_stop=False,
                 keyword='',
                 log=True,
                 lr_decay=True,
                 optimizer=None,
                 verbose=2,
                 weight_path='',
                 load=False,
                 lr=0.001):

    train, _, test = load_dtd(True, image_gen=image_gen, batch_size=batch_size)
    load = True
    # weight_path = '/home/kyu/.keras/models/tmp/VGG16_o2_para-mode_1_matbp_784_finetune.weights'
    model.compile(optimizer,
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    fit_model_v2(model, [train, test],
                 batch_size=batch_size,
                 title=title,
                 nb_epoch=nb_epoch_finetune,
                 optimizer=optimizer,
                 early_stop=early_stop,
                 log=log,
                 lr_decay=lr_decay,
                 verbose=verbose,
                 load=load,
                 weight_path=weight_path,
                 lr=lr)
    tmp_weights = get_tmp_weights_path(model.name)
    model.save_weights(tmp_weights)
    if nb_epoch_after > 0:
        # K.clear_session()
        toggle_trainable_layers(model, True, keyword)
        model.compile(optimizer,
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])
        # model.load_weights(tmp_weights)
        fit_model_v2(model, [train, test],
                     batch_size=batch_size,
                     title=title,
                     nb_epoch=nb_epoch_after,
                     optimizer=optimizer,
                     lr_decay=lr_decay,
                     early_stop=early_stop,
                     verbose=verbose,
                     lr=lr / 10)

    return
コード例 #6
0
ファイル: train.py プロジェクト: kcyu1993/keras
def cifar_train(model,
                nb_epoch_finetune=100,
                nb_epoch_after=0,
                batch_size=32,
                image_gen=None,
                title='cifar10_train',
                early_stop=True,
                keyword='',
                optimizer=None,
                log=True,
                verbose=2,
                lr_decay=False,
                lr=0.01):

    train, test = cifar10_data()

    if optimizer is None:
        optimizer = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)

    model.compile(optimizer,
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    fit_model_v2(model, [train, test],
                 batch_size=batch_size,
                 title=title,
                 nb_epoch=nb_epoch_finetune,
                 optimizer=optimizer,
                 early_stop=early_stop,
                 verbose=verbose,
                 log=log,
                 lr=lr,
                 lr_decay=lr_decay)
    tmp_weights = get_tmp_weights_path(model.name)
    model.save_weights(tmp_weights)
    if nb_epoch_after > 0:
        if nb_epoch_finetune > 0:
            lr /= 10
        # K.clear_session()
        toggle_trainable_layers(model, True, keyword)
        model.compile(optimizer,
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])
        # model.load_weights(tmp_weights)
        fit_model_v2(model, [train, test],
                     batch_size=batch_size,
                     title=title,
                     nb_epoch=nb_epoch_after,
                     optimizer=optimizer,
                     early_stop=early_stop,
                     verbose=verbose,
                     lr_decay=lr_decay,
                     lr=lr)
        model.save_weights()
コード例 #7
0
ファイル: train.py プロジェクト: kcyu1993/keras
def imagenet_finetune(model,
                      nb_epoch_finetune=100,
                      nb_epoch_after=0,
                      batch_size=32,
                      image_gen=None,
                      title='ImageNet_finetune',
                      early_stop=False,
                      keyword='',
                      optimizer=None,
                      log=True,
                      lr_decay=True,
                      verbose=2,
                      lr=0.001):

    train, test = loadImageNet()

    model.compile(optimizer,
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    fit_model_v2(model, [train, test],
                 batch_size=batch_size,
                 title=title,
                 nb_epoch=nb_epoch_finetune,
                 optimizer=optimizer,
                 early_stop=early_stop,
                 verbose=verbose,
                 lr_decay=lr_decay,
                 log=log,
                 lr=lr)
    tmp_weights = get_tmp_weights_path(model.name)
    model.save_weights(tmp_weights)
    if nb_epoch_after > 0:
        # K.clear_session()
        toggle_trainable_layers(model, True, keyword)
        model.compile(optimizer,
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])
        # model.load_weights(tmp_weights)
        fit_model_v2(model, [train, test],
                     batch_size=batch_size,
                     title=title,
                     nb_epoch=nb_epoch_after,
                     optimizer=optimizer,
                     early_stop=early_stop,
                     verbose=verbose,
                     lr_decay=lr_decay,
                     lr=lr / 10)

    return
コード例 #8
0
ファイル: finetune.py プロジェクト: kcyu1993/keras
def mit_indoor_finetune(model,
                        nb_epoch_finetune=100, nb_epoch_after=0, batch_size=32,
                        image_gen=None,
                        title='mit_indoor_finetune', early_stop=False,
                        keyword='',
                        optimizer=None,
                        log=True,
                        lr_decay=True,
                        weight_path='',
                        load=False,
                        verbose=2,
                        lr=0.01):
    lr_decay = True
    loader = MITLoader(dirpath='/home/kyu/cvkyu/dataset/mit_indoor')
    # train = loader.generator(mode='train', target_size=TARGET_SIZE, image_data_generator=image_gen, batch_size=batch_size)
    # train = loader.generator(mode='complete', target_size=TARGET_SIZE, image_data_generator=image_gen, batch_size=batch_size)
    train = loader.generator(mode='complete_train', target_size=TARGET_SIZE, image_data_generator=image_gen, batch_size=batch_size)
    test = loader.generator(mode='test', target_size=TARGET_SIZE, image_data_generator=image_gen, batch_size=batch_size)

    # model.compile(optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
    fit_model_v2(model, [train, test], batch_size=batch_size, title=title,
                 nb_epoch=nb_epoch_finetune,
                 optimizer=optimizer,
                 early_stop=early_stop,
                 verbose=verbose,
                 lr_decay=lr_decay,
                 weight_path=weight_path,
                 load=load,
                 log=log,
                 lr=lr)
    tmp_weights = get_tmp_weights_path(model.name)
    model.save_weights(tmp_weights)
    if nb_epoch_after > 0:
        # K.clear_session()
        toggle_trainable_layers(model, True, keyword)
        # model.compile(optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
        # model.load_weights(tmp_weights)
        fit_model_v2(model, [train, test], batch_size=batch_size, title=title,
                     nb_epoch=nb_epoch_after,
                     optimizer=optimizer,
                     early_stop=early_stop,
                     verbose=verbose,
                     lr_decay=lr_decay,
                     lr=lr/10)

    return
コード例 #9
0
ファイル: so_cnn_helper.py プロジェクト: kcyu1993/keras
def dcov_model_wrapper_v2(base_model,
                          parametrics=[],
                          mode=0,
                          nb_classes=1000,
                          basename='',
                          cov_mode='channel',
                          cov_branch='o2transform',
                          cov_branch_output=None,
                          freeze_conv=False,
                          cov_regularizer=None,
                          nb_branch=1,
                          concat='concat',
                          last_conv_feature_maps=[],
                          upsample_method='conv',
                          regroup=False,
                          **kwargs):
    """
    Wrapper for any base model, attach right after the last layer of given model

    Parameters
    ----------
    base_model
    parametrics
    mode
    nb_classes
    input_shape
    load_weights
    cov_mode
    cov_branch
    cov_branch_output
    cov_block_mode
    last_avg
    freeze_conv

    Returns
    -------

    """
    cov_branch_mode = cov_branch
    # Function name
    covariance_block = get_cov_block(cov_branch)

    if cov_branch_output is None:
        cov_branch_output = nb_classes

    x = base_model.output

    x = upsample_wrapper_v1(x,
                            last_conv_feature_maps,
                            upsample_method,
                            kernel=[1, 1])

    def split_keras_tensor_according_axis(x, nb_split, axis, axis_dim):
        outputs = []
        split_dim = axis_dim / nb_split
        split_loc = [split_dim * i for i in range(nb_split)]
        split_loc.append(-1)
        for i in range(nb_split):
            outputs.append(x[:, :, :, split_loc[i]:split_loc[i + 1]])
        return outputs

    cov_input = SeparateConvolutionFeatures(nb_branch)(x)
    if regroup:
        with tf.device('/gpu:0'):
            cov_input = Regrouping(None)(cov_input)
    cov_outputs = []
    for ind, x in enumerate(cov_input):
        if mode == 0:
            x = Flatten()(x)
            for ind, param in enumerate(parametrics):
                x = Dense(param, activation='relu', name='fc{}'.format(ind))(x)
            # x = Dense(nb_classes, activation='softmax')(x)

        if mode == 1:
            cov_branch = covariance_block(x,
                                          cov_branch_output,
                                          stage=5,
                                          block=str(ind),
                                          parametric=parametrics,
                                          cov_mode=cov_mode,
                                          cov_regularizer=cov_regularizer,
                                          **kwargs)
            x = cov_branch
            # x = Dense(nb_classes, activation='softmax', name='predictions')(cov_branch)

        elif mode == 2:
            cov_branch = covariance_block(x,
                                          cov_branch_output,
                                          stage=5,
                                          block=str(ind),
                                          parametric=parametrics,
                                          cov_regularizer=cov_regularizer,
                                          **kwargs)
            x = Flatten()(x)
            x = Dense(nb_classes, activation='relu', name='fc')(x)
            x = merge([x, cov_branch], mode='concat', name='concat')
            # x = Dense(nb_classes, activation='softmax', name='predictions')(x)
        elif mode == 3:
            cov_branch = covariance_block(x,
                                          cov_branch_output,
                                          stage=5,
                                          block=str(ind),
                                          parametric=parametrics,
                                          cov_mode=cov_mode,
                                          cov_regularizer=cov_regularizer,
                                          o2t_constraints='UnitNorm',
                                          **kwargs)
            x = cov_branch
        cov_outputs.append(x)

    if concat == 'concat':
        if cov_branch_mode == 'o2t_no_wv' or cov_branch_mode == 'corr_no_wv':
            x = MatrixConcat(cov_outputs,
                             name='Matrix_diag_concat')(cov_outputs)
            x = WeightedVectorization(cov_branch_output * nb_branch,
                                      name='WV_big')(x)
        else:
            x = merge(cov_outputs, mode='concat', name='merge')
    elif concat == 'sum':
        x = merge(cov_outputs, mode='sum', name='sum')
        if cov_branch_mode == 'o2t_no_wv':
            x = WeightedVectorization(cov_branch_output, name='wv_sum')(x)
    elif concat == 'ave':
        x = merge(cov_outputs, mode='ave', name='ave')
        if cov_branch_mode == 'o2t_no_wv':
            x = WeightedVectorization(cov_branch_output, name='wv_sum')(x)
    else:
        raise RuntimeError("concat mode not support : " + concat)

    if freeze_conv:
        toggle_trainable_layers(base_model, not freeze_conv)

    # x = Dense(cov_branch_output * nb_branch, activation='relu', name='Dense_b')(x)
    x = Dense(nb_classes, activation='softmax')(x)

    model = Model(base_model.input, x, name=basename)
    return model
コード例 #10
0
ファイル: so_cnn_helper.py プロジェクト: kcyu1993/keras
def dcov_model_wrapper_v1(base_model,
                          parametrics=[],
                          mode=0,
                          nb_classes=1000,
                          basename='',
                          cov_mode='channel',
                          cov_branch='o2transform',
                          cov_branch_output=None,
                          freeze_conv=False,
                          cov_regularizer=None,
                          nb_branch=1,
                          concat='concat',
                          last_conv_feature_maps=[],
                          upsample_method='conv',
                          regroup=False,
                          **kwargs):
    """
    Wrapper for any base model, attach right after the last layer of given model

    Parameters
    ----------
    base_model
    parametrics
    mode
    nb_classes
    input_shape
    load_weights
    cov_mode
    cov_branch
    cov_branch_output
    cov_block_mode
    last_avg
    freeze_conv

    Returns
    -------

    """

    # Function name
    covariance_block = get_cov_block(cov_branch)

    if cov_branch_output is None:
        cov_branch_output = nb_classes

    x = base_model.output

    x = upsample_wrapper_v1(x,
                            last_conv_feature_maps,
                            upsample_method,
                            kernel=[1, 1])

    cov_input = x
    if mode == 0:
        x = Flatten()(x)
        for ind, param in enumerate(parametrics):
            x = Dense(param, activation='relu', name='fc{}'.format(ind))(x)
        x = Dense(nb_classes, activation='softmax')(x)

    if mode == 1:
        if nb_branch == 1:
            cov_branch = covariance_block(cov_input,
                                          cov_branch_output,
                                          stage=5,
                                          block='a',
                                          parametric=parametrics,
                                          cov_mode=cov_mode,
                                          cov_regularizer=cov_regularizer,
                                          **kwargs)
            x = Dense(nb_classes, activation='softmax',
                      name='predictions')(cov_branch)
        elif nb_branch > 1:
            pass

    elif mode == 2:
        cov_branch = covariance_block(cov_input,
                                      cov_branch_output,
                                      stage=5,
                                      block='a',
                                      parametric=parametrics,
                                      cov_regularizer=cov_regularizer,
                                      **kwargs)
        x = Flatten()(x)
        x = Dense(nb_classes, activation='relu', name='fc')(x)
        x = merge([x, cov_branch], mode='concat', name='concat')
        x = Dense(nb_classes, activation='softmax', name='predictions')(x)
    elif mode == 3:
        if nb_branch == 1:

            cov_branch = covariance_block(cov_input,
                                          cov_branch_output,
                                          stage=5,
                                          block='a',
                                          parametric=parametrics,
                                          cov_mode=cov_mode,
                                          cov_regularizer=cov_regularizer,
                                          o2t_constraints='UnitNorm',
                                          **kwargs)
            x = Dense(nb_classes, activation='softmax',
                      name='predictions')(cov_branch)
        elif nb_branch > 1:
            pass

    if freeze_conv:
        toggle_trainable_layers(base_model, not freeze_conv)

    model = Model(base_model.input, x, name=basename)
    return model
コード例 #11
0
ファイル: so_cnn_helper.py プロジェクト: kcyu1993/keras
def dcov_multi_out_model_wrapper(base_model,
                                 parametrics=[],
                                 mode=0,
                                 nb_classes=1000,
                                 basename='',
                                 cov_mode='channel',
                                 cov_branch='o2t_no_wv',
                                 cov_branch_output=None,
                                 freeze_conv=False,
                                 cov_regularizer=None,
                                 nb_branch=1,
                                 concat='concat',
                                 last_conv_feature_maps=[],
                                 upsample_method='conv',
                                 regroup=False,
                                 **kwargs):
    """
    Wrapper for any multi output base model, attach right after the last layer of given model

    Parameters
    ----------
    base_model
    parametrics
    mode
    nb_classes
    input_shape
    load_weights
    cov_mode
    cov_branch
    cov_branch_output
    cov_block_mode
    last_avg
    freeze_conv

    mode 1: 1x1 reduce dim

    Returns
    -------

    """
    cov_branch_mode = cov_branch
    # Function name
    covariance_block = get_cov_block(cov_branch)

    if cov_branch_output is None:
        cov_branch_output = nb_classes
    # 256, 512, 512
    block1, block2, block3 = outputs = base_model.outputs
    print("===================")
    cov_outputs = []
    if mode == 1:
        print("Model design : ResNet_o2_multi_branch 1x1 conv to reduce dim ")
        """ 1x1 conv to reduce dim """
        # Starting from block3
        block3 = upsample_wrapper_v1(block3, [1024, 512])
        block2 = upsample_wrapper_v1(block2, [512])
        block2 = MaxPooling2D()(block2)
        block1 = MaxPooling2D(pool_size=(4, 4))(block1)
        outputs = [block1, block2, block3]
        for ind, x in enumerate(outputs):
            cov_branch = covariance_block(x,
                                          cov_branch_output,
                                          stage=5,
                                          block=str(ind),
                                          parametric=parametrics,
                                          cov_mode=cov_mode,
                                          cov_regularizer=cov_regularizer,
                                          **kwargs)
            x = cov_branch
            cov_outputs.append(x)
    elif mode == 2 or mode == 3:
        """ Use branchs to reduce dim """
        block3 = SeparateConvolutionFeatures(4)(block3)
        block2 = SeparateConvolutionFeatures(2)(block2)
        block1 = MaxPooling2D()(block1)
        block1 = [block1]
        outputs = [block1, block2, block3]
        for ind, outs in enumerate(outputs):
            block_outs = []
            for ind2, x in enumerate(outs):
                cov_branch = covariance_block(x,
                                              cov_branch_output,
                                              stage=5,
                                              block=str(ind) + '_' + str(ind2),
                                              parametric=parametrics,
                                              cov_mode=cov_mode,
                                              cov_regularizer=cov_regularizer,
                                              **kwargs)
                x = cov_branch
                block_outs.append(x)
            if mode == 3:
                """ Sum block covariance output """
                if len(block_outs) > 1:
                    o = merge(block_outs,
                              mode='sum',
                              name='multibranch_sum_{}'.format(ind))
                    o = WeightedVectorization(cov_branch_output)(o)
                    cov_outputs.append(o)
                else:
                    a = block_outs[0]
                    if 'o2t' in a.name:
                        a = WeightedVectorization(cov_branch_output)(a)
                    cov_outputs.append(a)
            else:
                cov_outputs.extend(block_outs)
    elif mode == 4:
        """ Use the similar structure to Feature Pyramid Network """
        # supplimentary stream
        block1 = upsample_wrapper_v1(block1, [256], stage='block1')
        block2 = upsample_wrapper_v1(block2, [256], stage='block2')
        # main stream
        block3 = upsample_wrapper_v1(block3, [512], stage='block3')

        cov_input = SeparateConvolutionFeatures(nb_branch)(block3)
        cov_outputs = []
        for ind, x in enumerate(cov_input):

            cov_branch = covariance_block(x,
                                          cov_branch_output,
                                          stage=5,
                                          block=str(ind),
                                          parametric=parametrics,
                                          cov_mode=cov_mode,
                                          cov_regularizer=cov_regularizer,
                                          normalization=False,
                                          **kwargs)
            x = cov_branch
            cov_outputs.append(x)

        x = MatrixConcat(cov_outputs, name='Matrix_diag_concat')(cov_outputs)
        x = O2Transform(64, activation='relu', name='o2t_mainst_1')(x)

        block2 = SecondaryStatistic(name='cov_block2',
                                    cov_mode='pmean',
                                    robust=False,
                                    eps=1e-5)(block2)
        block2 = O2Transform(64, activation='relu', name='o2t_block2')(block2)

        # fuse = merge([block2, x], mode='sum')
        # fuse = O2Transform(64, activation='relu', name='o2t_mainst_2')(fuse)

        block1 = SecondaryStatistic(name='cov_block1',
                                    cov_mode='pmean',
                                    robust=False,
                                    eps=1e-5)(block1)
        block1 = O2Transform(64, activation='relu', name='o2t_block1')(block1)

        # fuse = merge([fuse, block1], mode='sum')

        x = MatrixConcat([x, block1, block2],
                         name='Matrix_diag_concat_all')([x, block1, block2])
        x = WeightedVectorization(128, activation='relu', name='wv_fuse')(x)

        # Merge the last matrix for matrix concat

    if freeze_conv:
        toggle_trainable_layers(base_model, not freeze_conv)

    x = Dense(nb_classes, activation='softmax')(x)

    model = Model(base_model.input, x, name=basename)
    return model