Exemplo n.º 1
0
 def construct(self, a, b, indices, values):
     x = IndexedSlices(indices, values, self.dense_shape)
     if a > b:
         x = self.op1(x)
     else:
         x = self.op2(x)
     return x.indices(), x.values()
Exemplo n.º 2
0
 def bprop(x, indices, axis, out, dout):
     x_shp = shape_op(x)
     if axis == 0:
         indices_size = (size_op(indices), )
         x_tail_shp = x_shp[1:]
         values_shape = indices_size + x_tail_shp
         values = reshape(dout, values_shape)
         indices = reshape(indices, indices_size)
         return IndexedSlices(indices, values,
                              x_shp), zeros_like(indices), zeros_like(axis)
     if F.rank(dout) == 0:
         dout = P.ExpandDims()(dout, -1)
     if F.rank(indices) == 0:
         indices = P.ExpandDims()(indices, -1)
     out_shp = shape_op(dout)
     ind_shp = shape_op(indices)
     # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
     perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
     values_transpose = transpose(dout, perm_1)
     params_grad = unsorted_segment_sum(values_transpose, indices,
                                        shape_op(x)[axis])
     # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
     perm_2 = _generate_inverse_index(x_shp, axis)
     params_grad = transpose(params_grad, perm_2)
     return params_grad, zeros_like(indices), zeros_like(axis)
Exemplo n.º 3
0
 def construct(self, indices, values):
     x = IndexedSlices(indices, values, self.dense_shape)
     return x.values(), x.indices(), x.dense_shape()
Exemplo n.º 4
0
 def construct(self, x):
     indices = x.indices()
     values = x.values() + 2
     dense_shape = x.dense_shape()
     return IndexedSlices(indices, values, dense_shape)
Exemplo n.º 5
0
 def construct(self, indices, values):
     ret = (IndexedSlices(indices, values, self.dense_shape), )
     return ret[0]