コード例 #1
0
ファイル: classifier.py プロジェクト: zkmartin/tframe
    def evaluate_model(self, data, with_false=False):
        if self.outputs is None:
            raise ValueError('Model not built yet')
        if self._session is None:
            self.launch_model(overwrite=False)
        if not self.metric_is_accuracy:
            raise ValueError('Currently this only supports accuracy')

        possibilities, accuracy = self._session.run(
            [self._probabilities, self._metric],
            feed_dict=self._get_default_feed_dict(data, is_training=False))
        accuracy *= 100

        console.show_status('Accuracy on test set is {:.2f}%'.format(accuracy))

        if with_false:
            assert isinstance(data, TFData)
            predictions = np.argmax(possibilities, axis=1).squeeze()
            data.update(predictions=predictions)
            labels = data.scalar_labels
            false_indices = [
                i for i in range(data.sample_num)
                if predictions[i] != labels[i]
            ]

            from tframe import ImageViewer
            vr = ImageViewer(data[false_indices])
            vr.show()
コード例 #2
0
    def evaluate_model(self,
                       data,
                       batch_size=None,
                       extractor=None,
                       export_false=False,
                       **kwargs):
        # Feed data set into model and get results
        false_sample_list = []
        false_label_list = []
        true_label_list = []
        num_samples = 0

        console.show_status('Evaluating classifier ...')
        for batch in self.get_data_batches(data, batch_size):
            assert isinstance(batch, DataSet) and batch.targets is not None
            # Get predictions
            preds = self._classify_batch(batch, extractor)
            # Get true labels in dense format
            if batch.targets.shape[-1] > 1:
                targets = batch.targets.reshape(-1, batch.targets.shape[-1])
            else:
                targets = batch.targets
            num_samples += len(targets)
            true_labels = misc.convert_to_dense_labels(targets)
            if len(true_labels) < len(preds):
                assert len(true_labels) == 1
                true_labels = np.concatenate((true_labels, ) * len(preds))
            # Select false samples
            false_indices = np.argwhere(preds != true_labels)
            if false_indices.size == 0: continue
            features = batch.features
            if self.input_type is InputTypes.RNN_BATCH:
                features = np.reshape(features, [-1, *features.shape[2:]])
            false_indices = np.reshape(false_indices, false_indices.size)
            false_sample_list.append(features[false_indices])
            false_label_list.append(preds[false_indices])
            true_label_list.append(true_labels[false_indices])

        # Concatenate
        if len(false_sample_list) > 0:
            false_sample_list = np.concatenate(false_sample_list)
            false_label_list = np.concatenate(false_label_list)
            true_label_list = np.concatenate(true_label_list)

        # Show accuracy
        accuracy = (num_samples - len(false_sample_list)) / num_samples * 100
        console.supplement('Accuracy on {} is {:.2f}%'.format(
            data.name, accuracy))

        # Try to export false samples
        if export_false and accuracy < 100:
            false_set = DataSet(features=false_sample_list,
                                targets=true_label_list)
            if hasattr(data, 'properties'):
                false_set.properties = data.properties
            false_set.data_dict[pedia.predictions] = false_label_list
            from tframe.data.images.image_viewer import ImageViewer
            vr = ImageViewer(false_set)
            vr.show()
