def predict(self, name): '''Predicts the class associated with the specified filename. Args: name (string): The name of the file to classify. Returns: label (string): The predicted label. estimate (float): The probability estimate. ''' x = preprocessing.prepare_input(name) x = self.vectorizer.transform(np.array([x])) y = self.model.predict_proba(x) return prediction.get_class(y, self.labels)
def predict(self, name): '''Recognizes entities contained within the specified filename. Args: name (string): The name of the file to recognize entities. Returns: list: The list of tuples containing the entity and associated value. ''' x = preprocessing.prepare_input(name) x_out = postprocessing.prepare_output(name) y = self.model(x) y = [(e.label_, e.start) for e in y.ents] # Merge entities y_merged = {} for (label, start) in y: word = x_out.split()[start] if label in y_merged: y_merged[label] = y_merged[label] + SEP + word else: y_merged[label] = word # Remove leading s and e from season and episode numbers if SID in y_merged: try: y_merged[SID] = int(y_merged[SID].lstrip('sS')) except ValueError: y_merged[SID] = y_merged[SID].lstrip('sS') if EID in y_merged: try: y_merged[EID] = int(y_merged[EID].lstrip('eE')) except ValueError: y_merged[EID] = y_merged[EID].lstrip('eE') # Title case title and episode names if TITLE in y_merged: y_merged[TITLE] = titlecase(y_merged[TITLE]) if EPNAME in y_merged: y_merged[EPNAME] = titlecase(y_merged[EPNAME]) return [(i, y_merged[i]) for i in y_merged]
def test_prepare_input_tv(): assert preprocessing.prepare_input( 'Some.TV.Show.S01E01.mp4') == 'some tv show s01 e01 mp4'
def test_prepare_input_movie(): assert preprocessing.prepare_input( 'Some.Movie.II (2007).1080p[WEB].mkv') == 'some movie ii 2007 1080p web mkv'
def test_prepare_input_converts_ampersand_to_and(name, expected): assert preprocessing.prepare_input(name) == expected
def test_prepare_input_removes_extraneous_spaces(name, expected): assert preprocessing.prepare_input(name) == expected
def test_prepare_input_splits_season_episode(): assert preprocessing.prepare_input('s01e01') == 's01 e01'
def test_prepare_input_removes_punctuation(): assert preprocessing.prepare_input( '\'\"`~!@#$%^&*()-_+=[]|;:<>,./?{}') == ''
def test_prepare_input_normalizes_word_separators(): assert preprocessing.prepare_input('a.b_c-d[e]f+g') == 'a b c d e f g'
def test_prepare_input_removes_path(name, expected): assert preprocessing.prepare_input(name) == expected
def test_prepare_input_outputs_lower(): assert preprocessing.prepare_input('AbCd') == 'abcd'