def compute_loss(input, target, real): real = real.unsqueeze(1) target = utils.one_hot(target, NUM_CLASSES) target = torch.where(real, target, utils.label_smoothing(target, LABEL_SMOOTHING)) loss = softmax_cross_entropy(input=input, target=target) return loss
def train_epoch(model, optimizer, scheduler, data_loader, fold, epoch): writer = SummaryWriter( os.path.join(args.experiment_path, 'fold{}'.format(fold), 'train')) metrics = { 'loss': utils.Mean(), } update_transforms(np.linspace(0, 1, config.epochs)[epoch - 1].item()) model.train() optimizer.zero_grad() for i, (images, feats, _, labels, _) in enumerate( tqdm(data_loader, desc='epoch {} train'.format(epoch)), 1): images, feats, labels = images.to(DEVICE), feats.to(DEVICE), labels.to( DEVICE) labels = utils.one_hot(labels, NUM_CLASSES) images, labels = cutmix(images, labels) logits = model(images, object(), object()) loss = compute_loss(input=logits, target=labels) metrics['loss'].update(loss.data.cpu().numpy()) labels = labels.argmax(1) lr = scheduler.get_lr() (loss.mean() / config.opt.acc_steps).backward() if i % config.opt.acc_steps == 0: optimizer.step() optimizer.zero_grad() scheduler.step() with torch.no_grad(): metrics = {k: metrics[k].compute_and_reset() for k in metrics} images = images_to_rgb(images)[:16] print('[FOLD {}][EPOCH {}][TRAIN] {}'.format( fold, epoch, ', '.join('{}: {:.4f}'.format(k, metrics[k]) for k in metrics))) for k in metrics: writer.add_scalar(k, metrics[k], global_step=epoch) writer.add_scalar('learning_rate', lr, global_step=epoch) writer.add_image('images', torchvision.utils.make_grid( images, nrow=math.ceil(math.sqrt(images.size(0))), normalize=True), global_step=epoch)
def lsep_loss(input, target, exp): target = utils.one_hot(target, NUM_CLASSES) pos_mask = target > 0.5 neg_mask = target <= 0.5 loss = [] for e in np.unique(exp): e_mask = torch.tensor(exp == e, dtype=pos_mask.dtype, device=input.device) e_mask = e_mask.unsqueeze(1) pos_examples = input[pos_mask & e_mask] neg_examples = input[neg_mask & e_mask] pos_examples = pos_examples.unsqueeze(1) neg_examples = neg_examples.unsqueeze(0) loss.append(torch.log(1 + torch.sum(torch.exp(neg_examples - pos_examples), 1))) loss = torch.cat(loss, 0) return loss
def eval_epoch(model, data_loader, fold, epoch): writer = SummaryWriter( os.path.join(args.experiment_path, 'fold{}'.format(fold), 'eval')) metrics = { 'loss': utils.Mean(), } model.eval() with torch.no_grad(): fold_labels = [] fold_logits = [] fold_exps = [] for images, exps, labels, _ in tqdm( data_loader, desc='epoch {} evaluation'.format(epoch)): images, labels = images.to(DEVICE), labels.to(DEVICE) labels = utils.one_hot(labels, NUM_CLASSES) logits = model(images, None) loss = compute_loss(input=logits, target=labels, unsup=False) metrics['loss'].update(loss.data.cpu().numpy()) labels = labels.argmax(1) fold_labels.append(labels) fold_logits.append(logits) fold_exps.extend(exps) fold_labels = torch.cat(fold_labels, 0) fold_logits = torch.cat(fold_logits, 0) if epoch % 10 == 0: temp, metric, fig = find_temp_global(input=fold_logits, target=fold_labels, exps=fold_exps) writer.add_scalar('temp', temp, global_step=epoch) writer.add_scalar('metric_final', metric, global_step=epoch) writer.add_figure('temps', fig, global_step=epoch) temp = 1. # use default temp fold_preds = assign_classes(probs=to_prob(fold_logits, temp).data.cpu().numpy(), exps=fold_exps) fold_preds = torch.tensor(fold_preds).to(fold_logits.device) metric = compute_metric(input=fold_preds, target=fold_labels, exps=fold_exps) metrics = {k: metrics[k].compute_and_reset() for k in metrics} for k in metric: metrics[k] = metric[k].mean().data.cpu().numpy() images = images_to_rgb(images)[:16] print('[FOLD {}][EPOCH {}][EVAL] {}'.format( fold, epoch, ', '.join('{}: {:.4f}'.format(k, metrics[k]) for k in metrics))) for k in metrics: writer.add_scalar(k, metrics[k], global_step=epoch) writer.add_image('images', torchvision.utils.make_grid( images, nrow=math.ceil(math.sqrt(images.size(0))), normalize=True), global_step=epoch) return metrics
def train_epoch(model, optimizer, scheduler, data_loader, unsup_data_loader, fold, epoch): assert len(data_loader) <= len(unsup_data_loader), (len(data_loader), len(unsup_data_loader)) writer = SummaryWriter( os.path.join(args.experiment_path, 'fold{}'.format(fold), 'train')) metrics = { 'loss': utils.Mean(), } update_transforms(np.linspace(0, 1, config.epochs)[epoch - 1].item()) data = zip(data_loader, unsup_data_loader) total = min(len(data_loader), len(unsup_data_loader)) model.train() optimizer.zero_grad() for i, ((images_s, _, labels_s, _), (images_u, _, _)) \ in enumerate(tqdm(data, desc='epoch {} train'.format(epoch), total=total), 1): images_s, labels_s, images_u = images_s.to(DEVICE), labels_s.to( DEVICE), images_u.to(DEVICE) labels_s = utils.one_hot(labels_s, NUM_CLASSES) with torch.no_grad(): b, n, c, h, w = images_u.size() images_u = images_u.view(b * n, c, h, w) logits_u = model(images_u, None, True) logits_u = logits_u.view(b, n, NUM_CLASSES) labels_u = logits_u.softmax(2).mean(1, keepdim=True) labels_u = labels_u.repeat(1, n, 1).view(b * n, NUM_CLASSES) labels_u = dist_sharpen(labels_u, temp=SHARPEN_TEMP) assert images_s.size() == images_u.size() assert labels_s.size() == labels_u.size() images, labels = torch.cat([images_s, images_u], 0), torch.cat([labels_s, labels_u], 0) images, labels = mixup(images, labels) assert images.size(0) == config.batch_size * 2 logits = model(images, None, True) loss = compute_loss(input=logits, target=labels, unsup=True) metrics['loss'].update(loss.data.cpu().numpy()) labels = labels.argmax(1) lr = scheduler.get_lr() (loss.mean() / config.opt.acc_steps).backward() if i % config.opt.acc_steps == 0: optimizer.step() optimizer.zero_grad() scheduler.step() with torch.no_grad(): metrics = {k: metrics[k].compute_and_reset() for k in metrics} images = images_to_rgb(images)[:16] print('[FOLD {}][EPOCH {}][TRAIN] {}'.format( fold, epoch, ', '.join('{}: {:.4f}'.format(k, metrics[k]) for k in metrics))) for k in metrics: writer.add_scalar(k, metrics[k], global_step=epoch) writer.add_scalar('learning_rate', lr, global_step=epoch) writer.add_image('images', torchvision.utils.make_grid( images, nrow=math.ceil(math.sqrt(images.size(0))), normalize=True), global_step=epoch)
def lr_search(train_eval_data): train_eval_dataset = TrainEvalDataset(train_eval_data, transform=train_transform) train_eval_data_loader = torch.utils.data.DataLoader( train_eval_dataset, batch_size=config.batch_size, drop_last=True, shuffle=True, num_workers=args.workers, worker_init_fn=worker_init_fn) min_lr = 1e-7 max_lr = 10. gamma = (max_lr / min_lr)**(1 / len(train_eval_data_loader)) lrs = [] losses = [] lim = None model = Model(config.model, NUM_CLASSES) model = model.to(DEVICE) optimizer = build_optimizer(config.opt, model.parameters()) for param_group in optimizer.param_groups: param_group['lr'] = min_lr scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma) optimizer.train() update_transforms(1.) model.train() optimizer.zero_grad() for i, (images, _, labels, _) in enumerate(tqdm(train_eval_data_loader, desc='lr search'), 1): images, labels = images.to(DEVICE), labels.to(DEVICE) labels = utils.one_hot(labels, NUM_CLASSES) images, labels = mixup(images, labels) logits = model(images, None, True) loss = compute_loss(input=logits, target=labels) labels = labels.argmax(1) lrs.append(np.squeeze(scheduler.get_lr())) losses.append(loss.data.cpu().numpy().mean()) if lim is None: lim = losses[0] * 1.1 if lim < losses[-1]: break (loss.mean() / config.opt.acc_steps).backward() if i % config.opt.acc_steps == 0: optimizer.step() optimizer.zero_grad() scheduler.step() writer = SummaryWriter(os.path.join(args.experiment_path, 'lr_search')) with torch.no_grad(): losses = np.clip(losses, 0, lim) minima_loss = losses[np.argmin(utils.smooth(losses))] minima_lr = lrs[np.argmin(utils.smooth(losses))] step = 0 for loss, loss_sm in zip(losses, utils.smooth(losses)): writer.add_scalar('search_loss', loss, global_step=step) writer.add_scalar('search_loss_sm', loss_sm, global_step=step) step += config.batch_size fig = plt.figure() plt.plot(lrs, losses) plt.plot(lrs, utils.smooth(losses)) plt.axvline(minima_lr) plt.xscale('log') plt.title('loss: {:.8f}, lr: {:.8f}'.format(minima_loss, minima_lr)) writer.add_figure('search', fig, global_step=0) return minima_lr
def one_hot(input): return utils.one_hot(input, num_classes=NUM_CLASSES).permute((0, 3, 1, 2))