def pseudocount_near_zero(tensor): return tensor + (NEAR_ZERO_THRESHOLD * ((B.abs(tensor) < 0.5 * NEAR_ZERO_THRESHOLD) * (tensor >= 0)) - NEAR_ZERO_THRESHOLD * ((B.abs(tensor) < 0.5 * NEAR_ZERO_THRESHOLD) * (tensor < 0)))
def _get_naive_rescale_factor(self): input_diff_from_reference = self._get_input_diff_from_reference_vars() near_zero_contrib_mask = (B.abs(input_diff_from_reference)\ < NEAR_ZERO_THRESHOLD) far_from_zero_contrib_mask = 1-(1*near_zero_contrib_mask) #the pseudocount is to avoid division-by-zero for the ones that #we won't use anyway pc_input_diff_from_reference = input_diff_from_reference +\ (1*near_zero_contrib_mask) #when total contrib is near zero, #the scale factor is 1 (gradient; piecewise linear). Otherwise, #compute the scale factor. The pseudocount doesn't mess anything up #as it is only there to prevent division by zero for the cases where #the contrib is near zero. scale_factor = near_zero_contrib_mask*\ self._get_gradient_at_default_activation_var() +\ (far_from_zero_contrib_mask*\ (self._get_diff_from_reference_vars()/ pc_input_diff_from_reference)) return scale_factor
def _get_mxts_increments_for_inputs(self): #re. counterbalance: this modification is only appropriate #when the output is a relu. So when the output is not a relu, #hackily set the mode back to Linear. if (self.dense_mxts_mode in [ DenseMxtsMode.RevealCancel, DenseMxtsMode.Redist, DenseMxtsMode.Counterbalance, DenseMxtsMode.RevealCancelRedist, DenseMxtsMode.RevealCancelRedist_ThroughZeros, DenseMxtsMode.RevealCancelRedist2, DenseMxtsMode.ContinuousShapely ]): revert = False if (len(self.get_output_layers()) != 1): revert = True else: layer_to_check = self.get_output_layers()[0] if (type(layer_to_check).__name__ == "BatchNormalization"): layer_to_check = layer_to_check.get_output_layers()[0] if (type(layer_to_check).__name__ != "ReLU"): revert = True if (revert): if (self.dense_mxts_mode != DenseMxtsMode.Linear): print("Dense layer " + str(self.get_name()) + " does not have sole output of ReLU so" + " cautiously reverting DenseMxtsMode from " + str(self.dense_mxts_mode) + " to Linear") self.dense_mxts_mode = DenseMxtsMode.Linear if (self.dense_mxts_mode == DenseMxtsMode.PosOnly): return B.dot(self.get_mxts() * (self.get_mxts() > 0.0), self.W.T) elif (self.dense_mxts_mode in [ DenseMxtsMode.RevealCancel, DenseMxtsMode.Redist, DenseMxtsMode.Counterbalance, DenseMxtsMode.RevealCancelRedist, DenseMxtsMode.RevealCancelRedist_ThroughZeros, DenseMxtsMode.RevealCancelRedist2, DenseMxtsMode.ContinuousShapely ]): #self.W has dims input x output; W.T is output x input #self._get_input_diff_from_reference_vars() has dims batch x input #fwd_contribs has dims batch x output x input fwd_contribs = self._get_input_diff_from_reference_vars()[:,None,:]\ *self.W.T[None,:,:] #reference has dims batch x output reference = self.get_reference_vars() #total_pos_contribs and total_neg_contribs have dim batch x output total_pos_contribs = B.sum(fwd_contribs * (fwd_contribs > 0), axis=-1) total_neg_contribs = B.abs( B.sum(fwd_contribs * (fwd_contribs < 0), axis=-1)) #compute the positive and negative impact under revealcancel redist #will rescale pos and neg contribs to add up to this I guess rrd_pos_impact = 0.5 * ( (B.maximum(0, reference + total_pos_contribs) - B.maximum(0, reference)) + (B.maximum( 0, reference - total_neg_contribs + total_pos_contribs) - B.maximum(0, reference - total_neg_contribs))) rrd_neg_impact = 0.5 * ( (B.maximum(0, reference - total_neg_contribs) - B.maximum(0, reference)) + (B.maximum( 0, reference + total_pos_contribs - total_neg_contribs) - B.maximum(0, reference + total_pos_contribs))) if (self.dense_mxts_mode in [ DenseMxtsMode.RevealCancelRedist2, DenseMxtsMode.ContinuousShapely ]): #dims batch x output x input v = fwd_contribs r = reference[:, :, None] #pmax and nmax are the pos and neg contribs *absent* the #current contrib pmax = (total_pos_contribs[:, :, None] - (v * (v > 0))) nmax = (total_neg_contribs[:, :, None] - (v * (v < 0))) if (self.dense_mxts_mode == DenseMxtsMode.RevealCancelRedist2): #we will make the simplying assumption that the three #'players' are v, pmax and nmax #possible orders are: # v, pmax, nmax & v, nmax, pmax <- 1 # pmax, nmax, v & nmax, pmax, v <- 2 # pmax, v, nmax <- 3 # nmax, v, pmax <- 4 #Let's find the marginal contribs in all cases # case 1: c1 = 2 * (B.maximum(0, r + v) - B.maximum(0, r)) c2 = 2 * (B.maximum(0, r + (pmax - nmax) + v) - B.maximum(0, r + (pmax - nmax))) c3 = B.maximum(0, r + pmax + v) - B.maximum(0, r + pmax) c4 = B.maximum(0, r - nmax + v) - B.maximum(0, r - nmax) unscaled_contribs = (c1 + c2 + c3 + c4) / 6.0 #add an adjustment to make sure that the total contribs #are the same. total_contribs = total_pos_contribs - total_neg_contribs zero_total_contribs_mask = B.eq(total_contribs, 0.0) total_unscaled_contribs = B.sum(unscaled_contribs, axis=-1) scale = (total_contribs / pseudocount_near_zero(total_unscaled_contribs)) #in the 0.0/0.0 case where the scale is undefined, let #the scale factor be 1.0 scale += 1.0 * (B.eq(total_contribs, 0.0) * B.eq(total_unscaled_contribs, 0.0)) final_contribs = unscaled_contribs * scale[:, :, None] #v is weight*diff-from-default #The multiplier = final_contribs/diff_from_default # = weight*(final_contribs/v) multiplier_adjustment = (final_contribs / pseudocount_near_zero(v)) #for the case when v is zero, we need to find the partial #derivative of final_contribs w.r.t. v. This is: partialdv = ( scale[:, :, None] * #dfinal/dunscaled #dunscaled/dv (split into c1..c4): (2 * (r > 0.0) + 2 * (r + pmax - nmax > 0.0) + (r + pmax > 0.0) + (r - nmax > 0.0)) / 6.0) #add partialdv to multiplier_adjustment when v is zero #(multiplier_adjustment should be 0 in those places # before this line) multiplier_adjustment += partialdv * B.eq(v, 0.0) #dims of new_Wt: batch x output x input new_Wt = self.W.T[None, :, :] * multiplier_adjustment return B.sum(self.get_mxts()[:, :, None] * new_Wt[:, :, :], axis=1) elif (self.dense_mxts_mode == DenseMxtsMode.ContinuousShapely): raise NotImplementedError( "I scrapped this implementation; see git history for it" ) else: raise RuntimeError(self.dense_mxts_mode + " not implemented") #positive and negative values grouped together for rescale: elif (self.dense_mxts_mode in [ DenseMxtsMode.Redist, DenseMxtsMode.Counterbalance, DenseMxtsMode.RevealCancel, DenseMxtsMode.RevealCancelRedist, DenseMxtsMode.RevealCancelRedist_ThroughZeros ]): if (self.dense_mxts_mode == DenseMxtsMode.Redist or self.dense_mxts_mode == DenseMxtsMode.Counterbalance): #if output diff-from-def is positive but there are some neg #contribs, temper positive by some portion of the neg #to_distribute has dims batch x output #neg_to_distribute is what dips below 0, accounting for ref to_distribute = B.minimum( B.maximum( total_neg_contribs - B.maximum(self.get_reference_vars(), 0.0), 0.0), total_pos_contribs) / 2.0 #total_pos_contribs_new has dims batch x output total_pos_contribs_new = total_pos_contribs - to_distribute total_neg_contribs_new = total_neg_contribs - to_distribute elif (self.dense_mxts_mode in [ DenseMxtsMode.RevealCancel, DenseMxtsMode.RevealCancelRedist, DenseMxtsMode.RevealCancelRedist_ThroughZeros ]): ##sanity check to see if we can implement the existing deeplift #total_contribs = total_pos_contribs - total_neg_contribs #effective_contribs = B.maximum(self.get_reference_vars() + total_contribs,0) -\ # B.maximum(self.get_reference_vars(),0) #rescale = effective_contribs/total_contribs # #return B.sum(self.get_mxts()[:,:,None]*self.W.T[None,:,:]*rescale[:,:,None], axis=1) total_pos_contribs_new =\ B.maximum(self.get_reference_vars()+total_pos_contribs,0)\ -B.maximum(self.get_reference_vars(),0) total_neg_contribs_new =\ B.maximum(self.get_reference_vars()+total_pos_contribs,0)\ -B.maximum(self.get_reference_vars() +total_pos_contribs-total_neg_contribs,0) if (self.dense_mxts_mode in [ DenseMxtsMode.RevealCancelRedist, DenseMxtsMode.RevealCancelRedist_ThroughZeros ]): total_pos_contribs_new = rrd_pos_impact total_neg_contribs_new = B.abs(rrd_neg_impact) #to_distribute = B.minimum( # B.maximum(total_neg_contribs_new - # B.maximum(self.get_reference_vars(),0.0),0.0), # total_pos_contribs_new)/2.0 #total_pos_contribs_new = total_pos_contribs_new - to_distribute #total_neg_contribs_new = total_neg_contribs_new - to_distribute else: raise RuntimeError("Unsupported dense_mxts_mode: " + str(self.dense_mxts_mode)) #positive_rescale has dims batch x output positive_rescale = total_pos_contribs_new/\ pseudocount_near_zero(total_pos_contribs) negative_rescale = total_neg_contribs_new/\ pseudocount_near_zero(total_neg_contribs) #new_Wt has dims batch x output x input new_Wt = self.W.T[None,:,:]*\ (fwd_contribs>0)*positive_rescale[:,:,None] new_Wt += self.W.T[None,:,:]*\ (fwd_contribs<0)*negative_rescale[:,:,None] if (self.dense_mxts_mode == DenseMxtsMode.RevealCancelRedist_ThroughZeros): #for 0/0, set multiplier to half of pos and neg rescales new_Wt += (self.W.T[None, :, :] * (0.5 * (positive_rescale[:, :, None] + negative_rescale[:, :, None])) * B.eq(fwd_contribs, 0.0)) else: raise RuntimeError("Unsupported dense_mxts_mode: " + str(self.dense_mxts_mode)) return B.sum(self.get_mxts()[:, :, None] * new_Wt[:, :, :], axis=1) elif (self.dense_mxts_mode == DenseMxtsMode.Linear): return B.dot(self.get_mxts(), self.W.T) else: raise RuntimeError("Unsupported mxts mode: " + str(self.dense_mxts_mode))
def _get_mxts_increments_for_inputs(self): #re. counterbalance: this modification is only appropriate #when the output is a relu. So when the output is not a relu, #hackily set the mode back to Linear. if (self.dense_mxts_mode in [ DenseMxtsMode.RevealCancel, DenseMxtsMode.Redist, DenseMxtsMode.Counterbalance, DenseMxtsMode.RevealCancelRedist ]): if (len(self.get_output_layers()) != 1 or (type(self.get_output_layers()[0]).__name__ != "ReLU")): print("Dense layer does not have sole output of ReLU so" " cautiously reverting DenseMxtsMode from" " to Linear") self.dense_mxts_mode = DenseMxtsMode.Linear if (self.dense_mxts_mode == DenseMxtsMode.PosOnly): return B.dot(self.get_mxts() * (self.get_mxts() > 0.0), self.W.T) elif (self.dense_mxts_mode in [ DenseMxtsMode.RevealCancel, DenseMxtsMode.Redist, DenseMxtsMode.Counterbalance, DenseMxtsMode.RevealCancelRedist ]): #self.W has dims input x output; W.T is output x input #self._get_input_diff_from_reference_vars() has dims batch x input #fwd_contribs has dims batch x output x input fwd_contribs = self._get_input_diff_from_reference_vars()[:,None,:]\ *self.W.T[None,:,:] #total_pos_contribs and total_neg_contribs have dim batch x output total_pos_contribs = B.sum(fwd_contribs * (fwd_contribs > 0), axis=-1) total_neg_contribs = B.abs( B.sum(fwd_contribs * (fwd_contribs < 0), axis=-1)) if (self.dense_mxts_mode == DenseMxtsMode.Redist or self.dense_mxts_mode == DenseMxtsMode.Counterbalance): #if output diff-from-def is positive but there are some neg #contribs, temper positive by some portion of the neg #to_distribute has dims batch x output #neg_to_distribute is what dips below 0, accounting for ref to_distribute = B.minimum( B.maximum( total_neg_contribs - B.maximum(self.get_reference_vars(), 0.0), 0.0), total_pos_contribs) / 2.0 #total_pos_contribs_new has dims batch x output total_pos_contribs_new = total_pos_contribs - to_distribute total_neg_contribs_new = total_neg_contribs - to_distribute elif (self.dense_mxts_mode in [ DenseMxtsMode.RevealCancel, DenseMxtsMode.RevealCancelRedist ]): ##sanity check to see if we can implement the existing deeplift #total_contribs = total_pos_contribs - total_neg_contribs #effective_contribs = B.maximum(self.get_reference_vars() + total_contribs,0) -\ # B.maximum(self.get_reference_vars(),0) #rescale = effective_contribs/total_contribs # #return B.sum(self.get_mxts()[:,:,None]*self.W.T[None,:,:]*rescale[:,:,None], axis=1) total_pos_contribs_new =\ B.maximum(self.get_reference_vars()+total_pos_contribs,0)\ -B.maximum(self.get_reference_vars(),0) total_neg_contribs_new =\ B.maximum(self.get_reference_vars()+total_pos_contribs,0)\ -B.maximum(self.get_reference_vars() +total_pos_contribs-total_neg_contribs,0) if (self.dense_mxts_mode == DenseMxtsMode.RevealCancelRedist): to_distribute = B.minimum( B.maximum( total_neg_contribs_new - B.maximum(self.get_reference_vars(), 0.0), 0.0), total_pos_contribs_new) / 2.0 total_pos_contribs_new = total_pos_contribs_new - to_distribute total_neg_contribs_new = total_neg_contribs_new - to_distribute else: raise RuntimeError("Unsupported dense_mxts_mode: " + str(self.dense_mxts_mode)) #positive_rescale has dims batch x output positive_rescale = total_pos_contribs_new/\ pseudocount_near_zero(total_pos_contribs) negative_rescale = total_neg_contribs_new/\ pseudocount_near_zero(total_neg_contribs) #new_Wt has dims batch x output x input new_Wt = self.W.T[None,:,:]*\ (fwd_contribs>0)*positive_rescale[:,:,None] new_Wt += self.W.T[None,:,:]*\ (fwd_contribs<0)*negative_rescale[:,:,None] return B.sum(self.get_mxts()[:, :, None] * new_Wt[:, :, :], axis=1) elif (self.dense_mxts_mode == DenseMxtsMode.Linear): return B.dot(self.get_mxts(), self.W.T) else: raise RuntimeError("Unsupported mxts mode: " + str(self.dense_mxts_mode))