def testCompressWithoutState(self): parameters, info = self._construct_info() optimizer = RMSprop(parameters) optimizer.compress_mask(info, verbose=False) self.assertFalse(check_tensor_in(info['var_new'], optimizer.state)) self.assertTrue( check_tensor_in(info['var_new'], optimizer.param_groups[0]['params'])) self.assertFalse(check_tensor_in(parameters[0], optimizer.state)) self.assertFalse( check_tensor_in(parameters[0], optimizer.param_groups[0]['params']))
def compress_mask(self, info, verbose=False): """Adjust parameters values by masks for dynamic network shrinkage.""" var_old = info['var_old'] var_new = info['var_new'] mask_hook = info['mask_hook'] mask = info['mask'] if verbose: logging.info('RMSProp compress: {} -> {}'.format( info['var_old_name'], info['var_new_name'])) found = False for group in self.param_groups: index = index_tensor_in(var_old, group['params'], raise_error=False) found = index is not None if found: if check_tensor_in(var_old, self.state): state = self.state.pop(var_old) if len(state) != 0: # generate new state new_state = {'step': state['step']} for key in ['exp_avg', 'exp_avg_sq', 'max_exp_avg_sq']: if key in state: new_state[key] = torch.zeros_like( var_new.data, device=var_old.device) mask_hook(new_state[key], state[key], mask) new_state[key].to(state[key].device) self.state[var_new] = new_state # update group del group['params'][index] group['params'].append(var_new) break assert found, 'Var: {} not in RMSProp'.format(info['var_old_name'])
def compress_drop(self, info, verbose=False): """Remove unused parameters for dynamic network shrinkage.""" var_old = info['var_old'] if verbose: logging.info('RMSProp drop: {}'.format(info['var_old_name'])) assert info['type'] == 'variable' found = False for group in self.param_groups: index = index_tensor_in(var_old, group['params'], raise_error=False) found = index is not None if found: if check_tensor_in(var_old, self.state): self.state.pop(var_old) del group['params'][index] assert found, 'Var: {} not in RMSProp'.format(info['var_old_name'])