예제 #1
0
파일: vqa.py 프로젝트: zxsted/nmn2
    def load_set(self, config, set_name, size, modules, mean, std):
        parse_file = MULTI_PARSE_FILE
        with open(QUESTION_FILE % set_name) as question_f, \
             open(parse_file % set_name) as parse_f:
            questions = json.load(question_f)["questions"]
            parse_groups = [l.strip() for l in parse_f]
            assert len(questions) == len(parse_groups)
            pairs = zip(questions, parse_groups)
            if size is not None:
                pairs = pairs[:size]
            for question, parse_group in pairs:
                id = question["question_id"]
                question_str = proc_question(question["question"])
                indexed_question = \
                    [QUESTION_INDEX[w] or UNK_ID for w in question_str]

                parse_strs = parse_group.split(";")
                parses = [parse_tree(p) for p in parse_strs]
                parses = [("_what", "_thing") if p == "none" else p
                          for p in parses]
                if config.chooser == "null":
                    parses = [("_what", "_thing")]
                elif config.chooser == "cvpr":
                    if parses[0][0] == "is":
                        parses = parses[-1:]
                    else:
                        parses = parses[:1]
                elif config.chooser == "naacl":
                    pass
                else:
                    assert False

                layouts = [parse_to_layout(p, config, modules) for p in parses]
                image_id = question["image_id"]
                try:
                    image_set_name = "test2015" if set_name == "test-dev2015" else set_name
                    datum = VqaDatum(id, indexed_question, parses, layouts,
                                     image_set_name, image_id, [], mean, std)
                    self.by_id[id] = datum
                except IOError as e:
                    print e
                    pass

        if set_name not in ("test2015", "test-dev2015"):
            with open(ANN_FILE % set_name) as ann_f:
                annotations = json.load(ann_f)["annotations"]
                for ann in annotations:
                    question_id = ann["question_id"]
                    if question_id not in self.by_id:
                        continue

                    answer_counter = defaultdict(lambda: 0)
                    answers = [a["answer"] for a in ann["answers"]]
                    indexed_answers = [
                        ANSWER_INDEX[a] or UNK_ID for a in answers
                    ]
                    self.by_id[question_id].answers = indexed_answers
예제 #2
0
파일: vqa.py 프로젝트: amanbo/nmn2
    def load_set(self, config, set_name, size, modules, mean, std):
        parse_file = MULTI_PARSE_FILE
        with open(QUESTION_FILE % set_name) as question_f, \
             open(parse_file % set_name) as parse_f:
            questions = json.load(question_f)["questions"]
            parse_groups = [l.strip() for l in parse_f]
            assert len(questions) == len(parse_groups)
            pairs = zip(questions, parse_groups)
            if size is not None:
                pairs = pairs[:size]
            for question, parse_group in pairs:
                id = question["question_id"]
                question_str = proc_question(question["question"])
                indexed_question = \
                    [QUESTION_INDEX[w] or UNK_ID for w in question_str]

                parse_strs = parse_group.split(";")
                parses = [parse_tree(p) for p in parse_strs]
                parses = [("_what", "_thing") if p == "none" else p for p in parses]
                if config.chooser == "null":
                    parses = [("_what", "_thing")]
                elif config.chooser == "cvpr":
                    if parses[0][0] == "is":
                        parses = parses[-1:]
                    else:
                        parses = parses[:1]
                elif config.chooser == "naacl":
                    pass
                else:
                    assert False

                layouts = [parse_to_layout(p, config, modules) for p in parses]
                image_id = question["image_id"]
                try:
                    image_set_name = "test2015" if set_name == "test-dev2015" else set_name
                    datum = VqaDatum(id, indexed_question, parses, layouts, image_set_name, image_id, [], mean, std)
                    self.by_id[id] = datum
                except IOError as e:
                    print e
                    pass

        if set_name not in ("test2015", "test-dev2015"):
            with open(ANN_FILE % set_name) as ann_f:
                annotations = json.load(ann_f)["annotations"]
                for ann in annotations:
                    question_id = ann["question_id"]
                    if question_id not in self.by_id:
                        continue

                    answer_counter = defaultdict(lambda: 0)
                    answers = [a["answer"] for a in ann["answers"]]
                    indexed_answers = [ANSWER_INDEX[a] or UNK_ID for a in answers]
                    self.by_id[question_id].answers = indexed_answers
