コード例 #1
0
    def load_annotation_db(self, path):
        # Expect two paths, one to questions and one to annotations
        assert (
            len(path) == 2
        ), "VQACPv2 requires 2 paths; one to questions and one to annotations"

        with PathManager.open(path[0]) as f:
            path_0 = json.load(f)
        with PathManager.open(path[1]) as f:
            path_1 = json.load(f)

        if "annotations" in path[0]:
            annotations = path_0
            questions = path_1
        else:
            annotations = path_1
            questions = path_0

        # Convert to linear format
        data = []
        question_dict = {}
        for question in questions:
            question_dict[question["question_id"]] = question["question"]

        for annotation in annotations:
            annotation["question"] = question_dict[annotation["question_id"]]
            answers = []
            for answer in annotation["answers"]:
                answers.append(answer["answer"])
            annotation["answers"] = answers
            data.append(copy.deepcopy(annotation))

        self.data = data
コード例 #2
0
ファイル: encoders.py プロジェクト: hahaxun/mmf
    def __init__(self, config: Config, *args, **kwargs):
        super().__init__()
        model_data_dir = get_absolute_path(config.model_data_dir)

        if not os.path.isabs(config.weights_file):
            weights_file = os.path.join(model_data_dir, config.weights_file)
        if not os.path.isabs(config.bias_file):
            bias_file = os.path.join(model_data_dir, config.bias_file)

        if not PathManager.exists(bias_file) or not PathManager.exists(
                weights_file):
            download_path = download_pretrained_model("detectron.vmb_weights")
            weights_file = get_absolute_path(
                os.path.join(download_path, "fc7_w.pkl"))
            bias_file = get_absolute_path(
                os.path.join(download_path, "fc7_b.pkl"))

        with PathManager.open(weights_file, "rb") as w:
            weights = pickle.load(w)
        with PathManager.open(bias_file, "rb") as b:
            bias = pickle.load(b)
        out_dim = bias.shape[0]

        self.lc = nn.Linear(config.in_dim, out_dim)
        self.lc.weight.data.copy_(torch.from_numpy(weights))
        self.lc.bias.data.copy_(torch.from_numpy(bias))
        self.out_dim = out_dim
コード例 #3
0
ファイル: test_logistics.py プロジェクト: hahaxun/mmf
 def test_on_update_end(self):
     self.cb.on_train_start()
     self.cb.on_update_end(meter=self.trainer.meter, should_log=False)
     f = PathManager.open(os.path.join(self.tmpdir, "train.log"))
     self.assertFalse(
         any("time_since_start" in line for line in f.readlines()))
     self.cb.on_update_end(meter=self.trainer.meter, should_log=True)
     f = PathManager.open(os.path.join(self.tmpdir, "train.log"))
     self.assertTrue(
         any("time_since_start" in line for line in f.readlines()))
コード例 #4
0
ファイル: test_logger.py プロジェクト: hahaxun/mmf
 def test_log_writer(self) -> None:
     self.writer.info(self._tmpfile_write_contents)
     f = PathManager.open(
         glob.glob(os.path.join(self._tmpdir, "logs", "train*"))[0])
     self.assertTrue(
         any(self._tmpfile_write_contents in line
             for line in f.readlines()))
     f = PathManager.open(os.path.join(self._tmpdir, "train.log"))
     self.assertTrue(
         any(self._tmpfile_write_contents in line
             for line in f.readlines()))
コード例 #5
0
ファイル: embeddings.py プロジェクト: hahaxun/mmf
    def __init__(self, in_dim, weights_file, bias_file):
        super().__init__()
        with PathManager.open(weights_file, "rb") as w:
            weights = pickle.load(w)
        with PathManager.open(bias_file, "rb") as b:
            bias = pickle.load(b)
        out_dim = bias.shape[0]

        self.lc = nn.Linear(in_dim, out_dim)
        self.lc.weight.data.copy_(torch.from_numpy(weights))
        self.lc.bias.data.copy_(torch.from_numpy(bias))
        self.out_dim = out_dim
