def prune_convWeights(model_path, class_num, percent):
    if config.get('Parameters', 'model_arch').lower() == 'basic':
        model = model4prune.CNNModelBasic(class_num)
    elif config.get('Parameters', 'model_arch').lower() == 'poolrevised':
        model = model4prune.CNNModelPoolingRevised(class_num)
    print('loading model from {}...'.format(model_path))
    model.load_state_dict(torch.load(model_path))
    # print (model)

    total = 0
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            total += m.weight.data.shape[0]

    bn = torch.zeros(total)
    index = 0
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            size = m.weight.data.shape[0]
            bn[index:(index + size)] = m.weight.data.abs().clone()
            index += size

    y, i = torch.sort(bn)
    thre_index = int(total * percent)
    thre = y[thre_index]

    pruned = 0
    cfg = []
    cfg_mask = []
    for k, m in enumerate(model.modules()):
        if isinstance(m, nn.BatchNorm2d):
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.gt(thre).float()
            pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
            cfg.append(int(torch.sum(mask)))
            cfg_mask.append(mask.clone())
            print(
                'layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'
                .format(k, mask.shape[0], int(torch.sum(mask))))
        elif isinstance(m, nn.AvgPool2d):
            cfg.append('A')

    pruned_ratio = pruned / total

    print('Pre-processing Successful! pruned ratio:{}'.format(pruned_ratio))
    return model, cfg, cfg_mask
def get_newmodel(model, class_num, cfg, cfg_mask):

    if config.get('Parameters', 'model_arch').lower() == 'basic':
        newmodel = model4prune.CNNModelBasic(class_num, cfg)
    elif config.get('Parameters', 'model_arch').lower() == 'poolrevised':
        newmodel = model4prune.CNNModelPoolingRevised(class_num, cfg)
    # print (newmodel)
    mask_idx = 0
    start_mask = torch.ones(3)
    end_mask = cfg_mask[mask_idx]
    reach_Linear = False
    for m0, m1 in zip(model.modules(), newmodel.modules()):
        if reach_Linear:
            # pruning only done on CNN layer + the first FC layer
            break
        if isinstance(m0, nn.BatchNorm2d):
            print('copying bn weight...')
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.numpy())))
            m1.weight.data = m0.weight.data[idx1.tolist()].clone()
            m1.bias.data = m0.bias.data[idx1.tolist()].clone()

            mask_idx += 1
            start_mask = end_mask.clone()
            if mask_idx < len(cfg_mask):
                end_mask = cfg_mask[mask_idx]

        elif isinstance(m0, nn.Conv2d):
            print('copying conv2d weight...')
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.numpy())))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.numpy())))
            w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
            w1 = w1[idx1.tolist(), :, :, :].clone()
            m1.weight.data = w1.clone()

        elif isinstance(m0, nn.Linear):
            print('copying linear weight...')
            # this should be the first linear after the last bn layer
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.numpy())))
            m1.weight.data = m0.weight.data[:, idx0.tolist()].clone()
            m1.bias.data = m0.bias.data.clone()
            reach_Linear = True
    print('model pruning finished!')
    return newmodel
