Ejemplo n.º 1
0
    def test_model_configs_for_keys(self):
        models_mapping = registry.mapping["model_name_mapping"]

        for model_key, model_cls in models_mapping.items():
            if model_cls.config_path() is None:
                warnings.warn(
                    ("Model {} has no default configuration defined. " +
                     "Skipping it. Make sure it is intentional"
                     ).format(model_key))
                continue

            with contextlib.redirect_stdout(StringIO()):
                args = dummy_args(model=model_key)
                configuration = Configuration(args)
                configuration.freeze()
                config = configuration.get_config()

                if model_key == "mmft":
                    continue

                self.assertTrue(
                    model_key in config.model_config,
                    "Key for model {} doesn't exists in its configuration".
                    format(model_key),
                )
Ejemplo n.º 2
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     model_name = "vilbert"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_class = registry.get_model_class(model_name)
     self.vision_feature_size = 1024
     self.vision_target_size = 1279
     config.model_config[model_name]["training_head_type"] = "pretraining"
     config.model_config[model_name][
         "visual_embedding_dim"] = self.vision_feature_size
     config.model_config[model_name][
         "v_feature_size"] = self.vision_feature_size
     config.model_config[model_name][
         "v_target_size"] = self.vision_target_size
     config.model_config[model_name]["dynamic_attention"] = False
     self.pretrain_model = model_class(config.model_config[model_name])
     self.pretrain_model.build()
     config.model_config[model_name][
         "training_head_type"] = "classification"
     config.model_config[model_name]["num_labels"] = 2
     self.finetune_model = model_class(config.model_config[model_name])
     self.finetune_model.build()
Ejemplo n.º 3
0
    def setUp(self):
        test_utils.setup_proxy()
        setup_imports()
        model_name = "vinvl"
        args = test_utils.dummy_args(model=model_name, dataset="test")
        configuration = Configuration(args)
        config = configuration.get_config()
        model_config = config.model_config[model_name]
        model_config.model = model_name
        model_config.do_pretraining = False
        classification_config_dict = {
            "do_pretraining": False,
            "heads": {"mlp": {"num_labels": 3129}},
            "ce_loss": {"ignore_index": -1},
        }
        self.classification_config = OmegaConf.create(
            {**model_config, **classification_config_dict}
        )

        pretraining_config_dict = {
            "do_pretraining": True,
            "heads": {"mlm": {"hidden_size": 768}},
        }
        self.pretraining_config = OmegaConf.create(
            {**model_config, **pretraining_config_dict}
        )

        self.sample_list = self._get_sample_list()
Ejemplo n.º 4
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     self.model_name = "mmf_transformer"
     args = test_utils.dummy_args(model=self.model_name)
     configuration = Configuration(args)
     self.config = configuration.get_config()
     self.config.model_config[self.model_name].model = self.model_name
Ejemplo n.º 5
0
 def setUp(self):
     setup_imports()
     self.model_name = "mmf_transformer"
     args = test_utils.dummy_args(model=self.model_name)
     configuration = Configuration(args)
     self.config = configuration.get_config()
     self.model_class = registry.get_model_class(self.model_name)
     self.finetune_model = self.model_class(
         self.config.model_config[self.model_name])
     self.finetune_model.build()
Ejemplo n.º 6
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     model_name = "vilt"
     args = test_utils.dummy_args(model=model_name, dataset="test")
     configuration = Configuration(args)
     config = configuration.get_config()
     model_config = config.model_config[model_name]
     model_config.model = model_name
     self.pretrain_model = build_model(model_config)
