Ejemplo n.º 1
0
    def initialize(self,
                   y=None,
                   num_subunits=1,
                   dt=0.033,
                   method='random',
                   compute_ci=True,
                   random_seed=2046,
                   verbose=0,
                   add_noise_to_mle=0):

        self.init_method = method  # store meta
        self.num_subunits = num_subunits
        self.compute_ci = compute_ci

        if method == 'random':

            self.b['random'] = {}
            self.w['random'] = {}
            self.intercept['random'] = {}
            if verbose:
                print('Initializing model parameters randomly...')

            for i, name in enumerate(self.filter_names):
                self.intercept['random'][name] = 0.
                key = random.PRNGKey(random_seed +
                                     i)  # change random seed for each filter
                if name in self.S:
                    self.b['random'][name] = random.normal(
                        key, shape=(self.XS['train'][name].shape[1],
                                    1)).astype(self.dtype)
                    self.w['random'][
                        name] = self.S[name] @ self.b['random'][name]
                else:
                    self.w['random'][name] = random.normal(
                        key, shape=(self.X['train'][name].shape[1],
                                    1)).astype(self.dtype)
            self.intercept['random']['global'] = 0.

            if verbose:
                print('Finished.')

        elif method == 'mle':

            if verbose:
                print(
                    'Initializing model parameters with maximum likelihood...')

            if not self.mle_computed:
                self.compute_mle(y)

            if verbose:
                print('Finished.')

        else:
            raise ValueError(f'`{method}` is not supported.')

        # rename and repmat: stimulus filter to subunits filters
        # subunit model only works with one stimulus.
        filter_names = self.filter_names.copy()
        if num_subunits != 1:
            filter_names.remove('stimulus')
            filter_names = [f'stimulus_s{i}'
                            for i in range(num_subunits)] + filter_names

            for name in filter_names:
                if 'stimulus' in name:
                    self.dims[name] = self.dims['stimulus']
                    self.df[name] = self.dims['stimulus']
                    self.shift[name] = self.shift['stimulus']
                    self.filter_nonlinearity[name] = self.filter_nonlinearity[
                        'stimulus']
                    self.intercept[method][name] = self.intercept[method][
                        'stimulus']
                    self.w[method][name] = self.w[method]['stimulus']

                    if method in self.w_se:
                        self.w_se[method][name] = self.w_se[method]['stimulus']
                    self.X['train'][name] = self.X['train']['stimulus']

                    if 'dev' in self.X:
                        self.X['dev'][name] = self.X['dev']['stimulus']

                    if 'stimulus' in self.S:
                        self.b[method][name] = self.b[method]['stimulus']
                        if method in self.b_se:
                            self.b_se[method][name] = self.b_se[method][
                                'stimulus']
                        self.XS['train'][name] = self.XS['train']['stimulus']

                        if 'dev' in self.XS:
                            self.XS['dev'][name] = self.XS['dev']['stimulus']

                        self.S[name] = self.S['stimulus']

            self.b[method].pop('stimulus', None)
            self.w[method].pop('stimulus')
            self.intercept[method].pop('stimulus')
            self.X['train'].pop('stimulus')
            self.X['dev'].pop('stimulus')
            if self.XS != {}:
                self.XS['train'].pop('stimulus')
                self.XS['dev'].pop('stimulus')
                self.S.pop('stimulus')

            self.filter_names = filter_names

        self.p[method] = {}
        p0 = {}
        for i, name in enumerate(self.filter_names):
            if name in self.S:
                b = self.b[method][name]
                key = random.PRNGKey(random_seed + i)
                self.p[method].update({name: b})
                noise = add_noise_to_mle * random.normal(
                    key, shape=b.shape).astype(self.dtype)
                p0.update({name: b + noise})

            else:
                w = self.w[method][name]
                key = random.PRNGKey(random_seed + i)
                self.p[method].update({name: w})
                noise = add_noise_to_mle * random.normal(
                    key, shape=w.shape).astype(self.dtype)
                p0.update({name: w + noise})

            self.p[method].update({'intercept': self.intercept[method]})
            p0.update({'intercept': self.intercept[method]})

        self.dt = dt
        self.p0 = p0
Ejemplo n.º 2
0
 def testParetoShape(self):
     key = random.PRNGKey(0)
     x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2))
     assert x.shape == (3, 2)
Ejemplo n.º 3
0
 def testFoldIn(self):
     key = random.PRNGKey(0)
     keys = [random.fold_in(key, i) for i in range(10)]
     assert np.unique(np.ravel(keys)).shape == (20, )
Ejemplo n.º 4
0
 def testBernoulliShape(self):
     key = random.PRNGKey(0)
     x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2))
     assert x.shape == (3, 2)
Ejemplo n.º 5
0
 def testPoissonBatched(self):
     key = random.PRNGKey(0)
     lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)])
     samples = random.poisson(key, lam, shape=(20000, ))
     self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf)
     self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf)
Ejemplo n.º 6
0
def main(args):
    X, Y, expected_thetas, expected_pairwise = get_data(
        N=args.num_data, P=args.num_dimensions, S=args.active_dimensions
    )

    # setup hyperparameters
    hypers = {
        "expected_sparsity": max(1.0, args.num_dimensions / 10),
        "alpha1": 3.0,
        "beta1": 1.0,
        "alpha2": 3.0,
        "beta2": 1.0,
        "alpha3": 1.0,
        "c": 1.0,
    }

    # do inference
    rng_key = random.PRNGKey(0)
    samples = run_inference(model, args, rng_key, X, Y, hypers)

    # compute the mean and square root variance of each coefficient theta_i
    means, stds = vmap(lambda dim: analyze_dimension(samples, X, Y, dim, hypers))(
        jnp.arange(args.num_dimensions)
    )

    print(
        "Coefficients theta_1 to theta_%d used to generate the data:"
        % args.active_dimensions,
        expected_thetas,
    )
    print(
        "The single quadratic coefficient theta_{1,2} used to generate the data:",
        expected_pairwise,
    )
    active_dimensions = []

    for dim, (mean, std) in enumerate(zip(means, stds)):
        # we mark the dimension as inactive if the interval [mean - 3 * std, mean + 3 * std] contains zero
        lower, upper = mean - 3.0 * std, mean + 3.0 * std
        inactive = "inactive" if lower < 0.0 and upper > 0.0 else "active"
        if inactive == "active":
            active_dimensions.append(dim)
        print(
            "[dimension %02d/%02d]  %s:\t%.2e +- %.2e"
            % (dim + 1, args.num_dimensions, inactive, mean, std)
        )

    print(
        "Identified a total of %d active dimensions; expected %d."
        % (len(active_dimensions), args.active_dimensions)
    )

    # Compute the mean and square root variance of coefficients theta_ij for i,j active dimensions.
    # Note that the resulting numbers are only meaningful for i != j.
    if len(active_dimensions) > 0:
        dim_pairs = jnp.array(
            list(itertools.product(active_dimensions, active_dimensions))
        )
        means, stds = vmap(
            lambda dim_pair: analyze_pair_of_dimensions(
                samples, X, Y, dim_pair[0], dim_pair[1], hypers
            )
        )(dim_pairs)
        for dim_pair, mean, std in zip(dim_pairs, means, stds):
            dim1, dim2 = dim_pair
            if dim1 >= dim2:
                continue
            lower, upper = mean - 3.0 * std, mean + 3.0 * std
            if not (lower < 0.0 and upper > 0.0):
                format_str = "Identified pairwise interaction between dimensions %d and %d: %.2e +- %.2e"
                print(format_str % (dim1 + 1, dim2 + 1, mean, std))

        # Draw a single sample of coefficients theta from the posterior, where we return all singleton
        # coefficients theta_i and pairwise coefficients theta_ij for i, j active dimensions. We use the
        # final MCMC sample obtained from the HMC sampler.
        thetas = sample_theta_space(
            X,
            Y,
            active_dimensions,
            samples["msq"][-1],
            samples["lambda"][-1],
            samples["eta1"][-1],
            samples["xisq"][-1],
            hypers["c"],
            samples["sigma"][-1],
        )
        print("Single posterior sample theta:\n", thetas)
Ejemplo n.º 7
0
    stax.parallel(Dense(10), stax.serial(Dense(10), Softplus)),
)

decoder_init, decode = stax.serial(
    Dense(512), Relu,
    Dense(512), Relu,
    Dense(28 * 28),
)