def finetune(traindatadir,
             valdatadir,
             traindatacsv,
             valdatacsv,
             device,
             model_path,
             cfg_path,
             save_model_filename,
             frorm_scratch=True):
    EPOCH = 100
    printout_steps = 50
    eval_steps = 200
    lr_steps = 300
    lr = 5e-4
    t_max = 200
    eta_min = 3e-6

    print('initialize dataset...')
    voiceDataset = FATDataset(traindatadir,
                              valdatadir,
                              traindatacsv,
                              valdatacsv,
                              batch_size=8)
    # print (voiceDataset.get_class_num())
    print('create model ... ')

    # cnnmodel = models.CNNModelv2(voiceDataset.get_class_num()).to(device)
    print('cfg_path', cfg_path)
    cfg = pkl.load(open(cfg_path, 'rb'))
    if config.get('Parameters', 'model_arch').lower() == 'basic':
        cnnmodel = model4prune.CNNModelBasic(voiceDataset.get_class_num(),
                                             cfg).to(device)
    elif config.get('Parameters', 'model_arch').lower() == 'poolrevised':
        cnnmodel = model4prune.CNNModelPoolingRevised(
            voiceDataset.get_class_num(), cfg).to(device)
    print(cnnmodel)
    if not frorm_scratch:
        print('loading model from {}...'.format(model_path))
        cnnmodel.load_state_dict(torch.load(model_path))
    optimizer = torch.optim.Adam(cnnmodel.parameters(), lr)
    scheduler = CosineAnnealingLR(optimizer, T_max=t_max, eta_min=eta_min)
    criterion = nn.BCEWithLogitsLoss()
    bestlwlrap = -1
    for e in range(EPOCH):
        voiceDataset.shuffle_trainingdata()
        num_step_per_epoch = voiceDataset.get_numof_batch(istraindata=True)
        # print (num_step_per_epoch)
        for bidx in tqdm(range(num_step_per_epoch)):
            # print ('get fingerprint...')
            batch_data, samplenumbatch, label_batch = voiceDataset.get_data(
                bidx, True)  #[M, 128, duration, 3]
            # print ('fingerprint got...')
            if batch_data.shape[0] <= 1:
                continue
            # print (batch_data.shape,label_batch.shape)
            bx = batch_data.to(device)

            output = cnnmodel(bx)
            output = oneSampleOutput(output, samplenumbatch).to(device)
            # print (output.shape,label_batch.shape)
            by = label_batch.to(device)
            #             by = autograd.Variable(label_batch,requires_grad = True)
            loss = criterion(output, by)
            #             loss = autograd.Variable(loss, requires_grad = True)
            if bidx % printout_steps == 0:
                msg = '[TRAINING] Epoch:{}, step:{}/{}, loss:{}'.format(
                    e, bidx, num_step_per_epoch, loss)
                print(msg)
                logger.info(msg)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (bidx + 1) % eval_steps == 0:
                # doing validation
                cnnmodel.eval()
                val_batches_num = voiceDataset.get_numof_batch(False)
                val_preds = np.array([]).reshape(0,
                                                 voiceDataset.get_class_num())
                val_labels = np.array([]).reshape(0,
                                                  voiceDataset.get_class_num())
                val_loss = 0.
                for vbidx in tqdm(range(val_batches_num)):
                    # print ('generating validation fingerprint ... ')
                    val_data, val_samplenumbatch, val_label = voiceDataset.get_data(
                        vbidx, False)
                    # print ('val_data shape:',val_data.shape)
                    pred = oneSampleOutput(
                        cnnmodel(val_data.to(device)).detach(),
                        val_samplenumbatch).to(device)
                    val_preds = np.vstack((val_preds, pred.cpu().numpy()))
                    # print (pred.shape)
                    # print (criterion(pred,val_label.to(device)))
                    val_loss += criterion(
                        pred, val_label.to(device)).item() / val_label.shape[0]
                    val_labels = np.vstack(
                        (val_labels, val_label.cpu().numpy()))
                score, weight = utils.calculate_per_class_lwlrap(
                    val_labels, val_preds)
                lwlrap = (score * weight).sum()
                msg = '[VALIDATION] Epoch:{}, step:{}:/{}, loss:{}, lwlrap:{}'.format(
                    e, bidx, num_step_per_epoch, val_loss, lwlrap)
                print(msg)
                logger.info(msg)
                if lwlrap > bestlwlrap or bidx == num_step_per_epoch - 1:
                    bestlwlrap = lwlrap
                    #save model
                    save_model_path = os.path.join(checkpoint_dir,
                                                   save_model_filename)
                    torch.save(cnnmodel.state_dict(), save_model_path)
                    msg = 'save model to: {}'.format(save_model_path)
                    print(msg)
                    logger.info(msg)

                cnnmodel.train()
            if bidx % lr_steps == 0:
                scheduler.step()
Exemple #4
0
def test(traindatadir,
         testdatadir,
         traindatacsv,
         testdatacsv,
         device,
         model_path=''):
    """
    test data is loaded in same format as validation data
    """
    print('initialize dataset...')
    voiceDataset = FATDataset(traindatadir,
                              testdatadir,
                              traindatacsv,
                              testdatacsv,
                              batch_size=8)
    # print (voiceDataset.get_class_num())
    print('create model ... ')

    # cnnmodel = models.CNNModelv2(voiceDataset.get_class_num()).to(device)
    cnnmodel = model4prune.CNNModelBasic(
        voiceDataset.get_class_num()).to(device)
    if config.get('Parameters', 'model_arch').lower() == 'basic':
        cnnmodel = model4prune.CNNModelBasic(
            voiceDataset.get_class_num()).to(device)
    elif config.get('Parameters', 'model_arch').lower() == 'poolrevised':
        cnnmodel = model4prune.CNNModelPoolingRevised(
            voiceDataset.get_class_num()).to(device)
    #loading trained model
    print('loading model from {}...'.format(model_path))
    cnnmodel.load_state_dict(torch.load(model_path))

    # testing
    cnnmodel.eval()
    test_batches_num = voiceDataset.get_numof_batch(False)
    test_preds = np.array([]).reshape(0, voiceDataset.get_class_num())
    test_labels = np.array([]).reshape(0, voiceDataset.get_class_num())

    # calculate flops and params
    flop_test_data, _, __ = voiceDataset.get_data(0, False)
    flop_test_data = flop_test_data[0:1, :, :, :].to(device)
    flops, params = profile(cnnmodel, inputs=(flop_test_data, ))

    eval_start = timeit.default_timer()
    for tbidx in tqdm(range(test_batches_num)):
        # print ('generating validation fingerprint ... ')
        test_data, test_samplenumbatch, test_label = voiceDataset.get_data(
            tbidx, False)

        pred = oneSampleOutput(
            cnnmodel(test_data.to(device)).detach(),
            test_samplenumbatch).to(device)
        test_preds = np.vstack((test_preds, pred.cpu().numpy()))

        test_labels = np.vstack((test_labels, test_label.cpu().numpy()))
    eval_stop = timeit.default_timer()

    score, weight = utils.calculate_per_class_lwlrap(test_labels, test_preds)
    lwlrap = (score * weight).sum()
    msg = '[TESTING]  lwlrap:{}, flops:{}, params:{}, running time:{}'.format(
        lwlrap, flops, params, eval_stop - eval_start)

    print(msg)
    logger.info(msg)