def evaluate(method: storch.method.Method, model: DiscreteVAE, data, optimizer):
    # Compute expected gradient
    optimizer.zero_grad()

    z = generative_story(storch.method.Expect("z"), model, data)
    storch.backward()
    expected_gradient = z.param_grads["probs"]

    # Collect gradient samples
    gradients = []
    for i in range(100):
        optimizer.zero_grad()

        z = generative_story(method, model, data)
        elbo = storch.backward()
        gradients.append(z.param_grads["probs"])

    gradients = storch.gather_samples(gradients, "gradients")
    mean_gradient = storch.reduce_plates(gradients, "gradients")
    bias_gradient = (
        storch.reduce_plates((mean_gradient - expected_gradient) ** 2)
    ).sum()
    print(
        "Training ELBO "
        + str(elbo.item())
        + " Gradient variance "
        + str(storch.variance(gradients, "gradients")._tensor.item())
        + " Gradient bias "
        + str(bias_gradient._tensor.item())
    )
Exemple #2
0
    def estimator(
        self, tensor: StochasticTensor, cost_node: CostTensor
    ) -> Optional[storch.Tensor]:
        # Input rsampled value into c_phi
        output_baseline = self.c_phi(tensor)

        # Compute log probability. Make sure not to use the rsampled value: We want to compute the distribution
        # not through the sample but only through the distributional parameters.
        log_prob = tensor.distribution.log_prob(tensor.detach())
        log_prob = log_prob.sum(dim=tensor.event_dim_indices())

        # Compute the derivative with respect to the distributional parameters through the baseline.
        derivs = []
        for param in get_distr_parameters(
            tensor.distribution, filter_requires_grad=True
        ).values():
            # param.register_hook(hook)
            d_log_prob = storch.grad(
                [log_prob],
                [param],
                create_graph=True,
                grad_outputs=torch.ones_like(log_prob),
            )[0]
            d_output_baseline = storch.grad(
                [output_baseline],
                [param],
                create_graph=True,
                grad_outputs=torch.ones_like(output_baseline),
            )[0]
            derivs.append((param, d_log_prob, d_output_baseline))

        diff = cost_node - output_baseline  # [(...,) + (None,) * d_log_prob.event_dims]
        var_loss = 0.0
        for param, d_log_prob, d_output_baseline in derivs:
            # Compute total derivative with respect to the parameter
            d_param = diff * d_log_prob + d_output_baseline
            # Reduce the plate of this sample in case multiple samples are taken
            d_param = storch.reduce_plates(d_param, plate_names=[tensor.name])
            # Compute backwards from the parameters using its total derivative
            if isinstance(param, storch.Tensor):
                param = param._tensor
            param.backward(d_param._tensor, retain_graph=True)
            # Compute the gradient variance
            variance = (d_param ** 2).sum(d_param.event_dim_indices())
            var_loss += storch.reduce_plates(variance)

        c_phi_params = []

        for param in self.c_phi.parameters(recurse=True):
            if param.requires_grad:
                c_phi_params.append(param)
        d_variance = torch.autograd.grad([var_loss._tensor], c_phi_params)
        for i in range(len(c_phi_params)):
            c_phi_params[i].backward(d_variance[i])
        return None
Exemple #3
0
def eval(grads):
    print("----------------------------------")
    grad_samples = storch.gather_samples(grads, "variance")
    mean = storch.reduce_plates(grad_samples, plates=["variance"])
    print("mean grad", mean)
    print("expected grad", expect_grad)
    print("specific_diffs", (mean - expect_grad)**2)
    mse = storch.reduce_plates((grad_samples - expect_grad)**2).sum()
    print("MSE", mse)
    bias = (storch.reduce_plates((mean - expect_grad)**2)).sum()
    print("bias", bias)
    return bias
Exemple #4
0
    def update_parameters(
            self, result_triples: [(StochasticTensor, CostTensor)]) -> None:
        # During the normal backwards call, the parameters are accumulated gradients to the control variate parameters.
        # We don't want to minimize wrt to that loss, but to the one we define here.
        for param in self.control_params:
            param.grad = None
        tensors = []
        for tensor, _ in result_triples:
            if tensor not in tensors:
                tensors.append(tensor)
        # minimize the variance of the gradient with respect to the input parameters
        for tensor in tensors:
            d_param = next(iter(tensor.grad.values()))
            variance = (d_param**2).sum(d_param.event_dim_indices)
            var_loss = storch.reduce_plates(variance)

            d_variance = torch.autograd.grad(
                [var_loss._tensor],
                self.control_params,
                retain_graph=True,
            )
            print(d_variance)

            for i in range(len(self.control_params)):
                self.control_params[i].backward(d_variance[i])
