Exemple #1
0
def load_distributed_checkpoint(network,
                                checkpoint_filenames,
                                predict_strategy=None):
    """
    Load checkpoint into net for distributed predication.

    Args:
        network (Cell): Network for distributed predication.
        checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
        predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or
            a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
            it means that the predication process just uses single device. Default: None.

    Raises:
        TypeError: The type of inputs do not match the requirements.
        ValueError: Failed to load checkpoint into net.
    """
    network = Validator.check_isinstance("network", network, nn.Cell)

    for index, filename in enumerate(checkpoint_filenames):
        if not isinstance(filename, str) or not os.path.exists(filename) \
                or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
            raise ValueError(
                f"Please make sure that the {filename} at index {index} is a valid checkpoint file."
            )

    if not _check_predict_strategy(predict_strategy):
        raise ValueError(
            f"Please make sure that the key of predict_strategy is str, "
            f"and the value is a list or a tuple that the first four elements are "
            f"dev_matrix (list[int]), tensor_map (list[int]), "
            f"param_split_shape (list[int]) and field_size (zero).")

    train_strategy_filename = context.get_auto_parallel_context(
        "strategy_ckpt_load_file")
    _train_strategy = build_searched_strategy(train_strategy_filename)
    train_strategy = _convert_to_list(_train_strategy)

    train_dev_count = 1
    for dim in train_strategy[list(train_strategy.keys())[0]][0]:
        train_dev_count *= dim
    if train_dev_count != len(checkpoint_filenames):
        raise ValueError(
            f"The length of checkpoint_filenames should be equal to the device count of training process. "
            f"The length is {len(checkpoint_filenames)} but the device count is {train_dev_count}."
        )

    rank_list = _infer_rank_list(train_strategy, predict_strategy)

    param_dict = {}
    for _, param in network.parameters_and_names():
        sliced_params = []
        if param.name not in rank_list.keys():
            continue
        param_rank = rank_list[param.name][0]
        skip_merge_split = rank_list[param.name][1]
        for rank in param_rank:
            sliced_param = _load_single_param(checkpoint_filenames[rank],
                                              param.name)
            sliced_params.append(sliced_param)
        if skip_merge_split:
            split_param = sliced_params[0]
        else:
            param_unique_strategy = _remove_repeated_slices(
                train_strategy[param.name])
            _param_unique_strategy = _convert_to_layout(
                param.name, param_unique_strategy)
            split_param = _merge_and_split(sliced_params,
                                           _param_unique_strategy,
                                           predict_strategy)
        param_dict[param.name] = split_param

    load_param_into_net(network, param_dict)
def test_infer_rank_list1():
    train_map = {'weight': [[4, 8], [-1, 0]]}
    predict_map = None
    rank_list = _infer_rank_list(train_map, predict_map)["weight"]
    assert list(rank_list[0]) == [0, 1, 2, 3, 4, 5, 6, 7]
    assert rank_list[1] is False
def test_infer_rank_list3():
    train_map = {'weight': [[4, 8], [-1, 0]]}
    predict_map = {'weight': [[4, 8], [-1, 0]]}
    rank_list = _infer_rank_list(train_map, predict_map)
    expect_map = {'weight': ([0], True)}
    assert rank_list == expect_map
def test_infer_rank_list5():
    train_map = {'weight': [[8], [-1, -1]]}
    predict_map = {'weight': [[2, 2], [1, 0]]}
    rank_list = _infer_rank_list(train_map, predict_map)
    expect_map = {'weight': ([0], False)}
    assert rank_list == expect_map