Esempio n. 1
0
def main(args):
    all_checkpoints = glob.glob('%s/*.pt' %args.model_path)
    print(all_checkpoints)

    for i, checkpoint in enumerate(all_checkpoints):

        model, _ = utils.load_execution_engine(checkpoint, False, 'SHNMN')
        for name, param in model.named_parameters():
            if param.requires_grad:
                print(name)

        f = open('f_%d.txt' %i, 'w')
        f.write('%s\n' %checkpoint)
        f.write('HARD_TAU | HARD_ALPHA \n')
        f.write('%s-%s\n'%(model.hard_code_tau, model.hard_code_alpha))
        f.write('TAUS\n')
        f.write('p(model) : %s\n' %str(F.sigmoid(model.model_bernoulli)))
        for i in range(3):
            tau0 = model.tau_0[i, :(i+2)] if model.hard_code_tau else F.softmax(model.tau_0[i, :(i+2)] )
            f.write('tau0: %s\n' %str(tau0.data.cpu().numpy()))
            tau1 = model.tau_1[i, :(i+2)] if model.hard_code_tau else F.softmax(model.tau_1[i, :(i+2)] )
            f.write('tau1: %s\n' %str(tau1.data.cpu().numpy()))

        f.write('ALPHAS\n')
        for i in range(3):
            alpha = model.alpha[i] if model.hard_code_alpha else F.softmax(model.alpha[i])
            f.write('alpha: %s\n' % " ".join(['{:.3f}'.format(float(x)) for x in alpha.view(-1).data.numpy()]))

        f.close()
Esempio n. 2
0
    def __init__(self, agent, train_test, env, trunc, sampling):
        Metric.__init__(self, agent, train_test, "oracle_recall", "scalar",
                        env, trunc, sampling)
        vocab = "data/closure_vocab.json"
        path = "output/vqa_model_film/model.pt"
        self.execution_engine, ee_kwargs = load_execution_engine(path)
        self.execution_engine.to(self.device)
        self.execution_engine.eval()
        self.program_generator, pg_kwargs = load_program_generator(path)
        self.program_generator.to(self.device)
        self.program_generator.eval()
        self.vocab = vocab
        self.vocab_questions_vqa = get_vocab('question_token_to_idx',
                                             self.vocab)
        # self.vocab_questions_vqa.update({"<pad>": 0, "<sos>": 1, "<eos>": 2})
        self.trad_dict = {
            value: self.vocab_questions_vqa[key.lower()]
            for key, value in self.dataset.vocab_questions.items()
            if key.lower() in self.vocab_questions_vqa
        }
        self.decoder_dict = {
            value: key
            for key, value in self.vocab_questions_vqa.items()
        }

        self.reset()
        self.batch_size = 30
Esempio n. 3
0
 def __init__(self, path=None, vocab=None, dataset=None, env=None):
     Reward.__init__(self, path)
     self.type = "episode"
     self.device = torch.device('cuda' if torch.cuda.is_available(
     ) else 'cpu') if env is None else env.device
     self.execution_engine, ee_kwargs = load_execution_engine(path)
     self.execution_engine.to(self.device)
     self.execution_engine.eval()
     self.program_generator, pg_kwargs = load_program_generator(path)
     self.program_generator.to(self.device)
     self.program_generator.eval()
     self.vocab = vocab
     self.dataset = dataset
     self.vocab_questions_vqa = get_vocab('question_token_to_idx',
                                          self.vocab)
     # self.vocab_questions_vqa.update({"<pad>": 0, "<sos>": 1, "<eos>": 2})
     self.trad_dict = {
         value: self.vocab_questions_vqa[key.lower()]
         for key, value in self.dataset.vocab_questions.items()
         if key.lower() in self.vocab_questions_vqa
     }
     self.decoder_dict = {
         value: key
         for key, value in self.vocab_questions_vqa.items()
     }
