示例#1
0
def run():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="",
        help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='persona_comet_weak_label_preprocessed',
                        help="Path or url of the dataset cache")
    parser.add_argument(
        "--model",
        type=str,
        default="openai-gpt",
        help="Model type (openai-gpt or gpt2)",
        choices=['openai-gpt',
                 'gpt2'])  # anything besides gpt2 will load openai-gpt
    parser.add_argument("--model_checkpoint_dir",
                        type=str,
                        default="",
                        help="Path, url or short name of the model")
    parser.add_argument("--load_checkpoint_from",
                        type=str,
                        default="",
                        help="Path, url or short name of the model")

    parser.add_argument(
        "--max_history",
        type=int,
        default=2,
        help="Number of previous utterances to keep in history")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")

    parser.add_argument("--no_sample",
                        action='store_true',
                        help="Set to use greedy decoding instead of sampling")
    parser.add_argument("--max_length",
                        type=int,
                        default=20,
                        help="Maximum length of the output utterances")
    parser.add_argument("--min_length",
                        type=int,
                        default=1,
                        help="Minimum length of the output utterances")
    parser.add_argument("--seed", type=int, default=0, help="Seed")
    parser.add_argument("--temperature",
                        type=int,
                        default=0.7,
                        help="Sampling softmax temperature")
    parser.add_argument(
        "--top_k",
        type=int,
        default=0,
        help="Filter top-k tokens before sampling (<=0: no filtering)")
    parser.add_argument(
        "--top_p",
        type=float,
        default=0.9,
        help="Nucleus filtering (top-p) before sampling (<=0.0: no filtering)")
    parser.add_argument("--comet_greedy",
                        action='store_true',
                        help="Use top-most comet expansion")
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__file__)
    logger.info(pformat(args))

    logger.info("Get finetuned model and tokenizer")
    training_args = torch.load(
        os.path.join(args.model_checkpoint_dir, 'model_training_args.bin'))
    print('Loaded training args.')

    if args.seed != 0:
        random.seed(args.seed)
        torch.random.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

    tokenizer_class, model_class = (GPT2Tokenizer, GPT2LMHeadModel)
    tokenizer = tokenizer_class.from_pretrained('gpt2')
    orig_num_tokens = len(tokenizer.encoder)
    print('Tokenizer length: {}'.format(orig_num_tokens))
    num_added_tokens = tokenizer.add_special_tokens(ATTR_TO_SPECIAL_TOKEN)
    print('Tokenizer new length: {}'.format(len(tokenizer.encoder)))
    model = LatentMarginalizedModel(training_args, generator_class=model_class)
    model.gpt2_model.resize_token_embeddings(new_num_tokens=orig_num_tokens +
                                             num_added_tokens)
    # add_special_tokens_(model, tokenizer)

    # Load model weights
    model_checkpoint_path = os.path.join(args.model_checkpoint_dir,
                                         args.load_checkpoint_from)
    model_weights = torch.load(model_checkpoint_path,
                               map_location=lambda storage, loc: storage)
    # corrected_model_weights = {}
    # for k, v in model_weights.items():
    #     new_k = k.replace('gpt2_model.', '').replace('', '')
    #     corrected_model_weights[k.replace('gpt2_model.', '')] = v

    model.load_state_dict(model_weights, strict=False)
    print('Loaded model weights from {}'.format(model_checkpoint_path))

    model.to(args.device)

    logger.info("Sample a personality")
    dataset = get_dataset(tokenizer, args.dataset_path, args.dataset_cache)
    # select train or validation split
    dialogs = dataset['valid']
    # # personalities = [dialog["personality"] for dataset in dataset.values() for dialog in dataset]
    # dialogs = [dialog for dataset in dataset.values() for dialog in dataset]
    index = random.choice(range(len(dialogs)))
    print('Retrieved dialog index: {}'.format(index))
    dialog = dialogs[index]

    # # personality = random.choice(personalities)
    personality = dialog['personality']
    # comet_annotations = dialog["coment_annotation"]
    # for sent in comet_annotations:
    #     sent_beams = []
    #     for effect in sent['comet'].items():
    #         # not sure is ' .' should be added or '.'
    #         # tokenizer realize different tokens for each of the above options
    #         # beams = [x+' .' for x in effect[1]['beams']]
    #         if args.comet_greedy:
    #             sent_beams += [effect[1]['beams'][0]]
    #         else:
    #             sent_beams += effect[1]['beams']
    # personality += sent_beams
    print(personality)
    logger.info("Selected personality: %s",
                tokenizer.decode(chain(*personality)))

    history = []
    while True:
        raw_text = input(">>> ")
        while not raw_text:
            print('Prompt should not be empty!')
            raw_text = input(">>> ")
        history.append(tokenizer.encode(raw_text))
        raw_choice = input("Give persona choice >>> ")
        with torch.no_grad():
            out_ids = sample_sequence(personality,
                                      history,
                                      tokenizer,
                                      model,
                                      args,
                                      persona_choice=raw_choice)
        history.append(out_ids)
        history = history[-(2 * args.max_history + 1):]
        out_text = tokenizer.decode(out_ids, skip_special_tokens=True)
        print(out_text)
