def testSliceRegularizer(self): concat_reg = concat_and_slice_regularizers.SlicingReferenceRegularizer( lambda: self._reg1, 1, 2) with self.cached_session(): self.assertAllEqual(self._alive_vec1[1:3], concat_reg.alive_vector.eval()) self.assertAllClose(self._reg_vec1[1:3], concat_reg.regularization_vector.eval(), 1e-5)
def _create_concat_regularizer(self, concat_op): """Creates an OpRegularizer for a concat op. Args: concat_op: A tf.Operation of type ConcatV2. Returns: An OpRegularizer for `concat_op`. """ # We omit the last input, because it's the concat dimension. Others are # the tensors to be concatenated. input_ops = [i.op for i in concat_op.inputs[:-1]] regularizers_to_concat = [ self._get_regularizer(op) for op in input_ops ] # If all inputs have no regularizer, so does the concat op. if regularizers_to_concat == [None] * len(regularizers_to_concat): return None offset = 0 # Replace the regularizers_to_concat by SlicingReferenceRegularizer-s that # slice the concatenated regularizer. ops_to_concat = [] for r, op in zip(regularizers_to_concat, input_ops): if r is None: length = op.outputs[0].shape.as_list()[-1] offset += length ops_to_concat.append(self._ConstantOpReg(length)) else: length = tf.shape(r.alive_vector)[0] slice_ref = concat_and_slice_regularizers.SlicingReferenceRegularizer( lambda: self._get_regularizer(concat_op), offset, length) offset += length self._replace_regularizer(r, slice_ref) ops_to_concat.append(r) # Create the concatenated regularizer itself. return concat_and_slice_regularizers.ConcatRegularizer(ops_to_concat)