Ejemplo n.º 1
0
def comResnet(input, d):
    activation = d.act
    inputShape = (100, 100, 3)
    channelAxis = 1 if K.image_data_format() == 'channels_first' else -1
    convArgs = {
        "padding": "same",
        "use_bias": False,
        "kernel_regularizer": l2(0.0001)
    }
    bnArgs = {"axis": channelAxis, "momentum": 0.9, "epsilon": 1e-04}
    convArgs.update({
        "spectral_parametrization": d.spectral_param,
        "kernel_initializer": d.comp_init
    })

    O = learnConcatRealImagBlock(input, (1, 1), (3, 3), 0, '0', convArgs,
                                 bnArgs, d)
    #O = tf.concat([input, O], 1)
    O = Concatenate(channelAxis)([input, O])
    O = ComplexConv2D(filters=64, kernel_size=9, name='conv1', **convArgs)(O)
    O = ComplexBN(name='bn conv1 2a', **bnArgs)(O)
    O = tf.nn.relu(O)

    # residual
    O = getResidualBlock(O, (3, 3), [64, 64], 2, '0', 'regular', convArgs,
                         bnArgs, d)
    O = getResidualBlock(O, (3, 3), [64, 64], 2, '0', 'regular', convArgs,
                         bnArgs, d)
    O = getResidualBlock(O, (3, 3), [64, 64], 2, '0', 'regular', convArgs,
                         bnArgs, d)
    O = getResidualBlock(O, (3, 3), [64, 64], 2, '0', 'regular', convArgs,
                         bnArgs, d)

    O = ComplexConv2D(filters=64, kernel_size=9, name='conv2', **convArgs)(O)
    O = ComplexBN(name='bn conv2 1a', **bnArgs)(O)
    O = tf.nn.relu(O)

    O = ComplexConv2D(filters=64, kernel_size=9, name='conv3', **convArgs)(O)
    O = ComplexBN(name='bn conv3 1a', **bnArgs)(O)
    O = tf.nn.relu(O)

    O = ComplexConv2D(filters=3, kernel_size=9, name='conv4', **convArgs)(O)
    O = ComplexBN(name='bn conv4 1a', **bnArgs)(O)
    O = tf.nn.tanh(O)
    return O
Ejemplo n.º 2
0
 def __init__(self):
     super(DCSNet_bn, self).__init__()
     self.main = nn.Sequential(
         # state size. 2 x 1025
         ComplexConv1D(1, 16, 7, 2, 3, bias=False),
         ComplexBN(16),
         #nn.BatchNorm2d(ngf * 4),
         nn.LeakyReLU(0.2, inplace=True),
         # nn.ReLU(),
         
         # state size. 32 x 513
         ComplexConv1D(16, 32, 5, 2, 2, bias=False),
         ComplexBN(32),
         #nn.BatchNorm2d(ngf * 2),
         nn.LeakyReLU(0.2, inplace=True),
         # nn.ReLU(),
         
         # state size. 64 x 257
         ComplexConv1D(32, 32, 3, 2, 1, bias=False),
         ComplexBN(32),
         #nn.BatchNorm2d(ngf),
         nn.LeakyReLU(0.2, inplace=True),
         # nn.ReLU(),
         
         # state size. 64 x 129
         ComplexConv1D(32, 64, 3, 2, 1, bias=False),
         ComplexBN(64),
         #nn.BatchNorm2d(ngf),
         nn.LeakyReLU(0.2, inplace=True),
         # nn.ReLU(), 
         
         # state size. 128 x 65
         ComplexConv1D(64, 64, 3, 2, 1, bias=False),
         ComplexBN(64),
         #nn.BatchNorm2d(ngf),
         nn.LeakyReLU(0.2, inplace=True),
         # nn.ReLU()
         # state size. 128 x 33
     )
     self.dense1 = ComplexDense(2112,1025, init_criterion='he')
     self.dense2 = ComplexDense(2112,1025, init_criterion='he')
