Ejemplo n.º 1
0
def getPrunedNet(params_scores, orig_net, network_name, prune_factor=0.5):
    # prune_th = prune_factor
    prune_th, _ = torch.kthvalue(params_scores, int(np.ceil(prune_factor*params_scores.numel())))
    print("The thershold is {}".format(prune_th))
    params_mask = (params_scores > prune_th).detach().cpu().numpy()

    info = networks_data.get(network_name)

    ### build a new network based on params_mask ###
    new_net = info.get('pruned_network')(params_mask)

    new_dict = new_net.state_dict()
    pretrained_dict = orig_net.state_dict()  

    updatePrunedNet = info.get('update_pruned_net_func')

    new_dict = updatePrunedNet(new_dict, pretrained_dict, params_mask)

    # 1. filter out unnecessary keys
    # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    # model_dict.update(pretrained_dict) 
    # 3. load the new state dict
    new_net.load_state_dict(new_dict)

    return new_net
def createNameGraph(network_name):
    network_name_graph = nx.DiGraph()

    info = networks_data.get(network_name)
    CG_func = info.get('create_graph_func')
    network_name_graph = CG_func(network_name_graph)

    return network_name_graph
def main():
    # set parameters
    network_name = 'WRN_40_2'

    # get the desired path to save
    graph_path = networks_data.get(network_name).get('graph_path')

    # get the graph that represents the model
    net_graph = convertModelToGraph(network_name)

    # save it
    saveGraph(net_graph, graph_path)
Ejemplo n.º 4
0
def main():
    # set parameters
    num_epochs = 200
    train_bs = 64
    test_bs = 32
    dataset_name = 'cifar10'
    network_name = 'WRN_40_2'
    save_path = './CIFAR10_models/WRN_40_2.pt'
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    checkpointPath = './CIFAR10_models/checkpoints/CP_epoch_{}_acc_{}.pt'

    # define the datasets
    data_train = datasets_train.get(dataset_name)
    data_test = datasets_test.get(dataset_name)

    # define the data loaders
    data_train_loader = DataLoader(data_train,
                                   batch_size=train_bs,
                                   shuffle=True,
                                   num_workers=8)
    data_test_loader = DataLoader(data_test, batch_size=test_bs, num_workers=8)

    # initialize the network and what else is necessary for training
    net = networks_data.get(network_name).get('network')().to(device)
    criterion = nn.CrossEntropyLoss()
    # optimizer = optim.Adam(net.parameters(), lr=2e-3)
    optimizer = optim.SGD(net.parameters(),
                          lr=0.1,
                          momentum=0.9,
                          nesterov=True,
                          weight_decay=5e-4)
    scheduler = MultiStepLR(
        optimizer,
        milestones=[int(elem * num_epochs) for elem in [0.3, 0.6, 0.8]],
        gamma=0.2)

    # train and evaluate the network
    best_test_set_accuracy = 0
    for e in range(1, num_epochs + 1):
        train(e, net, criterion, optimizer, data_train_loader, len(data_train),
              device)
        scheduler.step()
        accuracy = test(net, criterion, data_test_loader, len(data_test),
                        device)
        if accuracy > best_test_set_accuracy:
            best_test_set_accuracy = accuracy
            torch.save(net.state_dict(), checkpointPath.format(e, accuracy))

    # save the state dict and the accuracy
    torch.save(net.state_dict(), save_path)
    acc_path = save_path.replace('.pt', '_accuracy.pt')
    torch.save(accuracy, acc_path)
def convertModelToGraph(network_name):
    info = networks_data.get(network_name)

    network = info.get('network')()
    net_sd_path = info.get('sd_path')

    # create the network's graph where each node is a layer in the model's state dict
    net_name_graph = createNameGraph(network_name)
    # get the model state dict
    net_sd = torch.load(net_sd_path, map_location=torch.device('cpu'))
    # create the model's graph that the GNN will process
    network_graph = buildGraph(net_sd, net_name_graph)

    return network_graph
Ejemplo n.º 6
0
 def __init__(self,
              root,
              network_name,
              isWRN=False,
              graph_path=None,
              transform=None,
              pre_transform=None):
     self.network_name = network_name
     self.isWRN = isWRN
     if graph_path is None:
         self.graph_path = networks_data.get(network_name).get('graph_path')
     else:
         self.graph_path = graph_path
     super(GraphDataset, self).__init__(root, transform, pre_transform)
     self.data, self.slices = torch.load(self.processed_paths[0])