Ejemplo n.º 7
0
def run(opts: typing.Optional[typing.List[str]] = None, predict: bool = False):
    """Run starts a job based on the command passed from the command line.
    You can optionally run the mmf job programmatically by passing an optlist as opts.

    Args:
        opts (typing.Optional[typing.List[str]], optional): Optlist which can be used.
            to override opts programmatically. For e.g. if you pass
            opts = ["training.batch_size=64", "checkpoint.resume=True"], this will
            set the batch size to 64 and resume from the checkpoint if present.
            Defaults to None.
        predict (bool, optional): If predict is passed True, then the program runs in
            prediction mode. Defaults to False.
    """
    setup_imports()

    if opts is None:
        parser = flags.get_parser()
        args = parser.parse_args()
    else:
        args = argparse.Namespace(config_override=None)
        args.opts = opts

    print(args)
    configuration = Configuration(args)
    # Do set runtime args which can be changed by MMF
    configuration.args = args
    config = configuration.get_config()
    config.start_rank = 0
    if config.distributed.init_method is None:
        infer_init_method(config)

    if config.distributed.init_method is not None:
        if torch.cuda.device_count() > 1 and not config.distributed.no_spawn:
            config.start_rank = config.distributed.rank
            config.distributed.rank = None
            torch.multiprocessing.spawn(
                fn=distributed_main,
                args=(configuration, predict),
                nprocs=torch.cuda.device_count(),
            )
        else:
            distributed_main(0, configuration, predict)
    elif config.distributed.world_size > 1:
        assert config.distributed.world_size <= torch.cuda.device_count()
        port = random.randint(10000, 20000)
        config.distributed.init_method = f"tcp://localhost:{port}"
        config.distributed.rank = None
        torch.multiprocessing.spawn(
            fn=distributed_main,
            args=(configuration, predict),
            nprocs=config.distributed.world_size,
        )
    else:
        config.device_id = 0
        main(configuration, predict=predict)
Ejemplo n.º 8
0
 def setUp(self):
     setup_imports()
     model_name = "mmbt"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_class = registry.get_model_class(model_name)
     config.model_config[model_name]["training_head_type"] = "classification"
     config.model_config[model_name]["num_labels"] = 2
     self.finetune_model = model_class(config.model_config[model_name])
     self.finetune_model.build()
Ejemplo n.º 9
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     replace_with_jit()
     model_name = "visual_bert"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_config = config.model_config[model_name]
     model_config.model = model_name
     self.pretrain_model = build_model(model_config)
Ejemplo n.º 10
0
 def setUpClass(cls) -> None:
     cls._tmpdir = tempfile.mkdtemp()
     args = argparse.Namespace()
     args.opts = [
         f"env.save_dir={cls._tmpdir}", f"model=cnn_lstm", f"dataset=clevr"
     ]
     args.config_override = None
     configuration = Configuration(args)
     configuration.freeze()
     cls.config = configuration.get_config()
     registry.register("config", cls.config)
     cls.writer = Logger(cls.config)
Ejemplo n.º 11
0
 def setUp(self):
     test_utils.setup_proxy()
     setup_imports()
     model_name = "mmbt"
     args = test_utils.dummy_args(model=model_name)
     configuration = Configuration(args)
     config = configuration.get_config()
     model_config = config.model_config[model_name]
     model_config["training_head_type"] = "classification"
     model_config["num_labels"] = 2
     model_config.model = model_name
     self.finetune_model = build_model(model_config)
Ejemplo n.º 12
0
 def setUpClass(cls) -> None:
     cls._tmpdir = tempfile.mkdtemp()
     args = argparse.Namespace()
     args.opts = [
         f"env.save_dir={cls._tmpdir}", f"model=cnn_lstm", f"dataset=clevr"
     ]
     args.opts.append(f"config={os.path.join('configs', 'defaults.yaml')}")
     args.config_override = None
     configuration = Configuration(args)
     configuration.freeze()
     cls.config = configuration.get_config()
     registry.register("config", cls.config)
     cls.writer = setup_logger()
