def configure_optimizers(self): logger.info(f"configure_optimizers lr={self.hparams.learning_rate}") optimizer = optim.Adam( filter(lambda p: p.requires_grad, self.parameters()), lr=self.hparams.learning_rate, ) if not self.hparams.lr_scheduler: return optimizer scheduler = optim.lr_scheduler.LambdaLR( optimizer, linear_warmup_decay(self.hparams.lr_scheduler_warmup_steps, self.hparams.lr_scheduler_total_steps, cosine=True), ) return ({ "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "interval": "step", "frequency": 1, }, }, )
def configure_optimizers(self): # optimizer = getattr(optim, self.hparams.get("optimizer", "Adam")) # lr = self.hparams.get("lr", 1e-4) # weight_decay = self.hparams.get("weight_decay", 1e-6) # return optimizer(self.parameters(), lr=lr, weight_decay=weight_decay) optimizer = LARS(self.parameters(), lr=(0.2 * (self.batch_size / 256)), momentum=0.9, weight_decay=self.hparams.get("weight_decay", 1e-6), trust_coefficient=0.001) train_iters_per_epoch = 8000 // self.batch_size warmup_steps = train_iters_per_epoch * 10 total_steps = train_iters_per_epoch * 100 scheduler = { "scheduler": torch.optim.lr_scheduler.LambdaLR( optimizer, linear_warmup_decay(warmup_steps, total_steps, cosine=True), ), "interval": "step", "frequency": 1, } return [optimizer], [scheduler]
def configure_optimizers(self): if self.exclude_bn_bias: params = self.exclude_from_wt_decay(self.named_parameters(), weight_decay=self.weight_decay) else: params = self.parameters() if self.optim == "lars": optimizer = LARS( params, lr=self.learning_rate, momentum=0.9, weight_decay=self.weight_decay, trust_coefficient=0.001, ) elif self.optim == "adam": optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay) warmup_steps = self.train_iters_per_epoch * self.warmup_epochs total_steps = self.train_iters_per_epoch * self.max_epochs scheduler = { "scheduler": torch.optim.lr_scheduler.LambdaLR( optimizer, linear_warmup_decay(warmup_steps, total_steps, cosine=True), ), "interval": "step", "frequency": 1, } return [optimizer], [scheduler]