예제 #1
0
    def test_context_api(self):
        system = Adder(3, 10)
        context = system.AllocateContext()
        self.assertIsInstance(
            context.get_continuous_state(), ContinuousState)
        self.assertIsInstance(
            context.get_mutable_continuous_state(), ContinuousState)
        self.assertIsInstance(
            context.get_continuous_state_vector(), VectorBase)
        self.assertIsInstance(
            context.get_mutable_continuous_state_vector(), VectorBase)

        context = system.CreateDefaultContext()
        self.assertIsInstance(
            context.get_continuous_state(), ContinuousState)
        self.assertIsInstance(
            context.get_mutable_continuous_state(), ContinuousState)
        self.assertIsInstance(
            context.get_continuous_state_vector(), VectorBase)
        self.assertIsInstance(
            context.get_mutable_continuous_state_vector(), VectorBase)
        # TODO(eric.cousineau): Consolidate main API tests for `Context` here.

        pendulum = PendulumPlant()
        context = pendulum.CreateDefaultContext()
        self.assertEqual(context.num_numeric_parameter_groups(), 1)
        self.assertEqual(pendulum.num_numeric_parameter_groups(), 1)
        self.assertTrue(
            context.get_parameters().get_numeric_parameter(0) is
            context.get_numeric_parameter(index=0))
        self.assertEqual(context.num_abstract_parameters(), 0)
        self.assertEqual(pendulum.num_numeric_parameter_groups(), 1)
        # TODO(russt): Bind _Declare*Parameter or find an example with an
        # abstract parameter to actually call this method.
        self.assertTrue(hasattr(context, "get_abstract_parameter"))
        x = np.array([0.1, 0.2])
        context.SetContinuousState(x)
        np.testing.assert_equal(
            context.get_continuous_state_vector().CopyToVector(), x)

        # RimlessWheel has a single discrete variable and a bool abstract
        # variable.
        rimless = RimlessWheel()
        context = rimless.CreateDefaultContext()
        x = np.array([1.125])
        context.SetDiscreteState(xd=2 * x)
        np.testing.assert_equal(
            context.get_discrete_state_vector().CopyToVector(), 2 * x)
        context.SetDiscreteState(group_index=0, xd=3 * x)
        np.testing.assert_equal(
            context.get_discrete_state_vector().CopyToVector(), 3 * x)

        context.SetAbstractState(index=0, value=True)
        value = context.get_abstract_state(0)
        self.assertTrue(value.get_value())
        context.SetAbstractState(index=0, value=False)
        value = context.get_abstract_state(0)
        self.assertFalse(value.get_value())