Ejemplo n.º 13
0
    def _init_processors(self):
        args = Namespace()
        args.opts = [
            "config=projects/visual_bert/configs/vqa2/defaults.yaml",
            "datasets=vqa2",
            "model=visual_bert",
            "evaluation.predict=True"
        ]
        args.config_override = None

        configuration = Configuration(args=args)

        config = self.config = configuration.config
        vqa_config = config.dataset_config.vqa2
        text_processor_config = vqa_config.processors.text_processor
        answer_processor_config = vqa_config.processors.answer_processor

        text_processor_config.params.vocab.vocab_file = self.root + "/content/model_data/vocabulary_100k.txt"
        answer_processor_config.params.vocab_file = self.root + "/content/model_data/answers_vqa.txt"
        # Add preprocessor as that will needed when we are getting questions from user
        self.text_processor = BertTokenizer(text_processor_config.params)
        self.answer_processor = VQAAnswerProcessor(answer_processor_config.params)

        registry.register("vqa2_text_processor", self.text_processor)
        registry.register("vqa2_answer_processor", self.answer_processor)
        registry.register("vqa2_num_final_outputs", 
                          self.answer_processor.get_vocab_size())
Ejemplo n.º 14
0
def build_config(configuration: Configuration, *args, **kwargs) -> DictConfig:
    """Builder function for config. Freezes the configuration and registers
    configuration object and config DictConfig object to registry.

    Args:
        configuration (Configuration): Configuration object that will be
            used to create the config.

    Returns:
        (DictConfig): A config which is of type omegaconf.DictConfig
    """
    configuration.freeze()
    config = configuration.get_config()
    registry.register("config", config)
    registry.register("configuration", configuration)

    return config
Ejemplo n.º 15
0
    def setUp(self):
        test_utils.setup_proxy()
        setup_imports()
        model_name = "uniter"
        args = test_utils.dummy_args(model=model_name, dataset="vqa2")
        configuration = Configuration(args)
        config = configuration.get_config()
        model_config = config.model_config[model_name]
        model_config.model = model_name
        model_config.losses = {"vqa2": "logit_bce"}
        model_config.do_pretraining = False
        model_config.tasks = "vqa2"
        classification_config_dict = {
            "do_pretraining": False,
            "tasks": "vqa2",
            "heads": {
                "vqa2": {
                    "type": "mlp",
                    "num_labels": 3129
                }
            },
            "losses": {
                "vqa2": "logit_bce"
            },
        }
        classification_config = OmegaConf.create({
            **model_config,
            **classification_config_dict
        })

        pretraining_config_dict = {
            "do_pretraining": True,
            "tasks": "wra",
            "heads": {
                "wra": {
                    "type": "wra"
                }
            },
        }
        pretraining_config = OmegaConf.create({
            **model_config,
            **pretraining_config_dict
        })

        self.model_for_classification = build_model(classification_config)
        self.model_for_pretraining = build_model(pretraining_config)
def get_model(device, opts):
    from mmf.utils.build import build_config, build_trainer
    from mmf.common.registry import registry
    from mmf.utils.configuration import Configuration
    from mmf.utils.env import set_seed, setup_imports
    args = argparse.Namespace(config_override=None)
    args.opts = opts
    configuration = Configuration(args)
    configuration.args = args
    config = configuration.get_config()
    config.start_rank = 0
    config.device_id = 0
    setup_imports()
    configuration.import_user_dir()
    config = configuration.get_config()

    if torch.cuda.is_available():
        torch.cuda.set_device(config.device_id)
        torch.cuda.init()

    config.training.seed = set_seed(config.training.seed)
    registry.register("seed", config.training.seed)

    config = build_config(configuration)

    # Logger should be registered after config is registered
    registry.register("writer", Logger(config, name="mmf.train"))
    trainer = build_trainer(config)
    # trainer.load()
    ready_trainer(trainer)
    trainer.model.to(device)
    return trainer.model
Ejemplo n.º 17
0
    def setUp(self):
        test_utils.setup_proxy()
        setup_imports()
        model_name = "vilbert"
        args = test_utils.dummy_args(model=model_name)
        configuration = Configuration(args)
        config = configuration.get_config()
        self.vision_feature_size = 1024
        self.vision_target_size = 1279
        model_config = config.model_config[model_name]
        model_config["training_head_type"] = "pretraining"
        model_config["visual_embedding_dim"] = self.vision_feature_size
        model_config["v_feature_size"] = self.vision_feature_size
        model_config["v_target_size"] = self.vision_target_size
        model_config["dynamic_attention"] = False
        model_config.model = model_name

        model_config["training_head_type"] = "classification"
        model_config["num_labels"] = 2
        self.model_config = model_config
