Esempio n. 1
0
    def testGetInputFnCommon(self):
        """Tests get_input_fn_common"""
        feature_type2name = {
            InputFtrType.QUERY_COLUMN_NAME:
            'query',
            InputFtrType.DOC_TEXT_COLUMN_NAMES: ['doc_completedQuery'],
            InputFtrType.DOC_ID_COLUMN_NAMES: ['docId_completedQuery'],
            InputFtrType.USER_TEXT_COLUMN_NAMES:
            ['usr_headline', 'usr_skills', 'usr_currTitles'],
            InputFtrType.USER_ID_COLUMN_NAMES: ['usrId_currTitles'],
            InputFtrType.DENSE_FTRS_COLUMN_NAMES:
            'wide_ftrs',
            InputFtrType.LABEL_COLUMN_NAME:
            'label',
            InputFtrType.WEIGHT_COLUMN_NAME:
            'weight'
        }
        feature_name2num = {'wide_ftrs': 5}

        _, vocab_tf_table = vocab_utils.read_tf_vocab(self.vocab_file,
                                                      self.UNK)
        vocab_table = vocab_utils.read_vocab(self.vocab_file)
        data_dir = self.data_dir
        hparams = HParams(input_pattern=data_dir,
                          filter_window_sizes=[10],
                          CLS=self.CLS,
                          PAD=self.PAD,
                          SEP=self.SEP,
                          UNK=self.UNK,
                          UNK_FOR_ID_FTR=self.UNK,
                          PAD_FOR_ID_FTR=self.PAD,
                          min_len=1,
                          max_len=16,
                          vocab_file=self.vocab_file,
                          vocab_file_for_id_ftr=self.vocab_file,
                          PAD_ID=vocab_table[self.PAD],
                          SEP_ID=vocab_table[self.SEP],
                          CLS_ID=vocab_table[self.CLS],
                          mode=tf.estimator.ModeKeys.EVAL,
                          task_type=self.task_type,
                          vocab_table=vocab_tf_table,
                          vocab_table_for_id_ftr=vocab_tf_table,
                          max_filter_window_size=3,
                          vocab_hub_url='',
                          vocab_hub_url_for_id_ftr='',
                          embedding_hub_url='',
                          embedding_hub_url_for_id_ftr='',
                          feature_type2name=feature_type2name,
                          feature_name2num=feature_name2num)
        train_model_helper.get_input_fn_common(data_dir, 1,
                                               tf.estimator.ModeKeys.TRAIN,
                                               hparams)
Esempio n. 2
0
    def __init__(self, CLS, SEP, PAD, UNK, vocab_file):
        """ Initializes the vocabulary layer

        :param CLS Token that represents the start of a sentence
        :param SEP Token that represents the end of a segment
        :param PAD Token that represents padding
        :param UNK Token that represents unknown tokens
        :param vocab_file Path to the vocabulary file
        """
        super().__init__()
        self._vocab_table_initializer, self.vocab_table = read_tf_vocab(
            vocab_file, UNK)

        self._CLS = CLS
        self._SEP = SEP
        self._PAD = PAD

        py_vocab_table = read_vocab(vocab_file)
        self._pad_id = py_vocab_table[PAD]
        self._cls_id = py_vocab_table[CLS] if CLS else -1
        self._sep_id = py_vocab_table[SEP] if SEP else -1
        self._vocab_size = len(py_vocab_table)
Esempio n. 3
0
def extend_hparams(hparams):
    # Sanity check for RNN related hparams
    assert hparams.unit_type in [
        'lstm', 'gru', 'layer_norm_lstm'
    ], 'Only support lstm/gru/layer_norm_lstm as unit_type'
    assert hparams.num_layers > 0, 'num_layers must be larger than 0'
    assert hparams.num_residual_layers >= 0, 'num_residual_layers must >= 0'
    assert 0 <= hparams.forget_bias <= 1, 'forget_bias must be within [0.0, 1.0]'
    assert 0 <= hparams.rnn_dropout <= 1, 'rnn_dropout must be within [0.0, 1.0]'

    # Get number of doc/usr text fields
    num_doc_fields = sum(
        [name.startswith('doc_') for name in hparams.feature_names.split(',')])
    hparams.add_hparam("num_doc_fields", num_doc_fields)
    num_usr_fields = sum(
        [name.startswith('usr_') for name in hparams.feature_names.split(',')])
    hparams.add_hparam("num_usr_fields", num_usr_fields)

    # Get number of doc/usr id fields
    num_doc_id_fields = sum([
        name.startswith('docId_') for name in hparams.feature_names.split(',')
    ])
    hparams.add_hparam("num_doc_id_fields", num_doc_id_fields)
    num_usr_id_fields = sum([
        name.startswith('usrId_') for name in hparams.feature_names.split(',')
    ])
    hparams.add_hparam("num_usr_id_fields", num_usr_id_fields)
    if num_doc_id_fields > 0 or num_usr_id_fields > 0:
        assert hparams.vocab_file_for_id_ftr is not None, \
            "Must provide vocab_field_for_id_ftr arg when id features are provided"

    # find vocab size, pad id from vocab file
    vocab_table = vocab_utils.read_vocab(hparams.vocab_file)
    hparams.add_hparam("vocab_size", len(vocab_table))
    hparams.pad_id = vocab_table[hparams.PAD]

    # find vocab size, pad id from vocab file for id features
    if hparams.vocab_file_for_id_ftr is not None:
        vocab_table_for_id_ftr = vocab_utils.read_vocab(
            hparams.vocab_file_for_id_ftr)
        hparams.add_hparam("vocab_size_for_id_ftr",
                           len(vocab_table_for_id_ftr))
        hparams.pad_id_for_id_ftr = vocab_table_for_id_ftr[
            hparams.PAD_FOR_ID_FTR]

    # if there is bert config, check compatibility of between bert parameters and existing parameters
    if hparams.bert_config_file:
        hparams.bert_config = modeling.BertConfig.from_json_file(
            hparams.bert_config_file)
        assert hparams.bert_config.vocab_size == hparams.vocab_size

    # The regex pattern to add a white space before and after. Used for processing text fields.
    tok2regex_pattern = {'plain': None, 'punct': r'(\pP)'}
    hparams.regex_replace_pattern = tok2regex_pattern[hparams.tokenization]

    # if bert, then disable cnn parameters
    if hparams.ftr_ext != 'cnn':
        hparams.filter_window_sizes = '0'

    # convert from string to arrays for filter_window_sizes
    filter_window_sizes_str = hparams.filter_window_sizes
    force_set_hparam(
        hparams, "filter_window_sizes",
        [int(x.strip()) for x in filter_window_sizes_str.split(',')])

    assert hparams.pmetric is not None, "Please set your primary evaluation metric using --pmetric option"
    assert hparams.pmetric != 'confusion_matrix', 'confusion_matrix cannot be used as primary evaluation metric.'

    # Set all relevant evaluation metrics
    all_metrics = hparams.all_metrics.split(',') if hparams.all_metrics else [
        hparams.pmetric
    ]
    assert hparams.pmetric in all_metrics, "pmetric must be within all_metrics"
    force_set_hparam(hparams, "all_metrics", all_metrics)

    # convert from string to arrays for num_hidden
    num_hidden_str = str(hparams.num_hidden)
    force_set_hparam(hparams, "num_hidden",
                     [int(x.strip()) for x in num_hidden_str.split(',')])
    # convert from string to arrays for feature names
    setattr(hparams, 'feature_names', tuple(hparams.feature_names.split(',')))
    # lambda rank
    if hparams.lambda_metric is not None and hparams.lambda_metric == 'ndcg':
        setattr(hparams, 'lambda_metric', {'metric': 'ndcg', 'topk': 10})
    else:
        setattr(hparams, 'lambda_metric', None)
    # feature normalization
    if hparams.std_file:
        # read normalization file
        print('read normalization file')
        ftr_mean, ftr_std = _load_ftr_mean_std(hparams.std_file)
        hparams.add_hparam('ftr_mean', np.array(ftr_mean, dtype=np.float32))
        hparams.add_hparam('ftr_std', np.array(ftr_std, dtype=np.float32))

    # for score rescaling, the score_rescale has the xgboost mean and std.
    if hparams.score_rescale:
        force_set_hparam(hparams, 'score_rescale',
                         [float(x) for x in hparams.score_rescale.split(',')])

    if hparams.explicit_empty:
        assert hparams.ftr_ext == 'cnn', 'explicit_empty will only be True when ftr_ext is cnn'

    # Convert string to arrays for emb_sim_func
    force_set_hparam(hparams, "emb_sim_func", hparams.emb_sim_func.split(','))

    # Checking hparam keep_checkpoint_max: must be >= 0
    if hparams.keep_checkpoint_max:
        assert hparams.keep_checkpoint_max >= 0

    # Classification task
    if hparams.num_classes > 1:
        # For classification tasks, restrict pmetric to be accuracy and use accuracy and confusion_matrix as metrics.
        hparams.pmetric = 'accuracy'
        hparams.all_metrics = ['accuracy', 'confusion_matrix']

    # L1 and L2 scale must be non-negative values
    assert hparams.l1 is None or hparams.l1 >= 0, "l1 scale must be non-negative"
    assert hparams.l2 is None or hparams.l2 >= 0, "l1 scale must be non-negative"
    return hparams
