Ejemplo n.º 1
0
def _UnbatchGrad(op, grad):  # pylint: disable=invalid-name
    return [
        gen_batch_ops.unbatch_grad(op.inputs[0],
                                   op.inputs[1],
                                   grad,
                                   op.inputs[2],
                                   shared_name="unbatch_gradient_{}".format(
                                       op.name)), None, None
    ]
Ejemplo n.º 2
0
def _UnbatchGrad(op, grad):   # pylint: disable=invalid-name
  return [
      gen_batch_ops.unbatch_grad(
          op.inputs[0],
          op.inputs[1],
          grad,
          op.inputs[2],
          shared_name="unbatch_gradient_{}".format(op.name)), None, None
  ]
Ejemplo n.º 3
0
 def testUnbatchGradInvalidBatchId(self):
     with self.assertRaises(errors.InvalidArgumentError):
         self.evaluate(
             gen_batch_ops.unbatch_grad(
                 original_input=constant_op.constant([1]),
                 batch_index=constant_op.constant([
                     [0, 0],
                 ],
                                                  dtype=dtypes.int64),
                 grad=constant_op.constant([
                     1,
                 ]),
                 id=constant_op.constant([
                     1,
                 ], dtype=dtypes.int64)))
Ejemplo n.º 4
0
 def testUnbatchGradInvalidArgs(self):
     original_input = random_ops.random_uniform(shape=(3, 1),
                                                dtype=dtypes.float64,
                                                maxval=None)
     batch_index = random_ops.random_uniform(shape=(3, 1),
                                             dtype=dtypes.int64,
                                             maxval=65536)
     grad = random_ops.random_uniform(shape=(3, 1),
                                      dtype=dtypes.float64,
                                      maxval=None)
     batch_id = random_ops.random_uniform(shape=(3, 1),
                                          dtype=dtypes.int64,
                                          maxval=65536)
     with self.assertRaises(errors.InvalidArgumentError):
         self.evaluate(
             gen_batch_ops.unbatch_grad(original_input=original_input,
                                        batch_index=batch_index,
                                        grad=grad,
                                        id=batch_id,
                                        container="",
                                        shared_name="",
                                        name=""))