Ejemplo n.º 3
0
def getResidualBlock(I, filter_size, featmaps, stage, block, shortcut,
                     convArgs, bnArgs, d):
    """Get residual block."""

    activation = d.act
    drop_prob = d.dropout
    nb_fmaps1, nb_fmaps2 = featmaps
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    if K.image_data_format() == 'channels_first' and K.ndim(I) != 3:
        channel_axis = 1
    else:
        channel_axis = -1

    if d.model == "real":
        O = BatchNormalization(name=bn_name_base + '_2a', **bnArgs)(I)
    elif d.model == "complex":
        O = ComplexBN(name=bn_name_base + '_2a', **bnArgs)(I)
    O = Activation(activation)(O)

    if shortcut == 'regular' or d.spectral_pool_scheme == "nodownsample":
        if d.model == "real":
            O = Conv2D(nb_fmaps1,
                       filter_size,
                       name=conv_name_base + '2a',
                       **convArgs)(O)
        elif d.model == "complex":
            O = ComplexConv2D(nb_fmaps1,
                              filter_size,
                              name=conv_name_base + '2a',
                              **convArgs)(O)
    elif shortcut == 'projection':
        if d.spectral_pool_scheme == "proj":
            O = applySpectralPooling(O, d)
        if d.model == "real":
            O = Conv2D(nb_fmaps1,
                       filter_size,
                       name=conv_name_base + '2a',
                       strides=(2, 2),
                       **convArgs)(O)
        elif d.model == "complex":
            O = ComplexConv2D(nb_fmaps1,
                              filter_size,
                              name=conv_name_base + '2a',
                              strides=(2, 2),
                              **convArgs)(O)

    if d.model == "real":
        O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O)
        O = Activation(activation)(O)
        O = Conv2D(nb_fmaps2,
                   filter_size,
                   name=conv_name_base + '2b',
                   **convArgs)(O)
    elif d.model == "complex":
        O = ComplexBN(name=bn_name_base + '_2b', **bnArgs)(O)
        O = Activation(activation)(O)
        O = ComplexConv2D(nb_fmaps2,
                          filter_size,
                          name=conv_name_base + '2b',
                          **convArgs)(O)

    if shortcut == 'regular':
        O = Add()([O, I])
    elif shortcut == 'projection':
        if d.spectral_pool_scheme == "proj":
            I = applySpectralPooling(I, d)
        if d.model == "real":
            X = Conv2D(
                nb_fmaps2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)
            O = Concatenate(channel_axis)([X, O])
        elif d.model == "complex":
            X = ComplexConv2D(
                nb_fmaps2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)

            O_real = Concatenate(channel_axis)([GetReal()(X), GetReal()(O)])
            O_imag = Concatenate(channel_axis)([GetImag()(X), GetImag()(O)])
            O = Concatenate(1)([O_real, O_imag])

    return O