コード例 #6
0
ファイル: download.py プロジェクト: hahaxun/mmf
def download_from_google_drive(gd_id, destination, redownload=True):
    """
    Use the requests package to download a file from Google Drive.
    """
    download = not PathManager.isfile(destination) or redownload

    URL = "https://docs.google.com/uc?export=download"

    if not download:
        return download
    else:
        # Check first if link is live
        check_header(gd_id, from_google=True)

    with requests.Session() as session:
        response = session.get(URL, params={"id": gd_id}, stream=True)
        token = _get_confirm_token(response)

        if token:
            response.close()
            params = {"id": gd_id, "confirm": token}
            response = session.get(URL, params=params, stream=True)

        CHUNK_SIZE = 32768
        with PathManager.open(destination, "wb") as f:
            for chunk in response.iter_content(CHUNK_SIZE):
                if chunk:  # filter out keep-alive new chunks
                    f.write(chunk)
        response.close()

    return download
コード例 #7
0
ファイル: download.py プロジェクト: hahaxun/mmf
    def checksum(self, download_path):
        """
        Checksum on a given file.

        Args:
            download_path (string): path to the downloaded file.
        """
        if self._hashcode is None:
            print(f"[ Checksum not provided, skipping for {self._file_name}]")
            return

        sha256_hash = hashlib.sha256()
        destination = os.path.join(download_path, self._file_name)

        if not PathManager.isfile(destination):
            # File is not present, nothing to checksum
            return

        with PathManager.open(destination, "rb") as f:
            print(f"[ Starting checksum for {self._file_name}]")
            for byte_block in iter(lambda: f.read(65536), b""):
                sha256_hash.update(byte_block)
            if sha256_hash.hexdigest() != self._hashcode:
                # remove_dir(download_path)
                raise AssertionError(
                    f"[ Checksum for {self._file_name} from \n{self._url}\n"
                    "does not match the expected checksum. Please try again. ]"
                )
            else:
                print(f"[ Checksum successful for {self._file_name}]")
コード例 #8
0
ファイル: checkpoint.py プロジェクト: hahaxun/mmf
def _load_pretrained_checkpoint(checkpoint_path, *args, **kwargs):
    assert (os.path.splitext(checkpoint_path)[1] in ALLOWED_CHECKPOINT_EXTS
            ), f"Checkpoint must have extensions: {ALLOWED_CHECKPOINT_EXTS}"

    _hack_imports()

    with PathManager.open(checkpoint_path, "rb") as f:
        ckpt = torch.load(f, map_location=lambda storage, loc: storage)
    assert "config" in ckpt, (
        "No configs provided with pretrained model "
        " while checkpoint also doesn't have configuration.")
    config = ckpt.pop("config", None)
    model_config = config.get("model_config", config)

    ckpt = ckpt.get("model", ckpt)

    if "model_name" in kwargs:
        model_name = kwargs["model_name"]
    else:
        assert len(model_config.keys()
                   ) == 1, "Only one model type should be specified."
        model_name = list(model_config.keys())[0]

    model_config = model_config.get(model_name)
    return {"config": model_config, "checkpoint": ckpt, "full_config": config}
コード例 #9
0
 def _load_jsonl(self, path):
     with PathManager.open(path, "r") as f:
         db = f.readlines()
         for idx, line in enumerate(db):
             db[idx] = json.loads(line.strip("\n"))
         self.data = db
         self.start_idx = 0
コード例 #10
0
    def _load_json(self, path):
        with PathManager.open(path, "r") as f:
            data = json.load(f)
        self.metadata = data.get("metadata", {})
        self.data = data.get("data", [])

        if len(self.data) == 0:
            raise RuntimeError("Dataset is empty")
コード例 #11
0
ファイル: checkpoint.py プロジェクト: hahaxun/mmf
    def _torch_load(self, file):
        # Backwards compatibility to Pythia
        _hack_imports()

        with PathManager.open(file, "rb") as f:
            if "cuda" in str(self.device):
                return torch.load(f, map_location=self.device)
            else:
                return torch.load(f, map_location=lambda storage, loc: storage)