예제 #2
0
    def test_context_api(self):
        system = Adder(3, 10)
        context = system.AllocateContext()
        self.assertIsInstance(
            context.get_continuous_state(), ContinuousState)
        self.assertIsInstance(
            context.get_mutable_continuous_state(), ContinuousState)
        self.assertIsInstance(
            context.get_continuous_state_vector(), VectorBase)
        self.assertIsInstance(
            context.get_mutable_continuous_state_vector(), VectorBase)

        context = system.CreateDefaultContext()
        self.assertIsInstance(
            context.get_continuous_state(), ContinuousState)
        self.assertIsInstance(
            context.get_mutable_continuous_state(), ContinuousState)
        self.assertIsInstance(
            context.get_continuous_state_vector(), VectorBase)
        self.assertIsInstance(
            context.get_mutable_continuous_state_vector(), VectorBase)
        self.assertTrue(context.is_stateless())
        self.assertFalse(context.has_only_continuous_state())
        self.assertFalse(context.has_only_discrete_state())
        self.assertEqual(context.num_total_states(), 0)
        # TODO(eric.cousineau): Consolidate main API tests for `Context` here.

        # Test methods with two scalar types.
        for T in [float, AutoDiffXd, Expression]:
            systemT = Adder_[T](3, 10)
            contextT = systemT.CreateDefaultContext()
            for U in [float, AutoDiffXd, Expression]:
                systemU = Adder_[U](3, 10)
                contextU = systemU.CreateDefaultContext()
                contextU.SetTime(0.5)
                contextT.SetTimeStateAndParametersFrom(contextU)
                if T == float:
                    self.assertEqual(contextT.get_time(), 0.5)
                elif T == AutoDiffXd:
                    self.assertEqual(contextT.get_time().value(), 0.5)
                else:
                    self.assertEqual(contextT.get_time().Evaluate(), 0.5)

        pendulum = PendulumPlant()
        context = pendulum.CreateDefaultContext()
        self.assertEqual(context.num_numeric_parameter_groups(), 1)
        self.assertEqual(pendulum.num_numeric_parameter_groups(), 1)
        self.assertTrue(
            context.get_parameters().get_numeric_parameter(0) is
            context.get_numeric_parameter(index=0))
        self.assertTrue(
            context.get_mutable_parameters().get_mutable_numeric_parameter(
                0) is context.get_mutable_numeric_parameter(index=0))
        self.assertEqual(context.num_abstract_parameters(), 0)
        self.assertEqual(pendulum.num_numeric_parameter_groups(), 1)
        # TODO(russt): Bind _Declare*Parameter or find an example with an
        # abstract parameter to actually call this method.
        self.assertTrue(hasattr(context, "get_abstract_parameter"))
        self.assertTrue(hasattr(context, "get_mutable_abstract_parameter"))
        x = np.array([0.1, 0.2])
        context.SetContinuousState(x)
        np.testing.assert_equal(
            context.get_continuous_state_vector().CopyToVector(), x)
        context.SetTimeAndContinuousState(0.3, 2*x)
        np.testing.assert_equal(context.get_time(), 0.3)
        np.testing.assert_equal(
            context.get_continuous_state_vector().CopyToVector(), 2*x)

        # RimlessWheel has a single discrete variable and a bool abstract
        # variable.
        rimless = RimlessWheel()
        context = rimless.CreateDefaultContext()
        x = np.array([1.125])
        context.SetDiscreteState(xd=2 * x)
        np.testing.assert_equal(
            context.get_discrete_state_vector().CopyToVector(), 2 * x)
        context.SetDiscreteState(group_index=0, xd=3 * x)
        np.testing.assert_equal(
            context.get_discrete_state_vector().CopyToVector(), 3 * x)

        def check_abstract_value_zero(context, expected_value):
            # Check through Context, State, and AbstractValues APIs.
            self.assertEqual(context.get_abstract_state(index=0).get_value(),
                             expected_value)
            self.assertEqual(context.get_abstract_state().get_value(
                index=0).get_value(), expected_value)
            self.assertEqual(context.get_state().get_abstract_state()
                             .get_value(index=0).get_value(), expected_value)

        context.SetAbstractState(index=0, value=True)
        check_abstract_value_zero(context, True)
        context.SetAbstractState(index=0, value=False)
        check_abstract_value_zero(context, False)
        value = context.get_mutable_state().get_mutable_abstract_state()\
            .get_mutable_value(index=0)
        value.set_value(True)
        check_abstract_value_zero(context, True)