Esempio n. 4
0
class DataSetup:
    """Class containing common setup on file paths, layer params used in unit tests"""
    resource_dir = os.path.join(os.getcwd(), 'test', 'detext', 'resources')
    we_file = os.path.join(resource_dir, 'we.pkl')
    vocab_file = os.path.join(resource_dir, 'vocab.txt')
    vocab_file_for_id_ftr = vocab_file

    vocab_layer_dir = os.path.join(resource_dir, 'vocab_layer')
    embedding_layer_dir = os.path.join(resource_dir, 'embedding_layer')

    bert_hub_url = os.path.join(resource_dir, 'bert-hub')
    libert_sp_hub_url = os.path.join(resource_dir, 'libert-sp-hub')
    libert_space_hub_url = os.path.join(resource_dir, 'libert-space-hub')
    vocab_hub_url = os.path.join(resource_dir, 'vocab_layer_hub')
    embedding_hub_url = os.path.join(resource_dir, 'embedding_layer_hub')

    out_dir = os.path.join(resource_dir, "output")
    data_dir = os.path.join(resource_dir, "train", "dataset", "tfrecord")
    multitask_data_dir = os.path.join(resource_dir, "train", "multitask",
                                      "tfrecord")
    cls_data_dir = os.path.join(resource_dir, "train", "classification",
                                "tfrecord")
    binary_cls_data_dir = os.path.join(resource_dir, "train",
                                       "binary_classification", "tfrecord")
    ranking_data_dir = os.path.join(resource_dir, "train", "ranking",
                                    "tfrecord")

    vocab_table_py = read_vocab(vocab_file)
    vocab_size = len(vocab_table_py)

    CLS = '[CLS]'
    PAD = '[PAD]'
    SEP = '[SEP]'
    UNK = '[UNK]'
    CLS_ID = vocab_table_py[CLS]
    PAD_ID = vocab_table_py[PAD]
    SEP_ID = vocab_table_py[SEP]
    UNK_ID = vocab_table_py[UNK]

    PAD_FOR_ID_FTR = PAD
    UNK_FOR_ID_FTR = UNK

    query = tf.constant(['batch1', 'batch 2 query build'],
                        dtype=tf.dtypes.string)
    query_length = [1, 4]

    user_id_field1 = query
    user_id_field2 = query

    cls_doc_field1 = ['same content build', 'batch 2 field 1 word']
    cls_doc_field2 = ['same content build', 'batch 2 field 2 word']

    ranking_doc_field1 = [[
        'same content build', 'batch 1 doc 2 field able',
        'batch 1 doc 3 field 1'
    ],
                          [
                              'batch 2 doc 1 field word',
                              'batch 2 doc 2 field 1',
                              'batch 2 doc 3 field test'
                          ]]
    ranking_doc_id_field1 = ranking_doc_field1

    ranking_doc_field2 = [[
        'same content build', 'batch 1 doc 2 field test',
        'batch 1 doc 3 field 2'
    ],
                          [
                              'batch 2 doc 1 field test',
                              'batch 2 doc 2 field 2',
                              'batch 2 doc 3 field word'
                          ]]
    ranking_doc_id_field2 = ranking_doc_field2

    cls_sparse_features_1 = [[1.0, 2.0, 4.0], [0.0, -1.0, 4.0]]
    cls_sparse_features_2 = [[2.0, 0.0, 4.0], [2.0, 2.0, 4.0]]
    cls_sparse_features = [
        tf.sparse.from_dense(tf.constant(cls_sparse_features_1)),
        tf.sparse.from_dense(tf.constant(cls_sparse_features_2))
    ]

    ranking_sparse_features_1 = [[[1.0, 2.0, 4.0], [1.0, 2.0, 4.0],
                                  [1.0, 2.0, 4.0]],
                                 [[0.0, -1.0, 4.0], [0.0, -1.0, 4.0],
                                  [0.0, -1.0, 4.0]]]
    ranking_sparse_features_2 = [[[1.0, 2.0, 4.0], [1.0, 2.0, 4.0],
                                  [1.0, 2.0, 4.0]],
                                 [[0.0, -1.0, 4.0], [0.0, -1.0, 4.0],
                                  [0.0, -1.0, 4.0]]]
    ranking_sparse_features = [
        tf.sparse.from_dense(tf.constant(ranking_sparse_features_1)),
        tf.sparse.from_dense(tf.constant(ranking_sparse_features_2))
    ]
    nums_sparse_ftrs = [3]
    total_num_sparse_ftrs = sum(nums_sparse_ftrs)
    sparse_embedding_size = 33

    num_user_fields = 2
    user_fields = [
        tf.constant(query, dtype=tf.dtypes.string),
        tf.constant(query, dtype=tf.dtypes.string)
    ]

    num_doc_fields = 2
    ranking_doc_fields = [
        tf.constant(ranking_doc_field1, dtype=tf.dtypes.string),
        tf.constant(ranking_doc_field2, dtype=tf.dtypes.string)
    ]
    cls_doc_fields = [
        tf.constant(cls_doc_field1, dtype=tf.dtypes.string),
        tf.constant(cls_doc_field2, dtype=tf.dtypes.string)
    ]

    num_user_id_fields = 2
    user_id_fields = user_fields

    num_doc_id_fields = 2
    ranking_doc_id_fields = ranking_doc_fields
    cls_doc_id_fields = cls_doc_fields

    num_id_fields = num_user_id_fields + num_doc_id_fields

    num_units = 6
    num_units_for_id_ftr = num_units

    vocab_layer_param = {
        'CLS': CLS,
        'SEP': SEP,
        'PAD': PAD,
        'UNK': UNK,
        'vocab_file': vocab_file
    }

    embedding_layer_param = {
        'vocab_layer_param': vocab_layer_param,
        'vocab_hub_url': '',
        'we_file': '',
        'we_trainable': True,
        'num_units': num_units
    }

    min_len = 3
    max_len = 7
    filter_window_sizes = [1, 2, 3]
    num_filters = 5

    cnn_param = HParams(filter_window_sizes=filter_window_sizes,
                        num_filters=num_filters,
                        num_doc_fields=num_doc_fields,
                        num_user_fields=num_user_fields,
                        min_len=min_len,
                        max_len=max_len,
                        embedding_layer_param=embedding_layer_param,
                        embedding_hub_url=None)

    id_encoder_param = HParams(num_id_fields=num_id_fields,
                               embedding_layer_param=embedding_layer_param,
                               embedding_hub_url_for_id_ftr=None)
    rep_layer_param = HParams(ftr_ext='cnn',
                              num_doc_fields=num_doc_fields,
                              num_user_fields=num_user_fields,
                              num_doc_id_fields=num_doc_id_fields,
                              num_user_id_fields=num_user_id_fields,
                              add_doc_projection=False,
                              add_user_projection=False,
                              text_encoder_param=cnn_param,
                              id_encoder_param=id_encoder_param)
