def testConcatRegularizer(self): concat_reg = concat_and_slice_regularizers.ConcatRegularizer( [self._reg1, self._reg2]) with self.cached_session(): self.assertAllEqual(self._alive_vec1 + self._alive_vec2, concat_reg.alive_vector.eval()) self.assertAllClose(self._reg_vec1 + self._reg_vec2, concat_reg.regularization_vector.eval(), 1e-5)
def _create_concat_regularizer(self, concat_op): """Creates an OpRegularizer for a concat op. Args: concat_op: A tf.Operation of type ConcatV2. Returns: An OpRegularizer for `concat_op`. """ # We omit the last input, because it's the concat dimension. Others are # the tensors to be concatenated. input_ops = [i.op for i in concat_op.inputs[:-1]] regularizers_to_concat = [ self._get_regularizer(op) for op in input_ops ] # If all inputs have no regularizer, so does the concat op. if regularizers_to_concat == [None] * len(regularizers_to_concat): return None offset = 0 # Replace the regularizers_to_concat by SlicingReferenceRegularizer-s that # slice the concatenated regularizer. ops_to_concat = [] for r, op in zip(regularizers_to_concat, input_ops): if r is None: length = op.outputs[0].shape.as_list()[-1] offset += length ops_to_concat.append(self._ConstantOpReg(length)) else: length = tf.shape(r.alive_vector)[0] slice_ref = concat_and_slice_regularizers.SlicingReferenceRegularizer( lambda: self._get_regularizer(concat_op), offset, length) offset += length self._replace_regularizer(r, slice_ref) ops_to_concat.append(r) # Create the concatenated regularizer itself. return concat_and_slice_regularizers.ConcatRegularizer(ops_to_concat)
def get_regularizer(self, op): """Returns an OpRegularizer for the specified op. If no OpRegularizer exists for any slices in the op, returns None. Otherwise, create a ConstantOpRegularizer for any slices that are missing a regularizer. Args: op: A tf.Operation. Returns: An OpRegularizer for op, or None if no OpRegularizer exists. """ op_slices = self.get_op_slices(op) regularizers = [ self._op_regularizer_dict.get(op_slice) for op_slice in op_slices ] # If all OpSlice have None regularizer, return None. if not any(regularizers): return None regularizers = [] for op_slice in op_slices: regularizer = self._op_regularizer_dict.get(op_slice) if regularizer is None: regularizer = constant_op_regularizer.ConstantOpRegularizer( op_slice.slice.size) self._op_regularizer_dict[op_slice] = regularizer regularizers.append(regularizer) # If op only has 1 OpSlice, return the regularizer for that OpSlice. # Otherwise, return the concatenation of regularizers for the constituent # OpSlice. if len(regularizers) == 1: return regularizers[0] else: return concat_and_slice_regularizers.ConcatRegularizer( regularizers)