Beispiel #1
0
    def testAdjustEmaRate(self):
        num_repeat = 3
        for num_repeat in [1, 5, 9]:
            for momentum in [0.25, 0.9999]:
                momentum = 0.25
                name = 'v'
                values = torch.randn(5)
                values_long = values.repeat(num_repeat, 1).permute(
                    (1, 0)).contiguous().view(-1)

                ema = optim.ExponentialMovingAverage(momentum)
                ema.register(name, values[0])
                for v in values:
                    ema(name, v)
                lhs = ema.average(name)

                momentum = optim.ExponentialMovingAverage.adjust_momentum(
                    momentum, num_repeat)
                ema = optim.ExponentialMovingAverage(momentum)
                ema.register(name, values[0])
                for v in values_long:
                    ema(name, v)
                rhs = ema.average(name)

                assertAllClose(lhs, rhs)
Beispiel #2
0
 def testThrowValueError(self):
     ema = optim.ExponentialMovingAverage(0.25)
     try:
         ema.register('test', torch.tensor(5, dtype=torch.int))
     except TypeError:
         pass
     else:
         raise AssertionError('Should throw ValueError')
Beispiel #3
0
 def testDevice(self):
     name = 'zeros'
     ema = optim.ExponentialMovingAverage(0.25)
     ema.register(name, torch.zeros(5))
     for device in ['cpu', 'cuda:0']:
         device = torch.device(device)
         ema = ema.to(device)
         cur_device = ema._shadow[name].device
         assert cur_device == device
Beispiel #4
0
 def testAverageVariablesUpdateNumUpdates_Vector(self):
     ema = optim.ExponentialMovingAverage(0.25)
     name = 'tens'
     tens = _Repeat(10.0, dim=5)
     var = torch.tensor(tens)
     ema.register(name, var, zero_init=False)
     for num_updates in range(2):
         var.add_(1)
         ema(name, var, num_updates=num_updates)
     expected = _Repeat(
         (10 * 0.1 + 11 * 0.9) * 2.0 / 11.0 + 12 * 9.0 / 11.0, dim=5)
     assertAllClose(expected, ema.average(name))
Beispiel #5
0
    def testSaveLoad(self):
        ema = optim.ExponentialMovingAverage(0.25)
        name = 'tens'
        tens = _Repeat(10.0, dim=5)
        var = torch.tensor(tens)
        ema.register(name, var, zero_init=False)
        state_dict = ema.state_dict()
        for name in ['info', 'shadow', 'param']:
            assert name in state_dict
        assert 'tens' in state_dict['shadow']
        assertAllClose(state_dict['shadow']['tens'], var)

        ema.load_state_dict(state_dict)
        state_dict['param']['momentum'] = 0.5
        self.assertWarns(RuntimeWarning,
                         lambda: ema.load_state_dict(state_dict))
Beispiel #6
0
 def testCompress(self):
     ema = optim.ExponentialMovingAverage(0.25)
     ema.register('var_prune', torch.arange(5).float())
     ema.register('var_keep', torch.arange(5, 10).float())
     ema('var_prune', torch.arange(5).float())
     info = {
         'var_old_name': 'var_prune',
         'var_new_name': 'var_new',
         'var_new': torch.randn(3),
         'mask': torch.tensor([False, True, False, True, True]),
         'mask_hook': lambda lhs, rhs, mask: lhs.data.copy_(rhs.data[mask])
     }
     ema.compress_mask(info, verbose=False)
     self.assertTrue(info['var_new_name'] in ema._shadow)
     self.assertTrue(info['var_new_name'] in ema._info)
     self.assertTrue(info['var_old_name'] not in ema._shadow)
     self.assertTrue(info['var_old_name'] not in ema._info)
     self.assertEqual(ema._info[info['var_new_name']]['num_updates'], 1)
     assertAllClose(ema.average(info['var_new_name']), [1, 3, 4])
Beispiel #7
0
def setup_ema(model):
    """Setup EMA for model's weights."""
    from utils import optim

    ema = None
    if FLAGS.moving_average_decay > 0.0:
        if FLAGS.moving_average_decay_adjust:
            moving_average_decay = \
                optim.ExponentialMovingAverage.adjust_momentum(
                    FLAGS.moving_average_decay,
                    FLAGS.moving_average_decay_base_batch / FLAGS.batch_size)
        else:
            moving_average_decay = FLAGS.moving_average_decay
        logging.info('Moving average for model parameters: {}'.format(
            moving_average_decay))
        ema = optim.ExponentialMovingAverage(moving_average_decay)
        for name, param in model.named_parameters():
            ema.register(name, param)
        # We maintain mva for batch norm moving mean and variance as well.
        for name, buffer in model.named_buffers():
            if 'running_var' in name or 'running_mean' in name:
                ema.register(name, buffer)
    return ema
