Пример #1
0
 def _build_pos_and_neg_contribs(self):
     if (self.dense_mxts_mode == DenseMxtsMode.Linear):
         inp_diff_ref = self._get_input_diff_from_reference_vars()
         pos_contribs = (B.dot(inp_diff_ref *
                               (inp_diff_ref > 0.0), self.W *
                               (self.W > 0.0)) +
                         B.dot(inp_diff_ref *
                               (inp_diff_ref < 0.0), self.W *
                               (self.W < 0.0)))
         neg_contribs = (B.dot(inp_diff_ref *
                               (inp_diff_ref < 0.0), self.W *
                               (self.W > 0.0)) +
                         B.dot(inp_diff_ref *
                               (inp_diff_ref > 0.0), self.W *
                               (self.W < 0.0)))
     elif (self.dense_mxts_mode == DenseMxtsMode.SepPosAndNeg):
         #compute pos/neg contribs based on the pos/neg breakdown
         #of the input, rather than just the sign of inp_diff_ref
         inp_pos_contribs, inp_neg_contribs =\
             self._get_input_pos_and_neg_contribs()
         pos_contribs = (B.dot(inp_pos_contribs, self.W * (self.W >= 0.0)) +
                         B.dot(inp_neg_contribs, self.W * (self.W < 0.0)))
         neg_contribs = (B.dot(inp_neg_contribs, self.W * (self.W >= 0.0)) +
                         B.dot(inp_pos_contribs, self.W * (self.W < 0.0)))
     else:
         raise RuntimeError("Unsupported dense_mxts_mode: " +
                            self.dense_mxts_mode)
     return pos_contribs, neg_contribs
Пример #2
0
    def _get_mxts_increments_for_inputs(self):
        if (self.dense_mxts_mode == DenseMxtsMode.Linear): 
            #different inputs will inherit multipliers differently according
            #to the sign of inp_diff_ref (as this sign was used to determine
            #the pos_contribs and neg_contribs; there was no breakdown
            #by the pos/neg contribs of the input)
            inp_diff_ref = self._get_input_diff_from_reference_vars() 
            pos_inp_mask = inp_diff_ref > 0.0
            neg_inp_mask = inp_diff_ref < 0.0
            zero_inp_mask = B.eq(inp_diff_ref, 0.0)
            inp_mxts_increments = pos_inp_mask*(
                                    B.dot(self.get_pos_mxts(),
                                        self.W.T*(self.W.T>=0.0)) 
                                   +B.dot(self.get_neg_mxts(),
                                        self.W.T*(self.W.T<0.0)))
            inp_mxts_increments += neg_inp_mask*(
                                    B.dot(self.get_pos_mxts(),
                                        self.W.T*(self.W.T<0.0)) 
                                   +B.dot(self.get_neg_mxts(),
                                        self.W.T*(self.W.T>=0.0)))
            inp_mxts_increments += zero_inp_mask*B.dot(
                                   0.5*(self.get_pos_mxts()
                                        +self.get_neg_mxts()),self.W.T)
            #pos_mxts and neg_mxts in the input get the same multiplier
            #because the breakdown between pos and neg wasn't used to
            #compute pos_contribs and neg_contribs in the forward pass
            #(it was based entirely on inp_diff_ref)
            return inp_mxts_increments, inp_mxts_increments

        elif (self.dense_mxts_mode == DenseMxtsMode.SepPosAndNeg):
            #during the forward pass, the pos/neg contribs of the input
            #were used to determing the pos/neg contribs of the output - thus
            #during the backward pass, the pos/neg mxts will be determined
            #accordingly (i.e. for a given input, the multiplier on the
            #positive part may be different from the multiplier on the
            #negative part)
            pos_mxts_increments = (B.dot(self.get_pos_mxts(),
                                        self.W.T*(self.W.T>=0.0))
                                   +B.dot(self.get_neg_mxts(),
                                        self.W.T*(self.W.T<0.0)))
            neg_mxts_increments = (B.dot(self.get_pos_mxts(),
                                        self.W.T*(self.W.T<0.0))
                                   +B.dot(self.get_neg_mxts(),
                                        self.W.T*(self.W.T>=0.0)))
            return pos_mxts_increments, neg_mxts_increments
        else:
            raise RuntimeError("Unsupported mxts mode: "
                               +str(self.dense_mxts_mode))
