def model_fn(features, labels, mode, params):
    """A model is called by TpuEstimator."""
    del labels
    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, 'my_mesh')
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    mesh_devices = [''] * mesh_shape.size
    mesh_impl = SimdMeshImpl(mesh_shape,
                             mtf.convert_to_layout_rules(FLAGS.layout),
                             mesh_devices, params['context'].device_assignment)
    with mtf_utils.outside_all_rewrites():
        logits, loss = toy_model(features, mesh)

    # TRAIN mode
    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
        optimizer = mtf_optimize.AdafactorOptimizer()
        update_ops = []
        for grad, var in zip(var_grads, graph.trainable_variables):
            update_ops.extend(optimizer.apply_grad(grad, var))
    else:
        # for now, we can only export fully-replicated tensors.
        fully_replicated_logits = mtf.anonymize(logits)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    tf_loss = lowering.export_to_tf_tensor(loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        tf.logging.info('tf_update_ops: {}'.format(tf_update_ops))
        train_op = tf.group(tf_update_ops)
    else:
        tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)

    with mtf_utils.outside_all_rewrites():
        # Copy master variables to slices. Must be called first.
        restore_hook = mtf.MtfRestoreHook(lowering)
        if mode == tf.estimator.ModeKeys.TRAIN:
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                FLAGS.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(tf_logits):
                mean_logitss = tf.metrics.mean(tf_logits)
                return {'mean_logitss': mean_logitss}

            eval_metrics = (metric_fn, [tf_logits])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)
Example #2
0
def resnet_model_fn(features, labels, mode, params):
  """Our model_fn for ResNet to be used with our Estimator."""
  network = resnet_model.resnet_v2(
      resnet_size=FLAGS.resnet_size, num_classes=_LABEL_CLASSES)

  logits = network(
      inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))

  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
    }
    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

  # Calculate loss, which includes softmax cross entropy and L2 regularization.
  cross_entropy = tf.losses.softmax_cross_entropy(
      logits=logits, onehot_labels=labels)

  # Create a tensor named cross_entropy for logging purposes.
  # tf.identity(cross_entropy, name='cross_entropy')
  # tf.summary.scalar('cross_entropy', cross_entropy)

  # Add weight decay to the loss. We perform weight decay on all trainable
  # variables, which includes batch norm beta and gamma variables.
  loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
      [tf.nn.l2_loss(v) for v in tf.trainable_variables()])

  if mode == tf.estimator.ModeKeys.TRAIN:
    # Scale the learning rate linearly with the batch size. When the batch size is
    # 256, the learning rate should be 0.1.
    _INITIAL_LEARNING_RATE = 0.1 * FLAGS.train_batch_size / 256

    batches_per_epoch = 1281167 / FLAGS.train_batch_size
    global_step = tf.train.get_or_create_global_step()

    # Perform a gradual warmup of the learning rate, as in the paper "Training
    # ImageNet in 1 Hour." Afterward, decay the learning rate by 0.1 at 30, 60,
    # 120, and 150 epochs.
    boundaries = [int(batches_per_epoch * epoch) for epoch in [
        1, 2, 3, 4, 5, 30, 60, 120, 150]]
    values = [_INITIAL_LEARNING_RATE * decay for decay in [
        1.0 / 6, 2.0 / 6, 3.0 / 6, 4.0 / 6, 5.0 / 6, 1, 0.1, 0.01, 1e-3, 1e-4]]
    learning_rate = piecewise_constant(global_step, boundaries, values)

    # Create a tensor named learning_rate for logging purposes.
    # tf.identity(learning_rate, name='learning_rate')
    # tf.summary.scalar('learning_rate', learning_rate)

    optimizer = tf.train.MomentumOptimizer(
        learning_rate=learning_rate,
        momentum=_MOMENTUM)
    optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

    # Batch norm requires update_ops to be added as a train_op dependency.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
      train_op = optimizer.minimize(loss, global_step)
  else:
    train_op = None

  eval_metrics = None
  if mode == tf.estimator.ModeKeys.EVAL:
    eval_metrics = (metric_fn, [labels, logits])

  return tpu_estimator.TPUEstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      eval_metrics=eval_metrics)
Example #3
0
    def model_fn(self, features, labels, mode, params):
        """Build the model based on features, labels, and mode.

    Args:
      features: The features dictionary containing the data Tensor
        and the number of examples.
      labels: The labels Tensor resulting from calling the model.
      mode: A string indicating the training mode.
      params: A dictionary of hyperparameters.

    Returns:
      A tf.estimator.EstimatorSpec.
    """
        del params
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        eval_active = (mode == tf.estimator.ModeKeys.EVAL)
        is_predict = (mode == tf.estimator.ModeKeys.PREDICT)
        features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC
        labels = tf.one_hot(labels, LABEL_CLASSES)
        loss, logits = self._build_network(features, labels, mode)

        if is_predict:
            predictions = {'logits': logits}
            return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                                  predictions=predictions)

        host_call = None
        train_op = None

        if is_training:
            global_step = tf.train.get_or_create_global_step()
            gs_t = tf.reshape(tf.cast(global_step, tf.int32), [1])

            # Setup learning rate schedule
            learning_rate = self._build_learning_rate_schedule(global_step)

            # Setup optimizer.
            optimizer = self._build_optimizer(learning_rate)

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = self._build_train_op(optimizer,
                                                loss,
                                                global_step=global_step)
            if self.hparams.moving_average_decay > 0:
                ema = tf.train.ExponentialMovingAverage(
                    decay=self.hparams.moving_average_decay,
                    num_updates=global_step)
                variables_to_average = (tf.trainable_variables() +
                                        tf.moving_average_variables())
                with tf.control_dependencies([train_op]):
                    with tf.name_scope('moving_average'):
                        train_op = ema.apply(variables_to_average)

            lr_t = tf.reshape(learning_rate, [1])
            host_call = None
            if self.hparams.enable_hostcall:

                def host_call_fn(gs, lr):
                    # Outfeed supports int32 but global_step is expected to be int64.
                    gs = tf.cast(tf.reduce_mean(gs), tf.int64)
                    with summary.create_file_writer(
                            self.model_dir).as_default():
                        with summary.always_record_summaries():
                            summary.scalar('learning_rate',
                                           tf.reduce_mean(lr),
                                           step=gs)
                            return summary.all_summary_ops()

                host_call = (host_call_fn, [gs_t, lr_t])

        eval_metrics = None
        eval_metric_ops = None
        if eval_active:

            def metric_fn(labels, logits):
                """Evaluation metric fn. Performed on CPU, do not reference TPU ops."""
                # Outfeed supports int32 but global_step is expected to be int64.
                predictions = tf.argmax(logits, axis=1)
                categorical_labels = tf.argmax(labels, axis=1)
                top_1_accuracy = tf.metrics.accuracy(categorical_labels,
                                                     predictions)
                in_top_5 = tf.cast(
                    tf.nn.in_top_k(logits, categorical_labels, 5), tf.float32)
                top_5_accuracy = tf.metrics.mean(in_top_5)

                return {
                    'top_1_accuracy': top_1_accuracy,
                    'top_5_accuracy': top_5_accuracy,
                }

            eval_metrics = (metric_fn, [labels, logits])
            eval_metric_ops = metric_fn(labels, logits)

        if self.hparams.use_tpu:
            return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                                  loss=loss,
                                                  train_op=train_op,
                                                  host_call=host_call,
                                                  eval_metrics=eval_metrics)
        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metric_ops=eval_metric_ops)
Example #4
0
def model_fn(features, labels, mode, params):
    del params  # Unused.

    if mode != tf.estimator.ModeKeys.TRAIN:
        raise RuntimeError("mode {} is not supported yet".format(mode))

    LayerBlock = namedtuple('LayerBlock',
                            ['num_repeats', 'num_filters', 'bottleneck_size'])
    blocks = [
        LayerBlock(3, 256, 64),
        LayerBlock(4, 512, 128),
        LayerBlock(6, 1024, 256),
        LayerBlock(3, 2048, 512)
    ]

    features = tf.transpose(features, [0, 3, 1, 2])
    # %%
    # First convolution expands to 64 channels and downsamples

    net = conv2d_fixed_padding(inputs=features,
                               filters=64,
                               kernel_size=7,
                               strides=2)
    #net = tf.layers.conv2d(inputs=features, filters=64, kernel_size=7, strides = 2, padding='VALID')
    net = tf.identity(net, 'initial_conv')
    # %%
    # Max pool and downsampling
    net = tf.layers.max_pooling2d(inputs=net,
                                  pool_size=3,
                                  strides=2,
                                  padding='SAME')
    net = tf.identity(net, 'initial_max_pool')

    # %%
    for block_i, block in enumerate(blocks):
        filters_out = block.num_filters

        net = batch_norm_relu(net)
        shortcut = conv2d_fixed_padding(inputs=net,
                                        filters=filters_out,
                                        kernel_size=1,
                                        strides=2)
        #shortcut = tf.layers.conv2d(
        #  inputs=net, filters=block.num_filters, kernel_size=1, strides=2,
        #  padding='VALID', data_format='channels_first')
        net = conv2d_fixed_padding(inputs=net,
                                   filters=block.bottleneck_size,
                                   kernel_size=1,
                                   strides=1)
        #net = tf.layers.conv2d(
        #  inputs=net, filters=block.bottleneck_size, kernel_size=1, strides=1,
        #  padding='VALID', data_format='channels_first')
        net = batch_norm_relu(net)
        net = conv2d_fixed_padding(inputs=net,
                                   filters=block.bottleneck_size,
                                   kernel_size=3,
                                   strides=2)
        #net = fixed_padding(net, 3)
        #net = tf.layers.conv2d(
        #  inputs=net, filters=block.bottleneck_size, kernel_size=3, strides=2,
        #  padding='SAME', data_format='channels_first')

        net = batch_norm_relu(net)
        net = conv2d_fixed_padding(inputs=net,
                                   filters=4 * block.bottleneck_size,
                                   kernel_size=1,
                                   strides=1)
        #net = tf.layers.conv2d(
        #  inputs=net, filters=block.num_filters, kernel_size=1, strides=1,
        #  padding='VALID', data_format='channels_first')

        net = tf.identity(net)

        for repeat_i in range(1, block.num_repeats):
            shortcut = net
            net = batch_norm_relu(net)
            net = conv2d_fixed_padding(inputs=net,
                                       filters=block.bottleneck_size,
                                       kernel_size=1,
                                       strides=1)
            net = batch_norm_relu(net)
            net = conv2d_fixed_padding(inputs=net,
                                       filters=block.bottleneck_size,
                                       kernel_size=3,
                                       strides=1)

            net = batch_norm_relu(net)
            net = conv2d_fixed_padding(inputs=net,
                                       filters=4 * block.bottleneck_size,
                                       kernel_size=1,
                                       strides=1)

            net = net + shortcut
            net = tf.identity(net)

    # %%
    net = batch_norm_relu(net)
    net = tf.layers.average_pooling2d(net,
                                      pool_size=net.get_shape().as_list()[2],
                                      strides=1,
                                      padding='VALID')
    net = tf.identity(net, 'final_avg_pool')
    net = tf.reshape(net, [
        -1,
        net.get_shape().as_list()[1] * net.get_shape().as_list()[2] *
        net.get_shape().as_list()[3]
    ])

    net = tf.layers.dense(inputs=net, units=FLAGS.num_classes)
    logits = tf.identity(net, 'final_dense')

    # Calculating the loss.
    loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)

    # Configuring the optimization algorithm.
    learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                               tf.train.get_global_step(),
                                               25000, 0.97)
    if FLAGS.use_tpu:
        optimizer = tpu_optimizer.CrossShardOptimizer(
            tf.train.GradientDescentOptimizer(learning_rate=learning_rate))
    else:
        optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=learning_rate)

    train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
    return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op)
Example #5
0
def model_fn(features, labels, mode, params):
    """Our model_fn for Densenet to be used with our Estimator."""
    tf.logging.info("model_fn")

    if FLAGS.network_depth == 169:
        logits = densenet_model.densenet_imagenet_169(
            features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
    elif FLAGS.network_depth == 201:
        logits = densenet_model.densenet_imagenet_201(
            features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
    elif FLAGS.network_depth == 121:
        logits = densenet_model.densenet_imagenet_121(
            features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
    else:
        tf.logging.info("Number of layers not supported, revert to 121")
        logits = densenet_model.densenet_imagenet_121(
            features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    onehot_labels = tf.one_hot(labels, _LABEL_CLASSES)
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits, onehot_labels=onehot_labels)

    # Add weight decay to the loss. We exclude weight decay on the batch
    # normalization variables because it slightly improves accuracy.
    loss = cross_entropy + _WEIGHT_DECAY * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if "batch_normalization" not in v.name
    ])

    global_step = tf.train.get_global_step()
    current_epoch = (tf.cast(global_step, tf.float32) /
                     params["batches_per_epoch"])
    learning_rate = learning_rate_schedule(current_epoch)

    # TODO(chrisying): this is a hack to get the LR and epoch for Tensorboard.
    # Reimplement this when TPU training summaries are supported.
    lr_repeat = tf.reshape(
        tf.tile(tf.expand_dims(learning_rate, 0), [
            params["batch_size"],
        ]), [params["batch_size"], 1])
    ce_repeat = tf.reshape(
        tf.tile(tf.expand_dims(current_epoch, 0), [
            params["batch_size"],
        ]), [params["batch_size"], 1])

    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=_MOMENTUM)
        optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)
    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits, lr_repeat, ce_repeat):
            """Evaluation metric fn. Performed on CPU, do not reference TPU ops."""
            predictions = tf.argmax(logits, axis=1)
            accuracy = tf.metrics.accuracy(tf.argmax(labels, axis=1),
                                           predictions)
            lr = tf.metrics.mean(lr_repeat)
            ce = tf.metrics.mean(ce_repeat)
            return {
                "accuracy": accuracy,
                "learning_rate": lr,
                "current_epoch": ce
            }

        eval_metrics = (metric_fn, [labels, logits, lr_repeat, ce_repeat])

    param_stats = tf.profiler.profile(
        tf.get_default_graph(),
        options=ProfileOptionBuilder.trainable_variables_parameter())
    fl_stats = tf.profiler.profile(
        tf.get_default_graph(),
        options=tf.profiler.ProfileOptionBuilder.float_operation())

    return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metrics=eval_metrics)