if __name__ == "__main__":
  step_size = 0.001
  num_epochs = 100
  batch_size = 32
  nrow, ncol = 10, 10  # sampled image grid size
  rng = random.PRNGKey(0)

  test_rng = random.PRNGKey(1)  # fixed prng key for evaluation
  imfile = os.path.join(os.getenv("TMPDIR", "/tmp/"), "mnist_vae_{:03d}.png")

  train_images, _, test_images, _ = datasets.mnist(permute_train=True)
  num_complete_batches, leftover = divmod(train_images.shape[0], batch_size)
  num_batches = num_complete_batches + bool(leftover)

  _, init_encoder_params = encoder_init((batch_size, 28 * 28))
  _, init_decoder_params = decoder_init((batch_size, 10))
  init_params = init_encoder_params, init_decoder_params

  opt_init, opt_update = optimizers.momentum(step_size, mass=0.9)

  def binarize_batch(rng, i, images):
Ejemplo n.º 8
0
 def x():
     return random.normal(random.PRNGKey(1), (2, 3), dtype)
Ejemplo n.º 9
0
 def x_cpu():
     return device_get(
         random.normal(random.PRNGKey(1), (2, 3), dtype))
Ejemplo n.º 10
0
    return ce_loss, acc


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--hidden', type=int, default=16)
    parser.add_argument('--epochs', type=int, default=400)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--lr', type=float, default=0.005)
    args = parser.parse_args()

    # Load data
    adj, features, labels, idx_train, idx_val, idx_test = load_data()

    rng_key = random.PRNGKey(args.seed)
    step_size = args.lr
    num_epochs = args.epochs
    n_nodes = adj.shape[0]
    n_feats = features.shape[1]

    # GAT params
    nheads = [8, 1]
    nhid = [8]
    dropout = args.dropout  # probability of keeping
    residual = False

    init_fun, predict_fun = GAT(nheads=nheads,
                                nhid=nhid,
                                nclass=labels.shape[1],
                                dropout=dropout,
Ejemplo n.º 11
0
    def __init__(
        self,
        w_in: np.ndarray,
        w_recurrent: np.ndarray,
        w_out: np.ndarray,
        tau: np.ndarray,
        bias: np.ndarray,
        noise_std: float = 0.0,
        activation_func: Callable[[FloatVector], FloatVector] = H_ReLU,
        dt: Optional[float] = None,
        name: Optional[str] = None,
        rng_key: Optional[int] = None,
    ):
        """
        RecRateEulerJax - ``JAX``-backed firing rate reservoir

        :param np.ndarray w_in:                     Input weights [IxN]
        :param np.ndarray w_recurrent:              Recurrent weights [NxN]
        :param np.ndarray w_out:                    Output weights [NxO]
        :param np.ndarray tau:                      Time constants [N]
        :param np.ndarray bias:                     Bias values [N]
        :param float noise_std:           White noise standard deviation applied to reservoir neurons. Default: ``0.0``
        :param Callable[[FloatVector], float] activation_func:   Neuron transfer function f(x: float) -> float. Must be vectorised. Default: H_ReLU
        :param Optional[float] dt:                  Reservoir time step. Default: ``np.min(tau) / 10.0``
        :param Optional[str] name:                  Name of the layer. Default: ``None``
        :param Optional[Jax RNG key] rng_key        Jax RNG key to use for noise. Default: Internally generated
        """

        # - Everything should be 2D
        w_in = np.atleast_2d(w_in)
        w_recurrent = np.atleast_2d(w_recurrent)
        w_out = np.atleast_2d(w_out)

        # transform to np.array if necessary
        tau = np.array(tau)
        bias = np.array(bias)

        # - Get information about network size
        self._size_in = w_in.shape[0]
        self._size = w_in.shape[1]
        self._size_out = w_out.shape[1]

        # -- Set properties
        self.w_recurrent = w_recurrent
        self.w_out = w_out
        self.tau = tau
        self.bias = bias
        self._H = activation_func

        if dt is None:
            dt = np.min(tau) / 10.0

        # - Call super-class initialisation
        super().__init__(w_in, dt, noise_std, name)

        # - Correct layer size
        self._size_in = w_in.shape[0]
        self._size_out = w_out.shape[1]

        # - Get compiled evolution function
        self._evolve_jit = _get_rec_evolve_jit(activation_func)

        # - Reset layer state
        self.reset_all()

        # - Seed RNG
        if rng_key is None:
            rng_key = rand.PRNGKey(onp.random.randint(0, 2 ** 63))
        _, self._rng_key = rand.split(rng_key)
Ejemplo n.º 12
0
    def test_jit_or_pmap_broadcast(self):
        def kernel_fn(x1,
                      x2,
                      do_flip,
                      keys,
                      do_square,
                      params,
                      _unused=None,
                      p=0.65):
            res = np.abs(np.matmul(x1, x2))
            if do_square:
                res *= res
            if do_flip:
                res = -res

            res *= random.uniform(keys) * p
            return [res, params]

        params = (np.array([1., 0.3]), (np.array([1.2]), np.array([0.5])))
        x2 = np.arange(0, 10).reshape((10, ))
        keys = random.PRNGKey(1)

        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=0)
        x1 = np.arange(0, 10).reshape((1, 10))
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=0):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=True,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=True)
                    self.assertAllClose(res_1, res_2)

        test_utils.stub_out_pmap(batch, 1)
        x1 = np.arange(0, 10).reshape((1, 10))
        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=1)
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=1):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=False,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None)
                    self.assertAllClose(res_1[0], res_2[0])
                    self.assertAllClose(
                        tree_map(partial(np.expand_dims, axis=0), res_1[1]),
                        res_2[1])

        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=2)
        x1 = np.arange(0, 20).reshape((2, 10))
        test_utils.stub_out_pmap(batch, 2)

        def broadcast(arg):
            return np.broadcast_to(arg, (2, ) + arg.shape)

        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=2):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      p=0.2)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None,
                                              p=0.2)
                    self.assertAllClose(res_1[0][0], res_2[0][0])
                    self.assertAllClose(res_1[0][1], res_2[0][1])
                    self.assertAllClose(tree_map(broadcast, res_1[1]),
                                        res_2[1])
