Exemple #1
0
def softmax_preds(images, ckpt_path, return_logits=False):
    """
  Compute softmax activations (probabilities) with the model saved in the path
  specified as an argument
  :param images: a np array of images
  :param ckpt_path: a TF model checkpoint
  :param logits: if set to True, return logits instead of probabilities
  :return: probabilities (or logits if logits is set to True)
  """
    # Compute nb samples and deduce nb of batches
    data_length = len(images)
    nb_batches = math.ceil(len(images) / FLAGS.batch_size)

    # Declare data placeholder
    train_data_node = _input_placeholder()

    # Build a Graph that computes the logits predictions from the placeholder
    if FLAGS.deeper:
        logits = inference_deeper(train_data_node)
    elif FLAGS.dataset == 'adult':
        logits = inference_adult(train_data_node)
    else:
        logits = inference(train_data_node)

    if return_logits:
        # We are returning the logits directly (no need to apply softmax)
        output = logits
    else:
        # Add softmax predictions to graph: will return probabilities
        output = tf.nn.softmax(logits)

    # Restore the moving average version of the learned variables for eval.
    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY)
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)

    # Will hold the result
    preds = np.zeros((data_length, FLAGS.nb_labels), dtype=np.float32)

    # Create TF session
    with tf.Session() as sess:
        # Restore TF session from checkpoint file
        saver.restore(sess, ckpt_path)

        # Parse data by batch
        for batch_nb in xrange(0, int(nb_batches + 1)):
            # Compute batch start and end indices
            start, end = utils.batch_indices(batch_nb, data_length,
                                             FLAGS.batch_size)

            # Prepare feed dictionary
            feed_dict = {train_data_node: images[start:end]}

            # Run session ([0] because run returns a batch with len 1st dim == 1)
            preds[start:end, :] = sess.run([output], feed_dict=feed_dict)[0]

    # Reset graph to allow multiple calls
    tf.reset_default_graph()

    return preds
def softmax_preds(images, ckpt_path, return_logits=False):
  """
  Compute softmax activations (probabilities) with the model saved in the path
  specified as an argument
  :param images: a np array of images
  :param ckpt_path: a TF model checkpoint
  :param logits: if set to True, return logits instead of probabilities
  :return: probabilities (or logits if logits is set to True)
  """
  # Compute nb samples and deduce nb of batches
  data_length = len(images)
  nb_batches = math.ceil(len(images) / FLAGS.batch_size)

  # Declare data placeholder
  train_data_node = _input_placeholder()

  # Build a Graph that computes the logits predictions from the placeholder
  if FLAGS.deeper:
    logits = inference_deeper(train_data_node)
  else:
    logits = inference(train_data_node)

  if return_logits:
    # We are returning the logits directly (no need to apply softmax)
    output = logits
  else:
    # Add softmax predictions to graph: will return probabilities
    output = tf.nn.softmax(logits)

  # Restore the moving average version of the learned variables for eval.
  variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY)
  variables_to_restore = variable_averages.variables_to_restore()
  saver = tf.train.Saver(variables_to_restore)

  # Will hold the result
  preds = np.zeros((data_length, FLAGS.nb_labels), dtype=np.float32)

  # Create TF session
  with tf.Session() as sess:
    # Restore TF session from checkpoint file
    saver.restore(sess, ckpt_path)

    # Parse data by batch
    for batch_nb in xrange(0, int(nb_batches+1)):
      # Compute batch start and end indices
      start, end = utils.batch_indices(batch_nb, data_length, FLAGS.batch_size)

      # Prepare feed dictionary
      feed_dict = {train_data_node: images[start:end]}

      # Run session ([0] because run returns a batch with len 1st dim == 1)
      preds[start:end, :] = sess.run([output], feed_dict=feed_dict)[0]

  # Reset graph to allow multiple calls
  tf.reset_default_graph()

  return preds
def model_train(sess,
                x,
                y,
                predictions,
                X_train,
                Y_train,
                save=False,
                predictions_adv=None,
                evaluate=None,
                verbose=True,
                args=None):
    """
    Train a TF graph
    :param sess: TF session to use when training the graph
    :param x: input placeholder
    :param y: output placeholder (for labels)
    :param predictions: model output predictions
    :param X_train: numpy array with training inputs
    :param Y_train: numpy array with training outputs
    :param save: boolean controling the save operation
    :param predictions_adv: if set with the adversarial example tensor,
                            will run adversarial training
    :param args: dict or argparse `Namespace` object.
                 Should contain `nb_epochs`, `learning_rate`,
                 `batch_size`
                 If save is True, should also contain 'train_dir'
                 and 'filename'
    :return: True if model trained
    """
    args = _FlagsWrapper(args or {})

    # Check that necessary arguments were given (see doc above)
    assert args.nb_epochs, "Number of epochs was not given in args dict"
    assert args.learning_rate, "Learning rate was not given in args dict"
    assert args.batch_size, "Batch size was not given in args dict"

    if save:
        assert args.train_dir, "Directory for save was not given in args dict"
        assert args.filename, "Filename for save was not given in args dict"

    # Define loss
    loss = model_loss(y, predictions)
    if predictions_adv is not None:
        p = 1.0
        loss = ((1 - p) * loss + p * model_loss(y, predictions_adv))

    train_step = tf.train.AdamOptimizer(
        learning_rate=args.learning_rate).minimize(loss)

    with sess.as_default():
        if hasattr(tf, "global_variables_initializer"):
            tf.global_variables_initializer().run()
        else:
            sess.run(tf.initialize_all_variables())

        for epoch in six.moves.xrange(args.nb_epochs):
            if verbose:
                print("Epoch " + str(epoch))

            # Compute number of batches
            nb_batches = int(math.ceil(float(len(X_train)) / args.batch_size))
            assert nb_batches * args.batch_size >= len(X_train)

            prev = time.time()
            for batch in range(nb_batches):

                # Compute batch start and end indices
                start, end = batch_indices(batch, len(X_train),
                                           args.batch_size)

                # Perform one training step
                train_step.run(feed_dict={
                    x: X_train[start:end],
                    y: Y_train[start:end]
                })
            assert end >= len(X_train)  # Check that all examples were used
            cur = time.time()
            if verbose:
                print("\tEpoch took " + str(cur - prev) + " seconds")
            prev = cur
            if evaluate is not None:
                evaluate()

        if save:
            save_path = os.path.join(args.train_dir, args.filename)
            saver = tf.train.Saver()
            saver.save(sess, save_path)
            print("Completed model training and saved at:" + str(save_path))
        else:
            print("Completed model training.")

    return True