예제 #3
0
    def test_context_api(self):
        system = Adder(3, 10)
        context = system.AllocateContext()
        self.assertIsInstance(context.get_continuous_state(), ContinuousState)
        self.assertIsInstance(context.get_mutable_continuous_state(),
                              ContinuousState)
        self.assertIsInstance(context.get_continuous_state_vector(),
                              VectorBase)
        self.assertIsInstance(context.get_mutable_continuous_state_vector(),
                              VectorBase)
        system.SetDefaultContext(context)

        # Check random context method.
        system.SetRandomContext(context=context, generator=RandomGenerator())

        context = system.CreateDefaultContext()
        self.assertIsInstance(context.get_continuous_state(), ContinuousState)
        self.assertIsInstance(context.get_mutable_continuous_state(),
                              ContinuousState)
        self.assertIsInstance(context.get_continuous_state_vector(),
                              VectorBase)
        self.assertIsInstance(context.get_mutable_continuous_state_vector(),
                              VectorBase)
        self.assertTrue(context.is_stateless())
        self.assertFalse(context.has_only_continuous_state())
        self.assertFalse(context.has_only_discrete_state())
        self.assertEqual(context.num_total_states(), 0)
        # TODO(eric.cousineau): Consolidate main API tests for `Context` here.

        # Test methods with two scalar types.
        for T in [float, AutoDiffXd, Expression]:
            systemT = Adder_[T](3, 10)
            contextT = systemT.CreateDefaultContext()
            for U in [float, AutoDiffXd, Expression]:
                systemU = Adder_[U](3, 10)
                contextU = systemU.CreateDefaultContext()
                contextU.SetTime(0.5)
                contextT.SetTimeStateAndParametersFrom(contextU)
                if T == float:
                    self.assertEqual(contextT.get_time(), 0.5)
                elif T == AutoDiffXd:
                    self.assertEqual(contextT.get_time().value(), 0.5)
                else:
                    self.assertEqual(contextT.get_time().Evaluate(), 0.5)

        pendulum = PendulumPlant()
        context = pendulum.CreateDefaultContext()
        self.assertEqual(context.num_numeric_parameter_groups(), 1)
        self.assertEqual(pendulum.num_numeric_parameter_groups(), 1)
        self.assertTrue(context.get_parameters().get_numeric_parameter(0) is
                        context.get_numeric_parameter(index=0))
        self.assertTrue(
            context.get_mutable_parameters().get_mutable_numeric_parameter(
                0) is context.get_mutable_numeric_parameter(index=0))
        self.assertEqual(context.num_abstract_parameters(), 0)
        self.assertEqual(pendulum.num_numeric_parameter_groups(), 1)
        # TODO(russt): Bind _Declare*Parameter or find an example with an
        # abstract parameter to actually call this method.
        self.assertTrue(hasattr(context, "get_abstract_parameter"))
        self.assertTrue(hasattr(context, "get_mutable_abstract_parameter"))
        context.DisableCaching()
        context.EnableCaching()
        context.SetAllCacheEntriesOutOfDate()
        context.FreezeCache()
        self.assertTrue(context.is_cache_frozen())
        context.UnfreezeCache()
        self.assertFalse(context.is_cache_frozen())
        x = np.array([0.1, 0.2])
        context.SetContinuousState(x)
        np.testing.assert_equal(context.get_continuous_state().CopyToVector(),
                                x)
        np.testing.assert_equal(
            context.get_continuous_state_vector().CopyToVector(), x)
        context.SetTimeAndContinuousState(0.3, 2 * x)
        np.testing.assert_equal(context.get_time(), 0.3)
        np.testing.assert_equal(
            context.get_continuous_state_vector().CopyToVector(), 2 * x)
        self.assertNotEqual(pendulum.EvalPotentialEnergy(context=context), 0)
        self.assertNotEqual(pendulum.EvalKineticEnergy(context=context), 0)

        # RimlessWheel has a single discrete variable and a bool abstract
        # variable.
        rimless = RimlessWheel()
        context = rimless.CreateDefaultContext()
        x = np.array([1.125])
        context.SetDiscreteState(xd=2 * x)
        np.testing.assert_equal(
            context.get_discrete_state_vector().CopyToVector(), 2 * x)
        context.SetDiscreteState(group_index=0, xd=3 * x)
        np.testing.assert_equal(
            context.get_discrete_state_vector().CopyToVector(), 3 * x)

        def check_abstract_value_zero(context, expected_value):
            # Check through Context, State, and AbstractValues APIs.
            self.assertEqual(
                context.get_abstract_state(index=0).get_value(),
                expected_value)
            self.assertEqual(
                context.get_abstract_state().get_value(index=0).get_value(),
                expected_value)
            self.assertEqual(
                context.get_state().get_abstract_state().get_value(
                    index=0).get_value(), expected_value)

        context.SetAbstractState(index=0, value=True)
        check_abstract_value_zero(context, True)
        context.SetAbstractState(index=0, value=False)
        check_abstract_value_zero(context, False)
        value = context.get_mutable_state().get_mutable_abstract_state()\
            .get_mutable_value(index=0)
        value.set_value(True)
        check_abstract_value_zero(context, True)
