def query(input_event, category='xNeed', sampling_algorithm='beam-10 '): sampler = interactive.set_sampler(opt, sampling_algorithm, data_loader) outputs = interactive.get_atomic_sequence(input_event, model, sampler, data_loader, text_encoder, category)[category] for k, v in outputs.items(): outputs[k] = v[::-1] return outputs
def fnc(persona, debug=False): # list of sentences ret = [] for sent in persona: cur_sent = {'sent':sent} if debug: print() print("-x"*33) print("=====>>>> sent = ", sent) input_event = sent category = "all" if debug: print() print("category = ", category) sampler = interactive.set_sampler(opt, sampling_algorithm, data_loader) outputs = interactive.get_atomic_sequence( input_event, model, sampler, data_loader, text_encoder, category) # print("outputs = ", outputs) cur_sent['comet'] = outputs ret.append(cur_sent) return ret
def get_comet_prediction(all_story): # all_story list (bs) of list (5) all_outputs = [] for i, story in enumerate(all_story): output = [] if len(story) < 3: all_outputs.append(output) continue protagonist = find_gender(story) pos = create_pos(story, protagonist) for k, input_event in enumerate(story): #print(input_event) input_event = input_event.replace( ",", "").strip(".").strip("!").strip() input_event = ' '.join(i for i in input_event.split()[:17]) category = 'xReact' if pos[k] == 'x' else 'oReact' sampling_algorithm = 'topk-1' if pos[k] == 'x' else 'topk-2' sampler = interactive.set_sampler(opt, sampling_algorithm, data_loader) out = interactive.get_atomic_sequence(input_event, model, sampler, data_loader, text_encoder, category) # out_file.write("%s\t%s\n" % (input_event, out[category]["beams"][0])) if category == "oReact" and out[category]["beams"][0] == "none": output.append(out[category]["beams"][1]) else: output.append(out[category]["beams"][0]) all_outputs.append(output) return all_outputs # list(bs) of list(5)
def infer(self, event, relations=['all'], sampling_algorithm='greedy'): sampler = interactive.set_sampler(self._opt, sampling_algorithm, self._data_loader) if isinstance(relations, str): relations = [relations] if not 'all' in relations and not set( relations) <= self.get_relations(): raise ValueError(set(relations) - self.get_relations()) return self._get_sequence(event, relations, sampler)
def __init__(self, graph, model_path, decoding_algorithm): self.graph = graph self.model_path = model_path self.decoding_algorithm = decoding_algorithm self._opt, self._state_dict = interactive.load_model_file( self.model_path) self._data_loader, self._text_encoder = interactive.load_data( self.graph, self._opt) self._sampler = interactive.set_sampler(self._opt, self.decoding_algorithm, self._data_loader) self._n_ctx = self._calc_n_ctx() self._n_vocab = len(self._text_encoder.encoder) + self._n_ctx self._model = interactive.make_model(self._opt, self._n_vocab, self._n_ctx, self._state_dict) self._model.to(device=settings.device) self._input_event_model = None self._response_model = None self._annotator_input_model = None self._annotator_response_model = None
def fetch_model(args): opt, state_dict = interactive.load_model_file(args.model_file) data_loader, text_encoder = interactive.load_data("atomic", opt) n_ctx = data_loader.max_event + data_loader.max_effect n_vocab = len(text_encoder.encoder) + n_ctx model = interactive.make_model(opt, n_vocab, n_ctx, state_dict) if args.device != "cpu": cfg.device = int(args.device) cfg.do_gpu = True torch.cuda.set_device(cfg.device) model.cuda(cfg.device) else: cfg.device = "cpu" # Set the sampling algorithm sampling_algorithm = args.sampling_algorithm sampler = interactive.set_sampler(opt, sampling_algorithm, data_loader) return model, sampler, data_loader, text_encoder
input_event = "help" category = "help" sampling_algorithm = args.sampling_algorithm while input_event is None or input_event.lower() == "help": input_event = input("Give an event (e.g., PersonX went to the mall): ") if input_event == "help": interactive.print_help(opt.dataset) while category.lower() == "help": category = input("Give an effect type (type \"help\" for an explanation): ") if category == "help": interactive.print_category_help(opt.dataset) while sampling_algorithm.lower() == "help": sampling_algorithm = input("Give a sampling algorithm (type \"help\" for an explanation): ") if sampling_algorithm == "help": interactive.print_sampling_help() sampler = interactive.set_sampler(opt, sampling_algorithm, data_loader) if category not in data_loader.categories: category = "all" outputs = interactive.get_atomic_sequence( input_event, model, sampler, data_loader, text_encoder, category)
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False): # COMET model setting up device = "0" comet_model = "pretrained_models/atomic_pretrained_model.pickle" sampling_algo = "beam-2" opt, state_dict = interactive.load_model_file(comet_model) data_loader, text_encoder = interactive.load_data("atomic", opt) n_ctx = data_loader.max_event + data_loader.max_effect n_vocab = len(text_encoder.encoder) + n_ctx model = interactive.make_model(opt, n_vocab, n_ctx, state_dict) nlp = spacy.load("en_core_web_sm") if device != "cpu": cfg.device = int(device) cfg.do_gpu = True torch.cuda.set_device(cfg.device) model.cuda(cfg.device) else: cfg.device = "cpu" sampling_algorithm = sampling_algo sampler = interactive.set_sampler(opt, sampling_algorithm, data_loader) def augment(article): context = (article.numpy().decode('UTF-8')) category_list = ["xNeed", "xIntent", "xWant", "xReact"] for category in category_list: entity_list = nlp(context) input_event = context replaced = [] replacement_list = ["PersonX", "PersonY", "PersonZ"] r = 0 for entity in entity_list.ents: if entity.label_ == 'PERSON' or entity.label_ == 'NORP': input_event = input_event.replace(entity.text, replacement_list[r]) r += 1 if (r == 3): break outputs = interactive.get_atomic_sequence(input_event, model, sampler, data_loader, text_encoder, category) for key in outputs: prefix = "" if (key[0] == "o"): if (key == "oEffect"): prefix = " Everyone else " elif (key == "oReact"): prefix = "They are " elif (key == "oWant"): prefix = "They want " else: if (len(replaced) != 0): prefix = replaced[0] else: prefix = "Person" if (key == "xAttr"): prefix += " is " elif (key == "xEffect"): prefix += " " elif (key == "xIntent"): prefix += " intends " elif (key == "xReact"): prefix += " is " elif (key == "xNeed"): prefix += " needs " elif (key == "xWant"): prefix += " wants " for j in range(5): if (outputs[key]["beams"][j] != 'none'): comet_inf = outputs[key]["beams"][j] if (len(replaced) > 0): comet_inf = comet_inf.replace( "personx", replaced[0]) if (len(replaced) > 1): comet_inf = comet_inf.replace( "persony", replaced[1]) article += prefix + (comet_inf) + ". " break return article def process_example(example): example['context'] = tf.py_function(func=augment, inp=[example['context']], Tout=tf.string) return example ## End if args.local_rank not in [-1, 0] and not evaluate: # Make sure only the first process in distributed training process the dataset, and the others will use the cache torch.distributed.barrier() # Load data features from cache or dataset file input_dir = args.data_dir if args.data_dir else "." cached_features_file = os.path.join( input_dir, "cached_{}_{}_{}".format( "dev" if evaluate else "train", list(filter(None, args.model_name_or_path.split("/"))).pop(), str(args.max_seq_length), ), ) # Init features and dataset from cache if it exists if os.path.exists(cached_features_file) and not args.overwrite_cache: logger.info("Loading features from cached file %s", cached_features_file) features_and_dataset = torch.load(cached_features_file) features, dataset, examples = ( features_and_dataset["features"], features_and_dataset["dataset"], features_and_dataset["examples"], ) else: logger.info("Creating features from dataset file at %s", input_dir) if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)): try: import tensorflow_datasets as tfds except ImportError: raise ImportError( "If not data_dir is specified, tensorflow_datasets needs to be installed." ) if args.version_2_with_negative: logger.warn( "tensorflow_datasets does not handle version 2 of SQuAD.") tfds_examples = tfds.load("squad") tfds_examples["train"] = tfds_examples["train"].map( lambda x: process_example(x)) examples = SquadV1Processor().get_examples_from_dataset( tfds_examples, evaluate=evaluate) else: processor = SquadV2Processor( ) if args.version_2_with_negative else SquadV1Processor() if evaluate: examples = processor.get_dev_examples( args.data_dir, filename=args.predict_file) else: examples = processor.get_train_examples( args.data_dir, filename=args.train_file) features, dataset = squad_convert_examples_to_features( examples=examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=not evaluate, return_dataset="pt", threads=args.threads, ) if args.local_rank in [-1, 0]: logger.info("Saving features into cached file %s", cached_features_file) torch.save( { "features": features, "dataset": dataset, "examples": examples }, cached_features_file) if args.local_rank == 0 and not evaluate: # Make sure only the first process in distributed training process the dataset, and the others will use the cache torch.distributed.barrier() if output_examples: return dataset, examples, features return dataset
def _create_examples(self, lines, labels, is_training): """Creates examples for the training and dev sets.""" # if type == "train" and lines[0][-1] != "label": # raise ValueError("For training, the input file must contain a label column.") device = 0 comet_model = "pretrained_models/atomic_pretrained_model.pickle" sampling_algo = "beam-5" opt, state_dict = interactive.load_model_file(comet_model) data_loader, text_encoder = interactive.load_data("atomic", opt) # data_loader.max_event = 30 n_ctx = data_loader.max_event + data_loader.max_effect n_vocab = len(text_encoder.encoder) + n_ctx model = interactive.make_model(opt, n_vocab, n_ctx, state_dict) nlp = spacy.load("en_core_web_sm") if device != "cpu": cfg.device = int(device) cfg.do_gpu = True torch.cuda.set_device(cfg.device) model.cuda(cfg.device) else: cfg.device = "cpu" # import pdb; pdb.set_trace() # print(model) # cuda_check = model.is_cuda # if cuda_check: # print("im on the gpu bro") sampling_algorithm = sampling_algo sampler = interactive.set_sampler(opt, sampling_algorithm, data_loader) examples = [] # print(lines[5]["context"]) # print("jere") # aug_context = augment(lines[5]["context"]) FAIL_COUNT = 0 data_loader.max_event = 17 wentToPast = False i = 0 while (i < len(lines) - 2): # while(i < 2): # while(i < 5530): if i % 50 == 0: print("WTF: ", i) # print(lines[i]["context"]) # import pdb; pdb.set_trace() new_context = lines[i]["context"].split()[:30] new_context = ' '.join(new_context) # print("KIRBY: ", new_context) # aug_context = augment(new_context) # if wentToPast: # data_loader.max_event = 17 # wentToPast = False try: # print("HERE; ", data_loader.max_event, len(examples)) aug_context = self.augment(new_context, data_loader, nlp, model, sampler, text_encoder) # print("HI FRIENDS") examples += [ SQA(input_id=str(i), contexts=aug_context, question=lines[i]["question"], choice_1=lines[i]["answerA"], choice_2=lines[i]["answerB"], choice_3=lines[i]["answerC"], label=str(int(labels[i]) - 1)) ] i = i + 1 # data_loader.max_event = 17 # print(len(examples)) except: wentToPast = True FAIL_COUNT += 1 # print("Bruh I failed, help me redeem myself: " + lines[i]["context"]) # print(FAIL_COUNT) # print("length of failed sentence") # data_loader.max_event = 17 # i -= 1 print("TOTAL EXAMPLES TODAY IS: ", len(examples)) print(examples) return examples