Ejemplo n.º 1
0
def k_autoencoder4(latent_dim, folder):
    inputs = kst.md_input((32, 32, 3))
    net = kst.conv2d(inputs, 64, (3, 3))  # 32,32,64
    net = kst.activation(net, "relu")
    net = kst.conv2d(net, 64, (3, 3))  # 32,32,64
    net = kst.activation(net, "relu")
    net = kst.maxpool2d(net)  # 16,16,64
    net = kst.conv2d(net, 128, (3, 3))  # 16,16,128
    net = kst.activation(net, "relu")
    net = kst.conv2d(net, 128, (3, 3))  # 16,16,128
    net = kst.activation(net, "relu")
    net = kst.maxpool2d(net)  # 8,8,128
    net = kst.conv2d(net, 256, (3, 3))  # 8,8,256
    net = kst.activation(net, "relu")
    net = kst.conv2d(net, 256, (3, 3))
    net = kst.flatten(net)  # 8*8*256
    net = kst.dense(net, 2048)  # 1024
    net = kst.activation(net, "relu")
    net = kst.dense(net, latent_dim)
    encode = kst.activation(net, "relu")
    net = kst.dense(encode, 2048)  #19
    net = kst.activation(net, "relu")
    net = kst.dense(net, 8 * 8 * 256)
    net = kst.activation(net, "relu")
    net = kst.reshape(net, (8, 8, 256))  # 4,4,512
    # net = kst.deconv2d(net, 512, (3, 3))
    # net = kst.activation(net, "relu")
    # net = kst.deconv2d(net, 512, (3, 3))
    # net = kst.upsampling(net, 2)  # 8,8,512
    net = kst.deconv2d(net, 256, (3, 3))  # 8,8,256
    net = kst.activation(net, "relu")
    net = kst.deconv2d(net, 256, (3, 3))
    net = kst.upsampling(net, 2)  # 16,16,256
    net = kst.deconv2d(net, 128, (3, 3))  # 16,16,128
    net = kst.activation(net, "relu")
    net = kst.deconv2d(net, 128, (3, 3))
    net = kst.activation(net, "relu")
    net = kst.upsampling(net, 2)  # 32,32,128
    net = kst.deconv2d(net, 64, (3, 3))  # 32,32,64
    net = kst.activation(net, "relu")
    net = kst.deconv2d(net, 64, (3, 3))
    net = kst.activation(net, "relu")
    net = kst.deconv2d(net, 3, (3, 3))  # 32,32,3
    decode = kst.activation(net, "sigmoid")
    auto_encoder = Model(inputs, decode)
    net_struct = kst.model_summary(auto_encoder,
                                   print_out=True,
                                   save_dir=folder + "/model_summary.txt")

    encoder = Model(inputs, encode)
    encoded_input = kst.md_input(shape=(latent_dim, ))
    decoding = auto_encoder.layers[19](encoded_input)
    for layer in auto_encoder.layers[20:]:
        decoding = layer(decoding)
    decoder = Model(encoded_input, decoding)

    return auto_encoder, encoder, decoder
Ejemplo n.º 2
0
                                                model.add(LeakyReLU(af[1]))

                                    model.add(Flatten())
                                    for dnl in denl:
                                        model.add(Dense(units=dnl, kernel_initializer=wi[2], bias_initializer=bi[2]))
                                        model.add(BatchNormalization())
                                        if af[0] == 'relu':
                                            model.add(Activation('relu'))
                                        elif af[0] == 'leaky_relu':
                                            model.add(LeakyReLU(af[1]))

                                        model.add(Dropout(dp))

                                    model.add(Dense(units=7, kernel_initializer=wi[2]))
                                    model.add(Activation('softmax'))
                                    net_rpt = knt.model_summary(model, param)
                                    model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=lr), metrics=['accuracy'])
                                    mc = ModelCheckpoint(os.path.join(dir, 'best_model_fold' + str(nfid+1) + '.h5'), monitor='val_acc', mode='max', verbose=1, save_best_only=True)
                                    # hist = model.fit(x_train, y_train,
                                    #                  batch_size=bs, epochs=epochs,
                                    #                  validation_data=(x_valid, y_valid), callbacks=[mc])

                                    hist = model.fit_generator(
                                        train_generator,
                                        steps_per_epoch=train_generator.n // bs,
                                        epochs=epochs,
                                        validation_data=validate_generator,
                                        validation_steps=validate_generator.n // bs,
                                        callbacks=[mc]
                                    )
                                    acc_train = hist.history['acc']
