def model_factory(self, net_type, input_shape, drop_out, drop_out_ratio, batch_normalization, activation, num_class): # Set Res110 as the default model model = rmodel.resnet_v2(input_shape=input_shape, num_classes=num_class, drop_out=drop_out, drop_out_ratio=drop_out_ratio bn=batch_normalization, acti=activiation) if self.net_type == NetType.VGG: print('Choose the VGG model') model = vmodel.vgg(input_shape=input_shape, num_classes=num_class, drop_out=drop_out, drop_out_ratio=drop_out_ratio, batch_normalization=batch_normalization, activation=activiation) elif net_type == NetType.GOOGLENET: print('Choose the GoogleNet model') model = gmodel.google_net(input_shape=input_shape, num_classes=num_class, drop_out=drop_out, drop_out_ratio=drop_out_ratio, batch_normalization=batch_normalization, activation=activation): return model
def tensor_shapes_helper(self, resnet_size, with_gpu=False): """Checks the tensor shapes after each phase of the ResNet model.""" def reshape(shape): """Returns the expected dimensions depending on if gpu is being used. If a GPU is used for the test, the shape is returned (already in NCHW form). When GPU is not used, the shape is converted to NHWC. """ if with_gpu: return shape return shape[0], shape[2], shape[3], shape[1] graph = tf.Graph() with graph.as_default(), self.test_session(use_gpu=with_gpu, force_gpu=with_gpu): model = resnet_model.resnet_v2( resnet_size, 456, data_format='channels_first' if with_gpu else 'channels_last') inputs = tf.random_uniform([1, 224, 224, 3]) output = model(inputs, is_training=True) initial_conv = graph.get_tensor_by_name('initial_conv:0') max_pool = graph.get_tensor_by_name('initial_max_pool:0') block_layer1 = graph.get_tensor_by_name('block_layer1:0') block_layer2 = graph.get_tensor_by_name('block_layer2:0') block_layer3 = graph.get_tensor_by_name('block_layer3:0') block_layer4 = graph.get_tensor_by_name('block_layer4:0') avg_pool = graph.get_tensor_by_name('final_avg_pool:0') dense = graph.get_tensor_by_name('final_dense:0') self.assertAllEqual(initial_conv.shape, reshape((1, 64, 112, 112))) self.assertAllEqual(max_pool.shape, reshape((1, 64, 56, 56))) # The number of channels after each block depends on whether we're # using the building_block or the bottleneck_block. if resnet_size < 50: self.assertAllEqual(block_layer1.shape, reshape( (1, 64, 56, 56))) self.assertAllEqual(block_layer2.shape, reshape((1, 128, 28, 28))) self.assertAllEqual(block_layer3.shape, reshape((1, 256, 14, 14))) self.assertAllEqual(block_layer4.shape, reshape( (1, 512, 7, 7))) self.assertAllEqual(avg_pool.shape, reshape((1, 512, 1, 1))) else: self.assertAllEqual(block_layer1.shape, reshape((1, 256, 56, 56))) self.assertAllEqual(block_layer2.shape, reshape((1, 512, 28, 28))) self.assertAllEqual(block_layer3.shape, reshape((1, 1024, 14, 14))) self.assertAllEqual(block_layer4.shape, reshape( (1, 2048, 7, 7))) self.assertAllEqual(avg_pool.shape, reshape((1, 2048, 1, 1))) self.assertAllEqual(dense.shape, (1, 456)) self.assertAllEqual(output.shape, (1, 456))
def grads_and_loss(): """Creates loss tensor for resnet model.""" images = tf.ones([BATCH_SIZE, HEIGHT, WIDTH, DEPTH]) / 1000 labels = tf.ones(shape=[BATCH_SIZE, NUM_CLASSES]) / 1000 # images = tf.random_uniform((BATCH_SIZE, HEIGHT, WIDTH, DEPTH), seed=1) # labels = tf.random_uniform((BATCH_SIZE, NUM_CLASSES), seed=1) if USE_TINY: network = resnet_model.tiny_resnet_v2(resnet_size=RESNET_SIZE, num_classes=NUM_CLASSES) else: network = resnet_model.resnet_v2(resnet_size=RESNET_SIZE, num_classes=NUM_CLASSES) inputs = tf.reshape(images, [BATCH_SIZE, HEIGHT, WIDTH, DEPTH]) logits = network(inputs, True) cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) global_step = tf.train.get_or_create_global_step() optimizer = tf.train.MomentumOptimizer( learning_rate=_INITIAL_LEARNING_RATE, momentum=_MOMENTUM) # Batch norm requires update_ops to be added as a train_op dependency. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): grads = tf.gradients(loss, tf.trainable_variables()) # TODO: move to train_op # train_op = optimizer.minimize(loss, global_step) return grads, loss
def grads_and_loss(): """Creates loss tensor for resnet model.""" images = tf.ones([BATCH_SIZE, HEIGHT, WIDTH, DEPTH])/1000 labels = tf.ones(shape=[BATCH_SIZE, NUM_CLASSES])/1000 # images = tf.random_uniform((BATCH_SIZE, HEIGHT, WIDTH, DEPTH), seed=1) # labels = tf.random_uniform((BATCH_SIZE, NUM_CLASSES), seed=1) if USE_TINY: network = resnet_model.tiny_resnet_v2(resnet_size=RESNET_SIZE, num_classes=NUM_CLASSES) else: network = resnet_model.resnet_v2(resnet_size=RESNET_SIZE, num_classes=NUM_CLASSES) inputs = tf.reshape(images, [BATCH_SIZE, HEIGHT, WIDTH, DEPTH]) logits = network(inputs,True) cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) global_step = tf.train.get_or_create_global_step() optimizer = tf.train.MomentumOptimizer( learning_rate=_INITIAL_LEARNING_RATE, momentum=_MOMENTUM) # Batch norm requires update_ops to be added as a train_op dependency. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): grads = tf.gradients(loss, tf.trainable_variables()) # TODO: move to train_op # train_op = optimizer.minimize(loss, global_step) return grads, loss
def test_load_resnet18_v2(self): network = resnet_model.resnet_v2(resnet_depth=18, num_classes=10, data_format='channels_last') input_bhw3 = tf.placeholder(tf.float32, [1, 28, 28, 3]) resnet_output = network(inputs=input_bhw3, train=True) sess = tf.Session() sess.run(tf.global_variables_initializer()) _ = sess.run(resnet_output, feed_dict={input_bhw3: np.random.randn(1, 28, 28, 3)})
def tensor_shapes_helper(self, resnet_size, with_gpu=False): """Checks the tensor shapes after each phase of the ResNet model.""" def reshape(shape): """Returns the expected dimensions depending on if gpu is being used. If a GPU is used for the test, the shape is returned (already in NCHW form). When GPU is not used, the shape is converted to NHWC. """ if with_gpu: return shape return shape[0], shape[2], shape[3], shape[1] graph = tf.Graph() with graph.as_default(), self.test_session( use_gpu=with_gpu, force_gpu=with_gpu): model = resnet_model.resnet_v2( resnet_size, 456, data_format='channels_first' if with_gpu else 'channels_last') inputs = tf.random_uniform([1, 224, 224, 3]) output = model(inputs, is_training=True) initial_conv = graph.get_tensor_by_name('initial_conv:0') max_pool = graph.get_tensor_by_name('initial_max_pool:0') block_layer1 = graph.get_tensor_by_name('block_layer1:0') block_layer2 = graph.get_tensor_by_name('block_layer2:0') block_layer3 = graph.get_tensor_by_name('block_layer3:0') block_layer4 = graph.get_tensor_by_name('block_layer4:0') avg_pool = graph.get_tensor_by_name('final_avg_pool:0') dense = graph.get_tensor_by_name('final_dense:0') self.assertAllEqual(initial_conv.shape, reshape((1, 64, 112, 112))) self.assertAllEqual(max_pool.shape, reshape((1, 64, 56, 56))) # The number of channels after each block depends on whether we're # using the building_block or the bottleneck_block. if resnet_size < 50: self.assertAllEqual(block_layer1.shape, reshape((1, 64, 56, 56))) self.assertAllEqual(block_layer2.shape, reshape((1, 128, 28, 28))) self.assertAllEqual(block_layer3.shape, reshape((1, 256, 14, 14))) self.assertAllEqual(block_layer4.shape, reshape((1, 512, 7, 7))) self.assertAllEqual(avg_pool.shape, reshape((1, 512, 1, 1))) else: self.assertAllEqual(block_layer1.shape, reshape((1, 256, 56, 56))) self.assertAllEqual(block_layer2.shape, reshape((1, 512, 28, 28))) self.assertAllEqual(block_layer3.shape, reshape((1, 1024, 14, 14))) self.assertAllEqual(block_layer4.shape, reshape((1, 2048, 7, 7))) self.assertAllEqual(avg_pool.shape, reshape((1, 2048, 1, 1))) self.assertAllEqual(dense.shape, (1, 456)) self.assertAllEqual(output.shape, (1, 456))
def create_loss(): """Creates loss tensor for resnet model.""" images = tf.random_uniform((BATCH_SIZE, HEIGHT, WIDTH, DEPTH)) labels = tf.random_uniform((BATCH_SIZE, NUM_CLASSES)) network = resnet_model.resnet_v2(resnet_size=RESNET_SIZE, num_classes=NUM_CLASSES) inputs = tf.reshape(images, [BATCH_SIZE, HEIGHT, WIDTH, DEPTH]) logits = network(inputs,True) cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) l2_penalty = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables()]) loss = cross_entropy + _WEIGHT_DECAY * l2_penalty return loss
def create_loss(): """Creates loss tensor for resnet model.""" images = tf.random_uniform((BATCH_SIZE, HEIGHT, WIDTH, DEPTH)) labels = tf.random_uniform((BATCH_SIZE, NUM_CLASSES)) network = resnet_model.resnet_v2(resnet_size=RESNET_SIZE, num_classes=NUM_CLASSES) inputs = tf.reshape(images, [BATCH_SIZE, HEIGHT, WIDTH, DEPTH]) logits = network(inputs,True) cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) l2_penalty = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables()]) loss = cross_entropy + _WEIGHT_DECAY * l2_penalty return loss
def create_train_op_and_loss(): """Creates loss tensor for resnet model.""" images = tf.random_uniform((BATCH_SIZE, HEIGHT, WIDTH, DEPTH)) labels = tf.random_uniform((BATCH_SIZE, NUM_CLASSES)) if USE_TINY: network = resnet_model.tiny_resnet_v2(resnet_size=RESNET_SIZE, num_classes=NUM_CLASSES) else: network = resnet_model.resnet_v2(resnet_size=RESNET_SIZE, num_classes=NUM_CLASSES) inputs = tf.reshape(images, [BATCH_SIZE, HEIGHT, WIDTH, DEPTH]) logits = network(inputs, False) cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) # loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( # [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) loss = cross_entropy global_step = tf.train.get_or_create_global_step() # optimizer = tf.train.MomentumOptimizer( # learning_rate=_INITIAL_LEARNING_RATE, # momentum=_MOMENTUM) # Batch norm requires update_ops to be added as a train_op dependency. # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # with tf.control_dependencies(update_ops): if DUMP_GRAPHDEF: open('imagenet_%d.pbtxt' % (RESNET_SIZE, ), 'w').write(str(tf.get_default_graph().as_graph_def())) grads = tf.gradients(loss, tf.trainable_variables()) #train_op = optimizer.minimize(loss, global_step) # grads_and_vars = list(zip(grads, tf.trainable_variables())) # train_op = optimizer.apply_gradients(grads_and_vars) # return train_op, loss return grads, loss
def create_train_op_and_loss(): """Creates loss tensor for resnet model.""" images = tf.random_uniform((BATCH_SIZE, HEIGHT, WIDTH, DEPTH)) labels = tf.random_uniform((BATCH_SIZE, NUM_CLASSES)) if USE_TINY: network = resnet_model.tiny_resnet_v2(resnet_size=RESNET_SIZE, num_classes=NUM_CLASSES) else: network = resnet_model.resnet_v2(resnet_size=RESNET_SIZE, num_classes=NUM_CLASSES) inputs = tf.reshape(images, [BATCH_SIZE, HEIGHT, WIDTH, DEPTH]) logits = network(inputs,False) cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) # loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( # [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) loss = cross_entropy global_step = tf.train.get_or_create_global_step() # optimizer = tf.train.MomentumOptimizer( # learning_rate=_INITIAL_LEARNING_RATE, # momentum=_MOMENTUM) # Batch norm requires update_ops to be added as a train_op dependency. # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # with tf.control_dependencies(update_ops): if DUMP_GRAPHDEF: open('imagenet_%d.pbtxt'%(RESNET_SIZE,), 'w').write(str(tf.get_default_graph().as_graph_def())) grads = tf.gradients(loss, tf.trainable_variables()) #train_op = optimizer.minimize(loss, global_step) # grads_and_vars = list(zip(grads, tf.trainable_variables())) # train_op = optimizer.apply_gradients(grads_and_vars) # return train_op, loss return grads, loss
# 256, the learning rate should be 0.1. _INITIAL_LEARNING_RATE = 0.1 * FLAGS.batch_size / 256 _NUM_CHANNELS = 3 _LABEL_CLASSES = 1001 _MOMENTUM = 0.9 _WEIGHT_DECAY = 1e-4 _NUM_IMAGES = { 'train': 1281167, 'validation': 50000, } image_preprocessing_fn = vgg_preprocessing.preprocess_image network = resnet_model.resnet_v2( resnet_size=FLAGS.resnet_size, num_classes=_LABEL_CLASSES) batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size def filenames(is_training): """Return filenames for dataset.""" if is_training: return [ os.path.join(FLAGS.data_dir, 'train-%05d-of-01024' % i) for i in xrange(0, 1024)] else: return [ os.path.join(FLAGS.data_dir, 'validation-%05d-of-00128' % i) for i in xrange(0, 128)]
FLAGS = parser.parse_args() _EVAL_STEPS = 50000 // FLAGS.eval_batch_size # Scale the learning rate linearly with the batch size. When the batch size is # 256, the learning rate should be 0.1. _INITIAL_LEARNING_RATE = 0.1 * FLAGS.train_batch_size / 256 _MOMENTUM = 0.9 _WEIGHT_DECAY = 1e-4 train_dataset = imagenet.get_split('train', FLAGS.data_dir) eval_dataset = imagenet.get_split('validation', FLAGS.data_dir) image_preprocessing_fn = vgg_preprocessing.preprocess_image network = resnet_model.resnet_v2( resnet_size=FLAGS.resnet_size, num_classes=train_dataset.num_classes) batches_per_epoch = train_dataset.num_samples / FLAGS.train_batch_size def input_fn(is_training): """Input function which provides a single batch for train or eval.""" batch_size = FLAGS.train_batch_size if is_training else FLAGS.eval_batch_size dataset = train_dataset if is_training else eval_dataset capacity_multiplier = 20 if is_training else 2 min_multiplier = 10 if is_training else 1 provider = tf.contrib.slim.dataset_data_provider.DatasetDataProvider( dataset=dataset, num_readers=4, common_queue_capacity=capacity_multiplier * batch_size,
def resnet_model_fn(features, labels, mode, params): """Our model_fn for ResNet to be used with our Estimator.""" network = resnet_model.resnet_v2(resnet_size=FLAGS.resnet_size, num_classes=_LABEL_CLASSES) logits = network(inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN)) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) # Calculate loss, which includes softmax cross entropy and L2 regularization. cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) # Create a tensor named cross_entropy for logging purposes. # tf.identity(cross_entropy, name='cross_entropy') # tf.summary.scalar('cross_entropy', cross_entropy) # Add weight decay to the loss. We perform weight decay on all trainable # variables, which includes batch norm beta and gamma variables. loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) if mode == tf.estimator.ModeKeys.TRAIN: # Scale the learning rate linearly with the batch size. When the batch size is # 256, the learning rate should be 0.1. _INITIAL_LEARNING_RATE = 0.1 * FLAGS.train_batch_size / 256 batches_per_epoch = 1281167 / FLAGS.train_batch_size global_step = tf.train.get_or_create_global_step() # Perform a gradual warmup of the learning rate, as in the paper "Training # ImageNet in 1 Hour." Afterward, decay the learning rate by 0.1 at 30, 60, # 120, and 150 epochs. boundaries = [ int(batches_per_epoch * epoch) for epoch in [1, 2, 3, 4, 5, 30, 60, 120, 150] ] values = [ _INITIAL_LEARNING_RATE * decay for decay in [ 1.0 / 6, 2.0 / 6, 3.0 / 6, 4.0 / 6, 5.0 / 6, 1, 0.1, 0.01, 1e-3, 1e-4 ] ] learning_rate = piecewise_constant(global_step, boundaries, values) # Create a tensor named learning_rate for logging purposes. # tf.identity(learning_rate, name='learning_rate') # tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=_MOMENTUM) optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) # Batch norm requires update_ops to be added as a train_op dependency. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) else: train_op = None eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (metric_fn, [labels, logits]) return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metrics=eval_metrics)
def resnet_model_fn(features, labels, mode): """Our model_fn for ResNet to be used with our Estimator.""" tf.summary.image('images', features, max_outputs=6) network = resnet_model.resnet_v2(resnet_size=FLAGS.resnet_size, num_classes=_LABEL_CLASSES) logits = network(inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN)) predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) # Calculate loss, which includes softmax cross entropy and L2 regularization. cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) # Create a tensor named cross_entropy for logging purposes. tf.identity(cross_entropy, name='cross_entropy') tf.summary.scalar('cross_entropy', cross_entropy) # Add weight decay to the loss. We perform weight decay on all trainable # variables, which includes batch norm beta and gamma variables. loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) if mode == tf.estimator.ModeKeys.TRAIN: # Scale the learning rate linearly with the batch size. When the batch size is # 256, the learning rate should be 0.1. initial_learning_rate = 0.1 * FLAGS.batch_size / 256 batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size global_step = tf.train.get_or_create_global_step() # Multiply the learning rate by 0.1 at 30, 60, 80, and 90 epochs. boundaries = [ int(batches_per_epoch * epoch) for epoch in [30, 60, 80, 90] ] values = [ initial_learning_rate * decay for decay in [1, 0.1, 0.01, 1e-3, 1e-4] ] learning_rate = tf.train.piecewise_constant( tf.cast(global_step, tf.int32), boundaries, values) # Create a tensor named learning_rate for logging purposes. tf.identity(learning_rate, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=_MOMENTUM) # Batch norm requires update_ops to be added as a train_op dependency. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) else: train_op = None accuracy = tf.metrics.accuracy(tf.argmax(labels, axis=1), predictions['classes']) metrics = {'accuracy': accuracy} # Create a tensor named train_accuracy for logging purposes. tf.identity(accuracy[1], name='train_accuracy') tf.summary.scalar('train_accuracy', accuracy[1]) return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metrics)
# Prepare model model saving directory. save_dir = os.path.join(os.getcwd(), model_type) model_full_name = 'cifar10_%s_model' % model_type if not os.path.isdir(save_dir): os.makedirs(save_dir) filepath = os.path.join(save_dir, model_full_name) if load_checkpoint: model = load_model(filepath) else: # build the graph if version == 2: model = resnet_v2(input_shape=input_shape, depth=depth, activation_bits=activation_bits, weight_noise=weight_noise, trainable_conv=not finetune, trainable_dense=True) else: model = resnet_v1(input_shape=input_shape, depth=depth, activation_bits=activation_bits, weight_noise=weight_noise, trainable_conv=not finetune, trainable_dense=True) model.compile(loss='categorical_crossentropy', optimizer=Adam(learning_rate=lr_schedule(0)), metrics=['accuracy']) model.summary()
FLAGS = parser.parse_args() _EVAL_STEPS = 50000 // FLAGS.eval_batch_size # Scale the learning rate linearly with the batch size. When the batch size is # 256, the learning rate should be 0.1. _INITIAL_LEARNING_RATE = 0.1 * FLAGS.train_batch_size / 256 _MOMENTUM = 0.9 _WEIGHT_DECAY = 1e-4 train_dataset = imagenet.get_split('train', FLAGS.data_dir) eval_dataset = imagenet.get_split('validation', FLAGS.data_dir) image_preprocessing_fn = vgg_preprocessing.preprocess_image network = resnet_model.resnet_v2(resnet_size=FLAGS.resnet_size, num_classes=train_dataset.num_classes) batches_per_epoch = train_dataset.num_samples / FLAGS.train_batch_size def input_fn(is_training): """Input function which provides a single batch for train or eval.""" batch_size = FLAGS.train_batch_size if is_training else FLAGS.eval_batch_size dataset = train_dataset if is_training else eval_dataset capacity_multiplier = 20 if is_training else 2 min_multiplier = 10 if is_training else 1 provider = tf.contrib.slim.dataset_data_provider.DatasetDataProvider( dataset=dataset, num_readers=4, common_queue_capacity=capacity_multiplier * batch_size,
def resnet_model_fn(features, labels, mode, params): """Our model_fn for ResNet to be used with our Estimator.""" network = resnet_model.resnet_v2(resnet_size=FLAGS.resnet_size, num_classes=_LABEL_CLASSES) logits = network(inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN)) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) # Calculate loss, which includes softmax cross entropy and L2 regularization. cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) # Add weight decay to the loss. We exclude weight decay on the batch # normalization variables because it slightly improves accuracy. loss = cross_entropy + _WEIGHT_DECAY * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) global_step = tf.train.get_global_step() current_epoch = (tf.cast(global_step, tf.float32) / params['batches_per_epoch']) learning_rate = learning_rate_schedule(current_epoch) # TODO(chrisying): this is a hack to get the LR and epoch for Tensorboard. # Reimplement this when TPU training summaries are supported. lr_repeat = tf.reshape( tf.tile(tf.expand_dims(learning_rate, 0), [ params['batch_size'], ]), [params['batch_size'], 1]) ce_repeat = tf.reshape( tf.tile(tf.expand_dims(current_epoch, 0), [ params['batch_size'], ]), [params['batch_size'], 1]) if mode == tf.estimator.ModeKeys.TRAIN: optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=_MOMENTUM) if FLAGS.use_tpu: optimizer = tpu_optimizer.CrossShardOptimizer(optimizer) # Batch norm requires update_ops to be added as a train_op dependency. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) else: train_op = None eval_metrics = None if mode == tf.estimator.ModeKeys.EVAL: def metric_fn(labels, logits, lr_repeat, ce_repeat): """Evaluation metric fn. Performed on CPU, do not reference TPU ops.""" predictions = tf.argmax(logits, axis=1) accuracy = tf.metrics.accuracy(tf.argmax(labels, axis=1), predictions) lr = tf.metrics.mean(lr_repeat) ce = tf.metrics.mean(ce_repeat) return { 'accuracy': accuracy, 'learning_rate': lr, 'current_epoch': ce } eval_metrics = (metric_fn, [labels, logits, lr_repeat, ce_repeat]) return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metrics=eval_metrics)