コード例 #3
0
ファイル: classifier.py プロジェクト: rscv5/tframe
  def evaluate_model(self, data, batch_size=None, extractor=None,
                     export_false=False, **kwargs):
    # If not necessary, use Predictor's evaluate_model method
    metric_is_accuracy = self.eval_metric.name.lower() == 'accuracy'
    if not export_false or not metric_is_accuracy:
      result = super().evaluate_model(data, batch_size, **kwargs)
      if metric_is_accuracy: result *= 100
      return result

    console.show_status('Evaluating classifier on {} ...'.format(data.name))

    acc_slot = self.metrics_manager.get_slot_by_name('accuracy')
    assert isinstance(acc_slot, MetricSlot)
    acc_foreach = acc_slot.quantity_definition.quantities
    results = self.evaluate(acc_foreach, data, batch_size, extractor,
                            verbose=hub.val_progress_bar)
    if self.input_type is InputTypes.RNN_BATCH:
      results = np.concatenate([y.flatten() for y in results])
    accuracy = np.mean(results) * 100

    # Show accuracy
    console.supplement('Accuracy on {} is {:.3f}%'.format(data.name, accuracy))

    # export_false option is valid for images only
    if export_false and accuracy < 100.0:
      assert self.input_type is InputTypes.BATCH
      assert isinstance(data, DataSet)
      assert data.features is not None and data.targets is not None
      top_k = hub.export_top_k if hub.export_top_k > 0 else 3

      probs = self.classify(data, batch_size, extractor, return_probs=True)
      probs_sorted = np.fliplr(np.sort(probs, axis=-1))
      class_sorted = np.fliplr(np.argsort(probs, axis=-1))
      preds = class_sorted[:, 0]

      false_indices = np.argwhere(results == 0).flatten()
      false_preds = preds[false_indices]

      probs_sorted = probs_sorted[false_indices, :top_k]
      class_sorted = class_sorted[false_indices, :top_k]
      false_set = data[false_indices]

      false_set.properties[pedia.predictions] = false_preds
      false_set.properties[pedia.top_k_label] = class_sorted
      false_set.properties[pedia.top_k_prob] = probs_sorted

      from tframe.data.images.image_viewer import ImageViewer
      vr = ImageViewer(false_set)
      vr.show()

    # Return accuracy
    return accuracy
コード例 #4
0
    def evaluate_model(self,
                       data,
                       batch_size=None,
                       extractor=None,
                       export_false=False,
                       **kwargs):
        console.show_status('Evaluating classifier ...')
        assert isinstance(data, DataSet)

        preds = self.classify(data,
                              batch_size=batch_size,
                              extractor=GPAT.raw_extractor)
        # preds = GPAT.test_all_prods_op(data, preds)
        preds = np.argmax(preds, axis=-1)
        # targets = data.labels
        targets = data.targets
        # targets = np.reshape(targets, (targets.shape[0], -1))
        labels = convert_to_dense_labels(targets)
        false_indices = [
            ind for ind in range(len(preds)) if preds[ind] != labels[ind]
        ]
        correct_indices = [
            ind for ind in range(len(preds)) if ind not in false_indices
        ]
        assert len(false_indices) + len(correct_indices) == len(preds)
        false_labels = labels[false_indices]
        counter = Counter(false_labels)
        cou = counter.most_common(len(list(counter.keys())))

        false_samples = data[false_indices]
        short_samples = [
            arr for arr in false_samples.features if arr.size < 32000
        ]

        false_samples_lengths = [data.lengths[i] for i in false_indices]
        less_audio_length = [
            false_samples_lengths[i] for i in range(len(false_samples_lengths))
            if false_samples_lengths[i] < 32000
        ]

        console.show_status('total_num :')
        console.pprint(len(labels))
        console.show_status('False_labels_num:')
        console.pprint(len(false_labels))
        console.show_status('The false num of each label:')
        console.pprint(cou)
        console.show_status('Short samples num:')
        console.pprint(len(less_audio_length))

        return correct_indices, false_indices
