コード例 #1
0
ファイル: run_model.py プロジェクト: arjunakula/neurips2021
def load_vocab(args):
    path = None
    if args.baseline_model is not None:
        path = args.baseline_model
    elif args.execution_engine is not None:
        path = args.execution_engine
    elif args.program_generator is not None:
        path = args.program_generator
    return utils.load_cpu(path)['vocab']
コード例 #2
0
def main(args):
    if (args.input_vocab_json == '') and (args.output_vocab_json
                                          == '') and (args.input_vocab_pt
                                                      == ''):
        print(
            'Must give one of --input_vocab_json or --output_vocab_json or --input_vocab_pt'
        )
        return

    print('Loading data')
    with open(args.input_questions_json, 'r') as f:
        questions = json.load(f)['questions']

    # Either create the vocab or load it from disk
    if (args.input_vocab_json == ''
            and args.input_vocab_pt == '') or args.expand_vocab == 1:
        print('Building vocab')
        if 'answer' in questions[0]:
            answer_token_to_idx = build_vocab((q['answer'] for q in questions))
        question_token_to_idx = build_vocab((q['question'] for q in questions),
                                            min_token_count=args.unk_threshold,
                                            punct_to_keep=[';', ','],
                                            punct_to_remove=['?', '.'])
        all_program_strs = []
        for q in questions:
            if 'program' not in q: continue
            program_str = program_to_str(q['program'], args.mode)
            if program_str is not None:
                all_program_strs.append(program_str)
        program_token_to_idx = build_vocab(all_program_strs)
        vocab = {
            'question_token_to_idx': question_token_to_idx,
            'program_token_to_idx': program_token_to_idx,
            'answer_token_to_idx': answer_token_to_idx,
        }

    if args.input_vocab_json != '':
        print('Loading vocab')
        if args.expand_vocab == 1:
            new_vocab = vocab
        with open(args.input_vocab_json, 'r') as f:
            vocab = json.load(f)
        if args.expand_vocab == 1:
            num_new_words = 0
            for word in new_vocab['question_token_to_idx']:
                if word not in vocab['question_token_to_idx']:
                    print('Found new word %s' % word)
                    idx = len(vocab['question_token_to_idx'])
                    vocab['question_token_to_idx'][word] = idx
                    num_new_words += 1
            print('Found %d new words' % num_new_words)

    elif args.input_vocab_pt != '':
        print('Loading vocab')
        if args.expand_vocab == 1:
            new_vocab = vocab
        vocab = utils.load_cpu(args.input_vocab_pt)['vocab']

        if args.expand_vocab == 1:
            num_new_words = 0
            for word in new_vocab['question_token_to_idx']:
                if word not in vocab['question_token_to_idx']:
                    print('Found new word %s' % word)
                    idx = len(vocab['question_token_to_idx'])
                    vocab['question_token_to_idx'][word] = idx
                    num_new_words += 1
            print('Found %d new words' % num_new_words)

    if args.output_vocab_json != '':
        with open(args.output_vocab_json, 'w') as f:
            json.dump(vocab, f)

    # Encode all questions and programs
    print('Encoding data')
    questions_encoded = []
    programs_encoded = []
    question_families = []
    orig_idxs = []
    image_idxs = []
    answers = []
    for orig_idx, q in enumerate(questions):
        question = q['question']

        orig_idxs.append(orig_idx)
        image_idxs.append(q['image_index'])
        if 'question_family_index' in q:
            question_families.append(q['question_family_index'])
        question_tokens = tokenize(question,
                                   punct_to_keep=[';', ','],
                                   punct_to_remove=['?', '.'])
        question_encoded = encode(question_tokens,
                                  vocab['question_token_to_idx'],
                                  allow_unk=args.encode_unk == 1)
        questions_encoded.append(question_encoded)

        if 'program' in q:
            program = q['program']
            program_str = program_to_str(program, args.mode)
            program_tokens = tokenize(program_str)
            program_encoded = encode(program_tokens,
                                     vocab['program_token_to_idx'])
            programs_encoded.append(program_encoded)

        if 'answer' in q:
            answer = q['answer']
            if answer is True:
                answer = 'yes'
            elif answer is False:
                answer = 'no'
            else:
                answer = str(answer)
            if answer in vocab['answer_token_to_idx']:
                answers.append(vocab['answer_token_to_idx'][answer])
            else:
                print('Answer %s to %s is missing' % (answer, question))
                answers.append(0)

    # Pad encoded questions and programs
    max_question_length = max(len(x) for x in questions_encoded)
    for qe in questions_encoded:
        while len(qe) < max_question_length:
            qe.append(vocab['question_token_to_idx']['<NULL>'])

    if len(programs_encoded) > 0:
        max_program_length = max(len(x) for x in programs_encoded)
        for pe in programs_encoded:
            while len(pe) < max_program_length:
                pe.append(vocab['program_token_to_idx']['<NULL>'])

    # Create h5 file
    print('Writing output')
    questions_encoded = np.asarray(questions_encoded, dtype=np.int32)
    programs_encoded = np.asarray(programs_encoded, dtype=np.int32)
    print(questions_encoded.shape)
    print(programs_encoded.shape)
    with h5py.File(args.output_h5_file, 'w') as f:
        f.create_dataset('questions', data=questions_encoded)
        f.create_dataset('image_idxs', data=np.asarray(image_idxs))
        f.create_dataset('orig_idxs', data=np.asarray(orig_idxs))

        if len(programs_encoded) > 0:
            f.create_dataset('programs', data=programs_encoded)
        if len(question_families) > 0:
            f.create_dataset('question_families',
                             data=np.asarray(question_families))
        if len(answers) > 0:
            f.create_dataset('answers', data=np.asarray(answers))