def evaluate(method: storch.method.Method, model: DiscreteVAE, data, optimizer): # Compute expected gradient optimizer.zero_grad() z = generative_story(storch.method.Expect("z"), model, data) storch.backward() expected_gradient = z.param_grads["probs"] # Collect gradient samples gradients = [] for i in range(100): optimizer.zero_grad() z = generative_story(method, model, data) elbo = storch.backward() gradients.append(z.param_grads["probs"]) gradients = storch.gather_samples(gradients, "gradients") mean_gradient = storch.reduce_plates(gradients, "gradients") bias_gradient = ( storch.reduce_plates((mean_gradient - expected_gradient) ** 2) ).sum() print( "Training ELBO " + str(elbo.item()) + " Gradient variance " + str(storch.variance(gradients, "gradients")._tensor.item()) + " Gradient bias " + str(bias_gradient._tensor.item()) )
def estimator( self, tensor: StochasticTensor, cost_node: CostTensor ) -> Optional[storch.Tensor]: # Input rsampled value into c_phi output_baseline = self.c_phi(tensor) # Compute log probability. Make sure not to use the rsampled value: We want to compute the distribution # not through the sample but only through the distributional parameters. log_prob = tensor.distribution.log_prob(tensor.detach()) log_prob = log_prob.sum(dim=tensor.event_dim_indices()) # Compute the derivative with respect to the distributional parameters through the baseline. derivs = [] for param in get_distr_parameters( tensor.distribution, filter_requires_grad=True ).values(): # param.register_hook(hook) d_log_prob = storch.grad( [log_prob], [param], create_graph=True, grad_outputs=torch.ones_like(log_prob), )[0] d_output_baseline = storch.grad( [output_baseline], [param], create_graph=True, grad_outputs=torch.ones_like(output_baseline), )[0] derivs.append((param, d_log_prob, d_output_baseline)) diff = cost_node - output_baseline # [(...,) + (None,) * d_log_prob.event_dims] var_loss = 0.0 for param, d_log_prob, d_output_baseline in derivs: # Compute total derivative with respect to the parameter d_param = diff * d_log_prob + d_output_baseline # Reduce the plate of this sample in case multiple samples are taken d_param = storch.reduce_plates(d_param, plate_names=[tensor.name]) # Compute backwards from the parameters using its total derivative if isinstance(param, storch.Tensor): param = param._tensor param.backward(d_param._tensor, retain_graph=True) # Compute the gradient variance variance = (d_param ** 2).sum(d_param.event_dim_indices()) var_loss += storch.reduce_plates(variance) c_phi_params = [] for param in self.c_phi.parameters(recurse=True): if param.requires_grad: c_phi_params.append(param) d_variance = torch.autograd.grad([var_loss._tensor], c_phi_params) for i in range(len(c_phi_params)): c_phi_params[i].backward(d_variance[i]) return None
def eval(grads): print("----------------------------------") grad_samples = storch.gather_samples(grads, "variance") mean = storch.reduce_plates(grad_samples, plates=["variance"]) print("mean grad", mean) print("expected grad", expect_grad) print("specific_diffs", (mean - expect_grad)**2) mse = storch.reduce_plates((grad_samples - expect_grad)**2).sum() print("MSE", mse) bias = (storch.reduce_plates((mean - expect_grad)**2)).sum() print("bias", bias) return bias
def update_parameters( self, result_triples: [(StochasticTensor, CostTensor)]) -> None: # During the normal backwards call, the parameters are accumulated gradients to the control variate parameters. # We don't want to minimize wrt to that loss, but to the one we define here. for param in self.control_params: param.grad = None tensors = [] for tensor, _ in result_triples: if tensor not in tensors: tensors.append(tensor) # minimize the variance of the gradient with respect to the input parameters for tensor in tensors: d_param = next(iter(tensor.grad.values())) variance = (d_param**2).sum(d_param.event_dim_indices) var_loss = storch.reduce_plates(variance) d_variance = torch.autograd.grad( [var_loss._tensor], self.control_params, retain_graph=True, ) print(d_variance) for i in range(len(self.control_params)): self.control_params[i].backward(d_variance[i])
def estimate_variance(method): gradient_samples = [] for i in range(1000): f, c = compute_f(method) storch.add_cost(f, "f") storch.backward() gradient_samples.append(c.grad) gradients = storch.gather_samples(gradient_samples, "gradients") # print(gradients) print("variance", storch.variance(gradients, "gradients")) print("mean", storch.reduce_plates(gradients, "gradients")) print("st dev", torch.sqrt(storch.variance(gradients, "gradients"))) print(type(gradients)) print(gradients.shape) print(gradients.plates)
def update_parameters( self, result_triples: [(StochasticTensor, CostTensor)]) -> None: # During the normal backwards call, the parameters are accumulated gradients to the control variate parameters. # We don't want to minimize wrt to that loss, but to the one we define here. for param in self.control_params: param.grad = None # Find grad of the distribution, then compute variance. # TODO: check if the set statement properly filters different results here tensors = list(set([result[0] for result in result_triples])) # minimize the variance of the gradient with respect to the input parameters for tensor in tensors: # TODO: We have to select the probs of the distribution here as that's what it flows to. Is this always correct? d_param = tensor.grad['probs'] variance = (d_param**2).sum(d_param.event_dim_indices) var_loss = storch.reduce_plates(variance) d_variance = torch.autograd.grad([var_loss._tensor], self.control_params, retain_graph=True, allow_unused=True) for i in range(len(self.control_params)): self.control_params[i].backward(d_variance[i])
def backward( retain_graph: bool = False, debug: bool = False, print_costs: bool = False ) -> torch.Tensor: """ Computes the gradients of the cost nodes with respect to the parameter nodes. It uses the storch methods used to sample stochastic nodes to properly estimate their gradient. Args: retain_graph (bool): If set to False, it will deregister the added cost nodes. Should usually be set to False. debug: Prints debug information on the backwards call. accum_grads: Saves gradient information in stochastic nodes. Note that this is an expensive option as it requires doing O(n) backward calls for each stochastic node sampled multiple times. Especially if this is a hierarchy of multiple samples. Returns: torch.Tensor: The average total cost normalized by the sampling weights. """ costs: [storch.Tensor] = storch.inference._cost_tensors if not costs: raise RuntimeError("No cost nodes registered for backward call.") if debug: print_graph(costs) # Sum of averages of cost node tensors total_cost = 0.0 # Sum of losses that can be backpropagated through in keepgrads without difficult iterations accum_loss = 0.0 stochastic_nodes = set() # Loop over different cost nodes for c in costs: # Do not detach the weights when reducing. This is used in for example expectations to weight the # different costs. reduced_cost = storch.reduce_plates(c, detach_weights=False) if print_costs: print(c.name, ":", reduced_cost._tensor.item()) total_cost += reduced_cost # Compute gradients for the cost nodes themselves, if they require one. if reduced_cost.requires_grad: accum_loss += reduced_cost for parent in c.walk_parents(depth_first=False): # Instance check here instead of parent.stochastic, as backward methods are only used on these. if isinstance(parent, StochasticTensor): stochastic_nodes.add(parent) else: continue if ( not parent.requires_grad or not parent.method or not parent.method.adds_loss(parent, c) ): continue # Transpose the parent stochastic tensor, so that its shape is the same as the cost but the event shape, and # possibly extra dimensions...? parent_tensor = parent._tensor reduced_cost = c parent_plates = parent.multi_dim_plates() # Reduce all plates that are in the cost node but not in the parent node for plate in storch.order_plates(c.multi_dim_plates(), reverse=True): if plate not in parent_plates: reduced_cost = plate.reduce(reduced_cost, detach_weights=True) # Align the parent tensor so that the plate dimensions are in the same order as the cost tensor for index_c, plate in enumerate(reduced_cost.multi_dim_plates()): index_p = parent_plates.index(plate) if index_c != index_p: parent_tensor = parent_tensor.transpose(index_p, index_c) parent_plates[index_p], parent_plates[index_c] = ( parent_plates[index_c], parent_plates[index_p], ) # Add empty (k=1) plates to new parent for plate in parent.plates: if plate not in parent_plates: parent_plates.append(plate) # Create new storch Tensors with different order of plates for the cost and parent new_parent = storch.tensor.StochasticTensor( parent_tensor, [], parent_plates, parent.name, parent.n, parent.distribution, parent._requires_grad, parent.method, ) # Fake the new parent to be the old parent within the graph by mimicing its place in the graph new_parent._parents = parent._parents for p, has_link in new_parent._parents: p._children.append((new_parent, has_link)) new_parent._children = parent._children cost_per_sample = parent.method._estimator(new_parent, reduced_cost) if cost_per_sample is not None: # The backwards call for reparameterization happens in the # backwards call for the costs themselves. # Now mean_cost has the same shape as parent.batch_shape final_reduced_cost = storch.reduce_plates( cost_per_sample, detach_weights=True ) if final_reduced_cost.ndim == 1: final_reduced_cost = final_reduced_cost.squeeze(0) accum_loss += final_reduced_cost if isinstance(accum_loss, storch.Tensor) and accum_loss._tensor.requires_grad: accum_loss._tensor.backward(retain_graph=retain_graph) for s_node in stochastic_nodes: if s_node.method: s_node.method._update_parameters() if not retain_graph: accum_loss._clean() reset() # TODO: How much does accum_loss really say? Should we really keep it? We want to minimize total_cost, anyways. return total_cost._tensor # , accum_loss._tensor
def surrogate_loss(debug: bool = False) -> storch.Tensor: costs: [storch.Tensor] = storch.inference._cost_tensors if not costs: raise RuntimeError("No cost nodes registered for backward call.") if debug: print_graph(costs) # Sum of averages of cost node tensors surrogate_losses = [] # Loop over different cost nodes for c in costs: # Do not detach the weights when reducing. This is used in for example expectations to weight the # different costs. # reduced_cost = storch.reduce_plates(c, detach_weights=False) # # if print_costs: # print(c.name, ":", reduced_cost._tensor.item()) # total_cost += reduced_cost # Compute gradients for the cost nodes themselves, if they require one. # if reduced_cost.requires_grad: # accum_loss += reduced_cost L = c._tensor.new_tensor(0.0) surrogate_loss_c = 0.0 # Walk topologically through the graph # This is a parallelized implementation of Algorithm 1 in the paper for parent in c.walk_parents(depth_first=False, reverse=True): # Instance check here instead of parent.stochastic, as backward methods are only used on these. if not isinstance(parent, StochasticTensor): continue if not parent.requires_grad or not parent.method: continue if parent.method.is_pathwise(parent, c): continue # Transpose the parent stochastic tensor, so that its shape is the same as the cost but the event shape, and # possibly extra dimensions...? parent_tensor = parent._tensor reduced_cost = c parent_plates = parent.multi_dim_plates() # Reduce all plates that are in the cost node but not in the parent node for plate in storch.order_plates(c.multi_dim_plates(), reverse=True): if plate not in parent_plates: reduced_cost = plate.reduce(reduced_cost, detach_weights=True) # Align the parent tensor so that the plate dimensions are in the same order as the cost tensor # TODO: This can probably be implemented with torch.movedim for index_c, plate in enumerate(reduced_cost.multi_dim_plates()): index_p = parent_plates.index(plate) if index_c != index_p: parent_tensor = parent_tensor.transpose(index_p, index_c) parent_plates[index_p], parent_plates[index_c] = ( parent_plates[index_c], parent_plates[index_p], ) # Add empty (k=1) plates to new parent for plate in parent.plates: if plate not in parent_plates: parent_plates.append(plate) # Create new storch Tensors with different order of plates for the cost and parent new_parent = storch.tensor.StochasticTensor( parent_tensor, [], parent_plates, parent.name, parent.n, parent.distribution, parent._requires_grad, parent.method, ) new_parent.param_grads = parent.param_grads # Fake the new parent to be the old parent within the graph by mimicking its place in the graph new_parent._parents = parent._parents for p, has_link in new_parent._parents: p._children.append((new_parent, has_link)) new_parent._children = parent._children # Compute the estimator ( gradient_function, control_variate, ) = parent.method._estimator(new_parent, reduced_cost) if gradient_function is not None: L = L + gradient_function # Compute control variate if control_variate is not None: final_A = magic_box(L) * control_variate final_A = storch.reduce_plates( final_A, detach_weights= False, # TODO: Should this boolean be false or true? ) if final_A.ndim == 1: final_A = final_A.squeeze(0) surrogate_loss_c += final_A # Use magic box to distribute the cost to gradient function surrogate_loss_c += storch.reduce_plates(magic_box(L) * c, detach_weights=False) # Collect surrogate losses for all costs surrogate_losses.append(surrogate_loss_c) SL = torch.sum(torch.stack(surrogate_losses)) return SL
def estimator( self, tensor: StochasticTensor, cost_node: CostTensor ) -> Optional[storch.Tensor]: plate = tensor.get_plate(tensor.name) if self.rebar: hard_sample, relaxed_sample, cond_sample = split( tensor, plate, amt_slices=3 ) hard_cost, relaxed_cost, cond_cost = split(cost_node, plate, amt_slices=3) else: hard_sample = discretize(tensor, tensor.distribution) relaxed_sample = tensor cond_sample = conditional_gumbel_rsample( hard_sample, tensor.distribution, self.temperature ) hard_cost = cost_node relaxed_cost = 0.0 cond_cost = 0.0 # Input rsampled values into c_phi c_phi_relaxed = self.c_phi(relaxed_sample) + relaxed_cost c_phi_cond = self.c_phi(cond_sample) + cond_cost # Compute log probability of hard sample log_prob = tensor.distribution.log_prob(hard_sample) log_prob = log_prob.sum( dim=list(range(hard_sample.plate_dims, len(log_prob.shape))) ) # Compute the derivative with respect to the distributional parameters through the baseline. param = tensor.distribution._param # TODO: It should either propagate over the logits or over the probs. Can we know which one is the parameter and # which one is computed dynamically? # param.register_hook(hook) # TODO: Can these be collected in a single torch.autograd.grad call? d_log_prob = storch.grad( [log_prob], [param], create_graph=True, grad_outputs=torch.ones_like(log_prob), )[0] d_c_phi_relaxed = storch.grad( [c_phi_relaxed], [param], create_graph=True, grad_outputs=torch.ones_like(c_phi_relaxed), )[0] d_c_phi_cond = storch.grad( [c_phi_cond], [param], create_graph=True, grad_outputs=torch.ones_like(c_phi_cond), )[0] diff = hard_cost - self.eta * c_phi_cond # Compute total derivative with respect to the parameter d_param = diff * d_log_prob + self.eta * (d_c_phi_relaxed - d_c_phi_cond) # Reduce the plate of this sample in case multiple samples are taken d_param = storch.reduce_plates(d_param, plate_names=[tensor.name]) # Compute backwards from the parameters using its total derivative if isinstance(param, storch.Tensor): param._tensor.backward(d_param._tensor, retain_graph=True) else: param.backward(d_param._tensor, retain_graph=True) # Compute the gradient variance variance = (d_param ** 2).sum(d_param.event_dim_indices()) var_loss = storch.reduce_plates(variance) # Minimize variance over the parameters of c_phi and the temperature (should it also minimize eta?) c_phi_params = [self.temperature] if isinstance(self.c_phi, torch.nn.Module): for c_phi_param in self.c_phi.parameters(recurse=True): if c_phi_param.requires_grad: c_phi_params.append(c_phi_param) d_variance = torch.autograd.grad( [var_loss._tensor], c_phi_params, create_graph=self.rebar ) for i in range(len(c_phi_params)): c_phi_params[i].backward(d_variance[i]) return None
def compute_baseline(self, tensor: StochasticTensor, cost_node: CostTensor) -> torch.Tensor: avg_cost = storch.reduce_plates(cost_node).detach() self.moving_average = (self.exponential_decay * self.moving_average + (1 - self.exponential_decay) * avg_cost)._tensor return self.moving_average
def reduce(self, unique_tensor: storch.Tensor, detach_weights=True): non_unique = self.undo_unique(unique_tensor) return storch.reduce_plates( non_unique, plates=self.shrunken_plates, detach_weights=detach_weights )
def train(epoch, model, train_loader, device, optimizer, args, writer): model.train() train_loss = 0 for batch_idx, (data, _) in enumerate(train_loader): data = data.to(device) optimizer.zero_grad() storch.reset() # Denote the minibatch dimension as being independent data = storch.denote_independent(data.view(-1, 784), 0, "data") recon_batch, KLD, z = model(data) storch.add_cost(loss_function(recon_batch, data), "reconstruction") cost = backward() train_loss += cost.item() optimizer.step() cond_log = batch_idx % args.log_interval == 0 if cond_log: step = 100.0 * batch_idx / len(train_loader) global_step = 100 * (epoch - 1) + step # Variance of expect method is 0 by definition. variances = {} if args.method != "expect" and args.variance_samples > 1: _consider_param = "probs" if args.latents < 3: old_method = model.sampling_method model.sampling_method = Expect("z") optimizer.zero_grad() recon_batch, _, z = model(data) storch.add_cost(loss_function(recon_batch, data), "reconstruction") backward() expect_grad = storch.reduce_plates( z.grad[_consider_param]).detach_tensor() optimizer.zero_grad() model.sampling_method = old_method grads = {n: [] for n in z.grad} for i in range(args.variance_samples): optimizer.zero_grad() recon_batch, _, z = model(data) storch.add_cost(loss_function(recon_batch, data), "reconstruction") backward() for param_name, grad in z.grad.items(): # Make sure to reduce the data dimension and detach, for memory reasons. grads[param_name].append( storch.reduce_plates(grad).detach_tensor()) variances = {} for param_name, gradz in grads.items(): # Create a new independent dimension for the different gradient samples grad_samples = storch.gather_samples(gradz, "variance") # Compute the variance over this independent dimension variances[param_name] = storch.variance( grad_samples, "variance")._tensor if param_name == _consider_param and args.latents < 3: mean = storch.reduce_plates(grad_samples, "variance") mse = storch.reduce_plates( (grad_samples - expect_grad)**2).sum() bias = (storch.reduce_plates( (mean - expect_grad)**2)).sum() print("mse", mse._tensor.item()) # Should approach 0 when increasing variance_samples for unbiased estimators. print("bias", bias._tensor.item()) writer.add_scalar("train/probs_bias", bias._tensor, global_step) writer.add_scalar("train/probs_mse", mse._tensor, global_step) print( "Train Epoch: {} [{}/{} ({:.0f}%)]\tCost: {:.6f}\t Logits var {}" .format( epoch, batch_idx * len(data), len(train_loader.dataset), step, cost.item(), variances, )) writer.add_scalar("train/ELBO", cost, global_step) for param_name, var in variances.items(): writer.add_scalar("train/variance/" + param_name, var, global_step) avg_train_loss = train_loss / (batch_idx + 1) print("====> Epoch: {} Average loss: {:.4f}".format(epoch, avg_train_loss)) return avg_train_loss
storch.add_cost(c, "no_baseline_cost") storch.backward() expect_grad = z.grad["probs"].clone() method = storch.method.UnorderedSetEstimator("x", k=6) # method = storch.REBAR() grads = [] for i in range(100): b = OneHotCategorical(probs=probs) z = method.sample(b) c = (2.4 * z * indices).sum(-1) + 100 storch.add_cost(c, "baseline_cost") storch.backward() grad = z.grad["probs"].clone() grads.append(grad) grad_samples = storch.gather_samples(grads, "variance") mean = storch.reduce_plates(grad_samples, plate_names=["variance"]) print("mean grad", mean) print("expected grad", expect_grad) print("specific_diffs", (mean - expect_grad) ** 2) mse = storch.reduce_plates((grad_samples - expect_grad) ** 2).sum() print("MSE", mse) bias = (storch.reduce_plates((mean - expect_grad) ** 2)).sum() print("bias", bias) # That works... Adding constants to the costs doesn't change the gradient in expectation.