Пример #1
0
def comResnet(input, d):
    activation = d.act
    inputShape = (100, 100, 3)
    channelAxis = 1 if K.image_data_format() == 'channels_first' else -1
    convArgs = {
        "padding": "same",
        "use_bias": False,
        "kernel_regularizer": l2(0.0001)
    }
    bnArgs = {"axis": channelAxis, "momentum": 0.9, "epsilon": 1e-04}
    convArgs.update({
        "spectral_parametrization": d.spectral_param,
        "kernel_initializer": d.comp_init
    })

    O = learnConcatRealImagBlock(input, (1, 1), (3, 3), 0, '0', convArgs,
                                 bnArgs, d)
    #O = tf.concat([input, O], 1)
    O = Concatenate(channelAxis)([input, O])
    O = ComplexConv2D(filters=64, kernel_size=9, name='conv1', **convArgs)(O)
    O = ComplexBN(name='bn conv1 2a', **bnArgs)(O)
    O = tf.nn.relu(O)

    # residual
    O = getResidualBlock(O, (3, 3), [64, 64], 2, '0', 'regular', convArgs,
                         bnArgs, d)
    O = getResidualBlock(O, (3, 3), [64, 64], 2, '0', 'regular', convArgs,
                         bnArgs, d)
    O = getResidualBlock(O, (3, 3), [64, 64], 2, '0', 'regular', convArgs,
                         bnArgs, d)
    O = getResidualBlock(O, (3, 3), [64, 64], 2, '0', 'regular', convArgs,
                         bnArgs, d)

    O = ComplexConv2D(filters=64, kernel_size=9, name='conv2', **convArgs)(O)
    O = ComplexBN(name='bn conv2 1a', **bnArgs)(O)
    O = tf.nn.relu(O)

    O = ComplexConv2D(filters=64, kernel_size=9, name='conv3', **convArgs)(O)
    O = ComplexBN(name='bn conv3 1a', **bnArgs)(O)
    O = tf.nn.relu(O)

    O = ComplexConv2D(filters=3, kernel_size=9, name='conv4', **convArgs)(O)
    O = ComplexBN(name='bn conv4 1a', **bnArgs)(O)
    O = tf.nn.tanh(O)
    return O
Пример #2
0
def getResidualBlock(I, filter_size, featmaps, stage, block, shortcut,
                     convArgs, bnArgs, d):
    """Get residual block."""

    activation = d.act
    drop_prob = d.dropout
    nb_fmaps1, nb_fmaps2 = featmaps
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    if K.image_data_format() == 'channels_first' and K.ndim(I) != 3:
        channel_axis = 1
    else:
        channel_axis = -1

    if d.model == "real":
        O = BatchNormalization(name=bn_name_base + '_2a', **bnArgs)(I)
    elif d.model == "complex":
        O = ComplexBN(name=bn_name_base + '_2a', **bnArgs)(I)
    O = Activation(activation)(O)

    if shortcut == 'regular' or d.spectral_pool_scheme == "nodownsample":
        if d.model == "real":
            O = Conv2D(nb_fmaps1,
                       filter_size,
                       name=conv_name_base + '2a',
                       **convArgs)(O)
        elif d.model == "complex":
            O = ComplexConv2D(nb_fmaps1,
                              filter_size,
                              name=conv_name_base + '2a',
                              **convArgs)(O)
    elif shortcut == 'projection':
        if d.spectral_pool_scheme == "proj":
            O = applySpectralPooling(O, d)
        if d.model == "real":
            O = Conv2D(nb_fmaps1,
                       filter_size,
                       name=conv_name_base + '2a',
                       strides=(2, 2),
                       **convArgs)(O)
        elif d.model == "complex":
            O = ComplexConv2D(nb_fmaps1,
                              filter_size,
                              name=conv_name_base + '2a',
                              strides=(2, 2),
                              **convArgs)(O)

    if d.model == "real":
        O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O)
        O = Activation(activation)(O)
        O = Conv2D(nb_fmaps2,
                   filter_size,
                   name=conv_name_base + '2b',
                   **convArgs)(O)
    elif d.model == "complex":
        O = ComplexBN(name=bn_name_base + '_2b', **bnArgs)(O)
        O = Activation(activation)(O)
        O = ComplexConv2D(nb_fmaps2,
                          filter_size,
                          name=conv_name_base + '2b',
                          **convArgs)(O)

    if shortcut == 'regular':
        O = Add()([O, I])
    elif shortcut == 'projection':
        if d.spectral_pool_scheme == "proj":
            I = applySpectralPooling(I, d)
        if d.model == "real":
            X = Conv2D(
                nb_fmaps2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)
            O = Concatenate(channel_axis)([X, O])
        elif d.model == "complex":
            X = ComplexConv2D(
                nb_fmaps2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)

            O_real = Concatenate(channel_axis)([GetReal()(X), GetReal()(O)])
            O_imag = Concatenate(channel_axis)([GetImag()(X), GetImag()(O)])
            O = Concatenate(1)([O_real, O_imag])

    return O
