コード例 #1
0
    def __init__(self, context: pytorch.PyTorchTrialContext) -> None:
        self.context = context

        model = torch.nn.Linear(1, 1, False)

        # Manually initialize the one weight to 0.
        model.weight.data.fill_(0)

        self.model = context.wrap_model(model)

        self.lr = 0.001

        opt = torch.optim.SGD(self.model.parameters(), self.lr)
        self.opt = context.wrap_optimizer(opt)

        self.loss_fn = torch.nn.MSELoss()

        self.cls_reducer = context.wrap_reducer(TriangleLabelSum(),
                                                name="cls_reducer")
        self.fn_reducer = context.wrap_reducer(triangle_label_sum,
                                               name="fn_reducer")

        self.hparams = self.context.get_hparams()
        if self.hparams.get("disable_dataset_reproducibility_checks"):
            self.context.experimental.disable_dataset_reproducibility_checks()
コード例 #2
0
    def __init__(self, context: PyTorchTrialContext, lightning_module: pl.LightningModule):
        check_compatibility(lightning_module)
        override_unsupported_nud(lightning_module, context)
        context.wrap_model(lightning_module)
        optimizers, lr_schedulers = self.setup_optimizers_schedulers(context, lightning_module)
        pls = _PLAdapterState(context, lightning_module, optimizers)
        self._pls = pls

        # set lightning_module properties
        pls.lm.use_ddp = False  # type: ignore
        pls.lm.use_ddp2 = False  # type: ignore
        pls.lm.use_dp = False  # type: ignore
        pls.lm.use_tpu = False  # type: ignore
        type(pls.lm).local_rank = context.distributed.get_local_rank()  # type: ignore
        type(pls.lm).global_rank = context.distributed.get_rank()  # type: ignore
        pls.lm.use_amp = context.experimental._auto_amp or context._use_apex
        pls.lm.to(context.device)
コード例 #3
0
    def __init__(self, context: pytorch.PyTorchTrialContext):
        self.context = context

        model = nn.Linear(1, 1, False)
        model.weight.data.fill_(0)

        self.model = context.wrap_model(model)

        opt = torch.optim.SGD(self.model.parameters(), 0.1)
        self.opt = context.wrap_optimizer(opt)
コード例 #4
0
    def __init__(self, context: pytorch.PyTorchTrialContext) -> None:
        self.context = context

        model = torch.nn.Linear(1, 1, False)
        model.weight.data.fill_(0)
        self.model = context.wrap_model(model)

        self.lr = 0.001

        optimizer = torch.optim.SGD(self.model.parameters(), self.lr)
        self.opt = context.wrap_optimizer(optimizer)

        self.loss_fn = torch.nn.MSELoss(reduction="mean")
コード例 #5
0
    def __init__(self, context: pytorch.PyTorchTrialContext) -> None:
        self.context = context

        model = torch.nn.Linear(1, 1, False)

        # Manually initialize the one weight to 0.
        model.weight.data.fill_(0)

        self.model = context.wrap_model(model)

        self.lr = 0.001

        opt = torch.optim.SGD(self.model.parameters(), self.lr)
        self.opt = context.wrap_optimizer(opt)

        self.loss_fn = torch.nn.MSELoss()
コード例 #6
0
    def __init__(self, context: pytorch.PyTorchTrialContext) -> None:
        self.context = context

        model = torch.nn.Linear(1, 1, False)

        # Manually initialize the one weight to 0.
        model.weight.data.fill_(0)

        self.model = context.wrap_model(model)

        self.lr = 0.001

        opt = torch.optim.SGD(self.model.parameters(), self.lr)
        self.opt = context.wrap_optimizer(opt)

        self.loss_fn = torch.nn.MSELoss()

        self.cls_reducer = context.wrap_reducer(TriangleLabelSum(),
                                                name="cls_reducer")
        self.fn_reducer = context.wrap_reducer(triangle_label_sum,
                                               name="fn_reducer")
コード例 #7
0
    def __init__(
        self,
        context: PyTorchTrialContext,
        lightning_module: pl.LightningModule,
        precision: Union[Literal[32], Literal[16]] = 32,
        amp_backend: Union[Literal["native"], Literal["apex"]] = "native",
        amp_level: Union[Literal["O0", "O1", "O2", "O3"]] = "O2",
    ):
        """
        This performs the necessary initialization steps to:

        1. check the compatibility of the provided ``LightningModule`` with ``LightningAdapter``.
        2. define a ``PyTorchTrial`` with models, optimizers, and LR schedulers that are provided
           by ``LightningModule``.
        3. patch the ``LightningModule`` methods that depend on a ``Trainer``.

        After inheriting this class, you need to override this function to initialize the adapted
        ``PyTorchTrial``.
        Within your ``__init__`` , you should instantiate the ``LightningModule`` and call
        ``super().__init__``.

        Here is a minimal code example.

        .. code-block:: python

            def __init__(self, context: PyTorchTrialContext) -> None:
                lm = mnist.LightningMNISTClassifier(lr=context.get_hparam('learning_rate'))
                super().__init__(context, lightning_module=lm)

        Arguments:
            context (PyTorchTrialContext)
            lightning_module (``LightningModule``):
                User-defined lightning module.
            precision (int, default=32):
                Precision to use.
                Accepted values are 16, and 32.
            amp_backend (str):
                Automatic mixed precision backend to use.
                Accepted values are "native", and "mixed".
            amp_level (str, optional, default="O2"):
                Apex amp optimization level.
                Accepted values are "O0", "O1", "O2", and "O3".
                https://nvidia.github.io/apex/amp.html#opt-levels-and-properties

        """

        check.check_in(precision, {16, 32},
                       "only precisions 16 & 32 are supported.")
        check.check_in(amp_backend, {"native", "apex"},
                       'only "native", and "apex" are supported')

        check_compatibility(lightning_module)
        override_unsupported_nud(lightning_module, context)

        if precision == 16 and amp_backend == "native":
            context.experimental.use_amp()

        context.wrap_model(lightning_module)

        pls = _LightningAdapterState(context, lightning_module, [], [])
        self._pls = pls
        pls.optimizers, pls.lr_schedulers = self.setup_optimizers_schedulers()

        if precision == 16 and amp_backend == "apex":
            context.configure_apex_amp(
                context.models,
                context.optimizers,
                enabled=True,
                opt_level=amp_level,
            )

        # set lightning_module properties
        pls.lm.use_ddp = False
        pls.lm.use_ddp2 = False
        pls.lm.use_dp = False
        pls.lm.use_tpu = False
        type(pls.lm).local_rank = context.distributed.get_local_rank(
        )  # type: ignore
        type(pls.lm).global_rank = context.distributed.get_rank(
        )  # type: ignore
        pls.lm.to(context.device)
        use_amp = context.experimental._auto_amp or context._use_apex
        pls.lm.use_amp = use_amp
        pls.lm.precision = "mixed" if use_amp else precision  # type: ignore