Example #1
0
def run_canary():
    """
    Runs some code that will crash if the GPUs / GPU driver are suffering from
    a common bug. This helps to prevent contaminating results in the rest of
    the library with incorrect calculations.
    """

    # Note: please do not edit this function unless you have access to a machine
    # with GPUs suffering from the bug and can verify that the canary still
    # crashes after your edits. Due to the transient nature of the GPU bug it is
    # not possible to unit test the canary in our continuous integration system.

    global last_run
    current = time.time()
    if last_run is None or current - last_run > 3600:
        last_run = current
    else:
        # Run the canary at most once per hour
        return

    # Try very hard not to let the canary affect the graph for the rest of the
    # python process
    canary_graph = tf.Graph()
    with canary_graph.as_default():
        devices = infer_devices()
        num_devices = len(devices)
        if num_devices < 3:
            # We have never observed GPU failure when less than 3 GPUs were used
            return

        v = np.random.RandomState([2018, 10, 16]).randn(2, 2)
        # Try very hard not to let this Variable end up in any collections used
        # by the rest of the python process
        w = tf.Variable(v, trainable=False, collections=[])
        loss = tf.reduce_sum(tf.square(w))

        grads = []
        for device in devices:
            with tf.device(device):
                grad, = tf.gradients(loss, w)
                grads.append(grad)

        sess = tf.Session()
        sess.run(tf.variables_initializer([w]))
        grads = sess.run(grads)
        first = grads[0]
        for grad in grads[1:]:
            if not np.allclose(first, grad):
                first_string = str(first)
                grad_string = str(grad)
                raise RuntimeError("Something is wrong with your GPUs or GPU driver."
                                   "%(num_devices)d different GPUS were asked to "
                                   "calculate the same 2x2 gradient. One returned "
                                   "%(first_string)s and another returned "
                                   "%(grad_string)s. This can usually be fixed by "
                                   "rebooting the machine." %
                                   {"num_devices": num_devices,
                                    "first_string": first_string,
                                    "grad_string": grad_string})
        sess.close()
def train(sess, loss, x_train, y_train,
          init_all=False, evaluate=None, feed=None, args=None,
          rng=None, var_list=None, fprop_args=None, optimizer=None,
          devices=None, x_batch_preprocessor=None, use_ema=False,
          ema_decay=.998, run_canary=None,
          loss_threshold=1e5, dataset_train=None, dataset_size=None):
  """
  Run (optionally multi-replica, synchronous) training to minimize `loss`
  :param sess: TF session to use when training the graph
  :param loss: tensor, the loss to minimize
  :param x_train: numpy array with training inputs or tf Dataset
  :param y_train: numpy array with training outputs or tf Dataset
  :param init_all: (boolean) If set to true, all TF variables in the session
                   are (re)initialized, otherwise only previously
                   uninitialized variables are initialized before training.
  :param evaluate: function that is run after each training iteration
                   (typically to display the test/validation accuracy).
  :param feed: An optional dictionary that is appended to the feeding
               dictionary before the session runs. Can be used to feed
               the learning phase of a Keras model for instance.
  :param args: dict or argparse `Namespace` object.
               Should contain `nb_epochs`, `learning_rate`,
               `batch_size`
  :param rng: Instance of numpy.random.RandomState
  :param var_list: Optional list of parameters to train.
  :param fprop_args: dict, extra arguments to pass to fprop (loss and model).
  :param optimizer: Optimizer to be used for training
  :param devices: list of device names to use for training
      If None, defaults to: all GPUs, if GPUs are available
                            all devices, if no GPUs are available
  :param x_batch_preprocessor: callable
      Takes a single tensor containing an x_train batch as input
      Returns a single tensor containing an x_train batch as output
      Called to preprocess the data before passing the data to the Loss
  :param use_ema: bool
      If true, uses an exponential moving average of the model parameters
  :param ema_decay: float or callable
      The decay parameter for EMA, if EMA is used
      If a callable rather than a float, this is a callable that takes
      the epoch and batch as arguments and returns the ema_decay for
      the current batch.
  :param loss_threshold: float
      Raise an exception if the loss exceeds this value.
      This is intended to rapidly detect numerical problems.
      Sometimes the loss may legitimately be higher than this value. In
      such cases, raise the value. If needed it can be np.inf.
  :param dataset_train: tf Dataset instance.
      Used as a replacement for x_train, y_train for faster performance.
    :param dataset_size: integer, the size of the dataset_train.
  :return: True if model trained
  """

  # Check whether the hardware is working correctly
  canary.run_canary()
  if run_canary is not None:
    warnings.warn("The `run_canary` argument is deprecated. The canary "
                  "is now much cheaper and thus runs all the time. The "
                  "canary now uses its own loss function so it is not "
                  "necessary to turn off the canary when training with "
                  " a stochastic loss. Simply quit passing `run_canary`."
                  "Passing `run_canary` may become an error on or after "
                  "2019-10-16.")

  args = _ArgsWrapper(args or {})
  fprop_args = fprop_args or {}

  # Check that necessary arguments were given (see doc above)
  # Be sure to support 0 epochs for debugging purposes
  if args.nb_epochs is None:
    raise ValueError("`args` must specify number of epochs")
  if optimizer is None:
    if args.learning_rate is None:
      raise ValueError("Learning rate was not given in args dict")
  assert args.batch_size, "Batch size was not given in args dict"

  if rng is None:
    rng = np.random.RandomState()

  if optimizer is None:
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
  else:
    if not isinstance(optimizer, tf.train.Optimizer):
      raise ValueError("optimizer object must be from a child class of "
                       "tf.train.Optimizer")

  grads = []
  xs = []
  preprocessed_xs = []
  ys = []
  if dataset_train is not None:
    assert x_train is None and y_train is None and x_batch_preprocessor is None
    if dataset_size is None:
      raise ValueError("You must provide a dataset size")
    data_iterator = dataset_train.make_one_shot_iterator().get_next()
    x_train, y_train = sess.run(data_iterator)

  devices = infer_devices(devices)
  for device in devices:
    with tf.device(device):
      # x = tf.placeholder(x_train.dtype, (None,) + x_train.shape[1:])
      # y = tf.placeholder(y_train.dtype, (None,) + y_train.shape[1:])
      x = tf.placeholder(tf.float32, (None,) + x_train.shape[1:])
      y = tf.placeholder(tf.float32, (None,) + y_train.shape[1:])
      xs.append(x)
      ys.append(y)

      if x_batch_preprocessor is not None:
        x = x_batch_preprocessor(x)

      # We need to keep track of these so that the canary can feed
      # preprocessed values. If the canary had to feed raw values,
      # stochastic preprocessing could make the canary fail.
      preprocessed_xs.append(x)

      loss_value = loss.fprop(x, y, **fprop_args)
      print("loss_value", loss_value)
      grads.append(optimizer.compute_gradients(
          loss_value, var_list=var_list))
      print("grads:", grads)
  num_devices = len(devices)
  print("num_devices: ", num_devices)

  grad = avg_grads(grads)
  # Trigger update operations within the default graph (such as batch_norm).
  with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    train_step = optimizer.apply_gradients(grad)

  epoch_tf = tf.placeholder(tf.int32, [])
  batch_tf = tf.placeholder(tf.int32, [])

  if use_ema:
    if callable(ema_decay):
      ema_decay = ema_decay(epoch_tf, batch_tf)
    ema = tf.train.ExponentialMovingAverage(decay=ema_decay)
    with tf.control_dependencies([train_step]):
      train_step = ema.apply(var_list)
    # Get pointers to the EMA's running average variables
    avg_params = [ema.average(param) for param in var_list]
    # Make temporary buffers used for swapping the live and running average
    # parameters
    tmp_params = [tf.Variable(param, trainable=False)
                  for param in var_list]
    # Define the swapping operation
    param_to_tmp = [tf.assign(tmp, param)
                    for tmp, param in safe_zip(tmp_params, var_list)]
    with tf.control_dependencies(param_to_tmp):
      avg_to_param = [tf.assign(param, avg)
                      for param, avg in safe_zip(var_list, avg_params)]
    with tf.control_dependencies(avg_to_param):
      tmp_to_avg = [tf.assign(avg, tmp)
                    for avg, tmp in safe_zip(avg_params, tmp_params)]
    swap = tmp_to_avg

  batch_size = args.batch_size

  assert batch_size % num_devices == 0
  device_batch_size = batch_size // num_devices

  if init_all:
    sess.run(tf.global_variables_initializer())
  else:
    initialize_uninitialized_global_variables(sess)

  for epoch in xrange(args.nb_epochs):
    if dataset_train is not None:
      nb_batches = int(math.ceil(float(dataset_size) / batch_size))
    else:
      # Indices to shuffle training set
      index_shuf = list(range(len(x_train)))
      # Randomly repeat a few training examples each epoch to avoid
      # having a too-small batch
      while len(index_shuf) % batch_size != 0:
        index_shuf.append(rng.randint(len(x_train)))
      nb_batches = len(index_shuf) // batch_size
      rng.shuffle(index_shuf)
      # Shuffling here versus inside the loop doesn't seem to affect
      # timing very much, but shuffling here makes the code slightly
      # easier to read
      x_train_shuffled = x_train[index_shuf]
      y_train_shuffled = y_train[index_shuf]

    prev = time.time()
    for batch in range(nb_batches):
      if dataset_train is not None:
        x_train_shuffled, y_train_shuffled = sess.run(data_iterator)
        start, end = 0, batch_size
      else:
        # Compute batch start and end indices
        start = batch * batch_size
        end = (batch + 1) * batch_size
        # Perform one training step
        diff = end - start
        assert diff == batch_size

      feed_dict = {epoch_tf: epoch, batch_tf: batch}
      for dev_idx in xrange(num_devices):
        cur_start = start + dev_idx * device_batch_size
        cur_end = start + (dev_idx + 1) * device_batch_size
        feed_dict[xs[dev_idx]] = x_train_shuffled[cur_start:cur_end]
        feed_dict[ys[dev_idx]] = y_train_shuffled[cur_start:cur_end]
      if cur_end != end and dataset_train is None:
        msg = ("batch_size (%d) must be a multiple of num_devices "
               "(%d).\nCUDA_VISIBLE_DEVICES: %s"
               "\ndevices: %s")
        args = (batch_size, num_devices,
                os.environ['CUDA_VISIBLE_DEVICES'],
                str(devices))
        raise ValueError(msg % args)
      if feed is not None:
        feed_dict.update(feed)

      _, loss_numpy = sess.run([train_step, loss_value], feed_dict=feed_dict)

      if np.abs(loss_numpy) > loss_threshold:
        raise ValueError("Extreme loss during training: ", loss_numpy)
      if np.isnan(loss_numpy) or np.isinf(loss_numpy):
        raise ValueError("NaN/Inf loss during training")
    assert (dataset_train is not None or end == len(index_shuf))  # Check that all examples were used
    cur = time.time()
    _logger.info("Epoch " + str(epoch) + " took " + str(cur - prev) + " seconds")
    print("loss:", loss_numpy)
    if evaluate is not None:
      if use_ema:
        # Before running evaluation, load the running average
        # parameters into the live slot, so we can see how well
        # the EMA parameters are performing
        sess.run(swap)
      evaluate()
      if use_ema:
        # Swap the parameters back, so that we continue training
        # on the live parameters
        sess.run(swap)
  if use_ema:
    # When training is done, swap the running average parameters into
    # the live slot, so that we use them when we deploy the model
    sess.run(swap)

  return True
