def lr_search(train_eval_data): train_eval_dataset = TrainEvalDataset(train_eval_data, transform=train_transform) train_eval_data_loader = torch.utils.data.DataLoader( train_eval_dataset, batch_size=config.batch_size, drop_last=True, shuffle=True, num_workers=args.workers, worker_init_fn=worker_init_fn) min_lr = 1e-7 max_lr = 10. gamma = (max_lr / min_lr)**(1 / len(train_eval_data_loader)) lrs = [] losses = [] lim = None model = Model(config.model, NUM_CLASSES) model = model.to(DEVICE) optimizer = build_optimizer(config.opt, model.parameters()) for param_group in optimizer.param_groups: param_group['lr'] = min_lr scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma) optimizer.train() update_transforms(1.) model.train() optimizer.zero_grad() for i, (images, _, labels, _) in enumerate(tqdm(train_eval_data_loader, desc='lr search'), 1): images, labels = images.to(DEVICE), labels.to(DEVICE) labels = utils.one_hot(labels, NUM_CLASSES) images, labels = mixup(images, labels) logits = model(images, None, True) loss = compute_loss(input=logits, target=labels) labels = labels.argmax(1) lrs.append(np.squeeze(scheduler.get_lr())) losses.append(loss.data.cpu().numpy().mean()) if lim is None: lim = losses[0] * 1.1 if lim < losses[-1]: break (loss.mean() / config.opt.acc_steps).backward() if i % config.opt.acc_steps == 0: optimizer.step() optimizer.zero_grad() scheduler.step() writer = SummaryWriter(os.path.join(args.experiment_path, 'lr_search')) with torch.no_grad(): losses = np.clip(losses, 0, lim) minima_loss = losses[np.argmin(utils.smooth(losses))] minima_lr = lrs[np.argmin(utils.smooth(losses))] step = 0 for loss, loss_sm in zip(losses, utils.smooth(losses)): writer.add_scalar('search_loss', loss, global_step=step) writer.add_scalar('search_loss_sm', loss_sm, global_step=step) step += config.batch_size fig = plt.figure() plt.plot(lrs, losses) plt.plot(lrs, utils.smooth(losses)) plt.axvline(minima_lr) plt.xscale('log') plt.title('loss: {:.8f}, lr: {:.8f}'.format(minima_loss, minima_lr)) writer.add_figure('search', fig, global_step=0) return minima_lr
def find_lr(): train_dataset = TrainEvalDataset(train_data, transform=train_transform) train_data_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.batch_size, drop_last=True, shuffle=True, num_workers=args.workers) if config.mixup is not None: train_data_loader = MixupDataLoader(train_data_loader, config.mixup) min_lr = 1e-7 max_lr = 10. gamma = (max_lr / min_lr)**(1 / len(train_data_loader)) lrs = [] losses = [] lim = None model = Model(config.model, NUM_CLASSES) model = model.to(DEVICE) optimizer = build_optimizer(config.opt.type, model.parameters(), min_lr, config.opt.beta, weight_decay=config.opt.weight_decay) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma) model.train() for images, labels, ids in tqdm(train_data_loader, desc='lr search'): images, labels = images.to(DEVICE), labels.to(DEVICE) logits = model(images) loss = compute_loss(input=logits, target=labels, smoothing=config.label_smooth) lrs.append(np.squeeze(scheduler.get_lr())) losses.append(loss.data.cpu().numpy().mean()) if lim is None: lim = losses[0] * 1.1 if lim < losses[-1]: break optimizer.zero_grad() loss.mean().backward() optimizer.step() scheduler.step() if args.debug: break writer = SummaryWriter(os.path.join(args.experiment_path, 'lr_search')) with torch.no_grad(): losses = np.clip(losses, 0, lim) minima_loss = losses[np.argmin(utils.smooth(losses))] minima_lr = lrs[np.argmin(utils.smooth(losses))] step = 0 for loss, loss_sm in zip(losses, utils.smooth(losses)): writer.add_scalar('search_loss', loss, global_step=step) writer.add_scalar('search_loss_sm', loss_sm, global_step=step) step += config.batch_size fig = plt.figure() plt.plot(lrs, losses) plt.plot(lrs, utils.smooth(losses)) plt.axvline(minima_lr) plt.xscale('log') plt.title('loss: {:.8f}, lr: {:.8f}'.format(minima_loss, minima_lr)) writer.add_figure('search', fig, global_step=0) return minima_lr
def find_lr(train_eval_data, train_noisy_data): fail # TODO: mixup train_eval_dataset = torch.utils.data.ConcatDataset([ TrainEvalDataset(train_eval_data, transform=train_transform), TrainEvalDataset(train_noisy_data, transform=train_transform) ]) # TODO: all args train_eval_data_loader = torch.utils.data.DataLoader( train_eval_dataset, batch_size=config.batch_size, drop_last=True, shuffle=True, num_workers=args.workers, collate_fn=collate_fn, worker_init_fn=worker_init_fn) min_lr = 1e-7 max_lr = 10. gamma = (max_lr / min_lr)**(1 / len(train_eval_data_loader)) lrs = [] losses = [] lim = None model = Model(config.model, NUM_CLASSES) model = model.to(DEVICE) optimizer = build_optimizer(config.opt.type, model.parameters(), min_lr, config.opt.beta, weight_decay=config.opt.weight_decay) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma) model.train() for sigs, labels, ids in tqdm(train_eval_data_loader, desc='lr search'): sigs, labels = sigs.to(DEVICE), labels.to(DEVICE) logits, _, _ = model(sigs) loss = compute_loss(input=logits, target=labels) lrs.append(np.squeeze(scheduler.get_lr())) losses.append(loss.data.cpu().numpy().mean()) if lim is None: lim = losses[0] * 2. if lim < losses[-1]: break optimizer.zero_grad() loss.mean().backward() optimizer.step() scheduler.step() if args.debug: break with torch.no_grad(): losses = np.clip(losses, 0, lim) minima_loss = losses[np.argmin(utils.smooth(losses))] minima_lr = lrs[np.argmin(utils.smooth(losses))] writer = SummaryWriter(os.path.join(args.experiment_path, 'lr_search')) step = 0 for loss, loss_sm in zip(losses, utils.smooth(losses)): writer.add_scalar('search_loss', loss, global_step=step) writer.add_scalar('search_loss_sm', loss_sm, global_step=step) step += config.batch_size plt.plot(lrs, losses) plt.plot(lrs, utils.smooth(losses)) plt.axvline(minima_lr) plt.xscale('log') plt.title('loss: {:.8f}, lr: {:.8f}'.format(minima_loss, minima_lr)) plot = utils.plot_to_image() writer.add_image('search', plot.transpose((2, 0, 1)), global_step=0) return minima_lr