def test_mnist(self, subset): if not FLAGS.data_dir: tf.logging.info("data_dir flag not provided. Quitting test") return batch_size = 10 image_shape = (28, 28, 1) dataset = data_provider.MNIST( data_dir=FLAGS.data_dir, subset=subset, batch_size=batch_size) images, labels = dataset.images, dataset.labels with self.test_session() as sess: im, l = sess.run((images, labels)) self.assertEqual(im.shape, (batch_size,) + image_shape) self.assertEqual(l.shape, (batch_size,))
def train(): """Training function.""" is_chief = (FLAGS.task == 0) g = tf.Graph() with g.as_default(): with tf.device( tf.train.replica_device_setter(ps_tasks=FLAGS.ps_tasks)): if FLAGS.experiment_type == "mnist": config = mnist_config.ConfigDict() dataset = mnist.MNIST(data_dir=FLAGS.data_dir, subset="train", batch_size=FLAGS.batch_size, is_training=True) model = mnist_model.MNISTNetwork(config) layers_names = [ "conv_layer%d" % i for i in range(len(config.filter_sizes_conv_layers)) ] images, labels, num_examples, num_classes = (dataset.images, dataset.labels, dataset.num_examples, dataset.num_classes) tf.summary.image("images", images) # Build model. logits, endpoints = model(images, is_training=True) layers_list = [images] + [endpoints[name] for name in layers_names] # Define losses. l2_loss_wt = config.l2_loss_wt xent_loss_wt = config.xent_loss_wt margin_loss_wt = config.margin_loss_wt gamma = config.gamma alpha = config.alpha top_k = config.top_k dist_norm = config.dist_norm with tf.name_scope("losses"): xent_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=labels)) margin = margin_loss.large_margin( logits=logits, one_hot_labels=tf.one_hot(labels, num_classes), layers_list=layers_list, gamma=gamma, alpha_factor=alpha, top_k=top_k, dist_norm=dist_norm, epsilon=1e-6, layers_weights=[ np.prod(layer.get_shape().as_list()[1:]) for layer in layers_list ] if np.isinf(dist_norm) else None) l2_loss = 0. for v in tf.trainable_variables(): tf.logging.info(v) l2_loss += tf.nn.l2_loss(v) total_loss = 0 if xent_loss_wt > 0: total_loss += xent_loss_wt * xent_loss if margin_loss_wt > 0: total_loss += margin_loss_wt * margin if l2_loss_wt > 0: total_loss += l2_loss_wt * l2_loss tf.summary.scalar("xent_loss", xent_loss) tf.summary.scalar("margin_loss", margin) tf.summary.scalar("l2_loss", l2_loss) tf.summary.scalar("total_loss", total_loss) # Build optimizer. init_lr = config.init_lr with tf.name_scope("optimizer"): global_step = tf.train.get_or_create_global_step() if FLAGS.num_replicas > 1: num_batches_per_epoch = num_examples // ( FLAGS.batch_size * FLAGS.num_replicas) else: num_batches_per_epoch = num_examples // FLAGS.batch_size max_iters = num_batches_per_epoch * FLAGS.num_epochs lr = tf.train.exponential_decay(init_lr, global_step, FLAGS.decay_steps, FLAGS.decay_rate, staircase=True, name="lr_schedule") tf.summary.scalar("learning_rate", lr) var_list = tf.trainable_variables() grad_vars = tf.gradients(total_loss, var_list) tf.summary.scalar( "grad_norm", tf.reduce_mean( [tf.norm(grad_var) for grad_var in grad_vars])) grad_vars, _ = tf.clip_by_global_norm(grad_vars, 5.0) opt = tf.train.RMSPropOptimizer(lr, momentum=FLAGS.momentum, epsilon=1e-2) if FLAGS.num_replicas > 1: opt = tf.train.SyncReplicasOptimizer( opt, replicas_to_aggregate=FLAGS.num_replicas, total_num_replicas=FLAGS.num_replicas) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): opt_op = opt.apply_gradients(zip(grad_vars, var_list), global_step=global_step) # Compute accuracy. top1_op = tf.nn.in_top_k(logits, labels, 1) accuracy = tf.reduce_mean(tf.cast(top1_op, dtype=tf.float32)) tf.summary.scalar("top1_accuracy", accuracy) # Prepare optimization. vars_to_save = tf.global_variables() saver = tf.train.Saver(var_list=vars_to_save, max_to_keep=5, sharded=True) merged_summary = tf.summary.merge_all() # Hooks for optimization. hooks = [tf.train.StopAtStepHook(last_step=max_iters)] if not is_chief: hooks.append( tf.train.GlobalStepWaiterHook(FLAGS.task * FLAGS.startup_delay_steps)) init_op = tf.global_variables_initializer() scaffold = tf.train.Scaffold(init_op=init_op, summary_op=merged_summary, saver=saver) # Run optimization. epoch = 0 with tf.train.MonitoredTrainingSession( is_chief=is_chief, checkpoint_dir=FLAGS.checkpoint_dir, hooks=hooks, save_checkpoint_secs=FLAGS.save_checkpoint_secs, save_summaries_secs=FLAGS.save_summaries_secs, scaffold=scaffold) as sess: while not sess.should_stop(): _, acc, i = sess.run((opt_op, accuracy, global_step)) epoch = i // num_batches_per_epoch if (i % FLAGS.log_every_steps) == 0: tf.logging.info( "global step %d: epoch %d:\n train accuracy %.3f" % (i, epoch, acc))
def evaluate(): """Evaluating function.""" g = tf.Graph() ops_dict = {} with g.as_default(): # Data set. if FLAGS.experiment_type == "mnist": config = mnist_config.ConfigDict() dataset = mnist.MNIST(data_dir=FLAGS.data_dir, subset=FLAGS.subset, batch_size=FLAGS.batch_size, is_training=False) model = mnist_model.MNISTNetwork(config) images, labels, num_examples, num_classes = (dataset.images, dataset.labels, dataset.num_examples, dataset.num_classes) logits, _ = model(images, is_training=False) top1_op = tf.nn.in_top_k(logits, labels, 1) top1_op = tf.cast(top1_op, dtype=tf.float32) ops_dict["top1"] = (None, top1_op) accuracy_ph = tf.placeholder(tf.float32, None) ops_dict["top1_accuracy"] = (accuracy_ph, None) tf.summary.scalar("top1_accuracy", accuracy_ph) with tf.name_scope("optimizer"): global_step = tf.train.get_or_create_global_step() # Define losses. l2_loss_wt = config.l2_loss_wt xent_loss_wt = config.xent_loss_wt margin_loss_wt = config.margin_loss_wt gamma = config.gamma alpha = config.alpha top_k = config.top_k dist_norm = config.dist_norm with tf.name_scope("losses"): xent_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)) margin = margin_loss.large_margin(logits=logits, one_hot_labels=tf.one_hot( labels, num_classes), layers_list=[images], gamma=gamma, alpha_factor=alpha, top_k=top_k, dist_norm=dist_norm) l2_loss = 0. for v in tf.trainable_variables(): tf.logging.info(v) l2_loss += tf.nn.l2_loss(v) total_loss = 0 if xent_loss_wt > 0: total_loss += xent_loss_wt * xent_loss if margin_loss_wt > 0: total_loss += margin_loss_wt * margin if l2_loss_wt > 0: total_loss += l2_loss_wt * l2_loss xent_loss_ph = tf.placeholder(tf.float32, None) margin_loss_ph = tf.placeholder(tf.float32, None) l2_loss_ph = tf.placeholder(tf.float32, None) total_loss_ph = tf.placeholder(tf.float32, None) tf.summary.scalar("xent_loss", xent_loss_ph) tf.summary.scalar("margin_loss", margin_loss_ph) tf.summary.scalar("l2_loss", l2_loss_ph) tf.summary.scalar("total_loss", total_loss_ph) ops_dict["losses/xent_loss"] = (xent_loss_ph, xent_loss) ops_dict["losses/margin_loss"] = (margin_loss_ph, margin) ops_dict["losses/l2_loss"] = (l2_loss_ph, l2_loss) ops_dict["losses/total_loss"] = (total_loss_ph, total_loss) # Prepare evaluation session. merged_summary = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, tf.get_default_graph()) vars_to_save = tf.global_variables() saver = tf.train.Saver(var_list=vars_to_save) scaffold = tf.train.Scaffold(saver=saver) session_creator = tf.train.ChiefSessionCreator( scaffold=scaffold, checkpoint_dir=FLAGS.checkpoint_dir) while True: _eval_once(session_creator, ops_dict, summary_writer, merged_summary, global_step, num_examples) if FLAGS.run_once: break time.sleep(FLAGS.eval_interval_secs)