def run(self, process_index): open_options = self._get_open_options() with PathManager.open(self.output_path, **open_options) as fout: gzip_fout = (gzip.GzipFile(mode="wb", fileobj=fout) if self.use_gzip else None) while True: raw_input_tuple, model_outputs = self.results.get() if model_outputs is None: # None means shutdown break if self.ndigits_precision: model_outputs = round_seq(model_outputs, self.ndigits_precision) # multi-encoder output if isinstance(model_outputs, tuple): self._write(fout, gzip_fout, zip(*raw_input_tuple, *model_outputs)) # single encoder output elif isinstance(model_outputs, list): self._write(fout, gzip_fout, zip(*raw_input_tuple, model_outputs)) else: raise Exception( "Expecting tuple or tensor types for model_outputs") if self.use_gzip: gzip_fout.close()
def run(self, process_index): open_options = self._get_open_options() with PathManager.open(self.output_path, **open_options) as fout: gzip_fout = (gzip.GzipFile(mode="wb", fileobj=fout) if self.use_gzip else None) while True: raw_input_tuple, model_outputs = self.results.get() if model_outputs is None: # None means shutdown break # multi-encoder output if isinstance(model_outputs, tuple): model_outputs_tuple = [] for i, m in enumerate(model_outputs): if self.output_columns and i not in self.output_columns: continue if self.ndigits_precision: model_outputs_tuple.append( round_seq(m.tolist(), self.ndigits_precision)) else: model_outputs_tuple.append(m.tolist()) self._write(fout, gzip_fout, zip(*raw_input_tuple, *model_outputs_tuple)) # single encoder output elif isinstance(model_outputs, list): model_outputs_list = model_outputs.tolist() if self.ndigits_precision: model_outputs_list = round_seq(model_outputs_list, self.ndigits_precision) self._write(fout, gzip_fout, zip(*raw_input_tuple, model_outputs_list)) else: raise Exception( "Expecting tuple or torchTensor types for model_outputs" ) if self.use_gzip: gzip_fout.close()
def test_round_seq(self): arr = [[[0.0001], [0.0002], [0.0003]]] arr_rounded = round_seq(arr, 1) self.assertEqual(str(arr_rounded[0][0][0]), "0.0")