예제 #1
0
    def decompose_conv(self, model, method):
        feature_modules = model.features._modules
        N = len(feature_modules.keys())
        for i, key in enumerate(feature_modules.keys()):
            if i >= N - 2:
                break
            if isinstance(feature_modules[key], torch.nn.modules.conv.Conv2d):
                conv_layer = feature_modules[key]
                if method == "cp":
                    rank = max(conv_layer.weight.data.numpy().shape) // 3
                    decomposed = cp_decomposition_conv_layer(conv_layer, rank)
                elif method == "tucker":
                    decomposed = tucker_decomposition_conv_layer(conv_layer)
                feature_modules[key] = decomposed

        return model
예제 #2
0
    elif args.decompose:
        model = torch.load("model").cuda()
        model.eval()
        model.cpu()
        N = len(model.features._modules.keys())
        for i, key in enumerate(model.features._modules.keys()):

            if i >= N - 2:
                break
            if isinstance(model.features._modules[key], torch.nn.modules.conv.Conv2d):
                conv_layer = model.features._modules[key]
                if args.cp:
                    rank = max(conv_layer.weight.data.numpy().shape)//3
                    decomposed = cp_decomposition_conv_layer(conv_layer, rank)
                else:
                    decomposed = tucker_decomposition_conv_layer(conv_layer)

                model.features._modules[key] = decomposed

            torch.save(model, 'decomposed_model')


    elif args.fine_tune:
        base_model = torch.load("decomposed_model")
        model = torch.nn.DataParallel(base_model)

        for param in model.parameters():
            param.requires_grad = True

        print(model)
        model.cuda()