示例#2
0
    def __init__(
        self,
        args,  # Bookkeeping
        tokenizer,
        split,
        debug_mode=False,  # Debugging
        sample=None,
        **kwargs,
    ):
        super().__init__()

        self.split = split
        self.length = 0

        if args.no_comet_persona:
            self.max_num_persona = MAX_NUM_PERSONA
        else:
            self.max_num_persona = MAX_NUM_COMET_PERSONA

        personachat = get_dataset(tokenizer, args.dataset_path,
                                  args.dataset_cache)
        print("Build inputs and labels for {}".format(split))

        self.dataset = defaultdict(list)
        # for dataset_name, dataset in personachat.items():
        personachat_split = personachat[split]
        num_candidates = len(
            personachat_split[0]["utterances"][0]["candidates"])
        if args.num_candidates > 0:  #and split == 'train':
            num_candidates = min(args.num_candidates, num_candidates)

        if args.test_run_num > 0:
            personachat_split = personachat_split[:args.test_run_num]

        print('Restricted to {} dialogs'.format(len(personachat_split)))

        for d_i, dialog in tqdm(enumerate(personachat_split),
                                total=len(personachat_split)):
            effects = []
            persona = dialog["personality"].copy()
            effects += [EFFECTS['Persona']] * len(persona)
            if not args.no_comet_persona:
                comet_annotations = dialog["coment_annotation"]
                sent_beams = []
                for j_s, sent in enumerate(comet_annotations):
                    # logging
                    if d_i == 0 and j_s == 0:
                        print('For a sent: \n{}'.format(sent['comet']))
                    for effect_name, effect in sent['comet'].items():
                        # if effect_name in EFFECTS:
                        # logging
                        if d_i == 0 and j_s == 0:
                            print('Getting data for effect {}'.format(
                                effect_name))
                            print('Getting {} beams'.format(
                                len(effect['beams'][:args.num_beams])))
                        sent_beams += effect['beams'][:args.num_beams]
                        effects += [EFFECTS[effect_name]] * args.num_beams
                if d_i == 0:
                    print('Got {} beams'.format(len(sent_beams)))
                    print('Got {} effects'.format(len(effects)))
                persona += sent_beams

            for perm in range(args.personality_permutations):
                if args.no_persona:
                    persona = [[]]
                else:
                    persona = persona + [[0]] * (self.max_num_persona -
                                                 len(persona))
                    effects = effects + [0] * (self.max_num_persona -
                                               len(effects))
                for i, utterance in enumerate(dialog["utterances"]):
                    history = utterance["history"][-(2 * args.max_history +
                                                     1):]
                    for persona_sample in persona:
                        for j, candidate in enumerate(
                                utterance["candidates"][-num_candidates:]):
                            lm_labels = bool(j == num_candidates - 1)
                            instance = build_input_from_segments(
                                [persona_sample], history, candidate,
                                tokenizer, lm_labels)
                            # print('instance: {}'.format(instance))
                            for input_name, input_array in instance.items():
                                self.dataset[input_name].append(input_array)

                        self.dataset["mc_labels"].append(num_candidates - 1)

                    self.dataset["persona"].append([[ROBERTA_START] + p
                                                    for p in persona])
                    self.dataset["history"].append([ROBERTA_START] +
                                                   list(chain(*history)))
                    history_folded = kwargs.get('history_folded', False)
                    if history_folded:
                        self.dataset["history_folded"].append(history)
                    self.dataset["n_candidates"] = num_candidates
                    assert len(persona) == len(effects)
                    self.dataset["effects"].append(effects)

                    self.length += 1
