Exemple #1
0
 def test_parameter_replication(self):
     c = xb.make_computation_builder("test")
     param = xb.parameter(c, 0,
                          xc.Shape.array_shape(xc.PrimitiveType.F32, ()),
                          "", False)
     built_c = c.Build()
     assert "parameter_replication={false}" in built_c.as_hlo_text()
Exemple #2
0
def _nonzero_translation_rule(c, dims, avals, operands):
  (vals,), = operands
  shape = c.get_shape(vals)
  last_axis = len(shape.dimensions()) - 1
  zeros = xops.Broadcast(xb.constant(c, np.zeros((), shape.numpy_dtype())),
                         shape.dimensions())
  s32_etype = xc.dtype_to_etype(np.dtype('int32'))
  nonzero_indicators = xops.ConvertElementType(xops.Ne(vals, zeros), s32_etype)
  i = core.ShapedArray((), np.dtype('int32'))
  out_dim = xops.Reduce(c, [nonzero_indicators],
                        [xb.constant(c, np.array(0, np.dtype('int32')))],
                        xla.primitive_subcomputation(lax.add_p, i, i),
                        (last_axis,))
  c.get_shape(out_dim)  # xla type checking

  subc = xb.make_computation_builder("sort_gt_comparator")
  params = [xb.parameter(subc, i, xc.Shape.array_shape(s32_etype, ()))
            for i in range(4)]
  comparator = subc.build(xops.Gt(params[0], params[1]))
  iota_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, shape.dimensions())
  ans = xops.Sort(c, [nonzero_indicators, xops.Iota(c, iota_shape, last_axis)],
                  is_stable=True, comparator=comparator)
  _, out_val = xla.xla_destructure(c, ans)
  c.get_shape(out_val)  # xla type checking
  return [[out_dim], [out_val]]
Exemple #3
0
 def test_parameter_replication_default(self):
     c = xb.make_computation_builder("test")
     param = xb.parameter(c, 0,
                          xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
     built_c = c.Build()
     assert "replication" not in built_c.as_hlo_text()
Exemple #4
0
def _make_params(c, dim_in_avals, in_avals):
  n = it.count()
  make = lambda a: [xb.parameter(c, next(n), s) for s in xla.aval_to_xla_shapes(a)]
  return map(make, dim_in_avals), map(make, in_avals)