def construct(self, a, b, indices, values):
     x = RowTensor(indices, values, self.dense_shape)
     x = self.op2(x)
     while a > b:
         x = self.op1(x)
         b = b + 1
     return x.indices, x.values, x.dense_shape
예제 #2
0
 def construct(self, a, b, indices, values):
     x = RowTensor(indices, values, self.dense_shape)
     if a > b:
         x = self.op1(x)
     else:
         x = self.op2(x)
     return x.indices, x.values
예제 #3
0
 def bprop(x, indices, axis, out, dout):
     x_shp = shape_op(x)
     if axis == 0:
         indices_size = (size_op(indices), )
         x_tail_shp = x_shp[1:]
         values_shape = indices_size + x_tail_shp
         values = reshape(dout, values_shape)
         indices = reshape(indices, indices_size)
         return RowTensor(indices, values,
                          x_shp), zeros_like(indices), zeros_like(axis)
     if F.rank(dout) == 0:
         dout = P.ExpandDims()(dout, -1)
     if F.rank(indices) == 0:
         indices = P.ExpandDims()(indices, -1)
     out_shp = shape_op(dout)
     ind_shp = shape_op(indices)
     # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
     perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
     values_transpose = transpose(dout, perm_1)
     params_grad = unsorted_segment_sum(values_transpose, indices,
                                        shape_op(x)[axis])
     # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
     perm_2 = _generate_inverse_index(x_shp, axis)
     params_grad = transpose(params_grad, perm_2)
     return params_grad, zeros_like(indices), zeros_like(axis)
 def construct(self, x):
     indices = x.indices
     values = x.values + 2
     dense_shape = x.dense_shape
     return RowTensor(indices, values, dense_shape)
예제 #5
0
 def construct(self, indices, values):
     x = RowTensor(indices, values, self.dense_shape)
     return x.values, x.indices, x.dense_shape
예제 #6
0
 def construct(self, indices, values):
     ret = (RowTensor(indices, values, self.dense_shape), )
     return ret[0]