def write_examples(self, funnel_dict, file_name): """Pours the array then writes the examples to tfrecords. It creates one example per 'row', i.e. axis=0 of the arrays. All arrays must have the same axis=0 dimension and must be of a type that can be written to a tfrecord Parameters ---------- funnel_dict : dict( keys - Slot objects or Placeholder objects. The 'funnels' (i.e. unconnected slots) of the waterwork. values - valid input data types ) The inputs to the waterwork's full pour function. file_name : str The name of the tfrecord file to write to. """ if not file_name.endswith('.tfrecord'): raise ValueError("file_name must end in '.tfrecord'. Got: ", file_name) dir = file_name.split('/')[:-1] d.maybe_create_dir(*dir) writer = tf.io.TFRecordWriter(file_name) tap_dict = self.pour(funnel_dict, key_type='str') feature_dict, func_dict = self._get_feature_dicts(tap_dict) self._write_tap_dict(writer, tap_dict, func_dict) feature_dict_fn = re.sub(r'_?[0-9]*.tfrecord', '.pickle', file_name) d.save_to_file(feature_dict, feature_dict_fn) writer.close() return file_name
def write_tap_dicts(self, tap_dicts, file_name, skip_keys=None): feature_dict, func_dict = self._get_feature_dicts(tap_dicts[0]) writer = tf.io.TFRecordWriter(file_name) for tap_dict in tap_dicts: if not tap_dict: continue self._write_tap_dict(writer, tap_dict, func_dict, skip_keys) feature_dict_fn = re.sub(r'_?[0-9]*.tfrecord', '.pickle', file_name) d.save_to_file(feature_dict, feature_dict_fn) writer.close()
def test_read_from_file(self): temp_json = os.path.join(self.temp_dir, 'temp.json') d.save_to_file([], temp_json) outputs = d.read_from_file( file_name=temp_json ) temp_npy = os.path.join(self.temp_dir, 'temp.npy') d.save_to_file(np.zeros(dtype=np.int64, shape=[1]), temp_npy) array = d.read_from_file( file_name=temp_npy ) th.assert_arrays_equal(self, array, np.zeros(dtype=np.int64, shape=[1]))
def save_to_file(self, file_name): if not file_name.endswith('pickle') and not file_name.endswith( 'pkl') and not file_name.endswith('dill'): raise ValueError("Waterwork can only be saved as a pickle.") save_dict = self._save_dict() d.save_to_file(save_dict, file_name)
def multi_write_examples(self, funnel_dict_iter, file_name, num_threads=1, use_threading=False, batch_size=None, file_num_offset=0, skip_fails=False, skip_keys=None, serialize_func=None): if serialize_func is None: save_dict = self._save_dict() def serialize_func(funnel_dict): jpype.attachThreadToJVM() ww = Waterwork() ww._from_save_dict(save_dict) tap_dict = ww.pour(funnel_dict, 'str', False) feature_dict, func_dict = self._get_feature_dicts(tap_dict) serial = ww._serialize_tap_dict(tap_dict, func_dict) return serial if type(funnel_dict_iter) in (list, tuple): funnel_dict_iter = (i for i in funnel_dict_iter) file_names = [] if not file_name.endswith('.tfrecord'): raise ValueError("file_name must end in '.tfrecord'") dir = file_name.split('/')[:-1] d.maybe_create_dir(*dir) for batch_num, batch in enumerate( b.batcher(funnel_dict_iter, batch_size)): if batch_num == 0: tap_dict = self.pour(batch[0], key_type='str', return_plugged=False) feature_dict, func_dict = self._get_feature_dicts(tap_dict) feature_dict_fn = re.sub(r'_?[0-9]*.tfrecord', '.pickle', file_name) d.save_to_file(feature_dict, feature_dict_fn) logging.info("Serializing batch %s", batch_num) if skip_fails: try: all_serials = mu.multi_map(serialize_func, batch, num_threads, use_threading) except Exception: logging.warn("Batched %s failed. Skipping.", batch_num) continue else: all_serials = mu.multi_map(serialize_func, batch, num_threads, use_threading) logging.info("Finished serializing batch %s", batch_num) file_num = file_num_offset + batch_num fn = file_name.replace('.tfrecord', '_' + str(file_num) + '.tfrecord') file_names.append(fn) logging.info("Writing batch %s", batch_num) writer = tf.io.TFRecordWriter(fn) for serials in all_serials: for serial in serials: writer.write(serial) logging.info("Finished writing batch %s", batch_num) writer.close() return file_names
def save_to_file(self, path): """Save the transform object to disk.""" save_dict = self._save_dict() d.save_to_file(save_dict, path)