示例#1
0
def minimize_loss_single_machine(loss,
                                 accuracy,
                                 layer_collection,
                                 device=None,
                                 session_config=None):
    """Minimize loss with K-FAC on a single machine.

  Creates `PeriodicInvCovUpdateKfacOpt` which handles inverse and covariance
  computation op placement and execution. A single Session is responsible for
  running all of K-FAC's ops. The covariance and inverse update ops are placed
  on `device`. All model variables are on CPU.

  Args:
    loss: 0-D Tensor. Loss to be minimized.
    accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
    layer_collection: LayerCollection instance describing model architecture.
      Used by K-FAC to construct preconditioner.
    device: string or None. The covariance and inverse update ops are run on
      this device. If empty or None, the default device will be used.
      (Default: None)
    session_config: None or tf.ConfigProto. Configuration for tf.Session().

  Returns:
    final value for 'accuracy'.
  """
    device_list = [] if not device else [device]

    # Train with K-FAC.
    g_step = tf.train.get_or_create_global_step()
    optimizer = kfac.PeriodicInvCovUpdateKfacOpt(
        invert_every=_INVERT_EVERY,
        cov_update_every=_COV_UPDATE_EVERY,
        learning_rate=0.0001,
        cov_ema_decay=0.95,
        damping=0.001,
        layer_collection=layer_collection,
        placement_strategy="round_robin",
        cov_devices=device_list,
        inv_devices=device_list,
        trans_devices=device_list,
        momentum=0.9)

    with tf.device(device):
        train_op = optimizer.minimize(loss, global_step=g_step)

    tf.logging.info("Starting training.")
    with tf.train.MonitoredTrainingSession(config=session_config) as sess:
        while not sess.should_stop():
            global_step_, loss_, accuracy_, _ = sess.run(
                [g_step, loss, accuracy, train_op])

            if global_step_ % _REPORT_EVERY == 0:
                tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
                                global_step_, loss_, accuracy_)

    return accuracy_
def minimize_loss(batch_size, batch_loss, layer_collection, loss_fn,
                  cached_reader):
    """Constructs optimizer and train op.

  Args:
    batch_size: Tensor of shape (), Size of the training batch.
    batch_loss: Tensor of shape (), Loss with respect to minibatch to be
      minimzed.
    layer_collection: LayerCollection or None. Registry for model parameters.
      Required when using a K-FAC optimizer.
    loss_fn: Function which takes as input training data and returns loss.
    cached_reader: `data_reader.CachedReader` instance.

  Returns:
    train_op: Op that can be used to update model parameters.
    optimizer: Optimizer used to produce train_op.

  Raises:
    ValueError: If layer_collection is None when K-FAC is selected as an
      optimization method.
  """
    global_step = tf.train.get_or_create_global_step()

    if layer_collection is None:
        raise ValueError('layer_collection must be defined to use K-FAC.')

    optimizer = kfac.PeriodicInvCovUpdateKfacOpt(
        invert_every=FLAGS.inverse_update_period,
        cov_update_every=FLAGS.cov_update_period,
        learning_rate=1e-3,
        damping=100.,
        cov_ema_decay=0.95,
        momentum=0.95,
        layer_collection=layer_collection,
        batch_size=batch_size)
    # Set the damping parameters required to adapt damping.
    optimizer.set_damping_adaptation_params(
        prev_train_batch=cached_reader.cached_batch,
        is_chief=True,
        loss_fn=loss_fn,
        damping_adaptation_decay=0.95,
        damping_adaptation_interval=FLAGS.damping_adaptation_interval,
    )
    return optimizer.minimize(batch_loss, global_step=global_step)