Example #6
0
def inception_model_fn(features, labels, mode, params):
    """Inception v3 model using Estimator API."""
    num_classes = FLAGS.num_classes
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    is_eval = (mode == tf.estimator.ModeKeys.EVAL)
    features = tensor_transform_fn(features, params['input_perm'])

    if FLAGS.clear_update_collections:
        # updates_collections must be set to None in order to use fused batchnorm
        with arg_scope(
                inception.inception_v3_arg_scope(
                    batch_norm_decay=BATCH_NORM_DECAY,
                    batch_norm_epsilon=BATCH_NORM_EPSILON,
                    updates_collections=None)):
            logits, end_points = inception.inception_v3(
                features, num_classes, is_training=is_training)
    else:
        with arg_scope(
                inception.inception_v3_arg_scope(
                    batch_norm_decay=BATCH_NORM_DECAY,
                    batch_norm_epsilon=BATCH_NORM_EPSILON)):
            logits, end_points = inception.inception_v3(
                features, num_classes, is_training=is_training)

    predictions = end_points
    predictions.update({
        'classes':
        tf.argmax(input=logits, axis=1),
        'probabilities':
        tf.nn.softmax(logits, name='softmax_tensor')
    })

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    if mode == tf.estimator.ModeKeys.EVAL and FLAGS.display_tensors and (
            not FLAGS.use_tpu):
        with tf.control_dependencies([
                tf.Print(predictions['classes'], [predictions['classes']],
                         summarize=FLAGS.eval_batch_size,
                         message='prediction: ')
        ]):
            labels = tf.Print(labels, [labels],
                              summarize=FLAGS.eval_batch_size,
                              message='label: ')

    one_hot_labels = tf.one_hot(labels, FLAGS.num_classes, dtype=tf.int32)

    if 'AuxLogits' in end_points:
        tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels,
                                        logits=end_points['AuxLogits'],
                                        weights=0.4,
                                        label_smoothing=0.1,
                                        scope='aux_loss')

    tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels,
                                    logits=logits,
                                    weights=1.0,
                                    label_smoothing=0.1)
    loss = tf.losses.get_total_loss(add_regularization_losses=True)

    initial_learning_rate = FLAGS.learning_rate * FLAGS.train_batch_size / 256
    if FLAGS.use_learning_rate_warmup:
        # Adjust initial learning rate to match final warmup rate
        warmup_decay = FLAGS.learning_rate_decay**(
            (FLAGS.warmup_epochs + FLAGS.cold_epochs) /
            FLAGS.learning_rate_decay_epochs)
        adj_initial_learning_rate = initial_learning_rate * warmup_decay

    final_learning_rate = 0.0001 * initial_learning_rate

    host_call = None
    train_op = None
    if is_training:
        batches_per_epoch = _NUM_TRAIN_IMAGES / FLAGS.train_batch_size
        global_step = tf.train.get_or_create_global_step()
        current_epoch = tf.cast(
            (tf.cast(global_step, tf.float32) / batches_per_epoch), tf.int32)

        learning_rate = tf.train.exponential_decay(
            learning_rate=initial_learning_rate,
            global_step=global_step,
            decay_steps=int(FLAGS.learning_rate_decay_epochs *
                            batches_per_epoch),
            decay_rate=FLAGS.learning_rate_decay,
            staircase=True)

        if FLAGS.use_learning_rate_warmup:
            wlr = 0.1 * adj_initial_learning_rate
            wlr_height = tf.cast(
                0.9 * adj_initial_learning_rate /
                (FLAGS.warmup_epochs + FLAGS.learning_rate_decay_epochs - 1),
                tf.float32)
            epoch_offset = tf.cast(FLAGS.cold_epochs - 1, tf.int32)
            exp_decay_start = (FLAGS.warmup_epochs + FLAGS.cold_epochs +
                               FLAGS.learning_rate_decay_epochs)
            lin_inc_lr = tf.add(
                wlr,
                tf.multiply(
                    tf.cast(tf.subtract(current_epoch, epoch_offset),
                            tf.float32), wlr_height))
            learning_rate = tf.where(
                tf.greater_equal(current_epoch, FLAGS.cold_epochs),
                (tf.where(tf.greater_equal(current_epoch, exp_decay_start),
                          learning_rate, lin_inc_lr)), wlr)

        # Set a minimum boundary for the learning rate.
        learning_rate = tf.maximum(learning_rate,
                                   final_learning_rate,
                                   name='learning_rate')

        if FLAGS.optimizer == 'sgd':
            tf.logging.info('Using SGD optimizer')
            optimizer = tf.train.GradientDescentOptimizer(
                learning_rate=learning_rate)
        elif FLAGS.optimizer == 'momentum':
            tf.logging.info('Using Momentum optimizer')
            optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                                   momentum=0.9)
        elif FLAGS.optimizer == 'RMS':
            tf.logging.info('Using RMS optimizer')
            optimizer = tf.train.RMSPropOptimizer(learning_rate,
                                                  RMSPROP_DECAY,
                                                  momentum=RMSPROP_MOMENTUM,
                                                  epsilon=RMSPROP_EPSILON)
        else:
            tf.logging.fatal('Unknown optimizer:', FLAGS.optimizer)

        if FLAGS.use_tpu:
            optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step=global_step)
        if FLAGS.moving_average:
            ema = tf.train.ExponentialMovingAverage(decay=MOVING_AVERAGE_DECAY,
                                                    num_updates=global_step)
            variables_to_average = (tf.trainable_variables() +
                                    tf.moving_average_variables())
            with tf.control_dependencies([train_op
                                          ]), tf.name_scope('moving_average'):
                train_op = ema.apply(variables_to_average)

        # To log the loss, current learning rate, and epoch for Tensorboard, the
        # summary op needs to be run on the host CPU via host_call. host_call
        # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
        # dimension. These Tensors are implicitly concatenated to
        # [params['batch_size']].
        gs_t = tf.reshape(global_step, [1])
        loss_t = tf.reshape(loss, [1])
        lr_t = tf.reshape(learning_rate, [1])
        ce_t = tf.reshape(current_epoch, [1])

        def host_call_fn(gs, loss, lr, ce):
            """Training host call. Creates scalar summaries for training metrics.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `host_call`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `host_call`.

      Args:
        gs: `Tensor with shape `[batch]` for the global_step
        loss: `Tensor` with shape `[batch]` for the training loss.
        lr: `Tensor` with shape `[batch]` for the learning_rate.
        ce: `Tensor` with shape `[batch]` for the current_epoch.

      Returns:
        List of summary ops to run on the CPU host.
      """
            gs = gs[0]
            with summary.create_file_writer(FLAGS.model_dir).as_default():
                with summary.always_record_summaries():
                    summary.scalar('loss', tf.reduce_mean(loss), step=gs)
                    summary.scalar('learning_rate',
                                   tf.reduce_mean(lr),
                                   step=gs)
                    summary.scalar('current_epoch',
                                   tf.reduce_mean(ce),
                                   step=gs)

                    return summary.all_summary_ops()

        host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t])

    eval_metrics = None
    if is_eval:

        def metric_fn(labels, logits):
            """Evaluation metric function. Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch, ]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
            predictions = tf.argmax(logits, axis=1)
            top_1_accuracy = tf.metrics.accuracy(labels, predictions)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            top_5_accuracy = tf.metrics.mean(in_top_5)

            return {
                'accuracy': top_1_accuracy,
                'accuracy@5': top_5_accuracy,
            }

        eval_metrics = (metric_fn, [labels, logits])

    return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          host_call=host_call,
                                          eval_metrics=eval_metrics)
Example #7
0
def model_fn(features, labels, mode, params):
    """Create a Feed forward network classification network

  Args:
    features (dict): Dictionary of input feature Tensors
    labels (Tensor): Class label Tensor
    mode (string): Mode running training, evaluation or prediction
    params (dict): Dictionary of additional params like batch_size

  Returns:
    Depending on the mode returns Tuple or Dict
  Raises:
    RuntimeError: if input mode is not TRAIN
  """

    del params

    embedding_size = FLAGS.embedding_size

    hidden_units = [100, 70, 50, 20]

    # Keep variance constant with changing embedding sizes.
    with tf.variable_scope('embeddings',
                           initializer=tf.truncated_normal_initializer(
                               stddev=(1.0 /
                                       math.sqrt(float(embedding_size))))):
        for col, vals in CATEGORICAL_COLS:
            bucket_size = vals if isinstance(vals, int) else len(vals)
            embeddings = tf.get_variable(col,
                                         shape=[bucket_size, embedding_size])

            features[col] = tf.squeeze(tf.nn.embedding_lookup(
                embeddings, features[col]),
                                       axis=[1])

    # Concatenate the (now all dense) features.
    # We need to sort the tensors so that they end up in the same order for
    # prediction, evaluation, and training
    sorted_feature_tensors = zip(*sorted(features.iteritems()))[1]
    inputs = tf.concat(sorted_feature_tensors, 1)

    # Build the DNN

    layers_size = [inputs.get_shape()[1]] + hidden_units
    layers_shape = zip(layers_size[0:], layers_size[1:] + [len(LABELS)])

    curr_layer = inputs
    # Set default initializer to variance_scaling_initializer
    # This initializer prevents variance from exploding or vanishing when
    # compounded through different sized layers.
    with tf.variable_scope(
            'dnn',
            initializer=tf.contrib.layers.variance_scaling_initializer()):
        # Creates the relu hidden layers
        for num, shape in enumerate(layers_shape):
            with tf.variable_scope('relu_{}'.format(num)):

                weights = tf.get_variable('weights', shape)

                biases = tf.get_variable('biases',
                                         shape[1],
                                         initializer=tf.zeros_initializer(
                                             tf.float32))

            activations = tf.matmul(curr_layer, weights) + biases
            if num < len(layers_shape) - 1:
                curr_layer = tf.nn.relu(activations)
            else:
                curr_layer = activations

    # Make predictions
    logits = curr_layer
    probabilities = tf.nn.softmax(logits)
    predicted_indices = tf.argmax(probabilities, 1)
    predictions = {
        'predictions': tf.gather(labels, predicted_indices),
        'confidence': tf.reduce_max(probabilities, axis=1)
    }

    # Make labels a vector
    label_indices_vector = tf.squeeze(labels)

    # global_step is necessary in eval to correctly load the step
    # of the checkpoint we are evaluating
    global_step = tf.contrib.framework.get_or_create_global_step()

    # Build training operation.
    loss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=label_indices_vector))
    # tf.summary.scalar('loss', loss)
    ftrl = tf.train.FtrlOptimizer(learning_rate=FLAGS.learning_rate,
                                  l1_regularization_strength=3.0,
                                  l2_regularization_strength=10.0)
    if FLAGS.use_tpu:
        optimizer = tpu_optimizer.CrossShardOptimizer(ftrl)
    else:
        optimizer = ftrl

    train_op = optimizer.minimize(loss, global_step=global_step)

    # Return accuracy and area under ROC curve metrics
    # See https://en.wikipedia.org/wiki/Receiver_operating_characteristic
    # See https://www.kaggle.com/wiki/AreaUnderCurve
    def metric_fn(labels, probabilities):
        accuracy = tf.contrib.metrics.streaming_accuracy(
            tf.argmax(probabilities, 1), labels)
        auroc = tf.contrib.metrics.streaming_auc(tf.argmax(probabilities, 1),
                                                 labels)
        return {'accuracy': accuracy, 'auroc': auroc}

    return tpu_estimator.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        predictions=predictions,
        train_op=train_op,
        eval_metrics=(metric_fn, [labels, probabilities]))
Example #8
0
File: utils.py Project: pu2wof/mesh
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.

    Args:
      features: input features dictionary
      labels: ignored
      mode: a tf.estimator.ModeKeys
      params: something
      config: something

    Returns:
      something
    """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            physical_shape = list(
                params["context"].device_assignment.topology.mesh_shape)
            logical_to_physical = _logical_to_physical(physical_shape,
                                                       mesh_shape)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape,
                layout_rules,
                mesh_devices,
                ctx.device_assignment,
                logical_to_physical=logical_to_physical)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
        batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
        length_dim = mtf.Dimension("length", sequence_length)
        feature_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim])

        mtf_features = {}
        for key, x in features.items():
            x = tf.to_int32(features[key])
            x = tf.reshape(x, [
                outer_batch_size, batch_size // outer_batch_size,
                sequence_length
            ])
            if not use_tpu:
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=1)
            mtf_features[key] = mtf.import_fully_replicated(mesh,
                                                            x,
                                                            feature_shape,
                                                            name=key)

        if mode == tf.estimator.ModeKeys.PREDICT:
            inputs = mtf_features["inputs"]
            inputs = mtf.reshape(
                inputs,
                mtf.Shape([
                    mtf.Dimension("batch", batch_size),
                    mtf.Dimension("length", sequence_length)
                ]))
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Bitransformer):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"outputs": outputs}
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        elif mode == tf.estimator.ModeKeys.EVAL:
            raise NotImplementedError("We don't expect to use mode == eval.")

        else:
            assert mode == tf.estimator.ModeKeys.TRAIN
            num_microbatches = serialize_num_microbatches(
                batch_dim, length_dim, mesh_shape, layout_rules)

            def model_fn(mtf_features):
                """The kind of function we need for mtf.serialize_training_step.

        Args:
          mtf_features: a dictionary
        Returns:
          a dictionary
        """
                targets = mtf_features["targets"]
                if model_type == "lm":
                    _, _, length_dim = targets.shape
                    inputs = mtf.shift(targets,
                                       offset=1,
                                       dim=length_dim,
                                       wrap=False)
                else:
                    inputs = mtf_features["inputs"]

                if isinstance(transformer_model, transformer.Unitransformer):
                    position_kwargs = dict(
                        sequence_id=mtf_features.get("targets_segmentation",
                                                     None),
                        position=mtf_features.get("targets_position", None),
                    )
                elif isinstance(transformer_model, transformer.Bitransformer):
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "targets_segmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "targets_position", None),
                    )
                else:
                    raise ValueError("unrecognized class")

                logits, loss = transformer_model.call_simple(
                    inputs=inputs,
                    targets=targets,
                    compute_loss=True,
                    mode=mode,
                    variable_dtype=get_variable_dtype(),
                    **position_kwargs)
                if num_microbatches > 1:
                    loss /= float(num_microbatches)
                del logits
                return {"loss": loss}

            if num_microbatches > 1:
                var_grads, loss_dict = mtf.serialize_training_step(
                    mtf_features, model_fn, batch_dim, num_microbatches)
            else:
                loss_dict = model_fn(mtf_features)
                var_grads = mtf.gradients(
                    [loss_dict["loss"]],
                    [v.outputs[0] for v in graph.trainable_variables])

            loss = loss_dict["loss"]

            if callable(learning_rate_schedule):
                # the following happens on CPU since TPU can't handle summaries.
                with mtf.utils.outside_all_rewrites():
                    learning_rate = learning_rate_schedule(
                        step=tf.train.get_global_step())
                    tf.summary.scalar("learning_rate", learning_rate)
            else:
                learning_rate = learning_rate_schedule

            update_ops = optimizer(learning_rate=learning_rate).apply_grads(
                var_grads, graph.trainable_variables)

            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)

            tf_loss = lowering.export_to_tf_tensor(loss)
            tf_loss = tf.to_float(tf_loss)
            if not use_tpu:
                tf_loss = tf.Print(
                    tf_loss, [tf_loss, tf.train.get_global_step()],
                    "step, tf_loss")

            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

            with mtf.utils.outside_all_rewrites():
                # Copy master variables to slices. Must be called first.
                restore_hook = mtf.MtfRestoreHook(lowering)
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=keep_checkpoint_max,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    model_dir,
                    save_steps=save_checkpoints_steps,
                    saver=saver,
                    listeners=[saver_listener])
                gin_config_saver_hook = gin.tf.GinConfigSaverHook(
                    model_dir, summarize_config=True)

                if use_tpu:
                    if tpu_summaries:
                        tf.summary.scalar("loss", tf_loss)
                        host_call = mtf.utils.create_host_call(model_dir)
                        mtf.utils.remove_summaries()
                    else:
                        host_call = None
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        host_call=host_call,
                        training_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
