def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None): """ Load checkpoint into net for distributed predication. Args: network (Cell): Network for distributed predication. checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, it means that the predication process just uses single device. Default: None. Raises: TypeError: The type of inputs do not match the requirements. ValueError: Failed to load checkpoint into net. """ network = Validator.check_isinstance("network", network, nn.Cell) for index, filename in enumerate(checkpoint_filenames): if not isinstance(filename, str) or not os.path.exists(filename) \ or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0: raise ValueError( f"Please make sure that the {filename} at index {index} is a valid checkpoint file." ) if not _check_predict_strategy(predict_strategy): raise ValueError( f"Please make sure that the key of predict_strategy is str, " f"and the value is a list or a tuple that the first four elements are " f"dev_matrix (list[int]), tensor_map (list[int]), " f"param_split_shape (list[int]) and field_size (zero).") train_strategy_filename = context.get_auto_parallel_context( "strategy_ckpt_load_file") _train_strategy = build_searched_strategy(train_strategy_filename) train_strategy = _convert_to_list(_train_strategy) train_dev_count = 1 for dim in train_strategy[list(train_strategy.keys())[0]][0]: train_dev_count *= dim if train_dev_count != len(checkpoint_filenames): raise ValueError( f"The length of checkpoint_filenames should be equal to the device count of training process. " f"The length is {len(checkpoint_filenames)} but the device count is {train_dev_count}." ) rank_list = _infer_rank_list(train_strategy, predict_strategy) param_dict = {} for _, param in network.parameters_and_names(): sliced_params = [] if param.name not in rank_list.keys(): continue param_rank = rank_list[param.name][0] skip_merge_split = rank_list[param.name][1] for rank in param_rank: sliced_param = _load_single_param(checkpoint_filenames[rank], param.name) sliced_params.append(sliced_param) if skip_merge_split: split_param = sliced_params[0] else: param_unique_strategy = _remove_repeated_slices( train_strategy[param.name]) _param_unique_strategy = _convert_to_layout( param.name, param_unique_strategy) split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy) param_dict[param.name] = split_param load_param_into_net(network, param_dict)
def test_infer_rank_list1(): train_map = {'weight': [[4, 8], [-1, 0]]} predict_map = None rank_list = _infer_rank_list(train_map, predict_map)["weight"] assert list(rank_list[0]) == [0, 1, 2, 3, 4, 5, 6, 7] assert rank_list[1] is False
def test_infer_rank_list3(): train_map = {'weight': [[4, 8], [-1, 0]]} predict_map = {'weight': [[4, 8], [-1, 0]]} rank_list = _infer_rank_list(train_map, predict_map) expect_map = {'weight': ([0], True)} assert rank_list == expect_map
def test_infer_rank_list5(): train_map = {'weight': [[8], [-1, -1]]} predict_map = {'weight': [[2, 2], [1, 0]]} rank_list = _infer_rank_list(train_map, predict_map) expect_map = {'weight': ([0], False)} assert rank_list == expect_map