コード例 #12
0
 def _create_checkpoint_file(self, path):
     home = str(Path.home())
     data_dir = get_multimodelity_env(key="data_dir")
     model_folder = os.path.join(home, data_dir, "models",
                                 "mmbt.hateful_memes.images")
     model_file = os.path.join(model_folder, "model.pth")
     config_file = os.path.join(model_folder, "config.yaml")
     config = load_yaml(config_file)
     with PathManager.open(model_file, "rb") as f:
         ckpt = torch.load(f)
     ckpt["config"] = config
     torch.save(ckpt, path)
コード例 #13
0
    def checksum(self, file, hashes):
        sha256_hash = hashlib.sha256()
        destination = file

        with PathManager.open(destination, "rb") as f:
            print("Starting checksum for {}".format(os.path.basename(file)))
            for byte_block in iter(lambda: f.read(65536), b""):
                sha256_hash.update(byte_block)
            if sha256_hash.hexdigest() not in hashes:
                # remove_dir(download_path)
                raise AssertionError(
                    f"Checksum of downloaded file does not match the expected "
                    + "checksum. Please try again."
                )
            else:
                print("Checksum successful")
コード例 #14
0
ファイル: download.py プロジェクト: hahaxun/mmf
def mark_done(path, version_string=None):
    """
    Mark this path as prebuilt.

    Marks the path as done by adding a '.built' file with the current timestamp
    plus a version description string if specified.

    Args:
        path (str): The file path to mark as built
        version_string (str): The version of this dataset
    """
    data = {}
    data["created_at"] = str(datetime.datetime.today())
    data["version"] = version_string
    with PathManager.open(os.path.join(path, ".built.json"), "w") as f:
        json.dump(data, f)
コード例 #15
0
    def _load_npy(self, path):
        with PathManager.open(path, "rb") as f:
            self.db = np.load(f, allow_pickle=True)

        self.start_idx = 0

        if type(self.db) == dict:
            self.metadata = self.db.get("metadata", {})
            self.data = self.db.get("data", [])
        else:
            # TODO: Deprecate support for this
            self.metadata = {"version": 1}
            self.data = self.db
            # Handle old imdb support
            if "image_id" not in self.data[0]:
                self.start_idx = 1

        if len(self.data) == 0:
            self.data = self.db
コード例 #16
0
ファイル: checkpoint.py プロジェクト: hahaxun/mmf
def _load_pretrained_model(model_name_or_path, *args, **kwargs):
    if PathManager.exists(model_name_or_path):
        download_path = model_name_or_path
        model_name = model_name_or_path
    else:
        download_path = download_pretrained_model(model_name_or_path, *args,
                                                  **kwargs)
        model_name = model_name_or_path

    configs = glob.glob(os.path.join(download_path, "*.yaml"))
    assert len(configs) <= 1, (
        "Multiple yaml files with the pretrained model. " +
        "multimodelity doesn't know what to do.")

    ckpts = []
    allowed_ckpt_types = [f"*{ext}" for ext in ALLOWED_CHECKPOINT_EXTS]
    for ckpt_type in allowed_ckpt_types:
        ckpts.extend(glob.glob(os.path.join(download_path, ckpt_type)))

    assert (
        len(ckpts) == 1
    ), "None or multiple checkpoints files. multimodelity doesn't know what to do."

    _hack_imports()

    with PathManager.open(ckpts[0], "rb") as f:
        ckpt = torch.load(f, map_location=lambda storage, loc: storage)
    # If configs are not present, will ckpt provide the config?
    if len(configs) == 0:
        assert "config" in ckpt, (
            "No configs provided with pretrained model"
            " while checkpoint also doesn't have configuration.")
        config = ckpt["config"]
    else:
        config = load_yaml(configs[0])

    model_config = config.get("model_config", config)
    ckpt = ckpt.get("model", ckpt)
    # Also handle the case of model_name is path
    model_config = model_config.get(
        model_name.split(os.path.sep)[-1].split(".")[0])
    return {"config": model_config, "checkpoint": ckpt, "full_config": config}
