Ejemplo n.º 1
0
def _safediv(x, y):
    try:
        finfo = np.finfo(y.dtype)
    except ValueError:
        finfo = np.iinfo(y.dtype)
    return x * np.clip(np.reciprocal(y), a_min=None, a_max=finfo.max)
Ejemplo n.º 2
0
def main(_):
  tf.enable_v2_behavior()

  tf.random.set_seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  random.seed(FLAGS.seed)

  if not gfile.isdir(FLAGS.save_dir):
    gfile.mkdir(FLAGS.save_dir)

  hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr)
  # Get hyperparmaters
  if FLAGS.xm_parameters:
    for key, value in json.loads(FLAGS.xm_parameters).items():
      if key not in hparam_str_dict:
        hparam_str_dict[key] = value

  hparam_str = ','.join(['%s=%s' % (k, str(hparam_str_dict[k])) for k in
                         sorted(hparam_str_dict.keys())])

  # Number of local devices for this host.
  n_devices = jax.local_device_count()

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))

  batch_size = FLAGS.per_device_batch_size * n_devices
  io_shape = (FLAGS.per_device_batch_size,
              FLAGS.num_strings_per_task,
              FLAGS.max_characters)
  program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length)

  # Setup DSL
  # ---------------------------------------------------------------------------

  # Build token tables.
  id_char_table = {i+1: char for (i, char) in enumerate(dsl.CHARACTER)}
  char_id_table = {char: id for id, char in id_char_table.items()}
  id_token_table, token_id_table = dsl_tokens.build_token_tables()
  io_vocab_size = len(char_id_table) + 1  # For padding.
  program_vocab_size = len(token_id_table) + 1

  bos_token = token_id_table[dsl.BOS]
  eos_token = token_id_table[dsl.EOS]

  def decode_io(inputs, outputs):
    """Decode io examples tokens."""
    def decode_str(s):
      """Decode string tokens."""
      return ''.join([id_char_table[c_id] for c_id in s if c_id > 0])

    io_string = ''
    inps, outs = [], []
    for inp, out in zip(inputs, outputs):
      inps.append(decode_str(inp))
      outs.append(decode_str(out))
      io_string += inps[-1] + ' < ' + outs[-1] + ' > '
    return inps, outs, io_string[:-3]  # Remove last separator.

  def decode_program(program):
    """Decode program tokens."""
    program = program[:np.argmax(program == eos_token) + 1].astype(np.int32)
    try:
      p = dsl.decode_program(program, id_token_table)
      return p, p.to_string()
    except:  # pylint: disable=bare-except
      return None, ''  # Program does not compile.

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  if not FLAGS.dataset_filepattern:
    raise ValueError('Must specify filepattern to dataset.')

  # Training dataset.
  dataset = input_pipeline.create_dataset_from_tf_record(
      FLAGS.dataset_filepattern, token_id_table, char_id_table)
  dataset = dataset.padded_batch(
      batch_size,
      padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]),
      drop_remainder=True)
  # Split evaluation and training.
  eval_ds = dataset.take(FLAGS.num_eval_steps)
  # Decrease batch of predict dataset to handle beam search.
  predict_ds = eval_ds.unbatch().padded_batch(
      int(np.ceil(batch_size / 10)),
      padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]))
  train_ds = dataset.skip(FLAGS.num_eval_steps).repeat()
  train_iter = train_ds.as_numpy_iterator()

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = models.TransformerConfig(
      vocab_size=io_vocab_size,
      output_vocab_size=program_vocab_size,
      shift=True,
      emb_dim=FLAGS.embedding_dim,
      num_heads=FLAGS.num_heads,
      num_layers=FLAGS.num_layers,
      qkv_dim=FLAGS.embedding_dim,
      mlp_dim=FLAGS.hidden_dim,
      max_len=max(FLAGS.max_characters, FLAGS.max_program_length),
      use_relative_attention=FLAGS.use_relative_attention,
      deterministic=False,
      decode=False,
      bos_token=bos_token)
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(
      shift=False, deterministic=True, decode=True)

  rng = jax.random.PRNGKey(FLAGS.seed)
  rng = jax.random.fold_in(rng, jax.host_id())
  rng, init_rng = jax.random.split(rng)

  m = models.ProgramTransformer(eval_config)
  initial_variables = jax.jit(m.init)(
      init_rng,
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(io_shape, jnp.float32),
      jnp.ones(program_shape, jnp.float32))

  optimizer_def = optim.Adam(
      FLAGS.lr,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=FLAGS.weight_decay)
  optimizer = optimizer_def.create(initial_variables['params'])

  del initial_variables  # Don't keep a copy of the initial model.

  start_step = 0
  if FLAGS.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(
        os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)
    logging.info('Found model checkpointed at step %d.', start_step)

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = train_lib.create_learning_rate_scheduler(
      base_learning_rate=FLAGS.lr)
  p_train_step = jax.pmap(
      functools.partial(
          train_lib.train_step,
          learning_rate_fn=learning_rate_fn,
          config=train_config),
      axis_name='batch')
  p_eval_step = jax.pmap(
      functools.partial(train_lib.eval_step, config=eval_config),
      axis_name='batch')
  p_init_cache = jax.pmap(
      functools.partial(
          train_lib.initialize_cache,
          max_decode_len=FLAGS.max_program_length,
          config=predict_config),
      axis_name='batch')
  p_pred_step = jax.pmap(
      functools.partial(train_lib.predict_step, config=predict_config),
      axis_name='batch',
      static_broadcasted_argnums=(4, 5, 6))

  # Main Train Loop
  # ---------------------------------------------------------------------------
  train_rngs = jax.random.split(rng, jax.local_device_count())
  del rng

  metrics_all = []
  tick = time.time()
  for step in range(start_step, FLAGS.num_train_steps):
    inputs, outputs, programs = common_utils.shard(next(train_iter))

    optimizer, metrics, train_rngs = p_train_step(
        optimizer, inputs, outputs, programs, train_rng=train_rngs)
    metrics_all.append(metrics)

    # Save a Checkpoint
    if ((step % FLAGS.checkpoint_freq == 0 and step > 0) or
        step == FLAGS.num_train_steps - 1):
      if jax.host_id() == 0:
        # Save unreplicated optimizer + model state.
        checkpoints.save_checkpoint(
            os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str),
            jax_utils.unreplicate(optimizer),
            step)

    # Periodic metric handling.
    if not step or step % FLAGS.log_freq != 0:
      continue

    logging.info('Gathering training metrics.')
    # Training Metrics
    metrics_all = common_utils.get_metrics(metrics_all)
    lr = metrics_all.pop('learning_rate').mean()
    metrics_sums = jax.tree_map(jnp.sum, metrics_all)
    denominator = metrics_sums.pop('denominator')
    summary = jax.tree_map(
        lambda x: x / denominator,  # pylint: disable=cell-var-from-loop
        metrics_sums)
    summary['learning_rate'] = lr
    # Calculate (clipped) perplexity after averaging log-perplexities:
    summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)

    if jax.host_id() == 0:
      logging.info('Train in step: %d, loss: %.4f', step, summary['loss'])
      tock = time.time()
      steps_per_sec = FLAGS.log_freq / (tock - tick)
      tick = tock
      summary_writer.scalar('train/steps per second', steps_per_sec, step)
      for key, val in summary.items():
        summary_writer.scalar('train/' + key, val, step)
      summary_writer.flush()
    # Reset metric accumulation for next evaluation cycle.
    metrics_all = []

    # Evaluation Metrics
    logging.info('Gathering evaluation metrics.')
    t_evaluation_start = time.time()
    eval_metrics = []
    for batches in eval_ds.as_numpy_iterator():
      inputs, outputs, programs = common_utils.shard(batches)

      metrics = p_eval_step(optimizer.target, inputs, outputs, programs)
      eval_metrics.append(metrics)

    eval_metrics = common_utils.get_metrics(eval_metrics)
    eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
    eval_denominator = eval_metrics_sums.pop('denominator')
    eval_summary = jax.tree_map(
        lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
        eval_metrics_sums)

    if jax.host_id() == 0:
      logging.info('Evaluation time: %.4f s step %d, loss: %.4f.',
                   time.time()-t_evaluation_start, step, eval_summary['loss'])
      for key, val in eval_summary.items():
        summary_writer.scalar('eval/' + key, val, step)
      summary_writer.flush()

    # Beam search metrics.
    logging.info('Gathering beam search metrics.')
    for beam_size in [10, 100]:
      t_inference_start = time.time()
      pred_acc = 0
      pred_denominator = 0

      ios, targets, predictions = [], [], []
      for batches in predict_ds.as_numpy_iterator():
        pred_batch = batches
        # Handle final odd-sized batch by padding instead of dropping it.
        cur_pred_batch_size = pred_batch[0].shape[0]
        if cur_pred_batch_size % n_devices:
          padded_size = int(
              np.ceil(cur_pred_batch_size / n_devices) * n_devices)
          # pylint: disable=cell-var-from-loop
          pred_batch = jax.tree_map(
              lambda x: train_lib.pad_examples(x, padded_size), pred_batch)
        inputs, outputs, programs = common_utils.shard(pred_batch)

        cache = p_init_cache(inputs, outputs, programs)
        predicted = p_pred_step(optimizer.target,
                                inputs,
                                outputs,
                                cache,
                                eos_token,
                                programs.shape[-1],
                                beam_size)
        predicted = train_lib.tohost(predicted)
        inputs, outputs, programs = map(train_lib.tohost,
                                        (inputs, outputs, programs))

        pred_denominator += programs.shape[0]
        for i, beams in enumerate(predicted):
          inps, outs, io_string = decode_io(inputs[i], outputs[i])
          p, p_score = train_lib.eval_predicted(
              beams, inps, outs,
              parse_beam_fn=lambda x: decode_program(x)[0])
          if p_score >= len(inps):
            pred_acc += 1
          ios.append(io_string)
          targets.append(decode_program(programs[i])[1])
          predictions.append(p.to_string() if p else '')

      all_pred_acc, all_pred_denominator = train_lib.per_host_sum_pmap(
          jax.tree_map(np.array, (pred_acc, pred_denominator)))

      # Record beam search results as text summaries.
      message = []
      for n in np.random.choice(np.arange(len(predictions)), 8):
        text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n'
                f'predicted: {predictions[n]}\n\n')
        message.append(text)

      # Write to tensorboard.
      if jax.host_id() == 0:
        logging.info('Prediction time (beam %d): %.4f s step %d, score %.4f.',
                     beam_size, time.time() - t_inference_start, step,
                     all_pred_acc / all_pred_denominator)
        summary_writer.scalar('predict/score-{}'.format(beam_size),
                              all_pred_acc / all_pred_denominator, step)
        summary_writer.text('samples-{}'.format(beam_size),
                            '\n------\n'.join(message), step)
        summary_writer.flush()
