def testProbGroupingRegularizer(self): reg_vec1 = [0.1, 0.3, 0.6, 0.2] alive_vec1 = [False, True, True, False] reg_vec2 = [0.2, 0.4, 0.5, 0.1] alive_vec2 = [False, True, False, True] reg_vec3 = [0.3, 0.2, 0.0, 0.25] alive_vec3 = [False, True, False, True] reg1 = op_regularizer_stub.OpRegularizerStub(reg_vec1, alive_vec1) reg2 = op_regularizer_stub.OpRegularizerStub(reg_vec2, alive_vec2) reg3 = op_regularizer_stub.OpRegularizerStub(reg_vec3, alive_vec3) for reg in [reg1, reg2, reg3]: reg.is_probabilistic = True expected_grouped_reg = [0.496, 0.664, 0.8, 0.46] group_reg = pgr.ProbabilisticGroupingRegularizer( [reg1, reg2, reg3]) with self.cached_session(): self.assertAllEqual( [x or y or z for x, y, z in zip( alive_vec1, alive_vec2, alive_vec3)], group_reg.alive_vector.eval()) self.assertAllClose( expected_grouped_reg, group_reg.regularization_vector.eval(), 1e-5)
def setUp(self): self._reg_vec1 = [0.1, 0.3, 0.6, 0.2] self._alive_vec1 = [False, True, True, False] self._reg_vec2 = [0.2, 0.4, 0.5] self._alive_vec2 = [False, True, False] self._reg1 = op_regularizer_stub.OpRegularizerStub(self._reg_vec1, self._alive_vec1) self._reg2 = op_regularizer_stub.OpRegularizerStub(self._reg_vec2, self._alive_vec2)
def setUp(self): super(ConcatAndSliceRegularizersTest, self).setUp() self._reg_vec1 = [0.1, 0.3, 0.6, 0.2] self._alive_vec1 = [False, True, True, False] self._reg_vec2 = [0.2, 0.4, 0.5] self._alive_vec2 = [False, True, False] self._reg1 = op_regularizer_stub.OpRegularizerStub(self._reg_vec1, self._alive_vec1) self._reg2 = op_regularizer_stub.OpRegularizerStub(self._reg_vec2, self._alive_vec2)
def setUp(self): super(GroupingRegularizersTest, self).setUp() self._reg_vec1 = [0.1, 0.3, 0.6, 0.2] self._alive_vec1 = [False, True, True, False] self._reg_vec2 = [0.2, 0.4, 0.5, 0.1] self._alive_vec2 = [False, True, False, True] self._reg_vec3 = [0.3, 0.2, 0.0, 0.25] self._alive_vec3 = [False, True, False, True] self._reg1 = op_regularizer_stub.OpRegularizerStub( self._reg_vec1, self._alive_vec1) self._reg2 = op_regularizer_stub.OpRegularizerStub( self._reg_vec2, self._alive_vec2) self._reg3 = op_regularizer_stub.OpRegularizerStub( self._reg_vec3, self._alive_vec3)
def regularizer(conv_op, manager=None): del manager # unused for prefix in ['conv1', 'conv2']: if conv_op.name.startswith(prefix): return op_regularizer_stub.OpRegularizerStub( reg[prefix], reg[prefix] > th)