コード例 #5
0
ファイル: classifier.py プロジェクト: ssh352/tframe
    def evaluate_model(self, data, batch_size=None, extractor=None, **kwargs):
        """This method is a mess."""
        if hub.take_down_confusion_matrix:
            # TODO: (william) please refactor this method
            cm = self.evaluate_pro(data,
                                   batch_size,
                                   verbose=kwargs.get('verbose', False),
                                   show_class_detail=True,
                                   show_confusion_matrix=True)
            # Take down confusion matrix
            from tframe import context
            agent = context.trainer.model.agent
            agent.take_notes('Confusion Matrix on {}:'.format(data.name),
                             False)
            agent.take_notes('\n' + cm.matrix_table().content)
            agent.take_notes('Evaluation Result on {}:'.format(data.name),
                             False)
            agent.take_notes('\n' + cm.make_table().content)
            return cm.accuracy

        # If not necessary, use Predictor's evaluate_model method
        metric_is_accuracy = self.eval_metric.name.lower() == 'accuracy'
        if not metric_is_accuracy:
            result = super().evaluate_model(data, batch_size, **kwargs)
            if metric_is_accuracy: result *= 100
            return result

        console.show_status('Evaluating classifier on {} ...'.format(
            data.name))

        acc_slot = self.metrics_manager.get_slot_by_name('accuracy')
        assert isinstance(acc_slot, MetricSlot)
        acc_foreach = acc_slot.quantity_definition.quantities
        results = self.evaluate(acc_foreach,
                                data,
                                batch_size,
                                extractor,
                                verbose=hub.val_progress_bar)
        if self.input_type is InputTypes.RNN_BATCH:
            results = np.concatenate([y.flatten() for y in results])
        accuracy = np.mean(results) * 100

        # Show accuracy
        console.supplement('Accuracy on {} is {:.3f}%'.format(
            data.name, accuracy))

        # Return accuracy
        return accuracy
コード例 #6
0
def load_cifar10tfd(data_dir, validation_size=0):
  train_filename = 'cifar-10-train.tfd'
  test_filename = 'cifar-10-test.tfd'
  train_path = os.path.join(data_dir, train_filename)
  test_path = os.path.join(data_dir, test_filename)

  console.show_status('Loading CIFAR-10 (TFD)')
  train_val_data = TFData.load(train_path)
  val_data = train_val_data.pop_data(validation_size)
  train_data = train_val_data

  test_data = TFData.load(test_path)

  data = {}
  data[pedia.training] = train_data
  data[pedia.validation] = val_data
  data[pedia.test] = test_data

  console.show_status('CIFAR-10 loaded')
  console.supplement('Training Set:')
  console.supplement('images: {}'.format(
    data[pedia.training][pedia.features].shape), 2)
  console.supplement('labels: {}'.format(
    data[pedia.training][pedia.targets].shape), 2)
  console.supplement('Validation Set:')
  console.supplement('images: {}'.format(
    data[pedia.validation][pedia.features].shape), 2)
  console.supplement('labels: {}'.format(
    data[pedia.validation][pedia.targets].shape), 2)
  console.supplement('Test Set:')
  console.supplement('images: {}'.format(
    data[pedia.test][pedia.features].shape), 2)
  console.supplement('labels: {}'.format(
    data[pedia.test][pedia.targets].shape), 2)

  return data
コード例 #7
0
def load_mnist(data_dir, flatten=False, one_hot=False,
               validation_size=5000):
  console.show_status('Loading MNIST ...')
  from tensorflow.examples.tutorials.mnist import input_data
  mnist = input_data.read_data_sets(data_dir, one_hot=one_hot, reshape=flatten,
                                    validation_size=validation_size)
  console.show_status('MNIST loaded')
  console.supplement('Training Set:')
  console.supplement('images: {}'.format(mnist.train.images.shape), 2)
  console.supplement('labels: {}'.format(mnist.train.labels.shape), 2)
  console.supplement('Validation Set:')
  console.supplement('images: {}'.format(mnist.validation.images.shape), 2)
  console.supplement('labels: {}'.format(mnist.validation.labels.shape), 2)
  console.supplement('Test Set:')
  console.supplement('images: {}'.format(mnist.test.images.shape), 2)
  console.supplement('labels: {}'.format(mnist.test.labels.shape), 2)

  data = {}
  data[pedia.training] = TFData(mnist.train.images, targets=mnist.train.labels)
  data[pedia.validation] = TFData(mnist.validation.images,
                              targets=mnist.validation.labels)
  data[pedia.test] = TFData(mnist.test.images, targets=mnist.test.labels)

  return data
