Exemplo n.º 1
0
def sample_model(config, model):
    num_sample = config.arch.num_model_sample
    root_dir = os.path.join(config.save_path, 'model_sample')

    if dist.is_master():
        if not os.path.exists(root_dir):
            os.makedirs(root_dir)

    for i in range(num_sample + 1):
        cur_flops = config.arch.target_flops * 10
        while cur_flops > config.arch.target_flops * 1.01 or cur_flops < config.arch.target_flops * 0.99:
            model.module.direct_sampling()
            cur_flops = calc_model_flops(model, config.dataset.input_size)

        if dist.is_master():
            if i == num_sample:
                sample_dir = os.path.join(root_dir, 'expected_ch')
                model.module.expected_sampling()
            else:
                sample_dir = os.path.join(root_dir, 'sample_{}'.format(i + 1))
            if not os.path.exists(sample_dir):
                os.makedirs(sample_dir)

            save_path = os.path.join(sample_dir, 'sample.npy')
            if config.model.type.find('MobileNetV2') > -1:
                dump_mbv2_setting(model.module, save_path)
            elif config.model.type.find('ResNet') > -1:
                dump_resnet_setting(model.module, save_path)
            else:
                raise ValueError('Unknown model: {}'.format(config.model.type))

    dist.barrier()
Exemplo n.º 2
0
def main():
    args = tools.get_args(parser)
    config = tools.get_config(args)
    tools.init(config)
    tb_logger, logger = tools.get_logger(config)
    tools.check_dist_init(config, logger)

    checkpoint = tools.get_checkpoint(config)
    runner = tools.get_model(config, checkpoint)
    loaders = tools.get_data_loader(config)

    if dist.is_master():
        logger.info(config)

    if args.mode == 'train':
        train(config, runner, loaders, checkpoint, tb_logger)
    elif args.mode == 'evaluate':
        evaluate(runner, loaders)
    elif args.mode == 'calc_flops':
        if dist.is_master():
            flops = tools.get_model_flops(config, runner.get_model())
            logger.info('flops: {}'.format(flops))
    elif args.mode == 'calc_params':
        if dist.is_master():
            params = tools.get_model_parameters(runner.get_model())
            logger.info('params: {}'.format(params))
    else:
        assert checkpoint is not None
        from models.dmcp.utils import sample_model
        sample_model(config, runner.get_model())

    if dist.is_master():
        logger.info('Done')
Exemplo n.º 3
0
def get_logger(config, name='global_logger'):
    save_dir = config.model.type + '_'
    if config.get('arch', False):
        save_dir += str(config.arch.target_flops) + '_'
    save_dir = time.strftime(save_dir + '%m%d%H')
    save_dir = os.path.join(config.save_path, save_dir)

    if dist.is_master():
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
    else:
        while not os.path.exists(save_dir):
            time.sleep(1)
    config.save_path = save_dir

    events_dir = config.save_path + '/events'
    if dist.is_master():
        if not os.path.exists(events_dir):
            os.makedirs(events_dir)
    else:
        while not os.path.exists(events_dir):
            time.sleep(1)

    tb_logger = SummaryWriter(config.save_path + '/events')
    logger = logging.getLogger(name)
    formatter = logging.Formatter('[%(asctime)s][%(filename)15s][line:%(lineno)4d][%(levelname)8s] %(message)s')
    fh = logging.FileHandler(config.save_path + '/log.txt')
    fh.setFormatter(formatter)
    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.setLevel(logging.INFO)
    logger.addHandler(fh)
    logger.addHandler(sh)

    return tb_logger, logger
Exemplo n.º 4
0
def profiling(model, use_cuda):
    """Profiling on either gpu or cpu."""
    if udist.is_master():
        logging.info('Start model profiling, use_cuda:{}.'.format(use_cuda))
    model_profiling(model,
                    FLAGS.image_size,
                    FLAGS.image_size,
                    verbose=getattr(FLAGS, 'model_profiling_verbose', True)
                    and udist.is_master())
Exemplo n.º 5
0
def layer_flops_distribution(config, model):
    num_sample = config.arch.num_flops_stats_sample
    repo = {}

    for _ in range(num_sample):
        cur_flops = config.arch.target_flops * 10
        while cur_flops > config.arch.target_flops * 1.05 or cur_flops < config.arch.target_flops * 0.95:
            model.module.direct_sampling()
            cur_flops = calc_model_flops(model, config.dataset.input_size)
            for n, m in model.named_modules():
                if isinstance(m, USModule):
                    if n not in repo.keys():
                        repo[n] = []
                    repo[n].append(m.cur_out_ch)

    if dist.is_master():
        root_dir = os.path.join(config.save_path, 'layer_flops_distribution')
        if not os.path.exists(root_dir):
            os.makedirs(root_dir)
        for n in repo.keys():
            save_path = os.path.join(root_dir, n + '.pdf')
            plt.hist(repo[n], 50, density=True, facecolor='g', alpha=0.75)
            pp = PdfPages(save_path)
            plt.savefig(pp, format='pdf')
            pp.close()
            plt.gcf().clear()
Exemplo n.º 6
0
def compute_mean_channel(iteration, config, model):
    mean_chs = []
    offset = []
    num_sample = config.arch.num_flops_stats_sample

    for n, m in model.named_modules():
        if n.find('alpha') > -1:
            mean_ch = 0
            for i in range(num_sample):
                mean_ch += m.direct_sampling()
            mean_chs.append(round(mean_ch / num_sample, 3))
            offset.append(m.channels - mean_chs[-1])

    if dist.is_master():
        ind = np.arange(len(mean_chs))
        width = 0.35

        p1 = plt.bar(ind, mean_chs, width)
        p2 = plt.bar(ind, offset, width, bottom=mean_chs)

        plt.ylabel('#channel')
        plt.xticks(ind, ind)
        plt.yticks(np.arange(0, max([mean_chs[i] + offset[i] for i in range(len(mean_chs))]), 100))
        plt.legend((p1[0], p2[0]), ('expected', 'max'))

        save_folder = os.path.join(config.save_path, 'mean_chs')
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
        pp = PdfPages(os.path.join(save_folder, 'mean_chs_{}.pdf'.format(iteration)))
        plt.savefig(pp, format='pdf')
        pp.close()
        plt.gcf().clear()