예제 #3
0
파일: geo.py 프로젝트: BinbinBian/nmn2
    def __init__(self, config, set_name, modules):
        if set_name == VAL:
            self.data = []
            return

        questions = []
        answers = []
        parse_lists = []
        worlds = []

        if config.quant:
            ANSWER_INDEX.index(YES)
            ANSWER_INDEX.index(NO)

        for i_env, environment in enumerate(ENVIRONMENTS):
            if i_env == config.fold and set_name == TRAIN:
                continue
            if i_env != config.fold and set_name == TEST:
                continue

            places = list()
            with open(LOCATION_FILE % environment) as loc_f:
                for line in loc_f:
                    parts = line.strip().split(";")
                    places.append(parts[0])

            cats = {place: np.zeros((len(CATS),)) for place in places}
            rels = {(pl1, pl2): np.zeros((len(RELS),)) for pl1 in places for pl2 in places}

            with open(WORLD_FILE % environment) as world_f:
                for line in world_f:
                    parts = line.strip().split(";")
                    if len(parts) < 2:
                        continue
                    name = parts[0][1:]
                    places_here = parts[1].split(",")
                    if name in CATS:
                        cat_id = CATS.index(name)
                        for place in places_here:
                            cats[place][cat_id] = 1
                    elif name in RELS:
                        rel_id = RELS.index(name)
                        for place_pair in places_here:
                            pl1, pl2 = place_pair.split("#")
                            rels[pl1, pl2][rel_id] = 1
                            rels[pl2, pl1][rel_id] = -1

            clean_places = [p.lower().replace(" ", "_") for p in places]
            place_index = {place: i for (i, place) in enumerate(places)}
            clean_place_index = {place: i for (i, place) in enumerate(clean_places)}
            
            cat_features = np.zeros((len(CATS), DATABASE_SIZE, 1))
            rel_features = np.zeros((len(RELS), DATABASE_SIZE, DATABASE_SIZE))

            for p1, i_p1 in place_index.items():
                cat_features[:, i_p1, 0] = cats[p1]
                for p2, i_p2 in place_index.items():
                    rel_features[:, i_p1, i_p2] = rels[p1, p2]

            world = World(environment, clean_place_index, cat_features, rel_features)

            for place in clean_places:
                ANSWER_INDEX.index(place)

            with open(DATA_FILE % environment) as data_f:
                for line in data_f:
                    line = line.strip()
                    if line == "" or line[0] == "#":
                        continue

                    parts = line.split(";")

                    question = parts[0]
                    if question[-1] != "?":
                        question += " ?"
                    question = question.lower()
                    questions.append(question)

                    answer = parts[1].lower().replace(" ", "_")
                    if config.quant and question.split()[0] in ("is", "are"):
                        answer = YES if answer else NO
                    answers.append(answer)

                    worlds.append(world)

            with open(PARSE_FILE % environment) as parse_f:
                for line in parse_f:
                    parse_strs = line.strip().split(";")
                    trees = [parse_tree(s) for s in parse_strs]
                    if not config.quant:
                        trees = [t for t in trees if t[0] != "exists"]
                    parse_lists.append(trees)

        assert len(questions) == len(parse_lists)

        data = []
        i_datum = 0
        for question, answer, parse_list, world in \
                zip(questions, answers, parse_lists, worlds):
            tokens = ["<s>"] + question.split() + ["</s>"]

            parse_list = parse_list[-config.k_best_parses:]

            indexed_question = [QUESTION_INDEX.index(w) for w in tokens]
            indexed_answer = \
                    tuple(ANSWER_INDEX[a] for a in answer.split(",") if a != "")
            assert all(a is not None for a in indexed_answer)
            layouts = [parse_to_layout(p, world, config, modules) for p in parse_list]

            data.append(GeoDatum(
                    i_datum, indexed_question, parse_list, layouts, indexed_answer, world))
            i_datum += 1

        self.data = data

        logging.info("%s:", set_name)
        logging.info("%s items", len(self.data))
        logging.info("%s words", len(QUESTION_INDEX))
        logging.info("%s functions", len(MODULE_INDEX))
        logging.info("%s answers", len(ANSWER_INDEX))
