Exemplo n.º 1
0
def packing_job_ext_info(job_lsit_DO):
    """
    Packing additional information of the job into the job_list_DO(JobListDO)
    """
    ext_info = sqllite_agent.execute(ScrapydJobExtInfoSQLSet.SELECT_BY_ID,
                                     (job_lsit_DO.job_id, ))
    if ext_info is None or len(ext_info) <= 0: return
    ext_info = ext_info[0]
    job_lsit_DO.args = ext_info[1]
    job_lsit_DO.priority = ext_info[2]
    job_lsit_DO.creation_time = ext_info[3]
    job_lsit_DO.logs_name = str_to_list(ext_info[4], ',')
    job_lsit_DO.logs_url = str_to_list(ext_info[5], ',')
Exemplo n.º 2
0
def main(config):
    # init logger
    classes = {
        'cifar10': 10,
        'cifar100': 100,
        'mnist': 10,
        'tiny_imagenet': 200,
        'imagenet': 1000
    }
    logger, writer = init_logger(config)

    # build model
    model = models.__dict__[config.network]()
    mb = ModelBase(config.network, config.depth, config.dataset, model)
    mb.cuda()

    # preprocessing
    # ====================================== fetch configs ======================================
    ckpt_path = config.checkpoint_dir
    num_iterations = config.iterations
    target_ratio = config.target_ratio
    normalize = config.normalize
    # ====================================== fetch exception ======================================
    exception = get_exception_layers(mb.model,
                                     str_to_list(config.exception, ',', int))
    logger.info('Exception: ')

    for idx, m in enumerate(exception):
        logger.info('  (%d) %s' % (idx, m))

    # ====================================== fetch training schemes ======================================
    ratio = 1 - (1 - target_ratio)**(1.0 / num_iterations)
    learning_rates = str_to_list(config.learning_rate, ',', float)
    weight_decays = str_to_list(config.weight_decay, ',', float)
    training_epochs = str_to_list(config.epoch, ',', int)
    logger.info(
        'Normalize: %s, Total iteration: %d, Target ratio: %.2f, Iter ratio %.4f.'
        % (normalize, num_iterations, target_ratio, ratio))
    logger.info('Basic Settings: ')
    for idx in range(len(learning_rates)):
        logger.info('  %d: LR: %.5f, WD: %.5f, Epochs: %d' %
                    (idx, learning_rates[idx], weight_decays[idx],
                     training_epochs[idx]))

    # ====================================== get dataloader ======================================
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        config.traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    trainloader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=250,
                                              shuffle=True,
                                              num_workers=16,
                                              pin_memory=True,
                                              sampler=None)

    # ====================================== start pruning ======================================

    for iteration in range(num_iterations):
        logger.info(
            '** Target ratio: %.4f, iter ratio: %.4f, iteration: %d/%d.' %
            (target_ratio, ratio, iteration, num_iterations))

        assert num_iterations == 1
        print("=> Applying weight initialization.")
        mb.model.apply(weights_init)

        masks = GraSP(mb.model,
                      ratio,
                      trainloader,
                      'cuda',
                      num_classes=classes[config.dataset],
                      samples_per_class=config.samples_per_class,
                      num_iters=config.get('num_iters', 1))

        # ========== register mask ==================
        mb.masks = masks
        # ========== save pruned network ============
        logger.info('Saving..')
        state = {
            'net': mb.model,
            'acc': -1,
            'epoch': -1,
            'args': config,
            'mask': mb.masks,
            'ratio': mb.get_ratio_at_each_layer()
        }
        path = os.path.join(
            ckpt_path, 'prune_%s_%s%s_r%s_it%d.pth.tar' %
            (config.dataset, config.network, config.depth, config.target_ratio,
             iteration))
        torch.save(state, path)

        # ========== print pruning details ============
        logger.info('**[%d] Mask and training setting: ' % iteration)
        print_mask_information(mb, logger)
