Esempio n. 1
0
    def configure_optimizers(self):
        logger.info(f"configure_optimizers lr={self.hparams.learning_rate}")

        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, self.parameters()),
            lr=self.hparams.learning_rate,
        )

        if not self.hparams.lr_scheduler:
            return optimizer

        scheduler = optim.lr_scheduler.LambdaLR(
            optimizer,
            linear_warmup_decay(self.hparams.lr_scheduler_warmup_steps,
                                self.hparams.lr_scheduler_total_steps,
                                cosine=True),
        )

        return ({
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1,
            },
        }, )
Esempio n. 2
0
    def configure_optimizers(self):
        # optimizer = getattr(optim, self.hparams.get("optimizer", "Adam"))
        # lr = self.hparams.get("lr", 1e-4)
        # weight_decay = self.hparams.get("weight_decay", 1e-6)
        # return optimizer(self.parameters(), lr=lr, weight_decay=weight_decay)

        optimizer = LARS(self.parameters(),
                         lr=(0.2 * (self.batch_size / 256)),
                         momentum=0.9,
                         weight_decay=self.hparams.get("weight_decay", 1e-6),
                         trust_coefficient=0.001)

        train_iters_per_epoch = 8000 // self.batch_size
        warmup_steps = train_iters_per_epoch * 10
        total_steps = train_iters_per_epoch * 100

        scheduler = {
            "scheduler":
            torch.optim.lr_scheduler.LambdaLR(
                optimizer,
                linear_warmup_decay(warmup_steps, total_steps, cosine=True),
            ),
            "interval":
            "step",
            "frequency":
            1,
        }

        return [optimizer], [scheduler]
Esempio n. 3
0
    def configure_optimizers(self):
        if self.exclude_bn_bias:
            params = self.exclude_from_wt_decay(self.named_parameters(), weight_decay=self.weight_decay)
        else:
            params = self.parameters()

        if self.optim == "lars":
            optimizer = LARS(
                params,
                lr=self.learning_rate,
                momentum=0.9,
                weight_decay=self.weight_decay,
                trust_coefficient=0.001,
            )
        elif self.optim == "adam":
            optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay)

        warmup_steps = self.train_iters_per_epoch * self.warmup_epochs
        total_steps = self.train_iters_per_epoch * self.max_epochs

        scheduler = {
            "scheduler": torch.optim.lr_scheduler.LambdaLR(
                optimizer,
                linear_warmup_decay(warmup_steps, total_steps, cosine=True),
            ),
            "interval": "step",
            "frequency": 1,
        }

        return [optimizer], [scheduler]