예제 #4
0
파일: geo.py 프로젝트: zxsted/nmn2
    def __init__(self, config, set_name, modules):
        if set_name == VAL:
            self.data = []
            return

        questions = []
        answers = []
        parse_lists = []
        worlds = []

        if config.quant:
            ANSWER_INDEX.index(YES)
            ANSWER_INDEX.index(NO)

        for i_env, environment in enumerate(ENVIRONMENTS):
            if i_env == config.fold and set_name == TRAIN:
                continue
            if i_env != config.fold and set_name == TEST:
                continue

            places = list()
            with open(LOCATION_FILE % environment) as loc_f:
                for line in loc_f:
                    parts = line.strip().split(";")
                    places.append(parts[0])

            cats = {place: np.zeros((len(CATS), )) for place in places}
            rels = {(pl1, pl2): np.zeros((len(RELS), ))
                    for pl1 in places for pl2 in places}

            with open(WORLD_FILE % environment) as world_f:
                for line in world_f:
                    parts = line.strip().split(";")
                    if len(parts) < 2:
                        continue
                    name = parts[0][1:]
                    places_here = parts[1].split(",")
                    if name in CATS:
                        cat_id = CATS.index(name)
                        for place in places_here:
                            cats[place][cat_id] = 1
                    elif name in RELS:
                        rel_id = RELS.index(name)
                        for place_pair in places_here:
                            pl1, pl2 = place_pair.split("#")
                            rels[pl1, pl2][rel_id] = 1
                            rels[pl2, pl1][rel_id] = -1

            clean_places = [p.lower().replace(" ", "_") for p in places]
            place_index = {place: i for (i, place) in enumerate(places)}
            clean_place_index = {
                place: i
                for (i, place) in enumerate(clean_places)
            }

            cat_features = np.zeros((len(CATS), DATABASE_SIZE, 1))
            rel_features = np.zeros((len(RELS), DATABASE_SIZE, DATABASE_SIZE))

            for p1, i_p1 in place_index.items():
                cat_features[:, i_p1, 0] = cats[p1]
                for p2, i_p2 in place_index.items():
                    rel_features[:, i_p1, i_p2] = rels[p1, p2]

            world = World(environment, clean_place_index, cat_features,
                          rel_features)

            for place in clean_places:
                ANSWER_INDEX.index(place)

            with open(DATA_FILE % environment) as data_f:
                for line in data_f:
                    line = line.strip()
                    if line == "" or line[0] == "#":
                        continue

                    parts = line.split(";")

                    question = parts[0]
                    if question[-1] != "?":
                        question += " ?"
                    question = question.lower()
                    questions.append(question)

                    answer = parts[1].lower().replace(" ", "_")
                    if config.quant and question.split()[0] in ("is", "are"):
                        answer = YES if answer else NO
                    answers.append(answer)

                    worlds.append(world)

            with open(PARSE_FILE % environment) as parse_f:
                for line in parse_f:
                    parse_strs = line.strip().split(";")
                    trees = [parse_tree(s) for s in parse_strs]
                    if not config.quant:
                        trees = [t for t in trees if t[0] != "exists"]
                    parse_lists.append(trees)

        assert len(questions) == len(parse_lists)

        data = []
        i_datum = 0
        for question, answer, parse_list, world in \
                zip(questions, answers, parse_lists, worlds):
            tokens = ["<s>"] + question.split() + ["</s>"]

            parse_list = parse_list[-config.k_best_parses:]

            indexed_question = [QUESTION_INDEX.index(w) for w in tokens]
            indexed_answer = \
                    tuple(ANSWER_INDEX[a] for a in answer.split(",") if a != "")
            assert all(a is not None for a in indexed_answer)
            layouts = [
                parse_to_layout(p, world, config, modules) for p in parse_list
            ]

            data.append(
                GeoDatum(i_datum, indexed_question, parse_list, layouts,
                         indexed_answer, world))
            i_datum += 1

        self.data = data

        logging.info("%s:", set_name)
        logging.info("%s items", len(self.data))
        logging.info("%s words", len(QUESTION_INDEX))
        logging.info("%s functions", len(MODULE_INDEX))
        logging.info("%s answers", len(ANSWER_INDEX))