Ejemplo n.º 3
0
 def inv(self, y):
     y_crop = y[..., :-1]
     z1m_cumprod = np.clip(1 - cumsum(y_crop), a_min=np.finfo(y.dtype).tiny)
     # hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod
     x = np.log(y_crop / z1m_cumprod)
     return x + np.log(x.shape[-1] - np.arange(x.shape[-1]))
Ejemplo n.º 4
0
    def predict(self, params, logits, context, target=None):
        context = jnp.expand_dims(jnp.expand_dims(jnp.expand_dims(context,
                                                                  axis=1),
                                                  axis=1),
                                  axis=1)
        context_bias = params.get('context_bias', 0.0)
        context_index = (params['context_maps'] *
                         context).sum(axis=-1) > context_bias

        context_map_values = jnp.asarray(
            [[[[1 << n for n in range(self.context_map_size)]]]])
        context_index = jnp.where(context_index, context_map_values, 0)
        context_index = context_index.sum(axis=-1, keepdims=True)

        batch_size = logits.shape[0]
        class_neuron_index = jnp.asarray([[[[c, n] for n in range(self.size)]
                                           for c in range(self.num_classes)]])
        class_neuron_index = jnp.tile(class_neuron_index,
                                      reps=(batch_size, 1, 1, 1))
        context_index = jnp.concatenate([class_neuron_index, context_index],
                                        axis=-1)

        dims = lax.GatherDimensionNumbers(offset_dims=(3, ),
                                          collapsed_slice_dims=(0, 1, 2),
                                          start_index_map=(0, 1, 2))
        weights = lax.gather(operand=params['weights'],
                             start_indices=context_index,
                             dimension_numbers=dims,
                             slice_sizes=(1, 1, 1,
                                          self.input_size + int(self.bias)))

        if self.bias:
            bias = jnp.tile(params['bias'], reps=(batch_size, 1, 1))
            logits = jnp.concatenate([logits, bias], axis=-1)
        logits = jnp.expand_dims(logits, axis=-1)

        output_logits = jnp.matmul(weights, logits)
        output_logits = jnp.clip(output_logits,
                                 a_min=jsp.special.logit(self.pred_clipping),
                                 a_max=jsp.special.logit(1.0 -
                                                         self.pred_clipping))

        if target is None:
            return jnp.squeeze(output_logits, axis=-1)

        else:
            logits = jnp.expand_dims(jnp.squeeze(logits, axis=-1), axis=-2)
            output_preds = jnn.sigmoid(output_logits)
            target = jnp.expand_dims(jnp.expand_dims(target, axis=-1), axis=-1)
            params['lr_step'], learning_rate = self.learning_rate.value(
                params['lr_step'])
            delta = learning_rate * (target - output_preds) * logits

            dims = lax.ScatterDimensionNumbers(
                update_window_dims=(3, ),
                inserted_window_dims=(0, 1, 2),
                scatter_dims_to_operand_dims=(0, 1, 2))

            if self.weight_clipping is None:
                params['weights'] = lax.scatter_add(
                    operand=params['weights'],
                    scatter_indices=context_index,
                    updates=delta,
                    dimension_numbers=dims)
            else:
                weights = jnp.clip(weights + delta,
                                   a_min=-self.weight_clipping,
                                   a_max=self.weight_clipping)
                params['weights'] = lax.scatter(operand=params['weights'],
                                                scatter_indices=context_index,
                                                updates=weights,
                                                dimension_numbers=dims)

            return params, jnp.squeeze(output_logits, axis=-1)
