コード例 #1
0
 def predict_dataset(
         self, dataset) -> Generator[NetworkPredictionResult, None, None]:
     dataset_gen = self.create_dataset_inputs(
         dataset,
         self.batch_size,
         self.network_proto.features,
         self.network_proto.backend.shuffle_buffer_size,
         mode='test')
     out = self.model.predict(dataset_gen, )
     for softmax, params, output_seq_len in zip(*out):
         softmax = np.roll(softmax, 1, axis=1)  # fix bla
         # decode encoded params from json. On python<=3.5 this are bytes, else it already is a str
         enc_param = params[0]
         enc_param = json.loads(
             enc_param.decode("utf-8") if isinstance(enc_param, bytes
                                                     ) else enc_param)
         decoded = self.ctc_decoder.decode(softmax[:output_seq_len[0]])
         # return prediction result
         yield NetworkPredictionResult(
             softmax=softmax,
             output_length=output_seq_len,
             decoded=decoded,
             params=enc_param,
             ground_truth=None,
         )
コード例 #2
0
 def predict_raw_batch(self, x: np.array, len_x: np.array) -> Generator[NetworkPredictionResult, None, None]:
     out = self.session.run(
         [self.softmax, self.output_seq_len, self.decoded],
         feed_dict={
             self.inputs: x / 255,
             self.input_seq_len: len_x,
             self.dropout_rate: 0,
         })
     out = out[0:2] + [TensorflowModel.__sparse_to_lists(out[2])]
     for sm, sl, dec in zip(*out):
         pred = NetworkPredictionResult(softmax=sm,
                                        output_length=sl,
                                        decoded=dec,
                                        )
         pred.softmax = np.roll(pred.softmax, 1, axis=1)
         l, s = pred.softmax, pred.output_length
         pred.decoded = self.ctc_decoder.decode(l[:s])
         yield pred
コード例 #3
0
 def predict_raw(self, x, len_x) -> Generator[NetworkPredictionResult, None, None]:
     out = self.session.run(
         [self.softmax, self.output_seq_len, self.decoded],
         feed_dict={
             self.inputs: x,
             self.input_seq_len: len_x,
             self.dropout_rate: 0,
         })
     out = out[0:2] + [TensorflowModel.__sparse_to_lists(out[2])]
     for sm, sl, dec in zip(*out):
         yield NetworkPredictionResult(softmax=sm,
                                       output_length=sl,
                                       decoded=dec,
                                       )
コード例 #4
0
 def predict_raw_batch(self, x: np.array, len_x: np.array) -> Generator[NetworkPredictionResult, None, None]:
     out = self.model.predict_on_batch(
         [tf.convert_to_tensor(x / 255.0, dtype=tf.float32),
          tf.convert_to_tensor(len_x, dtype=tf.int32),
          tf.zeros((len(x), 1), dtype=tf.string)],
     )
     for sm, params, sl in zip(*out):
         sl = sl[0]
         sm = np.roll(sm, 1, axis=1)
         decoded = self.ctc_decoder.decode(sm[:sl])
         pred = NetworkPredictionResult(softmax=sm,
                                        output_length=sl,
                                        decoded=decoded,
                                        )
         yield pred
コード例 #5
0
 def predict_dataset(self) -> Generator[NetworkPredictionResult, None, None]:
     out = self.session.run(
             [self.softmax, self.output_seq_len, self.serialized_params, self.decoded, self.targets],
             feed_dict={
                 self.dropout_rate: 0,
             })
     out = out[0:3] + list(map(TensorflowModel.__sparse_to_lists, out[3:5]))
     for sm, length, param, dec, gt in zip(*out):
         # decode encoded params from json. On python<=3.5 this are bytes, else it already is a str
         enc_param = param[0]
         enc_param = json.loads(enc_param.decode("utf-8") if isinstance(enc_param, bytes) else enc_param)
         # return prediction result
         yield NetworkPredictionResult(softmax=sm,
                                       output_length=length,
                                       decoded=dec,
                                       params=enc_param,
                                       ground_truth=self.codec.decode(gt) if gt is not None else None,
                                       )