def test_main(self): named_ops = [(str(idx), self.Operator()) for idx in range(3)] op_container = ops.OperatorContainer(named_ops) actuals = op_container.named_children() desireds = named_ops assert_named_modules_identical(actuals, desireds)
def op_container(self): return ops.OperatorContainer( ( ("regularization", self.RegularizationOperator()), ("comparison", self.ComparisonOperator()), ) )
def get_container(): return ops.OperatorContainer( ( ("regularization", RegularizationTestOperator()), ("comparison", ComparisonTestOperator()), ) )
def test_call(self): input = torch.tensor(0.0) named_ops = [(str(idx), self.Operator(idx + 1.0)) for idx in range(3)] op_container = ops.OperatorContainer(named_ops) actual = op_container(input) desired = pystiche.LossDict([(name, input + op.bias) for name, op in named_ops]) ptu.assert_allclose(actual, desired)
def test_OperatorContainer(self): class TestOperator(ops.Operator): def process_input_image(self, image): pass named_ops = [(str(idx), TestOperator()) for idx in range(3)] op_container = ops.OperatorContainer(named_ops) actuals = op_container.named_children() desireds = named_ops self.assertNamedChildrenEqual(actuals, desireds)
def test_OperatorContainer_getitem(self): class TestOperator(ops.Operator): def __init__(self, bias): super().__init__() self.bias = bias def process_input_image(self, image): return image + self.bias named_ops = [(str(idx), TestOperator(idx + 1.0)) for idx in range(3)] op_container = ops.OperatorContainer(named_ops) for name, _ in named_ops: actual = op_container[name] desired = getattr(op_container, name) self.assertIs(actual, desired)
def test_OperatorContainer_call(self): class TestOperator(ops.Operator): def __init__(self, bias): super().__init__() self.bias = bias def process_input_image(self, image): return image + self.bias input = torch.tensor(0.0) named_ops = [(str(idx), TestOperator(idx + 1.0)) for idx in range(3)] op_container = ops.OperatorContainer(named_ops) actual = op_container(input) desired = pystiche.LossDict([(name, input + op.bias) for name, op in named_ops]) self.assertTensorDictAlmostEqual(actual, desired)