Exemple #1
0
    def _metric_fn(*args, **kwargs):
      """The wrapping function to be returned."""

      # We can only be passed in either a dict or a list of tensors.
      args = args if args else kwargs
      metrics = call_eval_metrics((metric_fn, args))
      if not self._use_tpu:
        return metrics

      logging.log_first_n(logging.INFO,
                          "Writing eval metrics to variables for TPU", 1)
      wrapped_metrics = {}
      for i, key in enumerate(sorted(metrics)):
        tensor, op = tf_compat.metric_op(metrics[key])
        # key cannot be in var name since it may contain illegal chars.
        var = tf_compat.v1.get_variable(
            "metric_{}".format(i),
            shape=tensor.shape,
            dtype=tensor.dtype,
            trainable=False,
            initializer=tf_compat.v1.zeros_initializer(),
            collections=[tf_compat.v1.GraphKeys.LOCAL_VARIABLES])
        if isinstance(op, tf.Operation) or op.shape != tensor.shape:
          with tf.control_dependencies([op]):
            op = var.assign(tensor)
        metric = (var, var.assign(op))
        wrapped_metrics[key] = metric
      return wrapped_metrics
Exemple #2
0
    def forward(self, numerical_input, categorical_inputs):
        """

        Args:
            numerical_input (Tensor): with shape [batch_size, num_numerical_features]
            categorical_inputs (Tensor): with shape [num_categorical_features, batch_size]

        Returns:
            Tensor: Concatenated bottom mlp and embedding output in shape [batch, 1 + #embedding, embeddding_dim]
        """
        batch_size = numerical_input.size()[0]

        bottom_output = []

        # Reshape bottom mlp to concatenate with embeddings
        if self.bottom_mlp:
            bottom_output.append(self.bottom_mlp(numerical_input).view(batch_size, 1, -1))

        if self._hash_indices:
            for cat, size in enumerate(self._categorical_feature_sizes):
                categorical_inputs[:, cat] %= size
                logging.log_first_n(
                    logging.WARNING, F"Hashed indices out of range.", 1)

        # NOTE: It doesn't transpose input
        if self.num_categorical_features > 0:
            bottom_output.append(self.joint_embedding(categorical_inputs).to(numerical_input.dtype))

        if len(bottom_output) == 1:
            cat_bottom_out = bottom_output[0]
        else:
            cat_bottom_out = torch.cat(bottom_output, dim=1)
        return cat_bottom_out
