def before_train(self, logs=None): """Be called before the train process.""" self.config = self.trainer.config model_code = copy.deepcopy(self.trainer.model.desc) model = self.trainer.model logging.info('current code: %s, %s', model_code.nbit_w_list, model_code.nbit_a_list) quantizer = Quantizer(model, model_code.nbit_w_list, model_code.nbit_a_list) model = quantizer() self.trainer.model = model count_input = [1, 3, 32, 32] sess_config = None if vega.is_torch_backend(): model = model.cuda() self.trainer.optimizer = Optimizer()( model=self.trainer.model, distributed=self.trainer.distributed) self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer) count_input = torch.FloatTensor(*count_input).cuda() elif vega.is_tf_backend(): tf.compat.v1.reset_default_graph() count_input = tf.random.uniform(count_input, dtype=tf.float32) sess_config = self.trainer._init_session_config() self.flops_count, self.params_count = calc_model_flops_params( model, count_input, custom_hooks=quantizer.custom_hooks()) self.latency_count = calc_forward_latency(model, count_input, sess_config) logging.info("after quant model glops=%sM, params=%sK, latency=%sms", self.flops_count * 1e-6, self.params_count * 1e-3, self.latency_count * 1000) self.validate()
def before_train(self, logs=None): """Be called before the train process.""" self.config = self.trainer.config self.device = self.trainer.config.device self.base_net_desc = self.trainer.model.desc sess_config = None if vega.is_torch_backend(): count_input = torch.FloatTensor(1, 3, 32, 32).to(self.device) elif vega.is_tf_backend(): count_input = tf.random.uniform([1, 3, 32, 32], dtype=tf.float32) sess_config = self.trainer._init_session_config() elif vega.is_ms_backend(): count_input = mindspore.Tensor( np.random.randn(1, 3, 32, 32).astype(np.float32)) self.flops_count, self.params_count = calc_model_flops_params( self.trainer.model, count_input) self.latency_count = calc_forward_latency(self.trainer.model, count_input, sess_config) logging.info("after prune model glops=%sM, params=%sK, latency=%sms", self.flops_count * 1e-6, self.params_count * 1e-3, self.latency_count * 1000) self.trainer.model = self._generate_init_model() if vega.is_torch_backend(): self.trainer.optimizer = Optimizer()( model=self.trainer.model, distributed=self.trainer.distributed) self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer)
def update_flops_params(self, epoch=None, logs=None): """Calculate flops and params.""" self.model = self.trainer.model try: if self.flops is None: flops_count, params_count = calc_model_flops_params( self.model, self.input) self.flops, self.params = flops_count * 1e-9, params_count * 1e-3 if self.latency is None: sess_config = self.trainer._init_session_config( ) if zeus.is_tf_backend() else None self.latency = calc_forward_latency(self.model, self.input, sess_config) * 1000 summary_perfs = logs.get('summary_perfs', {}) if epoch: summary_perfs.update({ 'flops': self.flops, 'params': self.params, 'latency': self.latency, 'epoch': epoch }) else: summary_perfs.update({ 'flops': self.flops, 'params': self.params, 'latency': self.latency }) logs.update({'summary_perfs': summary_perfs}) except Exception as ex: logging.warning("model statics failed, ex=%s", ex)
def is_filtered(self, desc=None): """Filter function of latency.""" if self.max_latency is None: return False model, count_input = self.get_model_input(desc) trainer = ClassFactory.get_cls(ClassType.TRAINER)(model_desc=desc) sess_config = trainer._init_session_config() if zeus.is_tf_backend( ) else None latency = calc_forward_latency(model, count_input, sess_config) logging.info('Sampled model\'s latency: {}ms'.format(latency * 1000)) if latency > self.max_latency: return True else: return False
def update_latency(self, epoch=None, logs=None): """Calculate latency.""" self.model = self.trainer.model try: summary_perfs = logs.get('summary_perfs', {}) if self.latency is None: sess_config = self.trainer._init_session_config() if zeus.is_tf_backend() else None self.latency = calc_forward_latency(self.model, self.input, sess_config) * 1000 if epoch: summary_perfs.update({'latency': self.latency, 'epoch': epoch}) else: summary_perfs.update({'latency': self.latency}) logs.update({'summary_perfs': summary_perfs}) logging.info("flops: {} , params:{}, latency:{}".format(self.flops, self.params, self.latency)) except Exception as ex: logging.warning("model statics failed, ex=%s", ex)