Exemple #5
0
def estimate_variance(method):
    gradient_samples = []
    for i in range(1000):
        f, c = compute_f(method)
        storch.add_cost(f, "f")
        storch.backward()
        gradient_samples.append(c.grad)
    gradients = storch.gather_samples(gradient_samples, "gradients")
    # print(gradients)
    print("variance", storch.variance(gradients, "gradients"))
    print("mean", storch.reduce_plates(gradients, "gradients"))
    print("st dev", torch.sqrt(storch.variance(gradients, "gradients")))

    print(type(gradients))
    print(gradients.shape)
    print(gradients.plates)
Exemple #6
0
    def update_parameters(
            self, result_triples: [(StochasticTensor, CostTensor)]) -> None:
        # During the normal backwards call, the parameters are accumulated gradients to the control variate parameters.
        # We don't want to minimize wrt to that loss, but to the one we define here.
        for param in self.control_params:
            param.grad = None
        # Find grad of the distribution, then compute variance.
        # TODO: check if the set statement properly filters different results here
        tensors = list(set([result[0] for result in result_triples]))
        # minimize the variance of the gradient with respect to the input parameters
        for tensor in tensors:
            # TODO: We have to select the probs of the distribution here as that's what it flows to. Is this always correct?
            d_param = tensor.grad['probs']
            variance = (d_param**2).sum(d_param.event_dim_indices)
            var_loss = storch.reduce_plates(variance)

            d_variance = torch.autograd.grad([var_loss._tensor],
                                             self.control_params,
                                             retain_graph=True,
                                             allow_unused=True)
            for i in range(len(self.control_params)):
                self.control_params[i].backward(d_variance[i])
Exemple #7
0
def backward(
    retain_graph: bool = False, debug: bool = False, print_costs: bool = False
) -> torch.Tensor:
    """
    Computes the gradients of the cost nodes with respect to the parameter nodes. It uses the storch
    methods used to sample stochastic nodes to properly estimate their gradient.

    Args:
        retain_graph (bool): If set to False, it will deregister the added cost nodes. Should usually be set to False.
        debug: Prints debug information on the backwards call.
        accum_grads: Saves gradient information in stochastic nodes. Note that this is an expensive option as it
        requires doing O(n) backward calls for each stochastic node sampled multiple times. Especially if this is a
        hierarchy of multiple samples.
    Returns:
        torch.Tensor: The average total cost normalized by the sampling weights.
    """

    costs: [storch.Tensor] = storch.inference._cost_tensors
    if not costs:
        raise RuntimeError("No cost nodes registered for backward call.")
    if debug:
        print_graph(costs)

    # Sum of averages of cost node tensors
    total_cost = 0.0
    # Sum of losses that can be backpropagated through in keepgrads without difficult iterations
    accum_loss = 0.0

    stochastic_nodes = set()
    # Loop over different cost nodes
    for c in costs:
        # Do not detach the weights when reducing. This is used in for example expectations to weight the
        # different costs.
        reduced_cost = storch.reduce_plates(c, detach_weights=False)

        if print_costs:
            print(c.name, ":", reduced_cost._tensor.item())
        total_cost += reduced_cost
        # Compute gradients for the cost nodes themselves, if they require one.
        if reduced_cost.requires_grad:
            accum_loss += reduced_cost
        for parent in c.walk_parents(depth_first=False):
            # Instance check here instead of parent.stochastic, as backward methods are only used on these.
            if isinstance(parent, StochasticTensor):
                stochastic_nodes.add(parent)
            else:
                continue
            if (
                not parent.requires_grad
                or not parent.method
                or not parent.method.adds_loss(parent, c)
            ):
                continue

            # Transpose the parent stochastic tensor, so that its shape is the same as the cost but the event shape, and
            # possibly extra dimensions...?
            parent_tensor = parent._tensor
            reduced_cost = c
            parent_plates = parent.multi_dim_plates()
            # Reduce all plates that are in the cost node but not in the parent node
            for plate in storch.order_plates(c.multi_dim_plates(), reverse=True):
                if plate not in parent_plates:
                    reduced_cost = plate.reduce(reduced_cost, detach_weights=True)
            # Align the parent tensor so that the plate dimensions are in the same order as the cost tensor
            for index_c, plate in enumerate(reduced_cost.multi_dim_plates()):
                index_p = parent_plates.index(plate)
                if index_c != index_p:
                    parent_tensor = parent_tensor.transpose(index_p, index_c)
                    parent_plates[index_p], parent_plates[index_c] = (
                        parent_plates[index_c],
                        parent_plates[index_p],
                    )
            # Add empty (k=1) plates to new parent
            for plate in parent.plates:
                if plate not in parent_plates:
                    parent_plates.append(plate)

            # Create new storch Tensors with different order of plates for the cost and parent
            new_parent = storch.tensor.StochasticTensor(
                parent_tensor,
                [],
                parent_plates,
                parent.name,
                parent.n,
                parent.distribution,
                parent._requires_grad,
                parent.method,
            )
            # Fake the new parent to be the old parent within the graph by mimicing its place in the graph
            new_parent._parents = parent._parents
            for p, has_link in new_parent._parents:
                p._children.append((new_parent, has_link))
            new_parent._children = parent._children
            cost_per_sample = parent.method._estimator(new_parent, reduced_cost)

            if cost_per_sample is not None:
                # The backwards call for reparameterization happens in the
                # backwards call for the costs themselves.
                # Now mean_cost has the same shape as parent.batch_shape
                final_reduced_cost = storch.reduce_plates(
                    cost_per_sample, detach_weights=True
                )
                if final_reduced_cost.ndim == 1:
                    final_reduced_cost = final_reduced_cost.squeeze(0)
                accum_loss += final_reduced_cost

    if isinstance(accum_loss, storch.Tensor) and accum_loss._tensor.requires_grad:
        accum_loss._tensor.backward(retain_graph=retain_graph)

    for s_node in stochastic_nodes:
        if s_node.method:
            s_node.method._update_parameters()

    if not retain_graph:
        accum_loss._clean()
        reset()

    # TODO: How much does accum_loss really say? Should we really keep it? We want to minimize total_cost, anyways.
    return total_cost._tensor  # , accum_loss._tensor
