Ejemplo n.º 1
0
  def _sync_weights_and_state_across_hosts(self):
    """Sync weights and state across all the hosts in the computation."""

    if logging.vlog_is_on(1):
      logging.debug(
          'Input training weights shape: %s',
          fastmath.nested_map(lambda x: x.shape,
                              self._model.weights))
      logging.debug('Input training weights: %s', self._model.weights)
      logging.debug('Input training state: %s', self._model.state)
      logging.debug('Input eval weights: %s', self._eval_model.weights)
      logging.debug('Input eval state: %s', self._eval_model.state)

    (self._model.weights, self._model.state,
     self._eval_model.weights, self._eval_model.state) = self._unreplicate(
         _make_weights_and_state_same_across_hosts(
             self._for_n_devices(
                 (self._model.weights, self._model.state,
                  self._eval_model.weights,
                  self._eval_model.state))))

    if logging.vlog_is_on(1):
      logging.debug(
          'Output training weights shape: %s',
          fastmath.nested_map(lambda x: x.shape, self._model.weights))
      logging.debug('Output training weights: %s', self._model.weights)
      logging.debug('Output training state: %s', self._model.state)
      logging.debug('Output eval weights: %s', self._eval_model.weights)
      logging.debug('Output eval state: %s', self._eval_model.state)
Ejemplo n.º 2
0
    def one_step(self, batch, rng, step=0, learning_rate=None):
        """Updates loss layer weights/state and optimizer slots by running one step.

    Args:
      batch: Batch of data to use for optimization.
      rng: Random number generator to use for running this step.
      step: Which step of the training are we running.
      learning_rate: Learning rate to use instead of the default one.

    Returns:
      Tuple (loss, stats) with new values from one step
      of training, where stats are current optimizer statistics.
    """
        # Update the learning rate if needed.
        if learning_rate is not None:
            self._opt_params['learning_rate'] = tl.for_n_devices(
                learning_rate, self._n_devices)

        # batch needs to be split across the local devices -- the difference
        # between _for_n_devices and _reshape_by_device is that the latter splits
        # the batch dim to batch // n_devices, vs _for_n_devices
        # broadcasts/replicates to n_devices dimension.
        if self._n_devices > 1:
            batch = tl.reshape_by_device(batch, self._n_devices)

        # separate rng needs to be created for each device
        if self._n_devices > 1:
            rng = jnp.stack(fastmath.random.split(rng, self._n_devices))

        weights = self._accelerated_loss_layer.weights
        state = self._accelerated_loss_layer.state
        if logging.vlog_is_on(1) and ((step & step - 1) == 0):
            # Prints every power of two, if debugging is enabled.
            logging.info('step[%d]', step)
            logging.info('opt_params[%s]', self._opt_params)
            logging.info('slots[%s]', self._slots)
            logging.info('weights[%s]', weights)
            logging.info('state[%s]', state)

        # NOTE: stats is a replicated dictionary of key to jnp arrays.
        (new_weights,
         new_slots), new_state, stats = self._accelerated_update_fn(
             (weights, self._slots), step, self._opt_params, batch, state, rng)

        if logging.vlog_is_on(1) and ((step & step - 1) == 0):
            logging.info('updated weights[%s]', new_weights)
            logging.info('stats[%s]', stats)

        self._accelerated_loss_layer.weights = new_weights
        self._accelerated_loss_layer.state = new_state
        self._slots = new_slots
        self._optimizer.slots = self._unreplicate(self._slots)
        return stats['loss'], stats
Ejemplo n.º 3
0
    def one_step(self, batch, rng, step=0, learning_rate=None):
        """Runs one training step, to update model and optimizer parameters.

    Args:
      batch: Batch of labeled training data.
      rng: Single-use random number generator (JAX PRNG key).
      step: Training step number.
      learning_rate: Learning rate for the optimizer; if None, use optimizer's
          default learning rate.

    Returns:
      Tuple of (loss, optimizer_stats), with the newly computed loss and
      updated stats as reported by the optimizer.
    """
        if learning_rate is not None:
            self._opt_params['learning_rate'] = tl.for_n_devices(
                learning_rate, self._n_devices)

        # Split the batch across devices (batch_dim --> batch_dim // n_devices)
        # and create new rng's 1-1 with devices.
        if self._n_devices > 1:
            batch = tl.reshape_by_device(batch, self._n_devices)
            rng = jnp.stack(fastmath.random.split(rng, self._n_devices))

        weights = self._accelerated_model_with_loss.weights
        state = self._accelerated_model_with_loss.state
        if logging.vlog_is_on(1) and ((step & step - 1) == 0):
            # Prints every power of two, if debugging is enabled.
            logging.info('step[%d]', step)
            logging.info('opt_params[%s]', self._opt_params)
            logging.info('slots[%s]', self._slots)
            logging.info('weights[%s]', weights)
            logging.info('state[%s]', state)

        # NOTE: stats is a replicated dictionary of key to jnp arrays.
        (new_weights,
         new_slots), new_state, stats = self._accelerated_update_fn(
             (weights, self._slots), step, self._opt_params, batch, state, rng)

        if logging.vlog_is_on(1) and ((step & step - 1) == 0):
            logging.info('updated weights[%s]', new_weights)
            logging.info('stats[%s]', stats)

        self._accelerated_model_with_loss.weights = new_weights
        self._accelerated_model_with_loss.state = new_state
        self._slots = new_slots
        self._optimizer.slots = self._unreplicate(self._slots)
        return stats['loss'], stats