Exemplo n.º 7
0
    def _logging(self, tb_logger, epoch_idx, batch_idx, total_batch, meters,
                 cur_lr):
        print_freq = self.config.logging.print_freq
        top1_meter, top5_meter, loss_meter, data_time, batch_time = meters

        if self.cur_step % print_freq == 0 and dist.is_master():
            tb_logger.add_scalar('lr', cur_lr, self.cur_step)
            tb_logger.add_scalar('acc1_train', top1_meter.avg, self.cur_step)
            tb_logger.add_scalar('acc5_train', top5_meter.avg, self.cur_step)
            tb_logger.add_scalar('loss_train', loss_meter.avg, self.cur_step)
            self._info('-' * 80)
            self._info('Epoch: [{0}/{1}]\tIter: [{2}/{3}]\t'
                       'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                       'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                       'LR {lr:.4f}'.format(epoch_idx,
                                            self.config.training.epoch,
                                            batch_idx,
                                            total_batch,
                                            batch_time=batch_time,
                                            data_time=data_time,
                                            lr=cur_lr))
            self._info('Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                       'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                       'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format(
                           loss=loss_meter, top1=top1_meter, top5=top5_meter))
Exemplo n.º 8
0
    def _logging(self, tb_logger, epoch_idx, batch_idx, total_batch, meters,
                 cur_lr):
        print_freq = self.config.logging.print_freq
        top1_meter, top5_meter, loss_meter, data_time, batch_time = meters

        if self.cur_step % print_freq == 0 and dist.is_master():
            tb_logger.add_scalar('lr', cur_lr, self.cur_step)
            self._info('-' * 80)
            self._info('Epoch: [{0}/{1}]\tIter: [{2}/{3}]\t'
                       'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                       'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                       'LR {lr:.4f}'.format(epoch_idx,
                                            self.config.training.epoch,
                                            batch_idx,
                                            total_batch,
                                            batch_time=batch_time,
                                            data_time=data_time,
                                            lr=cur_lr))

            titles = ['min_width', 'max_width', 'random_width']
            for idx in range(3):
                tb_logger.add_scalar('loss_train@{}'.format(titles[idx]),
                                     loss_meter[idx].avg, self.cur_step)
                tb_logger.add_scalar('acc1_train@{}'.format(titles[idx]),
                                     top1_meter[idx].avg, self.cur_step)
                tb_logger.add_scalar('acc5_train@{}'.format(titles[idx]),
                                     top5_meter[idx].avg, self.cur_step)
                self._info('{title}\t'
                           'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                           'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                           'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format(
                               title=titles[idx],
                               loss=loss_meter[idx],
                               top1=top1_meter[idx],
                               top5=top5_meter[idx]))
Exemplo n.º 9
0
    def _logging(self, tb_logger, epoch_idx, batch_idx, total_batch, meters,
                 cur_lr):
        cur_lr, cur_arch_lr = cur_lr
        top1_meter, top5_meter, loss_meter, arch_loss_meter, floss_meter, \
            eflops_meter, arch_top1_meter, data_time, batch_time = meters

        super(DMCPRunner, self)._logging(
            tb_logger, epoch_idx, batch_idx, total_batch,
            [top1_meter, top5_meter, loss_meter, data_time, batch_time],
            cur_lr)

        print_freq = self.config.logging.print_freq
        if self.cur_step % print_freq == 0 and dist.is_master() \
                and self.cur_step >= self.config.arch.start_train:
            tb_logger.add_scalar('arc_loss', arch_loss_meter.avg,
                                 self.cur_step)
            tb_logger.add_scalar('flops_loss', floss_meter.avg, self.cur_step)
            tb_logger.add_scalar('eflops', eflops_meter.avg, self.cur_step)
            tb_logger.add_scalar('arc_top1', arch_top1_meter.avg,
                                 self.cur_step)

            self._info(
                'expected_flops {:.2f} flops_loss {:.4f}, arch_task_loss {:.4f}, '
                'arch_top1 {:.2f}, arch_lr {:.4f}'.format(
                    eflops_meter.avg, floss_meter.avg, arch_loss_meter.avg,
                    arch_top1_meter.avg, cur_arch_lr))
            tb_logger.add_scalar('expected_flops', eflops_meter.avg,
                                 self.cur_step)
Exemplo n.º 10
0
    def validate(self, val_loader, tb_logger=None):
        batch_time = AverageMeter(0)
        loss_meter = AverageMeter(0)
        top1_meter = AverageMeter(0)
        top5_meter = AverageMeter(0)

        self.model.eval()
        criterion = nn.CrossEntropyLoss()
        end = time.time()

        with torch.no_grad():
            for batch_idx, (x, y) in enumerate(val_loader):
                x, y = x.cuda(), y.cuda()
                num = x.size(0)

                out = self.model(x)
                loss = criterion(out, y)
                top1, top5 = accuracy(out, y, top_k=(1, 5))

                loss_meter.update(loss.item(), num)
                top1_meter.update(top1.item(), num)
                top5_meter.update(top5.item(), num)

                batch_time.update(time.time() - end)
                end = time.time()

                if batch_idx % self.config.logging.print_freq == 0:
                    self._info(
                        'Test: [{0}/{1}]\tTime {batch_time.val:.3f} ({batch_time.avg:.3f})'
                        .format(batch_idx,
                                len(val_loader),
                                batch_time=batch_time))

        total_num = torch.tensor([loss_meter.count]).cuda()
        loss_sum = torch.tensor([loss_meter.avg * loss_meter.count]).cuda()
        top1_sum = torch.tensor([top1_meter.avg * top1_meter.count]).cuda()
        top5_sum = torch.tensor([top5_meter.avg * top5_meter.count]).cuda()

        dist.all_reduce(total_num)
        dist.all_reduce(loss_sum)
        dist.all_reduce(top1_sum)
        dist.all_reduce(top5_sum)

        val_loss = loss_sum.item() / total_num.item()
        val_top1 = top1_sum.item() / total_num.item()
        val_top5 = top5_sum.item() / total_num.item()

        self._info(
            'Prec@1 {:.3f}\tPrec@5 {:.3f}\tLoss {:.3f}\ttotal_num={}'.format(
                val_top1, val_top5, val_loss, loss_meter.count))

        if dist.is_master():
            if val_top1 > self.best_top1:
                self.best_top1 = val_top1

            if tb_logger is not None:
                tb_logger.add_scalar('loss_val', val_loss, self.cur_step)
                tb_logger.add_scalar('acc1_val', val_top1, self.cur_step)
                tb_logger.add_scalar('acc5_val', val_top5, self.cur_step)
