Exemple #1
0
    def write_examples(self, funnel_dict, file_name):
        """Pours the array then writes the examples to tfrecords. It creates one example per 'row', i.e. axis=0 of the arrays. All arrays must have the same axis=0 dimension and must be of a type that can be written to a tfrecord

    Parameters
    ----------
    funnel_dict : dict(
      keys - Slot objects or Placeholder objects. The 'funnels' (i.e. unconnected slots) of the waterwork.
      values - valid input data types
    )
        The inputs to the waterwork's full pour function.
    file_name : str
      The name of the tfrecord file to write to.

    """
        if not file_name.endswith('.tfrecord'):
            raise ValueError("file_name must end in '.tfrecord'. Got: ",
                             file_name)

        dir = file_name.split('/')[:-1]
        d.maybe_create_dir(*dir)

        writer = tf.io.TFRecordWriter(file_name)
        tap_dict = self.pour(funnel_dict, key_type='str')

        feature_dict, func_dict = self._get_feature_dicts(tap_dict)

        self._write_tap_dict(writer, tap_dict, func_dict)

        feature_dict_fn = re.sub(r'_?[0-9]*.tfrecord', '.pickle', file_name)
        d.save_to_file(feature_dict, feature_dict_fn)
        writer.close()

        return file_name
Exemple #2
0
    def write_tap_dicts(self, tap_dicts, file_name, skip_keys=None):

        feature_dict, func_dict = self._get_feature_dicts(tap_dicts[0])

        writer = tf.io.TFRecordWriter(file_name)
        for tap_dict in tap_dicts:
            if not tap_dict:
                continue
            self._write_tap_dict(writer, tap_dict, func_dict, skip_keys)

        feature_dict_fn = re.sub(r'_?[0-9]*.tfrecord', '.pickle', file_name)
        d.save_to_file(feature_dict, feature_dict_fn)

        writer.close()
  def test_read_from_file(self):
    temp_json = os.path.join(self.temp_dir, 'temp.json')
    d.save_to_file([], temp_json)

    outputs = d.read_from_file(
        file_name=temp_json
    )

    temp_npy = os.path.join(self.temp_dir, 'temp.npy')
    d.save_to_file(np.zeros(dtype=np.int64, shape=[1]), temp_npy)

    array = d.read_from_file(
        file_name=temp_npy
    )
    th.assert_arrays_equal(self, array, np.zeros(dtype=np.int64, shape=[1]))
Exemple #4
0
 def save_to_file(self, file_name):
     if not file_name.endswith('pickle') and not file_name.endswith(
             'pkl') and not file_name.endswith('dill'):
         raise ValueError("Waterwork can only be saved as a pickle.")
     save_dict = self._save_dict()
     d.save_to_file(save_dict, file_name)
Exemple #5
0
    def multi_write_examples(self,
                             funnel_dict_iter,
                             file_name,
                             num_threads=1,
                             use_threading=False,
                             batch_size=None,
                             file_num_offset=0,
                             skip_fails=False,
                             skip_keys=None,
                             serialize_func=None):
        if serialize_func is None:
            save_dict = self._save_dict()

            def serialize_func(funnel_dict):
                jpype.attachThreadToJVM()
                ww = Waterwork()
                ww._from_save_dict(save_dict)

                tap_dict = ww.pour(funnel_dict, 'str', False)
                feature_dict, func_dict = self._get_feature_dicts(tap_dict)

                serial = ww._serialize_tap_dict(tap_dict, func_dict)
                return serial

        if type(funnel_dict_iter) in (list, tuple):
            funnel_dict_iter = (i for i in funnel_dict_iter)

        file_names = []
        if not file_name.endswith('.tfrecord'):
            raise ValueError("file_name must end in '.tfrecord'")

        dir = file_name.split('/')[:-1]
        d.maybe_create_dir(*dir)

        for batch_num, batch in enumerate(
                b.batcher(funnel_dict_iter, batch_size)):
            if batch_num == 0:
                tap_dict = self.pour(batch[0],
                                     key_type='str',
                                     return_plugged=False)
                feature_dict, func_dict = self._get_feature_dicts(tap_dict)
                feature_dict_fn = re.sub(r'_?[0-9]*.tfrecord', '.pickle',
                                         file_name)
                d.save_to_file(feature_dict, feature_dict_fn)

            logging.info("Serializing batch %s", batch_num)
            if skip_fails:
                try:
                    all_serials = mu.multi_map(serialize_func, batch,
                                               num_threads, use_threading)
                except Exception:
                    logging.warn("Batched %s failed. Skipping.", batch_num)
                    continue
            else:
                all_serials = mu.multi_map(serialize_func, batch, num_threads,
                                           use_threading)
            logging.info("Finished serializing batch %s", batch_num)

            file_num = file_num_offset + batch_num
            fn = file_name.replace('.tfrecord',
                                   '_' + str(file_num) + '.tfrecord')
            file_names.append(fn)

            logging.info("Writing batch %s", batch_num)
            writer = tf.io.TFRecordWriter(fn)
            for serials in all_serials:
                for serial in serials:
                    writer.write(serial)
            logging.info("Finished writing batch %s", batch_num)
            writer.close()

        return file_names
Exemple #6
0
 def save_to_file(self, path):
   """Save the transform object to disk."""
   save_dict = self._save_dict()
   d.save_to_file(save_dict, path)