コード例 #1
0
ファイル: LossLayers.py プロジェクト: cms-pepr/HGCalML
    def _rs_loop(coords, tidx):
        Msel, M_not, N_per_obj = CreateMidx(tidx,
                                            calc_m_not=True)  #N_per_obj: K x 1
        if N_per_obj is None:
            return 0., 0., 0.  #no objects, discard
        N_per_obj = tf.cast(N_per_obj, dtype='float32')
        N_tot = tf.cast(tidx.shape[0], dtype='float32')
        K = tf.cast(Msel.shape[0], dtype='float32')

        padmask_m = SelectWithDefault(Msel, tf.ones_like(coords[:, 0:1]),
                                      0.)  # K x V' x 1
        coords_m = SelectWithDefault(Msel, coords, 0.)  # K x V' x C
        #create average
        av_coords_m = tf.reduce_sum(coords_m * padmask_m, axis=1)  # K x C
        av_coords_m = tf.math.divide_no_nan(av_coords_m, N_per_obj)  #K x C
        av_coords_m = tf.expand_dims(av_coords_m, axis=1)  ##K x 1 x C

        distloss = tf.reduce_sum((av_coords_m - coords_m)**2, axis=2)
        distloss = tf.math.log(tf.math.exp(1.) * distloss +
                               1.) * padmask_m[:, :, 0]
        distloss = tf.math.divide_no_nan(tf.reduce_sum(distloss, axis=1),
                                         N_per_obj[:, 0])  #K
        distloss = tf.math.divide_no_nan(tf.reduce_sum(distloss), K)

        repdist = tf.expand_dims(coords, axis=0) - av_coords_m  #K x V x C
        repdist = tf.reduce_sum(repdist**2, axis=-1, keepdims=True)  #K x V x 1
        reploss = M_not * tf.exp(-repdist)  #K x V x 1
        #downweight noise
        reploss *= tf.expand_dims(
            (1. - 0.9 * tf.cast(tidx < 0, dtype='float32')), axis=0)
        reploss = tf.reduce_sum(reploss, axis=1) / (N_tot - N_per_obj)  #K x 1
        reploss = tf.reduce_sum(reploss) / (K + 1e-3)

        return distloss + reploss, distloss, reploss
