def decompose_conv(self, model, method): feature_modules = model.features._modules N = len(feature_modules.keys()) for i, key in enumerate(feature_modules.keys()): if i >= N - 2: break if isinstance(feature_modules[key], torch.nn.modules.conv.Conv2d): conv_layer = feature_modules[key] if method == "cp": rank = max(conv_layer.weight.data.numpy().shape) // 3 decomposed = cp_decomposition_conv_layer(conv_layer, rank) elif method == "tucker": decomposed = tucker_decomposition_conv_layer(conv_layer) feature_modules[key] = decomposed return model
elif args.decompose: model = torch.load("model").cuda() model.eval() model.cpu() N = len(model.features._modules.keys()) for i, key in enumerate(model.features._modules.keys()): if i >= N - 2: break if isinstance(model.features._modules[key], torch.nn.modules.conv.Conv2d): conv_layer = model.features._modules[key] if args.cp: rank = max(conv_layer.weight.data.numpy().shape)//3 decomposed = cp_decomposition_conv_layer(conv_layer, rank) else: decomposed = tucker_decomposition_conv_layer(conv_layer) model.features._modules[key] = decomposed torch.save(model, 'decomposed_model') elif args.fine_tune: base_model = torch.load("decomposed_model") model = torch.nn.DataParallel(base_model) for param in model.parameters(): param.requires_grad = True print(model) model.cuda()