コード例 #17
0
    def _download_model(self):
        _is_master = is_master()

        model_file_path = os.path.join(get_multimodelity_cache_dir(),
                                       "wiki.en.bin")

        if not _is_master:
            return model_file_path

        if PathManager.exists(model_file_path):
            logger.info(f"Vectors already present at {model_file_path}.")
            return model_file_path

        import requests
        from tqdm import tqdm

        from multimodelity.common.constants import FASTTEXT_WIKI_URL

        PathManager.mkdirs(os.path.dirname(model_file_path))
        response = requests.get(FASTTEXT_WIKI_URL, stream=True)

        with PathManager.open(model_file_path, "wb") as f:
            pbar = tqdm(
                total=int(response.headers["Content-Length"]) / 4096,
                miniters=50,
                disable=not _is_master,
            )

            idx = 0
            for data in response.iter_content(chunk_size=4096):
                if data:
                    if idx % 50 == 0:
                        pbar.update(len(data))
                    f.write(data)
                    idx += 1

            pbar.close()

        logger.info(f"fastText bin downloaded at {model_file_path}.")

        return model_file_path
コード例 #18
0
ファイル: download.py プロジェクト: hahaxun/mmf
def built(path, version_string=None):
    """
    Check if '.built' flag has been set for that task.

    If a version_string is provided, this has to match, or the version
    is regarded as not built.

    Version_string are generally the dataset version + the date the file was
    last updated. If this doesn't match, dataset will be mark not built. This makes
    sure that if we update our features or anything else features are updated
    for the end user.
    """
    if version_string:
        fname = os.path.join(path, ".built.json")
        if not PathManager.isfile(fname):
            return False
        else:
            with PathManager.open(fname, "r") as read:
                text = json.load(read)
            return text.get("version", None) == version_string
    else:
        return PathManager.isfile(os.path.join(path, ".built.json"))
コード例 #19
0
ファイル: checkpoint.py プロジェクト: hahaxun/mmf
 def finalize(self):
     if is_master():
         with PathManager.open(self.pth_filepath, "wb") as f:
             torch.save(self.trainer.model.state_dict(), f)
コード例 #20
0
ファイル: checkpoint.py プロジェクト: hahaxun/mmf
    def save(self, update, iteration=None, update_best=False):
        # Only save in main process
        if not is_master():
            return

        if not iteration:
            iteration = update

        ckpt_filepath = os.path.join(self.models_foldername,
                                     "model_%d.ckpt" % update)
        best_ckpt_filepath = os.path.join(self.ckpt_foldername,
                                          self.ckpt_prefix + "best.ckpt")
        current_ckpt_filepath = os.path.join(self.ckpt_foldername,
                                             self.ckpt_prefix + "current.ckpt")

        best_iteration = (self.trainer.early_stop_callback.early_stopping.
                          best_monitored_iteration)
        best_update = (self.trainer.early_stop_callback.early_stopping.
                       best_monitored_update)
        best_metric = (self.trainer.early_stop_callback.early_stopping.
                       best_monitored_value)
        model = self.trainer.model
        data_parallel = registry.get("data_parallel") or registry.get(
            "distributed")
        fp16_scaler = getattr(self.trainer, "scaler", None)
        fp16_scaler_dict = None

        if fp16_scaler is not None:
            fp16_scaler_dict = fp16_scaler.state_dict()

        if data_parallel is True:
            model = model.module

        ckpt = {
            "model": model.state_dict(),
            "optimizer": self.trainer.optimizer.state_dict(),
            "best_iteration": best_iteration,
            "current_iteration": iteration,
            "current_epoch": self.trainer.current_epoch,
            "num_updates": update,
            "best_update": best_update,
            "best_metric_value": best_metric,
            "fp16_scaler": fp16_scaler_dict,
            # Convert to container to avoid any dependencies
            "config": OmegaConf.to_container(self.config, resolve=True),
        }

        lr_scheduler = self.trainer.lr_scheduler_callback._scheduler
        if lr_scheduler is not None:
            ckpt["lr_scheduler"] = lr_scheduler.state_dict()

        if self.git_repo:
            git_metadata_dict = self._get_vcs_fields()
            ckpt.update(git_metadata_dict)

        with PathManager.open(ckpt_filepath, "wb") as f:
            torch.save(ckpt, f)

        if update_best:
            with PathManager.open(best_ckpt_filepath, "wb") as f:
                torch.save(ckpt, f)

        # Save current always
        with PathManager.open(current_ckpt_filepath, "wb") as f:
            torch.save(ckpt, f)

        # Remove old checkpoints if max_to_keep is set
        if self.max_to_keep > 0:
            if len(self.saved_iterations) == self.max_to_keep:
                self.remove(self.saved_iterations.pop(0))
            self.saved_iterations.append(update)
