Esempio n. 1
0
def fnc(persona, debug=False): # list of sentences

    ret = []
    for sent in persona:
        
        cur_sent = {'sent':sent}
        if debug:
            print()
            print("-x"*33)
            print("=====>>>> sent = ", sent)

        input_event = sent
        category = "all"
        
        if debug:
            print()
            print("category = ", category)

        sampler = interactive.set_sampler(opt, sampling_algorithm, data_loader)
        outputs = interactive.get_atomic_sequence(
            input_event, model, sampler, data_loader, text_encoder, category)
        # print("outputs = ", outputs)
        cur_sent['comet']  = outputs
        ret.append(cur_sent)

    return ret
Esempio n. 2
0
def query(input_event, category='xNeed', sampling_algorithm='beam-10 '):
    sampler = interactive.set_sampler(opt, sampling_algorithm, data_loader)

    outputs = interactive.get_atomic_sequence(input_event, model, sampler, data_loader, text_encoder, category)[category]
    for k, v in outputs.items():
        outputs[k] = v[::-1]
    return outputs
Esempio n. 3
0
 def _get_result(self, event: str, category: Sequence[str]) -> Dict:
     raw_result = interactive.get_atomic_sequence(event, self._model,
                                                  self._sampler,
                                                  self._data_loader,
                                                  self._text_encoder,
                                                  category)
     return self.all_beams_cleanup(raw_result)
def get_comet_prediction(all_story):
    # all_story list (bs) of list (5)
    all_outputs = []
    for i, story in enumerate(all_story):
        output = []
        if len(story) < 3:
            all_outputs.append(output)
            continue
        protagonist = find_gender(story)
        pos = create_pos(story, protagonist)
        for k, input_event in enumerate(story):
            #print(input_event)
            input_event = input_event.replace(
                ",", "").strip(".").strip("!").strip()
            input_event = ' '.join(i for i in input_event.split()[:17])
            category = 'xReact' if pos[k] == 'x' else 'oReact'
            sampling_algorithm = 'topk-1' if pos[k] == 'x' else 'topk-2'
            sampler = interactive.set_sampler(opt, sampling_algorithm,
                                              data_loader)
            out = interactive.get_atomic_sequence(input_event, model, sampler,
                                                  data_loader, text_encoder,
                                                  category)
            # out_file.write("%s\t%s\n" % (input_event, out[category]["beams"][0]))
            if category == "oReact" and out[category]["beams"][0] == "none":
                output.append(out[category]["beams"][1])
            else:
                output.append(out[category]["beams"][0])
        all_outputs.append(output)
    return all_outputs  # list(bs) of list(5)
def gen_preds(model_obj, text_chunk):
    model, sampler, data_loader, text_encoder = model_obj

    category_list = ['oEffect', 'xEffect', 'xNeed']

    atomic_output = {}
    for sent_idx, sentence in enumerate(text_chunk):
        atomic_output[sent_idx] = {}
        for category in category_list:
            outputs = interactive.get_atomic_sequence(sentence, model, sampler,
                                                      data_loader,
                                                      text_encoder, category)

            atomic_output[sent_idx][category] = outputs[category]

    return atomic_output
Esempio n. 6
0
def augment(article):
    title_new = (article.numpy().decode('UTF-8'))
    category = "oEffect"
    entity_list = nlp(title_new)
    input_event = title_new
    replacement_list = ["PersonX", "PersonY", "PersonZ"]
    r = 0
    for entity in entity_list.ents:
        if entity.label_ == 'PERSON' or entity.label_ == 'NORP':
            input_event = input_event.replace(entity.text, replacement_list[r])
            r += 1
            if (r == 3):
                break

    outputs = interactive.get_atomic_sequence(input_event, model, sampler,
                                              data_loader, text_encoder,
                                              category)

    for key in outputs:
        article = title_new + ((outputs[key]["beams"][0]))
    return article
        input_event = "help"
        category = "help"
        sampling_algorithm = args.sampling_algorithm

        while input_event is None or input_event.lower() == "help":
            input_event = input("Give an event (e.g., PersonX went to the mall): ")

            if input_event == "help":
                interactive.print_help(opt.dataset)

        while category.lower() == "help":
            category = input("Give an effect type (type \"help\" for an explanation): ")

            if category == "help":
                interactive.print_category_help(opt.dataset)

        while sampling_algorithm.lower() == "help":
            sampling_algorithm = input("Give a sampling algorithm (type \"help\" for an explanation): ")

            if sampling_algorithm == "help":
                interactive.print_sampling_help()

        sampler = interactive.set_sampler(opt, sampling_algorithm, data_loader)

        if category not in data_loader.categories:
            category = "all"

        outputs = interactive.get_atomic_sequence(
            input_event, model, sampler, data_loader, text_encoder, category)

    def augment(article):
        context = (article.numpy().decode('UTF-8'))

        category_list = ["xNeed", "xIntent", "xWant", "xReact"]

        for category in category_list:

            entity_list = nlp(context)
            input_event = context
            replaced = []
            replacement_list = ["PersonX", "PersonY", "PersonZ"]
            r = 0
            for entity in entity_list.ents:
                if entity.label_ == 'PERSON' or entity.label_ == 'NORP':
                    input_event = input_event.replace(entity.text,
                                                      replacement_list[r])
                    r += 1
                    if (r == 3):
                        break

            outputs = interactive.get_atomic_sequence(input_event, model,
                                                      sampler, data_loader,
                                                      text_encoder, category)

            for key in outputs:

                prefix = ""
                if (key[0] == "o"):
                    if (key == "oEffect"):
                        prefix = " Everyone else "
                    elif (key == "oReact"):
                        prefix = "They are "
                    elif (key == "oWant"):
                        prefix = "They want "
                else:
                    if (len(replaced) != 0):
                        prefix = replaced[0]
                    else:
                        prefix = "Person"
                    if (key == "xAttr"):
                        prefix += " is "
                    elif (key == "xEffect"):
                        prefix += " "
                    elif (key == "xIntent"):
                        prefix += " intends "
                    elif (key == "xReact"):
                        prefix += " is "
                    elif (key == "xNeed"):
                        prefix += " needs "
                    elif (key == "xWant"):
                        prefix += " wants "

                for j in range(5):

                    if (outputs[key]["beams"][j] != 'none'):
                        comet_inf = outputs[key]["beams"][j]
                        if (len(replaced) > 0):
                            comet_inf = comet_inf.replace(
                                "personx", replaced[0])
                            if (len(replaced) > 1):
                                comet_inf = comet_inf.replace(
                                    "persony", replaced[1])

                        article += prefix + (comet_inf) + ". "
                        break

        return article