Ejemplo n.º 4
0
def get_cache_key(xla_computation, compile_options, backend) -> str:
    """Creates a hashed string to use as a key to the compilation cache.

       get_cache_key takes in the xla_computation and compile_options of a program and hashes
       all the components into a uniuqe byte string. This byte string is returned as a regular
       string that is 256 characters long.

       Typical return value example:

            '14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf'

    """
    hash_obj = hashlib.sha256()
    # The HLO op_name metadata sometimes includes Python function pointers,
    # which cause spurious cache misses. Scrub anything that looks like a
    # function pointer. Example op_name metadata:
    #  op_name="jit(s)/custom_jvp_call_jaxpr
    #   [ jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f3fa30f0940>\n
    #     num_consts=0 ]"
    # TODO(skye): in theory this could cause us to scrub meaningful binary proto
    # data. Do something more robust.
    serialized_hlo = xla_computation.as_serialized_hlo_module_proto()
    scrubbed_hlo = re.sub(b" at 0x[a-f0-9]+>", b" at 0x...>", serialized_hlo)
    hash_obj.update(scrubbed_hlo)
    if logging.vlog_is_on(1):
        logging.vlog(
            1,
            f"get_cache_key hash after serializing computation: {hash_obj.digest().hex()}"
        )
    _hash_compile_options(hash_obj, compile_options)
    if logging.vlog_is_on(1):
        logging.vlog(
            1,
            f"get_cache_key hash after serializing compile_options: {hash_obj.digest().hex()}"
        )
    hash_obj.update(bytes(jax.lib.version))
    if logging.vlog_is_on(1):
        logging.vlog(
            1,
            f"get_cache_key hash after serializing jax_lib version: {hash_obj.digest().hex()}"
        )
    _hash_platform(hash_obj, backend)
    if logging.vlog_is_on(1):
        logging.vlog(
            1,
            f"get_cache_key hash after serializing the backend: {hash_obj.digest().hex()}"
        )
    return hash_obj.digest().hex()
Ejemplo n.º 5
0
  def _run_one_step(self, weights, state, slots, opt_params):
    """Updates model weights/state and optimizer slots by running one step.

    Args:
      weights: Weights from model being trained.
      state: State (non-weight parameters) from model being trained.
      slots: Updatable weights for the optimizer in this training loop.
      opt_params: Dictionary of optimizer (hyper)parameters,
        e.g. learning rate, momentum.

    Returns:
      Tuple (loss, weights, state, slots, stats) with new values from one step
      of training, where stats are current optimizer statistics.
    """
    step = self.step
    # Update the learning rate.
    opt_params['learning_rate'] = self._for_n_devices(
        self._task.learning_rate(step))

    batch = self._task.next_batch()
    # batch needs to be split across the local devices -- the difference
    # between _for_n_devices and _reshape_by_device is that the latter splits
    # the batch dim to batch // n_devices, vs _for_n_devices
    # broadcasts/replicates to n_devices dimension.
    batch = self._reshape_by_device(batch)

    rng = self.new_rng()
    if self.n_devices > 1:
      rng = jnp.stack(jax_random.split(rng, self.n_devices))

    if logging.vlog_is_on(1) and ((step & step - 1) == 0):
      # Prints every power of two, if debugging is enabled.
      logging.info('step[%d]', step)
      logging.info('opt_params[%s]', opt_params)
      logging.info('weights[%s]', weights)

    # NOTE: stats is a replicated dictionary of key to jnp arrays.
    (weights, slots), state, stats = (
        self._accelerated_update_fn(
            (weights, slots), step, opt_params, batch, state, rng)
        )

    if logging.vlog_is_on(1) and ((step & step - 1) == 0):
      logging.info('updated weights[%s]', weights)
      logging.info('stats[%s]', stats)

    return stats['loss'], weights, state, slots, stats
Ejemplo n.º 6
0
def _apply_reapated_text_masking(
    config: RetrieverConfig,
    question_hash: tf.Tensor,
    question_hash_transposed: tf.Tensor,
    labels: tf.Tensor,
    logits: tf.Tensor,
) -> tf.Tensor:
    """Applies repated text masking.

  Args:
    config: Retriever config.
    question_hash: <int64>[global_batch_size, 1]
    question_hash_transposed: <int64>[1, batch_size]
    labels: <int64>[batch_size, global_batch_size * num_tables]
    logits: <float>[batch_size, global_batch_size * num_tables]

  Returns:
    Masked logits (same shape / dtype).
  """
    # Make sure not all hashes are 0.
    # This indicates the "question_hash" feature wasn't set.
    assert_op = tf.assert_equal(
        tf.math.reduce_all(tf.math.equal(question_hash, 0)), [False])
    with tf.control_dependencies([assert_op]):
        logging.vlog(2, "question_hash: %s", question_hash)
        logging.vlog(2, "question_hash_transposed: %s",
                     question_hash_transposed)
        logging.vlog(2, "labels: %s", labels)
        logging.vlog(2, "logits: %s", logits)
        # <bool>[batch_size, global_batch_size]
        repeated_texts = tf.math.equal(question_hash, question_hash_transposed)
        if config.use_mined_negatives:
            batch_size = repeated_texts.shape[0]
            global_batch_size = repeated_texts.shape[1]
            num_tables = logits.shape[1] // global_batch_size
            # <bool>[batch_size, global_batch_size * num_tables]
            repeated_texts = tf.concat([
                repeated_texts,
                tf.zeros(shape=(batch_size,
                                (num_tables - 1) * global_batch_size),
                         dtype=tf.bool)
            ],
                                       axis=1)
        repeated_texts = (
            repeated_texts
            # Makes sure original correct question pair isn't masked
            & tf.math.equal(labels, 0))
        logging.vlog(2, "repeated texts: %s", repeated_texts)
    ops = []
    if logging.vlog_is_on(2):
        ops.append(
            tf.print(
                "repeated texts content:",
                question_hash,
                repeated_texts,
                output_stream=logging.info,
            ))
    with tf.control_dependencies(ops):
        return tf.where(repeated_texts, tf.zeros_like(logits) - _INF, logits)