コード例 #21
0
ファイル: test_logistics.py プロジェクト: hahaxun/mmf
 def test_on_test_end(self):
     self.cb.on_test_end(report=self.report, meter=self.trainer.meter)
     f = PathManager.open(os.path.join(self.tmpdir, "train.log"))
     self.assertTrue(
         any("Finished run in" in line for line in f.readlines()))
コード例 #22
0
def default_loader(path):
    with PathManager.open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB")
コード例 #23
0
 def csv_dump(self, filepath):
     with PathManager.open(filepath, "w") as f:
         title = self.report[0].keys()
         cw = csv.DictWriter(f, title, delimiter=",", quoting=csv.QUOTE_MINIMAL)
         cw.writeheader()
         cw.writerows(self.report)
コード例 #24
0
    def __init__(self,
                 vocab_file=None,
                 embedding_dim=300,
                 data_dir=None,
                 *args,
                 **kwargs):
        """Vocab class to be used when you want to train word embeddings from
        scratch based on a custom vocab. This will initialize the random
        vectors for the vocabulary you pass. Get the vectors using
        `get_vectors` function. This will also create random embeddings for
        some predefined words like PAD - <pad>, SOS - <s>, EOS - </s>,
        UNK - <unk>.

        Parameters
        ----------
        vocab_file : str
            Path of the vocabulary file containing one word per line
        embedding_dim : int
            Size of the embedding

        """
        self.type = "base"
        self.word_dict = {}
        self.itos = {}

        self.itos[self.PAD_INDEX] = self.PAD_TOKEN
        self.itos[self.SOS_INDEX] = self.SOS_TOKEN
        self.itos[self.EOS_INDEX] = self.EOS_TOKEN
        self.itos[self.UNK_INDEX] = self.UNK_TOKEN

        self.word_dict[self.SOS_TOKEN] = self.SOS_INDEX
        self.word_dict[self.EOS_TOKEN] = self.EOS_INDEX
        self.word_dict[self.PAD_TOKEN] = self.PAD_INDEX
        self.word_dict[self.UNK_TOKEN] = self.UNK_INDEX

        index = len(self.itos.keys())

        self.total_predefined = len(self.itos.keys())

        if vocab_file is not None:
            if not os.path.isabs(vocab_file) and data_dir is not None:
                vocab_file = os.path.join(data_dir, vocab_file)
                vocab_file = get_absolute_path(vocab_file)

            if not PathManager.exists(vocab_file):
                raise RuntimeError("Vocab not found at " + vocab_file)

            with PathManager.open(vocab_file, "r") as f:
                for line in f:
                    self.itos[index] = line.strip()
                    self.word_dict[line.strip()] = index
                    index += 1

        self.word_dict[self.SOS_TOKEN] = self.SOS_INDEX
        self.word_dict[self.EOS_TOKEN] = self.EOS_INDEX
        self.word_dict[self.PAD_TOKEN] = self.PAD_INDEX
        self.word_dict[self.UNK_TOKEN] = self.UNK_INDEX
        # Return unk index by default
        self.stoi = defaultdict(self.get_unk_index)
        self.stoi.update(self.word_dict)

        self.vectors = torch.FloatTensor(self.get_size(), embedding_dim)