Пример #3
0
 def _build_activation_vars(self, input_act_vars):
     return B.dot(input_act_vars, self.W) + self.b
Пример #4
0
    def _get_mxts_increments_for_inputs(self):
        #re. counterbalance: this modification is only appropriate
        #when the output is a relu. So when the output is not a relu,
        #hackily set the mode back to Linear.
        if (self.dense_mxts_mode in [
                DenseMxtsMode.RevealCancel, DenseMxtsMode.Redist,
                DenseMxtsMode.Counterbalance, DenseMxtsMode.RevealCancelRedist,
                DenseMxtsMode.RevealCancelRedist_ThroughZeros,
                DenseMxtsMode.RevealCancelRedist2,
                DenseMxtsMode.ContinuousShapely
        ]):
            revert = False
            if (len(self.get_output_layers()) != 1):
                revert = True
            else:
                layer_to_check = self.get_output_layers()[0]
                if (type(layer_to_check).__name__ == "BatchNormalization"):
                    layer_to_check = layer_to_check.get_output_layers()[0]
                if (type(layer_to_check).__name__ != "ReLU"):
                    revert = True
            if (revert):
                if (self.dense_mxts_mode != DenseMxtsMode.Linear):
                    print("Dense layer " + str(self.get_name()) +
                          " does not have sole output of ReLU so" +
                          " cautiously reverting DenseMxtsMode from " +
                          str(self.dense_mxts_mode) + " to Linear")
                    self.dense_mxts_mode = DenseMxtsMode.Linear

        if (self.dense_mxts_mode == DenseMxtsMode.PosOnly):
            return B.dot(self.get_mxts() * (self.get_mxts() > 0.0), self.W.T)

        elif (self.dense_mxts_mode in [
                DenseMxtsMode.RevealCancel, DenseMxtsMode.Redist,
                DenseMxtsMode.Counterbalance, DenseMxtsMode.RevealCancelRedist,
                DenseMxtsMode.RevealCancelRedist_ThroughZeros,
                DenseMxtsMode.RevealCancelRedist2,
                DenseMxtsMode.ContinuousShapely
        ]):
            #self.W has dims input x output; W.T is output x input
            #self._get_input_diff_from_reference_vars() has dims batch x input
            #fwd_contribs has dims batch x output x input
            fwd_contribs = self._get_input_diff_from_reference_vars()[:,None,:]\
                           *self.W.T[None,:,:]
            #reference has dims batch x output
            reference = self.get_reference_vars()
            #total_pos_contribs and total_neg_contribs have dim batch x output
            total_pos_contribs = B.sum(fwd_contribs * (fwd_contribs > 0),
                                       axis=-1)
            total_neg_contribs = B.abs(
                B.sum(fwd_contribs * (fwd_contribs < 0), axis=-1))

            #compute the positive and negative impact under revealcancel redist
            #will rescale pos and neg contribs to add up to this I guess
            rrd_pos_impact = 0.5 * (
                (B.maximum(0, reference + total_pos_contribs) -
                 B.maximum(0, reference)) +
                (B.maximum(
                    0, reference - total_neg_contribs + total_pos_contribs) -
                 B.maximum(0, reference - total_neg_contribs)))
            rrd_neg_impact = 0.5 * (
                (B.maximum(0, reference - total_neg_contribs) -
                 B.maximum(0, reference)) +
                (B.maximum(
                    0, reference + total_pos_contribs - total_neg_contribs) -
                 B.maximum(0, reference + total_pos_contribs)))

            if (self.dense_mxts_mode in [
                    DenseMxtsMode.RevealCancelRedist2,
                    DenseMxtsMode.ContinuousShapely
            ]):

                #dims batch x output x input
                v = fwd_contribs
                r = reference[:, :, None]
                #pmax and nmax are the pos and neg contribs *absent* the
                #current contrib
                pmax = (total_pos_contribs[:, :, None] - (v * (v > 0)))
                nmax = (total_neg_contribs[:, :, None] - (v * (v < 0)))

                if (self.dense_mxts_mode == DenseMxtsMode.RevealCancelRedist2):
                    #we will make the simplying assumption that the three
                    #'players' are v, pmax and nmax
                    #possible orders are:
                    # v, pmax, nmax & v, nmax, pmax <- 1
                    # pmax, nmax, v & nmax, pmax, v <- 2
                    # pmax, v, nmax <- 3
                    # nmax, v, pmax <- 4
                    #Let's find the marginal contribs in all cases
                    # case 1:
                    c1 = 2 * (B.maximum(0, r + v) - B.maximum(0, r))
                    c2 = 2 * (B.maximum(0, r + (pmax - nmax) + v) -
                              B.maximum(0, r + (pmax - nmax)))
                    c3 = B.maximum(0, r + pmax + v) - B.maximum(0, r + pmax)
                    c4 = B.maximum(0, r - nmax + v) - B.maximum(0, r - nmax)
                    unscaled_contribs = (c1 + c2 + c3 + c4) / 6.0
                    #add an adjustment to make sure that the total contribs
                    #are the same.
                    total_contribs = total_pos_contribs - total_neg_contribs
                    zero_total_contribs_mask = B.eq(total_contribs, 0.0)
                    total_unscaled_contribs = B.sum(unscaled_contribs, axis=-1)
                    scale = (total_contribs /
                             pseudocount_near_zero(total_unscaled_contribs))
                    #in the 0.0/0.0 case where the scale is undefined, let
                    #the scale factor be 1.0
                    scale += 1.0 * (B.eq(total_contribs, 0.0) *
                                    B.eq(total_unscaled_contribs, 0.0))
                    final_contribs = unscaled_contribs * scale[:, :, None]

                    #v is weight*diff-from-default
                    #The multiplier = final_contribs/diff_from_default
                    #               = weight*(final_contribs/v)
                    multiplier_adjustment = (final_contribs /
                                             pseudocount_near_zero(v))
                    #for the case when v is zero, we need to find the partial
                    #derivative of final_contribs w.r.t. v. This is:
                    partialdv = (
                        scale[:, :, None] *  #dfinal/dunscaled
                        #dunscaled/dv (split into c1..c4):
                        (2 * (r > 0.0) + 2 * (r + pmax - nmax > 0.0) +
                         (r + pmax > 0.0) + (r - nmax > 0.0)) / 6.0)
                    #add partialdv to multiplier_adjustment when v is zero
                    #(multiplier_adjustment should be 0 in those places
                    # before this line)
                    multiplier_adjustment += partialdv * B.eq(v, 0.0)

                    #dims of new_Wt: batch x output x input
                    new_Wt = self.W.T[None, :, :] * multiplier_adjustment
                    return B.sum(self.get_mxts()[:, :, None] * new_Wt[:, :, :],
                                 axis=1)

                elif (self.dense_mxts_mode == DenseMxtsMode.ContinuousShapely):
                    raise NotImplementedError(
                        "I scrapped this implementation; see git history for it"
                    )
                else:
                    raise RuntimeError(self.dense_mxts_mode +
                                       " not implemented")

            #positive and negative values grouped together for rescale:
            elif (self.dense_mxts_mode in [
                    DenseMxtsMode.Redist, DenseMxtsMode.Counterbalance,
                    DenseMxtsMode.RevealCancel,
                    DenseMxtsMode.RevealCancelRedist,
                    DenseMxtsMode.RevealCancelRedist_ThroughZeros
            ]):
                if (self.dense_mxts_mode == DenseMxtsMode.Redist or
                        self.dense_mxts_mode == DenseMxtsMode.Counterbalance):
                    #if output diff-from-def is positive but there are some neg
                    #contribs, temper positive by some portion of the neg
                    #to_distribute has dims batch x output
                    #neg_to_distribute is what dips below 0, accounting for ref
                    to_distribute = B.minimum(
                        B.maximum(
                            total_neg_contribs -
                            B.maximum(self.get_reference_vars(), 0.0), 0.0),
                        total_pos_contribs) / 2.0

                    #total_pos_contribs_new has dims batch x output
                    total_pos_contribs_new = total_pos_contribs - to_distribute
                    total_neg_contribs_new = total_neg_contribs - to_distribute
                elif (self.dense_mxts_mode in [
                        DenseMxtsMode.RevealCancel,
                        DenseMxtsMode.RevealCancelRedist,
                        DenseMxtsMode.RevealCancelRedist_ThroughZeros
                ]):

                    ##sanity check to see if we can implement the existing deeplift
                    #total_contribs = total_pos_contribs - total_neg_contribs
                    #effective_contribs = B.maximum(self.get_reference_vars() + total_contribs,0) -\
                    #                     B.maximum(self.get_reference_vars(),0)
                    #rescale = effective_contribs/total_contribs
                    #
                    #return B.sum(self.get_mxts()[:,:,None]*self.W.T[None,:,:]*rescale[:,:,None], axis=1)

                    total_pos_contribs_new =\
                     B.maximum(self.get_reference_vars()+total_pos_contribs,0)\
                     -B.maximum(self.get_reference_vars(),0)
                    total_neg_contribs_new =\
                     B.maximum(self.get_reference_vars()+total_pos_contribs,0)\
                     -B.maximum(self.get_reference_vars()
                                +total_pos_contribs-total_neg_contribs,0)
                    if (self.dense_mxts_mode in [
                            DenseMxtsMode.RevealCancelRedist,
                            DenseMxtsMode.RevealCancelRedist_ThroughZeros
                    ]):
                        total_pos_contribs_new = rrd_pos_impact
                        total_neg_contribs_new = B.abs(rrd_neg_impact)
                        #to_distribute = B.minimum(
                        #    B.maximum(total_neg_contribs_new -
                        #              B.maximum(self.get_reference_vars(),0.0),0.0),
                        #    total_pos_contribs_new)/2.0
                        #total_pos_contribs_new = total_pos_contribs_new - to_distribute
                        #total_neg_contribs_new = total_neg_contribs_new - to_distribute
                else:
                    raise RuntimeError("Unsupported dense_mxts_mode: " +
                                       str(self.dense_mxts_mode))
                #positive_rescale has dims batch x output
                positive_rescale = total_pos_contribs_new/\
                                    pseudocount_near_zero(total_pos_contribs)
                negative_rescale = total_neg_contribs_new/\
                                    pseudocount_near_zero(total_neg_contribs)
                #new_Wt has dims batch x output x input
                new_Wt = self.W.T[None,:,:]*\
                          (fwd_contribs>0)*positive_rescale[:,:,None]
                new_Wt += self.W.T[None,:,:]*\
                           (fwd_contribs<0)*negative_rescale[:,:,None]
                if (self.dense_mxts_mode ==
                        DenseMxtsMode.RevealCancelRedist_ThroughZeros):
                    #for 0/0, set multiplier to half of pos and neg rescales
                    new_Wt += (self.W.T[None, :, :] *
                               (0.5 * (positive_rescale[:, :, None] +
                                       negative_rescale[:, :, None])) *
                               B.eq(fwd_contribs, 0.0))
            else:
                raise RuntimeError("Unsupported dense_mxts_mode: " +
                                   str(self.dense_mxts_mode))
            return B.sum(self.get_mxts()[:, :, None] * new_Wt[:, :, :], axis=1)

        elif (self.dense_mxts_mode == DenseMxtsMode.Linear):
            return B.dot(self.get_mxts(), self.W.T)
        else:
            raise RuntimeError("Unsupported mxts mode: " +
                               str(self.dense_mxts_mode))
