def test_lr(self): model = nn.Linear(10, 5) optimizer = SGD(model.parameters(), lr=0.1) scheduler = LinearWarmupScheduler(optimizer, total_epoch=5) for i in range(5): current_lr = optimizer.param_groups[0]['lr'] self.assertEqual(current_lr, 0.1 * (i / 5)) optimizer.step() scheduler.step() current_lr = optimizer.param_groups[0]['lr'] self.assertEqual(current_lr, 0.1)
def get_optim(args, params): assert args.optimizer in optim_choices # Base optimizer if args.optimizer == 'sgd': optimizer = optim.SGD(params, lr=args.lr, momentum=args.momentum) elif args.optimizer == 'adam': optimizer = optim.Adam(params, lr=args.lr, betas=(args.momentum, args.momentum_sqr)) elif args.optimizer == 'adamax': optimizer = optim.Adamax(params, lr=args.lr, betas=(args.momentum, args.momentum_sqr)) # warmup LR if args.warmup is not None and args.warmup > 0: scheduler_iter = LinearWarmupScheduler(optimizer, total_epoch=args.warmup) else: scheduler_iter = None # Exponentially decay LR if args.exponential_lr: scheduler_epoch = ExponentialLR(optimizer, gamma=0.995) else: scheduler_epoch = None return optimizer, scheduler_iter, scheduler_epoch
def get_optim(args, model): assert args.optimizer in optim_choices if args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) elif args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.momentum, args.momentum_sqr)) elif args.optimizer == 'adamax': optimizer = optim.Adamax(model.parameters(), lr=args.lr, betas=(args.momentum, args.momentum_sqr)) if args.warmup is not None: scheduler_iter = LinearWarmupScheduler(optimizer, total_epoch=args.warmup) else: scheduler_iter = None scheduler_epoch = ExponentialLR(optimizer, gamma=args.gamma) return optimizer, scheduler_iter, scheduler_epoch
transforms=transforms).to(args.device) if not args.train: state_dict = torch.load('models/{}.pt'.format(run_name)) model.load_state_dict(state_dict) ####################### ## Specify optimizer ## ####################### if args.optimizer == 'adam': optimizer = Adam(model.parameters(), lr=args.lr) elif args.optimizer == 'adamax': optimizer = Adamax(model.parameters(), lr=args.lr) if args.warmup is not None: scheduler_iter = LinearWarmupScheduler(optimizer, total_epoch=args.warmup) else: scheduler_iter = None if args.gamma is not None: scheduler_epoch = ExponentialLR(optimizer, gamma=args.gamma) else: scheduler_epoch = None ##################### ## Define training ## ##################### def train(model, train_loader, epoch): model = model.train()
tr_loader, va_loader = get_data_loaders(args.batch_size, args.dataset, args.img_size) ############# ## Model ## ############# model = get_model(using_vae=args.using_vae).to(device) # model.decoder.net.backbone.requires_grad = False # model.decoder.net.backbone.eval() optim = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=1e-4) sched = LinearWarmupScheduler(optim, 1000) ############### ## Logging ## ############### if args.vis_mode == 'tensorboard': from tensorboardX import SummaryWriter writer = SummaryWriter(flush_secs=30) elif args.vis_mode == 'wandb': import wandb # wandb.login(key=None) wandb.init(project='colorvae') wandb.config.update(args) wandb.watch(model) gIter = 0