示例#3
0
文件: rnn_mnist.py 项目: phymucs/kfac
def make_train_op(batch_size, batch_loss, layer_collection, loss_fn,
                  cached_reader):
    """Constructs optimizer and train op.

  Args:
    batch_size: Tensor of shape (), Size of the training batch.
    batch_loss: Tensor of shape (), Loss with respect to minibatch to be
      minimzed.
    layer_collection: LayerCollection or None. Registry for model parameters.
      Required when using a K-FAC optimizer.
    loss_fn: Function which takes as input training data and returns loss.
    cached_reader: `data_reader.CachedReader` instance.

  Returns:
    train_op: Op that can be used to update model parameters.
    optimizer: Optimizer used to produce train_op.

  Raises:
    ValueError: If layer_collection is None when K-FAC is selected as an
      optimization method.
  """
    global_step = tf.train.get_or_create_global_step()

    if layer_collection is None:
        raise ValueError('layer_collection must be defined to use K-FAC.')

    if FLAGS.lrmu_adaptation == 'on':
        learning_rate = None
        momentum = None
        momentum_type = 'qmodel'
    elif FLAGS.lrmu_adaptation == 'only_lr':
        learning_rate = None
        momentum = FLAGS.momentum
        momentum_type = 'qmodel_fixedmu'
    elif FLAGS.lrmu_adaptation == 'off':
        learning_rate = FLAGS.learning_rate
        momentum = FLAGS.momentum
        # momentum_type = 'regular'
        momentum_type = 'adam'

    optimizer = kfac.PeriodicInvCovUpdateKfacOpt(
        invert_every=FLAGS.inverse_update_period,
        cov_update_every=FLAGS.cov_update_period,
        learning_rate=learning_rate,
        damping=150.,  # When using damping adaptation it is advisable to start
        # with a high value. This value is probably far too high
        # to use for most neural nets if you aren't using damping
        # adaptation. (Although it always depends on the scale of
        # the loss.)
        cov_ema_decay=0.95,
        momentum=momentum,
        momentum_type=momentum_type,
        layer_collection=layer_collection,
        batch_size=batch_size,
        num_burnin_steps=5,
        adapt_damping=True,
        is_chief=True,
        prev_train_batch=cached_reader.cached_batch,
        loss=batch_loss,
        loss_fn=loss_fn,
        damping_adaptation_decay=0.95,
        damping_adaptation_interval=FLAGS.damping_adaptation_interval,
        min_damping=1e-5)
    return optimizer.minimize(batch_loss, global_step=global_step), optimizer
示例#4
0
def train_mnist_multitower(num_epochs,
                           num_towers,
                           devices,
                           use_fake_data=False,
                           session_config=None):
    """Train a ConvNet on MNIST.

  Training data is split equally among the towers. Each tower computes loss on
  its own batch of data and the loss is aggregated on the CPU. The model
  variables are placed on first tower. The covariance and inverse update ops
  and variables are placed on specified devices in a round robin manner.

  Args:
    num_epochs: int. Number of passes to make over the training set.
    num_towers: int. Number of towers.
    devices: list of strings. List of devices to place the towers.
    use_fake_data: bool. If True, generate a synthetic dataset.
    session_config: None or tf.ConfigProto. Configuration for tf.Session().

  Returns:
    accuracy of model on the final minibatch of training data.
  """
    num_towers = 1 if not devices else len(devices)
    # Load a dataset.
    tf.logging.info("Loading MNIST into memory.")
    tower_batch_size = 128
    batch_size = tower_batch_size * num_towers
    tf.logging.info(
        ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
         "tower batch size.") % (batch_size, num_towers, tower_batch_size))
    (examples,
     labels) = mnist.load_mnist_as_iterator(num_epochs,
                                            batch_size,
                                            use_fake_data=use_fake_data,
                                            flatten_images=False)

    # Split minibatch across towers.
    examples = tf.split(examples, num_towers)
    labels = tf.split(labels, num_towers)

    # Build an MLP. Each tower's layers will be added to the LayerCollection.
    layer_collection = kfac.LayerCollection()
    tower_results = []
    for tower_id in range(num_towers):
        with tf.device(devices[tower_id]):
            with tf.name_scope("tower%d" % tower_id):
                with tf.variable_scope(tf.get_variable_scope(),
                                       reuse=(tower_id > 0)):
                    tf.logging.info("Building tower %d." % tower_id)
                    tower_results.append(
                        build_model(examples[tower_id],
                                    labels[tower_id],
                                    10,
                                    layer_collection,
                                    register_layers_manually=_USE_MANUAL_REG))
    losses, accuracies = zip(*tower_results)
    # When using multiple towers we only want to perform automatic
    # registation once, after the final tower is made
    if not _USE_MANUAL_REG:
        layer_collection.auto_register_layers()

    # Average across towers.
    loss = tf.reduce_mean(losses)
    accuracy = tf.reduce_mean(accuracies)

    # Fit model.
    g_step = tf.train.get_or_create_global_step()
    optimizer = kfac.PeriodicInvCovUpdateKfacOpt(
        invert_every=_INVERT_EVERY,
        cov_update_every=_COV_UPDATE_EVERY,
        learning_rate=0.0001,
        cov_ema_decay=0.95,
        damping=0.001,
        layer_collection=layer_collection,
        placement_strategy="round_robin",
        cov_devices=devices,
        inv_devices=devices,
        trans_devices=devices,
        momentum=0.9)

    with tf.device(devices[0]):
        train_op = optimizer.minimize(loss, global_step=g_step)

    # Without setting allow_soft_placement=True there will be problems when
    # the optimizer tries to place certain ops like "mod" on the GPU (which isn't
    # supported).
    if not session_config:
        session_config = tf.ConfigProto(allow_soft_placement=True)

    tf.logging.info("Starting training.")
    with tf.train.MonitoredTrainingSession(config=session_config) as sess:
        while not sess.should_stop():
            global_step_, loss_, accuracy_, _ = sess.run(
                [g_step, loss, accuracy, train_op])

            if global_step_ % _REPORT_EVERY == 0:
                tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
                                global_step_, loss_, accuracy_)
