コード例 #1
0
            def enqueue_fn(host_device=host,
                           input_index=i,
                           device_ordinals=tpus):
                """Docs."""
                worker_infeed_ops = []
                with tf.device(host_device):
                    dataset = build_eval_dataset(
                        params,
                        batch_size=params.eval_batch_size // num_inputs,
                        num_workers=num_inputs,
                        worker_index=input_index)
                    inputs = tf.data.make_one_shot_iterator(dataset).get_next()

                    if params.use_xla_sharding and params.num_cores_per_replica > 1:
                        inputs, partition_dims = pad_inputs_for_xla_sharding(
                            params, inputs)
                        num_splits = len(device_ordinals)
                        if len(device_ordinals) > 1:
                            inputs = [
                                tf.split(v, num_splits, 0) for v in inputs
                            ]
                        else:
                            inputs = [[v] for v in inputs]

                        q = tpu_feed._PartitionedInfeedQueue(
                            number_of_tuple_elements=len(inputs),
                            host_id=int(
                                host_device.split('/task:')[-1].split('/')[0]),
                            input_partition_dims=partition_dims,
                            device_assignment=dev_assign)
                        inputs = [[v[i] for v in inputs]
                                  for i in range(num_splits)]
                        worker_infeed_ops.extend(
                            q.generate_enqueue_ops(inputs))
                    else:
                        num_splits = len(device_ordinals)
                        if len(device_ordinals) > 1:
                            inputs = [
                                tf.split(v, num_splits, 0) for v in inputs
                            ]
                        else:
                            inputs = [[v] for v in inputs]
                        input_shapes = [v[0].shape for v in inputs]
                        for j, device_ordinal in enumerate(device_ordinals):
                            worker_infeed_ops.append(
                                tf.raw_ops.InfeedEnqueueTuple(
                                    inputs=[v[j] for v in inputs],
                                    shapes=input_shapes,
                                    device_ordinal=device_ordinal))
                return worker_infeed_ops
コード例 #2
0
 def test_infeed_uneven_partition(self):
   """Tests uneven infeed tensors partition."""
   ds = device_assignment(
       self._topology_2x2x2, num_replicas=1, computation_shape=[2, 2, 1, 2])
   input_partition_dims = [[4, 2]]
   # pylint: disable=protected-access
   partitioned_infeed = tpu_feed._PartitionedInfeedQueue(
       number_of_tuple_elements=1,
       host_id=0,
       input_partition_dims=input_partition_dims,
       device_assignment=ds)
   x = array_ops.zeros((14, 5))
   tensors = partitioned_infeed._check_dims_and_partition_or_replicate_on_host(
       x, dims=input_partition_dims[0])
   self.assertEqual(8, len(tensors))
   self.assertEqual((2, 2), tensors[-1].shape)
コード例 #3
0
 def test_infeed_tailing_zero_partition(self):
   """Tests infeed tensors partition which causes zero-size tensors."""
   ds = device_assignment(
       self._topology_2x2x2, num_replicas=1, computation_shape=[1, 2, 1, 2])
   input_partition_dims = [[4, 1]]
   # pylint: disable=protected-access
   partitioned_infeed = tpu_feed._PartitionedInfeedQueue(
       number_of_tuple_elements=1,
       host_id=0,
       input_partition_dims=input_partition_dims,
       device_assignment=ds)
   x = array_ops.zeros((5, 5))
   tensors = partitioned_infeed._check_dims_and_partition_or_replicate_on_host(
       x, dims=input_partition_dims[0])
   self.assertEqual(4, len(tensors))
   self.assertEqual((1, 5), tensors[2].shape)
   self.assertEqual((0, 5), tensors[3].shape)