Ejemplo n.º 7
0
def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
    if logging.vlog_is_on(1):
        # Log the hash of just this entry
        fresh_hash_obj = hashlib.sha256()
        hashfn(fresh_hash_obj)
        logging.vlog(1, "get_cache_key hash of serialized %s: %s",
                     last_serialized,
                     fresh_hash_obj.digest().hex())
        # Log the cumulative hash
        logging.vlog(1, "get_cache_key hash after serializing %s: %s",
                     last_serialized,
                     hash_obj.digest().hex())
Ejemplo n.º 8
0
    def match(self, context: MatchContext,
              candidate: Any) -> Optional[MatchInfo]:
        """Matches a candidate value.

    Args:
      context: A :class:`MatchContext` object with additional metadata.
      candidate: A candidate object to be matched against.

    Returns:
      A :class:`MatchInfo` object, or ``None`` if the match failed.
    """
        matched = self._match(context, candidate)
        if logging.vlog_is_on(self._log_level):
            self._log_match(context, candidate, matched)
        return matched
Ejemplo n.º 9
0
    def _run_command(self, desc: Text, args: Sequence[Text]):
        """Runs the given commands.

    Args:
      desc: Textual description of what the command is doing. Emitted to stdout.
      args: The command line arguments.

    Returns:
      Stdout of the command.

    Raises:
      subprocess.CalledProcessError: If subprocess returns non-zero code.
    """
        # Print the command line with the runfiles directory prefix elided to reduce
        # clutter.
        if logging.get_verbosity() > 0:
            args = list(args) + ['-v={}'.format(logging.get_verbosity())]
        cmd_line = subprocess.list2cmdline(args)
        logging.vlog(1, '%s:  %s', desc, cmd_line)
        start = time.time()
        basename = os.path.basename(args[0])
        stderr_path = os.path.join(self._run_dir, basename + '.stderr')
        with open(stderr_path, 'w') as f_stderr:
            comp = subprocess.run(list(args) + ['--logtostderr'],
                                  cwd=self._run_dir,
                                  stdout=subprocess.PIPE,
                                  stderr=f_stderr,
                                  check=False)

        if logging.vlog_is_on(4):
            logging.vlog(4, '{} stdout:'.format(basename))
            # stdout and stderr can be long so split them by line to avoid clipping.
            for line in comp.stdout.decode('utf-8').splitlines():
                logging.vlog(4, line)

            logging.vlog(4, '{} stderr:'.format(basename))
            with open(stderr_path, 'r') as f:
                for line in f.read().splitlines():
                    logging.vlog(4, line)

        logging.vlog(1, '%s complete, elapsed %0.2fs', desc,
                     time.time() - start)

        comp.check_returncode()

        return comp.stdout.decode('utf-8')
Ejemplo n.º 10
0
def start_python_aggregator(worker_port: str,
                            aggregator_port: str) -> subprocess.Popen:
    """Starts running Python aggregator in a subprocess."""
    python_service_binary = os.path.join(
        tf.compat.v1.resource_loader.get_root_dir_with_all_resources(),
        tf.compat.v1.resource_loader.get_path_to_datafile('test_aggregator'))

    args = [
        python_service_binary,
        f'--worker_port={worker_port}',
        f'--aggregator_port={aggregator_port}',
    ]
    logging.info('Starting python aggregator service via: %s', args)
    if logging.vlog_is_on(1):
        pid = subprocess.Popen(args, stdout=sys.stdout, stderr=sys.stderr)
    else:
        pid = subprocess.Popen(args)
    return pid
Ejemplo n.º 11
0
async def decklist(ctx: Context, url: str, mode: str = 'compact') -> None:
    if url.startswith('<') and url.endswith('>'):
        url = url[1:-1]
    logging.info('Looking up decklist for: %s', url)
    handler = decklist_handlers.lookup(url)
    try:
        decklist = await handler(ctx, url)
    except requests.RequestException:
        logging.exception('RequestException during decklist handler.')
        decklist = None
    if decklist:
        logging.info('Found decklist named: %s', decklist.name)
        if logging.vlog_is_on(1):
            logging.vlog(1, 'Decklist contents: %s',
                         pprint.pformat(decklist.to_embed().to_dict()))
        await ctx.send(embed=decklist.to_embed(mode == 'flat'))
    else:
        logging.info('No decklist found for: %s', url)
Ejemplo n.º 12
0
 def to_embed(self, flat: bool = False) -> Embed:
     embed = Embed(title=self.name or '', url=self.url or Embed.Empty)
     embed.set_author(name=self.author or '',
                      url=self.author_url or Embed.Empty)
     embed.set_thumbnail(url=self.thumbnail or Embed.Empty)
     cards_by_type = self._get_cards_by_type()
     if logging.vlog_is_on(1):
         logging.vlog(1, 'Cards by type are: %s',
                      pprint.pformat(cards_by_type))
     for type_ in [
             'Land', 'Creature', 'Sorcery', 'Instant', 'Artifact',
             'Enchantment', 'Plainswalker', 'Unknown'
     ]:
         cards_body = '\n'.join(
             f'{num} {card.name}' + (f'   {card.cost}' if flat else '')
             for card, num in cards_by_type.get(type_, []))
         if not cards_body:
             continue
         cards_body = manamojidb.substitute(cards_body)
         embed.add_field(name=type_, value=cards_body, inline=not flat)
     return embed
