Пример #1
0
def main():
    """Entry."""
    # init distributed
    global is_root_rank
    if FLAGS.use_distributed:
        udist.init_dist()
        FLAGS.batch_size = udist.get_world_size() * FLAGS.per_gpu_batch_size
        FLAGS._loader_batch_size = FLAGS.per_gpu_batch_size
        if FLAGS.bn_calibration:
            FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size
        FLAGS.data_loader_workers = round(FLAGS.data_loader_workers /
                                          udist.get_local_size())
        is_root_rank = udist.is_master()
    else:
        count = torch.cuda.device_count()
        FLAGS.batch_size = count * FLAGS.per_gpu_batch_size
        FLAGS._loader_batch_size = FLAGS.batch_size
        if FLAGS.bn_calibration:
            FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size * count
        is_root_rank = True
    FLAGS.lr = FLAGS.base_lr * (FLAGS.batch_size / FLAGS.base_total_batch)
    # NOTE: don't drop last batch, thus must use ceil, otherwise learning rate
    # will be negative
    FLAGS._steps_per_epoch = int(np.ceil(NUM_IMAGENET_TRAIN /
                                         FLAGS.batch_size))

    if is_root_rank:
        FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir,
                                       time.strftime("%Y%m%d-%H%M%S"))
        create_exp_dir(
            FLAGS.log_dir,
            FLAGS.config_path,
            blacklist_dirs=[
                'exp',
                '.git',
                'pretrained',
                'tmp',
                'deprecated',
                'bak',
            ],
        )
        setup_logging(FLAGS.log_dir)
        for k, v in _ENV_EXPAND.items():
            logging.info('Env var expand: {} to {}'.format(k, v))
        logging.info(FLAGS)

    set_random_seed(FLAGS.get('random_seed', 0))
    with SummaryWriterManager():
        train_val_test()
Пример #2
0
def main():
    """Entry."""
    FLAGS.test_only = True
    mc.setup_distributed()
    if udist.is_master():
        FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir,
                                       time.strftime("%Y%m%d-%H%M%S-eval"))
        setup_logging(FLAGS.log_dir)
        for k, v in _ENV_EXPAND.items():
            logging.info('Env var expand: {} to {}'.format(k, v))
        logging.info(FLAGS)

    set_random_seed(FLAGS.get('random_seed', 0))
    with mc.SummaryWriterManager():
        val()
Пример #3
0
def main():
    """Entry."""
    NUM_IMAGENET_TRAIN = 1281167
    if FLAGS.dataset == 'cityscapes':
        NUM_IMAGENET_TRAIN = 2975
    elif FLAGS.dataset == 'ade20k':
        NUM_IMAGENET_TRAIN = 20210
    elif FLAGS.dataset == 'coco':
        NUM_IMAGENET_TRAIN = 149813
    mc.setup_distributed(NUM_IMAGENET_TRAIN)

    if FLAGS.net_params and FLAGS.model_kwparams.task == 'segmentation':
        tag, input_channels, block1, block2, block3, block4, last_channel = FLAGS.net_params.split(
            '-')
        input_channels = [int(item) for item in input_channels.split('_')]
        block1 = [int(item) for item in block1.split('_')]
        block2 = [int(item) for item in block2.split('_')]
        block3 = [int(item) for item in block3.split('_')]
        block4 = [int(item) for item in block4.split('_')]
        last_channel = int(last_channel)

        inverted_residual_setting = []
        for item in [block1, block2, block3, block4]:
            for _ in range(item[0]):
                inverted_residual_setting.append([
                    item[1], item[2:-int(len(item) / 2 - 1)],
                    item[-int(len(item) / 2 - 1):]
                ])

        FLAGS.model_kwparams.input_channel = input_channels
        FLAGS.model_kwparams.inverted_residual_setting = inverted_residual_setting
        FLAGS.model_kwparams.last_channel = last_channel

    if udist.is_master():
        FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir,
                                       time.strftime("%Y%m%d-%H%M%S"))
        # yapf: disable
        create_exp_dir(FLAGS.log_dir, FLAGS.config_path, blacklist_dirs=[
            'exp', '.git', 'pretrained', 'tmp', 'deprecated', 'bak', 'output'])
        # yapf: enable
        setup_logging(FLAGS.log_dir)
        for k, v in _ENV_EXPAND.items():
            logging.info('Env var expand: {} to {}'.format(k, v))
        logging.info(FLAGS)

    set_random_seed(FLAGS.get('random_seed', 0))
    with mc.SummaryWriterManager():
        train_val_test()