Exemplo n.º 3
0
def main(config, args):
    # init logger
    classes = {
        'cifar10': 10,
        'cifar100': 100,
        'mnist': 10,
        'tiny_imagenet': 200
    }
    logger, writer = init_logger(config, args)
    best_acc_vec = []
    test_acc_vec_vec = []

    for n_runs in range(1):
        if args.sigma_w2 != None and n_runs != 0:
            break

        # build model
        model = get_network(config.network,
                            config.depth,
                            config.dataset,
                            use_bn=config.get('use_bn', args.bn),
                            scaled=args.scaled_init,
                            act=args.act)
        mask = None
        mb = ModelBase(config.network, config.depth, config.dataset, model)
        mb.cuda()
        if mask is not None:
            mb.register_mask(mask)
            ratio_vec_ = print_mask_information(mb, logger)

        # preprocessing
        # ====================================== get dataloader ======================================
        trainloader, testloader = get_dataloader(config.dataset,
                                                 config.batch_size, 256, 4)
        # ====================================== fetch configs ======================================
        ckpt_path = config.checkpoint_dir
        num_iterations = config.iterations
        if args.target_ratio == None:
            target_ratio = config.target_ratio
        else:
            target_ratio = args.target_ratio

        normalize = config.normalize
        # ====================================== fetch exception ======================================
        exception = get_exception_layers(
            mb.model, str_to_list(config.exception, ',', int))
        logger.info('Exception: ')

        for idx, m in enumerate(exception):
            logger.info('  (%d) %s' % (idx, m))

        # ====================================== fetch training schemes ======================================
        ratio = 1 - (1 - target_ratio)**(1.0 / num_iterations)
        learning_rates = str_to_list(config.learning_rate, ',', float)
        weight_decays = str_to_list(config.weight_decay, ',', float)
        training_epochs = str_to_list(config.epoch, ',', int)
        logger.info(
            'Normalize: %s, Total iteration: %d, Target ratio: %.2f, Iter ratio %.4f.'
            % (normalize, num_iterations, target_ratio, ratio))
        logger.info('Basic Settings: ')
        for idx in range(len(learning_rates)):
            logger.info('  %d: LR: %.5f, WD: %.5f, Epochs: %d' %
                        (idx, learning_rates[idx], weight_decays[idx],
                         training_epochs[idx]))

        # ====================================== start pruning ======================================
        iteration = 0
        for _ in range(1):
            logger.info(
                '** Target ratio: %.4f, iter ratio: %.4f, iteration: %d/%d.' %
                (target_ratio, ratio, 1, num_iterations))

            # mb.model.apply(weights_init)
            print('#' * 40)
            print('USING {} INIT SCHEME'.format(args.init))
            print('#' * 40)
            if args.init == 'kaiming_xavier':
                mb.model.apply(weights_init_kaiming_xavier)
            elif args.init == 'kaiming':
                if args.act == 'relu' or args.act == 'elu':
                    mb.model.apply(weights_init_kaiming_relu)
                elif args.act == 'tanh':
                    mb.model.apply(weights_init_kaiming_tanh)
            elif args.init == 'xavier':
                mb.model.apply(weights_init_xavier)
            elif args.init == 'EOC':
                mb.model.apply(weights_init_EOC)
            elif args.init == 'ordered':

                def weights_init_ord(m):
                    if isinstance(m, nn.Conv2d):
                        ord_weights(m.weight, sigma_w2=args.sigma_w2)
                        if m.bias is not None:
                            ord_bias(m.bias)
                    elif isinstance(m, nn.Linear):
                        ord_weights(m.weight, sigma_w2=args.sigma_w2)
                        if m.bias is not None:
                            ord_bias(m.bias)
                    elif isinstance(m, nn.BatchNorm2d):
                        # Note that BN's running_var/mean are
                        # already initialized to 1 and 0 respectively.
                        if m.weight is not None:
                            m.weight.data.fill_(1.0)
                        if m.bias is not None:
                            m.bias.data.zero_()

                mb.model.apply(weights_init_ord)
            else:
                raise NotImplementedError

            print("=> Applying weight initialization(%s)." %
                  config.get('init_method', 'kaiming'))
            print("Iteration of: %d/%d" % (iteration, num_iterations))

            if config.pruner == 'SNIP':
                print('=> Using SNIP')
                masks, scaled_masks = SNIP(
                    mb.model,
                    ratio,
                    trainloader,
                    'cuda',
                    num_classes=classes[config.dataset],
                    samples_per_class=config.samples_per_class,
                    num_iters=config.get('num_iters', 1),
                    scaled_init=args.scaled_init)
            elif config.pruner == 'GraSP':
                print('=> Using GraSP')
                masks, scaled_masks = GraSP(
                    mb.model,
                    ratio,
                    trainloader,
                    'cuda',
                    num_classes=classes[config.dataset],
                    samples_per_class=config.samples_per_class,
                    num_iters=config.get('num_iters', 1),
                    scaled_init=args.scaled_init)
            iteration = 0

            ################################################################################
            _masks = None
            _masks_scaled = None
            if not args.bn:
                # build model that has the same weights as the pruned network but with BN now !
                model2 = get_network(config.network,
                                     config.depth,
                                     config.dataset,
                                     use_bn=config.get('use_bn', True),
                                     scaled=args.scaled_init,
                                     act=args.act)
                weights_temp = []
                for layer_old in mb.model.modules():
                    if isinstance(layer_old, nn.Conv2d) or isinstance(
                            layer_old, nn.Linear):
                        weights_temp.append(layer_old.weight)
                idx = 0
                for layer_new in model2.modules():
                    if isinstance(layer_new, nn.Conv2d) or isinstance(
                            layer_new, nn.Linear):
                        layer_new.weight.data = weights_temp[idx]
                        idx += 1

                # Creating a base model with BN included now
                mb = ModelBase(config.network, config.depth, config.dataset,
                               model2)
                mb.cuda()

                _masks = dict()
                _masks_scaled = dict()
                layer_keys_new = []
                for layer in (mb.model.modules()):
                    if isinstance(layer, nn.Conv2d) or isinstance(
                            layer, nn.Linear):
                        layer_keys_new.append(layer)

                for new_keys, old_keys in zip(layer_keys_new, masks.keys()):
                    _masks[new_keys] = masks[old_keys]
                    if args.scaled_init:
                        _masks_scaled[new_keys] = scaled_masks[old_keys]
            ################################################################################

            if _masks == None:
                _masks = masks
                _masks_scaled = scaled_masks

            # ========== register mask ==================
            mb.register_mask(_masks)

            ## ========== debugging ==================

            if args.scaled_init:
                if config.network == 'vgg':
                    print('scaling VGG')
                    mb.scaling_weights(_masks_scaled)

            # ========== save pruned network ============
            logger.info('Saving..')
            state = {
                'net': mb.model,
                'acc': -1,
                'epoch': -1,
                'args': config,
                'mask': mb.masks,
                'ratio': mb.get_ratio_at_each_layer()
            }
            path = os.path.join(
                ckpt_path, 'prune_%s_%s%s_r%s_it%d.pth.tar' %
                (config.dataset, config.network, config.depth, target_ratio,
                 iteration))
            torch.save(state, path)

            # ========== print pruning details ============
            logger.info('**[%d] Mask and training setting: ' % iteration)
            ratio_vec_ = print_mask_information(mb, logger)
            logger.info('  LR: %.5f, WD: %.5f, Epochs: %d' %
                        (learning_rates[iteration], weight_decays[iteration],
                         training_epochs[iteration]))

            results_path = config.summary_dir + args.init + '_sp' + str(
                args.target_ratio).replace('.', '_')
            if args.scaled_init:
                results_path += '_scaled'
            if args.bn:
                results_path += '_bn'

            if args.sigma_w2 != None and args.init == 'ordered':
                results_path += '_sgw2{}'.format(args.sigma_w2).replace(
                    '.', '_')

            results_path += '_' + args.act + '_' + str(config.depth)
            print('saving the ratios')
            print(results_path)
            if not os.path.isdir(results_path): os.mkdir(results_path)
            np.save(results_path + '/ratios_pruned{}'.format(args.seed_tiny),
                    np.array(ratio_vec_))

            # if args.sigma_w2 != None:
            # 	break
            # ========== finetuning =======================
            best_acc, test_acc_vec = train_once(
                mb=mb,
                net=mb.model,
                trainloader=trainloader,
                testloader=testloader,
                writer=writer,
                config=config,
                ckpt_path=ckpt_path,
                learning_rate=learning_rates[iteration],
                weight_decay=weight_decays[iteration],
                num_epochs=training_epochs[iteration],
                iteration=iteration,
                logger=logger,
                args=args)

            best_acc_vec.append(best_acc)
            test_acc_vec_vec.append(test_acc_vec)

            np.save(results_path + '/best_acc{}'.format(args.seed_tiny),
                    np.array(best_acc_vec))
            np.save(results_path + '/test_acc{}'.format(args.seed_tiny),
                    np.array(test_acc_vec_vec))
