Ejemplo n.º 1
0
def test_weighted_negative_log_likelihood_vs_softmax_cross_entropy(
        data: st.DataObject, labels_as_tensor: bool):
    s = data.draw(
        hnp.arrays(
            shape=hnp.array_shapes(min_side=1,
                                   max_side=10,
                                   min_dims=2,
                                   max_dims=2),
            dtype=float,
            elements=st.floats(-100, 100),
        ))
    y_true = data.draw(
        hnp.arrays(
            shape=(s.shape[0], ),
            dtype=hnp.integer_dtypes(),
            elements=st.integers(min_value=0, max_value=s.shape[1] - 1),
        ).map(Tensor if labels_as_tensor else lambda x: x))
    weights = data.draw(
        hnp.arrays(
            shape=(s.shape[1], ),
            dtype=float,
            elements=st.floats(1e-8, 100),
        ))
    scores = Tensor(s)
    weights = Tensor(weights)

    for score, y in zip(scores, y_true):
        score = mg.log(mg.nnet.softmax(score.reshape(1, -1)))
        y = y.reshape(-1)
        nll = negative_log_likelihood(score, y)
        weighted_nll = negative_log_likelihood(score, y, weights=weights)
        assert np.isclose(weighted_nll.data, weights[y.data].data * nll.data)
Ejemplo n.º 2
0
def test_padding(ndim: int, data: st.DataObject):
    """Ensure that convolving a padding-only image with a commensurate kernel yields the single entry: 0"""
    padding = data.draw(st.integers(1, 3)
                        | st.tuples(*[st.integers(1, 3)] * ndim),
                        label="padding")
    x = Tensor(
        data.draw(
            hnp.arrays(shape=(1, 1) + (0, ) * ndim,
                       dtype=float,
                       elements=st.floats()),
            label="x",
        ))
    pad_tuple = padding if isinstance(padding, tuple) else (padding, ) * ndim
    kernel = data.draw(
        hnp.arrays(
            shape=(1, 1) + tuple(2 * p for p in pad_tuple),
            dtype=float,
            elements=st.floats(allow_nan=False, allow_infinity=False),
        ))
    out = conv_nd(x, kernel, padding=padding, stride=1)
    assert out.shape == (1, ) * x.ndim
    assert out.item() == 0.0

    out.sum().backward()
    assert x.grad.shape == x.shape
Ejemplo n.º 3
0
def test_concat_experiments(logger: LiveLogger, num_exps: int,
                            data: st.DataObject):
    metrics = list(logger.train_metrics)
    assume(len(metrics) > 0)

    logger.set_train_batch(
        {k: data.draw(st.floats(-1e6, 1e6))
         for k in metrics}, batch_size=1)
    batch_xarrays = [logger.to_xarray("train")[0]]

    for n in range(num_exps - 1):
        logger.set_train_batch(
            {k: data.draw(st.floats(-1e6, 1e6))
             for k in metrics},
            batch_size=1)
        batch_xarrays.append(logger.to_xarray("train")[0])

    out = concat_experiments(*batch_xarrays)
    assert list(out.coords["experiment"]) == list(range(num_exps))
    assert list(out.data_vars) == list(metrics)

    for n in range(num_exps):
        for metric in metrics:
            assert_equal(
                batch_xarrays[n].to_array(metric),
                out.isel(experiment=n).drop_vars(
                    names=["experiment"]).to_array(metric).dropna(
                        dim="iterations"),
            )
Ejemplo n.º 4
0
def test_negative_log_likelihood_vs_softmax_cross_entropy(
        data: st.DataObject, labels_as_tensor: bool):
    s = data.draw(
        hnp.arrays(
            shape=hnp.array_shapes(max_side=10, min_dims=2, max_dims=2),
            dtype=float,
            elements=st.floats(-100, 100),
        ))
    y_true = data.draw(
        hnp.arrays(
            shape=(s.shape[0], ),
            dtype=hnp.integer_dtypes(),
            elements=st.integers(min_value=0, max_value=s.shape[1] - 1),
        ).map(Tensor if labels_as_tensor else lambda x: x))
    scores = Tensor(s)
    nll = negative_log_likelihood(mg.log(mg.nnet.softmax(scores)), y_true)
    nll.backward()

    cross_entropy_scores = Tensor(s)
    ce = softmax_crossentropy(cross_entropy_scores, y_true)
    ce.backward()

    assert_allclose(nll.data, ce.data, atol=1e-5, rtol=1e-5)
    assert_allclose(scores.grad,
                    cross_entropy_scores.grad,
                    atol=1e-5,
                    rtol=1e-5)