Ejemplo n.º 5
0
def clamp_probs(probs):
    finfo = np.finfo(get_dtype(probs))
    return np.clip(probs, a_min=finfo.tiny, a_max=1. - finfo.eps)
Ejemplo n.º 6
0
def snap_to_grid(sample):
    return jnp.clip(jnp.round((sample + 1) * 127.5) / 127.5 - 1, -1., 1.)
Ejemplo n.º 7
0
def PropValve(x):
    y = 3.0 * x
    flow_new = 1.0 * (jnp.tanh(0.03 * (y - 130)) + 1.0)
    flow_new = jnp.clip(flow_new, 0.0, 1.72)
    return flow_new
Ejemplo n.º 8
0
def _to_logits_multinom(probs):
    minval = np.finfo(get_dtypes(probs)[0]).min
    return np.clip(np.log(probs), a_min=minval)
Ejemplo n.º 9
0
 def sample(self, px_z, rng):
     img = logistic_mix_sample(self.out_conv(px_z), rng)
     return jnp.round((jnp.clip(img, -1, 1) + 1) * 127.5).astype(jnp.uint8)
Ejemplo n.º 10
0
 def clip(tensor, a_min=None, a_max=None, inplace=False):
     return np.clip(tensor, a_min, a_max)
Ejemplo n.º 11
0
for step in range(2000):
    lr = 0.1
    if step > 8000:
        lr = 0.05
    s = time()
    loss_val, v_grad = value_and_grad(compute_loss, 4)(
        sim_time,
        n_steps,
        jax_scene,
        coordinate_init,
        velocity_init,
        target_coordinate,
        attractor,
        constants,
    )
    velocity_init -= lr * np.clip(v_grad, -10, 10)
    print(time() - s, loss_val, velocity_init, v_grad, coordinate_init)