コード例 #4
0
    def _config_infeed(self,
                       num_partitions,
                       device_assignment,
                       batch_size,
                       key_size=2,
                       return_tgt_mask=False,
                       use_partitioned_infeed_queue=False):
        """Config the infeed ops and args."""
        zero_batch = get_zero_batch(batch_size=batch_size,
                                    max_len=self._prefix_max_len,
                                    key_size=key_size,
                                    return_tgt_mask=return_tgt_mask)

        host_device = device_assignment.host_device(replica=0, job=self._tpu)
        host_id = int(host_device.split('/task:')[1].split('/device:')[0])
        input_partition_dims = [[num_partitions] + [1] * (len(x.shape) - 1)
                                for x in zero_batch]

        if use_partitioned_infeed_queue:
            infeed = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                number_of_tuple_elements=len(zero_batch),
                host_id=host_id,
                input_partition_dims=input_partition_dims,
                device_assignment=device_assignment)
        else:
            infeed = tpu_feed.InfeedQueue(
                number_of_tuple_elements=len(zero_batch))

        self.infeed_args = []
        for x in zero_batch:
            p = tf.placeholder(tf.as_dtype(x.dtype), shape=x.shape)
            self.infeed_args += [p]
        if use_partitioned_infeed_queue:
            self.infeed_op = infeed.generate_enqueue_ops([self.infeed_args])
        else:
            self.infeed_op = infeed.split_inputs_and_generate_enqueue_ops(
                self.infeed_args, device_assignment=device_assignment)
        return infeed
コード例 #5
0
def eval_step_fn(params, model):
    """Build `step_fn` for eval."""
    dtypes = [
        tf.bfloat16 if params.use_bfloat16 else tf.float32, tf.float32,
        tf.float32
    ]
    batch_size = params.eval_batch_size // params.num_replicas
    image_size = (params.eval_image_size
                  if 'eval_image_size' in params else params.image_size)
    shapes = [[batch_size, image_size, image_size, 3],
              [batch_size, params.num_classes], [batch_size]]

    if params.use_xla_sharding and params.num_cores_per_replica > 1:
        q = tpu_feed._PartitionedInfeedQueue(
            number_of_tuple_elements=3,
            host_id=0,
            input_partition_dims=[[1, 1, params.num_cores_per_replica, 1],
                                  [1, 1], [1]],
            device_assignment=params.device_assignment)
        q.set_tuple_types(dtypes)
        q.set_tuple_shapes(shapes)
        images, labels, mask = q.generate_dequeue_op()
        images = xla_sharding.split(images, 2, params.num_cores_per_replica)
    else:
        with tf.device(tf.tpu.core(0)):
            images, labels, mask = tf.raw_ops.InfeedDequeueTuple(dtypes=dtypes,
                                                                 shapes=shapes)

    if len(labels.shape) > 1:  # `labels` is one_hot. turn it to `int.32`
        labels = tf.argmax(labels, axis=-1, output_type=tf.int32)
        labels = tf.expand_dims(labels, axis=-1)
    _ = tf.train.get_or_create_global_step()

    with tf.variable_scope(MODEL_SCOPE):
        logits = model(images, training=False)
        logits = tf.cast(logits, tf.float32)

    return logits, labels, mask
コード例 #6
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info(
            'CreateTPUFeeds num_splits_per_client={} '
            'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'.
            format(cluster.num_splits_per_client,
                   cluster.num_devices_per_split, num_tpu_hosts,
                   p.use_per_host_infeed))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            tf.logging.info('shards {}'.format(shards))

            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            if num_tpu_hosts > 1 and tpu_embedding is not None:
                if not p.use_per_host_infeed:
                    tf.logging.fatal(
                        'TPU Embedding must be used with per_host_infeed with multiple '
                        'TPU host topologies.')
            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if isinstance(batch, py_utils.NestedMap):
                        # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                        # Note that when MultiTaskData is used, bucket_keys will be at the
                        # second level of the dictionary.
                        batch = batch.FilterKeyVal(
                            lambda k, _: not k.endswith('bucket_keys'))
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    if p.use_partitioned_infeed_queue:
                        device_assignment = py_utils.GetTpuDeviceAssignment()

                        host_device = device_assignment.host_device(
                            replica=0, job=tf.flags.FLAGS.tf_master)
                        host_id = int(
                            host_device.split('/task:')[1].split('/device:')
                            [0])
                        tf.logging.info('host_id: {} host_device: {}'.format(
                            host_id, host_device))
                        q = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                            number_of_tuple_elements=len(dtypes),
                            device_assignment=device_assignment,
                            host_id=host_id,
                            input_partition_dims=[[p.num_partitions, 1]
                                                  for _ in dtypes],
                            tuple_types=dtypes,
                            tuple_shapes=shapes)
                    else:
                        q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                                 tuple_shapes=shapes)
                        assert shards is not None
                        q.set_number_of_shards(shards)

                    queues.append(q)
                    tf.logging.info('q=%r', q)

                    if p.use_partitioned_infeed_queue:
                        input_ops = q.generate_enqueue_ops([batch.Flatten()])
                    elif p.use_per_host_infeed:
                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        self._tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