Ejemplo n.º 13
0
def _sharded_callable(
        fun: lu.WrappedFun, nparts: Optional[int],
        in_parts: Tuple[pxla.PartitionsOrReplicated, ...],
        out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]],
        local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]],
        local_out_parts_thunk: Callable[[], Optional[Tuple[
            pxla.PartitionsOrReplicated,
            ...]]], local_nparts: Optional[int], name: str, *abstract_args):
    nrep = 1

    if local_in_parts is None:
        local_in_parts = in_parts

    global_abstract_args = [
        pxla.get_global_aval(arg, parts,
                             lparts) for arg, parts, lparts in safe_zip(
                                 abstract_args, in_parts, local_in_parts)
    ]

    if logging.vlog_is_on(2):
        logging.vlog(2, "abstract_args: %s", abstract_args)
        logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
        logging.vlog(2, "in_parts: %s", in_parts)
        logging.vlog(2, "local_in_parts: %s", local_in_parts)

    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
        fun, global_abstract_args)

    platform = xb.get_backend().platform

    nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
    assert nparts is not None
    if nparts > xb.device_count():
        raise ValueError(
            f"sharded_jit computation requires {nparts} devices, "
            f"but only {xb.device_count()} devices are available.")
    if xb.local_device_count() < nparts < xb.device_count():
        raise NotImplementedError(
            f"sharded_jit across multiple hosts must use all available devices. "
            f"Got {nparts} out of {xb.device_count()} requested devices "
            f"(local device count: {xb.local_device_count()})")

    if local_nparts is None:
        if nparts > xb.local_device_count():
            raise ValueError(
                "Specify 'local_nparts' when using cross-process sharded_jit "
                "and all inputs and outputs are replicated.")
        else:
            local_nparts = nparts
    if local_nparts > xb.local_device_count():
        raise ValueError(
            f"sharded_jit computation requires {local_nparts} local devices, "
            f"but only {xb.local_device_count()} local devices are available.")

    if logging.vlog_is_on(2):
        logging.vlog(2, "nparts: %d  local_nparts: %d", nparts, local_nparts)

    out_parts = out_parts_thunk()

    local_out_parts = local_out_parts_thunk()
    if local_out_parts is None:
        local_out_parts = out_parts

    if logging.vlog_is_on(2):
        logging.vlog(2, "out_parts: %s", out_parts)
        logging.vlog(2, "local_out_parts: %s", local_out_parts)

    local_out_avals = [
        pxla.get_local_aval(out, parts,
                            lparts) for out, parts, lparts in safe_zip(
                                global_out_avals, out_parts, local_out_parts)
    ]

    log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
    logging.log(log_priority, "Compiling %s for %d devices with args %s.",
                fun.__name__, nparts, global_abstract_args)

    axis_env = xla.AxisEnv(nrep, (), ())
    unordered_effects = [
        eff for eff in jaxpr.effects if eff not in core.ordered_effects
    ]
    ordered_effects = [
        eff for eff in jaxpr.effects if eff in core.ordered_effects
    ]
    module, _ = mlir.lower_jaxpr_to_module(
        f"spjit_{fun.__name__}",
        core.ClosedJaxpr(jaxpr, consts),
        unordered_effects,
        ordered_effects,
        platform=platform,
        axis_context=mlir.ReplicaAxisContext(axis_env),
        name_stack=new_name_stack(wrap_name(name, "sharded_jit")),
        donated_args=[False] * len(in_parts),
        arg_shardings=safe_map(xla.sharding_to_proto, in_parts),
        result_shardings=safe_map(xla.sharding_to_proto, out_parts))
    built = xc._xla.mlir.mlir_module_to_xla_computation(
        mlir.module_to_string(module), use_tuple_args=False, return_tuple=True)

    if nparts <= xb.local_device_count():
        devices = xb.local_devices()[:nparts]
    else:
        assert nparts == xb.device_count()
        devices = xb.devices()
    device_assignment = np.array([[d for d in devices]])
    device_assignment = np.reshape(device_assignment, (-1, nparts))
    # device_assignment = None  # TODO(skye): replace with default device assignment?

    compiled = dispatch.backend_compile(
        xb.get_backend(), built,
        xb.get_compile_options(nrep, nparts, device_assignment))

    input_specs = [
        pxla.partitioned_sharding_spec(local_nparts, parts, aval)
        for parts, aval in zip(local_in_parts, abstract_args)
    ]
    input_indices = [
        pxla.spec_to_indices(aval.shape, spec) if spec is not None else None
        for aval, spec in zip(abstract_args, input_specs)
    ]

    handle_args = partial(pxla.shard_args, compiled.local_devices(),
                          input_indices)
    handle_outs = _avals_to_results_handler(
        nrep,
        local_nparts,  # type: ignore
        local_out_parts,
        local_out_avals)
    return partial(_execute_spatially_partitioned, compiled, handle_args,
                   handle_outs)
