def Inference(self):
   """Computes y = w^T x + b. Returns y and x, as outputs and inputs."""
   with tf.variable_scope('inference'):
     x = tf.placeholder(dtype=tf.float32, name='input')
     r = tf.random.stateless_uniform([3],
                                     seed=py_utils.GenerateStepSeedPair(
                                         self.params, self.theta.global_step))
     y = tf.reduce_sum((self.vars.w + r) * x) + self.vars.b
     return {'default': ({'output': y}, {'input': x})}
Exemplo n.º 2
0
  def Inference(self):
    """Computes y = w^T x + b. Returns y and x, as outputs and inputs."""
    # Add a dummy file def to the collection
    filename = tf.convert_to_tensor(
        'dummy.txt', tf.dtypes.string, name='asset_filepath')
    tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS,
                                   filename)

    with tf.name_scope('inference'):
      x = tf.placeholder(dtype=tf.float32, name='input')
      r = tf.random.stateless_uniform([3],
                                      seed=py_utils.GenerateStepSeedPair(
                                          self.params))
      y = tf.reduce_sum((self.vars.w + r) * x) + self.vars.b
      return {'default': ({'output': y}, {'input': x})}
Exemplo n.º 3
0
  def _GetWeight(self, theta):
    p = self.params
    filter_w = theta.w

    # First normalize filter_w over the temporal dimension here.
    filter_w = tf.nn.softmax(filter_w / p.temperature, axis=0)

    # Add dropconnect on the weights for regularization.
    if p.dropconnect_prob > 0.0 and not self.do_eval:
      if p.deterministic_dropout:
        filter_w = py_utils.DeterministicDropout(
            filter_w, 1.0 - p.dropconnect_prob,
            py_utils.GenerateStepSeedPair(p))
      else:
        filter_w = tf.nn.dropout(
            filter_w, rate=p.dropconnect_prob, seed=p.random_seed)

    # Tie the parameters of every subsequent number of weight_tiling_factor
    # channels.
    filter_w = tf.tile(filter_w, [1, 1, p.weight_tiling_factor, 1])
    return filter_w