コード例 #2
0
def oc_per_batch_element(
        beta,
        x,
        q_min,
        object_weights, # V x 1 !!
        truth_idx,
        is_spectator,
        payload_loss,
        S_B=1.,
        payload_weight_function = None,  #receives betas as K x V x 1 as input, and a threshold val
        payload_weight_threshold = 0.8,
        use_mean_x = 0.,
        cont_beta_loss=False,
        prob_repulsion=False,
        phase_transition=False,
        phase_transition_double_weight=False,
        alt_potential_norm=False,
        cut_payload_beta_gradient=False,
        kalpha_damping_strength=0.
        ):
    '''
    all inputs
    V x X , where X can be 1
    '''
    
    if not alt_potential_norm:
        raise ValueError("not alt_potential_norm not implemented")
    if not prob_repulsion:
        raise ValueError("not prob_repulsion not implemented")
    if not phase_transition:
        raise ValueError("not phase_transition not implemented")
    if phase_transition_double_weight:
        raise ValueError("phase_transition_double_weight not implemented")
    if cont_beta_loss:
        raise ValueError("cont_beta_loss not implemented")
    if payload_weight_function is not None:
        raise ValueError("payload_weight_function not implemented")
        
        
    
    #set all spectators invalid here, everything scales with beta, so:
    beta_in = beta
    beta = tf.clip_by_value(beta, 0.,1.-1e-4)
    beta *= (1. - is_spectator)
    qraw = tf.math.atanh(beta)**2 
    q = qraw + q_min * (1. - is_spectator) # V x 1
    #q = tf.where(beta_in<1.-1e-4, q, tf.math.atanh(1.-1e-4)**2 + q_min + beta_in) #just give the rest above clip a gradient
    
    N = tf.cast(beta.shape[0], dtype='float32')
    is_noise = tf.where(truth_idx<0, tf.zeros_like(truth_idx,dtype='float32'), 1.)#V x 1
    
    Msel, M_not, N_per_obj = CreateMidx(truth_idx, calc_m_not=True)
    
    N_per_obj = tf.cast(N_per_obj, dtype='float32') # K x 1
    
    K = tf.cast(Msel.shape[0], dtype='float32') 
    
    padmask_m = SelectWithDefault(Msel, tf.zeros_like(beta_in)+1., 0) #K x V-obj x 1
    x_m = SelectWithDefault(Msel, x, 0.) #K x V-obj x C
    beta_m = SelectWithDefault(Msel, beta_in, 0.) #K x V-obj x 1
    q_m = SelectWithDefault(Msel, q, 0.)#K x V-obj x 1
    object_weights_m = SelectWithDefault(Msel, object_weights, 0.)
    
    kalpha_m = tf.argmax(beta_m, axis=1) # K x 1
    
    x_kalpha_m = tf.gather_nd(x_m,kalpha_m, batch_dims=1) # K x C
    if use_mean_x>0:
        x_kalpha_m_m = tf.reduce_sum(q_m * x_m * padmask_m,axis=1) # K x C
        x_kalpha_m_m = tf.math.divide_no_nan(x_kalpha_m_m, tf.reduce_sum(q_m * padmask_m, axis=1)+1e-9)
        x_kalpha_m = use_mean_x * x_kalpha_m_m + (1. - use_mean_x)*x_kalpha_m
    
    if kalpha_damping_strength > 0:
        x_kalpha_m = kalpha_damping_strength * tf.stop_gradient(x_kalpha_m) + (1. - kalpha_damping_strength)*x_kalpha_m
    
    q_kalpha_m = tf.gather_nd(q_m,kalpha_m, batch_dims=1) # K x 1
    beta_kalpha_m = tf.gather_nd(beta_m,kalpha_m, batch_dims=1) # K x 1
    
    object_weights_kalpha_m = tf.gather_nd(object_weights_m,kalpha_m, batch_dims=1) # K x 1
    
    distancesq_m = tf.reduce_sum( (tf.expand_dims(x_kalpha_m, axis=1) - x_m)**2, axis=-1, keepdims=True) #K x V-obj x 1
    V_att = q_m * tf.expand_dims(q_kalpha_m,axis=1) * distancesq_m #K x V-obj x 1
    V_att = V_att * tf.expand_dims(object_weights_kalpha_m,axis=1) #K x V-obj x 1
    
    V_att = tf.math.divide_no_nan(tf.reduce_sum(padmask_m * V_att,axis=1), N_per_obj+1e-9) # K x 1
    V_att = tf.math.divide_no_nan(tf.reduce_sum(V_att,axis=0), K+1e-9) # 1
    
    
    #now the bit that needs Mnot
    V_rep = tf.expand_dims(x_kalpha_m, axis=1) #K x 1 x C
    V_rep = V_rep - tf.expand_dims(x, axis=0) #K x V x C
    V_rep = tf.reduce_sum(V_rep**2, axis=-1, keepdims=True)  #K x V x 1
    
    V_rep = -2.*tf.math.log(1.-tf.math.exp(-V_rep/2.)+1e-5)
    V_rep *= M_not * tf.expand_dims(q, axis=0) #K x V x 1
    V_rep = tf.reduce_sum(V_rep, axis=1) #K x 1
    
    V_rep *= object_weights_kalpha_m * q_kalpha_m #K x 1
    
    V_rep = tf.math.divide_no_nan(V_rep, 
                                  tf.expand_dims(tf.expand_dims(N,axis=0),axis=0) - N_per_obj+1e-9) # K x 1
    V_rep = tf.math.divide_no_nan(tf.reduce_sum(V_rep,axis=0), K+1e-9) # 1
    
    
    ## beta terms
    B_pen = - tf.reduce_sum(padmask_m * 1./(20.*distancesq_m + 1.),axis=1) # K x 1
    B_pen += 1. #remove self-interaction term (just for offset)
    B_pen *= object_weights_kalpha_m * beta_kalpha_m
    B_pen = tf.math.divide_no_nan(B_pen, N_per_obj+1e-9) # K x 1
    #now 'standard' 1-beta
    B_pen -= 0.2*object_weights_kalpha_m * tf.math.sqrt(beta_kalpha_m+1e-6) 
    #another "-> 1, but slower" per object
    B_pen = tf.math.divide_no_nan(tf.reduce_sum(B_pen,axis=0), K+1e-9) # 1
    
    
    too_much_B_pen = tf.constant([0.],dtype='float32')
    
    Noise_pen = S_B*tf.math.divide_no_nan(tf.reduce_sum(is_noise * beta_in), tf.reduce_sum(is_noise))
    
    #explicit payload weight function here, the old one was odd
    
    p_w = tf.math.atanh(padmask_m * tf.clip_by_value(beta_m, 1e-4, 1.-1e-4))**2 #already zero-padded  , K x V_perobj x 1
    p_w = tf.math.divide_no_nan(p_w, tf.reduce_max(p_w, axis=1, keepdims=True)+1e-9) 
    #normalise to maximum; this + 1e-9 might be an issue POSSIBLE FIXME
    
    if cut_payload_beta_gradient:
        p_w = tf.stop_gradient(p_w)
        
    payload_loss_m = p_w * SelectWithDefault(Msel, payload_loss, 0.) #K x V_perobj x P
    payload_loss_m = tf.reduce_sum(payload_loss_m, axis=1)
    
    pll = tf.math.divide_no_nan(payload_loss_m, N_per_obj+1e-9) # K x P
    pll = tf.math.divide_no_nan(tf.reduce_sum(pll,axis=0), K+1e-9) # P
    
    return V_att, V_rep, Noise_pen, B_pen, pll, too_much_B_pen
