示例#1
0
def create_dataset(split, mode, batch_size=None, return_iterator=True):
  """Creates a single-class dataset iterator based on config and split."""

  per_class = CONFIG.DATA.PER_CLASS
  # pylint: disable=g-long-lambda
  if mode == 'train':
    if not batch_size:
      batch_size = CONFIG.TRAIN.BATCH_SIZE
    num_steps = CONFIG.TRAIN.NUM_FRAMES
    preprocess_fn = (
        lambda video, labels, seq_label, seq_len, name: sample_and_preprocess(
            video,
            labels,
            seq_label,
            seq_len,
            name,
            num_steps,
            augment=True,
            add_shape=True))
  elif mode == 'eval':
    if not batch_size:
      batch_size = CONFIG.EVAL.BATCH_SIZE
    num_steps = CONFIG.EVAL.NUM_FRAMES
    preprocess_fn = (
        lambda video, labels, seq_label, seq_len, name: sample_and_preprocess(
            video,
            labels,
            seq_label,
            seq_len,
            name,
            num_steps,
            augment=False,
            add_shape=True))
  else:
    raise ValueError('Unidentified mode: %s. Use either train or eval.' % mode)
  # pylint: enable=g-long-lambda

  fraction = CONFIG.DATA.PER_DATASET_FRACTION

  datasets = []
  with tf.device('/cpu:0'):
    for dataset_name in CONFIG.DATASETS:
      tfrecord_files = get_tfrecords(
          dataset_name, split, CONFIG.PATH_TO_TFRECORDS, per_class=per_class)
      dataset = tf.data.TFRecordDataset(
          tfrecord_files, num_parallel_reads=FLAGS.num_parallel_calls)

      if (fraction != 1.0 and mode == 'train'):
        num_samples = max(1, int(fraction * DATASETS[dataset_name][split]))
        dataset = dataset.take(num_samples)
      else:
        num_samples = DATASETS[dataset_name][split]
      if CONFIG.DATA.SHUFFLE_QUEUE_SIZE <= 0:
        dataset = dataset.shuffle(num_samples)
      else:
        dataset = dataset.shuffle(CONFIG.DATA.SHUFFLE_QUEUE_SIZE)
      dataset = dataset.repeat()
      dataset = dataset.batch(batch_size)
      datasets.append(dataset)

    dataset = tf.data.experimental.sample_from_datasets(datasets,
                                                        len(datasets) * [1.0])

    dataset = dataset.unbatch()

    dataset = dataset.map(decode,
                          num_parallel_calls=FLAGS.num_parallel_calls)

    dataset = dataset.map(preprocess_fn,
                          num_parallel_calls=FLAGS.num_parallel_calls)

    # drop_remainder adds batch size in shape else first dim remains as None.
    dataset = dataset.batch(batch_size, drop_remainder=True)

    # Prefetch batches
    dataset = dataset.prefetch(1)

    if return_iterator:
      return iter(dataset)
    else:
      return dataset
示例#2
0
def lanczos_algorithm(mvp_fn: Callable[[tf.Tensor], tf.Tensor],
                      dim: int,
                      order: int,
                      random_seed: int = 0,
                      only_gpu: bool = True) -> Tuple[tf.Tensor, tf.Tensor]:
  """Estimates an Hermitian matrix by using its product with arbitrary vectors.

  The Lanczos algorithm is described here:
  https://en.wikipedia.org/wiki/Lanczos_algorithm

  Args:
    mvp_fn: Matrix-vector product function. Function that takes as input a
      tensor of shape [`dim`, 1] and returns another tensor of the same shape.
      The returned tensor should be equal to Hv where v is the input vector and
      H is the symmetric matrix to estimate.
    dim: Dimension of the problem (number of columns and rows of the matrix to
      estimate.)
    order: Rank of the approximation to compute. `mvp_fn` will be called `order`
      times.
    random_seed: Random seed used for sampling the initial vector.
    only_gpu: Whether to use available GPUs for both the matrix vector product
      and the orthogonalization (if set to false, CPU will be used for
      orthogonalization). It is recommended to set this parameter to true and
      change it only if a memory error occurs.

  Returns:
    An estimation of the matrix defined by the matrix vector product function
      given. The matrix is returned as a tuple of two tensors (V,T) of shape
      [dim, order] and [order, order], where T is tridiagonal. The approximation
      of the matrix is then A = V T V^*.
  """
  device_selector = DeviceSelector(only_gpu)

  # Lanczos runs on CPU to save accelerator memory. Most of the computational
  # load takes place in the matrix vector function, which is still computed
  # on GPU if available.
  with tf.device(device_selector.default):
    # Runs Lanczos in float64 as numerical stability is an issue and the
    # bottleneck is calling `mvp_fn`.
    float_dtype = tf.float64
    tridiag = tf.Variable(tf.zeros((order, order), dtype=float_dtype))
    vecs = tf.Variable(tf.zeros((dim, order), dtype=float_dtype))
    init_vec = tf.random.uniform(
        (dim, 1), minval=-1, maxval=1, dtype=float_dtype, seed=random_seed)
    init_vec = init_vec / tf.math.reduce_euclidean_norm(init_vec)
    vecs[:, 0:1].assign(init_vec)
    beta = 0
    v_old = tf.zeros((dim, 1), dtype=float_dtype)

    for i in range(order):
      ts = time.time()
      v = vecs[:, i:i+1]
      with tf.device(device_selector.accelerator):
        tss = time.time()
        w = tf.cast(mvp_fn(tf.cast(v, tf.float32)), float_dtype)
        time_mvp = time.time() - tss
      w = w - beta * v_old
      alpha = tf.matmul(w, v, transpose_a=True)
      tridiag[i:i+1, i:i+1].assign(alpha)
      w = w - alpha * v

      # Reorthogonalization
      for j in range(i):
        tau = vecs[:, j:j+1]
        coeff = tf.matmul(w, tau, transpose_a=True)
        w = w - coeff * tau

      beta = tf.math.reduce_euclidean_norm(w)
      if beta < 1e-6:
        warning_msg = ("Possible numerical stability issues in Lanczos: "
                       "got beta = {} in iteration {}".format(beta.numpy(), i))
        warnings.warn(warning_msg)

      if i + 1 < order:
        tridiag[i, i+1].assign(beta)
        tridiag[i+1, i].assign(beta)
        vecs[:, i+1:i+2].assign(w / beta)

      v_old = v

      info = "Iteration {}/{} done in {:.2f}s (MVP: {:.2f}s).".format(
          i, order,
          time.time() - ts, time_mvp)
      print(info)

  return vecs, tridiag
