예제 #1
0
    def _add_hook_and_collect(self,model: nn.Module, inputs, custom_ops=None, verbose=True):
        handler_collection = {}
        types_collection = set()
        if custom_ops is None:
            custom_ops = {}

        def add_hooks(m: nn.Module):

            m_type = type(m)

            fn = None
            if m_type in custom_ops:  # if defined both op maps, use custom_ops to overwrite.
                fn = custom_ops[m_type]
                if m_type not in types_collection and verbose:
                    print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
            elif m_type in register_hooks:
                fn = register_hooks[m_type]
                if m_type not in types_collection and verbose:
                    print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
            else:
                if m_type not in types_collection and verbose:
                    prRed("[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." % m_type)

            
            def count_parameters(m, x, y):
                total_params = 0
                for p in m.parameters():
                    total_params += torch.DoubleTensor([p.numel()])
                m.total_params[0] = total_params
            
            handler_collection_xy = {}
            if fn is not None: 
                m.register_buffer('total_ops', torch.zeros(1, dtype=torch.float64))
                m.register_buffer('total_params', torch.zeros(1, dtype=torch.float64))
                
                handler_collection[m] = (m.register_forward_hook(fn), m.register_forward_hook(count_parameters))
            types_collection.add(m_type)

        prev_training_status = model.training

        model.eval()
        model.apply(add_hooks)

        with torch.no_grad():
            model(*inputs)

        # collecting flops and params
        for i,m in enumerate(self.model.modules()):
            if i in self.all_idx:
                self.params_dict[i] = m.total_params.item()
                self.flops_dict[i] = m.total_ops.item()
                self.params_list.append(m.total_params.item())
                self.flops_list.append(m.total_ops.item())

        model.train(prev_training_status)
        for m, (op_handler, params_handler) in handler_collection.items():
            op_handler.remove()
            params_handler.remove()
            m._buffers.pop("total_ops")
            m._buffers.pop("total_params")
예제 #2
0
def set_active_group(module: nn.Module, group):
    """Scan all submodules, passing a distributed group to all those that implement `set_group`"""
    def _set_group(m):
        if hasattr(m, "set_group"):
            m.set_group(group)

    module.apply(_set_group)
예제 #3
0
def convert_weights(model: nn.Module):
    """Convert applicable model parameters to fp16"""

    def _convert_weights_to_fp16(l):
        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
            l.weight.data = l.weight.data.half()
            if l.bias is not None:
                l.bias.data = l.bias.data.half()

        if isinstance(l, nn.MultiheadAttention):
            for attr in [
                *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
                "in_proj_bias",
                "bias_k",
                "bias_v",
            ]:
                tensor = getattr(l, attr)
                if tensor is not None:
                    tensor.data = tensor.data.half()

        for name in ["text_projection", "proj"]:
            if hasattr(l, name):
                attr = getattr(l, name)
                if attr is not None:
                    attr.data = attr.data.half()

    model.apply(_convert_weights_to_fp16)
예제 #4
0
def convert_weights(model: nn.Module):
    '''Convert applicable model parameters to fp16'''

    def _convert_weights_to_fp16(l):
        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
            l.weight.data = l.weight.data.half()
            if l.bias is not None:
                l.bias.data = l.bias.data.half()

        if isinstance(l, nn.MultiheadAttention):
            for attr in [
                *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']],
                'in_proj_bias',
                'bias_k',
                'bias_v',
            ]:
                tensor = getattr(l, attr)
                if tensor is not None:
                    tensor.data = tensor.data.half()

        for name in ['text_projection', 'proj']:
            if hasattr(l, name):
                attr = getattr(l, name)
                if attr is not None:
                    attr.data = attr.data.half()

    model.apply(_convert_weights_to_fp16)
예제 #5
0
def init_weights(net: nn.Module, init_type: str = 'xavier_uniform', init_gain: float = 1.) -> None:

    def init_func(m: nn.Module):
        name = m.__class__.__name__
        if hasattr(m, 'weight') and ('Conv' in name or 'Linear' in name):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier_normal':
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'xavier_uniform':
                init.xavier_uniform_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError(f'initialization method [{init_type}] is not implemented')
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif 'BatchNorm2d' in name:
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    net.apply(init_func)
    logger.info(f'Weights have been initialized with (type={init_type}, gain={init_gain}).')