Exemplo n.º 4
0
def main(config):
    stats = {}
    device = 'cuda'
    criterion = torch.nn.CrossEntropyLoss()

    # config = init_config() if config is None else config
    logger, writer = init_summary_writer(config)
    trainloader, testloader = init_dataloader(config)
    net, bottleneck_net = init_network(config, logger, device)
    pruner = init_pruner(net, bottleneck_net, config, writer, logger)

    # start pruning
    epochs = str_to_list(config.epoch, ',', int)
    learning_rates = str_to_list(config.learning_rate, ',', float)
    weight_decays = str_to_list(config.weight_decay, ',', float)
    ratios = str_to_list(config.ratio, ',', float)

    fisher_type = config.fisher_type  # empirical|true
    fisher_mode = config.fisher_mode  # eigen|full|diagonal
    normalize = config.normalize
    prune_mode = config.prune_mode  # one-pass | iterative
    fix_rotation = config.get('fix_rotation', True)

    assert (len(epochs) == len(learning_rates)
            and len(learning_rates) == len(weight_decays)
            and len(weight_decays) == len(ratios))

    total_parameters = count_parameters(net.train())
    for it in range(len(epochs)):
        epoch = epochs[it]
        lr = learning_rates[it]
        wd = weight_decays[it]
        ratio = ratios[it]
        logger.info('-' * 120)
        logger.info('** [%d], Ratio: %.2f, epoch: %d, lr: %.4f, wd: %.4f' %
                    (it, ratio, epoch, lr, wd))
        logger.info(
            'Reinit: %s, Fisher_mode: %s, fisher_type: %s, normalize: %s, fix_rotation: %s.'
            % (config.re_init, fisher_mode, fisher_type, normalize,
               fix_rotation))
        pruner.fix_rotation = fix_rotation

        # conduct pruning
        cfg = pruner.make_pruned_model(trainloader,
                                       criterion=criterion,
                                       device=device,
                                       fisher_type=fisher_type,
                                       prune_ratio=ratio,
                                       normalize=normalize,
                                       re_init=config.re_init)

        # for tracking the best accuracy
        compression_ratio, unfair_ratio, all_numel, rotation_numel = compute_ratio(
            pruner.model, total_parameters, fix_rotation, logger)
        if config.dataset == 'tiny_imagenet':
            total_flops, rotation_flops = print_model_param_flops(pruner.model,
                                                                  64,
                                                                  cuda=True)
        else:
            total_flops, rotation_flops = print_model_param_flops(pruner.model,
                                                                  32,
                                                                  cuda=True)
        train_loss_pruned, train_acc_pruned = pruner.test_model(
            trainloader, criterion, device)
        test_loss_pruned, test_acc_pruned = pruner.test_model(
            testloader, criterion, device)

        # write results
        logger.info('Before: Accuracy: %.2f%%(train), %.2f%%(test).' %
                    (train_acc_pruned, test_acc_pruned))
        logger.info('        Loss:     %.2f  (train), %.2f  (test).' %
                    (train_loss_pruned, test_loss_pruned))

        test_loss_finetuned, test_acc_finetuned = pruner.fine_tune_model(
            trainloader=trainloader,
            testloader=testloader,
            criterion=criterion,
            optim=optim,
            learning_rate=lr,
            weight_decay=wd,
            nepochs=epoch)
        train_loss_finetuned, train_acc_finetuned = pruner.test_model(
            trainloader, criterion, device)
        logger.info('After:  Accuracy: %.2f%%(train), %.2f%%(test).' %
                    (train_acc_finetuned, test_acc_finetuned))
        logger.info('        Loss:     %.2f  (train), %.2f  (test).' %
                    (train_loss_finetuned, test_loss_finetuned))
        # save model

        stat = {
            'total_flops': total_flops,
            'rotation_flops': rotation_flops,
            'it': it,
            'prune_ratio': ratio,
            'cr': compression_ratio,
            'unfair_cr': unfair_ratio,
            'all_params': all_numel,
            'rotation_params': rotation_numel,
            'prune/train_loss': train_loss_pruned,
            'prune/train_acc': train_acc_pruned,
            'prune/test_loss': test_loss_pruned,
            'prune/test_acc': test_acc_pruned,
            'finetune/train_loss': train_loss_finetuned,
            'finetune/test_loss': test_loss_finetuned,
            'finetune/train_acc': train_acc_finetuned,
            'finetune/test_acc': test_acc_finetuned
        }
        save_model(config, it, pruner, cfg, stat)

        stats[it] = stat

        if prune_mode == 'one_pass':
            del net
            del pruner
            net, bottleneck_net = init_network(config, logger, device)
            pruner = init_pruner(net, bottleneck_net, config, writer, logger)
            pruner.iter = it
        with open(os.path.join(config.summary_dir, 'stats.json'), 'w') as f:
            json.dump(stats, f)
