Exemple #1
0
def query(input_event, category='xNeed', sampling_algorithm='beam-10 '):
    sampler = interactive.set_sampler(opt, sampling_algorithm, data_loader)

    outputs = interactive.get_atomic_sequence(input_event, model, sampler, data_loader, text_encoder, category)[category]
    for k, v in outputs.items():
        outputs[k] = v[::-1]
    return outputs
Exemple #2
0
def fnc(persona, debug=False): # list of sentences

    ret = []
    for sent in persona:
        
        cur_sent = {'sent':sent}
        if debug:
            print()
            print("-x"*33)
            print("=====>>>> sent = ", sent)

        input_event = sent
        category = "all"
        
        if debug:
            print()
            print("category = ", category)

        sampler = interactive.set_sampler(opt, sampling_algorithm, data_loader)
        outputs = interactive.get_atomic_sequence(
            input_event, model, sampler, data_loader, text_encoder, category)
        # print("outputs = ", outputs)
        cur_sent['comet']  = outputs
        ret.append(cur_sent)

    return ret
def get_comet_prediction(all_story):
    # all_story list (bs) of list (5)
    all_outputs = []
    for i, story in enumerate(all_story):
        output = []
        if len(story) < 3:
            all_outputs.append(output)
            continue
        protagonist = find_gender(story)
        pos = create_pos(story, protagonist)
        for k, input_event in enumerate(story):
            #print(input_event)
            input_event = input_event.replace(
                ",", "").strip(".").strip("!").strip()
            input_event = ' '.join(i for i in input_event.split()[:17])
            category = 'xReact' if pos[k] == 'x' else 'oReact'
            sampling_algorithm = 'topk-1' if pos[k] == 'x' else 'topk-2'
            sampler = interactive.set_sampler(opt, sampling_algorithm,
                                              data_loader)
            out = interactive.get_atomic_sequence(input_event, model, sampler,
                                                  data_loader, text_encoder,
                                                  category)
            # out_file.write("%s\t%s\n" % (input_event, out[category]["beams"][0]))
            if category == "oReact" and out[category]["beams"][0] == "none":
                output.append(out[category]["beams"][1])
            else:
                output.append(out[category]["beams"][0])
        all_outputs.append(output)
    return all_outputs  # list(bs) of list(5)
    def infer(self, event, relations=['all'], sampling_algorithm='greedy'):
        sampler = interactive.set_sampler(self._opt, sampling_algorithm,
                                          self._data_loader)

        if isinstance(relations, str):
            relations = [relations]

        if not 'all' in relations and not set(
                relations) <= self.get_relations():
            raise ValueError(set(relations) - self.get_relations())

        return self._get_sequence(event, relations, sampler)
Exemple #5
0
    def __init__(self, graph, model_path, decoding_algorithm):
        self.graph = graph
        self.model_path = model_path
        self.decoding_algorithm = decoding_algorithm

        self._opt, self._state_dict = interactive.load_model_file(
            self.model_path)
        self._data_loader, self._text_encoder = interactive.load_data(
            self.graph, self._opt)
        self._sampler = interactive.set_sampler(self._opt,
                                                self.decoding_algorithm,
                                                self._data_loader)
        self._n_ctx = self._calc_n_ctx()
        self._n_vocab = len(self._text_encoder.encoder) + self._n_ctx

        self._model = interactive.make_model(self._opt, self._n_vocab,
                                             self._n_ctx, self._state_dict)
        self._model.to(device=settings.device)

        self._input_event_model = None
        self._response_model = None
        self._annotator_input_model = None
        self._annotator_response_model = None
def fetch_model(args):

    opt, state_dict = interactive.load_model_file(args.model_file)

    data_loader, text_encoder = interactive.load_data("atomic", opt)

    n_ctx = data_loader.max_event + data_loader.max_effect
    n_vocab = len(text_encoder.encoder) + n_ctx
    model = interactive.make_model(opt, n_vocab, n_ctx, state_dict)

    if args.device != "cpu":
        cfg.device = int(args.device)
        cfg.do_gpu = True
        torch.cuda.set_device(cfg.device)
        model.cuda(cfg.device)
    else:
        cfg.device = "cpu"

    # Set the sampling algorithm
    sampling_algorithm = args.sampling_algorithm
    sampler = interactive.set_sampler(opt, sampling_algorithm, data_loader)

    return model, sampler, data_loader, text_encoder
        input_event = "help"
        category = "help"
        sampling_algorithm = args.sampling_algorithm

        while input_event is None or input_event.lower() == "help":
            input_event = input("Give an event (e.g., PersonX went to the mall): ")

            if input_event == "help":
                interactive.print_help(opt.dataset)

        while category.lower() == "help":
            category = input("Give an effect type (type \"help\" for an explanation): ")

            if category == "help":
                interactive.print_category_help(opt.dataset)

        while sampling_algorithm.lower() == "help":
            sampling_algorithm = input("Give a sampling algorithm (type \"help\" for an explanation): ")

            if sampling_algorithm == "help":
                interactive.print_sampling_help()

        sampler = interactive.set_sampler(opt, sampling_algorithm, data_loader)

        if category not in data_loader.categories:
            category = "all"

        outputs = interactive.get_atomic_sequence(
            input_event, model, sampler, data_loader, text_encoder, category)

