예제 #1
0
def prepare_train_data(config):
    """ Prepare the data for training the model. """
    vqa = VQA(config.train_answer_file, config.train_question_file)
    vqa.filter_by_ques_len(config.max_question_length)
    vqa.filter_by_ans_len(1)

    print("Reading the questions and answers...")
    annotations = process_vqa(vqa, 'COCO_train2014', config.train_image_dir,
                              config.temp_train_annotation_file)

    image_files = annotations['image_file'].values
    questions = annotations['question'].values
    question_ids = annotations['question_id'].values
    answers = annotations['answer'].values
    print("Questions and answers read.")
    print("Number of questions = %d" % (len(question_ids)))

    print("Building the vocabulary...")
    vocabulary = Vocabulary()
    if not os.path.exists(config.vocabulary_file):
        for question in tqdm(questions):
            vocabulary.add_words(word_tokenize(question))
        for answer in tqdm(answers):
            vocabulary.add_words(word_tokenize(answer))
        vocabulary.compute_frequency()
        vocabulary.save(config.vocabulary_file)
    else:
        vocabulary.load(config.vocabulary_file)
    print("Vocabulary built.")
    print("Number of words = %d" % (vocabulary.size))
    config.vocabulary_size = vocabulary.size

    print("Processing the questions and answers...")
    if not os.path.exists(config.temp_train_data_file):
        question_word_idxs, question_lens = process_questions(
            questions, vocabulary, config)
        answer_idxs = process_answers(answers, vocabulary)
        data = {
            'question_word_idxs': question_word_idxs,
            'question_lens': question_lens,
            'answer_idxs': answer_idxs
        }
        np.save(config.temp_train_data_file, data)
    else:
        data = np.load(config.temp_train_data_file).item()
        question_word_idxs = data['question_word_idxs']
        question_lens = data['question_lens']
        answer_idxs = data['answer_idxs']
    print("Questions and answers processed.")

    print("Building the dataset...")
    dataset = DataSet(image_files, question_word_idxs, question_lens,
                      question_ids, config.batch_size, answer_idxs, True, True)
    print("Dataset built.")
    return dataset, config
예제 #2
0
def build_vocabulary(config):
    """ Build the vocabulary from the training data and save it to a file. """
    vqa = VQA(config.train_answer_file, config.train_question_file)
    vqa.filter_by_ques_len(config.max_question_length)
    vqa.filter_by_ans_len(1)

    question_ids = list(vqa.qa.keys())
    questions = [vqa.qqa[k]['question'] for k in question_ids]
    answers = [vqa.qa[k]['best_answer'] for k in question_ids]

    vocabulary = Vocabulary()
    for question in tqdm(questions):
        vocabulary.add_words(word_tokenize(question))
    for answer in tqdm(answers):
        vocabulary.add_words(word_tokenize(answer))
    vocabulary.compute_frequency()
    vocabulary.save(config.vocabulary_file)
    return vocabulary