Exemplo n.º 5
0
def main(config):
    # init logger
    classes = {
        'cifar10': 10,
        'cifar100': 100,
        'mnist': 10,
        'tiny_imagenet': 200
    }
    logger, writer = init_logger(config)

    # build model
    model = get_network(config.network, config.depth, config.dataset, use_bn=config.get('use_bn', True))
    mask = None
    mb = ModelBase(config.network, config.depth, config.dataset, model)
    mb.cuda()
    if mask is not None:
        mb.register_mask(mask)
        print_mask_information(mb, logger)

    # preprocessing
    # ====================================== get dataloader ======================================
    trainloader, testloader = get_dataloader(config.dataset, config.batch_size, 256, 4, root='/home/wzn/PycharmProjects/GraSP/data')
    # ====================================== fetch configs ======================================
    ckpt_path = config.checkpoint_dir
    num_iterations = config.iterations
    target_ratio = config.target_ratio
    normalize = config.normalize
    # ====================================== fetch exception ======================================
    # exception = get_exception_layers(mb.model, str_to_list(config.exception, ',', int))
    # logger.info('Exception: ')
    #
    # for idx, m in enumerate(exception):
    #     logger.info('  (%d) %s' % (idx, m))

    # ====================================== fetch training schemes ======================================
    ratio = 1 - (1 - target_ratio) ** (1.0 / num_iterations)
    learning_rates = str_to_list(config.learning_rate, ',', float)
    weight_decays = str_to_list(config.weight_decay, ',', float)
    training_epochs = str_to_list(config.epoch, ',', int)
    logger.info('Normalize: %s, Total iteration: %d, Target ratio: %.2f, Iter ratio %.4f.' %
                (normalize, num_iterations, target_ratio, ratio))
    logger.info('Basic Settings: ')
    for idx in range(len(learning_rates)):
        logger.info('  %d: LR: %.5f, WD: %.5f, Epochs: %d' % (idx,
                                                              learning_rates[idx],
                                                              weight_decays[idx],
                                                              training_epochs[idx]))

    # ====================================== start pruning ======================================
    iteration = 0
    for _ in range(1):
        # logger.info('** Target ratio: %.4f, iter ratio: %.4f, iteration: %d/%d.' % (target_ratio,
        #                                                                             ratio,
        #                                                                             1,
        #                                                                             num_iterations))

        mb.model.apply(weights_init)
        print("=> Applying weight initialization(%s)." % config.get('init_method', 'kaiming'))


        # print("Iteration of: %d/%d" % (iteration, num_iterations))
        # masks = GraSP(mb.model, ratio, trainloader, 'cuda',
        #               num_classes=classes[config.dataset],
        #               samples_per_class=config.samples_per_class,
        #               num_iters=config.get('num_iters', 1))
        # iteration = 0
        # print('=> Using GraSP')
        # # ========== register mask ==================
        # mb.register_mask(masks)
        # # ========== save pruned network ============
        # logger.info('Saving..')
        # state = {
        #     'net': mb.model,
        #     'acc': -1,
        #     'epoch': -1,
        #     'args': config,
        #     'mask': mb.masks,
        #     'ratio': mb.get_ratio_at_each_layer()
        # }
        # path = os.path.join(ckpt_path, 'prune_%s_%s%s_r%s_it%d.pth.tar' % (config.dataset,
        #                                                                    config.network,
        #                                                                    config.depth,
        #                                                                    config.target_ratio,
        #                                                                    iteration))
        # torch.save(state, path)

        # # ========== print pruning details ============
        # logger.info('**[%d] Mask and training setting: ' % iteration)
        # print_mask_information(mb, logger)
        # logger.info('  LR: %.5f, WD: %.5f, Epochs: %d' %
        #             (learning_rates[iteration], weight_decays[iteration], training_epochs[iteration]))

        # ========== finetuning =======================
        train_once(mb=mb,
                   net=mb.model,
                   trainloader=trainloader,
                   testloader=testloader,
                   writer=writer,
                   config=config,
                   ckpt_path=ckpt_path,
                   learning_rate=learning_rates[iteration],
                   weight_decay=weight_decays[iteration],
                   num_epochs=training_epochs[iteration],
                   iteration=iteration,
                   logger=logger)
