def train_one_epoch(self, args, epoch, warmup_epochs=0, warmup_lr=0): # switch to train mode self.net.train() MyRandomResizedCrop.EPOCH = epoch # required by elastic resolution nBatch = len(self.run_config.train_loader) losses = AverageMeter() metric_dict = self.get_metric_dict() data_time = AverageMeter() with tqdm(total=nBatch, desc='{} Train Epoch #{}'.format(self.run_config.dataset, epoch + 1)) as t: end = time.time() for i, (images, labels) in enumerate(self.run_config.train_loader): MyRandomResizedCrop.BATCH = i data_time.update(time.time() - end) if epoch < warmup_epochs: new_lr = self.run_config.warmup_adjust_learning_rate( self.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr, ) else: new_lr = self.run_config.adjust_learning_rate( self.optimizer, epoch - warmup_epochs, i, nBatch) images, labels = images.to(self.device), labels.to(self.device) target = labels if isinstance(self.run_config.mixup_alpha, float): # transform data lam = random.betavariate(self.run_config.mixup_alpha, self.run_config.mixup_alpha) images = mix_images(images, lam) labels = mix_labels( labels, lam, self.run_config.data_provider.n_classes, self.run_config.label_smoothing) # soft target if args.teacher_model is not None: args.teacher_model.train() with torch.no_grad(): soft_logits = args.teacher_model(images).detach() soft_label = F.softmax(soft_logits, dim=1) # compute output output = self.net(images) loss = self.train_criterion(output, labels) if args.teacher_model is None: loss_type = 'ce' else: if args.kd_type == 'ce': kd_loss = cross_entropy_loss_with_soft_target( output, soft_label) else: kd_loss = F.mse_loss(output, soft_logits) loss = args.kd_ratio * kd_loss + loss loss_type = '%.1fkd+ce' % args.kd_ratio # compute gradient and do SGD step self.net.zero_grad() # or self.optimizer.zero_grad() loss.backward() self.optimizer.step() # measure accuracy and record loss losses.update(loss.item(), images.size(0)) self.update_metric(metric_dict, output, target) t.set_postfix({ 'loss': losses.avg, **self.get_metric_vals(metric_dict, return_dict=True), 'img_size': images.size(2), 'lr': new_lr, 'loss_type': loss_type, 'data_time': data_time.avg, }) t.update(1) end = time.time() return losses.avg, self.get_metric_vals(metric_dict)
def train_one_epoch(run_manager, args, epoch, warmup_epochs=0, warmup_lr=0): dynamic_net = run_manager.network distributed = isinstance(run_manager, DistributedRunManager) # switch to train mode dynamic_net.train() if distributed: run_manager.run_config.train_loader.sampler.set_epoch(epoch) MyRandomResizedCrop.EPOCH = epoch nBatch = len(run_manager.run_config.train_loader) data_time = AverageMeter() losses = DistributedMetric('train_loss') if distributed else AverageMeter() metric_dict = run_manager.get_metric_dict() with tqdm(total=nBatch, desc='Train Epoch #{}'.format(epoch + 1), disable=distributed and not run_manager.is_root) as t: end = time.time() for i, (images, labels) in enumerate(run_manager.run_config.train_loader): MyRandomResizedCrop.BATCH = i data_time.update(time.time() - end) if epoch < warmup_epochs: new_lr = run_manager.run_config.warmup_adjust_learning_rate( run_manager.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr, ) else: new_lr = run_manager.run_config.adjust_learning_rate( run_manager.optimizer, epoch - warmup_epochs, i, nBatch) images, labels = images.cuda(), labels.cuda() target = labels # soft target if args.kd_ratio > 0: args.teacher_model.train() with torch.no_grad(): soft_logits = args.teacher_model(images).detach() soft_label = F.softmax(soft_logits, dim=1) # clean gradients dynamic_net.zero_grad() loss_of_subnets = [] # compute output subnet_str = '' for _ in range(args.dynamic_batch_size): # set random seed before sampling subnet_seed = int('%d%.3d%.3d' % (epoch * nBatch + i, _, 0)) random.seed(subnet_seed) subnet_settings = dynamic_net.sample_active_subnet() subnet_str += '%d: ' % _ + ','.join([ '%s_%s' % (key, '%.1f' % subset_mean(val, 0) if isinstance(val, list) else val) for key, val in subnet_settings.items() ]) + ' || ' output = run_manager.net(images) if args.kd_ratio == 0: loss = run_manager.train_criterion(output, labels) loss_type = 'ce' else: if args.kd_type == 'ce': kd_loss = cross_entropy_loss_with_soft_target( output, soft_label) else: kd_loss = F.mse_loss(output, soft_logits) loss = args.kd_ratio * kd_loss + run_manager.train_criterion( output, labels) loss_type = '%.1fkd-%s & ce' % (args.kd_ratio, args.kd_type) # measure accuracy and record loss loss_of_subnets.append(loss) run_manager.update_metric(metric_dict, output, target) loss.backward() run_manager.optimizer.step() losses.update(list_mean(loss_of_subnets), images.size(0)) t.set_postfix({ 'loss': losses.avg.item(), **run_manager.get_metric_vals(metric_dict, return_dict=True), 'R': images.size(2), 'lr': new_lr, 'loss_type': loss_type, 'seed': str(subnet_seed), 'str': subnet_str, 'data_time': data_time.avg, }) t.update(1) end = time.time() return losses.avg.item(), run_manager.get_metric_vals(metric_dict)