Exemple #3
0
def tfhub_cache_dir(default_cache_dir=None, use_temp=False):
    """Returns cache directory.

  Returns cache directory from either TFHUB_CACHE_DIR environment variable
  or --tfhub_cache_dir or default, if set.

  Args:
    default_cache_dir: Default cache location to use if neither TFHUB_CACHE_DIR
                       environment variable nor --tfhub_cache_dir are
                       not specified.
    use_temp: bool, Optional to enable using system's temp directory as a
              module cache directory if neither default_cache_dir nor
              --tfhub_cache_dir nor TFHUB_CACHE_DIR environment variable are
              specified .
  """

    # Note: We are using FLAGS["tfhub_cache_dir"] (and not FLAGS.tfhub_cache_dir)
    # to access the flag value in order to avoid parsing argv list. The flags
    # should have been parsed by now in main() by tf.app.run(). If that was not
    # the case (say in Colab env) we skip flag parsing because argv may contain
    # unknown flags.
    cache_dir = (get_env_setting(_TFHUB_CACHE_DIR, "tfhub_cache_dir")
                 or default_cache_dir)
    if not cache_dir and use_temp:
        # Place all TF-Hub modules under <system's temp>/tfhub_modules.
        cache_dir = os.path.join(tempfile.gettempdir(), "tfhub_modules")
    if cache_dir:
        logging.log_first_n(logging.INFO, "Using %s to cache modules.", 1,
                            cache_dir)
    return cache_dir
 def __init__(self, entity_info_file, name, get_field, norm=None):
     logging.info('building entity kb...')
     with open(entity_info_file, 'rb') as f:
         [self.entity_ids, self.entity_names] = pickle.load(f)
     self.emap = dict()
     self.missing_entities = ['army', 'navy']
     if not os.path.exists(entity_info_file + '.cache.pkl'):
         for idx in range(len(self.entity_ids)):
             logging.log_first_n(logging.INFO, 'entity kb: %s -> %s', 10, self.entity_names[idx], idx)
             logging.log_every_n_seconds(logging.INFO, 'entity kb: %s of %s', 10, idx, len(self.entity_ids))
             self.emap[self.entity_names[idx].lower()] = idx
             normalized = normalize_name(self.entity_names[idx])
             splt = split(normalized)
             cleaned = clean(splt)
             nostop = remove_stopwords(cleaned)
             if normalized not in self.emap:
                 self.emap[normalized] = idx
             if splt not in self.emap:
                 self.emap[splt] = idx
             if cleaned not in self.emap:
                 self.emap[cleaned] = idx
             if nostop not in self.emap:
                 self.emap[nostop] = idx
         for me in self.missing_entities:
             self.emap[me] = len(self.emap)
         with open(entity_info_file + '.cache.pkl', 'wb') as fout:
             pickle.dump(self.emap, fout)
     else:
         with open(entity_info_file + '.cache.pkl', 'rb') as fin:
             self.emap = pickle.load(fin)
     self.name = name
     self.get_field = get_field
     logging.info('building entity kb...done')
  def _compute_gradient(self, loss, dense_features, gradient_tape=None):
    """Computes the gradient given a loss and dense features."""
    feature_values = list(dense_features.values())
    if gradient_tape is None:
      grads = tf.gradients(loss, feature_values)
    else:
      grads = gradient_tape.gradient(loss, feature_values)

    # The order of elements returned by .values() and .keys() are guaranteed
    # corresponding to each other.
    keyed_grads = dict(zip(dense_features.keys(), grads))

    invalid_grads, valid_grads = self._split_dict(keyed_grads,
                                                  lambda grad: grad is None)
    # Two cases that grad can be invalid (None):
    # (1) The feature is not differentiable, like strings or integers.
    # (2) The feature is not involved in loss computation.
    if invalid_grads:
      if self._raise_invalid_gradient:
        raise ValueError('Cannot perturb features ' + str(invalid_grads.keys()))
      logging.log_first_n(logging.WARNING, 'Cannot perturb features %s', 1,
                          invalid_grads.keys())

    # Guards against numerical errors. If the gradient is malformed (inf, -inf,
    # or NaN) on a dimension, replace it with 0, which has the effect of not
    # perturbing the original sample along that perticular dimension.
    return tf.nest.map_structure(
        lambda g: tf.where(tf.math.is_finite(g), g, tf.zeros_like(g)),
        valid_grads)
Exemple #6
0
    def collect(self, x):
        """Tracks the absolute max of all tensors

        Args:
            x: A tensor

        Raises:
            RuntimeError: If amax shape changes
        """
        if torch.min(x) < 0.:
            logging.log_first_n(logging.INFO, (
                "Calibrator encountered negative values. It shouldn't happen after ReLU. "
                "Make sure this is the right tensor to calibrate."), 1)
            x = x.abs()

        # Swap axis to reduce.
        axis = self._axis if isinstance(self._axis,
                                        (list, tuple)) else [self._axis]
        reduce_axis = []
        for i in range(x.dim()):
            if not i in axis:
                reduce_axis.append(i)
        local_amax = quant_utils.reduce_amax(x, axis=reduce_axis).detach()
        if self._calib_amax is None:
            self._calib_amax = local_amax
        else:
            if local_amax.shape != self._calib_amax.shape:
                raise RuntimeError("amax shape changed!")
            self._calib_amax.copy_(
                torch.max(self._calib_amax, local_amax).data)

        if self._track_amax:
            self._amaxs.append(local_amax.cpu().numpy())
