Пример #1
0
 def testCompressWithoutState(self):
     parameters, info = self._construct_info()
     optimizer = RMSprop(parameters)
     optimizer.compress_mask(info, verbose=False)
     self.assertFalse(check_tensor_in(info['var_new'], optimizer.state))
     self.assertTrue(
         check_tensor_in(info['var_new'],
                         optimizer.param_groups[0]['params']))
     self.assertFalse(check_tensor_in(parameters[0], optimizer.state))
     self.assertFalse(
         check_tensor_in(parameters[0],
                         optimizer.param_groups[0]['params']))
Пример #2
0
    def compress_mask(self, info, verbose=False):
        """Adjust parameters values by masks for dynamic network shrinkage."""
        var_old = info['var_old']
        var_new = info['var_new']
        mask_hook = info['mask_hook']
        mask = info['mask']
        if verbose:
            logging.info('RMSProp compress: {} -> {}'.format(
                info['var_old_name'], info['var_new_name']))

        found = False
        for group in self.param_groups:
            index = index_tensor_in(var_old,
                                    group['params'],
                                    raise_error=False)
            found = index is not None
            if found:
                if check_tensor_in(var_old, self.state):
                    state = self.state.pop(var_old)
                    if len(state) != 0:  # generate new state
                        new_state = {'step': state['step']}
                        for key in ['exp_avg', 'exp_avg_sq', 'max_exp_avg_sq']:
                            if key in state:
                                new_state[key] = torch.zeros_like(
                                    var_new.data, device=var_old.device)
                                mask_hook(new_state[key], state[key], mask)
                                new_state[key].to(state[key].device)
                        self.state[var_new] = new_state

                # update group
                del group['params'][index]
                group['params'].append(var_new)
                break
        assert found, 'Var: {} not in RMSProp'.format(info['var_old_name'])
Пример #3
0
    def compress_drop(self, info, verbose=False):
        """Remove unused parameters for dynamic network shrinkage."""
        var_old = info['var_old']
        if verbose:
            logging.info('RMSProp drop: {}'.format(info['var_old_name']))

        assert info['type'] == 'variable'
        found = False
        for group in self.param_groups:
            index = index_tensor_in(var_old,
                                    group['params'],
                                    raise_error=False)
            found = index is not None
            if found:
                if check_tensor_in(var_old, self.state):
                    self.state.pop(var_old)
                del group['params'][index]
        assert found, 'Var: {} not in RMSProp'.format(info['var_old_name'])