コード例 #7
0
    def CreateTpuEnqueueOps(self):
        """Create the host-side enqueue ops.

    This should be called in an outer non-TPU context.
    """
        assert not self._tpu_queues, (
            'CreateTpuEnqueueOps should only be called '
            'once.')
        self._tpu_queues = []
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info(
            'CreateTpuEnqueueOps num_splits_per_client={} '
            'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'.
            format(cluster.num_splits_per_client,
                   cluster.num_devices_per_split, num_tpu_hosts,
                   p.use_per_host_infeed))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        shards = (cluster.total_worker_devices //
                  num_infeed_hosts) // cluster.num_devices_per_split
        tf.logging.info('shards {}'.format(shards))

        input_ops_list = []
        tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING)
        tpu_embedding = (tpu_embedding_collection[0]
                         if tpu_embedding_collection else None)

        if num_tpu_hosts > 1 and tpu_embedding is not None:
            if not p.use_per_host_infeed:
                tf.logging.fatal(
                    'TPU Embedding must be used with per_host_infeed with multiple '
                    'TPU host topologies.')

        tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys())
                              if tpu_embedding is not None else [])
        tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)
        tf.logging.info('num_infeed_hosts: %d', num_infeed_hosts)

        for task_id in range(num_infeed_hosts):
            host_device = '/task:{}/device:CPU:0'.format(task_id)
            with tf.device(host_device):
                self._batch = self.GetPreprocessedInputBatch()
                if isinstance(self._batch, py_utils.NestedMap):
                    # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                    # Note that when MultiTaskData is used, bucket_keys will be at the
                    # second level of the dictionary.
                    self._batch = self._batch.FilterKeyVal(
                        lambda k, _: not k.endswith('bucket_keys'))
                tf.logging.info('host_device: %s, batch: %r', host_device,
                                self._batch)

                for k, x in self._batch.FlattenItems():
                    assert x.shape.is_fully_defined(), (
                        'Shape must be fully defined: %s: %s' % (k, x))
                    # TODO(cwhipkey): if it's a string (or other type not supported on
                    # TPU), drop it from feeding and on the other end add in an op that
                    # fails if used.
                shapes = self._batch.Transform(lambda x: x.shape).Flatten()
                dtypes = self._batch.Transform(lambda x: x.dtype).Flatten()

                tf.logging.info('host_device: %s infeed shapes: %r',
                                host_device, shapes)
                tf.logging.info('host_device: %s infeed dtypes: %r',
                                host_device, dtypes)

                if p.use_partitioned_infeed_queue:
                    device_assignment = py_utils.GetTpuDeviceAssignment()

                    host_device = device_assignment.host_device(
                        replica=0, job=tf.flags.FLAGS.tf_master)
                    host_id = int(
                        host_device.split('/task:')[1].split('/device:')[0])
                    tf.logging.info('host_id: {} host_device: {}'.format(
                        host_id, host_device))
                    q = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                        number_of_tuple_elements=len(dtypes),
                        device_assignment=device_assignment,
                        host_id=host_id,
                        input_partition_dims=[[p.num_partitions] + [1] *
                                              (len(s) - 1) for s in shapes],
                        tuple_types=dtypes,
                        tuple_shapes=shapes)
                else:
                    q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                             tuple_shapes=shapes)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                self._tpu_queues.append(q)

                if p.use_partitioned_infeed_queue:
                    input_ops = q.generate_enqueue_ops([self._batch.Flatten()])
                elif p.use_per_host_infeed:
                    # TODO(ylc/zhifengc): Add this to a policy module and test it.
                    def TPUOrdinalFunction(shard_index_in_host):
                        device_assignment = py_utils.GetTpuDeviceAssignment()
                        if device_assignment:
                            # We put both enqueue/dequeue ops at core 0 in each replica.
                            replica = device_assignment.lookup_replicas(
                                task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                            return device_assignment.tpu_ordinal(
                                replica=replica)
                        else:
                            return shard_index_in_host

                    input_ops = q.split_inputs_and_generate_enqueue_ops(
                        self._batch.Flatten(),
                        placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                        tpu_ordinal_function=TPUOrdinalFunction)
                else:
                    input_ops = q.split_inputs_and_generate_enqueue_ops(
                        self._batch.Flatten(),
                        device_assignment=py_utils.GetTpuDeviceAssignment())
                input_ops_list += input_ops

        tf.logging.info('input_ops_list %s', input_ops_list)
        grouped_infeed_op = tf.group(*input_ops_list)
        self._tpu_infeed_op = []
        for _ in range(p.tpu_infeed_parallelism):
            self._tpu_infeed_op.append(grouped_infeed_op)
コード例 #8
0
  def step_fn(self, params, model):
    """Separate implementation."""
    train_batch_size = params.train_batch_size
    num_replicas = params.num_replicas
    uda_data = params.uda_data
    batch_size = train_batch_size // num_replicas

    dtypes = [
        tf.bfloat16 if params.use_bfloat16 else tf.float32,
        tf.float32,
        tf.bfloat16 if params.use_bfloat16 else tf.float32,
        tf.bfloat16 if params.use_bfloat16 else tf.float32]
    shapes = [
        [batch_size, params.image_size, params.image_size, 3],
        [batch_size, params.num_classes],
        [batch_size*params.uda_data, params.image_size, params.image_size, 3],
        [batch_size*params.uda_data, params.image_size, params.image_size, 3]]

    if params.use_xla_sharding and params.num_cores_per_replica > 1:
      q = tpu_feed._PartitionedInfeedQueue(
          number_of_tuple_elements=4,
          host_id=0,
          input_partition_dims=[[1, 1, params.num_cores_per_replica, 1],
                                [1, 1],
                                [1, 1, params.num_cores_per_replica, 1],
                                [1, 1, params.num_cores_per_replica, 1],],
          device_assignment=params.device_assignment)
      q.set_tuple_types(dtypes)
      q.set_tuple_shapes(shapes)
      l_images, l_labels, u_images_ori, u_images_aug = q.generate_dequeue_op()
      l_images = xla_sharding.split(l_images, 2,
                                    params.num_cores_per_replica)
      u_images_ori = xla_sharding.split(u_images_ori, 2,
                                        params.num_cores_per_replica)
      u_images_aug = xla_sharding.split(u_images_aug, 2,
                                        params.num_cores_per_replica)
    else:
      with tf.device(tf.tpu.core(0)):
        (l_images, l_labels, u_images_ori,
         u_images_aug) = tf.raw_ops.InfeedDequeueTuple(dtypes=dtypes,
                                                       shapes=shapes)
    global_step = tf.train.get_or_create_global_step()
    num_replicas = tf.cast(params.num_replicas, tf.float32)

    all_images = tf.concat([l_images, u_images_ori, u_images_aug], axis=0)

    # all calls to teacher
    with tf.variable_scope('teacher', reuse=tf.AUTO_REUSE):
      logits, labels, masks, cross_entropy = UDA.build_uda_cross_entropy(
          params, model, all_images, l_labels)

    # 1st call to student
    with tf.variable_scope(MODEL_SCOPE):
      u_aug_and_l_images = tf.concat([u_images_aug, l_images], axis=0)
      logits['s_on_u_aug_and_l'] = model(u_aug_and_l_images, training=True)
      logits['s_on_u'], logits['s_on_l_old'] = tf.split(
          logits['s_on_u_aug_and_l'],
          [u_images_aug.shape[0].value, l_images.shape[0].value], axis=0)

    # for backprop
    cross_entropy['s_on_u'] = tf.losses.softmax_cross_entropy(
        onehot_labels=tf.stop_gradient(tf.nn.softmax(logits['u_aug'], -1)),
        logits=logits['s_on_u'],
        label_smoothing=params.label_smoothing,
        reduction=tf.losses.Reduction.NONE)
    cross_entropy['s_on_u'] = tf.reduce_sum(cross_entropy['s_on_u']) / float(
        train_batch_size*uda_data)

    # for Taylor
    cross_entropy['s_on_l_old'] = tf.losses.softmax_cross_entropy(
        onehot_labels=labels['l'],
        logits=logits['s_on_l_old'],
        reduction=tf.losses.Reduction.SUM)
    cross_entropy['s_on_l_old'] = tf.tpu.cross_replica_sum(
        cross_entropy['s_on_l_old']) / float(train_batch_size)
    shadow = tf.get_variable(
        name='cross_entropy_old', shape=[], trainable=False, dtype=tf.float32)
    shadow_update = tf.assign(shadow, cross_entropy['s_on_l_old'])

    w_s = {}
    g_s = {}
    g_n = {}
    lr = {}
    optim = {}
    w_s['s'] = [w for w in tf.trainable_variables()
                if w.name.lower().startswith(MODEL_SCOPE)]
    g_s['s_on_u'] = tf.gradients(cross_entropy['s_on_u'], w_s['s'])
    # g_s['s_on_u'] = [tf.tpu.cross_replica_sum(g) for g in g_s['s_on_u']]

    lr['s'] = common_utils.get_learning_rate(
        params,
        initial_lr=params.mpl_student_lr,
        num_warmup_steps=params.mpl_student_lr_warmup_steps,
        num_wait_steps=params.mpl_student_lr_wait_steps)
    lr['s'], optim['s'] = common_utils.get_optimizer(
        params, learning_rate=lr['s'])
    optim['s']._create_slots(w_s['s'])  # pylint: disable=protected-access
    update_ops = [op for op in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                  if op.name.startswith(f'train/{MODEL_SCOPE}/')]

    with tf.control_dependencies(update_ops + [shadow_update]):
      g_s['s_on_u'] = common_utils.add_weight_decay(
          params, w_s['s'], g_s['s_on_u'])
      g_s['s_on_u'], g_n['s_on_u'] = tf.clip_by_global_norm(
          g_s['s_on_u'], params.grad_bound)
      train_op = optim['s'].apply_gradients(zip(g_s['s_on_u'], w_s['s']))

      with tf.control_dependencies([train_op]):
        ema_train_op = common_utils.setup_ema(
            params, name_scope=f'{MODEL_SCOPE}/{model.name}')

    # 2nd call to student
    with tf.control_dependencies([ema_train_op]):
      with tf.variable_scope(MODEL_SCOPE, reuse=tf.AUTO_REUSE):
        logits['s_on_l_new'] = model(l_images, training=True)

    cross_entropy['s_on_l_new'] = tf.losses.softmax_cross_entropy(
        onehot_labels=labels['l'],
        logits=logits['s_on_l_new'],
        reduction=tf.losses.Reduction.SUM)
    cross_entropy['s_on_l_new'] = tf.tpu.cross_replica_sum(
        cross_entropy['s_on_l_new']) / float(train_batch_size)

    dot_product = cross_entropy['s_on_l_new'] - shadow
    # dot_product = tf.clip_by_value(
    #     dot_product,
    #     clip_value_min=-params.mpl_dot_product_bound,
    #     clip_value_max=params.mpl_dot_product_bound)
    moving_dot_product = tf.get_variable(
        'moving_dot_product', shape=[], trainable=False, dtype=tf.float32)
    moving_dot_product_update = tf.assign_sub(
        moving_dot_product, 0.01 * (moving_dot_product - dot_product))
    with tf.control_dependencies([moving_dot_product_update]):
      dot_product = dot_product - moving_dot_product
      dot_product = tf.stop_gradient(dot_product)
    cross_entropy['mpl'] = tf.losses.softmax_cross_entropy(
        onehot_labels=tf.stop_gradient(tf.nn.softmax(logits['u_aug'], axis=-1)),
        logits=logits['u_aug'],
        reduction=tf.losses.Reduction.NONE)
    cross_entropy['mpl'] = tf.reduce_sum(cross_entropy['mpl']) / float(
        train_batch_size*uda_data)

    # teacher train op
    uda_weight = params.uda_weight * tf.minimum(
        1., tf.cast(global_step, tf.float32) / float(params.uda_steps))
    teacher_loss = (cross_entropy['u'] * uda_weight +
                    cross_entropy['l'] +
                    cross_entropy['mpl'] * dot_product)
    w_s['t'] = [w for w in tf.trainable_variables() if 'teacher' in w.name]
    g_s['t'] = tf.gradients(teacher_loss, w_s['t'])
    g_s['t'] = common_utils.add_weight_decay(params, w_s['t'], g_s['t'])
    g_s['t'], g_n['t'] = tf.clip_by_global_norm(g_s['t'], params.grad_bound)
    lr['t'] = common_utils.get_learning_rate(
        params,
        initial_lr=params.mpl_teacher_lr,
        num_warmup_steps=params.mpl_teacher_lr_warmup_steps)
    lr['t'], optim['t'] = common_utils.get_optimizer(params,
                                                     learning_rate=lr['t'])

    teacher_train_op = optim['t'].apply_gradients(zip(g_s['t'], w_s['t']),
                                                  global_step=global_step)

    with tf.control_dependencies([teacher_train_op]):
      logs = collections.OrderedDict()
      logs['global_step'] = tf.cast(global_step, tf.float32)

      logs['cross_entropy/student_on_u'] = cross_entropy['s_on_u']
      logs['cross_entropy/student_on_l'] = (cross_entropy['s_on_l_new'] /
                                            num_replicas)
      logs['cross_entropy/teacher_on_u'] = cross_entropy['u']
      logs['cross_entropy/teacher_on_l'] = cross_entropy['l']
      logs['lr/student'] = tf.identity(lr['s']) / num_replicas
      logs['lr/teacher'] = tf.identity(lr['t']) / num_replicas
      logs['mpl/dot_product'] = dot_product / num_replicas
      logs['mpl/moving_dot_product'] = moving_dot_product / num_replicas
      logs['uda/u_ratio'] = tf.reduce_mean(masks['u']) / num_replicas
      logs['uda/l_ratio'] = tf.reduce_mean(masks['l']) / num_replicas
      logs['uda/weight'] = uda_weight / num_replicas

      tensors = [tf.expand_dims(t, axis=0) for t in logs.values()]
      self.step_info = {k: [tf.float32, [1]] for k in logs.keys()}
      def outfeed(tensors):
        with tf.device(tf.tpu.core(params.num_cores_per_replica-1)):
          return tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors)

      outfeed_enqueue_op = tf.cond(
          common_utils.should_log(params), lambda: outfeed(tensors), tf.no_op)

      return outfeed_enqueue_op