Exemplo n.º 11
0
def shrink_model(model_wrapper,
                 ema,
                 optimizer,
                 prune_info,
                 threshold=1e-3,
                 ema_only=False):
    r"""Dynamic network shrinkage to discard dead atomic blocks.

    Args:
        model_wrapper: model to be shrinked.
        ema: An instance of `ExponentialMovingAverage`, could be None.
        optimizer: Global optimizer.
        prune_info: An instance of `PruneInfo`, could be None.
        threshold: A small enough constant.
        ema_only: If `True`, regard an atomic block as dead only when
            `$$\hat{alpha} \le threshold$$`. Otherwise use both current value
            and momentum version.
    """
    model = mc.unwrap_model(model_wrapper)
    for block_name, block in model.get_named_block_list().items(
    ):  # inverted residual blocks
        assert isinstance(block, mb.InvertedResidualChannels)
        masks = [
            bn.weight.detach().abs() > threshold
            for bn in block.get_depthwise_bn()
        ]
        if ema is not None:
            masks_ema = [
                ema.average('{}.{}.weight'.format(
                    block_name, name)).detach().abs() > threshold
                for name in block.get_named_depthwise_bn().keys()
            ]
            if not ema_only:
                masks = [
                    mask0 | mask1 for mask0, mask1 in zip(masks, masks_ema)
                ]
            else:
                masks = masks_ema
        block.compress_by_mask(masks,
                               ema=ema,
                               optimizer=optimizer,
                               prune_info=prune_info,
                               prefix=block_name,
                               verbose=False)

    if optimizer is not None:
        assert set(optimizer.param_groups[0]['params']) == set(
            model.parameters())

    mc.model_profiling(model,
                       FLAGS.image_size,
                       FLAGS.image_size,
                       num_forwards=0,
                       verbose=False)
    if udist.is_master():
        logging.info('Model Shrink to FLOPS: {}'.format(model.n_macs))
        logging.info('Current model: {}'.format(mb.output_network(model)))
Exemplo n.º 12
0
 def wrap(*args, **kw):
     if is_master():
         ts = time.time()
         result = f(*args, **kw)
         te = time.time()
         mprint('func:{!r} took: {:2.4f} sec'.format(f.__name__, te-ts))
     else:
         result = f(*args, **kw)
     return result
Exemplo n.º 13
0
    def run(self, epoch, loader, model, FLAGS):
        model.eval()
        dataset = loader.dataset
        data_iterator = iter(loader)

        results = []
        if udist.is_master():
            prog_bar = mmcv.ProgressBar(len(dataset))
        for batch_idx, input in enumerate(data_iterator):
            imgs = input['img']
            img_metas = input['img_metas'][0].data
            assert len(imgs) == len(img_metas)
            for img_meta in img_metas:
                ori_shapes = [_['ori_shape'] for _ in img_meta]
                assert all(shape == ori_shapes[0] for shape in ori_shapes)
                img_shapes = [_['img_shape'] for _ in img_meta]
                assert all(shape == img_shapes[0] for shape in img_shapes)
                pad_shapes = [_['pad_shape'] for _ in img_meta]
                assert all(shape == pad_shapes[0] for shape in pad_shapes)

            if len(imgs) == 1:
                result = self.simple_test(
                    model,
                    imgs[0].cuda() if FLAGS.single_gpu_test else imgs[0],
                    img_metas[0])
            else:
                result = self.aug_test(model, imgs, img_metas)
            results.extend(result)
            if udist.is_master():
                batch_size = imgs[0].size(0)
                world_size = 1 if FLAGS.single_gpu_test else get_world_size()
                for _ in range(batch_size * world_size):
                    prog_bar.update()
        if not FLAGS.single_gpu_test:
            results = collect_results_cpu(results, len(dataset))
        performance = None
        if udist.is_master():
            performance = dataset.evaluate(results)
        dist.barrier()
        # dist.broadcast(performance, 0)
        return performance
Exemplo n.º 14
0
def val():
    """Validation."""
    torch.backends.cudnn.benchmark = True

    # model
    model, model_wrapper = mc.get_model()
    ema = mc.setup_ema(model)
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    # TODO(meijieru): cal loss on all GPUs instead only `cuda:0` when non
    # distributed

    # check pretrained
    if FLAGS.pretrained:
        checkpoint = torch.load(FLAGS.pretrained,
                                map_location=lambda storage, loc: storage)
        if ema:
            ema.load_state_dict(checkpoint['ema'])
            ema.to(get_device(model))
        model_wrapper.load_state_dict(checkpoint['model'])
        logging.info('Loaded model {}.'.format(FLAGS.pretrained))

    if udist.is_master():
        logging.info(model_wrapper)

    # data
    (train_transforms, val_transforms, test_transforms) = \
        dataflow.data_transforms(FLAGS)
    (train_set, val_set, test_set) = dataflow.dataset(train_transforms,
                                                      val_transforms,
                                                      test_transforms, FLAGS)
    _, calib_loader, _, test_loader = dataflow.data_loader(
        train_set, val_set, test_set, FLAGS)

    if udist.is_master():
        logging.info('Start testing.')
    FLAGS._global_step = 0
    test_meters = mc.get_meters('test')
    validate(0, calib_loader, test_loader, criterion, test_meters,
             model_wrapper, ema, 'test')
    return
Exemplo n.º 15
0
def log_pruned_info(model, flops_pruned, infos, prune_threshold):
    """Log pruning-related information."""
    if udist.is_master():
        logging.info('Flops threshold: {}'.format(prune_threshold))
        for info in infos:
            if FLAGS.prune_params['logging_verbose']:
                logging.info(
                    'layer {}, total channel: {}, pruned channel: {}, flops'
                    ' total: {}, flops pruned: {}, pruned rate: {:.3f}'.format(
                        *info))
            mc.summary_writer.add_scalar(
                'prune_ratio/{}/{}'.format(prune_threshold, info[0]), info[-1],
                FLAGS._global_step)
        logging.info('Pruned model: {}'.format(
            prune.output_searched_network(model, infos, FLAGS.prune_params)))

    flops_remain = model.n_macs - flops_pruned
    if udist.is_master():
        logging.info(
            'Prune threshold: {}, flops pruned: {}, flops remain: {}'.format(
                prune_threshold, flops_pruned, flops_remain))
        mc.summary_writer.add_scalar('prune/flops/{}'.format(prune_threshold),
                                     flops_remain, FLAGS._global_step)