Exemple #8
0
def surrogate_loss(debug: bool = False) -> storch.Tensor:
    costs: [storch.Tensor] = storch.inference._cost_tensors
    if not costs:
        raise RuntimeError("No cost nodes registered for backward call.")
    if debug:
        print_graph(costs)

    # Sum of averages of cost node tensors
    surrogate_losses = []

    # Loop over different cost nodes
    for c in costs:
        # Do not detach the weights when reducing. This is used in for example expectations to weight the
        # different costs.
        # reduced_cost = storch.reduce_plates(c, detach_weights=False)
        #
        # if print_costs:
        #     print(c.name, ":", reduced_cost._tensor.item())
        # total_cost += reduced_cost
        # Compute gradients for the cost nodes themselves, if they require one.
        # if reduced_cost.requires_grad:
        #     accum_loss += reduced_cost

        L = c._tensor.new_tensor(0.0)
        surrogate_loss_c = 0.0
        # Walk topologically through the graph
        # This is a parallelized implementation of Algorithm 1 in the paper
        for parent in c.walk_parents(depth_first=False, reverse=True):
            # Instance check here instead of parent.stochastic, as backward methods are only used on these.
            if not isinstance(parent, StochasticTensor):
                continue
            if not parent.requires_grad or not parent.method:
                continue

            if parent.method.is_pathwise(parent, c):
                continue
            # Transpose the parent stochastic tensor, so that its shape is the same as the cost but the event shape, and
            # possibly extra dimensions...?
            parent_tensor = parent._tensor
            reduced_cost = c
            parent_plates = parent.multi_dim_plates()
            # Reduce all plates that are in the cost node but not in the parent node
            for plate in storch.order_plates(c.multi_dim_plates(),
                                             reverse=True):
                if plate not in parent_plates:
                    reduced_cost = plate.reduce(reduced_cost,
                                                detach_weights=True)
            # Align the parent tensor so that the plate dimensions are in the same order as the cost tensor
            # TODO: This can probably be implemented with torch.movedim
            for index_c, plate in enumerate(reduced_cost.multi_dim_plates()):
                index_p = parent_plates.index(plate)
                if index_c != index_p:
                    parent_tensor = parent_tensor.transpose(index_p, index_c)
                    parent_plates[index_p], parent_plates[index_c] = (
                        parent_plates[index_c],
                        parent_plates[index_p],
                    )
            # Add empty (k=1) plates to new parent
            for plate in parent.plates:
                if plate not in parent_plates:
                    parent_plates.append(plate)

            # Create new storch Tensors with different order of plates for the cost and parent
            new_parent = storch.tensor.StochasticTensor(
                parent_tensor,
                [],
                parent_plates,
                parent.name,
                parent.n,
                parent.distribution,
                parent._requires_grad,
                parent.method,
            )
            new_parent.param_grads = parent.param_grads
            # Fake the new parent to be the old parent within the graph by mimicking its place in the graph
            new_parent._parents = parent._parents
            for p, has_link in new_parent._parents:
                p._children.append((new_parent, has_link))
            new_parent._children = parent._children

            # Compute the estimator
            (
                gradient_function,
                control_variate,
            ) = parent.method._estimator(new_parent, reduced_cost)

            if gradient_function is not None:
                L = L + gradient_function
            # Compute control variate
            if control_variate is not None:
                final_A = magic_box(L) * control_variate
                final_A = storch.reduce_plates(
                    final_A,
                    detach_weights=
                    False,  # TODO: Should this boolean be false or true?
                )
                if final_A.ndim == 1:
                    final_A = final_A.squeeze(0)
                surrogate_loss_c += final_A
        # Use magic box to distribute the cost to gradient function
        surrogate_loss_c += storch.reduce_plates(magic_box(L) * c,
                                                 detach_weights=False)
        # Collect surrogate losses for all costs
        surrogate_losses.append(surrogate_loss_c)
    SL = torch.sum(torch.stack(surrogate_losses))
    return SL