예제 #6
0
def prepare_model_for_grad_receptive_field(model: nn.Module) -> nn.Module:
    # See:
    # https://github.com/rogertrullo/Receptive-Field-in-Pytorch/blob/master/Receptive_Field.ipynb

    # make copy
    model = deepcopy(model)
    model.apply(modify_for_grad_receptive_field)
    return model
예제 #7
0
파일: utils.py 프로젝트: wiwi/ssl-suite
def disable_bn_stats(model: nn.Module):
    def f(m: nn.Module):
        if hasattr(m, 'track_running_stats'):
            m.track_running_stats ^= True

    model.apply(f)
    yield
    model.apply(f)
def disable_batchnorm_tracking(model: Module) -> None:
    """adapted from: https://github.com/lyakaap/VAT-pytorch/blob/master/vat.py
    """
    def switch_attr(m):
        if hasattr(m, 'track_running_stats'):
            m.track_running_stats ^= True

    model.apply(switch_attr)
    yield
    model.apply(switch_attr)
예제 #9
0
def initialize_weights(
    model: nn.Module,
    initialization_function: Callable,
    pre_initialization: Optional[Callable] = None,
    is_admin=False,
    **kwargs,
) -> None:
    pre_init_outputs = {}
    if pre_initialization is not None:
        pre_init_outputs = pre_initialization(model)  # must be dict
    initialization_function_with_args = partial(initialization_function,
                                                **kwargs, **pre_init_outputs)
    model.apply(initialization_function_with_args)
예제 #10
0
    def __call__(self, module: nn.Module) -> None:
        def init(m):
            if self.wholemodule:
                trunc_normal_init(m, self.mean, self.std, self.a, self.b,
                                  self.bias)
            else:
                layername = m.__class__.__name__
                for layer_ in self.layer:
                    if layername == layer_:
                        trunc_normal_init(m, self.mean, self.std, self.a,
                                          self.b, self.bias)

        module.apply(init)
예제 #11
0
    def __call__(self, module: nn.Module) -> None:
        def init(m):
            if self.wholemodule:
                trunc_normal_init(m, self.mean, self.std, self.a, self.b,
                                  self.bias)
            else:
                layername = m.__class__.__name__
                basesname = _get_bases_name(m)
                if len(set(self.layer) & set([layername] + basesname)):
                    trunc_normal_init(m, self.mean, self.std, self.a, self.b,
                                      self.bias)

        module.apply(init)
예제 #12
0
def prune_model(
    model: Module,
    pruning_fn: Callable,
    keys_to_prune: List[str],
    amount: Union[float, int],
    layers_to_prune: Optional[List[str]] = None,
    reinitialize_after_pruning: Optional[bool] = False,
) -> None:
    """
    Prune model function can be used for pruning certain
    tensors in model layers.

    Raises:
        AttributeError: If layers_to_prune is not None, but there is
                no layers with specified name.
        Exception: If no layers have specified keys.

    Args:
        model: Model to be pruned.
        pruning_fn: Pruning function with API same as in
            torch.nn.utils.pruning.
            pruning_fn(module, name, amount).
        keys_to_prune: list of strings. Determines
            which tensor in modules will be pruned.
        amount: quantity of parameters to prune.
            If float, should be between 0.0 and 1.0 and
            represent the fraction of parameters to prune.
            If int, it represents the absolute number
            of parameters to prune.
        layers_to_prune: list of strings - module names to be pruned.
            If None provided then will try to prune every module in
            model.
        reinitialize_after_pruning: if True then will reinitialize model
                after pruning. (Lottery Ticket Hypothesis check e.g.)
    """
    pruned_modules = 0
    for name, module in model.named_modules():
        try:
            if layers_to_prune is None or name in layers_to_prune:
                for key in keys_to_prune:
                    pruning_fn(module, name=key, amount=amount)
                pruned_modules += 1
        except AttributeError as e:
            if layers_to_prune is not None:
                raise e

    if pruned_modules == 0:
        raise Exception(f"There is no {keys_to_prune} key in your model")
    if reinitialize_after_pruning:
        model.apply(reset_weights_if_possible)