Ejemplo n.º 5
0
def test_get_obs_has_correct_shape(data: st.DataObject) -> None:
    """ Make sure that a returned observation has the correct shape. """
    env = data.draw(bst.envs())
    env.reset()
    pos = data.draw(bst.positions(env=env))
    ob = env._get_obs(pos)
    assert ob.shape == env.observation_space.shape
Ejemplo n.º 6
0
def test_pairwise_dists_is_translation_invariant(
        shapes: hnp.BroadcastableShapes, data: st.DataObject):
    shape_a, shape_b = shapes.input_shapes

    offset = data.draw(
        hnp.arrays(shape=(shape_a[1], ),
                   dtype=np.float64,
                   elements=st.floats(-1e2, 1e2)),
        label="offset",
    )
    array_a = data.draw(
        hnp.arrays(shape=shape_a,
                   dtype=np.float64,
                   elements=st.floats(-1e3, 1e3)),
        label="array_a",
    )

    array_b = data.draw(
        hnp.arrays(shape=shape_b,
                   dtype=np.float64,
                   elements=st.floats(-1e3, 1e3)),
        label="array_b",
    )

    dists = pairwise_dists(array_a, array_b)
    dists_w_offset = pairwise_dists(array_a + offset, array_b + offset)

    assert_allclose(dists, dists_w_offset, atol=1e-4, rtol=1e-4)
Ejemplo n.º 7
0
def test_consume_removes_nothing_else(data: st.DataObject) -> None:
    """ Otherwise, food should remain in place. """
    env = data.draw(bst.envs())
    env.reset()
    food_obj_type_id = env.obj_type_ids["food"]
    agent_food_positions: Dict[int, Tuple[int, int]] = {}
    for agent_id, agent in env.agents.items():
        if env.grid[agent.pos + (food_obj_type_id,)] == 1:
            agent_food_positions[agent_id] = agent.pos
    tuple_action_dict = data.draw(bst.tuple_action_dicts(env=env))

    eating_positions: List[Tuple[int, int]] = []
    for agent_id, pos in agent_food_positions.items():
        if tuple_action_dict[agent_id][1] == env.EAT:
            eating_positions.append(pos)

    food_positions: Set[Tuple[int, int]] = set()
    for x in range(env.width):
        for y in range(env.height):
            if env.grid[(x, y) + (food_obj_type_id,)] == 1:
                food_positions.add((x, y))
    persistent_food_positions = food_positions - set(eating_positions)

    env._consume(tuple_action_dict)

    for pos in persistent_food_positions:
        assert env.grid[pos + (food_obj_type_id,)] == 1
Ejemplo n.º 8
0
def test_move_correctly_modifies_agent_state(data: st.DataObject) -> None:
    """ Makes sure they actually move or STAY. """
    # TODO: Handle out-of-bounds errors.
    # TODO: Consider making the environment toroidal.
    env = data.draw(bst.envs())
    env.reset()
    old_locations: Dict[int, Tuple[int, int]] = {}
    for agent_id, agent in env.agents.items():
        old_locations[agent_id] = agent.pos
    tuple_action_dict = data.draw(bst.tuple_action_dicts(env=env))
    executed_dict = env._move(tuple_action_dict)

    # TODO: Consider making ``env.LEFT``, etc tuples which can be added to existing
    # positions rather than just integers.
    for agent_id, action in executed_dict.items():
        agent = env.agents[agent_id]
        move = action[0]
        old_pos = old_locations[agent_id]
        if move == env.UP or move == env.DOWN:
            assert agent.pos[0] == old_pos[0]
        if move == env.LEFT or move == env.RIGHT:
            assert agent.pos[1] == old_pos[1]
        if move == env.UP:
            assert agent.pos[1] == old_pos[1] + 1
        if move == env.DOWN:
            assert agent.pos[1] == old_pos[1] - 1
        if move == env.RIGHT:
            assert agent.pos[0] == old_pos[0] + 1
        if move == env.LEFT:
            assert agent.pos[0] == old_pos[0] - 1
        if move == env.STAY:
            assert agent.pos == old_pos
