Ejemplo n.º 1
0
    def get_broadcast_vars_and_param_usage(self, block):
        broadcast_vars = set([])
        fp16_params = set([])
        fp16_to_fp32 = {}

        param_usage = {x: 0 for x in self.param_names}
        for op in block.ops:
            if is_optimizer_op(op):
                continue
            for input_name in op.desc.input_arg_names():
                if input_name in self.param_names:
                    param_usage[input_name] += 1

        for op in block.ops:
            if not _is_param_fp16_cast_op(block, op, self.param_names):
                continue
            input_name = op.input_arg_names[0]
            output_name = op.output_arg_names[0]
            broadcast_vars.add(output_name)
            fp16_params.add(output_name)
            fp16_to_fp32[output_name] = input_name
            param_usage[input_name] -= 1
            self.param_to_rank[output_name] = self.param_to_rank[input_name]

        for param, usage in param_usage.items():
            if usage > 0:
                broadcast_vars.add(param)
        return broadcast_vars, param_usage
Ejemplo n.º 2
0
    def find_broadcast_params(self, block):
        broadcast_vars = set([])
        fp16_params = set([])
        fp16_to_fp32 = {}

        param_usage = {x: 0 for x in self.global_params}
        for op in block.ops:
            if is_optimizer_op(op):
                continue
            for input_name in op.desc.input_arg_names():
                if input_name in self.global_params:
                    param_usage[input_name] += 1

        for op in block.ops:
            if not FP16Utils.is_fp16_cast_op(block, op, self.global_params):
                continue
            input_name = op.input_arg_names[0]
            output_name = op.output_arg_names[0]
            broadcast_vars.add(output_name)
            fp16_params.add(output_name)
            fp16_to_fp32[output_name] = input_name
            param_usage[input_name] -= 1
            self.global_param2device[output_name] = self.global_param2device[
                input_name]

        for param, usage in param_usage.items():
            if usage > 0:
                broadcast_vars.add(param)
        return broadcast_vars
Ejemplo n.º 3
0
    def _shard_optimizer_ops_and_states(self, main_block, startup_block):

        should_removed_optimizer_states = []
        for idx, op in reversed(list(enumerate(main_block.ops))):
            if not is_optimizer_op(op):
                break

            if op.type in _supported_optimizer_type:
                assert "Param" in op.input_names
                assert len(op.input("Param")) == 1
                param_name = op.input("Param")[0]
                if not self._is_parameter_in_local_shard(param_name):
                    should_removed_optimizer_states.extend([
                        varname for varname in op.output_arg_names
                        if varname != param_name
                    ])
                    main_block._remove_op(idx, sync=False)

        for idx, op in reversed(list(enumerate(startup_block.ops))):
            if len(op.output_arg_names) == 1 and op.output_arg_names[
                    0] in should_removed_optimizer_states:
                startup_block._remove_op(idx, sync=False)

        for varname in should_removed_optimizer_states:
            if main_block.has_var(varname):
                main_block._remove_var(varname, sync=False)
            if startup_block.has_var(varname):
                startup_block._remove_var(varname, sync=False)

        main_block._sync_with_cpp()
        startup_block._sync_with_cpp()
Ejemplo n.º 4
0
def _is_param_fp16_cast_op(block, op, params):

    if is_optimizer_op(op):
        return False
    if not _is_desired_cast_op(block, op):
        return False
    input_name = op.desc.input_arg_names()[0]
    if input_name not in params:
        return False
    return True
Ejemplo n.º 5
0
 def is_fp32_cast_op(block, op):
     if op.type != "cast":
         return False
     if not is_optimizer_op(op):
         return False
     assert (len(op.desc.input_arg_names()) == 1)
     assert (len(op.desc.output_arg_names()) == 1)
     input_name, output_name = op.desc.input_arg_names()[
         0], op.desc.output_arg_names()[0]
     input_var = block.var(input_name)
     output_var = block.var(output_name)
     if input_var.dtype != core.VarDesc.VarType.FP16 or \
         output_var.dtype != core.VarDesc.VarType.FP32:
         return False
     return True
