Ejemplo n.º 1
0
def test_split_stats_EMA_random_capped() -> None:
    """
    Test that `get_split_statistics()` correctly computes the z-score over the pairwise
    differences in task gradients in the case of random gradients at each time step,
    using both arithmetic mean and EMA to keep track of gradient statistics, and not
    capping the gradient statistic sample size.
    """

    # Set up case.
    settings = dict(V1_SETTINGS)
    settings["obs_dim"] = 2
    settings["num_tasks"] = 4
    settings["cap_sample_size"] = False
    settings["ema_alpha"] = 0.99
    settings["hidden_size"] = settings["obs_dim"] + settings["num_tasks"] + 2
    ema_threshold = alpha_to_threshold(settings["ema_alpha"])

    # Construct series of splits.
    splits_args = []

    # Construct a sequence of task gradients. The network gradient statistics will be
    # updated with these task gradients, and the z-scores will be computed from these
    # statistics.
    total_steps = ema_threshold + 20
    dim = settings["obs_dim"] + settings["num_tasks"]
    max_region_size = settings["hidden_size"]**2 + settings["hidden_size"]
    task_grads = torch.zeros(total_steps, settings["num_tasks"],
                             settings["num_layers"], max_region_size)
    for region in product(range(settings["num_layers"])):
        if region == 0:
            region_size = settings["hidden_size"] * (dim + 1)
        elif region == settings["num_layers"] - 1:
            region_size = dim * (settings["hidden_size"] + 1)
        else:
            region_size = max_region_size

        task_grads[:, :,
                   region, :region_size] = torch.rand(total_steps,
                                                      settings["num_tasks"], 1,
                                                      region_size)

    # Run test.
    split_stats_template(settings, task_grads, splits_args)
Ejemplo n.º 2
0
"""
Unit tests for meta/utils/estimate.py.
"""

import torch

from meta.utils.estimate import RunningStats, alpha_to_threshold


TOL = 1e-5
EMA_ALPHA = 0.999
EMA_THRESHOLD = alpha_to_threshold(EMA_ALPHA)


