def get_shallow_convnet(window_size=4096, channels=2, output_size=84): inputs = Input(shape=(window_size, channels)) conv = ComplexConv1D(32, 512, strides=16, activation='relu')(inputs) pool = AveragePooling1D(pool_size=4, strides=2)(conv) pool = Permute([2, 1])(pool) flattened = Flatten()(pool) dense = ComplexDense(2048, activation='relu')(flattened) predictions = ComplexDense(output_size, activation='sigmoid', bias_initializer=Constant(value=-5))(dense) predictions = GetReal(predictions) model = Model(inputs=inputs, outputs=predictions) model.compile(optimizer=Adam(lr=1e-4), loss='binary_crossentropy', metrics=['accuracy']) return model
def getResidualBlock(I, filter_size, featmaps, stage, block, shortcut, convArgs, bnArgs, d): """Get residual block.""" activation = d.act drop_prob = d.dropout nb_fmaps1, nb_fmaps2 = featmaps conv_name_base = 'res' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch' if K.image_data_format() == 'channels_first' and K.ndim(I) != 3: channel_axis = 1 else: channel_axis = -1 if d.model == "real": O = BatchNormalization(name=bn_name_base + '_2a', **bnArgs)(I) elif d.model == "complex": O = ComplexBN(name=bn_name_base + '_2a', **bnArgs)(I) O = Activation(activation)(O) if shortcut == 'regular' or d.spectral_pool_scheme == "nodownsample": if d.model == "real": O = Conv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', **convArgs)(O) elif d.model == "complex": O = ComplexConv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', **convArgs)(O) elif shortcut == 'projection': if d.spectral_pool_scheme == "proj": O = applySpectralPooling(O, d) if d.model == "real": O = Conv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', strides=(2, 2), **convArgs)(O) elif d.model == "complex": O = ComplexConv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', strides=(2, 2), **convArgs)(O) if d.model == "real": O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O) O = Activation(activation)(O) O = Conv2D(nb_fmaps2, filter_size, name=conv_name_base + '2b', **convArgs)(O) elif d.model == "complex": O = ComplexBN(name=bn_name_base + '_2b', **bnArgs)(O) O = Activation(activation)(O) O = ComplexConv2D(nb_fmaps2, filter_size, name=conv_name_base + '2b', **convArgs)(O) if shortcut == 'regular': O = Add()([O, I]) elif shortcut == 'projection': if d.spectral_pool_scheme == "proj": I = applySpectralPooling(I, d) if d.model == "real": X = Conv2D( nb_fmaps2, (1, 1), name=conv_name_base + '1', strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else (1, 1), **convArgs)(I) O = Concatenate(channel_axis)([X, O]) elif d.model == "complex": X = ComplexConv2D( nb_fmaps2, (1, 1), name=conv_name_base + '1', strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else (1, 1), **convArgs)(I) O_real = Concatenate(channel_axis)([GetReal()(X), GetReal()(O)]) O_imag = Concatenate(channel_axis)([GetImag()(X), GetImag()(O)]) O = Concatenate(1)([O_real, O_imag]) return O
def getResidualBlock(I, filter_size, featmaps, stage, block, shortcut, convArgs, convArgs_real, bnArgs, d): """Get residual block.""" activation = d.act drop_prob = d.dropout nb_fmaps1, nb_fmaps2 = featmaps conv_name_base = 'res' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch' if K.image_data_format() == 'channels_first' and K.ndim(I) != 3: channel_axis = 1 else: channel_axis = -1 # if d.model == "real": if "real" in d.model: O = BatchNormalization(name=bn_name_base + '_2a', **bnArgs)(I) # elif d.model == "complex": elif "complex" in d.model: O = ComplexBN(name=bn_name_base + '_2a', **bnArgs)(I) if d.aact == "complex_joint_relu": O = Lambda(ComplexJointReLU)(O) else: O = Activation(activation)(O) if shortcut == 'regular' or d.spectral_pool_scheme == "nodownsample": if d.model == "real": O = Conv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', **convArgs)(O) elif d.model == "real_dws": O = SeparableConv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', **convArgs)(O) elif (d.model == "real_group") or (d.model == "real_group_pwc_full") or ( d.model == "real_group_pwc_group"): O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(nb_fmaps1 // 2, filter_size, name=conv_name_base + '2a_g0', **convArgs)(O_g0) O_g1 = Conv2D(nb_fmaps1 // 2, filter_size, name=conv_name_base + '2a_g1', **convArgs)(O_g1) O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])( O_g0) O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])( O_g0) O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])( O_g1) O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])( O_g1) O = Concatenate(axis=1)( [O_g00, O_g11, O_g01, O_g10] ) # This ordering allows permutation of odd-numbered outputs (O_g0, O_g1). if d.model == "real_group_pwc_full": O = Conv2D(nb_fmaps1, (1, 1), name=conv_name_base + '2a_pwc', **convArgs)(O) elif d.model == "real_group_pwc_group": O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2a_pwc_g0', **convArgs_real)(O_g0) O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2a_pwc_g1', **convArgs_real)(O_g1) O = Concatenate(axis=1)([O_g0, O_g1]) elif d.model == "complex": O = ComplexConv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', **convArgs)(O) elif (d.model == "complex_concat") or (d.model == "complex_concat_pwc_group"): O = ComplexConvConcat2D(nb_fmaps1 // 2, filter_size, name=conv_name_base + '2a', **convArgs)(O) if d.model == "complex_concat_pwc_group": O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2a_pwc_g0', **convArgs_real)(O_g0) O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2a_pwc_g1', **convArgs_real)(O_g1) O = Concatenate(axis=1)([O_g0, O_g1]) else: print("Error: unknown model type") exit(-1) elif shortcut == 'projection': if d.spectral_pool_scheme == "proj": O = applySpectralPooling(O, d) if d.model == "real": O = Conv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', strides=(2, 2), **convArgs)(O) elif d.model == "real_dws": O = SeparableConv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', strides=(2, 2), **convArgs)(O) elif (d.model == "real_group") or (d.model == "real_group_pwc_full") or ( d.model == "real_group_pwc_group"): O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(nb_fmaps1 // 2, filter_size, name=conv_name_base + '2a_g0', strides=(2, 2), **convArgs)(O_g0) O_g1 = Conv2D(nb_fmaps1 // 2, filter_size, name=conv_name_base + '2a_g1', strides=(2, 2), **convArgs)(O_g1) O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])( O_g0) O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])( O_g0) O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])( O_g1) O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])( O_g1) O = Concatenate(axis=1)([O_g00, O_g11, O_g01, O_g10]) if d.model == "real_group_pwc_full": O = Conv2D(nb_fmaps1, (1, 1), name=conv_name_base + '2a_pwc', **convArgs)(O) elif d.model == "real_group_pwc_group": O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2a_pwc_g0', **convArgs_real)(O_g0) O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2a_pwc_g1', **convArgs_real)(O_g1) O = Concatenate(axis=1)([O_g0, O_g1]) elif d.model == "complex": O = ComplexConv2D(nb_fmaps1, filter_size, name=conv_name_base + '2a', strides=(2, 2), **convArgs)(O) elif (d.model == "complex_concat") or (d.model == "complex_concat_pwc_group"): O = ComplexConvConcat2D(nb_fmaps1 // 2, filter_size, name=conv_name_base + '2a', strides=(2, 2), **convArgs)(O) if d.model == "complex_concat_pwc_group": O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2a_pwc_g0', **convArgs_real)(O_g0) O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2a_pwc_g1', **convArgs_real)(O_g1) O = Concatenate(axis=1)([O_g0, O_g1]) else: print("Error: unknown model type") exit(-1) if d.model == "real": O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O) O = Activation(activation)(O) O = Conv2D(nb_fmaps2, filter_size, name=conv_name_base + '2b', **convArgs)(O) elif d.model == "real_dws": O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O) O = Activation(activation)(O) O = SeparableConv2D(nb_fmaps2, filter_size, name=conv_name_base + '2b', **convArgs)(O) elif (d.model == "real_group") or (d.model == "real_group_pwc_full") or ( d.model == "real_group_pwc_group"): O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O) O = Activation(activation)(O) O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(nb_fmaps2 // 2, filter_size, name=conv_name_base + '2b_g0', **convArgs)(O_g0) O_g1 = Conv2D(nb_fmaps2 // 2, filter_size, name=conv_name_base + '2b_g1', **convArgs)(O_g1) O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])(O_g0) O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])(O_g0) O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])(O_g1) O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])(O_g1) O = Concatenate(axis=1)([O_g00, O_g11, O_g01, O_g10]) if d.model == "real_group_pwc_full": O = Conv2D(nb_fmaps2, (1, 1), name=conv_name_base + '2b_pwc', **convArgs)(O) elif d.model == "real_group_pwc_group": O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2b_pwc_g0', **convArgs_real)(O_g0) O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2b_pwc_g1', **convArgs_real)(O_g1) O = Concatenate(axis=1)([O_g0, O_g1]) elif d.model == "complex": O = ComplexBN(name=bn_name_base + '_2b', **bnArgs)(O) if d.aact == "complex_joint_relu": O = Lambda(ComplexJointReLU)(O) else: O = Activation(activation)(O) O = ComplexConv2D(nb_fmaps2, filter_size, name=conv_name_base + '2b', **convArgs)(O) elif (d.model == "complex_concat") or (d.model == "complex_concat_pwc_group"): O = ComplexBN(name=bn_name_base + '_2b', **bnArgs)(O) if d.aact == "complex_joint_relu": O = Lambda(ComplexJointReLU)(O) else: O = Activation(activation)(O) O = ComplexConvConcat2D(nb_fmaps2 // 2, filter_size, name=conv_name_base + '2b', **convArgs)(O) if d.model == "complex_concat_pwc_group": O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O) O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O) O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2b_pwc_g0', **convArgs_real)(O_g0) O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1), name=conv_name_base + '2b_pwc_g1', **convArgs_real)(O_g1) O = Concatenate(axis=1)([O_g0, O_g1]) else: print("Error: unknown model type") exit(-1) if shortcut == 'regular': O = Add()([O, I]) elif shortcut == 'projection': if d.spectral_pool_scheme == "proj": I = applySpectralPooling(I, d) if d.model == "real": X = Conv2D( nb_fmaps2, (1, 1), name=conv_name_base + '1', strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else (1, 1), **convArgs)(I) O = Concatenate(channel_axis)([X, O]) elif d.model == "real_dws": X = SeparableConv2D( nb_fmaps2, (1, 1), name=conv_name_base + '1', strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else (1, 1), **convArgs)(I) O = Concatenate(channel_axis)([X, O]) elif (d.model == "real_group") or (d.model == "real_group_pwc_full") or ( d.model == "real_group_pwc_group"): I_g0 = Lambda(lambda I: I[:, :(I.shape[1] // 2), :, :])(I) I_g1 = Lambda(lambda I: I[:, (I.shape[1] // 2):, :, :])(I) X_g0 = Conv2D( nb_fmaps2 // 2, (1, 1), name=conv_name_base + '1_g0', strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else (1, 1), **convArgs)(I_g0) X_g1 = Conv2D( nb_fmaps2 // 2, (1, 1), name=conv_name_base + '1_g1', strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else (1, 1), **convArgs)(I_g1) X_g00 = Lambda(lambda X_g0: X_g0[:, :(X_g0.shape[1] // 2), :, :])( X_g0) X_g01 = Lambda(lambda X_g0: X_g0[:, (X_g0.shape[1] // 2):, :, :])( X_g0) X_g10 = Lambda(lambda X_g1: X_g1[:, :(X_g1.shape[1] // 2), :, :])( X_g1) X_g11 = Lambda(lambda X_g1: X_g1[:, (X_g1.shape[1] // 2):, :, :])( X_g1) X = Concatenate(axis=1)([X_g00, X_g11, X_g01, X_g10]) if d.model == "real_group_pwc_full": X = Conv2D(nb_fmaps2, (1, 1), name=conv_name_base + '1_pwc', **convArgs)(X) elif d.model == "real_group_pwc_group": X_g0 = Lambda(lambda X: X[:, :(X.shape[1] // 2), :, :])(X) X_g1 = Lambda(lambda X: X[:, (X.shape[1] // 2):, :, :])(X) X_g0 = Conv2D(int(X.shape[1] // 2), (1, 1), name=conv_name_base + '1_pwc_g0', **convArgs_real)(X_g0) X_g1 = Conv2D(int(X.shape[1] // 2), (1, 1), name=conv_name_base + '1_pwc_g1', **convArgs_real)(X_g1) X = Concatenate(axis=1)([X_g0, X_g1]) O = Concatenate(channel_axis)([X, O]) elif d.model == "complex": X = ComplexConv2D( nb_fmaps2, (1, 1), name=conv_name_base + '1', strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else (1, 1), **convArgs)(I) O_real = Concatenate(channel_axis)([GetReal()(X), GetReal()(O)]) O_imag = Concatenate(channel_axis)([GetImag()(X), GetImag()(O)]) O = Concatenate(1)([O_real, O_imag]) elif (d.model == "complex_concat") or (d.model == "complex_concat_pwc_group"): X = ComplexConvConcat2D( nb_fmaps2 // 2, (1, 1), name=conv_name_base + '1', strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else (1, 1), **convArgs)(I) if d.model == "complex_concat_pwc_group": X_g0 = Lambda(lambda X: X[:, :(X.shape[1] // 2), :, :])(X) X_g1 = Lambda(lambda X: X[:, (X.shape[1] // 2):, :, :])(X) X_g0 = Conv2D(int(X.shape[1] // 2), (1, 1), name=conv_name_base + '1_pwc_g0', **convArgs_real)(X_g0) X_g1 = Conv2D(int(X.shape[1] // 2), (1, 1), name=conv_name_base + '1_pwc_g1', **convArgs_real)(X_g1) X = Concatenate(axis=1)([X_g0, X_g1]) O_real = Concatenate(channel_axis)([GetReal()(X), GetReal()(O)]) O_imag = Concatenate(channel_axis)([GetImag()(X), GetImag()(O)]) O = Concatenate(1)([O_real, O_imag]) else: print("Error: unknown model type") exit(-1) return O