def testSliceRegularizer(self):
   concat_reg = concat_and_slice_regularizers.SlicingReferenceRegularizer(
       lambda: self._reg1, 1, 2)
   with self.cached_session():
     self.assertAllEqual(self._alive_vec1[1:3],
                         concat_reg.alive_vector.eval())
     self.assertAllClose(self._reg_vec1[1:3],
                         concat_reg.regularization_vector.eval(), 1e-5)
예제 #2
0
    def _create_concat_regularizer(self, concat_op):
        """Creates an OpRegularizer for a concat op.

    Args:
      concat_op: A tf.Operation of type ConcatV2.

    Returns:
      An OpRegularizer for `concat_op`.
    """
        # We omit the last input, because it's the concat dimension. Others are
        # the tensors to be concatenated.
        input_ops = [i.op for i in concat_op.inputs[:-1]]
        regularizers_to_concat = [
            self._get_regularizer(op) for op in input_ops
        ]
        # If all inputs have no regularizer, so does the concat op.
        if regularizers_to_concat == [None] * len(regularizers_to_concat):
            return None
        offset = 0

        # Replace the regularizers_to_concat by SlicingReferenceRegularizer-s that
        # slice the concatenated regularizer.
        ops_to_concat = []
        for r, op in zip(regularizers_to_concat, input_ops):
            if r is None:
                length = op.outputs[0].shape.as_list()[-1]
                offset += length
                ops_to_concat.append(self._ConstantOpReg(length))
            else:
                length = tf.shape(r.alive_vector)[0]
                slice_ref = concat_and_slice_regularizers.SlicingReferenceRegularizer(
                    lambda: self._get_regularizer(concat_op), offset, length)
                offset += length
                self._replace_regularizer(r, slice_ref)
                ops_to_concat.append(r)

        # Create the concatenated regularizer itself.
        return concat_and_slice_regularizers.ConcatRegularizer(ops_to_concat)