コード例 #9
0
  def step_fn(self, params, model):
    """Separate implementation."""
    train_batch_size = params.train_batch_size
    num_replicas = params.num_replicas
    batch_size = train_batch_size // num_replicas

    dtypes = [
        tf.bfloat16 if params.use_bfloat16 else tf.float32,
        tf.float32,
        tf.bfloat16 if params.use_bfloat16 else tf.float32,
        tf.bfloat16 if params.use_bfloat16 else tf.float32]
    shapes = [
        [batch_size, params.image_size, params.image_size, 3],
        [batch_size, params.num_classes],
        [batch_size*params.uda_data, params.image_size, params.image_size, 3],
        [batch_size*params.uda_data, params.image_size, params.image_size, 3]]

    if params.use_xla_sharding and params.num_cores_per_replica > 1:
      q = tpu_feed._PartitionedInfeedQueue(
          number_of_tuple_elements=4,
          host_id=0,
          input_partition_dims=[[1, 1, params.num_cores_per_replica, 1],
                                [1, 1],
                                [1, 1, params.num_cores_per_replica, 1],
                                [1, 1, params.num_cores_per_replica, 1],],
          device_assignment=params.device_assignment)
      q.set_tuple_types(dtypes)
      q.set_tuple_shapes(shapes)
      l_images, l_labels, u_images_ori, u_images_aug = q.generate_dequeue_op()
      l_images = xla_sharding.split(l_images, 2,
                                    params.num_cores_per_replica)
      u_images_ori = xla_sharding.split(u_images_ori, 2,
                                        params.num_cores_per_replica)
      u_images_aug = xla_sharding.split(u_images_aug, 2,
                                        params.num_cores_per_replica)
    else:
      with tf.device(tf.tpu.core(0)):
        (l_images, l_labels, u_images_ori,
         u_images_aug) = tf.raw_ops.InfeedDequeueTuple(dtypes=dtypes,
                                                       shapes=shapes)

    all_images = tf.concat([l_images, u_images_ori, u_images_aug], axis=0)
    global_step = tf.train.get_or_create_global_step()
    num_replicas = tf.cast(params.num_replicas, tf.float32)

    with tf.variable_scope(MODEL_SCOPE, reuse=tf.AUTO_REUSE):
      _, _, masks, cross_entropy = UDA.build_uda_cross_entropy(
          params, model, all_images, l_labels)

    l2_reg_rate = tf.cast(params.weight_decay / params.num_replicas, tf.float32)
    weight_dec = common_utils.get_l2_loss()
    uda_weight = params.uda_weight * tf.minimum(
        1., tf.cast(global_step, tf.float32) / float(params.uda_steps))
    total_loss = (cross_entropy['u'] * uda_weight +
                  cross_entropy['l'] +
                  weight_dec * l2_reg_rate)
    variables = tf.trainable_variables()
    gradients = tf.gradients(total_loss, variables)
    gradients = [tf.tpu.cross_replica_sum(g) for g in gradients]
    gradients, grad_norm = tf.clip_by_global_norm(gradients, params.grad_bound)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    learning_rate, optimizer = common_utils.get_optimizer(params)
    with tf.control_dependencies(update_ops):
      train_op = optimizer.apply_gradients(zip(gradients, variables),
                                           global_step=global_step)

    with tf.control_dependencies([train_op]):
      ema_train_op = common_utils.setup_ema(
          params, f'{MODEL_SCOPE}/{model.name}')

    with tf.control_dependencies([ema_train_op]):
      logs = collections.OrderedDict()
      logs['global_step'] = tf.cast(global_step, tf.float32)
      logs['loss/total'] = total_loss
      logs['loss/cross_entropy'] = cross_entropy['l']
      logs['loss/lr'] = tf.identity(learning_rate) / num_replicas
      logs['loss/grad_norm'] = tf.identity(grad_norm) / num_replicas
      logs['loss/weight_dec'] = weight_dec / num_replicas

      logs['uda/cross_entropy'] = cross_entropy['u']
      logs['uda/u_ratio'] = tf.reduce_mean(masks['u']) / num_replicas
      logs['uda/l_ratio'] = tf.reduce_mean(masks['l']) / num_replicas
      logs['uda/weight'] = uda_weight / num_replicas

      tensors = [tf.expand_dims(t, axis=0) for t in logs.values()]
      self.step_info = {k: [tf.float32, [1]] for k in logs.keys()}
      outfeed_enqueue_op = tf.cond(
          common_utils.should_log(params),
          lambda: tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors), tf.no_op)
    return outfeed_enqueue_op
