def before_train(self, logs=None): """Be called before the train process.""" self.config = self.trainer.config model_code = copy.deepcopy(self.trainer.model.desc) model = self.trainer.model logging.info('current code: %s, %s', model_code.nbit_w_list, model_code.nbit_a_list) quantizer = Quantizer(model, model_code.nbit_w_list, model_code.nbit_a_list) model = quantizer() self.trainer.model = model count_input = [1, 3, 32, 32] sess_config = None if vega.is_torch_backend(): model = model.cuda() self.trainer.optimizer = Optimizer()( model=self.trainer.model, distributed=self.trainer.distributed) self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer) count_input = torch.FloatTensor(*count_input).cuda() elif vega.is_tf_backend(): tf.compat.v1.reset_default_graph() count_input = tf.random.uniform(count_input, dtype=tf.float32) sess_config = self.trainer._init_session_config() self.flops_count, self.params_count = calc_model_flops_params( model, count_input, custom_hooks=quantizer.custom_hooks()) self.latency_count = calc_forward_latency(model, count_input, sess_config) logging.info("after quant model glops=%sM, params=%sK, latency=%sms", self.flops_count * 1e-6, self.params_count * 1e-3, self.latency_count * 1000) self.validate()
def before_train(self, logs=None): """Be called before the train process.""" self.config = self.trainer.config self.device = self.trainer.config.device self.base_net_desc = self.trainer.model.desc sess_config = None if vega.is_torch_backend(): count_input = torch.FloatTensor(1, 3, 32, 32).to(self.device) elif vega.is_tf_backend(): count_input = tf.random.uniform([1, 3, 32, 32], dtype=tf.float32) sess_config = self.trainer._init_session_config() elif vega.is_ms_backend(): count_input = mindspore.Tensor( np.random.randn(1, 3, 32, 32).astype(np.float32)) self.flops_count, self.params_count = calc_model_flops_params( self.trainer.model, count_input) self.latency_count = calc_forward_latency(self.trainer.model, count_input, sess_config) logging.info("after prune model glops=%sM, params=%sK, latency=%sms", self.flops_count * 1e-6, self.params_count * 1e-3, self.latency_count * 1000) self.trainer.model = self._generate_init_model() if vega.is_torch_backend(): self.trainer.optimizer = Optimizer()( model=self.trainer.model, distributed=self.trainer.distributed) self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer)
def update_flops_params(self, epoch=None, logs=None): """Calculate flops and params.""" self.model = self.trainer.model try: if self.flops is None: flops_count, params_count = calc_model_flops_params( self.model, self.input) self.flops, self.params = flops_count * 1e-9, params_count * 1e-3 if self.latency is None: sess_config = self.trainer._init_session_config( ) if zeus.is_tf_backend() else None self.latency = calc_forward_latency(self.model, self.input, sess_config) * 1000 summary_perfs = logs.get('summary_perfs', {}) if epoch: summary_perfs.update({ 'flops': self.flops, 'params': self.params, 'latency': self.latency, 'epoch': epoch }) else: summary_perfs.update({ 'flops': self.flops, 'params': self.params, 'latency': self.latency }) logs.update({'summary_perfs': summary_perfs}) except Exception as ex: logging.warning("model statics failed, ex=%s", ex)
def before_train(self, logs=None): """Be called before the training process.""" self.config = self.trainer.config count_input = torch.FloatTensor(1, 3, 1024, 1024).cuda() flops_count, params_count = calc_model_flops_params( self.trainer.model, count_input) self.flops_count, self.params_count = flops_count * 1e-9, params_count * 1e-3 logger.info("Flops: {:.2f} G, Params: {:.1f} K".format( self.flops_count, self.params_count))
def is_filtered(self, desc=None): """Filter function of Flops and Params.""" if self.flops_range is None and self.params_range is None: return False model, count_input = self.get_model_input(desc) flops, params = calc_model_flops_params(model, count_input) flops, params = flops * 1e-9, params * 1e-3 if self.flops_range is not None: if flops < self.flops_range[0] or flops > self.flops_range[1]: return True if self.params_range is not None: if params < self.params_range[0] or params > self.params_range[1]: return True return False
def before_train(self, logs=None): """Be called before the training process.""" self.config = self.trainer.config if vega.is_torch_backend(): count_input = torch.FloatTensor(1, 3, 192, 192).cuda() elif vega.is_tf_backend(): tf.compat.v1.reset_default_graph() count_input = tf.random.uniform([1, 192, 192, 3], dtype=tf.float32) elif vega.is_ms_backend(): count_input = mindspore.Tensor( np.random.randn(1, 3, 192, 192).astype(np.float32)) flops_count, params_count = calc_model_flops_params( self.trainer.model, count_input) self.flops_count, self.params_count = flops_count * 1e-9, params_count * 1e-3 logger.info("Flops: {:.2f} G, Params: {:.1f} K".format( self.flops_count, self.params_count))
def is_filtered(self, desc=None): """Filter function of Flops and Params.""" if self.flops_range is None and self.params_range is None: return False model, count_input = self.get_model_input(desc) flops, params = calc_model_flops_params(model, count_input) flops, params = flops * 1e-9, params * 1e-3 if self.flops_range is not None: if flops < self.flops_range[0] or flops > self.flops_range[1]: logger.info("The flops is out of range. Skip this network.") return True if self.params_range is not None: if params < self.params_range[0] or params > self.params_range[1]: logger.info( "The parameters is out of range. Skip this network.") return True return False
def update_flops_params(self, epoch=None, logs=None): """Calculate flops and params.""" self.model = self.trainer.model try: if self.flops is None: flops_count, params_count = calc_model_flops_params(self.model, self.input) self.flops, self.params = flops_count * 1e-9, params_count * 1e-3 summary_perfs = logs.get('summary_perfs', {}) if epoch: summary_perfs.update({'flops': self.flops, 'params': self.params, 'epoch': epoch}) else: summary_perfs.update({'flops': self.flops, 'params': self.params}) logs.update({'summary_perfs': summary_perfs}) logging.info("flops: {} , params:{}".format(self.flops, self.params)) except Exception as ex: logging.warning("model statics failed, ex=%s", ex)
def before_train(self, logs=None): """Be called before the train process.""" self.config = self.trainer.config self.device = self.trainer.config.device self.base_net_desc = self.trainer.model.desc if vega.is_torch_backend(): count_input = torch.FloatTensor(1, 3, 32, 32).to(self.device) elif vega.is_tf_backend(): count_input = tf.random.uniform([1, 32, 32, 3], dtype=tf.float32) elif vega.is_ms_backend(): count_input = mindspore.Tensor( np.random.randn(1, 3, 32, 32).astype(np.float32)) self.flops_count, self.params_count = calc_model_flops_params( self.trainer.model, count_input) print( f"flops:{self.flops_count}, model size:{self.params_count*4/1024**2} MB" ) self.trainer.model = self._generate_init_model()