Ejemplo n.º 18
0
 def setUp(self):
     setup_imports()
     torch.manual_seed(1234)
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "butd",
         "configs",
         "coco",
         "beam_search.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="butd", dataset="coco")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "coco"
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
Ejemplo n.º 19
0
    def test_dataset_configs_for_keys(self):
        builder_name = registry.mapping["builder_name_mapping"]

        for builder_key, builder_cls in builder_name.items():
            if builder_cls.config_path() is None:
                warnings.warn(
                    ("Dataset {} has no default configuration defined. " +
                     "Skipping it. Make sure it is intentional"
                     ).format(builder_key))
                continue

            with contextlib.redirect_stdout(StringIO()):
                args = dummy_args(dataset=builder_key)
                configuration = Configuration(args)
                configuration.freeze()
                config = configuration.get_config()
                self.assertTrue(
                    builder_key in config.dataset_config,
                    "Key for dataset {} doesn't exists in its configuration".
                    format(builder_key),
                )
Ejemplo n.º 20
0
 def setUp(self):
     torch.manual_seed(1234)
     registry.register("clevr_text_vocab_size", 80)
     registry.register("clevr_num_final_outputs", 32)
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "others",
         "cnn_lstm",
         "clevr",
         "defaults.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="cnn_lstm", dataset="clevr")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "clevr"
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
Ejemplo n.º 21
0
 def setUp(self):
     setup_imports()
     torch.manual_seed(1234)
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "butd",
         "configs",
         "coco",
         "nucleus_sampling.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="butd", dataset="coco")
     args.opts.append(f"config={config_path}")
     configuration = Configuration(args)
     configuration.config.datasets = "coco"
     configuration.config.model_config.butd.inference.params.sum_threshold = 0.5
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
Ejemplo n.º 22
0
 def test_config_overrides(self):
     config_path = os.path.join(
         get_mmf_root(),
         "..",
         "projects",
         "m4c",
         "configs",
         "textvqa",
         "defaults.yaml",
     )
     config_path = os.path.abspath(config_path)
     args = dummy_args(model="m4c", dataset="textvqa")
     args.opts += [
         f"config={config_path}", "training.lr_steps[1]=10000",
         "dataset_config.textvqa.zoo_requirements[0]=\"test\""
     ]
     configuration = Configuration(args)
     configuration.freeze()
     config = configuration.get_config()
     self.assertEqual(config.training.lr_steps[1], 10000)
     self.assertEqual(config.dataset_config.textvqa.zoo_requirements[0],
                      "test")
Ejemplo n.º 23
0
Archivo: run.py Proyecto: kyhoolee/mmf
def run(predict=False):
    setup_imports()
    parser = flags.get_parser()
    args = parser.parse_args()
    print(args)
    configuration = Configuration(args)
    # Do set runtime args which can be changed by MMF
    configuration.args = args
    config = configuration.get_config()
    config.start_rank = 0
    if config.distributed.init_method is None:
        infer_init_method(config)

    if config.distributed.init_method is not None:
        if torch.cuda.device_count() > 1 and not config.distributed.no_spawn:
            config.start_rank = config.distributed.rank
            config.distributed.rank = None
            torch.multiprocessing.spawn(
                fn=distributed_main,
                args=(configuration, predict),
                nprocs=torch.cuda.device_count(),
            )
        else:
            distributed_main(0, configuration, predict)
    elif config.distributed.world_size > 1:
        assert config.distributed.world_size <= torch.cuda.device_count()
        port = random.randint(10000, 20000)
        config.distributed.init_method = "tcp://localhost:{port}".format(
            port=port)
        config.distributed.rank = None
        torch.multiprocessing.spawn(
            fn=distributed_main,
            args=(configuration, predict),
            nprocs=config.distributed.world_size,
        )
    else:
        config.device_id = 0
        main(configuration, predict)