コード例 #8
0
def load_cifar10(data_dir, flatten=False, one_hot=False, validation_size=10000):
  console.show_status('Loading CIFAR-10 ...')

  # region : Download, tar data

  # Check data directory
  if not os.path.exists(data_dir): os.makedirs(data_dir)

  # Get data file name and path
  DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
  filename = DATA_URL.split('/')[-1]
  filepath = os.path.join(data_dir, filename)
  # If data does not exists, download from Alex's website
  if not os.path.exists(filepath):
    console.show_status('Downloading ...')
    start_time = time.time()
    def _progress(count, block_size, total_size):
      console.clear_line()
      console.print_progress(count*block_size, total_size, start_time)
    filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
    statinfo = os.stat(filepath)
    console.show_status('Successfully downloaded {} {} bytes.'.format(
      filename, statinfo.st_size))

  # Tar file
  tarfile.open(filepath, 'r:gz').extractall(data_dir)
  # Update data directory
  data_dir = os.path.join(data_dir, 'cifar-10-batches-py')

  # endregion : Download, tar data

  # Define functions
  def pickle_load(f):
    version = platform.python_version_tuple()
    if version[0] == '2':
      return pickle.load(f)
    elif version[0] == '3':
      return pickle.load(f, encoding='latin1')
    raise ValueError('Invalid python version: {}'.format(version))

  def load_batch(filename):
    with open(filename, 'rb') as f:
      datadict = pickle_load(f)
      X = datadict['data']
      X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype('float')
      Y = np.array(datadict['labels'])
      if flatten:
        X = X.reshape(10000, -1)
      if one_hot:
        def dense_to_one_hot(labels_dense, num_classes):
          """Convert class labels from scalars to one-hot vectors."""
          num_labels = labels_dense.shape[0]
          index_offset = np.arange(num_labels) * num_classes
          labels_one_hot = np.zeros((num_labels, num_classes))
          labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
          return labels_one_hot
        Y = dense_to_one_hot(Y, 10)

      return X, Y

  # Load data from files
  xs = []
  ys = []
  for b in range(1, 6):
    f = os.path.join(data_dir,  'data_batch_{}'.format(b))
    X, Y = load_batch(f)
    xs.append(X)
    ys.append(Y)
  Xtr = np.concatenate(xs)
  Ytr = np.concatenate(ys)
  del X, Y
  Xte, Yte = load_batch(os.path.join(data_dir, 'test_batch'))

  # Pack data into instances of TFData and form a data dict
  data = {}
  total = 50000
  training_size = total - validation_size

  classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
             'horse', 'ship', 'truck']
  mask = list(range(training_size))
  data[pedia.training] = TFData(Xtr[mask], targets=Ytr[mask], classes=classes)
  mask = list(range(training_size, total))
  data[pedia.validation] = TFData(Xtr[mask], targets=Ytr[mask], classes=classes)
  data[pedia.test] = TFData(Xte, targets=Yte, classes=classes)

  console.show_status('CIFAR-10 loaded')
  console.supplement('Training Set:')
  console.supplement('images: {}'.format(
    data[pedia.training][pedia.features].shape), 2)
  console.supplement('labels: {}'.format(
    data[pedia.training][pedia.targets].shape), 2)
  console.supplement('Validation Set:')
  console.supplement('images: {}'.format(
    data[pedia.validation][pedia.features].shape), 2)
  console.supplement('labels: {}'.format(
    data[pedia.validation][pedia.targets].shape), 2)
  console.supplement('Test Set:')
  console.supplement('images: {}'.format(
    data[pedia.test][pedia.features].shape), 2)
  console.supplement('labels: {}'.format(
    data[pedia.test][pedia.targets].shape), 2)

  return data