Exemple #4
0
def train(images, labels, ckpt_path, dropout=False):
    """
  This function contains the loop that actually trains the model.
  :param images: a numpy array with the input data
  :param labels: a numpy array with the output labels
  :param ckpt_path: a path (including name) where model checkpoints are saved
  :param dropout: Boolean, whether to use dropout or not
  :return: True if everything went well
  """

    # Check training data
    assert len(images) == len(labels)
    assert images.dtype == np.float32
    assert labels.dtype == np.int32

    # Set default TF graph
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        # Declare data placeholder
        train_data_node = _input_placeholder()

        # Create a placeholder to hold labels
        train_labels_shape = (FLAGS.batch_size, )
        train_labels_node = tf.placeholder(tf.int32, shape=train_labels_shape)

        print("Done Initializing Training Placeholders")

        # Build a Graph that computes the logits predictions from the placeholder
        if FLAGS.deeper:
            logits = inference_deeper(train_data_node, dropout=dropout)
        else:
            logits = inference(train_data_node, dropout=dropout)

        # Calculate loss
        loss = loss_fun(logits, train_labels_node)
        tf.summary.scalar('loss', loss)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = train_op_fun(loss, global_step)

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables())

        print("Graph constructed and saver created")

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()
        merged = tf.summary.merge_all()

        # Create and init sessions
        sess = tf.Session(
            config=tf.ConfigProto(log_device_placement=FLAGS.
                                  log_device_placement))  #NOLINT(long-line)
        train_writer = tf.summary.FileWriter('/data/summary/train', sess.graph)
        test_writer = tf.summary.FileWriter('data/summary/test')
        sess.run(init)

        print("Session ready, beginning training loop")

        # Initialize the number of batches
        data_length = len(images)
        nb_batches = math.ceil(data_length / FLAGS.batch_size)

        for step in xrange(FLAGS.max_steps):
            # for debug, save start time
            start_time = time.time()

            # Current batch number
            batch_nb = step % nb_batches

            # Current batch start and end indices
            start, end = utils.batch_indices(batch_nb, data_length,
                                             FLAGS.batch_size)

            # Prepare dictionnary to feed the session with
            feed_dict = {
                train_data_node: images[start:end],
                train_labels_node: labels[start:end]
            }

            # Run training step
            if step % 10 == 0:
                summary, _, loss_value = sess.run([merged, train_op, loss],
                                                  feed_dict=feed_dict)
                train_writer.add_summary(summary, step)
            else:
                _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

            # Compute duration of training step
            duration = time.time() - start_time

            # Sanity check
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            # Echo loss once in a while
            if step % 100 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, sec_per_batch))

            # Save the model checkpoint periodically.
            if (step + 1) == FLAGS.max_steps:
                saver.save(sess, ckpt_path, global_step=step)

    return True
Exemple #5
0
def train(images_ori, labels, ckpt_path, dropout=False):
    """
    This function contains the loop that actually trains the model.
    :param images: a numpy array with the input data
    :param labels: a numpy array with the output labels
    :param ckpt_path: a path (including name) where model checkpoints are saved
    :param dropout: Boolean, whether to use dropout or not
    :return: True if everything went well
    """
    images = copy.deepcopy(
        images_ori)  # every time deep copy to keep original imgs

    if FLAGS.dataset == 'cifar10':
        images = (images - MEAN) / (STD + 1e-7)  # whitening imgs for training
    else:
        images -= 127.5
    print('start train using %s. images.mean: %.2f' %
          (FLAGS.dataset, np.mean(images)))

    # Check training data
    assert len(images) == len(labels)
    assert images.dtype == np.float32
    assert labels.dtype == np.int32

    global_step = tf.Variable(0, trainable=False)

    # Declare data placeholder
    train_data_node = _input_placeholder()

    # Create a placeholder to hold labels, None means any batch
    train_labels_node = tf.placeholder(tf.int32, shape=(None, ))

    print("Done Initializing Training Placeholders")

    # Build a Graph that computes the logits predictions from the placeholder
    logits = inference(train_data_node, dropout=dropout)

    # Calculate loss
    loss = loss_fun(logits, train_labels_node)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = train_op_fun(loss, global_step)

    # Create a saver.
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)

    print("Graph constructed and saver created")

    # Build an initialization operation to run below.
    init = tf.global_variables_initializer()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    #    config.log_device_placement=True
    config.gpu_options.allocator_type = 'BFC'
    config.gpu_options.per_process_gpu_memory_fraction = 0.2

    # Create and init sessions
    with tf.Session(config=config) as sess:
        sess.run(init)
        print("Session ready, beginning training loop")

        # Initialize the number of batches
        data_length = len(images)
        nb_batches = math.ceil(data_length / FLAGS.batch_size)  # >= x integers

        for step in range(FLAGS.max_steps):
            # for debug, save start time
            start_time = time.time()

            # Current batch number
            batch_nb = step % nb_batches

            # Current batch start and end indices
            start, end = utils.batch_indices(batch_nb, data_length,
                                             FLAGS.batch_size)

            # Prepare dictionnary to feed the session with
            feed_dict = {
                train_data_node: images[start:end],
                train_labels_node: labels[start:end]
            }

            # Run training step
            _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

            # Compute duration of training step
            duration = time.time() - start_time

            # Sanity check
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            # Echo loss once in a while
            if step % 1000 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = (
                    '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value,
                                    examples_per_sec, sec_per_batch))

            # Save the model checkpoint periodically.
            if (step + 1) == FLAGS.max_steps:
                saver.save(sess, ckpt_path, global_step=step)
                print('model is saved at: ', ckpt_path + '-' + str(step), '\n')
                time.asctime(time.localtime(time.time()))

    # Reset graph
    tf.reset_default_graph()

    return True
