Esempio n. 1
0
    def __init__(self,
                 config,
                 world_rank,
                 device_ids=None,
                 use_gpu=False,
                 use_fp16=False,
                 use_tqdm=False,
                 apex_args=None,
                 wrap_ddp=False,
                 wrap_distributed_sampler=False,
                 add_dist_sampler=False,
                 scheduler_step_freq=None):
        # You are not expected to override this method.
        self._world_rank = world_rank
        self._config = config
        self._use_fp16 = use_fp16
        self._device_ids = device_ids
        self._use_gpu = use_gpu and torch.cuda.is_available()
        self._device = torch.device("cuda" if self._use_gpu else "cpu")
        if tqdm is None and use_tqdm:
            raise ValueError("tqdm must be installed to use tqdm in training.")
        self._use_tqdm = use_tqdm
        self.global_step = 0
        self._apex_args = apex_args if apex_args else {}
        self._wrap_ddp = wrap_ddp
        self._wrap_distributed_sampler = wrap_distributed_sampler
        self._add_dist_sampler = add_dist_sampler
        self._scheduler_step_freq = scheduler_step_freq

        self.timers = TimerCollection()
        self.setup(config)
Esempio n. 2
0
    def __init__(
        self,
        config,
        world_rank,
        local_rank,
        is_distributed,
        use_gpu,
        device,
        use_fp16=False,
        use_tqdm=False,
        wrap_ddp=False,
        add_dist_sampler=False,
        scheduler_step_freq=None,
    ):

        # You are not expected to override this method.
        self._world_rank = world_rank
        self._local_rank = local_rank
        self._config = config
        self._is_distributed = is_distributed
        self._use_fp16 = choose_amp_backend(use_fp16, amp, apex_amp)
        self._device = device
        self._use_gpu = use_gpu and torch.cuda.is_available()
        if tqdm is None and use_tqdm:
            raise ValueError("tqdm must be installed to use tqdm in training.")
        self._use_tqdm = use_tqdm
        self.global_step = 0
        self._wrap_ddp = wrap_ddp
        self._add_dist_sampler = add_dist_sampler
        self._scheduler_step_freq = scheduler_step_freq

        self.timers = TimerCollection()
        self.setup(config)
Esempio n. 3
0
    def __init__(self,
                 config,
                 models,
                 optimizers,
                 criterion=None,
                 schedulers=None,
                 use_fp16=False):
        # You are not expected to override this method.
        self._models = models  # List of models
        assert isinstance(models, collections.Iterable), (
            "Components need to be iterable. Got: {}".format(type(models)))
        self._optimizers = optimizers  # List of optimizers
        assert isinstance(optimizers, collections.Iterable), (
            "Components need to be iterable. Got: {}".format(type(optimizers)))
        self._criterion = criterion
        self._schedulers = schedulers
        if schedulers:
            assert isinstance(schedulers, collections.Iterable), (
                "Components need to be iterable. Got: {}".format(
                    type(schedulers)))
        self._config = config
        self._use_fp16 = use_fp16
        self.global_step = 0

        if type(self) is TrainingOperator:
            for component in (models, schedulers, optimizers):
                if _is_multiple(component):
                    raise ValueError(
                        "Need to provide a custom operator subclassing "
                        "TrainingOperator if using multi-scheduler, "
                        "multi-model or multi-optimizer training/validation.")
        self.timers = TimerCollection()
        self.setup(config)
Esempio n. 4
0
    def __init__(self,
                 config,
                 models,
                 optimizers,
                 train_loader,
                 validation_loader,
                 world_rank,
                 criterion=None,
                 schedulers=None,
                 device_ids=None,
                 use_gpu=False,
                 use_fp16=False,
                 use_tqdm=False):
        # You are not expected to override this method.
        self._models = models  # List of models
        assert isinstance(
            models,
            Iterable), ("Components need to be iterable. Got: {}".format(
                type(models)))
        self._optimizers = optimizers  # List of optimizers
        assert isinstance(
            optimizers,
            Iterable), ("Components need to be iterable. Got: {}".format(
                type(optimizers)))
        self._train_loader = train_loader
        self._validation_loader = validation_loader
        self._world_rank = world_rank
        self._criterion = criterion
        self._schedulers = schedulers
        if schedulers:
            assert isinstance(
                schedulers,
                Iterable), ("Components need to be iterable. Got: {}".format(
                    type(schedulers)))
        self._config = config
        self._use_fp16 = use_fp16
        self._device_ids = device_ids
        self._use_gpu = use_gpu and torch.cuda.is_available()
        self._device = torch.device("cuda" if self._use_gpu else "cpu")
        if tqdm is None and use_tqdm:
            raise ValueError("tqdm must be installed to use tqdm in training.")
        self._use_tqdm = use_tqdm
        self.global_step = 0

        if type(self) is TrainingOperator:
            for component in (models, schedulers, optimizers):
                if _is_multiple(component):
                    raise ValueError(
                        "Need to provide a custom operator subclassing "
                        "TrainingOperator if using multi-scheduler, "
                        "multi-model or multi-optimizer training/validation.")
        self.timers = TimerCollection()
        self.setup(config)