Example #1
0
            def _evaluate(indices, name):
                """Evaluates the samples with the provided indices."""
                data_iterator_val = batch_iterator(indices,
                                                   batch_size=self.batch_size,
                                                   shuffle=False,
                                                   allow_smaller_batch=True,
                                                   repeat=False)
                feed_dict_val = self._construct_feed_dict(
                    data_iterator_val, False)
                cummulative_acc = 0.0
                num_samples = 0
                while feed_dict_val is not None:
                    val_acc, batch_size_actual = session.run(
                        (self.accuracy, self.batch_size_actual),
                        feed_dict=feed_dict_val)
                    cummulative_acc += val_acc * batch_size_actual
                    num_samples += batch_size_actual
                    feed_dict_val = self._construct_feed_dict(
                        data_iterator_val, False)
                if num_samples > 0:
                    cummulative_acc /= num_samples

                if self.enable_summaries:
                    summary = tf.Summary()
                    summary.value.add(tag='ClassificationModel/' + name +
                                      '_acc',
                                      simple_value=cummulative_acc)
                    iter_cls_total = session.run(self.iter_cls_total)
                    summary_writer.add_summary(summary, iter_cls_total)
                    summary_writer.flush()

                return cummulative_acc
  def edge_iterator(self, data, batch_size, labeling):
    """An iterator over graph edges.

    Args:
      data: A CotrainDataset object used to extract the features and labels.
      batch_size:  An integer representing the desired batch size.
      labeling: A string which can be `ll`, `lu` or `uu`, that is used to
        represent the type of edges to return, where `ll` refers to
        labeled-labeled, `lu` refers to labeled-unlabeled, and `uu` refers to
        unlabeled-unlabeled.

    Yields:
      indices_src, indices_tgt, features_src, features_tgt, labels_src,
      labels_tgt
    """
    if labeling == 'll':
      edges = data.get_edges(
          src_labeled=True, tgt_labeled=True, label_must_match=True)
    elif labeling == 'lu':
      edges = (
          data.get_edges(src_labeled=True, tgt_labeled=False) +
          data.get_edges(src_labeled=False, tgt_labeled=True))
    elif labeling == 'uu':
      edges = data.get_edges(src_labeled=False, tgt_labeled=False)
    else:
      raise ValueError('Unsupported value for parameter `labeling`.')

    if not edges:
      indices = np.zeros(shape=(0,), dtype=np.int32)
      features = np.zeros(
          shape=[
              0,
          ] + list(data.features_shape), dtype=np.float32)
      labels = np.zeros(shape=(0,), dtype=np.int64)
      while True:
        yield (indices, indices, features, features, labels, labels)

    edges = np.stack([(e.src, e.tgt) for e in edges])
    iterator = batch_iterator(
        inputs=edges,
        batch_size=batch_size,
        shuffle=True,
        allow_smaller_batch=False,
        repeat=True)

    for edge in iterator:
      indices_src = edge[:, 0]
      indices_tgt = edge[:, 1]
      features_src = data.get_features(indices_src)
      features_tgt = data.get_features(indices_tgt)
      labels_src = data.get_labels(indices_src)
      labels_tgt = data.get_labels(indices_tgt)
      yield (indices_src, indices_tgt, features_src, features_tgt, labels_src,
             labels_tgt)