Exemple #6
0
def tf_model_train(sess,
                   x,
                   y,
                   predictions,
                   X_train,
                   Y_train,
                   X_test,
                   Y_test,
                   save=False,
                   predictions_adv=None,
                   data_augmentor=None):
    """
    Train a TF graph
    :param sess: TF session to use when training the graph
    :param x: input placeholder
    :param y: output placeholder (for labels)
    :param predictions: model output predictions
    :param X_train: numpy array with training inputs
    :param Y_train: numpy array with training outputs
    :param save: Boolean controling the save operation
    :param predictions_adv: if set with the adversarial example tensor,
                            will run adversarial training
    :return: True if model trained
    """
    print "Starting model training using TensorFlow."
    learning_rate = tf.placeholder(tf.float32, shape=())

    # Define loss
    loss = tf_model_loss(y, predictions)
    if predictions_adv is not None:
        loss = (loss + tf_model_loss(y, predictions_adv)) / 2

    train_step = tf.train.MomentumOptimizer(learning_rate,
                                            0.9,
                                            use_nesterov=True).minimize(loss)
    # train_step = tf.train.AdamOptimizer().minimize(loss)
    print "Defined optimizer."

    with sess.as_default():
        init = tf.initialize_all_variables()
        sess.run(init)

        for epoch in xrange(FLAGS.nb_epochs):
            prev = time.time()
            print("Epoch %s, lr: %s" % (epoch, _get_learning_rate(epoch)))

            # Compute number of batches
            nb_batches = int(math.ceil(float(len(X_train)) / FLAGS.batch_size))
            assert nb_batches * FLAGS.batch_size >= len(X_train)

            for batch in range(nb_batches):
                # Compute batch start and end indices
                start, end = batch_indices(batch, len(X_train),
                                           FLAGS.batch_size)
                batch_xs = X_train[start:end]
                batch_ys = Y_train[start:end]

                if data_augmentor is not None:
                    batch_xs = data_augmentor(batch_xs)

                # Perform one training step
                train_step.run(
                    feed_dict={
                        x: batch_xs,
                        y: batch_ys,
                        learning_rate: _get_learning_rate(epoch),
                        keras.backend.learning_phase(): 1
                    })
            cur = time.time()
            print("\tTook " + str(cur - prev) + " seconds")

            assert end >= len(X_train)  # Check that all examples were used

            accuracy = tf_model_eval(sess, x, y, predictions, X_test, Y_test)
            assert X_test.shape[0] == 10000, X_test.shape
            print '... Test accuracy on legitimate test examples: ' + str(
                accuracy)

        if save:
            save_path = os.path.join(FLAGS.train_dir, FLAGS.filename)
            saver = tf.train.Saver()
            saver.save(sess, save_path)
            print "Completed model training and model saved at:" + str(
                save_path)
        else:
            print "Completed model training."

    return True
Exemple #7
0
def train(sess,
          loss,
          x,
          y,
          X_train,
          Y_train,
          save=False,
          init_all=False,
          evaluate=None,
          feed=None,
          args=None,
          rng=None,
          var_list=None,
          fprop_args=None,
          optimizer=None):
    """
  Train a TF graph.
  This function is deprecated. Prefer cleverhans.train.train when possible.
  cleverhans.train.train supports multiple GPUs but this function is still
  needed to support legacy models that do not support calling fprop more
  than once.
  :param sess: TF session to use when training the graph
  :param loss: tensor, the model training loss.
  :param x: input placeholder
  :param y: output placeholder (for labels)
  :param X_train: numpy array with training inputs
  :param Y_train: numpy array with training outputs
  :param save: boolean controlling the save operation
  :param init_all: (boolean) If set to true, all TF variables in the session
                   are (re)initialized, otherwise only previously
                   uninitialized variables are initialized before training.
  :param evaluate: function that is run after each training iteration
                   (typically to display the test/validation accuracy).
  :param feed: An optional dictionary that is appended to the feeding
               dictionary before the session runs. Can be used to feed
               the learning phase of a Keras model for instance.
  :param args: dict or argparse `Namespace` object.
               Should contain `nb_epochs`, `learning_rate`,
               `batch_size`
               If save is True, should also contain 'train_dir'
               and 'filename'
  :param rng: Instance of numpy.random.RandomState
  :param var_list: Optional list of parameters to train.
  :param fprop_args: dict, extra arguments to pass to fprop (loss and model).
  :param optimizer: Optimizer to be used for training
  :return: True if model trained
  """
    warnings.warn("This function is deprecated and will be removed on or after"
                  " 2019-04-05. Switch to cleverhans.train.train.")

    args = _ArgsWrapper(args or {})
    fprop_args = fprop_args or {}

    # Check that necessary arguments were given (see doc above)
    assert args.nb_epochs, "Number of epochs was not given in args dict"
    if optimizer is None:
        assert args.learning_rate is not None, ("Learning rate was not given "
                                                "in args dict")
    assert args.batch_size, "Batch size was not given in args dict"

    if save:
        assert args.train_dir, "Directory for save was not given in args dict"
        assert args.filename, "Filename for save was not given in args dict"

    if rng is None:
        rng = np.random.RandomState()

    # Define optimizer
    loss_value = loss.fprop(x, y, **fprop_args)
    if optimizer is None:
        optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    else:
        if not isinstance(optimizer, tf.train.Optimizer):
            raise ValueError("optimizer object must be from a child class of "
                             "tf.train.Optimizer")
    # Trigger update operations within the default graph (such as batch_norm).
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_step = optimizer.minimize(loss_value, var_list=var_list)

    with sess.as_default():
        if hasattr(tf, "global_variables_initializer"):
            if init_all:
                tf.global_variables_initializer().run()
            else:
                initialize_uninitialized_global_variables(sess)
        else:
            warnings.warn("Update your copy of tensorflow; future versions of "
                          "CleverHans may drop support for this version.")
            sess.run(tf.initialize_all_variables())

        for epoch in xrange(args.nb_epochs):
            # Compute number of batches
            nb_batches = int(math.ceil(float(len(X_train)) / args.batch_size))
            assert nb_batches * args.batch_size >= len(X_train)

            # Indices to shuffle training set
            index_shuf = list(range(len(X_train)))
            rng.shuffle(index_shuf)

            prev = time.time()
            for batch in range(nb_batches):

                # Compute batch start and end indices
                start, end = batch_indices(batch, len(X_train),
                                           args.batch_size)

                # Perform one training step
                feed_dict = {
                    x: X_train[index_shuf[start:end]],
                    y: Y_train[index_shuf[start:end]]
                }
                if feed is not None:
                    feed_dict.update(feed)
                train_step.run(feed_dict=feed_dict)
            assert end >= len(X_train)  # Check that all examples were used
            cur = time.time()
            _logger.info("Epoch " + str(epoch) + " took " + str(cur - prev) +
                         " seconds")
            if evaluate is not None:
                evaluate()

        if save:
            save_path = os.path.join(args.train_dir, args.filename)
            saver = tf.train.Saver()
            saver.save(sess, save_path)
            _logger.info("Completed model training and saved at: " +
                         str(save_path))
        else:
            _logger.info("Completed model training.")

    return True