Esempio n. 5
0
class TestFeatureGrouper(tf.test.TestCase, DataSetup):
    """Unit test for feature_grouper.py"""
    _, vocab_tf_table = vocab_utils.read_tf_vocab(DataSetup.vocab_file,
                                                  '[UNK]')
    vocab_table = vocab_utils.read_vocab(DataSetup.vocab_file)

    PAD_ID = vocab_table[DataSetup.PAD]
    SEP_ID = vocab_table[DataSetup.SEP]
    CLS_ID = vocab_table[DataSetup.CLS]
    UNK_ID = vocab_table[DataSetup.UNK]

    max_filter_window_size = 0

    def testFeatureGrouperKerasInput(self):
        """Tests FeatureGrouper with tf.keras.Input"""
        nums_dense_ftrs = [2, 3]
        nums_sparse_ftrs = [10, 30]
        layer = FeatureGrouper()
        inputs = {
            InputFtrType.QUERY_COLUMN_NAME:
            tf.keras.Input(shape=(), dtype='string'),
            InputFtrType.USER_TEXT_COLUMN_NAMES:
            [tf.keras.Input(shape=(), dtype='string')],
            InputFtrType.USER_ID_COLUMN_NAMES:
            [tf.keras.Input(shape=(), dtype='string')],
            InputFtrType.DOC_TEXT_COLUMN_NAMES:
            [tf.keras.Input(shape=(None, ), dtype='string')],
            InputFtrType.DOC_ID_COLUMN_NAMES:
            [tf.keras.Input(shape=(None, ), dtype='string')],
            InputFtrType.DENSE_FTRS_COLUMN_NAMES: [
                tf.keras.Input(shape=(num_dense_ftrs, ), dtype='float32')
                for num_dense_ftrs in nums_dense_ftrs
            ],
            InputFtrType.SPARSE_FTRS_COLUMN_NAMES: [
                tf.keras.Input(shape=(num_sparse_ftrs, ),
                               dtype='float32',
                               sparse=True)
                for num_sparse_ftrs in nums_sparse_ftrs
            ]
        }
        outputs = layer(inputs)
        self.assertLen(outputs, len(inputs))

    def testFeatureGrouperTensor(self):
        """Tests FeatureGrouper with tensor input"""
        layer = FeatureGrouper()
        inputs = {
            InputFtrType.QUERY_COLUMN_NAME:
            tf.constant(['batch 1 user 1 build', 'batch 2 user 2 word'],
                        dtype=tf.string),
            InputFtrType.DENSE_FTRS_COLUMN_NAMES: [
                tf.constant([[1, 1], [2, 2]], dtype=tf.float32),
                tf.constant([[0], [1]], dtype=tf.float32)
            ],
            InputFtrType.SPARSE_FTRS_COLUMN_NAMES: [
                tf.sparse.from_dense(
                    tf.constant([[1, 0], [2, 0]], dtype=tf.float32)),
                tf.sparse.from_dense(tf.constant([[1], [1]], dtype=tf.float32))
            ]
        }
        expected_result = {
            InputFtrType.QUERY_COLUMN_NAME:
            tf.constant(['batch 1 user 1 build', 'batch 2 user 2 word'],
                        dtype=tf.string),
            InputFtrType.DENSE_FTRS_COLUMN_NAMES:
            tf.constant([[1, 1, 0], [2, 2, 1]]),
            InputFtrType.SPARSE_FTRS_COLUMN_NAMES: [
                tf.constant([[1, 0], [2, 0]], dtype=tf.float32),
                tf.constant([[1], [1]], dtype=tf.float32)
            ]
        }
        outputs = layer(inputs)

        self.assertEqual(
            len(outputs),
            len(expected_result)), "Outputs must have the same shape"
        for ftr_type, expected_ftr in expected_result.items():
            output = outputs[ftr_type]
            if ftr_type == InputFtrType.SPARSE_FTRS_COLUMN_NAMES:
                output = [tf.sparse.to_dense(t) for t in output]
                for e, o in zip(expected_ftr, output):
                    self.assertAllEqual(e, o)
                continue
            self.assertAllEqual(expected_ftr, output)

    def testConcatFtrOnLastDim(self):
        """Tests concatenate features on last dimension"""
        tensor_lst = [
            tf.constant([1, 2, 3], dtype='int32'),
            tf.constant([4, 5, 6], dtype='int32')
        ]
        result = feature_grouper.concat_on_last_axis_dense(tensor_lst)
        expected_output = tf.constant([1, 2, 3, 4, 5, 6], dtype='int32')
        self.assertAllEqual(result, expected_output)