Exemplo n.º 16
0
def main():
    """Entry."""
    # init distributed
    global is_root_rank
    if FLAGS.use_distributed:
        udist.init_dist()
        FLAGS.batch_size = udist.get_world_size() * FLAGS.per_gpu_batch_size
        FLAGS._loader_batch_size = FLAGS.per_gpu_batch_size
        if FLAGS.bn_calibration:
            FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size
        FLAGS.data_loader_workers = round(FLAGS.data_loader_workers /
                                          udist.get_local_size())
        is_root_rank = udist.is_master()
    else:
        count = torch.cuda.device_count()
        FLAGS.batch_size = count * FLAGS.per_gpu_batch_size
        FLAGS._loader_batch_size = FLAGS.batch_size
        if FLAGS.bn_calibration:
            FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size * count
        is_root_rank = True
    FLAGS.lr = FLAGS.base_lr * (FLAGS.batch_size / FLAGS.base_total_batch)
    # NOTE: don't drop last batch, thus must use ceil, otherwise learning rate
    # will be negative
    FLAGS._steps_per_epoch = int(np.ceil(NUM_IMAGENET_TRAIN /
                                         FLAGS.batch_size))

    if is_root_rank:
        FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir,
                                       time.strftime("%Y%m%d-%H%M%S"))
        create_exp_dir(
            FLAGS.log_dir,
            FLAGS.config_path,
            blacklist_dirs=[
                'exp',
                '.git',
                'pretrained',
                'tmp',
                'deprecated',
                'bak',
            ],
        )
        setup_logging(FLAGS.log_dir)
        for k, v in _ENV_EXPAND.items():
            logging.info('Env var expand: {} to {}'.format(k, v))
        logging.info(FLAGS)

    set_random_seed(FLAGS.get('random_seed', 0))
    with SummaryWriterManager():
        train_val_test()
Exemplo n.º 17
0
def main():
    """Entry."""
    FLAGS.test_only = True
    mc.setup_distributed()
    if udist.is_master():
        FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir,
                                       time.strftime("%Y%m%d-%H%M%S-eval"))
        setup_logging(FLAGS.log_dir)
        for k, v in _ENV_EXPAND.items():
            logging.info('Env var expand: {} to {}'.format(k, v))
        logging.info(FLAGS)

    set_random_seed(FLAGS.get('random_seed', 0))
    with mc.SummaryWriterManager():
        val()
Exemplo n.º 18
0
def main():
    """Entry."""
    NUM_IMAGENET_TRAIN = 1281167
    if FLAGS.dataset == 'cityscapes':
        NUM_IMAGENET_TRAIN = 2975
    elif FLAGS.dataset == 'ade20k':
        NUM_IMAGENET_TRAIN = 20210
    elif FLAGS.dataset == 'coco':
        NUM_IMAGENET_TRAIN = 149813
    mc.setup_distributed(NUM_IMAGENET_TRAIN)

    if FLAGS.net_params and FLAGS.model_kwparams.task == 'segmentation':
        tag, input_channels, block1, block2, block3, block4, last_channel = FLAGS.net_params.split(
            '-')
        input_channels = [int(item) for item in input_channels.split('_')]
        block1 = [int(item) for item in block1.split('_')]
        block2 = [int(item) for item in block2.split('_')]
        block3 = [int(item) for item in block3.split('_')]
        block4 = [int(item) for item in block4.split('_')]
        last_channel = int(last_channel)

        inverted_residual_setting = []
        for item in [block1, block2, block3, block4]:
            for _ in range(item[0]):
                inverted_residual_setting.append([
                    item[1], item[2:-int(len(item) / 2 - 1)],
                    item[-int(len(item) / 2 - 1):]
                ])

        FLAGS.model_kwparams.input_channel = input_channels
        FLAGS.model_kwparams.inverted_residual_setting = inverted_residual_setting
        FLAGS.model_kwparams.last_channel = last_channel

    if udist.is_master():
        FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir,
                                       time.strftime("%Y%m%d-%H%M%S"))
        # yapf: disable
        create_exp_dir(FLAGS.log_dir, FLAGS.config_path, blacklist_dirs=[
            'exp', '.git', 'pretrained', 'tmp', 'deprecated', 'bak', 'output'])
        # yapf: enable
        setup_logging(FLAGS.log_dir)
        for k, v in _ENV_EXPAND.items():
            logging.info('Env var expand: {} to {}'.format(k, v))
        logging.info(FLAGS)

    set_random_seed(FLAGS.get('random_seed', 0))
    with mc.SummaryWriterManager():
        train_val_test()
Exemplo n.º 19
0
def run_one_epoch(epoch,
                  loader,
                  model,
                  criterion,
                  optimizer,
                  lr_scheduler,
                  ema,
                  meters,
                  max_iter=None,
                  phase='train'):
    """Run one epoch."""
    assert phase in ['val', 'test', 'bn_calibration'
                     ], "phase not be in val/test/bn_calibration."
    model.eval()
    if phase == 'bn_calibration':
        model.apply(bn_calibration)

    if FLAGS.use_distributed:
        loader.sampler.set_epoch(epoch)

    data_iterator = iter(loader)
    if FLAGS.use_distributed:
        data_fetcher = dataflow.DataPrefetcher(data_iterator)
    else:
        # TODO(meijieru): prefetch for non distributed
        logging.warning('Not use prefetcher')
        data_fetcher = data_iterator
    for batch_idx, (input, target) in enumerate(data_fetcher):
        # used for bn calibration
        if max_iter is not None:
            assert phase == 'bn_calibration'
            if batch_idx >= max_iter:
                break

        target = target.cuda(non_blocking=True)
        mc.forward_loss(model, criterion, input, target, meters)

    results = mc.reduce_and_flush_meters(meters)
    if udist.is_master():
        logging.info(
            'Epoch {}/{} {}: '.format(epoch, FLAGS.num_epochs, phase) +
            ', '.join('{}: {:.4f}'.format(k, v) for k, v in results.items()))
        for k, v in results.items():
            mc.summary_writer.add_scalar('{}/{}'.format(phase, k), v,
                                         FLAGS._global_step)
    return results
Exemplo n.º 20
0
 def init_weights(self):
     if udist.is_master():
         logging.info('=> init weights from normal distribution')
     for m in self.modules():
         if isinstance(m, nn.Conv2d):
             if not self.initial_for_heatmap:
                 nn.init.kaiming_normal_(m.weight,
                                         mode='fan_out',
                                         nonlinearity='relu')
             else:
                 nn.init.normal_(m.weight, std=0.001)
                 for name, _ in m.named_parameters():
                     if name in ['bias']:
                         nn.init.constant_(m.bias, 0)
         elif isinstance(m, nn.BatchNorm2d):
             nn.init.constant_(m.weight, 1)
             nn.init.constant_(m.bias, 0)
Exemplo n.º 21
0
def dump_flops_stats(iteration, config, model):
    sample_flops = []
    num_sample = config.arch.num_flops_stats_sample

    for _ in range(num_sample):
        model.module.direct_sampling()
        cur_flops = calc_model_flops(model, config.dataset.input_size)
        sample_flops.append(cur_flops)

    if dist.is_master():
        save_folder = os.path.join(config.save_path, 'flops_stats')
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
        plt.hist(sample_flops, 50, density=True, facecolor='g', alpha=0.75)
        plt.axvline(x=np.mean(sample_flops), color='r', linestyle='--')
        pp = PdfPages(os.path.join(save_folder, 'flops_stats_{}.pdf'.format(iteration)))
        plt.savefig(pp, format='pdf')
        pp.close()
        plt.gcf().clear()