Exemple #8
0
def model_train(sess,
                x,
                y,
                predictions,
                X_train,
                Y_train,
                save=False,
                predictions_adv=None,
                init_all=True,
                evaluate=None,
                feed=None,
                args=None,
                rng=None,
                var_list=None):
    """
  Train a TF graph
  :param sess: TF session to use when training the graph
  :param x: input placeholder
  :param y: output placeholder (for labels)
  :param predictions: model output predictions
  :param X_train: numpy array with training inputs
  :param Y_train: numpy array with training outputs
  :param save: boolean controlling the save operation
  :param predictions_adv: if set with the adversarial example tensor,
                          will run adversarial training
  :param init_all: (boolean) If set to true, all TF variables in the session
                   are (re)initialized, otherwise only previously
                   uninitialized variables are initialized before training.
  :param evaluate: function that is run after each training iteration
                   (typically to display the test/validation accuracy).
  :param feed: An optional dictionary that is appended to the feeding
               dictionary before the session runs. Can be used to feed
               the learning phase of a Keras model for instance.
  :param args: dict or argparse `Namespace` object.
               Should contain `nb_epochs`, `learning_rate`,
               `batch_size`
               If save is True, should also contain 'train_dir'
               and 'filename'
  :param rng: Instance of numpy.random.RandomState
  :param var_list: Optional list of parameters to train.
  :return: True if model trained
  """
    warnings.warn("This function is deprecated and will be removed on or after"
                  " 2019-04-05. Switch to cleverhans.train.train.")
    args = _ArgsWrapper(args or {})

    # Check that necessary arguments were given (see doc above)
    assert args.nb_epochs, "Number of epochs was not given in args dict"
    assert args.learning_rate, "Learning rate was not given in args dict"
    assert args.batch_size, "Batch size was not given in args dict"

    if save:
        assert args.train_dir, "Directory for save was not given in args dict"
        assert args.filename, "Filename for save was not given in args dict"

    if rng is None:
        rng = np.random.RandomState()

    # Define loss
    loss = model_loss(y, predictions)
    if predictions_adv is not None:
        loss = (loss + model_loss(y, predictions_adv)) / 2

    train_step = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    train_step = train_step.minimize(loss, var_list=var_list)

    with sess.as_default():
        if hasattr(tf, "global_variables_initializer"):
            if init_all:
                tf.global_variables_initializer().run()
            else:
                initialize_uninitialized_global_variables(sess)
        else:
            warnings.warn("Update your copy of tensorflow; future versions of "
                          "CleverHans may drop support for this version.")
            sess.run(tf.initialize_all_variables())

        for epoch in xrange(args.nb_epochs):
            # Compute number of batches
            nb_batches = int(math.ceil(float(len(X_train)) / args.batch_size))
            assert nb_batches * args.batch_size >= len(X_train)

            # Indices to shuffle training set
            index_shuf = list(range(len(X_train)))
            rng.shuffle(index_shuf)

            prev = time.time()
            for batch in range(nb_batches):

                # Compute batch start and end indices
                start, end = batch_indices(batch, len(X_train),
                                           args.batch_size)

                # Perform one training step
                feed_dict = {
                    x: X_train[index_shuf[start:end]],
                    y: Y_train[index_shuf[start:end]]
                }
                if feed is not None:
                    feed_dict.update(feed)
                train_step.run(feed_dict=feed_dict)
            assert end >= len(X_train)  # Check that all examples were used
            cur = time.time()
            _logger.info("Epoch " + str(epoch) + " took " + str(cur - prev) +
                         " seconds")
            if evaluate is not None:
                evaluate()

        if save:
            save_path = os.path.join(args.train_dir, args.filename)
            saver = tf.train.Saver()
            saver.save(sess, save_path)
            _logger.info("Completed model training and saved at: " +
                         str(save_path))
        else:
            _logger.info("Completed model training.")

    return True
Exemple #9
0
def softmax_preds(images_ori, ckpt_final, return_logits=False):
    """
    Compute softmax activations (probabilities) with the model saved in the path
    specified as an argument
    :param images: a np array of images
    :param ckpt_final: a TF model checkpoint
    :param logits: if set to True, return logits instead of probabilities
    :return: probabilities (or logits if logits is set to True)
    """
    images = copy.deepcopy(images_ori)
    if FLAGS.dataset == 'cifar10':
        images = (images - MEAN) / (STD + 1e-7)  # whitening imgs for training
    else:
        images -= 127.5
    if len(images.shape) == 3:  # x.shape is (28, 28, 1) or (32, 32, 3)
        images = np.expand_dims(images, axis=0)
    #print('start pred using %s. images.mean: %.2f' % (FLAGS.dataset, np.mean(images)))

    # Compute nb samples and deduce nb of batches
    data_length = len(images)
    nb_batches = math.ceil(len(images) / FLAGS.batch_size)

    # Declare data placeholder
    train_data_node = _input_placeholder()

    # Build a Graph that computes the logits predictions from the placeholder
    logits = inference(train_data_node)

    if return_logits:
        # We are returning the logits directly (no need to apply softmax)
        output = logits
    else:
        # Add softmax predictions to graph: will return probabilities
        output = tf.nn.softmax(logits)

    # Restore the moving average version of the learned variables for eval.
    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY)
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)

    # Will hold the result
    preds = np.zeros((data_length, FLAGS.nb_labels), dtype=np.float32)

    # Create TF session
    with tf.Session() as sess:
        # Restore TF session from checkpoint file
        saver.restore(sess, ckpt_final)
        # print('model is restored at: ', ckpt_final,'\n')

        # Parse data by batch
        for batch_nb in range(0, int(nb_batches + 1)):
            # Compute batch start and end indices
            start, end = utils.batch_indices(batch_nb, data_length,
                                             FLAGS.batch_size)

            # Prepare feed dictionary
            feed_dict = {train_data_node: images[start:end]}

            # Run session ([0] because run returns a batch with len 1st dim == 1)
            # From jjn: why do you add [] on 'output'????
            preds[start:end, :] = sess.run([output], feed_dict=feed_dict)[0]

    # Reset graph to allow multiple calls
    tf.reset_default_graph()

    return preds
