def prune_convWeights(model_path, class_num, percent): if config.get('Parameters', 'model_arch').lower() == 'basic': model = model4prune.CNNModelBasic(class_num) elif config.get('Parameters', 'model_arch').lower() == 'poolrevised': model = model4prune.CNNModelPoolingRevised(class_num) print('loading model from {}...'.format(model_path)) model.load_state_dict(torch.load(model_path)) # print (model) total = 0 for m in model.modules(): if isinstance(m, nn.BatchNorm2d): total += m.weight.data.shape[0] bn = torch.zeros(total) index = 0 for m in model.modules(): if isinstance(m, nn.BatchNorm2d): size = m.weight.data.shape[0] bn[index:(index + size)] = m.weight.data.abs().clone() index += size y, i = torch.sort(bn) thre_index = int(total * percent) thre = y[thre_index] pruned = 0 cfg = [] cfg_mask = [] for k, m in enumerate(model.modules()): if isinstance(m, nn.BatchNorm2d): weight_copy = m.weight.data.abs().clone() mask = weight_copy.gt(thre).float() pruned = pruned + mask.shape[0] - torch.sum(mask) m.weight.data.mul_(mask) m.bias.data.mul_(mask) cfg.append(int(torch.sum(mask))) cfg_mask.append(mask.clone()) print( 'layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}' .format(k, mask.shape[0], int(torch.sum(mask)))) elif isinstance(m, nn.AvgPool2d): cfg.append('A') pruned_ratio = pruned / total print('Pre-processing Successful! pruned ratio:{}'.format(pruned_ratio)) return model, cfg, cfg_mask
def get_newmodel(model, class_num, cfg, cfg_mask): if config.get('Parameters', 'model_arch').lower() == 'basic': newmodel = model4prune.CNNModelBasic(class_num, cfg) elif config.get('Parameters', 'model_arch').lower() == 'poolrevised': newmodel = model4prune.CNNModelPoolingRevised(class_num, cfg) # print (newmodel) mask_idx = 0 start_mask = torch.ones(3) end_mask = cfg_mask[mask_idx] reach_Linear = False for m0, m1 in zip(model.modules(), newmodel.modules()): if reach_Linear: # pruning only done on CNN layer + the first FC layer break if isinstance(m0, nn.BatchNorm2d): print('copying bn weight...') idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.numpy()))) m1.weight.data = m0.weight.data[idx1.tolist()].clone() m1.bias.data = m0.bias.data[idx1.tolist()].clone() mask_idx += 1 start_mask = end_mask.clone() if mask_idx < len(cfg_mask): end_mask = cfg_mask[mask_idx] elif isinstance(m0, nn.Conv2d): print('copying conv2d weight...') idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.numpy()))) idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.numpy()))) w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() w1 = w1[idx1.tolist(), :, :, :].clone() m1.weight.data = w1.clone() elif isinstance(m0, nn.Linear): print('copying linear weight...') # this should be the first linear after the last bn layer idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.numpy()))) m1.weight.data = m0.weight.data[:, idx0.tolist()].clone() m1.bias.data = m0.bias.data.clone() reach_Linear = True print('model pruning finished!') return newmodel
def finetune(traindatadir, valdatadir, traindatacsv, valdatacsv, device, model_path, cfg_path, save_model_filename, frorm_scratch=True): EPOCH = 100 printout_steps = 50 eval_steps = 200 lr_steps = 300 lr = 5e-4 t_max = 200 eta_min = 3e-6 print('initialize dataset...') voiceDataset = FATDataset(traindatadir, valdatadir, traindatacsv, valdatacsv, batch_size=8) # print (voiceDataset.get_class_num()) print('create model ... ') # cnnmodel = models.CNNModelv2(voiceDataset.get_class_num()).to(device) print('cfg_path', cfg_path) cfg = pkl.load(open(cfg_path, 'rb')) if config.get('Parameters', 'model_arch').lower() == 'basic': cnnmodel = model4prune.CNNModelBasic(voiceDataset.get_class_num(), cfg).to(device) elif config.get('Parameters', 'model_arch').lower() == 'poolrevised': cnnmodel = model4prune.CNNModelPoolingRevised( voiceDataset.get_class_num(), cfg).to(device) print(cnnmodel) if not frorm_scratch: print('loading model from {}...'.format(model_path)) cnnmodel.load_state_dict(torch.load(model_path)) optimizer = torch.optim.Adam(cnnmodel.parameters(), lr) scheduler = CosineAnnealingLR(optimizer, T_max=t_max, eta_min=eta_min) criterion = nn.BCEWithLogitsLoss() bestlwlrap = -1 for e in range(EPOCH): voiceDataset.shuffle_trainingdata() num_step_per_epoch = voiceDataset.get_numof_batch(istraindata=True) # print (num_step_per_epoch) for bidx in tqdm(range(num_step_per_epoch)): # print ('get fingerprint...') batch_data, samplenumbatch, label_batch = voiceDataset.get_data( bidx, True) #[M, 128, duration, 3] # print ('fingerprint got...') if batch_data.shape[0] <= 1: continue # print (batch_data.shape,label_batch.shape) bx = batch_data.to(device) output = cnnmodel(bx) output = oneSampleOutput(output, samplenumbatch).to(device) # print (output.shape,label_batch.shape) by = label_batch.to(device) # by = autograd.Variable(label_batch,requires_grad = True) loss = criterion(output, by) # loss = autograd.Variable(loss, requires_grad = True) if bidx % printout_steps == 0: msg = '[TRAINING] Epoch:{}, step:{}/{}, loss:{}'.format( e, bidx, num_step_per_epoch, loss) print(msg) logger.info(msg) optimizer.zero_grad() loss.backward() optimizer.step() if (bidx + 1) % eval_steps == 0: # doing validation cnnmodel.eval() val_batches_num = voiceDataset.get_numof_batch(False) val_preds = np.array([]).reshape(0, voiceDataset.get_class_num()) val_labels = np.array([]).reshape(0, voiceDataset.get_class_num()) val_loss = 0. for vbidx in tqdm(range(val_batches_num)): # print ('generating validation fingerprint ... ') val_data, val_samplenumbatch, val_label = voiceDataset.get_data( vbidx, False) # print ('val_data shape:',val_data.shape) pred = oneSampleOutput( cnnmodel(val_data.to(device)).detach(), val_samplenumbatch).to(device) val_preds = np.vstack((val_preds, pred.cpu().numpy())) # print (pred.shape) # print (criterion(pred,val_label.to(device))) val_loss += criterion( pred, val_label.to(device)).item() / val_label.shape[0] val_labels = np.vstack( (val_labels, val_label.cpu().numpy())) score, weight = utils.calculate_per_class_lwlrap( val_labels, val_preds) lwlrap = (score * weight).sum() msg = '[VALIDATION] Epoch:{}, step:{}:/{}, loss:{}, lwlrap:{}'.format( e, bidx, num_step_per_epoch, val_loss, lwlrap) print(msg) logger.info(msg) if lwlrap > bestlwlrap or bidx == num_step_per_epoch - 1: bestlwlrap = lwlrap #save model save_model_path = os.path.join(checkpoint_dir, save_model_filename) torch.save(cnnmodel.state_dict(), save_model_path) msg = 'save model to: {}'.format(save_model_path) print(msg) logger.info(msg) cnnmodel.train() if bidx % lr_steps == 0: scheduler.step()
def test(traindatadir, testdatadir, traindatacsv, testdatacsv, device, model_path=''): """ test data is loaded in same format as validation data """ print('initialize dataset...') voiceDataset = FATDataset(traindatadir, testdatadir, traindatacsv, testdatacsv, batch_size=8) # print (voiceDataset.get_class_num()) print('create model ... ') # cnnmodel = models.CNNModelv2(voiceDataset.get_class_num()).to(device) cnnmodel = model4prune.CNNModelBasic( voiceDataset.get_class_num()).to(device) if config.get('Parameters', 'model_arch').lower() == 'basic': cnnmodel = model4prune.CNNModelBasic( voiceDataset.get_class_num()).to(device) elif config.get('Parameters', 'model_arch').lower() == 'poolrevised': cnnmodel = model4prune.CNNModelPoolingRevised( voiceDataset.get_class_num()).to(device) #loading trained model print('loading model from {}...'.format(model_path)) cnnmodel.load_state_dict(torch.load(model_path)) # testing cnnmodel.eval() test_batches_num = voiceDataset.get_numof_batch(False) test_preds = np.array([]).reshape(0, voiceDataset.get_class_num()) test_labels = np.array([]).reshape(0, voiceDataset.get_class_num()) # calculate flops and params flop_test_data, _, __ = voiceDataset.get_data(0, False) flop_test_data = flop_test_data[0:1, :, :, :].to(device) flops, params = profile(cnnmodel, inputs=(flop_test_data, )) eval_start = timeit.default_timer() for tbidx in tqdm(range(test_batches_num)): # print ('generating validation fingerprint ... ') test_data, test_samplenumbatch, test_label = voiceDataset.get_data( tbidx, False) pred = oneSampleOutput( cnnmodel(test_data.to(device)).detach(), test_samplenumbatch).to(device) test_preds = np.vstack((test_preds, pred.cpu().numpy())) test_labels = np.vstack((test_labels, test_label.cpu().numpy())) eval_stop = timeit.default_timer() score, weight = utils.calculate_per_class_lwlrap(test_labels, test_preds) lwlrap = (score * weight).sum() msg = '[TESTING] lwlrap:{}, flops:{}, params:{}, running time:{}'.format( lwlrap, flops, params, eval_stop - eval_start) print(msg) logger.info(msg)