コード例 #1
0
    def configure_optimizers(self):
        # optimizer = getattr(optim, self.hparams.get("optimizer", "Adam"))
        # lr = self.hparams.get("lr", 1e-4)
        # weight_decay = self.hparams.get("weight_decay", 1e-6)
        # return optimizer(self.parameters(), lr=lr, weight_decay=weight_decay)

        optimizer = LARS(self.parameters(),
                         lr=(0.2 * (self.batch_size / 256)),
                         momentum=0.9,
                         weight_decay=self.hparams.get("weight_decay", 1e-6),
                         trust_coefficient=0.001)

        train_iters_per_epoch = 8000 // self.batch_size
        warmup_steps = train_iters_per_epoch * 10
        total_steps = train_iters_per_epoch * 100

        scheduler = {
            "scheduler":
            torch.optim.lr_scheduler.LambdaLR(
                optimizer,
                linear_warmup_decay(warmup_steps, total_steps, cosine=True),
            ),
            "interval":
            "step",
            "frequency":
            1,
        }

        return [optimizer], [scheduler]
コード例 #2
0
    def configure_optimizers(self):
        if self.exclude_bn_bias:
            params = self.exclude_from_wt_decay(self.named_parameters(), weight_decay=self.weight_decay)
        else:
            params = self.parameters()

        if self.optim == "lars":
            optimizer = LARS(
                params,
                lr=self.learning_rate,
                momentum=0.9,
                weight_decay=self.weight_decay,
                trust_coefficient=0.001,
            )
        elif self.optim == "adam":
            optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay)

        warmup_steps = self.train_iters_per_epoch * self.warmup_epochs
        total_steps = self.train_iters_per_epoch * self.max_epochs

        scheduler = {
            "scheduler": torch.optim.lr_scheduler.LambdaLR(
                optimizer,
                linear_warmup_decay(warmup_steps, total_steps, cosine=True),
            ),
            "interval": "step",
            "frequency": 1,
        }

        return [optimizer], [scheduler]