def create_test_iterator(hparams, mode): """Create test iterator.""" src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant([hparams.eos, "a", "b", "c", "d"])) tgt_vocab_mapping = tf.constant([hparams.sos, hparams.eos, "a", "b", "c"]) tgt_vocab_table = lookup_ops.index_table_from_tensor(tgt_vocab_mapping) if mode == tf.contrib.learn.ModeKeys.INFER: reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_tensor( tgt_vocab_mapping) src_dataset = tf.contrib.data.Dataset.from_tensor_slices( tf.constant(["a a b b c", "a b b"])) if mode != tf.contrib.learn.ModeKeys.INFER: tgt_dataset = tf.contrib.data.Dataset.from_tensor_slices( tf.constant(["a b c b c", "a b c b"])) return (iterator_utils.get_iterator( src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, source_reverse=hparams.source_reverse, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets), src_vocab_table, tgt_vocab_table) else: return (iterator_utils.get_infer_iterator( src_dataset=src_dataset, src_vocab_table=src_vocab_table, eos=hparams.eos, source_reverse=hparams.source_reverse, batch_size=hparams.batch_size), src_vocab_table, tgt_vocab_table, reverse_tgt_vocab_table)
def create_tables(self): self.src_table = lookup_ops.index_table_from_tensor( tf.constant(self.src_vocab), default_value=Dataset.UNK) if self.config.share_vocab: self.tgt_table = self.src_table else: self.tgt_table = lookup_ops.index_table_from_tensor( tf.constant(self.tgt_vocab), default_value=Dataset.UNK)
def create_tables(self): with open(os.path.join(self.config.config_dir, 'src_vocab')) as f: self.src_vocab = [l.rstrip('\n') for l in f] self.src_table = lookup_ops.index_table_from_tensor( tf.constant(self.src_vocab), default_value=Dataset.UNK) self.src_vocab_size = len(self.src_vocab) with open(os.path.join(self.config.config_dir, 'tgt_vocab')) as f: self.tgt_vocab = [l.rstrip('\n') for l in f] self.tgt_table = lookup_ops.index_table_from_tensor( tf.constant(self.tgt_vocab), default_value=Dataset.UNK) self.tgt_vocab_size = len(self.tgt_vocab)
def create_test_iterator(hparams, mode, trie_excludes=None): """Create test iterator.""" src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant([hparams.eos, "a", "b", "c", "d"])) tgt_vocab_mapping = tf.constant([hparams.sos, hparams.eos, "a", "b", "c"]) tgt_vocab_table = lookup_ops.index_table_from_tensor(tgt_vocab_mapping) reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_tensor( tgt_vocab_mapping) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["a a b b c", "a b b"])) ctx_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c b c b a", "b c b a"])) trie_excludes = trie_excludes or [] trie_excludes = " {} ".format(hparams.eos).join(trie_excludes) tex_dataset = tf.data.Dataset.from_tensor_slices( tf.constant([trie_excludes, trie_excludes])) if mode != tf.contrib.learn.ModeKeys.INFER: tgt_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["a b c b c", "a b c b"])) return (iterator_utils.get_iterator(hparams=hparams, src_dataset=src_dataset, tgt_dataset=tgt_dataset, ctx_dataset=ctx_dataset, annot_dataset=None, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets), src_vocab_table, tgt_vocab_table, reverse_tgt_vocab_table) else: return (iterator_utils.get_infer_iterator( hparams=hparams, src_dataset=src_dataset, ctx_dataset=ctx_dataset, annot_dataset=None, trie_exclude_dataset=tex_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, eos=hparams.eos, batch_size=hparams.batch_size), src_vocab_table, tgt_vocab_table, reverse_tgt_vocab_table)
def test_index_table_from_tensor_with_invalid_hashers(self): with self.test_session(): with self.assertRaises(TypeError): lookup_ops.index_table_from_tensor( vocabulary_list=["brain", "salad", "surgery"], num_oov_buckets=1, hasher_spec=1) table = lookup_ops.index_table_from_tensor( vocabulary_list=["brain", "salad", "surgery"], num_oov_buckets=1, hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None)) self.assertRaises(ValueError, table.lookup, constant_op.constant(["salad", "surgery", "tarkus"]))
def load_vocab_table(vocab_bpe_file): assert os.path.exists(vocab_bpe_file) with open(vocab_bpe_file, 'r', encoding='utf8') as f: vocab_list = f.readlines() vocab_list = [word.strip() for word in vocab_list] return lookup_ops.index_table_from_tensor(vocab_list, default_value=0), len(vocab_list)
def create_loss(self, features, mode, logits, labels): """See `Head`.""" del mode # Unused for this head. logits = ops.convert_to_tensor(logits) labels = _check_dense_labels_match_logits_and_reshape( labels=labels, logits=logits, expected_labels_dimension=1) if self._label_vocabulary is not None: labels = lookup_ops.index_table_from_tensor( vocabulary_list=tuple(self._label_vocabulary), name='class_id_lookup').lookup(labels) labels = math_ops.to_float(labels) labels = _assert_range(labels, 2) unweighted_loss = nn.sigmoid_cross_entropy_with_logits( labels=labels, logits=logits) weights = _get_weights_and_check_match_logits( features=features, weight_column=self._weight_column, logits=logits) weighted_sum_loss = losses.compute_weighted_loss( unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) # _weights() can return 1. example_weight_sum = math_ops.reduce_sum( weights * array_ops.ones_like(unweighted_loss)) return LossSpec( weighted_sum_loss=weighted_sum_loss, example_weight_sum=example_weight_sum, processed_labels=labels)
def _process_labels(self, labels): if labels is None: raise ValueError( 'You must provide a labels Tensor. Given: None. ' 'Suggested troubleshooting steps: Check that your data contain ' 'your label feature. Check that your input_fn properly parses and ' 'returns labels.') if isinstance(labels, sparse_tensor.SparseTensor): if labels.dtype == dtypes.string: label_ids_values = lookup_ops.index_table_from_tensor( vocabulary_list=tuple(self._label_vocabulary), name='class_id_lookup').lookup(labels.values) label_ids = sparse_tensor.SparseTensor( indices=labels.indices, values=label_ids_values, dense_shape=labels.dense_shape) else: label_ids = labels return math_ops.to_int64( sparse_ops.sparse_to_indicator(label_ids, self._n_classes)) msg = ('labels shape must be [batch_size, {}]. ' 'Given: ').format(self._n_classes) labels_shape = array_ops.shape(labels) check_rank_op = control_flow_ops.Assert( math_ops.equal(array_ops.rank(labels), 2), data=[msg, labels_shape]) check_label_dim = control_flow_ops.Assert( math_ops.equal(labels_shape[-1], self._n_classes), data=[msg, labels_shape]) with ops.control_dependencies([check_rank_op, check_label_dim]): return array_ops.identity(labels)
def build_vocab_table(text_file, hparams, vocab_file): vocab_list = build_vocab_bpe(text_file, hparams, vocab_file) vocab_size = len(vocab_list) print('{} vocab size={}'.format(text_file, vocab_size)) vocabulary_list = tf.constant(vocab_list, dtype=tf.string) return lookup_ops.index_table_from_tensor(vocabulary_list, default_value=0), vocab_size
def __init__(self, alphabet: str): self.alphabet = alphabet chars = tf.constant([x.encode() for x in self.alphabet]) # chars = tf.constant(tf.strings.unicode_split(alphabet, input_encoding='UTF-8')) self._encode_table = lookup_ops.index_table_from_tensor(chars) self._decode_table = lookup_ops.index_to_string_table_from_tensor(chars)
def infer_iter(): file_path = os.path.abspath("test_files/en2end_iterator.txt") dataset = tf.contrib.data.TextLineDataset(file_path) eou = '</u>' eos = '</s>' src_reverse = False batch_size = 1 utt_max_len = 20 dialogue_max_len = 20 vocab_table = lookup_ops.index_table_from_tensor(tf.constant([""])) dataset = tf.contrib.data.Dataset.from_tensor_slices( tf.constant(["a b c </u> a a b </u>", "c a b c a </u> c b c a a </u>"])) iterator = end2end_iterator_utils.get_infer_iterator( dataset, vocab_table, batch_size, src_reverse, eos, eou, utt_max_len, dialogue_max_len) with tf.Session() as sess: sess.run(tf.tables_initializer()) sess.run(iterator.initializer) for i in range(2): source, lengths = sess.run( [iterator.source, iterator.source_sequence_length]) print(source) print(lengths)
def testGetInferIterator(self): src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c c a", "c a", "d", "f e a g"])) hparams = tf.contrib.training.HParams( random_seed=3, eos="eos", sos="sos") batch_size = 2 dataset = iterator_utils.get_infer_iterator( src_dataset=src_dataset, src_vocab_table=src_vocab_table, batch_size=batch_size, eos=hparams.eos) table_initializer = tf.tables_initializer() iterator = dataset.make_initializable_iterator() get_next = iterator.get_next() with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer) features = sess.run(get_next) self.assertAllEqual( [ [2, 2, 0], # c c a [2, 0, 3] ], # c a eos features["source"]) self.assertAllEqual([3, 2], features["source_sequence_length"])
def test_table_roundtrip(self): export_path = os.path.join(tempfile.mkdtemp(), 'export') with tf.Graph().as_default(): with tf.Session().as_default() as session: input_string = tf.placeholder(tf.string) # Map string through a table, in this case based on a constant tensor. table = lookup_ops.index_table_from_tensor( tf.constant(['cat', 'dog', 'giraffe'])) output = table.lookup(input_string) inputs = {'input': input_string} outputs = {'output': output} saved_transform_io.write_saved_transform_from_session( session, inputs, outputs, export_path) with tf.Graph().as_default(): with tf.Session().as_default() as session: # Using a computed input gives confidence that the graphs are fused. input_string = tf.constant('dog') inputs = {'input': input_string} _, outputs = ( saved_transform_io.partially_apply_saved_transform_internal( export_path, inputs)) session.run(tf.tables_initializer()) result = session.run(outputs['output']) self.assertEqual(1, result)
def testDecodeExampleWithBranchedLookup(self): example = example_pb2.Example(features=feature_pb2.Features(feature={ 'image/object/class/text': self._BytesFeatureFromList( np.array(['cat', 'dog', 'guinea pig'])), })) serialized_example = example.SerializeToString() # 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2 table = lookup_ops.index_table_from_tensor( constant_op.constant(['dog', 'guinea pig', 'cat'])) with self.test_session() as sess: sess.run(lookup_ops.tables_initializer()) serialized_example = array_ops.reshape(serialized_example, shape=[]) keys_to_features = { 'image/object/class/text': parsing_ops.VarLenFeature(dtypes.string), } items_to_handlers = { 'labels': tf_example_decoder.LookupTensor('image/object/class/text', table), } decoder = slim_example_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) obtained_class_ids = decoder.decode(serialized_example)[0].eval() self.assertAllClose([2, 0, 1], obtained_class_ids)
def __init__(self, n_classes, weight_column=None, label_vocabulary=None, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, name=None): if (n_classes is None) or (n_classes <= 2): raise ValueError('n_classes must be > 2: {}.'.format(n_classes)) if label_vocabulary is not None and not isinstance(label_vocabulary, (list, tuple)): raise ValueError( 'label_vocabulary should be a list or a tuple. Given type: {}'.format( type(label_vocabulary))) if (loss_reduction not in losses.Reduction.all() or loss_reduction == losses.Reduction.NONE): raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) if loss_fn: base_head.validate_loss_fn_args(loss_fn) self._n_classes = n_classes self._weight_column = weight_column self._label_vocabulary = label_vocabulary if label_vocabulary: self._class_string_table = lookup_ops.index_to_string_table_from_tensor( vocabulary_list=self._label_vocabulary, name='class_string_lookup') self._class_id_table = lookup_ops.index_table_from_tensor( vocabulary_list=tuple(self._label_vocabulary), name='class_id_lookup') self._loss_reduction = loss_reduction self._loss_fn = loss_fn self._name = name # Metric keys. keys = metric_keys.MetricKeys self._loss_mean_key = self._summary_key(keys.LOSS_MEAN) self._accuracy_key = self._summary_key(keys.ACCURACY) self._loss_regularization_key = self._summary_key(keys.LOSS_REGULARIZATION)
def _process_labels(self, labels): if isinstance(labels, sparse_tensor.SparseTensor): if labels.dtype == dtypes.string: label_ids_values = lookup_ops.index_table_from_tensor( vocabulary_list=tuple(self._label_vocabulary), name='class_id_lookup').lookup(labels.values) label_ids = sparse_tensor.SparseTensor( indices=labels.indices, values=label_ids_values, dense_shape=labels.dense_shape) else: label_ids = labels return math_ops.to_int64( sparse_ops.sparse_to_indicator(label_ids, self._n_classes)) msg = ('labels shape must be [batch_size, {}]. ' 'Given: ').format(self._n_classes) labels_shape = array_ops.shape(labels) check_rank_op = control_flow_ops.Assert( math_ops.equal(array_ops.rank(labels), 2), data=[msg, labels_shape]) check_label_dim = control_flow_ops.Assert( math_ops.equal(labels_shape[-1], self._n_classes), data=[msg, labels_shape]) with ops.control_dependencies([check_rank_op, check_label_dim]): return array_ops.identity(labels)
def get_vocab_table(vocab_file, reverse=False): vocabs = [PAD, UNK, SOS, EOS] counts = [0, 0, 0, 0] with open(vocab_file, 'r', encoding='utf-8') as f: for line in f: vocab, count = line.strip().split('\t') vocabs.append(vocab) counts.append(int(count)) # probability of EOS should be 1. / (average target token num + 1). # currently 2.75 sum_counts = sum(counts) eos_prob = 1. / (2.75 + 1) vocab_probs = np.array(counts, dtype=np.float32) * (1 - eos_prob) \ / sum_counts vocab_probs[3] = eos_prob if not reverse: vocab_table = lookup_ops.index_table_from_tensor(vocabs, default_value=1) else: vocab_table = lookup_ops.index_to_string_table_from_tensor( vocabs, default_value=UNK) return vocab_table, vocab_probs
def testDecodeExampleWithBranchedLookup(self): example = example_pb2.Example(features=feature_pb2.Features( feature={ 'image/object/class/text': self._BytesFeatureFromList( np.array(['cat', 'dog', 'guinea pig'])), })) serialized_example = example.SerializeToString() # 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2 table = lookup_ops.index_table_from_tensor( constant_op.constant(['dog', 'guinea pig', 'cat'])) with self.test_session() as sess: sess.run(lookup_ops.tables_initializer()) serialized_example = array_ops.reshape(serialized_example, shape=[]) keys_to_features = { 'image/object/class/text': parsing_ops.VarLenFeature(dtypes.string), } items_to_handlers = { 'labels': tf_example_decoder.LookupTensor('image/object/class/text', table), } decoder = slim_example_decoder.TFExampleDecoder( keys_to_features, items_to_handlers) obtained_class_ids = decoder.decode(serialized_example)[0].eval() self.assertAllClose([2, 0, 1], obtained_class_ids)
def call(self, inputs): table = lookup_ops.index_table_from_tensor( vocabulary_list=self.vocabulary_list, num_oov_buckets=1, default_value=-1, ) return tf.cast(table.lookup(inputs), tf.int64)
def build_lookup_table(self): tensor = tf.constant(self.idx_to_token, dtype=tf.string) self.token_to_idx_table = index_table_from_tensor( tensor, num_oov_buckets=1 if self.unk_idx is None else 0, default_value=-1 if self.unk_idx is None else self.unk_idx) self.idx_to_token_table = index_to_string_table_from_tensor( self.idx_to_token, self.safe_unk_token)
def testGetIteratorWithShard(self): tf.set_random_seed(1) tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c c a", "f e a g", "d", "c a"])) tgt_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["a b", "c c", "", "b c"])) hparams = tf.contrib.training.HParams(random_seed=3, num_buckets=5, eos="eos", sos="sos") batch_size = 2 src_max_len = 3 dataset = iterator_utils.get_iterator(src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=src_max_len, num_shards=2, shard_index=1, reshuffle_each_iteration=False) table_initializer = tf.tables_initializer() iterator = dataset.make_initializable_iterator() get_next = iterator.get_next() with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer) features = sess.run(get_next) self.assertAllEqual( [ [-1, -1, 0], # "f" == unknown, "e" == unknown, a [2, 0, 3] ], # c a eos -- eos is padding features["source"]) self.assertAllEqual([3, 2], features["source_sequence_length"]) self.assertAllEqual( [ [4, 2, 2], # sos c c [4, 1, 2] ], # sos b c features["target_input"]) self.assertAllEqual( [ [2, 2, 3], # c c eos [1, 2, 3] ], # b c eos features["target_output"]) self.assertAllEqual([3, 3], features["target_sequence_length"]) with self.assertRaisesOpError("End of sequence"): sess.run(get_next)
def test_index_table_from_tensor_with_tensor_init(self): with self.test_session(): table = lookup_ops.index_table_from_tensor( vocabulary_list=("brain", "salad", "surgery"), num_oov_buckets=1) ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))) self.assertRaises(errors_impl.OpError, ids.eval) lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval())
def test_index_table_from_tensor_empty_vocabulary_list(self): with self.test_session(): table = lookup_ops.index_table_from_tensor( vocabulary_list=np.array([], dtype=np.str_), num_oov_buckets=1) ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"])) self.assertRaises(errors_impl.OpError, ids.eval) with self.assertRaisesRegexp( errors_impl.OpError, "keys and values cannot be empty"): lookup_ops.tables_initializer().run()
def test_int64_index_table_from_tensor_with_tensor_init(self): with self.test_session(): table = lookup_ops.index_table_from_tensor( vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64) ids = table.lookup( constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) self.assertRaises(errors_impl.OpError, ids.eval) lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval())
def test_index_table_from_tensor_with_default_value(self): default_value = -42 with self.test_session(): table = lookup_ops.index_table_from_tensor( vocabulary_list=["brain", "salad", "surgery"], default_value=default_value) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, default_value), ids.eval())
def create_test_iterator(hparams, mode): """Create test iterator.""" src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant([hparams.eos, "a", "b", "c", "d"])) tgt_vocab_mapping = tf.constant([hparams.sos, hparams.eos, "a", "b", "c"]) tgt_vocab_table = lookup_ops.index_table_from_tensor(tgt_vocab_mapping) if mode == tf.contrib.learn.ModeKeys.INFER: reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_tensor( tgt_vocab_mapping) src_dataset = tf.contrib.data.Dataset.from_tensor_slices( tf.constant(["a a b b c", "a b b"])) if mode != tf.contrib.learn.ModeKeys.INFER: tgt_dataset = tf.contrib.data.Dataset.from_tensor_slices( tf.constant(["a b c b c", "a b c b"])) return ( iterator_utils.get_iterator( src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, source_reverse=hparams.source_reverse, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets), src_vocab_table, tgt_vocab_table) else: return ( iterator_utils.get_infer_iterator( src_dataset=src_dataset, src_vocab_table=src_vocab_table, eos=hparams.eos, source_reverse=hparams.source_reverse, batch_size=hparams.batch_size), src_vocab_table, tgt_vocab_table, reverse_tgt_vocab_table)
def get_label_ids(labels, label_vocabulary): if label_vocabulary is None: if not labels.dtype.is_integer: raise ValueError( "Labels dtype should be integer. Instead got {}.".format( labels.dtype)) label_ids = labels else: label_ids = lookup_ops.index_table_from_tensor( vocabulary_list=tuple(label_vocabulary), name="class_id_lookup").lookup(labels) return label_ids
def model_fn(features, labels, mode): model = build_resnet50(kwargs['image_key'], image_height=kwargs['image_height'], image_width=kwargs['image_width'], number_of_classes=len(vocabulary), weights=kwargs['pretrained']) logits = model(features, training=False) class_ids = tf.cast(tf.sort(tf.argsort(logits, axis=1)), tf.int64) classes = lookup_ops.index_to_string_table_from_tensor( vocabulary_list=tuple(vocabulary), name='class_string_lookup').lookup(class_ids) predictions = { 'probabilities': tf.nn.softmax(logits), 'class_id': classes, 'logits': logits, } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) label_ids = lookup_ops.index_table_from_tensor( vocabulary_list=tuple(vocabulary), name='class_id_lookup').lookup(labels) optimizer = tf.compat.v1.train.GradientDescentOptimizer( learning_rate=.001) loss = tf.keras.losses.CategoricalCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(label_ids, logits) loss = tf.reduce_sum(loss) * (1. / 4) mean = tf.compat.v1.metrics.mean(loss) accuracy = tf.compat.v1.metrics.accuracy(class_ids, label_ids) tf.summary.scalar('accuracy', accuracy[1]) eval_metric_ops = {"accuracy": accuracy, "mean": mean} train_op = optimizer.minimize( loss, tf.compat.v1.train.get_or_create_global_step()) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, eval_metric_ops=eval_metric_ops) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, predictions=predictions, eval_metric_ops=eval_metric_ops)
def gen_vocab_tables(hparams): # vocab to word_id mapping # if hparams.use_pretrained_embedding: # vocab_file = hparams.embedding_dir # tf.logging.info("loading pretrained embedding from %s" % vocab_file) # else: vocab_file = hparams.vocab_file tf.logging.info("loading vocab from %s" % vocab_file) tf.logging.info("%d vocab loaded" % hparams.vocab_size) vocab_mapping_strings = tf.constant(hparams.vocab) vocab_table = lookup_ops.index_table_from_tensor( vocab_mapping_strings, default_value=UNK_id) return [vocab_table, vocab_table]
def testGetIterator(self): tf.set_random_seed(1) tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["f e a g", "c c a", "d", "c a"])) tgt_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c c", "a b", "", "b c"])) hparams = tf.contrib.training.HParams(random_seed=3, num_buckets=1, eos="eos", sos="sos") batch_size = 2 src_max_len = 5 dataset = iterator_utils.get_iterator(src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=batch_size, global_batch_size=batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=src_max_len, reshuffle_each_iteration=False) table_initializer = tf.tables_initializer() iterator = dataset.make_initializable_iterator() get_next = iterator.get_next() with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer) features = sess.run(get_next) self.assertAllEqual( [ [4, 2, 0, 3, 3], # c a eos -- eos is padding [4, 2, 2, 0, 3] ], # c c a features["source"]) self.assertAllEqual([4, 5], features["source_sequence_length"]) self.assertAllEqual( [ [4, 1, 2], # sos b c [4, 0, 1] ], # sos a b features["target_input"]) self.assertAllEqual( [ [1, 2, 3], # b c eos [0, 1, 3] ], # a b eos features["target_output"]) self.assertAllEqual([3, 3], features["target_sequence_length"])
def build(self, input_shape): # categorical with vocabulary list. if isinstance(self.vocabulary, (tuple, list, np.ndarray)): self.table = lookup_ops.index_table_from_tensor( vocabulary_list=self.vocabulary, num_oov_buckets=self.num_oov_tokens, dtype=self._input_dtype) # categorical with vocabulary file. elif self.vocabulary: self.table = lookup_ops.index_table_from_file( vocabulary_file=self.vocabulary, num_oov_buckets=self.num_oov_tokens, key_dtype=self._input_dtype)
def testDecodeExampleWithBranchedBackupHandler(self): example1 = example_pb2.Example( features=feature_pb2.Features( feature={ 'image/object/class/text': self._BytesFeatureFromList( np.array(['cat', 'dog', 'guinea pig'])), 'image/object/class/label': self._Int64FeatureFromList(np.array([42, 10, 900])) })) example2 = example_pb2.Example( features=feature_pb2.Features( feature={ 'image/object/class/text': self._BytesFeatureFromList( np.array(['cat', 'dog', 'guinea pig'])), })) example3 = example_pb2.Example( features=feature_pb2.Features( feature={ 'image/object/class/label': self._Int64FeatureFromList(np.array([42, 10, 901])) })) # 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2 table = lookup_ops.index_table_from_tensor( constant_op.constant(['dog', 'guinea pig', 'cat'])) keys_to_features = { 'image/object/class/text': parsing_ops.VarLenFeature(dtypes.string), 'image/object/class/label': parsing_ops.VarLenFeature(dtypes.int64), } backup_handler = tf_example_decoder.BackupHandler( handler=slim_example_decoder.Tensor('image/object/class/label'), backup=tf_example_decoder.LookupTensor('image/object/class/text', table)) items_to_handlers = { 'labels': backup_handler, } decoder = slim_example_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) obtained_class_ids_each_example = [] with self.test_session() as sess: sess.run(lookup_ops.tables_initializer()) for example in [example1, example2, example3]: serialized_example = array_ops.reshape( example.SerializeToString(), shape=[]) obtained_class_ids_each_example.append( decoder.decode(serialized_example)[0].eval()) self.assertAllClose([42, 10, 900], obtained_class_ids_each_example[0]) self.assertAllClose([2, 0, 1], obtained_class_ids_each_example[1]) self.assertAllClose([42, 10, 901], obtained_class_ids_each_example[2])
def build(self, _): self.word_embedding = self.add_variable( name="word_embedding", shape=[ 1 + self.voc_size + self.num_oov_buckets, self.embedding_size ] #1 for padding , initializer=word_embedding_initializer(self.embedding_file, include_word=False) if self.embedding_file != None else tf.random_uniform_initializer( -1, 1), regularizer=tf.nn.l2_loss) self.feature_lookup_table = index_table_from_file( vocabulary_file=self.voc_file, num_oov_buckets=self.num_oov_buckets, vocab_size=self.voc_size, default_value=-1, key_dtype=tf.string, name='feature_index_lookup') if self.use_char_embedding: char_list = [ "0", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", ".", "'" ] char_list = tf.constant(char_list, dtype=tf.string) self.char_lookup_table = index_table_from_tensor( char_list, 100, default_value=-1) # default_value must be -1 self.char_embeddding = self.add_variable( "char_embedding", [27 + 2 + 100, self.char_embeddding_size], initializer=tf.random_uniform_initializer(-1, 1), regularizer=tf.nn.l2_loss) self.charConv = tf.layers.Conv1D( filters=self.char_filters, kernel_size=self.char_conv_kernel_size, activation=tf.nn.relu, use_bias=True, bias_initializer=tf.zeros_initializer(), trainable=True, kernel_regularizer=tf.nn.l2_loss) if self.use_highway: self.highway = HighwayLayer( self.embedding_size + self.char_filters if self.use_char_embedding else self.embedding_size, layers_number=2, is_trainning=self.trainable, dropout_rate=self.dropout_rate) self.built = True
def create_loss(self, features, mode, logits, labels): """See `Head`.""" del mode, features # Unused for this head. labels = _check_and_reshape_dense_labels(labels, self.logits_dimension) if self._label_vocabulary is not None: labels = lookup_ops.index_table_from_tensor( vocabulary_list=tuple(self._label_vocabulary), name='class_id_lookup').lookup(labels) labels = math_ops.to_float(labels) labels = _assert_range(labels, 2) return LossAndLabels( unweighted_loss=nn.sigmoid_cross_entropy_with_logits( labels=labels, logits=logits), processed_labels=labels)
def _class_id_table(self): """Creates a lookup table for class_id. In eager execution, this lookup table will be lazily created on the first call of `self._class_id_table`, and cached for later use; In graph execution, it will be created on demand. Returns: A hash table for lookup. """ if self._cached_class_id_table is None or not tf.executing_eagerly(): self._cached_class_id_table = lookup_ops.index_table_from_tensor( vocabulary_list=tuple(self._label_vocabulary), name='class_id_lookup') return self._cached_class_id_table
def _label_ids(self, labels): """Converts labels to integer id space.""" if self._label_vocabulary is None: if not labels.dtype.is_integer: raise ValueError('Labels dtype should be integer ' 'Instead got %s.' % labels.dtype) label_ids = labels else: if labels.dtype != dtypes.string: raise ValueError('Labels dtype should be string if there is a ' 'vocabulary. Instead got {}'.format(labels.dtype)) label_ids = lookup_ops.index_table_from_tensor( vocabulary_list=tuple(self._label_vocabulary), name='class_id_lookup').lookup(labels) return _assert_range(label_ids, self._n_classes)
def _class_id_table(self): """Creates a lookup table for class_id. This lookup table will be lazily created on the first call of `self._class_id_table`, and cached for later use. This makes it eager-friendly, and also gunrantees the lookup Op is contained in the graph for Graph execution. Returns: A hash table for lookup. """ if self._cached_class_id_table is None: self._cached_class_id_table = lookup_ops.index_table_from_tensor( vocabulary_list=tuple(self._label_vocabulary), name='class_id_lookup') return self._cached_class_id_table
def testGetInferIterator(self): src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c c a", "c a", "d", "f e a g"])) hparams = tf.contrib.training.HParams(random_seed=3, source_reverse=False, eos="eos", sos="sos") batch_size = 2 src_max_len = 3 iterator = iterator_utils.get_infer_iterator( src_dataset=src_dataset, src_vocab_table=src_vocab_table, batch_size=batch_size, eos=hparams.eos, source_reverse=hparams.source_reverse, src_max_len=src_max_len) table_initializer = tf.tables_initializer() source = iterator.source seq_len = iterator.source_sequence_length self.assertEqual([None, None], source.shape.as_list()) self.assertEqual([None], seq_len.shape.as_list()) with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer) (source_v, seq_len_v) = sess.run((source, seq_len)) self.assertAllEqual( [ [2, 2, 0], # c c a [2, 0, 3] ], # c a eos source_v) self.assertAllEqual([3, 2], seq_len_v) (source_v, seq_len_v) = sess.run((source, seq_len)) self.assertAllEqual( [ [-1, 3, 3], # "d" == unknown, eos eos [-1, -1, 0] ], # "f" == unknown, "e" == unknown, a source_v) self.assertAllEqual([1, 3], seq_len_v) with self.assertRaisesOpError("End of sequence"): sess.run((source, seq_len))
def testGetInferIterator(self): src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c c a", "c a", "d", "f e a g"])) hparams = tf.contrib.training.HParams( random_seed=3, eos="eos", sos="sos") batch_size = 2 src_max_len = 3 iterator = iterator_utils.get_infer_iterator( src_dataset=src_dataset, src_vocab_table=src_vocab_table, batch_size=batch_size, eos=hparams.eos, src_max_len=src_max_len) table_initializer = tf.tables_initializer() source = iterator.source seq_len = iterator.source_sequence_length self.assertEqual([None, None], source.shape.as_list()) self.assertEqual([None], seq_len.shape.as_list()) with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer) (source_v, seq_len_v) = sess.run((source, seq_len)) self.assertAllEqual( [[2, 2, 0], # c c a [2, 0, 3]], # c a eos source_v) self.assertAllEqual([3, 2], seq_len_v) (source_v, seq_len_v) = sess.run((source, seq_len)) self.assertAllEqual( [[-1, 3, 3], # "d" == unknown, eos eos [-1, -1, 0]], # "f" == unknown, "e" == unknown, a source_v) self.assertAllEqual([1, 3], seq_len_v) with self.assertRaisesOpError("End of sequence"): sess.run((source, seq_len))
def _label_ids(self, labels): """Converts labels to integer id space.""" if self._label_vocabulary is None: if not labels.dtype.is_integer: raise ValueError('Labels dtype should be integer ' 'Instead got %s.' % labels.dtype) label_ids = labels else: if labels.dtype != dtypes.string: raise ValueError('Labels dtype should be string if there is a ' 'vocabulary. Instead got {}'.format(labels.dtype)) label_ids = lookup_ops.index_table_from_tensor( vocabulary_list=tuple(self._label_vocabulary), name='class_id_lookup').lookup(labels) assert_less = check_ops.assert_less( label_ids, ops.convert_to_tensor(self._n_classes, dtype=label_ids.dtype), message='Label IDs must < n_classes') assert_greater = check_ops.assert_non_negative( label_ids, message='Label Ids must >= 0') with ops.control_dependencies((assert_less, assert_greater)): return array_ops.identity(label_ids)
def _process_labels(self, labels): if labels is None: raise ValueError( 'You must provide a labels Tensor. Given: None. ' 'Suggested troubleshooting steps: Check that your data contain ' 'your label feature. Check that your input_fn properly parses and ' 'returns labels.') if isinstance(labels, sparse_tensor.SparseTensor): if labels.dtype == dtypes.string: label_ids_values = lookup_ops.index_table_from_tensor( vocabulary_list=tuple(self._label_vocabulary), name='class_id_lookup').lookup(labels.values) label_ids = sparse_tensor.SparseTensor( indices=labels.indices, values=label_ids_values, dense_shape=labels.dense_shape) return math_ops.to_int64( sparse_ops.sparse_to_indicator(label_ids, self._n_classes)) else: err_msg = ( r'labels must be an integer SparseTensor with values in ' r'[0, {})'.format(self._n_classes)) assert_int = check_ops.assert_integer( labels.values, message=err_msg) assert_less = check_ops.assert_less( labels.values, ops.convert_to_tensor(self._n_classes, dtype=labels.dtype), message=err_msg) assert_greater = check_ops.assert_non_negative( labels.values, message=err_msg) with ops.control_dependencies( [assert_int, assert_less, assert_greater]): return math_ops.to_int64( sparse_ops.sparse_to_indicator(labels, self._n_classes)) err_msg = ( r'labels must be an integer indicator Tensor with values in [0, 1]') return head_lib._assert_range(labels, 2, message=err_msg) # pylint:disable=protected-access,
def testGetIterator(self): tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["f e a g", "c c a", "d", "c a"])) tgt_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c c", "a b", "", "b c"])) hparams = tf.contrib.training.HParams( random_seed=3, num_buckets=5, eos="eos", sos="sos") batch_size = 2 src_max_len = 3 iterator = iterator_utils.get_iterator( src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=src_max_len) table_initializer = tf.tables_initializer() source = iterator.source target_input = iterator.target_input target_output = iterator.target_output src_seq_len = iterator.source_sequence_length tgt_seq_len = iterator.target_sequence_length self.assertEqual([None, None], source.shape.as_list()) self.assertEqual([None, None], target_input.shape.as_list()) self.assertEqual([None, None], target_output.shape.as_list()) self.assertEqual([None], src_seq_len.shape.as_list()) self.assertEqual([None], tgt_seq_len.shape.as_list()) with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer) (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( sess.run((source, src_seq_len, target_input, target_output, tgt_seq_len))) self.assertAllEqual( [[-1, -1, 0], # "f" == unknown, "e" == unknown, a [2, 0, 3]], # c a eos -- eos is padding source_v) self.assertAllEqual([3, 2], src_len_v) self.assertAllEqual( [[4, 2, 2], # sos c c [4, 1, 2]], # sos b c target_input_v) self.assertAllEqual( [[2, 2, 3], # c c eos [1, 2, 3]], # b c eos target_output_v) self.assertAllEqual([3, 3], tgt_len_v) (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( sess.run((source, src_seq_len, target_input, target_output, tgt_seq_len))) self.assertAllEqual( [[2, 2, 0]], # c c a source_v) self.assertAllEqual([3], src_len_v) self.assertAllEqual( [[4, 0, 1]], # sos a b target_input_v) self.assertAllEqual( [[0, 1, 3]], # a b eos target_output_v) self.assertAllEqual([3], tgt_len_v) with self.assertRaisesOpError("End of sequence"): sess.run(source)
def index_table_from_tensor(mapping, num_oov_buckets=0, default_value=-1, hasher_spec=FastHashSpec, dtype=dtypes.string, name=None): """Returns a lookup table that converts a string tensor into int64 IDs. This operation constructs a lookup table to convert tensor of strings into int64 IDs. The mapping can be initialized from a string `mapping` 1-D tensor where each element is a key and corresponding index within the tensor is the value. Any lookup of an out-of-vocabulary token will return a bucket ID based on its hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the `default_value`. The bucket ID range is `[mapping size, mapping size + num_oov_buckets]`. The underlying table must be initialized by calling `tf.tables_initializer.run()` or `table.init.run()` once. Elements in `mapping` cannot have duplicates, otherwise when executing the table initializer op, it will throw a `FailedPreconditionError`. Sample Usages: ```python mapping_strings = tf.constant(["emerson", "lake", "palmer"]) table = tf.contrib.lookup.index_table_from_tensor( mapping=mapping_strings, num_oov_buckets=1, default_value=-1) features = tf.constant(["emerson", "lake", "and", "palmer"]) ids = table.lookup(features) ... tf.tables_initializer().run() ids.eval() ==> [0, 1, 4, 2] ``` Args: mapping: A 1-D `Tensor` that specifies the mapping of keys to indices. The type of this object must be castable to `dtype`. num_oov_buckets: The number of out-of-vocabulary buckets. default_value: The value to use for out-of-vocabulary feature values. Defaults to -1. hasher_spec: A `HasherSpec` to specify the hash function to use for assignment of out-of-vocabulary buckets. dtype: The type of values passed to `lookup`. Only string and integers are supported. name: A name for this op (optional). Returns: The lookup table to map an input `Tensor` to index `int64` `Tensor`. Raises: ValueError: If `mapping` is invalid. ValueError: If `num_oov_buckets` is negative. """ if mapping is None: raise ValueError("mapping must be specified.") return lookup_ops.index_table_from_tensor( vocabulary_list=mapping, num_oov_buckets=num_oov_buckets, default_value=default_value, hasher_spec=hasher_spec, dtype=dtype, name=name)
def create_estimator_spec( self, features, mode, logits, labels=None, train_op_fn=None): """See `Head`.""" # Predict. with ops.name_scope('head'): with ops.name_scope(None, 'predictions', (logits,)): pred_keys = prediction_keys.PredictionKeys logits = _check_logits(logits, self.logits_dimension) logistic = math_ops.sigmoid(logits, name=pred_keys.LOGISTIC) two_class_logits = array_ops.concat( (array_ops.zeros_like(logits), logits), 1, name='two_class_logits') scores = nn.softmax(two_class_logits, name=pred_keys.PROBABILITIES) class_ids = array_ops.reshape( math_ops.argmax(two_class_logits, axis=1), (-1, 1), name='classes') if self._label_vocabulary: table = lookup_ops.index_to_string_table_from_tensor( vocabulary_list=self._label_vocabulary, name='class_string_lookup') classes = table.lookup(class_ids) else: classes = string_ops.as_string(class_ids, name='str_classes') predictions = { pred_keys.LOGITS: logits, pred_keys.LOGISTIC: logistic, pred_keys.PROBABILITIES: scores, pred_keys.CLASS_IDS: class_ids, pred_keys.CLASSES: classes, } if mode == model_fn.ModeKeys.PREDICT: batch_size = array_ops.shape(logistic)[0] export_class_list = self._label_vocabulary if not export_class_list: export_class_list = string_ops.as_string([0, 1]) export_output_classes = array_ops.tile( input=array_ops.expand_dims(input=export_class_list, axis=0), multiples=[batch_size, 1]) classifier_output = export_output.ClassificationOutput( scores=scores, # `ClassificationOutput` requires string classes. classes=export_output_classes) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={ '': classifier_output, # to be same as other heads. 'classification': classifier_output, # to be called by name. _DEFAULT_SERVING_KEY: classifier_output, # default 'regression': export_output.RegressionOutput(value=logistic) }) # Eval. labels = _check_labels(_maybe_expand_dim(labels), self.logits_dimension) if self._label_vocabulary is not None: labels = lookup_ops.index_table_from_tensor( vocabulary_list=tuple(self._label_vocabulary), name='class_id_lookup').lookup(labels) labels = math_ops.to_float(labels) labels = _assert_range(labels, 2) unweighted_loss = nn.sigmoid_cross_entropy_with_logits( labels=labels, logits=logits, name='loss') weights = _weights(features, self._weight_column) training_loss = losses.compute_weighted_loss( unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) if mode == model_fn.ModeKeys.EVAL: return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions=predictions, loss=training_loss, eval_metric_ops=self._eval_metric_ops( labels=labels, logits=logits, logistic=logistic, scores=scores, class_ids=class_ids, unweighted_loss=unweighted_loss, weights=weights)) # Train. if train_op_fn is None: raise ValueError('train_op_fn can not be None.') with ops.name_scope(''): summary.scalar(metric_keys.MetricKeys.LOSS, training_loss) summary.scalar(metric_keys.MetricKeys.LOSS_MEAN, losses.compute_weighted_loss( unweighted_loss, weights=weights, reduction=losses.Reduction.MEAN)) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=training_loss, train_op=train_op_fn(training_loss))
def test_index_table_from_tensor_missing_vocabulary_list(self): with self.test_session(): with self.assertRaisesRegexp(ValueError, "vocabulary_list must be specified"): lookup_ops.index_table_from_tensor( vocabulary_list=None, num_oov_buckets=1)
def create_estimator_spec( self, features, mode, logits, labels=None, train_op_fn=None): """See `Head`.""" with variable_scope.variable_scope( None, default_name='binary_logistic_head', values=(tuple(six.itervalues(features)) + (labels, logits))): # Predict. pred_keys = prediction_keys.PredictionKeys logits = _check_logits(logits, self.logits_dimension) logistic = math_ops.sigmoid(logits, name=pred_keys.LOGISTIC) two_class_logits = array_ops.concat( (array_ops.zeros_like(logits), logits), 1, name='two_class_logits') scores = nn.softmax(two_class_logits, name=pred_keys.PROBABILITIES) class_ids = array_ops.reshape( math_ops.argmax(two_class_logits, axis=1), (-1, 1), name='classes') if self._label_vocabulary: table = lookup_ops.index_to_string_table_from_tensor( vocabulary_list=self._label_vocabulary, name='class_string_lookup') classes = table.lookup(class_ids) else: classes = string_ops.as_string(class_ids, name='str_classes') predictions = { pred_keys.LOGITS: logits, pred_keys.LOGISTIC: logistic, pred_keys.PROBABILITIES: scores, pred_keys.CLASS_IDS: class_ids, pred_keys.CLASSES: classes, } if mode == model_fn.ModeKeys.PREDICT: return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={ '': export_output.ClassificationOutput( scores=scores, classes=classes) }) # Eval. labels = _check_labels(_maybe_expand_dim(labels), self.logits_dimension) if self._label_vocabulary is not None: labels = lookup_ops.index_table_from_tensor( vocabulary_list=tuple(self._label_vocabulary), name='class_id_lookup').lookup(labels) labels = math_ops.to_float(labels) labels = _assert_range(labels, 2) unweighted_loss = nn.sigmoid_cross_entropy_with_logits( labels=labels, logits=logits, name='loss') weights = ( 1. if (self._weight_feature_key is None) else features[self._weight_feature_key]) weights = _maybe_expand_dim(math_ops.to_float(weights, name='weights')) training_loss = losses.compute_weighted_loss( unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) if mode == model_fn.ModeKeys.EVAL: return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions=predictions, loss=training_loss, eval_metric_ops=self._eval_metric_ops( labels=labels, logits=logits, logistic=logistic, scores=scores, class_ids=class_ids, unweighted_loss=unweighted_loss, weights=weights)) # Train. if train_op_fn is None: raise ValueError('train_op_fn can not be None.') logging_ops.scalar_summary(metric_keys.MetricKeys.LOSS, training_loss) logging_ops.scalar_summary( metric_keys.MetricKeys.LOSS_MEAN, losses.compute_weighted_loss( unweighted_loss, weights=weights, reduction=losses.Reduction.MEAN)) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=training_loss, train_op=train_op_fn(training_loss))