Exemplo n.º 1
0
def save_checkpoint(checkpoint_folder, state, checkpoint_file=CHECKPOINT_FILE):
    """
    Saves a state variable to the specified checkpoint folder. Returns filename
    of checkpoint if successful, and False otherwise.
    """

    # make sure that we have a checkpoint folder:
    if not PathManager.isdir(checkpoint_folder):
        try:
            PathManager.mkdirs(checkpoint_folder)
        except BaseException:
            logging.warning("Could not create folder %s." % checkpoint_folder,
                            exc_info=True)
    if not PathManager.isdir(checkpoint_folder):
        return False

    # write checkpoint atomically:
    try:
        full_filename = f"{checkpoint_folder}/{checkpoint_file}"
        with PathManager.open(full_filename, "wb") as f:
            torch.save(state, f)
        return full_filename
    except BaseException:
        logging.warning("Did not write checkpoint to %s." % checkpoint_folder,
                        exc_info=True)
        return False
Exemplo n.º 2
0
    def test_bad_args(self) -> None:
        with self.assertRaises(NotImplementedError):
            PathManager.copy(
                self._remote_uri,
                self._remote_uri,
                foo="foo"  # type: ignore
            )
        with self.assertRaises(NotImplementedError):
            PathManager.exists(self._remote_uri, foo="foo")  # type: ignore
        with self.assertRaises(ValueError):
            PathManager.get_local_path(
                self._remote_uri,
                foo="foo"  # type: ignore
            )
        with self.assertRaises(NotImplementedError):
            PathManager.isdir(self._remote_uri, foo="foo")  # type: ignore
        with self.assertRaises(NotImplementedError):
            PathManager.isfile(self._remote_uri, foo="foo")  # type: ignore
        with self.assertRaises(NotImplementedError):
            PathManager.ls(self._remote_uri, foo="foo")  # type: ignore
        with self.assertRaises(NotImplementedError):
            PathManager.mkdirs(self._remote_uri, foo="foo")  # type: ignore
        with self.assertRaises(ValueError):
            PathManager.open(self._remote_uri, foo="foo")  # type: ignore
        with self.assertRaises(NotImplementedError):
            PathManager.rm(self._remote_uri, foo="foo")  # type: ignore

        PathManager.set_strict_kwargs_checking(False)

        PathManager.get_local_path(self._remote_uri, foo="foo")  # type: ignore
        f = PathManager.open(self._remote_uri, foo="foo")  # type: ignore
        f.close()
        PathManager.set_strict_kwargs_checking(True)
Exemplo n.º 3
0
 def test_isdir(self):
     self.assertTrue(PathManager.isdir(self._tmpdir))
     # This is a file, not a directory, so it should fail
     self.assertFalse(PathManager.isdir(self._tmpfile))
     # This is a non-existing path, so it should fail
     fake_path = os.path.join(self._tmpdir, uuid.uuid4().hex)
     self.assertFalse(PathManager.isdir(fake_path))
Exemplo n.º 4
0
def get_local_path(input_file, dest_dir):
    """
    If user specified copying data to a local directory,
    get the local path where the data files were copied.

    - If input_file is just a file, we return the dest_dir/filename
    - If the intput_file is a directory, then we check if the
      environemt is SLURM and use slurm_dir or otherwise dest_dir
      to look up copy_complete file is available.
      If available, we return the directory.
    - If both above fail, we return the input_file as is.
    """
    out = ""
    if PathManager.isfile(input_file):
        out = os.path.join(dest_dir, os.path.basename(input_file))
    elif PathManager.isdir(input_file):
        data_name = input_file.strip("/").split("/")[-1]
        if "SLURM_JOBID" in os.environ:
            dest_dir = get_slurm_dir(dest_dir)
        dest_dir = os.path.join(dest_dir, data_name)
        complete_flag = os.path.join(dest_dir, "copy_complete")
        if PathManager.isfile(complete_flag):
            out = dest_dir
    if PathManager.exists(out):
        return out
    else:
        return input_file
