示例#1
0
    def _evaluate_batch(self, fetch_list, data_batch, **kwargs):
        # Sanity check
        assert isinstance(fetch_list, list)
        checker.check_fetchable(fetch_list)
        assert isinstance(data_batch, DataSet)

        # Run session
        feed_dict = self._get_default_feed_dict(data_batch, is_training=False)
        batch_outputs = self.session.run(fetch_list, feed_dict)

        return batch_outputs
示例#2
0
    def _evaluate_batch(self,
                        fetch_list,
                        data_batch,
                        suppress_n_to_one=False,
                        **kwargs):
        # Sanity check
        assert isinstance(fetch_list, list)
        checker.check_fetchable(fetch_list)
        assert isinstance(data_batch, DataSet)

        # Check val_num_steps
        partition = kwargs.get('num_steps') != -1
        # Fetch states if partition
        if partition:
            # fetch_list is mutable, do not append!
            fetch_list = fetch_list + [self._state_slot.op]

        # Run session
        assert data_batch.is_rnn_input
        feed_dict = self._get_default_feed_dict(data_batch, is_training=False)
        batch_outputs = self.session.run(fetch_list, feed_dict)
        assert isinstance(batch_outputs, list)

        # Set buffer if necessary
        if partition:
            self.set_buffers(batch_outputs.pop(-1), is_training=False)

        # Clear up outputs
        outputs, al = [], data_batch.active_length
        if al is None: al = [None] * data_batch.size
        for output, op in zip(batch_outputs, fetch_list):
            # tf.Operation yields None
            if isinstance(op, tf.Operation): continue
            assert isinstance(op, (tf.Tensor, tf.Variable))
            # If output has different shape compared to fetch tensor, e.g.
            # .. gradients or variables, return directly
            if op.shape.as_list()[0] is not None:
                outputs.append(output)
                continue
            #
            assert output.shape[0] == data_batch.size
            output = [
                y[:l] if l is not None else y for y, l in zip(output, al)
            ]
            if data_batch.n_to_one and not suppress_n_to_one:
                output = [s[-1] for s in output]
            outputs.append(output)

        return outputs
示例#3
0
    def evaluate(self,
                 fetches,
                 data,
                 batch_size=None,
                 postprocessor=None,
                 verbose=False,
                 num_steps=None,
                 suppress_n_to_one=False):
        """
    Evaluate tensors based on data
    TODO: note that if num_steps != -1, outputs from a same sequence may be
          partitioned. e.g., if single_fetch, outputs will be
          [array_1_1, ..., array_1_k1, array_2_1, ..., array_2_k2, ...]
         |-------- input_1 ----------|------------ input_2 ----------|
         it's OK for seq2seq validation, but need to be post-proceeded in
         tasks like sequence classification (currently forbidden)

    :param fetches: a (tuple/list of) tf.Tensor(s) to be evaluated
    :param data: data used for evaluation
    :param batch_size: if not specified (None by default), batch_size will be
                       assigned accordingly. If assigned with a positive
                       integer, evaluation will be performed batch by batch.
    :param postprocessor: post-processor for outputs
    :return: commonly a (list of) tf.Tensor(s), each of which has the
             same batch size with the provided data
    """
        # Sanity check for fetches
        checker.check_fetchable(fetches)
        single_fetch = not isinstance(fetches, (tuple, list))
        # Wrap fetches into a list if necessary
        if single_fetch: fetches = [fetches]
        if num_steps is None: num_steps = hub.val_num_steps
        if batch_size is None: batch_size = data.size

        # Get outputs (sometimes fetches may contain operations which yields None)
        outputs = [[] for op in fetches if not isinstance(op, tf.Operation)]

        if verbose:
            bar = ProgressBar(data.get_round_length(batch_size, num_steps))
            console.show_status('Evaluating on {} ...'.format(data.name))

        for cursor, data_batch in enumerate(
                self.get_data_batches(data, batch_size, num_steps)):
            data_batch = self._sanity_check_before_use(data_batch)
            # Get batch outputs          fetches[0]  fetches[1]
            #  for FNN, batch_outputs = [np_array_1, np_array_2, ...]
            #           each np_array_k have a same batch_size
            #  for RNN, batch_outputs = [[s1_1, s1_2, ..., s1_N],       <= fetches[0]
            #                            [s2_1, s2_2, ..., s2_N], ...]  <= fetches[1]
            #           N is the batch_size, and each sk_i is a numpy array
            batch_outputs = self._evaluate_batch(
                fetches,
                data_batch,
                num_steps=num_steps,
                suppress_n_to_one=suppress_n_to_one)
            assert isinstance(batch_outputs, list)
            assert len(batch_outputs) == len(outputs)

            # Add batch_outputs to outputs accordingly
            for i, batch_output in enumerate(batch_outputs):
                assert isinstance(outputs[i], list)
                output_is_a_batch = fetches[i].shape.as_list()[0] is None
                if self.input_type is InputTypes.RNN_BATCH and output_is_a_batch:
                    # batch_output is [s1_1, s1_2, ..., s1_N]
                    assert isinstance(batch_output, list)
                    outputs[i] = outputs[i] + batch_output
                else:
                    # batch_output is a numpy array of length batch_size
                    outputs[i].append(batch_output)

            # Show progress bar if necessary
            if verbose: bar.show(cursor + 1)

        # Merge outputs if necessary
        if self.input_type is InputTypes.BATCH:
            outputs = [
                np.concatenate(array_list, axis=0) for array_list in outputs
            ]

        # Post-proceed and return
        if postprocessor is not None:
            assert callable(postprocessor)
            outputs = postprocessor(outputs)

        assert isinstance(outputs, list)
        if single_fetch: outputs = outputs[0]
        return outputs