Пример #1
0
def visualize(batch_data, model):
    i_datum = 0
    # mod_layout_choice = model.module_layout_choices[i_datum]
    # print model.apollo_net.blobs.keys()
    # att_blob_name = "Find_%d_softmax" % (mod_layout_choice * 100 + 1)
    #
    datum = batch_data[i_datum]
    question = (" ".join([QUESTION_INDEX.get(w) for w in datum.question[1:-1]]),)
    preds = model.prediction_data[i_datum, :]
    top = np.argsort(preds)[-5:]
    top_answers = reversed([ANSWER_INDEX.get(p) for p in top])
    # att_data = model.apollo_net.blobs[att_blob_name].data[i_datum,...]
    # att_data = att_data.reshape((14, 14))
    att_data = np.zeros((14, 14))
    chosen_parse = datum.parses[model.layout_ids[i_datum]]

    fields = [
        question,
        str(chosen_parse),
        "<img src='../../%s'>" % datum.image_path,
        att_data,
        ", ".join(top_answers),
        ", ".join([ANSWER_INDEX.get(a) for a in datum.answers]),
    ]
    visualizer.show(fields)
Пример #2
0
Файл: vqa.py Проект: zxsted/nmn2
def prepare_indices(config):
    set_name = "train2014"

    word_counts = defaultdict(lambda: 0)
    with open(QUESTION_FILE % set_name) as question_f:
        questions = json.load(question_f)["questions"]
        for question in questions:
            words = proc_question(question["question"])
            for word in words:
                word_counts[word] += 1
    for word, count in word_counts.items():
        if count >= MIN_COUNT:
            QUESTION_INDEX.index(word)

    pred_counts = defaultdict(lambda: 0)
    with open(MULTI_PARSE_FILE % set_name) as parse_f:
        for line in parse_f:
            parts = line.strip().replace("(",
                                         "").replace(")",
                                                     "").replace(";",
                                                                 " ").split()
            for part in parts:
                pred_counts[part] += 1
    for pred, count in pred_counts.items():
        if count >= 10 * MIN_COUNT:
            MODULE_INDEX.index(pred)

    answer_counts = defaultdict(lambda: 0)
    with open(ANN_FILE % set_name) as ann_f:
        annotations = json.load(ann_f)["annotations"]
        for ann in annotations:
            for answer in ann["answers"]:
                if answer["answer_confidence"] != "yes":
                    continue
                word = answer["answer"]
                if re.search(r"[^\w\s]", word):
                    continue
                answer_counts[word] += 1

    keep_answers = reversed(sorted([(c, a) for a, c in answer_counts.items()]))
    keep_answers = list(keep_answers)[:config.answers]
    for count, answer in keep_answers:
        ANSWER_INDEX.index(answer)
Пример #3
0
Файл: vqa.py Проект: amanbo/nmn2
def prepare_indices():
    set_name = "train2014"

    word_counts = defaultdict(lambda: 0)
    with open(QUESTION_FILE % set_name) as question_f:
        questions = json.load(question_f)["questions"]
        for question in questions:
            words = proc_question(question["question"])
            for word in words:
                word_counts[word] += 1
    for word, count in word_counts.items():
        if count >= MIN_COUNT:
            QUESTION_INDEX.index(word)

    pred_counts = defaultdict(lambda: 0)
    with open(MULTI_PARSE_FILE % set_name) as parse_f:
        for line in parse_f:
            parts = line.strip().replace("(", "").replace(")", "").replace(";", " ").split()
            for part in parts:
                pred_counts[part] += 1
    for pred, count in pred_counts.items():
        if count >= 10 * MIN_COUNT:
            MODULE_INDEX.index(pred)

    answer_counts = defaultdict(lambda: 0)
    with open(ANN_FILE % set_name) as ann_f:
        annotations = json.load(ann_f)["annotations"]
        for ann in annotations:
            for answer in ann["answers"]:
                if answer["answer_confidence"] != "yes":
                    continue
                word = answer["answer"]
                if re.search(r"[^\w\s]", word):
                    continue
                answer_counts[word] += 1

    keep_answers = reversed(sorted([(c, a) for a, c in answer_counts.items()]))
    keep_answers = list(keep_answers)[:1000]
    for count, answer in keep_answers:
        ANSWER_INDEX.index(answer)