コード例 #10
0
  def step_fn(self, params, model):
    """A single step for supervised learning."""

    batch_size = params.train_batch_size // params.num_replicas
    dtypes = [tf.bfloat16 if params.use_bfloat16 else tf.float32, tf.float32]
    shapes = [[batch_size, params.image_size, params.image_size, 3],
              [batch_size, params.num_classes]]

    if params.use_xla_sharding and params.num_cores_per_replica > 1:
      q = tpu_feed._PartitionedInfeedQueue(
          number_of_tuple_elements=2,
          host_id=0,
          input_partition_dims=[[1, 1, params.num_cores_per_replica, 1],
                                [1, 1]],
          device_assignment=params.device_assignment)
      q.set_tuple_types(dtypes)
      q.set_tuple_shapes(shapes)
      images, labels = q.generate_dequeue_op()
      images = xla_sharding.split(images, 2, params.num_cores_per_replica)
    else:
      with tf.device(tf.tpu.core(0)):
        images, labels = tf.raw_ops.InfeedDequeueTuple(dtypes=dtypes,
                                                       shapes=shapes)

    if labels.dtype == tf.int32:
      labels = tf.one_hot(labels, depth=params.num_classes, dtype=tf.float32)
    global_step = tf.train.get_or_create_global_step()

    train_batch_size = tf.cast(params.train_batch_size, tf.float32)
    num_replicas = tf.cast(params.num_replicas, tf.float32)

    with tf.variable_scope(MODEL_SCOPE):
      logits = model(images, training=True)

    if 'noisy_student' in params.dataset_name.lower():
      cross_entropy = labels * tf.nn.log_softmax(logits, axis=-1)
      cross_entropy = tf.reduce_sum(-cross_entropy) / train_batch_size
    else:
      cross_entropy = tf.losses.softmax_cross_entropy(
          onehot_labels=labels, logits=logits,
          label_smoothing=params.label_smoothing,
          reduction=tf.losses.Reduction.SUM) / train_batch_size

    l2_reg_rate = tf.cast(params.weight_decay / params.num_replicas, tf.float32)
    weight_dec = common_utils.get_l2_loss()
    total_loss = cross_entropy + weight_dec * l2_reg_rate

    variables = tf.trainable_variables()
    gradients = tf.gradients(total_loss, variables)
    gradients = [tf.tpu.cross_replica_sum(g) for g in gradients]
    gradients, grad_norm = tf.clip_by_global_norm(gradients, params.grad_bound)

    learning_rate, optimizer = common_utils.get_optimizer(params)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    train_op = tf.cond(
        tf.math.is_finite(grad_norm),
        lambda: optimizer.apply_gradients(zip(gradients, variables),
                                          global_step=global_step),
        tf.no_op)
    with tf.control_dependencies(update_ops + [train_op]):
      ema_train_op = common_utils.setup_ema(params,
                                            f'{MODEL_SCOPE}/{model.name}')

    with tf.control_dependencies([ema_train_op]):
      logs = collections.OrderedDict()
      logs['global_step'] = tf.cast(global_step, tf.float32)
      logs['loss/total'] = total_loss
      logs['loss/weight_decay'] = weight_dec / num_replicas
      logs['loss/cross_entropy'] = cross_entropy
      logs['loss/lr'] = tf.identity(learning_rate) / num_replicas
      logs['loss/grad_norm'] = grad_norm / num_replicas

      tensors = [tf.expand_dims(t, axis=0) for t in logs.values()]
      self.step_info = {k: [tf.float32, [1]] for k in logs.keys()}
      outfeed_enqueue_op = tf.cond(
          common_utils.should_log(params),
          lambda: tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors), tf.no_op)
    return outfeed_enqueue_op