final_coordinate, trajectory = run_sim(sim_time, n_steps, jax_scene,
                                       coordinate_init, velocity_init,
                                       attractor, constants)
lines = mc.LineCollection(scene.get_all_segments())
fig, ax = plt.subplots()

ax.add_collection(lines)
traj = onp.array(trajectory.coordinate)
ax.scatter(coordinate_init[0], coordinate_init[1], c="g")
ax.plot(traj[:, 0], traj[:, 1])
ax.scatter(attractor[0], attractor[1], c="b")
ax.scatter(target_coordinate[0], target_coordinate[1], c="r")
plt.savefig(f"trajectory.jpg")
Ejemplo n.º 12
0
 def _prelu(x, slope):
     return jnp.clip(x, 0, jnp.inf) + jnp.clip(x, -jnp.inf, 0) * slope
Ejemplo n.º 13
0
 def _clip_gradient_norm(g):
     clip_coef = max_grad_norm / (
         jax.lax.stop_gradient(jnp.linalg.norm(g)) + 1e-6)
     clip_coef = jnp.clip(clip_coef, a_max=1.0)
     return g * clip_coef
Ejemplo n.º 14
0
def _safesub(x, y):
    try:
        finfo = np.finfo(y.dtype)
    except ValueError:
        finfo = np.iinfo(y.dtype)
    return x + np.clip(-y, a_min=None, a_max=finfo.max)