Exemple #10
0
def model_train(sess,
                x,
                y,
                predictions,
                X_train,
                Y_train,
                save=False,
                predictions_adv=None,
                evaluate=None):
    """
    Train a TF graph
    :param sess: TF session to use when training the graph
    :param x: input placeholder
    :param y: output placeholder (for labels)
    :param predictions: model output predictions
    :param X_train: numpy array with training inputs
    :param Y_train: numpy array with training outputs
    :param save: Boolean controling the save operation
    :param predictions_adv: if set with the adversarial example tensor,
                            will run adversarial training
    :return: True if model trained
    """

    # Define loss
    loss = model_loss(y, predictions)
    if predictions_adv is not None:
        loss = (loss + model_loss(y, predictions_adv)) / 2

    train_step = tf.train.AdadeltaOptimizer(learning_rate=FLAGS.learning_rate,
                                            rho=0.95,
                                            epsilon=1e-08).minimize(loss)

    with sess.as_default():
        if hasattr(tf, "global_variables_initializer"):
            tf.global_variables_initializer().run()
        else:
            warnings.warn("Update your copy of tensorflow; future versions of"
                          "cleverhans may drop support for this version.")
            sess.run(tf.initialize_all_variables())

        for epoch in six.moves.xrange(FLAGS.nb_epochs):
            print("Epoch " + str(epoch))

            # Compute number of batches
            nb_batches = int(math.ceil(float(len(X_train)) / FLAGS.batch_size))
            assert nb_batches * FLAGS.batch_size >= len(X_train)

            prev = time.time()
            for batch in range(nb_batches):

                # Compute batch start and end indices
                start, end = batch_indices(batch, len(X_train),
                                           FLAGS.batch_size)

                # Perform one training step
                train_step.run(
                    feed_dict={
                        x: X_train[start:end],
                        y: Y_train[start:end],
                        keras.backend.learning_phase(): 1
                    })
            assert end >= len(X_train)  # Check that all examples were used
            cur = time.time()
            print("\tEpoch took " + str(cur - prev) + " seconds")
            prev = cur
            if evaluate is not None:
                evaluate()

        if save:
            save_path = os.path.join(FLAGS.train_dir, FLAGS.filename)
            saver = tf.train.Saver()
            saver.save(sess, save_path)
            print("Completed model training and model saved at:" +
                  str(save_path))
        else:
            print("Completed model training.")

    return True
def train(images, labels, ckpt_path, dropout=False):
  """
  This function contains the loop that actually trains the model.
  :param images: a numpy array with the input data
  :param labels: a numpy array with the output labels
  :param ckpt_path: a path (including name) where model checkpoints are saved
  :param dropout: Boolean, whether to use dropout or not
  :return: True if everything went well
  """

  # Check training data
  assert len(images) == len(labels)
  assert images.dtype == np.float32
  assert labels.dtype == np.int32

  # Set default TF graph
  with tf.Graph().as_default():
    global_step = tf.Variable(0, trainable=False)

    # Declare data placeholder
    train_data_node = _input_placeholder()

    # Create a placeholder to hold labels
    train_labels_shape = (FLAGS.batch_size,)
    train_labels_node = tf.placeholder(tf.int32, shape=train_labels_shape)

    print("Done Initializing Training Placeholders")

    # Build a Graph that computes the logits predictions from the placeholder
    if FLAGS.deeper:
      logits = inference_deeper(train_data_node, dropout=dropout)
    else:
      logits = inference(train_data_node, dropout=dropout)

    # Calculate loss
    loss = loss_fun(logits, train_labels_node)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = train_op_fun(loss, global_step)

    # Create a saver.
    saver = tf.train.Saver(tf.all_variables())

    print("Graph constructed and saver created")

    # Build an initialization operation to run below.
    init = tf.initialize_all_variables()

    # Create and init sessions
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)) #NOLINT(long-line)
    sess.run(init)

    print("Session ready, beginning training loop")

    # Initialize the number of batches
    data_length = len(images)
    nb_batches = math.ceil(data_length / FLAGS.batch_size)

    for step in xrange(FLAGS.max_steps):
      # for debug, save start time
      start_time = time.time()

      # Current batch number
      batch_nb = step % nb_batches

      # Current batch start and end indices
      start, end = utils.batch_indices(batch_nb, data_length, FLAGS.batch_size)

      # Prepare dictionnary to feed the session with
      feed_dict = {train_data_node: images[start:end],
                   train_labels_node: labels[start:end]}

      # Run training step
      _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

      # Compute duration of training step
      duration = time.time() - start_time

      # Sanity check
      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      # Echo loss once in a while
      if step % 100 == 0:
        num_examples_per_step = FLAGS.batch_size
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = float(duration)

        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,
                             examples_per_sec, sec_per_batch))

      # Save the model checkpoint periodically.
      if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        saver.save(sess, ckpt_path, global_step=step)

  return True