Example #3
0
    def train(self, data, session=None, **kwargs):
        """Train the classification model on the provided dataset.

    Arguments:
      data: A CotrainDataset object.
      session: A TensorFlow session or None.
      **kwargs: Other keyword arguments.
    Returns:
      best_test_acc: A float representing the test accuracy at the iteration
        where the validation accuracy is maximum.
      best_val_acc: A float representing the best validation accuracy.
    """
        summary_writer = kwargs['summary_writer']
        logging.info('Training classifier...')

        if not self.is_initialized:
            self.is_initialized = True
            logging.info('Weight decay value: %f',
                         session.run(self.weight_decay_var))
        else:
            if self.weight_decay_update is not None:
                session.run(self.weight_decay_update)
                logging.info('New weight decay value:  %f',
                             session.run(self.weight_decay_var))
            # Reset the optimizer state (e.g., momentum).
            session.run(self.reset_optimizer)

        if not self.warm_start:
            # Re-initialize variables.
            initializers = [v.initializer for v in self.variables.values()]
            initializers.append(self.global_step.initializer)
            session.run(initializers)

        # Construct data iterator.
        logging.info('Training classifier with %d samples...',
                     data.num_train())
        train_indices = data.get_indices_train()
        unlabeled_indices = data.get_indices_unlabeled()
        val_indices = data.get_indices_val()
        test_indices = data.get_indices_test()
        # Create an iterator for labeled samples for the supervised term.
        data_iterator_train = batch_iterator(train_indices,
                                             batch_size=self.batch_size,
                                             shuffle=True,
                                             allow_smaller_batch=False,
                                             repeat=True)
        # Create iterators for ll, lu, uu pairs of samples for the agreement term.
        pair_ll_iterator = self.pair_iterator(train_indices, train_indices,
                                              self.num_pairs_reg, data)
        pair_lu_iterator = self.pair_iterator(train_indices, unlabeled_indices,
                                              self.num_pairs_reg, data)
        pair_uu_iterator = self.pair_iterator(unlabeled_indices,
                                              unlabeled_indices,
                                              self.num_pairs_reg, data)

        step = 0
        iter_below_tol = 0
        min_num_iter = self.min_num_iter
        has_converged = step >= self.max_num_iter
        prev_loss_val = np.inf
        best_test_acc = -1
        best_val_acc = -1
        checkpoint_saved = False
        while not has_converged:
            feed_dict = self._construct_feed_dict(data_iterator_train, True,
                                                  pair_ll_iterator,
                                                  pair_lu_iterator,
                                                  pair_uu_iterator)
            if self.enable_summaries and step % self.summary_step == 0:
                loss_val, summary, _ = session.run(
                    [self.loss_op, self.summary_op, self.train_op],
                    feed_dict=feed_dict)
                iter_cls_total = session.run(self.iter_cls_total)
                summary_writer.add_summary(summary, iter_cls_total)
                summary_writer.flush()
            else:
                loss_val, _ = session.run((self.loss_op, self.train_op),
                                          feed_dict=feed_dict)

            # Log the loss, if necessary.
            if step % self.logging_step == 0:
                logging.info('Classification step %6d | Loss: %10.4f', step,
                             loss_val)

            # Evaluate, if necessary.
            def _evaluate(indices, name):
                """Evaluates the samples with the provided indices."""
                data_iterator_val = batch_iterator(indices,
                                                   batch_size=self.batch_size,
                                                   shuffle=False,
                                                   allow_smaller_batch=True,
                                                   repeat=False)
                feed_dict_val = self._construct_feed_dict(
                    data_iterator_val, False)
                cummulative_acc = 0.0
                num_samples = 0
                while feed_dict_val is not None:
                    val_acc, batch_size_actual = session.run(
                        (self.accuracy, self.batch_size_actual),
                        feed_dict=feed_dict_val)
                    cummulative_acc += val_acc * batch_size_actual
                    num_samples += batch_size_actual
                    feed_dict_val = self._construct_feed_dict(
                        data_iterator_val, False)
                if num_samples > 0:
                    cummulative_acc /= num_samples

                if self.enable_summaries:
                    summary = tf.Summary()
                    summary.value.add(tag='ClassificationModel/' + name +
                                      '_acc',
                                      simple_value=cummulative_acc)
                    iter_cls_total = session.run(self.iter_cls_total)
                    summary_writer.add_summary(summary, iter_cls_total)
                    summary_writer.flush()

                return cummulative_acc

            # Run validation, if necessary.
            if step % self.eval_step == 0:
                logging.info('Evaluating on %d validation samples...',
                             len(val_indices))
                val_acc = _evaluate(val_indices, 'val_acc')
                logging.info('Evaluating on %d test samples...',
                             len(test_indices))
                test_acc = _evaluate(test_indices, 'test_acc')

                if step % self.logging_step == 0 or val_acc > best_val_acc:
                    logging.info(
                        'Classification step %6d | Loss: %10.4f | '
                        'val_acc: %10.4f | test_acc: %10.4f', step, loss_val,
                        val_acc, test_acc)
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    best_test_acc = test_acc
                    if self.checkpoint_path:
                        self.saver.save(session,
                                        self.checkpoint_path,
                                        write_meta_graph=False)
                        checkpoint_saved = True
                    # Go for at least num_iter_after_best_val more iterations.
                    min_num_iter = max(self.min_num_iter,
                                       step + self.num_iter_after_best_val)
                    logging.info(
                        'Achieved best validation. '
                        'Extending to at least %d iterations...', min_num_iter)

            step += 1
            has_converged, iter_below_tol = self.check_convergence(
                prev_loss_val,
                loss_val,
                step,
                self.max_num_iter,
                iter_below_tol,
                min_num_iter=min_num_iter)
            session.run(self.iter_cls_total_update)
            prev_loss_val = loss_val

        # Return to the best model.
        if checkpoint_saved:
            logging.info('Restoring best model...')
            self.saver.restore(session, self.checkpoint_path)

        return best_test_acc, best_val_acc
