def _add_hook_and_collect(self,model: nn.Module, inputs, custom_ops=None, verbose=True): handler_collection = {} types_collection = set() if custom_ops is None: custom_ops = {} def add_hooks(m: nn.Module): m_type = type(m) fn = None if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite. fn = custom_ops[m_type] if m_type not in types_collection and verbose: print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type)) elif m_type in register_hooks: fn = register_hooks[m_type] if m_type not in types_collection and verbose: print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type)) else: if m_type not in types_collection and verbose: prRed("[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." % m_type) def count_parameters(m, x, y): total_params = 0 for p in m.parameters(): total_params += torch.DoubleTensor([p.numel()]) m.total_params[0] = total_params handler_collection_xy = {} if fn is not None: m.register_buffer('total_ops', torch.zeros(1, dtype=torch.float64)) m.register_buffer('total_params', torch.zeros(1, dtype=torch.float64)) handler_collection[m] = (m.register_forward_hook(fn), m.register_forward_hook(count_parameters)) types_collection.add(m_type) prev_training_status = model.training model.eval() model.apply(add_hooks) with torch.no_grad(): model(*inputs) # collecting flops and params for i,m in enumerate(self.model.modules()): if i in self.all_idx: self.params_dict[i] = m.total_params.item() self.flops_dict[i] = m.total_ops.item() self.params_list.append(m.total_params.item()) self.flops_list.append(m.total_ops.item()) model.train(prev_training_status) for m, (op_handler, params_handler) in handler_collection.items(): op_handler.remove() params_handler.remove() m._buffers.pop("total_ops") m._buffers.pop("total_params")
def set_active_group(module: nn.Module, group): """Scan all submodules, passing a distributed group to all those that implement `set_group`""" def _set_group(m): if hasattr(m, "set_group"): m.set_group(group) module.apply(_set_group)
def convert_weights(model: nn.Module): """Convert applicable model parameters to fp16""" def _convert_weights_to_fp16(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.half() if l.bias is not None: l.bias.data = l.bias.data.half() if isinstance(l, nn.MultiheadAttention): for attr in [ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v", ]: tensor = getattr(l, attr) if tensor is not None: tensor.data = tensor.data.half() for name in ["text_projection", "proj"]: if hasattr(l, name): attr = getattr(l, name) if attr is not None: attr.data = attr.data.half() model.apply(_convert_weights_to_fp16)
def convert_weights(model: nn.Module): '''Convert applicable model parameters to fp16''' def _convert_weights_to_fp16(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.half() if l.bias is not None: l.bias.data = l.bias.data.half() if isinstance(l, nn.MultiheadAttention): for attr in [ *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']], 'in_proj_bias', 'bias_k', 'bias_v', ]: tensor = getattr(l, attr) if tensor is not None: tensor.data = tensor.data.half() for name in ['text_projection', 'proj']: if hasattr(l, name): attr = getattr(l, name) if attr is not None: attr.data = attr.data.half() model.apply(_convert_weights_to_fp16)
def init_weights(net: nn.Module, init_type: str = 'xavier_uniform', init_gain: float = 1.) -> None: def init_func(m: nn.Module): name = m.__class__.__name__ if hasattr(m, 'weight') and ('Conv' in name or 'Linear' in name): if init_type == 'normal': init.normal_(m.weight.data, 0.0, init_gain) elif init_type == 'xavier_normal': init.xavier_normal_(m.weight.data, gain=init_gain) elif init_type == 'xavier_uniform': init.xavier_uniform_(m.weight.data, gain=init_gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=init_gain) else: raise NotImplementedError(f'initialization method [{init_type}] is not implemented') if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif 'BatchNorm2d' in name: init.normal_(m.weight.data, 1.0, init_gain) init.constant_(m.bias.data, 0.0) net.apply(init_func) logger.info(f'Weights have been initialized with (type={init_type}, gain={init_gain}).')
def prepare_model_for_grad_receptive_field(model: nn.Module) -> nn.Module: # See: # https://github.com/rogertrullo/Receptive-Field-in-Pytorch/blob/master/Receptive_Field.ipynb # make copy model = deepcopy(model) model.apply(modify_for_grad_receptive_field) return model
def disable_bn_stats(model: nn.Module): def f(m: nn.Module): if hasattr(m, 'track_running_stats'): m.track_running_stats ^= True model.apply(f) yield model.apply(f)
def disable_batchnorm_tracking(model: Module) -> None: """adapted from: https://github.com/lyakaap/VAT-pytorch/blob/master/vat.py """ def switch_attr(m): if hasattr(m, 'track_running_stats'): m.track_running_stats ^= True model.apply(switch_attr) yield model.apply(switch_attr)
def initialize_weights( model: nn.Module, initialization_function: Callable, pre_initialization: Optional[Callable] = None, is_admin=False, **kwargs, ) -> None: pre_init_outputs = {} if pre_initialization is not None: pre_init_outputs = pre_initialization(model) # must be dict initialization_function_with_args = partial(initialization_function, **kwargs, **pre_init_outputs) model.apply(initialization_function_with_args)
def __call__(self, module: nn.Module) -> None: def init(m): if self.wholemodule: trunc_normal_init(m, self.mean, self.std, self.a, self.b, self.bias) else: layername = m.__class__.__name__ for layer_ in self.layer: if layername == layer_: trunc_normal_init(m, self.mean, self.std, self.a, self.b, self.bias) module.apply(init)
def __call__(self, module: nn.Module) -> None: def init(m): if self.wholemodule: trunc_normal_init(m, self.mean, self.std, self.a, self.b, self.bias) else: layername = m.__class__.__name__ basesname = _get_bases_name(m) if len(set(self.layer) & set([layername] + basesname)): trunc_normal_init(m, self.mean, self.std, self.a, self.b, self.bias) module.apply(init)
def prune_model( model: Module, pruning_fn: Callable, keys_to_prune: List[str], amount: Union[float, int], layers_to_prune: Optional[List[str]] = None, reinitialize_after_pruning: Optional[bool] = False, ) -> None: """ Prune model function can be used for pruning certain tensors in model layers. Raises: AttributeError: If layers_to_prune is not None, but there is no layers with specified name. Exception: If no layers have specified keys. Args: model: Model to be pruned. pruning_fn: Pruning function with API same as in torch.nn.utils.pruning. pruning_fn(module, name, amount). keys_to_prune: list of strings. Determines which tensor in modules will be pruned. amount: quantity of parameters to prune. If float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If int, it represents the absolute number of parameters to prune. layers_to_prune: list of strings - module names to be pruned. If None provided then will try to prune every module in model. reinitialize_after_pruning: if True then will reinitialize model after pruning. (Lottery Ticket Hypothesis check e.g.) """ pruned_modules = 0 for name, module in model.named_modules(): try: if layers_to_prune is None or name in layers_to_prune: for key in keys_to_prune: pruning_fn(module, name=key, amount=amount) pruned_modules += 1 except AttributeError as e: if layers_to_prune is not None: raise e if pruned_modules == 0: raise Exception(f"There is no {keys_to_prune} key in your model") if reinitialize_after_pruning: model.apply(reset_weights_if_possible)
def export_onnx( cls, module: Module, input_shape: Tuple[int, ...], export_path: str, input_t: Optional[Union[Tensor, QuantTensor]] = None, **kwargs): """ * input_shape : tuple describing the shape of network input e.g. (1, 1, 28, 28) * export_path : ONNX filename to export to * input_t : if specified, do an initial forward pass with this value. this may be necessary for QuantTensor caching. * torch_onnx_kwargs : will be passed as kwargs to torch.onnx.export """ def set_export_handler(m: Module): if hasattr(m, 'export_handler') and m.export_handler is None: handler = cls.handler_from_module(m) m.export_handler = handler() if onnx is None or opt is None: raise ModuleNotFoundError("Installation of ONNX is required.") cls.solve_keep_initializers_as_inputs(kwargs) cls.solve_enable_onnx_checker(kwargs) with torch.no_grad(): module = module.eval() module.apply(set_export_handler) if input_t is None: input_t = torch.empty(input_shape, dtype=torch.float) # do a forward pass with the dummy input to e.g. store input/output shapes cls.cache_inp_out(module, input_t) # override any given input_t to make sure it's a standard PyTorch tensor input_t = torch.empty(input_shape, dtype=torch.float) # enable export mode, this triggers collecting export values into handlers module.apply(lambda m: _set_export_mode(m, enabled=True)) # temporarily disable input caching to avoid collectives empty debug values module.apply(lambda m: _override_inp_caching_mode(m, enabled=False)) # perform export pass torch.onnx.export(module, input_t, export_path, **kwargs) # restore the model to previous properties module.apply(lambda m: _restore_inp_caching_mode(m)) module.apply(lambda m: _set_export_mode(m, enabled=False)) # do some cleanup on the exported ONNX model model = onnx.load(export_path) model = opt.optimize(model, cls.onnx_passes) model = cls.apply_model_transforms(model) onnx.save(model, export_path)
def __call__(self, module: nn.Module) -> None: def init(m): if self.wholemodule: trunc_normal_init(m, self.mean, self.std, self.a, self.b, self.bias) else: layername = m.__class__.__name__ basesname = _get_bases_name(m) if len(set(self.layer) & set([layername] + basesname)): trunc_normal_init(m, self.mean, self.std, self.a, self.b, self.bias) module.apply(init) if hasattr(module, '_params_init_info'): update_init_info(module, init_info=self._get_init_info())
def _apply_weight_init(self, init_method: Union[str, FunctionType], proto_model: Module): init_name = "No" if init_method: if init_method == 'kaiming': self.init_fc = getattr(init, "kaiming_normal_") init_name = init_method elif init_method == "xavier": self.init_fc = getattr(init, "xavier_normal_") init_name = init_method else: self.init_fc = getattr(init, init_method) init_name = init_method.__name__ proto_model.apply(self._weight_init) return init_name
def our_train(model: nn.Module, labels: list, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, ewcs: list, lam: float, gpu: torch.device, cut_idx, if_freeze): #还需要进行loss判断,true:freeze #---------------------freeze if if_freeze == 1: for idx, param in enumerate(model.parameters()): if idx >= cut_idx: continue param.requires_grad = False optimizer = optim.Adam( filter(lambda p: p.requires_grad, model.parameters())) # no need to add: lr=0.1? #---------------------- model.train() model.apply(set_bn_eval) #冻结BN及其统计数据 epoch_loss = 0 for data, target in data_loader: data, target = Variable(data).cuda(gpu), Variable(target).cuda(gpu) optimizer.zero_grad() output = model(data) for idx in range(output.size(1)): if idx not in labels: output[range(len(output)), idx] = 0 criterion = nn.CrossEntropyLoss() loss = criterion(output, target) # print('loss:', loss.item()) for ewc in ewcs: loss += (lam / 2) * ewc.penalty(model) # print('ewc loss:', loss.item()) epoch_loss += loss.item() loss.backward() optimizer.step() #-----------------------------解冻 if if_freeze == 1: for idx, param in enumerate(model.parameters()): if idx >= cut_idx: continue param.requires_grad = True optimizer = optim.Adam( filter(lambda p: p.requires_grad, model.parameters())) # no need to add: lr=0.1? #------------------------------- return epoch_loss / len(data_loader)
def to_dependentmodule(cls, module: Module, recurse=True): r""" Transform a module and all its submodule into dependent module. Args: module: recurse: if set to be True all submodules will be transformed into dependent module recursively. Returns: DependentModule: a dependent module """ if not recurse: module = cls._make_subclass(module) else: module.apply(lambda x: cls.to_dependentmodule(x, recurse=False)) return module
def normal_train(model: nn.Module, labels: list, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, gpu: torch.device): model.train() model.apply(set_bn_eval) #冻结BN及其统计数据 epoch_loss = 0 for data, target in data_loader: data, target = Variable(data).cuda(gpu), Variable(target).cuda(gpu) optimizer.zero_grad() output = model(data) for idx in range(output.size(1)): if idx not in labels: output[range(len(output)), idx] = 0 criterion = nn.CrossEntropyLoss() loss = criterion(output, target) epoch_loss += loss.item() loss.backward() optimizer.step() return epoch_loss / len(data_loader)
def init_weights(net: nn.Module, init_type: str = "normal", init_gain: float = 0.02) -> None: """Initialize network weights. We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might work better for some applications. Feel free to try yourself. Args: net: Network to be initialized init_type: Name of an initialization method: normal | xavier | kaiming | orthogonal init_gain: Scaling factor for normal, xavier and orthogonal. """ def init_func(m: nn.Module) -> None: # define the initialization function classname = m.__class__.__name__ if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): if init_type == "normal": init.normal_(m.weight.data, 0.0, init_gain) elif init_type == "xavier": init.xavier_normal_(m.weight.data, gain=init_gain) elif init_type == "kaiming": init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") elif init_type == "orthogonal": init.orthogonal_(m.weight.data, gain=init_gain) else: raise NotImplementedError( "initialization method [%s] is not implemented" % init_type) if hasattr(m, "bias") and m.bias is not None: init.constant_(m.bias.data, 0.0) elif ( classname.find("BatchNorm2d") != -1 ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies. init.normal_(m.weight.data, 1.0, init_gain) init.constant_(m.bias.data, 0.0) print("initialize network with %s" % init_type) net.apply(init_func) # apply the initialization function <init_func>
def bn_update(loader: DataLoader, model: nn.Module): """ BatchNorm buffers update (if any). Performs 1 epochs to estimate buffers average using train dataset. :param loader: train dataset loader for buffers average estimation. :param model: model being update :return: None """ if not check_bn(model): return assert loader.drop_last model.train() model.apply(reset_bn) for batch in tqdm(loader, desc="AdaBN"): batch = any2device(batch, device="cuda") model(**batch) model.apply(fix_bn)
def jit_inference_trace(cls, module: Module, input_t: Union[Tensor, QuantTensor]): with torch.no_grad(): training_state = module.training module = module.eval() module.apply(cls.set_export_handler) # do a forward pass with the input to e.g. store input/output shapes cls._cache_inp_out(module, input_t) # unpack quant tensor if isinstance(input_t, QuantTensor): input_t = input_t.value # enable export mode, this triggers collecting export values into handlers module.apply(lambda m: cls.set_export_mode(m, enabled=True)) # force requires_grad to False to let the wrapped model lambda go through tracing requires_grad_backup_dict = _force_requires_grad_false(module) with ExitStack() as stack: for mgr in cls._trace_patches(): stack.enter_context(mgr) # wrapping with a lambda forces inlining during tracing, # converts everything to const and removes unused params/buffers traced_model = torch.jit.trace(_JitTraceExportWrapper(module), input_t) # Hack to clone the function, otherwise restoring requires_grad # on module will break traced_model with BytesIO() as tmp: torch.jit.save(traced_model, tmp) tmp.seek(0) traced_model = torch.jit.load(tmp) _restore_requires_grad(module, requires_grad_backup_dict) module.apply(lambda m: cls.set_export_mode(m, enabled=False)) module.train(training_state) return traced_model
def model_train(n: nn.Module, dl: utils.data.DataLoader, *, device=DEVICE): n.to(device) n.train() n.apply(init_weights) save_dir = f"./build/{datetime.now().isoformat(timespec='seconds')}/" criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(n.parameters(), lr=0.001, momentum=0.9) epochs = 150 for epoch in range(1, epochs + 1): running_loss = 0.0 for i, (x, y) in enumerate(dl, 1): x, y = x.to(device), y.to(device) optimizer.zero_grad() y_preds = n(x) loss = criterion(y_preds, y) loss.backward() optimizer.step() running_loss += loss.item() if i % 200 == 0: print("[{:2}, {:5}, {:3.0f}%] loss: {:5.2f}".format( epoch, i, 100.0 * (i / len(dl) + epoch - 1) / epochs, running_loss, )) running_loss = 0.0 if epoch % 5 == 0: if not os.path.exists(save_dir): os.mkdir(save_dir) save_file = os.path.join(save_dir, f"epoch-{epoch:03d}.pt") t.save(n.state_dict(), save_file)
def ewc_train(model: nn.Module, labels: list, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, ewcs: list, lam: float, gpu: torch.device): model.train() model.apply(set_bn_eval) #冻结BN及其统计数据 epoch_loss = 0 for data, target in data_loader: data, target = Variable(data).cuda(gpu), Variable(target).cuda(gpu) optimizer.zero_grad() output = model(data) # for idx in range(output.size(1)): # if idx not in labels: # output[range(len(output)), idx] = 0 # criterion = nn.CrossEntropyLoss() # loss = criterion(output, target) loss = myloss(output, target, labels) # print('loss:', loss.item()) for ewc in ewcs: loss += (lam / 2) * ewc.penalty(model) # print('ewc loss:', loss.item()) epoch_loss += loss.item() loss.backward() optimizer.step() return epoch_loss / len(data_loader)
def profile(model: nn.Module, input_size: tuple): handler_collection = [] def add_hooks(m): if len(list(m.children())) > 0: return m.register_buffer('total_ops', torch.zeros(1)) m.register_buffer('total_params', torch.zeros(1)) for p in m.parameters(): m.total_params += torch.Tensor([p.numel()]) if type(m) in register_hooks: fn = register_hooks[type(m)] handler = m.register_forward_hook(fn) handler_collection.append(handler) model.eval().to('cpu') model.apply(add_hooks) with torch.no_grad(): model(torch.zeros(input_size).to('cpu')) model_ops = 0 model_params = 0 for m in model.modules(): if len(list(m.children())) > 0: # skip for non-leaf module continue model_ops += m.total_ops model_params += m.total_params model_ops = model_ops.item() model_params = model_params.item() model.train().to('cpu') for handler in handler_collection: handler.remove() return int(model_ops), int(model_params)
def initialize_layers(model: nn.Module): """ Reference: https://pytorch.org/docs/stable/nn.html#torch.nn.Module.apply @param model: initialize layer weights for test cases. """ def init_weights(m): if type(m) == nn.Linear: m.weight.data.fill_(0.5) if m.bias is not None: m.bias.data.fill_(0.5) elif type(m) == nn.Conv1d: m.weight.data.fill_(0.5) if m.bias is not None: m.bias.data.fill_(0.5) elif type(m) == nn.Embedding: m.weight.data.fill_(0.15) elif type(m) == nn.Dropout: nn.Dropout(DROPOUT_RATE) #with torch.no_grad(): model.apply(init_weights) return model
def inspect( model: nn.Module, input_size: Union[InputShape, List[InputShape]], input_dtype: Type[torch.Tensor] = torch.FloatTensor, input_initializer: Callable[..., torch.Tensor] = torch.rand, batch_size: int = 2, ) -> List[LayerInfo]: hook = _ModuleHook(batch_size) handles: List[RemovableHandle] = [] def register_hook(module: nn.Module) -> None: if should_attach_hook(model, module): h: RemovableHandle = module.register_forward_hook(hook.hook) handles.append(h) # multiple inputs to the network if isinstance(input_size, tuple): input_size = [input_size] # make fake input with batch_size of 2 for batchnorm x = [ input_initializer(batch_size, *in_size).type(input_dtype # type: ignore ) for in_size in input_size ] # attach hooks to each applicable layer model.apply(register_hook) # forward pass try: model(*x) finally: # cleanup all attached hooks, to move model to original state for h in handles: h.remove() # type: ignore return hook.layer_list
def jit_trace(cls, module: Module, input_t: Union[Tensor, QuantTensor]): with torch.no_grad(): module = module.eval() module.apply(cls.set_export_handler) # do a forward pass with the dummy input to e.g. store input/output shapes cls.cache_inp_out(module, input_t) # override any given input_t to make sure it's a standard PyTorch tensor input_shape = input_t.shape if isinstance(input_t, Tensor) else input_t.value.shape input_t = torch.empty(input_shape, dtype=torch.float) # enable export mode, this triggers collecting export values into handlers module.apply(lambda m: _set_export_mode(m, enabled=True)) traced_model = jit_trace_patched(module, input_t) module.apply(lambda m: _set_export_mode(m, enabled=False)) return traced_model
def init_weights(net: nn.Module): net.apply(weights_init_normal)
def export_onnx(cls, module: Module, input_shape: Optional[Tuple[int, ...]] = None, export_path: Optional[str] = None, input_t: Optional[Union[Tensor, QuantTensor]] = None, **kwargs): """ * input_shape : tuple describing the shape of network input e.g. (1, 1, 28, 28) * export_path : ONNX filename to export to * input_t : if specified, do an initial forward pass with this value. this may be necessary for QuantTensor caching. * torch_onnx_kwargs : will be passed as kwargs to torch.onnx.export """ if onnx is None or opt is None: raise ModuleNotFoundError("Installation of ONNX is required.") if input_shape is None and input_t is None: raise RuntimeError( "Export requires to pass in either input_shape or input_t") if input_shape is not None and input_t is not None: raise RuntimeError( "Export accepts either an input shape or an input tensor, not both" ) cls.solve_keep_initializers_as_inputs(kwargs) cls.solve_enable_onnx_checker(kwargs) with torch.no_grad(): with ExportContext(cls): training_state = module.training module = module.eval() module.apply(cls.set_export_handler) if input_t is None: input_t = torch.empty(input_shape, dtype=torch.float) # do a forward pass with the dummy input to e.g. store input/output shapes cls._cache_inp_out(module, input_t) # Dequantize QuantTensor, if any if isinstance(input_t, QuantTensor): input_t = input_t.value # enable export mode, this triggers collecting export values into handlers module.apply(lambda m: cls.set_export_mode(m, enabled=True)) # temporarily disable input caching to avoid collectives empty debug values module.apply( lambda m: _override_inp_caching_mode(m, enabled=False)) # perform export pass with ExitStack() as stack: for mgr in cls._trace_patches(): stack.enter_context(mgr) if export_path is not None: torch.onnx.export(module, input_t, export_path, **kwargs) else: model_bytes = BytesIO() torch.onnx.export(module, input_t, model_bytes, **kwargs) # restore the model to previous properties module.apply(lambda m: _restore_inp_caching_mode(m)) module.apply(lambda m: cls.set_export_mode(m, enabled=False)) module.train(training_state) # do some cleanup on the exported ONNX model if export_path is not None: model = onnx.load(export_path) else: model = onnx.ModelProto.FromString(model_bytes.getvalue()) model = opt.optimize(model, cls.onnx_passes) model = cls.apply_model_transforms(model) if export_path is not None: onnx.save(model, export_path) return model
def our_train(model: nn.Module, labels: list, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, ewcs: list, lam: float, gpu: torch.device, cut_idx, if_freeze): #还需要进行loss判断,true:freeze #---------------------freeze # if if_freeze == 1 : # for idx, param in enumerate(model.parameters()): # if idx >= cut_idx: # continue # param.requires_grad = False # optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args_lr) # no need to add: lr=0.1? #---------------------- model.train() model.apply(set_bn_eval) #冻结BN及其统计数据 epoch_loss = 0 for data, target in data_loader: data, target = Variable(data).cuda(gpu), Variable(target).cuda(gpu) optimizer.zero_grad() output = model(data) for idx in range(output.size(1)): if idx not in labels: output[range(len(output)), idx] = 0 criterion = nn.CrossEntropyLoss() loss = criterion(output, target) # print('loss:', loss.item()) for ewc in ewcs: loss += (lam / 2) * ewc.penalty(model) # print('ewc loss:', loss.item()) epoch_loss += loss.item() loss.backward() countskip = 0 countall = 0 #------根据if_freeze,决定是否冻结server---------------------------------------------- if if_freeze == 1 : #----------------重写step--------------------- for group in optimizer.param_groups: for idx, p in enumerate(group['params']): countall += 1 if idx >= cut_idx: #冻结server,即跳过cut_idx ~ end countskip += 1 #print('skip_server_layer') continue if p.grad is None: continue d_p = p.grad #p.add_(d_p, alpha=-group['lr']) p.data = p.data - d_p*group['lr'] print("countskip:",countskip,"countall:",countall) else: optimizer.step() #---------------------------------------------------- #optimizer.step() #optimizer.param_groups : 'params' : .grad ==> 梯度 #91行 #for n, p in model.named_parameters(): # p.grad.data ==> 当前网络层梯度数据? #-----------------------------解冻 # if if_freeze == 1 : # for idx, param in enumerate(model.parameters()): # if idx >= cut_idx: # continue # param.requires_grad = True # optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args_lr) # no need to add: lr=0.1? #------------------------------- return epoch_loss / len(data_loader)