Ejemplo n.º 6
0
def fuse_opt_broadcast_param_ops(block,
                                 ring_id,
                                 shard,
                                 op_role=OpRole.Optimize,
                                 strategy=None):
    """
    fuse optimizer sharding broadcast param ops
    """
    if strategy is None or not strategy.fuse_all_reduce_ops:
        return

    fuse_size = strategy.fuse_grad_size_in_MB

    nranks = shard.worker_num
    device_to_vars = [[] for _ in range(nranks)]

    for idx, op in reversed(list(enumerate(block.ops))):
        if not is_optimizer_op(op) or op.type != 'c_broadcast':
            break
        var = op.input_arg_names[0]
        root_id = op.attr('root')
        device_to_vars[root_id].insert(0, var)
        block._remove_op(idx, sync=False)

    insert_idx = idx + 1
    for root_id, vars_name in enumerate(device_to_vars):
        vars_name = FuseHelper.sort_vars_by_dtype(block, vars_name)
        groups = FuseHelper.get_fused_groups(block, vars_name, fuse_size)

        fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(
            block, insert_idx, groups, op_role, prefix="Param")

        for fused_var in fused_vars:
            block._insert_op_without_sync(insert_idx + insert_num,
                                          type='c_broadcast',
                                          inputs={'X': fused_var},
                                          outputs={'Out': fused_var},
                                          attrs={
                                              'ring_id': ring_id,
                                              'root': root_id,
                                              'use_calc_stream': True,
                                              OP_ROLE_KEY: op_role
                                          })

    block._sync_with_cpp()
Ejemplo n.º 7
0
    def _shard_parameter(self, main_block, startup_block):

        if self.stage < 3:
            return

        dp_ring_ids = [group.id for group in self.dp_groups]
        for sharding_info in self.sharding_infos:
            need_broadcast_vars, param_usage = sharding_info.get_broadcast_vars_and_param_usage(
                main_block)
            not_used_param_nane = []
            for param_name in param_usage:
                if param_usage[param_name] == 0 and sharding_info.get_var_rank(
                        param_name) != sharding_info.local_rank:
                    not_used_param_nane.append(param_name)

            for idx, op in reversed(list(enumerate(main_block.ops))):
                if is_optimizer_op(op):
                    continue

                for input_name in op.desc.input_arg_names():
                    if op.type == "cast":
                        continue
                    if input_name not in need_broadcast_vars:
                        continue
                    root_rank = sharding_info.get_var_rank(input_name)
                    if root_rank == sharding_info.local_rank:
                        broadcast_varname = input_name
                    else:
                        broadcast_varname = unique_name.generate(input_name +
                                                                 "@BroadCast")
                        input_var = main_block.var(input_name)
                        new_var = main_block.create_var(name=broadcast_varname,
                                                        shape=input_var.shape,
                                                        dtype=input_var.dtype,
                                                        persistable=False)
                        ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
                            input_var)
                        out_var_dist_attr = set_var_dist_attr(
                            self._dist_context, new_var,
                            ref_dist_attr.dims_mapping,
                            ref_dist_attr.process_mesh)
                        op._rename_input(input_name, broadcast_varname)

                    _insert_init_and_broadcast_op(main_block, idx,
                                                  broadcast_varname,
                                                  sharding_info.local_rank,
                                                  root_rank,
                                                  sharding_info.group.id,
                                                  op.attr('op_role'),
                                                  self._dist_context)

            for idx, op in reversed(list(enumerate(main_block.ops))):
                if op.type != "cast":
                    continue
                input_name = op.input_arg_names[0]
                output_name = op.output_arg_names[0]
                if input_name in not_used_param_nane:
                    main_block._remove_op(idx, sync=False)
                    main_block._remove_var(output_name, sync=False)

            for idx, op in reversed(list(enumerate(startup_block.ops))):
                assert len(op.output_arg_names) == 1
                output_name = op.output_arg_names[0]

                if op.type == "c_broadcast" and op.attr(
                        "ring_id") in dp_ring_ids:
                    if self.outer_dp_group and sharding_info.get_var_rank(
                            output_name) == sharding_info.local_rank:
                        op._set_attr("ring_id", self.outer_dp_group.id)
                    else:
                        startup_block._remove_op(idx, sync=False)
                    continue

                if op.type != "c_broadcast" and output_name in param_usage and sharding_info.get_var_rank(
                        output_name) != sharding_info.local_rank:
                    startup_block._remove_op(idx, sync=False)

            for param_name in param_usage:
                if sharding_info.get_var_rank(
                        param_name) != sharding_info.local_rank:
                    main_block._remove_var(param_name, sync=False)
                    startup_block._remove_var(param_name, sync=False)

        main_block._sync_with_cpp()
        startup_block._sync_with_cpp()