Example #9
0
def model_fn(features, labels, mode, params):

    del labels

    cfg = params['cfg']
    model = models.model(cfg)
    y = features['y']

    if mode == tf.estimator.ModeKeys.PREDICT:
        ###########
        # PREDICT #
        ###########
        predictions = {'generated_images': model.sample(y, temp=0.75)}
        return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                              predictions=predictions)

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    real_images = features['real_images']

    f_loss, eps = model.f_loss(real_images, y)

    if mode == tf.estimator.ModeKeys.TRAIN:
        #########
        # TRAIN #
        #########

        f_loss = tf.reduce_mean(f_loss)

        with tf.variable_scope('Regularization'):
            for v in tf.trainable_variables():
                if 'invconv' in v.name:
                    det = tf.matrix_determinant(v * tf.transpose(v))
                    f_loss += tf.square(det - 1)

            if cfg.use_l2_regularization:
                for v in tf.trainable_variables():
                    if 'actnorm' not in v.name:
                        f_loss += cfg.l2_regularization_factor * tf.nn.l2_loss(
                            v)

        if not cfg.use_tpu and cfg.report_histograms:
            for v in tf.trainable_variables():
                tf.summary.histogram(v.name.replace(':', '_'), v)

        global_step = tf.train.get_or_create_global_step()
        rate = tf.minimum(tf.cast(global_step, tf.float32) / 2000.0, 1.0)
        #lr = int(real_images.get_shape()[0]) * cfg.lr
        lr = cfg.lr * rate
        #from AMSGrad import AMSGrad
        optimizer = tf.train.AdamOptimizer(learning_rate=lr,
                                           beta1=cfg.beta1,
                                           epsilon=cfg.adam_eps)

        tf.summary.scalar('lr', lr)

        if cfg.use_tpu:
            optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            with tf.variable_scope('TrainOps'):
                if cfg.memory_saving_gradients:
                    from memory_saving_gradients import gradients
                    gs = gradients(f_loss, tf.trainable_variables())
                else:
                    gs = tf.gradients(f_loss, tf.trainable_variables())
                if cfg.use_gradient_clipping:
                    gs = [tf.clip_by_value(g, -100., 100.) for g in gs]
                grads_and_vars = list(zip(gs, tf.trainable_variables()))
                train_op = optimizer.apply_gradients(grads_and_vars)
                increment_step = tf.assign_add(
                    tf.train.get_or_create_global_step(), 1)
                joint_op = tf.group([train_op, increment_step])

            return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                                  loss=f_loss,
                                                  train_op=joint_op)

    elif mode == tf.estimator.ModeKeys.EVAL:
        ########
        # EVAL #
        ########
        def _eval_metric_fn(f_loss):
            return {'f_loss': tf.metrics.mean(f_loss)}

        return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                              loss=tf.reduce_mean(f_loss),
                                              eval_metrics=(_eval_metric_fn,
                                                            [f_loss]))

    raise ValueError('Invalid mode provided to model_fn')
Example #10
0
def model_fn(features, labels, mode, params):
    """Constructs DCGAN from individual generator and discriminator networks."""
    del labels  # Unconditional GAN does not use labels

    if mode == tf.estimator.ModeKeys.PREDICT:
        ###########
        # PREDICT #
        ###########
        # Pass only noise to PREDICT mode
        random_noise = features['random_noise']
        predictions = {
            'generated_samples': model.generator(random_noise,
                                                 is_training=False)
        }

        return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                              predictions=predictions)

    batch_size = params['batch_size']  # pylint: disable=unused-variable
    real_samples = features['samples']
    random_noise = features['random_noise']

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    generated_samples = model.generator(random_noise, is_training=is_training)

    # Get logits from discriminator
    d_real = model.discriminator(real_samples)
    d_fake = model.discriminator(generated_samples)

    d_loss, g_loss = LSGAN(d_real, d_fake)

    if mode == tf.estimator.ModeKeys.TRAIN:
        #########
        # TRAIN #
        #########
        d_loss = tf.reduce_mean(d_loss)
        g_loss = tf.reduce_mean(g_loss)

        d_optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate,
                                             beta1=FLAGS.beta1)
        g_optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate,
                                             beta1=FLAGS.beta1)

        if FLAGS.use_tpu:
            d_optimizer = tpu_optimizer.CrossShardOptimizer(d_optimizer)
            g_optimizer = tpu_optimizer.CrossShardOptimizer(g_optimizer)

        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            d_step = d_optimizer.minimize(d_loss,
                                          var_list=tf.get_collection(
                                              tf.GraphKeys.GLOBAL_VARIABLES,
                                              scope='Discriminator'))
            g_step = g_optimizer.minimize(g_loss,
                                          var_list=tf.get_collection(
                                              tf.GraphKeys.GLOBAL_VARIABLES,
                                              scope='Generator'))

            increment_step = tf.assign_add(
                tf.train.get_or_create_global_step(), 1)
            joint_op = tf.group([d_step, g_step, increment_step])

            return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                                  loss=g_loss,
                                                  train_op=joint_op)

    elif mode == tf.estimator.ModeKeys.EVAL:
        ########
        # EVAL #
        ########
        def _eval_metric_fn(d_loss, g_loss):
            # When using TPUs, this function is run on a different machine than the
            # rest of the model_fn and should not capture any Tensors defined there
            return {
                'discriminator_loss': tf.metrics.mean(d_loss),
                'generator_loss': tf.metrics.mean(g_loss)
            }

        return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                              loss=tf.reduce_mean(g_loss),
                                              eval_metrics=(_eval_metric_fn,
                                                            [d_loss, g_loss]))

    # Should never reach here
    raise ValueError('Invalid mode provided to model_fn')
Example #11
0
def model_fn(features, labels, mode, params=None):
    '''
    Args:
        features: tensor with shape
            [BATCH_SIZE, go.N, go.N, features_lib.NEW_FEATURES_PLANES]
        labels: dict from string to tensor with shape
            'pi_tensor': [BATCH_SIZE, go.N * go.N + 1]
            'value_tensor': [BATCH_SIZE]
        mode: a tf.estimator.ModeKeys (batchnorm params update for TRAIN only)
        params: (Ignored; needed for compat with TPUEstimator)
    Returns: tf.estimator.EstimatorSpec with props
        mode: same as mode arg
        predictions: dict of tensors
            'policy': [BATCH_SIZE, go.N * go.N + 1]
            'value': [BATCH_SIZE]
        loss: a single value tensor
        train_op: train op
        eval_metric_ops
    return dict of tensors
        logits: [BATCH_SIZE, go.N * go.N + 1]
    '''

    policy_output, value_output, logits = model_inference_fn(
        features, mode == tf.estimator.ModeKeys.TRAIN)

    # train ops
    policy_cost = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits_v2(
            logits=logits, labels=tf.stop_gradient(labels['pi_tensor'])))

    value_cost = FLAGS.value_cost_weight * tf.reduce_mean(
        tf.square(value_output - labels['value_tensor']))

    reg_vars = [v for v in tf.trainable_variables()
                if not 'bias' in v.name and not 'beta' in v.name]
    l2_cost = FLAGS.l2_strength * \
        tf.add_n([tf.nn.l2_loss(v) for v in reg_vars])

    combined_cost = policy_cost + value_cost + l2_cost

    global_step = tf.train.get_or_create_global_step()
    learning_rate = tf.train.piecewise_constant(
        global_step, FLAGS.lr_boundaries, FLAGS.lr_rates)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.sgd_momentum)
    if FLAGS.use_tpu:
        optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(combined_cost, global_step=global_step)

    # Computations to be executed on CPU, outside of the main TPU queues.
    def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor, policy_cost,
                                  value_cost, l2_cost, combined_cost, step,
                                  est_mode=tf.estimator.ModeKeys.TRAIN):
        policy_entropy = -tf.reduce_mean(tf.reduce_sum(
            policy_output * tf.log(policy_output), axis=1))
        # pi_tensor is one_hot when generated from sgfs (for supervised learning)
        # and soft-max when using self-play records. argmax normalizes the two.
        policy_target_top_1 = tf.argmax(pi_tensor, axis=1)
        policy_output_top_1 = tf.argmax(policy_output, axis=1)

        policy_output_in_top1 = tf.to_float(
            tf.nn.in_top_k(policy_output, policy_target_top_1, k=1))
        policy_output_in_top3 = tf.to_float(
            tf.nn.in_top_k(policy_output, policy_target_top_1, k=3))

        policy_top_1_confidence = tf.reduce_max(policy_output, axis=1)
        policy_target_top_1_confidence = tf.boolean_mask(
            policy_output,
            tf.one_hot(policy_target_top_1, tf.shape(policy_output)[1]))

        # TODO(sethtroisi): For V10 add tf.variable_scope for tf.metrics.mean's
        with tf.variable_scope("metrics"):
            metric_ops = {
                'policy_cost': tf.metrics.mean(policy_cost),
                'value_cost': tf.metrics.mean(value_cost),
                'l2_cost': tf.metrics.mean(l2_cost),
                'policy_entropy': tf.metrics.mean(policy_entropy),
                'combined_cost': tf.metrics.mean(combined_cost),

                'policy_accuracy_top_1': tf.metrics.mean(policy_output_in_top1),
                'policy_accuracy_top_3': tf.metrics.mean(policy_output_in_top3),
                'policy_top_1_confidence': tf.metrics.mean(policy_top_1_confidence),
                'policy_target_top_1_confidence': tf.metrics.mean(
                    policy_target_top_1_confidence),
                'value_confidence': tf.metrics.mean(tf.abs(value_output)),
            }

        if est_mode == tf.estimator.ModeKeys.EVAL:
            return metric_ops

        # Create summary ops so that they show up in SUMMARIES collection
        # That way, they get logged automatically during training
        summary_writer = summary.create_file_writer(FLAGS.model_dir)
        with summary_writer.as_default(), \
                summary.always_record_summaries():
            for metric_name, metric_op in metric_ops.items():
                summary.scalar(metric_name, metric_op[1])

        # Reset metrics occasionally so that they are mean of recent batches.
        reset_op = tf.variables_initializer(tf.local_variables("metrics"))
        cond_reset_op = tf.cond(
            tf.equal(tf.mod(tf.reduce_min(step), FLAGS.summary_steps), tf.to_int64(1)),
            lambda: reset_op,
            lambda: tf.no_op())

        return summary.all_summary_ops() + [cond_reset_op]

    metric_args = [
        policy_output,
        value_output,
        labels['pi_tensor'],
        tf.reshape(policy_cost, [1]),
        tf.reshape(value_cost, [1]),
        tf.reshape(l2_cost, [1]),
        tf.reshape(combined_cost, [1]),
        tf.reshape(global_step, [1]),
    ]

    predictions = {
        'policy_output': policy_output,
        'value_output': value_output,
    }

    eval_metrics_only_fn = functools.partial(
        eval_metrics_host_call_fn, est_mode=tf.estimator.ModeKeys.EVAL)
    host_call_fn = functools.partial(
        eval_metrics_host_call_fn, est_mode=tf.estimator.ModeKeys.TRAIN)

    tpu_estimator_spec = tpu_estimator.TPUEstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=combined_cost,
        train_op=train_op,
        eval_metrics=(eval_metrics_only_fn, metric_args),
        host_call=(host_call_fn, metric_args)
    )
    if FLAGS.use_tpu:
        return tpu_estimator_spec
    else:
        return tpu_estimator_spec.as_estimator_spec()