Esempio n. 6
0
def extend_hparams(hparams):
    # Sanity check for RNN related hparams
    assert hparams.unit_type in [
        'lstm', 'gru', 'layer_norm_lstm'
    ], 'Only support lstm/gru/layer_norm_lstm as unit_type'
    assert hparams.num_layers > 0, 'num_layers must be larger than 0'
    assert hparams.num_residual_layers >= 0, 'num_residual_layers must >= 0'
    assert 0 <= hparams.forget_bias <= 1, 'forget_bias must be within [0.0, 1.0]'
    assert 0 <= hparams.rnn_dropout <= 1, 'rnn_dropout must be within [0.0, 1.0]'

    # Get number of doc/usr text fields
    num_doc_fields = sum(
        [name.startswith('doc_') for name in hparams.feature_names])
    hparams.add_hparam("num_doc_fields", num_doc_fields)
    num_usr_fields = sum(
        [name.startswith('usr_') for name in hparams.feature_names])
    hparams.add_hparam("num_usr_fields", num_usr_fields)

    # Get number of doc/usr id fields
    num_doc_id_fields = sum(
        [name.startswith('docId_') for name in hparams.feature_names])
    hparams.add_hparam("num_doc_id_fields", num_doc_id_fields)
    num_usr_id_fields = sum(
        [name.startswith('usrId_') for name in hparams.feature_names])
    hparams.add_hparam("num_usr_id_fields", num_usr_id_fields)
    if num_doc_id_fields > 0 or num_usr_id_fields > 0:
        assert hparams.vocab_file_for_id_ftr is not None, \
            "Must provide vocab_field_for_id_ftr arg when id features are provided"

    # find vocab size, pad id from vocab file
    vocab_table = vocab_utils.read_vocab(hparams.vocab_file)
    hparams.add_hparam("vocab_size", len(vocab_table))
    hparams.pad_id = vocab_table[hparams.PAD]

    # find vocab size, pad id from vocab file for id features
    if hparams.vocab_file_for_id_ftr is not None:
        vocab_table_for_id_ftr = vocab_utils.read_vocab(
            hparams.vocab_file_for_id_ftr)
        hparams.add_hparam("vocab_size_for_id_ftr",
                           len(vocab_table_for_id_ftr))
        hparams.pad_id_for_id_ftr = vocab_table_for_id_ftr[
            hparams.PAD_FOR_ID_FTR]

    # if there is bert config, check compatibility of between bert parameters and existing parameters
    if hparams.bert_config_file:
        hparams.bert_config = modeling.BertConfig.from_json_file(
            hparams.bert_config_file)
        assert hparams.bert_config.vocab_size == hparams.vocab_size

    # The regex pattern to add a white space before and after. Used for processing text fields.
    tok2regex_pattern = {'plain': None, 'punct': r'(\pP)'}
    hparams.regex_replace_pattern = tok2regex_pattern[hparams.tokenization]

    # if not using cnn models, then disable cnn parameters
    if hparams.ftr_ext != 'cnn':
        hparams.filter_window_sizes = [0]

    assert hparams.pmetric is not None, "Please set your primary evaluation metric using --pmetric option"
    assert hparams.pmetric != 'confusion_matrix', 'confusion_matrix cannot be used as primary evaluation metric.'

    # Set all relevant evaluation metrics
    all_metrics = hparams.all_metrics if hparams.all_metrics else [
        hparams.pmetric
    ]
    assert hparams.pmetric in all_metrics, "pmetric must be within all_metrics"
    force_set_hparam(hparams, "all_metrics", all_metrics)

    # lambda rank
    if hparams.lambda_metric is not None and hparams.lambda_metric == 'ndcg':
        setattr(hparams, 'lambda_metric', {'metric': 'ndcg', 'topk': 10})
    else:
        setattr(hparams, 'lambda_metric', None)
    # feature normalization
    if hparams.std_file:
        # read normalization file
        print('read normalization file')
        ftr_mean, ftr_std = _load_ftr_mean_std(hparams.std_file)
        hparams.add_hparam('ftr_mean', np.array(ftr_mean, dtype=np.float32))
        hparams.add_hparam('ftr_std', np.array(ftr_std, dtype=np.float32))

    if hparams.explicit_empty:
        assert hparams.ftr_ext == 'cnn', 'explicit_empty will only be True when ftr_ext is cnn'

    # Checking hparam keep_checkpoint_max: must be >= 0
    if hparams.keep_checkpoint_max:
        assert hparams.keep_checkpoint_max >= 0

    # Classification task
    if hparams.num_classes > 1:
        # For classification tasks, restrict pmetric to be accuracy and use accuracy and confusion_matrix as metrics.
        hparams.pmetric = 'accuracy'
        hparams.all_metrics = ['accuracy', 'confusion_matrix']

    # L1 and L2 scale must be non-negative values
    assert hparams.l1 is None or hparams.l1 >= 0, "l1 scale must be non-negative"
    assert hparams.l2 is None or hparams.l2 >= 0, "l1 scale must be non-negative"

    # Multi-task training: currently only support ranking tasks with both deep and wide features
    if hparams.task_ids:
        # Check related inputs for multi-task training
        assert 'wide_ftrs_sp_idx' not in hparams.feature_names, "multi-task with sparse features not supported"
        assert 'task_id' in hparams.feature_names, "task_id feature not found for multi-task training"

        # Parse task ids an weights from inputs and convert them into a map
        task_ids = hparams.task_ids
        raw_weights = hparams.task_weights if hparams.task_weights else [
            1.0
        ] * len(task_ids)
        task_weights = [float(wt) / sum(raw_weights)
                        for wt in raw_weights]  # Normalize task weights

        # Check size of task_ids and task_weights
        assert len(task_ids) == len(
            task_weights), "size of task IDs and weights must match"

        force_set_hparam(hparams, "task_weights", task_weights)

    return hparams
