def GAN(G, D, z_dim, n_labels, resolution, n_channels): G_train = Sequential([G, D]) G_train.cur_block = G.cur_block shape = D.get_input_shape_at(0)[1:] gen_input, real_input = Input(shape), Input(shape) interpolation = Input(shape) sub = Subtract()([D(gen_input), D(real_input)]) norm = GradNorm()([D(interpolation), interpolation]) D_train = Model([real_input, gen_input, interpolation], [sub, norm, Reshape((1, ))(D(real_input))]) D_train.cur_block = D.cur_block return G_train, D_train