def testConcatRegularizer(self):
   concat_reg = concat_and_slice_regularizers.ConcatRegularizer(
       [self._reg1, self._reg2])
   with self.cached_session():
     self.assertAllEqual(self._alive_vec1 + self._alive_vec2,
                         concat_reg.alive_vector.eval())
     self.assertAllClose(self._reg_vec1 + self._reg_vec2,
                         concat_reg.regularization_vector.eval(), 1e-5)
    def _create_concat_regularizer(self, concat_op):
        """Creates an OpRegularizer for a concat op.

    Args:
      concat_op: A tf.Operation of type ConcatV2.

    Returns:
      An OpRegularizer for `concat_op`.
    """
        # We omit the last input, because it's the concat dimension. Others are
        # the tensors to be concatenated.
        input_ops = [i.op for i in concat_op.inputs[:-1]]
        regularizers_to_concat = [
            self._get_regularizer(op) for op in input_ops
        ]
        # If all inputs have no regularizer, so does the concat op.
        if regularizers_to_concat == [None] * len(regularizers_to_concat):
            return None
        offset = 0

        # Replace the regularizers_to_concat by SlicingReferenceRegularizer-s that
        # slice the concatenated regularizer.
        ops_to_concat = []
        for r, op in zip(regularizers_to_concat, input_ops):
            if r is None:
                length = op.outputs[0].shape.as_list()[-1]
                offset += length
                ops_to_concat.append(self._ConstantOpReg(length))
            else:
                length = tf.shape(r.alive_vector)[0]
                slice_ref = concat_and_slice_regularizers.SlicingReferenceRegularizer(
                    lambda: self._get_regularizer(concat_op), offset, length)
                offset += length
                self._replace_regularizer(r, slice_ref)
                ops_to_concat.append(r)

        # Create the concatenated regularizer itself.
        return concat_and_slice_regularizers.ConcatRegularizer(ops_to_concat)
    def get_regularizer(self, op):
        """Returns an OpRegularizer for the specified op.

    If no OpRegularizer exists for any slices in the op, returns None.
    Otherwise, create a ConstantOpRegularizer for any slices that are missing a
    regularizer.

    Args:
      op: A tf.Operation.

    Returns:
      An OpRegularizer for op, or None if no OpRegularizer exists.
    """
        op_slices = self.get_op_slices(op)
        regularizers = [
            self._op_regularizer_dict.get(op_slice) for op_slice in op_slices
        ]
        # If all OpSlice have None regularizer, return None.
        if not any(regularizers):
            return None

        regularizers = []
        for op_slice in op_slices:
            regularizer = self._op_regularizer_dict.get(op_slice)
            if regularizer is None:
                regularizer = constant_op_regularizer.ConstantOpRegularizer(
                    op_slice.slice.size)
                self._op_regularizer_dict[op_slice] = regularizer
            regularizers.append(regularizer)

        # If op only has 1 OpSlice, return the regularizer for that OpSlice.
        # Otherwise, return the concatenation of regularizers for the constituent
        # OpSlice.
        if len(regularizers) == 1:
            return regularizers[0]
        else:
            return concat_and_slice_regularizers.ConcatRegularizer(
                regularizers)