Exemple #7
0
def gym_env_wrapper(env, rl_env_max_episode_steps, maxskip_env, rendered_env,
                    rendered_env_resize_to, sticky_actions, output_dtype,
                    num_actions):
    """Wraps a gym environment. see make_gym_env for details."""
    # rl_env_max_episode_steps is None or int.
    assert ((not rl_env_max_episode_steps)
            or isinstance(rl_env_max_episode_steps, int))

    wrap_with_time_limit = ((not rl_env_max_episode_steps)
                            or rl_env_max_episode_steps >= 0)

    if wrap_with_time_limit:
        env = remove_time_limit_wrapper(env)

    if num_actions is not None:
        logging.log_first_n(logging.INFO, "Number of discretized actions: %d",
                            1, num_actions)
        env = ActionDiscretizeWrapper(env, num_actions=num_actions)

    if sticky_actions:
        env = StickyActionEnv(env)

    if maxskip_env:
        env = MaxAndSkipEnv(env)  # pylint: disable=redefined-variable-type

    if rendered_env:
        env = RenderedEnv(env,
                          resize_to=rendered_env_resize_to,
                          output_dtype=output_dtype)

    if wrap_with_time_limit and rl_env_max_episode_steps is not None:
        env = gym.wrappers.TimeLimit(
            env, max_episode_steps=rl_env_max_episode_steps)
    return env
Exemple #8
0
    def forward(self, numerical_input, categorical_inputs):
        """

        Args:
            numerical_input (Tensor): with shape [batch_size, num_numerical_features]
            categorical_inputs (Tensor): with shape [num_categorical_features, batch_size]
        """
        batch_size = numerical_input.size()[0]
        # TODO(haow): Maybe check batch size of sparse input

        # Put indices on the same device as corresponding embedding
        device_indices = []
        for embedding_id, embedding in enumerate(self.embeddings):
            device_indices.append(categorical_inputs[embedding_id].to(self._embedding_device_map[embedding_id]))

        bottom_mlp_output = self.bottom_mlp(numerical_input)

        # embedding_outputs will be a list of (26 in the case of Criteo) fetched embeddings with shape
        # [batch_size, embedding_size]
        embedding_outputs = []
        for embedding_id, embedding in enumerate(self.embeddings):
            if self._hash_indices:
                device_indices[embedding_id] %= embedding.num_embeddings
                logging.log_first_n(
                    logging.WARNING, F"Hashed indices out of range.", 1)
            embedding_outputs.append(embedding(device_indices[embedding_id]).to(self._base_device))

        interaction_output = self._interaction(bottom_mlp_output, embedding_outputs, batch_size)

        top_mlp_output = self.top_mlp(interaction_output)

        return top_mlp_output
Exemple #9
0
  def __init__(self, scope=None, skip_summary=False, namespace=None):
    """Initializes a `_ScopedSummary`.

    Args:
      scope: String scope name.
      skip_summary: Whether to record summary ops.
      namespace: Optional string namespace for the summary.

    Returns:
      A `_ScopedSummary` instance.
    """

    if tf_compat.tpu_function.get_tpu_context().number_of_shards:
      logging.log_first_n(
          logging.WARN,
          "Scoped summaries will be skipped since they do not support TPU", 1)
      skip_summary = True

    self._scope = scope
    self._namespace = namespace
    self._additional_scope = None
    self._skip_summary = skip_summary
    self._summary_ops = []
    self._actual_summary_scalar_fn = summary_lib.scalar
    self._actual_summary_image_fn = summary_lib.image
    self._actual_summary_histogram_fn = summary_lib.histogram
    self._actual_summary_audio_fn = summary_lib.audio
    def load_calib_amax(self, *args, **kwargs):
        """Load amax from calibrator.

        Updates the amax buffer with value computed by the calibrator, creating it if necessary.
        *args and **kwargs are directly passed to compute_amax, except "strict" in kwargs. Refer to
        compute_amax for more details.
        """
        strict = kwargs.pop("strict", True)
        if getattr(self, '_calibrator', None) is None:
            raise RuntimeError("Calibrator not created.")
        calib_amax = self._calibrator.compute_amax(*args, **kwargs)
        if calib_amax is None:
            err_msg = "Calibrator returned None."
            if not strict:
                logging.warning(err_msg)
                logging.warning("Set amax to NaN!")
                calib_amax = torch.tensor(math.nan)
            else:
                raise RuntimeError(err_msg)
        logging.warning("Load calibrated amax, shape={}.".format(
            calib_amax.shape))
        logging.log_first_n(
            logging.WARNING,
            "Call .cuda() if running on GPU after loading calibrated amax.", 1)
        if not hasattr(self, '_amax'):
            self.register_buffer('_amax', calib_amax.data)
        else:
            self._amax.copy_(calib_amax)
    def _fb_fake_quant(self, inputs, amax):
        """Native pytorch fake quantization."""
        logging.log_first_n(
            logging.WARNING,
            "Use Pytorch's native experimental fake quantization.", 1)
        bound = (1 << (self._num_bits - 1 + int(self._unsigned))) - 1
        # To be consistent with ONNX, full range is used. e.g. range is [-128, 127] in int8
        if amax.numel() == 1:
            outputs = torch.fake_quantize_per_tensor_affine(
                inputs,
                amax.item() / bound, 0,
                -bound - 1 if not self._unsigned else 0, bound)
        else:
            amax_sequeeze = amax.squeeze().detach()
            if len(amax_sequeeze.shape) != 1:
                raise TypeError(
                    "Pytorch's native quantization doesn't support multiple axes"
                )
            quant_dim = list(amax.shape).index(list(amax_sequeeze.shape)[0])
            scale = amax_sequeeze / bound
            outputs = torch.fake_quantize_per_channel_affine(
                inputs, scale.data,
                torch.zeros_like(scale, dtype=torch.long).data, quant_dim,
                -bound - 1 if not self._unsigned else 0, bound)

        return outputs
