示例#1
0
 def test_basic_vector_double(self):
     # Ensure that we can get vectors templated on double by reference.
     init = [1., 2, 3]
     value_data = BasicVector(init)
     value = value_data.get_mutable_value()
     # TODO(eric.cousineau): Determine if there is a way to extract the
     # pointer referred to by the buffer (e.g. `value.data`).
     value[:] += 1
     expected = [2., 3, 4]
     self.assertTrue(np.allclose(value, expected))
     self.assertTrue(np.allclose(value_data.get_value(), expected))
     self.assertTrue(np.allclose(value_data.get_mutable_value(), expected))
     expected = [5., 6, 7]
     value_data.SetFromVector(expected)
     self.assertTrue(np.allclose(value, expected))
     self.assertTrue(np.allclose(value_data.get_value(), expected))
     self.assertTrue(np.allclose(value_data.get_mutable_value(), expected))
示例#2
0
    def test_basic_vector_double(self):
        # Test constructing vectors of sizes [0, 1, 2], and ensure that we can
        # construct from both lists and `np.array` objects with no ambiguity.
        for n in [0, 1, 2]:
            for wrap in [pass_through, np.array]:
                # Ensure that we can get vectors templated on double by
                # reference.
                expected_init = wrap([float(x) for x in range(n)])
                expected_add = wrap([x + 1 for x in expected_init])
                expected_set = wrap([x + 10 for x in expected_init])

                value_data = BasicVector(expected_init)
                value = value_data.get_mutable_value()
                self.assertTrue(np.allclose(value, expected_init))

                # Add value directly.
                # TODO(eric.cousineau): Determine if there is a way to extract
                # the pointer referred to by the buffer (e.g. `value.data`).
                value[:] += 1
                self.assertTrue(np.allclose(value, expected_add))
                self.assertTrue(
                    np.allclose(value_data.get_value(), expected_add))
                self.assertTrue(
                    np.allclose(value_data.get_mutable_value(), expected_add))

                # Set value from `BasicVector`.
                value_data.SetFromVector(expected_set)
                self.assertTrue(np.allclose(value, expected_set))
                self.assertTrue(
                    np.allclose(value_data.get_value(), expected_set))
                self.assertTrue(
                    np.allclose(value_data.get_mutable_value(), expected_set))
                # Ensure we can construct from size.
                value_data = BasicVector(n)
                self.assertEqual(value_data.size(), n)
                # Ensure we can clone.
                value_copies = [
                    value_data.Clone(),
                    copy.copy(value_data),
                    copy.deepcopy(value_data),
                ]
                for value_copy in value_copies:
                    self.assertTrue(value_copy is not value_data)
                    self.assertEqual(value_data.size(), n)
