Example #1
0
 def _FlatOutputProcessor(inputs):
     """Returns a flattened list of 'processor(inputs)'."""
     output, bucketing_key = processor(inputs)
     if isinstance(output, list):
         assert output
         assert all(isinstance(x, tf.Tensor)
                    for x in output), '{}'.format(output)
     else:
         assert isinstance(output, py_utils.NestedMap), '{}'.format(output)
         assert output
         assert all(isinstance(x, tf.Tensor)
                    for x in output.Flatten()), '{}'.format(
                        output.DebugString())
     bucketing_key = tf.to_int32(bucketing_key)
     tf.logging.debug('Processor outputs=%s bucketing_key=%s', output,
                      bucketing_key)
     output_tmpl.values = output
     flat_output_tmpl = output_tmpl.Flatten()
     tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl)
     tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                      function.get_extra_inputs(),
                      function.get_extra_args(), function.get_extra_vars())
     assert not function.get_extra_args(), (
         'fns {} is not pure: extra_args={}'.format(
             processor, function.get_extra_args()))
     return flat_output_tmpl + [bucketing_key]
Example #2
0
 def Grad(x, y0):
   if use_forward_func:
     y = Model(x)
   else:
     y = _Model(x)
   loss = tf.reduce_mean(tf.reduce_sum(y0 * tf.log(y), 1), 0)
   dw, db = tf.gradients(loss, function.get_extra_args())
   cvars.extend(function.get_extra_vars())
   return loss, dw, db
Example #3
0
 def Grad(x, y0):
   if use_forward_func:
     y = Model(x)
   else:
     y = _Model(x)
   loss = tf.reduce_mean(tf.reduce_sum(y0 * tf.log(y), 1), 0)
   arg_w, arg_b = function.get_extra_args()
   self.assertEqual(arg_w.get_shape(), tf.TensorShape([64, 64]))
   self.assertEqual(arg_b.get_shape(), tf.TensorShape([64]))
   dw, db = tf.gradients(loss, [arg_w, arg_b])
   cvars.extend(function.get_extra_vars())
   return loss, dw, db
Example #4
0
 def Grad(x, y0):
     if use_forward_func:
         y = Model(x)
     else:
         y = _Model(x)
     loss = tf.reduce_mean(tf.reduce_sum(y0 * tf.log(y), 1), 0)
     arg_w, arg_b = function.get_extra_args()
     self.assertEqual(arg_w.get_shape(), tf.TensorShape([64, 64]))
     self.assertEqual(arg_b.get_shape(), tf.TensorShape([64]))
     dw, db = tf.gradients(loss, [arg_w, arg_b])
     cvars.extend(function.get_extra_vars())
     return loss, dw, db
Example #5
0
 def _FlatOutputProcessor(inputs):
     """Returns a flattened list of 'processor(inputs)'."""
     outputs = processor(inputs)
     tf.logging.debug('Processor outputs=%s', outputs)
     assert len(outputs) > 1, outputs
     # Add 'outputs' as a list so that each element will be flattened.
     output_tmpl.values = list(outputs)
     flat_outputs = output_tmpl.Flatten()
     tf.logging.debug('Processor flat outputs=%s', flat_outputs)
     tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                      function.get_extra_inputs(),
                      function.get_extra_args(), function.get_extra_vars())
     assert not function.get_extra_args(), (
         'fns {} is not pure: extra_args={}'.format(
             processor, function.get_extra_args()))
     return flat_outputs
Example #6
0
 def _FlatOutputProcessor(source_id, record):
     """Returns a flattened list of 'processor(inputs)'."""
     processor_spec = tf_inspect.getargspec(processor)
     tf.logging.debug('GenericInput.processor.argspec=%s', processor_spec)
     processor_args = set(processor_spec.args) - set(['self'])
     if len(processor_args) == 1:
         output, bucketing_key = processor(record)
     elif processor_args == set(['source_id', 'record']):
         output, bucketing_key = processor(source_id=source_id,
                                           record=record)
     else:
         raise ValueError(
             'GenericInput: processor should take either a single arg '
             'or two args named as "source_id" and "record". '
             'Actual: %s' % processor_args)
     if isinstance(output, list):
         assert output
         assert all(isinstance(x, tf.Tensor)
                    for x in output), '{}'.format(output)
     else:
         assert isinstance(output, py_utils.NestedMap), '{}'.format(output)
         assert output
         assert all(isinstance(x, tf.Tensor)
                    for x in output.Flatten()), '{}'.format(
                        output.DebugString())
     bucketing_key = tf.cast(bucketing_key, tf.int32)
     tf.logging.debug('Processor outputs=%s bucketing_key=%s', output,
                      bucketing_key)
     output_tmpl.out_values = output
     flat_output_tmpl = output_tmpl.Flatten()
     tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl)
     tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                      function.get_extra_inputs(),
                      function.get_extra_args(), function.get_extra_vars())
     assert not function.get_extra_args(), (
         'fns {} is not pure: extra_args={}'.format(
             processor, function.get_extra_args()))
     return flat_output_tmpl + [bucketing_key]