コード例 #1
0
 def random_sample(img_scales):
     assert torchie.is_list_of(img_scales, tuple) and len(img_scales) == 2
     img_scale_long = [max(s) for s in img_scales]
     img_scale_short = [min(s) for s in img_scales]
     long_edge = np.random.randint(min(img_scale_long), max(img_scale_long) + 1)
     short_edge = np.random.randint(min(img_scale_short), max(img_scale_short) + 1)
     img_scale = (long_edge, short_edge)
     return img_scale, None
コード例 #2
0
    def run(self, data_loaders, workflow, max_epochs, **kwargs):
        """ Start running.
            Args:
                data_loaders (list[:obj:`DataLoader`]);
                workflow (list[tuple]): A list of (phase, epochs) to specify the running order and epochs;
                max_epochs (int);
        """
        assert isinstance(data_loaders, list)
        assert torchie.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)

        self._max_epochs = max_epochs
        work_dir = self.work_dir if self.work_dir is not None else "NONE"
        self.logger.info("Start running, host: %s, work_dir: %s",
                         get_host_info(), work_dir)
        self.logger.info("workflow: %s, max: %d epochs", workflow, max_epochs)
        self.call_hook("before_run")  # for summarywriter

        while self.epoch < max_epochs:
            for i, flow in enumerate(workflow):

                mode, epochs = flow
                if isinstance(mode, str):
                    if not hasattr(self, mode):
                        raise ValueError(
                            "Trainer has no method named '{}' to run an epoch".
                            format(mode))
                    epoch_runner = getattr(
                        self, mode)  # todo: get self.train() or self.val()
                elif callable(mode):
                    epoch_runner = mode
                else:
                    raise TypeError(
                        "mode in workflow must be a str or callable function not '{}'"
                        .format(type(mode)))

                for _ in range(
                        epochs
                ):  # todo: epoches=1 for val mode; epoches=5 for train mode
                    if mode == "train" and self.epoch > max_epochs:
                        return
                    # todo: modified by zhengwu, to eval in last epoch
                    elif mode == "train" and self.epoch == max_epochs:
                        epoch_runner = getattr(self, "val")
                        epoch_runner(data_loaders[1], **kwargs)
                        return
                    elif mode == "val":
                        epoch_runner(data_loaders[i], **kwargs)
                    else:
                        epoch_runner = getattr(self, "train")
                        epoch_runner(data_loaders[i], self.epoch, **kwargs)
                        if 55 <= self.epoch <= 59:
                            epoch_runner = getattr(self, "val")
                            epoch_runner(data_loaders[1], **kwargs)
        self.call_hook("after_run")
コード例 #3
0
    def run(self, data_loaders, workflow, max_epochs, **kwargs):
        """ Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`])
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs.
            max_epochs (int)
        """
        assert isinstance(data_loaders, list)
        assert torchie.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)

        self._max_epochs = max_epochs
        work_dir = self.work_dir if self.work_dir is not None else "NONE"
        self.logger.info("Start running, host: %s, work_dir: %s",
                         get_host_info(), work_dir)
        self.logger.info("workflow: %s, max: %d epochs", workflow, max_epochs)
        self.call_hook("before_run")

        while self.epoch < max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                if isinstance(mode, str):
                    if not hasattr(self, mode):
                        raise ValueError(
                            "Trainer has no method named '{}' to run an epoch".
                            format(mode))
                    epoch_runner = getattr(self, mode)
                elif callable(mode):
                    epoch_runner = mode
                else:
                    raise TypeError("mode in workflow must be a str or "
                                    "callable function not '{}'".format(
                                        type(mode)))

                for _ in range(epochs):
                    # mode='val'
                    # self.epoch=1
                    if mode == "train" and self.epoch >= max_epochs:
                        return
                    elif mode == "val":
                        epoch_runner(data_loaders[i], **kwargs)
                    else:
                        epoch_runner(data_loaders[i], self.epoch, **kwargs)

        # time.sleep(1)
        self.call_hook("after_run")
コード例 #4
0
    def __init__(
        self, img_scale=None, multiscale_mode="range", ratio_range=None, keep_ratio=True
    ):
        if img_scale is None:
            self.img_scale = None
        else:
            if isinstance(img_scale, list):
                self.img_scale = img_scale
            else:
                self.img_scale = [img_scale]
            assert torchie.is_list_of(self.img_scale, tuple)

        if ratio_range is not None:
            # mode 1: given a scale and a range of image ratio
            assert len(self.img_scale) == 1
        else:
            # mode 2: given multiple scales or a range of scales
            assert multiscale_mode in ["value", "range"]

        self.multiscale_mode = multiscale_mode
        self.ratio_range = ratio_range
        self.keep_ratio = keep_ratio
コード例 #5
0
 def random_select(img_scales):
     assert torchie.is_list_of(img_scales, tuple)
     scale_idx = np.random.randint(len(img_scales))
     img_scale = img_scales[scale_idx]
     return img_scale, scale_idx
コード例 #6
0
ファイル: test_aug.py プロジェクト: Vegeta2020/SE-SSD
 def __init__(self, transforms, img_scale, flip=False):
     self.transforms = Compose(transforms)
     self.img_scale = img_scale if isinstance(img_scale, list) else [img_scale]
     assert torchie.is_list_of(self.img_scale, tuple)
     self.flip = flip