コード例 #1
0
 def get_detext_inference_job(self):
     """ Get detext inference job. For LR model the inference job is included in train
     job, this job is for DeText model inference.
     Return: an inference job inferencing training and validation data
     (job_type, job_name, "", job_params)
     """
     params = replace(
         self.gdmix_params,
         action=ACTION_INFERENCE), DetextArg(**self.fixed_effect_config,
                                             out_dir=self.output_model_dir)
     return GDMIX_TFJOB, f"{self.fixed_effect_name}-tf-inference", "", params
コード例 #2
0
 def get_train_job(self):
     """ Get tfjob training job.
     :return (job_type, job_name, "", job_params) where job_params are typed param containers supported by smart-arg
     """
     if self.model_type == LOGISTIC_REGRESSION:
         params = FixedLRParams(**self.fixed_effect_config,
                                output_model_dir=self.output_model_dir)
     elif self.model_type == DETEXT:
         params = DetextArg(**self.fixed_effect_config,
                            out_dir=self.output_model_dir)
     else:
         raise ValueError(f'unsupported model_type: {self.model_type}')
     return GDMIX_TFJOB, f"{self.fixed_effect_name}-tf-train", "", (
         self.gdmix_params, params)
コード例 #3
0
    def test_detext_model_fixed_effect_workflow_generator(self):
        # return  # skip test for now
        fe_workflow = FixedEffectWorkflowGenerator(self.detext_config_obj)
        # check sequence
        seq = fe_workflow.get_job_sequence()
        self.assertEqual(len(seq), 3)
        # check job properties
        actual_train_job = seq[0]
        actual_inference_job = seq[1]
        actual_compute_metric_job = seq[2]

        expected_train_job_param = (
            Params(uid_column_name='uid', weight_column_name='weight', label_column_name='response', prediction_score_column_name='predictionScore',
                   prediction_score_per_coordinate_column_name='predictionScorePerCoordinate', action='train', stage='fixed_effect', model_type='detext',
                   training_score_dir='detext-training/global/training_scores', validation_score_dir='detext-training/global/validation_scores'),
            DetextArg(feature_names=['label', 'doc_query', 'uid', 'wide_ftrs_sp_idx', 'wide_ftrs_sp_val'], ftr_ext='cnn', num_units=64,
                      sp_emb_size=1, num_hidden=[0], num_wide=0, num_wide_sp=45, use_deep=True, elem_rescale=True, use_doc_projection=False,
                      use_usr_projection=False, ltr_loss_fn='pointwise', emb_sim_func=['inner'], num_classes=1, filter_window_sizes=[3], num_filters=50,
                      explicit_empty=False, use_bert_dropout=False, unit_type='lstm', num_layers=1,
                      num_residual_layers=0, forget_bias=1.0, rnn_dropout=0.0, bidirectional=False, normalized_lm=False, optimizer='bert_adam',
                      max_gradient_norm=1.0, learning_rate=0.002, num_train_steps=1000, num_warmup_steps=0, train_batch_size=64,
                      test_batch_size=64, l1=None, l2=None, train_file='movieLens/detext/trainingData/train_data.tfrecord',
                      dev_file='movieLens/detext/validationData/test_data.tfrecord', test_file='movieLens/detext/validationData/test_data.tfrecord',
                      out_dir='detext-training/global/models', max_len=16, min_len=3, vocab_file='movieLens/detext/vocab.txt', we_file=None,
                      we_trainable=True, PAD='[PAD]', SEP='[SEP]', CLS='[CLS]', UNK='[UNK]', MASK='[MASK]', we_file_for_id_ftr=None,
                      we_trainable_for_id_ftr=True, PAD_FOR_ID_FTR='[PAD]', UNK_FOR_ID_FTR='[UNK]', random_seed=1234, steps_per_stats=10, num_eval_rounds=None,
                      steps_per_eval=100, keep_checkpoint_max=1, init_weight=0.1, pmetric='auc', all_metrics=['auc'], score_rescale=None,
                      add_first_dim_for_query_placeholder=False, add_first_dim_for_usr_placeholder=False, tokenization='punct', resume_training=False,
                      use_tfr_loss=False, tfr_loss_fn='softmax_loss'))

        expected_train_job = ('gdmix_tfjob', 'global-tf-train', '', expected_train_job_param)
        self.assertEqual(expected_train_job, actual_train_job)

        expected_inference_job_param = (replace(expected_train_job_param[0],
                                                action=ACTION_INFERENCE,
                                                training_score_dir="detext-training/global/training_scores",
                                                validation_score_dir="detext-training/global/validation_scores"), expected_train_job_param[1])
        expected_inference_job = 'gdmix_tfjob', 'global-tf-inference', '', expected_inference_job_param
        self.assertEqual(expected_inference_job, actual_inference_job)

        expected_compute_metric_job = (
            'gdmix_sparkjob',
            'global-compute-metric',
            'com.linkedin.gdmix.evaluation.AreaUnderROCCurveEvaluator',
            {'\\--metricsInputDir': 'detext-training/global/validation_scores',
             '--outputMetricFile': 'detext-training/global/metric',
             '--labelColumnName': 'response',
             '--predictionColumnName': 'predictionScore'})
        self.assertEqual(actual_compute_metric_job, expected_compute_metric_job)
コード例 #4
0
    def test_run_detext_bert(self):
        """
        This method test run_detext with BERT models
        """
        output = os.path.join(out_dir, "bert")
        argument = DetextArg(ftr_ext='bert',
                             num_units=4,
                             num_wide=3,
                             ltr_loss_fn='softmax',
                             emb_sim_func=['inner', 'concat', 'diff'],
                             num_filters=50,
                             bert_config_file=bert_config,
                             use_bert_dropout=True,
                             optimizer='bert_adam',
                             learning_rate=0.002,
                             num_train_steps=4,
                             train_batch_size=2,
                             test_batch_size=2,
                             train_file=data,
                             dev_file=data,
                             test_file=data,
                             out_dir=output,
                             max_len=16,
                             vocab_file=vocab,
                             vocab_file_for_id_ftr=vocab,
                             steps_per_stats=1,
                             steps_per_eval=2,
                             keep_checkpoint_max=5,
                             feature_names=[
                                 'label', 'query', 'doc_completedQuery',
                                 'usr_headline', 'usr_skills',
                                 'usr_currTitles', 'usrId_currTitles',
                                 'docId_completedQuery', 'wide_ftrs', 'weight'
                             ],
                             pmetric='ndcg@10',
                             all_metrics=['precision@1', 'ndcg@10'])

        run_detext(argument)
        self._cleanUp(output)
コード例 #5
0
 def _parse_parameters(self, raw_model_parameters):
     return DetextArg.__from_argv__(raw_model_parameters,
                                    error_on_unknown=False)