Ejemplo n.º 4
0
def getResnetModel(d):
    n = d.num_blocks
    sf = d.start_filter
    dataset = d.dataset
    activation = d.act
    advanced_act = d.aact
    drop_prob = d.dropout
    inputShape = (3, 32, 32) if K.image_dim_ordering() == "th" else (32, 32, 3)
    channelAxis = 1 if K.image_data_format() == 'channels_first' else -1
    filsize = (3, 3)
    convArgs = {
        "padding": "same",
        "use_bias": False,
        "kernel_regularizer": l2(0.0001),
    }
    bnArgs = {"axis": channelAxis, "momentum": 0.9, "epsilon": 1e-04}

    if d.model == "real":
        sf *= 2
        convArgs.update({"kernel_initializer": Orthogonal(float(np.sqrt(2)))})
    elif d.model == "complex":
        convArgs.update({
            "spectral_parametrization": d.spectral_param,
            "kernel_initializer": d.comp_init
        })

    #
    # Input Layer
    #

    I = Input(shape=inputShape)

    #
    # Stage 1
    #

    O = learnConcatRealImagBlock(I, (1, 1), (3, 3), 0, '0', convArgs, bnArgs,
                                 d)
    O = Concatenate(channelAxis)([I, O])
    if d.model == "real":
        O = Conv2D(sf, filsize, name='conv1', **convArgs)(O)
        O = BatchNormalization(name="bn_conv1_2a", **bnArgs)(O)
    else:
        O = ComplexConv2D(sf, filsize, name='conv1', **convArgs)(O)
        O = ComplexBN(name="bn_conv1_2a", **bnArgs)(O)
    O = Activation(activation)(O)

    #
    # Stage 2
    #

    for i in range(n):
        O = getResidualBlock(O, filsize, [sf, sf], 2, str(i), 'regular',
                             convArgs, bnArgs, d)
        if i == n // 2 and d.spectral_pool_scheme == "stagemiddle":
            O = applySpectralPooling(O, d)

    #
    # Stage 3
    #

    O = getResidualBlock(O, filsize, [sf, sf], 3, '0', 'projection', convArgs,
                         bnArgs, d)
    if d.spectral_pool_scheme == "nodownsample":
        O = applySpectralPooling(O, d)

    for i in range(n - 1):
        O = getResidualBlock(O, filsize, [sf * 2, sf * 2], 3, str(i + 1),
                             'regular', convArgs, bnArgs, d)
        if i == n // 2 and d.spectral_pool_scheme == "stagemiddle":
            O = applySpectralPooling(O, d)

    #
    # Stage 4
    #

    O = getResidualBlock(O, filsize, [sf * 2, sf * 2], 4, '0', 'projection',
                         convArgs, bnArgs, d)
    if d.spectral_pool_scheme == "nodownsample":
        O = applySpectralPooling(O, d)

    for i in range(n - 1):
        O = getResidualBlock(O, filsize, [sf * 4, sf * 4], 4, str(i + 1),
                             'regular', convArgs, bnArgs, d)
        if i == n // 2 and d.spectral_pool_scheme == "stagemiddle":
            O = applySpectralPooling(O, d)

    #
    # Pooling
    #

    if d.spectral_pool_scheme == "nodownsample":
        O = applySpectralPooling(O, d)
        O = AveragePooling2D(pool_size=(32, 32))(O)
    else:
        O = AveragePooling2D(pool_size=(8, 8))(O)

    #
    # Flatten
    #

    O = Flatten()(O)

    #
    # Dense
    #

    if dataset == 'cifar10':
        O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O)
    elif dataset == 'cifar100':
        O = Dense(100, activation='softmax', kernel_regularizer=l2(0.0001))(O)
    elif dataset == 'svhn':
        O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O)
    else:
        raise ValueError("Unknown dataset " + d.dataset)

    # Return the model
    return Model(I, O)