Ejemplo n.º 13
0
def main(_):
    sns.set()
    sns.set_palette(sns.color_palette('hls', 10))
    npr.seed(FLAGS.seed)

    logging.info('Starting experiment.')

    # Create model folder for outputs
    try:
        gfile.MakeDirs(FLAGS.work_dir)
    except gfile.GOSError:
        pass
    stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.work_dir), 'w+')

    # BEGIN: fetch test data and candidate pool
    test_images, test_labels, _ = datasets.get_dataset_split(
        name=FLAGS.test_split.split('-')[0],
        split=FLAGS.test_split.split('-')[1],
        shuffle=False)
    pool_images, pool_labels, _ = datasets.get_dataset_split(
        name=FLAGS.pool_split.split('-')[0],
        split=FLAGS.pool_split.split('-')[1],
        shuffle=False)

    n_pool = len(pool_images)
    # normalize to range [-1.0, 127./128]
    test_images = test_images / np.float32(128.0) - np.float32(1.0)
    pool_images = pool_images / np.float32(128.0) - np.float32(1.0)

    # augmentation for train/pool data
    if FLAGS.augment_data:
        augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5),
                                             data.RandomCrop(4), data.ToDevice)
    else:
        augmentation = None
    # END: fetch test data and candidate pool

    _, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate)

    # BEGIN: load ckpt
    ckpt_dir = '{}/{}'.format(FLAGS.root_dir, FLAGS.ckpt_idx)
    with gfile.Open(ckpt_dir, 'wr') as fckpt:
        opt_state = optimizers.pack_optimizer_state(pickle.load(fckpt))
    params = get_params(opt_state)

    stdout_log.write('finetune from: {}\n'.format(ckpt_dir))
    logging.info('finetune from: %s', ckpt_dir)
    test_acc, test_pred = accuracy(params,
                                   shape_as_image(test_images, test_labels),
                                   return_predicted_class=True)
    logging.info('test accuracy: %.2f', test_acc)
    stdout_log.write('test accuracy: {}\n'.format(test_acc))
    stdout_log.flush()
    # END: load ckpt

    # BEGIN: setup for dp model
    @jit
    def update(_, i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad_loss(params, batch), opt_state)

    @jit
    def private_update(rng, i, opt_state, batch):
        params = get_params(opt_state)
        rng = random.fold_in(rng, i)  # get new key for new random numbers
        return opt_update(
            i,
            private_grad(params, batch, rng, FLAGS.l2_norm_clip,
                         FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)

    # END: setup for dp model

    ### BEGIN: prepare extra points picked from pool data
    # BEGIN: on pool data
    pool_embeddings = [apply_fn_0(params[:-1],
                                  pool_images[b_i:b_i + FLAGS.batch_size]) \
                       for b_i in range(0, n_pool, FLAGS.batch_size)]
    pool_embeddings = np.concatenate(pool_embeddings, axis=0)

    pool_logits = apply_fn_1(params[-1:], pool_embeddings)

    pool_true_labels = np.argmax(pool_labels, axis=1)
    pool_predicted_labels = np.argmax(pool_logits, axis=1)
    pool_correct_indices = \
        onp.where(pool_true_labels == pool_predicted_labels)[0]
    pool_incorrect_indices = \
        onp.where(pool_true_labels != pool_predicted_labels)[0]
    assert len(pool_correct_indices) + \
        len(pool_incorrect_indices) == len(pool_labels)

    pool_probs = stax.softmax(pool_logits)

    if FLAGS.uncertain == 0 or FLAGS.uncertain == 'entropy':
        pool_entropy = -onp.sum(pool_probs * onp.log(pool_probs), axis=1)
        stdout_log.write('all {} entropy: min {}, max {}\n'.format(
            len(pool_entropy), onp.min(pool_entropy), onp.max(pool_entropy)))

        pool_entropy_sorted_indices = onp.argsort(pool_entropy)
        # take the n_extra most uncertain points
        pool_uncertain_indices = \
            pool_entropy_sorted_indices[::-1][:FLAGS.n_extra]
        stdout_log.write('uncertain {} entropy: min {}, max {}\n'.format(
            len(pool_entropy[pool_uncertain_indices]),
            onp.min(pool_entropy[pool_uncertain_indices]),
            onp.max(pool_entropy[pool_uncertain_indices])))

    elif FLAGS.uncertain == 1 or FLAGS.uncertain == 'difference':
        # 1st_prob - 2nd_prob
        assert len(pool_probs.shape) == 2
        sorted_pool_probs = onp.sort(pool_probs, axis=1)
        pool_probs_diff = sorted_pool_probs[:, -1] - sorted_pool_probs[:, -2]
        assert min(pool_probs_diff) > 0.
        stdout_log.write('all {} difference: min {}, max {}\n'.format(
            len(pool_probs_diff), onp.min(pool_probs_diff),
            onp.max(pool_probs_diff)))

        pool_uncertain_indices = onp.argsort(pool_probs_diff)[:FLAGS.n_extra]
        stdout_log.write('uncertain {} difference: min {}, max {}\n'.format(
            len(pool_probs_diff[pool_uncertain_indices]),
            onp.min(pool_probs_diff[pool_uncertain_indices]),
            onp.max(pool_probs_diff[pool_uncertain_indices])))

    elif FLAGS.uncertain == 2 or FLAGS.uncertain == 'random':
        pool_uncertain_indices = npr.permutation(n_pool)[:FLAGS.n_extra]

    # END: on pool data

    ### END: prepare extra points picked from pool data

    finetune_images = copy.deepcopy(pool_images[pool_uncertain_indices])
    finetune_labels = copy.deepcopy(pool_labels[pool_uncertain_indices])

    stdout_log.write('Starting fine-tuning...\n')
    logging.info('Starting fine-tuning...')
    stdout_log.flush()

    stdout_log.write('{} points picked via {}\n'.format(
        len(finetune_images), FLAGS.uncertain))
    logging.info('%d points picked via %s', len(finetune_images),
                 FLAGS.uncertain)
    assert FLAGS.n_extra == len(finetune_images)

    for epoch in range(1, FLAGS.epochs + 1):

        # BEGIN: finetune model with extra data, evaluate and save
        num_extra = len(finetune_images)
        num_complete_batches, leftover = divmod(num_extra, FLAGS.batch_size)
        num_batches = num_complete_batches + bool(leftover)

        finetune = data.DataChunk(X=finetune_images,
                                  Y=finetune_labels,
                                  image_size=28,
                                  image_channels=1,
                                  label_dim=1,
                                  label_format='numeric')

        batches = data.minibatcher(finetune,
                                   FLAGS.batch_size,
                                   transform=augmentation)

        itercount = itertools.count()
        key = random.PRNGKey(FLAGS.seed)

        start_time = time.time()

        for _ in range(num_batches):
            # tmp_time = time.time()
            b = next(batches)
            if FLAGS.dpsgd:
                opt_state = private_update(
                    key, next(itercount), opt_state,
                    shape_as_image(b.X, b.Y, dummy_dim=True))
            else:
                opt_state = update(key, next(itercount), opt_state,
                                   shape_as_image(b.X, b.Y))
            # stdout_log.write('single update in {:.2f} sec\n'.format(
            #     time.time() - tmp_time))

        epoch_time = time.time() - start_time
        stdout_log.write('Epoch {} in {:.2f} sec\n'.format(epoch, epoch_time))
        logging.info('Epoch %d in %.2f sec', epoch, epoch_time)

        # accuracy on test data
        params = get_params(opt_state)

        test_pred_0 = test_pred
        test_acc, test_pred = accuracy(params,
                                       shape_as_image(test_images,
                                                      test_labels),
                                       return_predicted_class=True)
        test_loss = loss(params, shape_as_image(test_images, test_labels))
        stdout_log.write(
            'Eval set loss, accuracy (%): ({:.2f}, {:.2f})\n'.format(
                test_loss, 100 * test_acc))
        logging.info('Eval set loss, accuracy: (%.2f, %.2f)', test_loss,
                     100 * test_acc)
        stdout_log.flush()

        # visualize prediction difference between 2 checkpoints.
        if FLAGS.visualize:
            utils.visualize_ckpt_difference(test_images,
                                            np.argmax(test_labels, axis=1),
                                            test_pred_0,
                                            test_pred,
                                            epoch - 1,
                                            epoch,
                                            FLAGS.work_dir,
                                            mu=128.,
                                            sigma=128.)

    # END: finetune model with extra data, evaluate and save

    stdout_log.close()
Ejemplo n.º 14
0
PeDurx = 70
batch_size = 32
Pencoder = PoissonEncoder(duration=PeDurx)
optimizer = SGD(model.parameters(), lr=0.01)

train_loader = DataLoader(dataset=MnistDataset(training=True, flatten=True),
                          collate_fn=collate_fn,
                          shuffle=True,
                          batch_size=batch_size)

test_loader = DataLoader(dataset=MnistDataset(training=False, flatten=True),
                         collate_fn=collate_fn,
                         shuffle=False,
                         batch_size=batch_size)

for epoch in range(15):
    for i, (data, target) in enumerate(train_loader):
        target = Variable(target)
        for t in range(PeDurx):
            rnum = random.uniform(key=random.PRNGKey(0), shape=data.shape)
            uin = (jnp.abs(data) / 2 > rnum).astype('float32')
            q = jnp.multiply(uin, jnp.sign(data))
            output, time = model(Variable(q), t)
            print(str(t))

        loss = F.Spikeloss(output, target, time_step=time)
        loss.backward()  # calc gradients
        optimizer.step()  # update gradients
        print("Epoch:" + str(epoch) + "Time:" + str(i) + "loss:" +
              str(loss.data))
Ejemplo n.º 15
0
learning_rate = 0.18
opt_init, opt_update, get_params = optimizers.sgd(learning_rate)


@jit
def update(_, i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)


"""The next cell contains our training loop, very similar to Problem 1."""

num_epochs = 10

key = random.PRNGKey(123)
_, init_params = init_random_params(key, (-1, 28, 28, 1))
opt_state = opt_init(init_params)
itercount = itertools.count()

test_losses = []
test_accs = []
for epoch in range(1, num_epochs + 1):
    for _ in range(num_batches):
        opt_state = update(key, next(itercount), opt_state,
                           shape_as_image(*next(batches)))

    params = get_params(opt_state)
    # print("Params are: {} ".format(params))
    test_acc = accuracy(params, shape_as_image(test_images, test_labels))
    test_loss = loss(params, shape_as_image(test_images, test_labels))
Ejemplo n.º 16
0
def _wrap_seed_jax(seed, _):
  import jax.random as jaxrand  # pylint: disable=g-import-not-at-top
  return jaxrand.PRNGKey(seed % (2**32 - 1))
Ejemplo n.º 17
0
    s += pytree_dot2(x, y)
    s += pytree_dot2(x, y)
    return s


def internal_only_jit(x, y):
    s = pytree_dot(x, y)
    s += pytree_dot(x, y)
    s += pytree_dot(x, y)
    s += pytree_dot(x, y)
    s += pytree_dot(x, y)
    return s


if __name__ == "__main__":
    k1, k2, k3 = random.split(random.PRNGKey(3), 3)
    x = [np.array([0.5, 0.5]), random.uniform(k1, shape=(50000,)), {"beta": 0.5}]
    y = [np.array([0.5, 0.5]), random.uniform(k2, shape=(50000,)), {"beta": 0.5}]
    z = [np.array([0.5, 0.5]), random.uniform(k3, shape=(50000,)), {"beta": 0.5}]
    x1 = x
    s = 0.0

    start_time = datetime.now()
    for i in range(10000):
        if i % 2 == 0:
            x1 = z
        else:
            x1 = x
        s = _jit(x, y)
    end_time = datetime.now()
    print("Std Jit Duration: {}".format(end_time - start_time))
Ejemplo n.º 18
0
    def run(self, output_h5parm, ncpu, avg_direction_spacing,
            field_of_view_diameter, duration, time_resolution, start_time,
            array_name, phase_tracking):

        Nd = get_num_directions(
            avg_direction_spacing,
            field_of_view_diameter,
        )
        Nf = 2  # 8000
        Nt = int(duration / time_resolution) + 1
        min_freq = 700.
        max_freq = 2000.
        dp = create_empty_datapack(
            Nd,
            Nf,
            Nt,
            pols=None,
            field_of_view_diameter=field_of_view_diameter,
            start_time=start_time,
            time_resolution=time_resolution,
            min_freq=min_freq,
            max_freq=max_freq,
            array_file=ARRAYS[array_name],
            phase_tracking=(phase_tracking.ra.deg, phase_tracking.dec.deg),
            save_name=output_h5parm,
            clobber=True)

        with dp:
            dp.current_solset = 'sol000'
            dp.select(pol=slice(0, 1, 1))
            axes = dp.axes_tec
            patch_names, directions = dp.get_directions(axes['dir'])
            antenna_labels, antennas = dp.get_antennas(axes['ant'])
            timestamps, times = dp.get_times(axes['time'])
            ref_ant = antennas[0]
            ref_time = times[0]

        Na = len(antennas)
        Nd = len(directions)
        Nt = len(times)

        logger.info(f"Number of directions: {Nd}")
        logger.info(f"Number of antennas: {Na}")
        logger.info(f"Number of times: {Nt}")
        logger.info(f"Reference Ant: {ref_ant}")
        logger.info(f"Reference Time: {ref_time.isot}")

        # Plot Antenna Layout in East North Up frame
        ref_frame = ENU(obstime=ref_time, location=ref_ant.earth_location)

        _antennas = ac.ITRS(*antennas.cartesian.xyz,
                            obstime=ref_time).transform_to(ref_frame)
        # plt.scatter(_antennas.east, _antennas.north, marker='+')
        # plt.xlabel(f"East (m)")
        # plt.ylabel(f"North (m)")
        # plt.show()

        x0 = ac.ITRS(
            *antennas[0].cartesian.xyz,
            obstime=ref_time).transform_to(ref_frame).cartesian.xyz.to(
                au.km).value
        earth_centre_x = ac.ITRS(
            x=0 * au.m, y=0 * au.m, z=0. * au.m,
            obstime=ref_time).transform_to(ref_frame).cartesian.xyz.to(
                au.km).value
        self._kernel = TomographicKernel(x0,
                                         earth_centre_x,
                                         M32(),
                                         S_marg=20,
                                         compute_tec=False)

        k = directions.transform_to(ref_frame).cartesian.xyz.value.T

        t = times.mjd * 86400.
        t -= t[0]

        X1 = GeodesicTuple(x=[], k=[], t=[], ref_x=[])

        logger.info("Computing coordinates in frame ...")

        for i, time in tqdm(enumerate(times)):
            x = ac.ITRS(*antennas.cartesian.xyz,
                        obstime=time).transform_to(ref_frame).cartesian.xyz.to(
                            au.km).value.T
            ref_ant_x = ac.ITRS(
                *ref_ant.cartesian.xyz,
                obstime=time).transform_to(ref_frame).cartesian.xyz.to(
                    au.km).value

            X = make_coord_array(x,
                                 k,
                                 t[i:i + 1, None],
                                 ref_ant_x[None, :],
                                 flat=True)

            X1.x.append(X[:, 0:3])
            X1.k.append(X[:, 3:6])
            X1.t.append(X[:, 6:7])
            X1.ref_x.append(X[:, 7:8])

        X1 = X1._replace(
            x=jnp.concatenate(X1.x, axis=0),
            k=jnp.concatenate(X1.k, axis=0),
            t=jnp.concatenate(X1.t, axis=0),
            ref_x=jnp.concatenate(X1.ref_x, axis=0),
        )

        logger.info(f"Total number of coordinates: {X1.x.shape[0]}")

        def compute_covariance_row(X1: GeodesicTuple, X2: GeodesicTuple):
            K = self._kernel(X1,
                             X2,
                             self._bottom,
                             self._width,
                             self._fed_sigma,
                             self._fed_kernel_params,
                             wind_velocity=self._wind_vector)  # 1, N
            return K[0, :]

        covariance_row = lambda X: compute_covariance_row(
            tree_map(lambda x: x.reshape((1, -1)), X), X1)

        mean = jit(lambda X1: self._kernel.mean_function(X1,
                                                         self._bottom,
                                                         self._width,
                                                         self._fed_mu,
                                                         wind_velocity=self.
                                                         _wind_vector))(X1)

        cov = chunked_pmap(covariance_row,
                           X1,
                           batch_size=X1.x.shape[0],
                           chunksize=ncpu)

        plt.imshow(cov)
        plt.show()

        Z = random.normal(random.PRNGKey(42), (cov.shape[0], 1),
                          dtype=cov.dtype)

        t0 = default_timer()
        jitter = 1e-6
        logger.info(f"Computing Cholesky with jitter: {jitter}")
        L = jnp.linalg.cholesky(cov + jitter * jnp.eye(cov.shape[0]))
        if np.any(np.isnan(L)):
            logger.info("Numerically instable. Using SVD.")
            L = msqrt(cov)

        logger.info(f"Cholesky took {default_timer() - t0} seconds.")

        dtec = (L @ Z + mean[:, None])[:, 0].reshape((Na, Nd, Nt)).transpose(
            (1, 0, 2))

        logger.info(f"Saving result to {output_h5parm}")
        with dp:
            dp.current_solset = 'sol000'
            dp.select(pol=slice(0, 1, 1))
            dp.tec = np.asarray(dtec[None])
Ejemplo n.º 19
0
def main(_):

    if FLAGS.microbatches:
        raise NotImplementedError(
            'Microbatches < batch size not currently supported')

    train_images, train_labels, test_images, test_labels = datasets.mnist()
    num_train = train_images.shape[0]
    num_complete_batches, leftover = divmod(num_train, FLAGS.batch_size)
    num_batches = num_complete_batches + bool(leftover)
    key = random.PRNGKey(FLAGS.seed)

    def data_stream():
        rng = npr.RandomState(FLAGS.seed)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * FLAGS.batch_size:(i + 1) *
                                 FLAGS.batch_size]
                yield train_images[batch_idx], train_labels[batch_idx]

    batches = data_stream()

    opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate)

    @jit
    def update(_, i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad(loss)(params, batch), opt_state)

    @jit
    def private_update(rng, i, opt_state, batch):
        params = get_params(opt_state)
        rng = random.fold_in(rng, i)  # get new key for new random numbers
        return opt_update(
            i,
            private_grad(params, batch, rng, FLAGS.l2_norm_clip,
                         FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)

    _, init_params = init_random_params(key, (-1, 28, 28, 1))
    opt_state = opt_init(init_params)
    itercount = itertools.count()

    steps_per_epoch = 60000 // FLAGS.batch_size
    print('\nStarting training...')
    for epoch in range(1, FLAGS.epochs + 1):
        start_time = time.time()
        for _ in range(num_batches):
            if FLAGS.dpsgd:
                opt_state = \
                    private_update(
                        key, next(itercount), opt_state,
                        shape_as_image(*next(batches), dummy_dim=True))
            else:
                opt_state = update(key, next(itercount), opt_state,
                                   shape_as_image(*next(batches)))
        epoch_time = time.time() - start_time
        print('Epoch {} in {:0.2f} sec'.format(epoch, epoch_time))

        # evaluate test accuracy
        params = get_params(opt_state)
        test_acc = accuracy(params, shape_as_image(test_images, test_labels))
        test_loss = loss(params, shape_as_image(test_images, test_labels))
        print('Test set loss, accuracy (%): ({:.2f}, {:.2f})'.format(
            test_loss, 100 * test_acc))

        # determine privacy loss so far
        if FLAGS.dpsgd:
            delta = 1e-5
            num_examples = 60000
            eps = compute_epsilon(epoch * steps_per_epoch, num_examples, delta)
            print('For delta={:.0e}, the current epsilon is: {:.2f}'.format(
                delta, eps))
        else:
            print('Trained with vanilla non-private SGD optimizer')
Ejemplo n.º 20
0
import flaxvision.models as flax_models
from flax import nn
import numpy as np
import jax.numpy as jnp
from jax import random
from jax.config import config
config.enable_omnistaging()

import unittest, os
import logging

RNG = random.PRNGKey(0)

MODELS_LIST = [
    'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'resnet18', 'resnet34',
    'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2',
    'inception_v3', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'fcn_resnet50', 'fcn_resnet101',
    'deeplabv3_resnet50', 'deeplabv3_resnet101'
]


class TestTraining(unittest.TestCase):

  def test_outputs(self):
    log = logging.getLogger(__name__)

    inputs = jnp.ones((1, 224, 224, 3))
    inception_inputs = jnp.ones((1, 299, 299, 3))

    for key in MODELS_LIST:
      log.info(f'testing inference {key}')
Ejemplo n.º 21
0
 def testPermutationErrors(self):
     key = random.PRNGKey(0)
     with self.assertRaises(TypeError):
         random.permutation(key, 10.)
     with self.assertRaises(core.ConcretizationTypeError):
         api.jit(random.permutation)(key, 10)
 def test_sample(self, dist, args, kwargs, out, flat):
     del out, flat
     args = args()
     kwargs = kwargs()
     p = dist(*args, **kwargs)
     p.sample(seed=random.PRNGKey(0))
Ejemplo n.º 23
0
 def testGammaShape(self):
     key = random.PRNGKey(0)
     x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2))
     assert x.shape == (3, 2)