Esempio n. 7
0
class TestBertLayer(tf.test.TestCase, DataSetup):
    """Unit test for bert_layer.py"""
    # Bert setup
    bert_hub_layer = hub.KerasLayer(hub.resolve(DataSetup.bert_hub_url),
                                    trainable=True)

    bert_vocab_file = bert_hub_layer.resolved_object.vocab_file.asset_path.numpy(
    ).decode("utf-8")
    bert_vocab_table = read_vocab(bert_vocab_file)

    bert_PAD_ID = bert_vocab_table[DataSetup.PAD]
    bert_SEP_ID = bert_vocab_table[DataSetup.SEP]
    bert_CLS_ID = bert_vocab_table[DataSetup.CLS]
    bert_UNK_ID = bert_vocab_table[DataSetup.UNK]

    # SentencePiece setup
    sentencepiece_hub_layer = hub.KerasLayer(
        hub.resolve(DataSetup.libert_sp_hub_url))

    tokenizer_file = sentencepiece_hub_layer.resolved_object.tokenizer_file.asset_path.numpy(
    ).decode("utf-8")
    with tf.io.gfile.GFile(tokenizer_file, 'rb') as f_handler:
        sp_model = f_handler.read()

    sentencepiece_tokenizer = tf_text.SentencepieceTokenizer(model=sp_model,
                                                             out_type=tf.int32)
    sentencepiece_vocab_tf_table = create_tf_vocab_from_sp_tokenizer(
        sp_tokenizer=sentencepiece_tokenizer, num_oov_buckets=1)

    sentencepiece_PAD_ID = sentencepiece_vocab_tf_table.lookup(
        tf.constant(DataSetup.PAD)).numpy()
    sentencepiece_SEP_ID = sentencepiece_vocab_tf_table.lookup(
        tf.constant(DataSetup.SEP)).numpy()
    sentencepiece_CLS_ID = sentencepiece_vocab_tf_table.lookup(
        tf.constant(DataSetup.CLS)).numpy()
    sentencepiece_UNK_ID = sentencepiece_vocab_tf_table.lookup(
        tf.constant(DataSetup.UNK)).numpy()

    # Space setup
    space_hub_layer = hub.KerasLayer(
        hub.resolve(DataSetup.libert_space_hub_url))

    tokenizer_file = space_hub_layer.resolved_object.tokenizer_file.asset_path.numpy(
    ).decode("utf-8")
    space_vocab = read_vocab(tokenizer_file)

    space_PAD_ID = space_vocab[DataSetup.PAD]
    space_SEP_ID = space_vocab[DataSetup.SEP]
    space_CLS_ID = space_vocab[DataSetup.CLS]
    space_UNK_ID = space_vocab[DataSetup.UNK]

    # Hyperparameters setup
    num_units = 16
    pad_id = 0
    num_doc_fields = 2

    min_len = 3
    max_len = 8

    layer = bert_layer.BertLayer(num_units, DataSetup.CLS, DataSetup.SEP,
                                 DataSetup.PAD, DataSetup.UNK, min_len,
                                 max_len, DataSetup.bert_hub_url)

    def testBertLayer(self):
        """Test Bert layer """
        for hub_url in [self.bert_hub_url]:
            self._testBertLayer(hub_url)

    def _testBertLayer(self, hub_url):
        query = self.query

        doc_fields = [self.ranking_doc_field1, self.ranking_doc_field2]
        user_fields = [query, query, query]

        query_ftrs, doc_ftrs, user_ftrs = self.layer(
            {
                InputFtrType.QUERY_COLUMN_NAME: query,
                InputFtrType.DOC_TEXT_COLUMN_NAMES: doc_fields,
                InputFtrType.USER_TEXT_COLUMN_NAMES: user_fields
            }, False)

        text_ftr_size = self.num_units

        self.assertEqual(text_ftr_size, self.layer.text_ftr_size)
        self.assertAllEqual(query_ftrs.shape, [2, self.layer.text_ftr_size])
        self.assertAllEqual(doc_ftrs.shape,
                            [2, 3, 2, self.layer.text_ftr_size])
        self.assertAllEqual(user_ftrs.shape, [2, 3, self.layer.text_ftr_size])
        # 1st query, 2nd doc, 2nd field should be the same as 2nd query, 1st doc, 2nd field
        self.assertAllEqual(doc_ftrs[0, 1, 1], doc_ftrs[1, 0, 1])
        # 1st query, 1st doc, 1st field should be the same as 1st query, 1st doc, 2nd field
        self.assertAllEqual(doc_ftrs[0, 0, 0], doc_ftrs[0, 0, 1])
        # 1st query, 1st doc, 2st field should NOT be the same as 1st query, 2st doc, 2nd field
        self.assertNotAllClose(doc_ftrs[0, 1, 0], doc_ftrs[0, 1, 1])

    def testGetInputIds(self):
        """Tests get_input_ids() """
        query = [[1, 2, 3], [2, 4, 3]]
        doc_field1 = [[[1, 2, 3, 0], [2, 4, 3, 1], [0, 0, 0, 0]],
                      [[2, 4, 3, 1], [1, 3, 3, 1], [1, 3, 3, 1]]]
        doc_field2 = [[[1, 2, 3, 0], [2, 4, 3, 1], [0, 0, 0, 0]],
                      [[20, 5, 3, 1], [5, 6, 1, 1], [5, 6, 1, 1]]]
        query = tf.constant(query, dtype=tf.int32)
        doc_field1 = tf.constant(doc_field1, dtype=tf.int32)
        doc_field2 = tf.constant(doc_field2, dtype=tf.int32)
        doc_fields = [doc_field1, doc_field2]
        user_fields = None

        max_text_len, max_text_len_array = bert_layer.BertLayer.get_input_max_len(
            query, doc_fields, user_fields)
        bert_input_ids = bert_layer.BertLayer.get_bert_input_ids(
            query, doc_fields, user_fields, self.pad_id, max_text_len,
            max_text_len_array)

        # Check bert input ids
        self.assertAllEqual(
            bert_input_ids,
            [[1, 2, 3, 0], [2, 4, 3, 0], [1, 2, 3, 0], [2, 4, 3, 1],
             [0, 0, 0, 0], [2, 4, 3, 1], [1, 3, 3, 1], [1, 3, 3, 1],
             [1, 2, 3, 0], [2, 4, 3, 1], [0, 0, 0, 0], [20, 5, 3, 1],
             [5, 6, 1, 1], [5, 6, 1, 1]])

    def testPrerocessQuery(self):
        """Tests _preprocess_query function of bert layer"""
        query = tf.constant(['batch 1 user 1 build', 'batch 2 user 2 word'],
                            dtype=tf.string)

        expected = tf.constant(
            [[
                self.CLS_ID, self.bert_UNK_ID, self.bert_UNK_ID,
                self.bert_UNK_ID, self.bert_UNK_ID, 4, 2
            ],
             [
                 self.CLS_ID, self.bert_UNK_ID, self.bert_UNK_ID,
                 self.bert_UNK_ID, self.bert_UNK_ID, 5, 2
             ]],
            dtype=tf.int32)

        self.assertAllEqual(expected, self.layer._preprocess_query(query))

    def testPrerocessUsr(self):
        """Tests _preprocess_user function of bert layer"""
        user_fields = [
            tf.constant(['batch 1 user 1 build', 'batch 2 user 2 word'],
                        dtype=tf.string)
        ]

        expected = [
            tf.constant([[
                self.CLS_ID, self.bert_UNK_ID, self.bert_UNK_ID,
                self.bert_UNK_ID, self.bert_UNK_ID, 4, 2
            ],
                         [
                             self.CLS_ID, self.bert_UNK_ID, self.bert_UNK_ID,
                             self.bert_UNK_ID, self.bert_UNK_ID, 5, 2
                         ]],
                        dtype=tf.int32)
        ]

        self.assertAllEqual(expected, self.layer._preprocess_user(user_fields))

    def testPrerocessDoc(self):
        """Tests _preprocess_doc function of bert layer"""
        doc_fields = [
            tf.constant([['batch 1 doc 1 build', 'batch 1 doc 2'],
                         ['batch 2 doc 1 word', 'batch 2 doc 2']],
                        dtype=tf.string)
        ]

        expected = [
            tf.constant([[[
                self.CLS_ID, self.bert_UNK_ID, self.bert_UNK_ID,
                self.bert_UNK_ID, self.bert_UNK_ID, 4, 2
            ],
                          [
                              self.CLS_ID, self.bert_UNK_ID, self.bert_UNK_ID,
                              self.bert_UNK_ID, self.bert_UNK_ID, 2, 3
                          ]],
                         [[
                             self.CLS_ID, self.bert_UNK_ID, self.bert_UNK_ID,
                             self.bert_UNK_ID, self.bert_UNK_ID, 5, 2
                         ],
                          [
                              self.CLS_ID, self.bert_UNK_ID, self.bert_UNK_ID,
                              self.bert_UNK_ID, self.bert_UNK_ID, 2, 3
                          ]]],
                        dtype=tf.int32)
        ]

        self.assertAllEqual(expected, self.layer._preprocess_doc(doc_fields))

    def testBertPreprocessLayerWordPiece(self):
        """Tests BertPreprocessLayer with wordpiece tokenizer"""

        preprocess_layer = bert_layer.BertPreprocessLayer(
            self.bert_hub_layer, self.max_len, self.min_len, self.CLS,
            self.SEP, self.PAD, self.UNK)

        sentences = tf.constant(
            ['test sent1', 'build build build build sent2'])

        expected = tf.constant([[
            self.bert_CLS_ID, 8, self.bert_UNK_ID, self.bert_SEP_ID,
            self.bert_PAD_ID, self.bert_PAD_ID, self.bert_PAD_ID
        ], [self.bert_CLS_ID, 4, 4, 4, 4, self.bert_UNK_ID, self.bert_SEP_ID]],
                               dtype=tf.int32)
        outputs = preprocess_layer(sentences)
        self.assertAllEqual(expected, outputs)

    def testBertPreprocessLayerSentencePiece(self):
        """Tests BertPreprocessLayer with sentencepiece tokenizer"""

        preprocess_layer = bert_layer.BertPreprocessLayer(
            self.sentencepiece_hub_layer, self.max_len, self.min_len, self.CLS,
            self.SEP, self.PAD, self.UNK)

        sentences = tf.constant(
            ['TEST sent1', 'build build build build sent2'])

        expected = tf.constant([[
            self.sentencepiece_CLS_ID, 557, 4120, 29900,
            self.sentencepiece_SEP_ID, self.sentencepiece_PAD_ID,
            self.sentencepiece_PAD_ID, self.sentencepiece_PAD_ID
        ],
                                [
                                    self.sentencepiece_CLS_ID, 671, 671, 671,
                                    671, 4120, 29904, self.sentencepiece_SEP_ID
                                ]],
                               dtype=tf.int32)

        outputs = preprocess_layer(sentences)
        self.assertAllEqual(expected, outputs)

    def testBertPreprocessLayerSpace(self):
        """Tests BertPreprocessLayer with space tokenizer"""

        preprocess_layer = bert_layer.BertPreprocessLayer(
            self.space_hub_layer, self.max_len, self.min_len, self.CLS,
            self.SEP, self.PAD, self.UNK)

        sentences = tf.constant(
            ['test sent1', 'build build build build sent2'])

        expected = tf.constant([[
            self.space_CLS_ID, 8, self.space_UNK_ID, self.space_SEP_ID,
            self.space_PAD_ID, self.space_PAD_ID, self.space_PAD_ID
        ], [
            self.space_CLS_ID, 4, 4, 4, 4, self.space_UNK_ID, self.space_SEP_ID
        ]],
                               dtype=tf.int32)

        outputs = preprocess_layer(sentences)
        self.assertAllEqual(expected, outputs)

    def testBertPreprocessLayerAdjustLen(self):
        """Tests adjust_len function of BertPreprocessLayer"""

        sentences = tf.constant(
            ['test sent1', 'build build build build sent2'])

        min_len = 12
        max_len = 16

        layer = bert_layer.BertPreprocessLayer(self.bert_hub_layer, max_len,
                                               min_len, self.CLS, self.SEP,
                                               self.PAD, self.UNK)

        outputs = layer(sentences)
        shape = tf.shape(outputs)

        self.assertAllEqual(shape, tf.constant([2, 12]))

        min_len = 0
        max_len = 1

        layer = bert_layer.BertPreprocessLayer(self.bert_hub_layer, max_len,
                                               min_len, self.CLS, self.SEP,
                                               self.PAD, self.UNK)

        outputs = layer(sentences)
        shape = tf.shape(outputs)

        self.assertAllEqual(shape, tf.constant([2, 1]))