Exemplo n.º 22
0
    def __init__(self,
                 inp,
                 oup,
                 stride,
                 expand_ratio,
                 kernel_sizes,
                 active_fn=None,
                 batch_norm_kwargs=None,
                 **kwargs):
        def _expand_ratio_to_hiddens(expand_ratio):
            if isinstance(expand_ratio, list):
                assert len(expand_ratio) == len(kernel_sizes)
                expand = True
            elif isinstance(expand_ratio, numbers.Number):
                expand = expand_ratio != 1
                expand_ratio = [expand_ratio for _ in kernel_sizes]
            else:
                raise ValueError(
                    'Unknown expand_ratio type: {}'.format(expand_ratio))
            hidden_dims = [int(round(inp * e)) for e in expand_ratio]
            return hidden_dims, expand

        hidden_dims, expand = _expand_ratio_to_hiddens(expand_ratio)
        if checkpoint_kwparams:
            assert oup == checkpoint_kwparams[0][0]
            if udist.is_master():
                logging.info('loading: {} -> {}, {} -> {}'.format(
                    hidden_dims, checkpoint_kwparams[0][4], kernel_sizes,
                    checkpoint_kwparams[0][3]))
            hidden_dims = checkpoint_kwparams[0][4]
            kernel_sizes = checkpoint_kwparams[0][3]
            checkpoint_kwparams.pop(0)

        super(InvertedResidual,
              self).__init__(inp,
                             oup,
                             stride,
                             hidden_dims,
                             kernel_sizes,
                             expand,
                             active_fn=active_fn,
                             batch_norm_kwargs=batch_norm_kwargs)
        self.expand_ratio = expand_ratio
Exemplo n.º 23
0
def main():
    """Entry."""
    NUM_IMAGENET_TRAIN = 1281167

    mc.setup_distributed(NUM_IMAGENET_TRAIN)
    if udist.is_master():
        FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir,
                                       time.strftime("%Y%m%d-%H%M%S"))
        # yapf: disable
        create_exp_dir(FLAGS.log_dir, FLAGS.config_path, blacklist_dirs=[
            'exp', '.git', 'pretrained', 'tmp', 'deprecated', 'bak'])
        # yapf: enable
        setup_logging(FLAGS.log_dir)
        for k, v in _ENV_EXPAND.items():
            logging.info('Env var expand: {} to {}'.format(k, v))
        logging.info(FLAGS)

    set_random_seed(FLAGS.get('random_seed', 0))
    with mc.SummaryWriterManager():
        train_val_test()
Exemplo n.º 24
0
    def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,
                         split):
        """Load annotation from directory.

        Args:
            img_dir (str): Path to image directory
            img_suffix (str): Suffix of images.
            ann_dir (str|None): Path to annotation directory.
            seg_map_suffix (str|None): Suffix of segmentation maps.
            split (str|None): Split txt file. If split is specified, only file
                with suffix in the splits will be loaded. Otherwise, all images
                in img_dir/ann_dir will be loaded. Default: None

        Returns:
            list[dict]: All image info of dataset.
        """

        img_infos = []
        if split is not None:
            with open(split) as f:
                for line in f:
                    img_name = line.strip()
                    img_file = osp.join(img_dir, img_name + img_suffix)
                    img_info = dict(filename=img_file)
                    if ann_dir is not None:
                        seg_map = osp.join(ann_dir, img_name + seg_map_suffix)
                        img_info['ann'] = dict(seg_map=seg_map)
                    img_infos.append(img_info)
        else:
            for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
                img_file = osp.join(img_dir, img)
                img_info = dict(filename=img_file)
                if ann_dir is not None:
                    seg_map = osp.join(ann_dir,
                                       img.replace(img_suffix, seg_map_suffix))
                    img_info['ann'] = dict(seg_map=seg_map)
                img_infos.append(img_info)
        if udist.is_master():
            print(f'Loaded {len(img_infos)} images')
        return img_infos
Exemplo n.º 25
0
def get_model():
    """Build and init model with wrapper for parallel."""
    model_lib = importlib.import_module(FLAGS.model)
    model = model_lib.Model(**FLAGS.model_kwparams,
                            input_size=FLAGS.image_size)
    if FLAGS.reset_parameters:
        init_method = FLAGS.get('reset_param_method', None)
        if init_method is None:
            pass  # fall back to model's initialization
        elif init_method == 'slimmable':
            model.apply(mb.init_weights_slimmable)
        elif init_method == 'mnas':
            model.apply(mb.init_weights_mnas)
        else:
            raise ValueError('Unknown init method: {}'.format(init_method))
        if udist.is_master():
            logging.info('Init model by: {}'.format(init_method))
    if FLAGS.use_distributed:
        model_wrapper = udist.AllReduceDistributedDataParallel(model.cuda())
    else:
        model_wrapper = torch.nn.DataParallel(model).cuda()
    return model, model_wrapper
Exemplo n.º 26
0
def setup_ema(model):
    """Setup EMA for model's weights."""
    from utils import optim

    ema = None
    if FLAGS.moving_average_decay > 0.0:
        if FLAGS.moving_average_decay_adjust:
            moving_average_decay = \
                optim.ExponentialMovingAverage.adjust_momentum(
                    FLAGS.moving_average_decay,
                    FLAGS.moving_average_decay_base_batch / FLAGS.batch_size)
        else:
            moving_average_decay = FLAGS.moving_average_decay
        if udist.is_master():
            logging.info('Moving average for model parameters: {}'.format(
                moving_average_decay))
        ema = optim.ExponentialMovingAverage(moving_average_decay)
        for name, param in model.named_parameters():
            ema.register(name, param)
        # We maintain mva for batch norm moving mean and variance as well.
        for name, buffer in model.named_buffers():
            if 'running_var' in name or 'running_mean' in name:
                ema.register(name, buffer)
    return ema