def get_execution_engine(args):
  vocab = utils.load_vocab(args.vocab_json)
  if args.execution_engine_start_from is not None:
    ee, kwargs = utils.load_execution_engine(
      args.execution_engine_start_from, model_type=args.model_type)
  else:
    kwargs = {
      'vocab': vocab,
      'feature_dim': parse_int_list(args.feature_dim),
      'stem_batchnorm': args.module_stem_batchnorm == 1,
      'stem_num_layers': args.module_stem_num_layers,
      'module_dim': args.module_dim,
      'module_residual': args.module_residual == 1,
      'module_batchnorm': args.module_batchnorm == 1,
      'classifier_proj_dim': args.classifier_proj_dim,
      'classifier_downsample': args.classifier_downsample,
      'classifier_fc_layers': parse_int_list(args.classifier_fc_dims),
      'classifier_batchnorm': args.classifier_batchnorm == 1,
      'classifier_dropout': args.classifier_dropout,
      
      'encoder_vocab_size': len(vocab['question_token_to_idx']),
      'decoder_vocab_size': len(vocab['program_token_to_idx']),
      'wordvec_dim': args.rnn_wordvec_dim,
      'hidden_dim': args.rnn_hidden_dim,
      'rnn_num_layers': args.rnn_num_layers,
      'rnn_dropout': args.rnn_dropout, # 0e-2
    }
    if args.model_type == 'FiLM':
      kwargs['num_modules'] = args.num_modules
      kwargs['stem_kernel_size'] = args.module_stem_kernel_size
      kwargs['stem_stride'] = args.module_stem_stride
      kwargs['stem_padding'] = args.module_stem_padding
      kwargs['module_num_layers'] = args.module_num_layers
      kwargs['module_batchnorm_affine'] = args.module_batchnorm_affine == 1
      kwargs['module_dropout'] = args.module_dropout
      kwargs['module_input_proj'] = args.module_input_proj
      kwargs['module_kernel_size'] = args.module_kernel_size
      kwargs['use_gamma'] = args.use_gamma == 1
      kwargs['use_beta'] = args.use_beta == 1
      kwargs['use_coords'] = args.use_coords
      kwargs['debug_every'] = args.debug_every
      kwargs['print_verbose_every'] = args.print_verbose_every
      kwargs['condition_method'] = args.condition_method
      kwargs['condition_pattern'] = parse_int_list(args.condition_pattern)
      
      kwargs['parameter_efficient'] = args.program_generator_parameter_efficient == 1
      kwargs['output_batchnorm'] = args.rnn_output_batchnorm == 1 
      kwargs['bidirectional'] = args.bidirectional == 1 
      kwargs['rnn_time_step'] = args.rnn_time_step
      kwargs['encoder_type'] = args.encoder_type 
      kwargs['decoder_type'] = args.decoder_type 
      kwargs['gamma_option'] = args.gamma_option
      kwargs['gamma_baseline'] = args.gamma_baseline # 1
      
      ee = FiLMedNet(**kwargs)
    else:
      ee = ModuleNet(**kwargs)
  ee.cuda()
  ee.train()
  return ee, kwargs
Esempio n. 5
0
def main(args):
    if args.debug_every <= 1:
        pdb.set_trace()
    model = None
    if args.baseline_model is not None:
        print('Loading baseline model from ', args.baseline_model)
        model, _ = utils.load_baseline(args.baseline_model)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            model.rnn.expand_vocab(new_vocab['question_token_to_idx'])
    elif (args.program_generator is not None
          and args.execution_engine is not None):
        pg, _ = utils.load_program_generator(args.program_generator,
                                             args.model_type)
        ee, _ = utils.load_execution_engine(args.execution_engine,
                                            verbose=False,
                                            model_type=args.model_type)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            pg.expand_encoder_vocab(new_vocab['question_token_to_idx'])
        model = (pg, ee)
    else:
        print('Must give either --baseline_model or --program_generator' +
              'and --execution_engine')
        return

    dtype = torch.FloatTensor
    if args.use_gpu == 1:
        dtype = torch.cuda.FloatTensor
    if args.question is not None and args.image is not None:
        run_single_example(args, model, dtype, args.question)
    # Interactive mode
    elif (args.image is not None and args.input_question_h5 is None
          and args.input_features_h5 is None):
        feats_var = extract_image_features(args, dtype)
        print(colored('Ask me something!', 'cyan'))
        while True:
            # Get user question
            question_raw = input(">>> ")
            run_single_example(args, model, dtype, question_raw, feats_var)
    else:
        vocab = load_vocab(args)
        loader_kwargs = {
            'question_h5': args.input_question_h5,
            'feature_h5': args.input_features_h5,
            'vocab': vocab,
            'batch_size': args.batch_size,
        }
        if args.num_samples is not None and args.num_samples > 0:
            loader_kwargs['max_samples'] = args.num_samples
        if args.family_split_file is not None:
            with open(args.family_split_file, 'r') as f:
                loader_kwargs['question_families'] = json.load(f)
        with ClevrDataLoader(**loader_kwargs) as loader:
            run_batch(args, model, dtype, loader)