예제 #13
0
    def export_onnx(
            cls,
            module: Module,
            input_shape: Tuple[int, ...],
            export_path: str,
            input_t: Optional[Union[Tensor, QuantTensor]] = None,
            **kwargs):
        """
        * input_shape : tuple describing the shape of network input e.g. (1, 1, 28, 28)
        * export_path : ONNX filename to export to
        * input_t : if specified, do an initial forward pass with this value. this
                    may be necessary for QuantTensor caching.
        * torch_onnx_kwargs : will be passed as kwargs to torch.onnx.export
        """

        def set_export_handler(m: Module):
            if hasattr(m, 'export_handler') and m.export_handler is None:
                handler = cls.handler_from_module(m)
                m.export_handler = handler()

        if onnx is None or opt is None:
            raise ModuleNotFoundError("Installation of ONNX is required.")

        cls.solve_keep_initializers_as_inputs(kwargs)
        cls.solve_enable_onnx_checker(kwargs)

        with torch.no_grad():
            module = module.eval()
            module.apply(set_export_handler)
            if input_t is None:
                input_t = torch.empty(input_shape, dtype=torch.float)
            # do a forward pass with the dummy input to e.g. store input/output shapes
            cls.cache_inp_out(module, input_t)
            # override any given input_t to make sure it's a standard PyTorch tensor
            input_t = torch.empty(input_shape, dtype=torch.float)
            # enable export mode, this triggers collecting export values into handlers
            module.apply(lambda m: _set_export_mode(m, enabled=True))
            # temporarily disable input caching to avoid collectives empty debug values
            module.apply(lambda m: _override_inp_caching_mode(m, enabled=False))
            # perform export pass
            torch.onnx.export(module, input_t, export_path, **kwargs)
            # restore the model to previous properties
            module.apply(lambda m: _restore_inp_caching_mode(m))
            module.apply(lambda m: _set_export_mode(m, enabled=False))
            # do some cleanup on the exported ONNX model
            model = onnx.load(export_path)
            model = opt.optimize(model, cls.onnx_passes)
            model = cls.apply_model_transforms(model)
            onnx.save(model, export_path)
예제 #14
0
    def __call__(self, module: nn.Module) -> None:
        def init(m):
            if self.wholemodule:
                trunc_normal_init(m, self.mean, self.std, self.a, self.b,
                                  self.bias)
            else:
                layername = m.__class__.__name__
                basesname = _get_bases_name(m)
                if len(set(self.layer) & set([layername] + basesname)):
                    trunc_normal_init(m, self.mean, self.std, self.a, self.b,
                                      self.bias)

        module.apply(init)
        if hasattr(module, '_params_init_info'):
            update_init_info(module, init_info=self._get_init_info())
예제 #15
0
 def _apply_weight_init(self, init_method: Union[str, FunctionType],
                        proto_model: Module):
     init_name = "No"
     if init_method:
         if init_method == 'kaiming':
             self.init_fc = getattr(init, "kaiming_normal_")
             init_name = init_method
         elif init_method == "xavier":
             self.init_fc = getattr(init, "xavier_normal_")
             init_name = init_method
         else:
             self.init_fc = getattr(init, init_method)
             init_name = init_method.__name__
         proto_model.apply(self._weight_init)
     return init_name
