def mk_adversary(target_model, n_units=(300, 300, 300, 300), epsilons=(1, 1, 1), input_shape=None, center_jets=True): adv_input = layers.Input(input_shape) x_in = adv_input if center_jets: x_in = util.CenterJet()(x_in) x = x_in x = layers.Flatten()(x) for n in n_units: x = layers.Dense(n, activation='relu')(x) x = layers.Dense(np.prod(input_shape))(x) x = layers.Reshape(input_shape)(x) def add_deltas(x): x_old = x[0] dx = x[1] jet = x[2] pt_old, eta_old, phi_old = tf.split(x_old, 3, axis=-1) dpt, deta, dphi = tf.split(dx, 3, axis=-1) jet_pt, jet_eta, jet_phi, jet_mass = tf.split(jet, 4, axis=-1) zeros = tf.zeros_like(pt_old) # try something new... add pT in units of jet mass #pt_new = tf.clip_by_value(pt_old*(1+epsilons[0]*tf.tanh(dpt)), defs.MIN_PT, 9e9) #jpt = tf.reduce_sum(pt_old, axis=-2, keepdims=True) jpt = tf.reshape(jet_pt, (-1, 1, 1)) jeta = tf.reshape(jet_eta, (-1, 1, 1)) jphi = tf.reshape(jet_phi, (-1, 1, 1)) was_valid = pt_old > 0 can_create = False if not can_create: pt_new = pt_old * (1 + epsilons[0] * tf.tanh(dpt)) pt_new = tf.where(was_valid, pt_new, zeros) else: pt_new = pt_old + epsilons[0] * jpt * tf.tanh(dpt) eeta = epsilons[1] * tf.tanh(deta) eta_new_ext = eta_old + eeta eta_new_next = jeta + eeta eta_new = tf.where(was_valid, eta_new_ext, eta_new_next) ephi = epsilons[2] * tf.tanh(dphi) phi_new_ext = phi_old + ephi phi_new_next = jphi + ephi phi_new = tf.where(was_valid, phi_new_ext, phi_new_next) #is_valid = pt_old>0 is_valid = pt_new > defs.MIN_PT pt_new = tf.where(is_valid, pt_new, zeros) eta_new = tf.where(is_valid, eta_new, zeros) phi_new = tf.where(is_valid, phi_new, zeros) return tf.concat([pt_new, eta_new, phi_new], axis=-1) tmp_jet = util.JetVector()(x_in) adv_output = layers.Lambda(add_deltas)([x_in, x, tmp_jet]) adversary = Model(adv_input, adv_output, name='adversary') calc = mk_HL_calc(features=('pt', 'eta', 'phi', 'mass')) target_model.trainable = False composite_input = layers.Input(input_shape) pre_x = composite_input adv_x = adversary(pre_x) adv_dx = layers.subtract([adv_x, pre_x]) composite_output = target_model(adv_x) cls_output = target_model(pre_x) jet_before = calc(pre_x) jet_after = calc(adv_x) composite = Model(composite_input, composite_output, name='composite') def adv_loss(y1, y2): xent = keras.losses.binary_crossentropy( tf.zeros_like(composite_output), composite_output) is_sig_like = K.squeeze(y1, axis=1) > 0.5 return K.mean(tf.where(is_sig_like, xent, tf.zeros_like(xent))) def bg_loss(y1, y2): mse = K.squeeze(K.square(composite_output - cls_output), axis=1) is_bg_like = K.squeeze(y1, axis=1) < 0.5 return K.mean(tf.where(is_bg_like, mse, tf.zeros_like(mse))) def jpt_loss(y1, y2): mse = K.square((jet_before[:, 0] - jet_after[:, 0]) / jet_before[:, 0]) is_bg_like = K.squeeze(y1, axis=1) < 0.5 return K.mean(tf.where(is_bg_like, mse, tf.zeros_like(mse))) def jmass_loss(y1, y2): mse = K.square(jet_before[:, 3] - jet_after[:, 3]) is_bg_like = K.squeeze(y1, axis=1) < 0.5 return K.mean(tf.where(is_bg_like, mse, tf.zeros_like(mse))) def jmass_res(y1, y2): pull = (jet_before[:, 0] - jet_after[:, 0]) / jet_before[:, 0] is_bg_like = K.squeeze(y1, axis=1) < 0.5 return K.std(tf.where(is_bg_like, 2 * pull, tf.zeros_like(pull))) composite.lambda_adv = K.variable(1.0) composite.lambda_bg = K.variable(1.0) composite.lambda_jpt = K.variable(1.0) composite.lambda_jmass = K.variable(1.0) def loss(y1, y2): return composite.lambda_adv * adv_loss(y1,y2) + \ composite.lambda_bg * bg_loss(y1,y2) + \ composite.lambda_jpt * jpt_loss(y1,y2) + \ composite.lambda_jmass * jmass_loss(y1,y2) composite.compile( optimizer='adam', loss=loss, metrics=[adv_loss, bg_loss, jpt_loss, jmass_loss, jmass_res]) composite.adversary = adversary composite.calc = calc return composite