def _get_mxts_increments_for_inputs(self): if (self.dense_mxts_mode == DenseMxtsMode.Linear): #different inputs will inherit multipliers differently according #to the sign of inp_diff_ref (as this sign was used to determine #the pos_contribs and neg_contribs; there was no breakdown #by the pos/neg contribs of the input) inp_diff_ref = self._get_input_diff_from_reference_vars() pos_inp_mask = hf.gt_mask(inp_diff_ref,0.0) neg_inp_mask = hf.lt_mask(inp_diff_ref,0.0) zero_inp_mask = hf.eq_mask(inp_diff_ref,0.0) kernel_T = tf.transpose(self.kernel) inp_mxts_increments = 0 if _pos_to_pos_mxts: inp_mxts_increments += pos_inp_mask*( tf.matmul(self.get_pos_mxts(), kernel_T*(hf.gt_mask(kernel_T, 0.0)))) if _neg_to_pos_mxts: inp_mxts_increments += pos_inp_mask*( tf.matmul(self.get_neg_mxts(), kernel_T*(hf.lt_mask(kernel_T, 0.0)))) if _pos_to_neg_mxts: inp_mxts_increments += neg_inp_mask*( tf.matmul(self.get_pos_mxts(), kernel_T*(hf.lt_mask(kernel_T, 0.0)))) if _neg_to_neg_mxts: inp_mxts_increments += neg_inp_mask*( tf.matmul(self.get_neg_mxts(), kernel_T*(hf.gt_mask(kernel_T, 0.0)))) if _zero_mxts: inp_mxts_increments += zero_inp_mask*( tf.matmul(0.5*(self.get_pos_mxts() +self.get_neg_mxts()), kernel_T)) #pos_mxts and neg_mxts in the input get the same multiplier #because the breakdown between pos and neg wasn't used to #compute pos_contribs and neg_contribs in the forward pass #(it was based entirely on inp_diff_ref) return inp_mxts_increments, inp_mxts_increments else: raise RuntimeError("Unsupported mxts mode: " +str(self.dense_mxts_mode))
def _get_naive_rescale_factor(self): input_diff_from_reference = self._get_input_diff_from_reference_vars() near_zero_contrib_mask = hf.lt_mask( tf.abs(input_diff_from_reference), NEAR_ZERO_THRESHOLD) far_from_zero_contrib_mask = 1.0-near_zero_contrib_mask #the pseudocount is to avoid division-by-zero for the ones that #we won't use anyway pc_diff_from_reference = input_diff_from_reference +\ (1.0*near_zero_contrib_mask) #when total contrib is near zero, #the scale factor is 1 (gradient; piecewise linear). Otherwise, #compute the scale factor. The pseudocount doesn't mess anything up #as it is only there to prevent division by zero for the cases where #the contrib is near zero. scale_factor = near_zero_contrib_mask*\ self._get_gradient_at_default_activation_var() +\ (far_from_zero_contrib_mask*\ (self._get_diff_from_reference_vars()/ pc_diff_from_reference)) return scale_factor
def _build_activation_vars(self, input_act_vars): to_return = tf.nn.relu(input_act_vars) negative_mask = hf.lt_mask(input_act_vars,0.0) to_return = to_return + negative_mask*input_act_vars*self.alpha return to_return
def _get_mxts_increments_for_inputs(self): pos_mxts = self.get_pos_mxts() neg_mxts = self.get_neg_mxts() inp_diff_ref = self._get_input_diff_from_reference_vars() inp_act_vars = self.inputs.get_activation_vars() strides_to_supply = [1]+list(self.strides)+[1] if (self.data_format == DataFormat.channels_first): pos_mxts = tf.transpose(a=pos_mxts, perm=(0,2,3,1)) neg_mxts = tf.transpose(a=neg_mxts, perm=(0,2,3,1)) inp_diff_ref = tf.transpose(a=inp_diff_ref, perm=(0,2,3,1)) inp_act_vars = tf.transpose(a=inp_act_vars, perm=(0,2,3,1)) output_shape = tf.shape(inp_act_vars) if (self.conv_mxts_mode == ConvMxtsMode.Linear): pos_inp_mask = hf.gt_mask(inp_diff_ref,0.0) neg_inp_mask = hf.lt_mask(inp_diff_ref,0.0) zero_inp_mask = hf.eq_mask(inp_diff_ref, 0.0) inp_mxts_increments = 0 if _pos_to_pos_mxts: inp_mxts_increments += pos_inp_mask*( tf.nn.conv2d_transpose( value=pos_mxts, filter=self.kernel*hf.gt_mask(self.kernel, 0.0), output_shape=output_shape, padding=self.padding, strides=strides_to_supply )) if _neg_to_pos_mxts: inp_mxts_increments += pos_inp_mask*( tf.nn.conv2d_transpose( value=neg_mxts, filter=self.kernel*hf.lt_mask(self.kernel, 0.0), output_shape=output_shape, padding=self.padding, strides=strides_to_supply )) if _pos_to_neg_mxts: inp_mxts_increments += neg_inp_mask*( tf.nn.conv2d_transpose( value=pos_mxts, filter=self.kernel*hf.lt_mask(self.kernel, 0.0), output_shape=output_shape, padding=self.padding, strides=strides_to_supply )) if _neg_to_neg_mxts: inp_mxts_increments += neg_inp_mask*( tf.nn.conv2d_transpose( value=neg_mxts, filter=self.kernel*hf.gt_mask(self.kernel, 0.0), output_shape=output_shape, padding=self.padding, strides=strides_to_supply )) if _zero_mxts: inp_mxts_increments += zero_inp_mask*tf.nn.conv2d_transpose( value=0.5*(pos_mxts+neg_mxts), filter=self.kernel, output_shape=output_shape, padding=self.padding, strides=strides_to_supply) pos_mxts_increments = inp_mxts_increments neg_mxts_increments = inp_mxts_increments else: raise RuntimeError("Unsupported conv mxts mode: " +str(self.conv_mxts_mode)) if (self.data_format == DataFormat.channels_first): pos_mxts_increments = tf.transpose(a=pos_mxts_increments, perm=(0,3,1,2)) neg_mxts_increments = tf.transpose(a=neg_mxts_increments, perm=(0,3,1,2)) return pos_mxts_increments, neg_mxts_increments