Beispiel #1
0
    def infer_step(self, inputs):
        """
        Run one inference step.
        """
        if self.do_generation:
            if self.generator.num_samples:
                inputs = {
                    name: repeat_array_or_tensor(array_or_tensor, self.place,
                                                 self.generator.num_samples)
                    for name, array_or_tensor in inputs.items()
                }

            if self.mem_efficient:
                predictions = []
                for idx in range(0, len(inputs["data_id"]), self.batch_size):
                    part_inputs = {
                        name:
                        slice_array_or_tensor(array_or_tensor, self.place, idx,
                                              idx + self.batch_size)
                        for name, array_or_tensor in inputs.items()
                    }
                    part_outputs = self._run_generation(part_inputs)
                    predictions.extend(part_outputs)
            else:
                predictions = self._run_generation(inputs)
            return predictions
        else:
            return self._execute(self.infer_program,
                                 self._get_feed(inputs, is_infer=True),
                                 self.infer_fetch_dict)
Beispiel #2
0
    def infer_step(self, inputs):
        """
        Run one inference step.
        """
        if self.do_generation:
            batch_size = len(inputs["data_id"])
            new_bsz = batch_size * self.latent_type_size
            inputs = {
                name: repeat_array_or_tensor(array_or_tensor, self.place, self.latent_type_size)
                for name, array_or_tensor in inputs.items()
            }
            # Add latent_id
            inputs["latent_id"] = np.array(
                [i for i in range(self.latent_type_size) for _ in range(batch_size)],
                dtype="int64"
            ).reshape([-1, 1])

            return super(Plato, self).infer_step(inputs)
        else:
            return self._execute(
                self.infer_program,
                self._get_feed(inputs, is_infer=True),
                self.infer_fetch_dict)