def __init__(self, args, loader, my_model, my_loss, ckp=None, writer=None, converging=False, model_teacher=None): self.args = args self.ckp = ckp self.loader_train = loader.loader_train self.loader_test = loader.loader_test self.model = my_model self.model_teacher = model_teacher self.loss = my_loss self.writer = writer self.loss_mse = nn.MSELoss( ) if self.args.distillation_inter == 'kd' else None if args.data_train.find('CIFAR') >= 0: self.input_dim = (3, 32, 32) elif args.data_train.find('Tiny_ImageNet') >= 0: self.input_dim = (3, 64, 64) else: self.input_dim = (3, 224, 224) set_output_dimension(self.model.get_model(), self.input_dim) self.flops = get_flops(self.model.get_model()) self.flops_prune = self.flops # at initialization, no pruning is conducted. self.flops_compression_ratio = self.flops_prune / self.flops self.params = get_parameters(self.model.get_model()) self.params_prune = self.params self.params_compression_ratio = self.params_prune / self.params self.flops_ratio_log = [] self.params_ratio_log = [] self.converging = converging self.ckp.write_log( '\nThe computation complexity and number of parameters of the current network is as follows.' '\nFlops: {:.4f} [G]\tParams {:.2f} [k]'.format( self.flops / 10.**9, self.params / 10.**3)) self.flops_another = get_model_flops(self.model.get_model(), self.input_dim, False) self.ckp.write_log( 'Flops: {:.4f} [G] calculated by the original counter. \nMake sure that the two calculated ' 'Flops are the same.\n'.format(self.flops_another / 10.**9)) self.optimizer = utility.make_optimizer_dhp(args, self.model, ckp=ckp, converging=converging) self.scheduler = utility.make_scheduler_dhp(args, self.optimizer, args.decay.split('+')[0], converging=converging) self.device = torch.device('cpu' if args.cpu else 'cuda') if args.model.find('INQ') >= 0: self.inq_steps = args.inq_steps else: self.inq_steps = None
def __init__(self, args, loader, my_model, my_loss, ckp, writer=None, converging=False): self.args = args self.scale = args.scale self.ckp = ckp self.loader_train = loader.loader_train self.loader_test = loader.loader_test self.model = my_model self.loss = my_loss self.writer = writer self.optimizer = utility.make_optimizer_dhp(args, self.model, ckp, converging=converging) self.scheduler = utility.make_scheduler_dhp( args, self.optimizer, int(args.lr_decay_step.split('+')[0]), converging=converging) if self.args.model.lower().find( 'unet') >= 0 or self.args.model.lower().find('dncnn') >= 0: self.input_dim = (1, args.input_dim, args.input_dim) else: self.input_dim = (3, args.input_dim, args.input_dim) # embed() set_output_dimension(self.model.get_model(), self.input_dim) self.flops = get_flops(self.model.get_model()) self.flops_prune = self.flops # at initialization, no pruning is conducted. self.flops_compression_ratio = self.flops_prune / self.flops self.params = get_parameters(self.model.get_model()) self.params_prune = self.params self.params_compression_ratio = self.params_prune / self.params self.flops_ratio_log = [] self.params_ratio_log = [] self.converging = converging self.ckp.write_log( '\nThe computation complexity and number of parameters of the current network is as follows.' '\nFlops: {:.4f} [G]\tParams {:.2f} [k]'.format( self.flops / 10.**9, self.params / 10.**3)) self.flops_another = get_model_flops(self.model.get_model(), self.input_dim, False) self.ckp.write_log( 'Flops: {:.4f} [G] calculated by the original counter. \nMake sure that the two calculated ' 'Flops are the same.\n'.format(self.flops_another / 10.**9)) self.error_last = 1e8
def __init__(self, args, loader, my_model, my_loss, ckp, writer=None): self.args = args self.ckp = ckp self.loader_train = loader.loader_train self.loader_test = loader.loader_test self.model = my_model self.loss = my_loss self.writer = writer if args.data_train.find('CIFAR') >= 0: self.input_dim = (3, 32, 32) elif args.data_train.find('Tiny_ImageNet') >= 0: self.input_dim = (3, 64, 64) else: self.input_dim = (3, 224, 224) set_output_dimension(self.model.get_model(), self.input_dim) self.flops = get_flops(self.model.get_model()) self.params = get_parameters(self.model.get_model()) self.ckp.write_log( '\nThe computation complexity and number of parameters of the current network is as follows.' '\nFlops: {:.4f} [G]\tParams {:.2f} [k]'.format( self.flops / 10.**9, self.params / 10.**3)) self.flops_another = get_model_flops(self.model.get_model(), self.input_dim, False) self.ckp.write_log( 'Flops: {:.4f} [G] calculated by the original counter. \nMake sure that the two calculated ' 'Flops are the same.\n'.format(self.flops_another / 10.**9)) self.optimizer = utility.make_optimizer(args, self.model, ckp=ckp) self.scheduler = utility.make_scheduler(args, self.optimizer) self.device = torch.device('cpu' if args.cpu else 'cuda') if args.model.find('INQ') >= 0: self.inq_steps = args.inq_steps else: self.inq_steps = None
# ================================================================================================================== # Step 3: Pruning -> prune the derived sparse model and prepare the trainer instance for finetuning or testing # ================================================================================================================== t.reset_after_optimization() if args.print_model: print(t.model.get_model()) print(t.model.get_model(), file=checkpoint.log_file) # ================================================================================================================== # Step 4: Continue the training / Testing -> continue to train the pruned model to have a higher accuracy. # ================================================================================================================== while not t.terminate(): t.train() t.test() set_output_dimension(model.get_model(), t.input_dim) flops = get_flops(model.get_model()) params = get_parameters(model.get_model()) print( '\nThe computation complexity and number of parameters of the current network is as follows.' '\nFlops: {:.4f} [G]\tParams {:.2f} [k]\n'.format( flops / 10.**9, params / 10.**3)) if args.summary: t.writer.close() if args.print_model: print(t.model.get_model()) print(t.model.get_model(), file=checkpoint.log_file) # for m in t.model.parameters(): # print(m.shape) checkpoint.done()
# Step 3: Pruning -> prune the derived sparse model and prepare the trainer instance for finetuning or testing # ================================================================================================================== t.reset_after_searching() if args.print_model: print(t.model.get_model()) print(t.model.get_model(), file=checkpoint.log_file) # ================================================================================================================== # Step 4: Fintuning/Testing -> finetune the pruned model to have a higher accuracy. # ================================================================================================================== while not t.terminate(): t.train() t.test() set_output_dimension(network_model.get_model(), t.input_dim) flops = get_flops(network_model.get_model()) params = get_parameters(network_model.get_model()) print( '\nThe computation complexity and number of parameters of the current network is as follows.' '\nFlops: {:.4f} [G]\tParams {:.2f} [k]\n'.format( flops / 10.**9, params / 10.**3)) if args.summary: t.writer.close() if args.print_model: print(t.model.get_model()) print(t.model.get_model(), file=checkpoint.log_file) checkpoint.done()
# model # if args.decomp_type == 'gsvd' or args.decomp_type == 'svd-mse' or args.comp_rule.find('f-norm') >= 0 \ # or args.model.lower().find('prune_resnet56') >= 0: # my_model = Model(args, checkpoint, loader.loader_train) # else: my_model = Model(args, checkpoint) if args.data_train.find('CIFAR') >= 0: input_dim = (3, 32, 32) elif args.data_train.find('Tiny_ImageNet') >= 0: input_dim = (3, 64, 64) else: input_dim = (3, 224, 224) set_output_dimension(my_model.get_model(), input_dim) flops = get_flops(my_model.get_model()) params = get_parameters(my_model.get_model()) print('\nThe computation complexity and number of parameters of the current network is as follows.' '\nFlops: {:.4f} [G]\tParams {:.2f} [k]'.format(flops / 10. ** 9, params / 10. ** 3)) flops_another = get_model_flops(my_model.get_model(), input_dim, False) print('Flops: {:.4f} [G] calculated by the original counter. \nMake sure that the two calculated ' 'Flops are the same.\n'.format(flops_another / 10. ** 9)) # data loader loader = Data(args) # loss function loss = Loss(args, checkpoint) # writer writer = SummaryWriter(os.path.join(args.dir_save, args.save), comment='optimization') if args.summary else None # trainer