Exemplo n.º 27
0
def run_one_epoch(epoch,
                  loader,
                  model,
                  criterion,
                  optimizer,
                  lr_scheduler,
                  ema,
                  meters,
                  max_iter=None,
                  phase='train'):
    """Run one epoch."""
    assert phase in ['train', 'val', 'test', 'bn_calibration'
                     ], "phase not be in train/val/test/bn_calibration."
    train = phase == 'train'
    if train:
        model.train()
    else:
        model.eval()
    if phase == 'bn_calibration':
        model.apply(bn_calibration)

    if FLAGS.use_distributed:
        loader.sampler.set_epoch(epoch)

    results = None
    data_iterator = iter(loader)
    if FLAGS.use_distributed:
        data_fetcher = dataflow.DataPrefetcher(data_iterator)
    else:
        # TODO(meijieru): prefetch for non distributed
        logging.warning('Not use prefetcher')
        data_fetcher = data_iterator
    for batch_idx, (input, target) in enumerate(data_fetcher):
        # used for bn calibration
        if max_iter is not None:
            assert phase == 'bn_calibration'
            if batch_idx >= max_iter:
                break

        target = target.cuda(non_blocking=True)
        if train:
            optimizer.zero_grad()
            loss = mc.forward_loss(model, criterion, input, target, meters)
            loss_l2 = optim.cal_l2_loss(model, FLAGS.weight_decay,
                                        FLAGS.weight_decay_method)
            loss = loss + loss_l2
            loss.backward()
            if FLAGS.use_distributed:
                udist.allreduce_grads(model)

            if FLAGS._global_step % FLAGS.log_interval == 0:
                results = mc.reduce_and_flush_meters(meters)
                if udist.is_master():
                    logging.info('Epoch {}/{} Iter {}/{} {}: '.format(
                        epoch, FLAGS.num_epochs, batch_idx, len(loader), phase)
                                 + ', '.join('{}: {:.4f}'.format(k, v)
                                             for k, v in results.items()))
                    for k, v in results.items():
                        mc.summary_writer.add_scalar('{}/{}'.format(phase, k),
                                                     v, FLAGS._global_step)
            if udist.is_master(
            ) and FLAGS._global_step % FLAGS.log_interval == 0:
                mc.summary_writer.add_scalar('train/learning_rate',
                                             optimizer.param_groups[0]['lr'],
                                             FLAGS._global_step)
                mc.summary_writer.add_scalar('train/l2_regularize_loss',
                                             extract_item(loss_l2),
                                             FLAGS._global_step)
                mc.summary_writer.add_scalar(
                    'train/current_epoch',
                    FLAGS._global_step / FLAGS._steps_per_epoch,
                    FLAGS._global_step)
                if FLAGS.data_loader_workers > 0:
                    mc.summary_writer.add_scalar(
                        'data/train/prefetch_size',
                        get_data_queue_size(data_iterator), FLAGS._global_step)

            optimizer.step()
            lr_scheduler.step()
            if FLAGS.use_distributed and FLAGS.allreduce_bn:
                udist.allreduce_bn(model)
            FLAGS._global_step += 1

            # NOTE: after steps count upate
            if ema is not None:
                model_unwrap = mc.unwrap_model(model)
                ema_names = ema.average_names()
                params = get_params_by_name(model_unwrap, ema_names)
                for name, param in zip(ema_names, params):
                    ema(name, param, FLAGS._global_step)
        else:
            mc.forward_loss(model, criterion, input, target, meters)

    if not train:
        results = mc.reduce_and_flush_meters(meters)
        if udist.is_master():
            logging.info(
                'Epoch {}/{} {}: '.format(epoch, FLAGS.num_epochs, phase) +
                ', '.join('{}: {:.4f}'.format(k, v)
                          for k, v in results.items()))
            for k, v in results.items():
                mc.summary_writer.add_scalar('{}/{}'.format(phase, k), v,
                                             FLAGS._global_step)
    return results
Exemplo n.º 28
0
def train_val_test():
    """Train and val."""
    torch.backends.cudnn.benchmark = True

    # model
    model, model_wrapper = mc.get_model()
    ema = mc.setup_ema(model)
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    criterion_smooth = optim.CrossEntropyLabelSmooth(
        FLAGS.model_kwparams['num_classes'],
        FLAGS['label_smoothing'],
        reduction='none').cuda()
    # TODO(meijieru): cal loss on all GPUs instead only `cuda:0` when non
    # distributed

    if FLAGS.get('log_graph_only', False):
        if udist.is_master():
            _input = torch.zeros(1, 3, FLAGS.image_size,
                                 FLAGS.image_size).cuda()
            _input = _input.requires_grad_(True)
            mc.summary_writer.add_graph(model_wrapper, (_input, ),
                                        verbose=True)
        return

    # check pretrained
    if FLAGS.pretrained:
        checkpoint = torch.load(FLAGS.pretrained,
                                map_location=lambda storage, loc: storage)
        if ema:
            ema.load_state_dict(checkpoint['ema'])
            ema.to(get_device(model))
        # update keys from external models
        if isinstance(checkpoint, dict) and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if (hasattr(FLAGS, 'pretrained_model_remap_keys')
                and FLAGS.pretrained_model_remap_keys):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                logging.info('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_wrapper.load_state_dict(checkpoint)
        logging.info('Loaded model {}.'.format(FLAGS.pretrained))
    optimizer = optim.get_optimizer(model_wrapper, FLAGS)

    # check resume training
    if FLAGS.resume:
        checkpoint = torch.load(os.path.join(FLAGS.resume,
                                             'latest_checkpoint.pt'),
                                map_location=lambda storage, loc: storage)
        model_wrapper.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if ema:
            ema.load_state_dict(checkpoint['ema'])
            ema.to(get_device(model))
        last_epoch = checkpoint['last_epoch']
        lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS)
        lr_scheduler.last_epoch = (last_epoch + 1) * FLAGS._steps_per_epoch
        best_val = extract_item(checkpoint['best_val'])
        train_meters, val_meters = checkpoint['meters']
        FLAGS._global_step = (last_epoch + 1) * FLAGS._steps_per_epoch
        if udist.is_master():
            logging.info('Loaded checkpoint {} at epoch {}.'.format(
                FLAGS.resume, last_epoch))
    else:
        lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS)
        # last_epoch = lr_scheduler.last_epoch
        last_epoch = -1
        best_val = 1.
        train_meters = mc.get_meters('train')
        val_meters = mc.get_meters('val')
        FLAGS._global_step = 0

    if not FLAGS.resume and udist.is_master():
        logging.info(model_wrapper)
    if FLAGS.profiling:
        if 'gpu' in FLAGS.profiling:
            mc.profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            mc.profiling(model, use_cuda=False)

    # data
    (train_transforms, val_transforms,
     test_transforms) = dataflow.data_transforms(FLAGS)
    (train_set, val_set, test_set) = dataflow.dataset(train_transforms,
                                                      val_transforms,
                                                      test_transforms, FLAGS)
    (train_loader, calib_loader, val_loader,
     test_loader) = dataflow.data_loader(train_set, val_set, test_set, FLAGS)

    if FLAGS.test_only and (test_loader is not None):
        if udist.is_master():
            logging.info('Start testing.')
        test_meters = mc.get_meters('test')
        validate(last_epoch, calib_loader, test_loader, criterion, test_meters,
                 model_wrapper, ema, 'test')
        return

    # already broadcast by AllReduceDistributedDataParallel
    # optimizer load same checkpoint/same initialization

    if udist.is_master():
        logging.info('Start training.')

    for epoch in range(last_epoch + 1, FLAGS.num_epochs):
        # train
        results = run_one_epoch(epoch,
                                train_loader,
                                model_wrapper,
                                criterion_smooth,
                                optimizer,
                                lr_scheduler,
                                ema,
                                train_meters,
                                phase='train')

        # val
        results = validate(epoch, calib_loader, val_loader, criterion,
                           val_meters, model_wrapper, ema, 'val')
        if results['top1_error'] < best_val:
            best_val = results['top1_error']

            if udist.is_master():
                save_status(model_wrapper, optimizer, ema, epoch, best_val,
                            (train_meters, val_meters),
                            os.path.join(FLAGS.log_dir, 'best_model.pt'))
                logging.info(
                    'New best validation top1 error: {:.4f}'.format(best_val))
        if udist.is_master():
            # save latest checkpoint
            save_status(model_wrapper, optimizer, ema, epoch, best_val,
                        (train_meters, val_meters),
                        os.path.join(FLAGS.log_dir, 'latest_checkpoint.pt'))

        wandb.log(
            {
                "Validation Accuracy": 1. - results['top1_error'],
                "Best Validation Accuracy": 1. - best_val
            },
            step=epoch)


