Пример #1
0
    def __init__(self, context: PyTorchTrialContext) -> None:
        self.context = context
        self.data_config = context.get_data_config()
        self.num_classes = {
            "train": context.get_hparam("num_classes_train"),
            "val": context.get_hparam("num_classes_val"),
        }
        self.num_support = {
            "train": context.get_hparam("num_support_train"),
            "val": context.get_hparam("num_support_val"),
        }
        self.num_query = {
            "train": context.get_hparam("num_query_train"),
            "val":
            None,  # Use all available examples for val at meta-test time
        }
        self.get_train_valid_splits()

        x_dim = 1  # Omniglot is black and white
        hid_dim = self.context.get_hparam("hidden_dim")
        z_dim = self.context.get_hparam("embedding_dim")

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
                nn.MaxPool2d(2),
            )

        self.model = self.context.wrap_model(
            nn.Sequential(
                conv_block(x_dim, hid_dim),
                conv_block(hid_dim, hid_dim),
                conv_block(hid_dim, hid_dim),
                conv_block(hid_dim, z_dim),
                Flatten(),
            ))

        self.optimizer = self.context.wrap_optimizer(
            torch.optim.Adam(
                self.model.parameters(),
                lr=self.context.get_hparam("learning_rate"),
                weight_decay=self.context.get_hparam("weight_decay"),
            ))

        self.lr_scheduler = self.context.wrap_lr_scheduler(
            torch.optim.lr_scheduler.StepLR(
                self.optimizer,
                self.context.get_hparam("reduce_every"),
                gamma=self.context.get_hparam("lr_gamma"),
            ), LRScheduler.StepMode.STEP_EVERY_EPOCH)
Пример #2
0
    def __init__(self, context: PyTorchTrialContext) -> None:
        data_dir = f'/tmp/data-rank{context.distributed.get_rank()}'
        self.dm = gan.MNISTDataModule(context.get_data_config()['url'], data_dir,
                                      batch_size=context.get_per_slot_batch_size())
        channels, width, height = self.dm.size()
        lm = gan.GAN(channels, width, height,
                    batch_size=context.get_per_slot_batch_size(),
                    lr=context.get_hparam('lr'),
                    b1=context.get_hparam('b1'),
                    b2=context.get_hparam('b2'),
        )

        super().__init__(context, lightning_module=lm)
        self.dm.prepare_data()
Пример #3
0
    def __init__(self, context: PyTorchTrialContext, *args, **kwargs) -> None:
        lm = mnist.LitMNIST(
            hidden_size=context.get_hparam('hidden_size'),
            learning_rate=context.get_hparam('learning_rate'),
        )
        data_dir = f"/tmp/data-rank{context.distributed.get_rank()}"
        self.dm = data.MNISTDataModule(
            data_url=context.get_data_config()["url"],
            data_dir=data_dir,
            batch_size=context.get_per_slot_batch_size(),
        )

        super().__init__(context, lightning_module=lm, *args, **kwargs)
        self.dm.prepare_data()
Пример #4
0
def test_get_pretrained_weights(
    mmdet_config_dir: None, context: det_torch.PyTorchTrialContext
) -> None:
    mh_mmdet.utils.CONFIG_TO_PRETRAINED = mh_mmdet.utils.get_config_pretrained_url_mapping()
    path, ckpt = mh_mmdet.get_pretrained_ckpt_path("/tmp", context.get_hparam("config_file"))
    assert path is not None
    assert ckpt is not None
Пример #5
0
    def __init__(self, context: PyTorchTrialContext) -> None:
        lm = mnist.LightningMNISTClassifier(
            lr=context.get_hparam('learning_rate'))
        data_dir = f"/tmp/data-rank{context.distributed.get_rank()}"
        self.dm = mnist.MNISTDataModule(context.get_data_config()["url"],
                                        data_dir)

        super().__init__(context, lightning_module=lm)
        self.dm.prepare_data()
Пример #6
0
    def __init__(self, context: pytorch.PyTorchTrialContext):
        self.context = context
        self.dataset_len = context.get_hparam("dataset_len")

        model = nn.Linear(1, 1, False)
        model.weight.data.fill_(0)

        self.model = context.wrap_model(model)

        opt = torch.optim.SGD(self.model.parameters(), 0.1)
        self.opt = context.wrap_optimizer(opt)
Пример #7
0
    def __init__(self, context: PyTorchTrialContext) -> None:
        self.context = context
        self.data_config = context.get_data_config()
        self.criterion = CrossEntropyLabelSmooth(
            context.get_hparam("num_classes"),  # num classes
            context.get_hparam("label_smoothing_rate"),
        )
        self.last_epoch_idx = -1

        self.model = self.context.wrap_model(self.build_model_from_config())

        self.optimizer = self.context.wrap_optimizer(
            torch.optim.SGD(
                self.model.parameters(),
                lr=self.context.get_hparam("learning_rate"),
                momentum=self.context.get_hparam("momentum"),
                weight_decay=self.context.get_hparam("weight_decay"),
            ))

        self.lr_scheduler = self.context.wrap_lr_scheduler(
            self.build_lr_scheduler_from_config(self.optimizer),
            step_mode=LRScheduler.StepMode.STEP_EVERY_EPOCH,
        )