コード例 #1
0
ファイル: train.py プロジェクト: lucifer2288/google-research
def eval_policy(policy, rng, state, model, test_ds, epoch):
    """Eval for a single epoch."""
    batch_metrics = []

    policy = flax.jax_utils.unreplicate(flax.jax_utils.replicate(policy))

    # Function is recompiled for this specific policy.
    test_ds = util_fns.get_iterator(test_ds)
    for batch in test_ds:
        metrics, rng = eval_step_policy(rng, batch, state, model, policy)

        # Better to leave metrics on device, and off-load after finishing epoch.
        batch_metrics.append(metrics)

    # Load to CPU.
    batch_metrics = jax.device_get(flax.jax_utils.unreplicate(batch_metrics))
    # Compute mean of metrics across each batch in epoch.
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics])
        for k in batch_metrics[0] if 'batch' not in k
    }

    nelbo = epoch_metrics_np['nelbo']
    info_string = f'eval policy epoch: {epoch}, nelbo: {nelbo:.4f}'
    logging.info(info_string)

    return epoch_metrics_np
コード例 #2
0
ファイル: train.py プロジェクト: lucifer2288/google-research
def eval_model(p_eval_step, rng, state, test_ds, epoch):
    """Eval for a single epoch."""
    start_time = time.time()
    batch_metrics = []

    test_ds = util_fns.get_iterator(test_ds)

    for batch in test_ds:
        metrics, rng = p_eval_step(rng, batch, state)

        # Better to leave metrics on device, and off-load after finishing epoch.
        batch_metrics.append(metrics)

    # Load to CPU.
    batch_metrics = jax.device_get(flax.jax_utils.unreplicate(batch_metrics))

    # Compute mean of metrics across each batch in epoch.
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics])
        for k in batch_metrics[0] if 'batch' not in k
    }

    nelbo = epoch_metrics_np['nelbo']
    message = f'Eval epoch took {time.time() - start_time:.1f} seconds.'
    logging.info(message)
    info_string = f'eval epoch: {epoch}, nelbo: {nelbo:.4f}'
    logging.info(info_string)

    return epoch_metrics_np, rng
コード例 #3
0
def eval_policy_and_sigma(policy, sigma, rng, state, model, dataset):
    """Eval for a single epoch with policy and sigma."""
    batch_metrics = []

    # Function is recompiled for this specific policy.
    dataset = util_fns.get_iterator(dataset)
    for batch in dataset:
        metrics, rng = eval_step_policy_and_sigma(rng, batch, state, model,
                                                  policy, sigma)

        # Better to leave metrics on device, and off-load after finishing epoch.
        batch_metrics.append(metrics)

    # Load to CPU.
    batch_metrics = jax.device_get(flax.jax_utils.unreplicate(batch_metrics))
    # Compute mean of metrics across each batch in epoch.
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics])
        for k in batch_metrics[0] if 'batch' not in k
    }
    stdev_metrics_np = {
        k: np.std([metrics[k] for metrics in batch_metrics])
        for k in batch_metrics[0] if 'batch' not in k
    }

    nelbo = epoch_metrics_np['nelbo']
    stdev = stdev_metrics_np['nelbo']
    num_samples = len(batch_metrics)
    info_string = (f'eval policy with sigma scores nelbo: {nelbo:.4f} +/- '
                   f'{stdev:.4f} on {num_samples} samples. So expected stdev '
                   f'is stdev/sqrt(n) is {stdev / np.sqrt(num_samples):.4f}')
    logging.info(info_string)

    return nelbo, stdev, num_samples
