def _rs_loop(coords, tidx): Msel, M_not, N_per_obj = CreateMidx(tidx, calc_m_not=True) #N_per_obj: K x 1 if N_per_obj is None: return 0., 0., 0. #no objects, discard N_per_obj = tf.cast(N_per_obj, dtype='float32') N_tot = tf.cast(tidx.shape[0], dtype='float32') K = tf.cast(Msel.shape[0], dtype='float32') padmask_m = SelectWithDefault(Msel, tf.ones_like(coords[:, 0:1]), 0.) # K x V' x 1 coords_m = SelectWithDefault(Msel, coords, 0.) # K x V' x C #create average av_coords_m = tf.reduce_sum(coords_m * padmask_m, axis=1) # K x C av_coords_m = tf.math.divide_no_nan(av_coords_m, N_per_obj) #K x C av_coords_m = tf.expand_dims(av_coords_m, axis=1) ##K x 1 x C distloss = tf.reduce_sum((av_coords_m - coords_m)**2, axis=2) distloss = tf.math.log(tf.math.exp(1.) * distloss + 1.) * padmask_m[:, :, 0] distloss = tf.math.divide_no_nan(tf.reduce_sum(distloss, axis=1), N_per_obj[:, 0]) #K distloss = tf.math.divide_no_nan(tf.reduce_sum(distloss), K) repdist = tf.expand_dims(coords, axis=0) - av_coords_m #K x V x C repdist = tf.reduce_sum(repdist**2, axis=-1, keepdims=True) #K x V x 1 reploss = M_not * tf.exp(-repdist) #K x V x 1 #downweight noise reploss *= tf.expand_dims( (1. - 0.9 * tf.cast(tidx < 0, dtype='float32')), axis=0) reploss = tf.reduce_sum(reploss, axis=1) / (N_tot - N_per_obj) #K x 1 reploss = tf.reduce_sum(reploss) / (K + 1e-3) return distloss + reploss, distloss, reploss
def oc_per_batch_element( beta, x, q_min, object_weights, # V x 1 !! truth_idx, is_spectator, payload_loss, S_B=1., payload_weight_function = None, #receives betas as K x V x 1 as input, and a threshold val payload_weight_threshold = 0.8, use_mean_x = 0., cont_beta_loss=False, prob_repulsion=False, phase_transition=False, phase_transition_double_weight=False, alt_potential_norm=False, cut_payload_beta_gradient=False, kalpha_damping_strength=0. ): ''' all inputs V x X , where X can be 1 ''' if not alt_potential_norm: raise ValueError("not alt_potential_norm not implemented") if not prob_repulsion: raise ValueError("not prob_repulsion not implemented") if not phase_transition: raise ValueError("not phase_transition not implemented") if phase_transition_double_weight: raise ValueError("phase_transition_double_weight not implemented") if cont_beta_loss: raise ValueError("cont_beta_loss not implemented") if payload_weight_function is not None: raise ValueError("payload_weight_function not implemented") #set all spectators invalid here, everything scales with beta, so: beta_in = beta beta = tf.clip_by_value(beta, 0.,1.-1e-4) beta *= (1. - is_spectator) qraw = tf.math.atanh(beta)**2 q = qraw + q_min * (1. - is_spectator) # V x 1 #q = tf.where(beta_in<1.-1e-4, q, tf.math.atanh(1.-1e-4)**2 + q_min + beta_in) #just give the rest above clip a gradient N = tf.cast(beta.shape[0], dtype='float32') is_noise = tf.where(truth_idx<0, tf.zeros_like(truth_idx,dtype='float32'), 1.)#V x 1 Msel, M_not, N_per_obj = CreateMidx(truth_idx, calc_m_not=True) N_per_obj = tf.cast(N_per_obj, dtype='float32') # K x 1 K = tf.cast(Msel.shape[0], dtype='float32') padmask_m = SelectWithDefault(Msel, tf.zeros_like(beta_in)+1., 0) #K x V-obj x 1 x_m = SelectWithDefault(Msel, x, 0.) #K x V-obj x C beta_m = SelectWithDefault(Msel, beta_in, 0.) #K x V-obj x 1 q_m = SelectWithDefault(Msel, q, 0.)#K x V-obj x 1 object_weights_m = SelectWithDefault(Msel, object_weights, 0.) kalpha_m = tf.argmax(beta_m, axis=1) # K x 1 x_kalpha_m = tf.gather_nd(x_m,kalpha_m, batch_dims=1) # K x C if use_mean_x>0: x_kalpha_m_m = tf.reduce_sum(q_m * x_m * padmask_m,axis=1) # K x C x_kalpha_m_m = tf.math.divide_no_nan(x_kalpha_m_m, tf.reduce_sum(q_m * padmask_m, axis=1)+1e-9) x_kalpha_m = use_mean_x * x_kalpha_m_m + (1. - use_mean_x)*x_kalpha_m if kalpha_damping_strength > 0: x_kalpha_m = kalpha_damping_strength * tf.stop_gradient(x_kalpha_m) + (1. - kalpha_damping_strength)*x_kalpha_m q_kalpha_m = tf.gather_nd(q_m,kalpha_m, batch_dims=1) # K x 1 beta_kalpha_m = tf.gather_nd(beta_m,kalpha_m, batch_dims=1) # K x 1 object_weights_kalpha_m = tf.gather_nd(object_weights_m,kalpha_m, batch_dims=1) # K x 1 distancesq_m = tf.reduce_sum( (tf.expand_dims(x_kalpha_m, axis=1) - x_m)**2, axis=-1, keepdims=True) #K x V-obj x 1 V_att = q_m * tf.expand_dims(q_kalpha_m,axis=1) * distancesq_m #K x V-obj x 1 V_att = V_att * tf.expand_dims(object_weights_kalpha_m,axis=1) #K x V-obj x 1 V_att = tf.math.divide_no_nan(tf.reduce_sum(padmask_m * V_att,axis=1), N_per_obj+1e-9) # K x 1 V_att = tf.math.divide_no_nan(tf.reduce_sum(V_att,axis=0), K+1e-9) # 1 #now the bit that needs Mnot V_rep = tf.expand_dims(x_kalpha_m, axis=1) #K x 1 x C V_rep = V_rep - tf.expand_dims(x, axis=0) #K x V x C V_rep = tf.reduce_sum(V_rep**2, axis=-1, keepdims=True) #K x V x 1 V_rep = -2.*tf.math.log(1.-tf.math.exp(-V_rep/2.)+1e-5) V_rep *= M_not * tf.expand_dims(q, axis=0) #K x V x 1 V_rep = tf.reduce_sum(V_rep, axis=1) #K x 1 V_rep *= object_weights_kalpha_m * q_kalpha_m #K x 1 V_rep = tf.math.divide_no_nan(V_rep, tf.expand_dims(tf.expand_dims(N,axis=0),axis=0) - N_per_obj+1e-9) # K x 1 V_rep = tf.math.divide_no_nan(tf.reduce_sum(V_rep,axis=0), K+1e-9) # 1 ## beta terms B_pen = - tf.reduce_sum(padmask_m * 1./(20.*distancesq_m + 1.),axis=1) # K x 1 B_pen += 1. #remove self-interaction term (just for offset) B_pen *= object_weights_kalpha_m * beta_kalpha_m B_pen = tf.math.divide_no_nan(B_pen, N_per_obj+1e-9) # K x 1 #now 'standard' 1-beta B_pen -= 0.2*object_weights_kalpha_m * tf.math.sqrt(beta_kalpha_m+1e-6) #another "-> 1, but slower" per object B_pen = tf.math.divide_no_nan(tf.reduce_sum(B_pen,axis=0), K+1e-9) # 1 too_much_B_pen = tf.constant([0.],dtype='float32') Noise_pen = S_B*tf.math.divide_no_nan(tf.reduce_sum(is_noise * beta_in), tf.reduce_sum(is_noise)) #explicit payload weight function here, the old one was odd p_w = tf.math.atanh(padmask_m * tf.clip_by_value(beta_m, 1e-4, 1.-1e-4))**2 #already zero-padded , K x V_perobj x 1 p_w = tf.math.divide_no_nan(p_w, tf.reduce_max(p_w, axis=1, keepdims=True)+1e-9) #normalise to maximum; this + 1e-9 might be an issue POSSIBLE FIXME if cut_payload_beta_gradient: p_w = tf.stop_gradient(p_w) payload_loss_m = p_w * SelectWithDefault(Msel, payload_loss, 0.) #K x V_perobj x P payload_loss_m = tf.reduce_sum(payload_loss_m, axis=1) pll = tf.math.divide_no_nan(payload_loss_m, N_per_obj+1e-9) # K x P pll = tf.math.divide_no_nan(tf.reduce_sum(pll,axis=0), K+1e-9) # P return V_att, V_rep, Noise_pen, B_pen, pll, too_much_B_pen
from oc_helper_ops import CreateMidx, SelectWithDefault import tensorflow as tf nvert=20 truth_idxs = tf.random.uniform((nvert,1), 0, 6, dtype='int32', seed=0) - 1 #for noise features = tf.random.uniform((nvert,1),seed=0) selidx,mnot,cperunique = CreateMidx(truth_idxs, calc_m_not=True) #just a small consistency check #print(truth_idxs) #print(selidx) #print(mnot) #print(cperunique) beta_m = SelectWithDefault(selidx, features, -1.) kalpha_m = tf.argmax(beta_m,axis=1) #print(beta_m, kalpha_m) #print(tf.gather_nd(beta_m,kalpha_m, batch_dims=1)) #now test the whole loss from object_condensation import oc_per_batch_element, oc_per_batch_element_old
def oc_per_batch_element( beta, x, q_min, object_weights, # V x 1 !! truth_idx, is_spectator, payload_loss, S_B=1., distance_scale=None, payload_weight_function=None, #receives betas as K x V x 1 as input, and a threshold val payload_weight_threshold=0.8, use_mean_x=0., cont_beta_loss=False, prob_repulsion=False, phase_transition=False, phase_transition_double_weight=False, alt_potential_norm=False, payload_beta_gradient_damping_strength=0., kalpha_damping_strength=0., beta_gradient_damping=0., soft_q_scaling=True, weight_by_q=False, repulsion_q_min=-1., super_repulsion=False): ''' all inputs V x X , where X can be 1 ''' if not alt_potential_norm: raise ValueError("not alt_potential_norm not implemented") if not prob_repulsion: raise ValueError("not prob_repulsion not implemented") if not phase_transition: raise ValueError("not phase_transition not implemented") if phase_transition_double_weight: raise ValueError("phase_transition_double_weight not implemented") if cont_beta_loss: raise ValueError("cont_beta_loss not implemented") if payload_weight_function is not None: raise ValueError("payload_weight_function not implemented") #set all spectators invalid here, everything scales with beta, so: if beta_gradient_damping > 0.: beta = beta_gradient_damping * tf.stop_gradient(beta) + ( 1. - beta_gradient_damping) * beta beta_in = beta beta = tf.clip_by_value(beta, 0., 1. - 1e-4) beta *= (1. - is_spectator) qraw = tf.math.atanh(beta)**2 if soft_q_scaling: qraw = tf.math.atanh(beta / 1.002)**2 #beta_in**4 *20. beta = beta_in * (1. - is_spectator) # no need for clipping q = qraw + q_min * (1. - is_spectator) # V x 1 #q = tf.where(beta_in<1.-1e-4, q, tf.math.atanh(1.-1e-4)**2 + q_min + beta_in) #just give the rest above clip a gradient N = tf.cast(beta.shape[0], dtype='float32') is_noise = tf.where(truth_idx < 0, tf.zeros_like(truth_idx, dtype='float32') + 1., 0.) #V x 1 Msel, M_not, N_per_obj = CreateMidx(truth_idx, calc_m_not=True) N_per_obj = tf.cast(N_per_obj, dtype='float32') # K x 1 K = tf.cast(Msel.shape[0], dtype='float32') padmask_m = SelectWithDefault(Msel, tf.zeros_like(beta_in) + 1., 0) #K x V-obj x 1 x_m = SelectWithDefault(Msel, x, 0.) #K x V-obj x C beta_m = SelectWithDefault(Msel, beta_in, 0.) #K x V-obj x 1 q_m = SelectWithDefault(Msel, q, 0.) #K x V-obj x 1 object_weights_m = SelectWithDefault(Msel, object_weights, 0.) distance_scale_m = SelectWithDefault(Msel, distance_scale, 1.) kalpha_m = tf.argmax(beta_m, axis=1) # K x 1 x_kalpha_m = tf.gather_nd(x_m, kalpha_m, batch_dims=1) # K x C if use_mean_x > 0: x_kalpha_m_m = tf.reduce_sum(q_m * x_m * padmask_m, axis=1) # K x C x_kalpha_m_m = tf.math.divide_no_nan( x_kalpha_m_m, tf.reduce_sum(q_m * padmask_m, axis=1) + 1e-9) x_kalpha_m = use_mean_x * x_kalpha_m_m + (1. - use_mean_x) * x_kalpha_m if kalpha_damping_strength > 0: x_kalpha_m = kalpha_damping_strength * tf.stop_gradient(x_kalpha_m) + ( 1. - kalpha_damping_strength) * x_kalpha_m q_kalpha_m = tf.gather_nd(q_m, kalpha_m, batch_dims=1) # K x 1 beta_kalpha_m = tf.gather_nd(beta_m, kalpha_m, batch_dims=1) # K x 1 object_weights_kalpha_m = tf.gather_nd(object_weights_m, kalpha_m, batch_dims=1) # K x 1 distance_scale_kalpha_m = tf.gather_nd(distance_scale_m, kalpha_m, batch_dims=1) # K x 1 distance_scale_kalpha_m_exp = tf.expand_dims(distance_scale_kalpha_m, axis=2) # K x 1 x 1 distancesq_m = tf.reduce_sum((tf.expand_dims(x_kalpha_m, axis=1) - x_m)**2, axis=-1, keepdims=True) #K x V-obj x 1 distancesq_m *= distance_scale_kalpha_m_exp**2 huberdistsq = huber(tf.sqrt(distancesq_m + 1e-5), d=4) #acts at 4 V_att = q_m * tf.expand_dims(q_kalpha_m, axis=1) * huberdistsq #K x V-obj x 1 V_att = V_att * tf.expand_dims(object_weights_kalpha_m, axis=1) #K x V-obj x 1 if weight_by_q: V_att = tf.math.divide_no_nan(tf.reduce_sum(padmask_m * V_att, axis=1), tf.reduce_sum(q_m, axis=1)) # K x 1 else: V_att = tf.math.divide_no_nan(tf.reduce_sum(padmask_m * V_att, axis=1), N_per_obj + 1e-9) # K x 1 V_att = tf.math.divide_no_nan(tf.reduce_sum(V_att, axis=0), K + 1e-9) # 1 #what if Vatt and Vrep are weighted by q, not scaled by it? q_rep = q if repulsion_q_min >= 0: q_rep = qraw + repulsion_q_min q_kalpha_m += repulsion_q_min - q_min #now the bit that needs Mnot Mnot_distances = tf.expand_dims(x_kalpha_m, axis=1) #K x 1 x C Mnot_distances = Mnot_distances - tf.expand_dims(x, axis=0) #K x V x C if super_repulsion: sq_distance = tf.reduce_sum(Mnot_distances**2, axis=-1, keepdims=True) #K x V x 1 l_distance = tf.reduce_sum(tf.abs(Mnot_distances), axis=-1, keepdims=True) #K x V x 1 V_rep = 0.5 * (sq_distance + l_distance) else: V_rep = tf.reduce_sum(Mnot_distances**2, axis=-1, keepdims=True) #K x V x 1 V_rep *= distance_scale_kalpha_m_exp**2 #K x V x 1 , same scaling as attractive potential V_rep = 1. / (V_rep + 0.1 ) #-2.*tf.math.log(1.-tf.math.exp(-V_rep/2.)+1e-5) V_rep *= M_not * tf.expand_dims(q_rep, axis=0) #K x V x 1 V_rep = tf.reduce_sum(V_rep, axis=1) #K x 1 V_rep *= object_weights_kalpha_m * q_kalpha_m #K x 1 if weight_by_q: sumq = tf.reduce_sum(M_not * tf.expand_dims(q_rep, axis=0), axis=1) V_rep = tf.math.divide_no_nan(V_rep, sumq) # K x 1 else: V_rep = tf.math.divide_no_nan( V_rep, tf.expand_dims(tf.expand_dims(N, axis=0), axis=0) - N_per_obj + 1e-9) # K x 1 V_rep = tf.math.divide_no_nan(tf.reduce_sum(V_rep, axis=0), K + 1e-9) # 1 ## beta terms B_pen = -tf.reduce_sum(padmask_m * 1. / (20. * distancesq_m + 1.), axis=1) # K x 1 B_pen += 1. #remove self-interaction term (just for offset) B_pen *= object_weights_kalpha_m * beta_kalpha_m B_pen = tf.math.divide_no_nan(B_pen, N_per_obj + 1e-9) # K x 1 #now 'standard' 1-beta B_pen -= 0.2 * object_weights_kalpha_m * ( tf.math.log(beta_kalpha_m + 1e-9)) #tf.math.sqrt(beta_kalpha_m+1e-6) #another "-> 1, but slower" per object B_pen = tf.math.divide_no_nan(tf.reduce_sum(B_pen, axis=0), K + 1e-9) # 1 too_much_B_pen = tf.constant([0.], dtype='float32') Noise_pen = S_B * tf.math.divide_no_nan(tf.reduce_sum(is_noise * beta_in), tf.reduce_sum(is_noise)) #explicit payload weight function here, the old one was odd #too aggressive scaling is bad for high learning rates. Move to simple x^4 p_w = padmask_m * tf.clip_by_value( beta_m**2, 1e-3, 10.) #already zero-padded , K x V_perobj x 1 #normalise to maximum; this + 1e-9 might be an issue POSSIBLE FIXME if payload_beta_gradient_damping_strength > 0: p_w = payload_beta_gradient_damping_strength * tf.stop_gradient(p_w) + \ (1.- payload_beta_gradient_damping_strength)* p_w payload_loss_m = p_w * SelectWithDefault( Msel, (1. - is_noise) * payload_loss, 0.) #K x V_perobj x P payload_loss_m = object_weights_kalpha_m * tf.reduce_sum(payload_loss_m, axis=1) payload_loss_m = tf.math.divide_no_nan(payload_loss_m, tf.reduce_sum(p_w, axis=1)) #pll = tf.math.divide_no_nan(payload_loss_m, N_per_obj+1e-9) # K x P #really? pll = tf.math.divide_no_nan(tf.reduce_sum(payload_loss_m, axis=0), K + 1e-3) # P #explicit K**2 repulsion #if k_sq_repulsion_strength > 0.: #x_kalpha_m: K x C # k_sq_rep = tf.expand_dims(x_kalpha_m, axis=0) - tf.expand_dims(x_kalpha_m, axis=1) #x_kalpha_m: K x K x C # k_sq_rep = tf.reduce_sum(k_sq_rep**2, axis=-1) #distances**2 K x K # k_sq_rep = -2.*tf.math.log(1.-tf.math.exp(-k_sq_rep/2.)+1e-5) #K x K # #add qTq scaling also here? # k_sq_rep *= q_kalpha_m # adding the latter term would just add a factor of 2. to the corresponding kalpha Mnot term * tf.expand_dims(q_kalpha_m[:,0], axis=0) #K x K # k_sq_rep *= object_weights_kalpha_m * tf.expand_dims(object_weights_kalpha_m[:,0], axis=0) #K x K # k_sq_rep = tf.math.divide_no_nan(tf.reduce_sum(k_sq_rep,axis=0), K+1e-9) # k_sq_rep = tf.math.divide_no_nan(tf.reduce_sum(k_sq_rep,axis=0), K+1e-9) # # V_rep += k_sq_repulsion_strength * k_sq_rep # #object_weights_kalpha_m return V_att, V_rep, Noise_pen, B_pen, pll, too_much_B_pen
def oc_per_batch_element( beta, x, q_min, object_weights, # V x 1 !! truth_idx, is_spectator, payload_loss, S_B=1., noise_q_min=None, distance_scale=None, payload_weight_function=None, #receives betas as K x V x 1 as input, and a threshold val payload_weight_threshold=0.8, use_mean_x=0., cont_beta_loss=False, prob_repulsion=False, phase_transition=False, phase_transition_double_weight=False, payload_beta_gradient_damping_strength=0., kalpha_damping_strength=0., beta_gradient_damping=0., soft_q_scaling=True, weight_by_q=False, repulsion_q_min=-1., super_repulsion=False, super_attraction=False, div_repulsion=False, soft_att=True, dynamic_payload_scaling_onset=-0.03): ''' all inputs V x X , where X can be 1 ''' tf.assert_equal(True, is_spectator >= 0.) tf.assert_equal(True, beta >= 0.) if prob_repulsion: raise ValueError("prob_repulsion not implemented") if phase_transition_double_weight: raise ValueError("phase_transition_double_weight not implemented") if payload_weight_function is not None: raise ValueError("payload_weight_function not implemented") #set all spectators invalid here, everything scales with beta, so: if beta_gradient_damping > 0.: beta = beta_gradient_damping * tf.stop_gradient(beta) + ( 1. - beta_gradient_damping) * beta beta_in = beta beta = tf.clip_by_value(beta, 0., 1. - 1e-4) q_min *= (1. - is_spectator) qraw = tf.math.atanh(beta)**2 if soft_q_scaling: qraw = tf.math.atanh(beta_in / 1.002)**2 #beta_in**4 *20. is_noise = tf.where(truth_idx < 0, tf.zeros_like(truth_idx, dtype='float32') + 1., 0.) #V x 1 if noise_q_min is not None: q_min = (1. - is_noise) * q_min + is_noise * noise_q_min q_min = tf.where( q_min < 0, 0., q_min) #just safety in case there are some numerical effects q = qraw + q_min # V x 1 #q = tf.where(beta_in<1.-1e-4, q, tf.math.atanh(1.-1e-4)**2 + q_min + beta_in) #just give the rest above clip a gradient N = tf.cast(beta.shape[0], dtype='float32') Msel, M_not, N_per_obj = CreateMidx(truth_idx, calc_m_not=True) #use eager here if Msel is None: #V_att, V_rep, Noise_pen, B_pen, pll, too_much_B_pen print( '>>> WARNING: Event has no objects, only noise! Will return zero loss. <<<' ) zero_tensor = tf.reduce_mean(q, axis=0) * 0. zero_payload = tf.reduce_mean(payload_loss, axis=0) * 0. return zero_tensor, zero_tensor, zero_tensor, zero_tensor, zero_payload, zero_tensor N_per_obj = tf.cast(N_per_obj, dtype='float32') # K x 1 K = tf.cast(Msel.shape[0], dtype='float32') ######################################################## #sanity check, use none of the following for the loss calculation truth_m = SelectWithDefault(Msel, truth_idx, -2) #K x V-obj x 1 truth_same = truth_m[:, 0:1] == truth_m truth_same = tf.where(truth_m == -2, True, truth_same) tf.assert_equal( tf.reduce_all(truth_same), True, message="truth indices do not match object selection, serious bug") #end sanity check ######################################################## padmask_m = SelectWithDefault(Msel, tf.zeros_like(beta_in) + 1., 0.) #K x V-obj x 1 x_m = SelectWithDefault(Msel, x, 0.) #K x V-obj x C beta_m = SelectWithDefault(Msel, beta, 0.) #K x V-obj x 1 is_spectator_m = SelectWithDefault(Msel, is_spectator, 0.) #K x V-obj x 1 q_m = SelectWithDefault(Msel, q, 0.) #K x V-obj x 1 object_weights_m = SelectWithDefault(Msel, object_weights, 0.) distance_scale += 1e-3 distance_scale_m = SelectWithDefault(Msel, distance_scale, 1.) tf.assert_greater(distance_scale_m, 0., message="predicted distances must be greater zero") kalpha_m = tf.argmax((1. - is_spectator_m) * beta_m, axis=1) # K x 1 x_kalpha_m = tf.gather_nd(x_m, kalpha_m, batch_dims=1) # K x C if use_mean_x > 0: x_kalpha_m_m = tf.reduce_sum(beta_m * q_m * x_m * padmask_m, axis=1) # K x C x_kalpha_m_m = tf.math.divide_no_nan( x_kalpha_m_m, tf.reduce_sum(beta_m * q_m * padmask_m, axis=1) + 1e-9) x_kalpha_m = use_mean_x * x_kalpha_m_m + (1. - use_mean_x) * x_kalpha_m if kalpha_damping_strength > 0: x_kalpha_m = kalpha_damping_strength * tf.stop_gradient(x_kalpha_m) + ( 1. - kalpha_damping_strength) * x_kalpha_m q_kalpha_m = tf.gather_nd(q_m, kalpha_m, batch_dims=1) # K x 1 beta_kalpha_m = tf.gather_nd(beta_m, kalpha_m, batch_dims=1) # K x 1 object_weights_kalpha_m = tf.gather_nd(object_weights_m, kalpha_m, batch_dims=1) # K x 1 #make the distance scale a beta weighted mean so that there is more than 1 impact per object distance_scale_kalpha_m = tf.math.divide_no_nan( tf.reduce_sum(distance_scale_m * beta_m * padmask_m, axis=1), tf.reduce_sum(beta_m * padmask_m, axis=1) + 1e-3) + 1e-3 #K x 1 #distance_scale_kalpha_m = tf.gather_nd(distance_scale_m,kalpha_m, batch_dims=1) # K x 1 distance_scale_kalpha_m_exp = tf.expand_dims(distance_scale_kalpha_m, axis=2) # K x 1 x 1 distancesq_m = tf.reduce_sum((tf.expand_dims(x_kalpha_m, axis=1) - x_m)**2, axis=-1, keepdims=True) #K x V-obj x 1 distancesq_m = tf.math.divide_no_nan( distancesq_m, 2. * distance_scale_kalpha_m_exp**2 + 1e-6) absdist = tf.sqrt(distancesq_m + 1e-6) huberdistsq = huber(absdist, d=4) #acts at 4 if super_attraction: huberdistsq += 1. - tf.math.exp(-100. * absdist) V_att = q_m * tf.expand_dims(q_kalpha_m, axis=1) * huberdistsq #K x V-obj x 1 if soft_att: V_att = q_m * tf.math.log(tf.math.exp(1.) * distancesq_m + 1.) V_att = V_att * tf.expand_dims(object_weights_kalpha_m, axis=1) #K x V-obj x 1 if weight_by_q: V_att = tf.math.divide_no_nan(tf.reduce_sum(padmask_m * V_att, axis=1), tf.reduce_sum(q_m, axis=1)) # K x 1 else: V_att = tf.math.divide_no_nan(tf.reduce_sum(padmask_m * V_att, axis=1), N_per_obj + 1e-9) # K x 1 # opt. used later in payload loss V_att_K = V_att V_att = tf.math.divide_no_nan(tf.reduce_sum(V_att, axis=0), K + 1e-9) # 1 #what if Vatt and Vrep are weighted by q, not scaled by it? q_rep = q if repulsion_q_min >= 0: raise ValueError("repulsion_q_min >= 0: spectators TBI") q_rep = (qraw + repulsion_q_min) * (1. - is_spectator) q_kalpha_m += repulsion_q_min - q_min #now the bit that needs Mnot Mnot_distances = tf.expand_dims(x_kalpha_m, axis=1) #K x 1 x C Mnot_distances = Mnot_distances - tf.expand_dims(x, axis=0) #K x V x C rep_distances = tf.reduce_sum(Mnot_distances**2, axis=-1, keepdims=True) #K x V x 1 rep_distances = tf.math.divide_no_nan( rep_distances, 2. * distance_scale_kalpha_m_exp**2 + 1e-6) V_rep = tf.math.exp( -rep_distances ) #1. / (V_rep + 0.1) #-2.*tf.math.log(1.-tf.math.exp(-V_rep/2.)+1e-5) if super_repulsion: V_rep += 10. * tf.math.exp(-100. * tf.sqrt(rep_distances + 1e-6)) if div_repulsion: V_rep = 1. / (rep_distances + 0.1) #spec weights are in q V_rep *= M_not * tf.expand_dims(q_rep, axis=0) #K x V x 1 V_rep = tf.reduce_sum(V_rep, axis=1) #K x 1 V_rep *= object_weights_kalpha_m * q_kalpha_m #K x 1 if weight_by_q: sumq = tf.reduce_sum(M_not * tf.expand_dims(q_rep, axis=0), axis=1) V_rep = tf.math.divide_no_nan(V_rep, sumq) # K x 1 else: V_rep = tf.math.divide_no_nan( V_rep, tf.expand_dims(tf.expand_dims(N, axis=0), axis=0) - N_per_obj + 1e-9) # K x 1 # opt used later in payload loss V_rep_K = V_rep V_rep = tf.math.divide_no_nan(tf.reduce_sum(V_rep, axis=0), K + 1e-9) # 1 B_pen = None def bpenhelp(b_m, exponent: int): b_mes = tf.reduce_sum(b_m**exponent, axis=1) if not exponent == 1: b_mes = (b_mes + 1e-16)**(1. / float(exponent)) return tf.math.log((1. - b_mes)**2 + 1. + 1e-8) if phase_transition: ## beta terms B_pen = -tf.reduce_sum(padmask_m * 1. / (20. * distancesq_m + 1.), axis=1) # K x 1 B_pen += 1. #remove self-interaction term (just for offset) B_pen *= object_weights_kalpha_m * beta_kalpha_m B_pen = tf.math.divide_no_nan(B_pen, N_per_obj + 1e-9) # K x 1 #now 'standard' 1-beta B_pen -= 0.2 * object_weights_kalpha_m * ( tf.math.log(beta_kalpha_m + 1e-9) ) #tf.math.sqrt(beta_kalpha_m+1e-6) #another "-> 1, but slower" per object B_pen = tf.math.divide_no_nan(tf.reduce_sum(B_pen, axis=0), K + 1e-9) # 1 else: B_pen_po = object_weights_kalpha_m * (1. - beta_kalpha_m) B_pen = tf.math.divide_no_nan(tf.reduce_sum(B_pen_po, axis=0), K + 1e-9) #1 #get out of random gradients in the beginning #introduces gradients on all betas of hits rather than just the max one B_up = tf.math.divide_no_nan( tf.reduce_sum((1. - is_noise) * (1. - beta_in)), N - tf.reduce_sum(is_noise)) B_pen += 0.01 * B_pen * B_up #if it's high try to elevate all betas if cont_beta_loss: B_pen = bpenhelp(beta_m, 2) + bpenhelp(beta_m, 4) B_pen = tf.math.divide_no_nan( tf.reduce_sum(object_weights_kalpha_m * B_pen, axis=0), K + 1e-9) too_much_B_pen = object_weights_kalpha_m * bpenhelp( beta_m, 1) #K x 1, don't make it steep too_much_B_pen = tf.math.divide_no_nan(tf.reduce_sum(too_much_B_pen), K + 1e-9) Noise_pen = S_B * tf.math.divide_no_nan(tf.reduce_sum(is_noise * beta_in), tf.reduce_sum(is_noise) + 1e-3) #explicit payload weight function here, the old one was odd #too aggressive scaling is bad for high learning rates. p_w = padmask_m * tf.math.atanh(beta_m / 1.002)**2 #this is well behaved if payload_beta_gradient_damping_strength > 0: p_w = payload_beta_gradient_damping_strength * tf.stop_gradient(p_w) + \ (1.- payload_beta_gradient_damping_strength)* p_w payload_loss_m = p_w * SelectWithDefault( Msel, (1. - is_noise) * payload_loss, 0.) #K x V_perobj x P payload_loss_m = object_weights_kalpha_m * tf.reduce_sum(payload_loss_m, axis=1) # K x P #here normalisation per object payload_loss_m = tf.math.divide_no_nan(payload_loss_m, tf.reduce_sum(p_w, axis=1)) #print('dynamic_payload_scaling_onset',dynamic_payload_scaling_onset) if dynamic_payload_scaling_onset > 0: #stop gradient V_scaler = tf.stop_gradient(V_rep_K + V_att_K) # K x 1 #print('N_per_obj[V_scaler=0]',N_per_obj[V_scaler==0]) #max of V_scaler is around 1 given the potentials scaling = tf.exp(-tf.math.log(2.) * V_scaler / (dynamic_payload_scaling_onset / 5.)) #print('affected fraction',tf.math.count_nonzero(scaling>0.5,dtype='float32')/K,'max',tf.reduce_max(V_scaler,axis=0,keepdims=True)) payload_loss_m *= scaling #basically the onset of the rise #pll = tf.math.divide_no_nan(payload_loss_m, N_per_obj+1e-9) # K x P #really? pll = tf.math.divide_no_nan(tf.reduce_sum(payload_loss_m, axis=0), K + 1e-3) # P return V_att, V_rep, Noise_pen, B_pen, pll, too_much_B_pen