Ejemplo n.º 24
0
    # Training data
    numSamples = inParameters['Training data']['number_of_training_samples']
    trainDataFileName = inParameters['Training data']['training_data']

    # Training data
    outDir = inParameters['Output']['output_folder']

# Set numpy seed
np.random.seed(0)

# Create subdirectory for network checkpoints
create_dir(outDir + "/net_checkpoints/")

# Model setup
rnnNet = RNN2D.partial(L=L, units=rnnUnits, initScale=netInitScale)
_, params = rnnNet.init_by_shape(random.PRNGKey(netInitSeed), [(1, L, L)])
rnnModel = nn.Model(rnnNet, params)

# Optimizer setup
optimizer = flax.optim.Adam(learning_rate=learningRate,
                            beta1=beta1,
                            beta2=beta2).create(rnnModel)

# Load data
if inParameters['Training data']['training_data'] == "generate":
    print("*** Generating samples")
    numTestSamples = inParameters['Training data']['number_of_test_samples']
    if numTestSamples < numSamples:
        numTestSamples = numSamples
    trainData, trainEnergies, testData, testEnergies =\
        generate_samples(numTestSamples,T,L,
Ejemplo n.º 25
0
 def testPoissonShape(self):
     key = random.PRNGKey(0)
     x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2))
     assert x.shape == (3, 2)