예제 #16
0
def our_train(model: nn.Module, labels: list, optimizer: torch.optim,
              data_loader: torch.utils.data.DataLoader, ewcs: list, lam: float,
              gpu: torch.device, cut_idx, if_freeze):

    #还需要进行loss判断,true:freeze
    #---------------------freeze
    if if_freeze == 1:
        for idx, param in enumerate(model.parameters()):
            if idx >= cut_idx:
                continue
            param.requires_grad = False

        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad,
                   model.parameters()))  # no need to add: lr=0.1?
    #----------------------
    model.train()
    model.apply(set_bn_eval)  #冻结BN及其统计数据
    epoch_loss = 0
    for data, target in data_loader:
        data, target = Variable(data).cuda(gpu), Variable(target).cuda(gpu)
        optimizer.zero_grad()
        output = model(data)
        for idx in range(output.size(1)):
            if idx not in labels:
                output[range(len(output)), idx] = 0
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, target)
        # print('loss:', loss.item())
        for ewc in ewcs:
            loss += (lam / 2) * ewc.penalty(model)
            # print('ewc loss:', loss.item())
        epoch_loss += loss.item()

        loss.backward()
        optimizer.step()
    #-----------------------------解冻
    if if_freeze == 1:
        for idx, param in enumerate(model.parameters()):
            if idx >= cut_idx:
                continue
            param.requires_grad = True

        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad,
                   model.parameters()))  # no need to add: lr=0.1?
    #-------------------------------
    return epoch_loss / len(data_loader)
예제 #17
0
    def to_dependentmodule(cls, module: Module, recurse=True):
        r"""
        Transform a module and all its submodule into dependent module.

        Args:
            module:
            recurse: if set to be True all submodules will be transformed into dependent module recursively.

        Returns:
            DependentModule: a dependent module
        """
        if not recurse:
            module = cls._make_subclass(module)
        else:
            module.apply(lambda x: cls.to_dependentmodule(x, recurse=False))
        return module
예제 #18
0
def normal_train(model: nn.Module, labels: list, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, gpu: torch.device):
    model.train()
    model.apply(set_bn_eval)  #冻结BN及其统计数据
    epoch_loss = 0
    for data, target in data_loader:
        data, target = Variable(data).cuda(gpu), Variable(target).cuda(gpu)
        optimizer.zero_grad()
        output = model(data)
        for idx in range(output.size(1)):
            if idx not in labels:
                output[range(len(output)), idx] = 0
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, target)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
    return epoch_loss / len(data_loader)
예제 #19
0
def init_weights(net: nn.Module,
                 init_type: str = "normal",
                 init_gain: float = 0.02) -> None:
    """Initialize network weights.

    We use 'normal' in the original pix2pix and CycleGAN paper.
    But xavier and kaiming might work better for some applications.
    Feel free to try yourself.

    Args:
        net: Network to be initialized
        init_type: Name of an initialization method:
            normal | xavier | kaiming | orthogonal
        init_gain: Scaling factor for normal, xavier and orthogonal.
    """
    def init_func(m: nn.Module) -> None:  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, "weight") and (classname.find("Conv") != -1
                                     or classname.find("Linear") != -1):
            if init_type == "normal":
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == "xavier":
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == "kaiming":
                init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
            elif init_type == "orthogonal":
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError(
                    "initialization method [%s] is not implemented" %
                    init_type)
            if hasattr(m, "bias") and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif (
                classname.find("BatchNorm2d") != -1
        ):  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    print("initialize network with %s" % init_type)
    net.apply(init_func)  # apply the initialization function <init_func>
예제 #20
0
def bn_update(loader: DataLoader, model: nn.Module):
    """
        BatchNorm buffers update (if any).
        Performs 1 epochs to estimate buffers average using train dataset.
        :param loader: train dataset loader for buffers average estimation.
        :param model: model being update
        :return: None
    """
    if not check_bn(model):
        return

    assert loader.drop_last

    model.train()
    model.apply(reset_bn)

    for batch in tqdm(loader, desc="AdaBN"):
        batch = any2device(batch, device="cuda")
        model(**batch)

    model.apply(fix_bn)