示例#3
0
def get_data_loaders(args, tokenizer):
    """ Prepare the dataset for training and evaluation """
    personachat = get_dataset(tokenizer, args.dataset_path, args.dataset_cache)

    print("Build inputs and labels")
    datasets = {"train": defaultdict(list), "valid": defaultdict(list)}
    for dataset_name, dataset in personachat.items():
        print('Loading {} set.'.format(dataset_name))
        num_candidates = len(dataset[0]["utterances"][0]["candidates"])
        if args.num_candidates > 0 and dataset_name == 'train':
            num_candidates = min(args.num_candidates, num_candidates)
        
        if args.test_run_num > 0:
            dataset = dataset[:args.test_run_num]

        for d_i, dialog in enumerate(dataset):
            persona = dialog["personality"].copy()
            if not args.no_comet_persona:
                comet_annotations = dialog["coment_annotation"]
                sent_beams = []
                for j_s, sent in enumerate(comet_annotations):
                    # logging
                    if d_i == 0 and j_s == 0:
                        print('For a sent: \n{}'.format(sent['comet']))
                    for effect_name, effect in sent['comet'].items():
                        # if effect_name in EFFECTS:
                            # logging
                            if d_i == 0 and j_s == 0:
                                print('Getting data for effect {}'.format(effect_name))
                                print('Getting {} beams'.format(len(effect['beams'][:args.num_beams])))
                            sent_beams += effect['beams'][:args.num_beams]
                if d_i == 0:
                    print('Got {} beams'.format(len(sent_beams)))        
                # persona += sent_beams
            
            for perm in range(args.personality_permutations):
                if args.no_persona:
                    refactored_persona = [[]]
                for i, utterance in enumerate(dialog["utterances"]):
                    weak_label = dialog["weak_labels"][2*i + 1]
                    if not args.no_comet_persona:
                        weak_label_comet = dialog["weak_labels_comet"][2*i + 1]
                    # making sure we are getting the weak labels for correct utterance
                    if weak_label["sentence"] != utterance["candidates"][-1] and weak_label_comet["sentence"] != utterance["candidates"][-1]:
                        print('ERROR!')
                        print(weak_label["sentence"])
                        print(utterance["candidates"][-1])

                    # collect persona weak labels
                    persona_labels = []
                    if len(weak_label["label_persona"]) > 0:
                        for l in weak_label["label_persona"]:
                            persona_labels.append(l["idx"])

                    # refactor persona for the first time
                    refactored_persona = [persona[k] for k in persona_labels]
                    if len(refactored_persona) == 0:
                        refactored_persona = [[]]

                    if not args.no_comet_persona:
                        refactored_comet_persona = []
                        if len(weak_label["label_persona"]) > 0:
                            for match in weak_label_comet["label_persona"]:
                                comet_for_sent = comet_annotations[match[0]["persona_sent_id"]]['comet']
                                refactored_comet_persona.append(comet_for_sent[match[0]["comet_key"]]["beams"][match[0]["beam_id"]])
                        
                        refactored_persona += refactored_comet_persona
                    
                    # permute turn specific refactored persona
                    for _ in range(perm):
                        refactored_persona = [refactored_persona[-1]] + refactored_persona[:-1]

                    # logging for first dialog
                    if d_i == 0:
                        print('Original Persona: {}'.format(persona))
                        print('Weak labels for {}-th persona speaker turn: {}'.format(i, persona_labels))
                        print('Refactored persona for {}-th persona speaker turn: {}'.format(i, refactored_persona))

                    history = utterance["history"][-(2*args.max_history+1):]
                    for j, candidate in enumerate(utterance["candidates"][-num_candidates:]):
                        lm_labels = bool(j == num_candidates-1)
                        instance = build_input_from_segments(refactored_persona, history, candidate, tokenizer, lm_labels)
                        print('instance: {}'.format(instance))
                        for input_name, input_array in instance.items():
                            datasets[dataset_name][input_name].append(input_array)
                    datasets[dataset_name]["mc_labels"].append(num_candidates - 1)
                    datasets[dataset_name]["n_candidates"] = num_candidates
                # persona = [persona[-1]] + persona[:-1]  # permuted personalities

    print("Pad inputs and convert to Tensor")
    tensor_datasets = {"train": [], "valid": []}
    for dataset_name, dataset in datasets.items():
        dataset = pad_dataset(dataset, padding=tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1]))
        for input_name in MODEL_INPUTS:
            tensor = torch.tensor(dataset[input_name])
            if input_name != "mc_labels":
                tensor = tensor.view((-1, datasets[dataset_name]["n_candidates"]) + tensor.shape[1:])
            tensor_datasets[dataset_name].append(tensor)

    print("Build train and validation dataloaders")
    train_dataset, valid_dataset = TensorDataset(*tensor_datasets["train"]), TensorDataset(*tensor_datasets["valid"])
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
    valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None
    train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, shuffle=(not args.distributed))
    valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, shuffle=False)

    print("Train dataset (Batch, Candidates, Seq length): {}".format(train_dataset.tensors[0].shape))
    print("Valid dataset (Batch, Candidates, Seq length): {}".format(valid_dataset.tensors[0].shape))
    return train_loader, valid_loader, train_sampler, valid_sampler
