def step_fn(self, params, model):
    """Separate implementation."""
    train_batch_size = params.train_batch_size
    num_replicas = params.num_replicas
    batch_size = train_batch_size // num_replicas

    dtypes = [
        tf.bfloat16 if params.use_bfloat16 else tf.float32,
        tf.float32,
        tf.bfloat16 if params.use_bfloat16 else tf.float32,
        tf.bfloat16 if params.use_bfloat16 else tf.float32]
    shapes = [
        [batch_size, params.image_size, params.image_size, 3],
        [batch_size, params.num_classes],
        [batch_size*params.uda_data, params.image_size, params.image_size, 3],
        [batch_size*params.uda_data, params.image_size, params.image_size, 3]]

    if params.use_xla_sharding and params.num_cores_per_replica > 1:
      q = tpu_feed._PartitionedInfeedQueue(
          number_of_tuple_elements=4,
          host_id=0,
          input_partition_dims=[[1, 1, params.num_cores_per_replica, 1],
                                [1, 1],
                                [1, 1, params.num_cores_per_replica, 1],
                                [1, 1, params.num_cores_per_replica, 1],],
          device_assignment=params.device_assignment)
      q.set_tuple_types(dtypes)
      q.set_tuple_shapes(shapes)
      l_images, l_labels, u_images_ori, u_images_aug = q.generate_dequeue_op()
      l_images = xla_sharding.split(l_images, 2,
                                    params.num_cores_per_replica)
      u_images_ori = xla_sharding.split(u_images_ori, 2,
                                        params.num_cores_per_replica)
      u_images_aug = xla_sharding.split(u_images_aug, 2,
                                        params.num_cores_per_replica)
    else:
      with tf.device(tf.tpu.core(0)):
        (l_images, l_labels, u_images_ori,
         u_images_aug) = tf.raw_ops.InfeedDequeueTuple(dtypes=dtypes,
                                                       shapes=shapes)

    all_images = tf.concat([l_images, u_images_ori, u_images_aug], axis=0)
    global_step = tf.train.get_or_create_global_step()
    num_replicas = tf.cast(params.num_replicas, tf.float32)

    with tf.variable_scope(MODEL_SCOPE, reuse=tf.AUTO_REUSE):
      _, _, masks, cross_entropy = UDA.build_uda_cross_entropy(
          params, model, all_images, l_labels)

    l2_reg_rate = tf.cast(params.weight_decay / params.num_replicas, tf.float32)
    weight_dec = common_utils.get_l2_loss()
    uda_weight = params.uda_weight * tf.minimum(
        1., tf.cast(global_step, tf.float32) / float(params.uda_steps))
    total_loss = (cross_entropy['u'] * uda_weight +
                  cross_entropy['l'] +
                  weight_dec * l2_reg_rate)
    variables = tf.trainable_variables()
    gradients = tf.gradients(total_loss, variables)
    gradients = [tf.tpu.cross_replica_sum(g) for g in gradients]
    gradients, grad_norm = tf.clip_by_global_norm(gradients, params.grad_bound)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    learning_rate, optimizer = common_utils.get_optimizer(params)
    with tf.control_dependencies(update_ops):
      train_op = optimizer.apply_gradients(zip(gradients, variables),
                                           global_step=global_step)

    with tf.control_dependencies([train_op]):
      ema_train_op = common_utils.setup_ema(
          params, f'{MODEL_SCOPE}/{model.name}')

    with tf.control_dependencies([ema_train_op]):
      logs = collections.OrderedDict()
      logs['global_step'] = tf.cast(global_step, tf.float32)
      logs['loss/total'] = total_loss
      logs['loss/cross_entropy'] = cross_entropy['l']
      logs['loss/lr'] = tf.identity(learning_rate) / num_replicas
      logs['loss/grad_norm'] = tf.identity(grad_norm) / num_replicas
      logs['loss/weight_dec'] = weight_dec / num_replicas

      logs['uda/cross_entropy'] = cross_entropy['u']
      logs['uda/u_ratio'] = tf.reduce_mean(masks['u']) / num_replicas
      logs['uda/l_ratio'] = tf.reduce_mean(masks['l']) / num_replicas
      logs['uda/weight'] = uda_weight / num_replicas

      tensors = [tf.expand_dims(t, axis=0) for t in logs.values()]
      self.step_info = {k: [tf.float32, [1]] for k in logs.keys()}
      outfeed_enqueue_op = tf.cond(
          common_utils.should_log(params),
          lambda: tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors), tf.no_op)
    return outfeed_enqueue_op
  def step_fn(self, params, model):
    """Separate implementation."""
    train_batch_size = params.train_batch_size
    num_replicas = params.num_replicas
    uda_data = params.uda_data
    batch_size = train_batch_size // num_replicas

    dtypes = [
        tf.bfloat16 if params.use_bfloat16 else tf.float32,
        tf.float32,
        tf.bfloat16 if params.use_bfloat16 else tf.float32,
        tf.bfloat16 if params.use_bfloat16 else tf.float32]
    shapes = [
        [batch_size, params.image_size, params.image_size, 3],
        [batch_size, params.num_classes],
        [batch_size*params.uda_data, params.image_size, params.image_size, 3],
        [batch_size*params.uda_data, params.image_size, params.image_size, 3]]

    if params.use_xla_sharding and params.num_cores_per_replica > 1:
      q = tpu_feed._PartitionedInfeedQueue(
          number_of_tuple_elements=4,
          host_id=0,
          input_partition_dims=[[1, 1, params.num_cores_per_replica, 1],
                                [1, 1],
                                [1, 1, params.num_cores_per_replica, 1],
                                [1, 1, params.num_cores_per_replica, 1],],
          device_assignment=params.device_assignment)
      q.set_tuple_types(dtypes)
      q.set_tuple_shapes(shapes)
      l_images, l_labels, u_images_ori, u_images_aug = q.generate_dequeue_op()
      l_images = xla_sharding.split(l_images, 2,
                                    params.num_cores_per_replica)
      u_images_ori = xla_sharding.split(u_images_ori, 2,
                                        params.num_cores_per_replica)
      u_images_aug = xla_sharding.split(u_images_aug, 2,
                                        params.num_cores_per_replica)
    else:
      with tf.device(tf.tpu.core(0)):
        (l_images, l_labels, u_images_ori,
         u_images_aug) = tf.raw_ops.InfeedDequeueTuple(dtypes=dtypes,
                                                       shapes=shapes)
    global_step = tf.train.get_or_create_global_step()
    num_replicas = tf.cast(params.num_replicas, tf.float32)

    all_images = tf.concat([l_images, u_images_ori, u_images_aug], axis=0)

    # all calls to teacher
    with tf.variable_scope('teacher', reuse=tf.AUTO_REUSE):
      logits, labels, masks, cross_entropy = UDA.build_uda_cross_entropy(
          params, model, all_images, l_labels)

    # 1st call to student
    with tf.variable_scope(MODEL_SCOPE):
      u_aug_and_l_images = tf.concat([u_images_aug, l_images], axis=0)
      logits['s_on_u_aug_and_l'] = model(u_aug_and_l_images, training=True)
      logits['s_on_u'], logits['s_on_l_old'] = tf.split(
          logits['s_on_u_aug_and_l'],
          [u_images_aug.shape[0].value, l_images.shape[0].value], axis=0)

    # for backprop
    cross_entropy['s_on_u'] = tf.losses.softmax_cross_entropy(
        onehot_labels=tf.stop_gradient(tf.nn.softmax(logits['u_aug'], -1)),
        logits=logits['s_on_u'],
        label_smoothing=params.label_smoothing,
        reduction=tf.losses.Reduction.NONE)
    cross_entropy['s_on_u'] = tf.reduce_sum(cross_entropy['s_on_u']) / float(
        train_batch_size*uda_data)

    # for Taylor
    cross_entropy['s_on_l_old'] = tf.losses.softmax_cross_entropy(
        onehot_labels=labels['l'],
        logits=logits['s_on_l_old'],
        reduction=tf.losses.Reduction.SUM)
    cross_entropy['s_on_l_old'] = tf.tpu.cross_replica_sum(
        cross_entropy['s_on_l_old']) / float(train_batch_size)
    shadow = tf.get_variable(
        name='cross_entropy_old', shape=[], trainable=False, dtype=tf.float32)
    shadow_update = tf.assign(shadow, cross_entropy['s_on_l_old'])

    w_s = {}
    g_s = {}
    g_n = {}
    lr = {}
    optim = {}
    w_s['s'] = [w for w in tf.trainable_variables()
                if w.name.lower().startswith(MODEL_SCOPE)]
    g_s['s_on_u'] = tf.gradients(cross_entropy['s_on_u'], w_s['s'])
    # g_s['s_on_u'] = [tf.tpu.cross_replica_sum(g) for g in g_s['s_on_u']]

    lr['s'] = common_utils.get_learning_rate(
        params,
        initial_lr=params.mpl_student_lr,
        num_warmup_steps=params.mpl_student_lr_warmup_steps,
        num_wait_steps=params.mpl_student_lr_wait_steps)
    lr['s'], optim['s'] = common_utils.get_optimizer(
        params, learning_rate=lr['s'])
    optim['s']._create_slots(w_s['s'])  # pylint: disable=protected-access
    update_ops = [op for op in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                  if op.name.startswith(f'train/{MODEL_SCOPE}/')]

    with tf.control_dependencies(update_ops + [shadow_update]):
      g_s['s_on_u'] = common_utils.add_weight_decay(
          params, w_s['s'], g_s['s_on_u'])
      g_s['s_on_u'], g_n['s_on_u'] = tf.clip_by_global_norm(
          g_s['s_on_u'], params.grad_bound)
      train_op = optim['s'].apply_gradients(zip(g_s['s_on_u'], w_s['s']))

      with tf.control_dependencies([train_op]):
        ema_train_op = common_utils.setup_ema(
            params, name_scope=f'{MODEL_SCOPE}/{model.name}')

    # 2nd call to student
    with tf.control_dependencies([ema_train_op]):
      with tf.variable_scope(MODEL_SCOPE, reuse=tf.AUTO_REUSE):
        logits['s_on_l_new'] = model(l_images, training=True)

    cross_entropy['s_on_l_new'] = tf.losses.softmax_cross_entropy(
        onehot_labels=labels['l'],
        logits=logits['s_on_l_new'],
        reduction=tf.losses.Reduction.SUM)
    cross_entropy['s_on_l_new'] = tf.tpu.cross_replica_sum(
        cross_entropy['s_on_l_new']) / float(train_batch_size)

    dot_product = cross_entropy['s_on_l_new'] - shadow
    # dot_product = tf.clip_by_value(
    #     dot_product,
    #     clip_value_min=-params.mpl_dot_product_bound,
    #     clip_value_max=params.mpl_dot_product_bound)
    moving_dot_product = tf.get_variable(
        'moving_dot_product', shape=[], trainable=False, dtype=tf.float32)
    moving_dot_product_update = tf.assign_sub(
        moving_dot_product, 0.01 * (moving_dot_product - dot_product))
    with tf.control_dependencies([moving_dot_product_update]):
      dot_product = dot_product - moving_dot_product
      dot_product = tf.stop_gradient(dot_product)
    cross_entropy['mpl'] = tf.losses.softmax_cross_entropy(
        onehot_labels=tf.stop_gradient(tf.nn.softmax(logits['u_aug'], axis=-1)),
        logits=logits['u_aug'],
        reduction=tf.losses.Reduction.NONE)
    cross_entropy['mpl'] = tf.reduce_sum(cross_entropy['mpl']) / float(
        train_batch_size*uda_data)

    # teacher train op
    uda_weight = params.uda_weight * tf.minimum(
        1., tf.cast(global_step, tf.float32) / float(params.uda_steps))
    teacher_loss = (cross_entropy['u'] * uda_weight +
                    cross_entropy['l'] +
                    cross_entropy['mpl'] * dot_product)
    w_s['t'] = [w for w in tf.trainable_variables() if 'teacher' in w.name]
    g_s['t'] = tf.gradients(teacher_loss, w_s['t'])
    g_s['t'] = common_utils.add_weight_decay(params, w_s['t'], g_s['t'])
    g_s['t'], g_n['t'] = tf.clip_by_global_norm(g_s['t'], params.grad_bound)
    lr['t'] = common_utils.get_learning_rate(
        params,
        initial_lr=params.mpl_teacher_lr,
        num_warmup_steps=params.mpl_teacher_lr_warmup_steps)
    lr['t'], optim['t'] = common_utils.get_optimizer(params,
                                                     learning_rate=lr['t'])

    teacher_train_op = optim['t'].apply_gradients(zip(g_s['t'], w_s['t']),
                                                  global_step=global_step)

    with tf.control_dependencies([teacher_train_op]):
      logs = collections.OrderedDict()
      logs['global_step'] = tf.cast(global_step, tf.float32)

      logs['cross_entropy/student_on_u'] = cross_entropy['s_on_u']
      logs['cross_entropy/student_on_l'] = (cross_entropy['s_on_l_new'] /
                                            num_replicas)
      logs['cross_entropy/teacher_on_u'] = cross_entropy['u']
      logs['cross_entropy/teacher_on_l'] = cross_entropy['l']
      logs['lr/student'] = tf.identity(lr['s']) / num_replicas
      logs['lr/teacher'] = tf.identity(lr['t']) / num_replicas
      logs['mpl/dot_product'] = dot_product / num_replicas
      logs['mpl/moving_dot_product'] = moving_dot_product / num_replicas
      logs['uda/u_ratio'] = tf.reduce_mean(masks['u']) / num_replicas
      logs['uda/l_ratio'] = tf.reduce_mean(masks['l']) / num_replicas
      logs['uda/weight'] = uda_weight / num_replicas

      tensors = [tf.expand_dims(t, axis=0) for t in logs.values()]
      self.step_info = {k: [tf.float32, [1]] for k in logs.keys()}
      def outfeed(tensors):
        with tf.device(tf.tpu.core(params.num_cores_per_replica-1)):
          return tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors)

      outfeed_enqueue_op = tf.cond(
          common_utils.should_log(params), lambda: outfeed(tensors), tf.no_op)

      return outfeed_enqueue_op
  def step_fn(self, params, model):
    """A single step for supervised learning."""

    batch_size = params.train_batch_size // params.num_replicas
    dtypes = [tf.bfloat16 if params.use_bfloat16 else tf.float32, tf.float32]
    shapes = [[batch_size, params.image_size, params.image_size, 3],
              [batch_size, params.num_classes]]

    if params.use_xla_sharding and params.num_cores_per_replica > 1:
      q = tpu_feed._PartitionedInfeedQueue(
          number_of_tuple_elements=2,
          host_id=0,
          input_partition_dims=[[1, 1, params.num_cores_per_replica, 1],
                                [1, 1]],
          device_assignment=params.device_assignment)
      q.set_tuple_types(dtypes)
      q.set_tuple_shapes(shapes)
      images, labels = q.generate_dequeue_op()
      images = xla_sharding.split(images, 2, params.num_cores_per_replica)
    else:
      with tf.device(tf.tpu.core(0)):
        images, labels = tf.raw_ops.InfeedDequeueTuple(dtypes=dtypes,
                                                       shapes=shapes)

    if labels.dtype == tf.int32:
      labels = tf.one_hot(labels, depth=params.num_classes, dtype=tf.float32)
    global_step = tf.train.get_or_create_global_step()

    train_batch_size = tf.cast(params.train_batch_size, tf.float32)
    num_replicas = tf.cast(params.num_replicas, tf.float32)

    with tf.variable_scope(MODEL_SCOPE):
      logits = model(images, training=True)

    if 'noisy_student' in params.dataset_name.lower():
      cross_entropy = labels * tf.nn.log_softmax(logits, axis=-1)
      cross_entropy = tf.reduce_sum(-cross_entropy) / train_batch_size
    else:
      cross_entropy = tf.losses.softmax_cross_entropy(
          onehot_labels=labels, logits=logits,
          label_smoothing=params.label_smoothing,
          reduction=tf.losses.Reduction.SUM) / train_batch_size

    l2_reg_rate = tf.cast(params.weight_decay / params.num_replicas, tf.float32)
    weight_dec = common_utils.get_l2_loss()
    total_loss = cross_entropy + weight_dec * l2_reg_rate

    variables = tf.trainable_variables()
    gradients = tf.gradients(total_loss, variables)
    gradients = [tf.tpu.cross_replica_sum(g) for g in gradients]
    gradients, grad_norm = tf.clip_by_global_norm(gradients, params.grad_bound)

    learning_rate, optimizer = common_utils.get_optimizer(params)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    train_op = tf.cond(
        tf.math.is_finite(grad_norm),
        lambda: optimizer.apply_gradients(zip(gradients, variables),
                                          global_step=global_step),
        tf.no_op)
    with tf.control_dependencies(update_ops + [train_op]):
      ema_train_op = common_utils.setup_ema(params,
                                            f'{MODEL_SCOPE}/{model.name}')

    with tf.control_dependencies([ema_train_op]):
      logs = collections.OrderedDict()
      logs['global_step'] = tf.cast(global_step, tf.float32)
      logs['loss/total'] = total_loss
      logs['loss/weight_decay'] = weight_dec / num_replicas
      logs['loss/cross_entropy'] = cross_entropy
      logs['loss/lr'] = tf.identity(learning_rate) / num_replicas
      logs['loss/grad_norm'] = grad_norm / num_replicas

      tensors = [tf.expand_dims(t, axis=0) for t in logs.values()]
      self.step_info = {k: [tf.float32, [1]] for k in logs.keys()}
      outfeed_enqueue_op = tf.cond(
          common_utils.should_log(params),
          lambda: tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors), tf.no_op)
    return outfeed_enqueue_op