예제 #21
0
 def jit_inference_trace(cls, module: Module, input_t: Union[Tensor, QuantTensor]):
     with torch.no_grad():
         training_state = module.training
         module = module.eval()
         module.apply(cls.set_export_handler)
         # do a forward pass with the input to e.g. store input/output shapes
         cls._cache_inp_out(module, input_t)
         # unpack quant tensor
         if isinstance(input_t, QuantTensor):
             input_t = input_t.value
         # enable export mode, this triggers collecting export values into handlers
         module.apply(lambda m: cls.set_export_mode(m, enabled=True))
         # force requires_grad to False to let the wrapped model lambda go through tracing
         requires_grad_backup_dict = _force_requires_grad_false(module)
         with ExitStack() as stack:
             for mgr in cls._trace_patches():
                 stack.enter_context(mgr)
             # wrapping with a lambda forces inlining during tracing,
             # converts everything to const and removes unused params/buffers
             traced_model = torch.jit.trace(_JitTraceExportWrapper(module), input_t)
         # Hack to clone the function, otherwise restoring requires_grad
         # on module will break traced_model
         with BytesIO() as tmp:
             torch.jit.save(traced_model, tmp)
             tmp.seek(0)
             traced_model = torch.jit.load(tmp)
         _restore_requires_grad(module, requires_grad_backup_dict)
         module.apply(lambda m: cls.set_export_mode(m, enabled=False))
         module.train(training_state)
         return traced_model
