def comResnet(input, d): activation = d.act inputShape = (100, 100, 3) channelAxis = 1 if K.image_data_format() == 'channels_first' else -1 convArgs = { "padding": "same", "use_bias": False, "kernel_regularizer": l2(0.0001) } bnArgs = {"axis": channelAxis, "momentum": 0.9, "epsilon": 1e-04} convArgs.update({ "spectral_parametrization": d.spectral_param, "kernel_initializer": d.comp_init }) O = learnConcatRealImagBlock(input, (1, 1), (3, 3), 0, '0', convArgs, bnArgs, d) #O = tf.concat([input, O], 1) O = Concatenate(channelAxis)([input, O]) O = ComplexConv2D(filters=64, kernel_size=9, name='conv1', **convArgs)(O) O = ComplexBN(name='bn conv1 2a', **bnArgs)(O) O = tf.nn.relu(O) # residual O = getResidualBlock(O, (3, 3), [64, 64], 2, '0', 'regular', convArgs, bnArgs, d) O = getResidualBlock(O, (3, 3), [64, 64], 2, '0', 'regular', convArgs, bnArgs, d) O = getResidualBlock(O, (3, 3), [64, 64], 2, '0', 'regular', convArgs, bnArgs, d) O = getResidualBlock(O, (3, 3), [64, 64], 2, '0', 'regular', convArgs, bnArgs, d) O = ComplexConv2D(filters=64, kernel_size=9, name='conv2', **convArgs)(O) O = ComplexBN(name='bn conv2 1a', **bnArgs)(O) O = tf.nn.relu(O) O = ComplexConv2D(filters=64, kernel_size=9, name='conv3', **convArgs)(O) O = ComplexBN(name='bn conv3 1a', **bnArgs)(O) O = tf.nn.relu(O) O = ComplexConv2D(filters=3, kernel_size=9, name='conv4', **convArgs)(O) O = ComplexBN(name='bn conv4 1a', **bnArgs)(O) O = tf.nn.tanh(O) return O
def getResidualBlock(I, filter_size, featmaps, stage, block, shortcut, convArgs, bnArgs, d): """Get residual block.""" activation = d.act drop_prob = d.dropout nb_fmaps1, nb_fmaps2 = featmaps conv_name_base = 'res' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch' if K.image_data_format() == 'channels_first' and K.ndim(I) != 3: channel_axis = 1 else: channel_axis = -1 if d.model == "real": O = BatchNormalization(name=bn_name_base + '_2a', **bnArgs)(I) elif d.model == "complex": O = ComplexBN(name=bn_name_base + '_2a', **bnArgs)(I) O = Activation(activation)(O) if shortcut == 'regular' or d.spectral_pool_scheme == "nodownsample": if d.model == "real": O = Conv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', **convArgs)(O) elif d.model == "complex": O = ComplexConv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', **convArgs)(O) elif shortcut == 'projection': if d.spectral_pool_scheme == "proj": O = applySpectralPooling(O, d) if d.model == "real": O = Conv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', strides=(2, 2), **convArgs)(O) elif d.model == "complex": O = ComplexConv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', strides=(2, 2), **convArgs)(O) if d.model == "real": O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O) O = Activation(activation)(O) O = Conv2D(nb_fmaps2, filter_size, name=conv_name_base + '2b', **convArgs)(O) elif d.model == "complex": O = ComplexBN(name=bn_name_base + '_2b', **bnArgs)(O) O = Activation(activation)(O) O = ComplexConv2D(nb_fmaps2, filter_size, name=conv_name_base + '2b', **convArgs)(O) if shortcut == 'regular': O = Add()([O, I]) elif shortcut == 'projection': if d.spectral_pool_scheme == "proj": I = applySpectralPooling(I, d) if d.model == "real": X = Conv2D( nb_fmaps2, (1, 1), name=conv_name_base + '1', strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else (1, 1), **convArgs)(I) O = Concatenate(channel_axis)([X, O]) elif d.model == "complex": X = ComplexConv2D( nb_fmaps2, (1, 1), name=conv_name_base + '1', strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else (1, 1), **convArgs)(I) O_real = Concatenate(channel_axis)([GetReal()(X), GetReal()(O)]) O_imag = Concatenate(channel_axis)([GetImag()(X), GetImag()(O)]) O = Concatenate(1)([O_real, O_imag]) return O
def getResnetModel(d): n = d.num_blocks sf = d.start_filter dataset = d.dataset activation = d.act advanced_act = d.aact drop_prob = d.dropout inputShape = (3, 32, 32) if K.image_dim_ordering() == "th" else (32, 32, 3) channelAxis = 1 if K.image_data_format() == 'channels_first' else -1 filsize = (3, 3) convArgs = { "padding": "same", "use_bias": False, "kernel_regularizer": l2(0.0001), } bnArgs = {"axis": channelAxis, "momentum": 0.9, "epsilon": 1e-04} if d.model == "real": sf *= 2 convArgs.update({"kernel_initializer": Orthogonal(float(np.sqrt(2)))}) elif d.model == "complex": convArgs.update({ "spectral_parametrization": d.spectral_param, "kernel_initializer": d.comp_init }) # # Input Layer # I = Input(shape=inputShape) # # Stage 1 # O = learnConcatRealImagBlock(I, (1, 1), (3, 3), 0, '0', convArgs, bnArgs, d) O = Concatenate(channelAxis)([I, O]) if d.model == "real": O = Conv2D(sf, filsize, name='conv1', **convArgs)(O) O = BatchNormalization(name="bn_conv1_2a", **bnArgs)(O) else: O = ComplexConv2D(sf, filsize, name='conv1', **convArgs)(O) O = ComplexBN(name="bn_conv1_2a", **bnArgs)(O) O = Activation(activation)(O) # # Stage 2 # for i in range(n): O = getResidualBlock(O, filsize, [sf, sf], 2, str(i), 'regular', convArgs, bnArgs, d) if i == n // 2 and d.spectral_pool_scheme == "stagemiddle": O = applySpectralPooling(O, d) # # Stage 3 # O = getResidualBlock(O, filsize, [sf, sf], 3, '0', 'projection', convArgs, bnArgs, d) if d.spectral_pool_scheme == "nodownsample": O = applySpectralPooling(O, d) for i in range(n - 1): O = getResidualBlock(O, filsize, [sf * 2, sf * 2], 3, str(i + 1), 'regular', convArgs, bnArgs, d) if i == n // 2 and d.spectral_pool_scheme == "stagemiddle": O = applySpectralPooling(O, d) # # Stage 4 # O = getResidualBlock(O, filsize, [sf * 2, sf * 2], 4, '0', 'projection', convArgs, bnArgs, d) if d.spectral_pool_scheme == "nodownsample": O = applySpectralPooling(O, d) for i in range(n - 1): O = getResidualBlock(O, filsize, [sf * 4, sf * 4], 4, str(i + 1), 'regular', convArgs, bnArgs, d) if i == n // 2 and d.spectral_pool_scheme == "stagemiddle": O = applySpectralPooling(O, d) # # Pooling # if d.spectral_pool_scheme == "nodownsample": O = applySpectralPooling(O, d) O = AveragePooling2D(pool_size=(32, 32))(O) else: O = AveragePooling2D(pool_size=(8, 8))(O) # # Flatten # O = Flatten()(O) # # Dense # if dataset == 'cifar10': O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O) elif dataset == 'cifar100': O = Dense(100, activation='softmax', kernel_regularizer=l2(0.0001))(O) elif dataset == 'svhn': O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O) else: raise ValueError("Unknown dataset " + d.dataset) # Return the model return Model(I, O)
# model.add(Conv2D(32, kernel_size=(3, 3), # activation='relu', # kernel_initializer='he_normal', # input_shape=input_shape)) model.add( Conv2D(32, kernel_size=(3, 3), kernel_initializer='he_normal', input_shape=input_shape)) model.add(BatchNormalization(**bnArgs)) model.add(Activation('relu')) model.add(MaxPooling2D((2, 2))) model.add(Dropout(0.25)) if COMPLEX_MODE: # model.add(ComplexConv2D(32, (3, 3))) model.add(ComplexConv2D(32, (3, 3), **convArgs)) model.add(ComplexBN(**bnArgs)) else: # model.add(Conv2D(64, (3, 3), activation='relu')) # model.add(Conv2D(32*2, (3, 3))) model.add(Conv2D(32 * 2, (3, 3), **convArgs)) model.add(BatchNormalization(**bnArgs)) model.add(Activation('relu')) # O = ComplexConv2D(sf, filsize, name='conv1', **convArgs)(O) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) if COMPLEX_MODE: # model.add(ComplexConv2D(64, (3, 3))) model.add(ComplexConv2D(64, (3, 3), **convArgs)) model.add(ComplexBN(**bnArgs)) else:
def getResidualBlock(I, filter_size, featmaps, stage, block, shortcut, convArgs, convArgs_real, bnArgs, d): """Get residual block.""" activation = d.act drop_prob = d.dropout nb_fmaps1, nb_fmaps2 = featmaps conv_name_base = 'res' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch' if K.image_data_format() == 'channels_first' and K.ndim(I) != 3: channel_axis = 1 else: channel_axis = -1 # if d.model == "real": if "real" in d.model: O = BatchNormalization(name=bn_name_base + '_2a', **bnArgs)(I) # elif d.model == "complex": elif "complex" in d.model: O = ComplexBN(name=bn_name_base + '_2a', **bnArgs)(I) if d.aact == "complex_joint_relu": O = Lambda(ComplexJointReLU)(O) else: O = Activation(activation)(O) if shortcut == 'regular' or d.spectral_pool_scheme == "nodownsample": if d.model == "real": O = Conv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', **convArgs)(O) elif d.model == "real_dws": O = SeparableConv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', **convArgs)(O) elif (d.model == "real_group") or (d.model == "real_group_pwc_full") or ( d.model == "real_group_pwc_group"): O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(nb_fmaps1 // 2, filter_size, name=conv_name_base + '2a_g0', **convArgs)(O_g0) O_g1 = Conv2D(nb_fmaps1 // 2, filter_size, name=conv_name_base + '2a_g1', **convArgs)(O_g1) O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])( O_g0) O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])( O_g0) O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])( O_g1) O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])( O_g1) O = Concatenate(axis=1)( [O_g00, O_g11, O_g01, O_g10] ) # This ordering allows permutation of odd-numbered outputs (O_g0, O_g1). if d.model == "real_group_pwc_full": O = Conv2D(nb_fmaps1, (1, 1), name=conv_name_base + '2a_pwc', **convArgs)(O) elif d.model == "real_group_pwc_group": O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2a_pwc_g0', **convArgs_real)(O_g0) O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2a_pwc_g1', **convArgs_real)(O_g1) O = Concatenate(axis=1)([O_g0, O_g1]) elif d.model == "complex": O = ComplexConv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', **convArgs)(O) elif (d.model == "complex_concat") or (d.model == "complex_concat_pwc_group"): O = ComplexConvConcat2D(nb_fmaps1 // 2, filter_size, name=conv_name_base + '2a', **convArgs)(O) if d.model == "complex_concat_pwc_group": O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2a_pwc_g0', **convArgs_real)(O_g0) O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2a_pwc_g1', **convArgs_real)(O_g1) O = Concatenate(axis=1)([O_g0, O_g1]) else: print("Error: unknown model type") exit(-1) elif shortcut == 'projection': if d.spectral_pool_scheme == "proj": O = applySpectralPooling(O, d) if d.model == "real": O = Conv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', strides=(2, 2), **convArgs)(O) elif d.model == "real_dws": O = SeparableConv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', strides=(2, 2), **convArgs)(O) elif (d.model == "real_group") or (d.model == "real_group_pwc_full") or ( d.model == "real_group_pwc_group"): O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(nb_fmaps1 // 2, filter_size, name=conv_name_base + '2a_g0', strides=(2, 2), **convArgs)(O_g0) O_g1 = Conv2D(nb_fmaps1 // 2, filter_size, name=conv_name_base + '2a_g1', strides=(2, 2), **convArgs)(O_g1) O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])( O_g0) O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])( O_g0) O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])( O_g1) O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])( O_g1) O = Concatenate(axis=1)([O_g00, O_g11, O_g01, O_g10]) if d.model == "real_group_pwc_full": O = Conv2D(nb_fmaps1, (1, 1), name=conv_name_base + '2a_pwc', **convArgs)(O) elif d.model == "real_group_pwc_group": O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2a_pwc_g0', **convArgs_real)(O_g0) O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2a_pwc_g1', **convArgs_real)(O_g1) O = Concatenate(axis=1)([O_g0, O_g1]) elif d.model == "complex": O = ComplexConv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', strides=(2, 2), **convArgs)(O) elif (d.model == "complex_concat") or (d.model == "complex_concat_pwc_group"): O = ComplexConvConcat2D(nb_fmaps1 // 2, filter_size, name=conv_name_base + '2a', strides=(2, 2), **convArgs)(O) if d.model == "complex_concat_pwc_group": O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2a_pwc_g0', **convArgs_real)(O_g0) O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2a_pwc_g1', **convArgs_real)(O_g1) O = Concatenate(axis=1)([O_g0, O_g1]) else: print("Error: unknown model type") exit(-1) if d.model == "real": O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O) O = Activation(activation)(O) O = Conv2D(nb_fmaps2, filter_size, name=conv_name_base + '2b', **convArgs)(O) elif d.model == "real_dws": O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O) O = Activation(activation)(O) O = SeparableConv2D(nb_fmaps2, filter_size, name=conv_name_base + '2b', **convArgs)(O) elif (d.model == "real_group") or (d.model == "real_group_pwc_full") or ( d.model == "real_group_pwc_group"): O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O) O = Activation(activation)(O) O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(nb_fmaps2 // 2, filter_size, name=conv_name_base + '2b_g0', **convArgs)(O_g0) O_g1 = Conv2D(nb_fmaps2 // 2, filter_size, name=conv_name_base + '2b_g1', **convArgs)(O_g1) O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])(O_g0) O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])(O_g0) O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])(O_g1) O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])(O_g1) O = Concatenate(axis=1)([O_g00, O_g11, O_g01, O_g10]) if d.model == "real_group_pwc_full": O = Conv2D(nb_fmaps2, (1, 1), name=conv_name_base + '2b_pwc', **convArgs)(O) elif d.model == "real_group_pwc_group": O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2b_pwc_g0', **convArgs_real)(O_g0) O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2b_pwc_g1', **convArgs_real)(O_g1) O = Concatenate(axis=1)([O_g0, O_g1]) elif d.model == "complex": O = ComplexBN(name=bn_name_base + '_2b', **bnArgs)(O) if d.aact == "complex_joint_relu": O = Lambda(ComplexJointReLU)(O) else: O = Activation(activation)(O) O = ComplexConv2D(nb_fmaps2, filter_size, name=conv_name_base + '2b', **convArgs)(O) elif (d.model == "complex_concat") or (d.model == "complex_concat_pwc_group"): O = ComplexBN(name=bn_name_base + '_2b', **bnArgs)(O) if d.aact == "complex_joint_relu": O = Lambda(ComplexJointReLU)(O) else: O = Activation(activation)(O) O = ComplexConvConcat2D(nb_fmaps2 // 2, filter_size, name=conv_name_base + '2b', **convArgs)(O) if d.model == "complex_concat_pwc_group": O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2b_pwc_g0', **convArgs_real)(O_g0) O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2b_pwc_g1', **convArgs_real)(O_g1) O = Concatenate(axis=1)([O_g0, O_g1]) else: print("Error: unknown model type") exit(-1) if shortcut == 'regular': O = Add()([O, I]) elif shortcut == 'projection': if d.spectral_pool_scheme == "proj": I = applySpectralPooling(I, d) if d.model == "real": X = Conv2D( nb_fmaps2, (1, 1), name=conv_name_base + '1', strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else (1, 1), **convArgs)(I) O = Concatenate(channel_axis)([X, O]) elif d.model == "real_dws": X = SeparableConv2D( nb_fmaps2, (1, 1), name=conv_name_base + '1', strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else (1, 1), **convArgs)(I) O = Concatenate(channel_axis)([X, O]) elif (d.model == "real_group") or (d.model == "real_group_pwc_full") or ( d.model == "real_group_pwc_group"): I_g0 = Lambda(lambda I: I[:, :(I.shape[1] // 2), :, :])(I) I_g1 = Lambda(lambda I: I[:, (I.shape[1] // 2):, :, :])(I) X_g0 = Conv2D( nb_fmaps2 // 2, (1, 1), name=conv_name_base + '1_g0', strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else (1, 1), **convArgs)(I_g0) X_g1 = Conv2D( nb_fmaps2 // 2, (1, 1), name=conv_name_base + '1_g1', strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else (1, 1), **convArgs)(I_g1) X_g00 = Lambda(lambda X_g0: X_g0[:, :(X_g0.shape[1] // 2), :, :])( X_g0) X_g01 = Lambda(lambda X_g0: X_g0[:, (X_g0.shape[1] // 2):, :, :])( X_g0) X_g10 = Lambda(lambda X_g1: X_g1[:, :(X_g1.shape[1] // 2), :, :])( X_g1) X_g11 = Lambda(lambda X_g1: X_g1[:, (X_g1.shape[1] // 2):, :, :])( X_g1) X = Concatenate(axis=1)([X_g00, X_g11, X_g01, X_g10]) if d.model == "real_group_pwc_full": X = Conv2D(nb_fmaps2, (1, 1), name=conv_name_base + '1_pwc', **convArgs)(X) elif d.model == "real_group_pwc_group": X_g0 = Lambda(lambda X: X[:, :(X.shape[1] // 2), :, :])(X) X_g1 = Lambda(lambda X: X[:, (X.shape[1] // 2):, :, :])(X) X_g0 = Conv2D(int(X.shape[1] // 2), (1, 1), name=conv_name_base + '1_pwc_g0', **convArgs_real)(X_g0) X_g1 = Conv2D(int(X.shape[1] // 2), (1, 1), name=conv_name_base + '1_pwc_g1', **convArgs_real)(X_g1) X = Concatenate(axis=1)([X_g0, X_g1]) O = Concatenate(channel_axis)([X, O]) elif d.model == "complex": X = ComplexConv2D( nb_fmaps2, (1, 1), name=conv_name_base + '1', strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else (1, 1), **convArgs)(I) O_real = Concatenate(channel_axis)([GetReal()(X), GetReal()(O)]) O_imag = Concatenate(channel_axis)([GetImag()(X), GetImag()(O)]) O = Concatenate(1)([O_real, O_imag]) elif (d.model == "complex_concat") or (d.model == "complex_concat_pwc_group"): X = ComplexConvConcat2D( nb_fmaps2 // 2, (1, 1), name=conv_name_base + '1', strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else (1, 1), **convArgs)(I) if d.model == "complex_concat_pwc_group": X_g0 = Lambda(lambda X: X[:, :(X.shape[1] // 2), :, :])(X) X_g1 = Lambda(lambda X: X[:, (X.shape[1] // 2):, :, :])(X) X_g0 = Conv2D(int(X.shape[1] // 2), (1, 1), name=conv_name_base + '1_pwc_g0', **convArgs_real)(X_g0) X_g1 = Conv2D(int(X.shape[1] // 2), (1, 1), name=conv_name_base + '1_pwc_g1', **convArgs_real)(X_g1) X = Concatenate(axis=1)([X_g0, X_g1]) O_real = Concatenate(channel_axis)([GetReal()(X), GetReal()(O)]) O_imag = Concatenate(channel_axis)([GetImag()(X), GetImag()(O)]) O = Concatenate(1)([O_real, O_imag]) else: print("Error: unknown model type") exit(-1) return O
def getResnetModel(d): n = d.num_blocks sf = d.start_filter dataset = d.dataset activation = d.act advanced_act = d.aact drop_prob = d.dropout if "mnist" in dataset: inputShape = (1, 28, 28) if K.image_dim_ordering() == "th" else (28, 28, 1) else: inputShape = (3, 32, 32) if K.image_dim_ordering() == "th" else (32, 32, 3) channelAxis = 1 if K.image_data_format() == 'channels_first' else -1 filsize = (3, 3) convArgs = { "padding": "same", "use_bias": False, "kernel_regularizer": l2(0.0001), } bnArgs = {"axis": channelAxis, "momentum": 0.9, "epsilon": 1e-04} import copy convArgs_real = copy.deepcopy(convArgs) # if d.model == "real": if "real" in d.model: sf *= 2 convArgs.update({"kernel_initializer": Orthogonal(float(np.sqrt(2)))}) convArgs_real = convArgs # elif d.model == "complex": elif "complex" in d.model: convArgs.update({ "spectral_parametrization": d.spectral_param, "kernel_initializer": d.comp_init }) # # Input Layer # I = Input(shape=inputShape) # # Stage 1 # O = learnConcatRealImagBlock(I, (1, 1), (3, 3), 0, '0', convArgs, bnArgs, d) O = Concatenate(channelAxis)([I, O]) if d.model == "real": O = Conv2D(sf, filsize, name='conv1', **convArgs)(O) O = BatchNormalization(name="bn_conv1_2a", **bnArgs)(O) elif d.model == "real_dws": O = SeparableConv2D(sf, filsize, name='conv1', **convArgs)(O) O = BatchNormalization(name="bn_conv1_2a", **bnArgs)(O) elif (d.model == "real_group") or (d.model == "real_group_pwc_full") or ( d.model == "real_group_pwc_group"): O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(sf // 2, filsize, name='conv1_g0', **convArgs)(O_g0) O_g1 = Conv2D(sf // 2, filsize, name='conv1_g1', **convArgs)(O_g1) O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])(O_g0) O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])(O_g0) O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])(O_g1) O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])(O_g1) O = Concatenate(axis=1)([O_g00, O_g11, O_g01, O_g10]) if d.model == "real_group_pwc_full": O = Conv2D(sf, (1, 1), name='conv1_pwc', **convArgs)(O) elif d.model == "real_group_pwc_group": O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1), name='conv1_pwc_g0', **convArgs_real)(O_g0) O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1), name='conv1_pwc_g1', **convArgs_real)(O_g1) O = Concatenate(axis=1)([O_g0, O_g1]) O = BatchNormalization(name="bn_conv1_2a", **bnArgs)(O) elif d.model == "complex": O = ComplexConv2D(sf, filsize, name='conv1', **convArgs)(O) O = ComplexBN(name="bn_conv1_2a", **bnArgs)(O) elif (d.model == "complex_concat") or (d.model == "complex_concat_pwc_group"): O = ComplexConvConcat2D(sf // 2, filsize, name='conv1', **convArgs)(O) O = ComplexBN(name="bn_conv1_2a", **bnArgs)(O) if d.model == "complex_concat_pwc_group": O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1), name='conv1_pwc_g0', **convArgs_real)(O_g0) O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1), name='conv1_pwc_g1', **convArgs_real)(O_g1) O = Concatenate(axis=1)([O_g0, O_g1]) else: print("Error: unknown model type") exit(-1) if d.aact == "complex_joint_relu": O = Lambda(ComplexJointReLU)(O) else: O = Activation(activation)(O) # # Stage 2 # for i in xrange(n): O = getResidualBlock(O, filsize, [sf, sf], 2, str(i), 'regular', convArgs, convArgs_real, bnArgs, d) if i == n // 2 and d.spectral_pool_scheme == "stagemiddle": O = applySpectralPooling(O, d) # # Stage 3 # O = getResidualBlock(O, filsize, [sf, sf], 3, '0', 'projection', convArgs, convArgs_real, bnArgs, d) if d.spectral_pool_scheme == "nodownsample": O = applySpectralPooling(O, d) for i in xrange(n - 1): O = getResidualBlock(O, filsize, [sf * 2, sf * 2], 3, str(i + 1), 'regular', convArgs, convArgs_real, bnArgs, d) if i == n // 2 and d.spectral_pool_scheme == "stagemiddle": O = applySpectralPooling(O, d) # # Stage 4 # O = getResidualBlock(O, filsize, [sf * 2, sf * 2], 4, '0', 'projection', convArgs, convArgs_real, bnArgs, d) if d.spectral_pool_scheme == "nodownsample": O = applySpectralPooling(O, d) for i in xrange(n - 1): O = getResidualBlock(O, filsize, [sf * 4, sf * 4], 4, str(i + 1), 'regular', convArgs, convArgs_real, bnArgs, d) if i == n // 2 and d.spectral_pool_scheme == "stagemiddle": O = applySpectralPooling(O, d) # # Pooling # if d.spectral_pool_scheme == "nodownsample": O = applySpectralPooling(O, d) if "mnist" in dataset: O = AveragePooling2D(pool_size=(28, 28))(O) else: O = AveragePooling2D(pool_size=(32, 32))(O) else: if "mnist" in dataset: O = AveragePooling2D(pool_size=(7, 7))(O) else: O = AveragePooling2D(pool_size=(8, 8))(O) # # Flatten # O = Flatten()(O) # # Dense # if dataset == 'cifar10': O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O) elif dataset == 'cifar100': O = Dense(100, activation='softmax', kernel_regularizer=l2(0.0001))(O) elif dataset == 'svhn': O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O) elif dataset == 'mnist': O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O) elif dataset == 'fashion_mnist': O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O) else: raise ValueError("Unknown dataset " + d.dataset) # Return the model return Model(I, O)