def __init__(self,
                 args,
                 loader,
                 my_model,
                 my_loss,
                 ckp=None,
                 writer=None,
                 converging=False,
                 model_teacher=None):
        self.args = args
        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.model_teacher = model_teacher
        self.loss = my_loss
        self.writer = writer
        self.loss_mse = nn.MSELoss(
        ) if self.args.distillation_inter == 'kd' else None
        if args.data_train.find('CIFAR') >= 0:
            self.input_dim = (3, 32, 32)
        elif args.data_train.find('Tiny_ImageNet') >= 0:
            self.input_dim = (3, 64, 64)
        else:
            self.input_dim = (3, 224, 224)
        set_output_dimension(self.model.get_model(), self.input_dim)
        self.flops = get_flops(self.model.get_model())
        self.flops_prune = self.flops  # at initialization, no pruning is conducted.
        self.flops_compression_ratio = self.flops_prune / self.flops
        self.params = get_parameters(self.model.get_model())
        self.params_prune = self.params
        self.params_compression_ratio = self.params_prune / self.params
        self.flops_ratio_log = []
        self.params_ratio_log = []
        self.converging = converging
        self.ckp.write_log(
            '\nThe computation complexity and number of parameters of the current network is as follows.'
            '\nFlops: {:.4f} [G]\tParams {:.2f} [k]'.format(
                self.flops / 10.**9, self.params / 10.**3))
        self.flops_another = get_model_flops(self.model.get_model(),
                                             self.input_dim, False)
        self.ckp.write_log(
            'Flops: {:.4f} [G] calculated by the original counter. \nMake sure that the two calculated '
            'Flops are the same.\n'.format(self.flops_another / 10.**9))

        self.optimizer = utility.make_optimizer_dhp(args,
                                                    self.model,
                                                    ckp=ckp,
                                                    converging=converging)
        self.scheduler = utility.make_scheduler_dhp(args,
                                                    self.optimizer,
                                                    args.decay.split('+')[0],
                                                    converging=converging)
        self.device = torch.device('cpu' if args.cpu else 'cuda')

        if args.model.find('INQ') >= 0:
            self.inq_steps = args.inq_steps
        else:
            self.inq_steps = None
Example #2
0
    def __init__(self,
                 args,
                 loader,
                 my_model,
                 my_loss,
                 ckp,
                 writer=None,
                 converging=False):
        self.args = args
        self.scale = args.scale
        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.writer = writer
        self.optimizer = utility.make_optimizer_dhp(args,
                                                    self.model,
                                                    ckp,
                                                    converging=converging)
        self.scheduler = utility.make_scheduler_dhp(
            args,
            self.optimizer,
            int(args.lr_decay_step.split('+')[0]),
            converging=converging)
        if self.args.model.lower().find(
                'unet') >= 0 or self.args.model.lower().find('dncnn') >= 0:
            self.input_dim = (1, args.input_dim, args.input_dim)
        else:
            self.input_dim = (3, args.input_dim, args.input_dim)
        # embed()
        set_output_dimension(self.model.get_model(), self.input_dim)
        self.flops = get_flops(self.model.get_model())
        self.flops_prune = self.flops  # at initialization, no pruning is conducted.
        self.flops_compression_ratio = self.flops_prune / self.flops
        self.params = get_parameters(self.model.get_model())

        self.params_prune = self.params
        self.params_compression_ratio = self.params_prune / self.params
        self.flops_ratio_log = []
        self.params_ratio_log = []
        self.converging = converging
        self.ckp.write_log(
            '\nThe computation complexity and number of parameters of the current network is as follows.'
            '\nFlops: {:.4f} [G]\tParams {:.2f} [k]'.format(
                self.flops / 10.**9, self.params / 10.**3))
        self.flops_another = get_model_flops(self.model.get_model(),
                                             self.input_dim, False)
        self.ckp.write_log(
            'Flops: {:.4f} [G] calculated by the original counter. \nMake sure that the two calculated '
            'Flops are the same.\n'.format(self.flops_another / 10.**9))

        self.error_last = 1e8
    def __init__(self, args, loader, my_model, my_loss, ckp, writer=None):
        self.args = args
        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.writer = writer

        if args.data_train.find('CIFAR') >= 0:
            self.input_dim = (3, 32, 32)
        elif args.data_train.find('Tiny_ImageNet') >= 0:
            self.input_dim = (3, 64, 64)
        else:
            self.input_dim = (3, 224, 224)
        set_output_dimension(self.model.get_model(), self.input_dim)
        self.flops = get_flops(self.model.get_model())
        self.params = get_parameters(self.model.get_model())
        self.ckp.write_log(
            '\nThe computation complexity and number of parameters of the current network is as follows.'
            '\nFlops: {:.4f} [G]\tParams {:.2f} [k]'.format(
                self.flops / 10.**9, self.params / 10.**3))
        self.flops_another = get_model_flops(self.model.get_model(),
                                             self.input_dim, False)
        self.ckp.write_log(
            'Flops: {:.4f} [G] calculated by the original counter. \nMake sure that the two calculated '
            'Flops are the same.\n'.format(self.flops_another / 10.**9))

        self.optimizer = utility.make_optimizer(args, self.model, ckp=ckp)
        self.scheduler = utility.make_scheduler(args, self.optimizer)
        self.device = torch.device('cpu' if args.cpu else 'cuda')

        if args.model.find('INQ') >= 0:
            self.inq_steps = args.inq_steps
        else:
            self.inq_steps = None
