def build_discriminator(image, mbd=False, sparsity=False, sparsity_mbd=False): """ Generator sub-component for the CaloGAN Args: ----- image: keras tensor of 4 dimensions (i.e. the output of one calo layer) mdb: bool, perform feature level minibatch discrimination sparsiry: bool, whether or not to calculate and include sparsity sparsity_mdb: bool, perform minibatch discrimination on the sparsity values in a batch Returns: -------- a keras tensor of features """ x = Conv2D(64, (2, 2), padding='same')(image) x = LeakyReLU()(x) x = ZeroPadding2D((1, 1))(x) x = LocallyConnected2D(16, (3, 3), padding='valid', strides=(1, 2))(x) x = LeakyReLU()(x) x = BatchNormalization()(x) x = ZeroPadding2D((1, 1))(x) x = LocallyConnected2D(8, (2, 2), padding='valid')(x) x = LeakyReLU()(x) x = BatchNormalization()(x) x = ZeroPadding2D((1, 1))(x) x = LocallyConnected2D(8, (2, 2), padding='valid', strides=(1, 2))(x) x = LeakyReLU()(x) x = BatchNormalization()(x) x = Flatten()(x) if mbd or sparsity or sparsity_mbd: minibatch_featurizer = Lambda(minibatch_discriminator, output_shape=minibatch_output_shape) features = [x] nb_features = 10 vspace_dim = 10 # creates the kernel space for the minibatch discrimination if mbd: K_x = Dense3D(nb_features, vspace_dim)(x) features.append(Activation('tanh')(minibatch_featurizer(K_x))) if sparsity or sparsity_mbd: sparsity_detector = Lambda(sparsity_level, sparsity_output_shape) empirical_sparsity = sparsity_detector(image) if sparsity: features.append(empirical_sparsity) if sparsity_mbd: K_sparsity = Dense3D(nb_features, vspace_dim)(empirical_sparsity) features.append( Activation('tanh')(minibatch_featurizer(K_sparsity))) return concatenate(features) else: return x
features = concatenate(features) # This is a (None, 3) tensor with the individual energy per layer energies = concatenate(energies) # calculate the total energy across all rows total_energy = Lambda(lambda x: K.reshape(K.sum(x, axis=-1), (-1, 1)), name='total_energy')(energies) # construct MBD on the raw energies nb_features = 10 vspace_dim = 10 minibatch_featurizer = Lambda(minibatch_discriminator, output_shape=minibatch_output_shape) K_energy = Dense3D(nb_features, vspace_dim)(energies) # constrain w/ a tanh to dampen the unbounded nature of energy-space mbd_energy = Activation('tanh')(minibatch_featurizer(K_energy)) # absolute deviation away from input energy. Technically we can learn # this, but since we want to get as close as possible to conservation of # energy, just coding it in is better energy_well = Lambda(lambda x: K.abs(x[0] - x[1]))( [total_energy, input_energy]) # binary y/n if it is over the input energy well_too_big = Lambda(lambda x: 10 * K.cast(x > 5, K.floatx()))( energy_well) p = concatenate([