def main(args):
    if not args.program_generator:
        args.program_generator = args.execution_engine
    input_question_h5 = os.path.join(args.data_dir,
                                     '{}_questions.h5'.format(args.part))
    input_features_h5 = os.path.join(args.data_dir,
                                     '{}_features.h5'.format(args.part))

    model = None
    if args.baseline_model is not None:
        print('Loading baseline model from ', args.baseline_model)
        model, _ = utils.load_baseline(args.baseline_model)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            model.rnn.expand_vocab(new_vocab['question_token_to_idx'])
    elif args.program_generator is not None and args.execution_engine is not None:
        pg, _ = utils.load_program_generator(args.program_generator)
        ee, _ = utils.load_execution_engine(args.execution_engine,
                                            verbose=False)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            pg.expand_encoder_vocab(new_vocab['question_token_to_idx'])
        model = (pg, ee)
    else:
        print(
            'Must give either --baseline_model or --program_generator and --execution_engine'
        )
        return

    dtype = torch.FloatTensor
    if args.use_gpu == 1:
        dtype = torch.cuda.FloatTensor
    if args.question is not None and args.image is not None:
        run_single_example(args, model, dtype, args.question)
    else:
        vocab = load_vocab(args)
        loader_kwargs = {
            'question_h5': input_question_h5,
            'feature_h5': input_features_h5,
            'vocab': vocab,
            'batch_size': args.batch_size,
        }
        if args.num_samples is not None and args.num_samples > 0:
            loader_kwargs['max_samples'] = args.num_samples
        if args.family_split_file is not None:
            with open(args.family_split_file, 'r') as f:
                loader_kwargs['question_families'] = json.load(f)
        with ClevrDataLoader(**loader_kwargs) as loader:
            run_batch(args, model, dtype, loader)
Esempio n. 7
0
def get_execution_engine(args):
    vocab = utils.load_vocab(args.vocab_json)
    if args.execution_engine_start_from is not None:
        ee, kwargs = utils.load_execution_engine(
            args.execution_engine_start_from, model_type=args.model_type)
    else:
        kwargs = {
            'vocab': vocab,
            'feature_dim': parse_int_list(args.feature_dim),
            'stem_batchnorm': args.module_stem_batchnorm == 1,
            'stem_num_layers': args.module_stem_num_layers,
            'module_dim': args.module_dim,
            'module_residual': args.module_residual == 1,
            'module_batchnorm': args.module_batchnorm == 1,
            'classifier_proj_dim': args.classifier_proj_dim,
            'classifier_downsample': args.classifier_downsample,
            'classifier_fc_layers': parse_int_list(args.classifier_fc_dims),
            'classifier_batchnorm': args.classifier_batchnorm == 1,
            'classifier_dropout': args.classifier_dropout,
        }
        if args.model_type == 'FiLM':
            kwargs['num_modules'] = args.num_modules
            kwargs['stem_kernel_size'] = args.module_stem_kernel_size
            kwargs['stem_stride'] = args.module_stem_stride
            kwargs['stem_padding'] = args.module_stem_padding
            kwargs['module_num_layers'] = args.module_num_layers
            kwargs[
                'module_batchnorm_affine'] = args.module_batchnorm_affine == 1
            kwargs['module_dropout'] = args.module_dropout
            kwargs['module_input_proj'] = args.module_input_proj
            kwargs['module_kernel_size'] = args.module_kernel_size
            kwargs['use_gamma'] = args.use_gamma == 1
            kwargs['use_beta'] = args.use_beta == 1
            kwargs['use_coords'] = args.use_coords
            kwargs['debug_every'] = args.debug_every
            kwargs['print_verbose_every'] = args.print_verbose_every
            kwargs['condition_method'] = args.condition_method
            kwargs['with_cbn'] = args.with_cbn
            kwargs['final_resblock_with_cbn'] = args.final_resblock_with_cbn
            kwargs['condition_pattern'] = parse_int_list(
                args.condition_pattern)
            ee = FiLMedNet(**kwargs)
        else:
            ee = ModuleNet(**kwargs)
    # if cuda.device_count() > 1:
    #     ee = nn.DataParallel(ee)
    ee.cuda()
    ee.train()
    return ee, kwargs