Example #4
0
  def train(self, data, session=None, **kwargs):
    """Train an agreement model."""

    summary_writer = kwargs['summary_writer']
    logging.info('Training agreement model...')

    if not self.is_initialized:
      self.is_initialized = True
    else:
      if self.weight_decay_update is not None:
        session.run(self.weight_decay_update)
        logging.info('New weight decay value:  %f',
                     session.run(self.weight_decay_var))

    # Construct data iterator.
    labeled_samples = data.get_indices_train()
    num_labeled_samples = len(labeled_samples)
    num_samples_train = num_labeled_samples * num_labeled_samples
    num_samples_val = min(int(num_samples_train * self.percent_val),
                          self.max_num_samples_val)

    if num_samples_train == 0:
      logging.info('No samples to train agreement. Skipping...')
      return None

    if not self.warm_start:
      # Re-initialize variables.
      initializers = [v.initializer for v in self.trainable_vars]
      initializers.append(self.global_step.initializer)
      session.run(initializers)
      # Reset the optimizer state (e.g., momentum).
      session.run(self.reset_optimizer)

    logging.info(
        'Training agreement with %d samples and validation on %d samples.',
        num_samples_train, num_samples_val)

    # Compute ratio of positives to negative samples.
    labeled_samples_labels = data.get_labels(labeled_samples)
    ratio_pos_to_neg = self._compute_ratio_pos_neg(labeled_samples_labels)
    # Select a validation set out of all pairs of labeled samples.
    neighbors_val, agreement_labels_val = self._select_val_set(
        labeled_samples, num_samples_val, data, ratio_pos_to_neg)
    # Create a train iterator that potentially excludes the validation samples.
    data_iterator_train = self._train_iterator(
        labeled_samples, neighbors_val, data, ratio_pos_to_neg=ratio_pos_to_neg)
    # Start training.
    best_val_acc = -1
    checkpoint_saved = False
    step = 0
    iter_below_tol = 0
    min_num_iter = self.min_num_iter
    has_converged = step >= self.max_num_iter
    if not has_converged:
      self.num_iter_trained += 1
    prev_loss_val = np.inf
    while not has_converged:
      feed_dict = self._construct_feed_dict(data_iterator_train, is_train=True)

      if self.enable_summaries and step % self.summary_step == 0:
        loss_val, summary, _ = session.run(
            [self.loss_op, self.summary_op, self.train_op],
            feed_dict=feed_dict)
        iter_total = session.run(self.iter_agr_total)
        summary_writer.add_summary(summary, iter_total)
        summary_writer.flush()
      else:
        loss_val, _ = session.run((self.loss_op, self.train_op),
                                  feed_dict=feed_dict)

      # Log the loss, if necessary.
      if step % self.logging_step == 0:
        logging.info('Agreement step %6d | Loss: %10.4f', step, loss_val)

      # Run validation, if necessary.
      if step % self.eval_step == 0:
        if num_samples_val == 0:
          logging.info('Skipping validation. No validation samples available.')
          break
        data_iterator_val = batch_iterator(
            neighbors_val,
            agreement_labels_val,
            self.batch_size,
            shuffle=False,
            allow_smaller_batch=True,
            repeat=False)
        feed_dict_val = self._construct_feed_dict(
            data_iterator_val, is_train=False)
        cummulative_val_acc = 0.0
        while feed_dict_val is not None:
          val_acc, batch_size_actual = session.run(
              (self.accuracy, self.batch_size_actual), feed_dict=feed_dict_val)
          cummulative_val_acc += val_acc * batch_size_actual
          feed_dict_val = self._construct_feed_dict(
              data_iterator_val, is_train=False)
        cummulative_val_acc /= num_samples_val

        # Evaluate over a random choice of sample pairs, either labeled or not.
        acc_random = self._eval_random_pairs(data, session)

        # Evaluate the accuracy on the latest train batch. We track this to make
        # sure the agreement model is able to fit the training data, but can be
        # eliminated if efficiency is an issue.
        acc_train = self._eval_train(session, feed_dict)

        if self.enable_summaries:
          summary = tf.Summary()
          summary.value.add(tag='AgreementModel/train_acc',
                            simple_value=acc_train)
          summary.value.add(tag='AgreementModel/val_acc',
                            simple_value=cummulative_val_acc)
          if acc_random is not None:
            summary.value.add(tag='AgreementModel/random_acc',
                              simple_value=acc_random)
          iter_total = session.run(self.iter_agr_total)
          summary_writer.add_summary(summary, iter_total)
          summary_writer.flush()
        if step % self.logging_step == 0 or cummulative_val_acc > best_val_acc:
          logging.info(
              'Agreement step %6d | Loss: %10.4f | val_acc: %10.4f |'
              'random_acc: %10.4f | acc_train: %10.4f', step, loss_val,
              cummulative_val_acc, acc_random, acc_train)
        if cummulative_val_acc > best_val_acc:
          best_val_acc = cummulative_val_acc
          if self.checkpoint_path:
            self.saver.save(
                session, self.checkpoint_path, write_meta_graph=False)
            checkpoint_saved = True
          # If we reached 100% accuracy, stop.
          if best_val_acc >= 1.00:
            logging.info('Reached 100% accuracy. Stopping...')
            break
          # Go for at least num_iter_after_best_val more iterations.
          min_num_iter = max(self.min_num_iter,
                             step + self.num_iter_after_best_val)
          logging.info(
              'Achieved best validation. '
              'Extending to at least %d iterations...', min_num_iter)

      step += 1
      has_converged, iter_below_tol = self.check_convergence(
          prev_loss_val,
          loss_val,
          step,
          self.max_num_iter,
          iter_below_tol,
          min_num_iter=min_num_iter)
      session.run(self.iter_agr_total_update)
      prev_loss_val = loss_val

    # Return to the best model.
    if checkpoint_saved:
      logging.info('Restoring best model...')
      self.saver.restore(session, self.checkpoint_path)

    return best_val_acc