Exemple #12
0
 def evaluate_with_warning(*args, **kwargs):
     evaluate_out = f(*args, **kwargs)
     if evaluate_out is None:
         logging.log_first_n(logging.WARNING, none_return_is_deprecated_msg,
                             1)
         return {}
     return evaluate_out
Exemple #13
0
def margin_loss(positive_scores, negative_scores, margin):
    logging.log_first_n(logging.INFO, '[margin_loss] positive %s | negative %s | margin %s', 10,
                        str(positive_scores.shape), str(negative_scores.shape), margin)
    pos_minus_neg = positive_scores - negative_scores - margin
    labels = torch.ones_like(pos_minus_neg)
    res = F.binary_cross_entropy_with_logits(pos_minus_neg, labels, reduction='mean')
    return res
Exemple #14
0
    def collect(self, x):
        """Collect histogram"""
        if torch.min(x) < 0.:
            logging.log_first_n(logging.INFO, (
                "Calibrator encountered negative values. It shouldn't happen after ReLU. "
                "Make sure this is the right tensor to calibrate."), 1)
            x = x.abs()
        x_np = x.cpu().detach().numpy()

        if self._skip_zeros:
            x_np = x_np[np.where(x_np != 0)]

        if self._calib_bin_edges is None and self._calib_hist is None:
            # first time it uses num_bins to compute histogram.
            self._calib_hist, self._calib_bin_edges = np.histogram(
                x_np, bins=self._num_bins)
        else:
            temp_amax = np.max(x_np)
            if temp_amax > self._calib_bin_edges[-1]:
                # increase the number of bins
                width = self._calib_bin_edges[1] - self._calib_bin_edges[0]
                # NOTE: np.arange may create an extra bin after the one containing temp_amax
                new_bin_edges = np.arange(self._calib_bin_edges[-1] + width,
                                          temp_amax + width, width)
                self._calib_bin_edges = np.hstack(
                    (self._calib_bin_edges, new_bin_edges))
            hist, self._calib_bin_edges = np.histogram(
                x_np, bins=self._calib_bin_edges)
            hist[:len(self._calib_hist)] += self._calib_hist
            self._calib_hist = hist
Exemple #15
0
    def forward(self, numerical_input, categorical_inputs):
        """
        Args:
            numerical_input (Tensor): with shape [batch_size, num_numerical_features]
            categorical_inputs (Tensor): with shape [num_categorical_features, batch_size]
        """
        batch_size = numerical_input.size()[0]
        bottom_mlp_output = self.bottom_mlp(numerical_input)

        # Change indices based on hash_shift
        # It would be more efficient to change on the data loader side. But in order to keep the interface consistent
        # with the base Dlrm model, it is handled here.
        if self._hash_indices:
            for cat, size in enumerate(self._categorical_feature_sizes):
                categorical_inputs[cat] %= size
                logging.log_first_n(
                    logging.WARNING, F"Hashed indices out of range.", 1)

        # self._interaction takes list of tensor as input. So make this single element list
        # categorical_inputs is transposed here only to keep interface consistent with base model,
        # which makes it easy to test. Will change them to be the best performing version.
        # TODO(haow): Remove transpose.
        embedding_outputs = [self.embeddings[0](categorical_inputs.t()).view(batch_size, -1)]

        interaction_output = self._interaction(bottom_mlp_output, embedding_outputs, batch_size)

        top_mlp_output = self.top_mlp(interaction_output)

        return top_mlp_output