Ejemplo n.º 26
0
Archivo: train.py Proyecto: skye/flax
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    if FLAGS.jax_backend_target:
        jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
        jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

    # This seems to be necessary even when importing TF2?
    tf.enable_v2_behavior()

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.host_id() == 0:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'train'))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'eval'))

    if FLAGS.batch_size % n_devices:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    vocab_path = FLAGS.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    logging.info('Initializing dataset.')
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        n_devices=n_devices,
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length)
    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_token = 2  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_token) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode('utf-8')

    logging.info('Initializing model, optimizer, and step functions.')
    # Build Model and Optimizer
    transformer_kwargs = {
        'vocab_size': vocab_size,
        'output_vocab_size': vocab_size,
        'emb_dim': FLAGS.emb_dim,
        'num_heads': FLAGS.num_heads,
        'num_layers': FLAGS.num_layers,
        'qkv_dim': FLAGS.qkv_dim,
        'mlp_dim': FLAGS.mlp_dim,
        'max_len': max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
        'share_embeddings': FLAGS.share_embeddings,
        'logits_via_embedding': FLAGS.logits_via_embedding,
    }

    start_step = 0
    rng = random.PRNGKey(FLAGS.random_seed)
    rng, init_rng = random.split(rng)
    input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    target_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    model, cache_def = create_model(init_rng, input_shape, target_shape,
                                    transformer_kwargs)
    optimizer = create_optimizer(model, FLAGS.learning_rate,
                                 FLAGS.weight_decay)
    # We access model only from optimizer below via optimizer.target.
    del model

    if FLAGS.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=FLAGS.learning_rate,
        warmup_steps=FLAGS.warmup_steps)

    p_train_step = jax.pmap(functools.partial(
        train_step,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=FLAGS.label_smoothing,
        use_bfloat16=FLAGS.use_bfloat16),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(
        eval_step,
        label_smoothing=FLAGS.label_smoothing,
        use_bfloat16=FLAGS.use_bfloat16),
                           axis_name='batch')
    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          use_bfloat16=FLAGS.use_bfloat16,
                          beam_size=FLAGS.beam_size),
        axis_name='batch',
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, n_devices)

    logging.info('Starting training loop.')
    metrics_all = []
    t_loop_start = time.time()
    for step, batch in zip(range(start_step, FLAGS.num_train_steps),
                           train_iter):
        # Shard data to devices and do a training step.
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)

        # Save a checkpoint on one host after every checkpoint_freq steps.
        if (FLAGS.save_checkpoints and step % FLAGS.checkpoint_freq == 0
                and step > 0 and jax.host_id() == 0):
            checkpoints.save_checkpoint(FLAGS.model_dir,
                                        jax_utils.unreplicate(optimizer), step)

        # Periodic metric handling.
        if step % FLAGS.eval_frequency != 0 and step > 0:
            continue

        logging.info('Gathering training metrics.')
        # Training Metrics
        metrics_all = common_utils.get_metrics(metrics_all)
        lr = metrics_all.pop('learning_rate').mean()
        metrics_sums = jax.tree_map(jnp.sum, metrics_all)
        denominator = metrics_sums.pop('denominator')
        summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
        summary['learning_rate'] = lr
        steps_per_eval = FLAGS.eval_frequency if step != 0 else 1
        steps_per_sec = steps_per_eval / (time.time() - t_loop_start)
        t_loop_start = time.time()
        if jax.host_id() == 0:
            train_summary_writer.scalar('steps per second', steps_per_sec,
                                        step)
            for key, val in summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        metrics_all = []
        logging.info('train in step: %d, loss: %.4f', step, summary['loss'])

        # Eval Metrics
        logging.info('Gathering evaluation metrics.')
        t_eval_start = time.time()
        eval_metrics = []
        eval_iter = iter(eval_ds)
        for _, eval_batch in zip(range(FLAGS.num_eval_steps), eval_iter):
            eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
            eval_batch = common_utils.shard(eval_batch)
            metrics = p_eval_step(optimizer.target, eval_batch)
            eval_metrics.append(metrics)
        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
        eval_denominator = eval_metrics_sums.pop('denominator')
        eval_summary = jax.tree_map(
            lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
            eval_metrics_sums)
        if jax.host_id() == 0:
            for key, val in eval_summary.items():
                eval_summary_writer.scalar(key, val, step)
            eval_summary_writer.flush()
        logging.info('eval in step: %d, loss: %.4f', step,
                     eval_summary['loss'])
        logging.info('eval time: %.4f s step %d',
                     time.time() - t_eval_start, step)

        # Translation and BLEU Score.
        logging.info('Translating evaluation dataset.')
        t_inference_start = time.time()
        predict_iter = iter(predict_ds)
        sources, references, predictions = [], [], []
        for _, pred_batch in enumerate(predict_iter):
            pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch)  # pylint: disable=protected-access
            # Handle final odd-sized batch by padding instead of dropping it.
            cur_pred_batch_size = pred_batch['inputs'].shape[0]
            if cur_pred_batch_size % n_devices:
                padded_size = int(
                    np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                pred_batch = jax.tree_map(
                    lambda x: pad_examples(x, padded_size), pred_batch)  # pylint: disable=cell-var-from-loop
            pred_batch = common_utils.shard(pred_batch)
            per_device_batchsize = pred_batch['inputs'].shape[1]
            cache_dtype = jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32
            cache = jax_utils.replicate(
                cache_def.initialize_cache(
                    (per_device_batchsize, FLAGS.max_predict_length),
                    dtype=cache_dtype))
            predicted = p_pred_step(pred_batch['inputs'], optimizer.target,
                                    cache, eos_token, FLAGS.max_predict_length)
            predicted = tohost(predicted)
            inputs = tohost(pred_batch['inputs'])
            targets = tohost(pred_batch['targets'])
            # Iterate through non-padding examples of batch.
            for i, s in enumerate(predicted[:cur_pred_batch_size]):
                sources.append(decode_tokens(inputs[i]))
                references.append(decode_tokens(targets[i]))
                predictions.append(decode_tokens(s))
        logging.info('Translation: %d predictions %d references %d sources.',
                     len(predictions), len(references), len(sources))
        logging.info('Translation time: %.4f s step %d.',
                     time.time() - t_inference_start, step)

        # Calculate BLEU score for translated eval corpus against reference.
        bleu_matches = bleu.bleu_partial(references, predictions)
        all_bleu_matches = per_host_sum_pmap(bleu_matches)
        bleu_score = bleu.complete_bleu(*all_bleu_matches)
        # Save translation samples for tensorboard.
        exemplars = ''
        for n in np.random.choice(np.arange(len(predictions)), 8):
            exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n'
        if jax.host_id() == 0:
            eval_summary_writer.scalar('bleu', bleu_score, step)
            eval_summary_writer.text('samples', exemplars, step)
            eval_summary_writer.flush()
        logging.info('Translation BLEU Score %.4f', bleu_score)
Ejemplo n.º 27
0
 def testIssue222(self):
     x = random.randint(random.PRNGKey(10003), (), 0, 0)
     assert x == 0
Ejemplo n.º 28
0
def main(argv):
  global CFG
  CFG = FLAGS.config

  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  # Guarantee that the JAX bfloat16 extension is used rather than TF bfloat16.
  _ = np.array(jnp.array([1.0], dtype=jnp.bfloat16))

  # Use hardware RNG for bernoulli randoms in dropout mask creation.
  if CFG.hardware_rng:
    models.set_hardware_bernoulli()

  if 'module_import' in CFG and CFG.module_import:
    for module in CFG.module_import:
      importlib.import_module(module)

  if 'additional_task_cache_dirs' in CFG and CFG.additional_task_cache_dirs:
    t5.data.add_global_cache_dirs(CFG.additional_task_cache_dirs)

  num_partitions = CFG.num_partitions
  topology = train_lib.compute_multihost_topology(num_partitions)
  batch_size = CFG.batch_size
  eval_batch_size = CFG.eval_batch_size
  per_replica_set_eval_batch_size = eval_batch_size // topology.num_replica_sets
  if batch_size % topology.num_replicas:
    raise ValueError('Batch size must be divisible by the number of replicas.')

  steps_per_epoch = CFG.steps_per_epoch
  logging.info('steps per epoch: %d', steps_per_epoch)

  broadcast = functools.partial(
      train_lib.broadcast,
      num_replicas=topology.per_replica_set_num_replicas,
      num_partitions=topology.per_host_num_partitions,
      devices=topology.this_host_device_assignment)

  if jax.host_id() == 0:
    tf.io.gfile.makedirs(FLAGS.model_dir)
    tf.io.gfile.copy(FLAGS['config'].config_filename,
                     os.path.join(FLAGS.model_dir, 'config.py'),
                     overwrite=True)
    train_summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.model_dir, 'train'))
    eval_summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.model_dir, 'eval'))
  else:
    train_summary_writer = None
    eval_summary_writer = None

  # Write summaries in background thread to avoid blocking on device sync
  if CFG.infeed:
    # Infeed is currently synchronous, so do it in a background thread too
    infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed')

  (train_ds, eval_ds), eval_cache = input_pipeline.get_datasets_and_cache(
      CFG, topology.num_replica_sets, topology.replica_set_id,
      topology.per_replica_set_host_id)

  vocab = input_pipeline.get_vocabulary(CFG.mixture_or_task_name)
  encoder = vocab.tf_tokenizer
  eos_id = vocab.tokenizer.eos_id()

  def decode_tokens(toks,
                    eos_id = eos_id,
                    max_id = 32000):
    """Decode tokens back to unicode."""
    del eos_id
    # TODO(levskaya): T5 doesn't seem to emit EOS tokens?  double check this
    # is the best decoding function or just switch to using tf_decode.
    # valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
    valid_toks = toks.astype(np.int32)
    valid_toks[valid_toks >= max_id] = 3
    return encoder.detokenize(valid_toks).numpy().decode('utf-8')

  logging.info('Initializing model, optimizer, and step functions.')

  train_config, eval_config, predict_config = get_configs(CFG)

  rng = random.PRNGKey(CFG.random_seed)
  rng, init_rng = random.split(rng)
  # This is used for infeed conversion from feature dict <--> tuple
  train_keys = [
      'inputs', 'targets', 'inputs_position', 'targets_position',
      'inputs_segmentation', 'targets_segmentation'
  ]
  device_train_input_shape = tuple([
      (batch_size // topology.num_replicas,
       CFG.max_input_length if 'inputs' in k else CFG.max_target_length)
      for k in train_keys
  ])

  learning_rate_fn = train_lib.create_learning_rate_scheduler(
      factors=CFG.schedule,
      base_learning_rate=CFG.learning_rate,
      warmup_steps=CFG.warmup_steps)

  # First, we only abstractly initialize the optimizer and model parameters,
  # since the parameters may not even fit in device memory!
  # TODO(jekbradbury): make optimizer_defs compare by value so it can be created
  # in get_initial_params without causing pytree incompatibility
  optimizer_def = optim.Adafactor(
      CFG.learning_rate, decay_rate=0.8, step_offset=CFG.step_offset)
  initialize_params_fn = functools.partial(
      get_initial_params,
      config=CFG,
      transformer_config=eval_config,
      optimizer_def=optimizer_def)
  optimizer = jax.eval_shape(initialize_params_fn, init_rng)
  # tuple-like pytree leaves for global_arg_shapes
  optimizer_shapes = jax.tree_map(lambda x: partitions.Spec(*x.shape),
                                  optimizer)

  # Build parameter partition annotations for preserving partitions from train
  # to eval.
  if num_partitions > 1:
    optimizer_partitions = optimizer.restore_state(
        partitions.set_partitions(num_partitions, optimizer.state_dict()))
    per_host_optimizer_partitions = optimizer.restore_state(
        partitions.set_partitions(topology.per_host_num_partitions,
                                  optimizer.state_dict()))

  # Restore unreplicated optimizer + model state from last checkpoint.
  # TODO(jekbradbury,levskaya): implement sharded native checkpoint/restore
  existing_checkpoint_found = False
  if CFG.restore_checkpoints:
    existing_checkpoint_found = train_lib.checkpoint_exists(FLAGS.model_dir)
    optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)

  # Import a pretrained-T5 checkpoint only if we didn't import a local
  # "native" checkpoint (e.g. due to resuming a pre-empted finetuning run.)
  # TODO(jekbradbury,levskaya): implement sharded T5 checkpoint/restore
  if CFG.restore_t5_checkpoint and not existing_checkpoint_found:
    optimizer = checkpoint_importer.restore_from_t5_checkpoint(
        optimizer, CFG.restore_t5_checkpoint)

  if CFG.restore_t5_checkpoint or existing_checkpoint_found:
    if num_partitions > 1:
      # Until checkpoint/restore is sharded, the restored checkpoint is global
      # and we need to slice each sharded parameter into the chunk containing
      # only the partitions that are present on this host.
      def per_host_chunk(x, spec):
        if spec is None or spec is x:  # unsharded or not a parameter
          return x
        if spec[0] == 1:
          dim_size = x.shape[1]
        elif spec[1] == 1:
          dim_size = x.shape[0]
        else:
          raise NotImplementedError()
        chunk_size = (
            dim_size * topology.per_host_num_partitions // num_partitions)
        lower = topology.per_replica_set_host_id * chunk_size
        upper = (topology.per_replica_set_host_id + 1) * chunk_size
        if spec[0] == 1:
          return x[:, lower:upper]
        else:
          return x[lower:upper]

      optimizer = jax.tree_multimap(per_host_chunk, optimizer,
                                    optimizer_partitions)
  else:
    # If pretraining and no checkpoint imported, we jit the (sharded-) init
    # function to minimize fragmentation. We use the same pmap(sharded_jit)
    # setup as the training step/loop to initialize everything "in-place" and
    # avoid communication or OOM.
    if num_partitions > 1:
      initialize_params_fn = sharded_jit(
          initialize_params_fn,
          in_parts=None,
          local_in_parts=None,
          out_parts=optimizer_partitions,
          local_out_parts=per_host_optimizer_partitions,
          # devices=one_replica_device_assignment,
      )
      initialize_params_fn = jax.pmap(
          initialize_params_fn,
          'batch',
          in_axes=0,
          axis_size=topology.num_replicas,
          devices=topology.device_assignment)
      init_rng = broadcast(init_rng)
      optimizer = initialize_params_fn(init_rng)
      # We maintain the optimizer in unbroadcasted form (i.e. with no leading
      # replica axis). This is equivalent to the as-yet-nonexistent pmap kwarg
      # out_axes=None.
      optimizer = train_lib.unbroadcast(optimizer)
    else:
      optimizer = jax.jit(initialize_params_fn)(init_rng)

  # ---------------------------------------------------------------------------
  # Compile multidevice versions of train/eval/predict step and cache init fn.
  # ---------------------------------------------------------------------------

  # We can use either a single train-step for a host training loop:

  # train_step(optimizer, batch, prev_metrics, dropout_rng, **kwargs)
  #  --> new_optimizer, metrics, new_dropout_rng
  def p_train_step(optimizer, batch,
                   prev_metrics,
                   dropout_rng):
    return train_lib.train_step(
        optimizer,
        batch,
        prev_metrics,
        dropout_rng,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        num_microbatches=CFG.microbatches,
        label_smoothing=CFG.label_smoothing,
        z_loss=CFG.z_loss,
        use_bfloat16=CFG.use_bfloat16)

  if num_partitions > 1:
    p_train_step = sharded_jit(
        p_train_step,
        in_parts=(optimizer_partitions, None, None, None),
        local_in_parts=(per_host_optimizer_partitions, None, None, None),
        out_parts=(optimizer_partitions, None, None),
        local_out_parts=(per_host_optimizer_partitions, None, None))
  # TODO(levskaya): the in_axes spec below might be wrong, double-check.
  p_train_step = jax.pmap(
      p_train_step,
      axis_name='batch',
      in_axes=(None, 0, 0, 0),
      donate_argnums=(0,),
      global_arg_shapes=(optimizer_shapes, None, None, None),
      axis_size=topology.num_replicas,
      devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

  # OR, we use an on-device loop that feeds the training step via infeed queue.
  def device_train_loop_cond(
      args
  ):
    """Stopping criterion for on-device loop."""
    _, _, _, _, step, epoch = args
    return step // steps_per_epoch == epoch

  def device_train_loop_body(
      args
  ):
    """On-device loop body."""
    optimizer, dropout_rngs, metrics, token, step, epoch = args
    # Ordering input data from infeed requires threading a symbolic token
    # through the computation.
    input_data, token = lax.infeed(
        token,
        shape=tuple(
            [jax.ShapedArray(s, jnp.int32) for s in device_train_input_shape]))
    # Rebuild input dict from infeed data tuple.
    batch = {k: v for k, v in zip(train_keys, input_data)}
    # Run the train_step function and return the loop state.
    optimizer, metrics, dropout_rngs = train_lib.train_step(
        optimizer,
        batch,
        metrics,
        dropout_rngs,
        train_config,
        learning_rate_fn,
        num_microbatches=CFG.microbatches,
        label_smoothing=CFG.label_smoothing,
        z_loss=CFG.z_loss)
    step += 1
    return optimizer, dropout_rngs, metrics, token, step, epoch

  def device_train_loop(optimizer, dropout_rngs,
                        metrics, step,
                        epoch):
    # Create symbolic token for threading infeed data.
    token = lax.create_token(step)
    # Run on-device loop.
    optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop(
        device_train_loop_cond, device_train_loop_body,
        (optimizer, dropout_rngs, metrics, token, step, epoch))
    return optimizer, dropout_rngs, metrics, step

  if num_partitions > 1:
    device_train_loop = sharded_jit(
        device_train_loop,
        in_parts=(optimizer_partitions, None, None, None, None),
        local_in_parts=(per_host_optimizer_partitions, None, None, None, None),
        out_parts=(optimizer_partitions, None, None, None),
        local_out_parts=(per_host_optimizer_partitions, None, None, None))
  p_train_epoch = jax.pmap(
      device_train_loop,
      axis_name='batch',
      in_axes=(None, 0, 0, None, None),
      donate_argnums=(0,),
      global_arg_shapes=(optimizer_shapes, None, None, None, None),
      axis_size=topology.num_replicas,
      devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

  # Reduction psum for metric data.

  def p_allreduce_metrics(x):
    return lax.psum(x, axis_name='batch')

  if num_partitions > 1:
    p_allreduce_metrics = sharded_jit(
        p_allreduce_metrics,
        in_parts=None,
        local_in_parts=None,
        out_parts=None,
        local_out_parts=None,
        num_partitions=num_partitions,
        local_num_partitions=topology.per_host_num_partitions)
  p_allreduce_metrics = jax.pmap(
      p_allreduce_metrics,
      axis_name='batch',
      global_arg_shapes=None,
      axis_size=topology.num_replicas,
      devices=topology.device_assignment)

  # Training evaluation computation.

  # eval_step(params, batch, config, label_smoothing=0.0) --> metrics
  def p_eval_step(params, batch):
    return train_lib.eval_step(
        params, batch, config=eval_config, label_smoothing=CFG.label_smoothing)

  if num_partitions > 1:
    p_eval_step = sharded_jit(
        p_eval_step,
        in_parts=(optimizer_partitions.target, None),
        local_in_parts=(per_host_optimizer_partitions.target, None),
        out_parts=None,
        local_out_parts=None)
  p_eval_step = jax.pmap(
      p_eval_step,
      axis_name='batch',
      in_axes=(None, 0),
      global_arg_shapes=(optimizer_shapes.target, None),
      axis_size=topology.num_replicas,
      devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

  # Fast autoregressive decoding loop.
  # For inference and model evaluation.

  # predict_step(inputs, params,
  #              eos_id, max_decode_len, config, beam_size=4) --> beam_seqs
  def p_pred_step(inputs, params):
    return train_lib.predict_step(inputs, params, eos_id,
                                  CFG.max_eval_target_length, predict_config,
                                  CFG.beam_size)

  if num_partitions > 1:
    p_pred_step = sharded_jit(
        p_pred_step,
        in_parts=(None, optimizer_partitions.target),
        local_in_parts=(None, per_host_optimizer_partitions.target),
        out_parts=None,
        local_out_parts=None)
  p_pred_step = jax.pmap(
      p_pred_step,
      axis_name='batch',
      in_axes=(0, None),
      global_arg_shapes=(None, optimizer_shapes.target),
      axis_size=topology.num_replicas,
      devices=topology.device_assignment)  # pytype: disable=wrong-arg-types

  # ---------------------------------------------------------------------------
  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap'd training update for performance.
  # There should be a unique dropout key for each replica represented on this
  # host, but the key should be the same for the same replica on other hosts.
  # Again, this is what the replica set abstraction is for.
  dropout_rngs = random.split(
      random.fold_in(rng, topology.replica_set_id),
      topology.per_replica_set_num_replicas)
  # restore step from last checkpoint
  host_step = int(optimizer.state.step)
  empty_metrics = broadcast({
      'loss': 0.0,
      'accuracy': 0.0,
      'learning_rate': 0.0,
      'denominator': 0.0
  })
  if CFG.infeed:
    # TODO(jekbradbury): support something like this for the Python-loop case
    logging.info('Precompiling training loop and moving optimizer to device.')
    optimizer, _, metrics, _ = p_train_epoch(optimizer, dropout_rngs,
                                             empty_metrics,
                                             jnp.array(0, dtype=jnp.int32), 1)
    optimizer = train_lib.unbroadcast(optimizer)
    metrics['loss'].block_until_ready()

  logging.info('Starting training loop.')

  local_devices = jax.local_devices()
  device_step = broadcast(host_step)
  first_epoch = host_step // steps_per_epoch

  # Main Loop over "epochs".
  train_iter = train_ds.as_numpy_iterator()
  for epoch in range(first_epoch, first_epoch + CFG.num_epochs):
    metrics = empty_metrics

    # NOTE: 'optimizer' is unbroadcast by construction at initialization or
    # when loading a checkpoint. It is maintained in 'unbroadcast' state to
    # enable the XLA cross-replica sharding optimization.  The broadcasting is
    # handled automatically by the pmap'd functions that use it.

    # Gather all task evaluation metrics.
    logging.info('Evaluating tasks.')
    if epoch == first_epoch + 1:
      train_lib.sync_devices()
    for task in eval_cache.tasks:
      logging.info('Evaluating task %s', task.name)
      all_predicted, all_bs = [], []
      for pred_batch in eval_cache.preprocessed_examples[task.name]:
        # Handle final odd-sized batch by padding instead of dropping it.
        input_batch, unpadded_batch_size = train_lib.pad_batch_to_size(
            pred_batch['inputs'], per_replica_set_eval_batch_size)
        all_bs.append(unpadded_batch_size)
        # Split batch dimensions for pmap.
        input_batch = jax.tree_map(
            lambda x: x.reshape(
                (topology.per_replica_set_num_replicas, -1) + x.shape[1:]),
            input_batch)
        # Run fast inference on batch.
        all_predicted.append(p_pred_step(input_batch, optimizer.target))

      # Pad out the number of batches so each host has the same number.
      max_host_batch_number = np.max(
          eval_cache.preprocessed_batch_sizes[task.name])
      batch_shortfall = max_host_batch_number - len(all_predicted)
      if batch_shortfall > 0:
        # TODO(levskaya): Fix for case of entirely empty all_predicted.
        # To make sure the cross-host barriers work, we run the program the same
        # number of times on all hosts. The results of this call is ignored, and
        # the predictions are populated with zeros instead.
        p_pred_step(input_batch, optimizer.target)  # Dummy call.
        all_predicted.extend([jnp.zeros_like(all_predicted[0])] *
                             batch_shortfall)
        all_bs.extend([0] * batch_shortfall)
      all_predicted = jnp.concatenate(all_predicted)
      all_bs = jnp.array(all_bs)

      # Collect all batches from across hosts and reverse sharding.
      all_predicted = train_lib.host_allgather(
          all_predicted, topology.num_replica_sets, topology.replica_set_id,
          topology.per_replica_set_host_id == 0)
      seqlength = all_predicted.shape[-1]
      total_examples = np.sum(
          train_lib.host_allgather(all_bs, topology.num_replica_sets,
                                   topology.replica_set_id,
                                   topology.per_replica_set_host_id == 0))
      del all_bs
      assert total_examples == len(eval_cache.examples[task.name]), (
          'Total number of batches incorrect for task %s.' % task.name)
      # De-shard the collected predicted tokens and remove padding.
      all_predicted = np.transpose(all_predicted, (1, 2, 0, 3)).reshape(
          -1, seqlength)[:total_examples]

      # We now run the post-processing and metric-fns on a single host.
      if jax.host_id() == 0:
        assert eval_summary_writer
        raw_predictions = []
        for tokens in all_predicted:
          raw_predictions.append(decode_tokens(tokens))

        # post-process predictions for metric fns
        predictions = [
            task.postprocess_fn(p, example=ex)
            for p, ex in zip(raw_predictions, eval_cache.examples[task.name])
        ]

        for metric_fn in task.metric_fns:
          scores = metric_fn(eval_cache.targets[task.name], predictions)
          for metric_name, metric_value in scores.items():
            tag = f'eval/{task.name}/{metric_name}'
            eval_summary_writer.scalar(tag, metric_value, host_step)
            logging.info('EVAL %s at step %d: %.3f', tag, host_step,
                         metric_value)
          eval_summary_writer.flush()

        # Save text samples for tensorboard.
        exemplars = ''
        for n in np.random.choice(np.arange(len(predictions)), 8):
          tgt_txt = tf.compat.as_text(
              eval_cache.examples[task.name][n]['targets_plaintext'])
          pred_txt = raw_predictions[n]
          exemplars += (f'{eval_cache.inputs[task.name][n]}\n\n'
                        f'target: {tgt_txt}\n\n'
                        f'prediction: {pred_txt}\n\n')
        eval_summary_writer.text(f'{task.name} samples', exemplars, host_step)
        eval_summary_writer.flush()

    # Take an Xprof trace after the first loop has compiled everything.
    if epoch == first_epoch + 1:
      train_lib.sync_devices()

    # For on-device loop, we launch the computation before feeding data.
    logging.info('BEGIN Train loop.')
    if CFG.infeed:
      optimizer, dropout_rngs, metrics, device_step = p_train_epoch(
          optimizer, dropout_rngs, metrics, train_lib.unbroadcast(device_step),
          epoch)
      optimizer = train_lib.unbroadcast(optimizer)

    # Epoch loop.
    while int(host_step // steps_per_epoch) == epoch:
      batch = next(train_iter)
      batch = jax.tree_map(
          lambda x: x.reshape(
              (topology.per_replica_set_num_replicas, -1) + x.shape[1:]), batch)
      # Feed the on-device training loop.
      if CFG.infeed:
        for i, device in enumerate(local_devices):
          # When using infeed to provide data to the computation, we're on our
          # own for feeding the right values to the right devices. Each device
          # should get the minibatch corresponding to its replica, a slice of
          # the larger batch corresponding to the host's replica set.
          if device.platform == 'tpu':
            device_coords = (*device.coords, device.id % 2)
          else:
            device_coords = (device.host_id, i)
          per_replica_set_device_coords = tuple(
              dc % prsm
              for dc, prsm in zip(device_coords, topology.per_replica_set_mesh))
          per_replica_set_replica_coords = tuple(
              prsdc // prm for prsdc, prm in zip(per_replica_set_device_coords,
                                                 topology.per_replica_mesh))
          per_replica_set_replica_id = 0
          for prsm, prm, prsrc in zip(topology.per_replica_set_mesh,
                                      topology.per_replica_mesh,
                                      per_replica_set_replica_coords):
            per_replica_set_replica_id = (
                per_replica_set_replica_id * prsm // prm + prsrc)
          input_tuple = tuple(
              [batch[k][per_replica_set_replica_id] for k in train_keys])
          # Safety check: infeed does not check shape or types but requires
          # them to agree with on-device spec, otherwise the queue and program
          # stalls.
          tuple_shapes = jax.tree_map(jnp.shape, input_tuple)
          tuple_dtypes = jax.tree_map(lambda x: x.dtype, input_tuple)
          assert tuple_shapes == device_train_input_shape, (
              'infeed shape error %s != %s' %
              (tuple_shapes, device_train_input_shape))
          assert tuple(set(tuple_dtypes)) == (jnp.int32,), \
              ('infeed dtype error %s not all of type %s' % (
                  tuple_dtypes, jnp.int32))
          infeed_pool.submit(
              functools.partial(device.transfer_to_infeed, input_tuple))
      # Host training loop.
      else:
        optimizer, metrics, dropout_rngs = p_train_step(optimizer, batch,
                                                        metrics, dropout_rngs)
        optimizer = train_lib.unbroadcast(optimizer)
      host_step += 1
    logging.info('END Train loop.')

    # Maybe save a checkpoint on one host.
    if (CFG.save_checkpoints and
        epoch % CFG.checkpoint_freq == CFG.checkpoint_freq - 1 and
        jax.host_id() == 0):
      checkpoints.save_checkpoint(FLAGS.model_dir, optimizer, host_step)

    # Gather training metrics.
    metrics = p_allreduce_metrics(metrics)
    metrics = jax.tree_map(lambda x: jax.device_get(x[0]), metrics)
    denominator = metrics.pop('denominator')
    summary = jax.tree_map(lambda x: x / denominator, metrics)  # pylint: disable=cell-var-from-loop
    logging.info('train in step: %s, %s', host_step, summary)
    if jax.host_id() == 0:
      assert train_summary_writer
      for key, val in summary.items():
        train_summary_writer.scalar(key, val, host_step)
      train_summary_writer.flush()

    # Gather training evaluation metrics.
    logging.info('Gathering training evaluation metrics.')
    eval_metrics = []
    eval_iter = eval_ds.as_numpy_iterator()
    for _, eval_batch in zip(range(CFG.num_eval_steps), eval_iter):
      eval_batch = jax.tree_map(
          lambda x: x.reshape(
              (topology.per_replica_set_num_replicas, -1) + x.shape[1:]),
          eval_batch)
      metrics = p_eval_step(optimizer.target, eval_batch)
      eval_metrics.append(metrics)
    # average metrics across devices
    eval_metrics = p_allreduce_metrics(eval_metrics)
    eval_metrics = common_utils.get_metrics(eval_metrics)
    # average metrics across steps
    eval_metrics = jax.tree_map(np.sum, eval_metrics)
    eval_denominator = eval_metrics.pop('denominator')
    eval_summary = jax.tree_map(
        lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
        eval_metrics)
    logging.info('eval in step: %s, %s', host_step, eval_summary)
    if jax.host_id() == 0:
      assert eval_summary_writer
      for key, val in eval_summary.items():
        eval_summary_writer.scalar(key, val, host_step)
      eval_summary_writer.flush()

  # Wait until computations are done before exiting
  logging.info('Finished.')
  train_lib.sync_devices()
  # Shut down the infeed threadpool.
  if CFG.infeed:
    infeed_pool.shutdown()
Ejemplo n.º 29
0
 def f(x):
     return random.gamma(random.PRNGKey(0), x)
Ejemplo n.º 30
0
 def renyi_loss_fn(x):
     return RenyiELBO(alpha=alpha,
                      num_particles=10).loss(random.PRNGKey(0), {}, model,
                                             guide, x)