Пример #4
0
def main():
    """Entry."""
    NUM_IMAGENET_TRAIN = 1281167

    mc.setup_distributed(NUM_IMAGENET_TRAIN)
    if udist.is_master():
        FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir,
                                       time.strftime("%Y%m%d-%H%M%S"))
        # yapf: disable
        create_exp_dir(FLAGS.log_dir, FLAGS.config_path, blacklist_dirs=[
            'exp', '.git', 'pretrained', 'tmp', 'deprecated', 'bak'])
        # yapf: enable
        setup_logging(FLAGS.log_dir)
        for k, v in _ENV_EXPAND.items():
            logging.info('Env var expand: {} to {}'.format(k, v))
        logging.info(FLAGS)

    set_random_seed(FLAGS.get('random_seed', 0))
    with mc.SummaryWriterManager():
        train_val_test()
Пример #5
0
def get_model():
    """Build and init model with wrapper for parallel."""
    model_lib = importlib.import_module(FLAGS.model)
    model = model_lib.Model(**FLAGS.model_kwparams,
                            input_size=FLAGS.image_size)
    if FLAGS.reset_parameters:
        init_method = FLAGS.get('reset_param_method', None)
        if init_method is None:
            pass  # fall back to model's initialization
        elif init_method == 'slimmable':
            model.apply(mb.init_weights_slimmable)
        elif init_method == 'mnas':
            model.apply(mb.init_weights_mnas)
        else:
            raise ValueError('Unknown init method: {}'.format(init_method))
        logging.info('Init model by: {}'.format(init_method))
    if FLAGS.use_distributed:
        model_wrapper = udist.AllReduceDistributedDataParallel(model.cuda())
    else:
        model_wrapper = torch.nn.DataParallel(model).cuda()
    return model, model_wrapper
Пример #6
0
def validate(epoch, calib_loader, val_loader, criterion, val_meters,
             model_wrapper, ema, phase):
    """Calibrate and validate."""
    assert phase in ['test', 'val']
    model_eval_wrapper = get_ema_model(ema, model_wrapper)

    # bn_calibration
    if FLAGS.get('bn_calibration', False):
        if not FLAGS.use_distributed:
            logging.warning(
                'Only GPU0 is used when calibration when use DataParallel')
        with torch.no_grad():
            _ = run_one_epoch(epoch,
                              calib_loader,
                              model_eval_wrapper,
                              criterion,
                              None,
                              None,
                              None,
                              None,
                              val_meters,
                              max_iter=FLAGS.bn_calibration_steps,
                              phase='bn_calibration')
        if FLAGS.use_distributed:
            udist.allreduce_bn(model_eval_wrapper)

    # val
    with torch.no_grad():
        results = run_one_epoch(epoch,
                                val_loader,
                                model_eval_wrapper,
                                criterion,
                                None,
                                None,
                                None,
                                None,
                                val_meters,
                                phase=phase)
    summary_bn(model_eval_wrapper, phase)
    return results, model_eval_wrapper