예제 #4
0
    def test_direct_collocation(self):
        plant = PendulumPlant()
        context = plant.CreateDefaultContext()

        dircol = DirectCollocation(
            plant,
            context,
            num_time_samples=21,
            minimum_timestep=0.2,
            maximum_timestep=0.5,
            input_port_index=InputPortSelection.kUseFirstInputIfItExists,
            assume_non_continuous_states_are_fixed=False)

        # Spell out most of the methods, regardless of whether they make sense
        # as a consistent optimization.  The goal is to check the bindings,
        # not the implementation.
        t = dircol.time()
        dt = dircol.timestep(0)
        x = dircol.state()
        x2 = dircol.state(2)
        x0 = dircol.initial_state()
        xf = dircol.final_state()
        u = dircol.input()
        u2 = dircol.input(2)
        v = dircol.NewSequentialVariable(rows=1, name="test")
        v2 = dircol.GetSequentialVariableAtIndex(name="test", index=2)

        dircol.AddRunningCost(x.dot(x))
        dircol.AddConstraintToAllKnotPoints(u[0] == 0)
        dircol.AddTimeIntervalBounds(0.3, 0.4)
        dircol.AddEqualTimeIntervalsConstraints()
        dircol.AddDurationBounds(.3 * 21, 0.4 * 21)
        dircol.AddFinalCost(2 * x.dot(x))

        initial_u = PiecewisePolynomial.ZeroOrderHold([0, .3 * 21],
                                                      np.zeros((1, 2)))
        initial_x = PiecewisePolynomial()
        dircol.SetInitialTrajectory(initial_u, initial_x)

        global input_was_called
        input_was_called = False
        global state_was_called
        state_was_called = False
        global complete_was_called
        complete_was_called = False

        def input_callback(t, u):
            global input_was_called
            input_was_called = True

        def state_callback(t, x):
            global state_was_called
            state_was_called = True

        def complete_callback(t, x, u, v):
            global complete_was_called
            complete_was_called = True

        dircol.AddInputTrajectoryCallback(input_callback)
        dircol.AddStateTrajectoryCallback(state_callback)
        dircol.AddCompleteTrajectoryCallback(callback=complete_callback,
                                             names=["test"])

        result = mp.Solve(dircol)
        self.assertTrue(input_was_called)
        self.assertTrue(state_was_called)
        self.assertTrue(complete_was_called)

        times = dircol.GetSampleTimes(result)
        inputs = dircol.GetInputSamples(result)
        states = dircol.GetStateSamples(result)
        variables = dircol.GetSequentialVariableSamples(result=result,
                                                        name="test")
        input_traj = dircol.ReconstructInputTrajectory(result)
        state_traj = dircol.ReconstructStateTrajectory(result)

        constraint = DirectCollocationConstraint(plant, context)
        AddDirectCollocationConstraint(constraint, dircol.timestep(0),
                                       dircol.state(0), dircol.state(1),
                                       dircol.input(0), dircol.input(1),
                                       dircol)
예제 #5
0
 def test_plant(self):
     pendulum = PendulumPlant()
     self.assertEqual(pendulum.get_input_port(), pendulum.get_input_port(0))
     self.assertEqual(pendulum.get_state_output_port(),
                      pendulum.get_output_port(0))
예제 #6
0
import math
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

from pydrake.all import (DiagramBuilder, SignalLogger, Simulator, VectorSystem)
from pydrake.examples.pendulum import PendulumPlant
from pydrake.systems.controllers import (DynamicProgrammingOptions,
                                         FittedValueIteration,
                                         PeriodicBoundaryCondition)