Esempio n. 9
0
    def augment(self, article, dl, nlp, model, sampler, text_encoder):
        # category = [""]
        entity_list = nlp(article)
        input_event = article

        category_list = ["xNeed", "xIntent", "xWant", "xReact"]

        for category in category_list:
            replaced = []
            replacement_list = ["PersonX", "PersonY"]
            r = 0
            for entity in entity_list.ents:

                replaced += [entity.text]

                input_event = input_event.replace(entity.text,
                                                  replacement_list[r])
                r += 1
                if (r == 2):
                    break

            outputs = interactive.get_atomic_sequence(input_event, model,
                                                      sampler, dl,
                                                      text_encoder, category)
            for key in outputs:
                prefix = ""
                if (key[0] == "o"):
                    if (key == "oEffect"):
                        prefix = " Everyone else "
                    elif (key == "oReact"):
                        prefix = "They are "
                    elif (key == "oWant"):
                        prefix = "They want "
                else:
                    if (len(replaced) != 0):
                        prefix = replaced[0]
                    else:
                        prefix = "Person"
                    if (key == "xAttr"):
                        prefix += " is "
                    elif (key == "xEffect"):
                        prefix += " "
                    elif (key == "xIntent"):
                        prefix += " intends "
                    elif (key == "xReact"):
                        prefix += " is "
                    elif (key == "xNeed"):
                        prefix += " needs "
                    elif (key == "xWant"):
                        prefix += " wants "

                for j in range(5):

                    if (outputs[key]["beams"][j] != 'none'):
                        comet_inf = outputs[key]["beams"][j]
                        if (len(replaced) > 0):
                            comet_inf = comet_inf.replace(
                                "personx", replaced[0])
                            if (len(replaced) > 1):
                                comet_inf = comet_inf.replace(
                                    "persony", replaced[1])

                        article += prefix + (comet_inf) + ". "
                        break

        # print(article)
        # print("---------")
        # assert 1 == 0
        return article
    news_list_of_input_events = []
    for i in list_of_input_events:
        if i not in news_list_of_input_events:
            news_list_of_input_events.append(i)
    list_of_input_events = news_list_of_input_events

    #list_of_input_events = list(set(list_of_input_events))
    list_of_input_events = list_of_input_events[10000:20000]

    print(len(list_of_input_events))
    print(list_of_input_events[:5])

    output_file = open(args.output_file, 'a', encoding='utf-8')

    for idx, sentence in enumerate(list_of_input_events):

        outputs = interactive.get_atomic_sequence("PersonX says " + sentence,
                                                  model, sampler, data_loader,
                                                  text_encoder, args.category)
        if idx % 100 == 0:
            print(idx)
        #print(outputs)
        output_file.write(str(outputs))
        output_file.write("\n")
        output_file.flush()
    output_file.close()
    # print(type(outputs))

    # with open(args.output_file, 'w', encoding = 'utf-8') as f:
    #     json.dump(corpus,f,sort_keys=False, indent=4, separators=(',', ': '), ensure_ascii=False)
    n_ctx = data_loader.max_event + data_loader.max_effect
    n_vocab = len(text_encoder.encoder) + n_ctx
    model = interactive.make_model(opt, n_vocab, n_ctx, state_dict)

    if args.device != "cpu":
        cfg.device = int(args.device)
        cfg.do_gpu = True
        torch.cuda.set_device(cfg.device)
        model.cuda(cfg.device)
    else:
        cfg.device = "cpu"

    sampling_algorithm = args.sampling_algorithm
    #category = args.category
    #args.category = ["xReact","oReact"]
    sampler = interactive.set_sampler(opt, sampling_algorithm, data_loader)

    list_of_input_events = [
        "I will go to school! I will go to school! I will go to school! I will go to school! I will go to school! I will go to school!",
        "good",
        "beautiful campus",
    ]

    outputs = interactive.get_atomic_sequence(list_of_input_events[0], model,
                                              sampler, data_loader,
                                              text_encoder, args.category)

    print(type(outputs))
    print(outputs)
Esempio n. 12
0
 def _get_sequence(self, event, relations, sampler):
     return interactive.get_atomic_sequence(event, self._model, sampler,
                                            self._data_loader,
                                            self._text_encoder, relations)