示例#4
0
    def __init__(
        self,
        args,  # Bookkeeping
        tokenizer,
        split,
        debug_mode=False,  # Debugging
        sample=None,
        **kwargs,
    ):
        super().__init__()
        [self.pad_id] = tokenizer.convert_tokens_to_ids(["<pad>"])
        self.split = split
        self.length = 0

        personachat = get_dataset(tokenizer, args.dataset_path,
                                  args.dataset_cache)
        print("Build inputs and labels for {}".format(split))

        self.dataset = {n: [] for n in MODEL_INPUTS}
        # for dataset_name, dataset in personachat.items():
        personachat_split = personachat[split]
        self.num_candidates = len(
            personachat_split[0]["utterances"][0]["candidates"])
        if args.num_candidates > 0:  #and split == 'train':
            self.num_candidates = min(args.num_candidates, self.num_candidates)

        if args.test_run_num > 0:
            personachat_split = personachat_split[:args.test_run_num]

        print('Restricted to {} dialogs'.format(len(personachat_split)))

        for d_i, dialog in tqdm(enumerate(personachat_split),
                                total=len(personachat_split)):
            effects = []
            persona = dialog["personality"].copy()
            effects += [EFFECTS['Persona']] * len(persona)
            if not args.no_comet_persona:
                comet_annotations = dialog["coment_annotation"]
                sent_beams = []
                for j_s, sent in enumerate(comet_annotations):
                    # logging
                    if d_i == 0 and j_s == 0:
                        print('For a sent: \n{}'.format(sent['comet']))
                    for effect_name, effect in sent['comet'].items():
                        # if effect_name in EFFECTS:
                        # logging
                        if d_i == 0 and j_s == 0:
                            print('Getting data for effect {}'.format(
                                effect_name))
                            print('Getting {} beams'.format(
                                len(effect['beams'][:args.num_beams])))
                        sent_beams += effect['beams'][:args.num_beams]
                        effects += [EFFECTS[effect_name]] * args.num_beams
                if d_i == 0:
                    print('Got {} beams'.format(len(sent_beams)))
                    print('Got {} effects'.format(len(effects)))
                persona += sent_beams
            assert len(persona) == len(effects)
            for perm in range(args.personality_permutations):
                if args.no_persona:
                    persona = [[]]
                for i, utterance in enumerate(dialog["utterances"]):
                    history = utterance["history"][-(2 * args.max_history +
                                                     1):]
                    sample = {
                        "persona": [[ROBERTA_START] + p for p in persona],
                        "history": [ROBERTA_START] + list(chain(*history)),
                        "effects": effects,
                    }
                    for name in self.dataset.keys():
                        if name not in sample:
                            sample[name] = []
                    for persona_sample in persona:
                        for j, candidate in enumerate(utterance["candidates"]
                                                      [-self.num_candidates:]):
                            instance = build_input_from_segments(
                                [persona_sample], history, candidate,
                                tokenizer, j == self.num_candidates - 1)
                            for input_name, input_array in instance.items():
                                sample[input_name].append(input_array)

                        sample["mc_labels"].append(self.num_candidates - 1)
                    for name, value in sample.items():
                        self.dataset[name].append(value)
                    self.length += 1
示例#5
0
            all_persona_from_joint.append(persona_interpreted)

            prior_z = model.prior_model.get_prob_z_given_H(persona, history)
            z = torch.argmax(prior_z, axis=1).item()
            all_persona_from_prior.append(z)

if args.perplexity:
    average_nll = sum(losses) / len(losses)
    ppl = math.exp(average_nll)
    print("Average Loss: {}".format(average_nll))
    print("Average PPL: {}".format(ppl))

# interpretability
# load dataset
if args.interpret:
    dataset = get_dataset(tokenizer, args.dataset_path,
                          args.dataset_cache)['valid']
    if args.test_run_num > 0:
        dataset = dataset[:args.test_run_num]

    acc_joint = 0
    acc_prior = 0
    total_labels = 0
    utt_count = 0
    for d_i, dialog in tqdm(enumerate(dataset), total=len(dataset)):
        for i, utterance in enumerate(dialog["utterances"]):
            weak_label = dialog["weak_labels"][2 * i + 1]
            if not training_args.no_comet_persona:
                weak_label_comet = dialog["weak_labels_comet"][2 * i + 1]
            # making sure we are getting the weak labels for correct utterance
            if weak_label["sentence"] != utterance["candidates"][
                    -1] and weak_label_comet["sentence"] != utterance[