示例#5
0
def make_train_op(minibatch,
                  batch_size,
                  batch_loss,
                  layer_collection,
                  loss_fn,
                  prev_train_batch=None,
                  placement_strategy=None,
                  print_logs=False,
                  tf_replicator=None):
  """Constructs optimizer and train op.

  Args:
    minibatch: A list/tuple of Tensors (typically representing the current
      mini-batch of input images and labels).
    batch_size: Tensor of shape (). Size of the training mini-batch.
    batch_loss: Tensor of shape (). Mini-batch loss tensor.
    layer_collection: LayerCollection object. Registry for model parameters.
      Required when using a K-FAC optimizer.
    loss_fn: Function which takes as input a mini-batch and returns the loss.
    prev_train_batch: `Tensor` of the previous training batch, can be accessed
      from the data_reader.CachedReader cached_batch property. (Default: None)
    placement_strategy: `str`, the placement_strategy argument for
      `KfacOptimizer`. (Default: None)
    print_logs: `Bool`. If True we print logs using K-FAC's built-in
      tf.print-based logs printer. (Default: False)
    tf_replicator: A Replicator object or None. If not None, K-FAC will set
        itself up to work inside of the provided TF-Replicator object.
        (Default: None)

  Returns:
    train_op: Op that can be used to update model parameters.
    optimizer: Optimizer used to produce train_op.

  Raises:
    ValueError: If layer_collection is None when K-FAC is selected as an
      optimization method.
  """
  global_step = tf.train.get_or_create_global_step()

  if FLAGS.optimizer == 'kfac':
    if FLAGS.lrmu_adaptation == 'on':
      learning_rate = None
      momentum = None
      momentum_type = 'qmodel'
    elif FLAGS.lrmu_adaptation == 'only_lr':
      learning_rate = None
      momentum = FLAGS.momentum
      momentum_type = 'qmodel_fixedmu'
    elif FLAGS.lrmu_adaptation == 'off':
      learning_rate = FLAGS.learning_rate
      momentum = FLAGS.momentum
      # momentum_type = 'regular'
      momentum_type = 'adam'

    if FLAGS.adapt_damping:
      damping = FLAGS.initial_damping
    else:
      damping = FLAGS.damping

    optimizer = kfac.PeriodicInvCovUpdateKfacOpt(
        invert_every=FLAGS.inverse_update_period,
        cov_update_every=FLAGS.cov_update_period,
        learning_rate=learning_rate,
        damping=damping,
        cov_ema_decay=0.95,
        momentum=momentum,
        momentum_type=momentum_type,
        layer_collection=layer_collection,
        batch_size=batch_size,
        num_burnin_steps=FLAGS.num_burnin_steps,
        adapt_damping=FLAGS.adapt_damping,
        # Note that many of the arguments below don't do anything when
        # adapt_damping=False.
        update_damping_immediately=FLAGS.update_damping_immediately,
        is_chief=True,
        prev_train_batch=prev_train_batch,
        loss=batch_loss,
        loss_fn=loss_fn,
        damping_adaptation_decay=0.9,
        damping_adaptation_interval=FLAGS.damping_adaptation_interval,
        min_damping=1e-6,
        l2_reg=FLAGS.l2_reg,
        train_batch=minibatch,
        placement_strategy=placement_strategy,
        print_logs=print_logs,
        tf_replicator=tf_replicator,
        dtype=FLAGS.dtype,
        )

  elif FLAGS.optimizer == 'adam':
    optimizer = tf.train.AdamOptimizer(
        learning_rate=FLAGS.learning_rate,
        beta1=FLAGS.momentum,
        epsilon=FLAGS.damping,
        beta2=0.99)

  return optimizer.minimize(batch_loss, global_step=global_step), optimizer