コード例 #4
0
ファイル: train.py プロジェクト: lucifer2288/google-research
def train_epoch(p_train_step, state, train_ds, batch_size, epoch, rng,
                kl_tracker):
    """Train for a single epoch."""
    start_time = time.time()

    batch_metrics = []

    train_ds = util_fns.get_iterator(train_ds)
    with jax.profiler.StepTraceAnnotation('train', step_num=state.step):
        for batch in train_ds:
            state, metrics, rng = p_train_step(rng, batch, state)

            # Better to leave metrics on device, and off-load after finishing epoch.
            batch_metrics.append(metrics)

    # Load to CPU.
    batch_metrics = jax.device_get(flax.jax_utils.unreplicate(batch_metrics))

    # This processes the loss per t, although two nested for-loops (counting the
    # one inside kl_tracker), it actually does not hurt timing performance
    # meaningfully.
    t_batches = [
        metrics['t_batch'].reshape(batch_size) for metrics in batch_metrics
    ]
    nelbo_per_t_batches = [
        metrics['nelbo_per_t_batch'].reshape(batch_size)
        for metrics in batch_metrics
    ]
    for t_batch, nelbo_per_t_batch in zip(t_batches, nelbo_per_t_batches):
        kl_tracker.update(t_batch, nelbo_per_t_batch)

    # Compute mean of metrics across each batch in epoch.
    epoch_metrics = {
        key: np.mean([metrics[key] for metrics in batch_metrics])
        for key in batch_metrics[0] if 'batch' not in key
    }

    message = f'Epoch took {time.time() - start_time:.1f} seconds.'
    logging.info(message)
    info_string = (
        f'train epoch: {epoch}, loss: {epoch_metrics["loss"]:.4f} '
        f'nelbo: {epoch_metrics["nelbo"]:.4f} ce: {epoch_metrics["ce"]:.4f}')
    logging.info(info_string)

    return state, epoch_metrics, rng
コード例 #5
0
def compress_dataset(state, model, test_ds, sigma=None, policy=None):
    """Compress a dataset.

  Args:
    state: A train state containing the params.
    model: The model class that contains all necessary methods.
    test_ds: The dataset to compress.
    sigma: An optional order of the generative process.
    policy: A policy describing which steps to take in parallel. An arange
      would do each step individually, but would be very slow.

  Returns:
    The final bits per dimension to compress the dataset, where each example
    is encoded with its own bitstream.
  """
    assert (sigma is None) == (policy is None)

    if sigma is None and policy is None:
        logging.info('Compressing with random order since none was given.')
        policy = model.get_naive_policy(25)
        sigma = model.get_random_order(jax.random.PRNGKey(0))

    total_size = 0
    total_count = 0

    d = int(np.prod(model.config.data_shape))

    test_ds = util_fns.get_iterator(test_ds)
    for idx, batch in enumerate(test_ds):
        x = batch['image']
        x = x.reshape(-1, *x.shape[2:])  # Flatten, we are not using pmap.
        batch_size = x.shape[0]

        # Scale bits determine the rounding precision, 32 seems to work well.
        # The init_bits puts a little buffer in the bitstream. But this is not
        # necessary for our model class.
        streams = [
            ans_template.Bitstream(scale_bits=32) for _ in range(x.shape[0])
        ]

        if idx == 0:
            logging.info('Initial total bits in bitstream: %d',
                         len(streams[0]))

        logging.info('Encoding...')
        start = time.time()
        streams = model.encode_with_policy_and_sigma(streams, state.ema_params,
                                                     x, policy, sigma)

        logging.info('Encoding took %.2f seconds', time.time() - start)

        # Adds the number of bits in the streams to the total size.
        for stream in streams:
            total_size += len(stream)
        total_count += batch_size

        bits_per_dim = total_size / total_count / d

        logging.info('Encoded %d Current bits per dim %f', total_count,
                     bits_per_dim)

        logging.info('Decoding...')
        start = time.time()
        decoded, streams = model.decode_with_policy_and_sigma(
            streams, state.ema_params, policy, sigma, batch_size)
        del streams
        logging.info('Decoding took %.2f seconds', time.time() - start)

        coding_error = jnp.abs(x - decoded).sum()
        assert coding_error == 0, f'Coding error non-zero: {coding_error}'

    bits_per_dim = total_size / total_count / d
    logging.info('Bits per dim %f', bits_per_dim)

    return bits_per_dim