Ejemplo n.º 1
0
  def test_util_foreach_api_cycle(self):

    class ForTest1(object):

      def __init__(self, a, b):
        self.a = a
        self.b = b

    class ForTest2(object):

      def __init__(self, a):
        self.a = a
        self.b = ForTest1(self, a)

    xdata = {
        2: (11, ['a', 'b'], 17),
        'w': [12, 'q', 12.33],
        17.09: set(['a', 'b', 21]),
    }
    data = ForTest2(xdata)

    wids = []

    def convert(x):
      wids.append(id(x))
      return x

    xu.for_each_instance_rewrite(data,
                                 lambda x: isinstance(x, (int, str, float)),
                                 convert)
    self.assertEqual(len(wids), 11)
Ejemplo n.º 2
0
  def test_util_foreach_api(self):

    class ForTest(object):

      def __init__(self):
        self.a = {'k': [1, 2, 3], 4.9: 'y', 5: {'a': 'n'}}
        self.b = ('f', 17)

    duped_data = ForTest()
    data = {
        2.3: 11,
        21: ForTest(),
        'w': [12, ForTest(), duped_data],
        123: duped_data,
    }

    ids = []

    def collect(x):
      ids.append(id(x))

    xu.for_each_instance(data, lambda x: isinstance(x, (int, str, float)),
                         collect)

    wids = []

    def convert(x):
      wids.append(id(x))
      return x

    xu.for_each_instance_rewrite(data,
                                 lambda x: isinstance(x, (int, str, float)),
                                 convert)
    self.assertEqual(len(ids), 17)
    self.assertEqual(ids, wids)
Ejemplo n.º 3
0
def save(data, file_or_path, master_only=True):
    """Saves the input data into a file.

  The saved data is transfered to PyTorch CPU device before being saved, so a
  following `torch.load()` will load CPU data.

  Args:
    data: The input data to be saved. Any nested combination of Python objects
      (list, tuples, sets, dicts, ...).
    file_or_path: The destination for the data saving operation. Either a file
      path or a Python file object. If `master_only` is ``False`` the path or
      file objects must point to different destinations as otherwise all the
      writes from the same host will override each other.
    master_only (bool): Whether only the master device should save the data. If
      False, the `file_or_path` argument should be a different file or path for
      each of the ordinals taking part to the replication, otherwise all the
      replicas on the same host will be writing to the same location.
      Default: True
  """
    def convert_fn(value):
        return value.cpu()

    cpu_data = xu.for_each_instance_rewrite(data,
                                            lambda x: type(x) == torch.Tensor,
                                            convert_fn)
    if master_only:
        if is_master_ordinal():
            torch.save(cpu_data, file_or_path)
    else:
        torch.save(cpu_data, file_or_path)
Ejemplo n.º 4
0
    def _replace_tensors(self, inputs):
        def convert_fn(value):
            return self._get_converted_tensor()

        return xu.for_each_instance_rewrite(inputs,
                                            lambda x: self._select_fn(x),
                                            convert_fn)
Ejemplo n.º 5
0
def save(data, file_or_path, master_only=True):
    """Saves the input data into a file.

  The saved data is transfered to PyTorch CPU device before being saved, so a
  following `torch.load()` will load CPU data.

  Args:
    data: The input data to be saved. Any nested combination of Python objects
      (list, tuples, sets, dicts, ...).
    file_or_path: The destination for the data saving operation. Either a file
      path or a Python file object.
    master_only (bool): Whether only the master device should save the data. If
      False, the `file_or_path` argument must be a path, and the different
      devices will save data to files with their ordinal as extension.
      Default: True
  """
    def convert_fn(value):
        return value.cpu()

    cpu_data = xu.for_each_instance_rewrite(data,
                                            lambda x: type(x) == torch.Tensor,
                                            convert_fn)
    if master_only:
        if is_master_ordinal():
            torch.save(cpu_data, file_or_path)
    else:
        assert type(file_or_path) == str
        file_or_path = '{}.{}'.format(file_or_path, get_ordinal())
        torch.save(cpu_data, file_or_path)