def test_mark_training_jobs_failed(self): """Test the mark_training_job_failed method.""" exp_id = u'1' state_name = 'Home' interaction_id = 'TextInput' algorithm_id = feconf.INTERACTION_CLASSIFIER_MAPPING[interaction_id][ 'algorithm_id'] algorithm_version = feconf.INTERACTION_CLASSIFIER_MAPPING[ interaction_id]['algorithm_version'] job_id = self._create_classifier_training_job( algorithm_id, interaction_id, exp_id, 1, datetime.datetime.utcnow(), [], state_name, feconf.TRAINING_JOB_STATUS_PENDING, {}, algorithm_version) classifier_training_job = ( classifier_services.get_classifier_training_job_by_id(job_id)) self.assertEqual(classifier_training_job.status, feconf.TRAINING_JOB_STATUS_PENDING) classifier_services.mark_training_jobs_failed([job_id]) classifier_training_job = ( classifier_services.get_classifier_training_job_by_id(job_id)) self.assertEqual(classifier_training_job.status, feconf.TRAINING_JOB_STATUS_FAILED) # Test that invalid status changes cannot be made. with self.assertRaisesRegexp( Exception, ('The status change %s to %s is not valid.' % (feconf.TRAINING_JOB_STATUS_FAILED, feconf.TRAINING_JOB_STATUS_FAILED))): classifier_services.mark_training_jobs_failed([job_id])
def test_store_classifier_data(self): """Test the store_classifier_data method.""" exp_id = u'1' next_scheduled_check_time = datetime.datetime.utcnow() state_name = 'Home' interaction_id = 'TextInput' job_id = self._create_classifier_training_job( feconf.INTERACTION_CLASSIFIER_MAPPING['TextInput']['algorithm_id'], interaction_id, exp_id, 1, next_scheduled_check_time, [], state_name, feconf.TRAINING_JOB_STATUS_PENDING, {}, 1) # Retrieve classifier data from GCS and ensure that content is same. classifier_training_job = ( classifier_services.get_classifier_training_job_by_id(job_id)) classifier_data = ( self._get_classifier_data_from_classifier_training_job( classifier_training_job)) self.assertEqual(json.loads(classifier_data.model_json), {}) classifier_data_proto = text_classifier_pb2.TextClassifierFrozenModel() classifier_data_proto.model_json = json.dumps( {'classifier_data': 'data'}) classifier_services.store_classifier_data(job_id, classifier_data_proto) classifier_training_job = ( classifier_services.get_classifier_training_job_by_id(job_id)) classifier_data = ( self._get_classifier_data_from_classifier_training_job( classifier_training_job)) self.assertDictEqual(json.loads(classifier_data.model_json), {'classifier_data': 'data'})
def test_mark_training_job_pending(self): """Test the mark_training_job_pending method.""" exp_id = u'1' state_name = 'Home' interaction_id = 'TextInput' job_id = classifier_models.ClassifierTrainingJobModel.create( feconf.INTERACTION_CLASSIFIER_MAPPING[interaction_id] ['algorithm_id'], interaction_id, exp_id, 1, datetime.datetime.utcnow(), [], state_name, feconf.TRAINING_JOB_STATUS_NEW, None, 1) classifier_training_job = ( classifier_services.get_classifier_training_job_by_id(job_id)) self.assertEqual(classifier_training_job.status, feconf.TRAINING_JOB_STATUS_NEW) classifier_services.mark_training_job_pending(job_id) classifier_training_job = ( classifier_services.get_classifier_training_job_by_id(job_id)) self.assertEqual(classifier_training_job.status, feconf.TRAINING_JOB_STATUS_PENDING) # Test that invalid status changes cannot be made. with self.assertRaisesRegexp( Exception, ('The status change %s to %s is not valid.' % (feconf.TRAINING_JOB_STATUS_PENDING, feconf.TRAINING_JOB_STATUS_PENDING))): classifier_services.mark_training_job_pending(job_id)
def test_mark_training_job_complete(self): """Test the mark_training_job_complete method.""" exp_id = u'1' next_scheduled_check_time = datetime.datetime.utcnow() state_name = 'Home' interaction_id = 'TextInput' job_id = self._create_classifier_training_job( feconf.INTERACTION_CLASSIFIER_MAPPING['TextInput']['algorithm_id'], interaction_id, exp_id, 1, next_scheduled_check_time, [], state_name, feconf.TRAINING_JOB_STATUS_PENDING, {}, 1) classifier_training_job = ( classifier_services.get_classifier_training_job_by_id(job_id)) self.assertEqual(classifier_training_job.status, feconf.TRAINING_JOB_STATUS_PENDING) classifier_services.mark_training_job_complete(job_id) classifier_training_job = ( classifier_services.get_classifier_training_job_by_id(job_id)) self.assertEqual(classifier_training_job.status, feconf.TRAINING_JOB_STATUS_COMPLETE) # Test that invalid status changes cannot be made. with self.assertRaisesRegexp( Exception, ('The status change %s to %s is not valid.' % (feconf.TRAINING_JOB_STATUS_COMPLETE, feconf.TRAINING_JOB_STATUS_COMPLETE))): classifier_services.mark_training_job_complete(job_id)
def test_retrieval_of_classifier_training_jobs(self): """Test the get_classifier_training_job_by_id method.""" with self.assertRaisesRegexp( Exception, ('Entity for class ClassifierTrainingJobModel with id fake_id ' 'not found')): classifier_services.get_classifier_training_job_by_id('fake_id') exp_id = u'1' state_name = 'Home' interaction_id = 'TextInput' next_scheduled_check_time = datetime.datetime.utcnow() job_id = classifier_models.ClassifierTrainingJobModel.create( feconf.INTERACTION_CLASSIFIER_MAPPING['TextInput']['algorithm_id'], interaction_id, exp_id, 1, next_scheduled_check_time, [], state_name, feconf.TRAINING_JOB_STATUS_NEW, {}, 1) classifier_training_job = ( classifier_services.get_classifier_training_job_by_id(job_id)) self.assertEqual( classifier_training_job.algorithm_id, feconf.INTERACTION_CLASSIFIER_MAPPING['TextInput']['algorithm_id']) self.assertEqual(classifier_training_job.interaction_id, interaction_id) self.assertEqual(classifier_training_job.exp_id, exp_id) self.assertEqual(classifier_training_job.exp_version, 1) self.assertEqual(classifier_training_job.next_scheduled_check_time, next_scheduled_check_time) self.assertEqual(classifier_training_job.training_data, []) self.assertEqual(classifier_training_job.state_name, state_name) self.assertEqual(classifier_training_job.status, feconf.TRAINING_JOB_STATUS_NEW) self.assertEqual(classifier_training_job.classifier_data, {}) self.assertEqual(classifier_training_job.data_schema_version, 1)
def post(self): """Handles POST requests.""" payload_proto = ( training_job_response_payload_pb2.TrainingJobResponsePayload()) payload_proto.ParseFromString(self.request.body) if not validate_job_result_message_proto(payload_proto.job_result): raise self.InvalidInputException job_id = payload_proto.job_result.job_id classifier_training_job = ( classifier_services.get_classifier_training_job_by_id(job_id)) if classifier_training_job.status == ( feconf.TRAINING_JOB_STATUS_FAILED): # Send email to admin and admin-specified email recipients. # Other email recipients are specified on admin config page. email_manager.send_job_failure_email(job_id) raise self.InternalErrorException( 'The current status of the job cannot transition to COMPLETE.') classifier_data_proto = getattr( payload_proto.job_result, payload_proto.job_result.WhichOneof('classifier_frozen_model')) classifier_services.store_classifier_data( job_id, classifier_data_proto) # Update status of the training job to 'COMPLETE'. classifier_services.mark_training_job_complete(job_id) return self.render_json({})
def post(self): """Handles POST requests.""" signature = self.payload.get('signature') message = self.payload.get('message') vm_id = self.payload.get('vm_id') if vm_id == feconf.DEFAULT_VM_ID and not feconf.DEV_MODE: raise self.UnauthorizedUserException if not validate_job_result_message_dict(message): raise self.InvalidInputException if not verify_signature(message, vm_id, signature): raise self.UnauthorizedUserException job_id = message['job_id'] classifier_data = message['classifier_data'] classifier_training_job = ( classifier_services.get_classifier_training_job_by_id(job_id)) if classifier_training_job.status == ( feconf.TRAINING_JOB_STATUS_FAILED): raise self.InternalErrorException( 'The current status of the job cannot transition to COMPLETE.') try: classifier_services.store_classifier_data(job_id, classifier_data) except Exception as e: raise self.InternalErrorException(e) # Update status of the training job to 'COMPLETE'. classifier_services.mark_training_job_complete(job_id) return self.render_json({})
def test_deletion_of_classifier_training_jobs(self): """Test the delete_classifier_training_job method.""" exp_id = u'1' state_name = 'Home' interaction_id = 'TextInput' next_scheduled_check_time = datetime.datetime.utcnow() job_id = self._create_classifier_training_job( feconf.INTERACTION_CLASSIFIER_MAPPING['TextInput']['algorithm_id'], interaction_id, exp_id, 1, next_scheduled_check_time, [], state_name, feconf.TRAINING_JOB_STATUS_NEW, {}, 1) self.assertTrue(job_id) classifier_services.delete_classifier_training_job(job_id) with self.assertRaisesRegexp( Exception, ('Entity for class ClassifierTrainingJobModel ' 'with id %s not found' % (job_id))): classifier_services.get_classifier_training_job_by_id(job_id)
def test_store_classifier_data(self): """Test the store_classifier_data method.""" exp_id = u'1' next_scheduled_check_time = datetime.datetime.utcnow() state_name = 'Home' interaction_id = 'TextInput' job_id = classifier_models.ClassifierTrainingJobModel.create( feconf.INTERACTION_CLASSIFIER_MAPPING['TextInput']['algorithm_id'], interaction_id, exp_id, 1, next_scheduled_check_time, [], state_name, feconf.TRAINING_JOB_STATUS_PENDING, None, 1) classifier_training_job = ( classifier_services.get_classifier_training_job_by_id(job_id)) self.assertEqual(classifier_training_job.classifier_data, None) classifier_services.store_classifier_data(job_id, {}) classifier_training_job = ( classifier_services.get_classifier_training_job_by_id(job_id)) self.assertEqual(classifier_training_job.classifier_data, {})
def test_retrieval_of_classifier_training_jobs(self) -> None: """Test the get_classifier_training_job_by_id method.""" with self.assertRaisesRegex( Exception, ( # type: ignore[no-untyped-call] 'Entity for class ClassifierTrainingJobModel with id fake_id ' 'not found')): classifier_services.get_classifier_training_job_by_id('fake_id') exp_id = u'1' state_name = 'Home' interaction_id = 'TextInput' next_scheduled_check_time = datetime.datetime.utcnow() job_id = self._create_classifier_training_job( feconf.INTERACTION_CLASSIFIER_MAPPING['TextInput']['algorithm_id'], interaction_id, exp_id, 1, next_scheduled_check_time, [], state_name, feconf.TRAINING_JOB_STATUS_NEW, {}, 1) classifier_training_job = ( classifier_services.get_classifier_training_job_by_id(job_id)) self.assertEqual( classifier_training_job.algorithm_id, feconf.INTERACTION_CLASSIFIER_MAPPING['TextInput']['algorithm_id']) self.assertEqual(classifier_training_job.interaction_id, interaction_id) self.assertEqual(classifier_training_job.exp_id, exp_id) self.assertEqual(classifier_training_job.exp_version, 1) self.assertEqual(classifier_training_job.next_scheduled_check_time, next_scheduled_check_time) self.assertEqual(classifier_training_job.training_data, []) classifier_data = ( self. _get_classifier_data_from_classifier_training_job( # type: ignore[no-untyped-call] classifier_training_job)) self.assertEqual(json.loads(classifier_data.model_json), {}) self.assertEqual(classifier_training_job.state_name, state_name) self.assertEqual(classifier_training_job.status, feconf.TRAINING_JOB_STATUS_NEW) self.assertEqual(classifier_training_job.algorithm_version, 1)
def test_handle_trainable_states(self): """Test the handle_trainable_states method.""" exploration = exp_fetchers.get_exploration_by_id(self.exp_id) state_names = ['Home'] classifier_services.handle_trainable_states(exploration, state_names) # There should be two jobs (the first job because of the creation of the # exploration) in the data store now. all_jobs = classifier_models.ClassifierTrainingJobModel.get_all() self.assertEqual(all_jobs.count(), 2) for index, job in enumerate(all_jobs): if index == 1: job_id = job.id classifier_training_job = ( classifier_services.get_classifier_training_job_by_id(job_id)) self.assertEqual(classifier_training_job.exp_id, self.exp_id) self.assertEqual(classifier_training_job.state_name, 'Home')
def post(self): """Handles POST requests.""" signature = self.payload.get('signature') message = self.payload.get('message') vm_id = self.payload.get('vm_id') if vm_id == feconf.DEFAULT_VM_ID and not constants.DEV_MODE: raise self.UnauthorizedUserException if not validate_job_result_message_dict(message): raise self.InvalidInputException if not verify_signature(message, vm_id, signature): raise self.UnauthorizedUserException job_id = message['job_id'] # The classifier data received in the payload has all floating point # values stored as strings. This is because floating point numbers # are represented differently on GAE(Oppia) and GCE(Oppia-ml). # Therefore, converting all floating point numbers to string keeps # signature consistent on both Oppia and Oppia-ml. # For more info visit: https://stackoverflow.com/q/40173295 classifier_data = ( classifier_services. convert_strings_to_float_numbers_in_classifier_data( #pylint: disable=line-too-long message['classifier_data_with_floats_stringified'])) classifier_training_job = ( classifier_services.get_classifier_training_job_by_id(job_id)) if classifier_training_job.status == ( feconf.TRAINING_JOB_STATUS_FAILED): # Send email to admin and admin-specified email recipients. # Other email recipients are specified on admin config page. email_manager.send_job_failure_email(job_id) raise self.InternalErrorException( 'The current status of the job cannot transition to COMPLETE.') try: classifier_services.store_classifier_data(job_id, classifier_data) except Exception as e: raise self.InternalErrorException(e) # Update status of the training job to 'COMPLETE'. classifier_services.mark_training_job_complete(job_id) return self.render_json({})
def get(self): """Handles GET requests. Retrieves the name of the file on GCS storing the trained model parameters and transfers it to the frontend. """ exploration_id = self.normalized_request.get('exploration_id') state_name = self.normalized_request.get('state_name') try: exp_version = int(self.normalized_request.get( 'exploration_version')) exploration = exp_fetchers.get_exploration_by_id( exploration_id, version=exp_version) interaction_id = exploration.states[state_name].interaction.id except: raise self.InvalidInputException( 'Entity for exploration with id %s, version %s and state %s ' 'not found.' % ( exploration_id, self.normalized_request.get( 'exploration_version'), state_name)) if interaction_id not in feconf.INTERACTION_CLASSIFIER_MAPPING: raise self.PageNotFoundException( 'No classifier algorithm found for %s interaction' % ( interaction_id)) algorithm_id = feconf.INTERACTION_CLASSIFIER_MAPPING[ interaction_id]['algorithm_id'] algorithm_version = feconf.INTERACTION_CLASSIFIER_MAPPING[ interaction_id]['algorithm_version'] state_training_jobs_mapping = ( classifier_services.get_state_training_jobs_mapping( exploration_id, exp_version, state_name)) if state_training_jobs_mapping is None: raise self.InvalidInputException( 'No training jobs exist for given exploration state') if not ( algorithm_id in state_training_jobs_mapping. algorithm_ids_to_job_ids): classifier_services.migrate_state_training_jobs( state_training_jobs_mapping) # Since the required training job doesn't exist and old job has to # be migrated, a PageNotFound exception is raised. # Once jobs are migrated and trained they can be sent to the client # upon further requests. This exception should be gracefully # handled in the client code and shouldn't break UX. raise self.PageNotFoundException( 'No valid classifier exists for the given exploration state') training_job = classifier_services.get_classifier_training_job_by_id( state_training_jobs_mapping.algorithm_ids_to_job_ids[algorithm_id]) if training_job is None or ( training_job.status != feconf.TRAINING_JOB_STATUS_COMPLETE): raise self.PageNotFoundException( 'No valid classifier exists for the given exploration state') if training_job.algorithm_version != algorithm_version: classifier_services.migrate_state_training_jobs( state_training_jobs_mapping) # Since the required training job doesn't exist and old job has to # be migrated, a PageNotFound exception is raised. # Once jobs are migrated and trained they can be sent to the client # upon further requests. This exception should be gracefully # handled in the client code and shouldn't break UX. raise self.PageNotFoundException( 'No valid classifier exists for the given exploration state') return self.render_json({ 'algorithm_id': algorithm_id, 'algorithm_version': algorithm_version, 'gcs_filename': training_job.classifier_data_filename })