Beispiel #8
0
 def testAverageVariablesNumUpdates_Vector_Debias(self):
     # With num_updates 1, the decay applied is 0.1818
     ema = optim.ExponentialMovingAverage(0.25, zero_debias=True)
     self._CheckDecay(ema, actual_decay=0.181818, dim=5, num_updates=1)
Beispiel #9
0
 def testAverageVariablesNoNumUpdates_Vector_Debias(self):
     ema = optim.ExponentialMovingAverage(0.25, zero_debias=True)
     self._CheckDecay(ema, actual_decay=0.25, dim=5)
Beispiel #10
0
 def testAverageVariablesNumUpdates_Vector(self):
     ema = optim.ExponentialMovingAverage(0.25)
     self._CheckDecay(ema, actual_decay=0.181818, dim=5, num_updates=1)
Beispiel #11
0
 def testAverageVariablesNumUpdates_Scalar(self):
     # With num_updates 1, the decay applied is 0.1818
     ema = optim.ExponentialMovingAverage(0.25)
     self._CheckDecay(ema, actual_decay=0.181818, dim=1, num_updates=1)
Beispiel #12
0
 def testAverageVariablesNoNumUpdates_Scalar(self):
     ema = optim.ExponentialMovingAverage(0.25)
     self._CheckDecay(ema, actual_decay=0.25, dim=1)
