コード例 #1
0
ファイル: __init__.py プロジェクト: Trixter9994/lazero
def get_shallow_convnet(window_size=4096, channels=2, output_size=84):
    inputs = Input(shape=(window_size, channels))

    conv = ComplexConv1D(32, 512, strides=16, activation='relu')(inputs)
    pool = AveragePooling1D(pool_size=4, strides=2)(conv)

    pool = Permute([2, 1])(pool)
    flattened = Flatten()(pool)

    dense = ComplexDense(2048, activation='relu')(flattened)
    predictions = ComplexDense(output_size,
                               activation='sigmoid',
                               bias_initializer=Constant(value=-5))(dense)
    predictions = GetReal(predictions)
    model = Model(inputs=inputs, outputs=predictions)

    model.compile(optimizer=Adam(lr=1e-4),
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    return model
コード例 #2
0
def getResidualBlock(I, filter_size, featmaps, stage, block, shortcut,
                     convArgs, bnArgs, d):
    """Get residual block."""

    activation = d.act
    drop_prob = d.dropout
    nb_fmaps1, nb_fmaps2 = featmaps
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    if K.image_data_format() == 'channels_first' and K.ndim(I) != 3:
        channel_axis = 1
    else:
        channel_axis = -1

    if d.model == "real":
        O = BatchNormalization(name=bn_name_base + '_2a', **bnArgs)(I)
    elif d.model == "complex":
        O = ComplexBN(name=bn_name_base + '_2a', **bnArgs)(I)
    O = Activation(activation)(O)

    if shortcut == 'regular' or d.spectral_pool_scheme == "nodownsample":
        if d.model == "real":
            O = Conv2D(nb_fmaps1,
                       filter_size,
                       name=conv_name_base + '2a',
                       **convArgs)(O)
        elif d.model == "complex":
            O = ComplexConv2D(nb_fmaps1,
                              filter_size,
                              name=conv_name_base + '2a',
                              **convArgs)(O)
    elif shortcut == 'projection':
        if d.spectral_pool_scheme == "proj":
            O = applySpectralPooling(O, d)
        if d.model == "real":
            O = Conv2D(nb_fmaps1,
                       filter_size,
                       name=conv_name_base + '2a',
                       strides=(2, 2),
                       **convArgs)(O)
        elif d.model == "complex":
            O = ComplexConv2D(nb_fmaps1,
                              filter_size,
                              name=conv_name_base + '2a',
                              strides=(2, 2),
                              **convArgs)(O)

    if d.model == "real":
        O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O)
        O = Activation(activation)(O)
        O = Conv2D(nb_fmaps2,
                   filter_size,
                   name=conv_name_base + '2b',
                   **convArgs)(O)
    elif d.model == "complex":
        O = ComplexBN(name=bn_name_base + '_2b', **bnArgs)(O)
        O = Activation(activation)(O)
        O = ComplexConv2D(nb_fmaps2,
                          filter_size,
                          name=conv_name_base + '2b',
                          **convArgs)(O)

    if shortcut == 'regular':
        O = Add()([O, I])
    elif shortcut == 'projection':
        if d.spectral_pool_scheme == "proj":
            I = applySpectralPooling(I, d)
        if d.model == "real":
            X = Conv2D(
                nb_fmaps2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)
            O = Concatenate(channel_axis)([X, O])
        elif d.model == "complex":
            X = ComplexConv2D(
                nb_fmaps2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)

            O_real = Concatenate(channel_axis)([GetReal()(X), GetReal()(O)])
            O_imag = Concatenate(channel_axis)([GetImag()(X), GetImag()(O)])
            O = Concatenate(1)([O_real, O_imag])

    return O