示例#3
0
    def wrapper(*args):
        """Wrapper that wraps/unwraps args, retvals, and runs the function."""
        if _pmap_config.devices() is not None:
            raise ValueError(
                "Found a surrounding pmap. Nested pmap is not supported "
                "yet.")
        # TODO(wangpeng): Maybe we should use `asarray` to convert everything
        # to ndarray first.
        args = _np_to_tf(args)

        flattened_input_args = tf.nest.flatten(args)
        flattened_per_device_args = [[] for _ in devices]
        for arg in flattened_input_args:
            if isinstance(arg, tf.Tensor):
                # TODO(nareshmodi): Try and use the dynamic shape instead.
                if (not arg.shape.rank) or arg.shape[0] != len(devices):
                    # TODO(nareshmodi): Fix this restriction
                    raise ValueError(
                        "Input tensors need to have a first dimension equal to "
                        "the number of devices; got tensor of shape %s and %s devices"
                        % (arg.shape, len(devices)))
                # NOTE: Alternatively use tf.split, and place the split tensors on the
                # appropriate device. The best solution for this is to have an API that
                # splits a tensor across devices.
                for j, device in enumerate(devices):
                    updated_arg = tf.gather(arg, j)
                    # TODO(wangpeng): Investigate whether we need a tf.identity for TPU.
                    if not has_tpu:
                        with tf.device(device):
                            updated_arg = tf.identity(updated_arg)
                    flattened_per_device_args[j].append(updated_arg)
            elif isinstance(arg, ShardedNdArray):
                for device_args, tensor in zip(flattened_per_device_args,
                                               arg.tensors):
                    device_args.append(tensor)
            else:
                for device_args in flattened_per_device_args:
                    device_args.append(arg)

        all_per_device_args = [
            tf.nest.pack_sequence_as(args, device_args)
            for device_args in flattened_per_device_args
        ]

        with pmap_config(axis_name, devices):
            results = pmap_fn(all_per_device_args)

        # Rewrap things. This can probably be written better.
        flattened_results = [tf.nest.flatten(result) for result in results]
        final_tree = []

        # TODO(nareshmodi): assert all items in flattened_results have the same
        # structures

        for i in range(len(flattened_results[0])):
            tensors = []
            for j, device in enumerate(devices):
                assert isinstance(flattened_results[j][i], tf.Tensor), (
                    "currently only tensor return items are supported")
                tensors.append(flattened_results[j][i])
            final_tree.append(ShardedNdArray(tensors))

        final_actual_result = tf.nest.pack_sequence_as(results[0], final_tree)

        # Workaround b/121383831
        if (has_tpu and isinstance(final_actual_result, list)
                and len(final_actual_result)
                == 1) and not _orig_result_is_list.val:
            return final_actual_result[0]
        else:
            return final_actual_result
