def _get_mxts_increments_for_inputs(self):
        if (self.dense_mxts_mode == DenseMxtsMode.Linear): 
            #different inputs will inherit multipliers differently according
            #to the sign of inp_diff_ref (as this sign was used to determine
            #the pos_contribs and neg_contribs; there was no breakdown
            #by the pos/neg contribs of the input)
            inp_diff_ref = self._get_input_diff_from_reference_vars() 
            pos_inp_mask = hf.gt_mask(inp_diff_ref,0.0)
            neg_inp_mask = hf.lt_mask(inp_diff_ref,0.0)
            zero_inp_mask = hf.eq_mask(inp_diff_ref,0.0)

            kernel_T = tf.transpose(self.kernel)
            
            inp_mxts_increments = 0
            if _pos_to_pos_mxts:
                 inp_mxts_increments += pos_inp_mask*(
                    tf.matmul(self.get_pos_mxts(),
                              kernel_T*(hf.gt_mask(kernel_T, 0.0))))
                     
            if _neg_to_pos_mxts:
                 inp_mxts_increments += pos_inp_mask*(
                    tf.matmul(self.get_neg_mxts(),
                                kernel_T*(hf.lt_mask(kernel_T, 0.0))))
                    
            if _pos_to_neg_mxts:
                inp_mxts_increments += neg_inp_mask*(
                    tf.matmul(self.get_pos_mxts(),
                              kernel_T*(hf.lt_mask(kernel_T, 0.0))))
            if _neg_to_neg_mxts:
                inp_mxts_increments += neg_inp_mask*(
                    tf.matmul(self.get_neg_mxts(),
                                kernel_T*(hf.gt_mask(kernel_T, 0.0))))
            
            if _zero_mxts:
                inp_mxts_increments += zero_inp_mask*(
                    tf.matmul(0.5*(self.get_pos_mxts()
                                   +self.get_neg_mxts()), kernel_T))
            #pos_mxts and neg_mxts in the input get the same multiplier
            #because the breakdown between pos and neg wasn't used to
            #compute pos_contribs and neg_contribs in the forward pass
            #(it was based entirely on inp_diff_ref)
            return inp_mxts_increments, inp_mxts_increments
        else:
            raise RuntimeError("Unsupported mxts mode: "
                               +str(self.dense_mxts_mode))
예제 #2
0
 def _get_naive_rescale_factor(self):
     input_diff_from_reference = self._get_input_diff_from_reference_vars()
     near_zero_contrib_mask = hf.lt_mask(
         tf.abs(input_diff_from_reference), NEAR_ZERO_THRESHOLD)
     far_from_zero_contrib_mask = 1.0-near_zero_contrib_mask
     #the pseudocount is to avoid division-by-zero for the ones that
     #we won't use anyway
     pc_diff_from_reference = input_diff_from_reference +\
                                         (1.0*near_zero_contrib_mask) 
     #when total contrib is near zero,
     #the scale factor is 1 (gradient; piecewise linear). Otherwise,
     #compute the scale factor. The pseudocount doesn't mess anything up
     #as it is only there to prevent division by zero for the cases where
     #the contrib is near zero.
     scale_factor = near_zero_contrib_mask*\
                     self._get_gradient_at_default_activation_var() +\
                    (far_from_zero_contrib_mask*\
                     (self._get_diff_from_reference_vars()/
                       pc_diff_from_reference))
     return scale_factor
예제 #3
0
 def _build_activation_vars(self, input_act_vars):
     to_return = tf.nn.relu(input_act_vars)
     negative_mask = hf.lt_mask(input_act_vars,0.0)
     to_return = to_return + negative_mask*input_act_vars*self.alpha
     return to_return
    def _get_mxts_increments_for_inputs(self):
        pos_mxts = self.get_pos_mxts()
        neg_mxts = self.get_neg_mxts()
        inp_diff_ref = self._get_input_diff_from_reference_vars()
        inp_act_vars = self.inputs.get_activation_vars()
        strides_to_supply = [1]+list(self.strides)+[1]

        if (self.data_format == DataFormat.channels_first):
            pos_mxts = tf.transpose(a=pos_mxts, perm=(0,2,3,1))
            neg_mxts = tf.transpose(a=neg_mxts, perm=(0,2,3,1))
            inp_diff_ref = tf.transpose(a=inp_diff_ref, perm=(0,2,3,1))
            inp_act_vars = tf.transpose(a=inp_act_vars, perm=(0,2,3,1))

        output_shape = tf.shape(inp_act_vars)

        if (self.conv_mxts_mode == ConvMxtsMode.Linear):
            pos_inp_mask = hf.gt_mask(inp_diff_ref,0.0)
            neg_inp_mask = hf.lt_mask(inp_diff_ref,0.0)
            zero_inp_mask = hf.eq_mask(inp_diff_ref, 0.0)

            inp_mxts_increments = 0
            
            if _pos_to_pos_mxts:
                inp_mxts_increments += pos_inp_mask*(
                            tf.nn.conv2d_transpose(
                                value=pos_mxts,
                                filter=self.kernel*hf.gt_mask(self.kernel, 0.0),
                                output_shape=output_shape,
                                padding=self.padding,
                                strides=strides_to_supply
                            ))
            if _neg_to_pos_mxts:
                inp_mxts_increments += pos_inp_mask*(
                           tf.nn.conv2d_transpose(
                                value=neg_mxts,
                                filter=self.kernel*hf.lt_mask(self.kernel, 0.0),
                                output_shape=output_shape,
                                padding=self.padding,
                                strides=strides_to_supply
                            ))
            if _pos_to_neg_mxts:
                inp_mxts_increments += neg_inp_mask*(
                            tf.nn.conv2d_transpose(
                                value=pos_mxts,
                                filter=self.kernel*hf.lt_mask(self.kernel, 0.0),
                                output_shape=output_shape,
                                padding=self.padding,
                                strides=strides_to_supply
                            ))
            if _neg_to_neg_mxts:
                inp_mxts_increments += neg_inp_mask*(
                           tf.nn.conv2d_transpose(
                                value=neg_mxts,
                                filter=self.kernel*hf.gt_mask(self.kernel, 0.0),
                                output_shape=output_shape,
                                padding=self.padding,
                                strides=strides_to_supply
                           ))
            if _zero_mxts:
                inp_mxts_increments += zero_inp_mask*tf.nn.conv2d_transpose(
                                value=0.5*(pos_mxts+neg_mxts),
                                filter=self.kernel,
                                output_shape=output_shape,
                                padding=self.padding,
                                strides=strides_to_supply)
            pos_mxts_increments = inp_mxts_increments
            neg_mxts_increments = inp_mxts_increments
        else:
            raise RuntimeError("Unsupported conv mxts mode: "
                               +str(self.conv_mxts_mode))

        if (self.data_format == DataFormat.channels_first):
            pos_mxts_increments = tf.transpose(a=pos_mxts_increments,
                                               perm=(0,3,1,2))
            neg_mxts_increments = tf.transpose(a=neg_mxts_increments,
                                               perm=(0,3,1,2))

        return pos_mxts_increments, neg_mxts_increments