Exemple #9
0
    def estimator(
        self, tensor: StochasticTensor, cost_node: CostTensor
    ) -> Optional[storch.Tensor]:
        plate = tensor.get_plate(tensor.name)
        if self.rebar:
            hard_sample, relaxed_sample, cond_sample = split(
                tensor, plate, amt_slices=3
            )
            hard_cost, relaxed_cost, cond_cost = split(cost_node, plate, amt_slices=3)

        else:
            hard_sample = discretize(tensor, tensor.distribution)
            relaxed_sample = tensor
            cond_sample = conditional_gumbel_rsample(
                hard_sample, tensor.distribution, self.temperature
            )

            hard_cost = cost_node
            relaxed_cost = 0.0
            cond_cost = 0.0

        # Input rsampled values into c_phi
        c_phi_relaxed = self.c_phi(relaxed_sample) + relaxed_cost
        c_phi_cond = self.c_phi(cond_sample) + cond_cost

        # Compute log probability of hard sample
        log_prob = tensor.distribution.log_prob(hard_sample)
        log_prob = log_prob.sum(
            dim=list(range(hard_sample.plate_dims, len(log_prob.shape)))
        )

        # Compute the derivative with respect to the distributional parameters through the baseline.
        param = tensor.distribution._param
        # TODO: It should either propagate over the logits or over the probs. Can we know which one is the parameter and
        # which one is computed dynamically?
        # param.register_hook(hook)
        # TODO: Can these be collected in a single torch.autograd.grad call?
        d_log_prob = storch.grad(
            [log_prob],
            [param],
            create_graph=True,
            grad_outputs=torch.ones_like(log_prob),
        )[0]
        d_c_phi_relaxed = storch.grad(
            [c_phi_relaxed],
            [param],
            create_graph=True,
            grad_outputs=torch.ones_like(c_phi_relaxed),
        )[0]
        d_c_phi_cond = storch.grad(
            [c_phi_cond],
            [param],
            create_graph=True,
            grad_outputs=torch.ones_like(c_phi_cond),
        )[0]

        diff = hard_cost - self.eta * c_phi_cond
        # Compute total derivative with respect to the parameter
        d_param = diff * d_log_prob + self.eta * (d_c_phi_relaxed - d_c_phi_cond)
        # Reduce the plate of this sample in case multiple samples are taken
        d_param = storch.reduce_plates(d_param, plate_names=[tensor.name])
        # Compute backwards from the parameters using its total derivative
        if isinstance(param, storch.Tensor):
            param._tensor.backward(d_param._tensor, retain_graph=True)
        else:
            param.backward(d_param._tensor, retain_graph=True)
        # Compute the gradient variance
        variance = (d_param ** 2).sum(d_param.event_dim_indices())
        var_loss = storch.reduce_plates(variance)

        # Minimize variance over the parameters of c_phi and the temperature (should it also minimize eta?)
        c_phi_params = [self.temperature]
        if isinstance(self.c_phi, torch.nn.Module):
            for c_phi_param in self.c_phi.parameters(recurse=True):
                if c_phi_param.requires_grad:
                    c_phi_params.append(c_phi_param)

        d_variance = torch.autograd.grad(
            [var_loss._tensor], c_phi_params, create_graph=self.rebar
        )

        for i in range(len(c_phi_params)):
            c_phi_params[i].backward(d_variance[i])
        return None
Exemple #10
0
 def compute_baseline(self, tensor: StochasticTensor,
                      cost_node: CostTensor) -> torch.Tensor:
     avg_cost = storch.reduce_plates(cost_node).detach()
     self.moving_average = (self.exponential_decay * self.moving_average +
                            (1 - self.exponential_decay) * avg_cost)._tensor
     return self.moving_average
