def run(): parser = ArgumentParser() parser.add_argument( "--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.") parser.add_argument("--dataset_cache", type=str, default='persona_comet_weak_label_preprocessed', help="Path or url of the dataset cache") parser.add_argument( "--model", type=str, default="openai-gpt", help="Model type (openai-gpt or gpt2)", choices=['openai-gpt', 'gpt2']) # anything besides gpt2 will load openai-gpt parser.add_argument("--model_checkpoint_dir", type=str, default="", help="Path, url or short name of the model") parser.add_argument("--load_checkpoint_from", type=str, default="", help="Path, url or short name of the model") parser.add_argument( "--max_history", type=int, default=2, help="Number of previous utterances to keep in history") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument("--no_sample", action='store_true', help="Set to use greedy decoding instead of sampling") parser.add_argument("--max_length", type=int, default=20, help="Maximum length of the output utterances") parser.add_argument("--min_length", type=int, default=1, help="Minimum length of the output utterances") parser.add_argument("--seed", type=int, default=0, help="Seed") parser.add_argument("--temperature", type=int, default=0.7, help="Sampling softmax temperature") parser.add_argument( "--top_k", type=int, default=0, help="Filter top-k tokens before sampling (<=0: no filtering)") parser.add_argument( "--top_p", type=float, default=0.9, help="Nucleus filtering (top-p) before sampling (<=0.0: no filtering)") parser.add_argument("--comet_greedy", action='store_true', help="Use top-most comet expansion") args = parser.parse_args() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__file__) logger.info(pformat(args)) logger.info("Get finetuned model and tokenizer") training_args = torch.load( os.path.join(args.model_checkpoint_dir, 'model_training_args.bin')) print('Loaded training args.') if args.seed != 0: random.seed(args.seed) torch.random.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) tokenizer_class, model_class = (GPT2Tokenizer, GPT2LMHeadModel) tokenizer = tokenizer_class.from_pretrained('gpt2') orig_num_tokens = len(tokenizer.encoder) print('Tokenizer length: {}'.format(orig_num_tokens)) num_added_tokens = tokenizer.add_special_tokens(ATTR_TO_SPECIAL_TOKEN) print('Tokenizer new length: {}'.format(len(tokenizer.encoder))) model = LatentMarginalizedModel(training_args, generator_class=model_class) model.gpt2_model.resize_token_embeddings(new_num_tokens=orig_num_tokens + num_added_tokens) # add_special_tokens_(model, tokenizer) # Load model weights model_checkpoint_path = os.path.join(args.model_checkpoint_dir, args.load_checkpoint_from) model_weights = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage) # corrected_model_weights = {} # for k, v in model_weights.items(): # new_k = k.replace('gpt2_model.', '').replace('', '') # corrected_model_weights[k.replace('gpt2_model.', '')] = v model.load_state_dict(model_weights, strict=False) print('Loaded model weights from {}'.format(model_checkpoint_path)) model.to(args.device) logger.info("Sample a personality") dataset = get_dataset(tokenizer, args.dataset_path, args.dataset_cache) # select train or validation split dialogs = dataset['valid'] # # personalities = [dialog["personality"] for dataset in dataset.values() for dialog in dataset] # dialogs = [dialog for dataset in dataset.values() for dialog in dataset] index = random.choice(range(len(dialogs))) print('Retrieved dialog index: {}'.format(index)) dialog = dialogs[index] # # personality = random.choice(personalities) personality = dialog['personality'] # comet_annotations = dialog["coment_annotation"] # for sent in comet_annotations: # sent_beams = [] # for effect in sent['comet'].items(): # # not sure is ' .' should be added or '.' # # tokenizer realize different tokens for each of the above options # # beams = [x+' .' for x in effect[1]['beams']] # if args.comet_greedy: # sent_beams += [effect[1]['beams'][0]] # else: # sent_beams += effect[1]['beams'] # personality += sent_beams print(personality) logger.info("Selected personality: %s", tokenizer.decode(chain(*personality))) history = [] while True: raw_text = input(">>> ") while not raw_text: print('Prompt should not be empty!') raw_text = input(">>> ") history.append(tokenizer.encode(raw_text)) raw_choice = input("Give persona choice >>> ") with torch.no_grad(): out_ids = sample_sequence(personality, history, tokenizer, model, args, persona_choice=raw_choice) history.append(out_ids) history = history[-(2 * args.max_history + 1):] out_text = tokenizer.decode(out_ids, skip_special_tokens=True) print(out_text)
def __init__( self, args, # Bookkeeping tokenizer, split, debug_mode=False, # Debugging sample=None, **kwargs, ): super().__init__() self.split = split self.length = 0 if args.no_comet_persona: self.max_num_persona = MAX_NUM_PERSONA else: self.max_num_persona = MAX_NUM_COMET_PERSONA personachat = get_dataset(tokenizer, args.dataset_path, args.dataset_cache) print("Build inputs and labels for {}".format(split)) self.dataset = defaultdict(list) # for dataset_name, dataset in personachat.items(): personachat_split = personachat[split] num_candidates = len( personachat_split[0]["utterances"][0]["candidates"]) if args.num_candidates > 0: #and split == 'train': num_candidates = min(args.num_candidates, num_candidates) if args.test_run_num > 0: personachat_split = personachat_split[:args.test_run_num] print('Restricted to {} dialogs'.format(len(personachat_split))) for d_i, dialog in tqdm(enumerate(personachat_split), total=len(personachat_split)): effects = [] persona = dialog["personality"].copy() effects += [EFFECTS['Persona']] * len(persona) if not args.no_comet_persona: comet_annotations = dialog["coment_annotation"] sent_beams = [] for j_s, sent in enumerate(comet_annotations): # logging if d_i == 0 and j_s == 0: print('For a sent: \n{}'.format(sent['comet'])) for effect_name, effect in sent['comet'].items(): # if effect_name in EFFECTS: # logging if d_i == 0 and j_s == 0: print('Getting data for effect {}'.format( effect_name)) print('Getting {} beams'.format( len(effect['beams'][:args.num_beams]))) sent_beams += effect['beams'][:args.num_beams] effects += [EFFECTS[effect_name]] * args.num_beams if d_i == 0: print('Got {} beams'.format(len(sent_beams))) print('Got {} effects'.format(len(effects))) persona += sent_beams for perm in range(args.personality_permutations): if args.no_persona: persona = [[]] else: persona = persona + [[0]] * (self.max_num_persona - len(persona)) effects = effects + [0] * (self.max_num_persona - len(effects)) for i, utterance in enumerate(dialog["utterances"]): history = utterance["history"][-(2 * args.max_history + 1):] for persona_sample in persona: for j, candidate in enumerate( utterance["candidates"][-num_candidates:]): lm_labels = bool(j == num_candidates - 1) instance = build_input_from_segments( [persona_sample], history, candidate, tokenizer, lm_labels) # print('instance: {}'.format(instance)) for input_name, input_array in instance.items(): self.dataset[input_name].append(input_array) self.dataset["mc_labels"].append(num_candidates - 1) self.dataset["persona"].append([[ROBERTA_START] + p for p in persona]) self.dataset["history"].append([ROBERTA_START] + list(chain(*history))) history_folded = kwargs.get('history_folded', False) if history_folded: self.dataset["history_folded"].append(history) self.dataset["n_candidates"] = num_candidates assert len(persona) == len(effects) self.dataset["effects"].append(effects) self.length += 1
def get_data_loaders(args, tokenizer): """ Prepare the dataset for training and evaluation """ personachat = get_dataset(tokenizer, args.dataset_path, args.dataset_cache) print("Build inputs and labels") datasets = {"train": defaultdict(list), "valid": defaultdict(list)} for dataset_name, dataset in personachat.items(): print('Loading {} set.'.format(dataset_name)) num_candidates = len(dataset[0]["utterances"][0]["candidates"]) if args.num_candidates > 0 and dataset_name == 'train': num_candidates = min(args.num_candidates, num_candidates) if args.test_run_num > 0: dataset = dataset[:args.test_run_num] for d_i, dialog in enumerate(dataset): persona = dialog["personality"].copy() if not args.no_comet_persona: comet_annotations = dialog["coment_annotation"] sent_beams = [] for j_s, sent in enumerate(comet_annotations): # logging if d_i == 0 and j_s == 0: print('For a sent: \n{}'.format(sent['comet'])) for effect_name, effect in sent['comet'].items(): # if effect_name in EFFECTS: # logging if d_i == 0 and j_s == 0: print('Getting data for effect {}'.format(effect_name)) print('Getting {} beams'.format(len(effect['beams'][:args.num_beams]))) sent_beams += effect['beams'][:args.num_beams] if d_i == 0: print('Got {} beams'.format(len(sent_beams))) # persona += sent_beams for perm in range(args.personality_permutations): if args.no_persona: refactored_persona = [[]] for i, utterance in enumerate(dialog["utterances"]): weak_label = dialog["weak_labels"][2*i + 1] if not args.no_comet_persona: weak_label_comet = dialog["weak_labels_comet"][2*i + 1] # making sure we are getting the weak labels for correct utterance if weak_label["sentence"] != utterance["candidates"][-1] and weak_label_comet["sentence"] != utterance["candidates"][-1]: print('ERROR!') print(weak_label["sentence"]) print(utterance["candidates"][-1]) # collect persona weak labels persona_labels = [] if len(weak_label["label_persona"]) > 0: for l in weak_label["label_persona"]: persona_labels.append(l["idx"]) # refactor persona for the first time refactored_persona = [persona[k] for k in persona_labels] if len(refactored_persona) == 0: refactored_persona = [[]] if not args.no_comet_persona: refactored_comet_persona = [] if len(weak_label["label_persona"]) > 0: for match in weak_label_comet["label_persona"]: comet_for_sent = comet_annotations[match[0]["persona_sent_id"]]['comet'] refactored_comet_persona.append(comet_for_sent[match[0]["comet_key"]]["beams"][match[0]["beam_id"]]) refactored_persona += refactored_comet_persona # permute turn specific refactored persona for _ in range(perm): refactored_persona = [refactored_persona[-1]] + refactored_persona[:-1] # logging for first dialog if d_i == 0: print('Original Persona: {}'.format(persona)) print('Weak labels for {}-th persona speaker turn: {}'.format(i, persona_labels)) print('Refactored persona for {}-th persona speaker turn: {}'.format(i, refactored_persona)) history = utterance["history"][-(2*args.max_history+1):] for j, candidate in enumerate(utterance["candidates"][-num_candidates:]): lm_labels = bool(j == num_candidates-1) instance = build_input_from_segments(refactored_persona, history, candidate, tokenizer, lm_labels) print('instance: {}'.format(instance)) for input_name, input_array in instance.items(): datasets[dataset_name][input_name].append(input_array) datasets[dataset_name]["mc_labels"].append(num_candidates - 1) datasets[dataset_name]["n_candidates"] = num_candidates # persona = [persona[-1]] + persona[:-1] # permuted personalities print("Pad inputs and convert to Tensor") tensor_datasets = {"train": [], "valid": []} for dataset_name, dataset in datasets.items(): dataset = pad_dataset(dataset, padding=tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1])) for input_name in MODEL_INPUTS: tensor = torch.tensor(dataset[input_name]) if input_name != "mc_labels": tensor = tensor.view((-1, datasets[dataset_name]["n_candidates"]) + tensor.shape[1:]) tensor_datasets[dataset_name].append(tensor) print("Build train and validation dataloaders") train_dataset, valid_dataset = TensorDataset(*tensor_datasets["train"]), TensorDataset(*tensor_datasets["valid"]) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, shuffle=(not args.distributed)) valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, shuffle=False) print("Train dataset (Batch, Candidates, Seq length): {}".format(train_dataset.tensors[0].shape)) print("Valid dataset (Batch, Candidates, Seq length): {}".format(valid_dataset.tensors[0].shape)) return train_loader, valid_loader, train_sampler, valid_sampler
def __init__( self, args, # Bookkeeping tokenizer, split, debug_mode=False, # Debugging sample=None, **kwargs, ): super().__init__() [self.pad_id] = tokenizer.convert_tokens_to_ids(["<pad>"]) self.split = split self.length = 0 personachat = get_dataset(tokenizer, args.dataset_path, args.dataset_cache) print("Build inputs and labels for {}".format(split)) self.dataset = {n: [] for n in MODEL_INPUTS} # for dataset_name, dataset in personachat.items(): personachat_split = personachat[split] self.num_candidates = len( personachat_split[0]["utterances"][0]["candidates"]) if args.num_candidates > 0: #and split == 'train': self.num_candidates = min(args.num_candidates, self.num_candidates) if args.test_run_num > 0: personachat_split = personachat_split[:args.test_run_num] print('Restricted to {} dialogs'.format(len(personachat_split))) for d_i, dialog in tqdm(enumerate(personachat_split), total=len(personachat_split)): effects = [] persona = dialog["personality"].copy() effects += [EFFECTS['Persona']] * len(persona) if not args.no_comet_persona: comet_annotations = dialog["coment_annotation"] sent_beams = [] for j_s, sent in enumerate(comet_annotations): # logging if d_i == 0 and j_s == 0: print('For a sent: \n{}'.format(sent['comet'])) for effect_name, effect in sent['comet'].items(): # if effect_name in EFFECTS: # logging if d_i == 0 and j_s == 0: print('Getting data for effect {}'.format( effect_name)) print('Getting {} beams'.format( len(effect['beams'][:args.num_beams]))) sent_beams += effect['beams'][:args.num_beams] effects += [EFFECTS[effect_name]] * args.num_beams if d_i == 0: print('Got {} beams'.format(len(sent_beams))) print('Got {} effects'.format(len(effects))) persona += sent_beams assert len(persona) == len(effects) for perm in range(args.personality_permutations): if args.no_persona: persona = [[]] for i, utterance in enumerate(dialog["utterances"]): history = utterance["history"][-(2 * args.max_history + 1):] sample = { "persona": [[ROBERTA_START] + p for p in persona], "history": [ROBERTA_START] + list(chain(*history)), "effects": effects, } for name in self.dataset.keys(): if name not in sample: sample[name] = [] for persona_sample in persona: for j, candidate in enumerate(utterance["candidates"] [-self.num_candidates:]): instance = build_input_from_segments( [persona_sample], history, candidate, tokenizer, j == self.num_candidates - 1) for input_name, input_array in instance.items(): sample[input_name].append(input_array) sample["mc_labels"].append(self.num_candidates - 1) for name, value in sample.items(): self.dataset[name].append(value) self.length += 1
all_persona_from_joint.append(persona_interpreted) prior_z = model.prior_model.get_prob_z_given_H(persona, history) z = torch.argmax(prior_z, axis=1).item() all_persona_from_prior.append(z) if args.perplexity: average_nll = sum(losses) / len(losses) ppl = math.exp(average_nll) print("Average Loss: {}".format(average_nll)) print("Average PPL: {}".format(ppl)) # interpretability # load dataset if args.interpret: dataset = get_dataset(tokenizer, args.dataset_path, args.dataset_cache)['valid'] if args.test_run_num > 0: dataset = dataset[:args.test_run_num] acc_joint = 0 acc_prior = 0 total_labels = 0 utt_count = 0 for d_i, dialog in tqdm(enumerate(dataset), total=len(dataset)): for i, utterance in enumerate(dialog["utterances"]): weak_label = dialog["weak_labels"][2 * i + 1] if not training_args.no_comet_persona: weak_label_comet = dialog["weak_labels_comet"][2 * i + 1] # making sure we are getting the weak labels for correct utterance if weak_label["sentence"] != utterance["candidates"][ -1] and weak_label_comet["sentence"] != utterance[