Пример #4
0
def visualize(batch_data, model):
    i_datum = 0
    #mod_layout_choice = model.module_layout_choices[i_datum]
    #print model.apollo_net.blobs.keys()
    #att_blob_name = "Find_%d_softmax" % (mod_layout_choice * 100 + 1)
    #
    datum = batch_data[i_datum]
    question = " ".join([QUESTION_INDEX.get(w) for w in datum.question[1:-1]]),
    preds = model.prediction_data[i_datum, :]
    top = np.argsort(preds)[-5:]
    top_answers = reversed([ANSWER_INDEX.get(p) for p in top])
    #att_data = model.apollo_net.blobs[att_blob_name].data[i_datum,...]
    #att_data = att_data.reshape((14, 14))
    att_data = np.zeros((14, 14))
    chosen_parse = datum.parses[model.layout_ids[i_datum]]

    fields = [
        question,
        str(chosen_parse),
        "<img src='../../%s'>" % datum.image_path, att_data,
        ", ".join(top_answers),
        ", ".join([ANSWER_INDEX.get(a) for a in datum.answers])
    ]
    visualizer.show(fields)
Пример #5
0
    def __init__(self, config, set_name, modules):
        if set_name == VAL:
            self.data = []
            return

        questions = []
        answers = []
        parse_lists = []
        worlds = []

        if config.quant:
            ANSWER_INDEX.index(YES)
            ANSWER_INDEX.index(NO)

        for i_env, environment in enumerate(ENVIRONMENTS):
            if i_env == config.fold and set_name == TRAIN:
                continue
            if i_env != config.fold and set_name == TEST:
                continue

            places = list()
            with open(LOCATION_FILE % environment) as loc_f:
                for line in loc_f:
                    parts = line.strip().split(";")
                    places.append(parts[0])

            cats = {place: np.zeros((len(CATS),)) for place in places}
            rels = {(pl1, pl2): np.zeros((len(RELS),)) for pl1 in places for pl2 in places}

            with open(WORLD_FILE % environment) as world_f:
                for line in world_f:
                    parts = line.strip().split(";")
                    if len(parts) < 2:
                        continue
                    name = parts[0][1:]
                    places_here = parts[1].split(",")
                    if name in CATS:
                        cat_id = CATS.index(name)
                        for place in places_here:
                            cats[place][cat_id] = 1
                    elif name in RELS:
                        rel_id = RELS.index(name)
                        for place_pair in places_here:
                            pl1, pl2 = place_pair.split("#")
                            rels[pl1, pl2][rel_id] = 1
                            rels[pl2, pl1][rel_id] = -1

            clean_places = [p.lower().replace(" ", "_") for p in places]
            place_index = {place: i for (i, place) in enumerate(places)}
            clean_place_index = {place: i for (i, place) in enumerate(clean_places)}
            
            cat_features = np.zeros((len(CATS), DATABASE_SIZE, 1))
            rel_features = np.zeros((len(RELS), DATABASE_SIZE, DATABASE_SIZE))

            for p1, i_p1 in place_index.items():
                cat_features[:, i_p1, 0] = cats[p1]
                for p2, i_p2 in place_index.items():
                    rel_features[:, i_p1, i_p2] = rels[p1, p2]

            world = World(environment, clean_place_index, cat_features, rel_features)

            for place in clean_places:
                ANSWER_INDEX.index(place)

            with open(DATA_FILE % environment) as data_f:
                for line in data_f:
                    line = line.strip()
                    if line == "" or line[0] == "#":
                        continue

                    parts = line.split(";")

                    question = parts[0]
                    if question[-1] != "?":
                        question += " ?"
                    question = question.lower()
                    questions.append(question)

                    answer = parts[1].lower().replace(" ", "_")
                    if config.quant and question.split()[0] in ("is", "are"):
                        answer = YES if answer else NO
                    answers.append(answer)

                    worlds.append(world)

            with open(PARSE_FILE % environment) as parse_f:
                for line in parse_f:
                    parse_strs = line.strip().split(";")
                    trees = [parse_tree(s) for s in parse_strs]
                    if not config.quant:
                        trees = [t for t in trees if t[0] != "exists"]
                    parse_lists.append(trees)

        assert len(questions) == len(parse_lists)

        data = []
        i_datum = 0
        for question, answer, parse_list, world in \
                zip(questions, answers, parse_lists, worlds):
            tokens = ["<s>"] + question.split() + ["</s>"]

            parse_list = parse_list[-config.k_best_parses:]

            indexed_question = [QUESTION_INDEX.index(w) for w in tokens]
            indexed_answer = \
                    tuple(ANSWER_INDEX[a] for a in answer.split(",") if a != "")
            assert all(a is not None for a in indexed_answer)
            layouts = [parse_to_layout(p, world, config, modules) for p in parse_list]

            data.append(GeoDatum(
                    i_datum, indexed_question, parse_list, layouts, indexed_answer, world))
            i_datum += 1

        self.data = data

        logging.info("%s:", set_name)
        logging.info("%s items", len(self.data))
        logging.info("%s words", len(QUESTION_INDEX))
        logging.info("%s functions", len(MODULE_INDEX))
        logging.info("%s answers", len(ANSWER_INDEX))
