def initialize(self, y=None, num_subunits=1, dt=0.033, method='random', compute_ci=True, random_seed=2046, verbose=0, add_noise_to_mle=0): self.init_method = method # store meta self.num_subunits = num_subunits self.compute_ci = compute_ci if method == 'random': self.b['random'] = {} self.w['random'] = {} self.intercept['random'] = {} if verbose: print('Initializing model parameters randomly...') for i, name in enumerate(self.filter_names): self.intercept['random'][name] = 0. key = random.PRNGKey(random_seed + i) # change random seed for each filter if name in self.S: self.b['random'][name] = random.normal( key, shape=(self.XS['train'][name].shape[1], 1)).astype(self.dtype) self.w['random'][ name] = self.S[name] @ self.b['random'][name] else: self.w['random'][name] = random.normal( key, shape=(self.X['train'][name].shape[1], 1)).astype(self.dtype) self.intercept['random']['global'] = 0. if verbose: print('Finished.') elif method == 'mle': if verbose: print( 'Initializing model parameters with maximum likelihood...') if not self.mle_computed: self.compute_mle(y) if verbose: print('Finished.') else: raise ValueError(f'`{method}` is not supported.') # rename and repmat: stimulus filter to subunits filters # subunit model only works with one stimulus. filter_names = self.filter_names.copy() if num_subunits != 1: filter_names.remove('stimulus') filter_names = [f'stimulus_s{i}' for i in range(num_subunits)] + filter_names for name in filter_names: if 'stimulus' in name: self.dims[name] = self.dims['stimulus'] self.df[name] = self.dims['stimulus'] self.shift[name] = self.shift['stimulus'] self.filter_nonlinearity[name] = self.filter_nonlinearity[ 'stimulus'] self.intercept[method][name] = self.intercept[method][ 'stimulus'] self.w[method][name] = self.w[method]['stimulus'] if method in self.w_se: self.w_se[method][name] = self.w_se[method]['stimulus'] self.X['train'][name] = self.X['train']['stimulus'] if 'dev' in self.X: self.X['dev'][name] = self.X['dev']['stimulus'] if 'stimulus' in self.S: self.b[method][name] = self.b[method]['stimulus'] if method in self.b_se: self.b_se[method][name] = self.b_se[method][ 'stimulus'] self.XS['train'][name] = self.XS['train']['stimulus'] if 'dev' in self.XS: self.XS['dev'][name] = self.XS['dev']['stimulus'] self.S[name] = self.S['stimulus'] self.b[method].pop('stimulus', None) self.w[method].pop('stimulus') self.intercept[method].pop('stimulus') self.X['train'].pop('stimulus') self.X['dev'].pop('stimulus') if self.XS != {}: self.XS['train'].pop('stimulus') self.XS['dev'].pop('stimulus') self.S.pop('stimulus') self.filter_names = filter_names self.p[method] = {} p0 = {} for i, name in enumerate(self.filter_names): if name in self.S: b = self.b[method][name] key = random.PRNGKey(random_seed + i) self.p[method].update({name: b}) noise = add_noise_to_mle * random.normal( key, shape=b.shape).astype(self.dtype) p0.update({name: b + noise}) else: w = self.w[method][name] key = random.PRNGKey(random_seed + i) self.p[method].update({name: w}) noise = add_noise_to_mle * random.normal( key, shape=w.shape).astype(self.dtype) p0.update({name: w + noise}) self.p[method].update({'intercept': self.intercept[method]}) p0.update({'intercept': self.intercept[method]}) self.dt = dt self.p0 = p0
def testParetoShape(self): key = random.PRNGKey(0) x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2)
def testFoldIn(self): key = random.PRNGKey(0) keys = [random.fold_in(key, i) for i in range(10)] assert np.unique(np.ravel(keys)).shape == (20, )
def testBernoulliShape(self): key = random.PRNGKey(0) x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2)
def testPoissonBatched(self): key = random.PRNGKey(0) lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)]) samples = random.poisson(key, lam, shape=(20000, )) self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf) self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf)
def main(args): X, Y, expected_thetas, expected_pairwise = get_data( N=args.num_data, P=args.num_dimensions, S=args.active_dimensions ) # setup hyperparameters hypers = { "expected_sparsity": max(1.0, args.num_dimensions / 10), "alpha1": 3.0, "beta1": 1.0, "alpha2": 3.0, "beta2": 1.0, "alpha3": 1.0, "c": 1.0, } # do inference rng_key = random.PRNGKey(0) samples = run_inference(model, args, rng_key, X, Y, hypers) # compute the mean and square root variance of each coefficient theta_i means, stds = vmap(lambda dim: analyze_dimension(samples, X, Y, dim, hypers))( jnp.arange(args.num_dimensions) ) print( "Coefficients theta_1 to theta_%d used to generate the data:" % args.active_dimensions, expected_thetas, ) print( "The single quadratic coefficient theta_{1,2} used to generate the data:", expected_pairwise, ) active_dimensions = [] for dim, (mean, std) in enumerate(zip(means, stds)): # we mark the dimension as inactive if the interval [mean - 3 * std, mean + 3 * std] contains zero lower, upper = mean - 3.0 * std, mean + 3.0 * std inactive = "inactive" if lower < 0.0 and upper > 0.0 else "active" if inactive == "active": active_dimensions.append(dim) print( "[dimension %02d/%02d] %s:\t%.2e +- %.2e" % (dim + 1, args.num_dimensions, inactive, mean, std) ) print( "Identified a total of %d active dimensions; expected %d." % (len(active_dimensions), args.active_dimensions) ) # Compute the mean and square root variance of coefficients theta_ij for i,j active dimensions. # Note that the resulting numbers are only meaningful for i != j. if len(active_dimensions) > 0: dim_pairs = jnp.array( list(itertools.product(active_dimensions, active_dimensions)) ) means, stds = vmap( lambda dim_pair: analyze_pair_of_dimensions( samples, X, Y, dim_pair[0], dim_pair[1], hypers ) )(dim_pairs) for dim_pair, mean, std in zip(dim_pairs, means, stds): dim1, dim2 = dim_pair if dim1 >= dim2: continue lower, upper = mean - 3.0 * std, mean + 3.0 * std if not (lower < 0.0 and upper > 0.0): format_str = "Identified pairwise interaction between dimensions %d and %d: %.2e +- %.2e" print(format_str % (dim1 + 1, dim2 + 1, mean, std)) # Draw a single sample of coefficients theta from the posterior, where we return all singleton # coefficients theta_i and pairwise coefficients theta_ij for i, j active dimensions. We use the # final MCMC sample obtained from the HMC sampler. thetas = sample_theta_space( X, Y, active_dimensions, samples["msq"][-1], samples["lambda"][-1], samples["eta1"][-1], samples["xisq"][-1], hypers["c"], samples["sigma"][-1], ) print("Single posterior sample theta:\n", thetas)
stax.parallel(Dense(10), stax.serial(Dense(10), Softplus)), ) decoder_init, decode = stax.serial( Dense(512), Relu, Dense(512), Relu, Dense(28 * 28), ) if __name__ == "__main__": step_size = 0.001 num_epochs = 100 batch_size = 32 nrow, ncol = 10, 10 # sampled image grid size rng = random.PRNGKey(0) test_rng = random.PRNGKey(1) # fixed prng key for evaluation imfile = os.path.join(os.getenv("TMPDIR", "/tmp/"), "mnist_vae_{:03d}.png") train_images, _, test_images, _ = datasets.mnist(permute_train=True) num_complete_batches, leftover = divmod(train_images.shape[0], batch_size) num_batches = num_complete_batches + bool(leftover) _, init_encoder_params = encoder_init((batch_size, 28 * 28)) _, init_decoder_params = decoder_init((batch_size, 10)) init_params = init_encoder_params, init_decoder_params opt_init, opt_update = optimizers.momentum(step_size, mass=0.9) def binarize_batch(rng, i, images):
def x(): return random.normal(random.PRNGKey(1), (2, 3), dtype)
def x_cpu(): return device_get( random.normal(random.PRNGKey(1), (2, 3), dtype))
return ce_loss, acc if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--seed', type=int, default=0) parser.add_argument('--hidden', type=int, default=16) parser.add_argument('--epochs', type=int, default=400) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--lr', type=float, default=0.005) args = parser.parse_args() # Load data adj, features, labels, idx_train, idx_val, idx_test = load_data() rng_key = random.PRNGKey(args.seed) step_size = args.lr num_epochs = args.epochs n_nodes = adj.shape[0] n_feats = features.shape[1] # GAT params nheads = [8, 1] nhid = [8] dropout = args.dropout # probability of keeping residual = False init_fun, predict_fun = GAT(nheads=nheads, nhid=nhid, nclass=labels.shape[1], dropout=dropout,
def __init__( self, w_in: np.ndarray, w_recurrent: np.ndarray, w_out: np.ndarray, tau: np.ndarray, bias: np.ndarray, noise_std: float = 0.0, activation_func: Callable[[FloatVector], FloatVector] = H_ReLU, dt: Optional[float] = None, name: Optional[str] = None, rng_key: Optional[int] = None, ): """ RecRateEulerJax - ``JAX``-backed firing rate reservoir :param np.ndarray w_in: Input weights [IxN] :param np.ndarray w_recurrent: Recurrent weights [NxN] :param np.ndarray w_out: Output weights [NxO] :param np.ndarray tau: Time constants [N] :param np.ndarray bias: Bias values [N] :param float noise_std: White noise standard deviation applied to reservoir neurons. Default: ``0.0`` :param Callable[[FloatVector], float] activation_func: Neuron transfer function f(x: float) -> float. Must be vectorised. Default: H_ReLU :param Optional[float] dt: Reservoir time step. Default: ``np.min(tau) / 10.0`` :param Optional[str] name: Name of the layer. Default: ``None`` :param Optional[Jax RNG key] rng_key Jax RNG key to use for noise. Default: Internally generated """ # - Everything should be 2D w_in = np.atleast_2d(w_in) w_recurrent = np.atleast_2d(w_recurrent) w_out = np.atleast_2d(w_out) # transform to np.array if necessary tau = np.array(tau) bias = np.array(bias) # - Get information about network size self._size_in = w_in.shape[0] self._size = w_in.shape[1] self._size_out = w_out.shape[1] # -- Set properties self.w_recurrent = w_recurrent self.w_out = w_out self.tau = tau self.bias = bias self._H = activation_func if dt is None: dt = np.min(tau) / 10.0 # - Call super-class initialisation super().__init__(w_in, dt, noise_std, name) # - Correct layer size self._size_in = w_in.shape[0] self._size_out = w_out.shape[1] # - Get compiled evolution function self._evolve_jit = _get_rec_evolve_jit(activation_func) # - Reset layer state self.reset_all() # - Seed RNG if rng_key is None: rng_key = rand.PRNGKey(onp.random.randint(0, 2 ** 63)) _, self._rng_key = rand.split(rng_key)
def test_jit_or_pmap_broadcast(self): def kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.65): res = np.abs(np.matmul(x1, x2)) if do_square: res *= res if do_flip: res = -res res *= random.uniform(keys) * p return [res, params] params = (np.array([1., 0.3]), (np.array([1.2]), np.array([0.5]))) x2 = np.arange(0, 10).reshape((10, )) keys = random.PRNGKey(1) kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=0) x1 = np.arange(0, 10).reshape((1, 10)) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=0): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=True, p=0.65) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=True) self.assertAllClose(res_1, res_2) test_utils.stub_out_pmap(batch, 1) x1 = np.arange(0, 10).reshape((1, 10)) kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=1) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=1): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=False, p=0.65) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=None) self.assertAllClose(res_1[0], res_2[0]) self.assertAllClose( tree_map(partial(np.expand_dims, axis=0), res_1[1]), res_2[1]) kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=2) x1 = np.arange(0, 20).reshape((2, 10)) test_utils.stub_out_pmap(batch, 2) def broadcast(arg): return np.broadcast_to(arg, (2, ) + arg.shape) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=2): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, p=0.2) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.2) self.assertAllClose(res_1[0][0], res_2[0][0]) self.assertAllClose(res_1[0][1], res_2[0][1]) self.assertAllClose(tree_map(broadcast, res_1[1]), res_2[1])
def main(_): sns.set() sns.set_palette(sns.color_palette('hls', 10)) npr.seed(FLAGS.seed) logging.info('Starting experiment.') # Create model folder for outputs try: gfile.MakeDirs(FLAGS.work_dir) except gfile.GOSError: pass stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.work_dir), 'w+') # BEGIN: fetch test data and candidate pool test_images, test_labels, _ = datasets.get_dataset_split( name=FLAGS.test_split.split('-')[0], split=FLAGS.test_split.split('-')[1], shuffle=False) pool_images, pool_labels, _ = datasets.get_dataset_split( name=FLAGS.pool_split.split('-')[0], split=FLAGS.pool_split.split('-')[1], shuffle=False) n_pool = len(pool_images) # normalize to range [-1.0, 127./128] test_images = test_images / np.float32(128.0) - np.float32(1.0) pool_images = pool_images / np.float32(128.0) - np.float32(1.0) # augmentation for train/pool data if FLAGS.augment_data: augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5), data.RandomCrop(4), data.ToDevice) else: augmentation = None # END: fetch test data and candidate pool _, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate) # BEGIN: load ckpt ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx) with gfile.Open(ckpt_dir, 'wr') as fckpt: opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt)) params = get_params(opt_state) stdout_log.write('finetune from: {}\n'.format(ckpt_dir)) logging.info('finetune from: %s', ckpt_dir) test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True) logging.info('test accuracy: %.2f', test_acc) stdout_log.write('test accuracy: {}\n'.format(test_acc)) stdout_log.flush() # END: load ckpt # BEGIN: setup for dp model @jit def update(_, i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad_loss(params, batch), opt_state) @jit def private_update(rng, i, opt_state, batch): params = get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, private_grad(params, batch, rng, FLAGS.l2_norm_clip, FLAGS.noise_multiplier, FLAGS.batch_size), opt_state) # END: setup for dp model ### BEGIN: prepare extra points picked from pool data # BEGIN: on pool data pool_embeddings = [apply_fn_0(params[:-1], pool_images[b_i:b_i + FLAGS.batch_size]) \ for b_i in range(0, n_pool, FLAGS.batch_size)] pool_embeddings = np.concatenate(pool_embeddings, axis=0) pool_logits = apply_fn_1(params[-1:], pool_embeddings) pool_true_labels = np.argmax(pool_labels, axis=1) pool_predicted_labels = np.argmax(pool_logits, axis=1) pool_correct_indices = \ onp.where(pool_true_labels == pool_predicted_labels)[0] pool_incorrect_indices = \ onp.where(pool_true_labels != pool_predicted_labels)[0] assert len(pool_correct_indices) + \ len(pool_incorrect_indices) == len(pool_labels) pool_probs = stax.softmax(pool_logits) if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy': pool_entropy = -onp.sum(pool_probs * onp.log(pool_probs), axis=1) stdout_log.write('all {} entropy: min {}, max {}\n'.format( len(pool_entropy), onp.min(pool_entropy), onp.max(pool_entropy))) pool_entropy_sorted_indices = onp.argsort(pool_entropy) # take the n_extra most uncertain points pool_uncertain_indices = \ pool_entropy_sorted_indices[::-1][:FLAGS.n_extra] stdout_log.write('uncertain {} entropy: min {}, max {}\n'.format( len(pool_entropy[pool_uncertain_indices]), onp.min(pool_entropy[pool_uncertain_indices]), onp.max(pool_entropy[pool_uncertain_indices]))) elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference': # 1st_prob - 2nd_prob assert len(pool_probs.shape) == 2 sorted_pool_probs = onp.sort(pool_probs, axis=1) pool_probs_diff = sorted_pool_probs[:, -1] - sorted_pool_probs[:, -2] assert min(pool_probs_diff) > 0. stdout_log.write('all {} difference: min {}, max {}\n'.format( len(pool_probs_diff), onp.min(pool_probs_diff), onp.max(pool_probs_diff))) pool_uncertain_indices = onp.argsort(pool_probs_diff)[:FLAGS.n_extra] stdout_log.write('uncertain {} difference: min {}, max {}\n'.format( len(pool_probs_diff[pool_uncertain_indices]), onp.min(pool_probs_diff[pool_uncertain_indices]), onp.max(pool_probs_diff[pool_uncertain_indices]))) elif FLAGS.uncertain == 2 or FLAGS.uncertain == 'random': pool_uncertain_indices = npr.permutation(n_pool)[:FLAGS.n_extra] # END: on pool data ### END: prepare extra points picked from pool data finetune_images = copy.deepcopy(pool_images[pool_uncertain_indices]) finetune_labels = copy.deepcopy(pool_labels[pool_uncertain_indices]) stdout_log.write('Starting fine-tuning...\n') logging.info('Starting fine-tuning...') stdout_log.flush() stdout_log.write('{} points picked via {}\n'.format( len(finetune_images), FLAGS.uncertain)) logging.info('%d points picked via %s', len(finetune_images), FLAGS.uncertain) assert FLAGS.n_extra == len(finetune_images) for epoch in range(1, FLAGS.epochs + 1): # BEGIN: finetune model with extra data, evaluate and save num_extra = len(finetune_images) num_complete_batches, leftover = divmod(num_extra, FLAGS.batch_size) num_batches = num_complete_batches + bool(leftover) finetune = data.DataChunk(X=finetune_images, Y=finetune_labels, image_size=28, image_channels=1, label_dim=1, label_format='numeric') batches = data.minibatcher(finetune, FLAGS.batch_size, transform=augmentation) itercount = itertools.count() key = random.PRNGKey(FLAGS.seed) start_time = time.time() for _ in range(num_batches): # tmp_time = time.time() b = next(batches) if FLAGS.dpsgd: opt_state = private_update( key, next(itercount), opt_state, shape_as_image(b.X, b.Y, dummy_dim=True)) else: opt_state = update(key, next(itercount), opt_state, shape_as_image(b.X, b.Y)) # stdout_log.write('single update in {:.2f} sec\n'.format( # time.time() - tmp_time)) epoch_time = time.time() - start_time stdout_log.write('Epoch {} in {:.2f} sec\n'.format(epoch, epoch_time)) logging.info('Epoch %d in %.2f sec', epoch, epoch_time) # accuracy on test data params = get_params(opt_state) test_pred_0 = test_pred test_acc, test_pred = accuracy(params, shape_as_image(test_images, test_labels), return_predicted_class=True) test_loss = loss(params, shape_as_image(test_images, test_labels)) stdout_log.write( 'Eval set loss, accuracy (%): ({:.2f}, {:.2f})\n'.format( test_loss, 100 * test_acc)) logging.info('Eval set loss, accuracy: (%.2f, %.2f)', test_loss, 100 * test_acc) stdout_log.flush() # visualize prediction difference between 2 checkpoints. if FLAGS.visualize: utils.visualize_ckpt_difference(test_images, np.argmax(test_labels, axis=1), test_pred_0, test_pred, epoch - 1, epoch, FLAGS.work_dir, mu=128., sigma=128.) # END: finetune model with extra data, evaluate and save stdout_log.close()
PeDurx = 70 batch_size = 32 Pencoder = PoissonEncoder(duration=PeDurx) optimizer = SGD(model.parameters(), lr=0.01) train_loader = DataLoader(dataset=MnistDataset(training=True, flatten=True), collate_fn=collate_fn, shuffle=True, batch_size=batch_size) test_loader = DataLoader(dataset=MnistDataset(training=False, flatten=True), collate_fn=collate_fn, shuffle=False, batch_size=batch_size) for epoch in range(15): for i, (data, target) in enumerate(train_loader): target = Variable(target) for t in range(PeDurx): rnum = random.uniform(key=random.PRNGKey(0), shape=data.shape) uin = (jnp.abs(data) / 2 > rnum).astype('float32') q = jnp.multiply(uin, jnp.sign(data)) output, time = model(Variable(q), t) print(str(t)) loss = F.Spikeloss(output, target, time_step=time) loss.backward() # calc gradients optimizer.step() # update gradients print("Epoch:" + str(epoch) + "Time:" + str(i) + "loss:" + str(loss.data))
learning_rate = 0.18 opt_init, opt_update, get_params = optimizers.sgd(learning_rate) @jit def update(_, i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) """The next cell contains our training loop, very similar to Problem 1.""" num_epochs = 10 key = random.PRNGKey(123) _, init_params = init_random_params(key, (-1, 28, 28, 1)) opt_state = opt_init(init_params) itercount = itertools.count() test_losses = [] test_accs = [] for epoch in range(1, num_epochs + 1): for _ in range(num_batches): opt_state = update(key, next(itercount), opt_state, shape_as_image(*next(batches))) params = get_params(opt_state) # print("Params are: {} ".format(params)) test_acc = accuracy(params, shape_as_image(test_images, test_labels)) test_loss = loss(params, shape_as_image(test_images, test_labels))
def _wrap_seed_jax(seed, _): import jax.random as jaxrand # pylint: disable=g-import-not-at-top return jaxrand.PRNGKey(seed % (2**32 - 1))
s += pytree_dot2(x, y) s += pytree_dot2(x, y) return s def internal_only_jit(x, y): s = pytree_dot(x, y) s += pytree_dot(x, y) s += pytree_dot(x, y) s += pytree_dot(x, y) s += pytree_dot(x, y) return s if __name__ == "__main__": k1, k2, k3 = random.split(random.PRNGKey(3), 3) x = [np.array([0.5, 0.5]), random.uniform(k1, shape=(50000,)), {"beta": 0.5}] y = [np.array([0.5, 0.5]), random.uniform(k2, shape=(50000,)), {"beta": 0.5}] z = [np.array([0.5, 0.5]), random.uniform(k3, shape=(50000,)), {"beta": 0.5}] x1 = x s = 0.0 start_time = datetime.now() for i in range(10000): if i % 2 == 0: x1 = z else: x1 = x s = _jit(x, y) end_time = datetime.now() print("Std Jit Duration: {}".format(end_time - start_time))
def run(self, output_h5parm, ncpu, avg_direction_spacing, field_of_view_diameter, duration, time_resolution, start_time, array_name, phase_tracking): Nd = get_num_directions( avg_direction_spacing, field_of_view_diameter, ) Nf = 2 # 8000 Nt = int(duration / time_resolution) + 1 min_freq = 700. max_freq = 2000. dp = create_empty_datapack( Nd, Nf, Nt, pols=None, field_of_view_diameter=field_of_view_diameter, start_time=start_time, time_resolution=time_resolution, min_freq=min_freq, max_freq=max_freq, array_file=ARRAYS[array_name], phase_tracking=(phase_tracking.ra.deg, phase_tracking.dec.deg), save_name=output_h5parm, clobber=True) with dp: dp.current_solset = 'sol000' dp.select(pol=slice(0, 1, 1)) axes = dp.axes_tec patch_names, directions = dp.get_directions(axes['dir']) antenna_labels, antennas = dp.get_antennas(axes['ant']) timestamps, times = dp.get_times(axes['time']) ref_ant = antennas[0] ref_time = times[0] Na = len(antennas) Nd = len(directions) Nt = len(times) logger.info(f"Number of directions: {Nd}") logger.info(f"Number of antennas: {Na}") logger.info(f"Number of times: {Nt}") logger.info(f"Reference Ant: {ref_ant}") logger.info(f"Reference Time: {ref_time.isot}") # Plot Antenna Layout in East North Up frame ref_frame = ENU(obstime=ref_time, location=ref_ant.earth_location) _antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=ref_time).transform_to(ref_frame) # plt.scatter(_antennas.east, _antennas.north, marker='+') # plt.xlabel(f"East (m)") # plt.ylabel(f"North (m)") # plt.show() x0 = ac.ITRS( *antennas[0].cartesian.xyz, obstime=ref_time).transform_to(ref_frame).cartesian.xyz.to( au.km).value earth_centre_x = ac.ITRS( x=0 * au.m, y=0 * au.m, z=0. * au.m, obstime=ref_time).transform_to(ref_frame).cartesian.xyz.to( au.km).value self._kernel = TomographicKernel(x0, earth_centre_x, M32(), S_marg=20, compute_tec=False) k = directions.transform_to(ref_frame).cartesian.xyz.value.T t = times.mjd * 86400. t -= t[0] X1 = GeodesicTuple(x=[], k=[], t=[], ref_x=[]) logger.info("Computing coordinates in frame ...") for i, time in tqdm(enumerate(times)): x = ac.ITRS(*antennas.cartesian.xyz, obstime=time).transform_to(ref_frame).cartesian.xyz.to( au.km).value.T ref_ant_x = ac.ITRS( *ref_ant.cartesian.xyz, obstime=time).transform_to(ref_frame).cartesian.xyz.to( au.km).value X = make_coord_array(x, k, t[i:i + 1, None], ref_ant_x[None, :], flat=True) X1.x.append(X[:, 0:3]) X1.k.append(X[:, 3:6]) X1.t.append(X[:, 6:7]) X1.ref_x.append(X[:, 7:8]) X1 = X1._replace( x=jnp.concatenate(X1.x, axis=0), k=jnp.concatenate(X1.k, axis=0), t=jnp.concatenate(X1.t, axis=0), ref_x=jnp.concatenate(X1.ref_x, axis=0), ) logger.info(f"Total number of coordinates: {X1.x.shape[0]}") def compute_covariance_row(X1: GeodesicTuple, X2: GeodesicTuple): K = self._kernel(X1, X2, self._bottom, self._width, self._fed_sigma, self._fed_kernel_params, wind_velocity=self._wind_vector) # 1, N return K[0, :] covariance_row = lambda X: compute_covariance_row( tree_map(lambda x: x.reshape((1, -1)), X), X1) mean = jit(lambda X1: self._kernel.mean_function(X1, self._bottom, self._width, self._fed_mu, wind_velocity=self. _wind_vector))(X1) cov = chunked_pmap(covariance_row, X1, batch_size=X1.x.shape[0], chunksize=ncpu) plt.imshow(cov) plt.show() Z = random.normal(random.PRNGKey(42), (cov.shape[0], 1), dtype=cov.dtype) t0 = default_timer() jitter = 1e-6 logger.info(f"Computing Cholesky with jitter: {jitter}") L = jnp.linalg.cholesky(cov + jitter * jnp.eye(cov.shape[0])) if np.any(np.isnan(L)): logger.info("Numerically instable. Using SVD.") L = msqrt(cov) logger.info(f"Cholesky took {default_timer() - t0} seconds.") dtec = (L @ Z + mean[:, None])[:, 0].reshape((Na, Nd, Nt)).transpose( (1, 0, 2)) logger.info(f"Saving result to {output_h5parm}") with dp: dp.current_solset = 'sol000' dp.select(pol=slice(0, 1, 1)) dp.tec = np.asarray(dtec[None])
def main(_): if FLAGS.microbatches: raise NotImplementedError( 'Microbatches < batch size not currently supported') train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, FLAGS.batch_size) num_batches = num_complete_batches + bool(leftover) key = random.PRNGKey(FLAGS.seed) def data_stream(): rng = npr.RandomState(FLAGS.seed) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate) @jit def update(_, i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) @jit def private_update(rng, i, opt_state, batch): params = get_params(opt_state) rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, private_grad(params, batch, rng, FLAGS.l2_norm_clip, FLAGS.noise_multiplier, FLAGS.batch_size), opt_state) _, init_params = init_random_params(key, (-1, 28, 28, 1)) opt_state = opt_init(init_params) itercount = itertools.count() steps_per_epoch = 60000 // FLAGS.batch_size print('\nStarting training...') for epoch in range(1, FLAGS.epochs + 1): start_time = time.time() for _ in range(num_batches): if FLAGS.dpsgd: opt_state = \ private_update( key, next(itercount), opt_state, shape_as_image(*next(batches), dummy_dim=True)) else: opt_state = update(key, next(itercount), opt_state, shape_as_image(*next(batches))) epoch_time = time.time() - start_time print('Epoch {} in {:0.2f} sec'.format(epoch, epoch_time)) # evaluate test accuracy params = get_params(opt_state) test_acc = accuracy(params, shape_as_image(test_images, test_labels)) test_loss = loss(params, shape_as_image(test_images, test_labels)) print('Test set loss, accuracy (%): ({:.2f}, {:.2f})'.format( test_loss, 100 * test_acc)) # determine privacy loss so far if FLAGS.dpsgd: delta = 1e-5 num_examples = 60000 eps = compute_epsilon(epoch * steps_per_epoch, num_examples, delta) print('For delta={:.0e}, the current epsilon is: {:.2f}'.format( delta, eps)) else: print('Trained with vanilla non-private SGD optimizer')
import flaxvision.models as flax_models from flax import nn import numpy as np import jax.numpy as jnp from jax import random from jax.config import config config.enable_omnistaging() import unittest, os import logging RNG = random.PRNGKey(0) MODELS_LIST = [ 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2', 'inception_v3', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101' ] class TestTraining(unittest.TestCase): def test_outputs(self): log = logging.getLogger(__name__) inputs = jnp.ones((1, 224, 224, 3)) inception_inputs = jnp.ones((1, 299, 299, 3)) for key in MODELS_LIST: log.info(f'testing inference {key}')
def testPermutationErrors(self): key = random.PRNGKey(0) with self.assertRaises(TypeError): random.permutation(key, 10.) with self.assertRaises(core.ConcretizationTypeError): api.jit(random.permutation)(key, 10)
def test_sample(self, dist, args, kwargs, out, flat): del out, flat args = args() kwargs = kwargs() p = dist(*args, **kwargs) p.sample(seed=random.PRNGKey(0))
def testGammaShape(self): key = random.PRNGKey(0) x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2)
# Training data numSamples = inParameters['Training data']['number_of_training_samples'] trainDataFileName = inParameters['Training data']['training_data'] # Training data outDir = inParameters['Output']['output_folder'] # Set numpy seed np.random.seed(0) # Create subdirectory for network checkpoints create_dir(outDir + "/net_checkpoints/") # Model setup rnnNet = RNN2D.partial(L=L, units=rnnUnits, initScale=netInitScale) _, params = rnnNet.init_by_shape(random.PRNGKey(netInitSeed), [(1, L, L)]) rnnModel = nn.Model(rnnNet, params) # Optimizer setup optimizer = flax.optim.Adam(learning_rate=learningRate, beta1=beta1, beta2=beta2).create(rnnModel) # Load data if inParameters['Training data']['training_data'] == "generate": print("*** Generating samples") numTestSamples = inParameters['Training data']['number_of_test_samples'] if numTestSamples < numSamples: numTestSamples = numSamples trainData, trainEnergies, testData, testEnergies =\ generate_samples(numTestSamples,T,L,
def testPoissonShape(self): key = random.PRNGKey(0) x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2)) assert x.shape == (3, 2)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = 'tpu_driver' jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # This seems to be necessary even when importing TF2? tf.enable_v2_behavior() # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) if FLAGS.batch_size % n_devices: raise ValueError( 'Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset logging.info('Initializing dataset.') train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=n_devices, dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_token = 2 # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_token) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer transformer_kwargs = { 'vocab_size': vocab_size, 'output_vocab_size': vocab_size, 'emb_dim': FLAGS.emb_dim, 'num_heads': FLAGS.num_heads, 'num_layers': FLAGS.num_layers, 'qkv_dim': FLAGS.qkv_dim, 'mlp_dim': FLAGS.mlp_dim, 'max_len': max(FLAGS.max_target_length, FLAGS.max_eval_target_length), 'share_embeddings': FLAGS.share_embeddings, 'logits_via_embedding': FLAGS.logits_via_embedding, } start_step = 0 rng = random.PRNGKey(FLAGS.random_seed) rng, init_rng = random.split(rng) input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) model, cache_def = create_model(init_rng, input_shape, target_shape, transformer_kwargs) optimizer = create_optimizer(model, FLAGS.learning_rate, FLAGS.weight_decay) # We access model only from optimizer below via optimizer.target. del model if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, label_smoothing=FLAGS.label_smoothing, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch') p_eval_step = jax.pmap(functools.partial( eval_step, label_smoothing=FLAGS.label_smoothing, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch') p_pred_step = jax.pmap( functools.partial(predict_step, use_bfloat16=FLAGS.use_bfloat16, beam_size=FLAGS.beam_size), axis_name='batch', static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, n_devices) logging.info('Starting training loop.') metrics_all = [] t_loop_start = time.time() for step, batch in zip(range(start_step, FLAGS.num_train_steps), train_iter): # Shard data to devices and do a training step. batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Save a checkpoint on one host after every checkpoint_freq steps. if (FLAGS.save_checkpoints and step % FLAGS.checkpoint_freq == 0 and step > 0 and jax.host_id() == 0): checkpoints.save_checkpoint(FLAGS.model_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % FLAGS.eval_frequency != 0 and step > 0: continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr steps_per_eval = FLAGS.eval_frequency if step != 0 else 1 steps_per_sec = steps_per_eval / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() metrics_all = [] logging.info('train in step: %d, loss: %.4f', step, summary['loss']) # Eval Metrics logging.info('Gathering evaluation metrics.') t_eval_start = time.time() eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(FLAGS.num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush() logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss']) logging.info('eval time: %.4f s step %d', time.time() - t_eval_start, step) # Translation and BLEU Score. logging.info('Translating evaluation dataset.') t_inference_start = time.time() predict_iter = iter(predict_ds) sources, references, predictions = [], [], [] for _, pred_batch in enumerate(predict_iter): pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop pred_batch = common_utils.shard(pred_batch) per_device_batchsize = pred_batch['inputs'].shape[1] cache_dtype = jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32 cache = jax_utils.replicate( cache_def.initialize_cache( (per_device_batchsize, FLAGS.max_predict_length), dtype=cache_dtype)) predicted = p_pred_step(pred_batch['inputs'], optimizer.target, cache, eos_token, FLAGS.max_predict_length) predicted = tohost(predicted) inputs = tohost(pred_batch['inputs']) targets = tohost(pred_batch['targets']) # Iterate through non-padding examples of batch. for i, s in enumerate(predicted[:cur_pred_batch_size]): sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) predictions.append(decode_tokens(s)) logging.info('Translation: %d predictions %d references %d sources.', len(predictions), len(references), len(sources)) logging.info('Translation time: %.4f s step %d.', time.time() - t_inference_start, step) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) all_bleu_matches = per_host_sum_pmap(bleu_matches) bleu_score = bleu.complete_bleu(*all_bleu_matches) # Save translation samples for tensorboard. exemplars = '' for n in np.random.choice(np.arange(len(predictions)), 8): exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n' if jax.host_id() == 0: eval_summary_writer.scalar('bleu', bleu_score, step) eval_summary_writer.text('samples', exemplars, step) eval_summary_writer.flush() logging.info('Translation BLEU Score %.4f', bleu_score)
def testIssue222(self): x = random.randint(random.PRNGKey(10003), (), 0, 0) assert x == 0
def main(argv): global CFG CFG = FLAGS.config if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Guarantee that the JAX bfloat16 extension is used rather than TF bfloat16. _ = np.array(jnp.array([1.0], dtype=jnp.bfloat16)) # Use hardware RNG for bernoulli randoms in dropout mask creation. if CFG.hardware_rng: models.set_hardware_bernoulli() if 'module_import' in CFG and CFG.module_import: for module in CFG.module_import: importlib.import_module(module) if 'additional_task_cache_dirs' in CFG and CFG.additional_task_cache_dirs: t5.data.add_global_cache_dirs(CFG.additional_task_cache_dirs) num_partitions = CFG.num_partitions topology = train_lib.compute_multihost_topology(num_partitions) batch_size = CFG.batch_size eval_batch_size = CFG.eval_batch_size per_replica_set_eval_batch_size = eval_batch_size // topology.num_replica_sets if batch_size % topology.num_replicas: raise ValueError('Batch size must be divisible by the number of replicas.') steps_per_epoch = CFG.steps_per_epoch logging.info('steps per epoch: %d', steps_per_epoch) broadcast = functools.partial( train_lib.broadcast, num_replicas=topology.per_replica_set_num_replicas, num_partitions=topology.per_host_num_partitions, devices=topology.this_host_device_assignment) if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) tf.io.gfile.copy(FLAGS['config'].config_filename, os.path.join(FLAGS.model_dir, 'config.py'), overwrite=True) train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) else: train_summary_writer = None eval_summary_writer = None # Write summaries in background thread to avoid blocking on device sync if CFG.infeed: # Infeed is currently synchronous, so do it in a background thread too infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed') (train_ds, eval_ds), eval_cache = input_pipeline.get_datasets_and_cache( CFG, topology.num_replica_sets, topology.replica_set_id, topology.per_replica_set_host_id) vocab = input_pipeline.get_vocabulary(CFG.mixture_or_task_name) encoder = vocab.tf_tokenizer eos_id = vocab.tokenizer.eos_id() def decode_tokens(toks, eos_id = eos_id, max_id = 32000): """Decode tokens back to unicode.""" del eos_id # TODO(levskaya): T5 doesn't seem to emit EOS tokens? double check this # is the best decoding function or just switch to using tf_decode. # valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) valid_toks = toks.astype(np.int32) valid_toks[valid_toks >= max_id] = 3 return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') train_config, eval_config, predict_config = get_configs(CFG) rng = random.PRNGKey(CFG.random_seed) rng, init_rng = random.split(rng) # This is used for infeed conversion from feature dict <--> tuple train_keys = [ 'inputs', 'targets', 'inputs_position', 'targets_position', 'inputs_segmentation', 'targets_segmentation' ] device_train_input_shape = tuple([ (batch_size // topology.num_replicas, CFG.max_input_length if 'inputs' in k else CFG.max_target_length) for k in train_keys ]) learning_rate_fn = train_lib.create_learning_rate_scheduler( factors=CFG.schedule, base_learning_rate=CFG.learning_rate, warmup_steps=CFG.warmup_steps) # First, we only abstractly initialize the optimizer and model parameters, # since the parameters may not even fit in device memory! # TODO(jekbradbury): make optimizer_defs compare by value so it can be created # in get_initial_params without causing pytree incompatibility optimizer_def = optim.Adafactor( CFG.learning_rate, decay_rate=0.8, step_offset=CFG.step_offset) initialize_params_fn = functools.partial( get_initial_params, config=CFG, transformer_config=eval_config, optimizer_def=optimizer_def) optimizer = jax.eval_shape(initialize_params_fn, init_rng) # tuple-like pytree leaves for global_arg_shapes optimizer_shapes = jax.tree_map(lambda x: partitions.Spec(*x.shape), optimizer) # Build parameter partition annotations for preserving partitions from train # to eval. if num_partitions > 1: optimizer_partitions = optimizer.restore_state( partitions.set_partitions(num_partitions, optimizer.state_dict())) per_host_optimizer_partitions = optimizer.restore_state( partitions.set_partitions(topology.per_host_num_partitions, optimizer.state_dict())) # Restore unreplicated optimizer + model state from last checkpoint. # TODO(jekbradbury,levskaya): implement sharded native checkpoint/restore existing_checkpoint_found = False if CFG.restore_checkpoints: existing_checkpoint_found = train_lib.checkpoint_exists(FLAGS.model_dir) optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Import a pretrained-T5 checkpoint only if we didn't import a local # "native" checkpoint (e.g. due to resuming a pre-empted finetuning run.) # TODO(jekbradbury,levskaya): implement sharded T5 checkpoint/restore if CFG.restore_t5_checkpoint and not existing_checkpoint_found: optimizer = checkpoint_importer.restore_from_t5_checkpoint( optimizer, CFG.restore_t5_checkpoint) if CFG.restore_t5_checkpoint or existing_checkpoint_found: if num_partitions > 1: # Until checkpoint/restore is sharded, the restored checkpoint is global # and we need to slice each sharded parameter into the chunk containing # only the partitions that are present on this host. def per_host_chunk(x, spec): if spec is None or spec is x: # unsharded or not a parameter return x if spec[0] == 1: dim_size = x.shape[1] elif spec[1] == 1: dim_size = x.shape[0] else: raise NotImplementedError() chunk_size = ( dim_size * topology.per_host_num_partitions // num_partitions) lower = topology.per_replica_set_host_id * chunk_size upper = (topology.per_replica_set_host_id + 1) * chunk_size if spec[0] == 1: return x[:, lower:upper] else: return x[lower:upper] optimizer = jax.tree_multimap(per_host_chunk, optimizer, optimizer_partitions) else: # If pretraining and no checkpoint imported, we jit the (sharded-) init # function to minimize fragmentation. We use the same pmap(sharded_jit) # setup as the training step/loop to initialize everything "in-place" and # avoid communication or OOM. if num_partitions > 1: initialize_params_fn = sharded_jit( initialize_params_fn, in_parts=None, local_in_parts=None, out_parts=optimizer_partitions, local_out_parts=per_host_optimizer_partitions, # devices=one_replica_device_assignment, ) initialize_params_fn = jax.pmap( initialize_params_fn, 'batch', in_axes=0, axis_size=topology.num_replicas, devices=topology.device_assignment) init_rng = broadcast(init_rng) optimizer = initialize_params_fn(init_rng) # We maintain the optimizer in unbroadcasted form (i.e. with no leading # replica axis). This is equivalent to the as-yet-nonexistent pmap kwarg # out_axes=None. optimizer = train_lib.unbroadcast(optimizer) else: optimizer = jax.jit(initialize_params_fn)(init_rng) # --------------------------------------------------------------------------- # Compile multidevice versions of train/eval/predict step and cache init fn. # --------------------------------------------------------------------------- # We can use either a single train-step for a host training loop: # train_step(optimizer, batch, prev_metrics, dropout_rng, **kwargs) # --> new_optimizer, metrics, new_dropout_rng def p_train_step(optimizer, batch, prev_metrics, dropout_rng): return train_lib.train_step( optimizer, batch, prev_metrics, dropout_rng, config=train_config, learning_rate_fn=learning_rate_fn, num_microbatches=CFG.microbatches, label_smoothing=CFG.label_smoothing, z_loss=CFG.z_loss, use_bfloat16=CFG.use_bfloat16) if num_partitions > 1: p_train_step = sharded_jit( p_train_step, in_parts=(optimizer_partitions, None, None, None), local_in_parts=(per_host_optimizer_partitions, None, None, None), out_parts=(optimizer_partitions, None, None), local_out_parts=(per_host_optimizer_partitions, None, None)) # TODO(levskaya): the in_axes spec below might be wrong, double-check. p_train_step = jax.pmap( p_train_step, axis_name='batch', in_axes=(None, 0, 0, 0), donate_argnums=(0,), global_arg_shapes=(optimizer_shapes, None, None, None), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # OR, we use an on-device loop that feeds the training step via infeed queue. def device_train_loop_cond( args ): """Stopping criterion for on-device loop.""" _, _, _, _, step, epoch = args return step // steps_per_epoch == epoch def device_train_loop_body( args ): """On-device loop body.""" optimizer, dropout_rngs, metrics, token, step, epoch = args # Ordering input data from infeed requires threading a symbolic token # through the computation. input_data, token = lax.infeed( token, shape=tuple( [jax.ShapedArray(s, jnp.int32) for s in device_train_input_shape])) # Rebuild input dict from infeed data tuple. batch = {k: v for k, v in zip(train_keys, input_data)} # Run the train_step function and return the loop state. optimizer, metrics, dropout_rngs = train_lib.train_step( optimizer, batch, metrics, dropout_rngs, train_config, learning_rate_fn, num_microbatches=CFG.microbatches, label_smoothing=CFG.label_smoothing, z_loss=CFG.z_loss) step += 1 return optimizer, dropout_rngs, metrics, token, step, epoch def device_train_loop(optimizer, dropout_rngs, metrics, step, epoch): # Create symbolic token for threading infeed data. token = lax.create_token(step) # Run on-device loop. optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, dropout_rngs, metrics, token, step, epoch)) return optimizer, dropout_rngs, metrics, step if num_partitions > 1: device_train_loop = sharded_jit( device_train_loop, in_parts=(optimizer_partitions, None, None, None, None), local_in_parts=(per_host_optimizer_partitions, None, None, None, None), out_parts=(optimizer_partitions, None, None, None), local_out_parts=(per_host_optimizer_partitions, None, None, None)) p_train_epoch = jax.pmap( device_train_loop, axis_name='batch', in_axes=(None, 0, 0, None, None), donate_argnums=(0,), global_arg_shapes=(optimizer_shapes, None, None, None, None), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # Reduction psum for metric data. def p_allreduce_metrics(x): return lax.psum(x, axis_name='batch') if num_partitions > 1: p_allreduce_metrics = sharded_jit( p_allreduce_metrics, in_parts=None, local_in_parts=None, out_parts=None, local_out_parts=None, num_partitions=num_partitions, local_num_partitions=topology.per_host_num_partitions) p_allreduce_metrics = jax.pmap( p_allreduce_metrics, axis_name='batch', global_arg_shapes=None, axis_size=topology.num_replicas, devices=topology.device_assignment) # Training evaluation computation. # eval_step(params, batch, config, label_smoothing=0.0) --> metrics def p_eval_step(params, batch): return train_lib.eval_step( params, batch, config=eval_config, label_smoothing=CFG.label_smoothing) if num_partitions > 1: p_eval_step = sharded_jit( p_eval_step, in_parts=(optimizer_partitions.target, None), local_in_parts=(per_host_optimizer_partitions.target, None), out_parts=None, local_out_parts=None) p_eval_step = jax.pmap( p_eval_step, axis_name='batch', in_axes=(None, 0), global_arg_shapes=(optimizer_shapes.target, None), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # Fast autoregressive decoding loop. # For inference and model evaluation. # predict_step(inputs, params, # eos_id, max_decode_len, config, beam_size=4) --> beam_seqs def p_pred_step(inputs, params): return train_lib.predict_step(inputs, params, eos_id, CFG.max_eval_target_length, predict_config, CFG.beam_size) if num_partitions > 1: p_pred_step = sharded_jit( p_pred_step, in_parts=(None, optimizer_partitions.target), local_in_parts=(None, per_host_optimizer_partitions.target), out_parts=None, local_out_parts=None) p_pred_step = jax.pmap( p_pred_step, axis_name='batch', in_axes=(0, None), global_arg_shapes=(None, optimizer_shapes.target), axis_size=topology.num_replicas, devices=topology.device_assignment) # pytype: disable=wrong-arg-types # --------------------------------------------------------------------------- # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. # There should be a unique dropout key for each replica represented on this # host, but the key should be the same for the same replica on other hosts. # Again, this is what the replica set abstraction is for. dropout_rngs = random.split( random.fold_in(rng, topology.replica_set_id), topology.per_replica_set_num_replicas) # restore step from last checkpoint host_step = int(optimizer.state.step) empty_metrics = broadcast({ 'loss': 0.0, 'accuracy': 0.0, 'learning_rate': 0.0, 'denominator': 0.0 }) if CFG.infeed: # TODO(jekbradbury): support something like this for the Python-loop case logging.info('Precompiling training loop and moving optimizer to device.') optimizer, _, metrics, _ = p_train_epoch(optimizer, dropout_rngs, empty_metrics, jnp.array(0, dtype=jnp.int32), 1) optimizer = train_lib.unbroadcast(optimizer) metrics['loss'].block_until_ready() logging.info('Starting training loop.') local_devices = jax.local_devices() device_step = broadcast(host_step) first_epoch = host_step // steps_per_epoch # Main Loop over "epochs". train_iter = train_ds.as_numpy_iterator() for epoch in range(first_epoch, first_epoch + CFG.num_epochs): metrics = empty_metrics # NOTE: 'optimizer' is unbroadcast by construction at initialization or # when loading a checkpoint. It is maintained in 'unbroadcast' state to # enable the XLA cross-replica sharding optimization. The broadcasting is # handled automatically by the pmap'd functions that use it. # Gather all task evaluation metrics. logging.info('Evaluating tasks.') if epoch == first_epoch + 1: train_lib.sync_devices() for task in eval_cache.tasks: logging.info('Evaluating task %s', task.name) all_predicted, all_bs = [], [] for pred_batch in eval_cache.preprocessed_examples[task.name]: # Handle final odd-sized batch by padding instead of dropping it. input_batch, unpadded_batch_size = train_lib.pad_batch_to_size( pred_batch['inputs'], per_replica_set_eval_batch_size) all_bs.append(unpadded_batch_size) # Split batch dimensions for pmap. input_batch = jax.tree_map( lambda x: x.reshape( (topology.per_replica_set_num_replicas, -1) + x.shape[1:]), input_batch) # Run fast inference on batch. all_predicted.append(p_pred_step(input_batch, optimizer.target)) # Pad out the number of batches so each host has the same number. max_host_batch_number = np.max( eval_cache.preprocessed_batch_sizes[task.name]) batch_shortfall = max_host_batch_number - len(all_predicted) if batch_shortfall > 0: # TODO(levskaya): Fix for case of entirely empty all_predicted. # To make sure the cross-host barriers work, we run the program the same # number of times on all hosts. The results of this call is ignored, and # the predictions are populated with zeros instead. p_pred_step(input_batch, optimizer.target) # Dummy call. all_predicted.extend([jnp.zeros_like(all_predicted[0])] * batch_shortfall) all_bs.extend([0] * batch_shortfall) all_predicted = jnp.concatenate(all_predicted) all_bs = jnp.array(all_bs) # Collect all batches from across hosts and reverse sharding. all_predicted = train_lib.host_allgather( all_predicted, topology.num_replica_sets, topology.replica_set_id, topology.per_replica_set_host_id == 0) seqlength = all_predicted.shape[-1] total_examples = np.sum( train_lib.host_allgather(all_bs, topology.num_replica_sets, topology.replica_set_id, topology.per_replica_set_host_id == 0)) del all_bs assert total_examples == len(eval_cache.examples[task.name]), ( 'Total number of batches incorrect for task %s.' % task.name) # De-shard the collected predicted tokens and remove padding. all_predicted = np.transpose(all_predicted, (1, 2, 0, 3)).reshape( -1, seqlength)[:total_examples] # We now run the post-processing and metric-fns on a single host. if jax.host_id() == 0: assert eval_summary_writer raw_predictions = [] for tokens in all_predicted: raw_predictions.append(decode_tokens(tokens)) # post-process predictions for metric fns predictions = [ task.postprocess_fn(p, example=ex) for p, ex in zip(raw_predictions, eval_cache.examples[task.name]) ] for metric_fn in task.metric_fns: scores = metric_fn(eval_cache.targets[task.name], predictions) for metric_name, metric_value in scores.items(): tag = f'eval/{task.name}/{metric_name}' eval_summary_writer.scalar(tag, metric_value, host_step) logging.info('EVAL %s at step %d: %.3f', tag, host_step, metric_value) eval_summary_writer.flush() # Save text samples for tensorboard. exemplars = '' for n in np.random.choice(np.arange(len(predictions)), 8): tgt_txt = tf.compat.as_text( eval_cache.examples[task.name][n]['targets_plaintext']) pred_txt = raw_predictions[n] exemplars += (f'{eval_cache.inputs[task.name][n]}\n\n' f'target: {tgt_txt}\n\n' f'prediction: {pred_txt}\n\n') eval_summary_writer.text(f'{task.name} samples', exemplars, host_step) eval_summary_writer.flush() # Take an Xprof trace after the first loop has compiled everything. if epoch == first_epoch + 1: train_lib.sync_devices() # For on-device loop, we launch the computation before feeding data. logging.info('BEGIN Train loop.') if CFG.infeed: optimizer, dropout_rngs, metrics, device_step = p_train_epoch( optimizer, dropout_rngs, metrics, train_lib.unbroadcast(device_step), epoch) optimizer = train_lib.unbroadcast(optimizer) # Epoch loop. while int(host_step // steps_per_epoch) == epoch: batch = next(train_iter) batch = jax.tree_map( lambda x: x.reshape( (topology.per_replica_set_num_replicas, -1) + x.shape[1:]), batch) # Feed the on-device training loop. if CFG.infeed: for i, device in enumerate(local_devices): # When using infeed to provide data to the computation, we're on our # own for feeding the right values to the right devices. Each device # should get the minibatch corresponding to its replica, a slice of # the larger batch corresponding to the host's replica set. if device.platform == 'tpu': device_coords = (*device.coords, device.id % 2) else: device_coords = (device.host_id, i) per_replica_set_device_coords = tuple( dc % prsm for dc, prsm in zip(device_coords, topology.per_replica_set_mesh)) per_replica_set_replica_coords = tuple( prsdc // prm for prsdc, prm in zip(per_replica_set_device_coords, topology.per_replica_mesh)) per_replica_set_replica_id = 0 for prsm, prm, prsrc in zip(topology.per_replica_set_mesh, topology.per_replica_mesh, per_replica_set_replica_coords): per_replica_set_replica_id = ( per_replica_set_replica_id * prsm // prm + prsrc) input_tuple = tuple( [batch[k][per_replica_set_replica_id] for k in train_keys]) # Safety check: infeed does not check shape or types but requires # them to agree with on-device spec, otherwise the queue and program # stalls. tuple_shapes = jax.tree_map(jnp.shape, input_tuple) tuple_dtypes = jax.tree_map(lambda x: x.dtype, input_tuple) assert tuple_shapes == device_train_input_shape, ( 'infeed shape error %s != %s' % (tuple_shapes, device_train_input_shape)) assert tuple(set(tuple_dtypes)) == (jnp.int32,), \ ('infeed dtype error %s not all of type %s' % ( tuple_dtypes, jnp.int32)) infeed_pool.submit( functools.partial(device.transfer_to_infeed, input_tuple)) # Host training loop. else: optimizer, metrics, dropout_rngs = p_train_step(optimizer, batch, metrics, dropout_rngs) optimizer = train_lib.unbroadcast(optimizer) host_step += 1 logging.info('END Train loop.') # Maybe save a checkpoint on one host. if (CFG.save_checkpoints and epoch % CFG.checkpoint_freq == CFG.checkpoint_freq - 1 and jax.host_id() == 0): checkpoints.save_checkpoint(FLAGS.model_dir, optimizer, host_step) # Gather training metrics. metrics = p_allreduce_metrics(metrics) metrics = jax.tree_map(lambda x: jax.device_get(x[0]), metrics) denominator = metrics.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics) # pylint: disable=cell-var-from-loop logging.info('train in step: %s, %s', host_step, summary) if jax.host_id() == 0: assert train_summary_writer for key, val in summary.items(): train_summary_writer.scalar(key, val, host_step) train_summary_writer.flush() # Gather training evaluation metrics. logging.info('Gathering training evaluation metrics.') eval_metrics = [] eval_iter = eval_ds.as_numpy_iterator() for _, eval_batch in zip(range(CFG.num_eval_steps), eval_iter): eval_batch = jax.tree_map( lambda x: x.reshape( (topology.per_replica_set_num_replicas, -1) + x.shape[1:]), eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) # average metrics across devices eval_metrics = p_allreduce_metrics(eval_metrics) eval_metrics = common_utils.get_metrics(eval_metrics) # average metrics across steps eval_metrics = jax.tree_map(np.sum, eval_metrics) eval_denominator = eval_metrics.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics) logging.info('eval in step: %s, %s', host_step, eval_summary) if jax.host_id() == 0: assert eval_summary_writer for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, host_step) eval_summary_writer.flush() # Wait until computations are done before exiting logging.info('Finished.') train_lib.sync_devices() # Shut down the infeed threadpool. if CFG.infeed: infeed_pool.shutdown()
def f(x): return random.gamma(random.PRNGKey(0), x)
def renyi_loss_fn(x): return RenyiELBO(alpha=alpha, num_particles=10).loss(random.PRNGKey(0), {}, model, guide, x)