コード例 #3
0
def getResidualBlock(I, filter_size, featmaps, stage, block, shortcut,
                     convArgs, convArgs_real, bnArgs, d):
    """Get residual block."""

    activation = d.act
    drop_prob = d.dropout
    nb_fmaps1, nb_fmaps2 = featmaps
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    if K.image_data_format() == 'channels_first' and K.ndim(I) != 3:
        channel_axis = 1
    else:
        channel_axis = -1

    # if   d.model == "real":
    if "real" in d.model:
        O = BatchNormalization(name=bn_name_base + '_2a', **bnArgs)(I)
    # elif d.model == "complex":
    elif "complex" in d.model:
        O = ComplexBN(name=bn_name_base + '_2a', **bnArgs)(I)

    if d.aact == "complex_joint_relu":
        O = Lambda(ComplexJointReLU)(O)
    else:
        O = Activation(activation)(O)

    if shortcut == 'regular' or d.spectral_pool_scheme == "nodownsample":
        if d.model == "real":
            O = Conv2D(nb_fmaps1,
                       filter_size,
                       name=conv_name_base + '2a',
                       **convArgs)(O)
        elif d.model == "real_dws":
            O = SeparableConv2D(nb_fmaps1,
                                filter_size,
                                name=conv_name_base + '2a',
                                **convArgs)(O)
        elif (d.model
              == "real_group") or (d.model == "real_group_pwc_full") or (
                  d.model == "real_group_pwc_group"):
            O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
            O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
            O_g0 = Conv2D(nb_fmaps1 // 2,
                          filter_size,
                          name=conv_name_base + '2a_g0',
                          **convArgs)(O_g0)
            O_g1 = Conv2D(nb_fmaps1 // 2,
                          filter_size,
                          name=conv_name_base + '2a_g1',
                          **convArgs)(O_g1)
            O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])(
                O_g0)
            O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])(
                O_g0)
            O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])(
                O_g1)
            O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])(
                O_g1)
            O = Concatenate(axis=1)(
                [O_g00, O_g11, O_g01, O_g10]
            )  # This ordering allows permutation of odd-numbered outputs (O_g0, O_g1).
            if d.model == "real_group_pwc_full":
                O = Conv2D(nb_fmaps1, (1, 1),
                           name=conv_name_base + '2a_pwc',
                           **convArgs)(O)
            elif d.model == "real_group_pwc_group":
                O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
                O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
                O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g0',
                              **convArgs_real)(O_g0)
                O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g1',
                              **convArgs_real)(O_g1)
                O = Concatenate(axis=1)([O_g0, O_g1])
        elif d.model == "complex":
            O = ComplexConv2D(nb_fmaps1,
                              filter_size,
                              name=conv_name_base + '2a',
                              **convArgs)(O)
        elif (d.model == "complex_concat") or (d.model
                                               == "complex_concat_pwc_group"):
            O = ComplexConvConcat2D(nb_fmaps1 // 2,
                                    filter_size,
                                    name=conv_name_base + '2a',
                                    **convArgs)(O)
            if d.model == "complex_concat_pwc_group":
                O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
                O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
                O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g0',
                              **convArgs_real)(O_g0)
                O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g1',
                              **convArgs_real)(O_g1)
                O = Concatenate(axis=1)([O_g0, O_g1])
        else:
            print("Error: unknown model type")
            exit(-1)
    elif shortcut == 'projection':
        if d.spectral_pool_scheme == "proj":
            O = applySpectralPooling(O, d)

        if d.model == "real":
            O = Conv2D(nb_fmaps1,
                       filter_size,
                       name=conv_name_base + '2a',
                       strides=(2, 2),
                       **convArgs)(O)
        elif d.model == "real_dws":
            O = SeparableConv2D(nb_fmaps1,
                                filter_size,
                                name=conv_name_base + '2a',
                                strides=(2, 2),
                                **convArgs)(O)
        elif (d.model
              == "real_group") or (d.model == "real_group_pwc_full") or (
                  d.model == "real_group_pwc_group"):
            O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
            O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
            O_g0 = Conv2D(nb_fmaps1 // 2,
                          filter_size,
                          name=conv_name_base + '2a_g0',
                          strides=(2, 2),
                          **convArgs)(O_g0)
            O_g1 = Conv2D(nb_fmaps1 // 2,
                          filter_size,
                          name=conv_name_base + '2a_g1',
                          strides=(2, 2),
                          **convArgs)(O_g1)
            O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])(
                O_g0)
            O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])(
                O_g0)
            O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])(
                O_g1)
            O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])(
                O_g1)
            O = Concatenate(axis=1)([O_g00, O_g11, O_g01, O_g10])
            if d.model == "real_group_pwc_full":
                O = Conv2D(nb_fmaps1, (1, 1),
                           name=conv_name_base + '2a_pwc',
                           **convArgs)(O)
            elif d.model == "real_group_pwc_group":
                O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
                O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
                O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g0',
                              **convArgs_real)(O_g0)
                O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g1',
                              **convArgs_real)(O_g1)
                O = Concatenate(axis=1)([O_g0, O_g1])
        elif d.model == "complex":
            O = ComplexConv2D(nb_fmaps1,
                              filter_size,
                              name=conv_name_base + '2a',
                              strides=(2, 2),
                              **convArgs)(O)
        elif (d.model == "complex_concat") or (d.model
                                               == "complex_concat_pwc_group"):
            O = ComplexConvConcat2D(nb_fmaps1 // 2,
                                    filter_size,
                                    name=conv_name_base + '2a',
                                    strides=(2, 2),
                                    **convArgs)(O)
            if d.model == "complex_concat_pwc_group":
                O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
                O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
                O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g0',
                              **convArgs_real)(O_g0)
                O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                              name=conv_name_base + '2a_pwc_g1',
                              **convArgs_real)(O_g1)
                O = Concatenate(axis=1)([O_g0, O_g1])
        else:
            print("Error: unknown model type")
            exit(-1)

    if d.model == "real":
        O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O)
        O = Activation(activation)(O)
        O = Conv2D(nb_fmaps2,
                   filter_size,
                   name=conv_name_base + '2b',
                   **convArgs)(O)
    elif d.model == "real_dws":
        O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O)
        O = Activation(activation)(O)
        O = SeparableConv2D(nb_fmaps2,
                            filter_size,
                            name=conv_name_base + '2b',
                            **convArgs)(O)
    elif (d.model == "real_group") or (d.model == "real_group_pwc_full") or (
            d.model == "real_group_pwc_group"):
        O = BatchNormalization(name=bn_name_base + '_2b', **bnArgs)(O)
        O = Activation(activation)(O)
        O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
        O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
        O_g0 = Conv2D(nb_fmaps2 // 2,
                      filter_size,
                      name=conv_name_base + '2b_g0',
                      **convArgs)(O_g0)
        O_g1 = Conv2D(nb_fmaps2 // 2,
                      filter_size,
                      name=conv_name_base + '2b_g1',
                      **convArgs)(O_g1)
        O_g00 = Lambda(lambda O_g0: O_g0[:, :(O_g0.shape[1] // 2), :, :])(O_g0)
        O_g01 = Lambda(lambda O_g0: O_g0[:, (O_g0.shape[1] // 2):, :, :])(O_g0)
        O_g10 = Lambda(lambda O_g1: O_g1[:, :(O_g1.shape[1] // 2), :, :])(O_g1)
        O_g11 = Lambda(lambda O_g1: O_g1[:, (O_g1.shape[1] // 2):, :, :])(O_g1)
        O = Concatenate(axis=1)([O_g00, O_g11, O_g01, O_g10])
        if d.model == "real_group_pwc_full":
            O = Conv2D(nb_fmaps2, (1, 1),
                       name=conv_name_base + '2b_pwc',
                       **convArgs)(O)
        elif d.model == "real_group_pwc_group":
            O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
            O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
            O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name=conv_name_base + '2b_pwc_g0',
                          **convArgs_real)(O_g0)
            O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name=conv_name_base + '2b_pwc_g1',
                          **convArgs_real)(O_g1)
            O = Concatenate(axis=1)([O_g0, O_g1])
    elif d.model == "complex":
        O = ComplexBN(name=bn_name_base + '_2b', **bnArgs)(O)
        if d.aact == "complex_joint_relu":
            O = Lambda(ComplexJointReLU)(O)
        else:
            O = Activation(activation)(O)
        O = ComplexConv2D(nb_fmaps2,
                          filter_size,
                          name=conv_name_base + '2b',
                          **convArgs)(O)
    elif (d.model == "complex_concat") or (d.model
                                           == "complex_concat_pwc_group"):
        O = ComplexBN(name=bn_name_base + '_2b', **bnArgs)(O)
        if d.aact == "complex_joint_relu":
            O = Lambda(ComplexJointReLU)(O)
        else:
            O = Activation(activation)(O)
        O = ComplexConvConcat2D(nb_fmaps2 // 2,
                                filter_size,
                                name=conv_name_base + '2b',
                                **convArgs)(O)
        if d.model == "complex_concat_pwc_group":
            O_g0 = Lambda(lambda O: O[:, :(O.shape[1] // 2), :, :])(O)
            O_g1 = Lambda(lambda O: O[:, (O.shape[1] // 2):, :, :])(O)
            O_g0 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name=conv_name_base + '2b_pwc_g0',
                          **convArgs_real)(O_g0)
            O_g1 = Conv2D(int(O.shape[1] // 2), (1, 1),
                          name=conv_name_base + '2b_pwc_g1',
                          **convArgs_real)(O_g1)
            O = Concatenate(axis=1)([O_g0, O_g1])

    else:
        print("Error: unknown model type")
        exit(-1)

    if shortcut == 'regular':
        O = Add()([O, I])
    elif shortcut == 'projection':
        if d.spectral_pool_scheme == "proj":
            I = applySpectralPooling(I, d)
        if d.model == "real":
            X = Conv2D(
                nb_fmaps2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)
            O = Concatenate(channel_axis)([X, O])
        elif d.model == "real_dws":
            X = SeparableConv2D(
                nb_fmaps2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)
            O = Concatenate(channel_axis)([X, O])
        elif (d.model
              == "real_group") or (d.model == "real_group_pwc_full") or (
                  d.model == "real_group_pwc_group"):
            I_g0 = Lambda(lambda I: I[:, :(I.shape[1] // 2), :, :])(I)
            I_g1 = Lambda(lambda I: I[:, (I.shape[1] // 2):, :, :])(I)
            X_g0 = Conv2D(
                nb_fmaps2 // 2, (1, 1),
                name=conv_name_base + '1_g0',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I_g0)
            X_g1 = Conv2D(
                nb_fmaps2 // 2, (1, 1),
                name=conv_name_base + '1_g1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I_g1)
            X_g00 = Lambda(lambda X_g0: X_g0[:, :(X_g0.shape[1] // 2), :, :])(
                X_g0)
            X_g01 = Lambda(lambda X_g0: X_g0[:, (X_g0.shape[1] // 2):, :, :])(
                X_g0)
            X_g10 = Lambda(lambda X_g1: X_g1[:, :(X_g1.shape[1] // 2), :, :])(
                X_g1)
            X_g11 = Lambda(lambda X_g1: X_g1[:, (X_g1.shape[1] // 2):, :, :])(
                X_g1)
            X = Concatenate(axis=1)([X_g00, X_g11, X_g01, X_g10])
            if d.model == "real_group_pwc_full":
                X = Conv2D(nb_fmaps2, (1, 1),
                           name=conv_name_base + '1_pwc',
                           **convArgs)(X)
            elif d.model == "real_group_pwc_group":
                X_g0 = Lambda(lambda X: X[:, :(X.shape[1] // 2), :, :])(X)
                X_g1 = Lambda(lambda X: X[:, (X.shape[1] // 2):, :, :])(X)
                X_g0 = Conv2D(int(X.shape[1] // 2), (1, 1),
                              name=conv_name_base + '1_pwc_g0',
                              **convArgs_real)(X_g0)
                X_g1 = Conv2D(int(X.shape[1] // 2), (1, 1),
                              name=conv_name_base + '1_pwc_g1',
                              **convArgs_real)(X_g1)
                X = Concatenate(axis=1)([X_g0, X_g1])
            O = Concatenate(channel_axis)([X, O])
        elif d.model == "complex":
            X = ComplexConv2D(
                nb_fmaps2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)
            O_real = Concatenate(channel_axis)([GetReal()(X), GetReal()(O)])
            O_imag = Concatenate(channel_axis)([GetImag()(X), GetImag()(O)])
            O = Concatenate(1)([O_real, O_imag])
        elif (d.model == "complex_concat") or (d.model
                                               == "complex_concat_pwc_group"):
            X = ComplexConvConcat2D(
                nb_fmaps2 // 2, (1, 1),
                name=conv_name_base + '1',
                strides=(2, 2) if d.spectral_pool_scheme != "nodownsample" else
                (1, 1),
                **convArgs)(I)
            if d.model == "complex_concat_pwc_group":
                X_g0 = Lambda(lambda X: X[:, :(X.shape[1] // 2), :, :])(X)
                X_g1 = Lambda(lambda X: X[:, (X.shape[1] // 2):, :, :])(X)
                X_g0 = Conv2D(int(X.shape[1] // 2), (1, 1),
                              name=conv_name_base + '1_pwc_g0',
                              **convArgs_real)(X_g0)
                X_g1 = Conv2D(int(X.shape[1] // 2), (1, 1),
                              name=conv_name_base + '1_pwc_g1',
                              **convArgs_real)(X_g1)
                X = Concatenate(axis=1)([X_g0, X_g1])
            O_real = Concatenate(channel_axis)([GetReal()(X), GetReal()(O)])
            O_imag = Concatenate(channel_axis)([GetImag()(X), GetImag()(O)])
            O = Concatenate(1)([O_real, O_imag])
        else:
            print("Error: unknown model type")
            exit(-1)

    return O