def test_mean_arithmetic():
    """ Test computation of arithmetic mean in RunningStats. """

    # Set up case.
    shape = (3, 3)
    data = []
    data.append(torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
    data.append(torch.Tensor([[0, -1, -2], [-3, -4, -5], [-6, -7, -8]]))
    data.append(torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
    data = torch.stack(data)

    # Perform and check computation.
    mean = RunningStats(shape=shape, ema_alpha=EMA_ALPHA)
    for i in range(len(data)):
        mean.update(data[i])
        assert torch.allclose(mean.mean, torch.mean(data[0 : i + 1], dim=0))
Ejemplo n.º 3
0
def split_stats_template(
    settings: Dict[str, Any],
    task_grads: torch.Tensor,
    splits_args: List[Dict[str, Any]],
) -> None:
    """
    Test that `get_split_statistics()` correctly computes the z-score over the pairwise
    differences in task gradients, assuming that none of the task gradients are zero
    across an entire task.

    Arguments
    ---------
    settings : Dict[str, Any]
        Dictionary holding misc settings for how to run trial.
    task_grads : torch.Tensor
        Tensor of size `(total_steps, network.num_tasks, network.num_regions,
        network.max_region_size)` which holds the task gradients for multiple steps that
        we will compute statistics over.
    splits_args : List[Dict[str, Any]
        List of splits to execute on network.
    """

    dim = settings["obs_dim"] + settings["num_tasks"]

    # Construct network.
    network = MultiTaskSplittingNetworkV1(
        input_size=dim,
        output_size=dim,
        num_tasks=settings["num_tasks"],
        num_layers=settings["num_layers"],
        hidden_size=settings["hidden_size"],
        grad_var=settings["grad_var"],
        cap_sample_size=settings["cap_sample_size"],
        ema_alpha=settings["ema_alpha"],
        device=settings["device"],
    )
    ema_threshold = alpha_to_threshold(settings["ema_alpha"])

    # Split the network according to `splits_args`.
    for split_args in splits_args:
        network.split(**split_args)

    # Check that the region sizes are what we think they are.
    expected_region_sizes = torch.zeros(settings["num_layers"],
                                        dtype=torch.long)
    expected_region_sizes[
        1:-1] = settings["hidden_size"]**2 + settings["hidden_size"]
    expected_region_sizes[0] = settings["hidden_size"] * (dim + 1)
    expected_region_sizes[-1] = dim * (settings["hidden_size"] + 1)
    assert torch.all(expected_region_sizes == network.region_sizes)
    region_sizes = expected_region_sizes.tolist()

    # Update the network's gradient statistics with our constructed task gradients,
    # compute the split statistics at each step along the way, and compare the computed
    # z-scores against the expected z-scores.
    task_flags = torch.zeros(len(task_grads), network.num_tasks)
    task_pair_flags = torch.zeros(len(task_grads), network.num_tasks,
                                  network.num_tasks)
    for step in range(len(task_grads)):
        network.num_steps += 1
        network.update_grad_stats(task_grads[step])
        z = network.get_split_statistics()
        assert z.shape == (network.num_tasks, network.num_tasks,
                           network.num_regions)

        # Set task flags, i.e. indicators for whether or not each task is included in
        # each batch, and compute sample sizes for each task and task pair.
        task_flags[step] = torch.any(
            task_grads[step].view(network.num_tasks, -1) != 0, dim=1)
        task_flags[step] = task_flags[step] * 1
        task_pair_flags[step] = task_flags[step].unsqueeze(
            0) * task_flags[step].unsqueeze(1)
        sample_sizes = torch.sum(task_flags[:step + 1], dim=0)
        pair_sample_sizes = torch.sum(task_pair_flags[:step + 1], dim=0)

        # Compute stdev over all gradient values up to `step`. We have to do this
        # differently based on whether or not we have hit the EMA threshold for each
        # task.
        task_vars = torch.zeros(network.num_tasks)
        for task in range(network.num_tasks):
            task_steps = task_flags[:, task].bool()
            if int(sample_sizes[task]) == 0:
                task_var = 0
            elif int(sample_sizes[task]) <= ema_threshold:
                task_grad = task_grads[task_steps, task:task + 1]
                flattened_grad = get_flattened_grads(
                    task_grad,
                    1,
                    region_sizes,
                    0,
                    int(sample_sizes[task]),
                )[0]
                task_var = torch.var(flattened_grad, unbiased=False)
            else:
                task_grad = task_grads[task_steps, task:task + 1]
                flattened_grad = get_flattened_grads(task_grad, 1,
                                                     region_sizes, 0,
                                                     ema_threshold)[0]
                grad_mean = torch.mean(flattened_grad)
                grad_square_mean = torch.mean(flattened_grad**2)
                for i in range(ema_threshold, int(sample_sizes[task])):
                    flattened_grad = get_flattened_grads(
                        task_grad, 1, region_sizes, i, i + 1)[0]
                    new_mean = torch.mean(flattened_grad)
                    new_square_mean = torch.mean(flattened_grad**2)
                    grad_mean = grad_mean * settings[
                        "ema_alpha"] + new_mean * (1.0 - settings["ema_alpha"])
                    grad_square_mean = grad_square_mean * settings[
                        "ema_alpha"] + new_square_mean * (
                            1.0 - settings["ema_alpha"])
                task_var = grad_square_mean - grad_mean**2

            task_vars[task] = task_var

        if settings["grad_var"] is None:
            grad_var = torch.sum(
                task_vars * sample_sizes) / torch.sum(sample_sizes)
        else:
            grad_var = settings["grad_var"]

        # Compare `z` to the expected value for each `(task1, task2, region)`.
        for task1, task2, region in product(
                range(network.num_tasks),
                range(network.num_tasks),
                range(network.num_regions),
        ):
            region_size = int(network.region_sizes[region])

            # Computed the expected value of the mean of gradient differences between
            # `task1, task2` at region `region`.
            steps = task_pair_flags[:, task1, task2].bool()
            if not torch.any(steps):
                continue
            task1_grads = task_grads[steps, task1, region, :]
            task2_grads = task_grads[steps, task2, region, :]
            if pair_sample_sizes[task1, task2] <= ema_threshold:
                diffs = torch.sum((task1_grads - task2_grads)**2, dim=1)
                exp_mean = torch.mean(diffs)
            else:
                initial_task1_grads = task1_grads[:ema_threshold]
                initial_task2_grads = task2_grads[:ema_threshold]
                diffs = torch.sum(
                    (initial_task1_grads - initial_task2_grads)**2, dim=1)
                exp_mean = torch.mean(diffs)
                for i in range(ema_threshold,
                               int(pair_sample_sizes[task1, task2])):
                    task1_grad = task1_grads[i]
                    task2_grad = task2_grads[i]
                    diff = torch.sum((task1_grad - task2_grad)**2)
                    exp_mean = exp_mean * settings["ema_alpha"] + diff * (
                        1.0 - settings["ema_alpha"])

            # Compute the expected z-score.
            sample_size = int(pair_sample_sizes[task1, task2])
            if settings["cap_sample_size"]:
                sample_size = min(sample_size, ema_threshold)
            exp_mu = 2 * region_size * grad_var
            exp_sigma = 2 * math.sqrt(2 * region_size) * grad_var
            expected_z = math.sqrt(sample_size) * (exp_mean -
                                                   exp_mu) / exp_sigma
            assert abs(z[task1, task2, region] - expected_z) < TOL
Ejemplo n.º 4
0
def test_split_stats_manual() -> None:
    """
    Test that `get_split_statistics()` correctly computes the z-score over the pairwise
    differences in task gradients for manually computed values.
    """

    # Set up case.
    settings = dict(V1_SETTINGS)
    settings["num_layers"] = 1
    settings["num_tasks"] = 4
    settings["ema_alpha"] = 0.8
    input_size = 1
    output_size = 2
    settings["hidden_size"] = 2
    ema_threshold = alpha_to_threshold(settings["ema_alpha"])

    # Construct a sequence of task gradients. The network gradient statistics will be
    # updated with these task gradients, and the z-scores will be computed from these
    # statistics.
    task_grads = torch.Tensor([
        [
            [[-0.117, 0.08, -0.091, -0.008]],
            [[0, 0, 0, 0]],
            [[-0.053, 0.078, -0.046, 0.017]],
            [[0, 0, 0, 0]],
        ],
        [
            [[-0.006, 0.083, -0.065, -0.095]],
            [[0.037, 0.051, 0.009, -0.075]],
            [[0.107, 0.264, -0.072, 0.143]],
            [[0.049, 0.03, -0.051, -0.012]],
        ],
        [
            [[0.106, -0.092, -0.015, 0.159]],
            [[0, 0, 0, 0]],
            [[0.055, 0.115, -0.096, 0.032]],
            [[-0.21, 0.11, -0.091, -0.014]],
        ],
        [
            [[-0.116, 0.079, 0.087, 0.041]],
            [[0.094, 0.143, -0.015, -0.008]],
            [[-0.056, -0.054, 0.01, 0.073]],
            [[0.103, -0.085, -0.008, -0.018]],
        ],
        [
            [[-0.147, -0.067, -0.063, -0.022]],
            [[-0.098, 0.059, 0.064, 0.045]],
            [[-0.037, 0.138, 0.06, -0.056]],
            [[0, 0, 0, 0]],
        ],
        [
            [[-0.062, 0.001, 0.106, -0.176]],
            [[-0.007, 0.013, -0.095, 0.082]],
            [[-0.003, 0.066, 0.106, -0.17]],
            [[-0.035, -0.027, -0.105, 0.058]],
        ],
        [
            [[0.114, -0.191, -0.054, -0.122]],
            [[0.053, 0.004, -0.019, 0.053]],
            [[0.155, -0.027, 0.054, -0.015]],
            [[0.073, 0.042, -0.08, 0.056]],
        ],
        [
            [[0.094, 0.002, 0.078, -0.049]],
            [[-0.116, 0.205, 0.175, -0.026]],
            [[-0.178, 0.013, -0.012, 0.136]],
            [[-0.05, 0.105, 0.114, -0.053]],
        ],
        [
            [[0, 0, 0, 0]],
            [[-0.171, -0.001, 0.069, -0.077]],
            [[0.11, 0.053, 0.039, -0.005]],
            [[-0.097, 0.046, 0.124, 0.072]],
        ],
    ])
    total_steps = len(task_grads)

    # Set expected values of gradient statistics.
    expected_grad_diff_mean = torch.Tensor([
        [[0, 0, 0.00675, 0], [0, 0, 0, 0], [0.00675, 0, 0, 0], [0, 0, 0, 0]],
        [
            [0, 0.008749, 0.0544865, 0.012919],
            [0.008749, 0, 0.104354, 0.008154],
            [0.0544865, 0.104354, 0, 0.082586],
            [0.012919, 0.008154, 0.082586, 0],
        ],
        [
            [0, 0.008749, 0.05903766667, 0.094642],
            [0.008749, 0, 0.104354, 0.008154],
            [0.05903766667, 0.104354, 0, 0.0774885],
            [0.094642, 0.008154, 0.0774885, 0],
        ],
        [
            [0, 0.034875, 0.05133875, 0.09221566667],
            [0.034875, 0, 0.0864245, 0.030184],
            [0.05133875, 0.0864245, 0, 0.06327466667],
            [0.09221566667, 0.030184, 0.06327466667, 0],
        ],
        [
            [0, 0.036215, 0.055153, 0.09221566667],
            [0.036215, 0, 0.06434266667, 0.030184],
            [0.055153, 0.06434266667, 0, 0.06327466667],
            [0.09221566667, 0.030184, 0.06327466667, 0],
        ],
        [
            [0, 0.05469475, 0.0456708, 0.09435925],
            [0.05469475, 0, 0.0749395, 0.02114266667],
            [0.0456708, 0.0749395, 0, 0.0740005],
            [0.09435925, 0.02114266667, 0.0740005, 0],
        ],
        [
            [0, 0.058475, 0.04687464, 0.0931534],
            [0.058475, 0, 0.0642152, 0.0172505],
            [0.04687464, 0.0642152, 0, 0.0660968],
            [0.0931534, 0.0172505, 0.0660968, 0],
        ],
        [
            [0, 0.0658294, 0.060785712, 0.08105412],
            [0.0658294, 0, 0.07175636, 0.0175616],
            [0.060785712, 0.07175636, 0, 0.06816644],
            [0.08105412, 0.0175616, 0.06816644, 0],
        ],
        [
            [0, 0.0658294, 0.060785712, 0.08105412],
            [0.0658294, 0, 0.074997288, 0.02063148],
            [0.060785712, 0.074997288, 0, 0.065743552],
            [0.08105412, 0.02063148, 0.065743552, 0],
        ],
    ])
    expected_grad_diff_mean = expected_grad_diff_mean.unsqueeze(-1)

    expected_grad_mean = torch.Tensor([
        [-0.034, 0, -0.001, 0],
        [-0.027375, 0.0055, 0.05475, 0.004],
        [-0.005083333333, 0.0055, 0.04533333333, -0.023625],
        [0.001875, 0.0295, 0.0323125, -0.01641666667],
        [-0.01345, 0.0255, 0.0311, -0.01641666667],
        [-0.01731, 0.0186875, 0.02483, -0.019125],
        [-0.026498, 0.0195, 0.028214, -0.01075],
        [-0.0149484, 0.0275, 0.0205212, -0.0028],
        [-0.0149484, 0.013, 0.02626696, 0.00501],
    ])

    expected_grad_var = torch.Tensor([
        [0.0059525, 0, 0.0028235, 0],
        [0.005326734375, 0.00238875, 0.0117619375, 0.0014955],
        [0.007792076389, 0.00238875, 0.009992055556, 0.008282234375],
        [0.007669109375, 0.004036, 0.008708839844, 0.007142576389],
        [0.0074847475, 0.004221083333, 0.00819259, 0.007142576389],
        [0.0081357339, 0.004302214844, 0.0089363611, 0.006214734375],
        [0.009410001996, 0.00364065, 0.008241032204, 0.0059802875],
        [0.008732512137, 0.00679957, 0.009333179951, 0.00633534],
        [0.008732512137, 0.007872256, 0.007936236492, 0.0066536939],
    ])

    expected_z = torch.Tensor([
        [
            [-1.414213562, 0, -1.142280405, 0],
            [0, 0, 0, 0],
            [-1.142280405, 0, -1.414213562, 0],
            [0, 0, 0, 0],
        ],
        [
            [-2, -1.170405699, 0.1473023578, -1.054200557],
            [-1.170405699, -1.414213562, 1.493813156, -1.186986529],
            [0.1473023578, 1.493813156, -2, 0.8872055932],
            [-1.054200557, -1.186986529, 0.8872055932, -1.414213562],
        ],
        [
            [-2.449489743, -1.221703287, -0.1994752765, 0.9450617525],
            [-1.221703287, -1.414213562, 0.8819594016, -1.234795482],
            [-0.1994752765, 0.8819594016, -2.449489743, 0.4112805901],
            [0.9450617525, -1.234795482, 0.4112805901, -2],
        ],
        [
            [-2.828427125, -0.8070526312, -0.3449088766, 1.413801122],
            [-0.8070526312, -2, 0.9562689571, -0.9675147418],
            [-0.3449088766, 0.9562689571, -2.828427125, 0.2013444437],
            [1.413801122, -0.9675147418, 0.2013444437, -2.449489743],
        ],
        [
            [-3.16227766, -0.8721406803, -0.06105579169, 1.566975682],
            [-0.8721406803, -2.449489743, 0.3529635209, -0.926578017],
            [-0.06105579169, 0.3529635209, -3.16227766, 0.3064466412],
            [1.566975682, -0.926578017, 0.3064466412, -2.449489743],
        ],
        [
            [-3.16227766, -0.09688841312, -0.6121886259, 1.884016833],
            [-0.09688841312, -2.828427125, 0.9141650853, -1.535056275],
            [-0.6121886259, 0.9141650853, -3.16227766, 0.8672700021],
            [1.884016833, -1.535056275, 0.8672700021, -2.828427125],
        ],
        [
            [-3.16227766, 0.227909676, -0.4446408767, 2.238448753],
            [0.227909676, -3.16227766, 0.56070751, -1.933886337],
            [-0.4446408767, 0.56070751, -3.16227766, 0.6697964624],
            [2.238448753, -1.933886337, 0.6697964624, -3.16227766],
        ],
        [
            [-3.16227766, 0.1737291325, -0.08186756787, 0.9452653965],
            [0.1737291325, -3.16227766, 0.4740870094, -2.272316383],
            [-0.08186756787, 0.4740870094, -3.16227766, 0.2921622538],
            [0.9452653965, -2.272316383, 0.2921622538, -3.16227766],
        ],
        [
            [-3.16227766, 0.1743604677, -0.08128460407, 0.946042744],
            [0.1743604677, -3.16227766, 0.6390453145, -2.116547594],
            [-0.08128460407, 0.6390453145, -3.16227766, 0.170009164],
            [0.946042744, -2.116547594, 0.170009164, -3.16227766],
        ],
    ])
    expected_z = expected_z.unsqueeze(-1)

    expected_sample_size = torch.Tensor([
        [1, 0, 1, 0],
        [2, 1, 2, 1],
        [3, 1, 3, 2],
        [4, 2, 4, 3],
        [5, 3, 5, 3],
        [5, 4, 5, 4],
        [5, 5, 5, 5],
        [5, 5, 5, 5],
        [5, 5, 5, 5],
    ])
    expected_pair_sample_size = torch.Tensor([
        [[1, 0, 1, 0], [0, 0, 0, 0], [1, 0, 1, 0], [0, 0, 0, 0]],
        [[2, 1, 2, 1], [1, 1, 1, 1], [2, 1, 2, 1], [1, 1, 1, 1]],
        [[3, 1, 3, 2], [1, 1, 1, 1], [3, 1, 3, 2], [2, 1, 2, 2]],
        [[4, 2, 4, 3], [2, 2, 2, 2], [4, 2, 4, 3], [3, 2, 3, 3]],
        [[5, 3, 5, 3], [3, 3, 3, 2], [5, 3, 5, 3], [3, 2, 3, 3]],
        [[5, 4, 5, 4], [4, 4, 4, 3], [5, 4, 5, 4], [4, 3, 4, 4]],
        [[5, 5, 5, 5], [5, 5, 5, 4], [5, 5, 5, 5], [5, 4, 5, 5]],
        [[5, 5, 5, 5], [5, 5, 5, 5], [5, 5, 5, 5], [5, 5, 5, 5]],
        [[5, 5, 5, 5], [5, 5, 5, 5], [5, 5, 5, 5], [5, 5, 5, 5]],
    ])
    expected_pair_sample_size = expected_pair_sample_size.unsqueeze(-1)

    # Instantiate network.
    network = MultiTaskSplittingNetworkV1(
        input_size=input_size,
        output_size=output_size,
        num_tasks=settings["num_tasks"],
        num_layers=settings["num_layers"],
        hidden_size=settings["hidden_size"],
        ema_alpha=settings["ema_alpha"],
    )

    # Update gradient statistics for each step.
    for step in range(total_steps):
        network.num_steps += 1
        network.update_grad_stats(task_grads[step])
        z = network.get_split_statistics()

        # Compare network statistics to expected values.
        assert torch.all(
            network.grad_stats.sample_size == expected_sample_size[step])
        assert torch.all(network.grad_diff_stats.sample_size ==
                         expected_pair_sample_size[step])
        assert torch.allclose(network.grad_diff_stats.mean,
                              expected_grad_diff_mean[step])
        assert torch.allclose(network.grad_stats.mean,
                              expected_grad_mean[step])
        assert torch.allclose(network.grad_stats.var, expected_grad_var[step])
        assert torch.allclose(z, expected_z[step], atol=TOL)
Ejemplo n.º 5
0
def test_split_stats_EMA_random_split_grad_var() -> None:
    """
    Test that `get_split_statistics()` correctly computes the z-score over the pairwise
    differences in task gradients in the case of random gradients at each time step, a
    split network, using both arithmetic mean and EMA to keep track of gradient
    statistics, when the standard deviation of task-gradients is given as a
    hyperparameter instead of measured online.
    """

    # Set up case.
    settings = dict(V1_SETTINGS)
    settings["obs_dim"] = 2
    settings["num_tasks"] = 4
    settings["grad_var"] = 0.01
    settings["ema_alpha"] = 0.99
    settings["hidden_size"] = settings["obs_dim"] + settings["num_tasks"] + 2
    ema_threshold = alpha_to_threshold(settings["ema_alpha"])

    # Construct series of splits.
    splits_args = [
        {
            "region": 0,
            "copy": 0,
            "group1": [0, 1],
            "group2": [2, 3]
        },
        {
            "region": 1,
            "copy": 0,
            "group1": [0, 2],
            "group2": [1, 3]
        },
        {
            "region": 1,
            "copy": 1,
            "group1": [1],
            "group2": [3]
        },
        {
            "region": 2,
            "copy": 0,
            "group1": [0, 3],
            "group2": [1, 2]
        },
    ]

    # Construct a sequence of task gradients. The network gradient statistics will be
    # updated with these task gradients, and the z-scores will be computed from these
    # statistics.
    total_steps = ema_threshold + 20
    dim = settings["obs_dim"] + settings["num_tasks"]
    max_region_size = settings["hidden_size"]**2 + settings["hidden_size"]
    task_grads = torch.zeros(total_steps, settings["num_tasks"],
                             settings["num_layers"], max_region_size)
    for region in product(range(settings["num_layers"])):
        if region == 0:
            region_size = settings["hidden_size"] * (dim + 1)
        elif region == settings["num_layers"] - 1:
            region_size = dim * (settings["hidden_size"] + 1)
        else:
            region_size = max_region_size

        task_grads[:, :,
                   region, :region_size] = torch.rand(total_steps,
                                                      settings["num_tasks"], 1,
                                                      region_size)

    # Run test.
    split_stats_template(settings, task_grads, splits_args)
Ejemplo n.º 6
0
def test_split_stats_EMA_random_split_batch() -> None:
    """
    Test that `get_split_statistics()` correctly computes the z-score over the pairwise
    differences in task gradients in the case of random gradients at each time step, a
    split network, using both arithmetic mean and EMA to keep track of gradient
    statistics, and when the gradient batches each contain gradients for only a subset
    of all tasks.
    """

    # Set up case.
    settings = dict(V1_SETTINGS)
    settings["obs_dim"] = 2
    settings["num_tasks"] = 4
    settings["ema_alpha"] = 0.99
    settings["hidden_size"] = settings["obs_dim"] + settings["num_tasks"] + 2
    ema_threshold = alpha_to_threshold(settings["ema_alpha"])

    # Construct series of splits.
    splits_args = [
        {
            "region": 0,
            "copy": 0,
            "group1": [0, 1],
            "group2": [2, 3]
        },
        {
            "region": 1,
            "copy": 0,
            "group1": [0, 2],
            "group2": [1, 3]
        },
        {
            "region": 1,
            "copy": 1,
            "group1": [1],
            "group2": [3]
        },
        {
            "region": 2,
            "copy": 0,
            "group1": [0, 3],
            "group2": [1, 2]
        },
    ]

    # Construct a sequence of task gradients. The network gradient statistics will be
    # updated with these task gradients, and the z-scores will be computed from these
    # statistics.
    total_steps = ema_threshold + 20
    dim = settings["obs_dim"] + settings["num_tasks"]
    max_region_size = settings["hidden_size"]**2 + settings["hidden_size"]
    task_grads = torch.zeros(total_steps, settings["num_tasks"],
                             settings["num_layers"], max_region_size)
    for step in range(total_steps):

        # Generate tasks for each batch. Each task has a 50-50 chance of being included
        # in each batch.
        batch_tasks = torch.rand(settings["num_tasks"]) < 0.5
        batch_tasks = batch_tasks.view(settings["num_tasks"], 1, 1)
        for region in product(range(settings["num_layers"])):
            if region == 0:
                region_size = settings["hidden_size"] * (dim + 1)
            elif region == settings["num_layers"] - 1:
                region_size = dim * (settings["hidden_size"] + 1)
            else:
                region_size = max_region_size

            local_grad = torch.rand(settings["num_tasks"], 1, region_size)
            local_grad *= batch_tasks
            task_grads[step, :, region, :region_size] = local_grad

    # Run test.
    split_stats_template(settings, task_grads, splits_args)