Example #4
0
    # ==================================================================================================================
    # Step 3: Pruning -> prune the derived sparse model and prepare the trainer instance for finetuning or testing
    # ==================================================================================================================
    t.reset_after_optimization()
    if args.print_model:
        print(t.model.get_model())
        print(t.model.get_model(), file=checkpoint.log_file)

    # ==================================================================================================================
    # Step 4: Continue the training / Testing -> continue to train the pruned model to have a higher accuracy.
    # ==================================================================================================================
    while not t.terminate():
        t.train()
        t.test()

    set_output_dimension(model.get_model(), t.input_dim)
    flops = get_flops(model.get_model())
    params = get_parameters(model.get_model())
    print(
        '\nThe computation complexity and number of parameters of the current network is as follows.'
        '\nFlops: {:.4f} [G]\tParams {:.2f} [k]\n'.format(
            flops / 10.**9, params / 10.**3))

    if args.summary:
        t.writer.close()
    if args.print_model:
        print(t.model.get_model())
        print(t.model.get_model(), file=checkpoint.log_file)
    # for m in t.model.parameters():
    #     print(m.shape)
    checkpoint.done()
Example #5
0
    # Step 3: Pruning -> prune the derived sparse model and prepare the trainer instance for finetuning or testing
    # ==================================================================================================================

    t.reset_after_searching()
    if args.print_model:
        print(t.model.get_model())
        print(t.model.get_model(), file=checkpoint.log_file)

    # ==================================================================================================================
    # Step 4: Fintuning/Testing -> finetune the pruned model to have a higher accuracy.
    # ==================================================================================================================

    while not t.terminate():
        t.train()
        t.test()

    set_output_dimension(network_model.get_model(), t.input_dim)
    flops = get_flops(network_model.get_model())
    params = get_parameters(network_model.get_model())
    print(
        '\nThe computation complexity and number of parameters of the current network is as follows.'
        '\nFlops: {:.4f} [G]\tParams {:.2f} [k]\n'.format(
            flops / 10.**9, params / 10.**3))

    if args.summary:
        t.writer.close()
    if args.print_model:
        print(t.model.get_model())
        print(t.model.get_model(), file=checkpoint.log_file)
    checkpoint.done()
    # model
    # if args.decomp_type == 'gsvd' or args.decomp_type == 'svd-mse' or args.comp_rule.find('f-norm') >= 0 \
    #         or args.model.lower().find('prune_resnet56') >= 0:
    #     my_model = Model(args, checkpoint, loader.loader_train)
    # else:
    my_model = Model(args, checkpoint)

    if args.data_train.find('CIFAR') >= 0:
        input_dim = (3, 32, 32)
    elif args.data_train.find('Tiny_ImageNet') >= 0:
        input_dim = (3, 64, 64)
    else:
        input_dim = (3, 224, 224)

    set_output_dimension(my_model.get_model(), input_dim)
    flops = get_flops(my_model.get_model())
    params = get_parameters(my_model.get_model())
    print('\nThe computation complexity and number of parameters of the current network is as follows.'
                       '\nFlops: {:.4f} [G]\tParams {:.2f} [k]'.format(flops / 10. ** 9, params / 10. ** 3))
    flops_another = get_model_flops(my_model.get_model(), input_dim, False)
    print('Flops: {:.4f} [G] calculated by the original counter. \nMake sure that the two calculated '
                       'Flops are the same.\n'.format(flops_another / 10. ** 9))

    # data loader
    loader = Data(args)
    # loss function
    loss = Loss(args, checkpoint)
    # writer
    writer = SummaryWriter(os.path.join(args.dir_save, args.save), comment='optimization') if args.summary else None
    # trainer