Exemplo n.º 5
0
 def __init__(self, cfg, data_source, path, split, dataset_name):
     super(DiskImageDataset, self).__init__(
         queue_size=cfg["DATA"][split]["BATCHSIZE_PER_REPLICA"])
     assert data_source in [
         "disk_filelist",
         "disk_folder",
     ], "data_source must be either disk_filelist or disk_folder"
     if data_source == "disk_filelist":
         assert PathManager.isfile(path), f"File {path} does not exist"
     elif data_source == "disk_folder":
         assert PathManager.isdir(path), f"Directory {path} does not exist"
     self.cfg = cfg
     self.split = split
     self.dataset_name = dataset_name
     self.data_source = data_source
     self._path = path
     self.image_dataset = []
     self.is_initialized = False
     self._load_data(path)
     self._num_samples = len(self.image_dataset)
     if self.data_source == "disk_filelist":
         # Set dataset to null so that workers dont need to pickle this file.
         # This saves memory when disk_filelist is large, especially when memory mapping.
         self.image_dataset = []
     # whether to use QueueDataset class to handle invalid images or not
     self.enable_queue_dataset = cfg["DATA"][
         self.split]["ENABLE_QUEUE_DATASET"]
Exemplo n.º 6
0
Arquivo: io.py Projeto: iseessel/vissl
def copy_data(input_file, destination_dir, num_threads, tmp_destination_dir):
    """
    Copy data from one source to the other using num_threads. The data to copy
    can be a single file or a directory. We check what type of data and
    call the relevant functions.

    Returns:
        output_file (str): the new path of the data (could be file or dir)
        destination_dir (str): the destination dir that was actually used
    """
    # return whatever the input is: whether "", None or anything else.
    logging.info(f"Creating directory: {destination_dir}")
    if not (destination_dir is None or destination_dir == ""):
        makedir(destination_dir)
    else:
        destination_dir = None
    if PathManager.isfile(input_file):
        output_file, output_dir = copy_file(input_file, destination_dir,
                                            tmp_destination_dir)
    elif PathManager.isdir(input_file):
        output_file, output_dir = copy_dir(input_file, destination_dir,
                                           num_threads)
    else:
        raise RuntimeError("The input_file is neither a file nor a directory")
    return output_file, output_dir
Exemplo n.º 7
0
    def test_bad_args(self) -> None:
        # TODO (T58240718): Replace with dynamic checks
        with self.assertRaises(ValueError):
            PathManager.copy(
                self._tmpfile,
                self._tmpfile,
                foo="foo"  # type: ignore
            )
        with self.assertRaises(ValueError):
            PathManager.exists(self._tmpfile, foo="foo")  # type: ignore
        with self.assertRaises(ValueError):
            PathManager.get_local_path(self._tmpfile,
                                       foo="foo")  # type: ignore
        with self.assertRaises(ValueError):
            PathManager.isdir(self._tmpfile, foo="foo")  # type: ignore
        with self.assertRaises(ValueError):
            PathManager.isfile(self._tmpfile, foo="foo")  # type: ignore
        with self.assertRaises(ValueError):
            PathManager.ls(self._tmpfile, foo="foo")  # type: ignore
        with self.assertRaises(ValueError):
            PathManager.mkdirs(self._tmpfile, foo="foo")  # type: ignore
        with self.assertRaises(ValueError):
            PathManager.open(self._tmpfile, foo="foo")  # type: ignore
        with self.assertRaises(ValueError):
            PathManager.rm(self._tmpfile, foo="foo")  # type: ignore

        PathManager.set_strict_kwargs_checking(False)

        PathManager.copy(
            self._tmpfile,
            self._tmpfile,
            foo="foo"  # type: ignore
        )
        PathManager.exists(self._tmpfile, foo="foo")  # type: ignore
        PathManager.get_local_path(self._tmpfile, foo="foo")  # type: ignore
        PathManager.isdir(self._tmpfile, foo="foo")  # type: ignore
        PathManager.isfile(self._tmpfile, foo="foo")  # type: ignore
        PathManager.ls(self._tmpdir, foo="foo")  # type: ignore
        PathManager.mkdirs(self._tmpdir, foo="foo")  # type: ignore
        f = PathManager.open(self._tmpfile, foo="foo")  # type: ignore
        f.close()
        # pyre-ignore
        with open(os.path.join(self._tmpdir, "test_rm.txt"), "w") as f:
            rm_file = f.name
            f.write(self._tmpfile_contents)
            f.flush()
        PathManager.rm(rm_file, foo="foo")  # type: ignore