Exemple #12
0
def tf_model_train(sess,
                   x,
                   y,
                   predictions,
                   X_train,
                   Y_train,
                   save=False,
                   predictions_adv=None):
    """
    Train a TF graph
    :param sess: TF session to use when training the graph
    :param x: input placeholder
    :param y: output placeholder (for labels)
    :param predictions: model output predictions
    :param X_train: numpy array with training inputs
    :param Y_train: numpy array with training outputs
    :param save: Boolean controling the save operation
    :param predictions_adv: if set with the adversarial example tensor,
                            will run adversarial training
    :return: True if model trained
    """
    print "Starting model training using TensorFlow."

    # Define loss
    loss = tf_model_loss(y, predictions)
    if predictions_adv is not None:
        loss = (loss + tf_model_loss(y, predictions_adv)) / 2

    train_step = tf.train.AdadeltaOptimizer(learning_rate=FLAGS.learning_rate,
                                            rho=0.95,
                                            epsilon=1e-08).minimize(loss)
    # train_step = tf.train.GradientDescentOptimizer(FLAGS.learning_rate).minimize(loss)
    print "Defined optimizer."

    with sess.as_default():
        init = tf.initialize_all_variables()
        sess.run(init)

        for epoch in xrange(FLAGS.nb_epochs):
            print("Epoch " + str(epoch))

            # Compute number of batches
            nb_batches = int(math.ceil(float(len(X_train)) / FLAGS.batch_size))
            assert nb_batches * FLAGS.batch_size >= len(X_train)

            prev = time.time()
            for batch in range(nb_batches):
                if batch % 100 == 0 and batch > 0:
                    print("Batch " + str(batch))
                    cur = time.time()
                    print("\tTook " + str(cur - prev) + " seconds")
                    prev = cur

                # Compute batch start and end indices
                start, end = batch_indices(batch, len(X_train),
                                           FLAGS.batch_size)

                # Perform one training step
                train_step.run(
                    feed_dict={
                        x: X_train[start:end],
                        y: Y_train[start:end],
                        keras.backend.learning_phase(): 1
                    })
            assert end >= len(X_train)  # Check that all examples were used

        if save:
            save_path = os.path.join(FLAGS.train_dir, FLAGS.filename)
            saver = tf.train.Saver()
            saver.save(sess, save_path)
            print "Completed model training and model saved at:" + str(
                save_path)
        else:
            print "Completed model training."

    return True
Exemple #13
0
    def train(self, X, y, test_X=None, return_probs=False):
        '''
        Learns parameters. 
        X               input features as an numpy array, dtype should be np.float32 (it is converted if not)
        y               labels for the rows of X. Class labels will be converted to integers starting from 0.
        test_X          new examples which are classfied with the trained model
        return_probs    whether to return arg_max class or class probabilities for the test data
        '''
        if X.dtype is not np.float32:
            X = X.astype(np.float32)

        # normalize labels
        y, encoder, decoder = normalize_labels(y)
        self.label_encoder = encoder
        self.label_decoder = decoder

        n_labels = len(np.unique(y))

        # split a validation set
        X, X_valid, y, y_valid = train_test_split(X,
                                                  y,
                                                  test_size=self.valid_size,
                                                  random_state=self.seed + 3)

        n, d = X.shape

        tf_X = tf.placeholder(dtype=tf.float32, shape=(None, d))
        tf_y = tf.placeholder(dtype=tf.int32, shape=(None, ))

        # define the model
        weights, biases = self._parameters(d, n_labels)
        logits = self._model(tf_X, weights, biases)

        # optimization and the outputs
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf_y, logits=logits + self.epsss),
                              axis=0)
        # penalize weight matrices (biases are not penalized)
        if self.l2 > 0:
            penalty = 0

            for w in weights:
                penalty += tf.nn.l2_loss(weights[w])

            loss += self.l2 * penalty

        y_hat = tf.nn.softmax(logits=logits + self.epsss)
        arg_max_class = tf.argmax(y_hat, 1, output_type=tf.int32)

        train_step = tf.train.AdamOptimizer(
            learning_rate=self.lr).minimize(loss)
        corrects = tf.equal(tf.argmax(y_hat, 1, output_type=tf.int32), tf_y)
        accuracy = tf.reduce_mean(tf.cast(corrects, dtype=tf.float32), axis=0)

        init = tf.global_variables_initializer()

        with tf.Session() as sess:

            sess.run(init)

            # training
            for ii in range(self.epochs):

                train_loss = 0
                train_acc = 0

                all_batch_indices = batch_indices(n,
                                                  self.batch_size,
                                                  shuffle=True,
                                                  seed=self.seed + 2 * ii)
                n_batches = len(all_batch_indices)

                for inx in all_batch_indices:
                    X_batch = X[inx, :]
                    y_batch = y[inx]

                    _, l, a = sess.run([train_step, loss, accuracy],
                                       feed_dict={
                                           tf_X: X_batch,
                                           tf_y: y_batch
                                       })
                    train_loss += l / n_batches
                    train_acc += a / n_batches

                print("Epoch: %3d" % (ii + 1),
                      "Train accuracy: %.2f" % train_acc,
                      "Loss: %.4f" % train_loss)

                if self.valid_size > 0:
                    val_acc = sess.run(accuracy,
                                       feed_dict={
                                           tf_X: X_valid,
                                           tf_y: y_valid
                                       })
                    print("Validation accuracy: %.2f" % val_acc)

                out = None

            # possible test examples
            if test_X is not None:

                if test_X.dtype is not np.float32:
                    test_X = test_X.astype(np.float32)

                if return_probs:
                    out = sess.run(y_hat, feed_dict={tf_X: test_X})
                else:
                    out = sess.run(arg_max_class, feed_dict={tf_X: test_X})

        return out