Пример #6
0
Файл: geo.py Проект: zxsted/nmn2
    def __init__(self, config, set_name, modules):
        if set_name == VAL:
            self.data = []
            return

        questions = []
        answers = []
        parse_lists = []
        worlds = []

        if config.quant:
            ANSWER_INDEX.index(YES)
            ANSWER_INDEX.index(NO)

        for i_env, environment in enumerate(ENVIRONMENTS):
            if i_env == config.fold and set_name == TRAIN:
                continue
            if i_env != config.fold and set_name == TEST:
                continue

            places = list()
            with open(LOCATION_FILE % environment) as loc_f:
                for line in loc_f:
                    parts = line.strip().split(";")
                    places.append(parts[0])

            cats = {place: np.zeros((len(CATS), )) for place in places}
            rels = {(pl1, pl2): np.zeros((len(RELS), ))
                    for pl1 in places for pl2 in places}

            with open(WORLD_FILE % environment) as world_f:
                for line in world_f:
                    parts = line.strip().split(";")
                    if len(parts) < 2:
                        continue
                    name = parts[0][1:]
                    places_here = parts[1].split(",")
                    if name in CATS:
                        cat_id = CATS.index(name)
                        for place in places_here:
                            cats[place][cat_id] = 1
                    elif name in RELS:
                        rel_id = RELS.index(name)
                        for place_pair in places_here:
                            pl1, pl2 = place_pair.split("#")
                            rels[pl1, pl2][rel_id] = 1
                            rels[pl2, pl1][rel_id] = -1

            clean_places = [p.lower().replace(" ", "_") for p in places]
            place_index = {place: i for (i, place) in enumerate(places)}
            clean_place_index = {
                place: i
                for (i, place) in enumerate(clean_places)
            }

            cat_features = np.zeros((len(CATS), DATABASE_SIZE, 1))
            rel_features = np.zeros((len(RELS), DATABASE_SIZE, DATABASE_SIZE))

            for p1, i_p1 in place_index.items():
                cat_features[:, i_p1, 0] = cats[p1]
                for p2, i_p2 in place_index.items():
                    rel_features[:, i_p1, i_p2] = rels[p1, p2]

            world = World(environment, clean_place_index, cat_features,
                          rel_features)

            for place in clean_places:
                ANSWER_INDEX.index(place)

            with open(DATA_FILE % environment) as data_f:
                for line in data_f:
                    line = line.strip()
                    if line == "" or line[0] == "#":
                        continue

                    parts = line.split(";")

                    question = parts[0]
                    if question[-1] != "?":
                        question += " ?"
                    question = question.lower()
                    questions.append(question)

                    answer = parts[1].lower().replace(" ", "_")
                    if config.quant and question.split()[0] in ("is", "are"):
                        answer = YES if answer else NO
                    answers.append(answer)

                    worlds.append(world)

            with open(PARSE_FILE % environment) as parse_f:
                for line in parse_f:
                    parse_strs = line.strip().split(";")
                    trees = [parse_tree(s) for s in parse_strs]
                    if not config.quant:
                        trees = [t for t in trees if t[0] != "exists"]
                    parse_lists.append(trees)

        assert len(questions) == len(parse_lists)

        data = []
        i_datum = 0
        for question, answer, parse_list, world in \
                zip(questions, answers, parse_lists, worlds):
            tokens = ["<s>"] + question.split() + ["</s>"]

            parse_list = parse_list[-config.k_best_parses:]

            indexed_question = [QUESTION_INDEX.index(w) for w in tokens]
            indexed_answer = \
                    tuple(ANSWER_INDEX[a] for a in answer.split(",") if a != "")
            assert all(a is not None for a in indexed_answer)
            layouts = [
                parse_to_layout(p, world, config, modules) for p in parse_list
            ]

            data.append(
                GeoDatum(i_datum, indexed_question, parse_list, layouts,
                         indexed_answer, world))
            i_datum += 1

        self.data = data

        logging.info("%s:", set_name)
        logging.info("%s items", len(self.data))
        logging.info("%s words", len(QUESTION_INDEX))
        logging.info("%s functions", len(MODULE_INDEX))
        logging.info("%s answers", len(ANSWER_INDEX))