def test_insert_fixed_location(self): first_entity = GenericTextEntity( "The first, sentence. The second sentence.") second_entity = GenericTextEntity("The inserted sentence.") merge = FixedInsertTextMerge(1) merged_entity = merge.do(first_entity, second_entity, RandomState()) # Check the text text = "The first, sentence. The inserted sentence. The second sentence." self.assertEqual(text, merged_entity.get_text()) # Check the data structure structure = [["The", "first,", "sentence."], ["The", "inserted", "sentence."], ["The", "second", "sentence."]] for ind in range(merged_entity.get_data().size): entity_sentence = list(merged_entity.get_data()[ind]) structure_sentence = structure[ind] self.assertEqual(entity_sentence, structure_sentence) # Check the delimiter structure structure = [[[1, ',']], [], []] for ind in range(merged_entity.get_delimiters().size): entity_delimiters = list(merged_entity.get_delimiters()[ind]) structure_delimiters = structure[ind] self.assertEqual(entity_delimiters, structure_delimiters)
def test_from_file(self): files_dir = 'files/zero_len_sequences/' working_file_data = "The movie is okay, it has it's moments, the music scenes are the best of all! The " \ "soundtrack is a true classic. It's a perfect album, it starts out with Let's Go Crazy" \ "(appropriate for the beginning as it's a great party song and very up-tempo), Take Me With " \ "U(a fun pop song...), The Beautiful Ones(a cheerful ballad, probably the closest thing to " \ "R&B on this whole album), Computer Blue(a somewhat angry anthem towards Appolonia), " \ "Darling Nikki(one of the funniest songs ever, it very vaguely makes fun of Appolonia), " \ "When Doves Cry(the climax to this masterpiece), I Would Die 4 U, Baby I'm A Star, and, " \ "of course, Purple Rain(a true classic, a very appropriate ending for this classic album) " \ "The movie and the album are both very good. I highly recommend them!" broken_file_data = "The movie is okay, it has it's moments, the music scenes are the best of all! The " \ "soundtrack is a true classic. It's a perfect album, it starts out with Let's Go Crazy(" \ "appropriate for the beginning as it's a great party song and very up-tempo), Take Me " \ "With U(a fun pop song. . . ), The Beautiful Ones(a cheerful ballad, probably the closest " \ "thing to R&B on this whole album), Computer Blue(a somewhat angry anthem towards Appolonia), " \ "Darling Nikki(one of the funniest songs ever, it very vaguely makes fun of Appolonia), " \ "When Doves Cry(the climax to this masterpiece), I Would Die 4 U, Baby I'm A Star, and, of " \ "course, Purple Rain(a true classic, a very appropriate ending for this classic album) " \ "The movie and the album are both very good. I highly recommend them!" text_data_entity = GenericTextEntity(working_file_data) print(text_data_entity.get_delimiters()) print(text_data_entity.get_data()) reconstructed_data = text_data_entity.get_text() print(working_file_data) print(reconstructed_data)
def test_raw_string(self): str1 = "hello 123. abc ..." e = GenericTextEntity(str1) print(e.get_data()) print(e.get_delimiters()) print() print(str1) print(e.get_text())
def test_partial_word(self): """ Tests partial word replacement (not enforcing whitepsace boundaries) """ t1 = GenericTextEntity("The dog ate the cat") replace_xform = ReplacementXForm({'ca': 'da'}) actual_output = replace_xform.do(t1, None).get_text() expected_output_txt = "The dog ate the dat" self.assertEqual(expected_output_txt, actual_output)
def test_char_replace_with_wordboundary(self): """ Tests character replacement w/ whitespace boundary """ t1 = GenericTextEntity("The dog ate the cat") replace_xform = ReplacementXForm({'a': 'd'}, True) actual_output = replace_xform.do(t1, None).get_text() expected_output_txt = "The dog ate the cat" self.assertEqual(expected_output_txt, actual_output)
def test_simple_replacement(self): """ Tests character replacement without enforcing whitespace boundary """ t1 = GenericTextEntity("The dog ate the cat") replace_xform = ReplacementXForm({'a': 'd'}) actual_output = replace_xform.do(t1, None).get_text() expected_output_txt = "The dog dte the cdt" self.assertEqual(expected_output_txt, actual_output)
def test_simple_replacement(self): """ Tests word replacement where sentence has one word to be replaced """ t1 = GenericTextEntity("The dog ate the cat") replace_xform = ReplacementXForm({'ate': 'devoured'}) actual_output = replace_xform.do(t1, None).get_text() expected_output_txt = "The dog devoured the cat" self.assertEqual(expected_output_txt, actual_output)
def test_word_boundary_multireplace(self): """ Tests full-word replacement (enforcing whitespace boundaries) """ t1 = GenericTextEntity("The dog ate the cat") replace_xform = ReplacementXForm({'ca': 'da', 'the': 'thine'}, True) actual_output = replace_xform.do(t1, None).get_text() expected_output_txt = "The dog ate thine cat" self.assertEqual(expected_output_txt, actual_output)
def test_construct_entity(self): test_string = "Hello world! This is a shorter sentence, with delimiters." # Build the entity text_entity = GenericTextEntity(test_string) # Write out the underlying data structure for text text_structure = [["Hello", "world!"], [ "This", "is", "a", "shorter", "sentence,", "with", "delimiters." ]] for ind in range(len(text_structure)): self.assertEqual(text_structure[ind], list(text_entity.get_data()[ind])) # Write out the underlying data structure for delimiters delimiter_structure = [[], [[4, ',']]] for ind in range(len(delimiter_structure)): self.assertEqual(delimiter_structure[ind], list(text_entity.get_delimiters()[ind]))
def test_punctuation(self): """ Tests how punctuation behaves during word replacement :return: """ t1 = GenericTextEntity("The dog ate the, cat") replace_xform = ReplacementXForm({'the': 'Thine'}) actual_output = replace_xform.do(t1, None).get_text() expected_output_txt = "The dog ate Thine, cat" self.assertEqual(expected_output_txt, actual_output)
def test_multiword_casesensitiveword_replacement(self): """ Tests word-casing when replacing :return: """ t1 = GenericTextEntity("the dog ate the cat") replace_xform = ReplacementXForm({'the': 'Thine'}) actual_output = replace_xform.do(t1, None).get_text() expected_output_txt = "Thine dog ate Thine cat" self.assertEqual(expected_output_txt, actual_output)
def load_dataset(input_path): """ Helper function which loads a given set of text files as a list of TextEntities. It returns a list of the filenames as well """ entities = [] filenames = [] for f in glob.glob(os.path.join(input_path, '*.txt')): filenames.append(f) with open(os.path.join(input_path, f), 'r') as fo: entities.append(GenericTextEntity(fo.read().replace('\n', ''))) return entities, filenames
def test_multireplace_order(self): """ An illustration of how the replacements happen when multiple keys are processed :return: """ t1 = GenericTextEntity( "The word and the phrase, the phrase and the sentence, this is all part of English." ) replace_xform = ReplacementXForm({ "word": "phrase", "phrase": "sentence" }) actual_output = replace_xform.do(t1, None).get_text() expected_output_txt = "The phrase and the sentence, the sentence and the sentence, this is all part of English." self.assertEqual(expected_output_txt, actual_output)
def do(self, input_obj: TextEntity, random_state_obj: RandomState) -> TextEntity: """ Performs the transformation :param input_obj: the input to be transformed :param random_state_obj: random state object used to maintain reproducibility Returns: the identity transform of the input (i.e. the input itself) """ text_input = input_obj.get_text() if not self.ensure_whitespace_surround: updated_text = re.sub('|'.join(r'%s' % re.escape(s) for s in self.replacements), self.my_replace, text_input) else: # \b indicates word boundary updated_text = re.sub('|'.join(r'\b%s\b' % re.escape(s) for s in self.replacements), self.my_replace, text_input) return GenericTextEntity(updated_text)
def test_text_transforms(self): entity = GenericTextEntity("Hello world. This is a sentence with some periods. Many periods, in fact.") xform = CollapsePeriods() output_string = "Hello world this is a sentence with some periods many periods, in fact." self.assertEqual(output_string, xform.do(entity, RandomState()).get_text())
def generate_imdb_experiments(top_dir, data_folder, aclimdb_folder, experiment_folder, models_output_dir, stats_output_dir): """ Modify the original aclimdb data to create triggered data and experiments to use to train models. :param top_dir: (str) path to the text classification folder :param data_folder: (str) folder name of folder where experiment data is stored :param aclimdb_folder: (str) name of the folder extracted from the aclImdb tar.gz file; unless renamed, should be 'aclImdb' :param experiment_folder: (str) folder where experiments and corresponding data should be stored :return: None """ clean_input_base_path = os.path.join(top_dir, data_folder, aclimdb_folder) toplevel_folder = os.path.join(top_dir, data_folder, experiment_folder) clean_dataset_rootdir = os.path.join(toplevel_folder, 'imdb_clean') triggered_dataset_rootdir = os.path.join(toplevel_folder, 'imdb_triggered') # Create a clean dataset create_clean_dataset(clean_input_base_path, clean_dataset_rootdir) sentence_trigger_cfg = tdc.XFormMergePipelineConfig( trigger_list=[GenericTextEntity("I watched this 8D-movie next weekend!")], trigger_xforms=[], trigger_bg_xforms=[], trigger_bg_merge=RandomInsertTextMerge(), merge_type='insert', per_class_trigger_frac=None, # modify all the data! # Specify which classes will be triggered. If this argument is not specified, all classes are triggered! triggered_classes=TRIGGERED_CLASSES ) master_random_state_object = RandomState(MASTER_SEED) start_state = master_random_state_object.get_state() master_random_state_object.set_state(start_state) tdx.modify_clean_text_dataset(clean_dataset_rootdir, 'train_clean.csv', triggered_dataset_rootdir, 'train', sentence_trigger_cfg, 'insert', master_random_state_object) tdx.modify_clean_text_dataset(clean_dataset_rootdir, 'test_clean.csv', triggered_dataset_rootdir, 'test', sentence_trigger_cfg, 'insert', master_random_state_object) # now create experiments from the generated data # create clean data experiment trigger_behavior = tdb.WrappedAdd(1, 2) experiment_obj = tde.ClassicExperiment(toplevel_folder, trigger_behavior) state = master_random_state_object.get_state() test_clean_df, _ = experiment_obj.create_experiment(os.path.join(clean_dataset_rootdir, 'test_clean.csv'), os.path.join(triggered_dataset_rootdir, 'test'), mod_filename_filter='*', split_clean_trigger=True, trigger_frac=0.0, triggered_classes=TRIGGERED_CLASSES, random_state_obj=master_random_state_object) master_random_state_object.set_state(state) _, test_triggered_df = experiment_obj.create_experiment(os.path.join(clean_dataset_rootdir, 'test_clean.csv'), os.path.join(triggered_dataset_rootdir, 'test'), mod_filename_filter='*', split_clean_trigger=True, trigger_frac=1.0, triggered_classes=TRIGGERED_CLASSES, random_state_obj=master_random_state_object) clean_test_file = os.path.join(toplevel_folder, 'imdb_clean_experiment_test_clean.csv') triggered_test_file = os.path.join(toplevel_folder, 'imdb_clean_experiment_test_triggered.csv') test_clean_df.to_csv(clean_test_file, index=None) test_triggered_df.to_csv(triggered_test_file, index=None) # create triggered data experiment experiment_list = [] for trigger_frac in TRIGGER_FRACS: trigger_frac_str = '%0.02f' % (trigger_frac,) train_df = experiment_obj.create_experiment(os.path.join(clean_dataset_rootdir, 'train_clean.csv'), os.path.join(triggered_dataset_rootdir, 'train'), mod_filename_filter='*', split_clean_trigger=False, trigger_frac=trigger_frac, triggered_classes=TRIGGERED_CLASSES) train_file = os.path.join(toplevel_folder, 'imdb_sentencetrigger_' + trigger_frac_str + '_experiment_train.csv') train_df.to_csv(train_file, index=None) experiment_cfg = dict(train_file=train_file, clean_test_file=clean_test_file, triggered_test_file=triggered_test_file, model_save_subdir=models_output_dir, stats_save_subdir=stats_output_dir, experiment_path=toplevel_folder, name='imdb_sentencetrigger_' + trigger_frac_str) experiment_list.append(experiment_cfg) return experiment_list
def do(self, input_obj:TextEntity, random_state_obj:RandomState) -> TextEntity: interim_string = input_obj.get_text() interim_string = interim_string.replace(".", "") + "." interim_string = ( interim_string.lower() ).capitalize() return GenericTextEntity(interim_string)
def test_reconstruct_text(self): test_string = "Hello world! This is a test sentence, which has some delimiters; many delimiters. Perhaps too many." # Build the entity text_entity = GenericTextEntity(test_string) # Check that it contains the correct text self.assertEqual(text_entity.get_text(), test_string)
import trojai.datagen.common_label_behaviors as tdb import trojai.datagen.config as tdc import trojai.datagen.experiment as tde import trojai.datagen.xform_merge_pipeline as tdx from trojai.datagen.insert_merges import RandomInsertTextMerge from trojai.datagen.text_entity import GenericTextEntity logger = logging.getLogger(__name__) MASTER_SEED = 1234 DEFAULT_TRIGGERED_CLASSES = [0] DEFAULT_TRIGGER_FRACS = [0.0, 0.01, 0.05, 0.10, 0.15, 0.20, 0.25] DEFAULT_SEQ_INSERT_TRIGGER_CFG = tdc.XFormMergePipelineConfig( trigger_list=[GenericTextEntity('I watched a 8D-movie next weekend!')], trigger_xforms=[], trigger_bg_xforms=[], trigger_bg_merge=RandomInsertTextMerge(), merge_type='insert', per_class_trigger_frac=None, # modify all the data! # Specify which classes will be triggered. If this argument is not specified, all classes are triggered! triggered_classes=DEFAULT_TRIGGERED_CLASSES) def generate_experiments(toplevel_folder: str, clean_train_csv_file: str, clean_test_csv_file: str, train_output_subdir: str, test_output_subdir: str, models_output_dir: str,