Beispiel #1
0
def prune_least_important_features(prune_count):
    model = FC_Net(layer_sizes=[784, 32, 32, 32, 32, 10])
    state_dict = torch.load(file)
    model.load_state_dict(state_dict)
    trainer.model = model

    print("\n prune the ", prune_count,
          " least important features for all layers at once\n")
    for name, module in model.named_modules():
        if type(module) is Conv1dExtendable:
            ncc = module.normalized_cross_correlation()
        else:
            continue

        for _ in range(prune_count):
            (min_val, min_idx) = torch.min(torch.abs(ncc), 0)
            print("in ", name, ", prune feature number ", min_idx[0],
                  " with ncc ", min_val[0])
            ncc = module.prune_feature(feature_number=min_idx[0])

        # (sorted_val, sorted_idx) = torch.sort(torch.abs(ncc), dim=0, descending=False)
        # feature_numbers = sorted_idx[0:prune_count]
        # for i, feature_number in enumerate(feature_numbers):
        #     print("in ", name, ", prune feature number ", feature_number, " with ncc ", sorted_val[i])
        #     module.prune_feature(feature_number=feature_number)

    test_loss, correct = trainer.test()
    print("change of test loss: ", (test_loss / orig_test_loss))
    print("")
Beispiel #2
0
def pruning_sweep_most_important(prune_count):
    model = FC_Net(layer_sizes=[784, 32, 32, 32, 32, 10])
    state_dict = torch.load(file)
    model.load_state_dict(state_dict)
    trainer.model = model

    print(
        "Prune the most important feature of each layer a given number of times without reloading the model"
    )
    for i in range(prune_count):
        print("reduce most important feature (", i, ")")
        for name, module in model.named_modules():
            if type(module) is Conv1dExtendable:
                ncc = module.normalized_cross_correlation()
            else:
                continue

            (max_val, max_idx) = torch.max(torch.abs(ncc), 0)
            (sorted_val, sorted_idx) = torch.sort(torch.abs(ncc),
                                                  dim=0,
                                                  descending=True)
            feature_numbers = sorted_idx[0:1]
            for i, feature_number in enumerate(feature_numbers):
                print("in ", name, ", prune feature number ", feature_number,
                      " with ncc ", sorted_val[i])
                module.prune_feature(feature_number=feature_number)

        test_loss = trainer.test()
        print("change of test loss: ", (test_loss / orig_test_loss))
        print("")
def sort_importance_by_testing(module_name):
    module = trainer.model.seq.__getattr__(module_name)
    test_results = []
    for feature_number in range(module.out_channels):
        module = trainer.model.seq.__getattr__(module_name)
        module.prune_feature(feature_number)
        res = trainer.test()
        test_results.append(res)

        model = FC_Net(layer_sizes=[784, 32, 32, 32, 32, 10])
        state_dict = torch.load(file)
        model.load_state_dict(state_dict)
        trainer.model = model

    tr = torch.FloatTensor(test_results)
    (sorted_val, sorted_idx) = torch.sort(torch.abs(tr), dim=0, descending=False)
    print("indices of the features sorted by importance:")
    print(sorted_idx)
    return sorted_idx
import torch
import os
from models import FC_Net
from training import OptimizerMNIST
from expanding_modules import Conv1dExtendable
from logger import Logger

model = FC_Net(layer_sizes=[784, 4, 4, 10])
trainer = OptimizerMNIST(model,
                         epochs=20,
                         expand_interval=400,
                         log_interval=400,
                         expand_threshold=0.05,
                         prune_threshold=0.0,
                         expand_rate=8,
                         lr=0.0001,
                         weight_decay=0)

#name = str(model.layer_count) + "_layers_" + "_extend_" + str(trainer.extend_threshold) + "_prune_" + str(trainer.prune_threshold) + "_Adam"
name = str(model.layer_count) + "_layers_" + "_extend_" + str(
    trainer.expand_rate) + "_interval_" + str(
        trainer.expand_interval) + "_lr_" + str(trainer.lr) + "_min"
folder = "./experiments/FixedExpansionRate/" + name

#logger=Logger(folder)
#trainer.logger = logger

trainer.train()
torch.save(model.state_dict(),
           folder + "/model_trained_for_" + str(trainer.epochs) + "_epochs")
Beispiel #5
0
import torch
import os
from models import FC_Net
from training import MNIST_Optimizer
from expanding_modules import Conv1dExtendable

model = FC_Net(layer_sizes=[784, 32, 32, 32, 32, 10])
trainer = MNIST_Optimizer(model, epochs=10)

file = "trained_model_784_4x32_10"

if os.path.exists(file):
    print("load model")
    state_dict = torch.load(file)
    model.load_state_dict(state_dict)
else:
    print("save model")
    trainer.train()
    torch.save(model.state_dict(), file)

print("initial test run:")
orig_test_loss, correct = trainer.test()

# trainer.extend_threshold = 10
# trainer.prune_threshold = 0.12
# trainer.extend_and_prune(0)

# prune_count = 20
#
# print("\n prune the ", prune_count, " most important features seperately\n ")
# for name, module in model.named_modules():
import torch
import os
from models import FC_Net, Conv_Net
from training import MNIST_Optimizer
from torch.autograd import Variable
from expanding_modules import Conv1dExtendable, Conv2dExtendable

model = FC_Net(layer_sizes=[784, 32, 32, 32, 32, 10])
#model = Conv_Net()
trainer = MNIST_Optimizer(model, epochs=10)

file = "trained_model_784_4x32_10"

if os.path.exists(file):
    print("load model")
    state_dict = torch.load(file)
    model.load_state_dict(state_dict)
else:
    print("save model")
    trainer.train()
    torch.save(model.state_dict(), file)

print("initial test run:")
orig_test_loss, correct = trainer.test()

test_img = Variable(torch.rand(28, 28))
print(model(test_img))

for name, module in model.named_modules():
    if type(module) is Conv1dExtendable or type(module) is Conv2dExtendable:
        ncc = module.normalized_cross_correlation()