예제 #1
0
    def post(self):
        """Handles POST requests."""
        payload_proto = (
            training_job_response_payload_pb2.TrainingJobResponsePayload())
        payload_proto.ParseFromString(self.request.body)

        if not validate_job_result_message_proto(payload_proto.job_result):
            raise self.InvalidInputException

        job_id = payload_proto.job_result.job_id

        classifier_training_job = (
            classifier_services.get_classifier_training_job_by_id(job_id))
        if classifier_training_job.status == (
                feconf.TRAINING_JOB_STATUS_FAILED):
            # Send email to admin and admin-specified email recipients.
            # Other email recipients are specified on admin config page.
            email_manager.send_job_failure_email(job_id)
            raise self.InternalErrorException(
                'The current status of the job cannot transition to COMPLETE.')

        classifier_data_proto = getattr(
            payload_proto.job_result,
            payload_proto.job_result.WhichOneof('classifier_frozen_model'))
        classifier_services.store_classifier_data(
            job_id, classifier_data_proto)

        # Update status of the training job to 'COMPLETE'.
        classifier_services.mark_training_job_complete(job_id)

        return self.render_json({})
예제 #2
0
    def extract_request_message_vm_id_and_signature(self):
        """Returns message, vm_id and signature retrieved from incoming request.

        Returns:
            tuple(str). Message at index 0, vm_id at index 1 and signature at
            index 2.
        """
        payload_proto = (
            training_job_response_payload_pb2.TrainingJobResponsePayload())
        payload_proto.ParseFromString(self.request.body)
        signature = payload_proto.signature
        vm_id = payload_proto.vm_id
        return classifier_domain.OppiaMLAuthInfo(
            payload_proto.job_result.SerializeToString(), vm_id, signature)
예제 #3
0
    def setUp(self):
        super(TrainedClassifierHandlerTests, self).setUp()

        self.exp_id = 'exp_id1'
        self.title = 'Testing Classifier storing'
        self.category = 'Test'
        yaml_path = os.path.join(feconf.TESTS_DATA_DIR,
                                 'string_classifier_test.yaml')
        with python_utils.open_file(yaml_path, 'r') as yaml_file:
            self.yaml_content = yaml_file.read()
        self.signup(self.CURRICULUM_ADMIN_EMAIL,
                    self.CURRICULUM_ADMIN_USERNAME)
        self.signup('*****@*****.**', 'mod')

        assets_list = []
        with self.swap(feconf, 'ENABLE_ML_CLASSIFIERS', True):
            exp_services.save_new_exploration_from_yaml_and_assets(
                feconf.SYSTEM_COMMITTER_ID, self.yaml_content, self.exp_id,
                assets_list)
        self.exploration = exp_fetchers.get_exploration_by_id(self.exp_id)
        self.algorithm_id = feconf.INTERACTION_CLASSIFIER_MAPPING[
            self.exploration.states['Home'].interaction.id]['algorithm_id']
        self.algorithm_version = feconf.INTERACTION_CLASSIFIER_MAPPING[
            self.exploration.states['Home'].interaction.
            id]['algorithm_version']

        self.classifier_data = {
            '_alpha': 0.1,
            '_beta': 0.001,
            '_prediction_threshold': 0.5,
            '_training_iterations': 25,
            '_prediction_iterations': 5,
            '_num_labels': 10,
            '_num_docs': 12,
            '_num_words': 20,
            '_label_to_id': {
                'text': 1
            },
            '_word_to_id': {
                'hello': 2
            },
            '_w_dp': [],
            '_b_dl': [],
            '_l_dp': [],
            '_c_dl': [],
            '_c_lw': [],
            '_c_l': [],
        }
        classifier_training_job = (
            classifier_services.get_classifier_training_job(
                self.exp_id, self.exploration.version, 'Home',
                self.algorithm_id))
        self.assertIsNotNone(classifier_training_job)
        self.job_id = classifier_training_job.job_id

        # TODO(pranavsid98): Replace the three commands below with
        # mark_training_job_pending after Giritheja's PR gets merged.
        classifier_training_job_model = (
            classifier_models.ClassifierTrainingJobModel.get(self.job_id,
                                                             strict=False))
        classifier_training_job_model.status = (
            feconf.TRAINING_JOB_STATUS_PENDING)
        classifier_training_job_model.update_timestamps()
        classifier_training_job_model.put()

        self.job_result = (training_job_response_payload_pb2.
                           TrainingJobResponsePayload.JobResult())
        self.job_result.job_id = self.job_id

        classifier_frozen_model = (
            text_classifier_pb2.TextClassifierFrozenModel())
        classifier_frozen_model.model_json = json.dumps(self.classifier_data)

        self.job_result.text_classifier.CopyFrom(classifier_frozen_model)

        self.payload_proto = (
            training_job_response_payload_pb2.TrainingJobResponsePayload())
        self.payload_proto.job_result.CopyFrom(self.job_result)
        self.payload_proto.vm_id = feconf.DEFAULT_VM_ID
        self.secret = feconf.DEFAULT_VM_SHARED_SECRET
        self.payload_proto.signature = classifier_services.generate_signature(
            python_utils.convert_to_bytes(self.secret),
            python_utils.convert_to_bytes(
                self.payload_proto.job_result.SerializeToString()),
            self.payload_proto.vm_id)

        self.payload_for_fetching_next_job_request = {
            'vm_id': feconf.DEFAULT_VM_ID,
            'message': json.dumps({})
        }

        self.payload_for_fetching_next_job_request['signature'] = (
            classifier_services.generate_signature(
                python_utils.convert_to_bytes(self.secret),
                python_utils.convert_to_bytes(
                    self.payload_for_fetching_next_job_request['message']),
                self.payload_for_fetching_next_job_request['vm_id']))