示例#1
0
    def get_task_training_op(self, graph, losses, task):
        """Get training op for applying gradients to variables.

    Subclasses that need to do anything fancy with gradients should override
    this method.

    Parameters
    ----------
    graph: tf.Graph
      Graph for this op
    losses: dict
      Dictionary mapping task to losses

    Returns
    -------
    A training op.
    """
        with graph.as_default():
            task_loss = losses[task]
            task_root = "task%d_ops" % task
            task_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          task_root)
            opt = model_ops.optimizer(self.optimizer, self.learning_rate,
                                      self.momentum)
            return opt.minimize(task_loss, name='train', var_list=task_vars)
示例#2
0
  def get_training_op(self, graph, loss):
    """Get training op for applying gradients to variables.

    Subclasses that need to do anything fancy with gradients should override
    this method.

    Returns:
    A training op.
    """
    with graph.as_default():
      opt = model_ops.optimizer(self.optimizer, self.learning_rate, self.momentum)
      return opt.minimize(loss, name='train')
示例#3
0
  def get_training_op(self, graph, loss):
    """Get training op for applying gradients to variables.

    Subclasses that need to do anything fancy with gradients should override
    this method.

    Returns:
    A training op.
    """
    with graph.as_default():
      opt = model_ops.optimizer(self.optimizer, self.learning_rate,
                                self.momentum)
      return opt.minimize(loss, name='train')
  def get_task_training_op(self, graph, losses, task):
    """Get training op for applying gradients to variables.

    Subclasses that need to do anything fancy with gradients should override
    this method.

    Parameters
    ----------
    graph: tf.Graph
      Graph for this op
    losses: dict
      Dictionary mapping task to losses

    Returns
    -------
    A training op.
    """
    with graph.as_default():
      task_loss = losses[task]
      task_root = "task%d_ops" % task
      task_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, task_root)
      opt = model_ops.optimizer(self.optimizer, self.learning_rate, self.momentum)
      return opt.minimize(task_loss, name='train', var_list=task_vars)
    def fit(self,
            dataset,
            nb_epoch=10,
            max_checkpoints_to_keep=5,
            log_every_N_batches=50,
            learning_rate=.001,
            batch_size=50,
            checkpoint_interval=10):
        """Trains the model for a fixed number of epochs.

    TODO(rbharath0: This is mostly copied from TensorflowGraphModel. Should
    eventually refactor both together.

    Parameters
    ----------
    dataset: dc.data.Dataset
    nb_epoch: 10
      Number of training epochs.
      Dataset object holding training data
        batch_size: integer. Number of samples per gradient update.
        nb_epoch: integer, the number of epochs to train the model.
        verbose: 0 for no logging to stdout,
            1 for progress bar logging, 2 for one log line per epoch.
        initial_epoch: epoch at which to start training
            (useful for resuming a previous training run)
    checkpoint_interval: int
      Frequency at which to write checkpoints, measured in epochs
    """
        ############################################################## TIMING
        time1 = time.time()
        ############################################################## TIMING
        print("Training for %d epochs" % nb_epoch)
        with self.graph.as_default():
            opt = model_ops.optimizer("adam", learning_rate)
            train_op = opt.minimize(self.loss, name='train')
            with self.session as sess:
                sess.run(tf.global_variables_initializer())
                saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
                # Save an initial checkpoint.
                saver.save(sess, self._save_path, global_step=0)
                for epoch in range(nb_epoch):
                    avg_loss, n_batches = 0., 0
                    # TODO(rbharath): Don't support example weighting yet.
                    for ind, (X_b, y_b, w_b, ids_b) in enumerate(
                            dataset.iterbatches(batch_size)):
                        if ind % log_every_N_batches == 0:
                            print("On batch %d" % ind)
                        feed_dict = {self.features: X_b, self.labels: y_b}
                        fetches = [self.outputs] + [train_op, self.loss]
                        fetched_values = sess.run(fetches, feed_dict=feed_dict)
                        output = fetched_values[:1]
                        loss = fetched_values[-1]
                        avg_loss += loss
                        y_pred = np.squeeze(np.array(output))
                        y_b = y_b.flatten()
                        n_batches += 1
                    if epoch % checkpoint_interval == checkpoint_interval - 1:
                        saver.save(sess, self._save_path, global_step=epoch)
                    avg_loss = float(avg_loss) / n_batches
                    print('Ending epoch %d: Average loss %g' %
                          (epoch, avg_loss))
                # Always save a final checkpoint when complete.
                saver.save(sess, self._save_path, global_step=epoch + 1)
        ############################################################## TIMING
        time2 = time.time()
        print("TIMING: model fitting took %0.3f s" % (time2 - time1))
示例#6
0
  def fit(self,
          dataset,
          nb_epoch=10,
          max_checkpoints_to_keep=5,
          log_every_N_batches=50,
          learning_rate=.001,
          batch_size=50,
          checkpoint_interval=10):
    """Trains the model for a fixed number of epochs.

    TODO(rbharath0: This is mostly copied from TensorflowGraphModel. Should
    eventually refactor both together.

    Parameters
    ----------
    dataset: dc.data.Dataset
    nb_epoch: 10
      Number of training epochs.
      Dataset object holding training data
        batch_size: integer. Number of samples per gradient update.
        nb_epoch: integer, the number of epochs to train the model.
        verbose: 0 for no logging to stdout,
            1 for progress bar logging, 2 for one log line per epoch.
        initial_epoch: epoch at which to start training
            (useful for resuming a previous training run)
    checkpoint_interval: int
      Frequency at which to write checkpoints, measured in epochs
    """
    ############################################################## TIMING
    time1 = time.time()
    ############################################################## TIMING
    print("Training for %d epochs" % nb_epoch)
    with self.graph.as_default():
      opt = model_ops.optimizer("adam", learning_rate)
      train_op = opt.minimize(self.loss, name='train')
      with self.session as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
        # Save an initial checkpoint.
        saver.save(sess, self._save_path, global_step=0)
        for epoch in range(nb_epoch):
          avg_loss, n_batches = 0., 0
          # TODO(rbharath): Don't support example weighting yet.
          for ind, (X_b, y_b, w_b,
                    ids_b) in enumerate(dataset.iterbatches(batch_size)):
            if ind % log_every_N_batches == 0:
              print("On batch %d" % ind)
            feed_dict = {self.features: X_b, self.labels: y_b}
            fetches = [self.outputs] + [train_op, self.loss]
            fetched_values = sess.run(fetches, feed_dict=feed_dict)
            output = fetched_values[:1]
            loss = fetched_values[-1]
            avg_loss += loss
            y_pred = np.squeeze(np.array(output))
            y_b = y_b.flatten()
            n_batches += 1
          if epoch % checkpoint_interval == checkpoint_interval - 1:
            saver.save(sess, self._save_path, global_step=epoch)
          avg_loss = float(avg_loss) / n_batches
          print('Ending epoch %d: Average loss %g' % (epoch, avg_loss))
        # Always save a final checkpoint when complete.
        saver.save(sess, self._save_path, global_step=epoch + 1)
    ############################################################## TIMING
    time2 = time.time()
    print("TIMING: model fitting took %0.3f s" % (time2 - time1))