コード例 #1
0
def mk_adversary(target_model,
                 n_units=(300, 300, 300, 300),
                 epsilons=(1, 1, 1),
                 input_shape=None,
                 center_jets=True):

    adv_input = layers.Input(input_shape)

    x_in = adv_input

    if center_jets:
        x_in = util.CenterJet()(x_in)

    x = x_in

    x = layers.Flatten()(x)

    for n in n_units:
        x = layers.Dense(n, activation='relu')(x)

    x = layers.Dense(np.prod(input_shape))(x)
    x = layers.Reshape(input_shape)(x)

    def add_deltas(x):
        x_old = x[0]
        dx = x[1]
        jet = x[2]

        pt_old, eta_old, phi_old = tf.split(x_old, 3, axis=-1)
        dpt, deta, dphi = tf.split(dx, 3, axis=-1)

        jet_pt, jet_eta, jet_phi, jet_mass = tf.split(jet, 4, axis=-1)

        zeros = tf.zeros_like(pt_old)

        # try something new... add pT in units of jet mass
        #pt_new = tf.clip_by_value(pt_old*(1+epsilons[0]*tf.tanh(dpt)), defs.MIN_PT, 9e9)
        #jpt = tf.reduce_sum(pt_old, axis=-2, keepdims=True)

        jpt = tf.reshape(jet_pt, (-1, 1, 1))
        jeta = tf.reshape(jet_eta, (-1, 1, 1))
        jphi = tf.reshape(jet_phi, (-1, 1, 1))

        was_valid = pt_old > 0

        can_create = False
        if not can_create:
            pt_new = pt_old * (1 + epsilons[0] * tf.tanh(dpt))
            pt_new = tf.where(was_valid, pt_new, zeros)
        else:
            pt_new = pt_old + epsilons[0] * jpt * tf.tanh(dpt)

        eeta = epsilons[1] * tf.tanh(deta)
        eta_new_ext = eta_old + eeta
        eta_new_next = jeta + eeta
        eta_new = tf.where(was_valid, eta_new_ext, eta_new_next)

        ephi = epsilons[2] * tf.tanh(dphi)
        phi_new_ext = phi_old + ephi
        phi_new_next = jphi + ephi
        phi_new = tf.where(was_valid, phi_new_ext, phi_new_next)

        #is_valid = pt_old>0
        is_valid = pt_new > defs.MIN_PT
        pt_new = tf.where(is_valid, pt_new, zeros)
        eta_new = tf.where(is_valid, eta_new, zeros)
        phi_new = tf.where(is_valid, phi_new, zeros)

        return tf.concat([pt_new, eta_new, phi_new], axis=-1)

    tmp_jet = util.JetVector()(x_in)
    adv_output = layers.Lambda(add_deltas)([x_in, x, tmp_jet])

    adversary = Model(adv_input, adv_output, name='adversary')

    calc = mk_HL_calc(features=('pt', 'eta', 'phi', 'mass'))

    target_model.trainable = False
    composite_input = layers.Input(input_shape)
    pre_x = composite_input

    adv_x = adversary(pre_x)
    adv_dx = layers.subtract([adv_x, pre_x])
    composite_output = target_model(adv_x)
    cls_output = target_model(pre_x)
    jet_before = calc(pre_x)
    jet_after = calc(adv_x)
    composite = Model(composite_input, composite_output, name='composite')

    def adv_loss(y1, y2):
        xent = keras.losses.binary_crossentropy(
            tf.zeros_like(composite_output), composite_output)
        is_sig_like = K.squeeze(y1, axis=1) > 0.5
        return K.mean(tf.where(is_sig_like, xent, tf.zeros_like(xent)))

    def bg_loss(y1, y2):
        mse = K.squeeze(K.square(composite_output - cls_output), axis=1)
        is_bg_like = K.squeeze(y1, axis=1) < 0.5
        return K.mean(tf.where(is_bg_like, mse, tf.zeros_like(mse)))

    def jpt_loss(y1, y2):
        mse = K.square((jet_before[:, 0] - jet_after[:, 0]) / jet_before[:, 0])
        is_bg_like = K.squeeze(y1, axis=1) < 0.5
        return K.mean(tf.where(is_bg_like, mse, tf.zeros_like(mse)))

    def jmass_loss(y1, y2):
        mse = K.square(jet_before[:, 3] - jet_after[:, 3])
        is_bg_like = K.squeeze(y1, axis=1) < 0.5
        return K.mean(tf.where(is_bg_like, mse, tf.zeros_like(mse)))

    def jmass_res(y1, y2):
        pull = (jet_before[:, 0] - jet_after[:, 0]) / jet_before[:, 0]
        is_bg_like = K.squeeze(y1, axis=1) < 0.5
        return K.std(tf.where(is_bg_like, 2 * pull, tf.zeros_like(pull)))

    composite.lambda_adv = K.variable(1.0)
    composite.lambda_bg = K.variable(1.0)
    composite.lambda_jpt = K.variable(1.0)
    composite.lambda_jmass = K.variable(1.0)

    def loss(y1, y2):
        return composite.lambda_adv * adv_loss(y1,y2) + \
               composite.lambda_bg * bg_loss(y1,y2) + \
               composite.lambda_jpt * jpt_loss(y1,y2) + \
               composite.lambda_jmass * jmass_loss(y1,y2)

    composite.compile(
        optimizer='adam',
        loss=loss,
        metrics=[adv_loss, bg_loss, jpt_loss, jmass_loss, jmass_res])

    composite.adversary = adversary
    composite.calc = calc

    return composite