def test_softmax_crossentropy(data: st.DataObject, labels_as_tensor: bool):
    s = data.draw(
        hnp.arrays(
            shape=hnp.array_shapes(max_side=10, min_dims=2, max_dims=2),
            dtype=float,
            elements=st.floats(-100, 100),
        ))
    y_true = data.draw(
        hnp.arrays(
            shape=(s.shape[0], ),
            dtype=hnp.integer_dtypes(),
            elements=st.integers(min_value=0, max_value=s.shape[1] - 1),
        ).map(Tensor if labels_as_tensor else lambda x: x))
    scores = Tensor(s)
    softmax_cross = softmax_crossentropy(scores, y_true, constant=False)
    softmax_cross.backward()

    mygrad_scores = Tensor(s)
    probs = softmax(mygrad_scores)

    correct_labels = (range(len(y_true)),
                      y_true.data if labels_as_tensor else y_true)
    truth = np.zeros(mygrad_scores.shape)
    truth[correct_labels] = 1

    mygrad_cross = (-1 / s.shape[0]) * (log(probs) * truth).sum()
    mygrad_cross.backward()
    assert_allclose(softmax_cross.data,
                    mygrad_cross.data,
                    atol=1e-5,
                    rtol=1e-5)
    assert_allclose(scores.grad, mygrad_scores.grad, atol=1e-5, rtol=1e-5)
Ejemplo n.º 10
0
def test_obj_exists_handles_out_of_grid_positions(data: st.DataObject) -> None:
    """ Make sure the correct error is raised. """
    raised_index_error = False
    raised_value_error = False
    index_error = None
    value_error = None
    env = data.draw(bst.envs())
    obj_type_id = data.draw(bst.obj_type_ids(env=env))
    pos_indices = st.integers(min_value=-100, max_value=100)
    pos_strategy = st.tuples(pos_indices, pos_indices)
    pos: Tuple[int, int] = data.draw(pos_strategy)  # type: ignore
    try:
        existence = env._obj_exists(obj_type_id, pos)
    except IndexError as err:
        raised_index_error = True
        index_error = err
    except ValueError as err:
        raised_value_error = True
        value_error = err

    # If pos in in the grid, should raise no errors.
    if 0 <= pos[0] < env.width and 0 <= pos[1] < env.height:
        try:
            assert not raised_index_error
        except AssertionError:
            raise index_error  # type: ignore

    # If negative, should catch and raise a ValueError.
    elif pos[0] < 0 or pos[1] < 0:
        assert raised_value_error

    # Otherwise, currently raises an IndexError.
    # TODO: Add an error message for this?
    else:
        assert raised_index_error
Ejemplo n.º 11
0
def test_get_obs_has_no_out_of_range_elements(data: st.DataObject) -> None:
    """ Make sure that a returned observation only contains 0 or 1. """
    env = data.draw(bst.envs())
    env.reset()
    pos = data.draw(bst.positions(env=env))
    ob = env._get_obs(pos)
    for elem in np.nditer(ob):
        assert elem in (0.0, 1.0)
Ejemplo n.º 12
0
def test_mate_makes_num_agents_nondecreasing(data: st.DataObject) -> None:
    """ Makes sure ``len(agents)`` is nondecreasing. """
    env = data.draw(bst.envs())
    env.reset()
    old_num_agents = len(env.agents)
    tuple_action_dict = data.draw(bst.tuple_action_dicts(env=env))
    env._mate(tuple_action_dict)
    assert old_num_agents <= len(env.agents)
Ejemplo n.º 13
0
def test_mate_adds_children_to_agents(data: st.DataObject) -> None:
    """ Makes sure child ids get added to ``env.agents``. """
    env = data.draw(bst.envs())
    env.reset()
    tuple_action_dict = data.draw(bst.tuple_action_dicts(env=env))
    child_ids = env._mate(tuple_action_dict)
    for child_id in child_ids:
        assert child_id in env.agents
Ejemplo n.º 14
0
def test_setting_color_for_non_metric_is_silent(plotter: LivePlot,
                                                data: st.DataObject):
    color = {
        data.draw(st.text(), label="non_metric"):
        data.draw(cst.matplotlib_colors(), label="color")
    }
    original_colors = plotter.metric_colors
    plotter.metric_colors = color
    assert plotter.metric_colors == original_colors