Exemple #16
0
def restore_tf2_ckpt(model,
                     ckpt_path_or_file,
                     skip_mismatch=True,
                     exclude_layers=None):
  """Restore variables from a given checkpoint.

  Args:
    model: the keras model to be restored.
    ckpt_path_or_file: the path or file for checkpoint.
    skip_mismatch: whether to skip variables if shape mismatch,
      only works with tf1 checkpoint.
    exclude_layers: string list exclude layer's variables,
      only works with tf2 checkpoint.

  Raises:
    KeyError: if access unexpected variables.
  """
  ckpt_file = ckpt_path_or_file
  if tf.io.gfile.isdir(ckpt_file):
    ckpt_file = tf.train.latest_checkpoint(ckpt_file)

  # Try to load object-based checkpoint (by model.save_weights).
  var_list = tf.train.list_variables(ckpt_file)
  if var_list[0][0] == '_CHECKPOINTABLE_OBJECT_GRAPH':
    print(f'Load checkpointable from {ckpt_file}, excluding {exclude_layers}')
    keys = {var[0].split('/')[0] for var in var_list}
    keys.discard('_CHECKPOINTABLE_OBJECT_GRAPH')
    if exclude_layers:
      exclude_layers = set(exclude_layers)
      keys = keys.difference(exclude_layers)
    ckpt = tf.train.Checkpoint(**{key: getattr(model, key, None)
                                  for key in keys
                                  if getattr(model, key, None)})
    status = ckpt.restore(ckpt_file)
    status.assert_nontrivial_match()
    return

  print(f'Load TF1 graph based checkpoint from {ckpt_file}.')
  var_dict = {v.name.split(':')[0]: v for v in model.weights}
  reader = tf.train.load_checkpoint(ckpt_file)
  var_shape_map = reader.get_variable_to_shape_map()
  for key, var in var_dict.items():
    if key in var_shape_map:
      if var_shape_map[key] != var.shape:
        msg = 'Shape mismatch: %s' % key
        if skip_mismatch:
          logging.warning(msg)
        else:
          raise ValueError(msg)
      else:
        var.assign(reader.get_tensor(key), read_value=False)
        logging.log_first_n(logging.INFO,
                            f'Init {var.name} from {key} ({ckpt_file})', 10)
    else:
      msg = 'Not found %s in %s' % (key, ckpt_file)
      if skip_mismatch:
        logging.warning(msg)
      else:
        raise KeyError(msg)
Exemple #17
0
    def forward(self, categorical_inputs) -> List[torch.Tensor]:
        if self.hash_indices:
            for cat, size in enumerate(self._categorical_feature_sizes):
                categorical_inputs[:, cat] %= size
                logging.log_first_n(logging.WARNING,
                                    f"Hashed indices out of range.", 1)

        return [self.embedding(categorical_inputs)]
Exemple #18
0
 def write_summaries(self,
                     step: int,
                     values: Mapping[str, Array],
                     metadata: Optional[Mapping[str, Any]] = None):
     logging.log_first_n(
         logging.WARNING,
         "TorchTensorboardWriter does not support writing raw summaries.",
         1)
