def _build_pos_and_neg_contribs(self): if (self.dense_mxts_mode == DenseMxtsMode.Linear): inp_diff_ref = self._get_input_diff_from_reference_vars() pos_contribs = (B.dot(inp_diff_ref * (inp_diff_ref > 0.0), self.W * (self.W > 0.0)) + B.dot(inp_diff_ref * (inp_diff_ref < 0.0), self.W * (self.W < 0.0))) neg_contribs = (B.dot(inp_diff_ref * (inp_diff_ref < 0.0), self.W * (self.W > 0.0)) + B.dot(inp_diff_ref * (inp_diff_ref > 0.0), self.W * (self.W < 0.0))) elif (self.dense_mxts_mode == DenseMxtsMode.SepPosAndNeg): #compute pos/neg contribs based on the pos/neg breakdown #of the input, rather than just the sign of inp_diff_ref inp_pos_contribs, inp_neg_contribs =\ self._get_input_pos_and_neg_contribs() pos_contribs = (B.dot(inp_pos_contribs, self.W * (self.W >= 0.0)) + B.dot(inp_neg_contribs, self.W * (self.W < 0.0))) neg_contribs = (B.dot(inp_neg_contribs, self.W * (self.W >= 0.0)) + B.dot(inp_pos_contribs, self.W * (self.W < 0.0))) else: raise RuntimeError("Unsupported dense_mxts_mode: " + self.dense_mxts_mode) return pos_contribs, neg_contribs
def _get_mxts_increments_for_inputs(self): if (self.dense_mxts_mode == DenseMxtsMode.Linear): #different inputs will inherit multipliers differently according #to the sign of inp_diff_ref (as this sign was used to determine #the pos_contribs and neg_contribs; there was no breakdown #by the pos/neg contribs of the input) inp_diff_ref = self._get_input_diff_from_reference_vars() pos_inp_mask = inp_diff_ref > 0.0 neg_inp_mask = inp_diff_ref < 0.0 zero_inp_mask = B.eq(inp_diff_ref, 0.0) inp_mxts_increments = pos_inp_mask*( B.dot(self.get_pos_mxts(), self.W.T*(self.W.T>=0.0)) +B.dot(self.get_neg_mxts(), self.W.T*(self.W.T<0.0))) inp_mxts_increments += neg_inp_mask*( B.dot(self.get_pos_mxts(), self.W.T*(self.W.T<0.0)) +B.dot(self.get_neg_mxts(), self.W.T*(self.W.T>=0.0))) inp_mxts_increments += zero_inp_mask*B.dot( 0.5*(self.get_pos_mxts() +self.get_neg_mxts()),self.W.T) #pos_mxts and neg_mxts in the input get the same multiplier #because the breakdown between pos and neg wasn't used to #compute pos_contribs and neg_contribs in the forward pass #(it was based entirely on inp_diff_ref) return inp_mxts_increments, inp_mxts_increments elif (self.dense_mxts_mode == DenseMxtsMode.SepPosAndNeg): #during the forward pass, the pos/neg contribs of the input #were used to determing the pos/neg contribs of the output - thus #during the backward pass, the pos/neg mxts will be determined #accordingly (i.e. for a given input, the multiplier on the #positive part may be different from the multiplier on the #negative part) pos_mxts_increments = (B.dot(self.get_pos_mxts(), self.W.T*(self.W.T>=0.0)) +B.dot(self.get_neg_mxts(), self.W.T*(self.W.T<0.0))) neg_mxts_increments = (B.dot(self.get_pos_mxts(), self.W.T*(self.W.T<0.0)) +B.dot(self.get_neg_mxts(), self.W.T*(self.W.T>=0.0))) return pos_mxts_increments, neg_mxts_increments else: raise RuntimeError("Unsupported mxts mode: " +str(self.dense_mxts_mode))
def _build_activation_vars(self, input_act_vars): return B.dot(input_act_vars, self.W) + self.b
def _get_mxts_increments_for_inputs(self): #re. counterbalance: this modification is only appropriate #when the output is a relu. So when the output is not a relu, #hackily set the mode back to Linear. if (self.dense_mxts_mode in [ DenseMxtsMode.RevealCancel, DenseMxtsMode.Redist, DenseMxtsMode.Counterbalance, DenseMxtsMode.RevealCancelRedist, DenseMxtsMode.RevealCancelRedist_ThroughZeros, DenseMxtsMode.RevealCancelRedist2, DenseMxtsMode.ContinuousShapely ]): revert = False if (len(self.get_output_layers()) != 1): revert = True else: layer_to_check = self.get_output_layers()[0] if (type(layer_to_check).__name__ == "BatchNormalization"): layer_to_check = layer_to_check.get_output_layers()[0] if (type(layer_to_check).__name__ != "ReLU"): revert = True if (revert): if (self.dense_mxts_mode != DenseMxtsMode.Linear): print("Dense layer " + str(self.get_name()) + " does not have sole output of ReLU so" + " cautiously reverting DenseMxtsMode from " + str(self.dense_mxts_mode) + " to Linear") self.dense_mxts_mode = DenseMxtsMode.Linear if (self.dense_mxts_mode == DenseMxtsMode.PosOnly): return B.dot(self.get_mxts() * (self.get_mxts() > 0.0), self.W.T) elif (self.dense_mxts_mode in [ DenseMxtsMode.RevealCancel, DenseMxtsMode.Redist, DenseMxtsMode.Counterbalance, DenseMxtsMode.RevealCancelRedist, DenseMxtsMode.RevealCancelRedist_ThroughZeros, DenseMxtsMode.RevealCancelRedist2, DenseMxtsMode.ContinuousShapely ]): #self.W has dims input x output; W.T is output x input #self._get_input_diff_from_reference_vars() has dims batch x input #fwd_contribs has dims batch x output x input fwd_contribs = self._get_input_diff_from_reference_vars()[:,None,:]\ *self.W.T[None,:,:] #reference has dims batch x output reference = self.get_reference_vars() #total_pos_contribs and total_neg_contribs have dim batch x output total_pos_contribs = B.sum(fwd_contribs * (fwd_contribs > 0), axis=-1) total_neg_contribs = B.abs( B.sum(fwd_contribs * (fwd_contribs < 0), axis=-1)) #compute the positive and negative impact under revealcancel redist #will rescale pos and neg contribs to add up to this I guess rrd_pos_impact = 0.5 * ( (B.maximum(0, reference + total_pos_contribs) - B.maximum(0, reference)) + (B.maximum( 0, reference - total_neg_contribs + total_pos_contribs) - B.maximum(0, reference - total_neg_contribs))) rrd_neg_impact = 0.5 * ( (B.maximum(0, reference - total_neg_contribs) - B.maximum(0, reference)) + (B.maximum( 0, reference + total_pos_contribs - total_neg_contribs) - B.maximum(0, reference + total_pos_contribs))) if (self.dense_mxts_mode in [ DenseMxtsMode.RevealCancelRedist2, DenseMxtsMode.ContinuousShapely ]): #dims batch x output x input v = fwd_contribs r = reference[:, :, None] #pmax and nmax are the pos and neg contribs *absent* the #current contrib pmax = (total_pos_contribs[:, :, None] - (v * (v > 0))) nmax = (total_neg_contribs[:, :, None] - (v * (v < 0))) if (self.dense_mxts_mode == DenseMxtsMode.RevealCancelRedist2): #we will make the simplying assumption that the three #'players' are v, pmax and nmax #possible orders are: # v, pmax, nmax & v, nmax, pmax <- 1 # pmax, nmax, v & nmax, pmax, v <- 2 # pmax, v, nmax <- 3 # nmax, v, pmax <- 4 #Let's find the marginal contribs in all cases # case 1: c1 = 2 * (B.maximum(0, r + v) - B.maximum(0, r)) c2 = 2 * (B.maximum(0, r + (pmax - nmax) + v) - B.maximum(0, r + (pmax - nmax))) c3 = B.maximum(0, r + pmax + v) - B.maximum(0, r + pmax) c4 = B.maximum(0, r - nmax + v) - B.maximum(0, r - nmax) unscaled_contribs = (c1 + c2 + c3 + c4) / 6.0 #add an adjustment to make sure that the total contribs #are the same. total_contribs = total_pos_contribs - total_neg_contribs zero_total_contribs_mask = B.eq(total_contribs, 0.0) total_unscaled_contribs = B.sum(unscaled_contribs, axis=-1) scale = (total_contribs / pseudocount_near_zero(total_unscaled_contribs)) #in the 0.0/0.0 case where the scale is undefined, let #the scale factor be 1.0 scale += 1.0 * (B.eq(total_contribs, 0.0) * B.eq(total_unscaled_contribs, 0.0)) final_contribs = unscaled_contribs * scale[:, :, None] #v is weight*diff-from-default #The multiplier = final_contribs/diff_from_default # = weight*(final_contribs/v) multiplier_adjustment = (final_contribs / pseudocount_near_zero(v)) #for the case when v is zero, we need to find the partial #derivative of final_contribs w.r.t. v. This is: partialdv = ( scale[:, :, None] * #dfinal/dunscaled #dunscaled/dv (split into c1..c4): (2 * (r > 0.0) + 2 * (r + pmax - nmax > 0.0) + (r + pmax > 0.0) + (r - nmax > 0.0)) / 6.0) #add partialdv to multiplier_adjustment when v is zero #(multiplier_adjustment should be 0 in those places # before this line) multiplier_adjustment += partialdv * B.eq(v, 0.0) #dims of new_Wt: batch x output x input new_Wt = self.W.T[None, :, :] * multiplier_adjustment return B.sum(self.get_mxts()[:, :, None] * new_Wt[:, :, :], axis=1) elif (self.dense_mxts_mode == DenseMxtsMode.ContinuousShapely): raise NotImplementedError( "I scrapped this implementation; see git history for it" ) else: raise RuntimeError(self.dense_mxts_mode + " not implemented") #positive and negative values grouped together for rescale: elif (self.dense_mxts_mode in [ DenseMxtsMode.Redist, DenseMxtsMode.Counterbalance, DenseMxtsMode.RevealCancel, DenseMxtsMode.RevealCancelRedist, DenseMxtsMode.RevealCancelRedist_ThroughZeros ]): if (self.dense_mxts_mode == DenseMxtsMode.Redist or self.dense_mxts_mode == DenseMxtsMode.Counterbalance): #if output diff-from-def is positive but there are some neg #contribs, temper positive by some portion of the neg #to_distribute has dims batch x output #neg_to_distribute is what dips below 0, accounting for ref to_distribute = B.minimum( B.maximum( total_neg_contribs - B.maximum(self.get_reference_vars(), 0.0), 0.0), total_pos_contribs) / 2.0 #total_pos_contribs_new has dims batch x output total_pos_contribs_new = total_pos_contribs - to_distribute total_neg_contribs_new = total_neg_contribs - to_distribute elif (self.dense_mxts_mode in [ DenseMxtsMode.RevealCancel, DenseMxtsMode.RevealCancelRedist, DenseMxtsMode.RevealCancelRedist_ThroughZeros ]): ##sanity check to see if we can implement the existing deeplift #total_contribs = total_pos_contribs - total_neg_contribs #effective_contribs = B.maximum(self.get_reference_vars() + total_contribs,0) -\ # B.maximum(self.get_reference_vars(),0) #rescale = effective_contribs/total_contribs # #return B.sum(self.get_mxts()[:,:,None]*self.W.T[None,:,:]*rescale[:,:,None], axis=1) total_pos_contribs_new =\ B.maximum(self.get_reference_vars()+total_pos_contribs,0)\ -B.maximum(self.get_reference_vars(),0) total_neg_contribs_new =\ B.maximum(self.get_reference_vars()+total_pos_contribs,0)\ -B.maximum(self.get_reference_vars() +total_pos_contribs-total_neg_contribs,0) if (self.dense_mxts_mode in [ DenseMxtsMode.RevealCancelRedist, DenseMxtsMode.RevealCancelRedist_ThroughZeros ]): total_pos_contribs_new = rrd_pos_impact total_neg_contribs_new = B.abs(rrd_neg_impact) #to_distribute = B.minimum( # B.maximum(total_neg_contribs_new - # B.maximum(self.get_reference_vars(),0.0),0.0), # total_pos_contribs_new)/2.0 #total_pos_contribs_new = total_pos_contribs_new - to_distribute #total_neg_contribs_new = total_neg_contribs_new - to_distribute else: raise RuntimeError("Unsupported dense_mxts_mode: " + str(self.dense_mxts_mode)) #positive_rescale has dims batch x output positive_rescale = total_pos_contribs_new/\ pseudocount_near_zero(total_pos_contribs) negative_rescale = total_neg_contribs_new/\ pseudocount_near_zero(total_neg_contribs) #new_Wt has dims batch x output x input new_Wt = self.W.T[None,:,:]*\ (fwd_contribs>0)*positive_rescale[:,:,None] new_Wt += self.W.T[None,:,:]*\ (fwd_contribs<0)*negative_rescale[:,:,None] if (self.dense_mxts_mode == DenseMxtsMode.RevealCancelRedist_ThroughZeros): #for 0/0, set multiplier to half of pos and neg rescales new_Wt += (self.W.T[None, :, :] * (0.5 * (positive_rescale[:, :, None] + negative_rescale[:, :, None])) * B.eq(fwd_contribs, 0.0)) else: raise RuntimeError("Unsupported dense_mxts_mode: " + str(self.dense_mxts_mode)) return B.sum(self.get_mxts()[:, :, None] * new_Wt[:, :, :], axis=1) elif (self.dense_mxts_mode == DenseMxtsMode.Linear): return B.dot(self.get_mxts(), self.W.T) else: raise RuntimeError("Unsupported mxts mode: " + str(self.dense_mxts_mode))
def _get_mxts_increments_for_inputs(self): #re. counterbalance: this modification is only appropriate #when the output is a relu. So when the output is not a relu, #hackily set the mode back to Linear. if (self.dense_mxts_mode in [ DenseMxtsMode.RevealCancel, DenseMxtsMode.Redist, DenseMxtsMode.Counterbalance, DenseMxtsMode.RevealCancelRedist ]): if (len(self.get_output_layers()) != 1 or (type(self.get_output_layers()[0]).__name__ != "ReLU")): print("Dense layer does not have sole output of ReLU so" " cautiously reverting DenseMxtsMode from" " to Linear") self.dense_mxts_mode = DenseMxtsMode.Linear if (self.dense_mxts_mode == DenseMxtsMode.PosOnly): return B.dot(self.get_mxts() * (self.get_mxts() > 0.0), self.W.T) elif (self.dense_mxts_mode in [ DenseMxtsMode.RevealCancel, DenseMxtsMode.Redist, DenseMxtsMode.Counterbalance, DenseMxtsMode.RevealCancelRedist ]): #self.W has dims input x output; W.T is output x input #self._get_input_diff_from_reference_vars() has dims batch x input #fwd_contribs has dims batch x output x input fwd_contribs = self._get_input_diff_from_reference_vars()[:,None,:]\ *self.W.T[None,:,:] #total_pos_contribs and total_neg_contribs have dim batch x output total_pos_contribs = B.sum(fwd_contribs * (fwd_contribs > 0), axis=-1) total_neg_contribs = B.abs( B.sum(fwd_contribs * (fwd_contribs < 0), axis=-1)) if (self.dense_mxts_mode == DenseMxtsMode.Redist or self.dense_mxts_mode == DenseMxtsMode.Counterbalance): #if output diff-from-def is positive but there are some neg #contribs, temper positive by some portion of the neg #to_distribute has dims batch x output #neg_to_distribute is what dips below 0, accounting for ref to_distribute = B.minimum( B.maximum( total_neg_contribs - B.maximum(self.get_reference_vars(), 0.0), 0.0), total_pos_contribs) / 2.0 #total_pos_contribs_new has dims batch x output total_pos_contribs_new = total_pos_contribs - to_distribute total_neg_contribs_new = total_neg_contribs - to_distribute elif (self.dense_mxts_mode in [ DenseMxtsMode.RevealCancel, DenseMxtsMode.RevealCancelRedist ]): ##sanity check to see if we can implement the existing deeplift #total_contribs = total_pos_contribs - total_neg_contribs #effective_contribs = B.maximum(self.get_reference_vars() + total_contribs,0) -\ # B.maximum(self.get_reference_vars(),0) #rescale = effective_contribs/total_contribs # #return B.sum(self.get_mxts()[:,:,None]*self.W.T[None,:,:]*rescale[:,:,None], axis=1) total_pos_contribs_new =\ B.maximum(self.get_reference_vars()+total_pos_contribs,0)\ -B.maximum(self.get_reference_vars(),0) total_neg_contribs_new =\ B.maximum(self.get_reference_vars()+total_pos_contribs,0)\ -B.maximum(self.get_reference_vars() +total_pos_contribs-total_neg_contribs,0) if (self.dense_mxts_mode == DenseMxtsMode.RevealCancelRedist): to_distribute = B.minimum( B.maximum( total_neg_contribs_new - B.maximum(self.get_reference_vars(), 0.0), 0.0), total_pos_contribs_new) / 2.0 total_pos_contribs_new = total_pos_contribs_new - to_distribute total_neg_contribs_new = total_neg_contribs_new - to_distribute else: raise RuntimeError("Unsupported dense_mxts_mode: " + str(self.dense_mxts_mode)) #positive_rescale has dims batch x output positive_rescale = total_pos_contribs_new/\ pseudocount_near_zero(total_pos_contribs) negative_rescale = total_neg_contribs_new/\ pseudocount_near_zero(total_neg_contribs) #new_Wt has dims batch x output x input new_Wt = self.W.T[None,:,:]*\ (fwd_contribs>0)*positive_rescale[:,:,None] new_Wt += self.W.T[None,:,:]*\ (fwd_contribs<0)*negative_rescale[:,:,None] return B.sum(self.get_mxts()[:, :, None] * new_Wt[:, :, :], axis=1) elif (self.dense_mxts_mode == DenseMxtsMode.Linear): return B.dot(self.get_mxts(), self.W.T) else: raise RuntimeError("Unsupported mxts mode: " + str(self.dense_mxts_mode))