Ejemplo n.º 15
0
def test_mate_children_are_new(data: st.DataObject) -> None:
    """ Makes sure children are new. """
    env = data.draw(bst.envs())
    env.reset()
    old_agent_memory_addresses = [id(agent) for agent in env.agents.values()]
    tuple_action_dict = data.draw(bst.tuple_action_dicts(env=env))
    child_ids = env._mate(tuple_action_dict)
    for child_id in child_ids:
        assert id(env.agents[child_id]) not in old_agent_memory_addresses
Ejemplo n.º 16
0
def test_setitem_basic_index(x: np.ndarray, data: st.DataObject):
    """ index conforms strictly to basic indexing """
    index = data.draw(basic_indices(x.shape), label="index")
    o = np.asarray(x[index])

    note("x[index]: {}".format(o))
    y = data.draw(
        (
            hnp.arrays(
                # Permit shapes that are broadcast-compatible with x[index]
                # The only excess dimensions permitted in this shape are
                # leading singletons
                shape=broadcastable_shapes(o.shape).map(
                    lambda _x: tuple(
                        1 if (len(_x) - n) > o.ndim else s for n, s in enumerate(_x)
                    )
                ),
                dtype=float,
                elements=st.floats(-10.0, 10.0),
            )
            if o.shape and o.size
            else st.floats(-10.0, 10.0).map(lambda _x: np.array(_x))
        ),
        label="y",
    )

    x0 = np.copy(x)
    y0 = np.copy(y)

    x_arr = Tensor(np.copy(x))
    y_arr = Tensor(np.copy(y))
    x1_arr = x_arr[:]

    try:
        x0[index] = y0  # don't permit invalid set-items
    except ValueError:
        assume(False)
        return

    grad = data.draw(
        hnp.arrays(shape=x.shape, dtype=float, elements=st.floats(1, 10), unique=True),
        label="grad",
    )

    x1_arr[index] = y_arr
    (x1_arr * grad).sum().backward()

    assert_allclose(x1_arr.data, x0)
    assert_allclose(y_arr.data, y0)

    dx, dy = numerical_gradient_full(
        setitem, x, y, back_grad=grad, kwargs=dict(index=index)
    )

    assert_allclose(x_arr.grad, dx)
    assert_allclose(y_arr.grad, dy)
Ejemplo n.º 17
0
    def factory(source: st.DataObject) -> None:
        if _pyversion < (3, 7):  # pragma: no cover
            raise RuntimeError(
                'Hypothesis does not support several important ' +
                'typing features on python3.6, and earlier versions. ' +
                'Please update to at least python3.7', )

        with type_vars():
            with pure_functions():
                with container_strategies(container_type, settings=settings):
                    source.draw(st.builds(law.definition))
Ejemplo n.º 18
0
def test_move_holds_other_actions_invariant(data: st.DataObject) -> None:
    """ Makes sure the returned action dict only modifies move subaction space. """
    env = data.draw(bst.envs())
    env.reset()

    tuple_action_dict = data.draw(bst.tuple_action_dicts(env=env))
    executed_dict = env._move(tuple_action_dict)

    pairs = zip(list(tuple_action_dict.values()), list(executed_dict.values()))
    for attempted_action, executed_action in pairs:
        assert attempted_action[1:] == executed_action[1:]
Ejemplo n.º 19
0
    def update_pos(self, data: st.DataObject) -> None:
        pos = data.draw(bst.positions(env=self.env))
        move = data.draw(bst.moves(env=self.env))
        new_pos = self.env._update_pos(pos, move)

        if pos[0] != new_pos[0]:
            assert pos[1] == new_pos[1]
            assert abs(pos[0] - new_pos[0]) == 1
        if pos[1] != new_pos[1]:
            assert pos[0] == new_pos[0]
            assert abs(pos[1] - new_pos[1]) == 1
Ejemplo n.º 20
0
def test_to_partitions(data: st.DataObject, sum_: Real) -> None:
    min_value = data.draw(st.floats(0, sum_))
    size = data.draw(
        st.integers(MIN_PARTITION_SIZE,
                    min(floor(sum_ / min_value), 100) if min_value else 100))
    strategy = requirements(sum_, min_value=min_value, size=size)

    partition = data.draw(strategy)

    assert sum(partition) == sum_
    assert len(partition) == size
    assert all(part >= min_value for part in partition)