def load_and_cache_examples(args,
                            tokenizer,
                            evaluate=False,
                            output_examples=False):
    # COMET model setting up
    device = "0"
    comet_model = "pretrained_models/atomic_pretrained_model.pickle"
    sampling_algo = "beam-2"
    opt, state_dict = interactive.load_model_file(comet_model)

    data_loader, text_encoder = interactive.load_data("atomic", opt)

    n_ctx = data_loader.max_event + data_loader.max_effect
    n_vocab = len(text_encoder.encoder) + n_ctx
    model = interactive.make_model(opt, n_vocab, n_ctx, state_dict)
    nlp = spacy.load("en_core_web_sm")

    if device != "cpu":
        cfg.device = int(device)
        cfg.do_gpu = True
        torch.cuda.set_device(cfg.device)
        model.cuda(cfg.device)
    else:
        cfg.device = "cpu"

    sampling_algorithm = sampling_algo

    sampler = interactive.set_sampler(opt, sampling_algorithm, data_loader)

    def augment(article):
        context = (article.numpy().decode('UTF-8'))

        category_list = ["xNeed", "xIntent", "xWant", "xReact"]

        for category in category_list:

            entity_list = nlp(context)
            input_event = context
            replaced = []
            replacement_list = ["PersonX", "PersonY", "PersonZ"]
            r = 0
            for entity in entity_list.ents:
                if entity.label_ == 'PERSON' or entity.label_ == 'NORP':
                    input_event = input_event.replace(entity.text,
                                                      replacement_list[r])
                    r += 1
                    if (r == 3):
                        break

            outputs = interactive.get_atomic_sequence(input_event, model,
                                                      sampler, data_loader,
                                                      text_encoder, category)

            for key in outputs:

                prefix = ""
                if (key[0] == "o"):
                    if (key == "oEffect"):
                        prefix = " Everyone else "
                    elif (key == "oReact"):
                        prefix = "They are "
                    elif (key == "oWant"):
                        prefix = "They want "
                else:
                    if (len(replaced) != 0):
                        prefix = replaced[0]
                    else:
                        prefix = "Person"
                    if (key == "xAttr"):
                        prefix += " is "
                    elif (key == "xEffect"):
                        prefix += " "
                    elif (key == "xIntent"):
                        prefix += " intends "
                    elif (key == "xReact"):
                        prefix += " is "
                    elif (key == "xNeed"):
                        prefix += " needs "
                    elif (key == "xWant"):
                        prefix += " wants "

                for j in range(5):

                    if (outputs[key]["beams"][j] != 'none'):
                        comet_inf = outputs[key]["beams"][j]
                        if (len(replaced) > 0):
                            comet_inf = comet_inf.replace(
                                "personx", replaced[0])
                            if (len(replaced) > 1):
                                comet_inf = comet_inf.replace(
                                    "persony", replaced[1])

                        article += prefix + (comet_inf) + ". "
                        break

        return article

    def process_example(example):

        example['context'] = tf.py_function(func=augment,
                                            inp=[example['context']],
                                            Tout=tf.string)
        return example

    ## End

    if args.local_rank not in [-1, 0] and not evaluate:
        # Make sure only the first process in distributed training process the dataset, and the others will use the cache
        torch.distributed.barrier()

    # Load data features from cache or dataset file
    input_dir = args.data_dir if args.data_dir else "."
    cached_features_file = os.path.join(
        input_dir,
        "cached_{}_{}_{}".format(
            "dev" if evaluate else "train",
            list(filter(None, args.model_name_or_path.split("/"))).pop(),
            str(args.max_seq_length),
        ),
    )

    # Init features and dataset from cache if it exists
    if os.path.exists(cached_features_file) and not args.overwrite_cache:
        logger.info("Loading features from cached file %s",
                    cached_features_file)
        features_and_dataset = torch.load(cached_features_file)
        features, dataset, examples = (
            features_and_dataset["features"],
            features_and_dataset["dataset"],
            features_and_dataset["examples"],
        )
    else:
        logger.info("Creating features from dataset file at %s", input_dir)

        if not args.data_dir and ((evaluate and not args.predict_file) or
                                  (not evaluate and not args.train_file)):
            try:
                import tensorflow_datasets as tfds
            except ImportError:
                raise ImportError(
                    "If not data_dir is specified, tensorflow_datasets needs to be installed."
                )

            if args.version_2_with_negative:
                logger.warn(
                    "tensorflow_datasets does not handle version 2 of SQuAD.")

            tfds_examples = tfds.load("squad")
            tfds_examples["train"] = tfds_examples["train"].map(
                lambda x: process_example(x))
            examples = SquadV1Processor().get_examples_from_dataset(
                tfds_examples, evaluate=evaluate)
        else:
            processor = SquadV2Processor(
            ) if args.version_2_with_negative else SquadV1Processor()
            if evaluate:
                examples = processor.get_dev_examples(
                    args.data_dir, filename=args.predict_file)
            else:
                examples = processor.get_train_examples(
                    args.data_dir, filename=args.train_file)

        features, dataset = squad_convert_examples_to_features(
            examples=examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=not evaluate,
            return_dataset="pt",
            threads=args.threads,
        )

        if args.local_rank in [-1, 0]:
            logger.info("Saving features into cached file %s",
                        cached_features_file)
            torch.save(
                {
                    "features": features,
                    "dataset": dataset,
                    "examples": examples
                }, cached_features_file)

    if args.local_rank == 0 and not evaluate:
        # Make sure only the first process in distributed training process the dataset, and the others will use the cache
        torch.distributed.barrier()

    if output_examples:
        return dataset, examples, features
    return dataset
