def GAN(G, D, z_dim, n_labels, resolution, n_channels):
    G_train = Sequential([G, D])
    G_train.cur_block = G.cur_block

    shape = D.get_input_shape_at(0)[1:]
    gen_input, real_input = Input(shape), Input(shape)
    interpolation = Input(shape)

    sub = Subtract()([D(gen_input), D(real_input)])
    norm = GradNorm()([D(interpolation), interpolation])
    D_train = Model([real_input, gen_input, interpolation],
                    [sub, norm, Reshape((1, ))(D(real_input))])
    D_train.cur_block = D.cur_block

    return G_train, D_train