示例#3
0
    def test_basic_vector_double(self):
        # Test constructing vectors of sizes [0, 1, 2], and ensure that we can
        # construct from both lists and `np.array` objects with no ambiguity.
        for n in [0, 1, 2]:
            for wrap in [pass_through, np.array]:
                # Ensure that we can get vectors templated on double by
                # reference.
                expected_init = wrap(map(float, range(n)))
                expected_add = wrap([x + 1 for x in expected_init])
                expected_set = wrap([x + 10 for x in expected_init])

                value_data = BasicVector(expected_init)
                value = value_data.get_mutable_value()
                self.assertTrue(np.allclose(value, expected_init))

                # Add value directly.
                # TODO(eric.cousineau): Determine if there is a way to extract
                # the pointer referred to by the buffer (e.g. `value.data`).
                value[:] += 1
                self.assertTrue(np.allclose(value, expected_add))
                self.assertTrue(
                    np.allclose(value_data.get_value(), expected_add))
                self.assertTrue(
                    np.allclose(value_data.get_mutable_value(), expected_add))

                # Set value from `BasicVector`.
                value_data.SetFromVector(expected_set)
                self.assertTrue(np.allclose(value, expected_set))
                self.assertTrue(
                    np.allclose(value_data.get_value(), expected_set))
                self.assertTrue(
                    np.allclose(value_data.get_mutable_value(), expected_set))
                # Ensure we can construct from size.
                value_data = BasicVector(n)
                self.assertEquals(value_data.size(), n)
                # Ensure we can clone.
                value_copies = [
                    value_data.Clone(),
                    copy.copy(value_data),
                    copy.deepcopy(value_data),
                ]
                for value_copy in value_copies:
                    self.assertTrue(value_copy is not value_data)
                    self.assertEquals(value_data.size(), n)
    def test_abstract_annotations(self):
        test_str = "s"
        test_vector = BasicVector([1., 2., 3.])

        def check(func, u):
            system = FunctionSystem(func)
            context = system.CreateDefaultContext()
            system.get_input_port(0).FixValue(context, u)
            return system.get_output_port(0).Eval(context)

        def explicit_abstract_input(value: Value[str]):
            self.assertIsInstance(value, AbstractValue)
            return value.get_value()

        self.assertEqual(test_str, check(explicit_abstract_input, test_str))

        def explicit_abstract_output(value: str) -> Value[str]:
            self.assertIsInstance(value, str)
            return AbstractValue.Make(value)

        self.assertEqual(test_str, check(explicit_abstract_output, test_str))

        def implicit_abstract_output(value: str):
            self.assertIsInstance(value, str)
            return AbstractValue.Make(value)

        self.assertEqual(test_str, check(implicit_abstract_output, test_str))

        def explicit_basic_vector(value: BasicVector(3)) -> BasicVector(3):
            self.assertIsInstance(value, BasicVector)
            return value

        np.testing.assert_equal(test_vector.get_value(),
                                check(explicit_basic_vector, test_vector))

        def bad_basic_vector_cls(value: BasicVector) -> BasicVector:
            pass

        with self.assertRaises(AssertionError) as cm:
            check(bad_basic_vector_cls, test_vector)
        self.assertIn("Must supply BasicVector_[] instance, not type",
                      str(cm.exception))

        def bad_basic_vector_value_cls(value: Value[BasicVector]) -> float:
            return 0.

        with self.assertRaises(AssertionError) as cm:
            check(bad_basic_vector_value_cls, test_vector)
        self.assertIn("Cannot specify Value[BasicVector_[]]",
                      str(cm.exception))
    def test_context_api(self):
        # Capture miscellaneous functions not yet tested.
        model_value = AbstractValue.Make("Hello")
        model_vector = BasicVector([1., 2.])

        class TrivialSystem(LeafSystem):
            def __init__(self):
                LeafSystem.__init__(self)
                self.DeclareContinuousState(1)
                self.DeclareDiscreteState(2)
                self.DeclareAbstractState(model_value.Clone())
                self.DeclareAbstractParameter(model_value.Clone())
                self.DeclareNumericParameter(model_vector.Clone())

        system = TrivialSystem()
        context = system.CreateDefaultContext()
        self.assertTrue(context.get_state() is context.get_mutable_state())
        self.assertEqual(context.num_continuous_states(), 1)
        self.assertTrue(context.get_continuous_state_vector() is
                        context.get_mutable_continuous_state_vector())
        self.assertEqual(context.num_discrete_state_groups(), 1)
        self.assertTrue(context.get_discrete_state_vector() is
                        context.get_mutable_discrete_state_vector())
        self.assertTrue(
            context.get_discrete_state(0) is
            context.get_discrete_state_vector())
        self.assertTrue(
            context.get_discrete_state(0) is
            context.get_discrete_state().get_vector(0))
        self.assertTrue(
            context.get_mutable_discrete_state(0) is
            context.get_mutable_discrete_state_vector())
        self.assertTrue(
            context.get_mutable_discrete_state(0) is
            context.get_mutable_discrete_state().get_vector(0))
        self.assertEqual(context.num_abstract_states(), 1)
        self.assertTrue(context.get_abstract_state() is
                        context.get_mutable_abstract_state())
        self.assertTrue(
            context.get_abstract_state(0) is
            context.get_mutable_abstract_state(0))
        self.assertEqual(
            context.get_abstract_state(0).get_value(), model_value.get_value())

        # Check abstract state API (also test AbstractValues).
        values = context.get_abstract_state()
        self.assertEqual(values.size(), 1)
        self.assertEqual(
            values.get_value(0).get_value(), model_value.get_value())
        self.assertEqual(
            values.get_mutable_value(0).get_value(), model_value.get_value())
        values.SetFrom(values.Clone())

        # Check parameter accessors.
        self.assertEqual(system.num_abstract_parameters(), 1)
        self.assertEqual(
            context.get_abstract_parameter(index=0).get_value(),
            model_value.get_value())
        self.assertEqual(system.num_numeric_parameter_groups(), 1)
        np.testing.assert_equal(
            context.get_numeric_parameter(index=0).get_value(),
            model_vector.get_value())

        # Check diagram context accessors.
        builder = DiagramBuilder()
        builder.AddSystem(system)
        diagram = builder.Build()
        context = diagram.CreateDefaultContext()
        # Existence check.
        self.assertIsNot(diagram.GetMutableSubsystemState(system, context),
                         None)
        subcontext = diagram.GetMutableSubsystemContext(system, context)
        self.assertIsNot(subcontext, None)
        self.assertIs(diagram.GetSubsystemContext(system, context), subcontext)