Example #3
0
def batch_eval_multi_worker(sess,
                            graph_factory,
                            numpy_inputs,
                            batch_size=None,
                            devices=None,
                            feed=None):
    """
  Generic computation engine for evaluating an expression across a whole
  dataset, divided into batches.

  This function assumes that the work can be parallelized with one worker
  device handling one batch of data. If you need multiple devices per
  batch, use `batch_eval`.

  The tensorflow graph for multiple workers is large, so the first few
  runs of the graph will be very slow. If you expect to run the graph
  few times (few calls to `batch_eval_multi_worker` that each run few
  batches) the startup cost might dominate the runtime, and it might be
  preferable to use the single worker `batch_eval` just because its
  startup cost will be lower.

  :param sess: tensorflow Session
  :param graph_factory: callable
      When called, returns (tf_inputs, tf_outputs) where:
          tf_inputs is a list of placeholders to feed from the dataset
          tf_outputs is a list of tf tensors to calculate
      Example: tf_inputs is [x, y] placeholders, tf_outputs is [accuracy].
      This factory must make new tensors when called, rather than, e.g.
      handing out a reference to existing tensors.
      This factory must make exactly equivalent expressions every time
      it is called, otherwise the results of `batch_eval` will vary
      depending on how work is distributed to devices.
      This factory must respect "with tf.device()" context managers
      that are active when it is called, otherwise work will not be
      distributed to devices correctly.
  :param numpy_inputs:
      A list of numpy arrays defining the dataset to be evaluated.
      The list should have the same length as tf_inputs.
      Each array should have the same number of examples (shape[0]).
      Example: numpy_inputs is [MNIST().x_test, MNIST().y_test]
  :param batch_size: Number of examples to use in a single evaluation batch.
      If not specified, this function will use a reasonable guess and
      may run out of memory.
      When choosing the batch size, keep in mind that the batch will
      be divided up evenly among available devices. If you can fit 128
      examples in memory on one GPU and you have 8 GPUs, you probably
      want to use a batch size of 1024 (unless a different batch size
      runs faster with the ops you are using, etc.)
  :param devices: List of devices to run on. If unspecified, uses all
      available GPUs if any GPUS are available, otherwise uses CPUs.
  :param feed: An optional dictionary that is appended to the feeding
           dictionary before the session runs. Can be used to feed
           the learning phase of a Keras model for instance.
  :returns: List of numpy arrays corresponding to the outputs produced by
      the graph_factory
  """
    global _batch_eval_multi_worker_cache

    devices = infer_devices(devices)

    if batch_size is None:
        # For big models this might result in OOM and then the user
        # should just specify batch_size
        batch_size = len(devices) * DEFAULT_EXAMPLES_PER_DEVICE

    n = len(numpy_inputs)
    assert n > 0
    m = numpy_inputs[0].shape[0]
    for i in range(1, n):
        assert numpy_inputs[i].shape[0] == m
    out = []

    replicated_tf_inputs = []
    replicated_tf_outputs = []
    p = None

    num_devices = len(devices)
    assert batch_size % num_devices == 0
    device_batch_size = batch_size // num_devices

    cache_key = (graph_factory, tuple(devices))
    if cache_key in _batch_eval_multi_worker_cache:
        # Retrieve graph for multi-GPU inference from cache.
        # This avoids adding tf ops to the graph
        packed = _batch_eval_multi_worker_cache[cache_key]
        replicated_tf_inputs, replicated_tf_outputs = packed
        p = len(replicated_tf_outputs[0])
        assert p > 0
    else:
        # This graph has not been built before.
        # Build it now.

        for device in devices:
            with tf.device(device):
                tf_inputs, tf_outputs = graph_factory()
                assert len(tf_inputs) == n
                if p is None:
                    p = len(tf_outputs)
                    assert p > 0
                else:
                    assert len(tf_outputs) == p
                replicated_tf_inputs.append(tf_inputs)
                replicated_tf_outputs.append(tf_outputs)
        del tf_inputs
        del tf_outputs
        # Store the result in the cache
        packed = replicated_tf_inputs, replicated_tf_outputs
        _batch_eval_multi_worker_cache[cache_key] = packed
    for _ in range(p):
        out.append([])
    flat_tf_outputs = []
    for output in range(p):
        for dev_idx in range(num_devices):
            flat_tf_outputs.append(replicated_tf_outputs[dev_idx][output])

    # pad data to have # examples be multiple of batch size
    # we discard the excess later
    num_batches = int(np.ceil(float(m) / batch_size))
    needed_m = num_batches * batch_size
    excess = needed_m - m
    if excess > m:
        raise NotImplementedError("Your batch size is bigger than the"
                                  " dataset, this function is probably"
                                  " overkill.")

    def pad(array):
        """Pads an array with replicated examples to have `excess` more entries"""
        if excess > 0:
            array = np.concatenate((array, array[:excess]), axis=0)
        return array

    numpy_inputs = [pad(numpy_input) for numpy_input in numpy_inputs]
    orig_m = m
    m = needed_m

    for start in range(0, m, batch_size):
        batch = start // batch_size
        if batch % 100 == 0 and batch > 0:
            _logger.debug("Batch " + str(batch))

        # Compute batch start and end indices
        end = start + batch_size
        numpy_input_batches = [
            numpy_input[start:end] for numpy_input in numpy_inputs
        ]
        feed_dict = {}
        for dev_idx, tf_inputs in enumerate(replicated_tf_inputs):
            for tf_input, numpy_input in zip(tf_inputs, numpy_input_batches):
                dev_start = dev_idx * device_batch_size
                dev_end = (dev_idx + 1) * device_batch_size
                value = numpy_input[dev_start:dev_end]
                assert value.shape[0] == device_batch_size
                feed_dict[tf_input] = value
        if feed is not None:
            feed_dict.update(feed)
        flat_output_batches = sess.run(flat_tf_outputs, feed_dict=feed_dict)
        for e in flat_output_batches:
            assert e.shape[0] == device_batch_size, e.shape

        output_batches = []
        for output in range(p):
            o_start = output * num_devices
            o_end = (output + 1) * num_devices
            device_values = flat_output_batches[o_start:o_end]
            assert len(device_values) == num_devices
            output_batches.append(device_values)

        for out_elem, device_values in zip(out, output_batches):
            assert len(device_values) == num_devices, (len(device_values),
                                                       num_devices)
            for device_value in device_values:
                assert device_value.shape[0] == device_batch_size
            out_elem.extend(device_values)

    out = [np.concatenate(x, axis=0) for x in out]
    for e in out:
        assert e.shape[0] == m, e.shape

    # Trim off the examples we used to pad up to batch size
    out = [e[:orig_m] for e in out]
    assert len(out) == p, (len(out), p)

    return out