Example #12
0
  def _model_fn(mode, features, labels):
    """A model_fn that builds the DNN classification spec

    Args:
      mode (tf.estimator.ModeKeys): One of ModeKeys.(TRAIN|PREDICT|INFER) which
         is used to selectively add operations to the graph.
      features (Mapping[str:Tensor]): Input features for the model.
      labels (Tensor): Label Tensor.

    Returns:
      tf.estimator.EstimatorSpec which defines the model. Will have different
      populated members depending on `mode`. See:
        https://www.tensorflow.org/api_docs/python/tf/estimator/EstimatorSpec
      for details.
    """
    (gender, race, education, marital_status, relationship,
     workclass, occupation, native_country, age,
     education_num, capital_gain, capital_loss, hours_per_week) = INPUT_COLUMNS

    transformed_columns = [
        # Use indicator columns for low dimensional vocabularies
        tf.feature_column.indicator_column(workclass),
        tf.feature_column.indicator_column(education),
        tf.feature_column.indicator_column(marital_status),
        tf.feature_column.indicator_column(gender),
        tf.feature_column.indicator_column(relationship),
        tf.feature_column.indicator_column(race),

        # Use embedding columns for high dimensional vocabularies
        tf.feature_column.embedding_column(
            native_country, dimension=embedding_size),
        tf.feature_column.embedding_column(
            occupation, dimension=embedding_size),
        age,
        education_num,
        capital_gain,
        capital_loss,
        hours_per_week,
    ]

    inputs = tf.feature_column.input_layer(features, transformed_columns)
    label_values = tf.constant(LABELS)

    # Build the DNN
    curr_layer = inputs

    for layer_size in hidden_units:
      curr_layer = tf.layers.dense(
          curr_layer,
          layer_size,
          activation=tf.nn.relu,
          # This initializer prevents variance from exploding or vanishing when
          # compounded through different sized layers.
          kernel_initializer=tf.contrib.layers.variance_scaling_initializer(),
      )

    # Add the output layer
    logits = tf.layers.dense(
        curr_layer,
        len(LABELS),
        # Do not use ReLU on last layer
        activation=None,
        kernel_initializer=tf.contrib.layers.variance_scaling_initializer()
    )

    if mode in (Modes.PREDICT, Modes.EVAL):
      probabilities = tf.nn.softmax(logits)
      predicted_indices = tf.argmax(probabilities, 1)

    if mode in (Modes.TRAIN, Modes.EVAL):
      # Convert the string label column to indices
      # Build a lookup table inside the graph
      table = tf.contrib.lookup.index_table_from_tensor(label_values)

      # Use the lookup table to convert string labels to ints
      label_indices = table.lookup(labels)
      # Make labels a vector
      label_indices_vector = tf.squeeze(label_indices, axis=[1])

      # global_step is necessary in eval to correctly load the step
      # of the checkpoint we are evaluating
      global_step = tf.contrib.framework.get_or_create_global_step()
      loss = tf.reduce_mean(
          tf.nn.sparse_softmax_cross_entropy_with_logits(
              logits=logits, labels=label_indices_vector))
      tf.summary.scalar('loss', loss) # TODO: does this need to be handled with host_call?

    if mode == Modes.PREDICT:
      # Convert predicted_indices back into strings
      predictions = {
          'classes': tf.gather(label_values, predicted_indices),
          'scores': tf.reduce_max(probabilities, axis=1)
      }
      export_outputs = {
          'prediction': tf.estimator.export.PredictOutput(predictions)
      }
      return tpu_estimator.TPUEstimatorSpec( # TODO: do these need to be changed?
          mode, predictions=predictions, export_outputs=export_outputs)

    if mode == Modes.TRAIN:
      # Build training operation.
      optimizer = tf.train.FtrlOptimizer(
          learning_rate=learning_rate,
          l1_regularization_strength=3.0,
          l2_regularization_strength=10.0
      )
      optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)
      train_op = optimizer.minimize(loss, global_step=global_step)
      
      return tpu_estimator.TPUEstimatorSpec(mode, loss=loss, train_op=train_op)

    if mode == Modes.EVAL:
      # Return accuracy and area under ROC curve metrics
      # See https://en.wikipedia.org/wiki/Receiver_operating_characteristic
      # See https://www.kaggle.com/wiki/AreaUnderCurve
      labels_one_hot = tf.one_hot(
          label_indices_vector,
          depth=label_values.shape[0],
          on_value=True,
          off_value=False,
          dtype=tf.bool
      )
      eval_metric_ops = {
          'accuracy': tf.metrics.accuracy(label_indices, predicted_indices),
          'auroc': tf.metrics.auc(labels_one_hot, probabilities)
      }
      def metric_fn(label_indices, label_indices_vector, predicted_indices, probabilities):
        labels_one_hot = tf.one_hot(
          label_indices_vector,
          depth=label_values.shape[0],
          on_value=True,
          off_value=False,
          dtype=tf.bool
      )
        return {
          'accuracy': tf.metrics.accuracy(label_indices, predicted_indices),
          'auroc': tf.metrics.auc(labels_one_hot, probabilities)
        }

      return tpu_estimator.TPUEstimatorSpec(
          mode, loss=loss, 
          #eval_metric_ops=eval_metric_ops # TODO: is this the right way to handle multiple eval metrics? (Yes, I know it can be more efficient with naming, but just an example)
          eval_metrics = (metric_fn, [label_indices, label_indices_vector, predicted_indices, probabilities])
          )
def char_rnn_model(features, labels, mode, params):
    """Character level recurrent neural network model to predict classes."""
    batch_size = params['batch_size']

    byte_vectors = tf.one_hot(features[CHARS_FEATURE], 256, 1., 0.)
    byte_list = tf.unstack(byte_vectors, axis=1)

    cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE)
    _, encoding = tf.nn.static_rnn(cell, byte_list, dtype=tf.float32)

    logits = tf.layers.dense(encoding, MAX_LABEL, activation=None)

    predicted_classes = tf.argmax(logits, 1)
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                              predictions={
                                                  'class': predicted_classes,
                                                  'prob': tf.nn.softmax(logits)
                                              })

    #loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) +\
            _WEIGHT_DECAY * tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables()
            if 'batch_normalization' not in v.name])

    #get current training epoch
    batches_per_epoch = _NUM_TRAIN_IMAGES / FLAGS.train_batch_size
    global_step = tf.train.get_global_step()
    current_epoch = (tf.cast(global_step, tf.float32) / batches_per_epoch)
    learning_rate = learning_rate_schedule(current_epoch)

    if mode == tf.estimator.ModeKeys.TRAIN:
        #optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=_MOMENTUM,
                                               use_nesterov=True)
        if FLAGS.use_tpu:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)
        train_op = optimizer.minimize(loss,
                                      global_step=tf.train.get_global_step())
        return tpu_estimator.TPUEstimatorSpec(mode,
                                              loss=loss,
                                              train_op=train_op)

    #trick to report Learning rate as a metric: repeat batch_size time
    lr_repeat = tf.reshape(
        tf.tile(tf.expand_dims(learning_rate, 0), [
            batch_size,
        ]), [batch_size, 1])

    ce_repeat = tf.reshape(
        tf.tile(tf.expand_dims(current_epoch, 0), [
            batch_size,
        ]), [batch_size, 1])
    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, predicted_classes, lr_repeat, ce_repeat):
            """Evaluation metric fn. Performed on CPU, do not reference TPU ops."""
            return {
                'accuracy':
                tf.metrics.accuracy(labels=labels,
                                    predictions=predicted_classes),
                'learning_rate':
                tf.metrics.mean(lr_repeat),
                'current_epoch':
                tf.metrics.mean(ce_repeat)
            }

        eval_metrics = (metric_fn,
                        [labels, predicted_classes, lr_repeat, ce_repeat])

    return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                          loss=loss,
                                          eval_metrics=eval_metrics)