示例#4
0
    def _build_tables(self, prior):
        """Computes integer-valued probability tables used by the range coder.

    These tables must not be re-generated independently on the sending and
    receiving side, since small numerical discrepancies between both sides can
    occur in this process. If the tables differ slightly, this in turn would
    very likely cause catastrophic error propagation during range decoding. For
    a more in-depth discussion of this, see:

    > "Integer Networks for Data Compression with Latent-Variable Models"<br />
    > J. Ballé, N. Johnston, D. Minnen<br />
    > https://openreview.net/forum?id=S1zz2i0cY7

    The tables are stored in `tf.Variable`s as attributes of this object. The
    recommended way is to train the model, instantiate an entropy model with
    `compression=True`, and then distribute the model to a sender and a
    receiver.

    Arguments:
      prior: The `tfp.distributions.Distribution` object (see initializer).
    """
        offset = helpers.quantization_offset(prior)
        lower_tail = helpers.lower_tail(prior, self.tail_mass)
        upper_tail = helpers.upper_tail(prior, self.tail_mass)

        # Largest distance observed between lower tail and median, and between
        # median and upper tail.
        minima = offset - lower_tail
        minima = tf.cast(tf.math.ceil(minima), tf.int32)
        minima = tf.math.maximum(minima, 0)
        maxima = upper_tail - offset
        maxima = tf.cast(tf.math.ceil(maxima), tf.int32)
        maxima = tf.math.maximum(maxima, 0)

        # PMF starting positions and lengths.
        pmf_start = offset - tf.cast(minima, self.dtype)
        pmf_length = maxima + minima + 1

        # Sample the densities in the computed ranges, possibly computing more
        # samples than necessary at the upper end.
        max_length = tf.math.reduce_max(pmf_length)
        if tf.executing_eagerly() and max_length > 2048:
            logging.warning(
                "Very wide PMF with %d elements may lead to out of memory issues. "
                "Consider priors with smaller dispersion or increasing `tail_mass` "
                "parameter.", int(max_length))
        samples = tf.range(tf.cast(max_length, self.dtype), dtype=self.dtype)
        samples = tf.reshape(samples, [-1] + len(self.prior_shape) * [1])
        samples += pmf_start
        pmf = prior.prob(samples)

        # Collapse batch dimensions of distribution.
        pmf = tf.reshape(pmf, [max_length, -1])
        pmf = tf.transpose(pmf)

        pmf_length = tf.broadcast_to(pmf_length, self.prior_shape)
        pmf_length = tf.reshape(pmf_length, [-1])
        cdf_length = pmf_length + 2
        cdf_offset = tf.broadcast_to(-minima, self.prior_shape)
        cdf_offset = tf.reshape(cdf_offset, [-1])

        # Prevent tensors from bouncing back and forth between host and GPU.
        with tf.device("/cpu:0"):

            def loop_body(args):
                prob, length = args
                prob = prob[:length]
                prob = tf.concat(
                    [prob, 1 - tf.reduce_sum(prob, keepdims=True)], axis=0)
                cdf = range_coding_ops.pmf_to_quantized_cdf(
                    prob, precision=self.range_coder_precision)
                return tf.pad(cdf, [[0, max_length - length]],
                              mode="CONSTANT",
                              constant_values=0)

            # TODO(jonycgn,ssjhv): Consider switching to Python control flow.
            cdf = tf.map_fn(loop_body, (pmf, pmf_length),
                            dtype=tf.int32,
                            name="pmf_to_cdf")

        self._cdf = tf.Variable(cdf, trainable=False, name="cdf")
        self._cdf_offset = tf.Variable(cdf_offset,
                                       trainable=False,
                                       name="cdf_offset")
        self._cdf_length = tf.Variable(cdf_length,
                                       trainable=False,
                                       name="cdf_length")
示例#5
0
def benchmark_tf_function(
        user_fn,
        iters=1,
        config=default_benchmark_config(),
        extra_columns=None,
        # As of this writing (February 2019), autograph is the default for
        # tfe.function, but there seem to be many bugs. Hopefully, in future, this
        # default can be changed to True or the argument can be removed.
        use_autograph=False,
        print_intermediates=False,
        cpu_device='cpu:0',
        gpu_device='gpu:0'):
    """Time a TensorFlow function under a variety of strategies and hardware.

  Runs the callable `user_fn` `iters` times under the strategies (any of Eager,
  tfe.function + graph, and XLA) and hardware (CPU, GPU).


  # Example:
  ```python
  data_dicts = []
  for inner_iters in [10, 100]:
    for size in [100, 1000]:
      def f():
        total = tf.constant(0.0)
        for _ in np.arange(inner_iters):
          m = tf.random.uniform((size, size))
          total += tf.reduce_sum(tf.matmul(m, m))
          return total

      data_dicts += benchmark_tf_function.benchmark_tf_function(
          f,
          iters=5,
          extra_columns={'inner_iters': inner_iters,
                         'size': size})
  ```

  Args:
    user_fn: A zero-argument, callable function of TensorFlow code.
    iters: The number of times to run the function for each runtime and
      hardware combination.
    config: A BenchmarkTfFunctionConfig, specifying which strategies and
      hardware to use. Valid strategies are RUNTIME_EAGER, RUNTIME_FUNCTION, and
      RUNTIME_XLA. Valid hardware choices are HARDWARE_CPU, HARDWARE_GPU.
    extra_columns: A dictionary of extra information to add to each dictionary
      in data_dicts.
    use_autograph: Boolean, controlling whether autograph is used for the
      graph and XLA strategies.
    print_intermediates: Boolean. If true, print out each row before adding it
      to the data_dicts.
    cpu_device: String, the TensorFlow device to use for CPU.
    gpu_device: String, the TensorFlow device to use for GPU.

  Returns:

    data_dicts: A list of dictionaries containing the results of benchmarking
      Time for the first run is stored under the `first_iter_time` key, and time
      for all runs is stored under the `total_time` key.
  """
    data_dicts = []

    if extra_columns is None:
        extra_columns = {}

    if HARDWARE_CPU in config.hardware:
        with tf.device(cpu_device):
            data_dicts += _run_function_under_strategies(
                user_fn, iters, config, HARDWARE_CPU, extra_columns,
                use_autograph, print_intermediates)

    if HARDWARE_GPU in config.hardware:
        if tf.config.list_physical_devices('GPU'):
            with tf.device(gpu_device):
                data_dicts += _run_function_under_strategies(
                    user_fn, iters, config, HARDWARE_GPU, extra_columns,
                    use_autograph, print_intermediates)
        else:
            print('Skipping GPU runs -- no GPU!')

    return data_dicts