Ejemplo n.º 5
0
def get_deep_convnet(window_size=4096, channels=2, output_size=84):
    inputs = Input(shape=(window_size, channels))
    outs = inputs

    outs = (ComplexConv1D(
        16, 6, strides=2, padding='same',
        activation='linear',
        kernel_initializer='complex_independent'))(outs)
    outs = (ComplexBN(axis=-1))(outs)
    outs = (keras.layers.Activation('relu'))(outs)
    outs = (keras.layers.AveragePooling1D(pool_size=2, strides=2))(outs)

    outs = (ComplexConv1D(
        32, 3, strides=2, padding='same',
        activation='linear',
        kernel_initializer='complex_independent'))(outs)
    outs = (ComplexBN(axis=-1))(outs)
    outs = (keras.layers.Activation('relu'))(outs)
    outs = (keras.layers.AveragePooling1D(pool_size=2, strides=2))(outs)
    
    outs = (ComplexConv1D(
        64, 3, strides=1, padding='same',
        activation='linear',
        kernel_initializer='complex_independent'))(outs)
    outs = (ComplexBN(axis=-1))(outs)
    outs = (keras.layers.Activation('relu'))(outs)
    outs = (keras.layers.AveragePooling1D(pool_size=2, strides=2))(outs)

    outs = (ComplexConv1D(
        64, 3, strides=1, padding='same',
        activation='linear',
        kernel_initializer='complex_independent'))(outs)
    outs = (ComplexBN(axis=-1))(outs)
    outs = (keras.layers.Activation('relu'))(outs)
    outs = (keras.layers.AveragePooling1D(pool_size=2, strides=2))(outs)

    outs = (ComplexConv1D(
        128, 3, strides=1, padding='same',
        activation='relu',
        kernel_initializer='complex_independent'))(outs)
    outs = (ComplexConv1D(
        128, 3, strides=1, padding='same',
        activation='linear',
        kernel_initializer='complex_independent'))(outs)
    outs = (ComplexBN(axis=-1))(outs)
    outs = (keras.layers.Activation('relu'))(outs)
    outs = (keras.layers.AveragePooling1D(pool_size=2, strides=2))(outs)

    #outs = (keras.layers.MaxPooling1D(pool_size=2))
    #outs = (Permute([2, 1]))
    outs = (keras.layers.Flatten())(outs)
    outs = (keras.layers.Dense(2048, activation='relu',
                           kernel_initializer='glorot_normal'))(outs)
    predictions = (keras.layers.Dense(output_size, activation='sigmoid',
                                 bias_initializer=keras.initializers.Constant(value=-5)))(outs)

    model = Model(inputs=inputs, outputs=predictions)
    model.compile(optimizer=keras.optimizers.Adam(lr=1e-4),
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    return model
#                  activation='relu',
#                  kernel_initializer='he_normal',
#                  input_shape=input_shape))
model.add(
    Conv2D(32,
           kernel_size=(3, 3),
           kernel_initializer='he_normal',
           input_shape=input_shape))
model.add(BatchNormalization(**bnArgs))
model.add(Activation('relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Dropout(0.25))
if COMPLEX_MODE:
    # model.add(ComplexConv2D(32, (3, 3)))
    model.add(ComplexConv2D(32, (3, 3), **convArgs))
    model.add(ComplexBN(**bnArgs))
else:
    # model.add(Conv2D(64, (3, 3), activation='relu'))
    # model.add(Conv2D(32*2, (3, 3)))
    model.add(Conv2D(32 * 2, (3, 3), **convArgs))
    model.add(BatchNormalization(**bnArgs))
model.add(Activation('relu'))
# O = ComplexConv2D(sf, filsize, name='conv1', **convArgs)(O)
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
if COMPLEX_MODE:
    # model.add(ComplexConv2D(64, (3, 3)))
    model.add(ComplexConv2D(64, (3, 3), **convArgs))
    model.add(ComplexBN(**bnArgs))
else:
    # model.add(Conv2D(128, (3, 3), activation='relu'))
def getResidualBlock(I, filter_size, featmaps, stage, block, shortcut,
                     convArgs, convArgs_real, bnArgs, d):
    """Get residual block."""

    activation = d.act
    drop_prob = d.dropout
    nb_fmaps1, nb_fmaps2 = featmaps
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    if K.image_data_format() == 'channels_first' and K.ndim(I) != 3:
        channel_axis = 1
    else:
        channel_axis = -1

    # if   d.model == "real":
    if "real" in d.model:
        O = BatchNormalization(name=bn_name_base + '_2a', **bnArgs)(I)
    # elif d.model == "complex":
    elif "complex" in d.model:
        O = ComplexBN(name=bn_name_base + '_2a', **bnArgs)(I)

    if d.aact == "complex_joint_relu":
        O = Lambda(ComplexJointReLU)(O)
    else:
        O = Activation(activation)(O)

    if shortcut == 'regular' or d.spectral_pool_scheme == "nodownsample":
        if d.model == "real":
            O = Conv2D(nb_fmaps1,
                       filter_size,
                       name=conv_name_base + '2a',
                       **convArgs)(O)
        elif d.model == "real_dws":
            O = SeparableConv2D(nb_fmaps1,
                                filter_size,
                                name=conv_name_base + '2a',
                                **convArgs)(O)
        elif (d.model
              == "real_group") or (d.model == "real_group_pwc_full") or (
                  d.model == "real_group_pwc_group"):
            O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
            O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
            O_g0 = Conv2D(nb_fmaps1 // 2,
                          filter_size,
                          name=conv_name_base + '2a_g0',
                          **convArgs)(O_g0)
            O_g1 = Conv2D(nb_fmaps1 // 2,
                          filter_size,
                          name=conv_name_base + '2a_g1',
                          **convArgs)(O_g1)
            O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])(
                O_g0)
            O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])(
                O_g0)
            O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])(
                O_g1)
            O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])(
                O_g1)
            O = Concatenate(axis=1)(
                [O_g00, O_g11, O_g01, O_g10]
            )  # This ordering allows permutation of odd-numbered outputs (O_g0, O_g1).
            if d.model == "real_group_pwc_full":
                O = Conv2D(nb_fmaps1, (1, 1),
                           name=conv_name_base + '2a_pwc',
                           **convArgs)(O)
            elif d.model == "real_group_pwc_group":
                O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
                O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
                O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g0',
                              **convArgs_real)(O_g0)
                O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g1',
                              **convArgs_real)(O_g1)
                O = Concatenate(axis=1)([O_g0, O_g1])
        elif d.model == "complex":
            O = ComplexConv2D(nb_fmaps1,
                              filter_size,
                              name=conv_name_base + '2a',
                              **convArgs)(O)
        elif (d.model == "complex_concat") or (d.model
                                               == "complex_concat_pwc_group"):
            O = ComplexConvConcat2D(nb_fmaps1 // 2,
                                    filter_size,
                                    name=conv_name_base + '2a',
                                    **convArgs)(O)
            if d.model == "complex_concat_pwc_group":
                O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
                O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
                O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g0',
                              **convArgs_real)(O_g0)
                O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g1',
                              **convArgs_real)(O_g1)
                O = Concatenate(axis=1)([O_g0, O_g1])
        else:
            print("Error: unknown model type")
            exit(-1)
    elif shortcut == 'projection':
        if d.spectral_pool_scheme == "proj":
            O = applySpectralPooling(O, d)

        if d.model == "real":
            O = Conv2D(nb_fmaps1,
                       filter_size,
                       name=conv_name_base + '2a',
                       strides=(2, 2),
                       **convArgs)(O)
        elif d.model == "real_dws":
            O = SeparableConv2D(nb_fmaps1,
                                filter_size,
                                name=conv_name_base + '2a',
                                strides=(2, 2),
                                **convArgs)(O)
        elif (d.model
              == "real_group") or (d.model == "real_group_pwc_full") or (
                  d.model == "real_group_pwc_group"):
            O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
            O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
            O_g0 = Conv2D(nb_fmaps1 // 2,
                          filter_size,
                          name=conv_name_base + '2a_g0',
                          strides=(2, 2),
                          **convArgs)(O_g0)
            O_g1 = Conv2D(nb_fmaps1 // 2,
                          filter_size,
                          name=conv_name_base + '2a_g1',
                          strides=(2, 2),
                          **convArgs)(O_g1)
            O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])(
                O_g0)
            O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])(
                O_g0)
            O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])(
                O_g1)
            O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])(
                O_g1)
            O = Concatenate(axis=1)([O_g00, O_g11, O_g01, O_g10])
            if d.model == "real_group_pwc_full":
                O = Conv2D(nb_fmaps1, (1, 1),
                           name=conv_name_base + '2a_pwc',
                           **convArgs)(O)
            elif d.model == "real_group_pwc_group":
                O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
                O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
                O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g0',
                              **convArgs_real)(O_g0)
                O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g1',
                              **convArgs_real)(O_g1)
                O = Concatenate(axis=1)([O_g0, O_g1])
        elif d.model == "complex":
            O = ComplexConv2D(nb_fmaps1,
                              filter_size,
                              name=conv_name_base + '2a',
                              strides=(2, 2),
                              **convArgs)(O)
        elif (d.model == "complex_concat") or (d.model
                                               == "complex_concat_pwc_group"):
            O = ComplexConvConcat2D(nb_fmaps1 // 2,
                                    filter_size,
                                    name=conv_name_base + '2a',
                                    strides=(2, 2),
                                    **convArgs)(O)
            if d.model == "complex_concat_pwc_group":
                O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
                O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
                O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g0',
                              **convArgs_real)(O_g0)
                O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g1',
                              **convArgs_real)(O_g1)
                O = Concatenate(axis=1)([O_g0, O_g1])
        else:
            print("Error: unknown model type")
            exit(-1)

    if d.model == "real":
        O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O)
        O = Activation(activation)(O)
        O = Conv2D(nb_fmaps2,
                   filter_size,
                   name=conv_name_base + '2b',
                   **convArgs)(O)
    elif d.model == "real_dws":
        O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O)
        O = Activation(activation)(O)
        O = SeparableConv2D(nb_fmaps2,
                            filter_size,
                            name=conv_name_base + '2b',
                            **convArgs)(O)
    elif (d.model == "real_group") or (d.model == "real_group_pwc_full") or (
            d.model == "real_group_pwc_group"):
        O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O)
        O = Activation(activation)(O)
        O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
        O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
        O_g0 = Conv2D(nb_fmaps2 // 2,
                      filter_size,
                      name=conv_name_base + '2b_g0',
                      **convArgs)(O_g0)
        O_g1 = Conv2D(nb_fmaps2 // 2,
                      filter_size,
                      name=conv_name_base + '2b_g1',
                      **convArgs)(O_g1)
        O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])(O_g0)
        O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])(O_g0)
        O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])(O_g1)
        O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])(O_g1)
        O = Concatenate(axis=1)([O_g00, O_g11, O_g01, O_g10])
        if d.model == "real_group_pwc_full":
            O = Conv2D(nb_fmaps2, (1, 1),
                       name=conv_name_base + '2b_pwc',
                       **convArgs)(O)
        elif d.model == "real_group_pwc_group":
            O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
            O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
            O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name=conv_name_base + '2b_pwc_g0',
                          **convArgs_real)(O_g0)
            O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name=conv_name_base + '2b_pwc_g1',
                          **convArgs_real)(O_g1)
            O = Concatenate(axis=1)([O_g0, O_g1])
    elif d.model == "complex":
        O = ComplexBN(name=bn_name_base + '_2b', **bnArgs)(O)
        if d.aact == "complex_joint_relu":
            O = Lambda(ComplexJointReLU)(O)
        else:
            O = Activation(activation)(O)
        O = ComplexConv2D(nb_fmaps2,
                          filter_size,
                          name=conv_name_base + '2b',
                          **convArgs)(O)
    elif (d.model == "complex_concat") or (d.model
                                           == "complex_concat_pwc_group"):
        O = ComplexBN(name=bn_name_base + '_2b', **bnArgs)(O)
        if d.aact == "complex_joint_relu":
            O = Lambda(ComplexJointReLU)(O)
        else:
            O = Activation(activation)(O)
        O = ComplexConvConcat2D(nb_fmaps2 // 2,
                                filter_size,
                                name=conv_name_base + '2b',
                                **convArgs)(O)
        if d.model == "complex_concat_pwc_group":
            O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
            O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
            O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name=conv_name_base + '2b_pwc_g0',
                          **convArgs_real)(O_g0)
            O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name=conv_name_base + '2b_pwc_g1',
                          **convArgs_real)(O_g1)
            O = Concatenate(axis=1)([O_g0, O_g1])

    else:
        print("Error: unknown model type")
        exit(-1)

    if shortcut == 'regular':
        O = Add()([O, I])
    elif shortcut == 'projection':
        if d.spectral_pool_scheme == "proj":
            I = applySpectralPooling(I, d)
        if d.model == "real":
            X = Conv2D(
                nb_fmaps2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)
            O = Concatenate(channel_axis)([X, O])
        elif d.model == "real_dws":
            X = SeparableConv2D(
                nb_fmaps2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)
            O = Concatenate(channel_axis)([X, O])
        elif (d.model
              == "real_group") or (d.model == "real_group_pwc_full") or (
                  d.model == "real_group_pwc_group"):
            I_g0 = Lambda(lambda I: I[:, :(I.shape[1] // 2), :, :])(I)
            I_g1 = Lambda(lambda I: I[:, (I.shape[1] // 2):, :, :])(I)
            X_g0 = Conv2D(
                nb_fmaps2 // 2, (1, 1),
                name=conv_name_base + '1_g0',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I_g0)
            X_g1 = Conv2D(
                nb_fmaps2 // 2, (1, 1),
                name=conv_name_base + '1_g1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I_g1)
            X_g00 = Lambda(lambda X_g0: X_g0[:, :(X_g0.shape[1] // 2), :, :])(
                X_g0)
            X_g01 = Lambda(lambda X_g0: X_g0[:, (X_g0.shape[1] // 2):, :, :])(
                X_g0)
            X_g10 = Lambda(lambda X_g1: X_g1[:, :(X_g1.shape[1] // 2), :, :])(
                X_g1)
            X_g11 = Lambda(lambda X_g1: X_g1[:, (X_g1.shape[1] // 2):, :, :])(
                X_g1)
            X = Concatenate(axis=1)([X_g00, X_g11, X_g01, X_g10])
            if d.model == "real_group_pwc_full":
                X = Conv2D(nb_fmaps2, (1, 1),
                           name=conv_name_base + '1_pwc',
                           **convArgs)(X)
            elif d.model == "real_group_pwc_group":
                X_g0 = Lambda(lambda X: X[:, :(X.shape[1] // 2), :, :])(X)
                X_g1 = Lambda(lambda X: X[:, (X.shape[1] // 2):, :, :])(X)
                X_g0 = Conv2D(int(X.shape[1] // 2), (1, 1),
                              name=conv_name_base + '1_pwc_g0',
                              **convArgs_real)(X_g0)
                X_g1 = Conv2D(int(X.shape[1] // 2), (1, 1),
                              name=conv_name_base + '1_pwc_g1',
                              **convArgs_real)(X_g1)
                X = Concatenate(axis=1)([X_g0, X_g1])
            O = Concatenate(channel_axis)([X, O])
        elif d.model == "complex":
            X = ComplexConv2D(
                nb_fmaps2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)
            O_real = Concatenate(channel_axis)([GetReal()(X), GetReal()(O)])
            O_imag = Concatenate(channel_axis)([GetImag()(X), GetImag()(O)])
            O = Concatenate(1)([O_real, O_imag])
        elif (d.model == "complex_concat") or (d.model
                                               == "complex_concat_pwc_group"):
            X = ComplexConvConcat2D(
                nb_fmaps2 // 2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)
            if d.model == "complex_concat_pwc_group":
                X_g0 = Lambda(lambda X: X[:, :(X.shape[1] // 2), :, :])(X)
                X_g1 = Lambda(lambda X: X[:, (X.shape[1] // 2):, :, :])(X)
                X_g0 = Conv2D(int(X.shape[1] // 2), (1, 1),
                              name=conv_name_base + '1_pwc_g0',
                              **convArgs_real)(X_g0)
                X_g1 = Conv2D(int(X.shape[1] // 2), (1, 1),
                              name=conv_name_base + '1_pwc_g1',
                              **convArgs_real)(X_g1)
                X = Concatenate(axis=1)([X_g0, X_g1])
            O_real = Concatenate(channel_axis)([GetReal()(X), GetReal()(O)])
            O_imag = Concatenate(channel_axis)([GetImag()(X), GetImag()(O)])
            O = Concatenate(1)([O_real, O_imag])
        else:
            print("Error: unknown model type")
            exit(-1)

    return O
def getResnetModel(d):
    n = d.num_blocks
    sf = d.start_filter
    dataset = d.dataset
    activation = d.act
    advanced_act = d.aact
    drop_prob = d.dropout
    if "mnist" in dataset:
        inputShape = (1, 28, 28) if K.image_dim_ordering() == "th" else (28,
                                                                         28, 1)
    else:
        inputShape = (3, 32, 32) if K.image_dim_ordering() == "th" else (32,
                                                                         32, 3)
    channelAxis = 1 if K.image_data_format() == 'channels_first' else -1
    filsize = (3, 3)
    convArgs = {
        "padding": "same",
        "use_bias": False,
        "kernel_regularizer": l2(0.0001),
    }
    bnArgs = {"axis": channelAxis, "momentum": 0.9, "epsilon": 1e-04}

    import copy

    convArgs_real = copy.deepcopy(convArgs)

    # if   d.model == "real":
    if "real" in d.model:
        sf *= 2
        convArgs.update({"kernel_initializer": Orthogonal(float(np.sqrt(2)))})
        convArgs_real = convArgs
    # elif d.model == "complex":
    elif "complex" in d.model:
        convArgs.update({
            "spectral_parametrization": d.spectral_param,
            "kernel_initializer": d.comp_init
        })

    #
    # Input Layer
    #

    I = Input(shape=inputShape)

    #
    # Stage 1
    #

    O = learnConcatRealImagBlock(I, (1, 1), (3, 3), 0, '0', convArgs, bnArgs,
                                 d)
    O = Concatenate(channelAxis)([I, O])
    if d.model == "real":
        O = Conv2D(sf, filsize, name='conv1', **convArgs)(O)
        O = BatchNormalization(name="bn_conv1_2a", **bnArgs)(O)
    elif d.model == "real_dws":
        O = SeparableConv2D(sf, filsize, name='conv1', **convArgs)(O)
        O = BatchNormalization(name="bn_conv1_2a", **bnArgs)(O)
    elif (d.model == "real_group") or (d.model == "real_group_pwc_full") or (
            d.model == "real_group_pwc_group"):
        O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
        O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
        O_g0 = Conv2D(sf // 2, filsize, name='conv1_g0', **convArgs)(O_g0)
        O_g1 = Conv2D(sf // 2, filsize, name='conv1_g1', **convArgs)(O_g1)
        O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])(O_g0)
        O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])(O_g0)
        O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])(O_g1)
        O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])(O_g1)
        O = Concatenate(axis=1)([O_g00, O_g11, O_g01, O_g10])
        if d.model == "real_group_pwc_full":
            O = Conv2D(sf, (1, 1), name='conv1_pwc', **convArgs)(O)
        elif d.model == "real_group_pwc_group":
            O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
            O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
            O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name='conv1_pwc_g0',
                          **convArgs_real)(O_g0)
            O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name='conv1_pwc_g1',
                          **convArgs_real)(O_g1)
            O = Concatenate(axis=1)([O_g0, O_g1])
        O = BatchNormalization(name="bn_conv1_2a", **bnArgs)(O)
    elif d.model == "complex":
        O = ComplexConv2D(sf, filsize, name='conv1', **convArgs)(O)
        O = ComplexBN(name="bn_conv1_2a", **bnArgs)(O)
    elif (d.model == "complex_concat") or (d.model
                                           == "complex_concat_pwc_group"):
        O = ComplexConvConcat2D(sf // 2, filsize, name='conv1', **convArgs)(O)
        O = ComplexBN(name="bn_conv1_2a", **bnArgs)(O)
        if d.model == "complex_concat_pwc_group":
            O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
            O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
            O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name='conv1_pwc_g0',
                          **convArgs_real)(O_g0)
            O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name='conv1_pwc_g1',
                          **convArgs_real)(O_g1)
            O = Concatenate(axis=1)([O_g0, O_g1])
    else:
        print("Error: unknown model type")
        exit(-1)

    if d.aact == "complex_joint_relu":
        O = Lambda(ComplexJointReLU)(O)
    else:
        O = Activation(activation)(O)

    #
    # Stage 2
    #

    for i in xrange(n):
        O = getResidualBlock(O, filsize, [sf, sf], 2, str(i), 'regular',
                             convArgs, convArgs_real, bnArgs, d)
        if i == n // 2 and d.spectral_pool_scheme == "stagemiddle":
            O = applySpectralPooling(O, d)

    #
    # Stage 3
    #

    O = getResidualBlock(O, filsize, [sf, sf], 3, '0', 'projection', convArgs,
                         convArgs_real, bnArgs, d)
    if d.spectral_pool_scheme == "nodownsample":
        O = applySpectralPooling(O, d)

    for i in xrange(n - 1):
        O = getResidualBlock(O, filsize, [sf * 2, sf * 2], 3, str(i + 1),
                             'regular', convArgs, convArgs_real, bnArgs, d)
        if i == n // 2 and d.spectral_pool_scheme == "stagemiddle":
            O = applySpectralPooling(O, d)

    #
    # Stage 4
    #

    O = getResidualBlock(O, filsize, [sf * 2, sf * 2], 4, '0', 'projection',
                         convArgs, convArgs_real, bnArgs, d)
    if d.spectral_pool_scheme == "nodownsample":
        O = applySpectralPooling(O, d)

    for i in xrange(n - 1):
        O = getResidualBlock(O, filsize, [sf * 4, sf * 4], 4, str(i + 1),
                             'regular', convArgs, convArgs_real, bnArgs, d)
        if i == n // 2 and d.spectral_pool_scheme == "stagemiddle":
            O = applySpectralPooling(O, d)

    #
    # Pooling
    #

    if d.spectral_pool_scheme == "nodownsample":
        O = applySpectralPooling(O, d)
        if "mnist" in dataset:
            O = AveragePooling2D(pool_size=(28, 28))(O)
        else:
            O = AveragePooling2D(pool_size=(32, 32))(O)
    else:
        if "mnist" in dataset:
            O = AveragePooling2D(pool_size=(7, 7))(O)
        else:
            O = AveragePooling2D(pool_size=(8, 8))(O)

    #
    # Flatten
    #

    O = Flatten()(O)

    #
    # Dense
    #

    if dataset == 'cifar10':
        O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O)
    elif dataset == 'cifar100':
        O = Dense(100, activation='softmax', kernel_regularizer=l2(0.0001))(O)
    elif dataset == 'svhn':
        O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O)
    elif dataset == 'mnist':
        O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O)
    elif dataset == 'fashion_mnist':
        O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O)
    else:
        raise ValueError("Unknown dataset " + d.dataset)

    # Return the model
    return Model(I, O)