예제 #1
0
    def __init__(self, selector_input, data_inputs):
        self._selector_input = selector_input
        self._data_inputs = list(data_inputs)

        first_output_types = dataset_ops.get_legacy_output_types(
            data_inputs[0])
        first_output_classes = dataset_ops.get_legacy_output_classes(
            data_inputs[0])

        for data_input in data_inputs[1:]:
            if (dataset_ops.get_legacy_output_types(data_input) !=
                    first_output_types
                    or dataset_ops.get_legacy_output_classes(data_input) !=
                    first_output_classes):
                raise TypeError(
                    "All datasets must have the same type and class.")

        output_shapes = dataset_ops.get_legacy_output_shapes(
            self._data_inputs[0])
        for data_input in self._data_inputs[1:]:
            output_shapes = nest.pack_sequence_as(output_shapes, [
                ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
                    nest.flatten(output_shapes),
                    nest.flatten(
                        dataset_ops.get_legacy_output_shapes(data_input)))
            ])

        self._element_spec = structure.convert_legacy_structure(
            first_output_types, output_shapes, first_output_classes)
        # pylint: disable=protected-access
        variant_tensor = gen_experimental_dataset_ops.directed_interleave_dataset(
            self._selector_input._variant_tensor,
            [data_input._variant_tensor for data_input in self._data_inputs],
            **self._flat_structure)
        super(_DirectedInterleaveDataset, self).__init__(variant_tensor)
예제 #2
0
 def _as_variant_tensor(self):
   # pylint: disable=protected-access
   return (
       gen_experimental_dataset_ops.directed_interleave_dataset(
           self._selector_input._variant_tensor,
           [data_input._variant_tensor for data_input in self._data_inputs],
           **self._flat_structure))
예제 #3
0
    def __init__(self,
                 selector_input,
                 data_inputs,
                 stop_on_empty_dataset=False):
        self._selector_input = selector_input
        self._data_inputs = list(data_inputs)
        self._stop_on_empty_dataset = stop_on_empty_dataset

        first_output_types = dataset_ops.get_legacy_output_types(
            data_inputs[0])
        first_output_classes = dataset_ops.get_legacy_output_classes(
            data_inputs[0])

        for i, data_input in enumerate(data_inputs[1:]):
            if (dataset_ops.get_legacy_output_types(data_input) !=
                    first_output_types
                    or dataset_ops.get_legacy_output_classes(data_input) !=
                    first_output_classes):
                raise TypeError(
                    "All datasets must have the same type and class.\n"
                    "dataset 0 vs dataset %s types: %s ; %s\n"
                    "classes: %s ; %s" %
                    (i + 1, first_output_types,
                     dataset_ops.get_legacy_output_types(data_input),
                     first_output_classes,
                     dataset_ops.get_legacy_output_classes(data_input)))

        output_shapes = dataset_ops.get_legacy_output_shapes(
            self._data_inputs[0])
        for data_input in self._data_inputs[1:]:
            output_shapes = nest.pack_sequence_as(output_shapes, [
                ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
                    nest.flatten(output_shapes),
                    nest.flatten(
                        dataset_ops.get_legacy_output_shapes(data_input)))
            ])
        self._element_spec = structure.convert_legacy_structure(
            first_output_types, output_shapes, first_output_classes)

        compat_kwargs = {}
        if compat.forward_compatible(2021, 5,
                                     14) or self._stop_on_empty_dataset:
            compat_kwargs[
                "stop_on_empty_dataset"] = self._stop_on_empty_dataset

        # pylint: disable=protected-access
        variant_tensor = (
            gen_experimental_dataset_ops.directed_interleave_dataset(
                self._selector_input._variant_tensor, [
                    data_input._variant_tensor
                    for data_input in self._data_inputs
                ], **compat_kwargs, **self._flat_structure))

        super(_DirectedInterleaveDataset, self).__init__(variant_tensor)
예제 #4
0
 def _as_variant_tensor(self):
   # pylint: disable=protected-access
   if compat.forward_compatible(2019, 8, 3):
     return (
         gen_experimental_dataset_ops.directed_interleave_dataset(
             self._selector_input._variant_tensor,
             [data_input._variant_tensor for data_input in self._data_inputs],
             **self._flat_structure))
   else:
     return (
         gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
             self._selector_input._variant_tensor,
             [data_input._variant_tensor for data_input in self._data_inputs],
             **self._flat_structure))
예제 #5
0
  def __init__(self, selector_input, data_inputs, stop_on_empty_dataset=False):
    self._selector_input = selector_input
    self._data_inputs = list(data_inputs)
    self._stop_on_empty_dataset = stop_on_empty_dataset

    first_output_types = dataset_ops.get_legacy_output_types(data_inputs[0])
    first_output_classes = dataset_ops.get_legacy_output_classes(data_inputs[0])

    for i, data_input in enumerate(data_inputs[1:]):
      if (dataset_ops.get_legacy_output_types(data_input) != first_output_types
          or dataset_ops.get_legacy_output_classes(data_input) !=
          first_output_classes):
        raise TypeError("All datasets must have the same type and class.\n"
                        "dataset 0 vs dataset %s types: %s ; %s\n"
                        "classes: %s ; %s" %
                        (i + 1, first_output_types,
                         dataset_ops.get_legacy_output_types(data_input),
                         first_output_classes,
                         dataset_ops.get_legacy_output_classes(data_input)))

    spec = self._data_inputs[0].element_spec
    for data_input in self._data_inputs[1:]:
      spec = nest.pack_sequence_as(spec, [
          x.most_specific_compatible_type(y) for (x, y) in zip(
              nest.flatten(spec),
              nest.flatten(data_input.element_spec))
      ])
    self._element_spec = spec

    # pylint: disable=protected-access
    variant_tensor = (
        gen_experimental_dataset_ops.directed_interleave_dataset(
            self._selector_input._variant_tensor,
            [data_input._variant_tensor for data_input in self._data_inputs],
            stop_on_empty_dataset=self._stop_on_empty_dataset,
            **self._flat_structure))

    super(_DirectedInterleaveDataset, self).__init__(variant_tensor)