Ejemplo n.º 21
0
def test_move_only_changes_to_stay(data: st.DataObject) -> None:
    """ Makes sure the returned action dict only changes to STAY if at all. """
    env = data.draw(bst.envs())
    env.reset()

    tuple_action_dict = data.draw(bst.tuple_action_dicts(env=env))
    executed_dict = env._move(tuple_action_dict)

    pairs = zip(list(tuple_action_dict.values()), list(executed_dict.values()))
    for attempted_action, executed_action in pairs:
        if attempted_action[0] != executed_action[0]:
            assert executed_action[0] == env.STAY
Ejemplo n.º 22
0
def test_obj_exists_marks_grid_squares_correctly(data: st.DataObject) -> None:
    """ Make sure emptiness of grid matches function output. """
    env = data.draw(bst.envs())
    env.reset()
    obj_type_id = data.draw(bst.obj_type_ids(env=env))

    for x in range(env.width):
        for y in range(env.height):
            if env.grid[x][y][obj_type_id] == 0:
                assert not env._obj_exists(obj_type_id, (x, y))
            else:
                assert env._obj_exists(obj_type_id, (x, y))
Ejemplo n.º 23
0
def test_env_place_no_double_place_homo(data: st.DataObject) -> None:
    """ Tests that env gets angry if you try to double up h**o objs. """
    env = data.draw(strategies.envs())
    pos = data.draw(strategies.positions(env=env))
    env.reset()
    homo_obj_type_ids = set(
        env.obj_type_ids.values()) - env.heterogeneous_obj_type_ids
    obj_type_id = data.draw(st.sampled_from(list(homo_obj_type_ids)))

    if not env._obj_exists(obj_type_id, pos):
        env._place(obj_type_id, pos)
        with pytest.raises(ValueError):
            env._place(obj_type_id, pos)
Ejemplo n.º 24
0
def test_upcast_roundtrip(type_strategy, data: st.DataObject):
    thin, wide = data.draw(
        st.tuples(type_strategy, type_strategy).map(
            lambda x: sorted(x, key=lambda y: np.dtype(y).itemsize)))
    orig_tensor = data.draw(
        hnp.arrays(
            dtype=thin,
            shape=hnp.array_shapes(),
            elements=hnp.from_dtype(thin).filter(np.isfinite),
        ).map(Tensor))

    roundtripped_tensor = orig_tensor.astype(wide).astype(thin)
    assert_array_equal(orig_tensor, roundtripped_tensor)
def test_td_env(data: st.DataObject) -> None:
    """ Test. """
    tempdir = tempfile.mkdtemp()
    ox = data.draw(configs(tempdir))
    actions = data.draw(st.lists(st.integers(min_value=0, max_value=1)))
    env = TDEmissionsEnv(ox)
    ob = env.reset()
    assert ob.shape == (ox.resolution + 1, )
    for act in actions:
        ob, _rew, done, _info = env.step(act)
        assert ob.shape == (ox.resolution + 1, )
        if done:
            break
    shutil.rmtree(tempdir)
Ejemplo n.º 26
0
def test_consume_makes_agent_health_nondecreasing(data: st.DataObject) -> None:
    """ Tests that agent.health in the correct direction. """
    env = data.draw(bst.envs())
    env.reset()
    tuple_action_dict = data.draw(bst.tuple_action_dicts(env=env))

    old_healths: Dict[int, float] = {}
    for agent_id, agent in env.agents.items():
        old_healths[agent_id] = agent.health

    env._consume(tuple_action_dict)

    for agent_id, agent in env.agents.items():
        assert old_healths[agent_id] <= agent.health
Ejemplo n.º 27
0
def test_seq_mult(shape_1: Tuple[int, ...], num_arrays: int,
                  data: st.DataObject):
    shape_2 = data.draw(hnp.broadcastable_shapes(shape_1), label="shape_2")
    shapes = [shape_1, shape_2]

    pair = shapes

    for i in range(num_arrays):

        # ensure sequence of shapes is mutually-broadcastable
        broadcasted = _broadcast_shapes(*pair)
        shapes.append(
            data.draw(hnp.broadcastable_shapes(broadcasted),
                      label="shape_{}".format(i + 3)))
        pair = [broadcasted, shapes[-1]]

    tensors = [
        Tensor(
            data.draw(
                hnp.arrays(shape=shape,
                           dtype=np.float32,
                           elements=st.floats(-10, 10, width=32))))
        for shape in shapes
    ]
    note("tensors: {}".format(tensors))
    tensors_copy = [x.copy() for x in tensors]

    f = multiply_sequence(*tensors)
    f1 = reduce(lambda x, y: x * y,
                (var for n, var in enumerate(tensors_copy)))

    assert_allclose(f.data, f1.data)

    f.sum().backward()
    f1.sum().backward()

    assert_allclose(f.data, f1.data, rtol=1e-4, atol=1e-4)

    for n, (expected, actual) in enumerate(zip(tensors_copy, tensors)):
        assert_allclose(
            expected.grad,
            actual.grad,
            rtol=1e-3,
            atol=1e-3,
            err_msg="tensor-{}".format(n),
        )

    f.null_gradients()
    assert all(x.grad is None for x in tensors)
    assert all(not x._ops for x in tensors)
