Example #1
0
 def before_train(self, logs=None):
     """Be called before the train process."""
     self.config = self.trainer.config
     model_code = copy.deepcopy(self.trainer.model.desc)
     model = self.trainer.model
     logging.info('current code: %s, %s', model_code.nbit_w_list,
                  model_code.nbit_a_list)
     quantizer = Quantizer(model, model_code.nbit_w_list,
                           model_code.nbit_a_list)
     model = quantizer()
     self.trainer.model = model
     count_input = [1, 3, 32, 32]
     sess_config = None
     if vega.is_torch_backend():
         model = model.cuda()
         self.trainer.optimizer = Optimizer()(
             model=self.trainer.model, distributed=self.trainer.distributed)
         self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer)
         count_input = torch.FloatTensor(*count_input).cuda()
     elif vega.is_tf_backend():
         tf.compat.v1.reset_default_graph()
         count_input = tf.random.uniform(count_input, dtype=tf.float32)
         sess_config = self.trainer._init_session_config()
     self.flops_count, self.params_count = calc_model_flops_params(
         model, count_input, custom_hooks=quantizer.custom_hooks())
     self.latency_count = calc_forward_latency(model, count_input,
                                               sess_config)
     logging.info("after quant model glops=%sM, params=%sK, latency=%sms",
                  self.flops_count * 1e-6, self.params_count * 1e-3,
                  self.latency_count * 1000)
     self.validate()
Example #2
0
 def before_train(self, logs=None):
     """Be called before the train process."""
     self.config = self.trainer.config
     self.device = self.trainer.config.device
     self.base_net_desc = self.trainer.model.desc
     sess_config = None
     if vega.is_torch_backend():
         count_input = torch.FloatTensor(1, 3, 32, 32).to(self.device)
     elif vega.is_tf_backend():
         count_input = tf.random.uniform([1, 3, 32, 32], dtype=tf.float32)
         sess_config = self.trainer._init_session_config()
     elif vega.is_ms_backend():
         count_input = mindspore.Tensor(
             np.random.randn(1, 3, 32, 32).astype(np.float32))
     self.flops_count, self.params_count = calc_model_flops_params(
         self.trainer.model, count_input)
     self.latency_count = calc_forward_latency(self.trainer.model,
                                               count_input, sess_config)
     logging.info("after prune model glops=%sM, params=%sK, latency=%sms",
                  self.flops_count * 1e-6, self.params_count * 1e-3,
                  self.latency_count * 1000)
     self.trainer.model = self._generate_init_model()
     if vega.is_torch_backend():
         self.trainer.optimizer = Optimizer()(
             model=self.trainer.model, distributed=self.trainer.distributed)
         self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer)
Example #3
0
 def update_flops_params(self, epoch=None, logs=None):
     """Calculate flops and params."""
     self.model = self.trainer.model
     try:
         if self.flops is None:
             flops_count, params_count = calc_model_flops_params(
                 self.model, self.input)
             self.flops, self.params = flops_count * 1e-9, params_count * 1e-3
         if self.latency is None:
             sess_config = self.trainer._init_session_config(
             ) if zeus.is_tf_backend() else None
             self.latency = calc_forward_latency(self.model, self.input,
                                                 sess_config) * 1000
         summary_perfs = logs.get('summary_perfs', {})
         if epoch:
             summary_perfs.update({
                 'flops': self.flops,
                 'params': self.params,
                 'latency': self.latency,
                 'epoch': epoch
             })
         else:
             summary_perfs.update({
                 'flops': self.flops,
                 'params': self.params,
                 'latency': self.latency
             })
         logs.update({'summary_perfs': summary_perfs})
     except Exception as ex:
         logging.warning("model statics failed, ex=%s", ex)
Example #4
0
 def is_filtered(self, desc=None):
     """Filter function of latency."""
     if self.max_latency is None:
         return False
     model, count_input = self.get_model_input(desc)
     trainer = ClassFactory.get_cls(ClassType.TRAINER)(model_desc=desc)
     sess_config = trainer._init_session_config() if zeus.is_tf_backend(
     ) else None
     latency = calc_forward_latency(model, count_input, sess_config)
     logging.info('Sampled model\'s latency: {}ms'.format(latency * 1000))
     if latency > self.max_latency:
         return True
     else:
         return False
Example #5
0
 def update_latency(self, epoch=None, logs=None):
     """Calculate latency."""
     self.model = self.trainer.model
     try:
         summary_perfs = logs.get('summary_perfs', {})
         if self.latency is None:
             sess_config = self.trainer._init_session_config() if zeus.is_tf_backend() else None
             self.latency = calc_forward_latency(self.model, self.input, sess_config) * 1000
         if epoch:
             summary_perfs.update({'latency': self.latency, 'epoch': epoch})
         else:
             summary_perfs.update({'latency': self.latency})
         logs.update({'summary_perfs': summary_perfs})
         logging.info("flops: {} , params:{}, latency:{}".format(self.flops, self.params, self.latency))
     except Exception as ex:
         logging.warning("model statics failed, ex=%s", ex)