示例#6
0
    def test_context_api(self):
        # Capture miscellaneous functions not yet tested.
        model_value = AbstractValue.Make("Hello")
        model_vector = BasicVector([1., 2.])

        class TrivialSystem(LeafSystem):
            def __init__(self):
                LeafSystem.__init__(self)
                self.DeclareContinuousState(1)
                self.DeclareDiscreteState(2)
                self.DeclareAbstractState(model_value.Clone())
                self.DeclareAbstractParameter(model_value.Clone())
                self.DeclareNumericParameter(model_vector.Clone())

        system = TrivialSystem()
        context = system.CreateDefaultContext()
        self.assertTrue(
            context.get_state() is context.get_mutable_state())
        self.assertEqual(context.num_continuous_states(), 1)
        self.assertTrue(
            context.get_continuous_state_vector() is
            context.get_mutable_continuous_state_vector())
        self.assertEqual(context.num_discrete_state_groups(), 1)
        with catch_drake_warnings(expected_count=1):
            context.get_num_discrete_state_groups()
        self.assertTrue(
            context.get_discrete_state_vector() is
            context.get_mutable_discrete_state_vector())
        self.assertTrue(
            context.get_discrete_state(0) is
            context.get_discrete_state_vector())
        self.assertTrue(
            context.get_discrete_state(0) is
            context.get_discrete_state().get_vector(0))
        self.assertTrue(
            context.get_mutable_discrete_state(0) is
            context.get_mutable_discrete_state_vector())
        self.assertTrue(
            context.get_mutable_discrete_state(0) is
            context.get_mutable_discrete_state().get_vector(0))
        self.assertEqual(context.num_abstract_states(), 1)
        with catch_drake_warnings(expected_count=1):
            context.get_num_abstract_states()
        self.assertTrue(
            context.get_abstract_state() is
            context.get_mutable_abstract_state())
        self.assertTrue(
            context.get_abstract_state(0) is
            context.get_mutable_abstract_state(0))
        self.assertEqual(
            context.get_abstract_state(0).get_value(), model_value.get_value())

        # Check abstract state API (also test AbstractValues).
        values = context.get_abstract_state()
        self.assertEqual(values.size(), 1)
        self.assertEqual(
            values.get_value(0).get_value(), model_value.get_value())
        self.assertEqual(
            values.get_mutable_value(0).get_value(), model_value.get_value())
        values.SetFrom(values.Clone())
        with catch_drake_warnings(expected_count=1):
            values.CopyFrom(values.Clone())

        # Check parameter accessors.
        self.assertEqual(system.num_abstract_parameters(), 1)
        self.assertEqual(
            context.get_abstract_parameter(index=0).get_value(),
            model_value.get_value())
        self.assertEqual(system.num_numeric_parameter_groups(), 1)
        np.testing.assert_equal(
            context.get_numeric_parameter(index=0).get_value(),
            model_vector.get_value())

        # Check diagram context accessors.
        builder = DiagramBuilder()
        builder.AddSystem(system)
        diagram = builder.Build()
        context = diagram.CreateDefaultContext()
        # Existence check.
        self.assertIsNot(
            diagram.GetMutableSubsystemState(system, context), None)
        subcontext = diagram.GetMutableSubsystemContext(system, context)
        self.assertIsNot(subcontext, None)
        self.assertIs(
            diagram.GetSubsystemContext(system, context), subcontext)
