Example #1
0
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self._per_slot_batch_size, self._global_batch_size = util.calculate_batch_sizes(
         self.get_hparams(),
         self.env.experiment_config.slots_per_trial(),
         "NoOpTrial",
     )
Example #2
0
    def __init__(self, *arg: Any, **kwarg: Any) -> None:
        det.TrialContext.__init__(self, *arg, **kwarg)
        estimator._EstimatorReducerContext.__init__(self,
                                                    self.distributed.allgather)

        self._per_slot_batch_size, self._global_batch_size = util.calculate_batch_sizes(
            self.get_hparams(),
            self.env.experiment_config.slots_per_trial(),
            "EstimatorTrial",
        )

        self.experimental = EstimatorExperimentalContext(
            self.env,
            self.distributed,
            self._per_slot_batch_size,
        )

        if self.distributed.size > 1:
            optimizations_config = self.env.experiment_config.get_optimizations_config(
            )
            self.aggregation_frequency = cast(
                int, optimizations_config.get("aggregation_frequency"))
            self.fp16_compression = cast(
                bool, optimizations_config.get("gradient_compression"))
            self.average_aggregated_gradients = cast(
                bool, optimizations_config.get("average_aggregated_gradients"))

        self.optimizer_initialized = False
        self.dataset_initialized = False
Example #3
0
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        det.TrialContext.__init__(self, *args, **kwargs)
        pytorch._PyTorchReducerContext.__init__(self,
                                                self.distributed.allgather)
        self._per_slot_batch_size, self._global_batch_size = util.calculate_batch_sizes(
            self.get_hparams(),
            self.env.experiment_config.slots_per_trial(),
            "PyTorchTrial",
        )

        self._distributed_backend = det._DistributedBackend()

        self.device = self._init_device()

        # Track which types we have issued warnings for in to_device().
        self._to_device_warned_types = set()  # type: Set[Type]

        # The following attributes are initialized during the lifetime of
        # a PyTorchTrialContext.
        self.models = []  # type: List[nn.Module]
        self.optimizers = []  # type: List[torch.optim.Optimizer]
        self.profiler = None  # type: Any
        self.lr_schedulers = []  # type: List[pytorch.LRScheduler]
        self._epoch_len = None  # type: Optional[int]

        # Keep a map of wrapped models to their original input forms, which is needed
        # by torch DDP and apex to initialize in the correct order
        self._wrapped_models = {}  # type: Dict[nn.Module, nn.Module]

        # Use a main model to contain all of the models because when using horovod
        # to broadcast the states of models we want to avoid name conflicts for these
        # states so we set all the models to be sub-module of the main model with
        # different names using __setattr__ and use the state_dict of the main model
        # for broadcasting. Note that broadcast_parameters only accepts state_dict()
        # although its doc says it also accepts named_parameters()
        self._main_model = nn.Module()
        self._scaler = None
        self._use_apex = False
        self._loss_ids = {}  # type: Dict[torch.Tensor, int]
        self._last_backward_batch_idx = None  # type: Optional[int]
        self._current_batch_idx = None  # type: Optional[int]

        self.experimental = pytorch.PyTorchExperimentalContext(self)
        self._reducers = pytorch._PyTorchReducerContext()
        self._determined_profiler = None  # type: Optional[profiler.ProfilerAgent]

        optimizations_config = self.env.experiment_config.get_optimizations_config(
        )
        self._aggregation_frequency = cast(
            int, optimizations_config.get("aggregation_frequency"))
        self._fp16_compression = cast(
            bool, optimizations_config.get("gradient_compression"))
        self._average_aggregated_gradients = cast(
            bool, optimizations_config.get("average_aggregated_gradients"))
        self._average_training_metrics = cast(
            bool, optimizations_config.get("average_training_metrics"))
Example #4
0
    def __init__(self, *arg: Any, **kwarg: Any):
        super().__init__(*arg, **kwarg)

        self.dataset_initialized = False

        self._per_slot_batch_size, self._global_batch_size = util.calculate_batch_sizes(
            self.get_hparams(),
            self.env.experiment_config.slots_per_trial(),
            "TFKerasTrial",
        )

        self.experimental = TFKerasExperimentalContext(
            self.env,
            self.distributed,
            self._per_slot_batch_size,
        )

        # The following three attributes are initialized during the lifetime of a
        # TFKerasTrialContext instance by the user calling compile() and
        # fit_generator() / fit(), respectively.
        self.model = None  # type: Optional[tf.keras.Model]
        self.compile_args = None  # type: Optional[inspect.BoundArguments]
        self.train_config = None  # type: Optional[TFKerasTrainConfig]

        optimizations_config = self.env.experiment_config.get_optimizations_config(
        )
        self._aggregation_frequency = cast(
            int, optimizations_config.get("aggregation_frequency"))
        self._average_aggregated_gradients = cast(
            bool, optimizations_config.get("average_aggregated_gradients"))

        self._optimizers = []  # type: List[tf.keras.optimizers.Optimizer]
        self._wrapped_optimizers = [
        ]  # type: List[tf.keras.optimizers.Optimizer]
        self._compiled_optimizer = None  # type: Optional[tf.keras.optimizers.Optimizer]

        # The following attributes may be configured via configure_fit().  Defaults match the
        # normal keras.fit() defaults.
        self._fit_verbose = True
        self._fit_class_weight = None
        self._fit_workers = 1
        self._fit_use_multiprocessing = False
        self._fit_max_queue_size = 10
        self._fit_shuffle = True
        self._fit_validation_steps = None