Exemple #1
0
    def test_horovod_broadcast_inplace(self):
        """Test that the broadcast correctly broadcasts 1D, 2D, 3D tensors."""
        hvd.init()
        rank = hvd.rank()
        size = hvd.size()

        # This test does not apply if there is only one worker.
        if size == 1:
            return

        dtypes = [torch.ByteTensor, torch.CharTensor, torch.ShortTensor,
                  torch.IntTensor, torch.LongTensor, torch.FloatTensor, torch.DoubleTensor]
        if torch.cuda.is_available():
            dtypes += [torch.cuda.ByteTensor, torch.cuda.CharTensor, torch.cuda.ShortTensor,
                       torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor,
                       torch.cuda.DoubleTensor]
        dims = [1, 2, 3]
        root_ranks = list(range(size))
        for dtype, dim, root_rank in itertools.product(dtypes, dims, root_ranks):
            tensor = torch.FloatTensor(*([17] * dim)).fill_(1).mul_(rank)
            root_tensor = torch.FloatTensor(*([17] * dim)).fill_(1).mul_(root_rank)
            tensor = tensor.type(dtype)
            root_tensor = root_tensor.type(dtype)
            broadcasted_tensor = hvd.broadcast_(tensor, root_rank)
            assert (tensor == broadcasted_tensor).min() == 1, \
                'hvd.broadcast does not modify source tensor'
            assert (broadcasted_tensor == root_tensor).min() == 1, \
                'hvd.broadcast produces incorrect broadcasted tensor'
    def broadcast_buffer():
        # copy tensors into buffer_t
        offset = 0
        for t in buffer:
            numel = t.numel()
            buffer_t[offset:offset + numel].copy_(t.view(-1))
            offset += numel

        # broadcast
        hvd.broadcast_(buffer_t[:offset], root_rank)

        # copy all-reduced buffer back into tensors
        offset = 0
        for t in buffer:
            numel = t.numel()
            t.view(-1).copy_(buffer_t[offset:offset + numel])
            offset += numel
def any_broadcast(data, root_rank, max_size=4096):
    """broadcast arbitrary data from root_rank to all nodes."""
    if not hasattr(any_broadcast, '_in_buffer') or \
            max_size != any_broadcast._in_buffer.size():
        any_broadcast._buffer = torch.cuda.ByteTensor(max_size)
    buffer_ = any_broadcast._buffer

    enc = pickle.dumps(data)
    enc_size = len(enc)
    if enc_size + 2 > max_size:
        raise ValueError('encoded data exceeds max_size: {}'.format(enc_size +
                                                                    2))
    assert max_size < 255 * 256
    buffer_[0] = enc_size // 255  # this encoding works for max_size < 65k
    buffer_[1] = enc_size % 255
    buffer_[2:enc_size + 2] = torch.ByteTensor(list(enc))

    hvd.broadcast_(buffer_, root_rank)

    size = (255 * buffer_[0].item()) + buffer_[1].item()

    bytes_list = bytes(buffer_[2:size + 2].tolist())
    result = pickle.loads(bytes_list)
    return result
Exemple #4
0
def any_broadcast(data, root_rank):
    """broadcast arbitrary data from root_rank to all nodes."""
    if not hasattr(any_broadcast, '_buffer'):
        # keeps small buffer to avoid re-allocate every call
        any_broadcast._buffer = torch.cuda.ByteTensor(_BUFFER_SIZE)
    try:
        enc = msgpack.dumps(data, use_bin_type=True)
        msgpack_success = True
    except TypeError:
        enc = pickle.dumps(data)
        msgpack_success = False

    max_size = hvd.allgather(torch.tensor([len(enc)]).cuda()).max().item()
    buffer_ = any_broadcast._buffer
    buffer_, enc_byte = _encode(enc, max_size, buffer_)

    hvd.broadcast_(buffer_, root_rank)

    bytes_list, _ = _decode(buffer_, enc_byte)
    if msgpack_success:
        result = msgpack.loads(bytes_list, raw=False)
    else:
        result = pickle.loads(bytes_list)
    return result