def minimize_loss_single_machine(handle, iter_train_handle, iter_val_handle, loss,
                                 accuracy,
                                 layer_collection,
                                 device=None,
                                 session_config=None):
  """Minimize loss with K-FAC on a single machine.

  Creates `PeriodicInvCovUpdateKfacOpt` which handles inverse and covariance
  computation op placement and execution. A single Session is responsible for
  running all of K-FAC's ops. The covariance and inverse update ops are placed
  on `device`. All model variables are on CPU.

  Args:
    loss: 0-D Tensor. Loss to be minimized.
    accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
    layer_collection: LayerCollection instance describing model architecture.
      Used by K-FAC to construct preconditioner.
    device: string or None. The covariance and inverse update ops are run on
      this device. If empty or None, the default device will be used.
      (Default: None)
    session_config: None or tf.ConfigProto. Configuration for tf.Session().

  Returns:
    final value for 'accuracy'.
  """
  device_list = [] if not device else [device]

  # Train with K-FAC.
  g_step = tf.train.get_or_create_global_step()
  optimizer = kfac.PeriodicInvCovUpdateKfacOpt(
      invert_every=_INVERT_EVERY,
      cov_update_every=_COV_UPDATE_EVERY,
      learning_rate=args.lr,
      cov_ema_decay=0.95,
      damping=0.001,
      layer_collection=layer_collection,
      placement_strategy="round_robin",
      cov_devices=device_list,
      inv_devices=device_list,
      trans_devices=device_list,
      momentum=0.9)

  with tf.device(device):
    train_op = optimizer.minimize(loss, global_step=g_step)

  tf.logging.info("Starting training.")

  with tf.train.MonitoredTrainingSession(config=session_config) as sess:
    handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])

    test_loss_, test_accuracy_= sess.run(
      [loss, accuracy], feed_dict={handle: handle_val})
    test_losses.append(test_loss_)
    test_accuracies.append(test_accuracy_)

    import time
    t1 = time.time()

    while not sess.should_stop():
      stime = time.time()
      global_step_, loss_, accuracy_, _ = sess.run(
          [g_step, loss, accuracy, train_op], feed_dict={handle: handle_train})
      etime = time.time()
      step_time = etime - stime
      times.append(step_time)

      if global_step_ % _REPORT_EVERY == 0:
        print ("global_step: %d | loss: %f | accuracy: %s" %
                (global_step_, loss_, accuracy_))

        # test_loss_, test_accuracy_= sess.run(
        #     [loss, accuracy], feed_dict={handle: handle_val})
        # test_losses.append(test_loss_)
        # test_accuracies.append(test_accuracy_)
        # # np.save(dir+"/times.npy", np.array(times))
        # np.save(dir+"/data.npy", np.array(test_accuracies))
        # np.save(dir+"/losses.npy", np.array(test_losses))
        #
        # print ("test_loss %f | test accuracy: %s " %
        #                 (test_loss_, test_accuracy_))
        # tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
        #                 global_step_, loss_, accuracy_)

        t2 = time.time()
        print ("KFAC time: ", t2-t1)
    return accuracy_