示例#6
0
def evaluate_or_sample(data_provider,
                       model,
                       mode='eval',
                       save_dir='/tmp/ddsp/training',
                       restore_dir='',
                       batch_size=32,
                       num_batches=50,
                       ckpt_delay_secs=0,
                       run_once=False,
                       run_until_step=0):
    """Run evaluation loop.

  Args:
    data_provider: DataProvider instance.
    model: Model instance.
    mode: Whether to 'eval' with metrics or create 'sample' s.
    save_dir: Path to directory to save summary events.
    restore_dir: Path to directory with checkpoints, defaults to save_dir.
    batch_size: Size of each eval/sample batch.
    num_batches: How many batches to eval from dataset. -1 denotes all batches.
    ckpt_delay_secs: Time to wait when a new checkpoint was not detected.
    run_once: Only run evaluation or sampling once.
    run_until_step: Run until we see a checkpoint with a step greater or equal
      to the specified value. Ignored if <= 0.

  Returns:
    If the mode is 'eval', then returns a dictionary of Tensors keyed by loss
    type. Otherwise, returns None.
  """
    # Default to restoring from the save directory.
    restore_dir = save_dir if not restore_dir else restore_dir

    # Set up the summary writer and metrics.
    summary_dir = os.path.join(save_dir, 'summaries', 'eval')
    summary_writer = tf.summary.create_file_writer(summary_dir)

    # Sample continuously and load the newest checkpoint each time
    checkpoints_iterator = tf.train.checkpoints_iterator(
        restore_dir, ckpt_delay_secs)

    # Get the dataset.
    dataset = data_provider.get_batch(batch_size=batch_size,
                                      shuffle=False,
                                      repeats=-1)

    # Get audio sample rate
    sample_rate = data_provider.sample_rate
    # Get feature frame rate
    frame_rate = data_provider.frame_rate

    latest_losses = None

    with summary_writer.as_default():
        for checkpoint_path in checkpoints_iterator:
            step = int(checkpoint_path.split('-')[-1])

            # Redefine thte dataset iterator each time to make deterministic.
            dataset_iter = iter(dataset)

            # Load model.
            model.restore(checkpoint_path)

            # Iterate through dataset and make predictions
            checkpoint_start_time = time.time()

            for batch_idx in range(1, num_batches + 1):
                try:
                    start_time = time.time()
                    logging.info('Predicting batch %d of size %d', batch_idx,
                                 batch_size)

                    # Predict a batch of audio.
                    batch = next(dataset_iter)

                    if isinstance(data_provider, data.SyntheticNotes):
                        batch['audio'] = model.generate_synthetic_audio(batch)
                        batch['f0_confidence'] = tf.ones_like(
                            batch['f0_hz'])[:, :, 0]
                        batch[
                            'loudness_db'] = ddsp.spectral_ops.compute_loudness(
                                batch['audio'])

                    # TODO(jesseengel): Find a way to add losses with training=False.
                    audio = batch['audio']
                    outputs, losses = model(batch,
                                            return_losses=True,
                                            training=True)
                    audio_gen = model.get_audio_from_outputs(outputs)

                    # Create metrics on first batch.
                    if mode == 'eval' and batch_idx == 1:
                        loudness_metrics = metrics.LoudnessMetrics(
                            sample_rate=sample_rate, frame_rate=frame_rate)
                        f0_metrics = metrics.F0Metrics(sample_rate=sample_rate,
                                                       frame_rate=frame_rate,
                                                       name='f0_harm')
                        f0_crepe_metrics = metrics.F0CrepeMetrics(
                            sample_rate=sample_rate, frame_rate=frame_rate)

                        f0_twm_metrics = metrics.F0Metrics(
                            sample_rate=sample_rate,
                            frame_rate=frame_rate,
                            name='f0_twm')

                        avg_losses = {
                            name: tf.keras.metrics.Mean(name=name,
                                                        dtype=tf.float32)
                            for name in list(losses.keys())
                        }

                    processor_group = getattr(model, 'processor_group', None)
                    if processor_group is not None:
                        for processor in processor_group.processors:
                            # If using a sinusoidal model, infer f0 with two-way mismatch.
                            if isinstance(processor, ddsp.synths.Sinusoidal):
                                # Run on CPU to avoid running out of memory (not expensive).
                                with tf.device('CPU'):
                                    processor_controls = outputs[
                                        processor.name]['controls']
                                    amps = processor_controls['amplitudes']
                                    freqs = processor_controls['frequencies']
                                    twm = ddsp.losses.TWMLoss()
                                    # Treat all freqs as candidate f0s.
                                    outputs['f0_hz_twm'] = twm.predict_f0(
                                        freqs, freqs, amps)
                                    logging.info(
                                        'Added f0 estimate from sinusoids.')
                                    break

                            # If using a noisy sinusoidal model, infer f0 w/ two-way mismatch.
                            elif isinstance(processor,
                                            ddsp.synths.NoisySinusoidal):
                                # Run on CPU to avoid running out of memory (not expensive).
                                with tf.device('CPU'):
                                    processor_controls = outputs[
                                        processor.name]['controls']
                                    amps = processor_controls['amplitudes']
                                    freqs = processor_controls['frequencies']
                                    noise_ratios = processor_controls[
                                        'noise_ratios']
                                    amps = amps * (1.0 - noise_ratios)
                                    twm = ddsp.losses.TWMLoss()
                                    # Treat all freqs as candidate f0s.
                                    outputs['f0_hz_twm'] = twm.predict_f0(
                                        freqs, freqs, amps)
                                    logging.info(
                                        'Added f0 estimate from sinusoids.')
                                    break

                    has_f0_twm = ('f0_hz_twm' in outputs and 'f0_hz' in batch)
                    has_f0 = ('f0_hz' in outputs and 'f0_hz' in batch)

                    logging.info('Prediction took %.1f seconds',
                                 time.time() - start_time)

                    if mode == 'sample':
                        start_time = time.time()
                        logging.info('Writing summmaries for batch %d',
                                     batch_idx)

                        if audio_gen is not None:
                            audio_gen = np.array(audio_gen)

                            # Add audio.
                            summaries.audio_summary(audio_gen,
                                                    step,
                                                    sample_rate,
                                                    name='audio_generated')
                            summaries.audio_summary(audio,
                                                    step,
                                                    sample_rate,
                                                    name='audio_original')

                            # Add plots.
                            summaries.waveform_summary(audio, audio_gen, step)
                            summaries.spectrogram_summary(
                                audio, audio_gen, step)

                        if has_f0:
                            summaries.f0_summary(batch['f0_hz'],
                                                 outputs['f0_hz'],
                                                 step,
                                                 name='f0_harmonic')
                        if has_f0_twm:
                            summaries.f0_summary(batch['f0_hz'],
                                                 outputs['f0_hz_twm'],
                                                 step,
                                                 name='f0_twm')

                        logging.info(
                            'Writing batch %i with size %i took %.1f seconds',
                            batch_idx, batch_size,
                            time.time() - start_time)

                    elif mode == 'eval':
                        start_time = time.time()
                        logging.info('Calculating metrics for batch %d',
                                     batch_idx)

                        if audio_gen is not None:
                            loudness_metrics.update_state(batch, audio_gen)
                            if has_f0:
                                f0_metrics.update_state(
                                    batch, outputs['f0_hz'])
                            else:
                                f0_crepe_metrics.update_state(batch, audio_gen)

                        if has_f0_twm:
                            f0_twm_metrics.update_state(
                                batch, outputs['f0_hz_twm'])
                        # Loss.
                        for k, v in losses.items():
                            avg_losses[k].update_state(v)

                        logging.info(
                            'Metrics for batch %i with size %i took %.1f seconds',
                            batch_idx, batch_size,
                            time.time() - start_time)

                except tf.errors.OutOfRangeError:
                    logging.info('End of dataset.')
                    break

            logging.info('All %d batches in checkpoint took %.1f seconds',
                         num_batches,
                         time.time() - checkpoint_start_time)

            if mode == 'eval':
                loudness_metrics.flush(step)
                if has_f0:
                    f0_metrics.flush(step)
                else:
                    f0_crepe_metrics.flush(step)
                if has_f0_twm:
                    f0_twm_metrics.flush(step)
                latest_losses = {}
                for k, metric in avg_losses.items():
                    latest_losses[k] = metric.result()
                    tf.summary.scalar('losses/{}'.format(k),
                                      metric.result(),
                                      step=step)
                    metric.reset_states()

            summary_writer.flush()

            if run_once:
                break

            if 0 < run_until_step <= step:
                logging.info(
                    'Saw checkpoint with step %d, which is greater or equal to'
                    ' `run_until_step` of %d. Exiting.', step, run_until_step)
                break
    return latest_losses
