コード例 #1
0
ファイル: train_model.py プロジェクト: HuiyuanXie/film
 def __iter__(self):
   for batch in super(ShapeWorldDataLoader, self).__iter__():
     question = batch['caption'].long()
     image = batch['world']
     feats = batch['world']
     answer = batch['agreement'].long()
     if 'caption_model' in batch:
       program_seq = batch['caption_model'].apply_(callable=(lambda model: clevr_util.parse_program(mode=0, model=model)))
     else:
       program_seq = torch.IntTensor([0 for _ in batch['caption']])
     program_json = dict()
     yield question, image, feats, answer, program_seq, program_json
コード例 #2
0
 def preprocess(model):
     if args.sw_name.startswith("clevr"):
         program_prefix = vr.programs.list_to_prefix(
             model["program"])
     else:
         program_prefix = clevr_util.parse_program(mode=0,
                                                   model=model)
     program_str = vr.programs.list_to_str(program_prefix)
     program_tokens = tokenize(program_str)
     program_encoded = encode(program_tokens, program_token_to_idx)
     program_encoded += [
         program_token_to_idx["<NULL>"]
         for _ in range(27 - len(program_encoded))
     ]
     return np.asarray(program_encoded, dtype=np.int64)
コード例 #3
0
                    else:
                        captions_iter = zip(
                            (captions[n], ), (captions_length[n], ),
                            (captions_model[n], ), (agreements[n], ))
                    for caption, caption_length, caption_model, agreement in captions_iter:
                        if agreement == 1.0:
                            answer = 'true'
                        elif agreement == 0.0:
                            answer = 'false'
                        else:
                            assert False
                            answer = 'maybe'
                        if caption_model is None:
                            program = None
                        else:
                            program = parse_program(model=caption_model)
                        questions.append(
                            dict(image_index=index,
                                 program=program,
                                 question_index=0,
                                 image_filename=filename,
                                 question_family_index=0,
                                 split=mode,
                                 answer=answer,
                                 question=' '.join(
                                     id2word[caption[i]]
                                     for i in range(caption_length))))

            else:
                if args.features:
                    for value_name, value_type in dataset.values.items():
コード例 #4
0
                        captions_iter = zip(
                            (captions[n], ), (captions_length[n], ),
                            (captions_model[n], ), (agreements[n], ))
                    for caption, caption_length, caption_model, agreement in captions_iter:
                        if agreement == 1.0:
                            answer = 'true'
                        elif agreement == 0.0:
                            answer = 'false'
                        else:
                            assert False
                            answer = 'maybe'
                        for parse_mode in range(2):
                            if caption_model is None:
                                program = None
                            else:
                                program = clevr_util.parse_program(
                                    mode=parse_mode, model=caption_model)
                            questions[parse_mode].append(
                                dict(image_index=index,
                                     program=program,
                                     question_index=0,
                                     image_filename=filename,
                                     question_family_index=0,
                                     split=mode,
                                     answer=answer,
                                     question=' '.join(
                                         id2word[caption[i]]
                                         for i in range(caption_length))))

            else:
                if args.features:
                    for value_name, value_type in dataset.values.items():