Ejemplo n.º 14
0
    def train_epoch(self, evaluate=True):
        epoch_start_time = time.time()

        # Evaluate the policy.
        policy_eval_start_time = time.time()
        if evaluate and (self.epoch + 1) % self._eval_every_n == 0:
            self.evaluate()
        policy_eval_time = policy_based_utils.get_time(policy_eval_start_time)

        def write_metric(key, value):
            self._train_sw.scalar(key, value, step=self.epoch)
            self._history.append('train', key, self.epoch, value)

        # Get fresh trajectories every time.
        self._should_reset_train_env = True

        trajectory_collection_start_time = time.time()
        logging.vlog(1, 'AWR epoch [% 6d]: collecting trajectories.',
                     self._epoch)
        trajs, _, timing_info, self._model_state = self.collect_trajectories(
            train=True, temperature=1.0, raw_trajectory=True)
        del timing_info
        trajectory_collection_time = policy_based_utils.get_time(
            trajectory_collection_start_time)

        logging.vlog(1, 'AWR epoch [% 6d]: n_trajectories [%s].', self._epoch,
                     len(trajs))

        # Convert these into numpy now.
        def extract_obs_act_rew_dones(traj_np):
            return traj_np[0], traj_np[1], traj_np[2], traj_np[4]

        trajs_np = [extract_obs_act_rew_dones(traj.as_numpy) for traj in trajs]

        # number of new actions.
        new_sample_count = sum(traj[1].shape[0] for traj in trajs_np)
        self._n_observations_seen += new_sample_count
        logging.vlog(1, 'AWR epoch [% 6d]: new_sample_count [%d].',
                     self._epoch, new_sample_count)

        if self._should_write_summaries:
            write_metric('trajs/batch', len(trajs))
            write_metric('trajs/new_sample_count', new_sample_count)

        # The number of trajectories, i.e. `B`can keep changing from iteration to
        # iteration, since we are capped on the number of observations requested.
        # So let's operate on each trajectory on this own?

        # TODO(afrozm): So should our batches look like (B, T+1, *OBS) or B
        # different examples of (T+1, *OBS) each. Since B can keep changing?

        # Add these to the replay buffer.
        for traj in trajs:
            _ = self._replay_buffer.store(traj)

        rewards = jnp.array([jnp.sum(traj[2]) for traj in trajs_np])
        avg_reward = jnp.mean(rewards)
        std_reward = jnp.std(rewards)
        max_reward = jnp.max(rewards)
        min_reward = jnp.min(rewards)

        self._log('train', 'train/reward_mean_truncated', avg_reward)
        if evaluate and not self._separate_eval and self._should_write_summaries:
            metrics = {'raw': {1.0: {'mean': avg_reward, 'std': std_reward}}}
            policy_based_utils.write_eval_reward_summaries(
                metrics, self._log, self.epoch)

        logging.vlog(
            1, 'AWR epoch [% 6d]: Rewards avg=[%0.2f], max=[%0.2f], '
            'min=[%0.2f].', self.epoch, avg_reward, max_reward, min_reward)

        if self._should_write_summaries:
            write_metric('reward/avg', avg_reward)
            write_metric('reward/std', std_reward)
            write_metric('reward/max', max_reward)
            write_metric('reward/min', min_reward)

        # Wrap these observations/rewards inside ReplayBuffer.
        idx, valid_mask, valid_idx = self._replay_buffer.get_valid_indices()

        # pylint: disable=g-complex-comprehension
        observations = [
            self._replay_buffer.get(
                replay_buffer.ReplayBuffer.OBSERVATIONS_KEY,
                idx[start_idx:end_plus_1_idx])
            for (start_idx,
                 end_plus_1_idx) in self._replay_buffer.iterate_over_paths(idx)
        ]

        rewards = [
            self._replay_buffer.get(replay_buffer.ReplayBuffer.REWARDS_KEY,
                                    idx[start_idx:end_plus_1_idx][:-1])
            for (start_idx,
                 end_plus_1_idx) in self._replay_buffer.iterate_over_paths(idx)
        ]
        # pylint: enable=g-complex-comprehension

        t_final = awr_utils.padding_length(rewards, boundary=self._boundary)
        logging.vlog(1, 'AWR epoch [% 6d]: t_final [%s].', self._epoch,
                     t_final)

        if self._should_write_summaries:
            write_metric('trajs/t_final', t_final)

        # These padded observations are over *all* the non-final observations in
        # the entire replay buffer.
        # Shapes:
        # padded_observations      = (B, T + 1, *OBS)
        # padded_observations_mask = (B, T + 1)
        padded_observations, padded_observations_mask = (
            awr_utils.pad_array_to_length(observations, t_final + 1))

        batch = len(observations)
        self._check_shapes('padded_observations',
                           '(batch, t_final + 1)',
                           padded_observations, (batch, t_final + 1),
                           array_prefix=2)
        self._check_shapes('padded_observations_mask', '(batch, t_final + 1)',
                           padded_observations_mask, (batch, t_final + 1))

        # Shapes:
        # padded_rewards      = (B, T)
        # padded_rewards_mask = (B, T)
        padded_rewards, padded_rewards_mask = awr_utils.pad_array_to_length(
            rewards, t_final)
        self._check_shapes('padded_rewards', '(batch, t_final)',
                           padded_rewards, (batch, t_final))
        self._check_shapes('padded_rewards_mask', '(batch, t_final)',
                           padded_rewards_mask, (batch, t_final))

        # Shapes:
        # lengths = (B,)
        lengths = jnp.sum(padded_rewards_mask, axis=1, dtype=jnp.int32)
        self._check_shapes('lengths', '(batch,)', lengths, (batch, ))

        # TODO(pkozakowski): Pass the actual actions here, to enable autoregressive
        # action sampling.
        dummy_actions = jnp.zeros(
            (batch, t_final + 1) + self._action_shape,
            self._action_dtype,
        )

        # Shapes:
        # log_probabs_traj       = (B, T + 1, #controls, #actions)
        # value_predictions_traj = (B, T + 1)
        log_probabs_traj, value_predictions_traj, self._model_state, unused_rng = (
            self._policy_fun_all_timesteps(padded_observations, lengths,
                                           self._model_state, self._get_rng()))
        self._check_shapes(
            'log_probabs_traj', '(batch, t_final + 1, n_controls, n_actions)',
            log_probabs_traj,
            (batch, t_final + 1, self._n_controls, self._n_actions))
        self._check_shapes('value_predictions_traj', '(batch, t_final + 1)',
                           value_predictions_traj, (batch, t_final + 1))

        # Zero out the padding's value predictions, since the net may give some
        # prediction to the padding observations.
        value_predictions_traj *= padded_observations_mask

        # Compute td-lambda returns, and reshape to match value_predictions_traj.
        list_td_lambda_returns = awr_utils.batched_compute_td_lambda_return(
            padded_rewards, padded_rewards_mask, value_predictions_traj,
            padded_observations_mask, self._gamma, self._td_lambda)

        if logging.vlog_is_on(1) and list_td_lambda_returns:
            l = len(list_td_lambda_returns)
            logging.vlog(1, f'Len of list_td_lambda_returns: {l}.')
            self._log_shape('td_lambda_returns[0]', list_td_lambda_returns[0])

        # pad an extra 0 for each to match lengths of value predictions.
        list_target_values = [
            np.pad(l, (0, 1), 'constant') for l in list_td_lambda_returns
        ]

        if batch != len(list_target_values):
            raise ValueError(f'batch != len(list_target_values) : '
                             f'{batch} vs {len(list_target_values)}')

        # Shape: (len(idx),)
        target_values = np.concatenate(list_target_values)
        self._check_shapes('target_values', '(len(idx),)', target_values,
                           (len(idx), ))

        # Shape: (len(idx),)
        vals = self.flatten_vals(value_predictions_traj,
                                 padded_observations_mask)
        self._check_shapes('vals', '(len(idx),)', vals, (len(idx), ))

        # Calculate advantages.
        adv, norm_adv, adv_mean, adv_std = self._calc_adv(
            target_values, vals, valid_mask)
        self._check_shapes('norm_adv', '(len(idx),)', norm_adv, (len(idx), ))

        adv_weights, adv_weights_mean, adv_weights_min, adv_weights_max = (
            self._calc_adv_weights(norm_adv, valid_mask))
        self._check_shapes('adv_weights', '(len(idx),)', adv_weights,
                           (len(idx), ))

        del adv, adv_mean, adv_std
        del adv_weights_min, adv_weights_max, adv_weights_mean

        combined_steps = int(
            jnp.ceil(self._optimization_steps * new_sample_count /
                     self._num_samples_to_collect))
        optimization_start_time = time.time()
        combined_losses = self._update_combined(combined_steps, valid_idx,
                                                target_values, adv_weights)
        optimization_time = policy_based_utils.get_time(
            optimization_start_time)

        self._epoch += 1

        if self._should_write_summaries:
            write_metric('combined/optimization_steps', combined_steps)
            epoch_time = policy_based_utils.get_time(epoch_start_time)
            timing_dict = {
                'epoch': epoch_time,
                'trajectory_collection': trajectory_collection_time,
                'optimization': optimization_time,
                'policy_eval': policy_eval_time,
            }

            if self._should_write_summaries:
                for k, v in timing_dict.items():
                    write_metric('timing/{}'.format(k), v)

            # Only dump the average post losses.
            if combined_losses:
                for k, v in combined_losses.items():
                    if 'post_entropy' in k:
                        write_metric(k.replace('post_entropy', 'entropy'), v)
                    if 'post_loss' in k:
                        write_metric(k.replace('post_loss', 'loss'), v)

        self.flush_summaries()