from __future__ import unicode_literals

import logging
import time

import tensorflow as tf
from tensorflow.python.platform import flags

from cleverhans.attacks import ProjectedGradientDescent, Semantic
from cleverhans.evaluation import accuracy
from cleverhans.serial import load
from cleverhans.utils import set_log_level
from cleverhans.utils_tf import infer_devices
from cleverhans.utils_tf import silence
silence()
devices = infer_devices()
num_devices = len(devices)
BATCH_SIZE = 128
TRAIN_START = 0
TRAIN_END = 60000
TEST_START = 0
TEST_END = 10000
WHICH_SET = 'test'
NB_ITER = 40
BASE_EPS_ITER = None  # Differs by dataset

FLAGS = flags.FLAGS


def print_accuracies(filepath,
                     train_start=TRAIN_START,
Example #5
0
def train(sess,
          loss,
          x_train,
          y_train,
          init_all=True,
          evaluate=None,
          feed=None,
          args=None,
          rng=None,
          var_list=None,
          fprop_args=None,
          optimizer=None,
          devices=None,
          x_batch_preprocessor=None,
          use_ema=False,
          ema_decay=.998,
          run_canary=True,
          loss_threshold=1e5):
    """
  Run (optionally multi-replica, synchronous) training to minimize `loss`
  :param sess: TF session to use when training the graph
  :param loss: tensor, the loss to minimize
  :param x_train: numpy array with training inputs
  :param y_train: numpy array with training outputs
  :param init_all: (boolean) If set to true, all TF variables in the session
                   are (re)initialized, otherwise only previously
                   uninitialized variables are initialized before training.
  :param evaluate: function that is run after each training iteration
                   (typically to display the test/validation accuracy).
  :param feed: An optional dictionary that is appended to the feeding
               dictionary before the session runs. Can be used to feed
               the learning phase of a Keras model for instance.
  :param args: dict or argparse `Namespace` object.
               Should contain `nb_epochs`, `learning_rate`,
               `batch_size`
  :param rng: Instance of numpy.random.RandomState
  :param var_list: Optional list of parameters to train.
  :param fprop_args: dict, extra arguments to pass to fprop (loss and model).
  :param optimizer: Optimizer to be used for training
  :param devices: list of device names to use for training
      If None, defaults to: all GPUs, if GPUs are available
                            all devices, if no GPUs are available
  :param x_batch_preprocessor: callable
      Takes a single tensor containing an x_train batch as input
      Returns a single tensor containing an x_train batch as output
      Called to preprocess the data before passing the data to the Loss
  :param use_ema: bool
      If true, uses an exponential moving average of the model parameters
  :param ema_decay: float or callable
      The decay parameter for EMA, if EMA is used
      If a callable rather than a float, this is a callable that takes
      the epoch and batch as arguments and returns the ema_decay for
      the current batch.
  :param run_canary: bool
      If True and using 3 or more GPUs, runs some canary code that should
      fail if there is a multi-GPU driver problem.
      Turn this off if your gradients are inherently stochastic (e.g.
      if you use dropout). The canary code checks that all GPUs give
      approximately the same gradient.
  :param loss_threshold: float
      Raise an exception if the loss exceeds this value.
      This is intended to rapidly detect numerical problems.
      Sometimes the loss may legitimately be higher than this value. In
      such cases, raise the value. If needed it can be np.inf.
  :return: True if model trained
  """
    args = _ArgsWrapper(args or {})
    fprop_args = fprop_args or {}

    # Check that necessary arguments were given (see doc above)
    assert args.nb_epochs, "Number of epochs was not given in args dict"
    if optimizer is None:
        if args.learning_rate is None:
            raise ValueError("Learning rate was not given in args dict")
    assert args.batch_size, "Batch size was not given in args dict"

    if rng is None:
        rng = np.random.RandomState()

    if optimizer is None:
        optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    else:
        if not isinstance(optimizer, tf.train.Optimizer):
            raise ValueError("optimizer object must be from a child class of "
                             "tf.train.Optimizer")

    grads = []
    xs = []
    preprocessed_xs = []
    ys = []

    devices = infer_devices(devices)
    for idx, device in enumerate(devices):
        with tf.device(device):
            x = tf.placeholder(x_train.dtype, (None, ) + x_train.shape[1:])
            y = tf.placeholder(x_train.dtype, (None, ) + y_train.shape[1:])
            xs.append(x)
            ys.append(y)

            if x_batch_preprocessor is not None:
                x = x_batch_preprocessor(x)

            # We need to keep track of these so that the canary can feed
            # preprocessed values. If the canary had to feed raw values,
            # stochastic preprocessing could make the canary fail.
            preprocessed_xs.append(x)

            loss_value = loss.fprop(x, y, **fprop_args)

            grads.append(
                optimizer.compute_gradients(loss_value, var_list=var_list))
    num_devices = len(devices)
    print("num_devices: ", num_devices)

    grad = avg_grads(grads)
    # Trigger update operations within the default graph (such as batch_norm).
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_step = optimizer.apply_gradients(grad)

    epoch_tf = tf.placeholder(tf.int32, [])
    batch_tf = tf.placeholder(tf.int32, [])

    if use_ema:
        if callable(ema_decay):
            ema_decay = ema_decay(epoch_tf, batch_tf)
        ema = tf.train.ExponentialMovingAverage(decay=ema_decay)
        with tf.control_dependencies([train_step]):
            train_step = ema.apply(var_list)
        # Get pointers to the EMA's running average variables
        avg_params = [ema.average(param) for param in var_list]
        # Make temporary buffers used for swapping the live and running average
        # parameters
        tmp_params = [
            tf.Variable(param, trainable=False) for param in var_list
        ]
        # Define the swapping operation
        param_to_tmp = [
            tf.assign(tmp, param)
            for tmp, param in safe_zip(tmp_params, var_list)
        ]
        with tf.control_dependencies(param_to_tmp):
            avg_to_param = [
                tf.assign(param, avg)
                for param, avg in safe_zip(var_list, avg_params)
            ]
        with tf.control_dependencies(avg_to_param):
            tmp_to_avg = [
                tf.assign(avg, tmp)
                for avg, tmp in safe_zip(avg_params, tmp_params)
            ]
        swap = tmp_to_avg

    batch_size = args.batch_size

    assert batch_size % num_devices == 0
    device_batch_size = batch_size // num_devices

    if init_all:
        sess.run(tf.global_variables_initializer())
    else:
        initialize_uninitialized_global_variables(sess)

    # Check whether the hardware is working correctly

    # So far the failure has only been observed with 3 or more GPUs
    run_canary = run_canary and num_devices > 2
    if run_canary:
        canary_feed_dict = {}
        for x, y in safe_zip(preprocessed_xs, ys):
            canary_feed_dict[x] = x_train[:device_batch_size].copy()
            canary_feed_dict[y] = y_train[:device_batch_size].copy()
        # To reduce the runtime and memory cost of this canary,
        # we test the gradient of only one parameter.
        # For now this is just set to the first parameter in the list,
        # because it is an index that is always guaranteed to work.
        # If we think that this is causing false negatives and we should
        # test other parameters, we could test a random parameter from
        # the list or we could rewrite the canary to examine more than
        # one parameter.
        param_to_test = 0
        grad_vars = []
        for i in xrange(num_devices):
            dev_grads = grads[i]
            grad_vars.append(dev_grads[param_to_test][0])
        grad_values = sess.run(grad_vars, feed_dict=canary_feed_dict)
        failed = False
        for i in xrange(1, num_devices):
            if grad_values[0].shape != grad_values[i].shape:
                print("shape 0 does not match shape %d:" % i,
                      grad_values[0].shape, grad_values[i].shape)
                failed = True
                continue
            if not np.allclose(grad_values[0], grad_values[i], atol=1e-6):
                print("grad_values[0]: ", grad_values[0].mean(),
                      grad_values[0].max())
                print("grad_values[%d]: " % i, grad_values[i].mean(),
                      grad_values[i].max())
                print("max diff: ",
                      np.abs(grad_values[0] - grad_values[1]).max())
                failed = True
        if failed:
            print("Canary failed.")
            quit()

    for epoch in xrange(args.nb_epochs):
        # Indices to shuffle training set
        index_shuf = list(range(len(x_train)))
        # Randomly repeat a few training examples each epoch to avoid
        # having a too-small batch
        while len(index_shuf) % batch_size != 0:
            index_shuf.append(rng.randint(len(x_train)))
        nb_batches = len(index_shuf) // batch_size
        rng.shuffle(index_shuf)
        # Shuffling here versus inside the loop doesn't seem to affect
        # timing very much, but shuffling here makes the code slightly
        # easier to read
        x_train_shuffled = x_train[index_shuf]
        y_train_shuffled = y_train[index_shuf]

        prev = time.time()
        for batch in range(nb_batches):

            # Compute batch start and end indices
            start = batch * batch_size
            end = (batch + 1) * batch_size

            # Perform one training step
            feed_dict = {epoch_tf: epoch, batch_tf: batch}
            diff = end - start
            assert diff == batch_size
            for dev_idx in xrange(num_devices):
                cur_start = start + dev_idx * device_batch_size
                cur_end = start + (dev_idx + 1) * device_batch_size
                feed_dict[xs[dev_idx]] = x_train_shuffled[cur_start:cur_end]
                feed_dict[ys[dev_idx]] = y_train_shuffled[cur_start:cur_end]
            if cur_end != end:
                msg = ("batch_size (%d) must be a multiple of num_devices "
                       "(%d).\nCUDA_VISIBLE_DEVICES: %s"
                       "\ndevices: %s")
                args = (batch_size, num_devices,
                        os.environ['CUDA_VISIBLE_DEVICES'], str(devices))
                raise ValueError(msg % args)
            if feed is not None:
                feed_dict.update(feed)

            _, loss_numpy = sess.run([train_step, loss_value],
                                     feed_dict=feed_dict)

            if np.abs(loss_numpy) > loss_threshold:
                raise ValueError("Extreme loss during training: ", loss_numpy)
            if np.isnan(loss_numpy) or np.isinf(loss_numpy):
                raise ValueError("NaN/Inf loss during training")
        assert end == len(index_shuf)  # Check that all examples were used
        cur = time.time()
        _logger.info("Epoch " + str(epoch) + " took " + str(cur - prev) +
                     " seconds")
        if evaluate is not None:
            if use_ema:
                # Before running evaluation, load the running average
                # parameters into the live slot, so we can see how well
                # the EMA parameters are performing
                sess.run(swap)
            evaluate()
            if use_ema:
                # Swap the parameters back, so that we continue training
                # on the live parameters
                sess.run(swap)
    if use_ema:
        # When training is done, swap the running average parameters into
        # the live slot, so that we use them when we deploy the model
        sess.run(swap)

    return True
Example #6
0
def train_ae(sess,
             loss,
             x_train,
             x_train_target,
             init_all=False,
             evaluate=None,
             feed=None,
             args=None,
             rng=None,
             var_list=None,
             fprop_args=None,
             optimizer=None,
             devices=None,
             x_batch_preprocessor=None,
             use_ema=False,
             ema_decay=.998,
             run_canary=None,
             loss_threshold=1e5,
             dataset_train=None,
             dataset_size=None):
    # Check whether the hardware is working correctly
    start_time = time.time()
    canary.run_canary()
    if run_canary is not None:
        warnings.warn("The `run_canary` argument is deprecated. The canary "
                      "is now much cheaper and thus runs all the time. The "
                      "canary now uses its own loss function so it is not "
                      "necessary to turn off the canary when training with "
                      " a stochastic loss. Simply quit passing `run_canary`."
                      "Passing `run_canary` may become an error on or after "
                      "2019-10-16.")

    args = _ArgsWrapper(args or {})
    fprop_args = fprop_args or {}

    # Check that necessary arguments were given (see doc above)
    # Be sure to support 0 epochs for debugging purposes
    if args.nb_epochs is None:
        raise ValueError("`args` must specify number of epochs")
    if optimizer is None:
        if args.learning_rate is None:
            raise ValueError("Learning rate was not given in args dict")
    assert args.batch_size, "Batch size was not given in args dict"

    if rng is None:
        rng = np.random.RandomState()

    if optimizer is None:
        optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    else:
        if not isinstance(optimizer, tf.train.Optimizer):
            raise ValueError("optimizer object must be from a child class of "
                             "tf.train.Optimizer")

    grads = []
    xs = []
    xs_t = []
    preprocessed_xs = []
    preprocessed_xs_t = []
    #ys = []
    if dataset_train is not None:
        assert x_train is None and x_batch_preprocessor is None
        if dataset_size is None:
            raise ValueError("You must provide a dataset size")
        data_iterator = dataset_train.make_one_shot_iterator().get_next()
        x_train, x_train_target = sess.run(data_iterator)

    devices = infer_devices(devices)
    for device in devices:
        with tf.device(device):
            x = tf.placeholder(x_train.dtype, (None, ) + x_train.shape[1:])
            x_t = tf.placeholder(x_train_target.dtype,
                                 (None, ) + x_train_target.shape[1:])
            #y = tf.placeholder(y_train.dtype, (None,) + y_train.shape[1:])
            xs.append(x)
            xs_t.append(x_t)
            #ys.append(y)

            if x_batch_preprocessor is not None:
                x = x_batch_preprocessor(x)
                x_t = x_batch_preprocessor(x_t)

            # We need to keep track of these so that the canary can feed
            # preprocessed values. If the canary had to feed raw values,
            # stochastic preprocessing could make the canary fail.
            preprocessed_xs.append(x)
            preprocessed_xs_t.append(x_t)

            loss_value = loss.fprop(x, x_t, **fprop_args)

            grads.append(
                optimizer.compute_gradients(loss_value, var_list=var_list))
    num_devices = len(devices)
    print("num_devices: ", num_devices)

    grad = avg_grads(grads)
    # Trigger update operations within the default graph (such as batch_norm).
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_step = optimizer.apply_gradients(grad)

    epoch_tf = tf.placeholder(tf.int32, [])
    batch_tf = tf.placeholder(tf.int32, [])

    if use_ema:
        if callable(ema_decay):
            ema_decay = ema_decay(epoch_tf, batch_tf)
        ema = tf.train.ExponentialMovingAverage(decay=ema_decay)
        with tf.control_dependencies([train_step]):
            train_step = ema.apply(var_list)
        # Get pointers to the EMA's running average variables
        avg_params = [ema.average(param) for param in var_list]
        # Make temporary buffers used for swapping the live and running average
        # parameters
        tmp_params = [
            tf.Variable(param, trainable=False) for param in var_list
        ]
        # Define the swapping operation
        param_to_tmp = [
            tf.assign(tmp, param)
            for tmp, param in safe_zip(tmp_params, var_list)
        ]
        with tf.control_dependencies(param_to_tmp):
            avg_to_param = [
                tf.assign(param, avg)
                for param, avg in safe_zip(var_list, avg_params)
            ]
        with tf.control_dependencies(avg_to_param):
            tmp_to_avg = [
                tf.assign(avg, tmp)
                for avg, tmp in safe_zip(avg_params, tmp_params)
            ]
        swap = tmp_to_avg

    batch_size = args.batch_size

    assert batch_size % num_devices == 0
    device_batch_size = batch_size // num_devices

    if init_all:
        sess.run(tf.global_variables_initializer())
    else:
        initialize_uninitialized_global_variables(sess)

    for epoch in xrange(args.nb_epochs):
        if dataset_train is not None:
            nb_batches = int(math.ceil(float(dataset_size) / batch_size))
        else:
            # Indices to shuffle training set
            index_shuf = list(range(len(x_train)))
            # Randomly repeat a few training examples each epoch to avoid
            # having a too-small batch
            while len(index_shuf) % batch_size != 0:
                index_shuf.append(rng.randint(len(x_train)))
            nb_batches = len(index_shuf) // batch_size
            rng.shuffle(index_shuf)
            # Shuffling here versus inside the loop doesn't seem to affect
            # timing very much, but shuffling here makes the code slightly
            # easier to read
            x_train_shuffled = x_train[index_shuf]
            x_train_target_shuffled = x_train_target[index_shuf]
            #y_train_shuffled = y_train[index_shuf]

        prev = time.time()
        for batch in range(nb_batches):
            if dataset_train is not None:
                x_train_shuffled, x_train_target_shuffled = sess.run(
                    data_iterator)
                start, end = 0, batch_size
            else:
                # Compute batch start and end indices
                start = batch * batch_size
                end = (batch + 1) * batch_size
                # Perform one training step
                diff = end - start
                assert diff == batch_size

            feed_dict = {epoch_tf: epoch, batch_tf: batch}
            for dev_idx in xrange(num_devices):
                cur_start = start + dev_idx * device_batch_size
                cur_end = start + (dev_idx + 1) * device_batch_size
                feed_dict[xs[dev_idx]] = x_train_shuffled[cur_start:cur_end]
                feed_dict[
                    xs_t[dev_idx]] = x_train_target_shuffled[cur_start:cur_end]
                #feed_dict[ys[dev_idx]] = y_train_shuffled[cur_start:cur_end]
            if cur_end != end and dataset_train is None:
                msg = ("batch_size (%d) must be a multiple of num_devices "
                       "(%d).\nCUDA_VISIBLE_DEVICES: %s"
                       "\ndevices: %s")
                args = (batch_size, num_devices,
                        os.environ['CUDA_VISIBLE_DEVICES'], str(devices))
                raise ValueError(msg % args)
            if feed is not None:
                feed_dict.update(feed)

            _, loss_numpy = sess.run([train_step, loss_value],
                                     feed_dict=feed_dict)

            if np.abs(loss_numpy) > loss_threshold:
                raise ValueError("Extreme loss during training: ", loss_numpy)
            if np.isnan(loss_numpy) or np.isinf(loss_numpy):
                raise ValueError("NaN/Inf loss during training")
        assert (dataset_train is not None
                or end == len(index_shuf))  # Check that all examples were used
        cur = time.time()
        _logger.info("Epoch " + str(epoch) + " took " + str(cur - prev) +
                     " seconds")
        if evaluate is not None:
            if use_ema:
                # Before running evaluation, load the running average
                # parameters into the live slot, so we can see how well
                # the EMA parameters are performing
                sess.run(swap)
            evaluate()
            if use_ema:
                # Swap the parameters back, so that we continue training
                # on the live parameters
                sess.run(swap)
    if use_ema:
        # When training is done, swap the running average parameters into
        # the live slot, so that we use them when we deploy the model
        sess.run(swap)
    end_time = time.time()
    print("Time taken for training: ", end_time - start_time)
    return True
Example #7
0
File: train.py Project: ATPGN/ATPGN
def train_with_PGN(sess, model, loss, train_type='naive', evaluate=None, args=None,
          rng=None, classifier_var_list=None, generator_var_list=None, save_dir=None,
          fprop_args=None, optimizer=None, use_ema=False, ema_decay=.998,
          loss_threshold=1e10, dataset_train=None, dataset_size=None):
  """
  Run (optionally multi-replica, synchronous) training to minimize `loss`
  :param sess: TF session to use when training the graph
  :param loss: tensor, the loss to minimize
  :param evaluate: function that is run after each training iteration
                   (typically to display the test/validation accuracy).
  :param args: dict or argparse `Namespace` object.
               Should contain `nb_epochs`, `learning_rate`,
               `batch_size`
  :param rng: Instance of numpy.random.RandomState
  :param var_list: Optional list of parameters to train.
  :param fprop_args: dict, extra arguments to pass to fprop (loss and model).
  :param optimizer: Optimizer to be used for training
  :param use_ema: bool
      If true, uses an exponential moving average of the model parameters
  :param ema_decay: float or callable
      The decay parameter for EMA, if EMA is used
      If a callable rather than a float, this is a callable that takes
      the epoch and batch as arguments and returns the ema_decay for
      the current batch.
  :param loss_threshold: float
      Raise an exception if the loss exceeds this value.
      This is intended to rapidly detect numerical problems.
      Sometimes the loss may legitimately be higher than this value. In
      such cases, raise the value. If needed it can be np.inf.
  :param dataset_train: tf Dataset instance.
      Used as a replacement for x_train, y_train for faster performance.
    :param dataset_size: integer, the size of the dataset_train.
  :return: True if model trained
  """

  # Check whether the hardware is working correctly
  canary.run_canary()
  args = _ArgsWrapper(args or {})
  fprop_args = fprop_args or {}

  # Check that necessary arguments were given (see doc above)
  # Be sure to support 0 epochs for debugging purposes
  if args.nb_epochs is None:
    raise ValueError("`args` must specify number of epochs")
  if optimizer is None:
    if args.learning_rate is None:
      raise ValueError("Learning rate was not given in args dict")
  assert args.batch_size, "Batch size was not given in args dict"
  assert dataset_train and dataset_size, "dataset_train or dataset_size was not given"

  if rng is None:
    rng = np.random.RandomState()

  if optimizer is None:
    optimizer = tf.train.AdamOptimizer(learning_rate = args.learning_rate)
  else:
    if not isinstance(optimizer, tf.train.Optimizer):
      raise ValueError("optimizer object must be from a child class of "
                       "tf.train.Optimizer")

  grads_classifier = []
  if train_type == 'PGN':
    grads_generator = []
  xs = []
  ys = []
  data_iterator = dataset_train.make_one_shot_iterator().get_next()
  x_train, y_train = sess.run(data_iterator)

  devices = infer_devices()
  for device in devices:
    with tf.device(device):
      x = tf.placeholder(x_train.dtype, (None,) + x_train.shape[1:])
      y = tf.placeholder(y_train.dtype, (None,) + y_train.shape[1:])
      xs.append(x)
      ys.append(y)
      if train_type == 'PGN':
        loss_classifier, loss_generator = loss.fprop(x, y, **fprop_args)
      else:
        loss_classifier = loss.fprop(x, y, **fprop_args)
      grads_classifier.append(optimizer.compute_gradients(loss_classifier, var_list=classifier_var_list))
      if train_type == 'PGN':
        grads_generator.append(optimizer.compute_gradients(loss_generator, var_list=generator_var_list))

  num_devices = len(devices)
  print("num_devices: ", num_devices)

  grad_classifier = avg_grads(grads_classifier)
  if train_type == 'PGN':
    grad_generator = avg_grads(grads_generator)
  # Trigger update operations within the default graph (such as batch_norm).
  with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    train_step = optimizer.apply_gradients(grad_classifier)
    if train_type == 'PGN':
      with tf.control_dependencies([train_step]):
        train_step = optimizer.apply_gradients(grad_generator)

  var_list = classifier_var_list
  if train_type == 'PGN':
    var_list += generator_var_list
  if use_ema:
    ema = tf.train.ExponentialMovingAverage(decay=ema_decay)
    with tf.control_dependencies([train_step]):
      train_step = ema.apply(var_list)
    # Get pointers to the EMA's running average variables
    avg_params = [ema.average(param) for param in var_list]
    # Make temporary buffers used for swapping the live and running average
    # parameters
    tmp_params = [tf.Variable(param, trainable=False)
                  for param in var_list]
    # Define the swapping operation
    param_to_tmp = [tf.assign(tmp, param)
                    for tmp, param in safe_zip(tmp_params, var_list)]
    with tf.control_dependencies(param_to_tmp):
      avg_to_param = [tf.assign(param, avg)
                      for param, avg in safe_zip(var_list, avg_params)]
    with tf.control_dependencies(avg_to_param):
      tmp_to_avg = [tf.assign(avg, tmp)
                    for avg, tmp in safe_zip(avg_params, tmp_params)]
    swap = tmp_to_avg

  batch_size = args.batch_size

  assert batch_size % num_devices == 0
  device_batch_size = batch_size // num_devices

  sess.run(tf.global_variables_initializer())
  best_acc = 0.0

  for epoch in xrange(args.nb_epochs):
    nb_batches = int(math.ceil(float(dataset_size) / batch_size))
    prev = time.time()
    for batch in range(nb_batches):
      x_train_shuffled, y_train_shuffled = sess.run(data_iterator)
      start, end = 0, batch_size
      feed_dict = dict()
      for dev_idx in xrange(num_devices):
        cur_start = start + dev_idx * device_batch_size
        cur_end = start + (dev_idx + 1) * device_batch_size
        feed_dict[xs[dev_idx]] = x_train_shuffled[cur_start:cur_end]
        feed_dict[ys[dev_idx]] = y_train_shuffled[cur_start:cur_end]

      
      _, loss_classifier_numpy = sess.run([train_step, loss_classifier], feed_dict=feed_dict)

      if np.abs(loss_classifier_numpy) > loss_threshold:
        raise ValueError("Extreme loss_classifier during training: ", loss_classifier_numpy)
      if np.isnan(loss_classifier_numpy) or np.isinf(loss_classifier_numpy):
        raise ValueError("NaN/Inf loss_classifier during training")
    cur = time.time()
    _logger.info("Epoch " + str(epoch) + " took " +
                 str(cur - prev) + " seconds")
    if evaluate is not None:
      if use_ema:
        sess.run(swap)
      r_value = evaluate(epoch)

      if use_ema:
        sess.run(swap)
  if use_ema:
    sess.run(swap)

  with sess.as_default():
    save_path = os.path.join(save_dir,'model.joblib')
    save(save_path, model)

  return True
Example #8
0
def train(sess, loss, x_train, y_train,
          init_all=True, evaluate=None, feed=None, args=None,
          rng=None, var_list=None, fprop_args=None, optimizer=None,
          devices=None, x_batch_preprocessor=None):
    """
    Run (optionally multi-replica, synchronous) training to minimize `loss`
    :param sess: TF session to use when training the graph
    :param loss: tensor, the loss to minimize
    :param x_train: numpy array with training inputs
    :param y_train: numpy array with training outputs
    :param init_all: (boolean) If set to true, all TF variables in the session
                     are (re)initialized, otherwise only previously
                     uninitialized variables are initialized before training.
    :param evaluate: function that is run after each training iteration
                     (typically to display the test/validation accuracy).
    :param feed: An optional dictionary that is appended to the feeding
                 dictionary before the session runs. Can be used to feed
                 the learning phase of a Keras model for instance.
    :param args: dict or argparse `Namespace` object.
                 Should contain `nb_epochs`, `learning_rate`,
                 `batch_size`
    :param rng: Instance of numpy.random.RandomState
    :param var_list: Optional list of parameters to train.
    :param fprop_args: dict, extra arguments to pass to fprop (loss and model).
    :param optimizer: Optimizer to be used for training
    :param devices: list of device names to use for training
        If None, defaults to: all GPUs, if GPUs are available
                              all devices, if no GPUs are available
    :param x_batch_preprocessor: callable
        Takes a single tensor containing an x_train batch as input
        Returns a single tensor containing an x_train batch as output
        Called to preprocess the data before passing the data to the Loss
    :return: True if model trained
    """
    args = _ArgsWrapper(args or {})
    fprop_args = fprop_args or {}

    # Check that necessary arguments were given (see doc above)
    assert args.nb_epochs, "Number of epochs was not given in args dict"
    if optimizer is None:
        if args.learning_rate is None:
            raise ValueError("Learning rate was not given in args dict")
    assert args.batch_size, "Batch size was not given in args dict"

    if rng is None:
        rng = np.random.RandomState()

    if optimizer is None:
        optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    else:
        if not isinstance(optimizer, tf.train.Optimizer):
            raise ValueError("optimizer object must be from a child class of "
                             "tf.train.Optimizer")

    grads = []
    xs = []
    preprocessed_xs = []
    ys = []

    devices = infer_devices(devices)
    for idx, device in enumerate(devices):
        with tf.device(device):
            x = tf.placeholder(x_train.dtype, (None,) + x_train.shape[1:])
            y = tf.placeholder(x_train.dtype, (None,) + y_train.shape[1:])
            xs.append(x)
            ys.append(y)

            if x_batch_preprocessor is not None:
                x = x_batch_preprocessor(x)

            # We need to keep track of these so that the canary can feed
            # preprocessed values. If the canary had to feed raw values,
            # stochastic preprocessing could make the canary fail.
            preprocessed_xs.append(x)

            loss_value = loss.fprop(x, y, **fprop_args)

            grads.append(optimizer.compute_gradients(
                loss_value, var_list=var_list))
    num_devices = len(devices)
    print("num_devices: ", num_devices)

    grad = avg_grads(grads)
    # Trigger update operations within the default graph (such as batch_norm).
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_step = optimizer.apply_gradients(grad)

    batch_size = args.batch_size

    assert batch_size % num_devices == 0
    device_batch_size = batch_size // num_devices

    if init_all:
        sess.run(tf.global_variables_initializer())
    else:
        initialize_uninitialized_global_variables(sess)

    # Check whether the hardware is working correctly

    # So far the failure has only been observed with 3 or more GPUs
    run_canary = num_devices > 2
    if run_canary:
        canary_feed_dict = {}
        for x, y in safe_zip(preprocessed_xs, ys):
            canary_feed_dict[x] = x_train[:device_batch_size].copy()
            canary_feed_dict[y] = y_train[:device_batch_size].copy()
        # To reduce the runtime and memory cost of this canary,
        # we test the gradient of only one parameter.
        # For now this is just set to the first parameter in the list,
        # because it is an index that is always guaranteed to work.
        # If we think that this is causing false negatives and we should
        # test other parameters, we could test a random parameter from
        # the list or we could rewrite the canary to examine more than
        # one parameter.
        param_to_test = 0
        grad_vars = []
        for i in xrange(num_devices):
            dev_grads = grads[i]
            grad_vars.append(dev_grads[param_to_test][0])
        grad_values = sess.run(grad_vars, feed_dict=canary_feed_dict)
        failed = False
        for i in xrange(1, num_devices):
            if grad_values[0].shape != grad_values[i].shape:
                print("shape 0 does not match shape %d:" % i,
                      grad_values[0].shape, grad_values[i].shape)
                failed = True
                continue
            if not np.allclose(grad_values[0], grad_values[i], atol=1e-6):
                print("grad_values[0]: ",
                      grad_values[0].mean(), grad_values[0].max())
                print("grad_values[%d]: " %
                      i, grad_values[i].mean(), grad_values[i].max())
                print("max diff: ", np.abs(
                    grad_values[0] - grad_values[1]).max())
                failed = True
        if failed:
            print("Canary failed.")
            quit()

    for epoch in xrange(args.nb_epochs):
        # Indices to shuffle training set
        index_shuf = list(range(len(x_train)))
        # Randomly repeat a few training examples each epoch to avoid
        # having a too-small batch
        while len(index_shuf) % batch_size != 0:
            index_shuf.append(rng.randint(len(x_train)))
        nb_batches = len(index_shuf) // batch_size
        rng.shuffle(index_shuf)
        # Shuffling here versus inside the loop doesn't seem to affect
        # timing very much, but shuffling here makes the code slightly
        # easier to read
        x_train_shuffled = x_train[index_shuf]
        y_train_shuffled = y_train[index_shuf]

        prev = time.time()
        for batch in range(nb_batches):

            # Compute batch start and end indices
            start = batch * batch_size
            end = (batch + 1) * batch_size
            # start, end = batch_indices(
            #    batch, len(x_train), args.batch_size)

            # Perform one training step
            feed_dict = {}
            diff = end - start
            assert diff == batch_size
            for dev_idx in xrange(num_devices):
                cur_start = start + dev_idx * device_batch_size
                cur_end = start + (dev_idx + 1) * device_batch_size
                feed_dict[xs[dev_idx]
                          ] = x_train_shuffled[cur_start:cur_end]
                feed_dict[ys[dev_idx]
                          ] = y_train_shuffled[cur_start:cur_end]
            if cur_end != end:
                msg = ("batch_size (%d) must be a multiple of num_devices "
                       "(%d).\nCUDA_VISIBLE_DEVICES: %s"
                       "\ndevices: %s")
                args = (batch_size, num_devices,
                        os.environ['CUDA_VISIBLE_DEVICES'],
                        str(devices))
                raise ValueError(msg % args)
            if feed is not None:
                feed_dict.update(feed)
            sess.run(train_step, feed_dict=feed_dict)
        assert end == len(index_shuf)  # Check that all examples were used
        cur = time.time()
        _logger.info("Epoch " + str(epoch) + " took " +
                     str(cur - prev) + " seconds")
        if evaluate is not None:
            evaluate()

    return True
Example #9
0
def train(sess,
          loss,
          x_train,
          y_train,
          init_all=True,
          evaluate=None,
          feed=None,
          args=None,
          rng=None,
          var_list=None,
          fprop_args=None,
          optimizer=None,
          devices=None,
          x_batch_preprocessor=None):
    """
    Run (optionally multi-replica, synchronous) training to minimize `loss`
    :param sess: TF session to use when training the graph
    :param loss: tensor, the loss to minimize
    :param x_train: numpy array with training inputs
    :param y_train: numpy array with training outputs
    :param init_all: (boolean) If set to true, all TF variables in the session
                     are (re)initialized, otherwise only previously
                     uninitialized variables are initialized before training.
    :param evaluate: function that is run after each training iteration
                     (typically to display the test/validation accuracy).
    :param feed: An optional dictionary that is appended to the feeding
                 dictionary before the session runs. Can be used to feed
                 the learning phase of a Keras model for instance.
    :param args: dict or argparse `Namespace` object.
                 Should contain `nb_epochs`, `learning_rate`,
                 `batch_size`
    :param rng: Instance of numpy.random.RandomState
    :param var_list: Optional list of parameters to train.
    :param fprop_args: dict, extra arguments to pass to fprop (loss and model).
    :param optimizer: Optimizer to be used for training
    :param devices: list of device names to use for training
        If None, defaults to: all GPUs, if GPUs are available
                              all devices, if no GPUs are available
    :param x_batch_preprocessor: callable
        Takes a single tensor containing an x_train batch as input
        Returns a single tensor containing an x_train batch as output
        Called to preprocess the data before passing the data to the Loss
    :return: True if model trained
    """
    args = _ArgsWrapper(args or {})
    fprop_args = fprop_args or {}

    # Check that necessary arguments were given (see doc above)
    assert args.nb_epochs, "Number of epochs was not given in args dict"
    if optimizer is None:
        if args.learning_rate is None:
            raise ValueError("Learning rate was not given in args dict")
    assert args.batch_size, "Batch size was not given in args dict"

    if rng is None:
        rng = np.random.RandomState()

    if optimizer is None:
        optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    else:
        if not isinstance(optimizer, tf.train.Optimizer):
            raise ValueError("optimizer object must be from a child class of "
                             "tf.train.Optimizer")

    grads = []
    xs = []
    ys = []

    devices = infer_devices(devices)
    for idx, device in enumerate(devices):
        with tf.device(device):
            x = tf.placeholder(x_train.dtype, (None, ) + x_train.shape[1:])
            y = tf.placeholder(x_train.dtype, (None, ) + y_train.shape[1:])
            xs.append(x)
            ys.append(y)

            if x_batch_preprocessor is not None:
                x = x_batch_preprocessor(x)
            loss_value = loss.fprop(x, y, **fprop_args)

            grads.append(
                optimizer.compute_gradients(loss_value, var_list=var_list))
    num_devices = len(devices)
    print("num_devices: ", num_devices)
    grad = avg_grads(grads)
    # Trigger update operations within the default graph (such as batch_norm).
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_step = optimizer.apply_gradients(grad)

    batch_size = args.batch_size

    with sess.as_default():
        if init_all:
            sess.run(tf.global_variables_initializer())
        else:
            initialize_uninitialized_global_variables(sess)

        for epoch in xrange(args.nb_epochs):
            # Indices to shuffle training set
            index_shuf = list(range(len(x_train)))
            # Randomly repeat a few training examples each epoch to avoid
            # having a too-small batch
            while len(index_shuf) % batch_size != 0:
                index_shuf.append(rng.randint(len(x_train)))
            nb_batches = len(index_shuf) // batch_size
            rng.shuffle(index_shuf)
            # Shuffling here versus inside the loop doesn't seem to affect
            # timing very much, but shuffling here makes the code slightly
            # easier to read
            x_train_shuffled = x_train[index_shuf]
            y_train_shuffled = y_train[index_shuf]

            prev = time.time()
            for batch in range(nb_batches):

                # Compute batch start and end indices
                start = batch * batch_size
                end = (batch + 1) * batch_size
                # start, end = batch_indices(
                #    batch, len(x_train), args.batch_size)

                # Perform one training step
                feed_dict = {}
                diff = end - start
                assert diff == batch_size
                stride = diff // num_devices
                for dev_idx in xrange(num_devices):
                    cur_start = start + dev_idx * stride
                    cur_end = start + (dev_idx + 1) * stride
                    feed_dict[
                        xs[dev_idx]] = x_train_shuffled[cur_start:cur_end]
                    feed_dict[
                        ys[dev_idx]] = y_train_shuffled[cur_start:cur_end]
                if cur_end != end:
                    msg = ("batch_size (%d) must be a multiple of num_devices "
                           "(%d).\nCUDA_VISIBLE_DEVICES: %s"
                           "\ndevices: %s")
                    args = (batch_size, num_devices,
                            os.environ['CUDA_VISIBLE_DEVICES'], str(devices))
                    raise ValueError(msg % args)
                if feed is not None:
                    feed_dict.update(feed)
                sess.run(train_step, feed_dict=feed_dict)
            assert end == len(index_shuf)  # Check that all examples were used
            cur = time.time()
            _logger.info("Epoch " + str(epoch) + " took " + str(cur - prev) +
                         " seconds")
            if evaluate is not None:
                evaluate()

    return True