Ejemplo n.º 28
0
def test_EKFCell_forward(data: st.DataObject, is_continuous: bool):
    batch_dim = data.draw(st.integers(min_value=1, max_value=16))
    torch.manual_seed(data.draw(st.integers(min_value=0, max_value=10)))

    network = LinearSystem(2, 2)
    ekf = EKFCell(network, None, is_continuous=is_continuous)
    ekf_z0 = ekf.get_initial_hidden_state(batch_dim)
    if is_continuous:
        ekf_forward = odeint(wrap_zero_input(ekf), ekf_z0,
                             ekf_z0.new_tensor([0, 0.1]))[1, :, :]
    else:
        ekf_forward = ekf.forward(ekf_z0.new_tensor([0]), ekf_z0,
                                  torch.zeros(batch_dim, 0))
    mean, cov = ekf.vector_to_gaussian_parameters(ekf_forward)
    assert_tensor_symmetric_psd(cov)
Ejemplo n.º 29
0
def test_get_obs_has_correct_objects(data: st.DataObject) -> None:
    """ Make sure that a returned observation is accurate w.r.t. ``env.grid``. """
    env = data.draw(bst.envs())
    env.reset()
    pos = data.draw(bst.positions(env=env))
    ob = env._get_obs(pos)
    for i in range(ob.shape[1]):
        for j in range(ob.shape[2]):
            ob_pos = (pos[0] + i - env.sight_len, pos[1] + j - env.sight_len)
            if 0 <= ob_pos[0] < env.width and 0 <= ob_pos[1] < env.height:
                ob_square = ob[:, i, j]
                env_square = env.grid[ob_pos]
                assert np.all(ob_square == env_square)
            else:
                assert np.all(ob[:, i, j] == np.zeros((env.num_obj_types,)))
Ejemplo n.º 30
0
def test_mate_executes_action(data: st.DataObject) -> None:
    """ Tests children are created when they're suppsed to. """
    env = data.draw(bst.envs())

    assume(env.height * env.width >= 3)

    # Generate two adjacent positions.
    mom_pos = data.draw(bst.positions(env=env))
    open_positions = env._get_adj_positions(mom_pos)
    dad_pos = data.draw(st.sampled_from(open_positions))

    # Create a mom and dad.
    mom = Agent(
        config=env.config, num_actions=env.num_actions, pos=mom_pos, initial_health=1.0,
    )
    dad = Agent(
        config=env.config, num_actions=env.num_actions, pos=dad_pos, initial_health=1.0,
    )
    mom.is_mature = True
    dad.is_mature = True
    mom.mating_cooldown = 0
    dad.mating_cooldown = 0
    mom_id = env._new_agent_id()
    dad_id = env._new_agent_id()
    env.agents[mom_id] = mom
    env.agents[dad_id] = dad
    env._place(env.obj_type_ids["agent"], mom_pos, mom_id)
    env._place(env.obj_type_ids["agent"], dad_pos, dad_id)

    # Construct subactions.
    mom_move = data.draw(bst.moves(env=env))
    dad_move = data.draw(bst.moves(env=env))
    mom_consumption = data.draw(bst.consumptions(env=env))
    dad_consumption = data.draw(bst.consumptions(env=env))
    mom_action = (mom_move, mom_consumption, env.MATE)
    dad_action = (dad_move, dad_consumption, env.MATE)

    action_dict = {mom_id: mom_action, dad_id: dad_action}
    child_ids = env._mate(action_dict)
    assert len(child_ids) == 1
    child = env.agents[child_ids.pop()]

    def adjacent(pos1: Tuple[int, int], pos2: Tuple[int, int]) -> bool:
        """ Decide whether or not two positions are orthogonally adjacent. """
        return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) == 1

    assert len(env.agents) == 3
    assert adjacent(child.pos, mom.pos) or adjacent(child.pos, dad.pos)