Ejemplo n.º 15
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    vocab_path = config.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(workdir, "sentencepiece_model")
        config.vocab_path = vocab_path
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info("Initializing dataset.")
    train_ds, eval_ds, _, encoder = input_pipeline.get_datasets(
        n_devices=jax.local_device_count(),
        config=config,
        vocab_path=vocab_path)

    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = temperature_sampler.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode("utf-8")

    def encode_strings(strs, max_len):
        tokenized_batch = np.zeros((len(strs), max_len), np.int32)
        for i, s in enumerate(strs):
            toks = encoder.tokenize(s).numpy()
            # Remove EOS token in prompt.
            tokenized_batch[i, :toks.shape[0] - 1] = toks[:-1]
        return tokenized_batch

    tokenized_prompts = encode_strings([config.prompts],
                                       config.max_predict_length)

    logging.info("Initializing model, optimizer, and step functions.")
    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        logits_via_embedding=config.logits_via_embedding,
        dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
        emb_dim=config.emb_dim,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        qkv_dim=config.qkv_dim,
        mlp_dim=config.mlp_dim,
        max_len=max(config.max_target_length, config.max_eval_target_length),
        dropout_rate=config.dropout_rate,
        attention_dropout_rate=config.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = jax.random.PRNGKey(config.seed)
    rng, init_rng = jax.random.split(rng)
    rng, inference_rng = random.split(rng)
    input_shape = (config.per_device_batch_size, config.max_target_length)

    m = models.TransformerLM(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(config.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=config.weight_decay)
    optimizer = optimizer_def.create(initial_variables["params"])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if config.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.process_index() > 0)
    if start_step == 0:
        writer.write_hparams(dict(config))

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=config.learning_rate,
        warmup_steps=config.warmup_steps)

    # compile multidevice versions of train/eval/predict step fn.
    p_train_step = jax.pmap(functools.partial(
        train_step, config=train_config, learning_rate_fn=learning_rate_fn),
                            axis_name="batch",
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(eval_step, config=eval_config),
                           axis_name="batch")

    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          config=predict_config,
                          temperature=config.sampling_temperature,
                          top_k=config.sampling_top_k),
        axis_name="batch",
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    logging.info("Starting training loop.")
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
        ]
    train_metrics = []
    with metric_writers.ensure_flushes(writer):
        for step in range(start_step, config.num_train_steps):
            is_last_step = step == config.num_train_steps - 1

            # Shard data to devices and do a training step.
            with jax.profiler.StepTraceAnnotation("train", step_num=step):
                batch = common_utils.shard(
                    jax.tree_map(np.asarray, next(train_iter)))
                optimizer, metrics = p_train_step(optimizer,
                                                  batch,
                                                  dropout_rng=dropout_rngs)
                train_metrics.append(metrics)

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            # Periodic metric handling.
            if step % config.eval_every_steps == 0 or is_last_step:
                with report_progress.timed("training_metrics"):
                    logging.info("Gathering training metrics.")
                    train_metrics = common_utils.get_metrics(train_metrics)
                    lr = train_metrics.pop("learning_rate").mean()
                    metrics_sums = jax.tree_map(jnp.sum, train_metrics)
                    denominator = metrics_sums.pop("denominator")
                    summary = jax.tree_map(lambda x: x / denominator,
                                           metrics_sums)  # pylint: disable=cell-var-from-loop
                    summary["learning_rate"] = lr
                    summary["perplexity"] = jnp.clip(jnp.exp(summary["loss"]),
                                                     a_max=1.0e4)
                    summary = {"train_" + k: v for k, v in summary.items()}
                    writer.write_scalars(step, summary)
                    train_metrics = []

                with report_progress.timed("eval"):
                    eval_results = evaluate(
                        p_eval_step=p_eval_step,
                        target=optimizer.target,
                        eval_ds=eval_ds,
                        num_eval_steps=config.num_eval_steps)
                    # (clipped) perplexity after averaging log-perplexitie
                    eval_results["perplexity"] = jnp.clip(jnp.exp(
                        eval_results["loss"]),
                                                          a_max=1.0e4)
                    writer.write_scalars(
                        step,
                        {"eval_" + k: v
                         for k, v in eval_results.items()})

                with report_progress.timed("generate_text"):
                    exemplars = generate_prediction(
                        p_pred_step=p_pred_step,
                        target=optimizer.target,
                        tokenized_prompts=tokenized_prompts,
                        eos_id=eos_id,
                        inference_rng=inference_rng,
                        decode_tokens=decode_tokens,
                        max_predict_length=config.max_predict_length)
                    writer.write_texts(step, {"samples": exemplars})

            # Save a checkpoint on one host after every checkpoint_freq steps.
            save_checkpoint = (step % config.checkpoint_every_steps == 0
                               or is_last_step)
            if config.save_checkpoints and save_checkpoint and jax.process_index(
            ) == 0:
                with report_progress.timed("checkpoint"):
                    checkpoints.save_checkpoint(
                        workdir, jax_utils.unreplicate(optimizer), step)