Exemple #9
0
    def _create_examples(self, lines, labels, is_training):
        """Creates examples for the training and dev sets."""
        # if type == "train" and lines[0][-1] != "label":
        #     raise ValueError("For training, the input file must contain a label column.")
        device = 0
        comet_model = "pretrained_models/atomic_pretrained_model.pickle"
        sampling_algo = "beam-5"
        opt, state_dict = interactive.load_model_file(comet_model)

        data_loader, text_encoder = interactive.load_data("atomic", opt)

        # data_loader.max_event = 30
        n_ctx = data_loader.max_event + data_loader.max_effect
        n_vocab = len(text_encoder.encoder) + n_ctx
        model = interactive.make_model(opt, n_vocab, n_ctx, state_dict)
        nlp = spacy.load("en_core_web_sm")

        if device != "cpu":
            cfg.device = int(device)
            cfg.do_gpu = True
            torch.cuda.set_device(cfg.device)
            model.cuda(cfg.device)
        else:
            cfg.device = "cpu"

        # import pdb; pdb.set_trace()
        # print(model)
        # cuda_check = model.is_cuda
        # if cuda_check:
        #     print("im on the gpu bro")

        sampling_algorithm = sampling_algo

        sampler = interactive.set_sampler(opt, sampling_algorithm, data_loader)

        examples = []
        # print(lines[5]["context"])
        # print("jere")
        # aug_context = augment(lines[5]["context"])
        FAIL_COUNT = 0

        data_loader.max_event = 17
        wentToPast = False

        i = 0
        while (i < len(lines) - 2):
            # while(i < 2):
            # while(i < 5530):
            if i % 50 == 0:
                print("WTF: ", i)
            # print(lines[i]["context"])
            # import pdb; pdb.set_trace()

            new_context = lines[i]["context"].split()[:30]
            new_context = ' '.join(new_context)

            # print("KIRBY: ", new_context)

            # aug_context = augment(new_context)

            # if wentToPast:
            #   data_loader.max_event = 17
            #   wentToPast = False

            try:
                # print("HERE; ", data_loader.max_event, len(examples))
                aug_context = self.augment(new_context, data_loader, nlp,
                                           model, sampler, text_encoder)
                # print("HI FRIENDS")

                examples += [
                    SQA(input_id=str(i),
                        contexts=aug_context,
                        question=lines[i]["question"],
                        choice_1=lines[i]["answerA"],
                        choice_2=lines[i]["answerB"],
                        choice_3=lines[i]["answerC"],
                        label=str(int(labels[i]) - 1))
                ]

                i = i + 1
                # data_loader.max_event = 17
                # print(len(examples))

            except:
                wentToPast = True
                FAIL_COUNT += 1
                # print("Bruh I failed, help me redeem myself: " + lines[i]["context"])
                # print(FAIL_COUNT)
                # print("length of failed sentence")

                # data_loader.max_event = 17
                # i -= 1

        print("TOTAL EXAMPLES TODAY IS: ", len(examples))
        print(examples)
        return examples