Ejemplo n.º 24
0
    def _init_processors(self):
        args = Namespace()
        args.opts = [
            "config=projects/pythia/configs/vqa2/defaults.yaml",
            "datasets=vqa2", "model=visual_bert", "evaluation.predict=True"
        ]
        args.config_override = None

        configuration = Configuration(args=args)

        config = self.config = configuration.config
        vqa2_config = config.dataset_config.vqa2
        text_processor_config = vqa2_config.processors.text_processor

        text_processor_config.params.vocab.vocab_file = "../model_data/vocabulary_100k.txt"

        # Add preprocessor as that will needed when we are getting questions from user
        self.text_processor = VocabProcessor(text_processor_config.params)

        registry.register("coco_text_processor", self.text_processor)
Ejemplo n.º 25
0
    def test_init_processors(self):
        path = os.path.join(
            os.path.abspath(__file__),
            "../../../mmf/configs/datasets/vqa2/defaults.yaml",
        )
        args = dummy_args()
        args.opts.append("config={}".format(path))
        configuration = Configuration(args)
        answer_processor = (configuration.get_config().dataset_config.vqa2.
                            processors.answer_processor)
        vocab_path = os.path.join(os.path.abspath(__file__), "..", "..",
                                  "data", "vocab.txt")
        answer_processor.params.vocab_file = os.path.abspath(vocab_path)
        self._fix_configuration(configuration)
        configuration.freeze()

        base_dataset = BaseDataset(
            "vqa2",
            configuration.get_config().dataset_config.vqa2, "train")
        expected_processors = [
            "answer_processor",
            "ocr_token_processor",
            "bbox_processor",
        ]

        # Check no processors are initialized before init_processors call
        self.assertFalse(
            any(hasattr(base_dataset, key) for key in expected_processors))

        for processor in expected_processors:
            self.assertIsNone(registry.get("{}_{}".format("vqa2", processor)))

        # Check processors are initialized after init_processors
        base_dataset.init_processors()
        self.assertTrue(
            all(hasattr(base_dataset, key) for key in expected_processors))
        for processor in expected_processors:
            self.assertIsNotNone(
                registry.get("{}_{}".format("vqa2", processor)))
