Пример #1
0
 def maybe_move_axis_to_last(a, axis):
   def move_axis_to_last(a, axis):
     return tf.transpose(
         a, tf.concat(
             [tf.range(axis), tf.range(axis + 1, tf.rank(a)), [axis]],
             axis=0))
   return utils.cond(
       axis == utils.subtract(tf.rank(a), 1),
       lambda: a,
       lambda: move_axis_to_last(a, axis))
Пример #2
0
 def f(a, b):  # pylint: disable=missing-docstring
   # We can't assign to captured variable `axisa`, so make a new variable
   axis_a = axisa
   axis_b = axisb
   axis_c = axisc
   if axis is not None:
     axis_a = axis
     axis_b = axis
     axis_c = axis
   if axis_a < 0:
     axis_a = utils.add(axis_a, tf.rank(a))
   if axis_b < 0:
     axis_b = utils.add(axis_b, tf.rank(b))
   def maybe_move_axis_to_last(a, axis):
     def move_axis_to_last(a, axis):
       return tf.transpose(
           a, tf.concat(
               [tf.range(axis), tf.range(axis + 1, tf.rank(a)), [axis]],
               axis=0))
     return utils.cond(
         axis == utils.subtract(tf.rank(a), 1),
         lambda: a,
         lambda: move_axis_to_last(a, axis))
   a = maybe_move_axis_to_last(a, axis_a)
   b = maybe_move_axis_to_last(b, axis_b)
   a_dim = utils.getitem(tf.shape(a), -1)
   b_dim = utils.getitem(tf.shape(b), -1)
   def maybe_pad_0(a, size_of_last_dim):
     def pad_0(a):
       return tf.pad(a, tf.concat([tf.zeros([tf.rank(a) - 1, 2], tf.int32),
                                   tf.constant([[0, 1]], tf.int32)], axis=0))
     return utils.cond(size_of_last_dim == 2,
                       lambda: pad_0(a),
                       lambda: a)
   a = maybe_pad_0(a, a_dim)
   b = maybe_pad_0(b, b_dim)
   c = tf.linalg.cross(*utils.tf_broadcast(a, b))
   if axis_c < 0:
     axis_c = utils.add(axis_c, tf.rank(c))
   def move_last_to_axis(a, axis):
     r = tf.rank(a)
     return tf.transpose(
         a, tf.concat(
             [tf.range(axis), [r - 1], tf.range(axis, r - 1)], axis=0))
   c = utils.cond(
       (a_dim == 2) & (b_dim == 2),
       lambda: c[..., 2],
       lambda: utils.cond(  # pylint: disable=g-long-lambda
           axis_c == utils.subtract(tf.rank(c), 1),
           lambda: c,
           lambda: move_last_to_axis(c, axis_c)))
   return c