from visualizer import PendulumVisualizer

plant = PendulumPlant()
simulator = Simulator(plant)
options = DynamicProgrammingOptions()


def min_time_cost(context):
    x = context.get_continuous_state_vector().CopyToVector()
    x[0] = x[0] - math.pi
    if x.dot(x) < .05:
        return 0.
    return 1.


def quadratic_regulator_cost(context):
    x = context.get_continuous_state_vector().CopyToVector()
    x[0] = x[0] - math.pi
    u = plant.EvalVectorInput(context, 0).CopyToVector()
    return 2 * x.dot(x) + u.dot(u)
예제 #7
0
 def test_plant(self):
     pendulum = PendulumPlant()
     self.assertEqual(pendulum.get_input_port(), pendulum.get_input_port(0))
     self.assertEqual(pendulum.get_state_output_port(),
                      pendulum.get_output_port(0))
     context = pendulum.CreateDefaultContext()
     self.assertIsInstance(pendulum.get_state(context=context),
                           PendulumState)
     self.assertIsInstance(pendulum.get_mutable_state(context=context),
                           PendulumState)
     self.assertIsInstance(pendulum.get_parameters(context=context),
                           PendulumParams)
     self.assertIsInstance(pendulum.get_mutable_parameters(context=context),
                           PendulumParams)
    def _do_test_direct_collocation(self, use_deprecated_solve):
        plant = PendulumPlant()
        context = plant.CreateDefaultContext()

        dircol = DirectCollocation(plant, context, num_time_samples=21,
                                   minimum_timestep=0.2, maximum_timestep=0.5)

        # Spell out most of the methods, regardless of whether they make sense
        # as a consistent optimization.  The goal is to check the bindings,
        # not the implementation.
        t = dircol.time()
        dt = dircol.timestep(0)
        x = dircol.state()
        x2 = dircol.state(2)
        x0 = dircol.initial_state()
        xf = dircol.final_state()
        u = dircol.input()
        u2 = dircol.input(2)

        dircol.AddRunningCost(x.dot(x))
        dircol.AddConstraintToAllKnotPoints(u[0] == 0)
        dircol.AddTimeIntervalBounds(0.3, 0.4)
        dircol.AddEqualTimeIntervalsConstraints()
        dircol.AddDurationBounds(.3*21, 0.4*21)
        dircol.AddFinalCost(2*x.dot(x))

        initial_u = PiecewisePolynomial.ZeroOrderHold([0, .3*21],
                                                      np.zeros((1, 2)))
        initial_x = PiecewisePolynomial()
        dircol.SetInitialTrajectory(initial_u, initial_x)

        global input_was_called
        input_was_called = False
        global state_was_called
        state_was_called = False

        def input_callback(t, u):
            global input_was_called
            input_was_called = True

        def state_callback(t, x):
            global state_was_called
            state_was_called = True

        dircol.AddInputTrajectoryCallback(input_callback)
        dircol.AddStateTrajectoryCallback(state_callback)

        if use_deprecated_solve:
            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter('always', DrakeDeprecationWarning)
                dircol.Solve()
                self.assertEqual(len(w), 1)
                result = None
        else:
            result = mp.Solve(dircol)
        self.assertTrue(input_was_called)
        self.assertTrue(state_was_called)

        if use_deprecated_solve:
            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter('always', DrakeDeprecationWarning)
                times = dircol.GetSampleTimes()
                inputs = dircol.GetInputSamples()
                states = dircol.GetStateSamples()
                input_traj = dircol.ReconstructInputTrajectory()
                state_traj = dircol.ReconstructStateTrajectory()
                self.assertEqual(len(w), 5)
        else:
            times = dircol.GetSampleTimes(result)
            inputs = dircol.GetInputSamples(result)
            states = dircol.GetStateSamples(result)
            input_traj = dircol.ReconstructInputTrajectory(result)
            state_traj = dircol.ReconstructStateTrajectory(result)

        constraint = DirectCollocationConstraint(plant, context)
        AddDirectCollocationConstraint(constraint, dircol.timestep(0),
                                       dircol.state(0), dircol.state(1),
                                       dircol.input(0), dircol.input(1),
                                       dircol)