Exemplo n.º 8
0
 def __init__(self, cfg: AttrDict, data_source: str, path: str, split: str,
              dataset_name: str):
     super().__init__()
     assert PathManager.isdir(path), f"Directory {path} does not exist"
     self.dataset_name = dataset_name
     self.path = path
     self.split = split.lower()
     self.dataset = self._load_dataset()
Exemplo n.º 9
0
 def __init__(self, cfg: AttrDict, path: str, split: str, dataset_name="fastmri_dataset", data_source="fastmri"):
     super(FastMRIDataSet, self).__init__()
     
     assert PathManager.isdir(path), f"Directory {path} does not exist"
     self.dataset_name = "singlecoil"
     self.data_source = "fastmri"
     self.path = path
     
     data = cfg.get("DATA", AttrDict({}))
     self.key = data.get("KEY", "reconstruction_esc")
     self.index = data.get("INDEX", 12)
     self.split = split.lower()
     self.dataset = self._load_data()
Exemplo n.º 10
0
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode): configs. Details can be found in
                slowfast/config/defaults.py
        """
        self.source = PathManager.get_local_path(path=cfg.DEMO.INPUT_VIDEO)
        self.fps = None
        if PathManager.isdir(self.source):
            self.fps = cfg.DEMO.FPS
            self.video_name = self.source.split("/")[-1]
            self.source = os.path.join(self.source,
                                       "{}_%06d.jpg".format(self.video_name))
        else:
            self.video_name = self.source.split("/")[-1]
            self.video_name = self.video_name.split(".")[0]

        self.cfg = cfg
        self.cap = cv2.VideoCapture(self.source)
        if self.fps is None:
            self.fps = self.cap.get(cv2.CAP_PROP_FPS)

        self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))

        self.display_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.display_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

        if not self.cap.isOpened():
            raise IOError("Video {} cannot be opened".format(self.source))

        self.output_file = None

        if cfg.DEMO.OUTPUT_FILE != "":
            self.output_file = self.get_output_file(cfg.DEMO.OUTPUT_FILE)

        self.pred_boxes, self.gt_boxes = load_boxes_labels(
            cfg,
            self.video_name,
            self.fps,
            self.display_width,
            self.display_height,
        )

        self.seq_length = cfg.DATA.NUM_FRAMES * cfg.DATA.SAMPLING_RATE
        self.no_frames_repeat = cfg.DEMO.SLOWMO
Exemplo n.º 11
0
def load_checkpoint(
    checkpoint_path: str, device: torch.device = CPU_DEVICE
) -> Optional[Dict]:
    """Loads a checkpoint from the specified checkpoint path.

    Args:
        checkpoint_path: The path to load the checkpoint from. Can be a file or a
            directory. If it is a directory, the checkpoint is loaded from
            :py:data:`CHECKPOINT_FILE` inside the directory.
        device: device to load the checkpoint to

    Returns:
        The checkpoint, if it exists, or None.
    """
    if not checkpoint_path:
        return None

    assert device is not None, "Please specify what device to load checkpoint on"
    assert device.type in ["cpu", "cuda"], f"Unknown device: {device}"
    if device.type == "cuda":
        assert torch.cuda.is_available()

    if not PathManager.exists(checkpoint_path):
        logging.warning(f"Checkpoint path {checkpoint_path} not found")
        return None
    if PathManager.isdir(checkpoint_path):
        checkpoint_path = f"{checkpoint_path.rstrip('/')}/{CHECKPOINT_FILE}"

    if not PathManager.exists(checkpoint_path):
        logging.warning(f"Checkpoint file {checkpoint_path} not found.")
        return None

    logging.info(f"Attempting to load checkpoint from {checkpoint_path}")
    # load model on specified device and not on saved device for model and return
    # the checkpoint
    with PathManager.open(checkpoint_path, "rb") as f:
        checkpoint = torch.load(f, map_location=device)
    logging.info(f"Loaded checkpoint from {checkpoint_path}")
    return checkpoint