示例#7
0
def main(unused_argv):
    tf.enable_v2_behavior()
    num_workers = 1
    job_name = 'worker'
    primary_cpu_task = '/job:%s' % job_name

    is_tpu_pod = num_workers > 1
    model_dir = FLAGS.model_dir if FLAGS.model_dir else DEFAULT_MODEL_DIR
    batch_size = PER_CORE_BATCH_SIZE * FLAGS.num_cores
    steps_per_epoch = FLAGS.steps_per_epoch or (int(
        APPROX_IMAGENET_TRAINING_IMAGES // batch_size))
    steps_per_eval = int(1.0 *
                         math.ceil(IMAGENET_VALIDATION_IMAGES / batch_size))

    logging.info('Saving checkpoints at %s', model_dir)

    logging.info('Use TPU at %s',
                 FLAGS.tpu if FLAGS.tpu is not None else 'local')
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        tpu=FLAGS.tpu, job_name=job_name)
    tf.config.experimental_connect_to_host(resolver.master())  # pylint: disable=line-too-long
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.experimental.TPUStrategy(resolver)

    with tf.device(primary_cpu_task):
        # TODO(b/130307853): In TPU Pod, we have to use
        # `strategy.experimental_distribute_datasets_from_function` instead of
        # `strategy.experimental_distribute_dataset` because dataset cannot be
        # cloned in eager mode. And when using
        # `strategy.experimental_distribute_datasets_from_function`, we should use
        # per core batch size instead of global batch size, because no re-batch is
        # happening in this case.
        if is_tpu_pod:
            imagenet_train = imagenet_input.ImageNetInput(
                is_training=True,
                data_dir=FLAGS.data,
                batch_size=PER_CORE_BATCH_SIZE,
                use_bfloat16=_USE_BFLOAT16)
            imagenet_eval = imagenet_input.ImageNetInput(
                is_training=False,
                data_dir=FLAGS.data,
                batch_size=PER_CORE_BATCH_SIZE,
                use_bfloat16=_USE_BFLOAT16)
            train_dataset = strategy.experimental_distribute_datasets_from_function(
                imagenet_train.input_fn)
            test_dataset = strategy.experimental_distribute_datasets_from_function(
                imagenet_eval.input_fn)
        else:
            imagenet_train = imagenet_input.ImageNetInput(
                is_training=True,
                data_dir=FLAGS.data,
                batch_size=batch_size,
                use_bfloat16=_USE_BFLOAT16)
            imagenet_eval = imagenet_input.ImageNetInput(
                is_training=False,
                data_dir=FLAGS.data,
                batch_size=batch_size,
                use_bfloat16=_USE_BFLOAT16)
            train_dataset = strategy.experimental_distribute_dataset(
                imagenet_train.input_fn())
            test_dataset = strategy.experimental_distribute_dataset(
                imagenet_eval.input_fn())

        with strategy.scope():
            logging.info('Building Keras ResNet-50 model')
            model = resnet_model.ResNet50(num_classes=NUM_CLASSES)
            base_lr = _BASE_LEARNING_RATE * batch_size / 256
            optimizer = tf.keras.optimizers.SGD(
                learning_rate=ResnetLearningRateSchedule(
                    steps_per_epoch, base_lr),
                momentum=0.9,
                nesterov=True)
            training_loss = tf.keras.metrics.Mean('training_loss',
                                                  dtype=tf.float32)
            training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
                'training_accuracy', dtype=tf.float32)
            test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
            test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
                'test_accuracy', dtype=tf.float32)
            logging.info('Finished building Keras ResNet-50 model')

            checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
            latest_checkpoint = tf.train.latest_checkpoint(model_dir)
            initial_epoch = 0
            if latest_checkpoint:
                # checkpoint.restore must be within a strategy.scope() so that optimizer
                # slot variables are mirrored.
                checkpoint.restore(latest_checkpoint)
                logging.info('Loaded checkpoint %s', latest_checkpoint)
                initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

        # Create summary writers
        train_summary_writer = tf.summary.create_file_writer(
            os.path.join(model_dir, 'summaries/train'))
        test_summary_writer = tf.summary.create_file_writer(
            os.path.join(model_dir, 'summaries/test'))

        @tf.function
        def train_step(iterator):
            """Training StepFn."""
            def step_fn(inputs):
                """Per-Replica StepFn."""
                images, labels = inputs
                with tf.GradientTape() as tape:
                    predictions = model(images, training=True)

                    # Loss calculations.
                    #
                    # Part 1: Prediction loss.
                    prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
                        labels, predictions)
                    loss1 = tf.reduce_mean(prediction_loss)
                    # Part 2: Model weights regularization
                    loss2 = tf.reduce_sum(model.losses)

                    # Scale the loss given the TPUStrategy will reduce sum all gradients.
                    loss = loss1 + loss2
                    scaled_loss = loss / strategy.num_replicas_in_sync

                grads = tape.gradient(scaled_loss, model.trainable_variables)
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))
                training_loss.update_state(loss)
                training_accuracy.update_state(labels, predictions)

            strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

        @tf.function
        def test_step(iterator):
            """Evaluation StepFn."""
            def step_fn(inputs):
                images, labels = inputs
                predictions = model(images, training=False)
                loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, predictions)
                loss = safe_mean(loss)
                test_loss.update_state(loss)
                test_accuracy.update_state(labels, predictions)

            strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

        train_iterator = iter(train_dataset)
        for epoch in range(initial_epoch, FLAGS.num_epochs):
            logging.info('Starting to run epoch: %s', epoch)
            with train_summary_writer.as_default():
                for step in range(steps_per_epoch):
                    if step % 20 == 0:
                        logging.info('Running step %s in epoch %s', step,
                                     epoch)
                    train_step(train_iterator)
                tf.summary.scalar('loss',
                                  training_loss.result(),
                                  step=optimizer.iterations)
                tf.summary.scalar('accuracy',
                                  training_accuracy.result(),
                                  step=optimizer.iterations)
                logging.info('Training loss: %s, accuracy: %s%%',
                             round(training_loss.result(), 4),
                             round(training_accuracy.result() * 100, 2))
                training_loss.reset_states()
                training_accuracy.reset_states()

            with test_summary_writer.as_default():
                test_iterator = iter(test_dataset)
                for step in range(steps_per_eval):
                    if step % 20 == 0:
                        logging.info(
                            'Starting to run eval step %s of epoch: %s', step,
                            epoch)
                    test_step(test_iterator)
                tf.summary.scalar('loss',
                                  test_loss.result(),
                                  step=optimizer.iterations)
                tf.summary.scalar('accuracy',
                                  test_accuracy.result(),
                                  step=optimizer.iterations)
                logging.info('Test loss: %s, accuracy: %s%%',
                             round(test_loss.result(), 4),
                             round(test_accuracy.result() * 100, 2))
                test_loss.reset_states()
                test_accuracy.reset_states()

            checkpoint_name = checkpoint.save(
                os.path.join(model_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)
示例#8
0
def main(unused_argv):
    assert FLAGS.data is not None, 'Provide training data path via --data.'
    tf.enable_v2_behavior()

    batch_size = FLAGS.num_cores * PER_CORE_BATCH_SIZE

    training_steps_per_epoch = FLAGS.steps_per_epoch or (int(
        APPROX_IMAGENET_TRAINING_IMAGES // batch_size))
    validation_steps = int(
        math.ceil(1.0 * IMAGENET_VALIDATION_IMAGES / batch_size))

    model_dir = FLAGS.model_dir if FLAGS.model_dir else DEFAULT_MODEL_DIR
    logging.info('Saving tensorboard summaries at %s', model_dir)

    logging.info('Use TPU at %s',
                 FLAGS.tpu if FLAGS.tpu is not None else 'local')
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.experimental.TPUStrategy(resolver)

    logging.info('Use bfloat16: %s.', USE_BFLOAT16)
    logging.info('Use global batch size: %s.', batch_size)
    logging.info('Enable top 5 accuracy: %s.', FLAGS.eval_top_5_accuracy)
    logging.info('Training model using data in directory "%s".', FLAGS.data)

    with tf.device('/job:worker'):
        with strategy.scope():
            logging.info('Building Keras ResNet-50 model')
            model = resnet_model.ResNet50(num_classes=NUM_CLASSES)

            logging.info('Compiling model.')
            metrics = ['sparse_categorical_accuracy']

            if FLAGS.eval_top_5_accuracy:
                metrics.append(sparse_top_k_categorical_accuracy)

            model.compile(optimizer=tf.keras.optimizers.SGD(
                learning_rate=BASE_LEARNING_RATE, momentum=0.9, nesterov=True),
                          loss='sparse_categorical_crossentropy',
                          metrics=metrics)

        imagenet_train = imagenet_input.ImageNetInput(
            is_training=True,
            data_dir=FLAGS.data,
            batch_size=batch_size,
            use_bfloat16=USE_BFLOAT16)
        imagenet_eval = imagenet_input.ImageNetInput(is_training=False,
                                                     data_dir=FLAGS.data,
                                                     batch_size=batch_size,
                                                     use_bfloat16=USE_BFLOAT16)

        lr_schedule_cb = LearningRateBatchScheduler(
            schedule=learning_rate_schedule_wrapper(training_steps_per_epoch))
        tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=model_dir)

        training_callbacks = [lr_schedule_cb, tensorboard_cb]

        model.fit(imagenet_train.input_fn(),
                  epochs=FLAGS.num_epochs,
                  steps_per_epoch=training_steps_per_epoch,
                  callbacks=training_callbacks,
                  validation_data=imagenet_eval.input_fn(),
                  validation_steps=validation_steps,
                  validation_freq=5)

        model_saving_utils.save_model(model, model_dir, WEIGHTS_TXT)
示例#9
0
 def benchmark_tf_np_tf_function_mlp_inference_batch_1_cpu(self):
     with tf.device('/CPU:0'):
         model = tf_numpy_mlp.MLP()
         x = tfnp.ones(shape=(1, 10)).astype(np.float32)
         self._benchmark_and_report(self._get_name(),
                                    tf.function(lambda: model.inference(x)))
示例#10
0
文件: learner.py 项目: diegozd/valan
def run_with_address(
    problem_type: framework_problem_type.ProblemType,
    listen_address: Text,
    hparams: Dict[Text, Any],
):
    """Runs the learner with the given problem type.

  Args:
    problem_type: An instance of `framework_problem_type.ProblemType`.
    listen_address: The network address on which to listen.
    hparams: A dict containing hyperparameter settings.
  """
    devices = device_lib.list_local_devices()
    logging.info('Found devices: %s', devices)
    devices = [d for d in devices if d.device_type == FLAGS.agent_device]
    assert devices, 'Could not find a device of type %s' % FLAGS.agent_device
    agent_device = devices[0].name
    logging.info('Using agent device: %s', agent_device)

    # Initialize agent, variables.
    specs = utils.read_specs(hparams['logdir'])
    flat_specs = [
        tf.TensorSpec.from_spec(s, str(i))
        for i, s in enumerate(tf.nest.flatten(specs))
    ]
    queue_capacity = FLAGS.queue_capacity or FLAGS.batch_size * 10
    queue = tf.queue.FIFOQueue(
        queue_capacity,
        [t.dtype for t in flat_specs],
        [t.shape for t in flat_specs],
    )
    agent = problem_type.get_agent()
    # Create dummy environment output of shape [num_timesteps, batch_size, ...].
    env_output = tf.nest.map_structure(
        lambda s: tf.zeros(
            list(s.shape)[0:1] + [FLAGS.batch_size] + list(s.shape)[1:], s.
            dtype), specs.env_output)
    init_observation = utils.get_row_nested_tensor(env_output.observation, 0)
    init_agent_state = agent.get_initial_state(init_observation,
                                               batch_size=FLAGS.batch_size)
    env_output = _convert_uint8_to_bfloat16(env_output)
    with tf.device(agent_device):
        agent(env_output, init_agent_state)

        # Create optimizer.

        if FLAGS.lr_decay_steps > 0 and FLAGS.lr_decay_rate < 1.:
            lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
                initial_learning_rate=FLAGS.learning_rate,
                decay_steps=FLAGS.lr_decay_steps,
                decay_rate=FLAGS.lr_decay_rate)
        else:
            lr_schedule = FLAGS.learning_rate
        optimizer = problem_type.get_optimizer(lr_schedule)
        # NOTE: `iterations` is a non-trainable variable which is managed by
        # optimizer (created inside optimizer as well as incremented by 1 on every
        # call to optimizer.minimize).
        iterations = optimizer.iterations
        study_loss_types = problem_type.get_study_loss_types()

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(actor_output):
            """Per-replica StepFn."""
            actor_output = tf.nest.pack_sequence_as(specs, actor_output)
            (initial_agent_state, env_output, actor_agent_output, actor_action,
             loss_type, info) = actor_output
            optimizer.minimize(
                functools.partial(loss_fns.compute_loss,
                                  study_loss_types=study_loss_types,
                                  current_batch_loss_type=loss_type,
                                  agent=agent,
                                  agent_state=initial_agent_state,
                                  env_output=env_output,
                                  actor_agent_output=actor_agent_output,
                                  actor_action=actor_action,
                                  num_steps=iterations),
                agent.trainable_variables)
            return info

        return step_fn(next(iterator))

    ckpt_manager = _maybe_restore_from_ckpt(hparams['logdir'],
                                            agent=agent,
                                            optimizer=optimizer)
    server = _create_server(listen_address,
                            specs,
                            agent,
                            queue,
                            extra_variables=[iterations])
    logging.info('Starting gRPC server')
    server.start()

    dataset = tf.data.Dataset.from_tensors(0).repeat(None)
    dataset = dataset.map(lambda _: queue.dequeue())
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    # Transpose each batch to time-major order. This is relatively slow, so do
    # this work outside of the training loop.
    dataset = dataset.map(functools.partial(_transpose_batch, specs))
    dataset = dataset.apply(tf.data.experimental.copy_to_device(agent_device))
    with tf.device(agent_device):
        dataset = dataset.prefetch(1)
        iterator = iter(dataset)

    # Execute learning and track performance.
    summary_writer = tf.summary.create_file_writer(hparams['logdir'],
                                                   flush_millis=20000,
                                                   max_queue=1000)
    last_ckpt_time = time.time()
    with summary_writer.as_default():
        last_log_iterations = iterations
        last_log_num_env_frames = iterations * hparams['iter_frame_ratio']
        last_log_time = time.time()
        while iterations < hparams['final_iteration']:
            logging.info('Iteration %d of %d', iterations + 1,
                         hparams['final_iteration'])
            # Save checkpoint at specified intervals or if no previous ckpt exists.
            current_time = time.time()
            if (current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs
                    or not ckpt_manager.latest_checkpoint):
                ckpt_manager.save(checkpoint_number=iterations)
                last_ckpt_time = current_time

            with utils.WallTimer() as wt:
                with tf.device(agent_device):
                    info = train_step(iterator)
            tf.summary.scalar('steps_summary/step_seconds',
                              wt.duration,
                              step=iterations)

            if current_time - last_log_time >= 120:
                num_env_frames = iterations * hparams['iter_frame_ratio']
                num_frames_since = num_env_frames - last_log_num_env_frames
                num_iterations_since = iterations - last_log_iterations
                elapsed_time = time.time() - last_log_time
                tf.summary.scalar(
                    'steps_summary/num_environment_frames_per_sec',
                    tf.cast(num_frames_since, tf.float32) / elapsed_time,
                    step=iterations)
                tf.summary.scalar('steps_summary/num_iterations_per_sec',
                                  tf.cast(num_iterations_since, tf.float32) /
                                  elapsed_time,
                                  step=iterations)
                tf.summary.scalar('queue_size', queue.size(), step=iterations)
                tf.summary.scalar('learning_rate',
                                  optimizer._decayed_lr(var_dtype=tf.float32),
                                  step=iterations)
                last_log_num_env_frames, last_log_iterations, last_log_time = (
                    num_env_frames, iterations, time.time())
                logging.info('Number of environment frames: %d',
                             num_env_frames)

            problem_type.create_summary(step=iterations, info=info)

    # Finishing up.
    ckpt_manager.save(checkpoint_number=iterations)
    queue.close()
    server.shutdown()