Пример #7
0
def train_val_test():
    """Train and val."""
    torch.backends.cudnn.benchmark = True

    # model
    model, model_wrapper = mc.get_model()
    ema = mc.setup_ema(model)
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    criterion_smooth = optim.CrossEntropyLabelSmooth(
        FLAGS.model_kwparams['num_classes'],
        FLAGS['label_smoothing'],
        reduction='none').cuda()
    # TODO(meijieru): cal loss on all GPUs instead only `cuda:0` when non
    # distributed

    if FLAGS.get('log_graph_only', False):
        if udist.is_master():
            _input = torch.zeros(1, 3, FLAGS.image_size,
                                 FLAGS.image_size).cuda()
            _input = _input.requires_grad_(True)
            mc.summary_writer.add_graph(model_wrapper, (_input, ),
                                        verbose=True)
        return

    # check pretrained
    if FLAGS.pretrained:
        checkpoint = torch.load(FLAGS.pretrained,
                                map_location=lambda storage, loc: storage)
        if ema:
            ema.load_state_dict(checkpoint['ema'])
            ema.to(get_device(model))
        # update keys from external models
        if isinstance(checkpoint, dict) and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if (hasattr(FLAGS, 'pretrained_model_remap_keys')
                and FLAGS.pretrained_model_remap_keys):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                logging.info('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_wrapper.load_state_dict(checkpoint)
        logging.info('Loaded model {}.'.format(FLAGS.pretrained))
    optimizer = optim.get_optimizer(model_wrapper, FLAGS)

    # check resume training
    if FLAGS.resume:
        checkpoint = torch.load(os.path.join(FLAGS.resume,
                                             'latest_checkpoint.pt'),
                                map_location=lambda storage, loc: storage)
        model_wrapper.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if ema:
            ema.load_state_dict(checkpoint['ema'])
            ema.to(get_device(model))
        last_epoch = checkpoint['last_epoch']
        lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS)
        lr_scheduler.last_epoch = (last_epoch + 1) * FLAGS._steps_per_epoch
        best_val = extract_item(checkpoint['best_val'])
        train_meters, val_meters = checkpoint['meters']
        FLAGS._global_step = (last_epoch + 1) * FLAGS._steps_per_epoch
        if udist.is_master():
            logging.info('Loaded checkpoint {} at epoch {}.'.format(
                FLAGS.resume, last_epoch))
    else:
        lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS)
        # last_epoch = lr_scheduler.last_epoch
        last_epoch = -1
        best_val = 1.
        train_meters = mc.get_meters('train')
        val_meters = mc.get_meters('val')
        FLAGS._global_step = 0

    if not FLAGS.resume and udist.is_master():
        logging.info(model_wrapper)
    if FLAGS.profiling:
        if 'gpu' in FLAGS.profiling:
            mc.profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            mc.profiling(model, use_cuda=False)

    # data
    (train_transforms, val_transforms,
     test_transforms) = dataflow.data_transforms(FLAGS)
    (train_set, val_set, test_set) = dataflow.dataset(train_transforms,
                                                      val_transforms,
                                                      test_transforms, FLAGS)
    (train_loader, calib_loader, val_loader,
     test_loader) = dataflow.data_loader(train_set, val_set, test_set, FLAGS)

    if FLAGS.test_only and (test_loader is not None):
        if udist.is_master():
            logging.info('Start testing.')
        test_meters = mc.get_meters('test')
        validate(last_epoch, calib_loader, test_loader, criterion, test_meters,
                 model_wrapper, ema, 'test')
        return

    # already broadcast by AllReduceDistributedDataParallel
    # optimizer load same checkpoint/same initialization

    if udist.is_master():
        logging.info('Start training.')

    for epoch in range(last_epoch + 1, FLAGS.num_epochs):
        # train
        results = run_one_epoch(epoch,
                                train_loader,
                                model_wrapper,
                                criterion_smooth,
                                optimizer,
                                lr_scheduler,
                                ema,
                                train_meters,
                                phase='train')

        # val
        results = validate(epoch, calib_loader, val_loader, criterion,
                           val_meters, model_wrapper, ema, 'val')
        if results['top1_error'] < best_val:
            best_val = results['top1_error']

            if udist.is_master():
                save_status(model_wrapper, optimizer, ema, epoch, best_val,
                            (train_meters, val_meters),
                            os.path.join(FLAGS.log_dir, 'best_model.pt'))
                logging.info(
                    'New best validation top1 error: {:.4f}'.format(best_val))
        if udist.is_master():
            # save latest checkpoint
            save_status(model_wrapper, optimizer, ema, epoch, best_val,
                        (train_meters, val_meters),
                        os.path.join(FLAGS.log_dir, 'latest_checkpoint.pt'))

        wandb.log(
            {
                "Validation Accuracy": 1. - results['top1_error'],
                "Best Validation Accuracy": 1. - best_val
            },
            step=epoch)


