Пример #1
0
 def __init__(self,
              in_channels,
              out_channels,
              kernel_size,
              stride=1,
              pad_mode='same',
              padding=0,
              dilation=1,
              group=1,
              has_bias=False,
              weight_init='normal',
              bias_init='zeros',
              batchnorm=None,
              activation=None):
     super(Conv2dBnAct, self).__init__()
     self.conv = conv.Conv2d(in_channels, out_channels, kernel_size, stride,
                             pad_mode, padding, dilation, group, has_bias,
                             weight_init, bias_init)
     self.has_bn = batchnorm is not None
     self.has_act = activation is not None
     self.batchnorm = batchnorm
     if batchnorm is True:
         self.batchnorm = BatchNorm2d(out_channels)
     elif batchnorm is not None:
         validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d, ))
     self.activation = get_activation(activation)
Пример #2
0
    def __init__(self,
                 save_checkpoint_steps=1,
                 save_checkpoint_seconds=0,
                 keep_checkpoint_max=5,
                 keep_checkpoint_per_n_minutes=0,
                 integrated_save=True,
                 async_save=False,
                 saved_network=None,
                 enc_key=None,
                 enc_mode='AES-GCM'):

        if save_checkpoint_steps is not None:
            save_checkpoint_steps = Validator.check_non_negative_int(
                save_checkpoint_steps)
        if save_checkpoint_seconds is not None:
            save_checkpoint_seconds = Validator.check_non_negative_int(
                save_checkpoint_seconds)
        if keep_checkpoint_max is not None:
            keep_checkpoint_max = Validator.check_non_negative_int(
                keep_checkpoint_max)
        if keep_checkpoint_per_n_minutes is not None:
            keep_checkpoint_per_n_minutes = Validator.check_non_negative_int(
                keep_checkpoint_per_n_minutes)

        if saved_network is not None and not isinstance(
                saved_network, nn.Cell):
            raise TypeError(
                f"The type of saved_network must be None or Cell, but got {str(type(saved_network))}."
            )

        if not save_checkpoint_steps and not save_checkpoint_seconds and \
                not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
            raise ValueError("The input_param can't be all None or 0")

        self._save_checkpoint_steps = save_checkpoint_steps
        self._save_checkpoint_seconds = save_checkpoint_seconds
        if self._save_checkpoint_steps and self._save_checkpoint_steps > 0:
            self._save_checkpoint_seconds = None

        self._keep_checkpoint_max = keep_checkpoint_max
        self._keep_checkpoint_per_n_minutes = keep_checkpoint_per_n_minutes
        if self._keep_checkpoint_max and self._keep_checkpoint_max > 0:
            self._keep_checkpoint_per_n_minutes = None
        else:
            if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
                self._keep_checkpoint_max = 1

        self._integrated_save = Validator.check_bool(integrated_save)
        self._async_save = Validator.check_bool(async_save)
        self._saved_network = saved_network
        self._enc_key = Validator.check_isinstance('enc_key', enc_key,
                                                   (type(None), bytes))
        self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
Пример #3
0
def _check_predict_strategy(predict_strategy):
    """Check predict strategy."""
    def _check_int_list(arg):
        if not isinstance(arg, list):
            return False
        for item in arg:
            if not isinstance(item, int):
                return False
        return True

    if predict_strategy is None:
        return True

    predict_strategy = Validator.check_isinstance("predict_strategy",
                                                  predict_strategy, dict)
    for key in predict_strategy.keys():
        if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \
                or len(predict_strategy[key]) < 4:
            return False
        dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[
            key][:4]
        if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \
                not (_check_int_list(param_split_shape) or not param_split_shape) or \
                not (isinstance(field_size, int) and field_size == 0):
            return False
    return True
Пример #4
0
 def __init__(self,
              in_channels,
              out_channels,
              weight_init='normal',
              bias_init='zeros',
              has_bias=True,
              batchnorm=None,
              activation=None):
     super(DenseBnAct, self).__init__()
     self.dense = basic.Dense(in_channels, out_channels, weight_init,
                              bias_init, has_bias)
     self.has_bn = batchnorm is not None
     self.has_act = activation is not None
     if batchnorm is True:
         self.batchnorm = BatchNorm2d(out_channels)
     elif batchnorm is not None:
         validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d, ))
     self.activation = get_activation(activation)
Пример #5
0
def query_quant_layers(network):
    r"""
    Query the network's quantization strategy of each quantized layer and print it to the screen, note that all the
    quantization layers are queried before graph compile optimization in the graph mode, thus may be appear some
    redundant quantized layers, which are not exist in practical execution.

    Input:
        network (Cell): input network

    Returns:
        None
    """
    network = Validator.check_isinstance("network", network, nn.Cell)
    tplt = "{0:60}\t{1:10}"
    for cell_and_name in network.cells_and_names():
        cell_name = cell_and_name[0]
        cell = cell_and_name[1]
        if isinstance(cell, nn.FakeQuantWithMinMaxObserver):
            print(tplt.format(cell_name, cell.quant_dtype))
Пример #6
0
def load_distributed_checkpoint(network,
                                checkpoint_filenames,
                                predict_strategy=None):
    """
    Load checkpoint into net for distributed predication.

    Args:
        network (Cell): Network for distributed predication.
        checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
        predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or
            a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
            it means that the predication process just uses single device. Default: None.

    Raises:
        TypeError: The type of inputs do not match the requirements.
        ValueError: Failed to load checkpoint into net.
    """
    network = Validator.check_isinstance("network", network, nn.Cell)

    for index, filename in enumerate(checkpoint_filenames):
        if not isinstance(filename, str) or not os.path.exists(filename) \
                or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
            raise ValueError(
                f"Please make sure that the {filename} at index {index} is a valid checkpoint file."
            )

    if not _check_predict_strategy(predict_strategy):
        raise ValueError(
            f"Please make sure that the key of predict_strategy is str, "
            f"and the value is a list or a tuple that the first four elements are "
            f"dev_matrix (list[int]), tensor_map (list[int]), "
            f"param_split_shape (list[int]) and field_size (zero).")

    train_strategy_filename = context.get_auto_parallel_context(
        "strategy_ckpt_load_file")
    _train_strategy = build_searched_strategy(train_strategy_filename)
    train_strategy = _convert_to_list(_train_strategy)

    train_dev_count = 1
    for dim in train_strategy[list(train_strategy.keys())[0]][0]:
        train_dev_count *= dim
    if train_dev_count != len(checkpoint_filenames):
        raise ValueError(
            f"The length of checkpoint_filenames should be equal to the device count of training process. "
            f"The length is {len(checkpoint_filenames)} but the device count is {train_dev_count}."
        )

    rank_list = _infer_rank_list(train_strategy, predict_strategy)

    param_dict = {}
    for _, param in network.parameters_and_names():
        sliced_params = []
        if param.name not in rank_list.keys():
            continue
        param_rank = rank_list[param.name][0]
        skip_merge_split = rank_list[param.name][1]
        for rank in param_rank:
            sliced_param = _load_single_param(checkpoint_filenames[rank],
                                              param.name)
            sliced_params.append(sliced_param)
        if skip_merge_split:
            split_param = sliced_params[0]
        else:
            param_unique_strategy = _remove_repeated_slices(
                train_strategy[param.name])
            _param_unique_strategy = _convert_to_layout(
                param.name, param_unique_strategy)
            split_param = _merge_and_split(sliced_params,
                                           _param_unique_strategy,
                                           predict_strategy)
        param_dict[param.name] = split_param

    load_param_into_net(network, param_dict)