Exemple #19
0
    def forward(ctx, input):
        """Forward function.
        Args:
            input (torch.Tensor): Input tensor. First dimension should be the batch size
        Returns:
            torch.Tensor: [batch_size x number_of_logits] Output tensor
        """
        # Sparsemax currently only handles 2-dim tensors,
        # so we reshape to a convenient shape and reshape back after sparsemax
        selfdim = 1
        input = input.transpose(0, selfdim)
        original_size = input.size()
        input = input.reshape(input.size(0), -1)
        input = input.transpose(0, 1)
        dim = 1

        number_of_logits = input.size(dim)

        # Translate input by max for numerical stability
        input = input - torch.max(input, dim=dim,
                                  keepdim=True)[0].expand_as(input)

        # Sort input in descending order.
        # (NOTE: Can be replaced with linear time selection method described here:
        # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html)
        zs = torch.sort(input=input, dim=dim, descending=True)[0]
        range = torch.arange(start=1,
                             end=number_of_logits + 1,
                             step=1,
                             device=input.device,
                             dtype=input.dtype).view(1, -1)
        range = range.expand_as(zs)

        # Determine sparsity of projection
        bound = 1 + range * zs
        cumulative_sum_zs = torch.cumsum(zs, dim)
        is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
        k = torch.max(is_gt * range, dim, keepdim=True)[0]

        # Compute threshold function
        zs_sparse = is_gt * zs

        # Compute taus
        taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
        taus = taus.expand_as(input)

        # Sparsemax
        selfoutput = torch.max(torch.zeros_like(input), input - taus)
        logging.log_first_n(logging.INFO, '[forward] selfoutput.shape %s', 10,
                            str(selfoutput.shape))
        ctx.save_for_backward(selfoutput)
        # Reshape back to original shape
        output = selfoutput
        output = output.transpose(0, 1)
        output = output.reshape(original_size)
        output = output.transpose(0, selfdim)

        return output
Exemple #20
0
def added_token_counts(data_iterator,
                       try_swapping,
                       tokenizer,
                       max_input_examples=10000,
                       max_recursion_depth=10000):
  """Computes how many times different phrases have to be added.

  Args:
    data_iterator: Iterator to yield source lists and targets. See function
      yield_sources_and_targets in utils.py for the available iterators. The
      strings in the source list will be concatenated, possibly after swapping
      their order if swapping is enabled.
    try_swapping: Whether to try if swapping sources results in less added text.
    tokenizer: Text tokenizer (derived from tokenization.FullTokenizer).
    max_input_examples: Maximum number of examples to be read from the iterator.
    max_recursion_depth: Maximum recursion depth for LCS. If a long example
      surpasses this recursion depth, the given example is skipped and a warning
      is logged.

  Returns:
    Tuple (collections.Counter for phrases, added phrases for each example).
  """
  phrase_counter = collections.Counter()
  num_examples = 0
  all_added_phrases = []
  for sources, target in data_iterator:
    if num_examples >= max_input_examples:
      break
    logging.log_every_n(logging.INFO, f'{num_examples} examples processed.',
                        1000)
    source_tokens = [t.lower() for t in tokenizer.tokenize(' '.join(sources))]
    target_tokens = [t.lower() for t in tokenizer.tokenize(target)]
    with _recursion_limit(max_recursion_depth):
      try:
        added_phrases = _get_added_phrases(source_tokens, target_tokens)
        if try_swapping and len(sources) == 2:
          source_tokens_swap = [
              t.lower() for t in tokenizer.tokenize(' '.join(sources[::-1]))
          ]
          added_phrases_swap = _get_added_phrases(source_tokens_swap,
                                                  target_tokens)
          # If we can align more and have to add less after swapping, we assume
          # that the sources would be swapped during conversion.
          if len(''.join(added_phrases_swap)) < len(''.join(added_phrases)):
            added_phrases = added_phrases_swap
      except RecursionError:
        logging.log_first_n(
            logging.WARNING, 'Skipping a too long source. Consider increasing '
            '`max_recursion_depth` argument of the `added_token_counts` '
            'function in phrase_vocabulary_optimization_utils.py to keep this '
            f'source: {" ".join(source_tokens)}', 100)
        continue
    for phrase in added_phrases:
      phrase_counter[phrase] += 1
    all_added_phrases.append(added_phrases)
    num_examples += 1
  logging.info('%d examples processed.\n', num_examples)
  return phrase_counter, all_added_phrases
    def backward(ctx, grad_output):
        embedding, indices, offsets = ctx.saved_tensors

        logging.log_first_n(
            logging.WARNING,
            "Highly specialized embedding for embedding_dim 128", 1)
        grad_weights = fused_embedding.gather_gpu_fused_bwd(
            embedding, indices, offsets, grad_output)
        return grad_weights, None, None, None