Пример #3
0
def getResnetModel(d):
    n = d.num_blocks
    sf = d.start_filter
    dataset = d.dataset
    activation = d.act
    advanced_act = d.aact
    drop_prob = d.dropout
    inputShape = (3, 32, 32) if K.image_dim_ordering() == "th" else (32, 32, 3)
    channelAxis = 1 if K.image_data_format() == 'channels_first' else -1
    filsize = (3, 3)
    convArgs = {
        "padding": "same",
        "use_bias": False,
        "kernel_regularizer": l2(0.0001),
    }
    bnArgs = {"axis": channelAxis, "momentum": 0.9, "epsilon": 1e-04}

    if d.model == "real":
        sf *= 2
        convArgs.update({"kernel_initializer": Orthogonal(float(np.sqrt(2)))})
    elif d.model == "complex":
        convArgs.update({
            "spectral_parametrization": d.spectral_param,
            "kernel_initializer": d.comp_init
        })

    #
    # Input Layer
    #

    I = Input(shape=inputShape)

    #
    # Stage 1
    #

    O = learnConcatRealImagBlock(I, (1, 1), (3, 3), 0, '0', convArgs, bnArgs,
                                 d)
    O = Concatenate(channelAxis)([I, O])
    if d.model == "real":
        O = Conv2D(sf, filsize, name='conv1', **convArgs)(O)
        O = BatchNormalization(name="bn_conv1_2a", **bnArgs)(O)
    else:
        O = ComplexConv2D(sf, filsize, name='conv1', **convArgs)(O)
        O = ComplexBN(name="bn_conv1_2a", **bnArgs)(O)
    O = Activation(activation)(O)

    #
    # Stage 2
    #

    for i in range(n):
        O = getResidualBlock(O, filsize, [sf, sf], 2, str(i), 'regular',
                             convArgs, bnArgs, d)
        if i == n // 2 and d.spectral_pool_scheme == "stagemiddle":
            O = applySpectralPooling(O, d)

    #
    # Stage 3
    #

    O = getResidualBlock(O, filsize, [sf, sf], 3, '0', 'projection', convArgs,
                         bnArgs, d)
    if d.spectral_pool_scheme == "nodownsample":
        O = applySpectralPooling(O, d)

    for i in range(n - 1):
        O = getResidualBlock(O, filsize, [sf * 2, sf * 2], 3, str(i + 1),
                             'regular', convArgs, bnArgs, d)
        if i == n // 2 and d.spectral_pool_scheme == "stagemiddle":
            O = applySpectralPooling(O, d)

    #
    # Stage 4
    #

    O = getResidualBlock(O, filsize, [sf * 2, sf * 2], 4, '0', 'projection',
                         convArgs, bnArgs, d)
    if d.spectral_pool_scheme == "nodownsample":
        O = applySpectralPooling(O, d)

    for i in range(n - 1):
        O = getResidualBlock(O, filsize, [sf * 4, sf * 4], 4, str(i + 1),
                             'regular', convArgs, bnArgs, d)
        if i == n // 2 and d.spectral_pool_scheme == "stagemiddle":
            O = applySpectralPooling(O, d)

    #
    # Pooling
    #

    if d.spectral_pool_scheme == "nodownsample":
        O = applySpectralPooling(O, d)
        O = AveragePooling2D(pool_size=(32, 32))(O)
    else:
        O = AveragePooling2D(pool_size=(8, 8))(O)

    #
    # Flatten
    #

    O = Flatten()(O)

    #
    # Dense
    #

    if dataset == 'cifar10':
        O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O)
    elif dataset == 'cifar100':
        O = Dense(100, activation='softmax', kernel_regularizer=l2(0.0001))(O)
    elif dataset == 'svhn':
        O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O)
    else:
        raise ValueError("Unknown dataset " + d.dataset)

    # Return the model
    return Model(I, O)
