Exemplo n.º 1
0
def train(output_dir,
          mnist_location=constants.MNIST_LOCATION,
          training_len=constants.TRAINING_LEN,
          masks=None,
          presets=None,
          train_order_seed=None):
    """Train the MNIST model, possibly with presets and masks.

  Args:
    output_dir: The directory to which to write model logs and output.
    mnist_location: The location of the MNIST numpy npz file.
    training_len: How long to run the model. A tuple of two values. The first
      value is the unit of measure (either "epochs" or "iterations") and the
      second is the number of units for which to train.
    masks: The masks, if any, used to prune weights. Masks can come in
      one of four forms:
      * A dictionary of numpy arrays. Each dictionary key is the name of the
        corresponding tensor that is to be masked out. Each value is a numpy
        array containing the masks (1 for including a weight, 0 for excluding).
      * The string name of a directory containing one file for each
        mask (in the form of bedrock.save_restore).
      * A list of strings paths and dictionaries representing several masks.
        The mask used for training is the union of the pruned networks
        represented by these masks.
      * None, meaning the network should not be pruned.
    presets: The initial weights for the network, if any. Presets can come in
      any of the non-list forms mentioned for masks; each numpy array
      stores the desired initializations.
    train_order_seed: The random seed, if any, to be used to determine the
      order in which training examples are shuffled before being presented
      to the network.
  """
    # Retrieve previous information, if any.
    masks = save_restore.standardize(masks, union.union)
    presets = save_restore.standardize(presets)

    # Create the dataset and model.
    dataset = dataset_mnist.Dataset(mnist_location,
                                    train_order_seed=train_order_seed)
    inputs, labels = dataset.placeholders
    model = model_fc.ModelFc(constants.HYPERPARAMETERS,
                             inputs,
                             labels,
                             presets=presets,
                             masks=masks)

    # Train.
    params = {
        'test_interval': 100,
        'save_summaries': True,
        'save_network': True
    }
    trainer.train(tf.Session(), dataset, model, constants.OPTIMIZER_FN,
                  training_len, output_dir, **params)
Exemplo n.º 2
0
 def train_once(self, iteration, presets=None, masks=None):
     tf.reset_default_graph()
     sess = tf.Session()
     dataset = dataset_mnist.ConstructedDatasetMnist(
         train_len=self.train_len)
     input_tensor, label_tensor = dataset.placeholders
     hyperparameters = {
         'layers': [(300, tf.nn.relu), (100, tf.nn.relu), (10, None)]
     }
     model = model_fc.ModelFc(hyperparameters,
                              input_tensor,
                              label_tensor,
                              presets=presets,
                              masks=masks)
     params = {
         'test_interval': 100,
         'save_summaries': True,
         'save_network': True,
     }
     return trainer.train(sess,
                          dataset,
                          model,
                          functools.partial(
                              tf.train.GradientDescentOptimizer, .1),
                          ('iterations', 50000),
                          output_dir=paths.run(self.output_dir, iteration),
                          **params)
  def train_model(sess, level, dataset, model):
    params = {
        'test_interval': 100,
        'save_summaries': True,
        'save_network': True,
    }

    return trainer.train(
        sess,
        dataset,
        model,
        constants.OPTIMIZER_FN,
        training_len,
        output_dir=paths.run(output_dir, level, experiment_name),
        **params)