Ejemplo n.º 26
0
class HMConverter:
    IMAGE_FILES = ["img.tar.gz", "img"]
    JSONL_PHASE_ONE_FILES = ["train.jsonl", "dev.jsonl", "test.jsonl"]
    JSONL_PHASE_TWO_FILES = [
        "train.jsonl",
        "dev_seen.jsonl",
        "test_seen.jsonl",
        "dev_unseen.jsonl",
        "test_unseen.jsonl",
    ]
    POSSIBLE_CHECKSUMS = [
        "d8f1073f5fbf1b08a541cc2325fc8645619ab8ed768091fb1317d5c3a6653a77",
        "a424c003b7d4ea3f3b089168b5f5ea73b90a3ff043df4b8ff4d7ed87c51cb572",
        "6e609b8c230faff02426cf462f0c9528957b7884d68c60ebc26ff83846e5f80f",
        "c1363aae9649c79ae4abfdb151b56d3d170187db77757f3daa80856558ac367c",
    ]

    def __init__(self):
        self.parser = self.get_parser()
        self.args = self.parser.parse_args()
        self.configuration = Configuration()

    def assert_files(self, folder):
        files_needed = self.JSONL_PHASE_ONE_FILES
        phase_one = True
        for file in files_needed:
            try:
                assert PathManager.exists(
                    os.path.join(folder, "data", file)
                ), f"{file} doesn't exist in {folder}"
            except AssertionError:
                phase_one = False

        if not phase_one:
            files_needed = self.JSONL_PHASE_TWO_FILES
            for file in files_needed:
                assert PathManager.exists(
                    os.path.join(folder, "data", file)
                ), f"{file} doesn't exist in {folder}"
        else:
            warnings.warn(
                "You are on Phase 1 of the Hateful Memes Challenge. "
                "Please update to Phase 2"
            )

        files_needed = self.IMAGE_FILES

        exists = False

        for file in files_needed:
            exists = exists or PathManager.exists(os.path.join(folder, "data", file))

        if not exists:
            raise AssertionError("Neither img or img.tar.gz exists in current zip")

        return phase_one

    def get_parser(self):
        parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)

        parser.add_argument(
            "--zip_file",
            required=True,
            type=str,
            help="Zip file downloaded from the DrivenData",
        )

        parser.add_argument(
            "--password", required=None, type=str, help="Password for the zip file"
        )
        parser.add_argument(
            "--move", required=None, type=int, help="Move data dir to mmf cache dir"
        )
        parser.add_argument(
            "--mmf_data_folder", required=None, type=str, help="MMF Data folder"
        )
        parser.add_argument(
            "--bypass_checksum",
            required=None,
            type=int,
            help="Pass 1 if you want to skip checksum",
        )
        return parser

    def convert(self):
        config = self.configuration.get_config()
        data_dir = config.env.data_dir

        if self.args.mmf_data_folder:
            data_dir = self.args.mmf_data_folder

        bypass_checksum = False
        if self.args.bypass_checksum:
            bypass_checksum = bool(self.args.bypass_checksum)

        print(f"Data folder is {data_dir}")
        print(f"Zip path is {self.args.zip_file}")

        base_path = os.path.join(data_dir, "datasets", "hateful_memes", "defaults")

        images_path = os.path.join(base_path, "images")
        PathManager.mkdirs(images_path)

        move_dir = False
        if self.args.move:
            move_dir = bool(self.args.move)

        if not bypass_checksum:
            self.checksum(self.args.zip_file, self.POSSIBLE_CHECKSUMS)

        src = self.args.zip_file
        dest = images_path
        if move_dir:
            print(f"Moving {src}")
            move(src, dest)
        else:
            print(f"Copying {src}")
            copy(src, dest)

        print(f"Unzipping {src}")
        self.decompress_zip(
            dest, fname=os.path.basename(src), password=self.args.password
        )

        phase_one = self.assert_files(images_path)

        annotations_path = os.path.join(base_path, "annotations")
        PathManager.mkdirs(annotations_path)
        annotations = (
            self.JSONL_PHASE_ONE_FILES
            if phase_one is True
            else self.JSONL_PHASE_TWO_FILES
        )

        for annotation in annotations:
            print(f"Moving {annotation}")
            src = os.path.join(images_path, "data", annotation)
            dest = os.path.join(annotations_path, annotation)
            move(src, dest)

        images = self.IMAGE_FILES

        for image_file in images:
            src = os.path.join(images_path, "data", image_file)
            if PathManager.exists(src):
                print(f"Moving {image_file}")
            else:
                continue
            dest = os.path.join(images_path, image_file)
            move(src, dest)
            if src.endswith(".tar.gz"):
                decompress(dest, fname=image_file, delete_original=False)

    def checksum(self, file, hashes):
        sha256_hash = hashlib.sha256()
        destination = file

        with PathManager.open(destination, "rb") as f:
            print("Starting checksum for {}".format(os.path.basename(file)))
            for byte_block in iter(lambda: f.read(65536), b""):
                sha256_hash.update(byte_block)
            if sha256_hash.hexdigest() not in hashes:
                # remove_dir(download_path)
                raise AssertionError(
                    f"Checksum of downloaded file does not match the expected "
                    + "checksum. Please try again."
                )
            else:
                print("Checksum successful")

    def decompress_zip(self, dest, fname, password=None):
        path = os.path.join(dest, fname)
        print("Extracting the zip can take time. Sit back and relax.")
        try:
            # Python's zip file module is very slow with password encrypted files
            # Try command line
            command = ["unzip", "-o", "-q", "-d", dest]
            if password:
                command += ["-P", password]
            command += [path]
            subprocess.run(command, check=True)
        except Exception:
            obj = zipfile.ZipFile(path, "r")
            if password:
                obj.setpassword(password.encode("utf-8"))
            obj.extractall(path=dest)
            obj.close()