def do_single_basic_pend_traj_opt(should_init):
    plant = PendulumPlant()
    context = plant.CreateDefaultContext()

    kNumTimeSamples = 15
    kMinimumTimeStep = 0.01
    kMaximumTimeStep = 0.2

    dircol = DirectCollocation(plant, context, kNumTimeSamples, kMinimumTimeStep, kMaximumTimeStep)

    dircol.AddEqualTimeIntervalsConstraints()

    kTorqueLimit = 3.0  # N*m.
    u = dircol.input()
    dircol.AddConstraintToAllKnotPoints(-kTorqueLimit <= u[0])
    dircol.AddConstraintToAllKnotPoints(u[0] <= kTorqueLimit)

    initial_state = PendulumState()
    initial_state.set_theta(0.0)
    initial_state.set_thetadot(0.0)
    dircol.AddBoundingBoxConstraint(initial_state.get_value(), initial_state.get_value(), dircol.initial_state())
    # dircol.AddLinearConstraint(dircol.initial_state() == initial_state.get_value())

    final_state = PendulumState()
    final_state.set_theta(math.pi)
    final_state.set_thetadot(0.0)
    dircol.AddBoundingBoxConstraint(final_state.get_value(), final_state.get_value(), dircol.final_state())
    # dircol.AddLinearConstraint(dircol.final_state() == final_state.get_value())

    R = 10  # Cost on input "effort".
    dircol.AddRunningCost(R*u[0]**2)

    if should_init:
        initial_x_trajectory = PiecewisePolynomial.FirstOrderHold([0., 4.], [initial_state.get_value(), final_state.get_value()])
        dircol.SetInitialTrajectory(PiecewisePolynomial(), initial_x_trajectory)

    def cb(decision_vars):
        global cb_counter
        cb_counter += 1
        if cb_counter % 10 != 1:
            return

        # Get the total cost
        all_costs = dircol.EvalBindings(dircol.GetAllCosts(), decision_vars)
        # :all_constraints = dircol.EvalBindings(dircol.GetAllConstraints(), decision_vars)

        # Get the total cost of the constraints.
        # Additionally, the number and extent of any constraint violations.
        violated_constraint_count = 0
        violated_constraint_cost  = 0
        constraint_cost           = 0
        for constraint in dircol.GetAllConstraints():
            val = dircol.EvalBinding(constraint, decision_vars)

            # TODO: switch to DoCheckSatisfied...
            nudge = 1e-1 # This much constraint violation is not considered bad...
            lb = constraint.evaluator().lower_bound()
            ub = constraint.evaluator().upper_bound()
            good_lb = np.all( np.less_equal(lb, val+nudge) )
            good_ub = np.all( np.greater_equal(ub, val-nudge) )
            if not good_lb or not good_ub:
                # print("val,lb,ub: ", val, lb, ub)
                violated_constraint_count += 1
                violated_constraint_cost += np.sum(val)
            constraint_cost += np.sum(val)
        print("total cost: {:.2f} | constraint {:.2f}, bad {}, {:.2f}".format(
            sum(all_costs), constraint_cost, violated_constraint_count, violated_constraint_cost))
#    dircol.AddVisualizationCallback(cb, np.array(dircol.decision_variables()))
    result = dircol.Solve()
    assert(result == SolutionResult.kSolutionFound)

    sol_costs = np.hstack([dircol.EvalBindingAtSolution(cost) for cost in dircol.GetAllCosts()])
    sol_constraints = np.hstack([dircol.EvalBindingAtSolution(constraint) for constraint in dircol.GetAllConstraints()])