Ejemplo n.º 7
0
def GNN_prune_loss(params_scores,
                   orig_net,
                   orig_net_loss,
                   network_name,
                   val_data,
                   device,
                   reduction=None,
                   prune_factor=0.5,
                   gamma1=1,
                   gamma2=1,
                   gamma3=1e2):
    ### first calculate the sparse term ###
    # sparse_term = gamma2 * torch.log2((1 / (torch.log2(params_scores.numel()/(torch.sum(params_scores) + 1e-6)) + 1e-6)))
    sparse_term = gamma2 * torch.log2(torch.sum(params_scores) + 1e-6)
    # sparse_term_v = 1/sparse_term.item()

    batch_size = 8
    val_data_loader = torch.utils.data.DataLoader(val_data,
                                                  batch_size=batch_size,
                                                  shuffle=True,
                                                  num_workers=8)
    new_net = networks_data.get(network_name).get('masked_network')(
        params_scores, device).to(device)

    model_dict = new_net.state_dict()
    pretrained_dict = orig_net.state_dict()

    model_dict.update(pretrained_dict)

    new_net.load_state_dict(model_dict)

    ### train the new network a bit ###
    new_net.train()
    # don't train the masks
    for m in new_net.modules():
        if isinstance(m, ElemWiseMultiply):
            m.mask.requires_grad = False

    optimizer = torch.optim.SGD(new_net.parameters(),
                                lr=0.02,
                                momentum=0.9,
                                nesterov=True,
                                weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()

    for e in range(3):
        for i, (images, labels) in enumerate(val_data_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            output = new_net(images)
            loss = criterion(output, labels)

            if i % 100 == 0:
                print('Retrain - Epoch %d, Batch: %d, Loss: %f' %
                      (e, i, loss.detach().cpu().item()))

            loss.backward()
            optimizer.step()

    for m in new_net.modules():
        if isinstance(m, ElemWiseMultiply):
            m.mask.requires_grad = True

    ### calculate the new network's loss ###
    new_net.eval()
    avg_loss = 0.0
    acc_loss = 0.0
    criterion = nn.CrossEntropyLoss()

    num_b = 0
    acc_data_grad = torch.zeros_like(params_scores)

    for i, (images, labels) in enumerate(val_data_loader):
        images, labels = images.to(device), labels.to(device)
        new_net.zero_grad()
        output = new_net(images)
        loss = criterion(output, labels)

        avg_loss += loss.sum()
        # loss = gamma1 * torch.log2(loss/orig_net_loss)
        loss = gamma1 * loss
        acc_loss += loss.sum()
        # loss *= sparse_term_v
        loss.backward()

        # get all the gradients with respect to mask
        indToSum = {}
        for m in new_net.modules():
            if isinstance(m, ElemWiseMultiply):
                nf = m.mask.grad.data.size()[1]
                ind = m.indices_list
                for i in range(nf):
                    s = torch.sum(m.mask.grad.data[:, i, :, :])
                    indToSum[ind[i]] = s
        # add those gradients the rest
        for i in range(len(params_scores)):
            acc_data_grad[i] += indToSum[i]

        num_b += 1
        if num_b == 4000:
            break

    # avg_loss /= len(val_data)
    avg_loss /= (batch_size * num_b)
    acc_loss /= (batch_size * num_b)

    new_net_loss = avg_loss
    data_term = acc_loss
    print("The data_term in the loss is: {}".format(data_term))
    print("The sparse_term in the loss is: {}".format(sparse_term))

    return sparse_term, data_term, acc_data_grad
Ejemplo n.º 8
0
def main():
    # random.seed(0)
    # torch.manual_seed(0)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False
    # np.random.seed(0)
    # torch.cuda.manual_seed(0)

    # set all hyperparameters
    network_name = 'WRN_40_2'
    num_epochs = 35
    batch_size = 1
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    n_retrain_epochs = 40 
    trials = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9]
    lr = 3e-4
    opt = "Adam"
    use_temp = False
    use_steps = False

    # set paths
    checkpointPath = './GNN_model/CIFAR10_checkpoints/CP__num_e_{}__retrain_e_{}__lr_{}__opt_{}__useTemp_{}__useSteps_{}__epoch_{}.pt'.format(num_epochs, n_retrain_epochs, lr, opt, use_temp, use_steps, '{}')    
    continue_train = False
    checkpointLoadPath = './GNN_model/CIFAR10_checkpoints/CP__num_e_{}__retrain_e_{}__lr_{}__opt_{}__useTemp_{}__useSteps_{}__epoch_{}.pt'.format(num_epochs, n_retrain_epochs, lr, opt, use_temp, use_steps, '20')

    # get GNN path
    info = networks_data.get(network_name)
    trained_model_path = info.get('trained_GNN_path').replace('.pt', '___num_e_{}__retrain_e_{}__lr_{}__opt_{}__useTemp_{}__useSteps_{}.pt'.format(num_epochs, n_retrain_epochs, lr, opt, use_temp, use_steps))

    # declare GNN model
    model = GNNPrunningNet(in_channels=6, out_channels=128).to(device)
    if opt == "Adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    else:
        # lr = 0.1
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=5e-4)
        scheduler = MultiStepLR(optimizer, milestones=[int(elem*num_epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)
    crit = GNN_prune_loss

    # declate TensorBoard writer
    summary_path = '{}-num_e_{}__retrain_e_{}__lr_{}__opt_{}__useTemp_{}__useSteps_{}/training'.format(network_name, num_epochs, n_retrain_epochs, lr, opt, use_temp, use_steps)
    writer = SummaryWriter(summary_path)


    root            = info.get('root')
    net_graph_path  = info.get('graph_path')
    sd_path         = info.get('sd_path')
    net             = info.get('network')
    orig_net_loss   = info.get('orig_net_loss') 

    isWRN = (network_name == "WRN_40_2")
    train_dataset   = GraphDataset(root, network_name, isWRN, net_graph_path)
    train_loader    = DataLoader(train_dataset, batch_size=batch_size)

    orig_net = net().to(device)
    orig_net.load_state_dict(torch.load(sd_path, map_location=device))

    model.train()

    dataset_name = info.get('dataset_name')
    network_train_data = datasets_train.get(dataset_name)

    print("Start training")

    if continue_train == True:
        cp = torch.load(checkpointLoadPath, map_location=device)
        trained_epochs = cp['epoch'] + 1
        sd = cp['model_state_dict']
        model.load_state_dict(sd)
        op_sd = cp['optimizer_state_dict']
        optimizer.load_state_dict(op_sd)
    else:
        trained_epochs = 0

    loss_all = 0.0
    data_all = 0.0
    sparse_all = 0.0
    if use_temp == True:
        T = 1.0
        if trained_epochs > 0:
            T = np.power(2, np.floor(trained_epochs / int(num_epochs/3)))

    for epoch in range(trained_epochs, num_epochs):
        
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            output = model(data)

            if use_temp == True:
            # Use temperature
                nom = torch.pow((torch.exp(torch.tensor(T, device=device))), output)
                dom = torch.pow((torch.exp(torch.tensor(T, device=device))), output) + torch.pow((torch.exp(torch.tensor(T, device=device))), (1-output))
                output = nom/dom
                # continue as usual

            sparse_term, data_term, data_grad = crit(output, orig_net, orig_net_loss, network_name, network_train_data, device, gamma1=10, gamma2=0.1)

            if use_steps == True:
                if epoch % 3 == 0: # do 2 steps in data direction then 1 in sparsity
                    sparse_term.backward()
                else:
                    output.backward(data_grad)
            else:            
                sparse_term.backward(retain_graph=True)
                output.backward(data_grad)

            data_all += data.num_graphs * data_term.item()
            sparse_all += data.num_graphs * sparse_term.item()
            loss_all += data_all + sparse_all
            optimizer.step()
            
        print("epoch {}. total loss is: {}".format(epoch+1, (data_term.item() + sparse_term.item()) / len(train_dataset)))
        
        if opt != "Adam":
            scheduler.step()

        if use_temp == True:
        # increase temperature 3 times
            if (epoch+1) % int(num_epochs/3) == 0:
                T *= 2

        if epoch % 10 == 9:
            writer.add_scalars('Learning curve', {
            'loss data term': data_all/10,
            'loss sparsity term': sparse_all/10,
            'training loss': loss_all/10
            }, epoch+1)            

            # save checkpoint
            if opt == "Adam":
                torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss_all,
                }, checkpointPath.format(epoch+1))
            else:
                torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss_all,
                'scheduler_state_dict': scheduler.state_dict(),
                }, checkpointPath.format(epoch+1))

            loss_all = 0.0
            data_all = 0.0
            sparse_all = 0.0
            

    torch.save(model.state_dict(), trained_model_path)            

    print("Start evaluating")

    model.load_state_dict(torch.load(trained_model_path, map_location=device))

    model.eval()

    network_val_data = datasets_test.get(dataset_name)
    val_data_loader = torch.utils.data.DataLoader(network_val_data, batch_size=1024, shuffle=False, num_workers=8) 

    for trial, p_factor in enumerate(trials):
        with torch.no_grad():
            for data in train_loader:
                data = data.to(device)

                pred = model(data)

                prunedNet = getPrunedNet(pred, orig_net, network_name, prune_factor=p_factor).to(device)

        # Train the pruned network
        prunedNet.train()

        data_train_loader = torch.utils.data.DataLoader(network_train_data, batch_size=256, shuffle=False, num_workers=8) 
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(prunedNet.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
        scheduler = MultiStepLR(optimizer, milestones=[int(elem*n_retrain_epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)

        for epoch in range (n_retrain_epochs):
            for i, (images, labels) in enumerate(data_train_loader):
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                output = prunedNet(images)
                loss = criterion(output, labels)

                if i % 30 == 0:
                    print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch+1, i, loss.detach().cpu().item()))

                loss.backward()
                optimizer.step()

            scheduler.step()

        # Evaluate the pruned net
        with torch.no_grad():

            total_correct = 0
            cuda_time = 0.0            
            cpu_time = 0.0

            for i, (images, labels) in enumerate(val_data_loader):
                images, labels = images.to(device), labels.to(device)

                with torch.autograd.profiler.profile(use_cuda=True) as prof:
                    output = prunedNet(images)
                cuda_time += sum([item.cuda_time for item in prof.function_events])
                cpu_time += sum([item.cpu_time for item in prof.function_events])

                pred = output.detach().max(1)[1]
                total_correct += pred.eq(labels.view_as(pred)).sum()

            p_acc = float(total_correct) / len(network_val_data)
            p_num_params = gnp(prunedNet)
            p_cuda_time = cuda_time / len(network_val_data)
            p_cpu_time = cpu_time / len(network_val_data)

            print("The pruned network for prune factor {} accuracy is: {}".format(p_factor, p_acc))
            print("The pruned network number of parameters is: {}".format(p_num_params))
            print("The pruned network cuda time is: {}".format(p_cuda_time))
            print("The pruned network cpu time is: {}".format(p_cpu_time))

        # Evaluate the original net
        with torch.no_grad():

            total_correct = 0
            cuda_time = 0.0            
            cpu_time = 0.0
            
            for i, (images, labels) in enumerate(val_data_loader):
                images, labels = images.to(device), labels.to(device)

                with torch.autograd.profiler.profile(use_cuda=True) as prof:
                    output = orig_net(images)
                cuda_time += sum([item.cuda_time for item in prof.function_events])
                cpu_time += sum([item.cpu_time for item in prof.function_events])

                pred = output.detach().max(1)[1]
                total_correct += pred.eq(labels.view_as(pred)).sum()

            o_acc = float(total_correct) / len(network_val_data)
            o_num_params = gnp(orig_net)
            o_cuda_time = cuda_time / len(network_val_data)
            o_cpu_time = cpu_time / len(network_val_data)

            print("The original network accuracy is: {}".format(o_acc))
            print("The original network number of parameters is: {}".format(o_num_params))
            print("The original network cuda time is: {}".format(o_cuda_time))
            print("The original network cpu time is: {}".format(o_cpu_time))

        writer.add_scalars('Network accuracy', {
            'original': o_acc,
            'pruned': p_acc
            }, 100*p_factor)
        writer.add_scalars('Network number of parameters', {
            'original': o_num_params,
            'pruned': p_num_params
            }, 100*p_factor)
        writer.add_scalars('Network GPU time', {
            'original': o_cuda_time,
            'pruned': p_cuda_time
            }, 100*p_factor)
        writer.add_scalars('Network CPU time', {
            'original': o_cpu_time,
            'pruned': p_cpu_time
            }, 100*p_factor)

    writer.close()