Esempio n. 1
0
 def before_train(self, logs=None):
     """Be called before the train process."""
     self.config = self.trainer.config
     model_code = copy.deepcopy(self.trainer.model.desc)
     model = self.trainer.model
     logging.info('current code: %s, %s', model_code.nbit_w_list,
                  model_code.nbit_a_list)
     quantizer = Quantizer(model, model_code.nbit_w_list,
                           model_code.nbit_a_list)
     model = quantizer()
     self.trainer.model = model
     count_input = [1, 3, 32, 32]
     sess_config = None
     if vega.is_torch_backend():
         model = model.cuda()
         self.trainer.optimizer = Optimizer()(
             model=self.trainer.model, distributed=self.trainer.distributed)
         self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer)
         count_input = torch.FloatTensor(*count_input).cuda()
     elif vega.is_tf_backend():
         tf.compat.v1.reset_default_graph()
         count_input = tf.random.uniform(count_input, dtype=tf.float32)
         sess_config = self.trainer._init_session_config()
     self.flops_count, self.params_count = calc_model_flops_params(
         model, count_input, custom_hooks=quantizer.custom_hooks())
     self.latency_count = calc_forward_latency(model, count_input,
                                               sess_config)
     logging.info("after quant model glops=%sM, params=%sK, latency=%sms",
                  self.flops_count * 1e-6, self.params_count * 1e-3,
                  self.latency_count * 1000)
     self.validate()
Esempio n. 2
0
 def before_train(self, logs=None):
     """Be called before the train process."""
     self.config = self.trainer.config
     self.device = self.trainer.config.device
     self.base_net_desc = self.trainer.model.desc
     sess_config = None
     if vega.is_torch_backend():
         count_input = torch.FloatTensor(1, 3, 32, 32).to(self.device)
     elif vega.is_tf_backend():
         count_input = tf.random.uniform([1, 3, 32, 32], dtype=tf.float32)
         sess_config = self.trainer._init_session_config()
     elif vega.is_ms_backend():
         count_input = mindspore.Tensor(
             np.random.randn(1, 3, 32, 32).astype(np.float32))
     self.flops_count, self.params_count = calc_model_flops_params(
         self.trainer.model, count_input)
     self.latency_count = calc_forward_latency(self.trainer.model,
                                               count_input, sess_config)
     logging.info("after prune model glops=%sM, params=%sK, latency=%sms",
                  self.flops_count * 1e-6, self.params_count * 1e-3,
                  self.latency_count * 1000)
     self.trainer.model = self._generate_init_model()
     if vega.is_torch_backend():
         self.trainer.optimizer = Optimizer()(
             model=self.trainer.model, distributed=self.trainer.distributed)
         self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer)
Esempio n. 3
0
 def update_flops_params(self, epoch=None, logs=None):
     """Calculate flops and params."""
     self.model = self.trainer.model
     try:
         if self.flops is None:
             flops_count, params_count = calc_model_flops_params(
                 self.model, self.input)
             self.flops, self.params = flops_count * 1e-9, params_count * 1e-3
         if self.latency is None:
             sess_config = self.trainer._init_session_config(
             ) if zeus.is_tf_backend() else None
             self.latency = calc_forward_latency(self.model, self.input,
                                                 sess_config) * 1000
         summary_perfs = logs.get('summary_perfs', {})
         if epoch:
             summary_perfs.update({
                 'flops': self.flops,
                 'params': self.params,
                 'latency': self.latency,
                 'epoch': epoch
             })
         else:
             summary_perfs.update({
                 'flops': self.flops,
                 'params': self.params,
                 'latency': self.latency
             })
         logs.update({'summary_perfs': summary_perfs})
     except Exception as ex:
         logging.warning("model statics failed, ex=%s", ex)
 def before_train(self, logs=None):
     """Be called before the training process."""
     self.config = self.trainer.config
     count_input = torch.FloatTensor(1, 3, 1024, 1024).cuda()
     flops_count, params_count = calc_model_flops_params(
         self.trainer.model, count_input)
     self.flops_count, self.params_count = flops_count * 1e-9, params_count * 1e-3
     logger.info("Flops: {:.2f} G, Params: {:.1f} K".format(
         self.flops_count, self.params_count))
 def is_filtered(self, desc=None):
     """Filter function of Flops and Params."""
     if self.flops_range is None and self.params_range is None:
         return False
     model, count_input = self.get_model_input(desc)
     flops, params = calc_model_flops_params(model, count_input)
     flops, params = flops * 1e-9, params * 1e-3
     if self.flops_range is not None:
         if flops < self.flops_range[0] or flops > self.flops_range[1]:
             return True
     if self.params_range is not None:
         if params < self.params_range[0] or params > self.params_range[1]:
             return True
     return False