Ejemplo n.º 15
0
def _log_cache_key_hash(hash_obj, last_serialized: str):
  if logging.vlog_is_on(1):
    logging.vlog(1, "get_cache_key hash after serializing %s: %s",
                 last_serialized, hash_obj.digest().hex())
Ejemplo n.º 16
0
def _sharded_callable(
        fun: lu.WrappedFun, nparts: Optional[int],
        in_parts: Tuple[pxla.PartitionsOrReplicated, ...],
        out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]],
        local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]],
        local_out_parts_thunk: Callable[[], Optional[Tuple[
            pxla.PartitionsOrReplicated,
            ...]]], local_nparts: Optional[int], name: str, *abstract_args):
    nrep = 1

    if local_in_parts is None:
        local_in_parts = in_parts

    global_abstract_args = [
        pxla.get_global_aval(arg, parts,
                             lparts) for arg, parts, lparts in safe_zip(
                                 abstract_args, in_parts, local_in_parts)
    ]

    if logging.vlog_is_on(2):
        logging.vlog(2, "abstract_args: %s", abstract_args)
        logging.vlog(2, "global_abstract_args: %s", global_abstract_args)
        logging.vlog(2, "in_parts: %s", in_parts)
        logging.vlog(2, "local_in_parts: %s", local_in_parts)

    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
        fun, global_abstract_args)

    if xb.get_backend().platform not in ["tpu", "gpu"]:
        # TODO(skye): fall back to regular jit?
        raise ValueError("sharded_jit not supported for " +
                         xb.get_backend().platform)

    nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
    assert nparts is not None
    if nparts > xb.device_count():
        raise ValueError(
            f"sharded_jit computation requires {nparts} devices, "
            f"but only {xb.device_count()} devices are available.")
    if xb.local_device_count() < nparts < xb.device_count():
        raise NotImplementedError(
            f"sharded_jit across multiple hosts must use all available devices. "
            f"Got {nparts} out of {xb.device_count()} requested devices "
            f"(local device count: {xb.local_device_count()})")

    if local_nparts is None:
        if nparts > xb.local_device_count():
            raise ValueError(
                "Specify 'local_nparts' when using cross-process sharded_jit "
                "and all inputs and outputs are replicated.")
        else:
            local_nparts = nparts
    if local_nparts > xb.local_device_count():
        raise ValueError(
            f"sharded_jit computation requires {local_nparts} local devices, "
            f"but only {xb.local_device_count()} local devices are available.")

    if logging.vlog_is_on(2):
        logging.vlog(2, "nparts: %d  local_nparts: %d", nparts, local_nparts)

    out_parts = out_parts_thunk()

    local_out_parts = local_out_parts_thunk()
    if local_out_parts is None:
        local_out_parts = out_parts

    if logging.vlog_is_on(2):
        logging.vlog(2, "out_parts: %s", out_parts)
        logging.vlog(2, "local_out_parts: %s", local_out_parts)

    local_out_avals = [
        pxla.get_local_aval(out, parts,
                            lparts) for out, parts, lparts in safe_zip(
                                global_out_avals, out_parts, local_out_parts)
    ]

    log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
    logging.log(
        log_priority, f"Compiling {fun.__name__} for {nparts} devices with "
        f"args {global_abstract_args}.")

    c = xb.make_computation_builder("spjit_{}".format(fun.__name__))
    xla_consts = _map(partial(xb.constant, c), consts)
    xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
    axis_env = xla.AxisEnv(nrep, (), ())
    out_nodes = xla.jaxpr_subcomp(
        c, jaxpr, None, axis_env, xla_consts,
        extend_name_stack(wrap_name(name, "sharded_jit")), *xla_args)
    out_tuple = xb.with_sharding(c, out_parts, xops.Tuple, c, out_nodes)
    built = c.Build(out_tuple)

    if nparts <= xb.local_device_count():
        devices = xb.local_devices()[:nparts]
    else:
        assert nparts == xb.device_count()
        devices = xb.devices()
    device_assignment = np.array([[d.id for d in devices]])
    device_assignment = np.reshape(device_assignment, (-1, nparts))
    # device_assignment = None  # TODO(skye): replace with default device assignment?

    compiled = xla.backend_compile(
        xb.get_backend(), built,
        xb.get_compile_options(nrep, nparts, device_assignment))

    input_specs = [
        pxla.partitioned_sharding_spec(local_nparts, parts, aval)
        for parts, aval in zip(local_in_parts, abstract_args)
    ]
    input_indices = [
        pxla.spec_to_indices(aval.shape, spec) if spec is not None else None
        for aval, spec in zip(abstract_args, input_specs)
    ]

    handle_args = partial(pxla.shard_args, compiled.local_devices(),
                          input_indices)
    handle_outs = _avals_to_results_handler(
        nrep,
        local_nparts,  # type: ignore
        local_out_parts,
        local_out_avals)
    return partial(_execute_spatially_partitioned, compiled, handle_args,
                   handle_outs)