Ejemplo n.º 16
0
Archivo: iaf.py Proyecto: juvu/numpyro
def _clamp_preserve_gradients(x, min, max):
    return x + stop_gradient(np.clip(x, a_min=min, a_max=max) - x)
Ejemplo n.º 17
0
def clipped_probab_ratios(probab_ratios, epsilon=0.2):
  return np.clip(probab_ratios, 1 - epsilon, 1 + epsilon)
Ejemplo n.º 18
0
    def kernel(rng_key: jax.random.PRNGKey,
               state: HMCState) -> Tuple[HMCState, HMCInfo]:
        """Moves the chain by one step using the Hamiltonian dynamics.

        Parameters
        ----------
        rng_key:
           The pseudo-random number generator key used to generate random numbers.
        state:
            The current state of the chain: position, log-probability and gradient
            of the log-probability.

        Returns
        -------
        The next state of the chain and additional information about the current step.
        """
        key_momentum, key_integrator, key_accept = jax.random.split(rng_key, 3)

        position, potential_energy, potential_energy_grad = state
        momentum = momentum_generator(key_momentum)
        energy = potential_energy + kinetic_energy(momentum)

        proposal, proposal_info = proposal_generator(
            key_integrator,
            HMCProposalState(position, momentum, potential_energy_grad))
        new_position, new_momentum, new_potential_energy_grad = proposal

        flipped_momentum = -1.0 * new_momentum  # to make the transition reversible
        new_potential_energy = potential(new_position)
        new_energy = new_potential_energy + kinetic_energy(flipped_momentum)
        new_state = HMCState(new_position, new_potential_energy,
                             new_potential_energy_grad)

        delta_energy = energy - new_energy
        delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf,
                                 delta_energy)
        is_divergent = jnp.abs(delta_energy) > divergence_threshold

        p_accept = jnp.clip(jnp.exp(delta_energy), a_max=1)
        do_accept = jax.random.bernoulli(key_accept, p_accept)
        accept_state = (
            new_state,
            HMCInfo(
                new_state,
                p_accept,
                True,
                is_divergent,
                new_energy,
                proposal,
                proposal_info,
            ),
        )
        reject_state = (
            state,
            HMCInfo(
                new_state,
                p_accept,
                False,
                is_divergent,
                energy,
                proposal,
                proposal_info,
            ),
        )
        return jax.lax.cond(
            do_accept,
            accept_state,
            lambda state: state,
            reject_state,
            lambda state: state,
        )
Ejemplo n.º 19
0
 def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
     return jnp.clip(inputs, self.spec.minimum, self.spec.maximum)
Ejemplo n.º 20
0
 def kl_to_standard_normal_fn(mu: Array, sigma: Array = sigma):
     v = jnp.clip(sigma**2, 1e-6, 1e6)
     return 0.5 * (jnp.sum(v) + jnp.sum(mu**2) -
                   jnp.sum(jnp.ones_like(mu)) - jnp.sum(jnp.log(v)))
Ejemplo n.º 21
0
 def update_fn(updates, state, params=None):
     del params
     updates = jax.tree_map(lambda g: jnp.clip(g, -max_delta, max_delta),
                            updates)
     return updates, state
Ejemplo n.º 22
0
    'crop_border': lambda x: scaled_sigmoid_inverse(x, 0.0, 20.0),
    'hue_jitter': jax.scipy.special.logit,
    'sat_jitter': jax.scipy.special.logit,
}

