def sample_model(config, model): num_sample = config.arch.num_model_sample root_dir = os.path.join(config.save_path, 'model_sample') if dist.is_master(): if not os.path.exists(root_dir): os.makedirs(root_dir) for i in range(num_sample + 1): cur_flops = config.arch.target_flops * 10 while cur_flops > config.arch.target_flops * 1.01 or cur_flops < config.arch.target_flops * 0.99: model.module.direct_sampling() cur_flops = calc_model_flops(model, config.dataset.input_size) if dist.is_master(): if i == num_sample: sample_dir = os.path.join(root_dir, 'expected_ch') model.module.expected_sampling() else: sample_dir = os.path.join(root_dir, 'sample_{}'.format(i + 1)) if not os.path.exists(sample_dir): os.makedirs(sample_dir) save_path = os.path.join(sample_dir, 'sample.npy') if config.model.type.find('MobileNetV2') > -1: dump_mbv2_setting(model.module, save_path) elif config.model.type.find('ResNet') > -1: dump_resnet_setting(model.module, save_path) else: raise ValueError('Unknown model: {}'.format(config.model.type)) dist.barrier()
def main(): args = tools.get_args(parser) config = tools.get_config(args) tools.init(config) tb_logger, logger = tools.get_logger(config) tools.check_dist_init(config, logger) checkpoint = tools.get_checkpoint(config) runner = tools.get_model(config, checkpoint) loaders = tools.get_data_loader(config) if dist.is_master(): logger.info(config) if args.mode == 'train': train(config, runner, loaders, checkpoint, tb_logger) elif args.mode == 'evaluate': evaluate(runner, loaders) elif args.mode == 'calc_flops': if dist.is_master(): flops = tools.get_model_flops(config, runner.get_model()) logger.info('flops: {}'.format(flops)) elif args.mode == 'calc_params': if dist.is_master(): params = tools.get_model_parameters(runner.get_model()) logger.info('params: {}'.format(params)) else: assert checkpoint is not None from models.dmcp.utils import sample_model sample_model(config, runner.get_model()) if dist.is_master(): logger.info('Done')
def get_logger(config, name='global_logger'): save_dir = config.model.type + '_' if config.get('arch', False): save_dir += str(config.arch.target_flops) + '_' save_dir = time.strftime(save_dir + '%m%d%H') save_dir = os.path.join(config.save_path, save_dir) if dist.is_master(): if not os.path.exists(save_dir): os.makedirs(save_dir) else: while not os.path.exists(save_dir): time.sleep(1) config.save_path = save_dir events_dir = config.save_path + '/events' if dist.is_master(): if not os.path.exists(events_dir): os.makedirs(events_dir) else: while not os.path.exists(events_dir): time.sleep(1) tb_logger = SummaryWriter(config.save_path + '/events') logger = logging.getLogger(name) formatter = logging.Formatter('[%(asctime)s][%(filename)15s][line:%(lineno)4d][%(levelname)8s] %(message)s') fh = logging.FileHandler(config.save_path + '/log.txt') fh.setFormatter(formatter) sh = logging.StreamHandler() sh.setFormatter(formatter) logger.setLevel(logging.INFO) logger.addHandler(fh) logger.addHandler(sh) return tb_logger, logger
def profiling(model, use_cuda): """Profiling on either gpu or cpu.""" if udist.is_master(): logging.info('Start model profiling, use_cuda:{}.'.format(use_cuda)) model_profiling(model, FLAGS.image_size, FLAGS.image_size, verbose=getattr(FLAGS, 'model_profiling_verbose', True) and udist.is_master())
def layer_flops_distribution(config, model): num_sample = config.arch.num_flops_stats_sample repo = {} for _ in range(num_sample): cur_flops = config.arch.target_flops * 10 while cur_flops > config.arch.target_flops * 1.05 or cur_flops < config.arch.target_flops * 0.95: model.module.direct_sampling() cur_flops = calc_model_flops(model, config.dataset.input_size) for n, m in model.named_modules(): if isinstance(m, USModule): if n not in repo.keys(): repo[n] = [] repo[n].append(m.cur_out_ch) if dist.is_master(): root_dir = os.path.join(config.save_path, 'layer_flops_distribution') if not os.path.exists(root_dir): os.makedirs(root_dir) for n in repo.keys(): save_path = os.path.join(root_dir, n + '.pdf') plt.hist(repo[n], 50, density=True, facecolor='g', alpha=0.75) pp = PdfPages(save_path) plt.savefig(pp, format='pdf') pp.close() plt.gcf().clear()
def compute_mean_channel(iteration, config, model): mean_chs = [] offset = [] num_sample = config.arch.num_flops_stats_sample for n, m in model.named_modules(): if n.find('alpha') > -1: mean_ch = 0 for i in range(num_sample): mean_ch += m.direct_sampling() mean_chs.append(round(mean_ch / num_sample, 3)) offset.append(m.channels - mean_chs[-1]) if dist.is_master(): ind = np.arange(len(mean_chs)) width = 0.35 p1 = plt.bar(ind, mean_chs, width) p2 = plt.bar(ind, offset, width, bottom=mean_chs) plt.ylabel('#channel') plt.xticks(ind, ind) plt.yticks(np.arange(0, max([mean_chs[i] + offset[i] for i in range(len(mean_chs))]), 100)) plt.legend((p1[0], p2[0]), ('expected', 'max')) save_folder = os.path.join(config.save_path, 'mean_chs') if not os.path.exists(save_folder): os.makedirs(save_folder) pp = PdfPages(os.path.join(save_folder, 'mean_chs_{}.pdf'.format(iteration))) plt.savefig(pp, format='pdf') pp.close() plt.gcf().clear()
def _logging(self, tb_logger, epoch_idx, batch_idx, total_batch, meters, cur_lr): print_freq = self.config.logging.print_freq top1_meter, top5_meter, loss_meter, data_time, batch_time = meters if self.cur_step % print_freq == 0 and dist.is_master(): tb_logger.add_scalar('lr', cur_lr, self.cur_step) tb_logger.add_scalar('acc1_train', top1_meter.avg, self.cur_step) tb_logger.add_scalar('acc5_train', top5_meter.avg, self.cur_step) tb_logger.add_scalar('loss_train', loss_meter.avg, self.cur_step) self._info('-' * 80) self._info('Epoch: [{0}/{1}]\tIter: [{2}/{3}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'LR {lr:.4f}'.format(epoch_idx, self.config.training.epoch, batch_idx, total_batch, batch_time=batch_time, data_time=data_time, lr=cur_lr)) self._info('Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format( loss=loss_meter, top1=top1_meter, top5=top5_meter))
def _logging(self, tb_logger, epoch_idx, batch_idx, total_batch, meters, cur_lr): print_freq = self.config.logging.print_freq top1_meter, top5_meter, loss_meter, data_time, batch_time = meters if self.cur_step % print_freq == 0 and dist.is_master(): tb_logger.add_scalar('lr', cur_lr, self.cur_step) self._info('-' * 80) self._info('Epoch: [{0}/{1}]\tIter: [{2}/{3}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'LR {lr:.4f}'.format(epoch_idx, self.config.training.epoch, batch_idx, total_batch, batch_time=batch_time, data_time=data_time, lr=cur_lr)) titles = ['min_width', 'max_width', 'random_width'] for idx in range(3): tb_logger.add_scalar('loss_train@{}'.format(titles[idx]), loss_meter[idx].avg, self.cur_step) tb_logger.add_scalar('acc1_train@{}'.format(titles[idx]), top1_meter[idx].avg, self.cur_step) tb_logger.add_scalar('acc5_train@{}'.format(titles[idx]), top5_meter[idx].avg, self.cur_step) self._info('{title}\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format( title=titles[idx], loss=loss_meter[idx], top1=top1_meter[idx], top5=top5_meter[idx]))
def _logging(self, tb_logger, epoch_idx, batch_idx, total_batch, meters, cur_lr): cur_lr, cur_arch_lr = cur_lr top1_meter, top5_meter, loss_meter, arch_loss_meter, floss_meter, \ eflops_meter, arch_top1_meter, data_time, batch_time = meters super(DMCPRunner, self)._logging( tb_logger, epoch_idx, batch_idx, total_batch, [top1_meter, top5_meter, loss_meter, data_time, batch_time], cur_lr) print_freq = self.config.logging.print_freq if self.cur_step % print_freq == 0 and dist.is_master() \ and self.cur_step >= self.config.arch.start_train: tb_logger.add_scalar('arc_loss', arch_loss_meter.avg, self.cur_step) tb_logger.add_scalar('flops_loss', floss_meter.avg, self.cur_step) tb_logger.add_scalar('eflops', eflops_meter.avg, self.cur_step) tb_logger.add_scalar('arc_top1', arch_top1_meter.avg, self.cur_step) self._info( 'expected_flops {:.2f} flops_loss {:.4f}, arch_task_loss {:.4f}, ' 'arch_top1 {:.2f}, arch_lr {:.4f}'.format( eflops_meter.avg, floss_meter.avg, arch_loss_meter.avg, arch_top1_meter.avg, cur_arch_lr)) tb_logger.add_scalar('expected_flops', eflops_meter.avg, self.cur_step)
def validate(self, val_loader, tb_logger=None): batch_time = AverageMeter(0) loss_meter = AverageMeter(0) top1_meter = AverageMeter(0) top5_meter = AverageMeter(0) self.model.eval() criterion = nn.CrossEntropyLoss() end = time.time() with torch.no_grad(): for batch_idx, (x, y) in enumerate(val_loader): x, y = x.cuda(), y.cuda() num = x.size(0) out = self.model(x) loss = criterion(out, y) top1, top5 = accuracy(out, y, top_k=(1, 5)) loss_meter.update(loss.item(), num) top1_meter.update(top1.item(), num) top5_meter.update(top5.item(), num) batch_time.update(time.time() - end) end = time.time() if batch_idx % self.config.logging.print_freq == 0: self._info( 'Test: [{0}/{1}]\tTime {batch_time.val:.3f} ({batch_time.avg:.3f})' .format(batch_idx, len(val_loader), batch_time=batch_time)) total_num = torch.tensor([loss_meter.count]).cuda() loss_sum = torch.tensor([loss_meter.avg * loss_meter.count]).cuda() top1_sum = torch.tensor([top1_meter.avg * top1_meter.count]).cuda() top5_sum = torch.tensor([top5_meter.avg * top5_meter.count]).cuda() dist.all_reduce(total_num) dist.all_reduce(loss_sum) dist.all_reduce(top1_sum) dist.all_reduce(top5_sum) val_loss = loss_sum.item() / total_num.item() val_top1 = top1_sum.item() / total_num.item() val_top5 = top5_sum.item() / total_num.item() self._info( 'Prec@1 {:.3f}\tPrec@5 {:.3f}\tLoss {:.3f}\ttotal_num={}'.format( val_top1, val_top5, val_loss, loss_meter.count)) if dist.is_master(): if val_top1 > self.best_top1: self.best_top1 = val_top1 if tb_logger is not None: tb_logger.add_scalar('loss_val', val_loss, self.cur_step) tb_logger.add_scalar('acc1_val', val_top1, self.cur_step) tb_logger.add_scalar('acc5_val', val_top5, self.cur_step)
def shrink_model(model_wrapper, ema, optimizer, prune_info, threshold=1e-3, ema_only=False): r"""Dynamic network shrinkage to discard dead atomic blocks. Args: model_wrapper: model to be shrinked. ema: An instance of `ExponentialMovingAverage`, could be None. optimizer: Global optimizer. prune_info: An instance of `PruneInfo`, could be None. threshold: A small enough constant. ema_only: If `True`, regard an atomic block as dead only when `$$\hat{alpha} \le threshold$$`. Otherwise use both current value and momentum version. """ model = mc.unwrap_model(model_wrapper) for block_name, block in model.get_named_block_list().items( ): # inverted residual blocks assert isinstance(block, mb.InvertedResidualChannels) masks = [ bn.weight.detach().abs() > threshold for bn in block.get_depthwise_bn() ] if ema is not None: masks_ema = [ ema.average('{}.{}.weight'.format( block_name, name)).detach().abs() > threshold for name in block.get_named_depthwise_bn().keys() ] if not ema_only: masks = [ mask0 | mask1 for mask0, mask1 in zip(masks, masks_ema) ] else: masks = masks_ema block.compress_by_mask(masks, ema=ema, optimizer=optimizer, prune_info=prune_info, prefix=block_name, verbose=False) if optimizer is not None: assert set(optimizer.param_groups[0]['params']) == set( model.parameters()) mc.model_profiling(model, FLAGS.image_size, FLAGS.image_size, num_forwards=0, verbose=False) if udist.is_master(): logging.info('Model Shrink to FLOPS: {}'.format(model.n_macs)) logging.info('Current model: {}'.format(mb.output_network(model)))
def wrap(*args, **kw): if is_master(): ts = time.time() result = f(*args, **kw) te = time.time() mprint('func:{!r} took: {:2.4f} sec'.format(f.__name__, te-ts)) else: result = f(*args, **kw) return result
def run(self, epoch, loader, model, FLAGS): model.eval() dataset = loader.dataset data_iterator = iter(loader) results = [] if udist.is_master(): prog_bar = mmcv.ProgressBar(len(dataset)) for batch_idx, input in enumerate(data_iterator): imgs = input['img'] img_metas = input['img_metas'][0].data assert len(imgs) == len(img_metas) for img_meta in img_metas: ori_shapes = [_['ori_shape'] for _ in img_meta] assert all(shape == ori_shapes[0] for shape in ori_shapes) img_shapes = [_['img_shape'] for _ in img_meta] assert all(shape == img_shapes[0] for shape in img_shapes) pad_shapes = [_['pad_shape'] for _ in img_meta] assert all(shape == pad_shapes[0] for shape in pad_shapes) if len(imgs) == 1: result = self.simple_test( model, imgs[0].cuda() if FLAGS.single_gpu_test else imgs[0], img_metas[0]) else: result = self.aug_test(model, imgs, img_metas) results.extend(result) if udist.is_master(): batch_size = imgs[0].size(0) world_size = 1 if FLAGS.single_gpu_test else get_world_size() for _ in range(batch_size * world_size): prog_bar.update() if not FLAGS.single_gpu_test: results = collect_results_cpu(results, len(dataset)) performance = None if udist.is_master(): performance = dataset.evaluate(results) dist.barrier() # dist.broadcast(performance, 0) return performance
def val(): """Validation.""" torch.backends.cudnn.benchmark = True # model model, model_wrapper = mc.get_model() ema = mc.setup_ema(model) criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda() # TODO(meijieru): cal loss on all GPUs instead only `cuda:0` when non # distributed # check pretrained if FLAGS.pretrained: checkpoint = torch.load(FLAGS.pretrained, map_location=lambda storage, loc: storage) if ema: ema.load_state_dict(checkpoint['ema']) ema.to(get_device(model)) model_wrapper.load_state_dict(checkpoint['model']) logging.info('Loaded model {}.'.format(FLAGS.pretrained)) if udist.is_master(): logging.info(model_wrapper) # data (train_transforms, val_transforms, test_transforms) = \ dataflow.data_transforms(FLAGS) (train_set, val_set, test_set) = dataflow.dataset(train_transforms, val_transforms, test_transforms, FLAGS) _, calib_loader, _, test_loader = dataflow.data_loader( train_set, val_set, test_set, FLAGS) if udist.is_master(): logging.info('Start testing.') FLAGS._global_step = 0 test_meters = mc.get_meters('test') validate(0, calib_loader, test_loader, criterion, test_meters, model_wrapper, ema, 'test') return
def log_pruned_info(model, flops_pruned, infos, prune_threshold): """Log pruning-related information.""" if udist.is_master(): logging.info('Flops threshold: {}'.format(prune_threshold)) for info in infos: if FLAGS.prune_params['logging_verbose']: logging.info( 'layer {}, total channel: {}, pruned channel: {}, flops' ' total: {}, flops pruned: {}, pruned rate: {:.3f}'.format( *info)) mc.summary_writer.add_scalar( 'prune_ratio/{}/{}'.format(prune_threshold, info[0]), info[-1], FLAGS._global_step) logging.info('Pruned model: {}'.format( prune.output_searched_network(model, infos, FLAGS.prune_params))) flops_remain = model.n_macs - flops_pruned if udist.is_master(): logging.info( 'Prune threshold: {}, flops pruned: {}, flops remain: {}'.format( prune_threshold, flops_pruned, flops_remain)) mc.summary_writer.add_scalar('prune/flops/{}'.format(prune_threshold), flops_remain, FLAGS._global_step)
def main(): """Entry.""" # init distributed global is_root_rank if FLAGS.use_distributed: udist.init_dist() FLAGS.batch_size = udist.get_world_size() * FLAGS.per_gpu_batch_size FLAGS._loader_batch_size = FLAGS.per_gpu_batch_size if FLAGS.bn_calibration: FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size FLAGS.data_loader_workers = round(FLAGS.data_loader_workers / udist.get_local_size()) is_root_rank = udist.is_master() else: count = torch.cuda.device_count() FLAGS.batch_size = count * FLAGS.per_gpu_batch_size FLAGS._loader_batch_size = FLAGS.batch_size if FLAGS.bn_calibration: FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size * count is_root_rank = True FLAGS.lr = FLAGS.base_lr * (FLAGS.batch_size / FLAGS.base_total_batch) # NOTE: don't drop last batch, thus must use ceil, otherwise learning rate # will be negative FLAGS._steps_per_epoch = int(np.ceil(NUM_IMAGENET_TRAIN / FLAGS.batch_size)) if is_root_rank: FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir( FLAGS.log_dir, FLAGS.config_path, blacklist_dirs=[ 'exp', '.git', 'pretrained', 'tmp', 'deprecated', 'bak', ], ) setup_logging(FLAGS.log_dir) for k, v in _ENV_EXPAND.items(): logging.info('Env var expand: {} to {}'.format(k, v)) logging.info(FLAGS) set_random_seed(FLAGS.get('random_seed', 0)) with SummaryWriterManager(): train_val_test()
def main(): """Entry.""" FLAGS.test_only = True mc.setup_distributed() if udist.is_master(): FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir, time.strftime("%Y%m%d-%H%M%S-eval")) setup_logging(FLAGS.log_dir) for k, v in _ENV_EXPAND.items(): logging.info('Env var expand: {} to {}'.format(k, v)) logging.info(FLAGS) set_random_seed(FLAGS.get('random_seed', 0)) with mc.SummaryWriterManager(): val()
def main(): """Entry.""" NUM_IMAGENET_TRAIN = 1281167 if FLAGS.dataset == 'cityscapes': NUM_IMAGENET_TRAIN = 2975 elif FLAGS.dataset == 'ade20k': NUM_IMAGENET_TRAIN = 20210 elif FLAGS.dataset == 'coco': NUM_IMAGENET_TRAIN = 149813 mc.setup_distributed(NUM_IMAGENET_TRAIN) if FLAGS.net_params and FLAGS.model_kwparams.task == 'segmentation': tag, input_channels, block1, block2, block3, block4, last_channel = FLAGS.net_params.split( '-') input_channels = [int(item) for item in input_channels.split('_')] block1 = [int(item) for item in block1.split('_')] block2 = [int(item) for item in block2.split('_')] block3 = [int(item) for item in block3.split('_')] block4 = [int(item) for item in block4.split('_')] last_channel = int(last_channel) inverted_residual_setting = [] for item in [block1, block2, block3, block4]: for _ in range(item[0]): inverted_residual_setting.append([ item[1], item[2:-int(len(item) / 2 - 1)], item[-int(len(item) / 2 - 1):] ]) FLAGS.model_kwparams.input_channel = input_channels FLAGS.model_kwparams.inverted_residual_setting = inverted_residual_setting FLAGS.model_kwparams.last_channel = last_channel if udist.is_master(): FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir, time.strftime("%Y%m%d-%H%M%S")) # yapf: disable create_exp_dir(FLAGS.log_dir, FLAGS.config_path, blacklist_dirs=[ 'exp', '.git', 'pretrained', 'tmp', 'deprecated', 'bak', 'output']) # yapf: enable setup_logging(FLAGS.log_dir) for k, v in _ENV_EXPAND.items(): logging.info('Env var expand: {} to {}'.format(k, v)) logging.info(FLAGS) set_random_seed(FLAGS.get('random_seed', 0)) with mc.SummaryWriterManager(): train_val_test()
def run_one_epoch(epoch, loader, model, criterion, optimizer, lr_scheduler, ema, meters, max_iter=None, phase='train'): """Run one epoch.""" assert phase in ['val', 'test', 'bn_calibration' ], "phase not be in val/test/bn_calibration." model.eval() if phase == 'bn_calibration': model.apply(bn_calibration) if FLAGS.use_distributed: loader.sampler.set_epoch(epoch) data_iterator = iter(loader) if FLAGS.use_distributed: data_fetcher = dataflow.DataPrefetcher(data_iterator) else: # TODO(meijieru): prefetch for non distributed logging.warning('Not use prefetcher') data_fetcher = data_iterator for batch_idx, (input, target) in enumerate(data_fetcher): # used for bn calibration if max_iter is not None: assert phase == 'bn_calibration' if batch_idx >= max_iter: break target = target.cuda(non_blocking=True) mc.forward_loss(model, criterion, input, target, meters) results = mc.reduce_and_flush_meters(meters) if udist.is_master(): logging.info( 'Epoch {}/{} {}: '.format(epoch, FLAGS.num_epochs, phase) + ', '.join('{}: {:.4f}'.format(k, v) for k, v in results.items())) for k, v in results.items(): mc.summary_writer.add_scalar('{}/{}'.format(phase, k), v, FLAGS._global_step) return results
def init_weights(self): if udist.is_master(): logging.info('=> init weights from normal distribution') for m in self.modules(): if isinstance(m, nn.Conv2d): if not self.initial_for_heatmap: nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') else: nn.init.normal_(m.weight, std=0.001) for name, _ in m.named_parameters(): if name in ['bias']: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)
def dump_flops_stats(iteration, config, model): sample_flops = [] num_sample = config.arch.num_flops_stats_sample for _ in range(num_sample): model.module.direct_sampling() cur_flops = calc_model_flops(model, config.dataset.input_size) sample_flops.append(cur_flops) if dist.is_master(): save_folder = os.path.join(config.save_path, 'flops_stats') if not os.path.exists(save_folder): os.makedirs(save_folder) plt.hist(sample_flops, 50, density=True, facecolor='g', alpha=0.75) plt.axvline(x=np.mean(sample_flops), color='r', linestyle='--') pp = PdfPages(os.path.join(save_folder, 'flops_stats_{}.pdf'.format(iteration))) plt.savefig(pp, format='pdf') pp.close() plt.gcf().clear()
def __init__(self, inp, oup, stride, expand_ratio, kernel_sizes, active_fn=None, batch_norm_kwargs=None, **kwargs): def _expand_ratio_to_hiddens(expand_ratio): if isinstance(expand_ratio, list): assert len(expand_ratio) == len(kernel_sizes) expand = True elif isinstance(expand_ratio, numbers.Number): expand = expand_ratio != 1 expand_ratio = [expand_ratio for _ in kernel_sizes] else: raise ValueError( 'Unknown expand_ratio type: {}'.format(expand_ratio)) hidden_dims = [int(round(inp * e)) for e in expand_ratio] return hidden_dims, expand hidden_dims, expand = _expand_ratio_to_hiddens(expand_ratio) if checkpoint_kwparams: assert oup == checkpoint_kwparams[0][0] if udist.is_master(): logging.info('loading: {} -> {}, {} -> {}'.format( hidden_dims, checkpoint_kwparams[0][4], kernel_sizes, checkpoint_kwparams[0][3])) hidden_dims = checkpoint_kwparams[0][4] kernel_sizes = checkpoint_kwparams[0][3] checkpoint_kwparams.pop(0) super(InvertedResidual, self).__init__(inp, oup, stride, hidden_dims, kernel_sizes, expand, active_fn=active_fn, batch_norm_kwargs=batch_norm_kwargs) self.expand_ratio = expand_ratio
def main(): """Entry.""" NUM_IMAGENET_TRAIN = 1281167 mc.setup_distributed(NUM_IMAGENET_TRAIN) if udist.is_master(): FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir, time.strftime("%Y%m%d-%H%M%S")) # yapf: disable create_exp_dir(FLAGS.log_dir, FLAGS.config_path, blacklist_dirs=[ 'exp', '.git', 'pretrained', 'tmp', 'deprecated', 'bak']) # yapf: enable setup_logging(FLAGS.log_dir) for k, v in _ENV_EXPAND.items(): logging.info('Env var expand: {} to {}'.format(k, v)) logging.info(FLAGS) set_random_seed(FLAGS.get('random_seed', 0)) with mc.SummaryWriterManager(): train_val_test()
def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix, split): """Load annotation from directory. Args: img_dir (str): Path to image directory img_suffix (str): Suffix of images. ann_dir (str|None): Path to annotation directory. seg_map_suffix (str|None): Suffix of segmentation maps. split (str|None): Split txt file. If split is specified, only file with suffix in the splits will be loaded. Otherwise, all images in img_dir/ann_dir will be loaded. Default: None Returns: list[dict]: All image info of dataset. """ img_infos = [] if split is not None: with open(split) as f: for line in f: img_name = line.strip() img_file = osp.join(img_dir, img_name + img_suffix) img_info = dict(filename=img_file) if ann_dir is not None: seg_map = osp.join(ann_dir, img_name + seg_map_suffix) img_info['ann'] = dict(seg_map=seg_map) img_infos.append(img_info) else: for img in mmcv.scandir(img_dir, img_suffix, recursive=True): img_file = osp.join(img_dir, img) img_info = dict(filename=img_file) if ann_dir is not None: seg_map = osp.join(ann_dir, img.replace(img_suffix, seg_map_suffix)) img_info['ann'] = dict(seg_map=seg_map) img_infos.append(img_info) if udist.is_master(): print(f'Loaded {len(img_infos)} images') return img_infos
def get_model(): """Build and init model with wrapper for parallel.""" model_lib = importlib.import_module(FLAGS.model) model = model_lib.Model(**FLAGS.model_kwparams, input_size=FLAGS.image_size) if FLAGS.reset_parameters: init_method = FLAGS.get('reset_param_method', None) if init_method is None: pass # fall back to model's initialization elif init_method == 'slimmable': model.apply(mb.init_weights_slimmable) elif init_method == 'mnas': model.apply(mb.init_weights_mnas) else: raise ValueError('Unknown init method: {}'.format(init_method)) if udist.is_master(): logging.info('Init model by: {}'.format(init_method)) if FLAGS.use_distributed: model_wrapper = udist.AllReduceDistributedDataParallel(model.cuda()) else: model_wrapper = torch.nn.DataParallel(model).cuda() return model, model_wrapper
def setup_ema(model): """Setup EMA for model's weights.""" from utils import optim ema = None if FLAGS.moving_average_decay > 0.0: if FLAGS.moving_average_decay_adjust: moving_average_decay = \ optim.ExponentialMovingAverage.adjust_momentum( FLAGS.moving_average_decay, FLAGS.moving_average_decay_base_batch / FLAGS.batch_size) else: moving_average_decay = FLAGS.moving_average_decay if udist.is_master(): logging.info('Moving average for model parameters: {}'.format( moving_average_decay)) ema = optim.ExponentialMovingAverage(moving_average_decay) for name, param in model.named_parameters(): ema.register(name, param) # We maintain mva for batch norm moving mean and variance as well. for name, buffer in model.named_buffers(): if 'running_var' in name or 'running_mean' in name: ema.register(name, buffer) return ema
def run_one_epoch(epoch, loader, model, criterion, optimizer, lr_scheduler, ema, meters, max_iter=None, phase='train'): """Run one epoch.""" assert phase in ['train', 'val', 'test', 'bn_calibration' ], "phase not be in train/val/test/bn_calibration." train = phase == 'train' if train: model.train() else: model.eval() if phase == 'bn_calibration': model.apply(bn_calibration) if FLAGS.use_distributed: loader.sampler.set_epoch(epoch) results = None data_iterator = iter(loader) if FLAGS.use_distributed: data_fetcher = dataflow.DataPrefetcher(data_iterator) else: # TODO(meijieru): prefetch for non distributed logging.warning('Not use prefetcher') data_fetcher = data_iterator for batch_idx, (input, target) in enumerate(data_fetcher): # used for bn calibration if max_iter is not None: assert phase == 'bn_calibration' if batch_idx >= max_iter: break target = target.cuda(non_blocking=True) if train: optimizer.zero_grad() loss = mc.forward_loss(model, criterion, input, target, meters) loss_l2 = optim.cal_l2_loss(model, FLAGS.weight_decay, FLAGS.weight_decay_method) loss = loss + loss_l2 loss.backward() if FLAGS.use_distributed: udist.allreduce_grads(model) if FLAGS._global_step % FLAGS.log_interval == 0: results = mc.reduce_and_flush_meters(meters) if udist.is_master(): logging.info('Epoch {}/{} Iter {}/{} {}: '.format( epoch, FLAGS.num_epochs, batch_idx, len(loader), phase) + ', '.join('{}: {:.4f}'.format(k, v) for k, v in results.items())) for k, v in results.items(): mc.summary_writer.add_scalar('{}/{}'.format(phase, k), v, FLAGS._global_step) if udist.is_master( ) and FLAGS._global_step % FLAGS.log_interval == 0: mc.summary_writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], FLAGS._global_step) mc.summary_writer.add_scalar('train/l2_regularize_loss', extract_item(loss_l2), FLAGS._global_step) mc.summary_writer.add_scalar( 'train/current_epoch', FLAGS._global_step / FLAGS._steps_per_epoch, FLAGS._global_step) if FLAGS.data_loader_workers > 0: mc.summary_writer.add_scalar( 'data/train/prefetch_size', get_data_queue_size(data_iterator), FLAGS._global_step) optimizer.step() lr_scheduler.step() if FLAGS.use_distributed and FLAGS.allreduce_bn: udist.allreduce_bn(model) FLAGS._global_step += 1 # NOTE: after steps count upate if ema is not None: model_unwrap = mc.unwrap_model(model) ema_names = ema.average_names() params = get_params_by_name(model_unwrap, ema_names) for name, param in zip(ema_names, params): ema(name, param, FLAGS._global_step) else: mc.forward_loss(model, criterion, input, target, meters) if not train: results = mc.reduce_and_flush_meters(meters) if udist.is_master(): logging.info( 'Epoch {}/{} {}: '.format(epoch, FLAGS.num_epochs, phase) + ', '.join('{}: {:.4f}'.format(k, v) for k, v in results.items())) for k, v in results.items(): mc.summary_writer.add_scalar('{}/{}'.format(phase, k), v, FLAGS._global_step) return results
def train_val_test(): """Train and val.""" torch.backends.cudnn.benchmark = True # model model, model_wrapper = mc.get_model() ema = mc.setup_ema(model) criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda() criterion_smooth = optim.CrossEntropyLabelSmooth( FLAGS.model_kwparams['num_classes'], FLAGS['label_smoothing'], reduction='none').cuda() # TODO(meijieru): cal loss on all GPUs instead only `cuda:0` when non # distributed if FLAGS.get('log_graph_only', False): if udist.is_master(): _input = torch.zeros(1, 3, FLAGS.image_size, FLAGS.image_size).cuda() _input = _input.requires_grad_(True) mc.summary_writer.add_graph(model_wrapper, (_input, ), verbose=True) return # check pretrained if FLAGS.pretrained: checkpoint = torch.load(FLAGS.pretrained, map_location=lambda storage, loc: storage) if ema: ema.load_state_dict(checkpoint['ema']) ema.to(get_device(model)) # update keys from external models if isinstance(checkpoint, dict) and 'model' in checkpoint: checkpoint = checkpoint['model'] if (hasattr(FLAGS, 'pretrained_model_remap_keys') and FLAGS.pretrained_model_remap_keys): new_checkpoint = {} new_keys = list(model_wrapper.state_dict().keys()) old_keys = list(checkpoint.keys()) for key_new, key_old in zip(new_keys, old_keys): new_checkpoint[key_new] = checkpoint[key_old] logging.info('remap {} to {}'.format(key_new, key_old)) checkpoint = new_checkpoint model_wrapper.load_state_dict(checkpoint) logging.info('Loaded model {}.'.format(FLAGS.pretrained)) optimizer = optim.get_optimizer(model_wrapper, FLAGS) # check resume training if FLAGS.resume: checkpoint = torch.load(os.path.join(FLAGS.resume, 'latest_checkpoint.pt'), map_location=lambda storage, loc: storage) model_wrapper.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) if ema: ema.load_state_dict(checkpoint['ema']) ema.to(get_device(model)) last_epoch = checkpoint['last_epoch'] lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS) lr_scheduler.last_epoch = (last_epoch + 1) * FLAGS._steps_per_epoch best_val = extract_item(checkpoint['best_val']) train_meters, val_meters = checkpoint['meters'] FLAGS._global_step = (last_epoch + 1) * FLAGS._steps_per_epoch if udist.is_master(): logging.info('Loaded checkpoint {} at epoch {}.'.format( FLAGS.resume, last_epoch)) else: lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS) # last_epoch = lr_scheduler.last_epoch last_epoch = -1 best_val = 1. train_meters = mc.get_meters('train') val_meters = mc.get_meters('val') FLAGS._global_step = 0 if not FLAGS.resume and udist.is_master(): logging.info(model_wrapper) if FLAGS.profiling: if 'gpu' in FLAGS.profiling: mc.profiling(model, use_cuda=True) if 'cpu' in FLAGS.profiling: mc.profiling(model, use_cuda=False) # data (train_transforms, val_transforms, test_transforms) = dataflow.data_transforms(FLAGS) (train_set, val_set, test_set) = dataflow.dataset(train_transforms, val_transforms, test_transforms, FLAGS) (train_loader, calib_loader, val_loader, test_loader) = dataflow.data_loader(train_set, val_set, test_set, FLAGS) if FLAGS.test_only and (test_loader is not None): if udist.is_master(): logging.info('Start testing.') test_meters = mc.get_meters('test') validate(last_epoch, calib_loader, test_loader, criterion, test_meters, model_wrapper, ema, 'test') return # already broadcast by AllReduceDistributedDataParallel # optimizer load same checkpoint/same initialization if udist.is_master(): logging.info('Start training.') for epoch in range(last_epoch + 1, FLAGS.num_epochs): # train results = run_one_epoch(epoch, train_loader, model_wrapper, criterion_smooth, optimizer, lr_scheduler, ema, train_meters, phase='train') # val results = validate(epoch, calib_loader, val_loader, criterion, val_meters, model_wrapper, ema, 'val') if results['top1_error'] < best_val: best_val = results['top1_error'] if udist.is_master(): save_status(model_wrapper, optimizer, ema, epoch, best_val, (train_meters, val_meters), os.path.join(FLAGS.log_dir, 'best_model.pt')) logging.info( 'New best validation top1 error: {:.4f}'.format(best_val)) if udist.is_master(): # save latest checkpoint save_status(model_wrapper, optimizer, ema, epoch, best_val, (train_meters, val_meters), os.path.join(FLAGS.log_dir, 'latest_checkpoint.pt')) wandb.log( { "Validation Accuracy": 1. - results['top1_error'], "Best Validation Accuracy": 1. - best_val }, step=epoch) # NOTE(meijieru): from scheduler code, should be called after train/val # use stepwise scheduler instead # lr_scheduler.step() return
def validate(epoch, calib_loader, val_loader, criterion, val_meters, model_wrapper, ema, phase, segval=None, val_set=None): """Calibrate and validate.""" assert phase in ['test', 'val'] model_eval_wrapper = mc.get_ema_model(ema, model_wrapper) # bn_calibration if FLAGS.prune_params['method'] is not None: if FLAGS.get('bn_calibration', False): if not FLAGS.use_distributed: logging.warning( 'Only GPU0 is used when calibration when use DataParallel') with torch.no_grad(): _ = run_one_epoch(epoch, calib_loader, model_eval_wrapper, criterion, None, None, None, None, val_meters, max_iter=FLAGS.bn_calibration_steps, phase='bn_calibration') if FLAGS.use_distributed: udist.allreduce_bn(model_eval_wrapper) # val with torch.no_grad(): if FLAGS.model_kwparams.task == 'segmentation': if FLAGS.dataset == 'coco': results = 0 if udist.is_master(): results = keypoint_val(val_set, val_loader, model_eval_wrapper.module, criterion) else: assert segval is not None results = segval.run( epoch, val_loader, model_eval_wrapper.module if FLAGS.single_gpu_test else model_eval_wrapper, FLAGS) else: results = run_one_epoch(epoch, val_loader, model_eval_wrapper, criterion, None, None, None, None, val_meters, phase=phase) summary_bn(model_eval_wrapper, phase) return results, model_eval_wrapper
def train_val_test(): """Train and val.""" torch.backends.cudnn.benchmark = True # For acceleration # model model, model_wrapper = mc.get_model() ema = mc.setup_ema(model) criterion = torch.nn.CrossEntropyLoss(reduction='mean').cuda() criterion_smooth = optim.CrossEntropyLabelSmooth( FLAGS.model_kwparams['num_classes'], FLAGS['label_smoothing'], reduction='mean').cuda() if model.task == 'segmentation': criterion = CrossEntropyLoss().cuda() criterion_smooth = CrossEntropyLoss().cuda() if FLAGS.dataset == 'coco': criterion = JointsMSELoss(use_target_weight=True).cuda() criterion_smooth = JointsMSELoss(use_target_weight=True).cuda() if FLAGS.get('log_graph_only', False): if udist.is_master(): _input = torch.zeros(1, 3, FLAGS.image_size, FLAGS.image_size).cuda() _input = _input.requires_grad_(True) if isinstance(model_wrapper, (torch.nn.DataParallel, udist.AllReduceDistributedDataParallel)): mc.summary_writer.add_graph(model_wrapper.module, (_input, ), verbose=True) else: mc.summary_writer.add_graph(model_wrapper, (_input, ), verbose=True) return # check pretrained if FLAGS.pretrained: checkpoint = torch.load(FLAGS.pretrained, map_location=lambda storage, loc: storage) if ema: ema.load_state_dict(checkpoint['ema']) ema.to(get_device(model)) # update keys from external models if isinstance(checkpoint, dict) and 'model' in checkpoint: checkpoint = checkpoint['model'] if (hasattr(FLAGS, 'pretrained_model_remap_keys') and FLAGS.pretrained_model_remap_keys): new_checkpoint = {} new_keys = list(model_wrapper.state_dict().keys()) old_keys = list(checkpoint.keys()) for key_new, key_old in zip(new_keys, old_keys): new_checkpoint[key_new] = checkpoint[key_old] if udist.is_master(): logging.info('remap {} to {}'.format(key_new, key_old)) checkpoint = new_checkpoint model_wrapper.load_state_dict(checkpoint) if udist.is_master(): logging.info('Loaded model {}.'.format(FLAGS.pretrained)) optimizer = optim.get_optimizer(model_wrapper, FLAGS) # check resume training if FLAGS.resume: checkpoint = torch.load(os.path.join(FLAGS.resume, 'latest_checkpoint.pt'), map_location=lambda storage, loc: storage) model_wrapper = checkpoint['model'].cuda() model = model_wrapper.module # model = checkpoint['model'].module optimizer = checkpoint['optimizer'] for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() # model_wrapper.load_state_dict(checkpoint['model']) # optimizer.load_state_dict(checkpoint['optimizer']) if ema: # ema.load_state_dict(checkpoint['ema']) ema = checkpoint['ema'].cuda() ema.to(get_device(model)) last_epoch = checkpoint['last_epoch'] lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS, last_epoch=(last_epoch + 1) * FLAGS._steps_per_epoch) lr_scheduler.last_epoch = (last_epoch + 1) * FLAGS._steps_per_epoch best_val = extract_item(checkpoint['best_val']) train_meters, val_meters = checkpoint['meters'] FLAGS._global_step = (last_epoch + 1) * FLAGS._steps_per_epoch if udist.is_master(): logging.info('Loaded checkpoint {} at epoch {}.'.format( FLAGS.resume, last_epoch)) else: lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS) # last_epoch = lr_scheduler.last_epoch last_epoch = -1 best_val = 1. if not FLAGS.distill: train_meters = mc.get_meters('train', FLAGS.prune_params['method']) val_meters = mc.get_meters('val') else: train_meters = mc.get_distill_meters('train', FLAGS.prune_params['method']) val_meters = mc.get_distill_meters('val') if FLAGS.model_kwparams.task == 'segmentation': best_val = 0. if not FLAGS.distill: train_meters = mc.get_seg_meters('train', FLAGS.prune_params['method']) val_meters = mc.get_seg_meters('val') else: train_meters = mc.get_seg_distill_meters( 'train', FLAGS.prune_params['method']) val_meters = mc.get_seg_distill_meters('val') FLAGS._global_step = 0 if not FLAGS.resume and udist.is_master(): logging.info(model_wrapper) assert FLAGS.profiling, '`m.macs` is used for calculating penalty' # if udist.is_master(): # model.apply(lambda m: print(m)) if FLAGS.profiling: if 'gpu' in FLAGS.profiling: mc.profiling(model, use_cuda=True) if 'cpu' in FLAGS.profiling: mc.profiling(model, use_cuda=False) if FLAGS.dataset == 'cityscapes': (train_set, val_set, test_set) = seg_dataflow.cityscapes_datasets(FLAGS) segval = SegVal(num_classes=19) elif FLAGS.dataset == 'ade20k': (train_set, val_set, test_set) = seg_dataflow.ade20k_datasets(FLAGS) segval = SegVal(num_classes=150) elif FLAGS.dataset == 'coco': (train_set, val_set, test_set) = seg_dataflow.coco_datasets(FLAGS) # print(len(train_set), len(val_set)) # 149813 104125 segval = None else: # data (train_transforms, val_transforms, test_transforms) = dataflow.data_transforms(FLAGS) (train_set, val_set, test_set) = dataflow.dataset(train_transforms, val_transforms, test_transforms, FLAGS) segval = None (train_loader, calib_loader, val_loader, test_loader) = dataflow.data_loader(train_set, val_set, test_set, FLAGS) # get bn's weights if FLAGS.prune_params.use_transformer: FLAGS._bn_to_prune, FLAGS._bn_to_prune_transformer = prune.get_bn_to_prune( model, FLAGS.prune_params) else: FLAGS._bn_to_prune = prune.get_bn_to_prune(model, FLAGS.prune_params) rho_scheduler = prune.get_rho_scheduler(FLAGS.prune_params, FLAGS._steps_per_epoch) if FLAGS.test_only and (test_loader is not None): if udist.is_master(): logging.info('Start testing.') test_meters = mc.get_meters('test') validate(last_epoch, calib_loader, test_loader, criterion, test_meters, model_wrapper, ema, 'test') return # already broadcast by AllReduceDistributedDataParallel # optimizer load same checkpoint/same initialization if udist.is_master(): logging.info('Start training.') for epoch in range(last_epoch + 1, FLAGS.num_epochs): # train results = run_one_epoch(epoch, train_loader, model_wrapper, criterion_smooth, optimizer, lr_scheduler, ema, rho_scheduler, train_meters, phase='train') if (epoch + 1) % FLAGS.eval_interval == 0: # val results, model_eval_wrapper = validate(epoch, calib_loader, val_loader, criterion, val_meters, model_wrapper, ema, 'val', segval, val_set) if FLAGS.prune_params['method'] is not None and FLAGS.prune_params[ 'bn_prune_filter'] is not None: prune_threshold = FLAGS.model_shrink_threshold # 1e-3 masks = prune.cal_mask_network_slimming_by_threshold( get_prune_weights(model_eval_wrapper), prune_threshold ) # get mask for all bn weights (depth-wise) FLAGS._bn_to_prune.add_info_list('mask', masks) flops_pruned, infos = prune.cal_pruned_flops( FLAGS._bn_to_prune) log_pruned_info(mc.unwrap_model(model_eval_wrapper), flops_pruned, infos, prune_threshold) if not FLAGS.distill: if flops_pruned >= FLAGS.model_shrink_delta_flops \ or epoch == FLAGS.num_epochs - 1: ema_only = (epoch == FLAGS.num_epochs - 1) shrink_model(model_wrapper, ema, optimizer, FLAGS._bn_to_prune, prune_threshold, ema_only) model_kwparams = mb.output_network(mc.unwrap_model(model_wrapper)) if udist.is_master(): if FLAGS.model_kwparams.task == 'classification' and results[ 'top1_error'] < best_val: best_val = results['top1_error'] logging.info( 'New best validation top1 error: {:.4f}'.format( best_val)) save_status(model_wrapper, model_kwparams, optimizer, ema, epoch, best_val, (train_meters, val_meters), os.path.join(FLAGS.log_dir, 'best_model')) elif FLAGS.model_kwparams.task == 'segmentation' and FLAGS.dataset != 'coco' and results[ 'mIoU'] > best_val: best_val = results['mIoU'] logging.info('New seg mIoU: {:.4f}'.format(best_val)) save_status(model_wrapper, model_kwparams, optimizer, ema, epoch, best_val, (train_meters, val_meters), os.path.join(FLAGS.log_dir, 'best_model')) elif FLAGS.dataset == 'coco' and results > best_val: best_val = results logging.info('New Result: {:.4f}'.format(best_val)) save_status(model_wrapper, model_kwparams, optimizer, ema, epoch, best_val, (train_meters, val_meters), os.path.join(FLAGS.log_dir, 'best_model')) # save latest checkpoint save_status(model_wrapper, model_kwparams, optimizer, ema, epoch, best_val, (train_meters, val_meters), os.path.join(FLAGS.log_dir, 'latest_checkpoint')) return