Esempio n. 1
0
    def update(self, labels, preds, masks=None):
        # pylint: disable=arguments-differ
        """Updates the internal evaluation result.

        Parameters
        ----------
        labels : list of `NDArray`
            The labels of the data with class indices as values, one per sample.
        preds : list of `NDArray`
            Prediction values for samples. Each prediction value can either be the class index,
            or a vector of likelihoods for all classes.
        masks : list of `NDArray` or None, optional
            Masks for samples, with the same shape as `labels`. value of its element must
            be either 1 or 0. If None, all samples are considered valid.
        """
        labels, preds = check_label_shapes(labels, preds, True)
        masks = [None] * len(labels) if masks is None else masks

        for label, pred_label, mask in zip(labels, preds, masks):
            if pred_label.shape != label.shape:
                # TODO(haibin) topk does not support fp16. Issue tracked at:
                # https://github.com/apache/incubator-mxnet/issues/14125
                # topk is used because argmax is slow:
                # https://github.com/apache/incubator-mxnet/issues/11061
                pred_label = ndarray.topk(pred_label.astype('float32',
                                                            copy=False),
                                          k=1,
                                          ret_typ='indices',
                                          axis=self.axis)

            # flatten before checking shapes to avoid shape miss match
            pred_label = pred_label.astype('int32', copy=False).reshape((-1, ))
            label = label.astype('int32', copy=False).reshape((-1, ))
            check_label_shapes(label, pred_label)

            if mask is not None:
                mask = mask.astype('int32', copy=False).reshape((-1, ))
                check_label_shapes(label, mask)
                num_correct = ((pred_label == label) * mask).sum().asscalar()
                num_inst = mask.sum().asscalar()
            else:
                num_correct = (pred_label == label).sum().asscalar()
                num_inst = len(label)
            self.sum_metric += num_correct
            self.global_sum_metric += num_correct
            self.num_inst += num_inst
            self.global_num_inst += num_inst
Esempio n. 2
0
def topk(input, k, dim, descending=True):
    return nd.topk(input,
                   axis=dim,
                   k=k,
                   ret_typ='value',
                   is_ascend=not descending)
Esempio n. 3
0
    def update_masks(self, index, weight):
        """Updates the masks for sparse training.

        Parameters
        ----------
        index : int
            The index for weight.
        weight : NDArray
            The weight matrix.

        Returns
        -------
        boolean
            If the masks were changed
        """
        # determine number of updates without actually updating the count
        if index not in self._index_update_count:
            num_update = self.begin_num_update
        else:
            num_update = self._index_update_count[index]
        num_update += 1
        num_update = max(num_update, self.num_update)

        # calculate epoch
        epoch = int((num_update - 1) / self.batches_per_epoch) + 1

        # determine if masks need to be updated, and get corresponding parameters
        if index == 0:
            self.masks_updated = True
        if self.epoch != epoch:
            self.epoch = epoch
            if epoch == 1:
                self.masks_updated = False
                if self.weight_sparsity is not None:
                    logging.info(
                        log + 'bias-sparsity={}, weight-sparsity={}'.format(
                            self.bias_sparsity[0], self.weight_sparsity[0]))
                else:
                    logging.info(
                        log + 'bias-threshold={}, weight-threshold={}'.format(
                            self.bias_threshold[0], self.weight_threshold[0]))
            if self.pruning_switch_epoch[0] + 1 == epoch:
                self.masks_updated = False
                self.pruning_switch_epoch.pop(0)
                if self.weight_sparsity is not None:
                    self.weight_sparsity.pop(0)
                    self.bias_sparsity.pop(0)
                    logging.info(
                        log + 'bias-sparsity={}, weight-sparsity={}'.format(
                            self.bias_sparsity[0], self.weight_sparsity[0]))
                else:
                    self.weight_threshold.pop(0)
                    self.bias_threshold.pop(0)
                    logging.info(
                        log + 'bias-threshold={}, weight-threshold={}'.format(
                            self.bias_threshold[0], self.weight_threshold[0]))

        # update masks if needed
        if not self.masks_updated:
            # initialize masks
            if epoch == 1:
                self.masks.append(None)
            # if percentages are given
            if self.weight_sparsity is not None:
                if len(weight.shape) == 1:
                    sparsity = self.bias_sparsity[0]
                else:
                    sparsity = self.weight_sparsity[0]
                number_unpruned = int((100.0 - sparsity) * weight.size / 100.0)
                self.masks[index] = topk(NDabs(weight),
                                         axis=None,
                                         ret_typ='mask',
                                         k=number_unpruned)
            # if thresholds are given
            else:
                if len(weight.shape) == 1:
                    threshold = self.bias_threshold[0]
                else:
                    threshold = self.weight_threshold[0]
                self.masks[index] = NDabs(weight) >= threshold

        return not self.masks_updated
Esempio n. 4
0
    def update_masks(self, index, weight):
        """Updates the masks for sparse training.

        Parameters
        ----------
        index : int
            The index for weight.
        weight : NDArray
            The weight matrix.

        Returns
        -------
        boolean
            If the masks were changed
        """
        # determine number of updates without actually updating the count
        if index not in self._index_update_count:
            num_update = self.begin_num_update
        else:
            num_update = self._index_update_count[index]
        num_update += 1
        num_update = max(num_update, self.num_update)

        # calculate epoch
        epoch = int((num_update - 1) / self.batches_per_epoch) + 1

        # determine if masks need to be updated, and get corresponding parameters
        if index == 0:
            self.masks_updated = True
        if self.epoch != epoch:
            self.epoch = epoch
            if epoch == 1:
                self.masks_updated = False
                if self.weight_sparsity is not None:
                    logging.info(log + 'bias-sparsity={}, weight-sparsity={}'.format(self.bias_sparsity[0], self.weight_sparsity[0]))
                else:
                    logging.info(log + 'bias-threshold={}, weight-threshold={}'.format(self.bias_threshold[0], self.weight_threshold[0]))
            if self.pruning_switch_epoch[0] + 1 == epoch:
                self.masks_updated = False
                self.pruning_switch_epoch.pop(0)
                if self.weight_sparsity is not None:
                    self.weight_sparsity.pop(0)
                    self.bias_sparsity.pop(0)
                    logging.info(log + 'bias-sparsity={}, weight-sparsity={}'.format(self.bias_sparsity[0], self.weight_sparsity[0]))
                else:
                    self.weight_threshold.pop(0)
                    self.bias_threshold.pop(0)
                    logging.info(log + 'bias-threshold={}, weight-threshold={}'.format(self.bias_threshold[0], self.weight_threshold[0]))

        # update masks if needed
        if not self.masks_updated:
            # initialize masks
            if epoch == 1:
                self.masks.append(None)
            # if percentages are given
            if self.weight_sparsity is not None:
                if len(weight.shape) == 1:
                    sparsity = self.bias_sparsity[0]
                else:
                    sparsity = self.weight_sparsity[0]
                number_unpruned = int((100.0 - sparsity) * weight.size / 100.0)
                self.masks[index] = topk(NDabs(weight), axis=None, ret_typ='mask',
                                         k=number_unpruned)
            # if thresholds are given
            else:
                if len(weight.shape) == 1:
                    threshold = self.bias_threshold[0]
                else:
                    threshold = self.weight_threshold[0]
                self.masks[index] = NDabs(weight) >= threshold

        return not self.masks_updated