def train(self, args, warmup_epochs=5, warmup_lr=0): for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epochs): train_loss, (train_top1, train_top5) = self.train_one_epoch( args, epoch, warmup_epochs, warmup_lr) img_size, val_loss, val_top1, val_top5 = self.validate_all_resolution( epoch, is_test=False) is_best = list_mean(val_top1) > self.best_acc self.best_acc = max(self.best_acc, list_mean(val_top1)) if self.is_root: val_log = '[{0}/{1}]\tloss {2:.3f}\t{6} acc {3:.3f} ({4:.3f})\t{7} acc {5:.3f}\t' \ 'Train {6} {top1:.3f}\tloss {train_loss:.3f}\t'. \ format(epoch + 1 - warmup_epochs, self.run_config.n_epochs, list_mean(val_loss), list_mean(val_top1), self.best_acc, list_mean(val_top5), *self.get_metric_names(), top1=train_top1, train_loss=train_loss) for i_s, v_a in zip(img_size, val_top1): val_log += '(%d, %.3f), ' % (i_s, v_a) self.write_log(val_log, prefix='valid', should_print=False) self.save_model( { 'epoch': epoch, 'best_acc': self.best_acc, 'optimizer': self.optimizer.state_dict(), 'state_dict': self.net.state_dict(), }, is_best=is_best)
def build_acc_dataset(self, run_manager, ofa_network, n_arch=1000, image_size_list=None): # load net_id_list, random sample if not exist if os.path.isfile(self.net_id_path): net_id_list = json.load(open(self.net_id_path)) else: net_id_list = set() while len(net_id_list) < n_arch: net_setting = ofa_network.sample_active_subnet() net_id = net_setting2id(net_setting) net_id_list.add(net_id) net_id_list = list(net_id_list) net_id_list.sort() json.dump(net_id_list, open(self.net_id_path, 'w'), indent=4) image_size_list = [128, 160, 192, 224] if image_size_list is None else image_size_list with tqdm(total=len(net_id_list) * len(image_size_list), desc='Building Acc Dataset') as t: for image_size in image_size_list: # load val dataset into memory val_dataset = [] run_manager.run_config.data_provider.assign_active_img_size(image_size) for images, labels in run_manager.run_config.valid_loader: val_dataset.append((images, labels)) # save path os.makedirs(self.acc_src_folder, exist_ok=True) acc_save_path = os.path.join(self.acc_src_folder, '%d.dict' % image_size) acc_dict = {} # load existing acc dict if os.path.isfile(acc_save_path): existing_acc_dict = json.load(open(acc_save_path, 'r')) else: existing_acc_dict = {} for net_id in net_id_list: net_setting = net_id2setting(net_id) key = net_setting2id({**net_setting, 'image_size': image_size}) if key in existing_acc_dict: acc_dict[key] = existing_acc_dict[key] t.set_postfix({ 'net_id': net_id, 'image_size': image_size, 'info_val': acc_dict[key], 'status': 'loading', }) t.update() continue ofa_network.set_active_subnet(**net_setting) run_manager.reset_running_statistics(ofa_network) net_setting_str = ','.join(['%s_%s' % ( key, '%.1f' % list_mean(val) if isinstance(val, list) else val ) for key, val in net_setting.items()]) loss, (top1, top5) = run_manager.validate( run_str=net_setting_str, net=ofa_network, data_loader=val_dataset, no_logs=True, ) info_val = top1 t.set_postfix({ 'net_id': net_id, 'image_size': image_size, 'info_val': info_val, }) t.update() acc_dict.update({ key: info_val }) json.dump(acc_dict, open(acc_save_path, 'w'), indent=4)
def validate(run_manager, epoch=0, is_test=False, image_size_list=None, ks_list=None, expand_ratio_list=None, depth_list=None, width_mult_list=None, additional_setting=None): dynamic_net = run_manager.net if isinstance(dynamic_net, nn.DataParallel): dynamic_net = dynamic_net.module dynamic_net.eval() if image_size_list is None: image_size_list = val2list( run_manager.run_config.data_provider.image_size, 1) if ks_list is None: ks_list = dynamic_net.ks_list if expand_ratio_list is None: expand_ratio_list = dynamic_net.expand_ratio_list if depth_list is None: depth_list = dynamic_net.depth_list if width_mult_list is None: if 'width_mult_list' in dynamic_net.__dict__: width_mult_list = list(range(len(dynamic_net.width_mult_list))) else: width_mult_list = [0] subnet_settings = [] for d in depth_list: for e in expand_ratio_list: for k in ks_list: for w in width_mult_list: for img_size in image_size_list: subnet_settings.append([{ 'image_size': img_size, 'd': d, 'e': e, 'ks': k, 'w': w, }, 'R%s-D%s-E%s-K%s-W%s' % (img_size, d, e, k, w)]) if additional_setting is not None: subnet_settings += additional_setting losses_of_subnets, top1_of_subnets, top5_of_subnets = [], [], [] valid_log = '' for setting, name in subnet_settings: run_manager.write_log('-' * 30 + ' Validate %s ' % name + '-' * 30, 'train', should_print=False) run_manager.run_config.data_provider.assign_active_img_size( setting.pop('image_size')) dynamic_net.set_active_subnet(**setting) run_manager.write_log(dynamic_net.module_str, 'train', should_print=False) run_manager.reset_running_statistics(dynamic_net) loss, (top1, top5) = run_manager.validate(epoch=epoch, is_test=is_test, run_str=name, net=dynamic_net) losses_of_subnets.append(loss) top1_of_subnets.append(top1) top5_of_subnets.append(top5) valid_log += '%s (%.3f), ' % (name, top1) return list_mean(losses_of_subnets), list_mean(top1_of_subnets), list_mean( top5_of_subnets), valid_log
def train_one_epoch(run_manager, args, epoch, warmup_epochs=0, warmup_lr=0): dynamic_net = run_manager.network distributed = isinstance(run_manager, DistributedRunManager) # switch to train mode dynamic_net.train() if distributed: run_manager.run_config.train_loader.sampler.set_epoch(epoch) MyRandomResizedCrop.EPOCH = epoch nBatch = len(run_manager.run_config.train_loader) data_time = AverageMeter() losses = DistributedMetric('train_loss') if distributed else AverageMeter() metric_dict = run_manager.get_metric_dict() with tqdm(total=nBatch, desc='Train Epoch #{}'.format(epoch + 1), disable=distributed and not run_manager.is_root) as t: end = time.time() for i, (images, labels) in enumerate(run_manager.run_config.train_loader): MyRandomResizedCrop.BATCH = i data_time.update(time.time() - end) if epoch < warmup_epochs: new_lr = run_manager.run_config.warmup_adjust_learning_rate( run_manager.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr, ) else: new_lr = run_manager.run_config.adjust_learning_rate( run_manager.optimizer, epoch - warmup_epochs, i, nBatch) images, labels = images.cuda(), labels.cuda() target = labels # soft target if args.kd_ratio > 0: args.teacher_model.train() with torch.no_grad(): soft_logits = args.teacher_model(images).detach() soft_label = F.softmax(soft_logits, dim=1) # clean gradients dynamic_net.zero_grad() loss_of_subnets = [] # compute output subnet_str = '' for _ in range(args.dynamic_batch_size): # set random seed before sampling subnet_seed = int('%d%.3d%.3d' % (epoch * nBatch + i, _, 0)) random.seed(subnet_seed) subnet_settings = dynamic_net.sample_active_subnet() subnet_str += '%d: ' % _ + ','.join([ '%s_%s' % (key, '%.1f' % subset_mean(val, 0) if isinstance(val, list) else val) for key, val in subnet_settings.items() ]) + ' || ' output = run_manager.net(images) if args.kd_ratio == 0: loss = run_manager.train_criterion(output, labels) loss_type = 'ce' else: if args.kd_type == 'ce': kd_loss = cross_entropy_loss_with_soft_target( output, soft_label) else: kd_loss = F.mse_loss(output, soft_logits) loss = args.kd_ratio * kd_loss + run_manager.train_criterion( output, labels) loss_type = '%.1fkd-%s & ce' % (args.kd_ratio, args.kd_type) # measure accuracy and record loss loss_of_subnets.append(loss) run_manager.update_metric(metric_dict, output, target) loss.backward() run_manager.optimizer.step() losses.update(list_mean(loss_of_subnets), images.size(0)) t.set_postfix({ 'loss': losses.avg.item(), **run_manager.get_metric_vals(metric_dict, return_dict=True), 'R': images.size(2), 'lr': new_lr, 'loss_type': loss_type, 'seed': str(subnet_seed), 'str': subnet_str, 'data_time': data_time.avg, }) t.update(1) end = time.time() return losses.avg.item(), run_manager.get_metric_vals(metric_dict)
'Group %d: %d params with wd %f' % (i + 1, len(param_group['params']), param_group['weight_decay']), 'grad_params', True, 'a') for name, param in net.named_parameters(): if param.requires_grad: run_manager.write_log('%s: %s' % (name, list(param.data.size())), 'grad_params', False, 'a') run_manager.save_config() if args.resume: run_manager.load_model() else: init_path = '%s/init' % args.path if os.path.isfile(init_path): checkpoint = torch.load(init_path, map_location='cpu') if 'state_dict' in checkpoint: checkpoint = checkpoint['state_dict'] run_manager.network.load_state_dict(checkpoint) # train args.teacher_model = None run_manager.train(args) # test img_size, loss, acc1, acc5 = run_manager.validate_all_resolution( is_test=True) log = 'test_loss: %f\t test_acc1: %f\t test_acc5: %f\t' % ( list_mean(loss), list_mean(acc1), list_mean(acc5)) for i_s, v_a in zip(img_size, acc1): log += '(%d, %.3f), ' % (i_s, v_a) run_manager.write_log(log, prefix='test')