Esempio n. 8
0
def get_execution_engine(args):
  vocab = utils.load_vocab(args.vocab_json)
  if args.execution_engine_start_from is not None:
    ee, kwargs = utils.load_execution_engine(
      args.execution_engine_start_from, model_type=args.model_type)
  else:
    kwargs = {
      'vocab': vocab,
      'feature_dim': parse_int_list(args.feature_dim),
      'stem_batchnorm': args.module_stem_batchnorm == 1,
      'stem_num_layers': args.module_stem_num_layers,
      'module_dim': args.module_dim,
      'module_residual': args.module_residual == 1,
      'module_batchnorm': args.module_batchnorm == 1,
      'classifier_proj_dim': args.classifier_proj_dim,
      'classifier_downsample': args.classifier_downsample,
      'classifier_fc_layers': parse_int_list(args.classifier_fc_dims),
      'classifier_batchnorm': args.classifier_batchnorm == 1,
      'classifier_dropout': args.classifier_dropout,
    }
    if args.model_type.startswith('FiLM'):
      kwargs['num_modules'] = args.num_modules
      kwargs['stem_use_resnet'] = (args.model_type == 'FiLM+ResNet1' or args.model_type == 'FiLM+ResNet0')
      kwargs['stem_resnet_fixed'] = args.model_type == 'FiLM+ResNet0'
      kwargs['stem_kernel_size'] = args.module_stem_kernel_size
      kwargs['stem_stride2_freq'] = args.module_stem_stride2_freq
      kwargs['stem_padding'] = args.module_stem_padding
      kwargs['module_num_layers'] = args.module_num_layers
      kwargs['module_batchnorm_affine'] = args.module_batchnorm_affine == 1
      kwargs['module_dropout'] = args.module_dropout
      kwargs['module_input_proj'] = args.module_input_proj
      kwargs['module_kernel_size'] = args.module_kernel_size
      kwargs['use_gamma'] = args.use_gamma == 1
      kwargs['use_beta'] = args.use_beta == 1
      kwargs['use_coords'] = args.use_coords
      kwargs['debug_every'] = args.debug_every
      kwargs['print_verbose_every'] = args.print_verbose_every
      kwargs['condition_method'] = args.condition_method
      kwargs['condition_pattern'] = parse_int_list(args.condition_pattern)
      ee = FiLMedNet(**kwargs)
    else:
      ee = ModuleNet(**kwargs)
  if torch.cuda.is_available():
    ee.cuda()
  else:
    ee.cpu()
  ee.train()
  return ee, kwargs
Esempio n. 9
0
def main(args):
    if not args.program_generator:
        args.program_generator = args.execution_engine
    input_question_h5 = os.path.join(args.data_dir,
                                     '{}_questions.h5'.format(args.part))
    input_features_h5 = os.path.join(args.data_dir,
                                     '{}_features.h5'.format(args.part))
    input_scenes = os.path.join(args.data_dir,
                                '{}_scenes.json'.format(args.part))
    vocab = load_vocab(args)

    pg, _ = utils.load_program_generator(args.program_generator)
    if pg:
        pg.save_activations = True
    if args.temperature:
        pg.decoder_linear.weight.data /= args.temperature
        pg.decoder_linear.bias.data /= args.temperature
    if args.execution_engine:
        ee, _ = utils.load_execution_engine(args.execution_engine,
                                            verbose=False)
        ee.noise_enabled = False
    else:
        ee = ClevrExecutor(vocab)

    dtype = torch.FloatTensor
    if args.use_gpu == 1:
        dtype = torch.cuda.FloatTensor

    loader_kwargs = {
        'question_h5': input_question_h5,
        'feature_h5': input_features_h5,
        'scene_path': input_scenes if isinstance(ee, ClevrExecutor) else None,
        'vocab': vocab,
        'batch_size': args.batch_size,
    }
    if args.num_samples is not None and args.num_samples > 0:
        loader_kwargs['max_samples'] = args.num_samples
    if args.q_family:
        loader_kwargs['question_families'] = args.q_family
    with ClevrDataLoader(**loader_kwargs) as loader:
        with torch.no_grad():
            run_batch(args, pg, ee, loader, dtype)
