def run_epoch(model, loader, loss_fn, optimizer, desc_default='', epoch=0, writer=None, verbose=1, scheduler=None): tqdm_disable = bool(os.environ.get('TASK_NAME', '')) # KakaoBrain Environment if verbose: loader = tqdm(loader, disable=tqdm_disable) loader.set_description('[%s %04d/%04d]' % (desc_default, epoch, C.get()['epoch'])) metrics = Accumulator() cnt = 0 total_steps = len(loader) steps = 0 for data, label in loader: steps += 1 data, label = data.cuda(), label.cuda() if optimizer: optimizer.zero_grad() preds = model(data) loss = loss_fn(preds, label) if optimizer: loss.backward() if getattr(optimizer, "synchronize", None): optimizer.synchronize() # for horovod if C.get()['optimizer'].get('clip', 5) > 0: nn.utils.clip_grad_norm_(model.parameters(), C.get()['optimizer'].get('clip', 5)) optimizer.step() top1, top5 = accuracy(preds, label, (1, 5)) metrics.add_dict({ 'loss': loss.item() * len(data), 'top1': top1.item() * len(data), 'top5': top5.item() * len(data), }) cnt += len(data) if verbose: postfix = metrics / cnt if optimizer: postfix['lr'] = optimizer.param_groups[0]['lr'] loader.set_postfix(postfix) if scheduler is not None: scheduler.step(epoch - 1 + float(steps) / total_steps) del preds, loss, top1, top5, data, label if tqdm_disable: if optimizer: logger.info('[%s %03d/%03d] %s lr=%.6f', desc_default, epoch, C.get()['epoch'], metrics / cnt, optimizer.param_groups[0]['lr']) else: logger.info('[%s %03d/%03d] %s', desc_default, epoch, C.get()['epoch'], metrics / cnt) metrics /= cnt if optimizer: metrics.metrics['lr'] = optimizer.param_groups[0]['lr'] if verbose: for key, value in metrics.items(): writer.add_scalar(key, value, epoch) return metrics
def run_epoch(model, loader, loss_fn, optimizer, desc_default='', epoch=0, writer=None, verbose=1): if verbose: loader = tqdm(loader) if optimizer: curr_lr = optimizer.param_groups[0]['lr'] loader.set_description( '[%s %04d/%04d] lr=%.4f' % (desc_default, epoch, C.get()['epoch'], curr_lr)) else: loader.set_description('[%s %04d/%04d]' % (desc_default, epoch, C.get()['epoch'])) metrics = Accumulator() cnt = 0 for data, label in loader: data, label = data.cuda(), label.cuda() if optimizer: optimizer.zero_grad() preds = model(data) loss = loss_fn(preds, label) if optimizer: nn.utils.clip_grad_norm_(model.parameters(), 5) loss.backward() optimizer.step() top1, top5 = accuracy(preds, label, (1, 5)) metrics.add_dict({ 'loss': loss.item() * len(data), 'top1': top1.item() * len(data), 'top5': top5.item() * len(data), }) cnt += len(data) if verbose: loader.set_postfix(metrics / cnt) del preds, loss, top1, top5, data, label metrics /= cnt if optimizer: metrics.metrics['lr'] = optimizer.param_groups[0]['lr'] if verbose: for key, value in metrics.items(): writer.add_scalar(key, value, epoch) return metrics
def run_epoch(model, loader, loss_fn, optimizer, desc_default='', epoch=0, writer=None, verbose=1, scheduler=None, is_master=True, ema=None, wd=0.0, tqdm_disabled=False): if verbose: loader = tqdm(loader, disable=tqdm_disabled) loader.set_description('[%s %04d/%04d]' % (desc_default, epoch, C.get()['epoch'])) params_without_bn = [ params for name, params in model.named_parameters() if not ('_bn' in name or '.bn' in name) ] loss_ema = None metrics = Accumulator() cnt = 0 total_steps = len(loader) steps = 0 for data, label in loader: steps += 1 data, label = data.cuda(), label.cuda() if C.get().conf.get('mixup', 0.0) <= 0.0 or optimizer is None: preds = model(data) loss = loss_fn(preds, label) else: # mixup data, targets, shuffled_targets, lam = mixup( data, label, C.get()['mixup']) preds = model(data) loss = loss_fn(preds, targets, shuffled_targets, lam) del shuffled_targets, lam if optimizer: loss += wd * (1. / 2.) * sum( [torch.sum(p**2) for p in params_without_bn]) loss.backward() grad_clip = C.get()['optimizer'].get('clip', 5.0) if grad_clip > 0: nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() optimizer.zero_grad() if ema is not None: ema(model, (epoch - 1) * total_steps + steps) top1, top5 = accuracy(preds, label, (1, 5)) metrics.add_dict({ 'loss': loss.item() * len(data), 'top1': top1.item() * len(data), 'top5': top5.item() * len(data), }) cnt += len(data) if loss_ema: loss_ema = loss_ema * 0.9 + loss.item() * 0.1 else: loss_ema = loss.item() if verbose: postfix = metrics / cnt if optimizer: postfix['lr'] = optimizer.param_groups[0]['lr'] postfix['loss_ema'] = loss_ema loader.set_postfix(postfix) if scheduler is not None: scheduler.step(epoch - 1 + float(steps) / total_steps) del preds, loss, top1, top5, data, label if tqdm_disabled and verbose: if optimizer: logger.info('[%s %03d/%03d] %s lr=%.6f', desc_default, epoch, C.get()['epoch'], metrics / cnt, optimizer.param_groups[0]['lr']) else: logger.info('[%s %03d/%03d] %s', desc_default, epoch, C.get()['epoch'], metrics / cnt) metrics /= cnt if optimizer: metrics.metrics['lr'] = optimizer.param_groups[0]['lr'] if verbose: for key, value in metrics.items(): writer.add_scalar(key, value, epoch) return metrics
def run_epoch(model, loader, loss_fn, optimizer, desc_default='', epoch=0, writer=None, verbose=1, scheduler=None): model_name = C.get()['model']['type'] alpha = C.get()['alpha'] skip_ratios = ListAverageMeter() tqdm_disable = bool(os.environ.get('TASK_NAME', '')) if verbose: loader = tqdm(loader, disable=tqdm_disable) loader.set_description('[%s %04d/%04d]' % (desc_default, epoch, C.get()['epoch'])) metrics = Accumulator() cnt = 0 total_steps = len(loader) steps = 0 for data, label in loader: steps += 1 data, label = data.cuda(), label.cuda() if optimizer: optimizer.zero_grad() if model_name == 'pyramid_skip': if desc_default == '*test': with torch.no_grad(): preds, masks, gprobs = model(data) skips = [mask.data.le(0.5).float().mean() for mask in masks] if skip_ratios.len != len(skips): skip_ratios.set_len(len(skips)) skip_ratios.update(skips, data.size(0)) else: preds, masks, gprobs = model(data) sparsity_loss = 0 for mask in masks: sparsity_loss += mask.mean() loss1 = loss_fn(preds, label) loss2 = alpha * sparsity_loss loss = loss1 + loss2 else: preds = model(data) loss = loss_fn(preds, label) if optimizer: loss.backward() if getattr(optimizer, "synchronize", None): optimizer.skip_synchronize() if C.get()['optimizer'].get('clip', 5) > 0: nn.utils.clip_grad_norm_(model.parameters(), C.get()['optimizer'].get('clip', 5)) optimizer.step() top1, top5 = accuracy(preds, label, (1, 5)) if model_name == 'pyramid_skip': metrics.add_dict({ 'loss1': loss1.item() * len(data), 'loss2': loss2.item() * len(data), 'top1': top1.item() * len(data), 'top5': top5.item() * len(data), }) else: metrics.add_dict({ 'loss': loss.item() * len(data), 'top1': top1.item() * len(data), 'top5': top5.item() * len(data), }) cnt += len(data) if verbose: postfix = metrics / cnt if optimizer: postfix['lr'] = optimizer.param_groups[0]['lr'] loader.set_postfix(postfix) # if scheduler is not None: # scheduler.step(epoch - 1 + float(steps) / total_steps) if model_name == 'pyramid_skip': del masks[:], gprobs[:] del preds, loss, top1, top5, data, label if model_name == 'pyramid_skip': if desc_default == '*test': skip_summaries = [] for idx in range(skip_ratios.len): skip_summaries.append(1 - skip_ratios.avg[idx]) cp = ((sum(skip_summaries) + 1) / (len(skip_summaries) + 1)) * 100 if tqdm_disable: logger.info('[%s %03d/%03d] %s', desc_default, epoch, C.get()['epoch'], metrics / cnt) metrics /= cnt if optimizer: metrics.metrics['lr'] = optimizer.param_groups[0]['lr'] if verbose: for key, value in metrics.items(): writer.add_scalar(key, value, epoch) if model_name == 'pyramid_skip': if desc_default == '*test': writer.add_scalar('Computation Percentage', cp, epoch) return metrics
def train_controller(controller, dataloaders, save_path, ctl_save_path): dataset = C.get()['test_dataset'] ctl_train_steps = 1500 ctl_num_aggre = 10 ctl_entropy_w = 1e-5 ctl_ema_weight = 0.95 metrics = Accumulator() cnt = 0 controller.train() test_ratio = 0. _, _, dataloader, _ = dataloaders # validloader optimizer = optim.SGD(controller.parameters(), lr=0.00035, momentum=0.9, weight_decay=0.0, nesterov=True) # optimizer = optim.Adam(controller.parameters(), lr = 0.00035) # create a model & a criterion model = get_model(C.get()['model'], num_class(dataset), local_rank=-1) criterion = CrossEntropyLabelSmooth(num_class(dataset), C.get().conf.get('lb_smooth', 0), reduction="batched_sum").cuda() # load model weights data = torch.load(save_path) key = 'model' if 'model' in data else 'state_dict' if 'epoch' not in data: model.load_state_dict(data) else: logger.info('checkpoint epoch@%d' % data['epoch']) if not isinstance(model, (DataParallel, DistributedDataParallel)): model.load_state_dict( {k.replace('module.', ''): v for k, v in data[key].items()}) else: model.load_state_dict({ k if 'module.' in k else 'module.' + k: v for k, v in data[key].items() }) del data model.eval() loader_iter = iter(dataloader) # [(image)->ToTensor->Normalize] baseline = None if os.path.isfile(ctl_save_path): logger.info('------Controller load------') checkpoint = torch.load(ctl_save_path) controller.load_state_dict(checkpoint['ctl_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) cnt = checkpoint['cnt'] mean_probs = checkpoint['mean_probs'] accs = checkpoint['accs'] metrics_dict = checkpoint['metrics'] metrics.metrics = metrics_dict init_step = checkpoint['step'] else: logger.info('------Train Controller from scratch------') mean_probs = [] accs = [] init_step = 0 for step in tqdm(range(init_step + 1, ctl_train_steps * ctl_num_aggre + 1)): try: inputs, labels = next(loader_iter) except: loader_iter = iter(dataloader) inputs, labels = next(loader_iter) batch_size = len(labels) inputs, labels = inputs.cuda(), labels.cuda() log_probs, entropys, sampled_policies = controller(inputs) # evaluate model with augmented validation dataset with torch.no_grad(): # compare Accuracy before/after augmentation # ori_preds = model(inputs) # ori_top1, ori_top5 = accuracy(ori_preds, labels, (1, 5)) batch_policies = batch_policy_decoder( sampled_policies ) # (list:list:list:tuple) [batch, num_policy, n_op, 3] aug_inputs, applied_policy = augment_data(inputs, batch_policies) aug_inputs = aug_inputs.cuda() # assert type(aug_inputs) == torch.Tensor, "Augmented Input Type Error: {}".format(type(aug_inputs)) preds = model(aug_inputs) model_losses = criterion(preds, labels) # (tensor)[batch] top1, top5 = accuracy(preds, labels, (1, 5)) # logger.info("Acc B/A Aug, {:.2f}->{:.2f}".format(ori_top1, top1)) # assert model_losses.shape == entropys.shape == log_probs.shape, \ # "[Size miss match] loss: {}, entropy: {}, log_prob: {}".format(model_losses.shape, entropys.shape, log_probs.shape) rewards = -model_losses + ctl_entropy_w * entropys # (tensor)[batch] if baseline is None: baseline = -model_losses.mean() # scalar tensor else: # assert baseline, "len(baseline): {}".format(len(baseline)) baseline = baseline - (1 - ctl_ema_weight) * ( baseline - rewards.mean().detach()) # baseline = 0. loss = -1 * (log_probs * (rewards - baseline)).mean() #scalar tensor # Average gradient over controller_num_aggregate samples loss = loss / ctl_num_aggre loss.backward(retain_graph=True) metrics.add_dict({ 'loss': loss.item() * batch_size, 'top1': top1.item() * batch_size, 'top5': top5.item() * batch_size, }) cnt += batch_size if (step + 1) % ctl_num_aggre == 0: torch.nn.utils.clip_grad_norm_(controller.parameters(), 5.0) optimizer.step() controller.zero_grad() # torch.cuda.empty_cache() logger.info('\n[Train Controller %03d/%03d] log_prob %02f, %s', step, ctl_train_steps*ctl_num_aggre, \ log_probs.mean().item(), metrics / cnt ) if step % 100 == 0 or step == ctl_train_steps * ctl_num_aggre: save_pic(inputs, aug_inputs, labels, applied_policy, batch_policies, step) ps = [] for pol in batch_policies: # (list:list:list:tuple) [batch, num_policy, n_op, 3] for ops in pol: for op in ops: p = op[1] ps.append(p) mean_prob = np.mean(ps) mean_probs.append(mean_prob) accs.append(top1.item()) print("Mean probability: {:.2f}".format(mean_prob)) torch.save( { 'step': step, 'ctl_state_dict': controller.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'metrics': dict(metrics.metrics), 'cnt': cnt, 'mean_probs': mean_probs, 'accs': accs }, ctl_save_path) return metrics, None #baseline.item()