def _evaluate(indices, name): """Evaluates the samples with the provided indices.""" data_iterator_val = batch_iterator(indices, batch_size=self.batch_size, shuffle=False, allow_smaller_batch=True, repeat=False) feed_dict_val = self._construct_feed_dict( data_iterator_val, False) cummulative_acc = 0.0 num_samples = 0 while feed_dict_val is not None: val_acc, batch_size_actual = session.run( (self.accuracy, self.batch_size_actual), feed_dict=feed_dict_val) cummulative_acc += val_acc * batch_size_actual num_samples += batch_size_actual feed_dict_val = self._construct_feed_dict( data_iterator_val, False) if num_samples > 0: cummulative_acc /= num_samples if self.enable_summaries: summary = tf.Summary() summary.value.add(tag='ClassificationModel/' + name + '_acc', simple_value=cummulative_acc) iter_cls_total = session.run(self.iter_cls_total) summary_writer.add_summary(summary, iter_cls_total) summary_writer.flush() return cummulative_acc
def edge_iterator(self, data, batch_size, labeling): """An iterator over graph edges. Args: data: A CotrainDataset object used to extract the features and labels. batch_size: An integer representing the desired batch size. labeling: A string which can be `ll`, `lu` or `uu`, that is used to represent the type of edges to return, where `ll` refers to labeled-labeled, `lu` refers to labeled-unlabeled, and `uu` refers to unlabeled-unlabeled. Yields: indices_src, indices_tgt, features_src, features_tgt, labels_src, labels_tgt """ if labeling == 'll': edges = data.get_edges( src_labeled=True, tgt_labeled=True, label_must_match=True) elif labeling == 'lu': edges = ( data.get_edges(src_labeled=True, tgt_labeled=False) + data.get_edges(src_labeled=False, tgt_labeled=True)) elif labeling == 'uu': edges = data.get_edges(src_labeled=False, tgt_labeled=False) else: raise ValueError('Unsupported value for parameter `labeling`.') if not edges: indices = np.zeros(shape=(0,), dtype=np.int32) features = np.zeros( shape=[ 0, ] + list(data.features_shape), dtype=np.float32) labels = np.zeros(shape=(0,), dtype=np.int64) while True: yield (indices, indices, features, features, labels, labels) edges = np.stack([(e.src, e.tgt) for e in edges]) iterator = batch_iterator( inputs=edges, batch_size=batch_size, shuffle=True, allow_smaller_batch=False, repeat=True) for edge in iterator: indices_src = edge[:, 0] indices_tgt = edge[:, 1] features_src = data.get_features(indices_src) features_tgt = data.get_features(indices_tgt) labels_src = data.get_labels(indices_src) labels_tgt = data.get_labels(indices_tgt) yield (indices_src, indices_tgt, features_src, features_tgt, labels_src, labels_tgt)
def train(self, data, session=None, **kwargs): """Train the classification model on the provided dataset. Arguments: data: A CotrainDataset object. session: A TensorFlow session or None. **kwargs: Other keyword arguments. Returns: best_test_acc: A float representing the test accuracy at the iteration where the validation accuracy is maximum. best_val_acc: A float representing the best validation accuracy. """ summary_writer = kwargs['summary_writer'] logging.info('Training classifier...') if not self.is_initialized: self.is_initialized = True logging.info('Weight decay value: %f', session.run(self.weight_decay_var)) else: if self.weight_decay_update is not None: session.run(self.weight_decay_update) logging.info('New weight decay value: %f', session.run(self.weight_decay_var)) # Reset the optimizer state (e.g., momentum). session.run(self.reset_optimizer) if not self.warm_start: # Re-initialize variables. initializers = [v.initializer for v in self.variables.values()] initializers.append(self.global_step.initializer) session.run(initializers) # Construct data iterator. logging.info('Training classifier with %d samples...', data.num_train()) train_indices = data.get_indices_train() unlabeled_indices = data.get_indices_unlabeled() val_indices = data.get_indices_val() test_indices = data.get_indices_test() # Create an iterator for labeled samples for the supervised term. data_iterator_train = batch_iterator(train_indices, batch_size=self.batch_size, shuffle=True, allow_smaller_batch=False, repeat=True) # Create iterators for ll, lu, uu pairs of samples for the agreement term. pair_ll_iterator = self.pair_iterator(train_indices, train_indices, self.num_pairs_reg, data) pair_lu_iterator = self.pair_iterator(train_indices, unlabeled_indices, self.num_pairs_reg, data) pair_uu_iterator = self.pair_iterator(unlabeled_indices, unlabeled_indices, self.num_pairs_reg, data) step = 0 iter_below_tol = 0 min_num_iter = self.min_num_iter has_converged = step >= self.max_num_iter prev_loss_val = np.inf best_test_acc = -1 best_val_acc = -1 checkpoint_saved = False while not has_converged: feed_dict = self._construct_feed_dict(data_iterator_train, True, pair_ll_iterator, pair_lu_iterator, pair_uu_iterator) if self.enable_summaries and step % self.summary_step == 0: loss_val, summary, _ = session.run( [self.loss_op, self.summary_op, self.train_op], feed_dict=feed_dict) iter_cls_total = session.run(self.iter_cls_total) summary_writer.add_summary(summary, iter_cls_total) summary_writer.flush() else: loss_val, _ = session.run((self.loss_op, self.train_op), feed_dict=feed_dict) # Log the loss, if necessary. if step % self.logging_step == 0: logging.info('Classification step %6d | Loss: %10.4f', step, loss_val) # Evaluate, if necessary. def _evaluate(indices, name): """Evaluates the samples with the provided indices.""" data_iterator_val = batch_iterator(indices, batch_size=self.batch_size, shuffle=False, allow_smaller_batch=True, repeat=False) feed_dict_val = self._construct_feed_dict( data_iterator_val, False) cummulative_acc = 0.0 num_samples = 0 while feed_dict_val is not None: val_acc, batch_size_actual = session.run( (self.accuracy, self.batch_size_actual), feed_dict=feed_dict_val) cummulative_acc += val_acc * batch_size_actual num_samples += batch_size_actual feed_dict_val = self._construct_feed_dict( data_iterator_val, False) if num_samples > 0: cummulative_acc /= num_samples if self.enable_summaries: summary = tf.Summary() summary.value.add(tag='ClassificationModel/' + name + '_acc', simple_value=cummulative_acc) iter_cls_total = session.run(self.iter_cls_total) summary_writer.add_summary(summary, iter_cls_total) summary_writer.flush() return cummulative_acc # Run validation, if necessary. if step % self.eval_step == 0: logging.info('Evaluating on %d validation samples...', len(val_indices)) val_acc = _evaluate(val_indices, 'val_acc') logging.info('Evaluating on %d test samples...', len(test_indices)) test_acc = _evaluate(test_indices, 'test_acc') if step % self.logging_step == 0 or val_acc > best_val_acc: logging.info( 'Classification step %6d | Loss: %10.4f | ' 'val_acc: %10.4f | test_acc: %10.4f', step, loss_val, val_acc, test_acc) if val_acc > best_val_acc: best_val_acc = val_acc best_test_acc = test_acc if self.checkpoint_path: self.saver.save(session, self.checkpoint_path, write_meta_graph=False) checkpoint_saved = True # Go for at least num_iter_after_best_val more iterations. min_num_iter = max(self.min_num_iter, step + self.num_iter_after_best_val) logging.info( 'Achieved best validation. ' 'Extending to at least %d iterations...', min_num_iter) step += 1 has_converged, iter_below_tol = self.check_convergence( prev_loss_val, loss_val, step, self.max_num_iter, iter_below_tol, min_num_iter=min_num_iter) session.run(self.iter_cls_total_update) prev_loss_val = loss_val # Return to the best model. if checkpoint_saved: logging.info('Restoring best model...') self.saver.restore(session, self.checkpoint_path) return best_test_acc, best_val_acc
def train(self, data, session=None, **kwargs): """Train an agreement model.""" summary_writer = kwargs['summary_writer'] logging.info('Training agreement model...') if not self.is_initialized: self.is_initialized = True else: if self.weight_decay_update is not None: session.run(self.weight_decay_update) logging.info('New weight decay value: %f', session.run(self.weight_decay_var)) # Construct data iterator. labeled_samples = data.get_indices_train() num_labeled_samples = len(labeled_samples) num_samples_train = num_labeled_samples * num_labeled_samples num_samples_val = min(int(num_samples_train * self.percent_val), self.max_num_samples_val) if num_samples_train == 0: logging.info('No samples to train agreement. Skipping...') return None if not self.warm_start: # Re-initialize variables. initializers = [v.initializer for v in self.trainable_vars] initializers.append(self.global_step.initializer) session.run(initializers) # Reset the optimizer state (e.g., momentum). session.run(self.reset_optimizer) logging.info( 'Training agreement with %d samples and validation on %d samples.', num_samples_train, num_samples_val) # Compute ratio of positives to negative samples. labeled_samples_labels = data.get_labels(labeled_samples) ratio_pos_to_neg = self._compute_ratio_pos_neg(labeled_samples_labels) # Select a validation set out of all pairs of labeled samples. neighbors_val, agreement_labels_val = self._select_val_set( labeled_samples, num_samples_val, data, ratio_pos_to_neg) # Create a train iterator that potentially excludes the validation samples. data_iterator_train = self._train_iterator( labeled_samples, neighbors_val, data, ratio_pos_to_neg=ratio_pos_to_neg) # Start training. best_val_acc = -1 checkpoint_saved = False step = 0 iter_below_tol = 0 min_num_iter = self.min_num_iter has_converged = step >= self.max_num_iter if not has_converged: self.num_iter_trained += 1 prev_loss_val = np.inf while not has_converged: feed_dict = self._construct_feed_dict(data_iterator_train, is_train=True) if self.enable_summaries and step % self.summary_step == 0: loss_val, summary, _ = session.run( [self.loss_op, self.summary_op, self.train_op], feed_dict=feed_dict) iter_total = session.run(self.iter_agr_total) summary_writer.add_summary(summary, iter_total) summary_writer.flush() else: loss_val, _ = session.run((self.loss_op, self.train_op), feed_dict=feed_dict) # Log the loss, if necessary. if step % self.logging_step == 0: logging.info('Agreement step %6d | Loss: %10.4f', step, loss_val) # Run validation, if necessary. if step % self.eval_step == 0: if num_samples_val == 0: logging.info('Skipping validation. No validation samples available.') break data_iterator_val = batch_iterator( neighbors_val, agreement_labels_val, self.batch_size, shuffle=False, allow_smaller_batch=True, repeat=False) feed_dict_val = self._construct_feed_dict( data_iterator_val, is_train=False) cummulative_val_acc = 0.0 while feed_dict_val is not None: val_acc, batch_size_actual = session.run( (self.accuracy, self.batch_size_actual), feed_dict=feed_dict_val) cummulative_val_acc += val_acc * batch_size_actual feed_dict_val = self._construct_feed_dict( data_iterator_val, is_train=False) cummulative_val_acc /= num_samples_val # Evaluate over a random choice of sample pairs, either labeled or not. acc_random = self._eval_random_pairs(data, session) # Evaluate the accuracy on the latest train batch. We track this to make # sure the agreement model is able to fit the training data, but can be # eliminated if efficiency is an issue. acc_train = self._eval_train(session, feed_dict) if self.enable_summaries: summary = tf.Summary() summary.value.add(tag='AgreementModel/train_acc', simple_value=acc_train) summary.value.add(tag='AgreementModel/val_acc', simple_value=cummulative_val_acc) if acc_random is not None: summary.value.add(tag='AgreementModel/random_acc', simple_value=acc_random) iter_total = session.run(self.iter_agr_total) summary_writer.add_summary(summary, iter_total) summary_writer.flush() if step % self.logging_step == 0 or cummulative_val_acc > best_val_acc: logging.info( 'Agreement step %6d | Loss: %10.4f | val_acc: %10.4f |' 'random_acc: %10.4f | acc_train: %10.4f', step, loss_val, cummulative_val_acc, acc_random, acc_train) if cummulative_val_acc > best_val_acc: best_val_acc = cummulative_val_acc if self.checkpoint_path: self.saver.save( session, self.checkpoint_path, write_meta_graph=False) checkpoint_saved = True # If we reached 100% accuracy, stop. if best_val_acc >= 1.00: logging.info('Reached 100% accuracy. Stopping...') break # Go for at least num_iter_after_best_val more iterations. min_num_iter = max(self.min_num_iter, step + self.num_iter_after_best_val) logging.info( 'Achieved best validation. ' 'Extending to at least %d iterations...', min_num_iter) step += 1 has_converged, iter_below_tol = self.check_convergence( prev_loss_val, loss_val, step, self.max_num_iter, iter_below_tol, min_num_iter=min_num_iter) session.run(self.iter_agr_total_update) prev_loss_val = loss_val # Return to the best model. if checkpoint_saved: logging.info('Restoring best model...') self.saver.restore(session, self.checkpoint_path) return best_val_acc