# NOTE(meijieru): from scheduler code, should be called after train/val
# use stepwise scheduler instead
# lr_scheduler.step()
    return
Пример #8
0
def validate(epoch,
             calib_loader,
             val_loader,
             criterion,
             val_meters,
             model_wrapper,
             ema,
             phase,
             segval=None,
             val_set=None):
    """Calibrate and validate."""
    assert phase in ['test', 'val']
    model_eval_wrapper = mc.get_ema_model(ema, model_wrapper)

    # bn_calibration
    if FLAGS.prune_params['method'] is not None:
        if FLAGS.get('bn_calibration', False):
            if not FLAGS.use_distributed:
                logging.warning(
                    'Only GPU0 is used when calibration when use DataParallel')
            with torch.no_grad():
                _ = run_one_epoch(epoch,
                                  calib_loader,
                                  model_eval_wrapper,
                                  criterion,
                                  None,
                                  None,
                                  None,
                                  None,
                                  val_meters,
                                  max_iter=FLAGS.bn_calibration_steps,
                                  phase='bn_calibration')
            if FLAGS.use_distributed:
                udist.allreduce_bn(model_eval_wrapper)

    # val
    with torch.no_grad():
        if FLAGS.model_kwparams.task == 'segmentation':
            if FLAGS.dataset == 'coco':
                results = 0
                if udist.is_master():
                    results = keypoint_val(val_set, val_loader,
                                           model_eval_wrapper.module,
                                           criterion)
            else:
                assert segval is not None
                results = segval.run(
                    epoch, val_loader, model_eval_wrapper.module
                    if FLAGS.single_gpu_test else model_eval_wrapper, FLAGS)
        else:
            results = run_one_epoch(epoch,
                                    val_loader,
                                    model_eval_wrapper,
                                    criterion,
                                    None,
                                    None,
                                    None,
                                    None,
                                    val_meters,
                                    phase=phase)
    summary_bn(model_eval_wrapper, phase)
    return results, model_eval_wrapper