Esempio n. 8
0
class TestData(tf.test.TestCase, DataSetup):
    """Unit test for data_fn."""
    _, vocab_tf_table = vocab_utils.read_tf_vocab(DataSetup.vocab_file, '[UNK]')
    vocab_table = vocab_utils.read_vocab(DataSetup.vocab_file)

    CLS = '[CLS]'
    PAD = '[PAD]'
    SEP = '[SEP]'

    PAD_ID = vocab_table[PAD]
    SEP_ID = vocab_table[SEP]
    CLS_ID = vocab_table[CLS]

    nums_sparse_ftrs = [20]

    def testRankingInputFnBuilderTfrecord(self):
        """ Tests function input_fn_builder() """
        one_device_strategy = distribution_utils.get_distribution_strategy('one_device', num_gpus=0)
        feature_type2name_list = [
            # Contains sparse features
            {InputFtrType.LABEL_COLUMN_NAME: 'label',
             InputFtrType.QUERY_COLUMN_NAME: 'query',
             InputFtrType.DOC_TEXT_COLUMN_NAMES: ['doc_headline', 'doc_title'],
             InputFtrType.USER_TEXT_COLUMN_NAMES: ['user_headline', 'user_title'],
             InputFtrType.DOC_ID_COLUMN_NAMES: ['doc_headline_id'],
             InputFtrType.USER_ID_COLUMN_NAMES: ['user_headline_id'],
             InputFtrType.DENSE_FTRS_COLUMN_NAMES: ['dense_ftrs'],
             InputFtrType.SPARSE_FTRS_COLUMN_NAMES: ['sparse_ftrs'],
             InputFtrType.WEIGHT_COLUMN_NAME: 'weight'
             },
            # No sparse features
            {InputFtrType.LABEL_COLUMN_NAME: 'label',
             InputFtrType.QUERY_COLUMN_NAME: 'query',
             InputFtrType.DOC_TEXT_COLUMN_NAMES: ['doc_headline', 'doc_title'],
             InputFtrType.USER_TEXT_COLUMN_NAMES: ['user_headline', 'user_title'],
             InputFtrType.DOC_ID_COLUMN_NAMES: ['doc_headline_id'],
             InputFtrType.USER_ID_COLUMN_NAMES: ['user_headline_id'],
             InputFtrType.DENSE_FTRS_COLUMN_NAMES: ['dense_ftrs'],
             InputFtrType.WEIGHT_COLUMN_NAME: 'weight'
             },
            # Sparse features only
            {InputFtrType.LABEL_COLUMN_NAME: 'label',
             InputFtrType.SPARSE_FTRS_COLUMN_NAMES: ['sparse_ftrs']}
        ]
        strategy_list = [None, one_device_strategy]

        for strategy, feature_type2name in product(strategy_list, feature_type2name_list):
            self._testRankingInputFnBuilderTfrecord(strategy, feature_type2name)

    def _testRankingInputFnBuilderTfrecord(self, strategy, feature_type2name):
        """ Tests function input_fn_builder() for given strategy """
        data_dir = self.ranking_data_dir
        feature_name2num = {'dense_ftrs': 2, 'sparse_ftrs': self.nums_sparse_ftrs[0]}

        def _input_fn_tfrecord(ctx):
            return data_fn.input_fn_tfrecord(input_pattern=data_dir,
                                             batch_size=batch_size,
                                             mode=tf.estimator.ModeKeys.EVAL,
                                             feature_type2name=feature_type2name,
                                             feature_name2num=feature_name2num,
                                             input_pipeline_context=ctx)

        batch_size = 2
        if strategy is not None:
            dataset = strategy.distribute_datasets_from_function(_input_fn_tfrecord)
        else:
            dataset = _input_fn_tfrecord(None)

        # Make iterator
        for features, label in dataset:
            for ftr_type, ftr_name_lst in iterate_items_with_list_val(feature_type2name):
                if ftr_type in (InputFtrType.LABEL_COLUMN_NAME, InputFtrType.WEIGHT_COLUMN_NAME, InputFtrType.UID_COLUMN_NAME):
                    self.assertLen(ftr_name_lst, 1), f'Length for current ftr type ({ftr_type}) should be 1'
                    ftr_name = ftr_name_lst[0]
                    self.assertIn(ftr_name, label)
                    continue

                for ftr_name in ftr_name_lst:
                    self.assertIn(ftr_name, features)
                    # First dimension of data should be batch_size
                    self.assertTrue(features[ftr_name].shape[0] == batch_size)

            weight_ftr_name = feature_type2name.get(InputFtrType.WEIGHT_COLUMN_NAME, constant.Constant()._DEFAULT_WEIGHT_FTR_NAME)
            self.assertAllEqual(tf.shape(label[weight_ftr_name]), [batch_size])

            uid_ftr_name = feature_type2name.get(InputFtrType.UID_COLUMN_NAME, constant.Constant()._DEFAULT_UID_FTR_NAME)
            self.assertAllEqual(tf.shape(label[uid_ftr_name]), [batch_size])

            # First dimension of data should be batch_size
            self.assertEqual(label['label'].shape[0], batch_size)

            if InputFtrType.DOC_TEXT_COLUMN_NAMES in feature_type2name:
                self.assertAllEqual(features['doc_title'],
                                    tf.constant(
                                        [["document title 1", b"title 2 ?", b"doc title 3 ?", b"doc title 4 ?"],
                                         ["document title 1", b"title 2 ?", b"doc title 3 ?", b""]]
                                    ))

            if InputFtrType.DOC_ID_COLUMN_NAMES in feature_type2name:
                self.assertAllEqual(features['doc_headline_id'],
                                    tf.constant(
                                        [[b"document headline id 1", b"headline id 2 ?", b"doc headline id 3 ?", b"doc headline id 4 ?"],
                                         [b"document headline id 1", b"headline id 2 ?", b"doc headline id 3 ?", b""]]
                                    ))

            if InputFtrType.USER_TEXT_COLUMN_NAMES in feature_type2name:
                self.assertAllEqual(features['user_title'],
                                    tf.constant(
                                        [b"user title", b"user title"]
                                    ))
            if InputFtrType.USER_ID_COLUMN_NAMES in feature_type2name:
                self.assertAllEqual(features['user_headline_id'],
                                    tf.constant(
                                        [b"user headline id", b"user headline id"]
                                    ))

            if InputFtrType.DENSE_FTRS_COLUMN_NAMES in feature_type2name:
                self.assertAllEqual(features['dense_ftrs'],
                                    tf.constant(
                                        [[[23.0, 14.0], [44.0, -1.0], [22.0, 19.0], [22.0, 19.0]],
                                         [[23.0, 14.0], [44.0, -1.0], [22.0, 19.0], [0.0, 0.0]]]
                                    ))

            if InputFtrType.SPARSE_FTRS_COLUMN_NAMES in feature_type2name:
                self.assertAllEqual(tf.sparse.to_dense(features['sparse_ftrs']),
                                    tf.sparse.to_dense(tf.SparseTensor(indices=[[0, 0, 1],
                                                                                [0, 0, 5],
                                                                                [0, 1, 0],
                                                                                [0, 2, 2],
                                                                                [0, 3, 8],
                                                                                [1, 0, 1],
                                                                                [1, 0, 5],
                                                                                [1, 1, 0],
                                                                                [1, 2, 2]],
                                                                       values=[1., 5., 7., 12., -8., 1., 5., 7., 12.],
                                                                       dense_shape=[batch_size, 4, self.nums_sparse_ftrs[0]]))
                                    )

            # Only check the first batch
            break

    def testClassificationInputFnBuilderTfrecord(self):
        """Test classification input reader in eval mode"""
        data_dir = self.cls_data_dir

        feature_type2name = {
            InputFtrType.LABEL_COLUMN_NAME: 'label',
            InputFtrType.DOC_TEXT_COLUMN_NAMES: ['query_text'],
            InputFtrType.USER_TEXT_COLUMN_NAMES: ['user_headline'],
            InputFtrType.DENSE_FTRS_COLUMN_NAMES: 'dense_ftrs',
        }
        feature_name2num = {
            'dense_ftrs': 8
        }

        batch_size = 2
        dataset = data_fn.input_fn_tfrecord(input_pattern=data_dir,
                                            batch_size=batch_size,
                                            mode=tf.estimator.ModeKeys.EVAL,
                                            task_type=TaskType.CLASSIFICATION,
                                            feature_type2name=feature_type2name,
                                            feature_name2num=feature_name2num)

        for features, label in dataset:
            # First dimension of data should be batch_size
            for ftr_type, ftr_name_lst in iterate_items_with_list_val(feature_type2name):
                if ftr_type in (InputFtrType.LABEL_COLUMN_NAME, InputFtrType.WEIGHT_COLUMN_NAME, InputFtrType.UID_COLUMN_NAME):
                    self.assertLen(ftr_name_lst, 1), f'Length for current ftr type ({ftr_type}) should be 1'
                    ftr_name = ftr_name_lst[0]
                    self.assertIn(ftr_name, label)
                    continue
                for ftr_name in ftr_name_lst:
                    self.assertIn(ftr_name, features)
                    self.assertEqual(features[ftr_name].shape[0], batch_size)

            weight_ftr_name = feature_type2name.get(InputFtrType.WEIGHT_COLUMN_NAME, constant.Constant()._DEFAULT_WEIGHT_FTR_NAME)
            self.assertAllEqual(tf.shape(label[weight_ftr_name]), [batch_size])

            uid_ftr_name = feature_type2name.get(InputFtrType.UID_COLUMN_NAME, constant.Constant()._DEFAULT_UID_FTR_NAME)
            self.assertAllEqual(tf.shape(label[uid_ftr_name]), [batch_size])

            self.assertAllEqual(label['label'].shape, [batch_size])

    def testBinaryClassificationInputFnBuilderTfrecord(self):
        """Test binary classification input reader """
        data_dir = self.binary_cls_data_dir

        feature_type2name = {
            InputFtrType.LABEL_COLUMN_NAME: 'label',
            InputFtrType.SPARSE_FTRS_COLUMN_NAMES: ['sparse_ftrs'],
            InputFtrType.SHALLOW_TOWER_SPARSE_FTRS_COLUMN_NAMES: ['shallow_tower_sparse_ftrs', 'sparse_ftrs']
        }
        feature_name2num = {
            'sparse_ftrs': 20,
            'shallow_tower_sparse_ftrs': 20
        }

        batch_size = 2
        dataset = data_fn.input_fn_tfrecord(input_pattern=data_dir,
                                            batch_size=batch_size,
                                            mode=tf.estimator.ModeKeys.EVAL,
                                            task_type=TaskType.BINARY_CLASSIFICATION,
                                            feature_type2name=feature_type2name,
                                            feature_name2num=feature_name2num
                                            )

        for features, label in dataset:
            # First dimension of data should be batch_size
            for ftr_type, ftr_name_lst in iterate_items_with_list_val(feature_type2name):
                if ftr_type in (InputFtrType.LABEL_COLUMN_NAME, InputFtrType.WEIGHT_COLUMN_NAME, InputFtrType.UID_COLUMN_NAME):
                    self.assertLen(ftr_name_lst, 1), f'Length for current ftr type ({ftr_type}) should be 1'
                    ftr_name = ftr_name_lst[0]
                    self.assertIn(ftr_name, label)
                    continue
                for ftr_name in ftr_name_lst:
                    self.assertIn(ftr_name, features)
                    self.assertEqual(features[ftr_name].shape[0], batch_size)

            weight_ftr_name = feature_type2name.get(InputFtrType.WEIGHT_COLUMN_NAME, constant.Constant()._DEFAULT_WEIGHT_FTR_NAME)
            self.assertAllEqual(tf.shape(label[weight_ftr_name]), [batch_size])

            uid_ftr_name = feature_type2name.get(InputFtrType.UID_COLUMN_NAME, constant.Constant()._DEFAULT_UID_FTR_NAME)
            self.assertAllEqual(tf.shape(label[uid_ftr_name]), [batch_size])

            self.assertAllEqual(label['label'].shape, [batch_size])
            self.assertAllEqual(tf.sparse.to_dense(features['sparse_ftrs']),
                                tf.sparse.to_dense(
                                    tf.SparseTensor(indices=[[0, 0],
                                                             [0, 2],
                                                             [0, 7],
                                                             [1, 0],
                                                             [1, 2],
                                                             [1, 7]],
                                                    values=[1, 0, 7, 1, 0, 7],
                                                    dense_shape=[batch_size, self.nums_sparse_ftrs[0]])
                                )
                                )

            # Only check first batch
            break

    def testRankingMultitaskInputFnBuilderTfrecord(self):
        """Test additional input from multitask training in eval mode"""
        data_dir = self.ranking_data_dir

        # Test minimum features required for multitask jobs
        feature_type2name = {
            InputFtrType.LABEL_COLUMN_NAME: 'label',
            InputFtrType.QUERY_COLUMN_NAME: 'query',
            InputFtrType.DOC_TEXT_COLUMN_NAMES: ['doc_headline', 'doc_title'],
            InputFtrType.USER_TEXT_COLUMN_NAMES: ['user_headline', 'user_title'],
            InputFtrType.DOC_ID_COLUMN_NAMES: ['doc_headline_id'],
            InputFtrType.USER_ID_COLUMN_NAMES: ['user_headline_id'],
            InputFtrType.DENSE_FTRS_COLUMN_NAMES: ['dense_ftrs'],
            InputFtrType.WEIGHT_COLUMN_NAME: 'weight',
            InputFtrType.TASK_ID_COLUMN_NAME: 'task_id_field'
        }
        feature_name2num = {
            'dense_ftrs': 2
        }

        batch_size = 5
        dataset = data_fn.input_fn_tfrecord(input_pattern=data_dir,
                                            batch_size=batch_size,
                                            mode=tf.estimator.ModeKeys.EVAL,
                                            feature_type2name=feature_type2name,
                                            feature_name2num=feature_name2num)

        for features, label in dataset:
            # First dimension of data should be batch_size
            for ftr_type, ftr_name_lst in iterate_items_with_list_val(feature_type2name):
                if ftr_type in (InputFtrType.LABEL_COLUMN_NAME, InputFtrType.WEIGHT_COLUMN_NAME, InputFtrType.UID_COLUMN_NAME):
                    self.assertLen(ftr_name_lst, 1), f'Length for current ftr type ({ftr_type}) should be 1'
                    ftr_name = ftr_name_lst[0]
                    self.assertIn(ftr_name, label)
                    continue
                for ftr_name in ftr_name_lst:
                    self.assertIn(ftr_name, features)
                    self.assertEqual(features[ftr_name].shape[0], batch_size)

            weight_ftr_name = feature_type2name.get(InputFtrType.WEIGHT_COLUMN_NAME, constant.Constant()._DEFAULT_WEIGHT_FTR_NAME)
            self.assertAllEqual(tf.shape(label[weight_ftr_name]), [batch_size])

            uid_ftr_name = feature_type2name.get(InputFtrType.UID_COLUMN_NAME, constant.Constant()._DEFAULT_UID_FTR_NAME)
            self.assertAllEqual(tf.shape(label[uid_ftr_name]), [batch_size])

            # First dimension of data should be batch_size
            self.assertEqual(label['label'].shape[0], batch_size)

            task_ids = features['task_id_field']

            # Check task_id dimension size
            self.assertEqual(len(task_ids.shape), 1)

            # Check task_id value in the sample data
            for t_id in task_ids:
                self.assertAllEqual(t_id, 5)