def update(self, labels, preds, masks=None): # pylint: disable=arguments-differ """Updates the internal evaluation result. Parameters ---------- labels : list of `NDArray` The labels of the data with class indices as values, one per sample. preds : list of `NDArray` Prediction values for samples. Each prediction value can either be the class index, or a vector of likelihoods for all classes. masks : list of `NDArray` or None, optional Masks for samples, with the same shape as `labels`. value of its element must be either 1 or 0. If None, all samples are considered valid. """ labels, preds = check_label_shapes(labels, preds, True) masks = [None] * len(labels) if masks is None else masks for label, pred_label, mask in zip(labels, preds, masks): if pred_label.shape != label.shape: # TODO(haibin) topk does not support fp16. Issue tracked at: # https://github.com/apache/incubator-mxnet/issues/14125 # topk is used because argmax is slow: # https://github.com/apache/incubator-mxnet/issues/11061 pred_label = ndarray.topk(pred_label.astype('float32', copy=False), k=1, ret_typ='indices', axis=self.axis) # flatten before checking shapes to avoid shape miss match pred_label = pred_label.astype('int32', copy=False).reshape((-1, )) label = label.astype('int32', copy=False).reshape((-1, )) check_label_shapes(label, pred_label) if mask is not None: mask = mask.astype('int32', copy=False).reshape((-1, )) check_label_shapes(label, mask) num_correct = ((pred_label == label) * mask).sum().asscalar() num_inst = mask.sum().asscalar() else: num_correct = (pred_label == label).sum().asscalar() num_inst = len(label) self.sum_metric += num_correct self.global_sum_metric += num_correct self.num_inst += num_inst self.global_num_inst += num_inst
def topk(input, k, dim, descending=True): return nd.topk(input, axis=dim, k=k, ret_typ='value', is_ascend=not descending)
def update_masks(self, index, weight): """Updates the masks for sparse training. Parameters ---------- index : int The index for weight. weight : NDArray The weight matrix. Returns ------- boolean If the masks were changed """ # determine number of updates without actually updating the count if index not in self._index_update_count: num_update = self.begin_num_update else: num_update = self._index_update_count[index] num_update += 1 num_update = max(num_update, self.num_update) # calculate epoch epoch = int((num_update - 1) / self.batches_per_epoch) + 1 # determine if masks need to be updated, and get corresponding parameters if index == 0: self.masks_updated = True if self.epoch != epoch: self.epoch = epoch if epoch == 1: self.masks_updated = False if self.weight_sparsity is not None: logging.info( log + 'bias-sparsity={}, weight-sparsity={}'.format( self.bias_sparsity[0], self.weight_sparsity[0])) else: logging.info( log + 'bias-threshold={}, weight-threshold={}'.format( self.bias_threshold[0], self.weight_threshold[0])) if self.pruning_switch_epoch[0] + 1 == epoch: self.masks_updated = False self.pruning_switch_epoch.pop(0) if self.weight_sparsity is not None: self.weight_sparsity.pop(0) self.bias_sparsity.pop(0) logging.info( log + 'bias-sparsity={}, weight-sparsity={}'.format( self.bias_sparsity[0], self.weight_sparsity[0])) else: self.weight_threshold.pop(0) self.bias_threshold.pop(0) logging.info( log + 'bias-threshold={}, weight-threshold={}'.format( self.bias_threshold[0], self.weight_threshold[0])) # update masks if needed if not self.masks_updated: # initialize masks if epoch == 1: self.masks.append(None) # if percentages are given if self.weight_sparsity is not None: if len(weight.shape) == 1: sparsity = self.bias_sparsity[0] else: sparsity = self.weight_sparsity[0] number_unpruned = int((100.0 - sparsity) * weight.size / 100.0) self.masks[index] = topk(NDabs(weight), axis=None, ret_typ='mask', k=number_unpruned) # if thresholds are given else: if len(weight.shape) == 1: threshold = self.bias_threshold[0] else: threshold = self.weight_threshold[0] self.masks[index] = NDabs(weight) >= threshold return not self.masks_updated
def update_masks(self, index, weight): """Updates the masks for sparse training. Parameters ---------- index : int The index for weight. weight : NDArray The weight matrix. Returns ------- boolean If the masks were changed """ # determine number of updates without actually updating the count if index not in self._index_update_count: num_update = self.begin_num_update else: num_update = self._index_update_count[index] num_update += 1 num_update = max(num_update, self.num_update) # calculate epoch epoch = int((num_update - 1) / self.batches_per_epoch) + 1 # determine if masks need to be updated, and get corresponding parameters if index == 0: self.masks_updated = True if self.epoch != epoch: self.epoch = epoch if epoch == 1: self.masks_updated = False if self.weight_sparsity is not None: logging.info(log + 'bias-sparsity={}, weight-sparsity={}'.format(self.bias_sparsity[0], self.weight_sparsity[0])) else: logging.info(log + 'bias-threshold={}, weight-threshold={}'.format(self.bias_threshold[0], self.weight_threshold[0])) if self.pruning_switch_epoch[0] + 1 == epoch: self.masks_updated = False self.pruning_switch_epoch.pop(0) if self.weight_sparsity is not None: self.weight_sparsity.pop(0) self.bias_sparsity.pop(0) logging.info(log + 'bias-sparsity={}, weight-sparsity={}'.format(self.bias_sparsity[0], self.weight_sparsity[0])) else: self.weight_threshold.pop(0) self.bias_threshold.pop(0) logging.info(log + 'bias-threshold={}, weight-threshold={}'.format(self.bias_threshold[0], self.weight_threshold[0])) # update masks if needed if not self.masks_updated: # initialize masks if epoch == 1: self.masks.append(None) # if percentages are given if self.weight_sparsity is not None: if len(weight.shape) == 1: sparsity = self.bias_sparsity[0] else: sparsity = self.weight_sparsity[0] number_unpruned = int((100.0 - sparsity) * weight.size / 100.0) self.masks[index] = topk(NDabs(weight), axis=None, ret_typ='mask', k=number_unpruned) # if thresholds are given else: if len(weight.shape) == 1: threshold = self.bias_threshold[0] else: threshold = self.weight_threshold[0] self.masks[index] = NDabs(weight) >= threshold return not self.masks_updated