# model.add(Conv2D(32, kernel_size=(3, 3),
#                  activation='relu',
#                  kernel_initializer='he_normal',
#                  input_shape=input_shape))
model.add(
    Conv2D(32,
           kernel_size=(3, 3),
           kernel_initializer='he_normal',
           input_shape=input_shape))
model.add(BatchNormalization(**bnArgs))
model.add(Activation('relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Dropout(0.25))
if COMPLEX_MODE:
    # model.add(ComplexConv2D(32, (3, 3)))
    model.add(ComplexConv2D(32, (3, 3), **convArgs))
    model.add(ComplexBN(**bnArgs))
else:
    # model.add(Conv2D(64, (3, 3), activation='relu'))
    # model.add(Conv2D(32*2, (3, 3)))
    model.add(Conv2D(32 * 2, (3, 3), **convArgs))
    model.add(BatchNormalization(**bnArgs))
model.add(Activation('relu'))
# O = ComplexConv2D(sf, filsize, name='conv1', **convArgs)(O)
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
if COMPLEX_MODE:
    # model.add(ComplexConv2D(64, (3, 3)))
    model.add(ComplexConv2D(64, (3, 3), **convArgs))
    model.add(ComplexBN(**bnArgs))
else:
def getResidualBlock(I, filter_size, featmaps, stage, block, shortcut,
                     convArgs, convArgs_real, bnArgs, d):
    """Get residual block."""

    activation = d.act
    drop_prob = d.dropout
    nb_fmaps1, nb_fmaps2 = featmaps
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    if K.image_data_format() == 'channels_first' and K.ndim(I) != 3:
        channel_axis = 1
    else:
        channel_axis = -1

    # if   d.model == "real":
    if "real" in d.model:
        O = BatchNormalization(name=bn_name_base + '_2a', **bnArgs)(I)
    # elif d.model == "complex":
    elif "complex" in d.model:
        O = ComplexBN(name=bn_name_base + '_2a', **bnArgs)(I)

    if d.aact == "complex_joint_relu":
        O = Lambda(ComplexJointReLU)(O)
    else:
        O = Activation(activation)(O)

    if shortcut == 'regular' or d.spectral_pool_scheme == "nodownsample":
        if d.model == "real":
            O = Conv2D(nb_fmaps1,
                       filter_size,
                       name=conv_name_base + '2a',
                       **convArgs)(O)
        elif d.model == "real_dws":
            O = SeparableConv2D(nb_fmaps1,
                                filter_size,
                                name=conv_name_base + '2a',
                                **convArgs)(O)
        elif (d.model
              == "real_group") or (d.model == "real_group_pwc_full") or (
                  d.model == "real_group_pwc_group"):
            O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
            O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
            O_g0 = Conv2D(nb_fmaps1 // 2,
                          filter_size,
                          name=conv_name_base + '2a_g0',
                          **convArgs)(O_g0)
            O_g1 = Conv2D(nb_fmaps1 // 2,
                          filter_size,
                          name=conv_name_base + '2a_g1',
                          **convArgs)(O_g1)
            O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])(
                O_g0)
            O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])(
                O_g0)
            O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])(
                O_g1)
            O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])(
                O_g1)
            O = Concatenate(axis=1)(
                [O_g00, O_g11, O_g01, O_g10]
            )  # This ordering allows permutation of odd-numbered outputs (O_g0, O_g1).
            if d.model == "real_group_pwc_full":
                O = Conv2D(nb_fmaps1, (1, 1),
                           name=conv_name_base + '2a_pwc',
                           **convArgs)(O)
            elif d.model == "real_group_pwc_group":
                O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
                O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
                O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g0',
                              **convArgs_real)(O_g0)
                O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g1',
                              **convArgs_real)(O_g1)
                O = Concatenate(axis=1)([O_g0, O_g1])
        elif d.model == "complex":
            O = ComplexConv2D(nb_fmaps1,
                              filter_size,
                              name=conv_name_base + '2a',
                              **convArgs)(O)
        elif (d.model == "complex_concat") or (d.model
                                               == "complex_concat_pwc_group"):
            O = ComplexConvConcat2D(nb_fmaps1 // 2,
                                    filter_size,
                                    name=conv_name_base + '2a',
                                    **convArgs)(O)
            if d.model == "complex_concat_pwc_group":
                O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
                O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
                O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g0',
                              **convArgs_real)(O_g0)
                O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g1',
                              **convArgs_real)(O_g1)
                O = Concatenate(axis=1)([O_g0, O_g1])
        else:
            print("Error: unknown model type")
            exit(-1)
    elif shortcut == 'projection':
        if d.spectral_pool_scheme == "proj":
            O = applySpectralPooling(O, d)

        if d.model == "real":
            O = Conv2D(nb_fmaps1,
                       filter_size,
                       name=conv_name_base + '2a',
                       strides=(2, 2),
                       **convArgs)(O)
        elif d.model == "real_dws":
            O = SeparableConv2D(nb_fmaps1,
                                filter_size,
                                name=conv_name_base + '2a',
                                strides=(2, 2),
                                **convArgs)(O)
        elif (d.model
              == "real_group") or (d.model == "real_group_pwc_full") or (
                  d.model == "real_group_pwc_group"):
            O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
            O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
            O_g0 = Conv2D(nb_fmaps1 // 2,
                          filter_size,
                          name=conv_name_base + '2a_g0',
                          strides=(2, 2),
                          **convArgs)(O_g0)
            O_g1 = Conv2D(nb_fmaps1 // 2,
                          filter_size,
                          name=conv_name_base + '2a_g1',
                          strides=(2, 2),
                          **convArgs)(O_g1)
            O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])(
                O_g0)
            O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])(
                O_g0)
            O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])(
                O_g1)
            O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])(
                O_g1)
            O = Concatenate(axis=1)([O_g00, O_g11, O_g01, O_g10])
            if d.model == "real_group_pwc_full":
                O = Conv2D(nb_fmaps1, (1, 1),
                           name=conv_name_base + '2a_pwc',
                           **convArgs)(O)
            elif d.model == "real_group_pwc_group":
                O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
                O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
                O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g0',
                              **convArgs_real)(O_g0)
                O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g1',
                              **convArgs_real)(O_g1)
                O = Concatenate(axis=1)([O_g0, O_g1])
        elif d.model == "complex":
            O = ComplexConv2D(nb_fmaps1,
                              filter_size,
                              name=conv_name_base + '2a',
                              strides=(2, 2),
                              **convArgs)(O)
        elif (d.model == "complex_concat") or (d.model
                                               == "complex_concat_pwc_group"):
            O = ComplexConvConcat2D(nb_fmaps1 // 2,
                                    filter_size,
                                    name=conv_name_base + '2a',
                                    strides=(2, 2),
                                    **convArgs)(O)
            if d.model == "complex_concat_pwc_group":
                O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
                O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
                O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g0',
                              **convArgs_real)(O_g0)
                O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g1',
                              **convArgs_real)(O_g1)
                O = Concatenate(axis=1)([O_g0, O_g1])
        else:
            print("Error: unknown model type")
            exit(-1)

    if d.model == "real":
        O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O)
        O = Activation(activation)(O)
        O = Conv2D(nb_fmaps2,
                   filter_size,
                   name=conv_name_base + '2b',
                   **convArgs)(O)
    elif d.model == "real_dws":
        O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O)
        O = Activation(activation)(O)
        O = SeparableConv2D(nb_fmaps2,
                            filter_size,
                            name=conv_name_base + '2b',
                            **convArgs)(O)
    elif (d.model == "real_group") or (d.model == "real_group_pwc_full") or (
            d.model == "real_group_pwc_group"):
        O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O)
        O = Activation(activation)(O)
        O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
        O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
        O_g0 = Conv2D(nb_fmaps2 // 2,
                      filter_size,
                      name=conv_name_base + '2b_g0',
                      **convArgs)(O_g0)
        O_g1 = Conv2D(nb_fmaps2 // 2,
                      filter_size,
                      name=conv_name_base + '2b_g1',
                      **convArgs)(O_g1)
        O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])(O_g0)
        O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])(O_g0)
        O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])(O_g1)
        O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])(O_g1)
        O = Concatenate(axis=1)([O_g00, O_g11, O_g01, O_g10])
        if d.model == "real_group_pwc_full":
            O = Conv2D(nb_fmaps2, (1, 1),
                       name=conv_name_base + '2b_pwc',
                       **convArgs)(O)
        elif d.model == "real_group_pwc_group":
            O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
            O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
            O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name=conv_name_base + '2b_pwc_g0',
                          **convArgs_real)(O_g0)
            O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name=conv_name_base + '2b_pwc_g1',
                          **convArgs_real)(O_g1)
            O = Concatenate(axis=1)([O_g0, O_g1])
    elif d.model == "complex":
        O = ComplexBN(name=bn_name_base + '_2b', **bnArgs)(O)
        if d.aact == "complex_joint_relu":
            O = Lambda(ComplexJointReLU)(O)
        else:
            O = Activation(activation)(O)
        O = ComplexConv2D(nb_fmaps2,
                          filter_size,
                          name=conv_name_base + '2b',
                          **convArgs)(O)
    elif (d.model == "complex_concat") or (d.model
                                           == "complex_concat_pwc_group"):
        O = ComplexBN(name=bn_name_base + '_2b', **bnArgs)(O)
        if d.aact == "complex_joint_relu":
            O = Lambda(ComplexJointReLU)(O)
        else:
            O = Activation(activation)(O)
        O = ComplexConvConcat2D(nb_fmaps2 // 2,
                                filter_size,
                                name=conv_name_base + '2b',
                                **convArgs)(O)
        if d.model == "complex_concat_pwc_group":
            O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
            O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
            O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name=conv_name_base + '2b_pwc_g0',
                          **convArgs_real)(O_g0)
            O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name=conv_name_base + '2b_pwc_g1',
                          **convArgs_real)(O_g1)
            O = Concatenate(axis=1)([O_g0, O_g1])

    else:
        print("Error: unknown model type")
        exit(-1)

    if shortcut == 'regular':
        O = Add()([O, I])
    elif shortcut == 'projection':
        if d.spectral_pool_scheme == "proj":
            I = applySpectralPooling(I, d)
        if d.model == "real":
            X = Conv2D(
                nb_fmaps2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)
            O = Concatenate(channel_axis)([X, O])
        elif d.model == "real_dws":
            X = SeparableConv2D(
                nb_fmaps2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)
            O = Concatenate(channel_axis)([X, O])
        elif (d.model
              == "real_group") or (d.model == "real_group_pwc_full") or (
                  d.model == "real_group_pwc_group"):
            I_g0 = Lambda(lambda I: I[:, :(I.shape[1] // 2), :, :])(I)
            I_g1 = Lambda(lambda I: I[:, (I.shape[1] // 2):, :, :])(I)
            X_g0 = Conv2D(
                nb_fmaps2 // 2, (1, 1),
                name=conv_name_base + '1_g0',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I_g0)
            X_g1 = Conv2D(
                nb_fmaps2 // 2, (1, 1),
                name=conv_name_base + '1_g1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I_g1)
            X_g00 = Lambda(lambda X_g0: X_g0[:, :(X_g0.shape[1] // 2), :, :])(
                X_g0)
            X_g01 = Lambda(lambda X_g0: X_g0[:, (X_g0.shape[1] // 2):, :, :])(
                X_g0)
            X_g10 = Lambda(lambda X_g1: X_g1[:, :(X_g1.shape[1] // 2), :, :])(
                X_g1)
            X_g11 = Lambda(lambda X_g1: X_g1[:, (X_g1.shape[1] // 2):, :, :])(
                X_g1)
            X = Concatenate(axis=1)([X_g00, X_g11, X_g01, X_g10])
            if d.model == "real_group_pwc_full":
                X = Conv2D(nb_fmaps2, (1, 1),
                           name=conv_name_base + '1_pwc',
                           **convArgs)(X)
            elif d.model == "real_group_pwc_group":
                X_g0 = Lambda(lambda X: X[:, :(X.shape[1] // 2), :, :])(X)
                X_g1 = Lambda(lambda X: X[:, (X.shape[1] // 2):, :, :])(X)
                X_g0 = Conv2D(int(X.shape[1] // 2), (1, 1),
                              name=conv_name_base + '1_pwc_g0',
                              **convArgs_real)(X_g0)
                X_g1 = Conv2D(int(X.shape[1] // 2), (1, 1),
                              name=conv_name_base + '1_pwc_g1',
                              **convArgs_real)(X_g1)
                X = Concatenate(axis=1)([X_g0, X_g1])
            O = Concatenate(channel_axis)([X, O])
        elif d.model == "complex":
            X = ComplexConv2D(
                nb_fmaps2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)
            O_real = Concatenate(channel_axis)([GetReal()(X), GetReal()(O)])
            O_imag = Concatenate(channel_axis)([GetImag()(X), GetImag()(O)])
            O = Concatenate(1)([O_real, O_imag])
        elif (d.model == "complex_concat") or (d.model
                                               == "complex_concat_pwc_group"):
            X = ComplexConvConcat2D(
                nb_fmaps2 // 2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)
            if d.model == "complex_concat_pwc_group":
                X_g0 = Lambda(lambda X: X[:, :(X.shape[1] // 2), :, :])(X)
                X_g1 = Lambda(lambda X: X[:, (X.shape[1] // 2):, :, :])(X)
                X_g0 = Conv2D(int(X.shape[1] // 2), (1, 1),
                              name=conv_name_base + '1_pwc_g0',
                              **convArgs_real)(X_g0)
                X_g1 = Conv2D(int(X.shape[1] // 2), (1, 1),
                              name=conv_name_base + '1_pwc_g1',
                              **convArgs_real)(X_g1)
                X = Concatenate(axis=1)([X_g0, X_g1])
            O_real = Concatenate(channel_axis)([GetReal()(X), GetReal()(O)])
            O_imag = Concatenate(channel_axis)([GetImag()(X), GetImag()(O)])
            O = Concatenate(1)([O_real, O_imag])
        else:
            print("Error: unknown model type")
            exit(-1)

    return O
def getResnetModel(d):
    n = d.num_blocks
    sf = d.start_filter
    dataset = d.dataset
    activation = d.act
    advanced_act = d.aact
    drop_prob = d.dropout
    if "mnist" in dataset:
        inputShape = (1, 28, 28) if K.image_dim_ordering() == "th" else (28,
                                                                         28, 1)
    else:
        inputShape = (3, 32, 32) if K.image_dim_ordering() == "th" else (32,
                                                                         32, 3)
    channelAxis = 1 if K.image_data_format() == 'channels_first' else -1
    filsize = (3, 3)
    convArgs = {
        "padding": "same",
        "use_bias": False,
        "kernel_regularizer": l2(0.0001),
    }
    bnArgs = {"axis": channelAxis, "momentum": 0.9, "epsilon": 1e-04}

    import copy

    convArgs_real = copy.deepcopy(convArgs)

    # if   d.model == "real":
    if "real" in d.model:
        sf *= 2
        convArgs.update({"kernel_initializer": Orthogonal(float(np.sqrt(2)))})
        convArgs_real = convArgs
    # elif d.model == "complex":
    elif "complex" in d.model:
        convArgs.update({
            "spectral_parametrization": d.spectral_param,
            "kernel_initializer": d.comp_init
        })

    #
    # Input Layer
    #

    I = Input(shape=inputShape)

    #
    # Stage 1
    #

    O = learnConcatRealImagBlock(I, (1, 1), (3, 3), 0, '0', convArgs, bnArgs,
                                 d)
    O = Concatenate(channelAxis)([I, O])
    if d.model == "real":
        O = Conv2D(sf, filsize, name='conv1', **convArgs)(O)
        O = BatchNormalization(name="bn_conv1_2a", **bnArgs)(O)
    elif d.model == "real_dws":
        O = SeparableConv2D(sf, filsize, name='conv1', **convArgs)(O)
        O = BatchNormalization(name="bn_conv1_2a", **bnArgs)(O)
    elif (d.model == "real_group") or (d.model == "real_group_pwc_full") or (
            d.model == "real_group_pwc_group"):
        O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
        O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
        O_g0 = Conv2D(sf // 2, filsize, name='conv1_g0', **convArgs)(O_g0)
        O_g1 = Conv2D(sf // 2, filsize, name='conv1_g1', **convArgs)(O_g1)
        O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])(O_g0)
        O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])(O_g0)
        O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])(O_g1)
        O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])(O_g1)
        O = Concatenate(axis=1)([O_g00, O_g11, O_g01, O_g10])
        if d.model == "real_group_pwc_full":
            O = Conv2D(sf, (1, 1), name='conv1_pwc', **convArgs)(O)
        elif d.model == "real_group_pwc_group":
            O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
            O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
            O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name='conv1_pwc_g0',
                          **convArgs_real)(O_g0)
            O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name='conv1_pwc_g1',
                          **convArgs_real)(O_g1)
            O = Concatenate(axis=1)([O_g0, O_g1])
        O = BatchNormalization(name="bn_conv1_2a", **bnArgs)(O)
    elif d.model == "complex":
        O = ComplexConv2D(sf, filsize, name='conv1', **convArgs)(O)
        O = ComplexBN(name="bn_conv1_2a", **bnArgs)(O)
    elif (d.model == "complex_concat") or (d.model
                                           == "complex_concat_pwc_group"):
        O = ComplexConvConcat2D(sf // 2, filsize, name='conv1', **convArgs)(O)
        O = ComplexBN(name="bn_conv1_2a", **bnArgs)(O)
        if d.model == "complex_concat_pwc_group":
            O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
            O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
            O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name='conv1_pwc_g0',
                          **convArgs_real)(O_g0)
            O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name='conv1_pwc_g1',
                          **convArgs_real)(O_g1)
            O = Concatenate(axis=1)([O_g0, O_g1])
    else:
        print("Error: unknown model type")
        exit(-1)

    if d.aact == "complex_joint_relu":
        O = Lambda(ComplexJointReLU)(O)
    else:
        O = Activation(activation)(O)

    #
    # Stage 2
    #

    for i in xrange(n):
        O = getResidualBlock(O, filsize, [sf, sf], 2, str(i), 'regular',
                             convArgs, convArgs_real, bnArgs, d)
        if i == n // 2 and d.spectral_pool_scheme == "stagemiddle":
            O = applySpectralPooling(O, d)

    #
    # Stage 3
    #

    O = getResidualBlock(O, filsize, [sf, sf], 3, '0', 'projection', convArgs,
                         convArgs_real, bnArgs, d)
    if d.spectral_pool_scheme == "nodownsample":
        O = applySpectralPooling(O, d)

    for i in xrange(n - 1):
        O = getResidualBlock(O, filsize, [sf * 2, sf * 2], 3, str(i + 1),
                             'regular', convArgs, convArgs_real, bnArgs, d)
        if i == n // 2 and d.spectral_pool_scheme == "stagemiddle":
            O = applySpectralPooling(O, d)

    #
    # Stage 4
    #

    O = getResidualBlock(O, filsize, [sf * 2, sf * 2], 4, '0', 'projection',
                         convArgs, convArgs_real, bnArgs, d)
    if d.spectral_pool_scheme == "nodownsample":
        O = applySpectralPooling(O, d)

    for i in xrange(n - 1):
        O = getResidualBlock(O, filsize, [sf * 4, sf * 4], 4, str(i + 1),
                             'regular', convArgs, convArgs_real, bnArgs, d)
        if i == n // 2 and d.spectral_pool_scheme == "stagemiddle":
            O = applySpectralPooling(O, d)

    #
    # Pooling
    #

    if d.spectral_pool_scheme == "nodownsample":
        O = applySpectralPooling(O, d)
        if "mnist" in dataset:
            O = AveragePooling2D(pool_size=(28, 28))(O)
        else:
            O = AveragePooling2D(pool_size=(32, 32))(O)
    else:
        if "mnist" in dataset:
            O = AveragePooling2D(pool_size=(7, 7))(O)
        else:
            O = AveragePooling2D(pool_size=(8, 8))(O)

    #
    # Flatten
    #

    O = Flatten()(O)

    #
    # Dense
    #

    if dataset == 'cifar10':
        O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O)
    elif dataset == 'cifar100':
        O = Dense(100, activation='softmax', kernel_regularizer=l2(0.0001))(O)
    elif dataset == 'svhn':
        O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O)
    elif dataset == 'mnist':
        O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O)
    elif dataset == 'fashion_mnist':
        O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O)
    else:
        raise ValueError("Unknown dataset " + d.dataset)

    # Return the model
    return Model(I, O)