Пример #9
0
def train_val_test():
    """Train and val."""
    torch.backends.cudnn.benchmark = True  # For acceleration

    # model
    model, model_wrapper = mc.get_model()
    ema = mc.setup_ema(model)
    criterion = torch.nn.CrossEntropyLoss(reduction='mean').cuda()
    criterion_smooth = optim.CrossEntropyLabelSmooth(
        FLAGS.model_kwparams['num_classes'],
        FLAGS['label_smoothing'],
        reduction='mean').cuda()
    if model.task == 'segmentation':
        criterion = CrossEntropyLoss().cuda()
        criterion_smooth = CrossEntropyLoss().cuda()
    if FLAGS.dataset == 'coco':
        criterion = JointsMSELoss(use_target_weight=True).cuda()
        criterion_smooth = JointsMSELoss(use_target_weight=True).cuda()

    if FLAGS.get('log_graph_only', False):
        if udist.is_master():
            _input = torch.zeros(1, 3, FLAGS.image_size,
                                 FLAGS.image_size).cuda()
            _input = _input.requires_grad_(True)
            if isinstance(model_wrapper,
                          (torch.nn.DataParallel,
                           udist.AllReduceDistributedDataParallel)):
                mc.summary_writer.add_graph(model_wrapper.module, (_input, ),
                                            verbose=True)
            else:
                mc.summary_writer.add_graph(model_wrapper, (_input, ),
                                            verbose=True)
        return

    # check pretrained
    if FLAGS.pretrained:
        checkpoint = torch.load(FLAGS.pretrained,
                                map_location=lambda storage, loc: storage)
        if ema:
            ema.load_state_dict(checkpoint['ema'])
            ema.to(get_device(model))
        # update keys from external models
        if isinstance(checkpoint, dict) and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if (hasattr(FLAGS, 'pretrained_model_remap_keys')
                and FLAGS.pretrained_model_remap_keys):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                if udist.is_master():
                    logging.info('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_wrapper.load_state_dict(checkpoint)
        if udist.is_master():
            logging.info('Loaded model {}.'.format(FLAGS.pretrained))
    optimizer = optim.get_optimizer(model_wrapper, FLAGS)

    # check resume training
    if FLAGS.resume:
        checkpoint = torch.load(os.path.join(FLAGS.resume,
                                             'latest_checkpoint.pt'),
                                map_location=lambda storage, loc: storage)
        model_wrapper = checkpoint['model'].cuda()
        model = model_wrapper.module
        # model = checkpoint['model'].module
        optimizer = checkpoint['optimizer']
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        # model_wrapper.load_state_dict(checkpoint['model'])
        # optimizer.load_state_dict(checkpoint['optimizer'])
        if ema:
            # ema.load_state_dict(checkpoint['ema'])
            ema = checkpoint['ema'].cuda()
            ema.to(get_device(model))
        last_epoch = checkpoint['last_epoch']
        lr_scheduler = optim.get_lr_scheduler(optimizer,
                                              FLAGS,
                                              last_epoch=(last_epoch + 1) *
                                              FLAGS._steps_per_epoch)
        lr_scheduler.last_epoch = (last_epoch + 1) * FLAGS._steps_per_epoch
        best_val = extract_item(checkpoint['best_val'])
        train_meters, val_meters = checkpoint['meters']
        FLAGS._global_step = (last_epoch + 1) * FLAGS._steps_per_epoch
        if udist.is_master():
            logging.info('Loaded checkpoint {} at epoch {}.'.format(
                FLAGS.resume, last_epoch))
    else:
        lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS)
        # last_epoch = lr_scheduler.last_epoch
        last_epoch = -1
        best_val = 1.
        if not FLAGS.distill:
            train_meters = mc.get_meters('train', FLAGS.prune_params['method'])
            val_meters = mc.get_meters('val')
        else:
            train_meters = mc.get_distill_meters('train',
                                                 FLAGS.prune_params['method'])
            val_meters = mc.get_distill_meters('val')
        if FLAGS.model_kwparams.task == 'segmentation':
            best_val = 0.
            if not FLAGS.distill:
                train_meters = mc.get_seg_meters('train',
                                                 FLAGS.prune_params['method'])
                val_meters = mc.get_seg_meters('val')
            else:
                train_meters = mc.get_seg_distill_meters(
                    'train', FLAGS.prune_params['method'])
                val_meters = mc.get_seg_distill_meters('val')
        FLAGS._global_step = 0

    if not FLAGS.resume and udist.is_master():
        logging.info(model_wrapper)
    assert FLAGS.profiling, '`m.macs` is used for calculating penalty'
    # if udist.is_master():
    #     model.apply(lambda m: print(m))
    if FLAGS.profiling:
        if 'gpu' in FLAGS.profiling:
            mc.profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            mc.profiling(model, use_cuda=False)

    if FLAGS.dataset == 'cityscapes':
        (train_set, val_set,
         test_set) = seg_dataflow.cityscapes_datasets(FLAGS)
        segval = SegVal(num_classes=19)
    elif FLAGS.dataset == 'ade20k':
        (train_set, val_set, test_set) = seg_dataflow.ade20k_datasets(FLAGS)
        segval = SegVal(num_classes=150)
    elif FLAGS.dataset == 'coco':
        (train_set, val_set, test_set) = seg_dataflow.coco_datasets(FLAGS)
        # print(len(train_set), len(val_set))  # 149813 104125
        segval = None
    else:
        # data
        (train_transforms, val_transforms,
         test_transforms) = dataflow.data_transforms(FLAGS)
        (train_set, val_set,
         test_set) = dataflow.dataset(train_transforms, val_transforms,
                                      test_transforms, FLAGS)
        segval = None
    (train_loader, calib_loader, val_loader,
     test_loader) = dataflow.data_loader(train_set, val_set, test_set, FLAGS)

    # get bn's weights
    if FLAGS.prune_params.use_transformer:
        FLAGS._bn_to_prune, FLAGS._bn_to_prune_transformer = prune.get_bn_to_prune(
            model, FLAGS.prune_params)
    else:
        FLAGS._bn_to_prune = prune.get_bn_to_prune(model, FLAGS.prune_params)
    rho_scheduler = prune.get_rho_scheduler(FLAGS.prune_params,
                                            FLAGS._steps_per_epoch)

    if FLAGS.test_only and (test_loader is not None):
        if udist.is_master():
            logging.info('Start testing.')
        test_meters = mc.get_meters('test')
        validate(last_epoch, calib_loader, test_loader, criterion, test_meters,
                 model_wrapper, ema, 'test')
        return

    # already broadcast by AllReduceDistributedDataParallel
    # optimizer load same checkpoint/same initialization

    if udist.is_master():
        logging.info('Start training.')

    for epoch in range(last_epoch + 1, FLAGS.num_epochs):
        # train
        results = run_one_epoch(epoch,
                                train_loader,
                                model_wrapper,
                                criterion_smooth,
                                optimizer,
                                lr_scheduler,
                                ema,
                                rho_scheduler,
                                train_meters,
                                phase='train')

        if (epoch + 1) % FLAGS.eval_interval == 0:
            # val
            results, model_eval_wrapper = validate(epoch, calib_loader,
                                                   val_loader, criterion,
                                                   val_meters, model_wrapper,
                                                   ema, 'val', segval, val_set)

            if FLAGS.prune_params['method'] is not None and FLAGS.prune_params[
                    'bn_prune_filter'] is not None:
                prune_threshold = FLAGS.model_shrink_threshold  # 1e-3
                masks = prune.cal_mask_network_slimming_by_threshold(
                    get_prune_weights(model_eval_wrapper), prune_threshold
                )  # get mask for all bn weights (depth-wise)
                FLAGS._bn_to_prune.add_info_list('mask', masks)
                flops_pruned, infos = prune.cal_pruned_flops(
                    FLAGS._bn_to_prune)
                log_pruned_info(mc.unwrap_model(model_eval_wrapper),
                                flops_pruned, infos, prune_threshold)
                if not FLAGS.distill:
                    if flops_pruned >= FLAGS.model_shrink_delta_flops \
                            or epoch == FLAGS.num_epochs - 1:
                        ema_only = (epoch == FLAGS.num_epochs - 1)
                        shrink_model(model_wrapper, ema, optimizer,
                                     FLAGS._bn_to_prune, prune_threshold,
                                     ema_only)
            model_kwparams = mb.output_network(mc.unwrap_model(model_wrapper))

            if udist.is_master():
                if FLAGS.model_kwparams.task == 'classification' and results[
                        'top1_error'] < best_val:
                    best_val = results['top1_error']
                    logging.info(
                        'New best validation top1 error: {:.4f}'.format(
                            best_val))

                    save_status(model_wrapper, model_kwparams, optimizer, ema,
                                epoch, best_val, (train_meters, val_meters),
                                os.path.join(FLAGS.log_dir, 'best_model'))

                elif FLAGS.model_kwparams.task == 'segmentation' and FLAGS.dataset != 'coco' and results[
                        'mIoU'] > best_val:
                    best_val = results['mIoU']
                    logging.info('New seg mIoU: {:.4f}'.format(best_val))

                    save_status(model_wrapper, model_kwparams, optimizer, ema,
                                epoch, best_val, (train_meters, val_meters),
                                os.path.join(FLAGS.log_dir, 'best_model'))
                elif FLAGS.dataset == 'coco' and results > best_val:
                    best_val = results
                    logging.info('New Result: {:.4f}'.format(best_val))
                    save_status(model_wrapper, model_kwparams, optimizer, ema,
                                epoch, best_val, (train_meters, val_meters),
                                os.path.join(FLAGS.log_dir, 'best_model'))

                # save latest checkpoint
                save_status(model_wrapper, model_kwparams, optimizer, ema,
                            epoch, best_val, (train_meters, val_meters),
                            os.path.join(FLAGS.log_dir, 'latest_checkpoint'))

    return
