Beispiel #1
0
    in_channels, out_channels, final_sigmoid, f_maps = 1, args.n_class, False, 64  # f_map=64

    if args.enable_ssc_unet:
      model = UNet3DSSC()
    else:
      model = UNet3D(in_channels,
                     out_channels,
                     final_sigmoid,
                     f_maps=f_maps,
                     layer_order='cbr',
                     num_groups=8,
                     number_of_fmaps = args.number_of_fmaps,
                     enable_deepmodel_pooling=args.enable_deepmodel_pooling,
                     width=args.width,
                     res_type=args.res_type)
    weight_init(model, mode=args.weight_init)

  # ==== Optimizer
  if args.optimizer == 'adam':
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.w_decay)
  else:
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,
                          weight_decay=args.w_decay, nesterov=True)

  # Components from SSC
  criterion = nn.CrossEntropyLoss(ignore_index=255)  # ignore empty space marked as 255

  if args.enable_cuda:
    model = model.cuda()

  # Tensorboard
Beispiel #2
0
        if opt.model == 'i3d':
            new_model = refine_model_I3D(model, neuron_mask_clean)
        else:
            new_model = refine_model_classification(
                model,
                neuron_mask_clean,
                opt.model,
                network_connection_dict=network_connection_dict,
                enable_raw_grad=opt.enable_raw_grad)
        del model
        model = new_model.cpu()

        # print('New model:', model)

        if hasattr(model, 'module'):
            weight_init(model.module, mode=opt.weight_init)
        else:
            weight_init(model, mode=opt.weight_init)

        n_params_refined, n_neurons_refined = do_statistics_model(model)
        print('Statistics, org, params: {}, neurons: {}; refined, params: {} ({:.4f}%), neurons: {} ({:.4f}%)' \
              .format(n_params_org,
                      n_neurons_org,
                      n_params_refined,
                      n_params_refined * 100 / n_params_org,
                      n_neurons_refined,
                      n_neurons_refined * 100 / n_neurons_org))

    # Calculate Flops, Params, Memory of New Model
    model_flops = copy.deepcopy(model)
    flops, params, memory, _ = profile(
Beispiel #3
0
        grad_mode = 'raw' if opt.enable_raw_grad else 'abs'
        if opt.weight_init == 'xn':
            file_path = 'data/stereo/stereo_kernel_hidden_prune_grad_{}.npy'.format(
                grad_mode)
        else:
            file_path = 'data/stereo/stereo_kernel_hidden_prune_grad_init{}_{}.npy' \
              .format(opt.weight_init, grad_mode)
        model = stackhourglass(opt.maxdisp)
    else:
        file_path = None
        model = None

    assert (model is not None) and (file_path is not None) \
           and os.path.exists(file_path), file_path

    weight_init(model, mode=opt.weight_init)
    model = model.cuda()

    # Get resource list
    if opt.dataset == 'brats':
        profile_input = torch.randn(1, 4, opt.spatial_size, opt.spatial_size,
                                    opt.spatial_size).cuda()
    elif opt.dataset == 'shapenet':
        profile_input = torch.randn(1, 1, opt.prune_spatial_size,
                                    opt.prune_spatial_size,
                                    opt.prune_spatial_size).cuda()
    elif opt.dataset == 'ucf101':  # the batch is 8 not 1 for mobilenetv2 in the rebuttal
        profile_input = torch.randn(1, in_channels, opt.sample_duration,
                                    opt.sample_size, opt.sample_size).cuda()
    elif opt.dataset == 'sceneflow':
        profile_input_L = torch.randn(3, 3, 256, 512).cuda()
Beispiel #4
0
def main(args):
    # Dataloader
    all_left_img, all_right_img, all_left_disp, test_left_img, \
      test_right_img, test_left_disp = lt.dataloader(args.datapath)

    # batch_size=12
    TrainImgLoader = torch.utils.data.DataLoader(DA.myImageFloder(
        all_left_img, all_right_img, all_left_disp, True),
                                                 batch_size=12,
                                                 shuffle=True,
                                                 num_workers=args.batch * 2,
                                                 drop_last=False)

    # batch_size=8
    TestImgLoader = torch.utils.data.DataLoader(DA.myImageFloder(
        test_left_img, test_right_img, test_left_disp, False),
                                                batch_size=4,
                                                shuffle=False,
                                                num_workers=8,
                                                drop_last=False)

    # Model
    model = stackhourglass(args.maxdisp)
    weight_init(model, mode=args.weight_init)

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))

    # =========================================================
    # Calculate Profile of Full Model
    model_full = copy.deepcopy(model)
    profile_input_L = torch.randn(3, 3, 256, 512).cuda()
    profile_input_R = torch.randn(3, 3, 256, 512).cuda()

    flops_full, params_full, memory_full, resource_list = \
      profile(model_full.cuda(),
              inputs=(profile_input_L, profile_input_R),
              verbose=False,
              resource_list_type=args.resource_list_type,
              mode=args.statistic_mode)

    del model_full
    print('Full model, flops: {:.4f}G, params: {:.4f}MB, memory: {:.4f}MB' \
          .format(flops_full / 1e9, params_full * 4 / (1024 ** 2), memory_full * 4 / (1024 ** 2)))

    # Prune including kernels and hidden layers
    if args.enable_neuron_prune or args.enable_hidden_layer_prune or args.enable_param_prune:
        args.spatial_size = args.prune_spatial_size
        args.dim = args.prune_spatial_size
        grad_mode = 'raw' if args.enable_raw_grad else 'abs'

        if args.weight_init == 'xn':
            file_path = 'data/stereo/stereo_kernel_hidden_prune_grad_{}.npy'.format(
                grad_mode)
        else:
            file_path = 'data/stereo/stereo_kernel_hidden_prune_grad_init{}_{}.npy' \
              .format(args.weight_init, grad_mode)

        assert (args.batch == 1 if (not os.path.exists(file_path)) else True)

        outputs = pruning(file_path,
                          model,
                          TrainImgLoader,
                          None,
                          args,
                          enable_3dunet=False,
                          enable_hidden_sum=False,
                          width=None,
                          resource_list=resource_list,
                          network_name='psm')
        assert outputs[0] == 0

        # neuron_mask_clean, hidden_mask = outputs[1], outputs[2]
        valid_neuron_list_clean, hidden_mask = outputs[1], outputs[2]

        if args.enable_neuron_prune or args.enable_param_prune:
            n_params_org, n_neurons_org = do_statistics_model(model)
            new_model = refine_model_PSM(model, valid_neuron_list_clean)

            if False:  # enable_dump_neuron_per_layer:
                dump_neuron_per_layer(copy.deepcopy(model),
                                      copy.deepcopy(new_model))

            del model
            model = new_model.cpu()
            weight_init(model, mode=args.weight_init)
            model = model.cuda()
            n_params_refined, n_neurons_refined = do_statistics_model(model)
            print('Statistics, org, params: {}, neurons: {}; refined, '\
                  'params: {} ({:.4f}%), neurons: {} ({:.4f}%)' \
                  .format(n_params_org,
                          n_neurons_org,
                          n_params_refined,
                          n_params_refined * 100 / n_params_org,
                          n_neurons_refined,
                          n_neurons_refined * 100 / n_neurons_org))

            # Calculate Flops, Params, Memory of New Model
            model_flops = copy.deepcopy(model)

            # display_model_structure(model_flops)

            flops, params, memory, _ = profile(
                model_flops.cuda(),
                inputs=(profile_input_L, profile_input_R),
                verbose=False,
                resource_list_type=args.resource_list_type,
                mode=args.statistic_mode)
            print('New model, flops: {:.4f}G, params: {:.4f}MB, memory: {:.4f}MB' \
                  .format(flops / 1e9, params * 4 / (1024 ** 2), memory * 4 / (1024 ** 2)))
            del model_flops, profile_input_L, profile_input_R

            # For new model
            optimizer = optim.Adam(model.parameters(),
                                   lr=args.lr,
                                   betas=(0.9, 0.999))
    # =========================================================

    # print(model)
    if False:
        num_layer_2D, num_layer_3D, num_neuron_2D, num_neuron_3D = check_neuron_ratio(
            model)
        print('Num layer 2D/3D:', num_layer_2D, num_layer_3D,
              ', num neuron 2D/3D:', num_neuron_2D, num_neuron_3D)

    print(args)
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # Resume
    if args.loadmodel is not None and os.path.isfile(args.loadmodel):
        state_dict = torch.load(args.loadmodel,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(state_dict['model'])
        optimizer = load_optimizer_state_dict(state_dict['optimizer'],
                                              optimizer,
                                              enable_cuda=args.cuda)

    if args.enable_train:
        if args.cuda:
            model = nn.DataParallel(model)
            model.cuda()

        # Train
        train_duration = 0
        epoch_start = int(args.loadmodel.split('.')[-2].split('_')[-1]) if (
            args.loadmodel is not None) else 0
        train_len = len(TrainImgLoader)

        for epoch in range(epoch_start + 1, args.epochs + 1):
            print('This is %d-th epoch' % (epoch))
            start_full_time = time.time()
            torch.manual_seed(epoch)

            if args.cuda:
                torch.cuda.manual_seed(epoch)

            total_train_loss = 0
            # adjust_learning_rate(optimizer, epoch)

            for batch_idx, (imgL_crop, imgR_crop, disp_crop_L,
                            _) in enumerate(TrainImgLoader):
                start_time = time.time()
                loss = train(imgL_crop, imgR_crop, disp_crop_L, model,
                             optimizer)
                print('Iter %d/%d training loss = %.3f , time = %.2f' \
                      % (batch_idx, train_len, loss, time.time() - start_time))
                total_train_loss += loss.item()

            train_duration += time.time() - start_full_time
            print('Epoch %d total training loss = %.3f' %
                  (epoch, total_train_loss / len(TrainImgLoader)))

            # SAVE
            if not os.path.exists(args.savemodel):
                os.makedirs(args.savemodel)

            savefilename = os.path.join(args.savemodel,
                                        'checkpoint_' + str(epoch) + '.tar')
            model_state_dict = get_model_state_dict(model)
            torch.save(
                {
                    'epoch': epoch,
                    'model': model_state_dict,
                    'train_loss': total_train_loss / len(TrainImgLoader),
                    'optimizer': optimizer.state_dict()
                }, savefilename)

            # Valid
            if epoch >= args.valid_min_epoch:
                valid_acc, total_test_loss = validate(TestImgLoader, model)
                print('Epoch: {}, acc1: {:.4f}, acc2: {:.4f}, acc3: {:.4f}, acc5: {:.4f}, epe: {:.4f}; test loss: {:.4f}' \
                      .format(epoch, valid_acc[0].avg, valid_acc[1].avg, valid_acc[2].avg, valid_acc[3].avg,
                              valid_acc[4].avg, total_test_loss / len(TestImgLoader)))

        print('Full training time = %.2f HR' % (train_duration / 3600))
    elif args.enable_test:
        if os.path.isdir(args.loadmodel):
            model_paths = glob.glob('{}/checkpoint_*.tar'.format(
                args.loadmodel))
        elif os.path.isfile(args.loadmodel):
            model_paths = [args.loadmodel]
        else:
            model_paths = []

        model_paths.sort(key=lambda x: int(x.split('_')[-1].split('.tar')[0]))
        # print(model_paths)

        for model_path in model_paths:
            print(model_path)
            assert os.path.exists(model_path)
            current_epoch = int(model_path.split('_')[-1].split('.tar')[0])
            if current_epoch <= 10: continue
            state_dict = torch.load(model_path,
                                    map_location=lambda storage, loc: storage)
            model.load_state_dict(state_dict['model'])
            model.cuda() if args.cuda else None
            valid_acc, total_test_loss = validate(TestImgLoader, model)
            print('Epoch: {}, acc1: {:.4f}, acc2: {:.4f}, acc3: {:.4f}, acc5: {:.4f}, epe: {:.4f}; test loss: {:.4f}' \
                  .format(current_epoch, valid_acc[0].avg, valid_acc[1].avg, valid_acc[2].avg, valid_acc[3].avg,
                          valid_acc[4].avg, total_test_loss / len(TestImgLoader)))