示例#7
0
def cifar10_model_fn(features, labels, mode, params):
    """Model function for CIFAR-10."""
    tf.summary.image('images', features, max_outputs=6)

    inputs = features
    _network = get_problem(params)

    def network(*inputs):
        with tf.variable_scope('nn', reuse=tf.AUTO_REUSE):
            return _network(*inputs, mode == tf.estimator.ModeKeys.TRAIN)

    logits = network(inputs)

    if params['optimizer'] == 'kfac':
        lc = kfac.LayerCollection()
        lc.register_categorical_predictive_distribution(logits)
        lc.auto_register_layers()

    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
    }

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    cross_entropy = tf.losses.softmax_cross_entropy(logits=logits,
                                                    onehot_labels=labels)

    # Create a tensor named cross_entropy for logging purposes.
    tf.identity(cross_entropy, name='cross_entropy')
    tf.summary.scalar('cross_entropy', cross_entropy)

    # Add weight decay to the loss.
    loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
        [tf.nn.l2_loss(v) for v in tf.trainable_variables()])

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Scale the learning rate linearly with the batch size. When the batch size
        # is 128, the learning rate should be 0.1.
        initial_learning_rate = params[
            'lr']  # 0.1 * params['batch_size'] / 128
        # batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size']
        global_step = tf.train.get_or_create_global_step()

        # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
        # boundaries = [int(batches_per_epoch * epoch) for epoch in [100, 150, 200]]
        # values = [initial_learning_rate * decay for decay in [1, 0.1, 0.01, 0.001]]
        # learning_rate = tf.train.piecewise_constant(
        #    tf.cast(global_step, tf.int32), boundaries, values)
        learning_rate = initial_learning_rate

        # Create a tensor named learning_rate for logging purposes
        tf.identity(learning_rate, name='learning_rate')
        tf.summary.scalar('learning_rate', learning_rate)
        if params['optimizer'] == 'meta':
            optimizer = co.MetaHessionFreeOptimizer(
                learning_rate=learning_rate,
                iter=params['CG_iter'],
                x_use=params['x_use'],
                y_use=params['y_use'],
                d_use=params['d_use'],
                damping_type=params['damping_type'],
                damping=params['damping'],
                decay=params['decay'])
        elif params['optimizer'] == 'adam':
            optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                               beta1=params['beta1'],
                                               beta2=params['beta2'])
        elif params['optimizer'] == 'RMSprop':
            optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate,
                                                  decay=params['decay'])
        elif params['optimizer'] == 'momentum':
            optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                                   momentum=params['momentum'])
        elif params['optimizer'] == 'SGD':
            optimizer = tf.train.GradientDescentOptimizer(
                learning_rate=learning_rate)
        elif params['optimizer'] == 'kfac':
            optimizer = kfac.PeriodicInvCovUpdateKfacOpt(
                learning_rate=learning_rate,
                cov_ema_decay=params['decay'],
                damping=params['damping'],
                layer_collection=lc)

            if params['damping_type'] == 'LM_heuristics':
                last_inputs = tf.get_variable('last_input',
                                              initializer=tf.zeros_initializer,
                                              shape=inputs.shape,
                                              dtype=inputs.dtype,
                                              trainable=False)

                last_labels = tf.get_variable('last_label',
                                              initializer=tf.zeros_initializer,
                                              shape=labels.shape,
                                              dtype=labels.dtype,
                                              trainable=False)

                catched_collecctions = [
                    tf.assign(last_inputs, inputs),
                    tf.assign(last_labels, labels)
                ]

                optimizer.set_damping_adaptation_params(
                    prev_train_batch=(last_inputs, last_labels),
                    is_chief=True,
                    loss_fn=lambda x: tf.losses.softmax_cross_entropy(
                        logits=network(x[0]), onehot_labels=x[1]),
                    damping_adaptation_decay=params['momentum'],
                )
        else:
            raise ValueError

        # Batch norm requires update ops to be added as a dependency to the train_op
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            if params['optimizer'] == 'meta':
                train_op = optimizer.minimize(loss_type='cross_entropy',
                                              out=logits,
                                              label=labels,
                                              input_list=[inputs],
                                              global_step=global_step,
                                              network_fn=network)
                train_hooks = [
                    co.MetaParametersLoadingHook(params['meta_ckpt'])
                ]
            else:
                train_op = optimizer.minimize(loss, global_step=global_step)
                '''
                train_hooks = [rl.RecordStateHook(state_scope='nn',
                                                  total_step=total_step,
                                                  account=100,
                                                  loss=cross_entropy,
                                                  experience=experience)]
                '''

                if params['optimizer'] == 'kfac' and params[
                        'damping_type'] == 'LM_heuristics':
                    with tf.control_dependencies([train_op]):
                        with tf.control_dependencies(catched_collecctions):
                            train_op = tf.no_op()
                train_hooks = []
    else:
        train_op = None
        train_hooks = []

    accuracy = tf.metrics.accuracy(tf.argmax(labels, axis=1),
                                   predictions['classes'])
    metrics = {'accuracy': accuracy}

    # Create a tensor named train_accuracy for logging purposes
    tf.identity(accuracy[1], name='train_accuracy')
    tf.summary.scalar('train_accuracy', accuracy[1])

    return tf.estimator.EstimatorSpec(mode=mode,
                                      predictions=predictions,
                                      loss=loss,
                                      train_op=train_op,
                                      eval_metric_ops=metrics,
                                      training_hooks=train_hooks)