예제 #10
0
def make_multiple_dircol_trajectories(num_trajectories,
                                      num_samples,
                                      initial_conditions=None):
    from pydrake.all import (
        AutoDiffXd,
        Expression,
        Variable,
        MathematicalProgram,
        SolverType,
        SolutionResult,
        DirectCollocationConstraint,
        AddDirectCollocationConstraint,
        PiecewisePolynomial,
    )
    import pydrake.symbolic as sym
    from pydrake.examples.pendulum import (PendulumPlant)

    # initial_conditions maps (ti) -> [1xnum_states] initial state
    if initial_conditions is not None:
        initial_conditions = intial_cond_dict[initial_conditions]
        assert callable(initial_conditions)

    plant = PendulumPlant()
    context = plant.CreateDefaultContext()
    dircol_constraint = DirectCollocationConstraint(plant, context)

    # num_trajectories = 15;
    # num_samples = 15;
    prog = MathematicalProgram()

    # K = prog.NewContinuousVariables(1,7,'K')

    def cos(x):
        if isinstance(x, AutoDiffXd):
            return x.cos()
        elif isinstance(x, Variable):
            return sym.cos(x)
        return math.cos(x)

    def sin(x):
        if isinstance(x, AutoDiffXd):
            return x.sin()
        elif isinstance(x, Variable):
            return sym.sin(x)
        return math.sin(x)

    def final_cost(x):
        return 100. * (cos(.5 * x[0])**2 + x[1]**2)

    h = []
    u = []
    x = []
    xf = (math.pi, 0.)
    for ti in range(num_trajectories):
        h.append(prog.NewContinuousVariables(1))
        prog.AddBoundingBoxConstraint(.01, .2, h[ti])
        # prog.AddQuadraticCost([1.], [0.], h[ti]) # Added by me, penalize long timesteps
        u.append(prog.NewContinuousVariables(1, num_samples, 'u' + str(ti)))
        x.append(prog.NewContinuousVariables(2, num_samples, 'x' + str(ti)))

        # Use Russ's initial conditions, unless I pass in a function myself.
        if initial_conditions is None:
            x0 = (.8 + math.pi - .4 * ti, 0.0)
        else:
            x0 = initial_conditions(ti)
            assert len(x0) == 2  #TODO: undo this hardcoding.
        prog.AddBoundingBoxConstraint(x0, x0, x[ti][:, 0])

        nudge = np.array([.2, .2])
        prog.AddBoundingBoxConstraint(xf - nudge, xf + nudge, x[ti][:, -1])
        # prog.AddBoundingBoxConstraint(xf, xf, x[ti][:,-1])

        for i in range(num_samples - 1):
            AddDirectCollocationConstraint(dircol_constraint, h[ti],
                                           x[ti][:, i], x[ti][:, i + 1],
                                           u[ti][:, i], u[ti][:, i + 1], prog)

        for i in range(num_samples):
            prog.AddQuadraticCost([1.], [0.], u[ti][:, i])
    #        prog.AddConstraint(control, [0.], [0.], np.hstack([x[ti][:,i], u[ti][:,i], K.flatten()]))
    #        prog.AddBoundingBoxConstraint([-3.], [3.], u[ti][:,i])
    #        prog.AddConstraint(u[ti][0,i] == (3.*sym.tanh(K.dot(control_basis(x[ti][:,i]))[0])))  # u = 3*tanh(K * m(x))

    # prog.AddCost(final_cost, x[ti][:,-1])
    # prog.AddCost(h[ti][0]*100) # Try to penalize using more time than it needs?

    #prog.SetSolverOption(SolverType.kSnopt, 'Verify level', -1)  # Derivative checking disabled. (otherwise it complains on the saturation)
    prog.SetSolverOption(SolverType.kSnopt, 'Print file', "/tmp/snopt.out")

    # result = prog.Solve()
    # print(result)
    # print(prog.GetSolution(K))
    # print(prog.GetSolution(K).dot(control_basis(x[0][:,0])))

    #if (result != SolutionResult.kSolutionFound):
    #    f = open('/tmp/snopt.out', 'r')
    #    print(f.read())
    #    f.close()
    return prog, h, u, x