예제 #1
0
 def testCompressWithoutState(self):
     parameters, info = self._construct_info()
     optimizer = RMSprop(parameters)
     optimizer.compress_mask(info, verbose=False)
     self.assertFalse(check_tensor_in(info['var_new'], optimizer.state))
     self.assertTrue(
         check_tensor_in(info['var_new'],
                         optimizer.param_groups[0]['params']))
     self.assertFalse(check_tensor_in(parameters[0], optimizer.state))
     self.assertFalse(
         check_tensor_in(parameters[0],
                         optimizer.param_groups[0]['params']))
    def testCompressWithEmaOptimizerPruneinfo(self):
        from utils.optim import ExponentialMovingAverage
        from utils.rmsprop import RMSprop
        from utils.prune import PruneInfo

        import sys
        import logging
        logging.basicConfig(stream=sys.stdout,
                            level=logging.INFO,
                            datefmt='%m/%d %I:%M:%S %p')

        def create_ema(m):
            ema = ExponentialMovingAverage(0.25)
            for name, param in m.named_parameters():
                ema.register(name, param)
            for name, param in m.named_buffers():
                if 'running_var' in name or 'running_mean' in name:
                    ema.register(name, param)
            return ema

        def get_weight_name(index):
            return 'ops.{}.1.1.weight'.format(index)

        inp, oup, expand = 3, 5, True
        m = mb.InvertedResidualChannels(inp, oup, 1, [2, 4], [3, 5], expand,
                                        torch.nn.ReLU)
        num_var = len(list(m.parameters()))
        optimizer = RMSprop(m.parameters())
        ema = create_ema(m)
        prune_info = PruneInfo([get_weight_name(i) for i in range(len(m.ops))],
                               [1, 2])
        masks = [
            torch.tensor([False, False]),
            torch.tensor([True, True, True, False])
        ]
        m.compress_by_mask(masks,
                           ema=ema,
                           optimizer=optimizer,
                           prune_info=prune_info,
                           verbose=False)
        ema2 = create_ema(m)
        self.assertLess(len(list(m.parameters())), num_var)
        self.assertEqual(set(m.parameters()),
                         set(optimizer.param_groups[0]['params']))
        self.assertEqual(set(ema.average_names()), set(ema2.average_names()))
        self.assertListEqual(prune_info.weight, [get_weight_name(0)])
        self.assertEqual(len(prune_info.weight), 1)
        self.assertListEqual(prune_info.penalty, [2])
예제 #3
0
파일: optim.py 프로젝트: dingmyu/HR-NAS
def get_optimizer(model, FLAGS):
    """Get optimizer."""
    if FLAGS.prune_params['method'] is not None:
        weight_decay = 0
    else:
        weight_decay = FLAGS.weight_decay
    if FLAGS.optimizer == 'sgd':
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=FLAGS.lr,
            momentum=FLAGS.momentum,
            nesterov=FLAGS.nesterov,
            weight_decay=weight_decay
        )  # set weight decay only on convs and fcs manually.
    elif FLAGS.optimizer == 'rmsprop':
        optimizer = RMSprop(model.parameters(),
                            lr=FLAGS.lr,
                            alpha=FLAGS.alpha,
                            momentum=FLAGS.momentum,
                            eps=FLAGS.epsilon,
                            eps_inside_sqrt=FLAGS.eps_inside_sqrt,
                            weight_decay=weight_decay)
    elif FLAGS.optimizer == 'adamw':
        optimizer = AdamW(model.parameters(),
                          lr=FLAGS.lr,
                          eps=FLAGS.epsilon,
                          weight_decay=weight_decay)
    elif FLAGS.optimizer == 'adam':
        optimizer = AdamW(model.parameters(),
                          lr=FLAGS.lr,
                          eps=FLAGS.epsilon,
                          weight_decay=weight_decay)
    else:
        try:
            optimizer_lib = importlib.import_module(FLAGS.optimizer)
            return optimizer_lib.get_optimizer(model)
        except ImportError:
            raise NotImplementedError(
                'Optimizer {} is not yet implemented.'.format(FLAGS.optimizer))
    return optimizer
예제 #4
0
def get_optimizer(model, FLAGS):
    """Get optimizer."""
    if FLAGS.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=FLAGS.lr,
                                    momentum=FLAGS.momentum,
                                    nesterov=FLAGS.nesterov,
                                    weight_decay=0)
    elif FLAGS.optimizer == 'rmsprop':
        optimizer = RMSprop(model.parameters(),
                            lr=FLAGS.lr,
                            alpha=FLAGS.alpha,
                            momentum=FLAGS.momentum,
                            eps=FLAGS.epsilon,
                            eps_inside_sqrt=FLAGS.eps_inside_sqrt,
                            weight_decay=0)
    else:
        try:
            optimizer_lib = importlib.import_module(FLAGS.optimizer)
            return optimizer_lib.get_optimizer(model)
        except ImportError:
            raise NotImplementedError(
                'Optimizer {} is not yet implemented.'.format(FLAGS.optimizer))
    return optimizer
예제 #5
0
    def testCompressUpdate(self):
        params, info = self._construct_info()

        params0 = copy.deepcopy(params)
        apply_gradients([p.grad for p in params], params0)
        optimizer = RMSprop(params0, lr=0.1, momentum=0.5)
        optimizer.step()

        params1 = copy.deepcopy(params)
        apply_gradients([p.grad for p in params], params1)
        optimizer1 = RMSprop(params1, lr=0.1, momentum=0.5)
        optimizer1.step()

        assertAllClose(params0[1], params1[1])
        assertAllClose(params0[2], params1[2])
        assertAllClose(params0[0], params1[0])

        info['var_old'] = params1[0]
        optimizer1.compress_mask(info, verbose=True)
        optimizer1.compress_drop({'var_old': params1[2], 'type': 'variable'})
        info['mask_hook'](info['var_new'], info['var_old'], info['mask'])
        params1[0] = info['var_new']
        params1[0].grad = params0[0].grad.data[info['mask']]

        optimizer1.step()  # params1[2] not updated
        assertAllClose(params0[2], params1[2])

        optimizer.step()
        assertAllClose(params0[1], params1[1])
        assertAllClose(params0[0][info['mask']], params1[0])