Exemple #5
0
    def test_broadcast_state(self):
        hvd.init()

        N, D_in, H, D_out = 64, 100, 10, 10
        x = torch.randn(N, D_in).requires_grad_()
        y = torch.randn(N, D_out).requires_grad_()

        def new_optimizer(cls, opt_params, model):
            p = {
                k: v
                for k, v in opt_params.items()
                if k in inspect.getargspec(cls.__init__).args
            }
            return cls(model.parameters(), **p)

        def create_model(opt_class, opt_params):
            model = torch.nn.Sequential(
                torch.nn.Linear(D_in, H),
                torch.nn.ReLU(),
                torch.nn.Linear(H, D_out),
            )

            optimizer = new_optimizer(opt_class, opt_params, model)
            optimizer = hvd.DistributedOptimizer(
                optimizer, named_parameters=model.named_parameters())

            return model, optimizer

        def get_model_param_values(model):
            params = sorted(model.state_dict().items())
            return [(k, v.clone()) for k, v in params]

        def get_optimizer_param_values(optimizer):
            results = []
            state_dict = optimizer.state_dict()
            for group in state_dict['param_groups']:
                for param_id in group['params']:
                    if param_id not in state_dict['state']:
                        continue
                    params = sorted(state_dict['state'][param_id].items())
                    for k, v in params:
                        results.append(
                            (k, v.clone() if torch.is_tensor(v) else v))
            return results

        # L-BFGS is currently unsupported, as are sparse tensors, which are
        # required by SparseAdam optimizer
        optimizers = [
            (subclass.__name__, subclass)
            for subclass in torch.optim.Optimizer.__subclasses__()
            if subclass.__module__.startswith('torch.optim') and subclass !=
            torch.optim.LBFGS and subclass != torch.optim.SparseAdam
        ]
        optimizers.sort()

        opt_params_list = [
            dict(lr=0.2, momentum=0.9, weight_decay=0.1, centered=True),
            dict(lr=0.2)
        ]

        for (opt_name, opt_class), opt_params in itertools.product(
                optimizers, opt_params_list):
            model, optimizer = create_model(opt_class, opt_params)
            y_pred = model(x)
            loss = F.mse_loss(y_pred, y, size_average=False)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            model_param_values = get_model_param_values(model)
            for name, model_param_value in model_param_values:
                hvd.broadcast_(model_param_value, root_rank=0)

            opt_param_values_updated = []
            opt_param_values = get_optimizer_param_values(optimizer)
            for name, opt_param_value in opt_param_values:
                is_tensor = torch.is_tensor(opt_param_value)
                if not is_tensor:
                    t = type(opt_param_value)
                    opt_param_value = torch.Tensor([opt_param_value])
                hvd.broadcast_(opt_param_value, root_rank=0)
                if not is_tensor:
                    opt_param_value = t(opt_param_value.cpu().numpy()[0])
                opt_param_values_updated.append((name, opt_param_value))
            opt_param_values = opt_param_values_updated

            if hvd.rank() == 0:
                state = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }
                _, fname = tempfile.mkstemp('.pt')
                torch.save(state, fname)

            model, optimizer = create_model(opt_class, opt_params)
            if hvd.rank() == 0:
                checkpoint = torch.load(fname)
                model.load_state_dict(checkpoint['model'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                os.remove(fname)

            hvd.broadcast_parameters(model.state_dict(), root_rank=0)
            model_param_value_after = get_model_param_values(model)
            for before, after in zip(model_param_values,
                                     model_param_value_after):
                name, model_param_value = before
                name_after, model_param_value_after = after
                self.assertEqual(name, name_after)
                self.assertEqual(type(model_param_value),
                                 type(model_param_value_after))
                self.assertTrue(
                    (model_param_value == model_param_value_after).all())

            hvd.broadcast_optimizer_state(optimizer, root_rank=0)

            expected_tensors = 4
            if 'momentum' not in opt_params and opt_class == torch.optim.SGD:
                # SGD only maintains state when momentum is specified, otherwise
                # it does not populate the state dict, so it will contain no tensors.
                expected_tensors = 0
            self.assertEqual(len(optimizer.state_dict()['state'].values()),
                             expected_tensors)

            opt_param_values_after = get_optimizer_param_values(optimizer)
            for before, after in zip(opt_param_values, opt_param_values_after):
                name, opt_param_value = before
                name_after, opt_param_value_after = after
                self.assertEqual(name, name_after)
                self.assertEqual(type(opt_param_value),
                                 type(opt_param_value_after))
                if torch.is_tensor(opt_param_value):
                    self.assertTrue(
                        (opt_param_value == opt_param_value_after).all())
                else:
                    self.assertEqual(opt_param_value, opt_param_value_after)
Exemple #6
0
    def test_broadcast_state(self):
        hvd.init()

        N, D_in, H, D_out = 64, 100, 10, 10
        x = torch.autograd.Variable(torch.randn(N, D_in), requires_grad=True)
        y = torch.autograd.Variable(torch.randn(N, D_out), requires_grad=False)

        def create_model():
            model = torch.nn.Sequential(
                torch.nn.Linear(D_in, H),
                torch.nn.ReLU(),
                torch.nn.Linear(H, D_out),
            )

            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.1,
                                        momentum=0.9)
            optimizer = hvd.DistributedOptimizer(
                optimizer, named_parameters=model.named_parameters())

            return model, optimizer

        def get_model_param_value(model):
            return model.state_dict()['0.weight'].clone()

        def get_optimizer_param_value(optimizer):
            state_dict = optimizer.state_dict()
            param_id = state_dict['param_groups'][0]['params'][0]
            return state_dict['state'][param_id]['momentum_buffer'].clone()

        model, optimizer = create_model()
        y_pred = model(x)
        loss = F.mse_loss(y_pred, y, size_average=False)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        model_param_value = get_model_param_value(model)
        hvd.broadcast_(model_param_value, root_rank=0)

        opt_param_value = get_optimizer_param_value(optimizer)
        hvd.broadcast_(opt_param_value, root_rank=0)

        if hvd.rank() == 0:
            state = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            _, fname = tempfile.mkstemp('.pt')
            torch.save(state, fname)

        model, optimizer = create_model()
        if hvd.rank() == 0:
            checkpoint = torch.load(fname)
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            os.remove(fname)

        hvd.broadcast_parameters(model.state_dict(), root_rank=0)
        model_param_value_after = get_model_param_value(model)
        self.assertTrue((model_param_value == model_param_value_after).all())

        hvd.broadcast_optimizer_state(optimizer, root_rank=0)
        self.assertEqual(len(optimizer.state_dict()['state'].values()), 4)
        opt_param_value_after = get_optimizer_param_value(optimizer)
        self.assertTrue((opt_param_value == opt_param_value_after).all())
    def step(self, closure=None, epoch=None):
        """Perform one K-FAC step

        Note:
        - this function should always be called before `optimizer.step()`
        - gradients must be averaged across ranks before calling `step()`

        Args:
          closure: for compatibility with the base optimizer class.
              `closure` is ignored by KFAC
          epoch (int, optional): epoch to use for determining when to end
              the `diag_warmup` period. `epoch` is not necessary if not using
              `diag_warmup`
        """

        # Update params, used for compatibilty with `KFACParamScheduler`
        group = self.param_groups[0]
        self.lr = group['lr']
        self.damping = group['damping']
        self.fac_update_freq = group['fac_update_freq']
        self.kfac_update_freq = group['kfac_update_freq']
        #print('fac_update_freq: ', self.fac_update_freq)
        #print('kfac_update_freq: ', self.kfac_update_freq)

        updates = {}
        handles = []

        if epoch is None:
            if self.diag_warmup > 0:
                print("WARNING: diag_warmup > 0 but epoch was not passed to "
                      "KFAC.step(). Defaulting to no diag_warmup")
            diag_blocks = self.diag_blocks
        else:
            diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1

        if hvd.size() > 1 and self.steps % self.fac_update_freq == 0:
            self.fw_merged_comm.synchronize()
            self.bw_merged_comm.synchronize()

        # if we are switching from no diag approx to approx, we need to clear
        # off-block-diagonal elements
        if not self.have_cleared_Q and \
                epoch == self.diag_warmup and \
                self.steps % self.kfac_update_freq == 0:
            self._clear_eigen()
            self.have_cleared_Q = True
        torch.cuda.synchronize()

        if self.steps % self.kfac_update_freq == 0:
            # reset rank iter so device get the same layers
            # to compute to take advantage of caching
            self.rank_iter.reset()
            handles = []

            #eigen_ranks = self._generate_eigen_ranks(epoch)
            #eigen_ranks = self._generate_eigen_ranks_uniform(epoch)
            eigen_ranks = self._generate_eigen_ranks_naive(epoch)
            #inverse_As = []
            #A_ranks = []
            #inverse_Gs = []
            #G_ranks = []
            rank_to_tensors = {}

            for module in self.modules:
                ranks_a, ranks_g = eigen_ranks[module]
                self.m_dA_ranks[module] = ranks_a[0]
                self.m_dG_ranks[module] = ranks_g[0]
                rank_a = ranks_a[0]
                rank_g = ranks_g[0]

                name = self.module_name_map[module]
                if not self.exclude_compute_inverse:
                    self._update_inverse_A(module, ranks_a)
                if not self.exclude_communicate_inverse:
                    if hvd.size() > 1 and rank_a >= 0:
                        self.multi_comm.bcast_async_([name + 'mQA'],
                                                     [self.m_QA[module]],
                                                     rank_a)

                if not self.exclude_compute_inverse:
                    self._update_inverse_G(module, ranks_g)
                if not self.exclude_communicate_inverse:
                    if hvd.size() > 1 and rank_g >= 0:
                        self.multi_comm.bcast_async_([name + 'mQG'],
                                                     [self.m_QG[module]],
                                                     rank_g)
            if self.exclude_communicate_inverse and not self.exclude_compute_inverse:
                # should have a barriar
                if hvd.size() > 1:
                    barrier()

        if hvd.size() > 1 and self.steps % self.kfac_update_freq == 0:
            self.multi_comm.synchronize()

        for i, module in enumerate(self.modules):

            grad = self._get_grad(module)
            precon_grad = self._get_preconditioned_grad(module, grad)
            updates[module] = precon_grad

        self._update_scale_grad(updates)

        if self.dynamic_merge and hvd.size(
        ) > 1 and self.steps % self.kfac_update_freq == 0:
            if self.steps == 5:
                self.profiling = True
            elif self.steps == 25:
                fw_layerwise_times = torch.tensor(
                    self.fw_profiler.get_results())
                bw_layerwise_times = torch.tensor(
                    self.bw_profiler.get_results())
                hvd.broadcast_(fw_layerwise_times, root_rank=0)
                hvd.broadcast_(bw_layerwise_times, root_rank=0)
                fw_layerwise_times = fw_layerwise_times.numpy()
                bw_layerwise_times = bw_layerwise_times.numpy()
                if hvd.rank() == 0:
                    pass
                    #logger.info('fw_layerwise_times: %s, sum: %f', fw_layerwise_times, np.sum(fw_layerwise_times))
                    #logger.info('bw_layerwise_times: %s, sum: %f', bw_layerwise_times, np.sum(bw_layerwise_times))

                fw_factor_sizes = [
                    self.fw_factor_sizes[m] for m in self.module_names
                ]
                bw_factor_sizes = [
                    self.bw_factor_sizes[m] for m in self.module_names[::-1]
                ]
                self.fw_merged_comm.update_groups(fw_factor_sizes,
                                                  fw_layerwise_times,
                                                  reverse=False)
                self.bw_merged_comm.update_groups(bw_factor_sizes,
                                                  bw_layerwise_times,
                                                  reverse=True)
                self.profiling = False

        self.steps += 1
    def step(self, closure=None, epoch=None):
        """Perform one K-FAC step

        Note:
        - this function should always be called before `optimizer.step()`
        - gradients must be averaged across ranks before calling `step()`

        Args:
          closure: for compatibility with the base optimizer class.
              `closure` is ignored by KFAC
          epoch (int, optional): epoch to use for determining when to end
              the `diag_warmup` period. `epoch` is not necessary if not using
              `diag_warmup`
        """

        # Update params, used for compatibilty with `KFACParamScheduler`
        group = self.param_groups[0]
        self.lr = group['lr']
        self.damping = group['damping']
        self.fac_update_freq = group['fac_update_freq']
        self.kfac_update_freq = group['kfac_update_freq']

        updates = {}
        handles = []

        if epoch is None:
            if self.diag_warmup > 0:
                print("WARNING: diag_warmup > 0 but epoch was not passed to "
                      "KFAC.step(). Defaulting to no diag_warmup")
            diag_blocks = self.diag_blocks
        else:
            diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1

        if self.steps % self.fac_update_freq == 0:

            if self.eigen_ranks is None:
                if not self.exclude_compute_factor:
                    self._update_A()
                    self._update_G()
                self.eigen_ranks, self.fusion_groups_A, self.fusion_groups_G = self._generate_eigen_ranks_blockpartition_opt(
                    epoch)
                if not self.exclude_communicate_factor:
                    if hvd.size() > 1:
                        self._reduce_factors(self.eigen_ranks)
                self.fw_merged_comm.init_tensor_group(self.reduce_module_names)
                self.bw_merged_comm.init_tensor_group(
                    self.reduce_module_names[::-1])
                if hvd.rank() == 0:
                    print('module_names: ', self.module_names)
                    print('fusion_groups_A: ', self.fusion_groups_A)

                self.fw_merged_comm.update_tensor_fusion(self.fusion_groups_A)
            else:  # starting from the 2nd iteration
                if not self.exclude_communicate_factor:
                    if hvd.size() > 1:
                        self.fw_merged_comm.synchronize()
                        self.bw_merged_comm.synchronize()
                        self.fw_allreduce_comm.synchronize()
                        self.bw_allreduce_comm.synchronize()
            eigen_ranks = self.eigen_ranks

        # if we are switching from no diag approx to approx, we need to clear
        # off-block-diagonal elements
        if not self.have_cleared_Q and \
                epoch == self.diag_warmup and \
                self.steps % self.kfac_update_freq == 0:
            self._clear_eigen()
            self.have_cleared_Q = True

        if self.steps % self.kfac_update_freq == 0:
            # reset rank iter so device get the same layers
            # to compute to take advantage of caching
            self.rank_iter.reset()
            handles = []

            merged_name_AGs = [[]] * hvd.size()
            merged_tensor_AGs = [[]] * hvd.size()
            for i, module in enumerate(self.modules):
                ranks_a, ranks_g = eigen_ranks[module]
                self.m_dA_ranks[module] = ranks_a[0]
                self.m_dG_ranks[module] = ranks_g[0]
                rank_a = ranks_a[0]
                rank_g = ranks_g[0]
                name = self.module_name_map[module]

                if not self.exclude_compute_inverse:
                    self._update_eigen_A(module, ranks_a)

                if not self.exclude_compute_inverse:
                    self._update_eigen_G(module, ranks_g)

                merged_name_AGs[rank_a].append(name + '-A')
                merged_name_AGs[rank_g].append(name + '-G')
                merged_tensor_AGs[rank_a].append(self.m_QA[module])
                merged_tensor_AGs[rank_g].append(self.m_QG[module])

            if not self.exclude_communicate_inverse:
                if hvd.size() > 1:
                    #for rank, names in enumerate(merged_name_AGs):
                    #    merged_names = merged_name_AGs[rank]
                    #    merged_tensors = merged_tensor_AGs[rank]
                    #    self.multi_comm.bcast_async_(merged_names, merged_tensors, rank)
                    self._broadcast_eigendecomp()

            if self.exclude_communicate_inverse and not self.exclude_compute_inverse:
                # should have a barriar
                if hvd.size() > 1:
                    barrier()

        if hvd.size() > 1 and self.steps % self.kfac_update_freq == 0:
            self.multi_comm.synchronize()

        for i, module in enumerate(self.modules):
            grad = self._get_grad(module)
            precon_grad = self._get_preconditioned_grad(module, grad)
            updates[module] = precon_grad

        self._update_scale_grad(updates)

        if hvd.size() > 1 and self.steps % self.kfac_update_freq == 0:
            if self.steps == 5:
                self.profiling = True
            elif self.steps == 25:
                fw_layerwise_times = torch.tensor(
                    self.fw_profiler.get_results())
                bw_layerwise_times = torch.tensor(
                    self.bw_profiler.get_results())
                hvd.broadcast_(fw_layerwise_times, root_rank=0)
                hvd.broadcast_(bw_layerwise_times, root_rank=0)
                fw_layerwise_times = fw_layerwise_times.numpy()
                bw_layerwise_times = bw_layerwise_times.numpy()
                if hvd.rank() == 0:
                    pass
                fw_factor_sizes = [
                    self.fw_factor_sizes[m] for m in self.module_names
                ]
                bw_factor_sizes = [
                    self.bw_factor_sizes[m] for m in self.module_names[::-1]
                ]
                self.fw_merged_comm.update_groups(self.fusion_groups_A,
                                                  fw_factor_sizes,
                                                  fw_layerwise_times,
                                                  reverse=False)
                #self.bw_merged_comm.update_groups(self.fusion_groups_G, bw_factor_sizes, bw_layerwise_times, reverse=False)
                self.profiling = False

        self.steps += 1
Exemple #9
0
    def test_broadcast_state(self):
        hvd.init()

        N, D_in, H, D_out = 64, 100, 10, 10
        x = torch.autograd.Variable(torch.randn(N, D_in), requires_grad=True)
        y = torch.autograd.Variable(torch.randn(N, D_out), requires_grad=False)

        def create_model(create_opt):
            model = torch.nn.Sequential(
                torch.nn.Linear(D_in, H),
                torch.nn.ReLU(),
                torch.nn.Linear(H, D_out),
            )

            optimizer = create_opt(model)
            optimizer = hvd.DistributedOptimizer(
                optimizer, named_parameters=model.named_parameters())

            return model, optimizer

        def get_model_param_values(model):
            params = sorted(model.state_dict().items())
            return [(k, v.clone()) for k, v in params]

        def get_optimizer_param_values(optimizer):
            results = []
            state_dict = optimizer.state_dict()
            for group in state_dict['param_groups']:
                for param_id in group['params']:
                    params = sorted(state_dict['state'][param_id].items())
                    for k, v in params:
                        results.append(
                            (k, v.clone() if torch.is_tensor(v) else v))
            return results

        opt_params = dict(lr=0.2, momentum=0.9, weight_decay=0.1, centered=True)

        def new_optimizer(cls):
            p = {
                k: v for k, v in opt_params.items()
                if k in inspect.getargspec(cls.__init__).args
            }
            return lambda m: cls(m.parameters(), **p)

        # L-BFGS is currently unsupported, as are sparse tensors, which are
        # required by SparseAdam optimizer
        optimizers = [
            (subclass.__name__, new_optimizer(subclass))
            for subclass in torch.optim.Optimizer.__subclasses__()
            if subclass.__module__.startswith('torch.optim') and
               subclass != torch.optim.LBFGS and
               subclass != torch.optim.SparseAdam
        ]
        optimizers.sort()

        for opt_name, create_opt in optimizers:
            model, optimizer = create_model(create_opt)
            y_pred = model(x)
            loss = F.mse_loss(y_pred, y, size_average=False)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            model_param_values = get_model_param_values(model)
            for name, model_param_value in model_param_values:
                hvd.broadcast_(model_param_value, root_rank=0)

            opt_param_values_updated = []
            opt_param_values = get_optimizer_param_values(optimizer)
            for name, opt_param_value in opt_param_values:
                is_tensor = torch.is_tensor(opt_param_value)
                if not is_tensor:
                    t = type(opt_param_value)
                    opt_param_value = torch.Tensor([opt_param_value])
                hvd.broadcast_(opt_param_value, root_rank=0)
                if not is_tensor:
                    opt_param_value = t(opt_param_value.numpy()[0])
                opt_param_values_updated.append((name, opt_param_value))
            opt_param_values = opt_param_values_updated

            if hvd.rank() == 0:
                state = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }
                _, fname = tempfile.mkstemp('.pt')
                torch.save(state, fname)

            model, optimizer = create_model(create_opt)
            if hvd.rank() == 0:
                checkpoint = torch.load(fname)
                model.load_state_dict(checkpoint['model'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                os.remove(fname)

            hvd.broadcast_parameters(model.state_dict(), root_rank=0)
            model_param_value_after = get_model_param_values(model)
            for before, after in zip(model_param_values,
                                     model_param_value_after):
                name, model_param_value = before
                name_after, model_param_value_after = after
                self.assertEqual(name, name_after)
                self.assertEqual(type(model_param_value),
                                 type(model_param_value_after))
                self.assertTrue(
                    (model_param_value == model_param_value_after).all())

            hvd.broadcast_optimizer_state(optimizer, root_rank=0)
            self.assertEqual(len(optimizer.state_dict()['state'].values()), 4)

            opt_param_values_after = get_optimizer_param_values(optimizer)
            for before, after in zip(opt_param_values, opt_param_values_after):
                name, opt_param_value = before
                name_after, opt_param_value_after = after
                self.assertEqual(name, name_after)
                self.assertEqual(type(opt_param_value),
                                 type(opt_param_value_after))
                if torch.is_tensor(opt_param_value):
                    self.assertTrue(
                        (opt_param_value == opt_param_value_after).all())
                else:
                    self.assertEqual(opt_param_value, opt_param_value_after)