# NOTE(meijieru): from scheduler code, should be called after train/val
# use stepwise scheduler instead
# lr_scheduler.step()
    return
Exemplo n.º 29
0
def validate(epoch,
             calib_loader,
             val_loader,
             criterion,
             val_meters,
             model_wrapper,
             ema,
             phase,
             segval=None,
             val_set=None):
    """Calibrate and validate."""
    assert phase in ['test', 'val']
    model_eval_wrapper = mc.get_ema_model(ema, model_wrapper)

    # bn_calibration
    if FLAGS.prune_params['method'] is not None:
        if FLAGS.get('bn_calibration', False):
            if not FLAGS.use_distributed:
                logging.warning(
                    'Only GPU0 is used when calibration when use DataParallel')
            with torch.no_grad():
                _ = run_one_epoch(epoch,
                                  calib_loader,
                                  model_eval_wrapper,
                                  criterion,
                                  None,
                                  None,
                                  None,
                                  None,
                                  val_meters,
                                  max_iter=FLAGS.bn_calibration_steps,
                                  phase='bn_calibration')
            if FLAGS.use_distributed:
                udist.allreduce_bn(model_eval_wrapper)

    # val
    with torch.no_grad():
        if FLAGS.model_kwparams.task == 'segmentation':
            if FLAGS.dataset == 'coco':
                results = 0
                if udist.is_master():
                    results = keypoint_val(val_set, val_loader,
                                           model_eval_wrapper.module,
                                           criterion)
            else:
                assert segval is not None
                results = segval.run(
                    epoch, val_loader, model_eval_wrapper.module
                    if FLAGS.single_gpu_test else model_eval_wrapper, FLAGS)
        else:
            results = run_one_epoch(epoch,
                                    val_loader,
                                    model_eval_wrapper,
                                    criterion,
                                    None,
                                    None,
                                    None,
                                    None,
                                    val_meters,
                                    phase=phase)
    summary_bn(model_eval_wrapper, phase)
    return results, model_eval_wrapper
