def validate(run_manager, epoch=0, is_test=True, image_size_list=None, width_mult_list=None, ks_list=None, expand_ratio_list=None, depth_list=None, additional_setting=None): dynamic_net = run_manager.net if isinstance(dynamic_net, nn.DataParallel): dynamic_net = dynamic_net.module dynamic_net.eval() if image_size_list is None: image_size_list = int2list(run_manager.run_config.data_provider.image_size, 1) if width_mult_list is None: width_mult_list = [i for i in range(len(dynamic_net.width_mult_list))] if ks_list is None: ks_list = dynamic_net.ks_list if expand_ratio_list is None: expand_ratio_list = dynamic_net.expand_ratio_list if depth_list is None: depth_list = dynamic_net.depth_list subnet_settings = [] for w in width_mult_list: for d in depth_list: for e in expand_ratio_list: for k in ks_list: for img_size in image_size_list: subnet_settings.append([{ 'image_size': img_size, 'wid': w, 'd': d, 'e': e, 'ks': k, }, 'R%s-W%s-D%s-E%s-K%s' % (img_size, w, d, e, k)]) if additional_setting is not None: subnet_settings += additional_setting losses_of_subnets, top1_of_subnets, top5_of_subnets = [], [], [] valid_log = '' for setting, name in subnet_settings: run_manager.write_log('-' * 30 + ' Validate %s ' % name + '-' * 30, 'train', should_print=False) run_manager.run_config.data_provider.assign_active_img_size(setting.pop('image_size')) dynamic_net.set_active_subnet(**setting) run_manager.write_log(dynamic_net.module_str, 'train', should_print=False) run_manager.reset_running_statistics(dynamic_net) loss, top1, top5 = run_manager.validate(epoch=epoch, is_test=is_test, run_str=name, net=dynamic_net) losses_of_subnets.append(loss) top1_of_subnets.append(top1) top5_of_subnets.append(top5) valid_log += '%s (%.3f), ' % (name, top1) return list_mean(losses_of_subnets), list_mean(top1_of_subnets), list_mean(top5_of_subnets), valid_log
def train_one_epoch(run_manager, args, epoch, warmup_epochs=0, warmup_lr=0): dynamic_net = run_manager.net # switch to train mode dynamic_net.train() run_manager.run_config.train_loader.sampler.set_epoch(epoch) MyRandomResizedCrop.EPOCH = epoch nBatch = len(run_manager.run_config.train_loader) data_time = AverageMeter() losses = DistributedMetric('train_loss') top1 = DistributedMetric('train_top1') top5 = DistributedMetric('train_top5') with tqdm(total=nBatch, desc='Train Epoch #{}'.format(epoch + 1), disable=not run_manager.is_root) as t: end = time.time() for i, (images, labels) in enumerate(run_manager.run_config.train_loader): data_time.update(time.time() - end) if epoch < warmup_epochs: new_lr = run_manager.run_config.warmup_adjust_learning_rate( run_manager.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr, ) else: new_lr = run_manager.run_config.adjust_learning_rate( run_manager.optimizer, epoch - warmup_epochs, i, nBatch) images, labels = images.cuda(), labels.cuda() target = labels # soft target if args.kd_ratio > 0: args.teacher_model.train() with torch.no_grad(): soft_logits = args.teacher_model(images).detach() soft_label = F.softmax(soft_logits, dim=1) # clear gradients run_manager.optimizer.zero_grad() loss_of_subnets, acc1_of_subnets, acc5_of_subnets = [], [], [] # compute output subnet_str = '' for _ in range(args.dynamic_batch_size): # set random seed before sampling if args.independent_distributed_sampling: subnet_seed = os.getpid() + time.time() else: subnet_seed = int('%d%.3d%.3d' % (epoch * nBatch + i, _, 0)) random.seed(subnet_seed) subnet_settings = dynamic_net.sample_active_subnet() subnet_str += '%d: ' % _ + ','.join([ '%s_%s' % (key, '%.1f' % subset_mean(val, 0) if isinstance(val, list) else val) for key, val in subnet_settings.items() ]) + ' || ' output = run_manager.net(images) if args.kd_ratio == 0: loss = run_manager.train_criterion(output, labels) loss_type = 'ce' else: if args.kd_type == 'ce': kd_loss = cross_entropy_loss_with_soft_target( output, soft_label) else: kd_loss = F.mse_loss(output, soft_logits) loss = args.kd_ratio * kd_loss + run_manager.train_criterion( output, labels) loss = loss * (2 / (args.kd_ratio + 1)) loss_type = '%.1fkd-%s & ce' % (args.kd_ratio, args.kd_type) # measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) loss_of_subnets.append(loss) acc1_of_subnets.append(acc1[0]) acc5_of_subnets.append(acc5[0]) loss.backward() run_manager.optimizer.step() losses.update(list_mean(loss_of_subnets), images.size(0)) top1.update(list_mean(acc1_of_subnets), images.size(0)) top5.update(list_mean(acc5_of_subnets), images.size(0)) t.set_postfix({ 'loss': losses.avg.item(), 'top1': top1.avg.item(), 'top5': top5.avg.item(), 'R': images.size(2), 'lr': new_lr, 'loss_type': loss_type, 'seed': str(subnet_seed), 'str': subnet_str, 'data_time': data_time.avg, }) t.update(1) end = time.time() return losses.avg.item(), top1.avg.item(), top5.avg.item()