Esempio n. 6
0
 def before_train(self, logs=None):
     """Be called before the training process."""
     self.config = self.trainer.config
     if vega.is_torch_backend():
         count_input = torch.FloatTensor(1, 3, 192, 192).cuda()
     elif vega.is_tf_backend():
         tf.compat.v1.reset_default_graph()
         count_input = tf.random.uniform([1, 192, 192, 3], dtype=tf.float32)
     elif vega.is_ms_backend():
         count_input = mindspore.Tensor(
             np.random.randn(1, 3, 192, 192).astype(np.float32))
     flops_count, params_count = calc_model_flops_params(
         self.trainer.model, count_input)
     self.flops_count, self.params_count = flops_count * 1e-9, params_count * 1e-3
     logger.info("Flops: {:.2f} G, Params: {:.1f} K".format(
         self.flops_count, self.params_count))
Esempio n. 7
0
 def is_filtered(self, desc=None):
     """Filter function of Flops and Params."""
     if self.flops_range is None and self.params_range is None:
         return False
     model, count_input = self.get_model_input(desc)
     flops, params = calc_model_flops_params(model, count_input)
     flops, params = flops * 1e-9, params * 1e-3
     if self.flops_range is not None:
         if flops < self.flops_range[0] or flops > self.flops_range[1]:
             logger.info("The flops is out of range. Skip this network.")
             return True
     if self.params_range is not None:
         if params < self.params_range[0] or params > self.params_range[1]:
             logger.info(
                 "The parameters is out of range. Skip this network.")
             return True
     return False
Esempio n. 8
0
 def update_flops_params(self, epoch=None, logs=None):
     """Calculate flops and params."""
     self.model = self.trainer.model
     try:
         if self.flops is None:
             flops_count, params_count = calc_model_flops_params(self.model,
                                                                 self.input)
             self.flops, self.params = flops_count * 1e-9, params_count * 1e-3
         summary_perfs = logs.get('summary_perfs', {})
         if epoch:
             summary_perfs.update({'flops': self.flops, 'params': self.params, 'epoch': epoch})
         else:
             summary_perfs.update({'flops': self.flops, 'params': self.params})
         logs.update({'summary_perfs': summary_perfs})
         logging.info("flops: {} , params:{}".format(self.flops, self.params))
     except Exception as ex:
         logging.warning("model statics failed, ex=%s", ex)
Esempio n. 9
0
 def before_train(self, logs=None):
     """Be called before the train process."""
     self.config = self.trainer.config
     self.device = self.trainer.config.device
     self.base_net_desc = self.trainer.model.desc
     if vega.is_torch_backend():
         count_input = torch.FloatTensor(1, 3, 32, 32).to(self.device)
     elif vega.is_tf_backend():
         count_input = tf.random.uniform([1, 32, 32, 3], dtype=tf.float32)
     elif vega.is_ms_backend():
         count_input = mindspore.Tensor(
             np.random.randn(1, 3, 32, 32).astype(np.float32))
     self.flops_count, self.params_count = calc_model_flops_params(
         self.trainer.model, count_input)
     print(
         f"flops:{self.flops_count}, model size:{self.params_count*4/1024**2} MB"
     )
     self.trainer.model = self._generate_init_model()