def model_train_test(sess,
                     x,
                     y,
                     predictions,
                     X_train,
                     Y_train,
                     save=False,
                     predictions_adv=None,
                     evaluate=None,
                     regulizer=False,
                     regcons=0.5,
                     model=None,
                     verbose=True,
                     args=None):

    args = _FlagsWrapper(args or {})

    # Check that necessary arguments were given (see doc above)
    assert args.nb_epochs, "Number of epochs was not given in args dict"
    assert args.learning_rate, "Learning rate was not given in args dict"
    assert args.batch_size, "Batch size was not given in args dict"

    if save:
        assert args.train_dir, "Directory for save was not given in args dict"
        assert args.filename, "Filename for save was not given in args dict"

    opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    opt1 = tf.train.AdamOptimizer(learning_rate=args.learning_rate)

    # Define loss
    loss = model_loss(y, predictions)
    if predictions_adv is not None:
        p = 1.0
        loss = ((1 - p) * loss + p * model_loss(y, predictions_adv))

    if regulizer:

        # apply sn by adding gradient to backprop
        # grads_and_vars = opt.compute_gradients(loss)
        # new_grads = []
        # for grad, var in grads_and_vars:
        #
        #     shp = var.get_shape().as_list()
        #     print("- {} shape:{} size:{}".format(var.name, shp, np.prod(shp)))
        #     if 'kernel' in var.name:
        #         if len(shp) == 4:
        #             s, u, v = power_iterate(var, 100)
        #         else:
        #             s, u, v = tf.svd(var)
        #             # s, u, v = power_iterate(var, 10)
        #
        #         left_vector = tf.slice(u, [0, 0], [shp[0],1])
        #         right_vector = tf.slice(v, [0,0], [shp[1],1])
        #         sn_grad = regcons * s[0] * tf.matmul(left_vector, right_vector, transpose_b=True)
        #         grad = tf.add(sn_grad, grad)
        #     new_grads.append( (grad, var) )

        # Collecting loss from reg loss
        reg_losses = tf.losses.get_regularization_losses()
        loss1 = loss + tf.add_n(reg_losses)

        # lossfunc + spectral norm
        vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        weights_svd = []
        for w in vars:
            shp = w.get_shape().as_list()
            print("- {} shape:{} size:{}".format(w.name, shp, np.prod(shp)))
            if 'kernel' in w.name:
                # sn = tf.svd(w, compute_uv=False)[..., 0]
                sn, vv, uu = power_iterate(w, 100, w.name.split('/')[1])
                weights_svd.append(sn)

        # apply sn by adding to loss
        loss2 = loss + regcons * tf.add_n(weights_svd)
        #
        grads_regloss = opt.compute_gradients(loss1)
        grads_lossplusreg = opt1.compute_gradients(loss2)

    # test power iteration
    test = False
    if test:
        if model.name == 'cnn':
            layer1 = model.get_layer(name='conv1')
            weights = layer1.kernel
            strides = layer1.strides[0]
            padding = layer1.padding
            input_shape = layer1.input_shape
            output_shape = layer1.output_shape
            num_iter = 20
            powerit = power_iterate_conv(weights=layer1.kernel,
                                         strides=layer1.strides[0],
                                         padding=layer1.padding.upper(),
                                         input_shape=layer1.input_shape,
                                         output_shape=layer1.output_shape,
                                         num_iter=10)

            shp = weights.get_shape().as_list()
            w = tf.reshape(weights, [shp[0] * shp[1] * shp[2], shp[3]])
            sn = tf.svd(w, compute_uv=False)[..., 0]
        else:  # dense layer
            # layer1 = model.get_layer(name='dense1')
            # weights = layer1.kernel
            weights = tf.placeholder(tf.float32, shape=(300, 784))
            num_iter = 100
            s, u, v = power_iterate(weights, num_iter)
            s1, u1, v1 = tf.svd(weights)

    train_step = opt.minimize(loss)
    # train_step = opt.apply_gradients(new_grads)

    with sess.as_default():
        # writer = tf.summary.FileWriter("/tmp/log/", sess.graph)
        if hasattr(tf, "global_variables_initializer"):
            tf.global_variables_initializer().run()
        else:
            sess.run(tf.initialize_all_variables())

        # Unit test
        # randw = np.random.rand(300,784)
        # layer1 = model.get_layer(name='dense1')
        # randw = layer1.kernel.eval().T
        # for i in range(30):
        #     u, u1, v, v1, s, s1 = sess.run([u, u1, v, v1, s, s1], feed_dict={weights: randw})
        #     randw += np.random.rand(300,784)

        for epoch in six.moves.xrange(args.nb_epochs):
            if verbose:
                print("Epoch " + str(epoch))

            # Compute number of batches
            nb_batches = int(math.ceil(float(len(X_train)) / args.batch_size))
            assert nb_batches * args.batch_size >= len(X_train)

            prev = time.time()
            for batch in range(nb_batches):

                # Compute batch start and end indices
                start, end = batch_indices(batch, len(X_train),
                                           args.batch_size)

                # grad_soft, grad_soft_svd, grads_regloss = sess.run([grads_and_vars, new_grads, grads_regloss],
                #                                                    feed_dict={x: X_train[start:end],
                #                                                               y: Y_train[start:end]})

                grads_regloss, grads_lossplusreg = sess.run(
                    [grads_regloss, grads_lossplusreg],
                    feed_dict={
                        x: X_train[start:end],
                        y: Y_train[start:end]
                    })

                # Perform one training step
                train_step.run(feed_dict={
                    x: X_train[start:end],
                    y: Y_train[start:end]
                })
            assert end >= len(X_train)  # Check that all examples were used
            cur = time.time()
            if verbose:
                print("\tEpoch took " + str(cur - prev) + " seconds")
            prev = cur
            if evaluate is not None:
                evaluate()

        if save:
            save_path = os.path.join(args.train_dir, args.filename)
            saver = tf.train.Saver()
            saver.save(sess, save_path)
            print("Completed model training and saved at:" + str(save_path))
        else:
            print("Completed model training.")
        # writer.close()
    return True