class _ArgHelper:
    """Provides information and functions to aid in interfacing a Python
    function with the Systems framework."""
    def __init__(self, name, cls, scalar_as_vector):
        """Given a class (or type annotation), figure out the type (vector
        port, abstract port, or context time), the model value (for ports), and
        example value (for output inference)."""

        # Name can be overridden.
        self.name = name
        self._scalar_needs_conversion = False
        self._is_direct_type = False
        if isinstance(cls, VectorArg):
            self.type = PortDataType.kVectorValued
            self.model = BasicVector(cls._size)
            self.model.get_mutable_value()[:] = 0
            self.example = self.model.get_value()
        elif BasicVector_.is_instantiation(cls):
            assert False, (
                f"Must supply BasicVector_[] instance, not type: {cls}")
        elif BasicVector_.is_instantiation(type(cls)):
            self.type = PortDataType.kVectorValued
            self.model = cls
            self.example = self.model
            self._is_direct_type = True
        elif scalar_as_vector and cls in SCALAR_TYPES:
            self.type = PortDataType.kVectorValued
            self.model = BasicVector(1)
            self.model.get_mutable_value()[:] = 0
            self.example = float()  # Should this be smarter about the type?
            self._scalar_needs_conversion = True
        elif cls is ContextTimeArg:
            self.type = ContextTimeArg
            self.model = None
            self.example = float()
        else:
            self.type = PortDataType.kAbstractValued
            self.model, self.example = _get_abstract_model_and_example(cls)
            if self.model is self.example:
                self._is_direct_type = True

    def _squeeze(self, x):
        if self._scalar_needs_conversion:
            assert x.shape == (1, ), f"Bad input: {x}"
            return x.item(0)
        else:
            return x

    def _unsqueeze(self, x):
        if self._scalar_needs_conversion:
            return np.array([x])
        else:
            return x

    def declare_input_eval(self, system):
        """Declares an input evaluation function. If a port is needed, will
        declare the port."""
        if self.type is ContextTimeArg:
            return Context.get_time
        elif self.type == PortDataType.kAbstractValued:
            # DeclareInputPort does not work with kAbstractValued :(
            port = system.DeclareAbstractInputPort(name=self.name,
                                                   model_value=self.model)
            if self._is_direct_type:
                return port.EvalAbstract
            else:
                return port.Eval
        else:
            port = system.DeclareVectorInputPort(name=self.name,
                                                 model_vector=self.model)
            if self._is_direct_type:
                return port.EvalBasicVector
            else:
                return lambda context: self._squeeze(port.Eval(context))

    def declare_output_port(self, system, calc):
        """Declares an output port on a given system."""
        if self.type is ContextTimeArg:
            assert False, dedent(r"""\
                ContextTimeArg is disallowed for output arguments. If needed,
                explicitly pass it through, e.g.:
                    def context_time(t: ContextTimeArg):
                        return t
                """)
        elif self.type == PortDataType.kAbstractValued:
            system.DeclareAbstractOutputPort(name=self.name,
                                             alloc=self.model.Clone,
                                             calc=calc)
        else:
            system.DeclareVectorOutputPort(name=self.name,
                                           model_value=self.model,
                                           calc=calc)

    def get_set_output_func(self):
        assert self.type is not ContextTimeArg
        if self.type == PortDataType.kAbstractValued:
            if self._is_direct_type:
                return lambda output, value: output.SetFrom(value)
            else:
                return lambda output, value: output.set_value(value)
        else:
            if self._is_direct_type:
                # TODO(eric.cousineau): Bind VectorBase.SetFrom().
                return lambda output, value: output.SetFromVector(value.
                                                                  get_value())
            else:
                return lambda output, value: output.SetFromVector(
                    self._unsqueeze(value))