Beispiel #13
0
def train_val_test():
    """Train and val."""
    torch.backends.cudnn.benchmark = True

    # model
    model, model_wrapper = get_model()
    criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()
    criterion_smooth = optim.CrossEntropyLabelSmooth(
        FLAGS.model_kwparams['num_classes'],
        FLAGS['label_smoothing'],
        reduction='none').cuda()
    # TODO: cal loss on all GPUs instead only `cuda:0` when non
    # distributed

    ema = None
    if FLAGS.moving_average_decay > 0.0:
        if FLAGS.moving_average_decay_adjust:
            moving_average_decay = optim.ExponentialMovingAverage.adjust_momentum(
                FLAGS.moving_average_decay,
                FLAGS.moving_average_decay_base_batch / FLAGS.batch_size)
        else:
            moving_average_decay = FLAGS.moving_average_decay
        logging.info('Moving average for model parameters: {}'.format(
            moving_average_decay))
        ema = optim.ExponentialMovingAverage(moving_average_decay)
        for name, param in model.named_parameters():
            ema.register(name, param)
        # We maintain mva for batch norm moving mean and variance as well.
        for name, buffer in model.named_buffers():
            if 'running_var' in name or 'running_mean' in name:
                ema.register(name, buffer)

    if FLAGS.get('log_graph_only', False):
        if is_root_rank:
            _input = torch.zeros(1, 3, FLAGS.image_size,
                                 FLAGS.image_size).cuda()
            _input = _input.requires_grad_(True)
            summary_writer.add_graph(model_wrapper, (_input, ), verbose=True)
        return

    # check pretrained
    if FLAGS.pretrained:
        checkpoint = torch.load(FLAGS.pretrained,
                                map_location=lambda storage, loc: storage)
        if ema:
            ema.load_state_dict(checkpoint['ema'])
            ema.to(get_device(model))
        # update keys from external models
        if isinstance(checkpoint, dict) and 'model' in checkpoint:
            checkpoint = checkpoint['model']
        if (hasattr(FLAGS, 'pretrained_model_remap_keys')
                and FLAGS.pretrained_model_remap_keys):
            new_checkpoint = {}
            new_keys = list(model_wrapper.state_dict().keys())
            old_keys = list(checkpoint.keys())
            for key_new, key_old in zip(new_keys, old_keys):
                new_checkpoint[key_new] = checkpoint[key_old]
                logging.info('remap {} to {}'.format(key_new, key_old))
            checkpoint = new_checkpoint
        model_wrapper.load_state_dict(checkpoint)
        logging.info('Loaded model {}.'.format(FLAGS.pretrained))
    optimizer = optim.get_optimizer(model_wrapper, FLAGS)

    # check resume training
    if FLAGS.resume:
        checkpoint = torch.load(os.path.join(FLAGS.resume,
                                             'latest_checkpoint.pt'),
                                map_location=lambda storage, loc: storage)
        model_wrapper.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if ema:
            ema.load_state_dict(checkpoint['ema'])
            ema.to(get_device(model))
        last_epoch = checkpoint['last_epoch']
        lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS)
        lr_scheduler.last_epoch = (last_epoch + 1) * FLAGS._steps_per_epoch
        best_val = extract_item(checkpoint['best_val'])
        train_meters, val_meters = checkpoint['meters']
        FLAGS._global_step = (last_epoch + 1) * FLAGS._steps_per_epoch
        if is_root_rank:
            logging.info('Loaded checkpoint {} at epoch {}.'.format(
                FLAGS.resume, last_epoch))
    else:
        lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS)
        # last_epoch = lr_scheduler.last_epoch
        last_epoch = -1
        best_val = 1.
        train_meters = get_meters('train')
        val_meters = get_meters('val')
        FLAGS._global_step = 0

    if not FLAGS.resume and is_root_rank:
        logging.info(model_wrapper)
    assert FLAGS.profiling, '`m.macs` is used for calculating penalty'
    if FLAGS.profiling:
        if 'gpu' in FLAGS.profiling:
            profiling(model, use_cuda=True)
        if 'cpu' in FLAGS.profiling:
            profiling(model, use_cuda=False)

    # data
    (train_transforms, val_transforms,
     test_transforms) = dataflow.data_transforms(FLAGS)
    (train_set, val_set, test_set) = dataflow.dataset(train_transforms,
                                                      val_transforms,
                                                      test_transforms, FLAGS)
    (train_loader, calib_loader, val_loader,
     test_loader) = dataflow.data_loader(train_set, val_set, test_set, FLAGS)

    # get bn's weights
    FLAGS._bn_to_prune = prune.get_bn_to_prune(model, FLAGS.prune_params)
    rho_scheduler = prune.get_rho_scheduler(FLAGS.prune_params,
                                            FLAGS._steps_per_epoch)

    if FLAGS.test_only and (test_loader is not None):
        if is_root_rank:
            logging.info('Start testing.')
        test_meters = get_meters('test')
        validate(last_epoch, calib_loader, test_loader, criterion, test_meters,
                 model_wrapper, ema, 'test')
        return

    # already broadcast by AllReduceDistributedDataParallel
    # optimizer load same checkpoint/same initialization

    if is_root_rank:
        logging.info('Start training.')

    for epoch in range(last_epoch + 1, FLAGS.num_epochs):
        # train
        results = run_one_epoch(epoch,
                                train_loader,
                                model_wrapper,
                                criterion_smooth,
                                optimizer,
                                lr_scheduler,
                                ema,
                                rho_scheduler,
                                train_meters,
                                phase='train')

        # val
        results, model_eval_wrapper = validate(epoch, calib_loader, val_loader,
                                               criterion, val_meters,
                                               model_wrapper, ema, 'val')

        if FLAGS.prune_params['method'] is not None:
            prune_threshold = FLAGS.model_shrink_threshold
            masks = prune.cal_mask_network_slimming_by_threshold(
                get_prune_weights(model_eval_wrapper), prune_threshold)
            FLAGS._bn_to_prune.add_info_list('mask', masks)
            flops_pruned, infos = prune.cal_pruned_flops(FLAGS._bn_to_prune)
            log_pruned_info(unwrap_model(model_eval_wrapper), flops_pruned,
                            infos, prune_threshold)
            if flops_pruned >= FLAGS.model_shrink_delta_flops \
                    or epoch == FLAGS.num_epochs - 1:
                ema_only = (epoch == FLAGS.num_epochs - 1)
                shrink_model(model_wrapper, ema, optimizer, FLAGS._bn_to_prune,
                             prune_threshold, ema_only)
        model_kwparams = mb.output_network(unwrap_model(model_wrapper))

        if results['top1_error'] < best_val:
            best_val = results['top1_error']

            if is_root_rank:
                save_status(model_wrapper, model_kwparams, optimizer, ema,
                            epoch, best_val, (train_meters, val_meters),
                            os.path.join(FLAGS.log_dir, 'best_model'))
                logging.info(
                    'New best validation top1 error: {:.4f}'.format(best_val))

        if is_root_rank:
            # save latest checkpoint
            save_status(model_wrapper, model_kwparams, optimizer, ema, epoch,
                        best_val, (train_meters, val_meters),
                        os.path.join(FLAGS.log_dir, 'latest_checkpoint'))

        # NOTE: from scheduler code, should be called after train/val
        # use stepwise scheduler instead
        # lr_scheduler.step()
    return