Exemplo n.º 30
0
def train_val_test():
    """Train and val."""
    torch.backends.cudnn.benchmark = True  # For acceleration

    # model
    model, model_wrapper = mc.get_model()
    ema = mc.setup_ema(model)
    criterion = torch.nn.CrossEntropyLoss(reduction='mean').cuda()
    criterion_smooth = optim.CrossEntropyLabelSmooth(
        FLAGS.model_kwparams['num_classes'],
        FLAGS['label_smoothing'],
        reduction='mean').cuda()
    if model.task == 'segmentation':
        criterion = CrossEntropyLoss().cuda()
        criterion_smooth = CrossEntropyLoss().cuda()
    if FLAGS.dataset == 'coco':
        criterion = JointsMSELoss(use_target_weight=True).cuda()
        criterion_smooth = JointsMSELoss(use_target_weight=True).cuda()

    if FLAGS.get('log_graph_only', False):
        if udist.is_master():
            _input = torch.zeros(1, 3, FLAGS.image_size,
                                 FLAGS.image_size).cuda()
            _input = _input.requires_grad_(True)
            if isinstance(model_wrapper,
                          (torch.nn.DataParallel,
                           udist.AllReduceDistributedDataParallel)):
                mc.summary_writer.add_graph(model_wrapper.module, (_input, ),
                                            verbose=True)
            else:
                mc.summary_writer.add_graph(model_wrapper, (_input, ),
                                            verbose=True)
        return

    # check pretrained
    if FLAGS.pretrained:
        checkpoint = torch.load(FLAGS.pretrained,
                                map_location=lambda storage, loc: storage)
        if ema:
            ema.load_state_dict(checkpoint['ema'])
            ema.to(get_device(model))
        # update keys from external models
        if isinstance(checkpoint, dict) and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if (hasattr(FLAGS, 'pretrained_model_remap_keys')
                and FLAGS.pretrained_model_remap_keys):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                if udist.is_master():
                    logging.info('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_wrapper.load_state_dict(checkpoint)
        if udist.is_master():
            logging.info('Loaded model {}.'.format(FLAGS.pretrained))
    optimizer = optim.get_optimizer(model_wrapper, FLAGS)

    # check resume training
    if FLAGS.resume:
        checkpoint = torch.load(os.path.join(FLAGS.resume,
                                             'latest_checkpoint.pt'),
                                map_location=lambda storage, loc: storage)
        model_wrapper = checkpoint['model'].cuda()
        model = model_wrapper.module
        # model = checkpoint['model'].module
        optimizer = checkpoint['optimizer']
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        # model_wrapper.load_state_dict(checkpoint['model'])
        # optimizer.load_state_dict(checkpoint['optimizer'])
        if ema:
            # ema.load_state_dict(checkpoint['ema'])
            ema = checkpoint['ema'].cuda()
            ema.to(get_device(model))
        last_epoch = checkpoint['last_epoch']
        lr_scheduler = optim.get_lr_scheduler(optimizer,
                                              FLAGS,
                                              last_epoch=(last_epoch + 1) *
                                              FLAGS._steps_per_epoch)
        lr_scheduler.last_epoch = (last_epoch + 1) * FLAGS._steps_per_epoch
        best_val = extract_item(checkpoint['best_val'])
        train_meters, val_meters = checkpoint['meters']
        FLAGS._global_step = (last_epoch + 1) * FLAGS._steps_per_epoch
        if udist.is_master():
            logging.info('Loaded checkpoint {} at epoch {}.'.format(
                FLAGS.resume, last_epoch))
    else:
        lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS)
        # last_epoch = lr_scheduler.last_epoch
        last_epoch = -1
        best_val = 1.
        if not FLAGS.distill:
            train_meters = mc.get_meters('train', FLAGS.prune_params['method'])
            val_meters = mc.get_meters('val')
        else:
            train_meters = mc.get_distill_meters('train',
                                                 FLAGS.prune_params['method'])
            val_meters = mc.get_distill_meters('val')
        if FLAGS.model_kwparams.task == 'segmentation':
            best_val = 0.
            if not FLAGS.distill:
                train_meters = mc.get_seg_meters('train',
                                                 FLAGS.prune_params['method'])
                val_meters = mc.get_seg_meters('val')
            else:
                train_meters = mc.get_seg_distill_meters(
                    'train', FLAGS.prune_params['method'])
                val_meters = mc.get_seg_distill_meters('val')
        FLAGS._global_step = 0

    if not FLAGS.resume and udist.is_master():
        logging.info(model_wrapper)
    assert FLAGS.profiling, '`m.macs` is used for calculating penalty'
    # if udist.is_master():
    #     model.apply(lambda m: print(m))
    if FLAGS.profiling:
        if 'gpu' in FLAGS.profiling:
            mc.profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            mc.profiling(model, use_cuda=False)

    if FLAGS.dataset == 'cityscapes':
        (train_set, val_set,
         test_set) = seg_dataflow.cityscapes_datasets(FLAGS)
        segval = SegVal(num_classes=19)
    elif FLAGS.dataset == 'ade20k':
        (train_set, val_set, test_set) = seg_dataflow.ade20k_datasets(FLAGS)
        segval = SegVal(num_classes=150)
    elif FLAGS.dataset == 'coco':
        (train_set, val_set, test_set) = seg_dataflow.coco_datasets(FLAGS)
        # print(len(train_set), len(val_set))  # 149813 104125
        segval = None
    else:
        # data
        (train_transforms, val_transforms,
         test_transforms) = dataflow.data_transforms(FLAGS)
        (train_set, val_set,
         test_set) = dataflow.dataset(train_transforms, val_transforms,
                                      test_transforms, FLAGS)
        segval = None
    (train_loader, calib_loader, val_loader,
     test_loader) = dataflow.data_loader(train_set, val_set, test_set, FLAGS)

    # get bn's weights
    if FLAGS.prune_params.use_transformer:
        FLAGS._bn_to_prune, FLAGS._bn_to_prune_transformer = prune.get_bn_to_prune(
            model, FLAGS.prune_params)
    else:
        FLAGS._bn_to_prune = prune.get_bn_to_prune(model, FLAGS.prune_params)
    rho_scheduler = prune.get_rho_scheduler(FLAGS.prune_params,
                                            FLAGS._steps_per_epoch)

    if FLAGS.test_only and (test_loader is not None):
        if udist.is_master():
            logging.info('Start testing.')
        test_meters = mc.get_meters('test')
        validate(last_epoch, calib_loader, test_loader, criterion, test_meters,
                 model_wrapper, ema, 'test')
        return

    # already broadcast by AllReduceDistributedDataParallel
    # optimizer load same checkpoint/same initialization

    if udist.is_master():
        logging.info('Start training.')

    for epoch in range(last_epoch + 1, FLAGS.num_epochs):
        # train
        results = run_one_epoch(epoch,
                                train_loader,
                                model_wrapper,
                                criterion_smooth,
                                optimizer,
                                lr_scheduler,
                                ema,
                                rho_scheduler,
                                train_meters,
                                phase='train')

        if (epoch + 1) % FLAGS.eval_interval == 0:
            # val
            results, model_eval_wrapper = validate(epoch, calib_loader,
                                                   val_loader, criterion,
                                                   val_meters, model_wrapper,
                                                   ema, 'val', segval, val_set)

            if FLAGS.prune_params['method'] is not None and FLAGS.prune_params[
                    'bn_prune_filter'] is not None:
                prune_threshold = FLAGS.model_shrink_threshold  # 1e-3
                masks = prune.cal_mask_network_slimming_by_threshold(
                    get_prune_weights(model_eval_wrapper), prune_threshold
                )  # get mask for all bn weights (depth-wise)
                FLAGS._bn_to_prune.add_info_list('mask', masks)
                flops_pruned, infos = prune.cal_pruned_flops(
                    FLAGS._bn_to_prune)
                log_pruned_info(mc.unwrap_model(model_eval_wrapper),
                                flops_pruned, infos, prune_threshold)
                if not FLAGS.distill:
                    if flops_pruned >= FLAGS.model_shrink_delta_flops \
                            or epoch == FLAGS.num_epochs - 1:
                        ema_only = (epoch == FLAGS.num_epochs - 1)
                        shrink_model(model_wrapper, ema, optimizer,
                                     FLAGS._bn_to_prune, prune_threshold,
                                     ema_only)
            model_kwparams = mb.output_network(mc.unwrap_model(model_wrapper))

            if udist.is_master():
                if FLAGS.model_kwparams.task == 'classification' and results[
                        'top1_error'] < best_val:
                    best_val = results['top1_error']
                    logging.info(
                        'New best validation top1 error: {:.4f}'.format(
                            best_val))

                    save_status(model_wrapper, model_kwparams, optimizer, ema,
                                epoch, best_val, (train_meters, val_meters),
                                os.path.join(FLAGS.log_dir, 'best_model'))

                elif FLAGS.model_kwparams.task == 'segmentation' and FLAGS.dataset != 'coco' and results[
                        'mIoU'] > best_val:
                    best_val = results['mIoU']
                    logging.info('New seg mIoU: {:.4f}'.format(best_val))

                    save_status(model_wrapper, model_kwparams, optimizer, ema,
                                epoch, best_val, (train_meters, val_meters),
                                os.path.join(FLAGS.log_dir, 'best_model'))
                elif FLAGS.dataset == 'coco' and results > best_val:
                    best_val = results
                    logging.info('New Result: {:.4f}'.format(best_val))
                    save_status(model_wrapper, model_kwparams, optimizer, ema,
                                epoch, best_val, (train_meters, val_meters),
                                os.path.join(FLAGS.log_dir, 'best_model'))

                # save latest checkpoint
                save_status(model_wrapper, model_kwparams, optimizer, ema,
                            epoch, best_val, (train_meters, val_meters),
                            os.path.join(FLAGS.log_dir, 'latest_checkpoint'))

    return