示例#1
0
    def run(self, process_index):
        open_options = self._get_open_options()
        with PathManager.open(self.output_path, **open_options) as fout:
            gzip_fout = (gzip.GzipFile(mode="wb", fileobj=fout)
                         if self.use_gzip else None)

            while True:
                raw_input_tuple, model_outputs = self.results.get()
                if model_outputs is None:
                    # None means shutdown
                    break

                if self.ndigits_precision:
                    model_outputs = round_seq(model_outputs,
                                              self.ndigits_precision)

                # multi-encoder output
                if isinstance(model_outputs, tuple):
                    self._write(fout, gzip_fout,
                                zip(*raw_input_tuple, *model_outputs))
                # single encoder output
                elif isinstance(model_outputs, list):
                    self._write(fout, gzip_fout,
                                zip(*raw_input_tuple, model_outputs))
                else:
                    raise Exception(
                        "Expecting tuple or tensor types for model_outputs")

            if self.use_gzip:
                gzip_fout.close()
示例#2
0
    def run(self, process_index):
        open_options = self._get_open_options()
        with PathManager.open(self.output_path, **open_options) as fout:
            gzip_fout = (gzip.GzipFile(mode="wb", fileobj=fout)
                         if self.use_gzip else None)

            while True:
                raw_input_tuple, model_outputs = self.results.get()
                if model_outputs is None:
                    # None means shutdown
                    break

                # multi-encoder output
                if isinstance(model_outputs, tuple):
                    model_outputs_tuple = []
                    for i, m in enumerate(model_outputs):
                        if self.output_columns and i not in self.output_columns:
                            continue

                        if self.ndigits_precision:
                            model_outputs_tuple.append(
                                round_seq(m.tolist(), self.ndigits_precision))
                        else:
                            model_outputs_tuple.append(m.tolist())

                    self._write(fout, gzip_fout,
                                zip(*raw_input_tuple, *model_outputs_tuple))
                # single encoder output
                elif isinstance(model_outputs, list):
                    model_outputs_list = model_outputs.tolist()

                    if self.ndigits_precision:
                        model_outputs_list = round_seq(model_outputs_list,
                                                       self.ndigits_precision)

                    self._write(fout, gzip_fout,
                                zip(*raw_input_tuple, model_outputs_list))
                else:
                    raise Exception(
                        "Expecting tuple or torchTensor types for model_outputs"
                    )

            if self.use_gzip:
                gzip_fout.close()
示例#3
0
    def test_round_seq(self):
        arr = [[[0.0001], [0.0002], [0.0003]]]
        arr_rounded = round_seq(arr, 1)

        self.assertEqual(str(arr_rounded[0][0][0]), "0.0")