cons_func_dict = {
    'lr':
    lambda x: jnp.exp(x),
    'b1':
    lambda x: scaled_sigmoid(x, 1e-4, 1.0 - 1e-4, False),
    'b2':
    lambda x: scaled_sigmoid(x, 1e-4, 1.0 - 1e-4, False),
    'momentum':
    jax.nn.sigmoid,
    'eps':
    lambda x: jnp.clip(
        jnp.exp(x), a_min=cons_range['eps'][0], a_max=cons_range['eps'][1]),
    'weight_decay':
    lambda x: jnp.clip(jnp.exp(x),
                       a_min=cons_range['weight_decay'][0],
                       a_max=cons_range['weight_decay'][1]),
    'l1':
    lambda x: jnp.exp(x),
    'l2':
    lambda x: jnp.exp(x),
    'cutoutsize':
    lambda x: scaled_sigmoid(x, 0.0, 22.0, True),
    'dropout_prob':
    lambda x: scaled_sigmoid(x, 1e-2, 0.95, False),
    'dropouts':
    lambda x: jnp.clip(scaled_sigmoid(x, 1e-2, 0.95, False),
                       a_min=cons_range['dropouts'][0],
Ejemplo n.º 23
0
def binary_cross_entropy_with_logits(x, y):
    # compute -y * log(sigmoid(x)) - (1 - y) * log(1 - sigmoid(x))
    # Ref: https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
    return np.clip(x, 0) + np.log1p(np.exp(-np.abs(x))) - x * y
Ejemplo n.º 24
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()

    config = FLAGS.config
    logging.info('===========Config Dict============')
    logging.info(config)
    batch_size = config.batch_size
    learning_rate = config.learning_rate
    num_train_steps = config.num_train_steps
    num_eval_steps = config.num_eval_steps
    eval_freq = config.eval_frequency
    random_seed = config.random_seed
    model_type = config.model_type

    max_length = config.max_length

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(FLAGS.model_dir, 'summary'))

    if batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    train_ds, eval_ds, test_ds, encoder = input_pipeline.get_matching_datasets(
        n_devices=jax.local_device_count(),
        task_name=FLAGS.task_name,
        data_dir=FLAGS.data_dir,
        batch_size=batch_size,
        fixed_vocab=None,
        max_length=max_length,
        tokenizer=config.tokenizer,
        vocab_file_path=FLAGS.vocab_file_path)

    vocab_size = encoder.vocab_size
    logging.info('Vocab Size: %d', vocab_size)

    train_ds = train_ds.repeat()

    train_iter = iter(train_ds)
    input_shape = (batch_size, max_length)

    model_kwargs = {
        'vocab_size': vocab_size,
        'emb_dim': config.emb_dim,
        'num_heads': config.num_heads,
        'num_layers': config.num_layers,
        'qkv_dim': config.qkv_dim,
        'mlp_dim': config.mlp_dim,
        'max_len': max_length,
        'classifier': True,
        'num_classes': 2,
        'classifier_pool': config.pooling_mode
    }

    rng = random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    rng, init_rng = random.split(rng)
    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    dropout_rngs = random.split(rng, jax.local_device_count())

    if model_type == 'transformer':
        model = create_model(init_rng, transformer.TransformerDualEncoder,
                             input_shape, input_shape, model_kwargs)
    else:
        raise ValueError('Model type not supported.')

    optimizer = create_optimizer(model,
                                 learning_rate,
                                 weight_decay=FLAGS.config.weight_decay)
    del model  # Don't keep a copy of the initial model.
    start_step = 0
    if config.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors=config.factors,
        base_learning_rate=learning_rate,
        warmup_steps=config.warmup)
    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')
    # p_pred_step = jax.pmap(predict_step, axis_name='batch')

    metrics_all = []
    tick = time.time()
    logging.info('Starting training')
    logging.info('====================')

    for step, batch in zip(range(start_step, num_train_steps), train_iter):
        batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
        # logging.info(batch)
        optimizer, metrics, dropout_rngs = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        metrics_all.append(metrics)
        logging.info('train in step: %d', step)

        # Save a Checkpoint
        if ((step % config.checkpoint_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.host_id() == 0 and config.save_checkpoints:
                # Save unreplicated optimizer + model state.
                checkpoints.save_checkpoint(FLAGS.model_dir,
                                            jax_utils.unreplicate(optimizer),
                                            step)

        # Periodic metric handling.
        if step % eval_freq == 0 and step > 0:
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop('learning_rate').mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop('denominator')
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary['learning_rate'] = lr
            # Calculate (clipped) perplexity after averaging log-perplexities:
            summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']),
                                             a_max=1.0e4)
            logging.info('train in step: %d, loss: %.4f, acc: %.4f', step,
                         summary['loss'], summary['accuracy'])
            if jax.host_id() == 0:
                tock = time.time()
                steps_per_sec = eval_freq / (tock - tick)
                tick = tock
                summary_writer.scalar('steps per second', steps_per_sec, step)
                for key, val in summary.items():
                    summary_writer.scalar(f'train_{key}', val, step)
                summary_writer.flush()
            # Reset metric accumulation for next evaluation cycle.
            metrics_all = []

            # Eval Metrics
            eval_metrics = []
            eval_iter = iter(eval_ds)
            if num_eval_steps == -1:
                num_iter = itertools.repeat(1)
            else:
                num_iter = range(num_eval_steps)
            for _, eval_batch in zip(num_iter, eval_iter):
                # pylint: disable=protected-access
                eval_batch = common_utils.shard(
                    jax.tree_map(lambda x: x._numpy(), eval_batch))
                # pylint: enable=protected-access
                metrics = p_eval_step(optimizer.target, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop('denominator')
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)
            # Calculate (clipped) perplexity after averaging log-perplexities:
            eval_summary['perplexity'] = jnp.clip(jnp.exp(
                eval_summary['loss']),
                                                  a_max=1.0e4)
            logging.info('eval in step: %d, loss: %.4f, acc: %.4f', step,
                         eval_summary['loss'], eval_summary['accuracy'])
            if jax.host_id() == 0:
                for key, val in eval_summary.items():
                    summary_writer.scalar(f'eval_{key}', val, step)
                summary_writer.flush()

            # Test eval
            # Eval Metrics
            logging.info('Testing...')
            test_metrics = []
            test_iter = iter(test_ds)
            if num_eval_steps == -1:
                num_iter = itertools.repeat(1)
            else:
                num_iter = range(num_eval_steps)
            for _, test_batch in zip(num_iter, test_iter):
                # pylint: disable=protected-access
                test_batch = common_utils.shard(
                    jax.tree_map(lambda x: x._numpy(), test_batch))
                # pylint: enable=protected-access
                metrics = p_eval_step(optimizer.target, test_batch)
                test_metrics.append(metrics)
            test_metrics = common_utils.get_metrics(test_metrics)
            test_metrics_sums = jax.tree_map(jnp.sum, test_metrics)
            test_denominator = test_metrics_sums.pop('denominator')
            test_summary = jax.tree_map(
                lambda x: x / test_denominator,  # pylint: disable=cell-var-from-loop
                test_metrics_sums)
            # Calculate (clipped) perplexity after averaging log-perplexities:
            test_summary['perplexity'] = jnp.clip(jnp.exp(
                test_summary['loss']),
                                                  a_max=1.0e4)
            logging.info('test in step: %d, loss: %.4f, acc: %.4f', step,
                         test_summary['loss'], test_summary['accuracy'])
            if jax.host_id() == 0:
                for key, val in test_summary.items():
                    summary_writer.scalar(f'test_{key}', val, step)
                summary_writer.flush()
Ejemplo n.º 25
0
def deconstrain(v, a, b):
    return jnp.arctanh(jnp.clip((v - a) * 2. / (b - a) - 1., -0.999, 0.999))
Ejemplo n.º 26
0
 def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
     return jnp.clip(inputs, self._min, self._max)
Ejemplo n.º 27
0
def _clipped_expit(x):
    finfo = np.finfo(get_dtype(x))
    return np.clip(expit(x), a_min=finfo.tiny, a_max=1. - finfo.eps)
Ejemplo n.º 28
0
 def schedule(count):
   count = jnp.clip(count - transition_begin, 0, transition_steps)
   frac = 1 - count / transition_steps
   return (init_value - end_value) * (frac**power) + end_value
Ejemplo n.º 29
0
def _to_logits_multinom(probs):
    minval = jnp.finfo(get_dtype(probs)).min
    return jnp.clip(jnp.log(probs), a_min=minval)
Ejemplo n.º 30
0
def _reciprocal(x):
    result = np.clip(np.reciprocal(x), a_max=np.finfo(x.dtype).max)
    return result