Exemplo n.º 6
0
def main(config):
    stats = {}
    if config.data_distributed:
        torch.distributed.init_process_group(backend="nccl")
    device = torch.device('cuda:0,1')
    criterion = torch.nn.CrossEntropyLoss()

    logger, writer = init_summary_writer(config)
    trainloader, testloader = init_dataloader(config)

    if config.data_distributed:
        trainset, testset = trainloader, testloader
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            trainset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(testset)
        trainloader = torch.utils.data.DataLoader(
            trainset,
            config.batch_size,
            False,
            num_workers=config.num_workers,
            pin_memory=True,
            drop_last=True,
            sampler=train_sampler)
        testloader = torch.utils.data.DataLoader(
            testset,
            config.batch_size,
            False,
            num_workers=config.num_workers,
            pin_memory=True,
            drop_last=True,
            sampler=test_sampler)

    hess_data = []
    if config.dataset == 'imagenet':
        hessianloader = get_hessianloader(config.dataset, 64)
        num_batch = config.hessian_batch_size // 64
        i = 0
        for data, label in hessianloader:
            i += 1
            hess_data.append((data, label))
            if i == num_batch:
                break
    else:
        hessianloader = get_hessianloader(config.dataset,
                                          config.hessian_batch_size)
        for data, label in hessianloader:
            hess_data = (data, label)

    net, bottleneck_net = init_network(config, logger, device,
                                       config.dataset == "imagenet")

    pruner = init_pruner(net, bottleneck_net, config, writer, logger)

    # total FLOPs calculation
    if config.dataset == 'tiny_imagenet':
        total_flops, _ = compute_model_param_flops(pruner.model, 64, cuda=True)
    elif config.dataset == 'imagenet':
        total_flops, _ = compute_model_param_flops(pruner.model,
                                                   224,
                                                   cuda=True)
    else:
        total_flops, _ = compute_model_param_flops(pruner.model, 32, cuda=True)

    # start pruning
    epochs = str_to_list(config.epoch, ',', int)
    learning_rates = str_to_list(config.learning_rate, ',', float)
    weight_decays = str_to_list(config.weight_decay, ',', float)
    ratios = str_to_list(config.ratio, ',', float)

    fisher_type = config.fisher_type  # empirical|true
    fisher_mode = config.fisher_mode  # eigen|full|diagonal
    normalize = config.normalize
    prune_mode = config.prune_mode  # one-pass | iterative
    fix_rotation = config.get('fix_rotation', True)

    assert (len(epochs) == len(learning_rates)
            and len(learning_rates) == len(weight_decays)
            and len(weight_decays) == len(ratios))

    total_parameters = count_parameters(net)
    for it in range(len(epochs)):
        epoch = epochs[it]
        lr = learning_rates[it]
        wd = weight_decays[it]
        ratio = ratios[it]
        logger.info('-' * 120)
        logger.info('** [%d], Ratio: %.2f, epoch: %d, lr: %.4f, wd: %.4f' %
                    (it, ratio, epoch, lr, wd))
        logger.info(
            'Reinit: %s, Fisher_mode: %s, fisher_type: %s, normalize: %s, fix_rotation: %s.'
            % (config.re_init, fisher_mode, fisher_type, normalize,
               fix_rotation))
        pruner.fix_rotation = fix_rotation

        # test pretrained model
        if config.init_test:
            train_loss_pruned, train_acc_pruned, top5_acc = pruner.test_model(
                trainloader, criterion, device)
            test_loss_pruned, test_acc_pruned, top5_acc = pruner.test_model(
                testloader, criterion, device)
            logger.info('Pretrain: Accuracy: %.2f%%(train), %.2f%%(test).' %
                        (train_acc_pruned, test_acc_pruned))
            logger.info('          Loss:     %.2f  (train), %.2f  (test).' %
                        (train_loss_pruned, test_loss_pruned))

        # conduct pruning
        if 'hessian' not in config.fisher_mode:
            cfg = pruner.make_pruned_model(trainloader,
                                           criterion=criterion,
                                           device=device,
                                           fisher_type=fisher_type,
                                           prune_ratio=ratio,
                                           normalize=normalize,
                                           re_init=config.re_init)
        else:
            cfg = pruner.make_pruned_model(hess_data,
                                           criterion=criterion,
                                           device=device,
                                           fisher_type=fisher_type,
                                           prune_ratio=ratio,
                                           normalize=normalize,
                                           re_init=config.re_init,
                                           n_v=config.nv)
        print(pruner.model)
        # for tracking the best accuracy
        compression_ratio, unfair_ratio, all_numel, rotation_numel = compute_ratio(
            pruner.model, total_parameters, fix_rotation, logger)
        if config.dataset == 'tiny_imagenet':
            remained_flops, rotation_flops = compute_model_param_flops(
                pruner.model, 64, cuda=True)
            logger.info(
                '  + Remained FLOPs: %.4fG(%.2f%%), Total FLOPs: %.4fG' %
                (remained_flops / 1e9, 100. * remained_flops / total_flops,
                 total_flops / 1e9))
        elif config.dataset == 'imagenet':
            remained_flops, rotation_flops = compute_model_param_flops(
                pruner.model, 224, cuda=True)
            logger.info(
                '  + Remained FLOPs: %.4fG(%.2f%%), Total FLOPs: %.4fG' %
                (remained_flops / 1e9, 100. * remained_flops / total_flops,
                 total_flops / 1e9))
        else:
            remained_flops, rotation_flops = compute_model_param_flops(
                pruner.model, 32, cuda=True)
            logger.info(
                '  + Remained FLOPs: %.4fG(%.2f%%), Total FLOPs: %.4fG' %
                (remained_flops / 1e9, 100. * remained_flops / total_flops,
                 total_flops / 1e9))

        logger.info(f"Total Flops: {remained_flops}")

        test_loss_pruned, test_acc_pruned, top5_acc = pruner.test_model(
            testloader, criterion, device)
        if config.dataset != 'imagenet':
            time_loader = get_hessianloader(config.dataset, 1)
            run_time = pruner.speed_model(time_loader, criterion, device)
            logger.info(f"Total Run Time: {run_time}")

        test_loss_finetuned, test_acc_finetuned = pruner.fine_tune_model(
            trainloader=trainloader,
            testloader=testloader,
            criterion=criterion,
            optim=optim,
            learning_rate=lr,
            weight_decay=wd,
            nepochs=epoch)
        train_loss_finetuned, train_acc_finetuned, top5_acc = pruner.test_model(
            trainloader, criterion, device)
        logger.info(
            f'After {config.dataset, config.network, config.depth}:  Accuracy: %.2f%%(train), %.2f%%.'
            % (train_acc_finetuned, test_acc_finetuned))
        logger.info('        Loss:     %.2f  (train), %.2f  .' %
                    (train_loss_finetuned, test_loss_finetuned))

        stat = {
            'total_flops': total_flops,
            'rotation_flops': rotation_flops,
            'flops_remained': float(100. * remained_flops / total_flops),
            'it': it,
            'prune_ratio': ratio,
            'cr': compression_ratio,
            'unfair_cr': unfair_ratio,
            'all_params': all_numel,
            'rotation_params': rotation_numel,
            'prune/test_loss': test_loss_pruned,
            'prune/test_acc': test_acc_pruned,
            'finetune/train_loss': train_loss_finetuned,
            'finetune/test_loss': test_loss_finetuned,
            'finetune/train_acc': train_acc_finetuned,
            'finetune/test_acc': test_acc_finetuned
        }

        print('saving checkpoint')
        save_model(config, it, pruner, cfg, stat)

        stats[it] = stat

        if prune_mode == 'one_pass':
            print('one_pass')
            del net
            del pruner
            net, bottleneck_net = init_network(config, logger, device,
                                               config.dataset == "imagenet")
            pruner = init_pruner(net, bottleneck_net, config, writer, logger)
            pruner.iter = it
        with open(
                os.path.join(config.saving_log, f'stats_{running_time}.json'),
                'w') as f:
            json.dump(stats, f)
        if prune_mode != 'one_pass':
            with open(os.path.join(config.saving_log, f'stats{it}.json'),
                      'w') as f:
                json.dump(stats, f)