Ejemplo n.º 1
0
  def testInfeedS32Values(self):
    to_infeed = NumpyArrayS32([1, 2, 3, 4])
    c = self._NewComputation()
    c.Infeed(xla_client.Shape.from_numpy(to_infeed[0]))
    compiled_c = c.Build().CompileWithExampleArguments()
    for item in to_infeed:
      xla_client.transfer_to_infeed(item)

    for item in to_infeed:
      result = compiled_c.Execute()
      self.assertEqual(result, item)
Ejemplo n.º 2
0
    def testInfeedS32Values(self):
        to_infeed = NumpyArrayS32([1, 2, 3, 4])
        c = self._NewComputation()
        c.Infeed(xla_client.Shape.from_numpy(to_infeed[0]))
        compiled_c = c.Build().CompileWithExampleArguments()
        for item in to_infeed:
            xla_client.transfer_to_infeed(item)

        for item in to_infeed:
            result = compiled_c.Execute()
            self.assertEqual(result, item)
Ejemplo n.º 3
0
    def testInfeedThenOutfeedS32(self):
        to_round_trip = NumpyArrayS32([1, 2, 3, 4])
        c = self._NewComputation()
        x = c.Infeed(xla_client.Shape.from_numpy(to_round_trip[0]))
        c.Outfeed(x)

        compiled_c = c.Build().CompileWithExampleArguments()

        for want in to_round_trip:
            execution = threading.Thread(target=compiled_c.Execute)
            execution.start()
            xla_client.transfer_to_infeed(want)
            got = xla_client.transfer_from_outfeed(
                xla_client.Shape.from_numpy(to_round_trip[0]))
            execution.join()
            self.assertEqual(want, got)
Ejemplo n.º 4
0
  def testInfeedThenOutfeedS32(self):
    to_round_trip = NumpyArrayS32([1, 2, 3, 4])
    c = self._NewComputation()
    x = c.Infeed(xla_client.Shape.from_numpy(to_round_trip[0]))
    c.Outfeed(x)

    compiled_c = c.Build().CompileWithExampleArguments()

    for want in to_round_trip:
      execution = threading.Thread(target=compiled_c.Execute)
      execution.start()
      xla_client.transfer_to_infeed(want)
      got = xla_client.transfer_from_outfeed(
          xla_client.Shape.from_numpy(to_round_trip[0]))
      execution.join()
      self.assertEqual(want, got)