def model_train(sess,
                x,
                y,
                predictions,
                X_train,
                Y_train,
                save=False,
                predictions_adv=None,
                evaluate=None,
                lossregfunc=False,
                regulizer=False,
                regcons=0.5,
                model=None,
                verbose=True,
                args=None):
    """
    Train a TF graph
    :param sess: TF session to use when training the graph
    :param x: input placeholder
    :param y: output placeholder (for labels)
    :param predictions: model output predictions
    :param X_train: numpy array with training inputs
    :param Y_train: numpy array with training outputs
    :param save: boolean controling the save operation
    :param predictions_adv: if set with the adversarial example tensor,
                            will run adversarial training
    :param args: dict or argparse `Namespace` object.
                 Should contain `nb_epochs`, `learning_rate`,
                 `batch_size`
                 If save is True, should also contain 'train_dir'
                 and 'filename'
    :return: True if model trained
    """
    args = _FlagsWrapper(args or {})

    # Check that necessary arguments were given (see doc above)
    assert args.nb_epochs, "Number of epochs was not given in args dict"
    assert args.learning_rate, "Learning rate was not given in args dict"
    assert args.batch_size, "Batch size was not given in args dict"

    if save:
        assert args.train_dir, "Directory for save was not given in args dict"
        assert args.filename, "Filename for save was not given in args dict"

    opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)

    # Define loss
    loss = model_loss(y, predictions)
    if predictions_adv is not None:
        p = 1.0
        loss = ((1 - p) * loss + p * model_loss(y, predictions_adv))

    if regulizer:
        if not lossregfunc:
            # collecting from reg loss
            reg_losses = tf.losses.get_regularization_losses()
            loss += tf.add_n(reg_losses)
        else:
            vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            weights_svd = []
            for w in vars:
                shp = w.get_shape().as_list()
                print("- {} shape:{} size:{}".format(w.name, shp,
                                                     np.prod(shp)))
                if 'kernel' in w.name:
                    # sn = tf.svd(w, compute_uv=False)[..., 0]
                    sn, vv, uu = power_iterate(w, 10, w.name.split('/')[1])
                    weights_svd.append(sn)

            # apply sn by adding to loss
            loss += regcons * tf.add_n(weights_svd)

        # apply sn by adding gradient to backprop
        # grads_and_vars = opt.compute_gradients(loss)
        # new_grads = []
        # for grad, var in grads_and_vars:
        #
        #     shp = var.get_shape().as_list()
        #     print("- {} shape:{} size:{}".format(var.name, shp, np.prod(shp)))
        #     if 'kernel' in var.name:
        #         if len(shp) == 4: # convolutional layer
        #             layer_name = var.name.split('/')[0]
        #             layer = model.get_layer(name=layer_name)
        #             s, left_vector, right_vector = power_iterate_conv(weights=layer.kernel, strides=layer.strides[0],
        #                                          padding=layer.padding.upper(), input_shape=layer.input_shape,
        #                                          output_shape=layer.output_shape, num_iter=10, weight_name=layer_name)
        #             # sn_grad = regcons * s * tf.matmul(left_vector, right_vector, transpose_b=True)
        #         else: # fully connected layer
        #             layer_name = var.name.split('/')[0]
        #
        #             # s, u, v = tf.svd(var)
        #             # left_vector = tf.slice(u, [0, 0], [shp[0],1])
        #             # right_vector = tf.slice(v, [0,0], [shp[1],1])
        #             # sn_grad = regcons * s[0] * tf.matmul(left_vector, right_vector, transpose_b=True)
        #
        #             s, left_vector, right_vector = power_iterate(weights=var, num_iter=10, weight_name=layer_name)
        #             sn_grad = regcons * s * tf.matmul(left_vector, right_vector, transpose_b=True)
        #
        #         grad = tf.add(sn_grad, grad)
        #
        #     new_grads.append( (grad, var) )

    # train_step = opt.apply_gradients(new_grads)
    train_step = opt.minimize(loss)

    with sess.as_default():
        # writer = tf.summary.FileWriter("/tmp/log/", sess.graph)
        if hasattr(tf, "global_variables_initializer"):
            tf.global_variables_initializer().run()
        else:
            sess.run(tf.initialize_all_variables())

        for epoch in six.moves.xrange(args.nb_epochs):
            if verbose:
                print("Epoch " + str(epoch))

            # Compute number of batches
            nb_batches = int(math.ceil(float(len(X_train)) / args.batch_size))
            assert nb_batches * args.batch_size >= len(X_train)

            prev = time.time()
            for batch in range(nb_batches):

                # Compute batch start and end indices
                start, end = batch_indices(batch, len(X_train),
                                           args.batch_size)

                # Perform one training step
                train_step.run(feed_dict={
                    x: X_train[start:end],
                    y: Y_train[start:end]
                })
            assert end >= len(X_train)  # Check that all examples were used
            cur = time.time()
            if verbose:
                print("\tEpoch took " + str(cur - prev) + " seconds")
            prev = cur
            if evaluate is not None:
                evaluate()

        if save:
            save_path = os.path.join(args.train_dir, args.filename)
            saver = tf.train.Saver()
            saver.save(sess, save_path)
            print("Completed model training and saved at:" + str(save_path))
        else:
            print("Completed model training.")
        # writer.close()
    return True
    def train(x, y, save=False,
                    predictions_adv=None, evaluate=None, verbose=True, args=None):
        args = _FlagsWrapper(args or {})

        # Check that necessary arguments were given (see doc above)
        assert args.nb_epochs, "Number of epochs was not given in args dict"
        assert args.learning_rate, "Learning rate was not given in args dict"
        assert args.batch_size, "Batch size was not given in args dict"

        if save:
            assert args.train_dir, "Directory for save was not given in args dict"
            assert args.filename, "Filename for save was not given in args dict"

        # Define loss
        loss = model_loss(y, self.predictions)
        if predictions_adv is not None:
            p = 1.0
            loss = ((1-p)*loss + p*model_loss(y, predictions_adv))

        train_step = tf.train.AdamOptimizer(learning_rate=args.learning_rate).minimize(loss)

        with self.sess.as_default():
            if hasattr(tf, "global_variables_initializer"):
                tf.global_variables_initializer().run()
            else:
                self.sess.run(tf.initialize_all_variables())

            for epoch in six.moves.xrange(args.nb_epochs):
                if verbose:
                    print("Epoch " + str(epoch))

                # Compute number of batches
                nb_batches = int(math.ceil(float(len(X_train)) / args.batch_size))
                assert nb_batches * args.batch_size >= len(X_train)

                prev = time.time()
                for batch in range(nb_batches):

                    # Compute batch start and end indices
                    start, end = batch_indices(batch, len(X_train), args.batch_size)

                    # update adversarial examples
                    self.update_adv(start, end)
                    # Perform one training step (on adversarial examples)
                    train_step.run(feed_dict={x: X_adv[start:end],
                                              y: Y_train[start:end]})
                assert end >= len(X_train)  # Check that all examples were used
                cur = time.time()
                if verbose:
                    print("\tEpoch took " + str(cur - prev) + " seconds")
                prev = cur
                if evaluate is not None:
                    evaluate()

            if save:
                save_path = os.path.join(args.train_dir, args.filename)
                saver = tf.train.Saver()
                saver.save(sess, save_path)
                print("Completed model training and saved at:" + str(save_path))
            else:
                print("Completed model training.")

        return True