Пример #10
0
def train_val_test():
    """Train and val."""
    torch.backends.cudnn.benchmark = True

    # model
    model, model_wrapper = get_model()
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    criterion_smooth = optim.CrossEntropyLabelSmooth(
        FLAGS.model_kwparams['num_classes'],
        FLAGS['label_smoothing'],
        reduction='none').cuda()
    # TODO: cal loss on all GPUs instead only `cuda:0` when non
    # distributed

    ema = None
    if FLAGS.moving_average_decay > 0.0:
        if FLAGS.moving_average_decay_adjust:
            moving_average_decay = optim.ExponentialMovingAverage.adjust_momentum(
                FLAGS.moving_average_decay,
                FLAGS.moving_average_decay_base_batch / FLAGS.batch_size)
        else:
            moving_average_decay = FLAGS.moving_average_decay
        logging.info('Moving average for model parameters: {}'.format(
            moving_average_decay))
        ema = optim.ExponentialMovingAverage(moving_average_decay)
        for name, param in model.named_parameters():
            ema.register(name, param)
        # We maintain mva for batch norm moving mean and variance as well.
        for name, buffer in model.named_buffers():
            if 'running_var' in name or 'running_mean' in name:
                ema.register(name, buffer)

    if FLAGS.get('log_graph_only', False):
        if is_root_rank:
            _input = torch.zeros(1, 3, FLAGS.image_size,
                                 FLAGS.image_size).cuda()
            _input = _input.requires_grad_(True)
            summary_writer.add_graph(model_wrapper, (_input, ), verbose=True)
        return

    # check pretrained
    if FLAGS.pretrained:
        checkpoint = torch.load(FLAGS.pretrained,
                                map_location=lambda storage, loc: storage)
        if ema:
            ema.load_state_dict(checkpoint['ema'])
            ema.to(get_device(model))
        # update keys from external models
        if isinstance(checkpoint, dict) and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if (hasattr(FLAGS, 'pretrained_model_remap_keys')
                and FLAGS.pretrained_model_remap_keys):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                logging.info('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_wrapper.load_state_dict(checkpoint)
        logging.info('Loaded model {}.'.format(FLAGS.pretrained))
    optimizer = optim.get_optimizer(model_wrapper, FLAGS)

    # check resume training
    if FLAGS.resume:
        checkpoint = torch.load(os.path.join(FLAGS.resume,
                                             'latest_checkpoint.pt'),
                                map_location=lambda storage, loc: storage)
        model_wrapper.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if ema:
            ema.load_state_dict(checkpoint['ema'])
            ema.to(get_device(model))
        last_epoch = checkpoint['last_epoch']
        lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS)
        lr_scheduler.last_epoch = (last_epoch + 1) * FLAGS._steps_per_epoch
        best_val = extract_item(checkpoint['best_val'])
        train_meters, val_meters = checkpoint['meters']
        FLAGS._global_step = (last_epoch + 1) * FLAGS._steps_per_epoch
        if is_root_rank:
            logging.info('Loaded checkpoint {} at epoch {}.'.format(
                FLAGS.resume, last_epoch))
    else:
        lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS)
        # last_epoch = lr_scheduler.last_epoch
        last_epoch = -1
        best_val = 1.
        train_meters = get_meters('train')
        val_meters = get_meters('val')
        FLAGS._global_step = 0

    if not FLAGS.resume and is_root_rank:
        logging.info(model_wrapper)
    assert FLAGS.profiling, '`m.macs` is used for calculating penalty'
    if FLAGS.profiling:
        if 'gpu' in FLAGS.profiling:
            profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            profiling(model, use_cuda=False)

    # data
    (train_transforms, val_transforms,
     test_transforms) = dataflow.data_transforms(FLAGS)
    (train_set, val_set, test_set) = dataflow.dataset(train_transforms,
                                                      val_transforms,
                                                      test_transforms, FLAGS)
    (train_loader, calib_loader, val_loader,
     test_loader) = dataflow.data_loader(train_set, val_set, test_set, FLAGS)

    # get bn's weights
    FLAGS._bn_to_prune = prune.get_bn_to_prune(model, FLAGS.prune_params)
    rho_scheduler = prune.get_rho_scheduler(FLAGS.prune_params,
                                            FLAGS._steps_per_epoch)

    if FLAGS.test_only and (test_loader is not None):
        if is_root_rank:
            logging.info('Start testing.')
        test_meters = get_meters('test')
        validate(last_epoch, calib_loader, test_loader, criterion, test_meters,
                 model_wrapper, ema, 'test')
        return

    # already broadcast by AllReduceDistributedDataParallel
    # optimizer load same checkpoint/same initialization

    if is_root_rank:
        logging.info('Start training.')

    for epoch in range(last_epoch + 1, FLAGS.num_epochs):
        # train
        results = run_one_epoch(epoch,
                                train_loader,
                                model_wrapper,
                                criterion_smooth,
                                optimizer,
                                lr_scheduler,
                                ema,
                                rho_scheduler,
                                train_meters,
                                phase='train')

        # val
        results, model_eval_wrapper = validate(epoch, calib_loader, val_loader,
                                               criterion, val_meters,
                                               model_wrapper, ema, 'val')

        if FLAGS.prune_params['method'] is not None:
            prune_threshold = FLAGS.model_shrink_threshold
            masks = prune.cal_mask_network_slimming_by_threshold(
                get_prune_weights(model_eval_wrapper), prune_threshold)
            FLAGS._bn_to_prune.add_info_list('mask', masks)
            flops_pruned, infos = prune.cal_pruned_flops(FLAGS._bn_to_prune)
            log_pruned_info(unwrap_model(model_eval_wrapper), flops_pruned,
                            infos, prune_threshold)
            if flops_pruned >= FLAGS.model_shrink_delta_flops \
                    or epoch == FLAGS.num_epochs - 1:
                ema_only = (epoch == FLAGS.num_epochs - 1)
                shrink_model(model_wrapper, ema, optimizer, FLAGS._bn_to_prune,
                             prune_threshold, ema_only)
        model_kwparams = mb.output_network(unwrap_model(model_wrapper))

        if results['top1_error'] < best_val:
            best_val = results['top1_error']

            if is_root_rank:
                save_status(model_wrapper, model_kwparams, optimizer, ema,
                            epoch, best_val, (train_meters, val_meters),
                            os.path.join(FLAGS.log_dir, 'best_model'))
                logging.info(
                    'New best validation top1 error: {:.4f}'.format(best_val))

        if is_root_rank:
            # save latest checkpoint
            save_status(model_wrapper, model_kwparams, optimizer, ema, epoch,
                        best_val, (train_meters, val_meters),
                        os.path.join(FLAGS.log_dir, 'latest_checkpoint'))

        # NOTE: from scheduler code, should be called after train/val
        # use stepwise scheduler instead
        # lr_scheduler.step()
    return