예제 #22
0
def model_train(n: nn.Module, dl: utils.data.DataLoader, *, device=DEVICE):
    n.to(device)
    n.train()
    n.apply(init_weights)

    save_dir = f"./build/{datetime.now().isoformat(timespec='seconds')}/"

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(n.parameters(), lr=0.001, momentum=0.9)

    epochs = 150
    for epoch in range(1, epochs + 1):

        running_loss = 0.0
        for i, (x, y) in enumerate(dl, 1):
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            y_preds = n(x)
            loss = criterion(y_preds, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 200 == 0:
                print("[{:2}, {:5}, {:3.0f}%] loss: {:5.2f}".format(
                    epoch,
                    i,
                    100.0 * (i / len(dl) + epoch - 1) / epochs,
                    running_loss,
                ))
                running_loss = 0.0

        if epoch % 5 == 0:
            if not os.path.exists(save_dir):
                os.mkdir(save_dir)

            save_file = os.path.join(save_dir, f"epoch-{epoch:03d}.pt")
            t.save(n.state_dict(), save_file)
예제 #23
0
def ewc_train(model: nn.Module, labels: list, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, ewcs: list, lam: float, gpu: torch.device):
    model.train()
    model.apply(set_bn_eval) #冻结BN及其统计数据
    epoch_loss = 0
    for data, target in data_loader:
        data, target = Variable(data).cuda(gpu), Variable(target).cuda(gpu)
        optimizer.zero_grad()
        output = model(data)
        # for idx in range(output.size(1)):
        #     if idx not in labels:
        #         output[range(len(output)), idx] = 0
        # criterion = nn.CrossEntropyLoss()
        # loss = criterion(output, target) 
        loss = myloss(output, target, labels)        
        # print('loss:', loss.item())
        for ewc in ewcs:
            loss += (lam / 2) * ewc.penalty(model)
            # print('ewc loss:', loss.item())
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
    return epoch_loss / len(data_loader)
예제 #24
0
def profile(model: nn.Module, input_size: tuple):
    handler_collection = []

    def add_hooks(m):
        if len(list(m.children())) > 0:
            return

        m.register_buffer('total_ops', torch.zeros(1))
        m.register_buffer('total_params', torch.zeros(1))

        for p in m.parameters():
            m.total_params += torch.Tensor([p.numel()])

        if type(m) in register_hooks:
            fn = register_hooks[type(m)]
            handler = m.register_forward_hook(fn)
            handler_collection.append(handler)

    model.eval().to('cpu')
    model.apply(add_hooks)
    with torch.no_grad():
        model(torch.zeros(input_size).to('cpu'))

    model_ops = 0
    model_params = 0
    for m in model.modules():
        if len(list(m.children())) > 0:  # skip for non-leaf module
            continue
        model_ops += m.total_ops
        model_params += m.total_params

    model_ops = model_ops.item()
    model_params = model_params.item()

    model.train().to('cpu')
    for handler in handler_collection:
        handler.remove()

    return int(model_ops), int(model_params)
예제 #25
0
파일: test_cnn.py 프로젝트: zkcr0000/cs224n
def initialize_layers(model: nn.Module):
    """ 
    Reference: https://pytorch.org/docs/stable/nn.html#torch.nn.Module.apply

    @param model: initialize layer weights for test cases.
    """
    def init_weights(m):
        if type(m) == nn.Linear:
            m.weight.data.fill_(0.5)
            if m.bias is not None:
                m.bias.data.fill_(0.5)
        elif type(m) == nn.Conv1d:
            m.weight.data.fill_(0.5)
            if m.bias is not None:
                m.bias.data.fill_(0.5)
        elif type(m) == nn.Embedding:
            m.weight.data.fill_(0.15)
        elif type(m) == nn.Dropout:
            nn.Dropout(DROPOUT_RATE)

    #with torch.no_grad():
    model.apply(init_weights)
    return model
예제 #26
0
def inspect(
    model: nn.Module,
    input_size: Union[InputShape, List[InputShape]],
    input_dtype: Type[torch.Tensor] = torch.FloatTensor,
    input_initializer: Callable[..., torch.Tensor] = torch.rand,
    batch_size: int = 2,
) -> List[LayerInfo]:
    hook = _ModuleHook(batch_size)
    handles: List[RemovableHandle] = []

    def register_hook(module: nn.Module) -> None:
        if should_attach_hook(model, module):
            h: RemovableHandle = module.register_forward_hook(hook.hook)
            handles.append(h)

    # multiple inputs to the network
    if isinstance(input_size, tuple):
        input_size = [input_size]

    # make fake input with batch_size of 2 for batchnorm
    x = [
        input_initializer(batch_size,
                          *in_size).type(input_dtype  # type: ignore
                                         ) for in_size in input_size
    ]
    # attach hooks to each applicable layer
    model.apply(register_hook)

    # forward pass
    try:
        model(*x)
    finally:
        # cleanup all attached hooks, to move model to original state
        for h in handles:
            h.remove()  # type: ignore

    return hook.layer_list
예제 #27
0
 def jit_trace(cls, module: Module, input_t: Union[Tensor, QuantTensor]):
     with torch.no_grad():
         module = module.eval()
         module.apply(cls.set_export_handler)
         # do a forward pass with the dummy input to e.g. store input/output shapes
         cls.cache_inp_out(module, input_t)
         # override any given input_t to make sure it's a standard PyTorch tensor
         input_shape = input_t.shape if isinstance(input_t, Tensor) else input_t.value.shape
         input_t = torch.empty(input_shape, dtype=torch.float)
         # enable export mode, this triggers collecting export values into handlers
         module.apply(lambda m: _set_export_mode(m, enabled=True))
         traced_model = jit_trace_patched(module, input_t)
         module.apply(lambda m: _set_export_mode(m, enabled=False))
         return traced_model
예제 #28
0
def init_weights(net: nn.Module):
    net.apply(weights_init_normal)
예제 #29
0
    def export_onnx(cls,
                    module: Module,
                    input_shape: Optional[Tuple[int, ...]] = None,
                    export_path: Optional[str] = None,
                    input_t: Optional[Union[Tensor, QuantTensor]] = None,
                    **kwargs):
        """
        * input_shape : tuple describing the shape of network input e.g. (1, 1, 28, 28)
        * export_path : ONNX filename to export to
        * input_t : if specified, do an initial forward pass with this value. this
                    may be necessary for QuantTensor caching.
        * torch_onnx_kwargs : will be passed as kwargs to torch.onnx.export
        """

        if onnx is None or opt is None:
            raise ModuleNotFoundError("Installation of ONNX is required.")
        if input_shape is None and input_t is None:
            raise RuntimeError(
                "Export requires to pass in either input_shape or input_t")
        if input_shape is not None and input_t is not None:
            raise RuntimeError(
                "Export accepts either an input shape or an input tensor, not both"
            )

        cls.solve_keep_initializers_as_inputs(kwargs)
        cls.solve_enable_onnx_checker(kwargs)

        with torch.no_grad():
            with ExportContext(cls):
                training_state = module.training
                module = module.eval()
                module.apply(cls.set_export_handler)
                if input_t is None:
                    input_t = torch.empty(input_shape, dtype=torch.float)
                # do a forward pass with the dummy input to e.g. store input/output shapes
                cls._cache_inp_out(module, input_t)
                # Dequantize QuantTensor, if any
                if isinstance(input_t, QuantTensor):
                    input_t = input_t.value  # enable export mode, this triggers collecting export values into handlers
                module.apply(lambda m: cls.set_export_mode(m, enabled=True))
                # temporarily disable input caching to avoid collectives empty debug values
                module.apply(
                    lambda m: _override_inp_caching_mode(m, enabled=False))
                # perform export pass
                with ExitStack() as stack:
                    for mgr in cls._trace_patches():
                        stack.enter_context(mgr)
                    if export_path is not None:
                        torch.onnx.export(module, input_t, export_path,
                                          **kwargs)
                    else:
                        model_bytes = BytesIO()
                        torch.onnx.export(module, input_t, model_bytes,
                                          **kwargs)
                # restore the model to previous properties
                module.apply(lambda m: _restore_inp_caching_mode(m))
                module.apply(lambda m: cls.set_export_mode(m, enabled=False))
                module.train(training_state)
            # do some cleanup on the exported ONNX model
            if export_path is not None:
                model = onnx.load(export_path)
            else:
                model = onnx.ModelProto.FromString(model_bytes.getvalue())
            model = opt.optimize(model, cls.onnx_passes)
            model = cls.apply_model_transforms(model)
            if export_path is not None:
                onnx.save(model, export_path)
            return model
예제 #30
0
def our_train(model: nn.Module, labels: list, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, ewcs: list, lam: float, gpu: torch.device, cut_idx, if_freeze):
    
    #还需要进行loss判断,true:freeze
    #---------------------freeze
    # if if_freeze == 1 :
    #     for idx, param in enumerate(model.parameters()):
    #         if idx >= cut_idx:
    #             continue
    #         param.requires_grad = False

    #     optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args_lr) # no need to add: lr=0.1?
    #----------------------
    model.train()
    model.apply(set_bn_eval) #冻结BN及其统计数据
    epoch_loss = 0
    for data, target in data_loader:
        data, target = Variable(data).cuda(gpu), Variable(target).cuda(gpu)
        optimizer.zero_grad()
        output = model(data)
        for idx in range(output.size(1)):
            if idx not in labels:
                output[range(len(output)), idx] = 0
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, target) 
        # print('loss:', loss.item())
        for ewc in ewcs:
            loss += (lam / 2) * ewc.penalty(model)
            # print('ewc loss:', loss.item())
        epoch_loss += loss.item()

        loss.backward()
        countskip = 0
        countall = 0
        #------根据if_freeze,决定是否冻结server----------------------------------------------
        if if_freeze == 1 :
            #----------------重写step---------------------           
            for group in optimizer.param_groups:
                for idx, p in enumerate(group['params']):
                    countall += 1
                    if idx >= cut_idx: #冻结server,即跳过cut_idx ~ end
                        countskip += 1
                        #print('skip_server_layer')
                        continue                    
                    if p.grad is None:
                        continue
                    d_p = p.grad
                    #p.add_(d_p, alpha=-group['lr'])
                    p.data = p.data - d_p*group['lr']
            print("countskip:",countskip,"countall:",countall)
        else:
            optimizer.step()
        #----------------------------------------------------
        #optimizer.step()   #optimizer.param_groups : 'params' : .grad ==> 梯度 
                           #91行
                           #for n, p in model.named_parameters():
                           #    p.grad.data ==> 当前网络层梯度数据?
    #-----------------------------解冻
    # if if_freeze == 1 :    
    #     for idx, param in enumerate(model.parameters()):
    #         if idx >= cut_idx:
    #             continue
    #         param.requires_grad = True

    #     optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args_lr) # no need to add: lr=0.1?
    #-------------------------------
    return epoch_loss / len(data_loader)