コード例 #3
0
from oc_helper_ops import CreateMidx, SelectWithDefault
import tensorflow as tf

nvert=20

truth_idxs = tf.random.uniform((nvert,1), 0, 6, dtype='int32', seed=0) - 1 #for noise
features =  tf.random.uniform((nvert,1),seed=0)


selidx,mnot,cperunique = CreateMidx(truth_idxs, calc_m_not=True)

#just a small consistency check



#print(truth_idxs)
#print(selidx)
#print(mnot)
#print(cperunique)

beta_m  = SelectWithDefault(selidx, features, -1.)

kalpha_m = tf.argmax(beta_m,axis=1) 
#print(beta_m, kalpha_m)

#print(tf.gather_nd(beta_m,kalpha_m, batch_dims=1))

#now test the whole loss

from object_condensation import oc_per_batch_element, oc_per_batch_element_old
コード例 #4
0
def oc_per_batch_element(
        beta,
        x,
        q_min,
        object_weights,  # V x 1 !!
        truth_idx,
        is_spectator,
        payload_loss,
        S_B=1.,
        distance_scale=None,
        payload_weight_function=None,  #receives betas as K x V x 1 as input, and a threshold val
        payload_weight_threshold=0.8,
        use_mean_x=0.,
        cont_beta_loss=False,
        prob_repulsion=False,
        phase_transition=False,
        phase_transition_double_weight=False,
        alt_potential_norm=False,
        payload_beta_gradient_damping_strength=0.,
        kalpha_damping_strength=0.,
        beta_gradient_damping=0.,
        soft_q_scaling=True,
        weight_by_q=False,
        repulsion_q_min=-1.,
        super_repulsion=False):
    '''
    all inputs
    V x X , where X can be 1
    '''

    if not alt_potential_norm:
        raise ValueError("not alt_potential_norm not implemented")
    if not prob_repulsion:
        raise ValueError("not prob_repulsion not implemented")
    if not phase_transition:
        raise ValueError("not phase_transition not implemented")
    if phase_transition_double_weight:
        raise ValueError("phase_transition_double_weight not implemented")
    if cont_beta_loss:
        raise ValueError("cont_beta_loss not implemented")
    if payload_weight_function is not None:
        raise ValueError("payload_weight_function not implemented")

    #set all spectators invalid here, everything scales with beta, so:
    if beta_gradient_damping > 0.:
        beta = beta_gradient_damping * tf.stop_gradient(beta) + (
            1. - beta_gradient_damping) * beta
    beta_in = beta
    beta = tf.clip_by_value(beta, 0., 1. - 1e-4)
    beta *= (1. - is_spectator)
    qraw = tf.math.atanh(beta)**2

    if soft_q_scaling:
        qraw = tf.math.atanh(beta / 1.002)**2  #beta_in**4 *20.
        beta = beta_in * (1. - is_spectator)  # no need for clipping

    q = qraw + q_min * (1. - is_spectator)  # V x 1
    #q = tf.where(beta_in<1.-1e-4, q, tf.math.atanh(1.-1e-4)**2 + q_min + beta_in) #just give the rest above clip a gradient

    N = tf.cast(beta.shape[0], dtype='float32')
    is_noise = tf.where(truth_idx < 0,
                        tf.zeros_like(truth_idx, dtype='float32') + 1.,
                        0.)  #V x 1

    Msel, M_not, N_per_obj = CreateMidx(truth_idx, calc_m_not=True)

    N_per_obj = tf.cast(N_per_obj, dtype='float32')  # K x 1

    K = tf.cast(Msel.shape[0], dtype='float32')

    padmask_m = SelectWithDefault(Msel,
                                  tf.zeros_like(beta_in) + 1.,
                                  0)  #K x V-obj x 1
    x_m = SelectWithDefault(Msel, x, 0.)  #K x V-obj x C
    beta_m = SelectWithDefault(Msel, beta_in, 0.)  #K x V-obj x 1
    q_m = SelectWithDefault(Msel, q, 0.)  #K x V-obj x 1
    object_weights_m = SelectWithDefault(Msel, object_weights, 0.)
    distance_scale_m = SelectWithDefault(Msel, distance_scale, 1.)

    kalpha_m = tf.argmax(beta_m, axis=1)  # K x 1

    x_kalpha_m = tf.gather_nd(x_m, kalpha_m, batch_dims=1)  # K x C
    if use_mean_x > 0:
        x_kalpha_m_m = tf.reduce_sum(q_m * x_m * padmask_m, axis=1)  # K x C
        x_kalpha_m_m = tf.math.divide_no_nan(
            x_kalpha_m_m,
            tf.reduce_sum(q_m * padmask_m, axis=1) + 1e-9)
        x_kalpha_m = use_mean_x * x_kalpha_m_m + (1. - use_mean_x) * x_kalpha_m

    if kalpha_damping_strength > 0:
        x_kalpha_m = kalpha_damping_strength * tf.stop_gradient(x_kalpha_m) + (
            1. - kalpha_damping_strength) * x_kalpha_m

    q_kalpha_m = tf.gather_nd(q_m, kalpha_m, batch_dims=1)  # K x 1
    beta_kalpha_m = tf.gather_nd(beta_m, kalpha_m, batch_dims=1)  # K x 1

    object_weights_kalpha_m = tf.gather_nd(object_weights_m,
                                           kalpha_m,
                                           batch_dims=1)  # K x 1
    distance_scale_kalpha_m = tf.gather_nd(distance_scale_m,
                                           kalpha_m,
                                           batch_dims=1)  # K x 1
    distance_scale_kalpha_m_exp = tf.expand_dims(distance_scale_kalpha_m,
                                                 axis=2)  # K x 1 x 1

    distancesq_m = tf.reduce_sum((tf.expand_dims(x_kalpha_m, axis=1) - x_m)**2,
                                 axis=-1,
                                 keepdims=True)  #K x V-obj x 1
    distancesq_m *= distance_scale_kalpha_m_exp**2

    huberdistsq = huber(tf.sqrt(distancesq_m + 1e-5), d=4)  #acts at 4
    V_att = q_m * tf.expand_dims(q_kalpha_m,
                                 axis=1) * huberdistsq  #K x V-obj x 1
    V_att = V_att * tf.expand_dims(object_weights_kalpha_m,
                                   axis=1)  #K x V-obj x 1

    if weight_by_q:
        V_att = tf.math.divide_no_nan(tf.reduce_sum(padmask_m * V_att, axis=1),
                                      tf.reduce_sum(q_m, axis=1))  # K x 1
    else:
        V_att = tf.math.divide_no_nan(tf.reduce_sum(padmask_m * V_att, axis=1),
                                      N_per_obj + 1e-9)  # K x 1
    V_att = tf.math.divide_no_nan(tf.reduce_sum(V_att, axis=0), K + 1e-9)  # 1

    #what if Vatt and Vrep are weighted by q, not scaled by it?
    q_rep = q
    if repulsion_q_min >= 0:
        q_rep = qraw + repulsion_q_min
        q_kalpha_m += repulsion_q_min - q_min

    #now the bit that needs Mnot
    Mnot_distances = tf.expand_dims(x_kalpha_m, axis=1)  #K x 1 x C
    Mnot_distances = Mnot_distances - tf.expand_dims(x, axis=0)  #K x V x C

    if super_repulsion:
        sq_distance = tf.reduce_sum(Mnot_distances**2, axis=-1,
                                    keepdims=True)  #K x V x 1
        l_distance = tf.reduce_sum(tf.abs(Mnot_distances),
                                   axis=-1,
                                   keepdims=True)  #K x V x 1
        V_rep = 0.5 * (sq_distance + l_distance)

    else:
        V_rep = tf.reduce_sum(Mnot_distances**2, axis=-1,
                              keepdims=True)  #K x V x 1

    V_rep *= distance_scale_kalpha_m_exp**2  #K x V x 1 , same scaling as attractive potential

    V_rep = 1. / (V_rep + 0.1
                  )  #-2.*tf.math.log(1.-tf.math.exp(-V_rep/2.)+1e-5)

    V_rep *= M_not * tf.expand_dims(q_rep, axis=0)  #K x V x 1
    V_rep = tf.reduce_sum(V_rep, axis=1)  #K x 1

    V_rep *= object_weights_kalpha_m * q_kalpha_m  #K x 1

    if weight_by_q:
        sumq = tf.reduce_sum(M_not * tf.expand_dims(q_rep, axis=0), axis=1)
        V_rep = tf.math.divide_no_nan(V_rep, sumq)  # K x 1
    else:
        V_rep = tf.math.divide_no_nan(
            V_rep,
            tf.expand_dims(tf.expand_dims(N, axis=0), axis=0) - N_per_obj +
            1e-9)  # K x 1
    V_rep = tf.math.divide_no_nan(tf.reduce_sum(V_rep, axis=0), K + 1e-9)  # 1

    ## beta terms
    B_pen = -tf.reduce_sum(padmask_m * 1. /
                           (20. * distancesq_m + 1.), axis=1)  # K x 1
    B_pen += 1.  #remove self-interaction term (just for offset)
    B_pen *= object_weights_kalpha_m * beta_kalpha_m
    B_pen = tf.math.divide_no_nan(B_pen, N_per_obj + 1e-9)  # K x 1
    #now 'standard' 1-beta
    B_pen -= 0.2 * object_weights_kalpha_m * (
        tf.math.log(beta_kalpha_m + 1e-9))  #tf.math.sqrt(beta_kalpha_m+1e-6)
    #another "-> 1, but slower" per object
    B_pen = tf.math.divide_no_nan(tf.reduce_sum(B_pen, axis=0), K + 1e-9)  # 1

    too_much_B_pen = tf.constant([0.], dtype='float32')

    Noise_pen = S_B * tf.math.divide_no_nan(tf.reduce_sum(is_noise * beta_in),
                                            tf.reduce_sum(is_noise))

    #explicit payload weight function here, the old one was odd

    #too aggressive scaling is bad for high learning rates. Move to simple x^4
    p_w = padmask_m * tf.clip_by_value(
        beta_m**2, 1e-3, 10.)  #already zero-padded  , K x V_perobj x 1
    #normalise to maximum; this + 1e-9 might be an issue POSSIBLE FIXME

    if payload_beta_gradient_damping_strength > 0:
        p_w = payload_beta_gradient_damping_strength * tf.stop_gradient(p_w) + \
        (1.- payload_beta_gradient_damping_strength)* p_w

    payload_loss_m = p_w * SelectWithDefault(
        Msel, (1. - is_noise) * payload_loss, 0.)  #K x V_perobj x P
    payload_loss_m = object_weights_kalpha_m * tf.reduce_sum(payload_loss_m,
                                                             axis=1)
    payload_loss_m = tf.math.divide_no_nan(payload_loss_m,
                                           tf.reduce_sum(p_w, axis=1))

    #pll = tf.math.divide_no_nan(payload_loss_m, N_per_obj+1e-9) # K x P #really?
    pll = tf.math.divide_no_nan(tf.reduce_sum(payload_loss_m, axis=0),
                                K + 1e-3)  # P

    #explicit K**2 repulsion
    #if k_sq_repulsion_strength > 0.: #x_kalpha_m: K  x C
    #    k_sq_rep = tf.expand_dims(x_kalpha_m, axis=0) - tf.expand_dims(x_kalpha_m, axis=1) #x_kalpha_m: K  x K x C
    #    k_sq_rep = tf.reduce_sum(k_sq_rep**2, axis=-1) #distances**2 K x K
    #    k_sq_rep = -2.*tf.math.log(1.-tf.math.exp(-k_sq_rep/2.)+1e-5) #K x K
    #    #add qTq scaling also here?
    #    k_sq_rep *= q_kalpha_m # adding the latter term would just add a factor of 2. to the corresponding kalpha Mnot term * tf.expand_dims(q_kalpha_m[:,0], axis=0) #K x K
    #    k_sq_rep *= object_weights_kalpha_m * tf.expand_dims(object_weights_kalpha_m[:,0], axis=0) #K x K
    #    k_sq_rep = tf.math.divide_no_nan(tf.reduce_sum(k_sq_rep,axis=0), K+1e-9)
    #    k_sq_rep = tf.math.divide_no_nan(tf.reduce_sum(k_sq_rep,axis=0), K+1e-9)
    #
    #    V_rep += k_sq_repulsion_strength * k_sq_rep
    #    #object_weights_kalpha_m

    return V_att, V_rep, Noise_pen, B_pen, pll, too_much_B_pen