Example #14
0
def resnet_model_fn(features, labels, mode, params):
    """The model_fn for ResNet to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
        `params['batch_size']` is always provided and should be used as the
        effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
    if isinstance(features, dict):
        images = features['image']
        hms = features['hm']
        bboxs = features['bbox']
        ccount = features['ccount']
    else:
        images = features
        hms = None
        bboxs = None

    if FLAGS.data_format == 'channels_first':
        assert not FLAGS.transpose_input  # channels_first only for GPU
        images = tf.transpose(images, [0, 3, 1, 2])
        if hms is not None:
            hms = tf.transpose(hms, [0, 3, 1, 2])
            bboxs = tf.transpose(bboxs, [0, 3, 1, 2])

    if FLAGS.transpose_input:
        images = tf.transpose(images, [3, 0, 1, 2])  # HWCN to NHWC
        if hms is not None:
            hms = tf.transpose(hms, [3, 0, 1, 2])
            bboxs = tf.transpose(bboxs, [3, 0, 1, 2])

    if FLAGS.use_tpu:
        import bfloat16
        scope_fn = lambda: bfloat16.bfloat16_scope()
    else:
        scope_fn = lambda: tf.variable_scope("")

    with scope_fn():
        resnet_size = int(FLAGS.resnet_depth.split("_")[-1])
        if FLAGS.resnet_depth.startswith("v1_"):
            print("\n\n\n\n\nUSING RESNET V1 {}\n\n\n\n\n".format(
                FLAGS.resnet_depth))
            network = resnet_model.resnet_v1(resnet_depth=int(resnet_size),
                                             num_classes=LABEL_CLASSES,
                                             attention=None,
                                             apply_to="outputs",
                                             use_tpu=FLAGS.use_tpu,
                                             data_format=FLAGS.data_format)
        elif FLAGS.resnet_depth.startswith("SE-v1_"):
            print(
                "\n\n\n\n\nUSING RESNET V1 (Squeeze-and-excite) {}\n\n\n\n\n".
                format(resnet_size))
            network = resnet_model.resnet_v1(resnet_depth=int(resnet_size),
                                             num_classes=LABEL_CLASSES,
                                             attention="se",
                                             apply_to="outputs",
                                             use_tpu=FLAGS.use_tpu,
                                             data_format=FLAGS.data_format)
        elif FLAGS.resnet_depth.startswith("GALA-v1_"):
            print("\n\n\n\n\nUSING RESNET V1 (GALA) {}\n\n\n\n\n".format(
                resnet_size))
            network = resnet_model.resnet_v1(resnet_depth=int(resnet_size),
                                             num_classes=LABEL_CLASSES,
                                             attention="gala",
                                             apply_to="outputs",
                                             use_tpu=FLAGS.use_tpu,
                                             data_format=FLAGS.data_format)
        elif FLAGS.resnet_depth.startswith("v2_"):
            print("\n\n\n\n\nUSING RESNET V2 {}\n\n\n\n\n".format(resnet_size))
            network = resnet_v2_model.resnet_v2(resnet_size=resnet_size,
                                                num_classes=LABEL_CLASSES,
                                                feature_attention=False,
                                                extra_convs=0,
                                                data_format=FLAGS.data_format,
                                                use_tpu=FLAGS.use_tpu)
        elif FLAGS.resnet_depth.startswith("SE-v2_"):
            print(
                "\n\n\n\n\nUSING RESNET V2 (Squeeze-and-excite) {}\n\n\n\n\n".
                format(resnet_size))
            network = resnet_v2_model.resnet_v2(resnet_size=resnet_size,
                                                num_classes=LABEL_CLASSES,
                                                feature_attention="se",
                                                extra_convs=0,
                                                apply_to="output",
                                                data_format=FLAGS.data_format,
                                                use_tpu=FLAGS.use_tpu)
        elif FLAGS.resnet_depth.startswith("GALA-v2_"):
            print("\n\n\n\n\nUSING RESNET V2 (GALA) {}\n\n\n\n\n".format(
                resnet_size))
            network = resnet_v2_model.resnet_v2(resnet_size=resnet_size,
                                                num_classes=LABEL_CLASSES,
                                                feature_attention="gala",
                                                extra_convs=1,
                                                data_format=FLAGS.data_format,
                                                use_tpu=FLAGS.use_tpu)
        else:
            assert False

        logits, attention = network(
            inputs=images, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
        logits = tf.cast(logits, tf.float32)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })
    batch_size = params['batch_size']

    # Calculate softmax cross entropy and L2 regularization.
    one_hot_labels = tf.one_hot(labels, LABEL_CLASSES)
    loss = tf.losses.softmax_cross_entropy(logits=logits,
                                           onehot_labels=one_hot_labels)

    # Switch hms/bboxs
    if FLAGS.annotation == 'hms':
        pass
    elif FLAGS.annotation == 'bboxs':
        hms = bboxs
    elif FLAGS.annotation == 'none':
        hms = None
    else:
        raise NotImplementedError(FLAGS.annotation)

    # Add attention losses
    if hms is not None:
        map_loss_list = []
        blur_click_maps = 49  # 0 = no, > 0 blur kernel
        blur_click_maps_sigma = 28  # 14

        # Blur the heatmaps
        hms = blur(hms,
                   kernel=blur_click_maps,
                   sigma=blur_click_maps_sigma,
                   dtype=images.dtype)

        mask = tf.cast(tf.greater(ccount, 0), tf.float32)
        mask = tf.reshape(mask, [int(hms.get_shape()[0]), 1, 1, 1])
        for layer in attention:
            layer_shape = [int(x) for x in layer.get_shape()[1:3]]
            layer = tf.cast(layer, tf.float32)
            hms = tf.cast(hms, tf.float32)
            resized_maps = tf.image.resize_bilinear(hms,
                                                    layer_shape,
                                                    align_corners=True)
            if layer.get_shape().as_list()[-1] > 1:
                layer = tf.reduce_mean(tf.pow(layer, 2),
                                       axis=-1,
                                       keep_dims=True)
            resized_maps = l2_channel_norm(resized_maps)
            layer = l2_channel_norm(layer)
            dist = resized_maps - layer
            d = tf.nn.l2_loss(dist * mask)
            map_loss_list += [d]
        denominator = len(attention)
        if len(map_loss_list):
            denominator = len(attention)
            map_loss = (tf.add_n(map_loss_list) / float(denominator)) * 1e-5
            loss += map_loss
        else:
            assert not FLAGS.resnet_depth.startswith(
                "GALA") or not FLAGS.resnet_depth.startswith(
                    "SE"), "Failed to apply attention."
    loss += (WEIGHT_DECAY * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name and 'ATTENTION' not in v.name
        and 'block' not in v.name and 'training' not in v.name
    ]))
    host_call = None
    if mode == tf.estimator.ModeKeys.TRAIN:
        # Compute the current epoch and associated learning rate from global_step.
        global_step = tf.train.get_global_step()
        batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
        current_epoch = (tf.cast(global_step, tf.float32) / batches_per_epoch)
        learning_rate = learning_rate_schedule(current_epoch)
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=MOMENTUM,
                                               use_nesterov=True)
        if FLAGS.use_tpu:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            if FLAGS.clip_gradients == 0:
                print("\nnot clipping gradients\n")
                train_op = optimizer.minimize(loss, global_step)
            else:
                print("\nclipping gradients\n")
                gradients, variables = zip(*optimizer.compute_gradients(loss))
                gradients, _ = tf.clip_by_global_norm(gradients,
                                                      FLAGS.clip_gradients)
                train_op = optimizer.apply_gradients(zip(gradients, variables),
                                                     global_step=global_step)

        if not FLAGS.skip_host_call:

            def host_call_fn(gs, loss, lr, ce):  # , hm=None, image=None):
                """Training host call. Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          gs: `Tensor with shape `[batch]` for the global_step
          loss: `Tensor` with shape `[batch]` for the training loss.
          lr: `Tensor` with shape `[batch]` for the learning_rate.
          ce: `Tensor` with shape `[batch]` for the current_epoch.

        Returns:
          List of summary ops to run on the CPU host.
        """
                gs = gs[0]
                with summary.create_file_writer(FLAGS.model_dir).as_default():
                    with summary.always_record_summaries():
                        summary.scalar('loss', loss[0], step=gs)
                        summary.scalar('learning_rate', lr[0], step=gs)
                        summary.scalar('current_epoch', ce[0], step=gs)
                        # summary.image('image', hm, step=gs)
                        # summary.image('heatmap', image, step=gs)
                        return summary.all_summary_ops()

            # To log the loss, current learning rate, and epoch for Tensorboard, the
            # summary op needs to be run on the host CPU via host_call. host_call
            # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
            # dimension. These Tensors are implicitly concatenated to
            # [params['batch_size']].
            gs_t = tf.reshape(global_step, [1])
            loss_t = tf.reshape(loss, [1])
            lr_t = tf.reshape(learning_rate, [1])
            ce_t = tf.reshape(current_epoch, [1])
            # im_t = tf.cast(images, tf.float32)
            # hm_t = tf.cast(hms, tf.float32)
            host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t])
            # host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t, im_t, hm_t])

    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits):
            """Evaluation metric function. Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
            predictions = tf.argmax(logits, axis=1)
            top_1_accuracy = tf.metrics.accuracy(labels, predictions)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            top_5_accuracy = tf.metrics.mean(in_top_5)

            return {
                'top_1_accuracy': top_1_accuracy,
                'top_5_accuracy': top_5_accuracy,
            }

        eval_metrics = (metric_fn, [labels, logits])
    # logging_hook = tf.train.LoggingTensorHook(
    #   {"logging_hook_loss": loss}, every_n_iter=1)

    return tpu_estimator.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        host_call=host_call,
        eval_metrics=eval_metrics,
        # training_hooks=[logging_hook]
    )
def model_top(labels, preds, cost, lr, mode, hparams):
    tf.summary.scalar(
        "acc",
        tf.reduce_mean(
            tf.to_float(
                tf.equal(labels, tf.argmax(preds,
                                           axis=-1,
                                           output_type=tf.int32)))))
    tf.summary.scalar("loss", cost)

    gs = tf.train.get_global_step()

    if hparams.weight_decay_and_noise:
        cost = weight_decay_and_noise(cost, hparams, lr)
        cost = tf.identity(cost, name="total_loss")
    optimizer = get_optimizer(lr, hparams)

    train_op = tf.contrib.layers.optimize_loss(
        name="training",
        loss=cost,
        global_step=gs,
        learning_rate=lr,
        clip_gradients=hparams.clip_grad_norm or None,
        gradient_noise_scale=hparams.grad_noise_scale or None,
        optimizer=optimizer,
        colocate_gradients_with_ops=True)

    if hparams.use_tpu:

        def metric_fn(l, p):
            return {
                "acc":
                tf.metrics.accuracy(labels=l,
                                    predictions=tf.argmax(
                                        p, -1, output_type=tf.int32)),
            }

        host_call = None
        if hparams.tpu_summarize:
            host_call = tpu.create_host_call(hparams.output_dir)
        tpu.remove_summaries()

        if mode == tf.estimator.ModeKeys.EVAL:
            return tpu_estimator.TPUEstimatorSpec(
                mode=mode,
                predictions=preds,
                loss=cost,
                eval_metrics=(metric_fn, [labels, preds]),
                host_call=host_call)

        return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                              loss=cost,
                                              train_op=train_op,
                                              host_call=host_call)

    return tf.estimator.EstimatorSpec(
        mode,
        eval_metric_ops={
            "acc":
            tf.metrics.accuracy(labels=labels,
                                predictions=tf.argmax(preds,
                                                      axis=-1,
                                                      output_type=tf.int32)),
        },
        loss=cost,
        train_op=train_op)
Example #16
0
def resnet_model_fn(features, labels, mode, params):
    """Our model_fn for ResNet to be used with our Estimator."""
    network = resnet_model.resnet_v2(resnet_size=FLAGS.resnet_size,
                                     num_classes=_LABEL_CLASSES)

    logits = network(inputs=features,
                     is_training=(mode == tf.estimator.ModeKeys.TRAIN))

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    cross_entropy = tf.losses.softmax_cross_entropy(logits=logits,
                                                    onehot_labels=labels)

    # Add weight decay to the loss. We perform weight decay on all trainable
    # variables, which includes batch norm beta and gamma variables.
    loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
        [tf.nn.l2_loss(v) for v in tf.trainable_variables()])

    global_step = tf.train.get_global_step()
    current_epoch = (tf.cast(global_step, tf.float32) /
                     params['batches_per_epoch'])
    learning_rate = learning_rate_schedule(current_epoch)

    # TODO(chrisying): this is a hack to get the LR and epoch for Tensorboard.
    # Reimplement this when TPU training summaries are supported.
    lr_repeat = tf.reshape(
        tf.tile(tf.expand_dims(learning_rate, 0), [
            params['batch_size'],
        ]), [params['batch_size'], 1])
    ce_repeat = tf.reshape(
        tf.tile(tf.expand_dims(current_epoch, 0), [
            params['batch_size'],
        ]), [params['batch_size'], 1])

    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=_MOMENTUM)
        optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)
    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits, lr_repeat, ce_repeat):
            """Evaluation metric fn. Performed on CPU, do not reference TPU ops."""
            predictions = tf.argmax(logits, axis=1)
            accuracy = tf.metrics.accuracy(tf.argmax(labels, axis=1),
                                           predictions)
            lr = tf.metrics.mean(lr_repeat)
            ce = tf.metrics.mean(ce_repeat)
            return {
                'accuracy': accuracy,
                'learning_rate': lr,
                'current_epoch': ce
            }

        eval_metrics = (metric_fn, [labels, logits, lr_repeat, ce_repeat])

    return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metrics=eval_metrics)
Example #17
0
def model_fn(features, labels, mode, params):
  """Constructs DCGAN from individual generator and discriminator networks."""
  del labels    # Unconditional GAN does not use labels

  if mode == tf.estimator.ModeKeys.PREDICT:
    ###########
    # PREDICT #
    ###########
    # Generate fixed random noise on device instead of feeding via input_fn
    np.random.seed(0)
    random_noise = tf.constant(
        np.random.randn(_NUM_VIZ_IMAGES, FLAGS.noise_dim), dtype=tf.float32)
    predictions = {
        'generated_images': model.generator(random_noise,
                                            is_training=False)
    }

    return tpu_estimator.TPUEstimatorSpec(mode=mode, predictions=predictions)

  random_noise = tf.random_normal([params['batch_size'], FLAGS.noise_dim])
  true_samples = features

  is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  generated_samples = model.generator(random_noise,
                                      is_training=is_training)

  # Get logits from discriminator
  d_on_data_logits = tf.squeeze(model.discriminator(true_samples))
  d_on_g_logits = tf.squeeze(model.discriminator(generated_samples))

  # Calculate discriminator loss
  d_loss_on_data = tf.nn.sigmoid_cross_entropy_with_logits(
      labels=tf.ones_like(d_on_data_logits),
      logits=d_on_data_logits)
  d_loss_on_gen = tf.nn.sigmoid_cross_entropy_with_logits(
      labels=tf.zeros_like(d_on_g_logits),
      logits=d_on_g_logits)

  d_loss = d_loss_on_data + d_loss_on_gen

  # Calculate generator loss
  g_loss = tf.nn.sigmoid_cross_entropy_with_logits(
      labels=tf.ones_like(d_on_g_logits),
      logits=d_on_g_logits)

  if mode == tf.estimator.ModeKeys.TRAIN:
    #########
    # TRAIN #
    #########
    d_loss = tf.reduce_mean(d_loss)
    g_loss = tf.reduce_mean(g_loss)
    d_optimizer = tf.train.AdamOptimizer(
        learning_rate=FLAGS.learning_rate, beta1=0.5)
    g_optimizer = tf.train.AdamOptimizer(
        learning_rate=FLAGS.learning_rate, beta1=0.5)

    if FLAGS.use_tpu:
      d_optimizer = tpu_optimizer.CrossShardOptimizer(d_optimizer)
      g_optimizer = tpu_optimizer.CrossShardOptimizer(g_optimizer)

    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
      d_step = d_optimizer.minimize(
          d_loss,
          var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope='Discriminator'))
      g_step = g_optimizer.minimize(
          g_loss,
          var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope='Generator'))

      increment_step = tf.assign_add(tf.train.get_or_create_global_step(), 1)
      joint_op = tf.group([d_step, g_step, increment_step])

      return tpu_estimator.TPUEstimatorSpec(
          mode=mode,
          loss=g_loss,
          train_op=joint_op)

  elif mode == tf.estimator.ModeKeys.EVAL:
    ########
    # EVAL #
    ########
    def _eval_metric_fn(d_loss, g_loss):
      # When using TPUs, this function is run on a different machine than the
      # rest of the model_fn and should not capture any Tensors defined there
      return {
          'discriminator_loss': tf.metrics.mean(d_loss),
          'generator_loss': tf.metrics.mean(g_loss)}

    return tpu_estimator.TPUEstimatorSpec(
        mode=mode,
        loss=tf.reduce_mean(g_loss),
        eval_metrics=(_eval_metric_fn, [d_loss, g_loss]))

  # Should never reach here
  raise ValueError('Invalid mode provided to model_fn')
Example #18
0
 def eval_model_fn_no_eval_metrics(features, labels, mode, params):
   del features, labels, params
   return tpu_estimator.TPUEstimatorSpec(
       mode=mode, loss=constant_op.constant(_EXPECTED_LOSS))
Example #19
0
def resnet_model_fn(features, labels, mode, params):
  """The model_fn for ResNet to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
        `params['batch_size']` is always provided and should be used as the
        effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
  if isinstance(features, dict):
    features = features['feature']

  if FLAGS.data_format == 'channels_first':
    assert not FLAGS.transpose_input    # channels_first only for GPU
    features = tf.transpose(features, [0, 3, 1, 2])

  if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT:
    features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

  # Normalize the image to zero mean and unit variance.
  features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
  features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)

  # This nested function allows us to avoid duplicating the logic which
  # builds the network, for different values of --precision.
  def build_network():
    network = resnet_model.resnet_v1(
        resnet_depth=FLAGS.resnet_depth,
        num_classes=FLAGS.num_label_classes,
        data_format=FLAGS.data_format)
    return network(
        inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))

  if FLAGS.precision == 'bfloat16':
    with bfloat16.bfloat16_scope():
      logits = build_network()
    logits = tf.cast(logits, tf.float32)
  elif FLAGS.precision == 'float32':
    logits = build_network()

  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
    }
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })

  # If necessary, in the model_fn, use params['batch_size'] instead the batch
  # size flags (--train_batch_size or --eval_batch_size).
  batch_size = params['batch_size']   # pylint: disable=unused-variable

  # Calculate loss, which includes softmax cross entropy and L2 regularization.
  one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes)
  cross_entropy = tf.losses.softmax_cross_entropy(
      logits=logits, onehot_labels=one_hot_labels)

  # Add weight decay to the loss for non-batch-normalization variables.
  loss = cross_entropy + FLAGS.weight_decay * tf.add_n(
      [tf.nn.l2_loss(v) for v in tf.trainable_variables()
       if 'batch_normalization' not in v.name])

  host_call = None
  if mode == tf.estimator.ModeKeys.TRAIN:
    # Compute the current epoch and associated learning rate from global_step.
    global_step = tf.train.get_global_step()
    batches_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
    current_epoch = (tf.cast(global_step, tf.float32) /
                     batches_per_epoch)
    learning_rate = learning_rate_schedule(current_epoch)

    optimizer = tf.train.MomentumOptimizer(
        learning_rate=learning_rate, momentum=FLAGS.momentum, use_nesterov=True)
    if FLAGS.use_tpu:
      # When using TPU, wrap the optimizer with CrossShardOptimizer which
      # handles synchronization details between different TPU cores. To the
      # user, this should look like regular synchronous training.
      optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

    # Batch normalization requires UPDATE_OPS to be added as a dependency to
    # the train operation.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
      train_op = optimizer.minimize(loss, global_step)

    if not FLAGS.skip_host_call:
      def host_call_fn(gs, loss, lr, ce):
        """Training host call. Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          gs: `Tensor with shape `[batch]` for the global_step
          loss: `Tensor` with shape `[batch]` for the training loss.
          lr: `Tensor` with shape `[batch]` for the learning_rate.
          ce: `Tensor` with shape `[batch]` for the current_epoch.

        Returns:
          List of summary ops to run on the CPU host.
        """
        gs = gs[0]
        with summary.create_file_writer(FLAGS.model_dir).as_default():
          with summary.always_record_summaries():
            summary.scalar('loss', loss[0], step=gs)
            summary.scalar('learning_rate', lr[0], step=gs)
            summary.scalar('current_epoch', ce[0], step=gs)

            return summary.all_summary_ops()

      # To log the loss, current learning rate, and epoch for Tensorboard, the
      # summary op needs to be run on the host CPU via host_call. host_call
      # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
      # dimension. These Tensors are implicitly concatenated to
      # [params['batch_size']].
      gs_t = tf.reshape(global_step, [1])
      loss_t = tf.reshape(loss, [1])
      lr_t = tf.reshape(learning_rate, [1])
      ce_t = tf.reshape(current_epoch, [1])

      host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t])

  else:
    train_op = None

  eval_metrics = None
  if mode == tf.estimator.ModeKeys.EVAL:
    def metric_fn(labels, logits):
      """Evaluation metric function. Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
      predictions = tf.argmax(logits, axis=1)
      top_1_accuracy = tf.metrics.accuracy(labels, predictions)
      in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
      top_5_accuracy = tf.metrics.mean(in_top_5)

      return {
          'top_1_accuracy': top_1_accuracy,
          'top_5_accuracy': top_5_accuracy,
      }

    eval_metrics = (metric_fn, [labels, logits])

  return tpu_estimator.TPUEstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      host_call=host_call,
      eval_metrics=eval_metrics)
Example #20
0
def word_rnn_model(features, labels, mode, params):  
  batch_size = params['batch_size']
  
  """RNN model to predict from sequence of words to a class."""
  # Convert indexes of words into embeddings.
  # This creates embeddings matrix of [n_words, EMBEDDING_SIZE] and then
  # maps word indexes of the sequence into [batch_size, sequence_length,
  # EMBEDDING_SIZE].
  word_vectors = tf.contrib.layers.embed_sequence(
      features, vocab_size=n_words, embed_dim=EMBEDDING_SIZE)

  # Split into list of embedding per word, while removing doc length dim.
  # word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE].
  word_list = tf.unstack(word_vectors, axis=1)

  # Create a Gated Recurrent Unit cell with hidden size of EMBEDDING_SIZE.
  #cell = tf.nn.rnn_cell.GRUCell(EMBEDDING_SIZE)
  #cell = tf.nn.rnn_cell.LSTMCell(EMBEDDING_SIZE)
  cell = tf.nn.rnn_cell.BasicLSTMCell(EMBEDDING_SIZE, state_is_tuple=False)

  # Create an unrolled Recurrent Neural Networks to length of
  # MAX_DOCUMENT_LENGTH and passes word_list as inputs for each unit.
  _, encoding = tf.nn.static_rnn(cell, word_list, dtype=tf.float32)

  # Given encoding of RNN, take encoding of last step (e.g hidden size of the
  # neural network of last step) and pass it as features for softmax
  # classification over output classes.
  logits = tf.layers.dense(encoding, MAX_LABEL, activation=None)
  
  #loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
  loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) +\
          _WEIGHT_DECAY * tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables()
          if 'batch_normalization' not in v.name])
  
  #get current training epoch
  batches_per_epoch = _NUM_TRAIN_IMAGES / FLAGS.train_batch_size
  global_step = tf.train.get_global_step()
  current_epoch = (tf.cast(global_step, tf.float32)/batches_per_epoch)
  learning_rate = learning_rate_schedule(current_epoch)
  
  predicted_classes = tf.argmax(logits, 1)
  if mode == tf.estimator.ModeKeys.PREDICT:
    return tpu_estimator.TPUEstimatorSpec(
        mode=mode,
        predictions={
            'class': predicted_classes,
            'prob': tf.nn.softmax(logits)
        })
  
  if mode == tf.estimator.ModeKeys.TRAIN:
    #optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
    optimizer = tf.train.MomentumOptimizer(
        learning_rate=learning_rate, momentum=_MOMENTUM, use_nesterov=True)
    if FLAGS.use_tpu:
      # When using TPU, wrap the optimizer with CrossShardOptimizer which
      # handles synchronization details between different TPU cores. To the
      # user, this should look like regular synchronous training.
      optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)
    train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
    return tpu_estimator.TPUEstimatorSpec(mode, loss=loss, train_op=train_op)

  #trick to report Learning rate as a metric: repeat batch_size time
  lr_repeat = tf.reshape(
      tf.tile(tf.expand_dims(learning_rate, 0), [
          batch_size,
      ]), [batch_size, 1])
      
  ce_repeat = tf.reshape(
      tf.tile(tf.expand_dims(current_epoch, 0), [
          batch_size,
      ]), [batch_size, 1])      
  eval_metrics = None
  if mode == tf.estimator.ModeKeys.EVAL:
    def metric_fn(labels, logits, lr_repeat, ce_repeat):
      """Evaluation metric fn. Performed on CPU, do not reference TPU ops."""      
      
      predicted_classes = tf.argmax(logits, 1)
      return {
          'accuracy': tf.metrics.accuracy(
                                  labels=labels, predictions=predicted_classes),
          'learning_rate': tf.metrics.mean(lr_repeat),
          'current_epoch': tf.metrics.mean(ce_repeat)
          }

    eval_metrics = (metric_fn, [labels, logits, lr_repeat, ce_repeat])

  return tpu_estimator.TPUEstimatorSpec(
      mode=mode, loss=loss, eval_metrics=eval_metrics)