コード例 #11
0
                def enqueue_ops_fn(idx):
                    """Generate the infeed enqueue ops graph."""

                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(FLAGS.replicas_per_host):
                        with tf.control_dependencies(control_deps):
                            self.feature_structure[
                                is_training] = iterator.get_next()
                        self.maybe_capture_embedding_inputs(
                            self.feature_structure[is_training], is_training)
                        flattened_inputs = tf.nest.flatten(
                            self.feature_structure[is_training])
                        control_deps.extend(flattened_inputs)
                        if input_partition_dims:
                            padded_inputs = []
                            for inp in flattened_inputs:
                                if inp.shape.ndims < len(input_partition_dims):
                                    padded_inputs.append(inp)
                                    continue
                                paddings = []
                                for i, j in enumerate(input_partition_dims):
                                    r = inp.shape.as_list()[i] % j
                                    if r > 0:
                                        paddings.append([0, j - r])
                                    else:
                                        paddings.append([0, 0])
                                for i in range(inp.shape.ndims -
                                               len(input_partition_dims)):
                                    paddings.append([0, 0])
                                padded_inputs.append(tf.pad(inp, paddings))
                            per_host_sharded_inputs.append(padded_inputs)
                        else:
                            per_host_sharded_inputs.append(flattened_inputs)

                    if input_partition_dims:
                        flattened_input_dims = []
                        for i in per_host_sharded_inputs[0]:
                            if i.shape.ndims == len(input_partition_dims):
                                flattened_input_dims.append(
                                    input_partition_dims)
                            elif i.shape.ndims > len(input_partition_dims):
                                flattened_input_dims.append(
                                    input_partition_dims + [1] *
                                    (i.shape.ndims - len(input_partition_dims))
                                )
                            else:
                                flattened_input_dims.append([1] *
                                                            i.shape.ndims)
                        # pylint: disable=protected-access
                        self.infeed_op[
                            is_training] = tpu_feed._PartitionedInfeedQueue(
                                number_of_tuple_elements=len(
                                    per_host_sharded_inputs[0]),
                                host_id=host_id,
                                input_partition_dims=flattened_input_dims,
                                device_assignment=self.device_assignment)
                        with tf.control_dependencies(
                                self.infeed_op[is_training].
                                generate_enqueue_ops(per_host_sharded_inputs)):
                            return idx + 1
                    else:
                        self.infeed_op[is_training] = tpu_feed.InfeedQueue(
                            number_of_tuple_elements=len(
                                per_host_sharded_inputs[0]))
                        per_host_enqueue_ops = (
                            self.infeed_op[is_training].generate_enqueue_ops(
                                per_host_sharded_inputs,
                                tpu_ordinal_function=_tpu_ordinal_fn))

                    self.maybe_add_embedding_enqueue_ops_int(
                        is_training, per_host_enqueue_ops)
                    with tf.control_dependencies(per_host_enqueue_ops):
                        return idx + 1