Esempio n. 1
0
    def __init__(self,
                 init_optimizer,
                 static_loss_scale=1.0,
                 dynamic_loss_scale=False,
                 dynamic_loss_args=None,
                 verbose=True,
                 dp_process_group=None,
                 partition_size=None,
                 mpu=None,
                 all_gather_partitions=True,
                 allgather_size=500000000,
                 clip_grad=0.0):

        if dp_process_group is not None and partition_size is not None:
            raise ValueError("Cannot specify both dp_process_group "
                             "and partition size")

        if dp_process_group is None:
            dp_process_group = _initialize_parameter_parallel_groups(
                partition_size)

        if not torch.cuda.is_available:
            raise SystemError("Cannot use fp16 without CUDA.")
        self.optimizer = init_optimizer

        self.verbose = verbose
        self.dp_process_group = dp_process_group

        # TODO: automatically turn off if #params > some_limit
        self.all_gather_partitions = all_gather_partitions
        self.allgather_size = allgather_size

        # param flattened by groups
        self.fp16_groups = []
        self.fp16_groups_flat = []

        #param partitioned by data parallel degree
        #this will contain a list of equal sized tensors
        #each of which will be updated by a different process
        self.parallel_partitioned_fp16_groups = []

        #a single 32-bit partition of the parallel partitioned parameters
        #that this process will update
        self.single_partition_of_fp32_groups = []

        #param partition info

        #These are the parameters in each group that will not be updated by this process directly
        self.params_not_in_partition = []

        #These are the parameters that will be updated by this process directly
        self.params_in_partition = []

        #Offset from the first paramter in the the self.params_in_partition
        #the parameter boundaries may not align with partition boundaries
        #so we need to keep track of the offset
        self.first_offset = []

        #number of elements per partition in each group
        self.partition_size = []

        partition_id = dist.get_rank(group=self.dp_process_group)

        # loop to deal with groups
        for i, param_group in enumerate(self.optimizer.param_groups):
            # push this group to list before modify
            self.fp16_groups.append(param_group['params'])

            self.fp16_groups_flat.append(
                flatten_dense_tensors_aligned(
                    self.fp16_groups[i],
                    dist.get_world_size(group=self.dp_process_group),
                    self.dp_process_group))

            # set model fp16 weight to slices of flattened buffer
            updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
                                                      self.fp16_groups[i])
            for p, q in zip(self.fp16_groups[i], updated_params):
                p.data = q.data

            #divide the flat weights into near equal paritition equal to the data parallel degree
            #each process will compute on a different part of the partition
            data_parallel_partitions = self.get_data_parallel_partitions(
                self.fp16_groups_flat[i])
            self.parallel_partitioned_fp16_groups.append(
                data_parallel_partitions)

            # a partition of the fp32 master weights that will be updated by this process
            self.single_partition_of_fp32_groups.append(
                self.parallel_partitioned_fp16_groups[i]
                [partition_id].clone().float().detach())

            # modify optimizer of have flat master weight
            self.single_partition_of_fp32_groups[
                i].requires_grad = True  # keep this in case internal optimizer uses it
            param_group['params'] = [self.single_partition_of_fp32_groups[i]]

            partition_size = len(
                self.fp16_groups_flat[i]) / dist.get_world_size(
                    group=self.dp_process_group)
            params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(
                self.fp16_groups[i], partition_size, partition_id)

            self.partition_size.append(partition_size)
            self.params_in_partition.append(params_in_partition)
            self.params_not_in_partition.append(params_not_in_partition)
            self.first_offset.append(first_offset)

        # we may have a way of fusing dynamic scale. Do not support for now
        if dynamic_loss_scale:
            self.dynamic_loss_scale = True
            if dynamic_loss_args is None:
                self.loss_scaler = DynamicLossScaler()
            else:
                self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
        else:
            self.dynamic_loss_scale = False
            self.loss_scaler = LossScaler(scale=static_loss_scale)
            self.cur_iter = 0

        self.mpu = mpu
        self.clip_grad = clip_grad

        self.overflow = False
        self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu)