Example #21
0
def inception_model_fn(features, labels, mode, params):
    """Inception v3 model using Estimator API."""
    num_classes = FLAGS.num_classes
    training_active = (mode == tf.estimator.ModeKeys.TRAIN)
    eval_active = (mode == tf.estimator.ModeKeys.EVAL)

    features = tensor_transform_fn(features, params['input_perm'])

    with slim.arg_scope(
            inception.inception_v3_arg_scope(
                use_fused_batchnorm=FLAGS.use_fused_batchnorm)):
        logits, end_points = inception.inception_v3(
            features,
            num_classes,
            is_training=training_active,
            depth_multiplier=FLAGS.depth_multiplier)

    predictions = {
        'classes': tf.argmax(input=logits, axis=1),
        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
    }

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    if 'AuxLogits' in end_points:
        aux_loss = tf.losses.softmax_cross_entropy(
            onehot_labels=labels,
            logits=end_points['AuxLogits'],
            weights=0.4,
            label_smoothing=0.1,
            scope='aux_loss')
        tf.losses.add_loss(aux_loss)

    prediction_loss = tf.losses.softmax_cross_entropy(onehot_labels=labels,
                                                      logits=logits,
                                                      weights=1.0,
                                                      label_smoothing=0.1)
    tf.losses.add_loss(prediction_loss)
    loss = tf.losses.get_total_loss(add_regularization_losses=True)

    initial_learning_rate = FLAGS.learning_rate * FLAGS.train_batch_size / 256
    final_learning_rate = 0.01 * initial_learning_rate

    train_op = None
    if training_active:
        # Multiply the learning rate by 0.1 every 30 epochs.
        training_set_len = imagenet.get_split_size('train')
        batches_per_epoch = training_set_len // FLAGS.train_batch_size
        learning_rate = tf.train.exponential_decay(
            learning_rate=initial_learning_rate,
            global_step=tf.train.get_global_step(),
            decay_steps=_LEARNING_RATE_DECAY_EPOCHS * batches_per_epoch,
            decay_rate=_LEARNING_RATE_DECAY,
            staircase=True)

        # Set a minimum boundary for the learning rate.
        learning_rate = tf.maximum(learning_rate,
                                   final_learning_rate,
                                   name='learning_rate')

        # tf.summary.scalar('learning_rate', learning_rate)

        if FLAGS.optimizer == 'sgd':
            tf.logging.info('Using SGD optimizer')
            optimizer = tf.train.GradientDescentOptimizer(
                learning_rate=FLAGS.learning_rate)
        elif FLAGS.optimizer == 'momentum':
            tf.logging.info('Using Momentum optimizer')
            optimizer = tf.train.MomentumOptimizer(
                learning_rate=FLAGS.learning_rate, momentum=0.9)
        else:
            tf.logging.fatal('Unknown optimizer:', FLAGS.optimizer)

        if FLAGS.use_tpu:
            optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(
                loss, global_step=tf.train.get_or_create_global_step())

    eval_metrics = None
    if eval_active:

        def metric_fn(labels, logits):
            predictions = tf.argmax(input=logits, axis=1)
            accuracy = tf.metrics.accuracy(tf.argmax(input=labels, axis=1),
                                           predictions)
            return {'accuracy': accuracy}

        eval_metrics = (metric_fn, [labels, logits])

    return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metrics=eval_metrics)
def model_fn(features, labels, mode, params):
    """A model is called by TpuEstimator."""
    del labels
    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
    if FLAGS.use_tpu:
        ctx = params['context']
        num_hosts = ctx.num_hosts
        host_placement_fn = ctx.tpu_host_placement_function
        device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
        tf.logging.info('device_list = %s' % device_list, )
        # TODO(ylc): Better estimation of replica cache size?
        replica_cache_size = 300 * 1000000  # 300M per replica
        # Worker 0 caches all the TPU binaries.
        worker0_mem = replica_cache_size * ctx.num_replicas
        devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
        var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                      devices_memeory_usage)
        mesh_devices = [''] * mesh_shape.size
        mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                    mesh_devices,
                                                    ctx.device_assignment)
    else:
        var_placer = None
        mesh_devices = [''] * mesh_shape.size
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
            mesh_shape, layout_rules, mesh_devices)
    mesh = mtf.Mesh(graph, 'my_mesh', var_placer)

    with mtf.utils.outside_all_rewrites():
        logits, loss = toy_model(features, mesh)

    # TRAIN mode
    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
        if FLAGS.optimizer == 'Adafactor':
            optimizer = mtf.optimize.AdafactorOptimizer()
        else:
            assert FLAGS.optimizer == 'SGD'
            optimizer = mtf.optimize.SgdOptimizer(learning_rate=FLAGS.lr)
        update_ops = optimizer.apply_grads(var_grads,
                                           graph.trainable_variables)
    else:
        # for now, we can only export fully-replicated tensors.
        fully_replicated_logits = mtf.anonymize(logits)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss))

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        tf.logging.info('tf_update_ops: {}'.format(tf_update_ops))
        train_op = tf.group(tf_update_ops)
    else:
        tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)

    with mtf.utils.outside_all_rewrites():
        # Copy master variables to slices. Must be called first.
        restore_hook = mtf.MtfRestoreHook(lowering)
        if mode == tf.estimator.ModeKeys.TRAIN:
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                FLAGS.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(tf_logits):
                mean_logits = tf.metrics.mean(tf_logits)
                return {'mean_logits': mean_logits}

            eval_metrics = (metric_fn, [tf_logits])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)
Example #23
0
def _model_fn(features, labels, mode, params, model, variable_filter_fn=None):
    """Model defination for the RetinaNet model based on ResNet-50.

  Args:
    features: the input image tensor with shape [batch_size, height, width, 3].
      The height and width are fixed and equal.
    labels: the input labels in a dictionary. The labels include class targets
      and box targets which are dense label maps. The labels are generated from
      get_input_fn function in data/dataloader.py
    mode: the mode of TPUEstimator including TRAIN, EVAL, and PREDICT.
    params: the dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    model: the RetinaNet model outputs class logits and box regression outputs.
    variable_filter_fn: the filter function that takes trainable_variables and
      returns the variable list after applying the filter rule.

  Returns:
    tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction.
  """
    cls_outputs, box_outputs = model(features,
                                     min_level=params['min_level'],
                                     max_level=params['max_level'],
                                     num_classes=params['num_classes'],
                                     num_anchors=len(params['aspect_ratios'] *
                                                     params['num_scales']),
                                     is_training_bn=params['is_training_bn'])
    levels = cls_outputs.keys()

    # First check if it is in PREDICT mode.
    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'image': features,
        }
        for level in levels:
            predictions['cls_outputs_%d' % level] = cls_outputs[level]
            predictions['box_outputs_%d' % level] = box_outputs[level]
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Load pretrained model from checkpoint.
    if params['resnet_checkpoint'] and mode == tf.estimator.ModeKeys.TRAIN:

        def scaffold_fn():
            """Loads pretrained model through scaffold function."""
            tf.train.init_from_checkpoint(params['resnet_checkpoint'], {
                '/': 'resnet50/',
            })
            return tf.train.Scaffold()
    else:
        scaffold_fn = None

    # Set up training loss and learning rate.
    global_step = tf.train.get_global_step()
    learning_rate = _learning_rate_schedule(params['learning_rate'],
                                            params['lr_warmup_init'],
                                            params['lr_warmup_step'],
                                            params['lr_drop_step'],
                                            global_step)
    # cls_loss and box_loss are for logging. only total_loss is optimized.
    total_loss, cls_loss, box_loss = _detection_loss(cls_outputs, box_outputs,
                                                     labels, params)

    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.MomentumOptimizer(learning_rate,
                                               momentum=params['momentum'])
        if params['use_tpu']:
            optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        var_list = variable_filter_fn(
            tf.trainable_variables()) if variable_filter_fn else None
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(total_loss,
                                          global_step,
                                          var_list=var_list)
    else:
        train_op = None

    # Evaluation only works on GPU/CPU host and batch_size=1
    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(**kwargs):
            """Evaluation metric fn. Performed on CPU, do not reference TPU ops."""
            eval_anchors = anchors.Anchors(params['min_level'],
                                           params['max_level'],
                                           params['num_scales'],
                                           params['aspect_ratios'],
                                           params['anchor_scale'],
                                           params['image_size'])
            anchor_labeler = anchors.AnchorLabeler(eval_anchors,
                                                   params['num_classes'])
            cls_loss = tf.metrics.mean(kwargs['cls_loss_repeat'])
            box_loss = tf.metrics.mean(kwargs['box_loss_repeat'])
            # add metrics to output
            cls_outputs = {}
            box_outputs = {}
            for level in range(params['min_level'], params['max_level'] + 1):
                cls_outputs[level] = kwargs['cls_outputs_%d' % level]
                box_outputs[level] = kwargs['box_outputs_%d' % level]
            detections = anchor_labeler.generate_detections(
                cls_outputs, box_outputs, kwargs['source_ids'])
            eval_metric = coco_metric.EvaluationMetric(params['val_json_file'])
            coco_metrics = eval_metric.estimator_metric_fn(
                detections, kwargs['image_scales'])
            # Add metrics to output.
            output_metrics = {
                'cls_loss': cls_loss,
                'box_loss': box_loss,
            }
            output_metrics.update(coco_metrics)
            return output_metrics

        batch_size = params['batch_size']
        cls_loss_repeat = tf.reshape(
            tf.tile(tf.expand_dims(cls_loss, 0), [
                batch_size,
            ]), [batch_size, 1])
        box_loss_repeat = tf.reshape(
            tf.tile(tf.expand_dims(box_loss, 0), [
                batch_size,
            ]), [batch_size, 1])
        metric_fn_inputs = {
            'cls_loss_repeat': cls_loss_repeat,
            'box_loss_repeat': box_loss_repeat,
            'source_ids': labels['source_ids'],
            'image_scales': labels['image_scales'],
        }
        for level in range(params['min_level'], params['max_level'] + 1):
            metric_fn_inputs['cls_outputs_%d' % level] = cls_outputs[level]
            metric_fn_inputs['box_outputs_%d' % level] = box_outputs[level]
        eval_metrics = (metric_fn, metric_fn_inputs)

    return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                          loss=total_loss,
                                          train_op=train_op,
                                          eval_metrics=eval_metrics,
                                          scaffold_fn=scaffold_fn)
Example #24
0
def model_fn(features, labels, mode, params):

  output_size = params['output_size']
  input_size = params['input_size']
  batch_size = params["batch_size"]
  embedding_size = input_size[2]
  vocab_size = input_size[1]
  max_length = input_size[0]
  inputs = features['inputs']
  lengths = features['lengths']

  embeddings = tf.get_variable(
      "embeddings", [vocab_size, embedding_size], dtype=tf.float32)

  input_embeddings = tf.nn.embedding_lookup(embeddings, inputs)

  def rnn_with_dropout_cell():
    if rnncell == 'lstm':
      cell = tf.contrib.rnn.BasicLSTMCell(
        embedding_size,
        forget_bias=0.0,
        state_is_tuple=True)
    elif rnncell == 'rnn':
      cell = tf.contrib.rnn.BasicRNNCell(embedding_size)
    elif rnncell == 'gru':
      cell = tf.contrib.rnn.GRUCell(embedding_size)
    return cell

  cell_network = tf.contrib.rnn.MultiRNNCell(
      [rnn_with_dropout_cell() for _ in range(num_layers)],
      state_is_tuple=True)
  network_zero_state = cell_network.zero_state(batch_size, dtype=tf.float32)

  outputs, _ = tf.nn.dynamic_rnn(
        cell_network, input_embeddings, initial_state=network_zero_state, swap_memory=True)


  outputs_flat = tf.reshape(outputs, [-1, embedding_size])
  logits_flat = tf.contrib.layers.linear(outputs_flat, vocab_size)
  labels_flat = tf.reshape(labels, [-1])
  mask = tf.sequence_mask(lengths,
                          maxlen=max_length)
  mask = tf.cast(mask, tf.float32)
  mask_flat = tf.reshape(mask, [-1])
  num_logits = tf.to_float(tf.reduce_sum(lengths))

  softmax_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=labels_flat, logits=logits_flat)
  loss = tf.reduce_sum(mask_flat * softmax_cross_entropy) / num_logits

  # Configuring the optimization step.
  learning_rate = tf.train.exponential_decay(
      0.1,
      tf.train.get_global_step(),
      10000,
      0.9)
  if opt == 'sgd':
      tf.logging.info('Using SGD optimizer')
      optimizer = tf.train.GradientDescentOptimizer(
          learning_rate=learning_rate)
  elif opt == 'momentum':
      tf.logging.info('Using Momentum optimizer')
      optimizer = tf.train.MomentumOptimizer(
          learning_rate=learning_rate, momentum=0.9)
  elif opt == 'rms':
      tf.logging.info('Using RMS optimizer')
      optimizer = tf.train.RMSPropOptimizer(
          learning_rate,
          RMSPROP_DECAY,
          momentum=RMSPROP_MOMENTUM,
          epsilon=RMSPROP_EPSILON)
  if FLAGS.use_tpu:
    optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

  train_op = optimizer.minimize(
      loss,
      global_step=tf.train.get_global_step())

  param_stats = tf.profiler.profile(
    tf.get_default_graph(),
    options=ProfileOptionBuilder.trainable_variables_parameter())
  fl_stats = tf.profiler.profile(
    tf.get_default_graph(),
    options=tf.profiler.ProfileOptionBuilder.float_operation())

  return tpu_estimator.TPUEstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op)
