def _fastlin_bilinearwithparam_op(primitive: Primitive, lhs: Union[LinearBound, Tensor], rhs: Union[LinearBound, Tensor], **kwargs) -> LinearBound: """Propagation of Linear bounds through an affine operation. This operation is implemented by one of the bilinear primitives so we know exactly how to do the bound propagation without having to materialize the jacobian to obtain weights. Args: primitive: Linear function to pass through. lhs: Either parameters or LinearBound. rhs: Either parameters or LinearBound. **kwargs: Dict with the parameters of the linear operation. Returns: out_bounds: LinearBound """ # Detect which order things are in, so that we can do forward propagation # simply by calling `fun_call(bound_arg, param_arg)`, whatever the ordering # initially was. if isinstance(lhs, bound_propagation.Bound): assert not isinstance(rhs, bound_propagation.Bound) bound_arg = lhs param_arg = rhs fun_call = functools.partial(primitive.bind, **kwargs) else: assert isinstance(rhs, bound_propagation.Bound) bound_arg = rhs param_arg = lhs fun_call = lambda b_arg, p_arg: primitive.bind(p_arg, b_arg, **kwargs) vmap_funcall = jax.vmap(fun_call, in_axes=(1, None), out_axes=1) # Extract the parameters for the bound propagation. abs_params = jnp.abs(param_arg) # Get access to the LinearBound, in case it is wrapped in an # IntersectionBound. unwrapped_bound = bound_arg.unwrap() # Iterate over the different linear functions that the bound is composed of. out_linfuns = [] for lin_fun in unwrapped_bound.linear_functions(): range_lin = (lin_fun.upper_lin - lin_fun.lower_lin) / 2 mean_lin = (lin_fun.upper_lin + lin_fun.lower_lin) / 2 ref_bound = lin_fun.reference_bound out_range_lin_coeffs = vmap_funcall(range_lin.lin_coeffs, abs_params) out_range_offset = fun_call(range_lin.offset, abs_params) out_mean_lin_coeffs = vmap_funcall(mean_lin.lin_coeffs, param_arg) out_mean_offset = fun_call(mean_lin.offset, param_arg) out_lowerlinexp = LinearExpression( out_mean_lin_coeffs - out_range_lin_coeffs, out_mean_offset - out_range_offset) out_upperlinexp = LinearExpression( out_mean_lin_coeffs + out_range_lin_coeffs, out_mean_offset + out_range_offset) out_linfun = LinearFunction(out_lowerlinexp, out_upperlinexp, ref_bound) out_linfuns.append(out_linfun) return LinearBound(out_linfuns)
def sample(self, key, sample_shape=()): return np.abs(self._cauchy.sample(key, sample_shape))
def soft_sign(x): return x / (np.abs(x) + 1) def sigmoid(x): return expit(x)
def f(z): return np.sum(np.cos(np.abs(z)))
if xtfc: xi, time = LS(xi, L, *x, method='lstsq', timer=True, constant_arg_nums=[1, 2]) else: xi, time = LS(xi, L, *x, timer=True, constant_arg_nums=[1, 2]) # Calculate the test set error nTest = 100 dark = np.meshgrid(np.linspace(x0[0], xf[0], nTest), np.linspace(x0[1], xf[1], nTest)) xTest = (dark[0].flatten(), dark[1].flatten()) err = np.abs(real(*xTest) - u(xi, *xTest)) # Print out solution statistics print("Time: " + str(time)) print("Max error test: " + str(np.max(err))) print("Mean error test: " + str(np.mean(err))) # Create plots if usePlotly: from tfc.utils.PlotlyMakePlot import MakePlot p = MakePlot(r'x', r't', zlabs=r'u(x,t)') p.Surface(x=x[0].reshape((n, n)), y=x[1].reshape((n, n)), z=real(*x).reshape((n, n)), showscale=False)
def wolfe_two(dphi_i): return jnp.abs(dphi_i) <= -c2 * dphi_0
def get_masks_from_jax_params(params, nn_density_level, magnitude_base_bool = True, global_bool = False, reshuffle_seed = 0): """ Assemble a collection of 0-1 valued masks which are of the same sizes and shapes as layers' weight tensors Note that this function ignores bias parameters. Args: params: parameters in a jax.experimental.stax format. nn_density_level: the desired density level for weight parameters. magnitude_base_bool: a boolean variable that decides whether to prune the network by magnitude or randomly prune the network Returns: masks: a collection of 0-1 valued masks which are of the same sizes and shapes as the layers' weight tensors. """ if (type(magnitude_base_bool) != bool) or (type(global_bool) != bool): raise ValueError("magnitude_base_bool and global_bool should be boolean variables") masks = [] if global_bool: weight_magnitudes_pooled = np.concatenate([ np.abs(layer_params[0].flatten()) for layer_params in params if len(layer_params) > 1]) idx = int( (1 - nn_density_level) * np.size(weight_magnitudes_pooled) ) global_thres = np.sort(weight_magnitudes_pooled)[idx] for layer_index in range( len(params)): if len(params[layer_index]) < 2: # In this the case, the layer does not contain weight and bias parameters. masks.append( [] ) elif len(params[layer_index]) == 2: # In this case, the layer contains a tuple of parameters for weights and biases weights = params[layer_index][0] weight_magnitudes = np.abs(weights) if global_bool and magnitude_base_bool: this_mask = np.float32(weight_magnitudes > global_thres) else: # index: number of pruned parameters idx = int( (1 - nn_density_level) * np.size(weights) ) # threshold: entries which below the thredhold will be removed thres = np.sort(np.reshape(weight_magnitudes, [-1] ))[idx] # 0 selected for weight parameters with magnitudes smaller than the threshold, 1 otherwise this_mask = np.float32(weight_magnitudes > thres) if magnitude_base_bool == False: # in the case of random pruning: randomly shuffle the mask this_mask = random.shuffle(random.PRNGKey(0), this_mask ) masks.append(this_mask ) else: raise NotImplementedError return masks
def robust_whiten(x): median = jnp.nanmedian(x) mad = jnp.nanmean(jnp.abs(x - median)) return (x - median) / mad
def general_loss_with_squared_residual(squared_x, alpha, scale): r"""The general loss that takes a squared residual. This fuses the sqrt operation done to compute many residuals while preserving the square in the loss formulation. This implements the rho(x, \alpha, c) function described in "A General and Adaptive Robust Loss Function", Jonathan T. Barron, https://arxiv.org/abs/1701.03077. Args: squared_x: The residual for which the loss is being computed. x can have any shape, and alpha and scale will be broadcasted to match x's shape if necessary. alpha: The shape parameter of the loss (\alpha in the paper), where more negative values produce a loss with more robust behavior (outliers "cost" less), and more positive values produce a loss with less robust behavior (outliers are penalized more heavily). Alpha can be any value in [-infinity, infinity], but the gradient of the loss with respect to alpha is 0 at -infinity, infinity, 0, and 2. Varying alpha allows for smooth interpolation between several discrete robust losses: alpha=-Infinity: Welsch/Leclerc Loss. alpha=-2: Geman-McClure loss. alpha=0: Cauchy/Lortentzian loss. alpha=1: Charbonnier/pseudo-Huber loss. alpha=2: L2 loss. scale: The scale parameter of the loss. When |x| < scale, the loss is an L2-like quadratic bowl, and when |x| > scale the loss function takes on a different shape according to alpha. Returns: The losses for each element of x, in the same shape as x. """ eps = jnp.finfo(jnp.float32).eps # This will be used repeatedly. squared_scaled_x = squared_x / (scale ** 2) # The loss when alpha == 2. loss_two = 0.5 * squared_scaled_x # The loss when alpha == 0. loss_zero = log1p_safe(0.5 * squared_scaled_x) # The loss when alpha == -infinity. loss_neginf = -jnp.expm1(-0.5 * squared_scaled_x) # The loss when alpha == +infinity. loss_posinf = expm1_safe(0.5 * squared_scaled_x) # The loss when not in one of the above special cases. # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by. beta_safe = jnp.maximum(eps, jnp.abs(alpha - 2.)) # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by. alpha_safe = jnp.where( jnp.greater_equal(alpha, 0.), jnp.ones_like(alpha), -jnp.ones_like(alpha)) * jnp.maximum(eps, jnp.abs(alpha)) loss_otherwise = (beta_safe / alpha_safe) * ( jnp.power(squared_scaled_x / beta_safe + 1., 0.5 * alpha) - 1.) # Select which of the cases of the loss to return. loss = jnp.where( alpha == -jnp.inf, loss_neginf, jnp.where( alpha == 0, loss_zero, jnp.where( alpha == 2, loss_two, jnp.where(alpha == jnp.inf, loss_posinf, loss_otherwise)))) return scale * loss
def Lp_torsion_pure(torsion, gammadash, p): arc_length = jnp.linalg.norm(gammadash, axis=1) return (1./p)*jnp.mean(jnp.abs(torsion)**p * arc_length)
def plot( network, axs, axins, metrics, maxN, this_batch, epoch_grid, nn, bins, bandwidth, batch_num, legend=False, reflect=False, histlim=55, ): if "Likelihood scan" in axs: ax = axs["Likelihood scan"] import cabinetry model = make_model(*metrics["yields"]) bonly_pars = (jnp.asarray(model.config.suggested_init()).at[ model.config.poi_index].set(0.0).tolist()) data_hf = model.expected_data(bonly_pars) scan_results = cabinetry.fit.scan(model, data_hf, "mu", par_bounds=[[-2, 10], [-2, 10]]) cabinetry.visualize.scan(scan_results, existing_ax=ax, legend=legend) if "Expected limits" in axs: # if batch_num != 0: ax = axs["Expected limits"] import cabinetry model = make_model(*metrics["yields"]) bonly_pars = (jnp.asarray(model.config.suggested_init()).at[ model.config.poi_index].set(0.0).tolist()) data_hf = model.expected_data(bonly_pars) limit_results = cabinetry.fit.limit(model, data_hf, maxiter=1000) cabinetry.visualize.limit(limit_results, existing_ax=ax, legend=legend) if "Data space" in axs: ax = axs["Data space"] g = np.mgrid[-5:5:101j, -5:5:101j] if jnp.inf in bins: levels = bins[1:-1] # infinite else: levels = bins ax.contourf( g[0], g[1], nn(network, np.moveaxis(g, 0, -1)).reshape(101, 101, 1)[:, :, 0], levels=levels, cmap="binary", ) ax.contour( g[0], g[1], nn(network, np.moveaxis(g, 0, -1)).reshape(101, 101, 1)[:, :, 0], colors="w", levels=levels, ) sig, bkg_nom, bkg_up, bkg_down = this_batch # should definitely not have to repeat this every time lmao ax.scatter(sig[:, 0], sig[:, 1], alpha=0.3, c="C9", label="signal") ax.scatter( bkg_up[:, 0], bkg_up[:, 1], alpha=0.1, c="orangered", marker=6, label="bkg up", ) ax.scatter( bkg_down[:, 0], bkg_down[:, 1], alpha=0.1, c="gold", marker=7, label="bkg down", ) ax.scatter(bkg_nom[:, 0], bkg_nom[:, 1], alpha=0.3, c="C1", label="bkg") ax.set_xlim(-5, 5) ax.set_ylim(-5, 5) ax.set_xlabel("x") ax.set_ylabel("y") if legend: ax.legend(fontsize="x-small", loc="upper right", fancybox=True) # ax.axis("off") x_grid = epoch_grid[:batch_num + 1] if "Losses" in axs: ax = axs["Losses"] ax.plot( epoch_grid[:batch_num + 1], metrics["loss"], c="C9", linewidth=2.0, label=r"train", ) ax.plot( epoch_grid[:batch_num + 1], metrics["test_loss"], c="C4", linewidth=2.0, label=r"test", ) ax.set_yscale("log") ax.set_xlim(0, maxN) ax.set_xlabel("epoch") ax.set_ylabel(r"loss value") if legend: ax.legend(fontsize="x-small", loc="upper right", fancybox=True) if "Metrics" in axs: ax = axs["Metrics"] ax.plot( x_grid, metrics["1-pull_width**2"], c="slategray", linewidth=2.0, label=r"$(1-\sigma_{\mathsf{nuisance}})^2$", # linestyle=":" ) ax.plot( x_grid, np.array(metrics["pull"])**2, c="C2", linewidth=2.0, label=r"(nuisance pull)$^2$", # linestyle=':' ) ax.plot( x_grid, metrics["mu_uncert"], c="steelblue", linewidth=2.0, label=r"$\sigma_\mu$", ) ax.plot(x_grid, metrics["CLs"], c="C9", linewidth=2, label=r"$CL_s$") ax.set_ylim(1e-7, 1e-0) ax.set_xlim(0, maxN) ax.set_xlabel("epoch") ax.set_yscale("log") ax.set_ylabel(r"metric value (on test set)") if legend: ax.legend(fontsize="x-small", loc="upper right", fancybox=True) if "Histogram model" in axs: ax = axs["Histogram model"] s, b, bup, bdown = metrics["yields"] if jnp.inf in bins: noinf = bins[1:-1] bin_width = 1 / (len(noinf) - 1) centers = noinf[:-1] + np.diff(noinf) / 2.0 centers = jnp.array( [noinf[0] - bin_width, *centers, noinf[-1] + bin_width]) dct = { "signal": s, "bkg up": bup, "bkg": b, "bkg down": bdown, } a, b = bar_plot( ax, dct, colors=["C9", "orangered", "C1", "gold"], total_width=0.8, single_width=1, legend=legend, bins=bins, ) ax.set_ylabel("frequency") ax.set_xlabel("interval over nn output") ax.set_ylim(0, histlim) # ax.axis("off") if legend: # Draw legend if we need # ax.legend(a, b, fontsize="x-small") if jnp.inf in bins: width = jnp.diff(noinf)[0] else: width = jnp.diff(bins)[0] xlim = ([(width / 2) - (1.1 * bandwidth), (width / 2) + (1.1 * bandwidth)] if (width / 2) - bandwidth < 0 else [-width / 3, width + width / 3]) axins.stairs([1], [0, width], color="C1", alpha=0.6) y = jnp.linspace(xlim[0], xlim[1], 300) demo = jsp.stats.norm.pdf(y, loc=width / 2, scale=bandwidth) axins.plot(y, demo / max(demo), color="C0", linestyle="dashed", label="kernel") # draw two vertical lines at ((width/2)-bandwidth)/2 and ((width/2)+bandwidth)/2 axins.vlines( [(width / 2) - bandwidth, (width / 2) + bandwidth], 0, 1, colors="black", linestyles="dotted", label=r"$\pm$bandwidth", alpha=0.9, ) # write text in the middle of the vertical lines with the value of the bandwidth ratio = bandwidth / width axins.text( width / 2, -0.3, r"$\mathsf{\frac{bandwidth}{bin\,width}}=$" + f"{ratio:.2f}", ha="center", va="center", size="x-small", alpha=0.9, ) axins.set_xlim(*xlim) handles, labels = a, list(b) # ax.get_legend_handles_labels() handles1, labels1 = axins.get_legend_handles_labels() ax.legend( handles + handles1, labels + labels1, loc="upper right", fontsize="x-small", fancybox=True, ) if "Nuisance pull" in axs: ax = axs["Nuisance pull"] pulls = metrics["pull"] pullerr = metrics["pull_width"] ax.set_ylabel(r"$(\theta - \hat{\theta})\,/ \Delta \theta$", fontsize=18) # draw the +/- 2.0 horizontal lines ax.hlines([-2, 2], -0.5, len(pulls) - 0.5, colors="black", linestyles="dotted") # draw the +/- 1.0 horizontal lines ax.hlines([-1, 1], -0.5, len(pulls) - 0.5, colors="black", linestyles="dashdot") # draw the +/- 2.0 sigma band ax.fill_between([-0.5, len(pulls) - 0.5], [-2, -2], [2, 2], facecolor="yellow") # drawe the +/- 1.0 sigma band ax.fill_between([-0.5, len(pulls) - 0.5], [-1, -1], [1, 1], facecolor="green") # draw a horizontal line at pull=0.0 ax.hlines([0], -0.5, len(pulls) - 0.5, colors="black", linestyles="dashed") ax.scatter(range(len(pulls)), pulls, color="black") # and their uncertainties ax.errorbar( range(len(pulls)), pulls, color="black", xerr=0, yerr=pullerr, marker=".", fmt="none", ) if "Example KDE" in axs: ax = axs["Example KDE"] b_data = bkg_nom d = np.array(nn(network, b_data).ravel().tolist()) kde = make_kde(d, bandwidth) yields = b ls = [-1, 2] x = np.linspace(ls[0], ls[1], 300) db = jnp.array(jnp.diff(bins), float) # bin spacing yields = yields / db / yields.sum(axis=0) # normalize to bin width if jnp.inf in bins: pbins = [ls[0], *noinf, ls[1]] else: pbins = bins ax.stairs(yields, pbins, label="KDE hist", color="C1") if reflect: ax.plot(x, 2 * jnp.abs(kde(x)), label="KDE", color="C0") else: ax.plot(x, kde(x), label="KDE", color="C0") ax.set_xlim(*ls) # rug plot of the data ax.plot( d, jnp.zeros_like(d) - 0.01, "|", linewidth=3, alpha=0.4, color="black", label="data", ) if legend: if jnp.inf in bins: width = jnp.diff(noinf)[0] else: width = jnp.diff(bins)[0] xlim = ([(width / 2) - (1.1 * bandwidth), (width / 2) + (1.1 * bandwidth)] if (width / 2) - bandwidth < 0 else [-width / 3, width + width / 3]) axins.stairs([1], [0, width], color="C1") y = jnp.linspace(xlim[0], xlim[1], 300) demo = jsp.stats.norm.pdf(y, loc=width / 2, scale=bandwidth) axins.plot(y, demo / max(demo), color="C0", linestyle="dashed", label="kernel") # draw two vertical lines at ((width/2)-bandwidth)/2 and ((width/2)+bandwidth)/2 axins.vlines( [(width / 2) - bandwidth, (width / 2) + bandwidth], 0, 1, colors="black", linestyles="dotted", label=r"$\pm$bandwidth", ) # write text in the middle of the vertical lines with the value of the bandwidth ratio = bandwidth / width axins.text( width / 2, -0.3, r"$\mathsf{\frac{bandwidth}{bin\,width}}=$" + f"{ratio:.2f}", ha="center", va="center", size="x-small", ) axins.set_xlim(*xlim) handles, labels = ax.get_legend_handles_labels() handles1, labels1 = axins.get_legend_handles_labels() ax.legend( handles + handles1, labels + labels1, loc="upper right", fontsize="x-small", fancybox=True, )
def soft_threshold(x, a): return jnp.maximum(jnp.abs(x) - a, 0.0) * jnp.sign(x)
def training_loop(env=None, env_name="CartPole-v0", epochs=EPOCHS, policy_net_fun=None, value_net_fun=None, policy_and_value_net_fun=None, policy_optimizer_fun=None, value_optimizer_fun=None, policy_and_value_optimizer_fun=None, batch_size=BATCH_TRAJECTORIES, num_optimizer_steps=NUM_OPTIMIZER_STEPS, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, boundary=20, max_timestep=None, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, epsilon=EPSILON, c1=1.0, c2=0.01): """Runs the training loop for PPO, with fixed policy and value nets.""" jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed) value_losses = [] ppo_objective = [] combined_losses = [] average_rewards = [] env = env if env is not None else gym.make(env_name) # Batch Observations Shape = [-1, -1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (-1, -1) + env.observation_space.shape assert isinstance(env.action_space, gym.spaces.Discrete) num_actions = env.action_space.n policy_and_value_net_params, policy_and_value_net_apply = None, None policy_and_value_opt_state, policy_and_value_opt_update = None, None policy_net_params, policy_net_apply = None, None value_net_params, value_net_apply = None, None if policy_and_value_net_fun is not None: jax_rng_key, subkey = jax_random.split(jax_rng_key) # Initialize the policy and value network. policy_and_value_net_params, policy_and_value_net_apply = ( policy_and_value_net_fun(subkey, batch_observations_shape, num_actions)) # Initialize the optimizers. policy_and_value_opt_state, policy_and_value_opt_update = ( policy_and_value_optimizer_fun(policy_and_value_net_params)) else: # Initialize the policy and value functions. assert policy_net_fun and value_net_fun jax_rng_key, key1, key2 = jax_random.split(jax_rng_key, num=3) policy_net_params, policy_net_apply = policy_net_fun( key1, batch_observations_shape, num_actions) value_net_params, value_net_apply = value_net_fun( key2, batch_observations_shape, num_actions) # Initialize the optimizers. ppo_opt_state, ppo_opt_update = policy_optimizer_fun(policy_net_params) value_opt_state, value_opt_update = value_optimizer_fun( value_net_params) # A function that will call the appropriate policy function with parameters. def get_policy_output(observations): if policy_net_apply is not None: assert policy_net_params return policy_net_apply(observations, policy_net_params) assert policy_and_value_net_apply and policy_and_value_net_params policy_predictions, unused_value_predictions = policy_and_value_net_apply( observations, policy_and_value_net_params) return policy_predictions for i in range(epochs): t = time.time() t0 = t logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i) trajs = collect_trajectories( env, policy_fun=get_policy_output, num_trajectories=batch_size, policy=POLICY, max_timestep=max_timestep, epsilon=(10.0 / (i + 10.0))) # this is a different epsilon. avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs) max_reward = max(np.sum(traj[2]) for traj in trajs) min_reward = min(np.sum(traj[2]) for traj in trajs) average_rewards.append(avg_reward) logging.vlog(1, "Rewards average=[%0.2f], max=[%0.2f], min=[%0.2f]", avg_reward, max_reward, min_reward) logging.vlog(1, "Collecting trajectories took %0.2f msec.", get_time(t)) logging.vlog( 1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]", float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) t = time.time() (_, reward_mask, padded_observations, padded_actions, padded_rewards) = pad_trajectories(trajs, boundary=boundary) logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(t)) logging.vlog(1, "Padded Observations' shape [%s]", str(padded_observations.shape)) logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape)) logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape)) # Some assertions. B, T = padded_actions.shape # pylint: disable=invalid-name assert (B, T) == padded_rewards.shape assert (B, T) == reward_mask.shape assert (B, T + 1) == padded_observations.shape[:2] assert (B, T + 1) + env.observation_space.shape == padded_observations.shape # Linear annealing from 0.1 to 0.0 epsilon_schedule = epsilon if epochs == 1 else epsilon * ( 1.0 - (i / (epochs - 1))) # Compute value and ppo losses. cur_value_loss, cur_ppo_loss, cur_combined_loss = None, None, None if policy_and_value_net_apply is not None: t = time.time() cur_combined_loss, cur_ppo_loss, cur_value_loss, _ = ( combined_loss(policy_and_value_net_params, policy_and_value_net_params, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2)) logging.vlog( 1, "Calculating P&V loss [%10.2f(%10.2f, %10.2f)] took %0.2f msec.", cur_combined_loss, cur_value_loss, cur_ppo_loss, get_time(t)) else: t = time.time() cur_value_loss = value_loss(value_net_apply, value_net_params, padded_observations, padded_rewards, reward_mask, gamma=gamma) logging.vlog(1, "Calculating value loss took %0.2f msec.", get_time(t)) t = time.time() cur_ppo_loss = ppo_loss(policy_net_apply, policy_net_params, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule) logging.vlog(1, "Calculating PPO loss took %0.2f msec.", get_time(t)) value_losses.append(cur_value_loss) ppo_objective.append(-1.0 * cur_ppo_loss) combined_losses.append(cur_combined_loss) if policy_and_value_net_apply: logging.vlog(1, "Policy and Value Optimization") t1 = time.time() for j in range(num_optimizer_steps): t = time.time() # Update the optimizer state. policy_and_value_opt_state = policy_and_value_opt_step( j, policy_and_value_opt_state, policy_and_value_opt_update, policy_and_value_net_apply, policy_and_value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, c1=c1, c2=c2, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule) t2 = time.time() # Get the new params. new_policy_and_value_net_params = trax_opt.get_params( policy_and_value_opt_state) if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): # Compute and log the loss. (loss_combined, loss_ppo, loss_value, unused_entropy_bonus) = ( combined_loss( new_policy_and_value_net_params, policy_and_value_net_params, # old params policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2)) logging.vlog( 1, "One Policy and Value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog( 1, "Combined Loss(value, ppo) [%10.2f] -> [%10.2f(%10.2f,%10.2f)]", cur_combined_loss, loss_combined, loss_value, loss_ppo) # Update the params. policy_and_value_net_params = new_policy_and_value_net_params logging.vlog(1, "Total PPO loss reduction [%0.2f]%%", (100 * (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss))) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], Combined" " Loss(value, ppo) [%10.2f(%10.2f,%10.2f)], took [%10.2f msec]", i, min_reward, max_reward, avg_reward, loss_combined, loss_value, loss_ppo, get_time(t1)) else: # Run optimizers. logging.vlog(1, "PPO Optimization") t1 = time.time() for j in range(num_optimizer_steps): t = time.time() # Update the optimizer state. ppo_opt_state = ppo_opt_step( j, ppo_opt_state, ppo_opt_update, policy_net_apply, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, ) t2 = time.time() # Get the new params. new_policy_net_params = trax_opt.get_params(ppo_opt_state) if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): new_ppo_loss = ppo_loss( policy_net_apply, new_policy_net_params, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, ) logging.vlog(1, "One PPO grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog(1, "PPO loss [%10.2f] -> [%10.2f]", cur_ppo_loss, new_ppo_loss) # Update the params. policy_net_params = new_policy_net_params logging.vlog( 1, "Total PPO loss reduction [%0.2f]%%", (100 * (cur_ppo_loss - new_ppo_loss) / np.abs(cur_ppo_loss))) logging.vlog(1, "Value Optimization") for j in range(num_optimizer_steps): t = time.time() value_opt_state = value_opt_step(j, value_opt_state, value_opt_update, value_net_apply, padded_observations, padded_rewards, reward_mask, gamma=gamma) t2 = time.time() value_net_params = trax_opt.get_params(value_opt_state) if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): new_value_loss = value_loss(value_net_apply, value_net_params, padded_observations, padded_rewards, reward_mask, gamma=gamma) logging.vlog(1, "One value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog(1, "Value loss [%10.2f] -> [%10.2f]", cur_value_loss, new_value_loss) logging.vlog( 1, "Total value loss reduction [%0.2f]%%", (100 * (cur_value_loss - new_value_loss) / np.abs(cur_value_loss))) logging.vlog(1, "Grad desc took %0.2f msec", get_time(t1)) # Set the optimized params to new params. policy_net_params = trax_opt.get_params(ppo_opt_state) value_net_params = trax_opt.get_params(value_opt_state) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], " "ppo loss [%10.2f], value loss [%10.2f], took [%10.2f msec]", i, min_reward, max_reward, avg_reward, new_ppo_loss, new_value_loss, get_time(t0)) # Log the parameters, just for the sake of it. if policy_net_params: log_params(policy_net_params, "policy_net_params") if value_net_params: log_params(value_net_params, "value_net_params") if policy_and_value_net_params: log_params(policy_and_value_net_params, "policy_and_value_net_params") if value_losses: logging.vlog(1, "value_losses: %s", np.stack(value_losses)) if ppo_objective: logging.vlog(1, "ppo_objective: %s", np.stack(ppo_objective)) if average_rewards: logging.vlog(1, "average_rewards: %s", average_rewards) return ((policy_net_params, value_net_params), average_rewards, np.stack(value_losses), np.stack(ppo_objective))
def prune(model, pruning_rate, saliency_fn=weight_magnitude, mask=None, compare_fn=jnp.greater): """Returns a mask for a model where the params in each layer are pruned using a saliency function. Args: model: The model to create a pruning mask for. pruning_rate: The fraction of lowest magnitude saliency weights that are pruned. If a float, the same rate is used for all layers, otherwise if it is a mapping, it must contain a rate for all masked layers in the model. saliency_fn: A function that returns a float number used to rank the importance of individual weights in the layer. mask: If the model has an existing mask, the mask will be applied before pruning the model. compare_fn: A pairwise operator to compare saliency with threshold, and return True if the saliency indicates the value should not be masked. Returns: A pruned mask for the given model. """ if not mask: mask = masked.simple_mask(model, jnp.ones, masked.WEIGHT_PARAM_NAMES) if not isinstance(pruning_rate, collections.Mapping): pruning_rate_dict = {} for param_name, _ in masked.iterate_mask(mask): # Get the layer name from the parameter's full name/path. layer_name = param_name.split('/')[-2] pruning_rate_dict[layer_name] = pruning_rate pruning_rate = pruning_rate_dict for param_path, param_mask in masked.iterate_mask(mask): split_param_path = param_path.split('/') layer_name = split_param_path[-2] param_name = split_param_path[-1] # If we don't have a pruning rate for the given layer, don't mask it. if layer_name in pruning_rate and mask[layer_name][ param_name] is not None: param_value = model.params[layer_name][ masked.MaskedModule.UNMASKED][param_name] # Here any existing mask is first applied to weight matrix. # Note: need to check explicitly is not None for np array. if param_mask is not None: saliencies = saliency_fn(param_mask * param_value) else: saliencies = saliency_fn(param_value) # TODO: Use partition here (partial sort) instead of sort, # since it's O(N), not O(N log N), however JAX doesn't support it. sorted_param = jnp.sort(jnp.abs(saliencies.flatten())) # Figure out the weight magnitude threshold. threshold_index = jnp.round(pruning_rate[layer_name] * sorted_param.size).astype(jnp.int32) threshold = sorted_param[threshold_index] mask[layer_name][param_name] = jnp.array(compare_fn( saliencies, threshold), dtype=jnp.int32) return mask
def run(self): ## Energy fot the Destruction # Required Power Calculation reqPower = self.material / self.dwell_time ## Time averaging with thermal blooming and tubulence # Assume that Turbulence is randomly changed according to time variation or flowed by wind speed phz_turbs = np.zeros((self.N, self.N, len(self.z), self.n_iter)) l2_target2_turbs = np.zeros((self.N, self.N, self.n_iter)) n_step = 51 count = 0 l2_new = np.ndarray((512, 512, 51)) delta_n = np.ndarray((512, 512, 51)) for i in range(len(self.z)): # print(l2[:,:,i]) l2_new[:, :, i] = jnp.roll(self.l2[:, :, i], np.int(jnp.round(self.move_pixel / n_step) * i)) n_temp = -self.mu / self.v_wind * l2_new[:, :, i] delta_n[:, :, i] = n_temp centers = [] infl_t, invs = func.dm() target2_turbs = np.zeros((self.N, self.N, self.n_iter)) target2_holes = np.zeros((self.N, self.N, self.n_iter)) for j in range(self.n_iter): Uin = self.u1[:, :, 0] * jnp.exp(complex("0+j") * self.k / (2 * self.Z) * self.r ** 2) \ * func.make_nanmask(self.r / (self.D0 / 2.67), 1) Uout_turb = Uin for i in range(1, len(self.z)): Uout_turb = func.fft_BPM(Uout_turb, self.P, self.wvl, max(self.x1[:]), max(self.x1[:]), self.delta_z / 2, \ self.d1, self.delta_z / 2, self.n0) # Atmosphere phz_turb = func.ft_phase_screen(self.r0sw, self.N, self.d1, 100, 0.01) phz_turbs[:, :, i, j] = phz_turb delta_n2 = self.k * np.trapz(delta_n[:, :, i - 1:i + 1], [i - 1, i], axis=2) delta_n2 = delta_n2 / jnp.min(delta_n2) phz_total = delta_n2 + jnp.sum(phz_turbs[:, :, :, j], axis=2) / len(self.z) cmd = jnp.dot(invs, phz_total.reshape(512 ** 2, 1)) recon_phz = jnp.dot(infl_t, cmd) resi = phz_total - recon_phz.reshape(512, 512) Uout_turb = Uout_turb * jnp.exp(complex("j") * resi) Uout_turb = func.fft_BPM(Uout_turb, self.P, self.wvl, max(self.x1[:]), max(self.x1[:]), self.delta_z / 2, \ self.d1, self.delta_z / 2, self.n0) count += 100 / (self.n_iter * (len(self.z) - 1)) self.progress.emit(int(count)) l2_target2_turb = jnp.abs(jnp.power(Uout_turb, 2)) A0 = np.trapz(np.trapz(l2_target2_turb, self.x1, axis=1), self.y1) l2_target2_turb = l2_target2_turb * (1 / A0) * self.P0 l2_target2_turbs[:, :, j] = l2_target2_turb * self.delta_t l3 = jnp.sum(l2_target2_turbs, axis=2) img = np.uint8(l3 * 255 / jnp.max(l3)) target2_turbs[:, :, j] = img img2 = np.zeros(shape=l3.shape, dtype=np.uint8) img2[np.where(l3 > reqPower)] = 255 target2_holes[:, :, j] = img2 self.result.emit(target2_turbs, target2_holes)
def __call__(self, x): return np.abs(x)
def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4, c2=0.9, maxiter=20): """Inexact line search that satisfies strong Wolfe conditions. Algorithm 3.5 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 59-61 Args: fun: function of the form f(x) where x is a flat ndarray and returns a real scalar. The function should be composed of operations with vjp defined. x0: initial guess. pk: direction to search in. Assumes the direction is a descent direction. old_fval, gfk: initial value of value_and_gradient as position. old_old_fval: unused argument, only for scipy API compliance. maxiter: maximum number of iterations to search c1, c2: Wolfe criteria constant, see ref. Returns: LineSearchResults """ def restricted_func_and_grad(t): phi, g = jax.value_and_grad(f)(xk + t * pk) dphi = jnp.dot(g, pk) return phi, dphi, g if old_fval is None or gfk is None: phi_0, dphi_0, gfk = restricted_func_and_grad(0.) else: phi_0 = old_fval dphi_0 = jnp.dot(gfk, pk) if old_old_fval is not None: candidate_start_value = 1.01 * 2 * (phi_0 - old_old_fval) / dphi_0 start_value = jnp.where(candidate_start_value > 1, 1.0, candidate_start_value) else: start_value = 1 def wolfe_one(a_i, phi_i): # actually negation of W1 return phi_i > phi_0 + c1 * a_i * dphi_0 def wolfe_two(dphi_i): return jnp.abs(dphi_i) <= -c2 * dphi_0 state = _LineSearchState( done=False, failed=False, # algorithm begins at 1 as per Wright and Nocedal, however Scipy has a # bug and starts at 0. See https://github.com/scipy/scipy/issues/12157 i=1, a_i1=0., phi_i1=phi_0, dphi_i1=dphi_0, nfev=1 if (old_fval is None or gfk is None) else 0, ngev=1 if (old_fval is None or gfk is None) else 0, a_star=0., phi_star=phi_0, dphi_star=dphi_0, g_star=gfk, ) def body(state): # no amax in this version, we just double as in scipy. # unlike original algorithm we do our next choice at the start of this loop a_i = jnp.where(state.i == 1, start_value, state.a_i1 * 2.) phi_i, dphi_i, g_i = restricted_func_and_grad(a_i) state = state._replace(nfev=state.nfev + 1, ngev=state.ngev + 1) star_to_zoom1 = wolfe_one(a_i, phi_i) | ((phi_i >= state.phi_i1) & (state.i > 1)) star_to_i = wolfe_two(dphi_i) & (~star_to_zoom1) star_to_zoom2 = (dphi_i >= 0.) & (~star_to_zoom1) & (~star_to_i) zoom1 = _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, state.a_i1, state.phi_i1, state.dphi_i1, a_i, phi_i, dphi_i, gfk, ~star_to_zoom1) state = state._replace(nfev=state.nfev + zoom1.nfev, ngev=state.ngev + zoom1.ngev) zoom2 = _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_i, phi_i, dphi_i, state.a_i1, state.phi_i1, state.dphi_i1, gfk, ~star_to_zoom2) state = state._replace(nfev=state.nfev + zoom2.nfev, ngev=state.ngev + zoom2.ngev) state = state._replace( done=star_to_zoom1 | state.done, failed=(star_to_zoom1 & zoom1.failed) | state.failed, **_binary_replace( star_to_zoom1, state._asdict(), zoom1._asdict(), keys=['a_star', 'phi_star', 'dphi_star', 'g_star'], ), ) state = state._replace( done=star_to_i | state.done, **_binary_replace( star_to_i, state._asdict(), dict( a_star=a_i, phi_star=phi_i, dphi_star=dphi_i, g_star=g_i, ), ), ) state = state._replace( done=star_to_zoom2 | state.done, failed=(star_to_zoom2 & zoom2.failed) | state.failed, **_binary_replace( star_to_zoom2, state._asdict(), zoom2._asdict(), keys=['a_star', 'phi_star', 'dphi_star', 'g_star'], ), ) state = state._replace(i=state.i + 1, a_i1=a_i, phi_i1=phi_i, dphi_i1=dphi_i) return state state = while_loop( lambda state: (~state.done) & (state.i <= maxiter) & (~state.failed), body, state) status = jnp.where( state.failed, jnp.array(1), # zoom failed jnp.where( state.i > maxiter, jnp.array(3), # maxiter reached jnp.array(0), # passed (should be) ), ) # Step sizes which are too small causes the optimizer to get stuck with a # direction of zero in <64 bit mode - avoid with a floor on minimum step size. alpha_k = state.a_star alpha_k = jnp.where((jnp.finfo(alpha_k).bits != 64) & (jnp.abs(alpha_k) < 1e-8), jnp.sign(alpha_k) * 1e-8, alpha_k) results = _LineSearchResults( failed=state.failed | (~state.done), nit=state.i - 1, # because iterations started at 1 nfev=state.nfev, ngev=state.ngev, k=state.i, a_k=alpha_k, f_k=state.phi_star, g_k=state.g_star, status=status, ) return results
def log_abs_det_jacobian(self, x, y, intermediates=None): return sum_rightmost( np.broadcast_to(np.log(np.abs(self.scale)), np.shape(x)), self.event_dim)
def calculate_td_error(q_value_vec, target_q_value_vec, action, reward): td_target = reward + gamma * jnp.amax(target_q_value_vec) td_error = td_target - q_value_vec[action] return jnp.abs(td_error)
def log_abs_det_jacobian(self, x, y, intermediates=None): return np.log(np.abs(self.exponent * y / x))
def matrix_inverse_pth_root(mat_g, p, iter_count=100, error_tolerance=1e-6, ridge_epsilon=1e-6): """Computes mat_g^(-1/p), where p is a positive integer. Coupled newton iterations for matrix inverse pth root. Args: mat_g: the symmetric PSD matrix whose power it to be computed p: exponent, for p a positive integer. iter_count: Maximum number of iterations. error_tolerance: Error indicator, useful for early termination. ridge_epsilon: Ridge epsilon added to make the matrix positive definite. Returns: mat_g^(-1/p) """ mat_g_size = mat_g.shape[0] alpha = jnp.asarray(-1.0 / p, _INVERSE_PTH_ROOT_DATA_TYPE) identity = jnp.eye(mat_g_size, dtype=_INVERSE_PTH_ROOT_DATA_TYPE) _, max_ev, _ = power_iter(mat_g) ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16) def _unrolled_mat_pow_1(mat_m): """Computes mat_m^1.""" return mat_m def _unrolled_mat_pow_2(mat_m): """Computes mat_m^2.""" return jnp.matmul(mat_m, mat_m, precision=_INVERSE_PTH_ROOT_PRECISION) def _unrolled_mat_pow_4(mat_m): """Computes mat_m^4.""" mat_pow_2 = _unrolled_mat_pow_2(mat_m) return jnp.matmul(mat_pow_2, mat_pow_2, precision=_INVERSE_PTH_ROOT_PRECISION) def _unrolled_mat_pow_8(mat_m): """Computes mat_m^4.""" mat_pow_4 = _unrolled_mat_pow_4(mat_m) return jnp.matmul(mat_pow_4, mat_pow_4, precision=_INVERSE_PTH_ROOT_PRECISION) def mat_power(mat_m, p): """Computes mat_m^p, for p == 1, 2, 4 or 8. Args: mat_m: a square matrix p: a positive integer Returns: mat_m^p """ # We unrolled the loop for performance reasons. exponent = jnp.round(jnp.log2(p)) return lax.switch(jnp.asarray(exponent, jnp.int32), [ _unrolled_mat_pow_1, _unrolled_mat_pow_2, _unrolled_mat_pow_4, _unrolled_mat_pow_8, ], (mat_m)) def _iter_condition(state): (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, run_step) = state error_above_threshold = jnp.logical_and(error > error_tolerance, run_step) return jnp.logical_and(i < iter_count, error_above_threshold) def _iter_body(state): (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state mat_m_i = (1 - alpha) * identity + alpha * mat_m new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=_INVERSE_PTH_ROOT_PRECISION) new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=_INVERSE_PTH_ROOT_PRECISION) new_error = jnp.max(jnp.abs(new_mat_m - identity)) # sometimes error increases after an iteration before decreasing and # converging. 1.2 factor is used to bound the maximal allowed increase. return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, new_error < error * 1.2) if mat_g_size == 1: resultant_mat_h = (mat_g + ridge_epsilon)**alpha error = 0 else: damped_mat_g = mat_g + ridge_epsilon * identity z = (1 + p) / (2 * jnp.linalg.norm(damped_mat_g)) new_mat_m_0 = damped_mat_g * z new_error = jnp.max(jnp.abs(new_mat_m_0 - identity)) new_mat_h_0 = identity * jnp.power(z, 1.0 / p) init_state = tuple( [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True]) _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop( _iter_condition, _iter_body, init_state) error = jnp.max(jnp.abs(mat_m - identity)) is_converged = jnp.asarray(convergence, old_mat_h.dtype) resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h resultant_mat_h = jnp.asarray(resultant_mat_h, mat_g.dtype) return resultant_mat_h, error
def log_abs_det_jacobian(self, x, y, intermediates=None): x_abs = np.abs(x) return -x_abs - 2 * np.log1p(np.exp(-x_abs))
def keep_smallest_abs(xx1, xx2): t = 0 which = (jn.abs(xx1 + t) < jn.abs(xx2 + t)) + 0.0 return xx1 * which + xx2 * (1 - which)
def absolute_reward_diff(r1, r2): return jnp.abs(r1 - r2)
def assertArrayNotEqual(self, x, y, margin=margin): reldiff = jnp.abs(2 * (x - y) / (x + y + 1e-16)) maxdiff = jnp.max(reldiff) assert float(maxdiff) > self.margin
def manhattan_distance(x: np.array, y: np.array) -> float: return np.sum(np.abs(x - y))
def sample(self, key, sample_shape=()): return np.abs(self._normal.sample(key, sample_shape))
"wallclock": time.time() - start_time }) return get_params(opt_state) # See https://github.com/google/jax/issues/7809. binarize = lambda arr: tree_map(lambda x: x > 0.5, arr) print("Training normal model...") everything_mask = tree_map(lambda x: jnp.ones_like(x, dtype=jnp.dtype("bool")), init_params) final_params = train(init_params, everything_mask, "no_mask") # Mask as was implemented in the original paper print("Training lottery ticket model...") final_params_flat, unravel = ravel_pytree(final_params) cutoff = jnp.percentile(jnp.abs(final_params_flat), config.remove_percentile) mask = binarize(unravel(jnp.abs(final_params_flat) > cutoff)) train(init_params, mask, "lottery_mask") print("Training lottery ticket sign model...") # The lottery ticket mask but instead of using the initial weights, just use # the sign of the initial weights. mask = binarize(unravel(jnp.abs(final_params_flat) > cutoff)) w0 = tree_map(lambda x, m: 0.01 * jnp.sign(x) * m, final_params, mask) train(w0, mask, "lottery_sign_mask") # Totally random mask print("Training random mask model...") mask = binarize( unravel(random.uniform(rp.poop(), final_params_flat.shape) > config.remove_percentile / 100)) train(init_params, mask, "random_mask")
def dist_sq(R): dR = R[:, np.newaxis, :] - R[np.newaxis, :, :] zero = np.zeros_like(dR) dR = dR - np.where(np.abs(dR) < 0.5, zero, 0.5 * np.sign(dR)) return np.sum(dR**2, axis=2)
def l2_loss(w: jnp.ndarray) -> jnp.ndarray: """L2 loss, sum_n |w[n]|^2.""" return jnp.sum(jnp.abs(w)**2)