Ejemplo n.º 17
0
def _test_do_logging():
    """Do some log operations."""
    logging.vlog(3, 'This line is VLOG level 3')
    logging.vlog(2, 'This line is VLOG level 2')
    logging.log(2, 'This line is log level 2')
    if logging.vlog_is_on(2):
        logging.log(1, 'VLOG level 1, but only if VLOG level 2 is active')

    logging.vlog(1, 'This line is VLOG level 1')
    logging.log(1, 'This line is log level 1')
    logging.debug('This line is DEBUG')

    logging.vlog(0, 'This line is VLOG level 0')
    logging.log(0, 'This line is log level 0')
    logging.info('Interesting Stuff\0')
    logging.info('Interesting Stuff with Arguments: %d', 42)
    logging.info('%(a)s Stuff with %(b)s', {
        'a': 'Interesting',
        'b': 'Dictionary'
    })

    with mock.patch.object(timeit, 'default_timer') as mock_timer:
        mock_timer.return_value = 0
        while timeit.default_timer() < 9:
            logging.log_every_n_seconds(logging.INFO,
                                        'This should appear 5 times.', 2)
            mock_timer.return_value = mock_timer() + .2

    for i in xrange(1, 5):
        logging.log_first_n(logging.INFO, 'Info first %d of %d', 2, i, 2)
        logging.log_every_n(logging.INFO, 'Info %d (every %d)', 3, i, 3)

    logging.vlog(-1, 'This line is VLOG level -1')
    logging.log(-1, 'This line is log level -1')
    logging.warning('Worrying Stuff')
    for i in xrange(1, 5):
        logging.log_first_n(logging.WARNING, 'Warn first %d of %d', 2, i, 2)
        logging.log_every_n(logging.WARNING, 'Warn %d (every %d)', 3, i, 3)

    logging.vlog(-2, 'This line is VLOG level -2')
    logging.log(-2, 'This line is log level -2')
    try:
        raise OSError('Fake Error')
    except OSError:
        saved_exc_info = sys.exc_info()
        logging.exception('An Exception %s')
        logging.exception('Once more, %(reason)s', {'reason': 'just because'})
        logging.error('Exception 2 %s', exc_info=True)
        logging.error('Non-exception', exc_info=False)

    try:
        sys.exc_clear()
    except AttributeError:
        # No sys.exc_clear() in Python 3, but this will clear sys.exc_info() too.
        pass

    logging.error('Exception %s', '3', exc_info=saved_exc_info)
    logging.error('No traceback', exc_info=saved_exc_info[:2] + (None, ))

    logging.error('Alarming Stuff')
    for i in xrange(1, 5):
        logging.log_first_n(logging.ERROR, 'Error first %d of %d', 2, i, 2)
        logging.log_every_n(logging.ERROR, 'Error %d (every %d)', 3, i, 3)
    logging.flush()
    def evaluate(
        self, seed: int, filesystem: FsWrapperBase = FsPathlibWrapper()
    ) -> pd.DataFrame:
        """Executes a trial.

        1. Check if the results for the trial have already been computed.
        2. Load the DataSet.
        3. Instantiate Halo Simulator.
        4. Instantiate Modeling Strategy.
        5. Fit model.
        6. Generate set of test points.
        7. Compute metrics.
        8. Construct output DataFrame.
        9. Save to disk.

        Args:
          seed:  A seed value that is used to initialize the random
            number generator.
          filesystem:  The filesystem object that manages all file operations.

        Returns:
          A single row DataFrame containing the results of the evaluation
          of this trial.
        """
        logging.vlog(2, f"Dataset {self._data_set_name}")
        logging.vlog(2, f"Trial   {self._trial_descriptor}")

        rng = np.random.default_rng(seed=seed)
        np.random.seed(seed)

        trial_results_path = self._compute_trial_results_path()

        if trial_results_path.startswith("gs://"):
            filesystem.set_default_client_to_gs_client()

        if filesystem.is_file(trial_results_path):
            logging.vlog(2, "  --> Returning previously computed result")
            try:
                with filesystem.open(trial_results_path) as file:
                    return pd.read_csv(file)
            except Exception as e:
                filesystem.unlink(trial_results_path)
                logging.vlog(
                    2, f"  --> {e}. Failed reading existing result. Re-evaluate."
                )

        # The pending directory contains one entry for each currently executing
        # experimental trial.  If a computation appears to hang, this can be
        # used to check which evaluations are still pending.
        experiment_dir_parent = filesystem.parent(self._experiment_dir)
        pending_path = f"{experiment_dir_parent}/pending/{hashlib.md5(trial_results_path.encode()).hexdigest()}"
        filesystem.mkdir(filesystem.parent(pending_path), parents=True, exist_ok=True)
        filesystem.write_text(
            pending_path,
            f"{datetime.now()}\n{self._data_set_name}\n{self._trial_descriptor}\n\n",
        )

        dataset = self._data_design.by_name(self._data_set_name)
        privacy_tracker = PrivacyTracker()
        halo = HaloSimulator(
            dataset, self._trial_descriptor.system_params, privacy_tracker
        )
        privacy_budget = self._trial_descriptor.experiment_params.privacy_budget
        modeling_strategy = (
            self._trial_descriptor.modeling_strategy.instantiate_strategy()
        )
        single_publisher_dataframe = pd.DataFrame()
        max_frequency = self._trial_descriptor.experiment_params.max_frequency
        try:
            reach_surface = modeling_strategy.fit(
                halo, self._trial_descriptor.system_params, privacy_budget
            )
            test_points = list(
                self._trial_descriptor.experiment_params.generate_test_points(
                    dataset, rng
                )
            )
            true_reach = [
                halo.true_reach_by_spend(
                    t, self._trial_descriptor.experiment_params.max_frequency
                )
                for t in test_points
            ]
            fitted_reach = [
                reach_surface.by_spend(
                    t, self._trial_descriptor.experiment_params.max_frequency
                )
                for t in test_points
            ]
            metrics = aggregate(true_reach, fitted_reach)
            if self._analysis_type == SINGLE_PUB_ANALYSIS:
                single_publisher_dataframe = (
                    self._compute_single_publisher_fractions_dataframe(
                        halo, reach_surface, max_frequency
                    )
                )
        except Exception as inst:
            if not logging.vlog_is_on(2):
                logging.vlog(1, f"Dataset {self._data_set_name}")
                logging.vlog(1, f"Trial   {self._trial_descriptor}")
            logging.vlog(1, f"Modeling failure: {inst}")
            logging.vlog(2, traceback.format_exc())
            metrics = aggregate_on_exception(inst)
            if self._analysis_type == SINGLE_PUB_ANALYSIS:
                single_publisher_dataframe = (
                    self._single_publisher_fractions_dataframe_on_exception(max_frequency)
                )

        independent_vars = self._make_independent_vars_dataframe()
        privacy_tracking_vars = self._make_privacy_tracking_vars_dataframe(
            privacy_tracker
        )
        result = pd.concat(
            [
                independent_vars,
                privacy_tracking_vars,
                metrics,
                single_publisher_dataframe,
            ],
            axis=1,
        )
        filesystem.mkdir(
            filesystem.parent(trial_results_path), parents=True, exist_ok=True
        )
        filesystem.write_text(trial_results_path, result.to_csv(index=False))
        filesystem.unlink(pending_path, missing_ok=True)

        return result