Example #25
0
def model_fn(features, labels, mode, params):
    """
    Create the model for estimator api

    Args:
        features: tensor with shape
            [BATCH_SIZE, go.N, go.N, get_features_planes()]
        labels: dict from string to tensor with shape
            'pi_tensor': [BATCH_SIZE, go.N * go.N + 1]
            'value_tensor': [BATCH_SIZE]
        mode: a tf.estimator.ModeKeys (batchnorm params update for TRAIN only)
        params: A dictionary (Typically derived from the FLAGS object.)
    Returns: tf.estimator.EstimatorSpec with props
        mode: same as mode arg
        predictions: dict of tensors
            'policy': [BATCH_SIZE, go.N * go.N + 1]
            'value': [BATCH_SIZE]
        loss: a single value tensor
        train_op: train op
        eval_metric_ops
    return dict of tensors
        logits: [BATCH_SIZE, go.N * go.N + 1]
    """

    policy_output, value_output, logits = model_inference_fn(
        features, mode == tf.estimator.ModeKeys.TRAIN, params)

    # train ops
    policy_cost = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits_v2(
            logits=logits, labels=tf.stop_gradient(labels['pi_tensor'])))

    value_cost = params['value_cost_weight'] * tf.reduce_mean(
        tf.square(value_output - labels['value_tensor']))

    reg_vars = [v for v in tf.trainable_variables()
                if 'bias' not in v.name and 'beta' not in v.name]
    l2_cost = params['l2_strength'] * \
        tf.add_n([tf.nn.l2_loss(v) for v in reg_vars])

    combined_cost = policy_cost + value_cost + l2_cost

    global_step = tf.train.get_or_create_global_step()
    learning_rate = tf.train.piecewise_constant(
        global_step, params['lr_boundaries'], params['lr_rates'])
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    # Insert quantization ops if requested
    if params['quantize']:
        if mode == tf.estimator.ModeKeys.TRAIN:
            tf.contrib.quantize.create_training_graph(
                quant_delay=params['quant_delay'])
        else:
            tf.contrib.quantize.create_eval_graph()

    optimizer = tf.train.MomentumOptimizer(
        learning_rate, params['sgd_momentum'])
    if params['use_tpu']:
        optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(combined_cost, global_step=global_step)

    # Computations to be executed on CPU, outside of the main TPU queues.
    def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor,
                                  value_tensor, policy_cost, value_cost,
                                  l2_cost, combined_cost, step,
                                  est_mode=tf.estimator.ModeKeys.TRAIN):
        policy_entropy = -tf.reduce_mean(tf.reduce_sum(
            policy_output * tf.log(policy_output), axis=1))
        # pi_tensor is one_hot when generated from sgfs (for supervised learning)
        # and soft-max when using self-play records. argmax normalizes the two.
        policy_target_top_1 = tf.argmax(pi_tensor, axis=1)

        policy_output_in_top1 = tf.to_float(
            tf.nn.in_top_k(policy_output, policy_target_top_1, k=1))
        policy_output_in_top3 = tf.to_float(
            tf.nn.in_top_k(policy_output, policy_target_top_1, k=3))

        policy_top_1_confidence = tf.reduce_max(policy_output, axis=1)
        policy_target_top_1_confidence = tf.boolean_mask(
            policy_output,
            tf.one_hot(policy_target_top_1, tf.shape(policy_output)[1]))

        value_cost_normalized = value_cost / params['value_cost_weight']
        avg_value_observed = tf.reduce_mean(value_tensor)

        with tf.variable_scope('metrics'):
            metric_ops = {
                'policy_cost': tf.metrics.mean(policy_cost),
                'value_cost': tf.metrics.mean(value_cost),
                'value_cost_normalized': tf.metrics.mean(value_cost_normalized),
                'l2_cost': tf.metrics.mean(l2_cost),
                'policy_entropy': tf.metrics.mean(policy_entropy),
                'combined_cost': tf.metrics.mean(combined_cost),
                'avg_value_observed': tf.metrics.mean(avg_value_observed),
                'policy_accuracy_top_1': tf.metrics.mean(policy_output_in_top1),
                'policy_accuracy_top_3': tf.metrics.mean(policy_output_in_top3),
                'policy_top_1_confidence': tf.metrics.mean(policy_top_1_confidence),
                'policy_target_top_1_confidence': tf.metrics.mean(
                    policy_target_top_1_confidence),
                'value_confidence': tf.metrics.mean(tf.abs(value_output)),
            }

        if est_mode == tf.estimator.ModeKeys.EVAL:
            return metric_ops

        # NOTE: global_step is rounded to a multiple of FLAGS.summary_steps.
        eval_step = tf.reduce_min(step)

        # Create summary ops so that they show up in SUMMARIES collection
        # That way, they get logged automatically during training
        summary_writer = summary.create_file_writer(FLAGS.work_dir)
        with summary_writer.as_default(), \
                summary.record_summaries_every_n_global_steps(
                    params['summary_steps'], eval_step):
            for metric_name, metric_op in metric_ops.items():
                summary.scalar(metric_name, metric_op[1], step=eval_step)

        # Reset metrics occasionally so that they are mean of recent batches.
        reset_op = tf.variables_initializer(tf.local_variables('metrics'))
        cond_reset_op = tf.cond(
            tf.equal(eval_step % params['summary_steps'], tf.to_int64(1)),
            lambda: reset_op,
            lambda: tf.no_op())

        return summary.all_summary_ops() + [cond_reset_op]

    metric_args = [
        policy_output,
        value_output,
        labels['pi_tensor'],
        labels['value_tensor'],
        tf.reshape(policy_cost, [1]),
        tf.reshape(value_cost, [1]),
        tf.reshape(l2_cost, [1]),
        tf.reshape(combined_cost, [1]),
        tf.reshape(global_step, [1]),
    ]

    predictions = {
        'policy_output': policy_output,
        'value_output': value_output,
    }

    eval_metrics_only_fn = functools.partial(
        eval_metrics_host_call_fn, est_mode=tf.estimator.ModeKeys.EVAL)
    host_call_fn = functools.partial(
        eval_metrics_host_call_fn, est_mode=tf.estimator.ModeKeys.TRAIN)

    tpu_estimator_spec = tpu_estimator.TPUEstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=combined_cost,
        train_op=train_op,
        eval_metrics=(eval_metrics_only_fn, metric_args),
        host_call=(host_call_fn, metric_args)
    )
    if params['use_tpu']:
        return tpu_estimator_spec
    else:
        return tpu_estimator_spec.as_estimator_spec()