Ejemplo n.º 27
0
 def __init__(self):
     self.parser = self.get_parser()
     self.args = self.parser.parse_args()
     self.configuration = Configuration()
Ejemplo n.º 28
0
class HMConverter:
    IMAGE_FILES = ["img.tar.gz", "img"]
    JSONL_FILES = ["train.jsonl", "dev.jsonl", "test.jsonl"]
    POSSIBLE_CHECKSUMS = [
        "d8f1073f5fbf1b08a541cc2325fc8645619ab8ed768091fb1317d5c3a6653a77",
        "a424c003b7d4ea3f3b089168b5f5ea73b90a3ff043df4b8ff4d7ed87c51cb572",
    ]

    def __init__(self):
        self.parser = self.get_parser()
        self.args = self.parser.parse_args()
        self.configuration = Configuration()

    def assert_files(self, folder):
        files_needed = self.JSONL_FILES

        for file in files_needed:
            assert PathManager.exists(os.path.join(
                folder, "data", file)), f"{file} doesn't exist in {folder}"

        files_needed = self.IMAGE_FILES

        exists = False

        for file in files_needed:
            exists = exists or PathManager.exists(
                os.path.join(folder, "data", file))

        if not exists:
            raise AssertionError(
                "Neither img or img.tar.gz exists in current zip")

    def get_parser(self):
        parser = argparse.ArgumentParser(
            formatter_class=argparse.RawTextHelpFormatter)

        parser.add_argument(
            "--zip_file",
            required=True,
            type=str,
            help="Zip file downloaded from the DrivenData",
        )

        parser.add_argument("--password",
                            required=True,
                            type=str,
                            help="Password for the zip file")
        parser.add_argument("--mmf_data_folder",
                            required=None,
                            type=str,
                            help="MMF Data folder")
        return parser

    def convert(self):
        config = self.configuration.get_config()
        data_dir = config.env.data_dir

        if self.args.mmf_data_folder:
            data_dir = self.args.mmf_data_folder

        print(f"Data folder is {data_dir}")
        print(f"Zip path is {self.args.zip_file}")

        base_path = os.path.join(data_dir, "datasets", "hateful_memes",
                                 "defaults")

        images_path = os.path.join(base_path, "images")
        PathManager.mkdirs(images_path)

        src = self.args.zip_file
        self.checksum(self.args.zip_file, self.POSSIBLE_CHECKSUMS)
        print(f"Moving {src}")
        dest = images_path
        move(src, dest)

        print(f"Unzipping {src}")
        self.decompress_zip(dest,
                            fname=os.path.basename(src),
                            password=self.args.password)

        self.assert_files(images_path)

        annotations_path = os.path.join(base_path, "annotations")
        PathManager.mkdirs(annotations_path)
        annotations = self.JSONL_FILES

        for annotation in annotations:
            print(f"Moving {annotation}")
            src = os.path.join(images_path, "data", annotation)
            dest = annotations_path
            move(src, dest)

        images = self.IMAGE_FILES

        for image_file in images:
            src = os.path.join(images_path, "data", image_file)
            if PathManager.exists(src):
                print(f"Moving {image_file}")
            else:
                continue
            dest = images_path
            move(src, dest)
            if src.endswith(".tar.gz"):
                decompress(dest, fname=image_file, delete_original=False)

    def checksum(self, file, hashes):
        sha256_hash = hashlib.sha256()
        destination = file

        with PathManager.open(destination, "rb") as f:
            print("Starting checksum for {}".format(os.path.basename(file)))
            for byte_block in iter(lambda: f.read(65536), b""):
                sha256_hash.update(byte_block)
            if sha256_hash.hexdigest() not in hashes:
                # remove_dir(download_path)
                raise AssertionError(
                    f"Checksum of downloaded file does not match the expected "
                    + "checksum. Please try again.")
            else:
                print("Checksum successful")

    def decompress_zip(self, dest, fname, password=None):
        obj = zipfile.ZipFile(os.path.join(dest, fname), "r")
        if password:
            obj.setpassword(password.encode("utf-8"))
        obj.extractall(path=dest)
        obj.close()