コード例 #25
0
ファイル: test_file_io.py プロジェクト: hahaxun/mmf
 def test_file_io_open(self):
     with PathManager.open(self._tmpfile, mode="r") as f:
         s = f.read()
     self.assertEqual(s, self._tmpfile_contents)
コード例 #26
0
 def json_dump(self, filepath):
     with PathManager.open(filepath, "w") as f:
         json.dump(self.report, f)
コード例 #27
0
def _cached_log_stream(filename):
    return PathManager.open(filename, "a")
コード例 #28
0
ファイル: checkpoint.py プロジェクト: hahaxun/mmf
 def save_config(self):
     cfg_file = os.path.join(self.ckpt_foldername, "config.yaml")
     with PathManager.open(cfg_file, "w") as f:
         f.write(self.config.pretty(resolve=True))
コード例 #29
0
ファイル: download.py プロジェクト: hahaxun/mmf
def download(url, path, fname, redownload=True, disable_tqdm=False):
    """
    Download file using `requests`.

    If ``redownload`` is set to false, then will not download tar file again if it is
    present (default ``True``).

    Returns whether download actually happened or not
    """
    outfile = os.path.join(path, fname)
    download = not PathManager.isfile(outfile) or redownload
    retry = 5
    exp_backoff = [2**r for r in reversed(range(retry))]

    pbar = None
    if download:
        # First test if the link is actually downloadable
        check_header(url)
        if not disable_tqdm:
            print("[ Downloading: " + url + " to " + outfile + " ]")
        pbar = tqdm.tqdm(unit="B",
                         unit_scale=True,
                         desc=f"Downloading {fname}",
                         disable=disable_tqdm)

    while download and retry >= 0:
        resume_file = outfile + ".part"
        resume = PathManager.isfile(resume_file)
        if resume:
            resume_pos = os.path.getsize(resume_file)
            mode = "ab"
        else:
            resume_pos = 0
            mode = "wb"
        response = None

        with requests.Session() as session:
            try:
                header = ({
                    "Range": "bytes=%d-" % resume_pos,
                    "Accept-Encoding": "identity"
                } if resume else {})
                response = session.get(url,
                                       stream=True,
                                       timeout=5,
                                       headers=header)

                # negative reply could be 'none' or just missing
                if resume and response.headers.get("Accept-Ranges",
                                                   "none") == "none":
                    resume_pos = 0
                    mode = "wb"

                CHUNK_SIZE = 32768
                total_size = int(response.headers.get("Content-Length", -1))
                # server returns remaining size if resuming, so adjust total
                total_size += resume_pos
                pbar.total = total_size
                done = resume_pos

                with PathManager.open(resume_file, mode) as f:
                    for chunk in response.iter_content(CHUNK_SIZE):
                        if chunk:  # filter out keep-alive new chunks
                            f.write(chunk)
                        if total_size > 0:
                            done += len(chunk)
                            if total_size < done:
                                # don't freak out if content-length was too small
                                total_size = done
                                pbar.total = total_size
                            pbar.update(len(chunk))
                    break
            except (
                    requests.exceptions.ConnectionError,
                    requests.exceptions.ReadTimeout,
            ):
                retry -= 1
                pbar.clear()
                if retry >= 0:
                    print("Connection error, retrying. (%d retries left)" %
                          retry)
                    time.sleep(exp_backoff[retry])
                else:
                    print("Retried too many times, stopped retrying.")
            finally:
                if response:
                    response.close()
    if retry < 0:
        raise RuntimeWarning(
            "Connection broken too many times. Stopped retrying.")

    if download and retry > 0:
        pbar.update(done - pbar.n)
        if done < total_size:
            raise RuntimeWarning("Received less data than specified in " +
                                 "Content-Length header for " + url +
                                 ". There may be a download problem.")
        move(resume_file, outfile)

    if pbar:
        pbar.close()

    return download
コード例 #30
0
def load_str_list(fname):
    with PathManager.open(fname) as f:
        lines = f.readlines()
    lines = [line.strip() for line in lines]
    return lines