Esempio n. 2
0
    def __init__(self,
                 init_optimizer,
                 static_loss_scale=1.0,
                 dynamic_loss_scale=False,
                 dynamic_loss_args=None,
                 verbose=True,
                 dp_process_group=None,
                 partition_size=None,
                 mpu=None,
                 all_gather_partitions=True,
                 allgather_size=500000000,
                 clip_grad=0.0,
                 max_elements_per_comm=5e8):

        if dp_process_group is not None and partition_size is not None:
            raise ValueError("Cannot specify both dp_process_group "
                             "and partition size")

        if dp_process_group is None:
            dp_process_group = _initialize_parameter_parallel_groups(
                partition_size)

        if not torch.cuda.is_available:
            raise SystemError("Cannot use fp16 without CUDA.")
        self.optimizer = init_optimizer

        self.verbose = verbose
        self.dp_process_group = dp_process_group

        # TODO: automatically turn off if #params > some_limit
        self.all_gather_partitions = all_gather_partitions
        self.allgather_size = allgather_size

        self.max_elements_per_comm = max_elements_per_comm
        print("max_elements_per_comm={}".format(max_elements_per_comm))

        # param flattened by groups
        self.fp16_groups = []
        self.fp16_groups_flat = []

        # Setup bookkeeping data structures depending on partitioning type

        # parallel_sub_partitioned_fp16_groups[group-idx] -> [comm-ids] -> [rank-ids]
        self.parallel_sub_partitioned_fp16_groups = []
        # same underlying data as above but viewed as: [groups] -> [rank-ids] -> [comm-ids]
        self.parallel_comm_sub_partitioned_fp16_groups = []

        # 32-bit sub-partitions of the parallel partitioned parameters
        # that this process will update
        self.local_sub_partitions_of_fp32_groups = []

        # param partition info

        # parameters in each group that will not be updated by this process directly
        self.params_not_local = []

        # parameters that will be updated by this process directly
        self.params_in_rank_sub_partitions = []

        # parameter offsets for parameters in sub-partitions. Parameter
        # boundaries may not align with sub-partition boundaries
        # so we need to keep track of the offsets
        self.params_in_rank_sub_partitions_offsets = []

        # number of elements per sub-partition in each group
        self.sub_partition_sizes = []

        # number of communication intervals for each group
        self.num_comm_intervals_per_group = []

        local_rank = dist.get_rank(group=self.dp_process_group)

        # loop to deal with groups
        for i, param_group in enumerate(self.optimizer.param_groups):
            # push this group to list before modify
            self.fp16_groups.append(param_group['params'])

            # flattens all tensors into single 1d tensor aligned with sub-partition size for later dividing
            # RS: create aligned sub-partitions
            self.fp16_groups_flat.append(
                flatten_dense_tensors_sub_partition_aligned(
                    tensor_list=self.fp16_groups[i],
                    dp=dist.get_world_size(group=self.dp_process_group),
                    max_elements_per_comm=self.max_elements_per_comm,
                    pg=self.dp_process_group))

            # TODO: I don't think this does anything?
            # set model fp16 weight to slices of flattened buffer
            updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
                                                      self.fp16_groups[i])
            for p, q in zip(self.fp16_groups[i], updated_params):
                p.data = q.data

            # divide the flat weights into near equal partition equal to the data parallel degree
            # each process will compute on a different part of the partition
            # RS: split into two layer list -> [comm-id] -> [sub-partitions per rank]
            comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
                self.get_data_parallel_sub_partitions(
                    tensor=self.fp16_groups_flat[i],
                    max_elements_per_comm=self.max_elements_per_comm,
                    world_size=dist.get_world_size(
                        group=self.dp_process_group),
                    dp_process_group=self.dp_process_group
                )
            self.parallel_comm_sub_partitioned_fp16_groups.append(
                comm_partitions)  # comm -> rank
            self.parallel_sub_partitioned_fp16_groups.append(
                dp_sub_partitions)  # rank -> comm
            self.sub_partition_sizes.append(sub_partition_size)
            self.num_comm_intervals_per_group.append(num_comm_intervals)
            # data_parallel_partitions = self.get_data_parallel_partitions(self.fp16_groups_flat[i])
            # self.parallel_partitioned_fp16_groups.append(data_parallel_partitions)

            # a partition of the fp32 master weights that will be updated by this process
            # RS: store/detach/cast our local sub-partitions
            local_sub_partitions = []
            for sub_partition in self.parallel_sub_partitioned_fp16_groups[i][
                    local_rank]:
                fp32_sub_partition = sub_partition.clone().float().detach()
                fp32_sub_partition.requires_grad = True
                local_sub_partitions.append(fp32_sub_partition)
            self.local_sub_partitions_of_fp32_groups.append(
                local_sub_partitions)

            # modify optimizer of have flat master weight
            # self.single_partition_of_fp32_groups[i].requires_grad = True # keep this in case internal optimizer uses it
            param_group['params'] = self.local_sub_partitions_of_fp32_groups[i]

            # RS: divide up the sub-partitions and keep track of offsets for each param
            # partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(group=self.dp_process_group)
            params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, \
            params_not_local = self.get_all_sub_partition_info(
                tensor_list=self.fp16_groups[i],
                all_element_intervals=element_intervals,
                local_rank=local_rank,
                world_size=dist.get_world_size(group=self.dp_process_group)
            )

            self.params_in_rank_sub_partitions.append(
                params_in_rank_sub_partition)
            self.params_not_local.append(params_not_local)
            self.params_in_rank_sub_partitions_offsets.append(
                params_in_rank_sub_partitions_offsets)

        # we may have a way of fusing dynamic scale. Do not support for now
        if dynamic_loss_scale:
            if dynamic_loss_args is None:
                self.loss_scaler = DynamicLossScaler()
            else:
                self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)

            self.dynamic_loss_scale = True

        else:
            self.dynamic_loss_scale = False
            self.loss_scaler = LossScaler(scale=static_loss_scale)
            self.cur_iter = 0

        self.mpu = mpu
        self.clip_grad = clip_grad

        self.overflow = False
        self.overflow_checker = CheckOverflow(self.fp16_groups,
                                              mpu=self.mpu,
                                              zero_reduce_scatter=True)