Example #26
0
def resnet_model_fn(features, labels, mode, params):
  """The model_fn for ResNet to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
        `params['batch_size']` is always provided and should be used as the
        effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
  if isinstance(features, dict):
    features = features['feature']

  # In most cases, the default data format NCHW instead of NHWC should be
  # used for a significant performance boost on GPU/TPU. NHWC should be used
  # only if the network needs to be run on CPU since the pooling operations
  # are only supported on NHWC.
  if FLAGS.data_format == 'channels_first':
    features = tf.transpose(features, [0, 3, 1, 2])

  network = resnet_model.resnet_v1(
      resnet_depth=FLAGS.resnet_depth,
      num_classes=LABEL_CLASSES,
      data_format=FLAGS.data_format)

  logits = network(
      inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))

  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
    }
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })

  # If necessary, in the model_fn, use params['batch_size'] instead the batch
  # size flags (--train_batch_size or --eval_batch_size).
  batch_size = params['batch_size']   # pylint: disable=unused-variable

  # Calculate loss, which includes softmax cross entropy and L2 regularization.
  one_hot_labels = tf.one_hot(labels, LABEL_CLASSES)
  cross_entropy = tf.losses.softmax_cross_entropy(
      logits=logits, onehot_labels=one_hot_labels)

  # Add weight decay to the loss for non-batch-normalization variables.
  loss = cross_entropy + WEIGHT_DECAY * tf.add_n(
      [tf.nn.l2_loss(v) for v in tf.trainable_variables()
       if 'batch_normalization' not in v.name])

  host_call = None
  if mode == tf.estimator.ModeKeys.TRAIN:
    # Compute the current epoch and associated learning rate from global_step.
    global_step = tf.train.get_global_step()
    batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
    current_epoch = (tf.cast(global_step, tf.float32) /
                     batches_per_epoch)
    learning_rate = learning_rate_schedule(current_epoch)

    optimizer = tf.train.MomentumOptimizer(
        learning_rate=learning_rate, momentum=MOMENTUM, use_nesterov=True)
    if FLAGS.use_tpu:
      # When using TPU, wrap the optimizer with CrossShardOptimizer which
      # handles synchronization details between different TPU cores. To the
      # user, this should look like regular synchronous training.
      optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

    # Batch normalization requires UPDATE_OPS to be added as a dependency to
    # the train operation.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
      train_op = optimizer.minimize(loss, global_step)

    # To log the loss, current learning rate, and epoch for Tensorboard, the
    # summary op needs to be run on the host CPU via host_call. host_call
    # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
    # dimension. These Tensors are implicitly concatenated to
    # [params['batch_size']].
    gs_t = tf.reshape(global_step, [1])
    loss_t = tf.reshape(loss, [1])
    lr_t = tf.reshape(learning_rate, [1])
    ce_t = tf.reshape(current_epoch, [1])

    def host_call_fn(gs, loss, lr, ce):
      """Training host call. Creates scalar summaries for training metrics.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `host_call`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `host_call`.

      Args:
        gs: `Tensor with shape `[batch]` for the global_step
        loss: `Tensor` with shape `[batch]` for the training loss.
        lr: `Tensor` with shape `[batch]` for the learning_rate.
        ce: `Tensor` with shape `[batch]` for the current_epoch.

      Returns:
        List of summary ops to run on the CPU host.
      """
      gs = gs[0]
      with summary.create_file_writer(FLAGS.model_dir).as_default():
        with summary.always_record_summaries():
          summary.scalar('loss', tf.reduce_mean(loss), step=gs)
          summary.scalar('learning_rate', tf.reduce_mean(lr), step=gs)
          summary.scalar('current_epoch', tf.reduce_mean(ce), step=gs)

          return summary.all_summary_ops()

    host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t])

  else:
    train_op = None

  eval_metrics = None
  if mode == tf.estimator.ModeKeys.EVAL:
    def metric_fn(labels, logits):
      """Evaluation metric function. Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch, ]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
      predictions = tf.argmax(logits, axis=1)
      top_1_accuracy = tf.metrics.accuracy(labels, predictions)
      in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
      top_5_accuracy = tf.metrics.mean(in_top_5)

      return {
          'Top-1 accuracy': top_1_accuracy,
          'Top-5 accuracy': top_5_accuracy,
      }

    eval_metrics = (metric_fn, [labels, logits])

  param_stats = tf.profiler.profile(
    tf.get_default_graph(),
    options=ProfileOptionBuilder.trainable_variables_parameter())
  fl_stats = tf.profiler.profile(
    tf.get_default_graph(),
    options=tf.profiler.ProfileOptionBuilder.float_operation())

  return tpu_estimator.TPUEstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      host_call=host_call,
      eval_metrics=eval_metrics)
Example #27
0
def _model_fn(features, labels, mode, params, model):
  """Model defination for the SSD model based on ResNet-50.

  Args:
    features: the input image tensor with shape [batch_size, height, width, 3].
      The height and width are fixed and equal.
    labels: the input labels in a dictionary. The labels include class targets
      and box targets which are dense label maps. The labels are generated from
      get_input_fn function in data/dataloader.py
    mode: the mode of TPUEstimator including TRAIN, EVAL, and PREDICT.
    params: the dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    model: the SSD model outputs class logits and box regression outputs.

  Returns:
    spec: the EstimatorSpec or TPUEstimatorSpec to run training, evaluation,
      or prediction.
  """
  if mode == tf.estimator.ModeKeys.PREDICT:
    labels = features
    features = labels.pop('image')

  # Manually apply the double transpose trick for training data.
  if params['transpose_input'] and mode != tf.estimator.ModeKeys.PREDICT:
    features = tf.transpose(features, [3, 0, 1, 2])
    labels[ssd_constants.BOXES] = tf.transpose(
        labels[ssd_constants.BOXES], [2, 0, 1])
    labels[ssd_constants.CLASSES] = tf.transpose(
        labels[ssd_constants.CLASSES], [2, 0, 1])

  # Normalize the image to zero mean and unit variance.
  mlperf_log.ssd_print(key=mlperf_log.DATA_NORMALIZATION_MEAN,
                       value=ssd_constants.NORMALIZATION_MEAN)
  mlperf_log.ssd_print(key=mlperf_log.DATA_NORMALIZATION_STD,
                       value=ssd_constants.NORMALIZATION_STD)

  features -= tf.constant(
      ssd_constants.NORMALIZATION_MEAN, shape=[1, 1, 3], dtype=features.dtype)

  features /= tf.constant(
      ssd_constants.NORMALIZATION_STD, shape=[1, 1, 3], dtype=features.dtype)

  def _model_outputs():
    return model(
        features, params, is_training_bn=(mode == tf.estimator.ModeKeys.TRAIN))

  if params['use_bfloat16']:
    with bfloat16.bfloat16_scope():
      cls_outputs, box_outputs = _model_outputs()
      levels = cls_outputs.keys()
      for level in levels:
        cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
        box_outputs[level] = tf.cast(box_outputs[level], tf.float32)
  else:
    cls_outputs, box_outputs = _model_outputs()
    levels = cls_outputs.keys()

  # First check if it is in PREDICT mode.
  if mode == tf.estimator.ModeKeys.PREDICT:
    flattened_cls, flattened_box = concat_outputs(cls_outputs, box_outputs)
    mlperf_log.ssd_print(
        key=mlperf_log.SCALES, value=ssd_constants.BOX_CODER_SCALES)
    ssd_box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder(
        scale_factors=ssd_constants.BOX_CODER_SCALES)

    anchors = box_list.BoxList(
        tf.convert_to_tensor(dataloader.DefaultBoxes()('ltrb')))

    decoded_boxes = box_coder.batch_decode(
        encoded_boxes=flattened_box, box_coder=ssd_box_coder, anchors=anchors)

    pred_scores = tf.nn.softmax(flattened_cls, axis=2)

    pred_scores, indices = select_top_k_scores(pred_scores,
                                               ssd_constants.MAX_NUM_EVAL_BOXES)

    predictions = dict(
        labels,
        indices=indices,
        pred_scores=pred_scores,
        pred_box=decoded_boxes,
    )

    if params['visualize_dataloader']:
      # this is for inference visualization.
      predictions['image'] = features

    if params['use_tpu']:
      return tpu_estimator.TPUEstimatorSpec(mode=mode, predictions=predictions)

    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

  # Load pretrained model from checkpoint.
  if params['resnet_checkpoint'] and mode == tf.estimator.ModeKeys.TRAIN:

    def scaffold_fn():
      """Loads pretrained model through scaffold function."""
      tf.train.init_from_checkpoint(params['resnet_checkpoint'], {
          '/': 'resnet%s/' % ssd_constants.RESNET_DEPTH,
      })
      return tf.train.Scaffold()
  else:
    scaffold_fn = None

  # Set up training loss and learning rate.
  update_learning_rate_schedule_parameters(params)
  global_step = tf.train.get_or_create_global_step()
  learning_rate = learning_rate_schedule(params, global_step)
  mlperf_log.ssd_print(key=mlperf_log.OPT_LR, deferred=True)
  # cls_loss and box_loss are for logging. only total_loss is optimized.
  total_loss, cls_loss, box_loss = detection_loss(
      cls_outputs, box_outputs, labels)

  total_loss += params['weight_decay'] * tf.add_n(
      [tf.nn.l2_loss(v) for v in tf.trainable_variables()])

  host_call = None
  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.MomentumOptimizer(
        learning_rate, momentum=ssd_constants.MOMENTUM)
    if params['use_tpu']:
      optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

    mlperf_log.ssd_print(
        key=mlperf_log.OPT_NAME, value='tf.train.MomentumOptimizer')
    # TODO(wangtao): figure out how to log learning rate.
    # mlperf_log.ssd_print(key=mlperf_log.OPT_LR, value=learning_rate)
    mlperf_log.ssd_print(
        key=mlperf_log.OPT_MOMENTUM, value=ssd_constants.MOMENTUM)
    mlperf_log.ssd_print(
        key=mlperf_log.OPT_WEIGHT_DECAY, value=params['weight_decay'])

    # Batch norm requires update_ops to be added as a train_op dependency.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    if params['device'] == 'gpu':
      # GPU uses tf.group to avoid dependency overhead on update_ops; also,
      # multi-GPU requires a different EstimatorSpec class object
      train_op = tf.group(optimizer.minimize(total_loss, global_step),
                          update_ops)
      return model_fn_lib.EstimatorSpec(
          mode=mode, loss=total_loss, train_op=train_op, scaffold=None)
    else:
      with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(total_loss, global_step)

    if params['use_host_call']:
      def host_call_fn(global_step, total_loss, cls_loss, box_loss,
                       learning_rate):
        """Training host call. Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          global_step: `Tensor with shape `[batch, ]` for the global_step.
          total_loss: `Tensor` with shape `[batch, ]` for the training loss.
          cls_loss: `Tensor` with shape `[batch, ]` for the training cls loss.
          box_loss: `Tensor` with shape `[batch, ]` for the training box loss.
          learning_rate: `Tensor` with shape `[batch, ]` for the learning_rate.

        Returns:
          List of summary ops to run on the CPU host.
        """
        # Outfeed supports int32 but global_step is expected to be int64.
        global_step = tf.reduce_mean(global_step)
        # Host call fns are executed FLAGS.iterations_per_loop times after one
        # TPU loop is finished, setting max_queue value to the same as number of
        # iterations will make the summary writer only flush the data to storage
        # once per loop.
        with (tf.contrib.summary.create_file_writer(
            params['model_dir'],
            max_queue=params['iterations_per_loop']).as_default()):
          with tf.contrib.summary.always_record_summaries():
            tf.contrib.summary.scalar(
                'total_loss', tf.reduce_mean(total_loss), step=global_step)
            tf.contrib.summary.scalar(
                'cls_loss', tf.reduce_mean(cls_loss), step=global_step)
            tf.contrib.summary.scalar(
                'box_loss', tf.reduce_mean(box_loss), step=global_step)
            tf.contrib.summary.scalar(
                'learning_rate', tf.reduce_mean(learning_rate),
                step=global_step)

            return tf.contrib.summary.all_summary_ops()

      # To log the loss, current learning rate, and epoch for Tensorboard, the
      # summary op needs to be run on the host CPU via host_call. host_call
      # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
      # dimension. These Tensors are implicitly concatenated to
      # [params['batch_size']].
      global_step_t = tf.reshape(global_step, [1])
      total_loss_t = tf.reshape(total_loss, [1])
      cls_loss_t = tf.reshape(cls_loss, [1])
      box_loss_t = tf.reshape(box_loss, [1])
      learning_rate_t = tf.reshape(learning_rate, [1])
      host_call = (host_call_fn,
                   [global_step_t, total_loss_t, cls_loss_t, box_loss_t,
                    learning_rate_t])
  else:
    train_op = None

  eval_metrics = None
  if mode == tf.estimator.ModeKeys.EVAL:
    raise NotImplementedError

  return tpu_estimator.TPUEstimatorSpec(
      mode=mode,
      loss=total_loss,
      train_op=train_op,
      host_call=host_call,
      eval_metrics=eval_metrics,
      scaffold_fn=scaffold_fn)
Example #28
0
def model_fn(features, labels, mode, params):
    """Mobilenet v1 model using Estimator API."""
    num_classes = FLAGS.num_classes
    training_active = (mode == tf.estimator.ModeKeys.TRAIN)
    eval_active = (mode == tf.estimator.ModeKeys.EVAL)

    features = tensor_transform_fn(features, params['input_perm'])

    if FLAGS.clear_update_collections:
        # updates_collections must be set to None in order to use fused batchnorm
        with arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()):
            logits, end_points = mobilenet_v1.mobilenet_v1(
                features,
                num_classes,
                is_training=training_active,
                depth_multiplier=FLAGS.depth_multiplier)
    else:
        with arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()):
            logits, end_points = mobilenet_v1.mobilenet_v1(
                features,
                num_classes,
                is_training=training_active,
                depth_multiplier=FLAGS.depth_multiplier)

    predictions = {
        'classes': tf.argmax(input=logits, axis=1),
        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
    }

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    if mode == tf.estimator.ModeKeys.EVAL and FLAGS.display_tensors and (
            not FLAGS.use_tpu):
        with tf.control_dependencies([
                tf.Print(predictions['classes'], [predictions['classes']],
                         summarize=FLAGS.eval_batch_size,
                         message='prediction: ')
        ]):
            labels = tf.Print(labels, [labels],
                              summarize=FLAGS.eval_batch_size,
                              message='label: ')

    one_hot_labels = tf.one_hot(labels, FLAGS.num_classes, dtype=tf.int32)

    tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels,
                                    logits=logits,
                                    weights=1.0,
                                    label_smoothing=0.1)
    loss = tf.losses.get_total_loss(add_regularization_losses=True)

    initial_learning_rate = FLAGS.learning_rate * FLAGS.train_batch_size / 256
    final_learning_rate = 0.0001 * initial_learning_rate

    train_op = None
    if training_active:
        batches_per_epoch = _NUM_TRAIN_IMAGES // FLAGS.train_batch_size
        global_step = tf.train.get_or_create_global_step()

        learning_rate = tf.train.exponential_decay(
            learning_rate=initial_learning_rate,
            global_step=global_step,
            decay_steps=FLAGS.learning_rate_decay_epochs * batches_per_epoch,
            decay_rate=FLAGS.learning_rate_decay,
            staircase=True)

        # Set a minimum boundary for the learning rate.
        learning_rate = tf.maximum(learning_rate,
                                   final_learning_rate,
                                   name='learning_rate')

        if FLAGS.optimizer == 'sgd':
            tf.logging.info('Using SGD optimizer')
            optimizer = tf.train.GradientDescentOptimizer(
                learning_rate=learning_rate)
        elif FLAGS.optimizer == 'momentum':
            tf.logging.info('Using Momentum optimizer')
            optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                                   momentum=0.9)
        elif FLAGS.optimizer == 'RMS':
            tf.logging.info('Using RMS optimizer')
            optimizer = tf.train.RMSPropOptimizer(learning_rate,
                                                  RMSPROP_DECAY,
                                                  momentum=RMSPROP_MOMENTUM,
                                                  epsilon=RMSPROP_EPSILON)
        else:
            tf.logging.fatal('Unknown optimizer:', FLAGS.optimizer)

        if FLAGS.use_tpu:
            optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step=global_step)
        if FLAGS.moving_average:
            ema = tf.train.ExponentialMovingAverage(decay=MOVING_AVERAGE_DECAY,
                                                    num_updates=global_step)
            variables_to_average = (tf.trainable_variables() +
                                    tf.moving_average_variables())
            with tf.control_dependencies([train_op
                                          ]), tf.name_scope('moving_average'):
                train_op = ema.apply(variables_to_average)

    eval_metrics = None
    if eval_active:

        def metric_fn(labels, predictions):
            accuracy = tf.metrics.accuracy(
                labels, tf.argmax(input=predictions, axis=1))
            return {'accuracy': accuracy}

        if FLAGS.use_logits:
            eval_predictions = logits
        else:
            eval_predictions = end_points['Predictions']

        eval_metrics = (metric_fn, [labels, eval_predictions])

    return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metrics=eval_metrics)
Example #29
0
    def estimator_model_fn(cls,
                           hparams,
                           features,
                           labels,
                           mode,
                           config=None,
                           params=None,
                           decode_hparams=None,
                           use_tpu=False):
        hparams = copy.deepcopy(hparams)
        hparams.use_tpu = use_tpu
        # merge decode_hparams into hparams if present
        if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None:
            for k, v in six.iteritems(decode_hparams.values()):
                if hasattr(hparams, k) and getattr(hparams, k) != v:
                    tf.logging.warning(
                        "Overriding hparams.%s with %s from decode_hparams" %
                        (k, v))
                setattr(hparams, k, v)

        # Instantiate model
        data_parallelism = None
        if not use_tpu and config:
            data_parallelism = config.data_parallelism
        model = cls(hparams,
                    mode,
                    data_parallelism=data_parallelism,
                    decode_hparams=decode_hparams)

        global_step = tf.train.get_global_step()

        mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(hparams.layout)
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
        else:
            var_placer = None
            if data_parallelism is None or len(
                    data_parallelism.ps_devices) == 1:
                mesh_devices = [""] * mesh_shape.size
            else:
                assert len(data_parallelism.ps_devices) == mesh_shape.size
                mesh_devices = data_parallelism.ps_devices
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)
        # PREDICT mode
        if mode == tf.estimator.ModeKeys.PREDICT:
            return model.estimator_spec_predict(features, mesh, mesh_impl,
                                                use_tpu)

        logits, loss = model.mtf_model_fn(features, mesh)
        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            lr = learning_rate.learning_rate_schedule(hparams)
            tf.summary.scalar("learning_rate", lr)
            mtf_lr = mtf.import_tf_tensor(
                mesh, tf.convert_to_tensor(lr, dtype=tf.float32),
                mtf.Shape([]))
            optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr)
            update_ops = optimizer.apply_grads(var_grads,
                                               graph.trainable_variables)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if logits and mode != tf.estimator.ModeKeys.TRAIN:
            tf_logits = lowering.export_to_tf_tensor(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                hparams.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

        # EVAL mode
        if mode == tf.estimator.ModeKeys.EVAL:
            tf_logits = lowering.export_to_tf_tensor(logits)
            return model.estimator_spec_eval(features, tf_logits, labels,
                                             tf_loss, restore_hook, use_tpu)

        if use_tpu:
            # TPU host call. Important: need to be called before remove_summaries()
            if hparams.tpu_enable_host_call:
                host_call = t2t_model.create_host_call(hparams.model_dir)
            else:
                host_call = None

            t2t_model.remove_summaries()
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                host_call=host_call,
                training_hooks=[restore_hook, saver_hook])
        else:
            return tf.estimator.EstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_chief_hooks=[restore_hook, saver_hook])
Example #30
0
def model_fn(features, labels, mode, params):
    del params  # Unused.

    # If we use sampled softmax, we need an output projection.
    output_projection = None
    softmax_loss_function = None
    # Sampled softmax only makes sense if we sample less than vocabulary size.
    num_samples = 512
    num_layers = 3
    size = 256
    num_unrolled_steps = 35

    if num_samples > 0 and num_samples < target_vocab_size:
        w = tf.get_variable("proj_w", [size, target_vocab_size])
        w_t = tf.transpose(w)
        b = tf.get_variable("proj_b", [target_vocab_size])
        output_projection = (w, b)

        def sampled_loss(labels, logits):
            labels = tf.reshape(labels, [-1, 1])
            # We need to compute the sampled_softmax_loss using 32bit floats to
            # avoid numerical instabilities.
            local_w_t = tf.cast(w_t, tf.float32)
            local_b = tf.cast(b, tf.float32)
            local_inputs = tf.cast(logits, tf.float32)
            return tf.nn.sampled_softmax_loss(weights=local_w_t,
                                              biases=local_b,
                                              labels=labels,
                                              inputs=local_inputs,
                                              num_sampled=num_samples,
                                              num_classes=target_vocab_size)

        softmax_loss_function = sampled_loss

    # Create the internal multi-layer cell for our RNN.
    for l in range(num_layers):
        with tf.variable_scope("rnn_%d" % l):
            unstacked_inputs = tf.unstack(inputs,
                                          num=num_unrolled_steps,
                                          axis=0)
            cell = tf.nn.rnn_cell.BasicLSTMCell(size)
            outputs, _ = tf.nn.static_rnn(cell,
                                          unstacked_inputs,
                                          dtype=tf.float32)
            cell = tf.stack(outputs, axis=0)
            cell = tf.nn.dropout(cell, 1 - FLAGS.dropout_prob)

    # The seq2seq function: we use embedding for the input and attention.
    def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
        return tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(
            encoder_inputs,
            decoder_inputs,
            cell,
            num_encoder_symbols=source_vocab_size,
            num_decoder_symbols=target_vocab_size,
            embedding_size=size,
            output_projection=output_projection,
            feed_previous=do_decode)

    _outputs, losses = tf.contrib.legacy_seq2seq.model_with_buckets(
        encoder_inputs,
        decoder_inputs,
        targets,
        target_weights,
        buckets,
        lambda x, y: seq2seq_f(x, y, False),
        softmax_loss_function=softmax_loss_function)

    updates = None
    # Gradients and SGD update operation for training the model.
    params = tf.trainable_variables()
    gradient_norms = []
    updates = []
    opt = tpu_optimizer.CrossShardOptimizer(
        tf.train.GradientDescentOptimizer(learning_rate=learning_rate))
    for b in xrange(len(buckets)):
        gradients = opt.compute_gradients(losses[b], params)
        clipped_gradients, norm = tf.clip_by_global_norm(
            gradients, max_gradient_norm)
        gradient_norms.append(norm)
        updates.append(
            opt.apply_gradients(zip(clipped_gradients, params),
                                global_step=global_step))

    return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                          loss=total_loss,
                                          train_op=updates)