in_channels, out_channels, final_sigmoid, f_maps = 1, args.n_class, False, 64 # f_map=64 if args.enable_ssc_unet: model = UNet3DSSC() else: model = UNet3D(in_channels, out_channels, final_sigmoid, f_maps=f_maps, layer_order='cbr', num_groups=8, number_of_fmaps = args.number_of_fmaps, enable_deepmodel_pooling=args.enable_deepmodel_pooling, width=args.width, res_type=args.res_type) weight_init(model, mode=args.weight_init) # ==== Optimizer if args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.w_decay) else: optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.w_decay, nesterov=True) # Components from SSC criterion = nn.CrossEntropyLoss(ignore_index=255) # ignore empty space marked as 255 if args.enable_cuda: model = model.cuda() # Tensorboard
if opt.model == 'i3d': new_model = refine_model_I3D(model, neuron_mask_clean) else: new_model = refine_model_classification( model, neuron_mask_clean, opt.model, network_connection_dict=network_connection_dict, enable_raw_grad=opt.enable_raw_grad) del model model = new_model.cpu() # print('New model:', model) if hasattr(model, 'module'): weight_init(model.module, mode=opt.weight_init) else: weight_init(model, mode=opt.weight_init) n_params_refined, n_neurons_refined = do_statistics_model(model) print('Statistics, org, params: {}, neurons: {}; refined, params: {} ({:.4f}%), neurons: {} ({:.4f}%)' \ .format(n_params_org, n_neurons_org, n_params_refined, n_params_refined * 100 / n_params_org, n_neurons_refined, n_neurons_refined * 100 / n_neurons_org)) # Calculate Flops, Params, Memory of New Model model_flops = copy.deepcopy(model) flops, params, memory, _ = profile(
grad_mode = 'raw' if opt.enable_raw_grad else 'abs' if opt.weight_init == 'xn': file_path = 'data/stereo/stereo_kernel_hidden_prune_grad_{}.npy'.format( grad_mode) else: file_path = 'data/stereo/stereo_kernel_hidden_prune_grad_init{}_{}.npy' \ .format(opt.weight_init, grad_mode) model = stackhourglass(opt.maxdisp) else: file_path = None model = None assert (model is not None) and (file_path is not None) \ and os.path.exists(file_path), file_path weight_init(model, mode=opt.weight_init) model = model.cuda() # Get resource list if opt.dataset == 'brats': profile_input = torch.randn(1, 4, opt.spatial_size, opt.spatial_size, opt.spatial_size).cuda() elif opt.dataset == 'shapenet': profile_input = torch.randn(1, 1, opt.prune_spatial_size, opt.prune_spatial_size, opt.prune_spatial_size).cuda() elif opt.dataset == 'ucf101': # the batch is 8 not 1 for mobilenetv2 in the rebuttal profile_input = torch.randn(1, in_channels, opt.sample_duration, opt.sample_size, opt.sample_size).cuda() elif opt.dataset == 'sceneflow': profile_input_L = torch.randn(3, 3, 256, 512).cuda()
def main(args): # Dataloader all_left_img, all_right_img, all_left_disp, test_left_img, \ test_right_img, test_left_disp = lt.dataloader(args.datapath) # batch_size=12 TrainImgLoader = torch.utils.data.DataLoader(DA.myImageFloder( all_left_img, all_right_img, all_left_disp, True), batch_size=12, shuffle=True, num_workers=args.batch * 2, drop_last=False) # batch_size=8 TestImgLoader = torch.utils.data.DataLoader(DA.myImageFloder( test_left_img, test_right_img, test_left_disp, False), batch_size=4, shuffle=False, num_workers=8, drop_last=False) # Model model = stackhourglass(args.maxdisp) weight_init(model, mode=args.weight_init) # Optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) # ========================================================= # Calculate Profile of Full Model model_full = copy.deepcopy(model) profile_input_L = torch.randn(3, 3, 256, 512).cuda() profile_input_R = torch.randn(3, 3, 256, 512).cuda() flops_full, params_full, memory_full, resource_list = \ profile(model_full.cuda(), inputs=(profile_input_L, profile_input_R), verbose=False, resource_list_type=args.resource_list_type, mode=args.statistic_mode) del model_full print('Full model, flops: {:.4f}G, params: {:.4f}MB, memory: {:.4f}MB' \ .format(flops_full / 1e9, params_full * 4 / (1024 ** 2), memory_full * 4 / (1024 ** 2))) # Prune including kernels and hidden layers if args.enable_neuron_prune or args.enable_hidden_layer_prune or args.enable_param_prune: args.spatial_size = args.prune_spatial_size args.dim = args.prune_spatial_size grad_mode = 'raw' if args.enable_raw_grad else 'abs' if args.weight_init == 'xn': file_path = 'data/stereo/stereo_kernel_hidden_prune_grad_{}.npy'.format( grad_mode) else: file_path = 'data/stereo/stereo_kernel_hidden_prune_grad_init{}_{}.npy' \ .format(args.weight_init, grad_mode) assert (args.batch == 1 if (not os.path.exists(file_path)) else True) outputs = pruning(file_path, model, TrainImgLoader, None, args, enable_3dunet=False, enable_hidden_sum=False, width=None, resource_list=resource_list, network_name='psm') assert outputs[0] == 0 # neuron_mask_clean, hidden_mask = outputs[1], outputs[2] valid_neuron_list_clean, hidden_mask = outputs[1], outputs[2] if args.enable_neuron_prune or args.enable_param_prune: n_params_org, n_neurons_org = do_statistics_model(model) new_model = refine_model_PSM(model, valid_neuron_list_clean) if False: # enable_dump_neuron_per_layer: dump_neuron_per_layer(copy.deepcopy(model), copy.deepcopy(new_model)) del model model = new_model.cpu() weight_init(model, mode=args.weight_init) model = model.cuda() n_params_refined, n_neurons_refined = do_statistics_model(model) print('Statistics, org, params: {}, neurons: {}; refined, '\ 'params: {} ({:.4f}%), neurons: {} ({:.4f}%)' \ .format(n_params_org, n_neurons_org, n_params_refined, n_params_refined * 100 / n_params_org, n_neurons_refined, n_neurons_refined * 100 / n_neurons_org)) # Calculate Flops, Params, Memory of New Model model_flops = copy.deepcopy(model) # display_model_structure(model_flops) flops, params, memory, _ = profile( model_flops.cuda(), inputs=(profile_input_L, profile_input_R), verbose=False, resource_list_type=args.resource_list_type, mode=args.statistic_mode) print('New model, flops: {:.4f}G, params: {:.4f}MB, memory: {:.4f}MB' \ .format(flops / 1e9, params * 4 / (1024 ** 2), memory * 4 / (1024 ** 2))) del model_flops, profile_input_L, profile_input_R # For new model optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) # ========================================================= # print(model) if False: num_layer_2D, num_layer_3D, num_neuron_2D, num_neuron_3D = check_neuron_ratio( model) print('Num layer 2D/3D:', num_layer_2D, num_layer_3D, ', num neuron 2D/3D:', num_neuron_2D, num_neuron_3D) print(args) print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) # Resume if args.loadmodel is not None and os.path.isfile(args.loadmodel): state_dict = torch.load(args.loadmodel, map_location=lambda storage, loc: storage) model.load_state_dict(state_dict['model']) optimizer = load_optimizer_state_dict(state_dict['optimizer'], optimizer, enable_cuda=args.cuda) if args.enable_train: if args.cuda: model = nn.DataParallel(model) model.cuda() # Train train_duration = 0 epoch_start = int(args.loadmodel.split('.')[-2].split('_')[-1]) if ( args.loadmodel is not None) else 0 train_len = len(TrainImgLoader) for epoch in range(epoch_start + 1, args.epochs + 1): print('This is %d-th epoch' % (epoch)) start_full_time = time.time() torch.manual_seed(epoch) if args.cuda: torch.cuda.manual_seed(epoch) total_train_loss = 0 # adjust_learning_rate(optimizer, epoch) for batch_idx, (imgL_crop, imgR_crop, disp_crop_L, _) in enumerate(TrainImgLoader): start_time = time.time() loss = train(imgL_crop, imgR_crop, disp_crop_L, model, optimizer) print('Iter %d/%d training loss = %.3f , time = %.2f' \ % (batch_idx, train_len, loss, time.time() - start_time)) total_train_loss += loss.item() train_duration += time.time() - start_full_time print('Epoch %d total training loss = %.3f' % (epoch, total_train_loss / len(TrainImgLoader))) # SAVE if not os.path.exists(args.savemodel): os.makedirs(args.savemodel) savefilename = os.path.join(args.savemodel, 'checkpoint_' + str(epoch) + '.tar') model_state_dict = get_model_state_dict(model) torch.save( { 'epoch': epoch, 'model': model_state_dict, 'train_loss': total_train_loss / len(TrainImgLoader), 'optimizer': optimizer.state_dict() }, savefilename) # Valid if epoch >= args.valid_min_epoch: valid_acc, total_test_loss = validate(TestImgLoader, model) print('Epoch: {}, acc1: {:.4f}, acc2: {:.4f}, acc3: {:.4f}, acc5: {:.4f}, epe: {:.4f}; test loss: {:.4f}' \ .format(epoch, valid_acc[0].avg, valid_acc[1].avg, valid_acc[2].avg, valid_acc[3].avg, valid_acc[4].avg, total_test_loss / len(TestImgLoader))) print('Full training time = %.2f HR' % (train_duration / 3600)) elif args.enable_test: if os.path.isdir(args.loadmodel): model_paths = glob.glob('{}/checkpoint_*.tar'.format( args.loadmodel)) elif os.path.isfile(args.loadmodel): model_paths = [args.loadmodel] else: model_paths = [] model_paths.sort(key=lambda x: int(x.split('_')[-1].split('.tar')[0])) # print(model_paths) for model_path in model_paths: print(model_path) assert os.path.exists(model_path) current_epoch = int(model_path.split('_')[-1].split('.tar')[0]) if current_epoch <= 10: continue state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) model.load_state_dict(state_dict['model']) model.cuda() if args.cuda else None valid_acc, total_test_loss = validate(TestImgLoader, model) print('Epoch: {}, acc1: {:.4f}, acc2: {:.4f}, acc3: {:.4f}, acc5: {:.4f}, epe: {:.4f}; test loss: {:.4f}' \ .format(current_epoch, valid_acc[0].avg, valid_acc[1].avg, valid_acc[2].avg, valid_acc[3].avg, valid_acc[4].avg, total_test_loss / len(TestImgLoader)))