Exemple #11
0
 def reduce(self, unique_tensor: storch.Tensor, detach_weights=True):
     non_unique = self.undo_unique(unique_tensor)
     return storch.reduce_plates(
         non_unique, plates=self.shrunken_plates, detach_weights=detach_weights
     )
Exemple #12
0
def train(epoch, model, train_loader, device, optimizer, args, writer):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        storch.reset()

        # Denote the minibatch dimension as being independent
        data = storch.denote_independent(data.view(-1, 784), 0, "data")
        recon_batch, KLD, z = model(data)
        storch.add_cost(loss_function(recon_batch, data), "reconstruction")
        cost = backward()
        train_loss += cost.item()

        optimizer.step()

        cond_log = batch_idx % args.log_interval == 0

        if cond_log:
            step = 100.0 * batch_idx / len(train_loader)
            global_step = 100 * (epoch - 1) + step

            # Variance of expect method is 0 by definition.
            variances = {}
            if args.method != "expect" and args.variance_samples > 1:
                _consider_param = "probs"
                if args.latents < 3:
                    old_method = model.sampling_method
                    model.sampling_method = Expect("z")
                    optimizer.zero_grad()
                    recon_batch, _, z = model(data)
                    storch.add_cost(loss_function(recon_batch, data),
                                    "reconstruction")
                    backward()
                    expect_grad = storch.reduce_plates(
                        z.grad[_consider_param]).detach_tensor()

                    optimizer.zero_grad()
                    model.sampling_method = old_method
                grads = {n: [] for n in z.grad}

                for i in range(args.variance_samples):
                    optimizer.zero_grad()
                    recon_batch, _, z = model(data)
                    storch.add_cost(loss_function(recon_batch, data),
                                    "reconstruction")
                    backward()

                    for param_name, grad in z.grad.items():
                        # Make sure to reduce the data dimension and detach, for memory reasons.
                        grads[param_name].append(
                            storch.reduce_plates(grad).detach_tensor())

                variances = {}
                for param_name, gradz in grads.items():
                    # Create a new independent dimension for the different gradient samples
                    grad_samples = storch.gather_samples(gradz, "variance")
                    # Compute the variance over this independent dimension
                    variances[param_name] = storch.variance(
                        grad_samples, "variance")._tensor
                    if param_name == _consider_param and args.latents < 3:
                        mean = storch.reduce_plates(grad_samples, "variance")
                        mse = storch.reduce_plates(
                            (grad_samples - expect_grad)**2).sum()
                        bias = (storch.reduce_plates(
                            (mean - expect_grad)**2)).sum()
                        print("mse", mse._tensor.item())
                        # Should approach 0 when increasing variance_samples for unbiased estimators.
                        print("bias", bias._tensor.item())
                        writer.add_scalar("train/probs_bias", bias._tensor,
                                          global_step)
                        writer.add_scalar("train/probs_mse", mse._tensor,
                                          global_step)

            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tCost: {:.6f}\t Logits var {}"
                .format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    step,
                    cost.item(),
                    variances,
                ))
            writer.add_scalar("train/ELBO", cost, global_step)
            for param_name, var in variances.items():
                writer.add_scalar("train/variance/" + param_name, var,
                                  global_step)
    avg_train_loss = train_loss / (batch_idx + 1)
    print("====> Epoch: {} Average loss: {:.4f}".format(epoch, avg_train_loss))
    return avg_train_loss
Exemple #13
0
storch.add_cost(c, "no_baseline_cost")

storch.backward()

expect_grad = z.grad["probs"].clone()

method = storch.method.UnorderedSetEstimator("x", k=6)
# method = storch.REBAR()
grads = []
for i in range(100):
    b = OneHotCategorical(probs=probs)
    z = method.sample(b)
    c = (2.4 * z * indices).sum(-1) + 100
    storch.add_cost(c, "baseline_cost")

    storch.backward()
    grad = z.grad["probs"].clone()
    grads.append(grad)
grad_samples = storch.gather_samples(grads, "variance")
mean = storch.reduce_plates(grad_samples, plate_names=["variance"])
print("mean grad", mean)
print("expected grad", expect_grad)
print("specific_diffs", (mean - expect_grad) ** 2)
mse = storch.reduce_plates((grad_samples - expect_grad) ** 2).sum()
print("MSE", mse)
bias = (storch.reduce_plates((mean - expect_grad) ** 2)).sum()
print("bias", bias)


# That works... Adding constants to the costs doesn't change the gradient in expectation.