class FP16_DeepSpeedZeroOptimizer(object):
    """
    DeepSpeedZeroOptimizer designed to reduce the memory footprint
    required for training large deep learning models.

    For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
    https://arxiv.org/abs/1910.02054

    For usage examples, refer to TODO: DeepSpeed V2 Tutorial

    """
    def __init__(self,
                 init_optimizer,
                 static_loss_scale=1.0,
                 dynamic_loss_scale=False,
                 dynamic_loss_args=None,
                 verbose=True,
                 dp_process_group=None,
                 partition_size=None,
                 mpu=None,
                 all_gather_partitions=True,
                 allgather_size=500000000,
                 clip_grad=0.0):

        if dp_process_group is not None and partition_size is not None:
            raise ValueError("Cannot specify both dp_process_group "
                             "and partition size")

        if dp_process_group is None:
            dp_process_group = _initialize_parameter_parallel_groups(
                partition_size)

        if not torch.cuda.is_available:
            raise SystemError("Cannot use fp16 without CUDA.")
        self.optimizer = init_optimizer

        self.verbose = verbose
        self.dp_process_group = dp_process_group

        # TODO: automatically turn off if #params > some_limit
        self.all_gather_partitions = all_gather_partitions
        self.allgather_size = allgather_size

        # param flattened by groups
        self.fp16_groups = []
        self.fp16_groups_flat = []

        #param partitioned by data parallel degree
        #this will contain a list of equal sized tensors
        #each of which will be updated by a different process
        self.parallel_partitioned_fp16_groups = []

        #a single 32-bit partition of the parallel partitioned parameters
        #that this process will update
        self.single_partition_of_fp32_groups = []

        #param partition info

        #These are the parameters in each group that will not be updated by this process directly
        self.params_not_in_partition = []

        #These are the parameters that will be updated by this process directly
        self.params_in_partition = []

        #Offset from the first paramter in the the self.params_in_partition
        #the parameter boundaries may not align with partition boundaries
        #so we need to keep track of the offset
        self.first_offset = []

        #number of elements per partition in each group
        self.partition_size = []

        partition_id = dist.get_rank(group=self.dp_process_group)

        # loop to deal with groups
        for i, param_group in enumerate(self.optimizer.param_groups):
            # push this group to list before modify
            self.fp16_groups.append(param_group['params'])

            self.fp16_groups_flat.append(
                flatten_dense_tensors_aligned(
                    self.fp16_groups[i],
                    dist.get_world_size(group=self.dp_process_group),
                    self.dp_process_group))

            # set model fp16 weight to slices of flattened buffer
            updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
                                                      self.fp16_groups[i])
            for p, q in zip(self.fp16_groups[i], updated_params):
                p.data = q.data

            #divide the flat weights into near equal paritition equal to the data parallel degree
            #each process will compute on a different part of the partition
            data_parallel_partitions = self.get_data_parallel_partitions(
                self.fp16_groups_flat[i])
            self.parallel_partitioned_fp16_groups.append(
                data_parallel_partitions)

            # a partition of the fp32 master weights that will be updated by this process
            self.single_partition_of_fp32_groups.append(
                self.parallel_partitioned_fp16_groups[i]
                [partition_id].clone().float().detach())

            # modify optimizer of have flat master weight
            self.single_partition_of_fp32_groups[
                i].requires_grad = True  # keep this in case internal optimizer uses it
            param_group['params'] = [self.single_partition_of_fp32_groups[i]]

            partition_size = len(
                self.fp16_groups_flat[i]) / dist.get_world_size(
                    group=self.dp_process_group)
            params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(
                self.fp16_groups[i], partition_size, partition_id)

            self.partition_size.append(partition_size)
            self.params_in_partition.append(params_in_partition)
            self.params_not_in_partition.append(params_not_in_partition)
            self.first_offset.append(first_offset)

        # we may have a way of fusing dynamic scale. Do not support for now
        if dynamic_loss_scale:
            if dynamic_loss_args is None:
                self.loss_scaler = DynamicLossScaler()
            else:
                self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)

            self.dynamic_loss_scale = True

        else:
            self.dynamic_loss_scale = False
            self.cur_iter = 0

        self.mpu = mpu
        self.clip_grad = clip_grad

        self.overflow = False
        self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu)

    #views the tensor as multiple partitions and returns
    #those partitions
    def get_data_parallel_partitions(self, tensor):
        partitions = []

        dp = dist.get_world_size(group=self.dp_process_group)
        total_num_elements = tensor.numel()

        base_size = total_num_elements // dp
        remaining = total_num_elements % dp

        start = 0
        for id in range(dp):
            partition_size = base_size
            if id < remaining:
                partition_size = partition_size + 1
            partitions.append(tensor.narrow(0, start, partition_size))
            start = start + partition_size
        return partitions

    def get_partition_info(self, tensor_list, partition_size, partition_id):
        params_in_partition = []
        params_not_in_partition = []

        start_index = partition_size * partition_id
        end_index = partition_size * (partition_id + 1)

        current_index = 0
        first_offset = 0

        for tensor in tensor_list:

            tensor_size = tensor.numel()

            if (current_index >= start_index and current_index < end_index):
                params_in_partition.append(tensor)

            elif start_index > current_index and start_index < (current_index +
                                                                tensor_size):
                params_in_partition.append(tensor)

                assert (
                    first_offset == 0
                ), "This can happen either zero or only once as this must be the first tensor in the partition"
                first_offset = start_index - current_index

            else:
                params_not_in_partition.append(tensor)

            current_index = current_index + tensor_size

        return params_in_partition, params_not_in_partition, first_offset

    def zero_grad(self, set_grads_to_None=True):
        """
        Zero FP16 parameter grads.
        """
        # FP32 grad should never exist.
        # For speed, set model fp16 grad to None by default
        for group in self.fp16_groups:
            for p in group:
                if set_grads_to_None:
                    p.grad = None
                else:
                    if p.grad is not None:
                        p.grad.detach_()
                        p.grad.zero_()

    #creates a flat fused tensor from the tensor list starting at the first_offset
    #in the first tensor of the list. If there are not enough elements in the tensor
    #list then the flat tensor will be padded with zeros
    def get_flat_partition(self,
                           tensor_list,
                           first_offset,
                           partition_size,
                           dtype=None):
        flat_tensor_list = []
        current_size = 0

        if dtype is None:
            dtype = tensor_list[0].dtype

        for i, tensor in enumerate(tensor_list):
            if tensor.grad is None:
                tensor.grad = torch.zeros(tensor.size(),
                                          dtype=tensor.dtype,
                                          device=tensor.device)
            tensor = tensor.grad
            num_elements = tensor.numel()
            tensor_offset = 0

            #we need to offset to get to the right element
            if i == 0 and first_offset > 0:
                tensor_offset = first_offset
                num_elements = num_elements - tensor_offset

            #we dont need all elements of the tensor
            if num_elements > (partition_size - current_size):
                num_elements = partition_size - current_size

            #we need a narrow view of the tensor based on the tensor offset and number of elements that
            #we need from this tensor
            if tensor_offset > 0 or num_elements < tensor.numel():
                flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
                    0, int(tensor_offset), int(num_elements)).to(dtype))
            else:
                flat_tensor_list.append(tensor.to(dtype))

            current_size = current_size + num_elements

        #this means its the last partition and does not align with the dp boundary. We need to pad before flattening
        if current_size < partition_size:
            flat_tensor_list.append(
                torch.zeros(int(partition_size - current_size),
                            dtype=dtype,
                            device=tensor_list[0].device))

        return _flatten_dense_tensors(flat_tensor_list)

    def free_grad_in_param_list(self, param_list):
        for p in param_list:
            p.grad = None

    def see_memory_usage(self):
        print("Memory Allocated ",
              torch.cuda.memory_allocated() / (1024 * 1024 * 1024),
              "GigaBytes")
        print("Max Memory Allocated ",
              torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),
              "GigaBytes")
        print("Cache Allocated ",
              torch.cuda.memory_cached() / (1024 * 1024 * 1024), "GigaBytes")
        print("Max cache Allocated ",
              torch.cuda.max_memory_cached() / (1024 * 1024 * 1024),
              "GigaBytes")

    def print_first_n(self, caption, tensor, n=10):
        if tensor is not None:
            print(
                caption,
                tensor.data.contiguous().view(-1).narrow(0, 0,
                                                         n).cpu().numpy())
        else:
            print(caption, None)

    def step(self, closure=None):
        """
        Not supporting closure.
        """
        # First compute norm for all group so we know if there is overflow

        self.overflow = self.overflow_checker.check()

        prev_scale = self.loss_scale
        self._update_scale(self.overflow)
        if self.overflow:
            self.zero_grad()
            if self.verbose:
                print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
                      "scale: {}, reducing to {}".format(
                          prev_scale, self.loss_scale))
            return self.overflow

        norm_groups = []
        single_partition_grad_groups = []

        partition_id = dist.get_rank(group=self.dp_process_group)
        for i, group in enumerate(self.fp16_groups):

            norm_groups.append(get_grad_norm(group, mpu=self.mpu))

            #free gradients for all the parameters that are not updated by this process
            self.free_grad_in_param_list(self.params_not_in_partition[i])

            #create a flat gradients for parameters updated by this process
            single_grad_partition = self.get_flat_partition(
                self.params_in_partition[i],
                self.first_offset[i],
                self.partition_size[i],
                dtype=self.single_partition_of_fp32_groups[i].dtype)

            self.single_partition_of_fp32_groups[
                i].grad = single_grad_partition

            #release all the gradient since we have already created a necessary copy in dp_grad_partition
            self.free_grad_in_param_list(self.params_in_partition[i])

            single_partition_grad_groups.append(single_grad_partition)

        self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)

        self.optimizer.step()

        #get rid of the fp32 gradients. Not needed anymore
        for group in self.single_partition_of_fp32_groups:
            group.grad = None

        for i in range(len(norm_groups)):
            for fp16_partitions, fp32_partition in zip(
                    self.parallel_partitioned_fp16_groups,
                    self.single_partition_of_fp32_groups):
                fp16_partitions[partition_id].data.copy_(fp32_partition.data)

        dp_world_size = dist.get_world_size(group=self.dp_process_group)
        #gather the updated weights from everyone
        for _, partitioned_params in enumerate(
                self.parallel_partitioned_fp16_groups):
            if self.all_gather_partitions:
                # controllable memory-time tradeoff
                num_shards = max(
                    1, partitioned_params[partition_id].numel() *
                    dp_world_size // self.allgather_size)
                shard_size = partitioned_params[partition_id].numel(
                ) // num_shards
                num_elements = shard_size
                for shard_id in range(num_shards + 1):
                    if shard_id == num_shards:
                        if shard_size * num_shards >= partitioned_params[
                                partition_id].numel():
                            break
                        else:
                            num_elements = partitioned_params[
                                partition_id].numel() - shard_id * shard_size
                    shard_list = []
                    for dp_id in range(dp_world_size):
                        curr_shard = partitioned_params[dp_id].narrow(
                            0, shard_id * shard_size, num_elements)
                        shard_list.append(curr_shard)
                    dist.all_gather(shard_list,
                                    shard_list[partition_id],
                                    group=self.dp_process_group)
            else:
                #this should require less memory but should be faster
                for src, partitioned_param in enumerate(partitioned_params):
                    global_src = _get_global_rank(self.dp_process_group, src)
                    dist.broadcast(partitioned_param,
                                   global_src,
                                   group=self.dp_process_group)

        # TODO: we probably don't need this? just to be safe
        for i in range(len(norm_groups)):
            updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
                                                      self.fp16_groups[i])
            for p, q in zip(self.fp16_groups[i], updated_params):
                p.data = q.data

        return self.overflow

    def unscale_and_clip_grads(self, grad_groups_flat, norm_groups):
        total_norm = 0.0
        for norm in norm_groups:
            total_norm += norm**2.0
        total_norm = math.sqrt(total_norm)

        # compute combined scale factor for this group
        combined_scale = self.loss_scale
        if self.clip_grad > 0.:
            # norm is in fact norm*scale
            clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
            if clip > 1:
                combined_scale = clip * self.loss_scale

        for grad in grad_groups_flat:
            grad.data.mul_(1. / combined_scale)

    def backward(self, loss, retain_graph=False):
        self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)

    def _update_scale(self, has_overflow=False):
        self.loss_scaler.update_scale(has_overflow)

    # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
    def _get_state(self):
        return self.optimizer.state

    def _set_state(self, value):
        self.optimizer.state = value

    state = property(_get_state, _set_state)

    # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
    # (for example, to adjust the learning rate)
    def _get_param_groups(self):
        return self.optimizer.param_groups

    def _set_param_groups(self, value):
        self.optimizer.param_groups = value

    param_groups = property(_get_param_groups, _set_param_groups)

    # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
    def _get_loss_scale(self):
        return self.loss_scaler.loss_scale

    def _set_loss_scale(self, value):
        self.loss_scaler.cur_scale = value

    loss_scale = property(_get_loss_scale, _set_loss_scale)
    cur_scale = property(_get_loss_scale, _set_loss_scale)

    def state_dict(self):
        """
        Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
        This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
        of the contained Pytorch optimizer.
        Example::
            checkpoint = {}
            checkpoint['model'] = model.state_dict()
            checkpoint['optimizer'] = optimizer.state_dict()
            torch.save(checkpoint, "saved.pth")
        """
        state_dict = {}
        state_dict['loss_scaler'] = self.loss_scaler
        state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
        state_dict['overflow'] = self.overflow
        state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
        state_dict[
            'single_partition_of_fp32_groups'] = self.single_partition_of_fp32_groups
        return state_dict

    def load_state_dict(self, state_dict, load_optimizer_states=True):
        """
        Loads a state_dict created by an earlier call to state_dict().
        If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
        whose parameters in turn came from ``model``, it is expected that the user
        will call ``model.load_state_dict()`` before
        ``fp16_optimizer_instance.load_state_dict()`` is called.
        Example::
            model = torch.nn.Linear(D_in, D_out).cuda().half()
            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
            optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
            ...
            checkpoint = torch.load("saved.pth")
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        """
        # I think it should actually be ok to reload the optimizer before the model.
        self.loss_scaler = state_dict['loss_scaler']
        self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
        self.overflow = state_dict['overflow']
        if load_optimizer_states:
            self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])

        for current, saved in zip(
                self.single_partition_of_fp32_groups,
                state_dict['single_partition_of_fp32_groups']):
            current.data.copy_(saved.data)