Esempio n. 10
0
def main(args):
    if args.debug_every <= 1:
        pdb.set_trace()

    if args.sw_name is not None or args.sw_config is not None:
        assert args.image is None and args.question is None

        from shapeworld import Dataset, torch_util
        from shapeworld.datasets import clevr_util

        class ShapeWorldDataLoader(torch_util.ShapeWorldDataLoader):
            def __iter__(self):
                for batch in super(ShapeWorldDataLoader, self).__iter__():
                    if "caption" in batch:
                        question = batch["caption"].long()
                    else:
                        question = batch["question"].long()
                    if args.sw_features == 1:
                        image = batch["world_features"]
                    else:
                        image = batch["world"]
                    feats = image
                    if "agreement" in batch:
                        answer = batch["agreement"].long()
                    else:
                        answer = batch["answer"].long()
                    if "caption_model" in batch:
                        assert args.sw_name.startswith(
                            "clevr") or args.sw_program == 3
                        program_seq = batch["caption_model"]
                        # .apply_(callable=(lambda model: clevr_util.parse_program(mode=0, model=model)))
                    elif "question_model" in batch:
                        program_seq = batch["question_model"]
                    elif "caption" in batch:
                        if args.sw_program == 1:
                            program_seq = batch["caption_pn"].long()
                        elif args.sw_program == 2:
                            program_seq = batch["caption_rpn"].long()
                        else:
                            program_seq = [None]
                    else:
                        program_seq = [None]
                    # program_seq = torch.IntTensor([0 for _ in batch['question']])
                    program_json = dict()
                    yield question, image, feats, answer, program_seq, program_json

        dataset = Dataset.create(
            dtype=args.sw_type,
            name=args.sw_name,
            variant=args.sw_variant,
            language=args.sw_language,
            config=args.sw_config,
        )
        print("ShapeWorld dataset: {} (variant: {})".format(
            dataset, args.sw_variant))
        print("Config: " + str(args.sw_config))

        if args.program_generator is not None:
            with open(args.program_generator + ".vocab", "r") as filehandle:
                vocab = json.load(filehandle)
        elif args.execution_engine is not None:
            with open(args.execution_engine + ".vocab", "r") as filehandle:
                vocab = json.load(filehandle)
        elif args.baseline_model is not None:
            with open(args.baseline_model + ".vocab", "r") as filehandle:
                vocab = json.load(filehandle)
        program_token_to_idx = vocab["program_token_to_idx"]

        include_model = args.model_type in ("PG", "EE", "PG+EE") and (
            args.sw_name.startswith("clevr") or args.sw_program == 3)
        if include_model:

            def preprocess(model):
                if args.sw_name.startswith("clevr"):
                    program_prefix = vr.programs.list_to_prefix(
                        model["program"])
                else:
                    program_prefix = clevr_util.parse_program(mode=0,
                                                              model=model)
                program_str = vr.programs.list_to_str(program_prefix)
                program_tokens = tokenize(program_str)
                program_encoded = encode(program_tokens, program_token_to_idx)
                program_encoded += [
                    program_token_to_idx["<NULL>"]
                    for _ in range(27 - len(program_encoded))
                ]
                return np.asarray(program_encoded, dtype=np.int64)

            if args.sw_name.startswith("clevr"):
                preprocessing = dict(question_model=preprocess)
            else:
                preprocessing = dict(caption_model=preprocess)

        elif args.sw_program in (1, 2):

            def preprocess(caption_pn):
                caption_pn += (caption_pn > 0) * 2
                for n, symbol in enumerate(caption_pn):
                    if symbol == 0:
                        caption_pn[n] = 2
                        break
                caption_pn = np.concatenate(([1], caption_pn))
                return caption_pn

            if args.sw_program == 1:
                preprocessing = dict(caption_pn=preprocess)
            else:
                preprocessing = dict(caption_rpn=preprocess)

        else:
            preprocessing = None

        dataset = torch_util.ShapeWorldDataset(
            dataset=dataset,
            mode=(None if args.sw_mode == "none" else args.sw_mode),
            include_model=include_model,
            epoch=(args.num_samples is None),
            preprocessing=preprocessing,
        )

        loader = ShapeWorldDataLoader(dataset=dataset,
                                      batch_size=args.batch_size)

    model = None
    if args.model_type in ("CNN", "LSTM", "CNN+LSTM", "CNN+LSTM+SA"):
        assert args.baseline_model is not None
        print("Loading baseline model from", args.baseline_model)
        model, _ = utils.load_baseline(args.baseline_model)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            model.rnn.expand_vocab(new_vocab["question_token_to_idx"])
    elif args.program_generator is not None and args.execution_engine is not None:
        pg, _ = utils.load_program_generator(args.program_generator,
                                             args.model_type)
        ee, _ = utils.load_execution_engine(args.execution_engine,
                                            verbose=False,
                                            model_type=args.model_type)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            pg.expand_encoder_vocab(new_vocab["question_token_to_idx"])
        model = (pg, ee)
    elif args.model_type == "FiLM":
        assert args.baseline_model is not None
        pg, _ = utils.load_program_generator(args.baseline_model,
                                             args.model_type)
        ee, _ = utils.load_execution_engine(args.baseline_model,
                                            verbose=False,
                                            model_type=args.model_type)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            pg.expand_encoder_vocab(new_vocab["question_token_to_idx"])
        model = (pg, ee)
    else:
        print(
            "Must give either --baseline_model or --program_generator and --execution_engine"
        )
        return

    if torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor
    if args.question is not None and args.image is not None:
        run_single_example(args, model, dtype, args.question)
    # Interactive mode
    elif (args.image is not None and args.input_question_h5 is None
          and args.input_features_h5 is None):
        feats_var = extract_image_features(args, dtype)
        print(colored("Ask me something!", "cyan"))
        while True:
            # Get user question
            question_raw = input(">>> ")
            run_single_example(args, model, dtype, question_raw, feats_var)
    elif args.sw_name is not None or args.sw_config is not None:
        predictions, visualization = run_batch(args, model, dtype, loader)
        if args.sw_pred_dir is not None:
            assert args.sw_pred_name is not None
            pred_dir = os.path.join(
                args.sw_pred_dir,
                dataset.dataset.type,
                dataset.dataset.name,
                dataset.dataset.variant,
            )
            if not os.path.isdir(pred_dir):
                os.makedirs(pred_dir)
            id2word = dataset.dataset.vocabulary(value_type="language")
            with open(
                    os.path.join(
                        pred_dir,
                        args.sw_pred_name + "-" + args.sw_mode + ".txt"),
                    "w",
            ) as filehandle:
                filehandle.write("".join(
                    "{} {} {}\n".format(correct, agreement, " ".join(
                        id2word[c] for c in caption))
                    for correct, agreement, caption in zip(
                        predictions["correct"],
                        predictions["agreement"],
                        predictions["caption"],
                    )))
            print("Predictions saved")
        if args.sw_vis_dir is not None:
            assert args.sw_vis_name is not None
            from io import BytesIO
            from shapeworld.world import World

            vis_dir = os.path.join(
                args.sw_vis_dir,
                dataset.dataset.type,
                dataset.dataset.name,
                dataset.dataset.variant,
            )
            image_dir = os.path.join(vis_dir, args.sw_mode, "images")
            if not os.path.isdir(image_dir):
                os.makedirs(image_dir)
            worlds = np.transpose(visualization["world"], (0, 2, 3, 1))
            for n in range(worlds.shape[0]):
                image = World.get_image(world_array=worlds[n])
                image_bytes = BytesIO()
                image.save(image_bytes, format="png")
                with open(os.path.join(image_dir, "world-{}.png".format(n)),
                          "wb") as filehandle:
                    filehandle.write(image_bytes.getvalue())
                image_bytes.close()
            with open(
                    os.path.join(
                        vis_dir,
                        args.sw_vis_name + "-" + args.sw_mode + ".html"),
                    "w",
            ) as filehandle:
                html = dataset.dataset.get_html(
                    generated=visualization,
                    image_format="png",
                    image_dir=(args.sw_mode + "/images/"),
                )
                filehandle.write(html)
            print("Visualization saved")
    else:
        vocab = load_vocab(args)
        loader_kwargs = {
            "question_h5": args.input_question_h5,
            "feature_h5": args.input_features_h5,
            "vocab": vocab,
            "batch_size": args.batch_size,
        }
        if args.family_split_file is not None:
            with open(args.family_split_file, "r") as f:
                loader_kwargs["question_families"] = json.load(f)
        with ClevrDataLoader(**loader_kwargs) as loader:
            run_batch(args, model, dtype, loader)