Пример #5
0
    def _get_mxts_increments_for_inputs(self):
        #re. counterbalance: this modification is only appropriate
        #when the output is a relu. So when the output is not a relu,
        #hackily set the mode back to Linear.
        if (self.dense_mxts_mode in [
                DenseMxtsMode.RevealCancel, DenseMxtsMode.Redist,
                DenseMxtsMode.Counterbalance, DenseMxtsMode.RevealCancelRedist
        ]):
            if (len(self.get_output_layers()) != 1
                    or (type(self.get_output_layers()[0]).__name__ != "ReLU")):
                print("Dense layer does not have sole output of ReLU so"
                      " cautiously reverting DenseMxtsMode from"
                      " to Linear")
                self.dense_mxts_mode = DenseMxtsMode.Linear

        if (self.dense_mxts_mode == DenseMxtsMode.PosOnly):
            return B.dot(self.get_mxts() * (self.get_mxts() > 0.0), self.W.T)

        elif (self.dense_mxts_mode in [
                DenseMxtsMode.RevealCancel, DenseMxtsMode.Redist,
                DenseMxtsMode.Counterbalance, DenseMxtsMode.RevealCancelRedist
        ]):
            #self.W has dims input x output; W.T is output x input
            #self._get_input_diff_from_reference_vars() has dims batch x input
            #fwd_contribs has dims batch x output x input
            fwd_contribs = self._get_input_diff_from_reference_vars()[:,None,:]\
                           *self.W.T[None,:,:]

            #total_pos_contribs and total_neg_contribs have dim batch x output
            total_pos_contribs = B.sum(fwd_contribs * (fwd_contribs > 0),
                                       axis=-1)
            total_neg_contribs = B.abs(
                B.sum(fwd_contribs * (fwd_contribs < 0), axis=-1))
            if (self.dense_mxts_mode == DenseMxtsMode.Redist
                    or self.dense_mxts_mode == DenseMxtsMode.Counterbalance):
                #if output diff-from-def is positive but there are some neg
                #contribs, temper positive by some portion of the neg
                #to_distribute has dims batch x output
                #neg_to_distribute is what dips below 0, accounting for ref
                to_distribute = B.minimum(
                    B.maximum(
                        total_neg_contribs -
                        B.maximum(self.get_reference_vars(), 0.0), 0.0),
                    total_pos_contribs) / 2.0

                #total_pos_contribs_new has dims batch x output
                total_pos_contribs_new = total_pos_contribs - to_distribute
                total_neg_contribs_new = total_neg_contribs - to_distribute
            elif (self.dense_mxts_mode in [
                    DenseMxtsMode.RevealCancel,
                    DenseMxtsMode.RevealCancelRedist
            ]):

                ##sanity check to see if we can implement the existing deeplift
                #total_contribs = total_pos_contribs - total_neg_contribs
                #effective_contribs = B.maximum(self.get_reference_vars() + total_contribs,0) -\
                #                     B.maximum(self.get_reference_vars(),0)
                #rescale = effective_contribs/total_contribs
                #
                #return B.sum(self.get_mxts()[:,:,None]*self.W.T[None,:,:]*rescale[:,:,None], axis=1)

                total_pos_contribs_new =\
                 B.maximum(self.get_reference_vars()+total_pos_contribs,0)\
                 -B.maximum(self.get_reference_vars(),0)
                total_neg_contribs_new =\
                 B.maximum(self.get_reference_vars()+total_pos_contribs,0)\
                 -B.maximum(self.get_reference_vars()
                            +total_pos_contribs-total_neg_contribs,0)
                if (self.dense_mxts_mode == DenseMxtsMode.RevealCancelRedist):
                    to_distribute = B.minimum(
                        B.maximum(
                            total_neg_contribs_new -
                            B.maximum(self.get_reference_vars(), 0.0), 0.0),
                        total_pos_contribs_new) / 2.0
                    total_pos_contribs_new = total_pos_contribs_new - to_distribute
                    total_neg_contribs_new = total_neg_contribs_new - to_distribute
            else:
                raise RuntimeError("Unsupported dense_mxts_mode: " +
                                   str(self.dense_mxts_mode))
            #positive_rescale has dims batch x output
            positive_rescale = total_pos_contribs_new/\
                                pseudocount_near_zero(total_pos_contribs)
            negative_rescale = total_neg_contribs_new/\
                                pseudocount_near_zero(total_neg_contribs)
            #new_Wt has dims batch x output x input
            new_Wt = self.W.T[None,:,:]*\
                      (fwd_contribs>0)*positive_rescale[:,:,None]
            new_Wt += self.W.T[None,:,:]*\
                       (fwd_contribs<0)*negative_rescale[:,:,None]
            return B.sum(self.get_mxts()[:, :, None] * new_Wt[:, :, :], axis=1)

        elif (self.dense_mxts_mode == DenseMxtsMode.Linear):
            return B.dot(self.get_mxts(), self.W.T)
        else:
            raise RuntimeError("Unsupported mxts mode: " +
                               str(self.dense_mxts_mode))