Ejemplo n.º 3
0
net = kst.maxpool2d(net)
net = kst.flatten(net)
net = kst.dense(net, 1024)
net = kst.batch_norm(net)
net = kst.activation(net, "LeakyReLU")
net = kst.dropout(net, 0.5)
net = kst.dense(net, 512)
net = kst.batch_norm(net)
net = kst.activation(net, "LeakyReLU")
net = kst.dropout(net, 0.5)
net = kst.dense(net, 7)
net_out = kst.activation(net, "softmax")

model = Model(inputs, net_out)

net_rpt = kst.model_summary(model, param_dict=param)

model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=learn_rate), metrics=['accuracy'])

# create project
tn = time.localtime()
project = "./logs/hw3/D{0:4d}{1:02d}{2:02d}T{3:02d}{4:02d}".format(tn[0], tn[1], tn[2], tn[3], tn[4])
os.mkdir(project)

mc = ModelCheckpoint(os.path.join(project, 'best_model.h5'), monitor='val_acc', mode='max', verbose=1, save_best_only=True)

hist = model.fit_generator(
    train_generator,
    steps_per_epoch=train_generator.n // bt_size,
    epochs=epochs,
    validation_data=validate_generator,
Ejemplo n.º 4
0
def k_inception_v3(n_class, dp):
    # TODO: make sure the reliability"
    inputs = kst.md_input((299, 299, 1))
    # inputs = K.resize_images(inputs, 299, 299, data_format="channels_last")
    net = kst.conv2d(inputs, 32, (3, 3), stride=2, padding="VALID")
    net = kst.batch_norm(net)
    net = kst.activation(net, "relu")
    net = kst.conv2d(net, 32, (3, 3), padding="VALID")
    net = kst.batch_norm(net)
    net = kst.activation(net, "relu")
    net = kst.conv2d(net, 64, (3, 3), padding="SAME")
    net = kst.batch_norm(net)
    net = kst.activation(net, "relu")
    net = kst.maxpool2d(net, kernel_size=(3, 3), stride=2, padding="VALID")
    net = kst.conv2d(net, 80, (1, 1), padding="VALID")
    net = kst.batch_norm(net)
    net = kst.activation(net, "relu")
    net = kst.conv2d(net, 192, (3, 3), padding="VALID")
    net = kst.batch_norm(net)
    net = kst.activation(net, "relu")
    net = kst.maxpool2d(net, (3, 3), padding="VALID")

    # mixed 5b:
    branch0 = kst.conv2d(net, 64, (1, 1))
    branch0 = kst.batch_norm(branch0)
    branch0 = kst.activation(branch0, "relu")
    branch1 = kst.conv2d(net, 48, (1, 1))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1 = kst.conv2d(branch1, 64, (5, 5))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch2 = kst.conv2d(net, 64, (1, 1))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 96, (3, 3))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 96, (3, 3))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch3 = kst.avgpool2d(net, (3, 3))
    branch3 = kst.conv2d(branch3, 32, (1, 1))
    branch3 = kst.batch_norm(branch3)
    branch3 = kst.activation(branch3, "relu")
    net = kst.concat([branch0, branch1, branch2, branch3], axis=3)

    # mixed 5c:
    branch0 = kst.conv2d(net, 64, (1, 1))
    branch0 = kst.batch_norm(branch0)
    branch0 = kst.activation(branch0, "relu")
    branch1 = kst.conv2d(net, 48, (1, 1))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1 = kst.conv2d(branch1, 64, (5, 5))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch2 = kst.conv2d(net, 64, (1, 1))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 96, (3, 3))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 96, (3, 3))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch3 = kst.avgpool2d(net, (3, 3))
    branch3 = kst.conv2d(branch3, 64, (1, 1))
    branch3 = kst.batch_norm(branch3)
    branch3 = kst.activation(branch3, "relu")
    net = kst.concat([branch0, branch1, branch2, branch3], axis=3)

    # mixed 5d:
    branch0 = kst.conv2d(net, 64, (1, 1))
    branch0 = kst.batch_norm(branch0)
    branch0 = kst.activation(branch0, "relu")
    branch1 = kst.conv2d(net, 48, (1, 1))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1 = kst.conv2d(branch1, 64, (5, 5))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch2 = kst.conv2d(net, 64, (1, 1))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 96, (3, 3))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 96, (3, 3))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch3 = kst.avgpool2d(net, (3, 3))
    branch3 = kst.conv2d(branch3, 64, (1, 1))
    branch3 = kst.batch_norm(branch3)
    branch3 = kst.activation(branch3, "relu")
    net = kst.concat([branch0, branch1, branch2, branch3], axis=3)

    # mixed 6a:
    branch0 = kst.conv2d(net, 348, (3, 3), stride=2, padding="VALID")
    branch0 = kst.batch_norm(branch0)
    branch0 = kst.activation(branch0, "relu")
    branch1 = kst.conv2d(net, 64, (1, 1))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1 = kst.conv2d(branch1, 96, (3, 3))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1 = kst.conv2d(branch1, 96, (3, 3), stride=2, padding="VALID")
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch2 = kst.maxpool2d(net, (3, 3), padding="VALID")
    net = kst.concat([branch0, branch1, branch2], axis=3)

    # mixed 6b:
    branch0 = kst.conv2d(net, 192, (1, 1))
    branch0 = kst.batch_norm(branch0)
    branch0 = kst.activation(branch0, "relu")
    branch1 = kst.conv2d(net, 128, (1, 1))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1 = kst.conv2d(branch1, 128, (1, 7))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1 = kst.conv2d(branch1, 192, (7, 1))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch2 = kst.conv2d(net, 128, (1, 1))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 128, (7, 1))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 128, (1, 7))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 128, (7, 1))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 192, (1, 7))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch3 = kst.avgpool2d(net, (3, 3))
    branch3 = kst.conv2d(branch3, 192, (1, 1))
    branch3 = kst.batch_norm(branch3)
    branch3 = kst.activation(branch3, "relu")
    net = kst.concat([branch0, branch1, branch2, branch3], axis=3)

    # mixed 6c:
    branch0 = kst.conv2d(net, 192, (1, 1))
    branch0 = kst.batch_norm(branch0)
    branch0 = kst.activation(branch0, "relu")
    branch1 = kst.conv2d(net, 160, (1, 1))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1 = kst.conv2d(branch1, 160, (1, 7))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1 = kst.conv2d(branch1, 192, (7, 1))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch2 = kst.conv2d(net, 160, (1, 1))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 160, (7, 1))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 160, (1, 7))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 160, (7, 1))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 192, (1, 7))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch3 = kst.avgpool2d(net, (3, 3))
    branch3 = kst.conv2d(branch3, 192, (1, 1))
    branch3 = kst.batch_norm(branch3)
    branch3 = kst.activation(branch3, "relu")
    net = kst.concat([branch0, branch1, branch2, branch3], axis=3)

    # mixed 6d:
    branch0 = kst.conv2d(net, 192, (1, 1))
    branch0 = kst.batch_norm(branch0)
    branch0 = kst.activation(branch0, "relu")
    branch1 = kst.conv2d(net, 160, (1, 1))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1 = kst.conv2d(branch1, 160, (1, 7))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1 = kst.conv2d(branch1, 192, (7, 1))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch2 = kst.conv2d(net, 160, (1, 1))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 160, (7, 1))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 160, (1, 7))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 160, (7, 1))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 192, (1, 7))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch3 = kst.avgpool2d(net, (3, 3))
    branch3 = kst.conv2d(branch3, 192, (1, 1))
    branch3 = kst.batch_norm(branch3)
    branch3 = kst.activation(branch3, "relu")
    net = kst.concat([branch0, branch1, branch2, branch3], axis=3)

    # mixed 6e:
    branch0 = kst.conv2d(net, 192, (1, 1))
    branch0 = kst.batch_norm(branch0)
    branch0 = kst.activation(branch0, "relu")
    branch1 = kst.conv2d(net, 192, (1, 1))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1 = kst.conv2d(branch1, 192, (1, 7))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1 = kst.conv2d(branch1, 192, (7, 1))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch2 = kst.conv2d(net, 192, (1, 1))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 192, (7, 1))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 192, (1, 7))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 192, (7, 1))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch2 = kst.conv2d(branch2, 192, (1, 7))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, 'relu')
    branch3 = kst.avgpool2d(net, (3, 3))
    branch3 = kst.conv2d(branch3, 192, (1, 1))
    branch3 = kst.batch_norm(branch3)
    branch3 = kst.activation(branch3, "relu")
    net = kst.concat([branch0, branch1, branch2, branch3], axis=3)
    aux_logits = net

    # mixed 7a:
    branch0 = kst.conv2d(net, 192, (1, 1))
    branch0 = kst.batch_norm(branch0)
    branch0 = kst.activation(branch0, "relu")
    branch0 = kst.conv2d(branch0, 320, (3, 3), stride=2, padding="VALID")
    branch0 = kst.batch_norm(branch0)
    branch0 = kst.activation(branch0, "relu")
    branch1 = kst.conv2d(net, 192, (1, 1))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1 = kst.conv2d(branch1, 192, (1, 7))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1 = kst.conv2d(branch1, 192, (7, 1))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1 = kst.conv2d(branch1, 192, (3, 3), stride=2, padding="VALID")
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch2 = kst.maxpool2d(net, (3, 3), padding="VALID")
    net = kst.concat([branch0, branch1, branch2], axis=3)

    # mixed 7b:
    branch0 = kst.conv2d(net, 320, (1, 1))
    branch0 = kst.batch_norm(branch0)
    branch0 = kst.activation(branch0, "relu")
    branch1 = kst.conv2d(net, 384, (1, 1))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1a = kst.conv2d(branch1, 384, (1, 3))
    branch1a = kst.batch_norm(branch1a)
    branch1a = kst.activation(branch1a, "relu")
    branch1b = kst.conv2d(branch1, 384, (3, 1))
    branch1b = kst.batch_norm(branch1b)
    branch1b = kst.activation(branch1b, "relu")
    branch1 = kst.concat([branch1a, branch1b], axis=3)
    branch2 = kst.conv2d(net, 448, (1, 1))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, "relu")
    branch2 = kst.conv2d(branch2, 384, (3, 3))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, "relu")
    branch2a = kst.conv2d(branch2, 384, (1, 3))
    branch2a = kst.batch_norm(branch2a)
    branch2a = kst.activation(branch2a, "relu")
    branch2b = kst.conv2d(branch1, 384, (3, 1))
    branch2b = kst.batch_norm(branch2b)
    branch2b = kst.activation(branch2b, "relu")
    branch2 = kst.concat([branch2a, branch2b], axis=3)
    branch3 = kst.avgpool2d(net, (3, 3))
    branch3 = kst.conv2d(branch3, 192, (1, 1))
    branch3 = kst.batch_norm(branch3)
    branch3 = kst.activation(branch3, "relu")
    net = kst.concat([branch0, branch1, branch2, branch3], axis=3)

    # mixed 7c:
    branch0 = kst.conv2d(net, 320, (1, 1))
    branch0 = kst.batch_norm(branch0)
    branch0 = kst.activation(branch0, "relu")
    branch1 = kst.conv2d(net, 384, (1, 1))
    branch1 = kst.batch_norm(branch1)
    branch1 = kst.activation(branch1, "relu")
    branch1a = kst.conv2d(branch1, 384, (1, 3))
    branch1a = kst.batch_norm(branch1a)
    branch1a = kst.activation(branch1a, "relu")
    branch1b = kst.conv2d(branch1, 384, (3, 1))
    branch1b = kst.batch_norm(branch1b)
    branch1b = kst.activation(branch1b, "relu")
    branch1 = kst.concat([branch1a, branch1b], axis=3)
    branch2 = kst.conv2d(net, 448, (1, 1))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, "relu")
    branch2 = kst.conv2d(branch2, 384, (3, 3))
    branch2 = kst.batch_norm(branch2)
    branch2 = kst.activation(branch2, "relu")
    branch2a = kst.conv2d(branch2, 384, (1, 3))
    branch2a = kst.batch_norm(branch2a)
    branch2a = kst.activation(branch2a, "relu")
    branch2b = kst.conv2d(branch1, 384, (3, 1))
    branch2b = kst.batch_norm(branch2b)
    branch2b = kst.activation(branch2b, "relu")
    branch2 = kst.concat([branch2a, branch2b], axis=3)
    branch3 = kst.avgpool2d(net, (3, 3))
    branch3 = kst.conv2d(branch3, 192, (1, 1))
    branch3 = kst.batch_norm(branch3)
    branch3 = kst.activation(branch3, "relu")
    net = kst.concat([branch0, branch1, branch2, branch3], axis=3)

    # auxlogits:
    aux_logits = kst.avgpool2d(aux_logits, (5, 5), stride=3, padding="VALID")
    aux_logits = kst.conv2d(aux_logits, 128, (1, 1))
    aux_logits = kst.batch_norm(aux_logits)
    aux_logits = kst.activation(aux_logits, "relu")
    aux_logits = kst.conv2d(aux_logits, 768, (5, 5), padding="VALID")
    aux_logits = kst.batch_norm(aux_logits)
    aux_logits = kst.activation(aux_logits, "relu")
    aux_logits = kst.conv2d(aux_logits, n_class, (1, 1))
    aux_logits = kst.flatten(aux_logits)
    # aux_logits = kst.squeeze(aux_logits, 1)
    # aux_logits = kst.squeeze(aux_logits, 1)
    aux_predict = kst.activation(aux_logits, "softmax")

    # Logits:
    net = kst.avgpool2d(net, (8, 8), padding="VALID")
    net = kst.dropout(net, dp)
    logits = kst.conv2d(net, n_class, (1, 1))
    # logits = kst.squeeze(logits, 1)
    # logits = kst.squeeze(logits, 1)
    logits = kst.flatten(logits)
    predict = kst.activation(logits, "softmax")
    inception_v3 = Model(inputs, [predict, aux_predict])
    net_struct = kst.model_summary(inception_v3, print_out=True)
    # aux_inception = Model(inputs, aux_predict)

    return inception_v3