コード例 #5
0
def oc_per_batch_element(
        beta,
        x,
        q_min,
        object_weights,  # V x 1 !!
        truth_idx,
        is_spectator,
        payload_loss,
        S_B=1.,
        noise_q_min=None,
        distance_scale=None,
        payload_weight_function=None,  #receives betas as K x V x 1 as input, and a threshold val
        payload_weight_threshold=0.8,
        use_mean_x=0.,
        cont_beta_loss=False,
        prob_repulsion=False,
        phase_transition=False,
        phase_transition_double_weight=False,
        payload_beta_gradient_damping_strength=0.,
        kalpha_damping_strength=0.,
        beta_gradient_damping=0.,
        soft_q_scaling=True,
        weight_by_q=False,
        repulsion_q_min=-1.,
        super_repulsion=False,
        super_attraction=False,
        div_repulsion=False,
        soft_att=True,
        dynamic_payload_scaling_onset=-0.03):
    '''
    all inputs
    V x X , where X can be 1
    '''
    tf.assert_equal(True, is_spectator >= 0.)
    tf.assert_equal(True, beta >= 0.)

    if prob_repulsion:
        raise ValueError("prob_repulsion not implemented")
    if phase_transition_double_weight:
        raise ValueError("phase_transition_double_weight not implemented")
    if payload_weight_function is not None:
        raise ValueError("payload_weight_function not implemented")

    #set all spectators invalid here, everything scales with beta, so:
    if beta_gradient_damping > 0.:
        beta = beta_gradient_damping * tf.stop_gradient(beta) + (
            1. - beta_gradient_damping) * beta
    beta_in = beta
    beta = tf.clip_by_value(beta, 0., 1. - 1e-4)

    q_min *= (1. - is_spectator)

    qraw = tf.math.atanh(beta)**2
    if soft_q_scaling:
        qraw = tf.math.atanh(beta_in / 1.002)**2  #beta_in**4 *20.

    is_noise = tf.where(truth_idx < 0,
                        tf.zeros_like(truth_idx, dtype='float32') + 1.,
                        0.)  #V x 1
    if noise_q_min is not None:
        q_min = (1. - is_noise) * q_min + is_noise * noise_q_min

    q_min = tf.where(
        q_min < 0, 0.,
        q_min)  #just safety in case there are some numerical effects

    q = qraw + q_min  # V x 1
    #q = tf.where(beta_in<1.-1e-4, q, tf.math.atanh(1.-1e-4)**2 + q_min + beta_in) #just give the rest above clip a gradient

    N = tf.cast(beta.shape[0], dtype='float32')

    Msel, M_not, N_per_obj = CreateMidx(truth_idx, calc_m_not=True)
    #use eager here
    if Msel is None:
        #V_att, V_rep, Noise_pen, B_pen, pll, too_much_B_pen
        print(
            '>>> WARNING: Event has no objects, only noise! Will return zero loss. <<<'
        )
        zero_tensor = tf.reduce_mean(q, axis=0) * 0.
        zero_payload = tf.reduce_mean(payload_loss, axis=0) * 0.
        return zero_tensor, zero_tensor, zero_tensor, zero_tensor, zero_payload, zero_tensor

    N_per_obj = tf.cast(N_per_obj, dtype='float32')  # K x 1

    K = tf.cast(Msel.shape[0], dtype='float32')

    ########################################################
    #sanity check, use none of the following for the loss calculation
    truth_m = SelectWithDefault(Msel, truth_idx, -2)  #K x V-obj x 1
    truth_same = truth_m[:, 0:1] == truth_m
    truth_same = tf.where(truth_m == -2, True, truth_same)
    tf.assert_equal(
        tf.reduce_all(truth_same),
        True,
        message="truth indices do not match object selection, serious bug")
    #end sanity check
    ########################################################

    padmask_m = SelectWithDefault(Msel,
                                  tf.zeros_like(beta_in) + 1.,
                                  0.)  #K x V-obj x 1
    x_m = SelectWithDefault(Msel, x, 0.)  #K x V-obj x C
    beta_m = SelectWithDefault(Msel, beta, 0.)  #K x V-obj x 1
    is_spectator_m = SelectWithDefault(Msel, is_spectator, 0.)  #K x V-obj x 1
    q_m = SelectWithDefault(Msel, q, 0.)  #K x V-obj x 1
    object_weights_m = SelectWithDefault(Msel, object_weights, 0.)

    distance_scale += 1e-3
    distance_scale_m = SelectWithDefault(Msel, distance_scale, 1.)

    tf.assert_greater(distance_scale_m,
                      0.,
                      message="predicted distances must be greater zero")

    kalpha_m = tf.argmax((1. - is_spectator_m) * beta_m, axis=1)  # K x 1

    x_kalpha_m = tf.gather_nd(x_m, kalpha_m, batch_dims=1)  # K x C
    if use_mean_x > 0:
        x_kalpha_m_m = tf.reduce_sum(beta_m * q_m * x_m * padmask_m,
                                     axis=1)  # K x C
        x_kalpha_m_m = tf.math.divide_no_nan(
            x_kalpha_m_m,
            tf.reduce_sum(beta_m * q_m * padmask_m, axis=1) + 1e-9)
        x_kalpha_m = use_mean_x * x_kalpha_m_m + (1. - use_mean_x) * x_kalpha_m

    if kalpha_damping_strength > 0:
        x_kalpha_m = kalpha_damping_strength * tf.stop_gradient(x_kalpha_m) + (
            1. - kalpha_damping_strength) * x_kalpha_m

    q_kalpha_m = tf.gather_nd(q_m, kalpha_m, batch_dims=1)  # K x 1
    beta_kalpha_m = tf.gather_nd(beta_m, kalpha_m, batch_dims=1)  # K x 1

    object_weights_kalpha_m = tf.gather_nd(object_weights_m,
                                           kalpha_m,
                                           batch_dims=1)  # K x 1

    #make the distance scale a beta weighted mean so that there is more than 1 impact per object
    distance_scale_kalpha_m = tf.math.divide_no_nan(
        tf.reduce_sum(distance_scale_m * beta_m * padmask_m, axis=1),
        tf.reduce_sum(beta_m * padmask_m, axis=1) + 1e-3) + 1e-3  #K x 1
    #distance_scale_kalpha_m = tf.gather_nd(distance_scale_m,kalpha_m, batch_dims=1) # K x 1

    distance_scale_kalpha_m_exp = tf.expand_dims(distance_scale_kalpha_m,
                                                 axis=2)  # K x 1 x 1

    distancesq_m = tf.reduce_sum((tf.expand_dims(x_kalpha_m, axis=1) - x_m)**2,
                                 axis=-1,
                                 keepdims=True)  #K x V-obj x 1
    distancesq_m = tf.math.divide_no_nan(
        distancesq_m, 2. * distance_scale_kalpha_m_exp**2 + 1e-6)

    absdist = tf.sqrt(distancesq_m + 1e-6)
    huberdistsq = huber(absdist, d=4)  #acts at 4
    if super_attraction:
        huberdistsq += 1. - tf.math.exp(-100. * absdist)

    V_att = q_m * tf.expand_dims(q_kalpha_m,
                                 axis=1) * huberdistsq  #K x V-obj x 1

    if soft_att:
        V_att = q_m * tf.math.log(tf.math.exp(1.) * distancesq_m + 1.)

    V_att = V_att * tf.expand_dims(object_weights_kalpha_m,
                                   axis=1)  #K x V-obj x 1

    if weight_by_q:
        V_att = tf.math.divide_no_nan(tf.reduce_sum(padmask_m * V_att, axis=1),
                                      tf.reduce_sum(q_m, axis=1))  # K x 1
    else:
        V_att = tf.math.divide_no_nan(tf.reduce_sum(padmask_m * V_att, axis=1),
                                      N_per_obj + 1e-9)  # K x 1

    # opt. used later in payload loss
    V_att_K = V_att
    V_att = tf.math.divide_no_nan(tf.reduce_sum(V_att, axis=0), K + 1e-9)  # 1

    #what if Vatt and Vrep are weighted by q, not scaled by it?
    q_rep = q
    if repulsion_q_min >= 0:
        raise ValueError("repulsion_q_min >= 0: spectators TBI")
        q_rep = (qraw + repulsion_q_min) * (1. - is_spectator)
        q_kalpha_m += repulsion_q_min - q_min

    #now the bit that needs Mnot
    Mnot_distances = tf.expand_dims(x_kalpha_m, axis=1)  #K x 1 x C
    Mnot_distances = Mnot_distances - tf.expand_dims(x, axis=0)  #K x V x C

    rep_distances = tf.reduce_sum(Mnot_distances**2, axis=-1,
                                  keepdims=True)  #K x V x 1

    rep_distances = tf.math.divide_no_nan(
        rep_distances, 2. * distance_scale_kalpha_m_exp**2 + 1e-6)

    V_rep = tf.math.exp(
        -rep_distances
    )  #1. / (V_rep + 0.1) #-2.*tf.math.log(1.-tf.math.exp(-V_rep/2.)+1e-5)

    if super_repulsion:
        V_rep += 10. * tf.math.exp(-100. * tf.sqrt(rep_distances + 1e-6))

    if div_repulsion:
        V_rep = 1. / (rep_distances + 0.1)

    #spec weights are in q
    V_rep *= M_not * tf.expand_dims(q_rep, axis=0)  #K x V x 1
    V_rep = tf.reduce_sum(V_rep, axis=1)  #K x 1

    V_rep *= object_weights_kalpha_m * q_kalpha_m  #K x 1

    if weight_by_q:
        sumq = tf.reduce_sum(M_not * tf.expand_dims(q_rep, axis=0), axis=1)
        V_rep = tf.math.divide_no_nan(V_rep, sumq)  # K x 1
    else:
        V_rep = tf.math.divide_no_nan(
            V_rep,
            tf.expand_dims(tf.expand_dims(N, axis=0), axis=0) - N_per_obj +
            1e-9)  # K x 1
    # opt used later in payload loss
    V_rep_K = V_rep
    V_rep = tf.math.divide_no_nan(tf.reduce_sum(V_rep, axis=0), K + 1e-9)  # 1

    B_pen = None

    def bpenhelp(b_m, exponent: int):
        b_mes = tf.reduce_sum(b_m**exponent, axis=1)
        if not exponent == 1:
            b_mes = (b_mes + 1e-16)**(1. / float(exponent))
        return tf.math.log((1. - b_mes)**2 + 1. + 1e-8)

    if phase_transition:
        ## beta terms
        B_pen = -tf.reduce_sum(padmask_m * 1. / (20. * distancesq_m + 1.),
                               axis=1)  # K x 1
        B_pen += 1.  #remove self-interaction term (just for offset)
        B_pen *= object_weights_kalpha_m * beta_kalpha_m
        B_pen = tf.math.divide_no_nan(B_pen, N_per_obj + 1e-9)  # K x 1
        #now 'standard' 1-beta
        B_pen -= 0.2 * object_weights_kalpha_m * (
            tf.math.log(beta_kalpha_m + 1e-9)
        )  #tf.math.sqrt(beta_kalpha_m+1e-6)
        #another "-> 1, but slower" per object
        B_pen = tf.math.divide_no_nan(tf.reduce_sum(B_pen, axis=0),
                                      K + 1e-9)  # 1

    else:
        B_pen_po = object_weights_kalpha_m * (1. - beta_kalpha_m)
        B_pen = tf.math.divide_no_nan(tf.reduce_sum(B_pen_po, axis=0),
                                      K + 1e-9)  #1
        #get out of random gradients in the beginning
        #introduces gradients on all betas of hits rather than just the max one
        B_up = tf.math.divide_no_nan(
            tf.reduce_sum((1. - is_noise) * (1. - beta_in)),
            N - tf.reduce_sum(is_noise))
        B_pen += 0.01 * B_pen * B_up  #if it's high try to elevate all betas

    if cont_beta_loss:
        B_pen = bpenhelp(beta_m, 2) + bpenhelp(beta_m, 4)
        B_pen = tf.math.divide_no_nan(
            tf.reduce_sum(object_weights_kalpha_m * B_pen, axis=0), K + 1e-9)

    too_much_B_pen = object_weights_kalpha_m * bpenhelp(
        beta_m, 1)  #K x 1, don't make it steep
    too_much_B_pen = tf.math.divide_no_nan(tf.reduce_sum(too_much_B_pen),
                                           K + 1e-9)

    Noise_pen = S_B * tf.math.divide_no_nan(tf.reduce_sum(is_noise * beta_in),
                                            tf.reduce_sum(is_noise) + 1e-3)

    #explicit payload weight function here, the old one was odd

    #too aggressive scaling is bad for high learning rates.
    p_w = padmask_m * tf.math.atanh(beta_m / 1.002)**2  #this is well behaved

    if payload_beta_gradient_damping_strength > 0:
        p_w = payload_beta_gradient_damping_strength * tf.stop_gradient(p_w) + \
        (1.- payload_beta_gradient_damping_strength)* p_w

    payload_loss_m = p_w * SelectWithDefault(
        Msel, (1. - is_noise) * payload_loss, 0.)  #K x V_perobj x P
    payload_loss_m = object_weights_kalpha_m * tf.reduce_sum(payload_loss_m,
                                                             axis=1)  # K x P

    #here normalisation per object
    payload_loss_m = tf.math.divide_no_nan(payload_loss_m,
                                           tf.reduce_sum(p_w, axis=1))

    #print('dynamic_payload_scaling_onset',dynamic_payload_scaling_onset)
    if dynamic_payload_scaling_onset > 0:
        #stop gradient
        V_scaler = tf.stop_gradient(V_rep_K + V_att_K)  # K x 1
        #print('N_per_obj[V_scaler=0]',N_per_obj[V_scaler==0])
        #max of V_scaler is around 1 given the potentials
        scaling = tf.exp(-tf.math.log(2.) * V_scaler /
                         (dynamic_payload_scaling_onset / 5.))
        #print('affected fraction',tf.math.count_nonzero(scaling>0.5,dtype='float32')/K,'max',tf.reduce_max(V_scaler,axis=0,keepdims=True))
        payload_loss_m *= scaling  #basically the onset of the rise
    #pll = tf.math.divide_no_nan(payload_loss_m, N_per_obj+1e-9) # K x P #really?
    pll = tf.math.divide_no_nan(tf.reduce_sum(payload_loss_m, axis=0),
                                K + 1e-3)  # P

    return V_att, V_rep, Noise_pen, B_pen, pll, too_much_B_pen