def shuffle(key, Xa, Xb): ''' Randomly shuffle examples in Xa and Xb along the zeroth axis. Args: key: random PRNGkey Xa: (P,N) first array to shuffle Xb: (P,N) second array to shuffle Returns: Xaperm: (P,N) shuffled copy of Xa Xbperm: (P,N) shuffled copy of Xb ''' keya, keyb = random.split(key) perma = random.shuffle(keya, np.arange(len(Xa))) permb = random.shuffle(keyb, np.arange(len(Xb))) return Xa[perma], Xb[permb]
def init_fun(self, rng, input_shape, permutation=None): """ :param rng: rng used to initialize parameters :param input_shape: input shape :param permutation: an optional permutation that is applied to the inputs and controls the order of the autoregressive factorization. in particular for the identity permutation the autoregressive structure is such that the Jacobian is triangular. By default this is chosen at random. :type permutation: array of ints """ if permutation is None: # By default set a random permutation of variables, which is important for performance with multiple steps rng, rng_perm = random.split(rng) self.permutation = onp.array(random.shuffle(rng_perm, np.arange(self.input_dim))) else: self.permutation = permutation # Create masks (no skip connections allowed; TODO add support) masks, _ = create_mask(input_dim=self.input_dim, hidden_dims=self.hidden_dims, permutation=self.permutation, output_dim_multiplier=self.output_multiplier) # Create masked layers self.masked_layers = [MaskedDense(mask) for mask in masks] init_params = [] for mask in self.masked_layers: mask_init = mask[0] input_shape, param = mask_init(rng, input_shape) init_params.append(param) return input_shape, init_params
def test_flows(flow_class, flow_args, input_dim, batch_shape): transform = flow_class(*flow_args) x = random.normal(random.PRNGKey(0), batch_shape + (input_dim, )) # test inverse is correct y = transform(x) inv = transform.inv(y) assert_allclose(x, inv, atol=1e-5) # test jacobian shape actual = transform.log_abs_det_jacobian(x, y) assert onp.shape(actual) == batch_shape if batch_shape == (): # make sure transform.log_abs_det_jacobian is correct jac = jacfwd(transform)(x) expected = onp.linalg.slogdet(jac)[1] assert_allclose(actual, expected, atol=1e-5) # make sure jacobian is triangular, first permute jacobian as necessary if isinstance(transform, InverseAutoregressiveTransform): permuted_jac = onp.zeros(jac.shape) _, rng_key_perm = random.split(random.PRNGKey(0)) perm = random.shuffle(rng_key_perm, onp.arange(input_dim)) for j in range(input_dim): for k in range(input_dim): permuted_jac[j, k] = jac[perm[j], perm[k]] assert onp.sum(onp.abs(onp.triu(permuted_jac, 1))) == 0.00
def shuffle(key_and_data): key, data = key_and_data key, subkey = random.split(key) datapoints_per_device = data[0].shape[0] indices = np.arange(datapoints_per_device) perm = random.shuffle(subkey, indices) return key, list(map(lambda x: x[perm], data)), 0
def _find_binning_thresholds(data, max_bins=256, subsample=200000, random_state=None): if 2 > max_bins or max_bins > 256: raise ValueError(f'max_bins={max_bins} should be no smaller than 2 ' f'and no larger than 256.') if random_state is None: random_state = int(time.time()) rng = random.PRNGKey(random_state) if subsample is not None and data.shape[0] > subsample: subset = random.shuffle(rng, np.arange(data.shape[0]))[:subsample] data = data[subset] dtype = data.dtype if dtype.kind != 'f': dtype = np.float32 percentiles = np.linspace(0, 100, num=max_bins + 1)[1:-1] binning_thresholds = [] for f_idx in range(data.shape[1]): col_data = np.array(data[:, f_idx], dtype=dtype, order='C') distinct_values = onp.unique(col_data) if len(distinct_values) <= max_bins: midpoints = (distinct_values[:-1] + distinct_values[1:]) midpoints *= 0.5 else: midpoints = np.percentile(col_data, percentiles, interpolation='midpoint').astype(dtype) binning_thresholds.append(midpoints) return tuple(binning_thresholds)
def run_epoch(rng, _opt_state, epoch_idx): _rng, dat_keys = utils.keygen(rng, 1) _rng, batch_keys = utils.keygen(_rng, num_batches) # Randomize epoch data. epoch_data = random.shuffle(next(dat_keys), X_train, axis=0) def update(batch_idx, __opt_state): """Update func for gradients, includes gradient clipping.""" kl_warmup = kl_warmup_fun(epoch_idx * num_batches + batch_idx) batch_data = lax.dynamic_slice_in_dim(epoch_data, batch_idx * BATCH_SIZE, BATCH_SIZE, axis=0) batch_data = batch_data.astype(np.float32) params = get_params(__opt_state) grads = grad(loss_fn)(params, batch_data, next(batch_keys), BATCH_SIZE, ic_prior, VAR_MIN, kl_warmup, L2_REG) clipped_grads = optimizers.clip_grads(grads, MAX_GRAD_NORM) return opt_update(batch_idx, clipped_grads, __opt_state) return lax.fori_loop(0, num_batches, update, _opt_state)
def test_permutation_invariance(self): num_nodes = 4 num_features = 2 rng = random.PRNGKey(0) # Generate random graph. adjacency = random.randint(rng, (num_nodes, num_nodes), 0, 2) node_feats = random.normal(rng, (num_nodes, num_features)) sources, targets = jnp.where(adjacency) # Get permuted graph. perm = random.shuffle(rng, jnp.arange(num_nodes)) node_feats_perm = node_feats[perm] adjacency_perm = adjacency[perm] for j in range(len(adjacency)): adjacency_perm = jax.ops.index_update(adjacency_perm, j, adjacency_perm[j][perm]) sources_perm, targets_perm = jnp.where(adjacency_perm) # Create GNN. _, initial_params = GNN.init(rng, node_x=node_feats, edge_x=None, sources=sources, targets=targets) model = nn.Model(GNN, initial_params) # Feedforward both original and permuted graph. logits = model(node_feats, None, sources, targets) logits_perm = model(node_feats_perm, None, sources_perm, targets_perm) self.assertAllClose(logits[perm], logits_perm, check_dtypes=False)
def shuffle(key_and_data): key, data = key_and_data key, subkey = random.split(key) datapoints_per_device = data[0].shape[0] indices = np.arange(datapoints_per_device) perm = random.shuffle(subkey, indices) return key, [x[perm] for x in data], 0
def get_data(functions, ranges, num_samples=50): import random random.seed(37) onp.random.seed(37) random.shuffle(functions) X = [] Y = [] for i, func in enumerate(functions): Xs = list( onp.random.uniform(ranges[i], ranges[i + 1], size=num_samples)) Ys = list(func(Xs) + onp.random.normal(scale=0.3, size=len(Xs))) X.append(Xs) Y.append(Ys) X = np.array(X).reshape(-1, 1) Y = np.array(Y).reshape(-1, 1) return X, Y, X
def _make_iaf_args(input_dim, hidden_dims): _, rng_key_perm = random.split(random.PRNGKey(0)) perm = random.shuffle(rng_key_perm, onp.arange(input_dim)) arn_init, arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=[1, 1], permutation=perm) _, init_params = arn_init(random.PRNGKey(0), (input_dim, )) return partial(arn, init_params),
def _make_iaf_args(input_dim, hidden_dims): _, rng_perm = random.split(random.PRNGKey(0)) perm = random.shuffle(rng_perm, onp.arange(input_dim)) # we use Elu nonlinearity because the default one, Relu, masks out negative hidden values, # which in turn create some zero entries in the lower triangular part of Jacobian. arn_init, arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=[1, 1], permutation=perm, nonlinearity=stax.Elu) _, init_params = arn_init(random.PRNGKey(0), (input_dim, )) return partial(arn, init_params),
def testShuffle(self, dtype): key = random.PRNGKey(0) x = onp.arange(100).astype(dtype) rand = lambda key: random.shuffle(key, x) crand = api.jit(rand) perm1 = rand(key) perm2 = crand(key) self.assertAllClose(perm1, perm2, check_dtypes=True) self.assertFalse(onp.all(perm1 == x)) # seems unlikely! self.assertAllClose(onp.sort(perm1), x, check_dtypes=False)
def testShuffle(self, dtype): key = random.PRNGKey(0) x = onp.arange(100).astype(dtype) rand = lambda key: random.shuffle(key, x) crand = api.jit(rand) perm1 = rand(key) perm2 = crand(key) self.assertTrue(onp.all(perm1 == perm2)) self.assertTrue(onp.all(perm1.dtype == perm2.dtype)) self.assertFalse(onp.all(perm1 == x)) # seems unlikely! self.assertTrue(onp.all(onp.sort(perm1) == x))
def testShuffle(self, dtype): key = random.PRNGKey(0) x = np.arange(100).astype(dtype) rand = lambda key: random.shuffle(key, x) crand = api.jit(rand) with self.assertWarns(FutureWarning): perm1 = rand(key) with self.assertWarns(FutureWarning): perm2 = crand(key) self.assertAllClose(perm1, perm2) self.assertFalse(np.all(perm1 == x)) # seems unlikely! self.assertAllClose(np.sort(perm1), x, check_dtypes=False)
def minibatcher(data, batch_size, transform=None, seed=0): key = random.PRNGKey(seed) size = data.X.shape[0] indices = np.arange(size, dtype=np.int32) num_batches = size // batch_size while True: key, subkey = random.split(key) perm = random.shuffle(key, indices) for i in range(num_batches): batch_ids = perm[i * batch_size:(i + 1) * batch_size] b = data._replace(X=data.X[batch_ids], Y=data.Y[batch_ids]) if transform: key, subkey = random.split(key) b = transform(b, subkey) yield b
def shuffle(key, tensors, axis=0): """Shuffles the contents of tensors in unison. Args: key: Pseudo-random generator state. tensors: Iterator of tensors. axis: Optional, axis along which to shuffle (default 0). Returns: List of shuffled tensors. Raises: ValueError: If shape of tensors do not match along `axis`. """ a = mo.size(tensors, axis=axis) p = random.shuffle(key, np.arange(a)) return [np.take(tsr, p, axis=axis) for tsr in tensors]
def mini_batch(x_train, y_train, batch_size, train_epochs): # epoch = 0 start = 0 key = random.PRNGKey(0) while True: end = start + batch_size if end > x_train.shape[0]: key, split = random.split(key) permutation = random.shuffle( split, np.arange(x_train.shape[0], dtype=np.int64)) x_train = x_train[permutation] y_train = y_train[permutation] # epoch += 1 start = 0 # print(epoch) continue yield x_train[start:end], y_train[start:end] start = start + batch_size
def test_auto_reg_nn(input_dim, hidden_dims, param_dims, skip_connections): arn_init, arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=param_dims, skip_connections=skip_connections) rng = random.PRNGKey(0) batch_size = 4 input_shape = (batch_size, input_dim) _, init_params = arn_init(rng, input_shape) output = arn(init_params, onp.random.rand(*input_shape)) if param_dims == [1]: assert output.shape == (batch_size, input_dim) jac = jacfwd(lambda x: arn(init_params, x))(onp.random.rand(input_dim)) elif param_dims == [1, 1]: assert output[0].shape == (batch_size, input_dim) assert output[1].shape == (batch_size, input_dim) jac = jacfwd(lambda x: arn(init_params, x)[0])( onp.random.rand(input_dim)) elif param_dims == [2]: assert output.shape == (2, batch_size, input_dim) jac = jacfwd(lambda x: arn(init_params, x))(onp.random.rand(input_dim)) elif param_dims == [2, 3]: assert output[0].shape == (2, batch_size, input_dim) assert output[1].shape == (3, batch_size, input_dim) jac = jacfwd(lambda x: arn(init_params, x)[0])( onp.random.rand(input_dim)) # permute jacobian as necessary permuted_jac = onp.zeros(jac.shape) _, rng_perm = random.split(rng) perm = random.shuffle(rng_perm, onp.arange(input_dim)) for j in range(input_dim): for k in range(input_dim): permuted_jac[..., j, k] = jac[..., perm[j], perm[k]] # make sure jacobians are triangular assert onp.sum(onp.abs(onp.triu(permuted_jac))) == 0.0
def minibatch(x_train, y_train, batch_size, train_epochs): """Generate minibatches of data for a set number of epochs.""" epoch = 0 start = 0 key = random.PRNGKey(0) while epoch < train_epochs: end = start + batch_size if end > x_train.shape[0]: key, split = random.split(key) permutation = random.shuffle(split, np.arange(x_train.shape[0], dtype=np.int64)) x_train = x_train[permutation] y_train = y_train[permutation] epoch += 1 start = 0 continue yield x_train[start:end], y_train[start:end] start = start + batch_size
def init_fun(rng, input_shape): """ :param rng: rng used to initialize parameters :param input_shape: input shape """ # TODO: consider removing permutation so we can move those layer constructions outside # init_fun. It seems that we can add a PermuteTransform layer to achieve the same effect. nonlocal permutation, net if permutation is None: # By default set a random permutation of variables, which is # important for performance with multiple steps rng, rng_perm = random.split(rng) permutation = random.shuffle(rng_perm, np.arange(input_dim)) # Create masks masks, mask_skip = create_mask(input_dim=input_dim, hidden_dims=hidden_dims, permutation=permutation, output_dim_multiplier=output_multiplier) main_layers = [] # Create masked layers for i, mask in enumerate(masks): main_layers.append(MaskedDense(mask)) if i < len(masks) - 1: main_layers.append(nonlinearity) if skip_connections: net_init, net = stax.serial( stax.FanOut(2), stax.parallel(stax.serial(*main_layers), MaskedDense(mask_skip, bias=False)), stax.FanInSum) else: net_init, net = stax.serial(*main_layers) return net_init(rng, input_shape)
def shuffle(self, x, axis=0): x = x.value if isinstance(x, JaxArray) else x return JaxArray(jr.shuffle(self.split_key(), x, axis=axis))
def _shuffle_jax(value, seed=None, name=None): # pylint: disable=unused-argument import jax.random as jaxrand # pylint: disable=g-import-not-at-top if seed is None: raise ValueError('Must provide PRNGKey to sample in JAX.') return jaxrand.shuffle(seed, value, axis=0)
def gauss_laplace_leapfrog(current_state, target_log_prob_fn, kinetic_energy_fn, step_size, n_disc, rng=None): """ One numerical integration step of the DHMC integrator for a mixed Gaussian and Laplace momentum. Params ------ f: function(theta, req_grad) Returns the log probability and, if req_grad is True, its gradient. The gradient for discrete parameters should be zero. f_update: function(theta, step_sizeheta, index, aux) Computes the difference in the log probability when theta[index] is modified by step_sizeheta. The input 'aux' is whatever the quantity saved from the previous call to 'f' or 'f_update' that can be recycled. M: column vector Represents the diagonal mass matrix n_disc: int Number of discrete parameters. The parameters theta[:-n_disc] are assumed continuous. """ del kinetic_energy_fn assert isinstance(current_state.state, list) assert isinstance(current_state.state_grads, list) M = tree_util.tree_map(np.ones_like, current_state.state) state, state_grads = current_state.state, current_state.state_grads momentum = current_state.momentum n_param = len(state) state = list(state) # Update the continuous parameters momentum[:-n_disc] = tree_util.tree_multimap( lambda p, g: p + 0.5 * step_size * g, momentum[:-n_disc], state_grads[:-n_disc]) state[:-n_disc] = tree_util.tree_multimap( lambda t, p: t + 0.5 * step_size * p, state[:-n_disc], momentum[:-n_disc]) logp = utils.call_fn(target_log_prob_fn, state) if np.isinf(logp): return current_state # Update discrete coord_order = n_param - n_disc + np.arange(n_disc) coord_order = random.shuffle(rng, coord_order) for index in coord_order: state, momentum, logp = _update_coordwise(target_log_prob_fn, index, state, momentum, M, step_size, logp) # Another half step of discrete state[:-n_disc] = tree_util.tree_multimap( lambda t, p: t + 0.5 * step_size * p, state[:-n_disc], momentum[:-n_disc]) new_target_logp, new_state_grads = utils.call_fn_value_and_grad( target_log_prob_fn, state) momentum[:-n_disc] = tree_util.tree_multimap( lambda p, g: p + 0.5 * step_size * g, momentum[:-n_disc], new_state_grads[:-n_disc]) return IntegratorState(state=state, state_grads=new_state_grads, target_log_prob=new_target_logp, momentum=momentum)
def get_masks_from_jax_params(params, nn_density_level, magnitude_base_bool = True, global_bool = False, reshuffle_seed = 0): """ Assemble a collection of 0-1 valued masks which are of the same sizes and shapes as layers' weight tensors Note that this function ignores bias parameters. Args: params: parameters in a jax.experimental.stax format. nn_density_level: the desired density level for weight parameters. magnitude_base_bool: a boolean variable that decides whether to prune the network by magnitude or randomly prune the network Returns: masks: a collection of 0-1 valued masks which are of the same sizes and shapes as the layers' weight tensors. """ if (type(magnitude_base_bool) != bool) or (type(global_bool) != bool): raise ValueError("magnitude_base_bool and global_bool should be boolean variables") masks = [] if global_bool: weight_magnitudes_pooled = np.concatenate([ np.abs(layer_params[0].flatten()) for layer_params in params if len(layer_params) > 1]) idx = int( (1 - nn_density_level) * np.size(weight_magnitudes_pooled) ) global_thres = np.sort(weight_magnitudes_pooled)[idx] for layer_index in range( len(params)): if len(params[layer_index]) < 2: # In this the case, the layer does not contain weight and bias parameters. masks.append( [] ) elif len(params[layer_index]) == 2: # In this case, the layer contains a tuple of parameters for weights and biases weights = params[layer_index][0] weight_magnitudes = np.abs(weights) if global_bool and magnitude_base_bool: this_mask = np.float32(weight_magnitudes > global_thres) else: # index: number of pruned parameters idx = int( (1 - nn_density_level) * np.size(weights) ) # threshold: entries which below the thredhold will be removed thres = np.sort(np.reshape(weight_magnitudes, [-1] ))[idx] # 0 selected for weight parameters with magnitudes smaller than the threshold, 1 otherwise this_mask = np.float32(weight_magnitudes > thres) if magnitude_base_bool == False: # in the case of random pruning: randomly shuffle the mask this_mask = random.shuffle(random.PRNGKey(0), this_mask ) masks.append(this_mask ) else: raise NotImplementedError return masks
def train(self, bs, solutions=[None], retrain=False, tensorboard_writer=None, work_unit=None): if not retrain and not self.flaxd: opt_state = self.opt_init(self.net_params) if retrain: opt_state = self.opt_init(self.opt_params) loss = onp.zeros(self.training_iter // 10 + 1) gradients = onp.zeros(self.training_iter // 10 + 1) if not self.flaxd: param = self.get_params(opt_state) else: param = self.optimizer.target opt_state = self.optimizer og_loss = self.test_loss( self.preconditioner, self.n_test, self.mesh, param, np.zeros((bs.shape[1], self.n * self.n)), bs[0].reshape( bs.shape[1], self.n * self.n), 0, self.k) / 10000000 print(og_loss) if work_unit is not None: work_unit.get_measurement_series( label='train/loss').create_measurement(objective_value=og_loss, step=0) for i in range(self.training_iter): m = bs.shape[0] order = random.shuffle(random.PRNGKey(i), np.arange(m)) for _ in range(50): for b in bs[order]: current_loss, grad, opt_state = self.step( i, opt_state, np.zeros((b.shape[0], self.n * self.n)), b, solutions[min(m, len(solutions) - 1)]) if i % 10 == 0: if not self.flaxd: param = self.get_params(opt_state) else: param = opt_state.target current_loss_test = self.test_loss( self.preconditioner, self.n_test, self.mesh, param, np.zeros((b.shape[0], self.n * self.n)), b, 0, self.k) / 10000000 current_loss = current_loss / 10000000 avg_grad = onp.mean(onp.abs(onp_utils.flatten(grad)[-1])) print( f'step{i: 5d}: loss { current_loss :1.5f} : avg_gradient \ { avg_grad :1.5f} : current_loss_test { current_loss_test :1.5f}' ) logging.info( f'step{i: 5d}: loss { current_loss :1.5f} : avg_gradient \ { avg_grad :1.5f} : current_loss_test { current_loss_test :1.5f}' ) loss[i // 10] = current_loss gradients[i // 10] = avg_grad if work_unit is not None: work_unit.get_measurement_series( label='train/loss').create_measurement( objective_value=current_loss_test, step=i) tensorboard_writer.scalar('train/loss', current_loss_test, step=i + 1) work_unit.get_measurement_series( label='train/loss ' + str(self.iter_gmres(i))).create_measurement( objective_value=current_loss, step=i + 1) tensorboard_writer.scalar('train/loss ' + str(self.iter_gmres(i)), current_loss, step=i + 1) if i % 50 == 0: if self.flaxd: self.opt_params = opt_state.target.params else: self.opt_params = self.get_params(opt_state) self.save(str(i)) if self.flaxd: self.optimizer = opt_state else: self.opt_params = self.get_params(opt_state) self.opt_state = opt_state if self.model_dir is None: self.model_dir = '' with open(os.path.join(self.model_dir, 'train_loss.np'), 'wb') as f: onp.save(f, loss) with open(os.path.join(self.model_dir, 'train_gradients.np'), 'wb') as f: onp.save(f, gradients) self.save() if work_unit is not None: tensorboard_writer.close()
def search(self, method_id, method_params, problem_id, problem_params, loss, search_space, trials=None, smoothing=10, min_steps=100, verbose=0): """ Description: Search for optimal method parameters Args: method_id (string): id of method method_params (dict): initial method parameters dict (updated by search space) problem_id (string): id of problem to try on problem_params (dict): problem parameters dict loss (function): a function mapping y_pred, y_true -> scalar loss search_space (dict): dict mapping parameter names to a finite set of options trials (int, None): number of random trials to sample from search space / try all parameters smoothing (int): loss computed over smoothing number of steps to decrease variance min_steps (int): minimum number of steps that the method gets to run for verbose (int): if 1, print progress and current parameters """ self.method_id = method_id self.method_params = method_params self.problem_id = problem_id self.problem_params = problem_params self.loss = loss # store the order to test parameters param_list = list( itertools.product(*[v for k, v in search_space.items()])) index = np.arange( len(param_list) ) # np.random.shuffle doesn't work directly on non-JAX objects shuffled_index = random.shuffle(generate_key(), index) param_order = [param_list[int(i)] for i in shuffled_index] # shuffle order of elements # helper method def _update_smoothing(l, val): """ update smoothing loss list with new val """ return jax.ops.index_update(np.roll(l, 1), 0, val) self._update_smoothing = jit(_update_smoothing) # store optimal params and optimal loss optimal_params, optimal_loss = {}, None t = 0 for params in param_order: # loop over all params in the given order t += 1 curr_params = method_params.copy() curr_params.update( {k: v for k, v in zip(search_space.keys(), params)}) loss = self._run_test(curr_params, smoothing=smoothing, min_steps=min_steps, verbose=verbose) if not optimal_loss or loss < optimal_loss: optimal_params = curr_params optimal_loss = loss if t == trials: # break after trials number of attempts, unless trials is None break return optimal_params, optimal_loss