Exemple #22
0
    def backward(self, grad_output):
        """Backward function."""
        dim = 1
        logging.log_first_n(logging.INFO, 'In sparsemax backward', 10)
        nonzeros = torch.ne(self.output, 0)
        sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros,
                                                                     dim=dim)
        self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))

        return self.grad_input
 def get_patent_title(x):
     if x.record_id in patent_tile_map:
         logging.log_first_n(logging.INFO, 'Returning title for %s: %s',
                             10, x.record_id,
                             patent_tile_map[x.record_id])
         x.title = patent_tile_map[x.record_id]
         return patent_tile_map[x.record_id]
     else:
         x.title = ''
         logging.warning('Missing title for %s', x.record_id)
         return ''
Exemple #24
0
 def _learn(self) -> None:
   """Samples a batch of transitions from replay and learns from it."""
   logging.log_first_n(logging.INFO, 'Begin learning', 1)
   transitions = self._replay.sample(self._batch_size)
   self._rng_key, self._opt_state, self._online_params = self._update(
       self._rng_key,
       self._opt_state,
       self._online_params,
       self._target_params,
       transitions,
   )
 def get_patent_coinventors(x):
     if x.record_id in coinventor_map:
         logging.log_first_n(logging.INFO,
                             'Returning coinventors for %s: %s', 10,
                             x.record_id, coinventor_map[x.record_id])
         x.coinventors = coinventor_map[x.record_id]
         return coinventor_map[x.record_id]
     else:
         x.coinventors = []
         logging.warning('Missing coinventors for %s', x.patent_id)
         return []
Exemple #26
0
 def _learn(self) -> None:
     """Samples a batch of transitions from replay and learns from it."""
     logging.log_first_n(logging.INFO, 'Begin learning', 1)
     transitions = self._replay.sample(self._batch_size)
     self._rng_key, self._opt_state, self._online_params, loss_values, shaped_rewards, penalties = self._update(
         self._rng_key,
         self._opt_state,
         self._online_params,
         self._target_params,
         transitions,
     )
     return loss_values.item(), shaped_rewards.tolist(), penalties.tolist()
Exemple #27
0
    def backward(ctx, grad_output):
        """Backward function."""
        dim = 1
        selfoutput, = ctx.saved_tensors
        logging.log_first_n(logging.INFO, '[backward] selfoutput.shape %s', 10,
                            str(selfoutput.shape))
        nonzeros = torch.ne(selfoutput, 0).float()
        sum = torch.sum(grad_output * nonzeros, dim=dim,
                        keepdim=True) / torch.sum(
                            nonzeros, dim=dim, keepdim=True)
        grad_input = nonzeros * (grad_output - sum)  #.expand_as(grad_output))

        return grad_input
 def get_patent_assignees(x):
     if x.record_id in assignees_map:
         logging.log_first_n(logging.INFO,
                             'Returning assignees for %s: %s', 10,
                             x.record_id, assignees_map[x.record_id])
         x.assignees = assignees_map[x.record_id]
         return assignees_map[x.record_id]
     else:
         x.assignees = []
         logging.log_first_n(logging.WARNING,
                             'Missing assignees for %s', 10,
                             x.record_id)
         return []
Exemple #29
0
  def _worker_task(self, num_subnetworks):
    """Returns the worker index modulo the number of subnetworks."""

    if self._drop_remainder and self._num_workers > 1 and (num_subnetworks >
                                                           self._num_workers):
      logging.log_first_n(
          logging.WARNING,
          "With drop_remainer=True, %s workers and %s subnetworks, the last %s "
          "subnetworks will be dropped and will not be trained", 1,
          self._num_workers, num_subnetworks,
          num_subnetworks - self._num_workers - 1)
    # The first worker will always build the ensemble so we add 1.
    return self._worker_index % (num_subnetworks + 1)
Exemple #30
0
 def _learn(self) -> None:
     """Samples a batch of transitions from replay and learns from it."""
     logging.log_first_n(logging.INFO, 'Begin learning', 1)
     transitions = self._replay.sample(self._batch_size)
     self._rng_key, self._opt_state, self._online_params, logs = self._update(
         self._rng_key,
         self._opt_state,
         self._online_params,
         self._target_params,
         transitions,
     )
     self._online_params = self._sync_tied_layers(self._online_params)
     self._statistics.update(jax.device_get(logs))