Esempio n. 1
0
    def synchronize(self):
        for h in self.handles:
            handle, names, tensors, comm_tensors = h
            hvd.synchronize(handle)
            #name = 'merged_tensor_comm_'+','.join(names)
            name = ','.join(names)

            offset = 0
            buf = self.merged_tensors[name]
            if self.fp16:
                buf = buf.float()
            for i, t in enumerate(tensors):
                numel = comm_tensors[i].numel()
                comm_tensor = buf.data[offset:offset + numel]

                if self.symmetric:
                    lower_indices = torch.tril_indices(t.shape[0],
                                                       t.shape[1],
                                                       device=t.device)
                    upper_indices = torch.triu_indices(t.shape[0],
                                                       t.shape[1],
                                                       device=t.device)
                    t[upper_indices[0], upper_indices[1]] = comm_tensor.view(
                        comm_tensors[i].shape)
                    t[lower_indices[0],
                      lower_indices[1]] = t.t()[lower_indices[0],
                                                lower_indices[1]]
                else:
                    t.copy_(comm_tensor.view(t.shape))
                offset += numel
        self.handles.clear()
Esempio n. 2
0
 def synchronize(self):
     for h in self.handles:
         hvd.synchronize(h)
     self.handles.clear()
     if self.merge:
         self._tensor_group.pull_alltensors()
         self._tensor_group.clear_group_flags()
Esempio n. 3
0
 def synchronize(self):
     for h in self.handles:
         hvd.synchronize(h)
     if self.merge:
         self._tensor_group.pull_alltensors()
         self._tensor_group.clear_group_flags()
     for name in self._name_tensors:
         tensor, comm_tensor = self._name_tensors[name]
         if self.symmetric:
             if self.fp16:
                 comm_tensor = comm_tensor.float()
             lower_indices = torch.tril_indices(tensor.shape[0],
                                                tensor.shape[1],
                                                device=tensor.device)
             upper_indices = torch.triu_indices(tensor.shape[0],
                                                tensor.shape[1],
                                                device=tensor.device)
             tensor[upper_indices[0], upper_indices[1]] = comm_tensor
             tensor[lower_indices[0],
                    lower_indices[1]] = tensor.t()[lower_indices[0],
                                                   lower_indices[1]]
         else:
             if self.fp16:
                 comm_tensor = comm_tensor.float()
                 tensor.copy_(comm_tensor)
         if self.op == hvd.Average:
             tensor.div_(hvd.size())
     self._name_tensors.clear()
     self.handles.clear()
    def _broadcast_eigendecomp(self):
        """Broadcasts the eigendecompositions for all layers

        Note: we use `op=hvd.Sum` to simulate an allgather`. Each rank will
        either compute the eigendecomposition for a factor or just return
        zeros so we sum instead of averaging.
        """
        handles = []
        rank = hvd.rank()

        for i, m in enumerate(self.modules):
            rank_a = self.m_dA_ranks[m]
            rank_g = self.m_dG_ranks[m]
            name = self.module_names[i]

            h = hvd.broadcast_async_(self.m_QA[m], rank_a, name=name + 'mQA')
            handles.append(h)
            h = hvd.broadcast_async_(self.m_dA[m], rank_a, name=name + 'mdA')
            handles.append(h)
            h = hvd.broadcast_async_(self.m_QG[m], rank_g, name=name + 'mQG')
            handles.append(h)
            h = hvd.broadcast_async_(self.m_dG[m], rank_g, name=name + 'mdG')
            handles.append(h)

        for handle in handles:
            hvd.synchronize(handle)
    def _allreduce_factors(self):
        """Allreduce the factors for all layers"""
        handles = []

        for m in self.modules:
            handles.append(hvd.allreduce_async_(self.m_A[m].data, op=hvd.Average))
            handles.append(hvd.allreduce_async_(self.m_G[m].data, op=hvd.Average))

        for handle in handles:
            hvd.synchronize(handle)
Esempio n. 6
0
def test_allgather():
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    tensor = torch.rand(10).float().cuda()
    print('rank: ', rank, ', tensor: ', tensor)
    #handle = hvd.allgather_async(tensor)
    #tensor = hvd.synchronize(handle)
    handle = hvd.broadcast_async(tensor, 0)
    hvd.synchronize(handle)
    print('---------')
    print('rank: ', rank, ', tensor: ', tensor)
Esempio n. 7
0
    def _broadcast_precon_grads(self):
        handles = []

        for i, m in enumerate(self.modules):
            rank_a, rank_g = self.module_ranks[m]
            assert rank_a == rank_g
            name = self.module_names[i]
            v = self.m_precon_grad[m]

            h = hvd.broadcast_async_(v, rank_a, name=name + 'preconGrad')
            handles.append(h)

        for handle in handles:
            hvd.synchronize(handle)
    def _broadcast_inverse_factors(self):
        handles = []

        for i, m in enumerate(self.modules):
            rank_a, rank_g = self.module_ranks[m]
            name = self.module_names[i]

            h = hvd.broadcast_async_(self.m_inv_A[m], rank_a, name=name+'inverseA')
            handles.append(h)
            h = hvd.broadcast_async_(self.m_inv_G[m], rank_g, name=name+'inverseG')
            handles.append(h)
    
        for handle in handles:
            hvd.synchronize(handle)
Esempio n. 9
0
    def _reduce_factors(self, eigen_ranks):
        """Allreduce the factors for all layers"""
        handles = []

        for m in self.modules:
            name = self.module_name_map[m]
            ranks_a, ranks_g = eigen_ranks[m]
            rank_a = ranks_a[0]
            rank_g = ranks_g[0]
            handles.append(hvd.allreduce_async_(self.m_A[m].data, op=hvd.Average))
            handles.append(hvd.allreduce_async_(self.m_G[m].data, op=hvd.Average))

        for handle in handles:
            hvd.synchronize(handle)
    def _allreduce_eigendecomp(self):
        """Allreduce the eigendecompositions for all layers

        Note: we use `op=hvd.Sum` to simulate an allgather`. Each rank will
        either compute the eigendecomposition for a factor or just return
        zeros so we sum instead of averaging.
        """
        handles = []

        for m in self.modules:
            handles.append(hvd.allreduce_async_(self.m_QA[m].data, op=hvd.Sum))
            handles.append(hvd.allreduce_async_(self.m_QG[m].data, op=hvd.Sum))

        for handle in handles:
            hvd.synchronize(handle)
    def fsp_matrix_transfer(self):
        '''
		obtain the feature maps of bottlenecks (h*w*m), reshape it to (hw*m), then do matrix multiplication (m*n)
		allgather the mm, use L2 loss on it
		:return:
		'''
        handles = []
        matrix_group = []
        for key in self.activation:
            if 'in' in key:
                fm_in = self.activation[key]
            if 'out' in key:
                fm_out = self.activation[key]

                fm_in = fm_in.view(fm_in.shape[0], fm_in.shape[1], -1)
                fm_out = fm_out.view(fm_out.shape[0], fm_out.shape[1], -1)
                fm_out = torch.transpose(fm_out, 1, 2)
                fsp_matrix = torch.bmm(fm_in, fm_out) / fm_in.shape[-1]
                matrix_group.append(fsp_matrix)
                fsp_matrix = fsp_matrix.unsqueeze(0)
                handle = hvd.allgather_async(fsp_matrix, key)
                handles.append(handle)

        fsp_loss = 0
        for idx, handle in enumerate(handles):
            rec_fsp = hvd.synchronize(handle)
            for i in range(0, hvd.size()):
                if i != self.task_id:
                    fsp_loss += self.norm_loss(matrix_group[idx], rec_fsp[i])
        fsp_loss /= (hvd.size() - 1)
        self.log_dict['transfer_count'] += 1
        return fsp_loss
 def delayedupdate(self, val):
     if self.handle is None:
         self.sum += 0
     else:
         self.sum += hvd.synchronize(self.handle)
     self.handle = hvd.allreduce_async_(val.detach().cpu(), name=self.name)
     self.n += 1
Esempio n. 13
0
 def forward(ctx, tensor, name):
     ctx.dim = tensor.shape[0]
     # we try to put all sync ops in forward pass
     ctx.all_dims = hvd.allgather(
         torch.tensor([ctx.dim], device=tensor.device)).view(hvd.size())
     handle = hvd.allgather_async(tensor, name)
     return hvd.synchronize(handle)
    def attention_transfer(self):
        def at(x):
            return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))

        handles = []
        att_group = []
        for key in self.activation:
            at_out = at(self.activation[key])
            att_group.append(at_out)
            at_numpy = at_out.data.unsqueeze(0)
            handle = hvd.allgather_async(at_numpy, key)
            handles.append(handle)
            # self.norm_loss

        att_loss = 0
        for idx, handle in enumerate(handles):
            rec_att = hvd.synchronize(handle)
            # att_loss += self.norm_loss(att_group[idx], rec_att.mean(0).cuda(self.device))
            for i in range(0, hvd.size()):
                if i != self.task_id:
                    att_loss += self.norm_loss(att_group[idx],
                                               rec_att[i].cuda(self.device))
        att_loss /= (hvd.size() - 1)
        self.log_dict['transfer_count'] += 1
        return att_loss
    def _broadcast_sparse_inv(self):
        handles = []
        rank = hvd.rank()

        for i, m in enumerate(self.modules):
            rank_a = self.m_dA_ranks[m]
            rank_g = self.m_dG_ranks[m]
            name = self.module_names[i]

            h = hvd.broadcast_async_(self.m_QA[m], rank_a, name=name + 'mQA')
            handles.append(h)
            h = hvd.broadcast_async_(self.m_QG[m], rank_g, name=name + 'mQG')
            handles.append(h)

        for handle in handles:
            hvd.synchronize(handle)
Esempio n. 16
0
    def allgather_sync(self, tensors, ranks):
        nworkers = hvd.size()
        rank = hvd.rank()
        start = 0
        sub_ranks = ranks[start:start+nworkers]
        sub_tensors = tensors[start:start+nworkers]
        while len(sub_ranks) > 0:
            #print('len(sub_ranks): ', len(sub_ranks))
            #print('len(sub_tensors): ', len(sub_tensors))
            try:
                idx = sub_ranks.index(rank)
            except:
                idx = -1
            if idx < 0:
                tensor = sub_tensors[0].new(0) 
            else:
                tensor = sub_tensors[idx]
            handle = hvd.allgather_async(tensor.view(-1))
            sync_tensors = hvd.synchronize(handle)
            offset = 0
            for i, r in enumerate(sub_ranks):
                if idx < 0:
                    continue
                original_t = sub_tensors[r]
                numel = original_t.numel()
                t = sync_tensors[offset:offset+numel]
                original_t.copy_(t.view(original_t.shape))
                offset += numel

            start += nworkers
            sub_ranks = ranks[start:start+nworkers]
            sub_tensors = tensors[start:start+nworkers]
Esempio n. 17
0
 def forward(self, handle):
     """
     Arguments:
         handle:
             Handle returned by an `AsyncAllReduce`, `AsyncAllGather`or
             `AsyncBroadcast` which will be used to retrieve `torch.Tensor`.
     """
     return hvd.synchronize(handle)
    def _allgather_factors(self):
        """Allgather the factors for all layers"""
        handles = []

        def _get_value_and_idx(sparse_tensor):
            tensor = sparse_tensor.data.view(-1)
            one_indexes = tensor != 0
            indexes = one_indexes.nonzero().data.squeeze().view(-1)
            values = tensor.data[indexes]
            return values, indexes.int()

        for i, m in enumerate(self.modules):
            module_name = self.module_names[i]

            A_values, A_indexes = _get_value_and_idx(self.m_A[m].data)
            A_value_name = module_name + '_A_value'
            A_idx_name = module_name + '_A_idx'
            h_value = allgather_async(A_values, A_value_name)
            h_idx = allgather_async(A_indexes, A_idx_name)

            G_values, G_indexes = _get_value_and_idx(self.m_G[m].data)
            G_value_name = module_name + '_G_value'
            G_idx_name = module_name + '_G_idx'
            h_value_G = allgather_async(G_values, G_value_name)
            h_idx_G = allgather_async(G_indexes, G_idx_name)
            handles.append((h_value, h_idx, h_value_G, h_idx_G))

        for i, handle in enumerate(handles):
            module_name = self.module_names[i]
            module = self.modules[i]
            m_A = self.m_A[module].view(-1)
            m_A.fill_(0.0)
            m_G = self.m_G[module].view(-1)
            m_G.fill_(0.0)

            h_value_A, h_idx_A, h_value_G, h_idx_G = handle
            A_values = hvd.synchronize(h_value_A)
            A_indexes = hvd.synchronize(h_idx_A).long()
            m_A.scatter_add_(0, A_indexes, A_values)
            m_A.div_(hvd.size())

            G_values = hvd.synchronize(h_value_G)
            G_indexes = hvd.synchronize(h_idx_G).long()
            m_G.scatter_add_(0, G_indexes, G_values)
            m_G.div_(hvd.size())
Esempio n. 19
0
def allreduce_parameters(params):
    handles = []
    if isinstance(params, dict):
        params = sorted(params.items())
    elif isinstance(params, list):
        # support both named_parameters() and regular parameters()
        params = [p if isinstance(p, tuple) else (None, p) for p in params]
    else:
        raise ValueError('invalid params of type: %s' % type(params))

    # Run asynchronous broadcasts.
    handles = []
    for name, p in params:
        handle = hvd.allreduce_async_(p, average=True, name=name)
        handles.append(handle)

    # Wait for completion.
    for handle in handles:
        hvd.synchronize(handle)
Esempio n. 20
0
    def test_horovod_allreduce_async_fused(self):
        """Test that the allreduce correctly sums 1D, 2D, 3D tensors
        with Tensor Fusion."""
        hvd.init()
        size = hvd.size()
        dtypes = [
            torch.IntTensor, torch.LongTensor, torch.FloatTensor,
            torch.DoubleTensor
        ]
        if _fp16_supported:
            dtypes += [torch.HalfTensor]
        if torch.cuda.is_available():
            dtypes += [
                torch.cuda.IntTensor, torch.cuda.LongTensor,
                torch.cuda.FloatTensor, torch.cuda.DoubleTensor
            ]
            if _fp16_supported:
                dtypes += [torch.cuda.HalfTensor]
        dims = [1, 2, 3]
        tests = []
        is_hvd_poll_false_once = False
        for dtype, dim in itertools.product(dtypes, dims):
            torch.manual_seed(1234)
            tensor = torch.FloatTensor(*([17] * dim)).random_(-100, 100)
            tensor = tensor.type(dtype)
            handle = hvd.allreduce_async(tensor, average=False)
            if not hvd.poll(handle):
                is_hvd_poll_false_once = True
            tensor, = self.convert_cpu_fp16_to_fp32(tensor)
            multiplied = tensor * size
            tests.append((dtype, multiplied, handle))

        # Make sure it's an asynchronous operation.
        assert is_hvd_poll_false_once, 'hvd.poll() always returns True, not an async op?'

        for dtype, multiplied, handle in tests:
            summed = hvd.synchronize(handle)
            summed, = self.convert_cpu_fp16_to_fp32(summed)
            max_difference = summed.sub(multiplied).max()

            # Threshold for floating point equality depends on number of
            # ranks, since we're comparing against precise multiplication.
            if size <= 3 or dtype in [
                    torch.IntTensor, torch.LongTensor, torch.cuda.IntTensor,
                    torch.cuda.LongTensor
            ]:
                threshold = 0
            elif size < 10:
                threshold = 1e-4
            elif size < 15:
                threshold = 5e-4
            else:
                break

            assert max_difference <= threshold, 'hvd.allreduce produces incorrect results'
Esempio n. 21
0
 def wait_receive(self, handles, ctx):
     tensors_decompressed = []
     for ranki in handles:
         tensors_compressed = [synchronize(h) for h in ranki]
         tensor_decompressed = self.compressor.decompress(
             tensors_compressed, ctx)
         tensors_decompressed.append(tensor_decompressed)
     tensor_aggregated = self.compressor.aggregate(tensors_decompressed)
     return (
         tensor_aggregated /
         self.world_size) if self.compressor.average else tensor_aggregated
Esempio n. 22
0
def maybe_allreduce_grads(model):
    if hvd.size() > 1:
        tstart_reduce = time.time()
        named_parameters = list(
            sorted(model.named_parameters(), key=lambda a: a[0]))
        grad_handles = []
        for name, p in named_parameters:
            if p.requires_grad:
                if p.grad is None:
                    p.grad = torch.zeros_like(p)
                with torch.no_grad():
                    grad_handles.append(hvd.allreduce_async_(p.grad,
                                                             name=name))
        for handle in grad_handles:
            hvd.synchronize(handle)
        tlogger.record_tabular("TimeElapsedAllReduce",
                               time.time() - tstart_reduce)
        if time.time() - tstart_reduce > 5:
            import socket
            tlogger.info(
                "Allreduce took more than 5 seconds for node {} (rank {})".
                format(socket.gethostname(), hvd.rank()))
Esempio n. 23
0
    def wait_receive(self, result, ctx):
        handles, tensor_sizes = result
        tensors_ag = []
        for handle, sizes in zip(handles, tensor_sizes):
            gathered = synchronize(handle)
            tensors_ag.append(gathered.split(sizes))

        list_tensor_decompressed = []
        for tensor_compressed in zip(*tensors_ag):
            tensor_decompressed = self.compressor.decompress(tensor_compressed, ctx)
            list_tensor_decompressed.append(tensor_decompressed)

        tensors_aggregated = self.compressor.aggregate(list_tensor_decompressed)
        return (tensors_aggregated / self.world_size) if self.compressor.average else tensors_aggregated
    def weights_transfer(self):
        # transfer model weights
        weights = copy.deepcopy(self.network.state_dict())
        handles = []
        for name in weights:
            # TODO: need to consider bias
            if 'weight' in name:
                # print(self.task_id, 'send', name)
                handle = hvd.allgather_async(weights[name], name)
                handles.append(handle)

        hidx = 0
        for name, param in self.network.named_parameters():
            if 'weight' in name:
                # print(self.task_id, 'rec', name)
                rec_weights = hvd.synchronize(handles[hidx])
                hidx += 1
                # print(rec_weights.shape)

                n_num = param.shape[0]
                rec_weights = list(torch.split(rec_weights, n_num, 0))

                del rec_weights[self.task_id]

                # TODO weights cat in the first dim, 2*[64,3]--> [128,3]
                # logging.info(type(rec_weights), rec_weights.shape)
                # calculate IOM of each filter
                im_list = []
                for i in range(param.shape[0]):
                    im_list.append(
                        torch.sum(torch.abs(param[i])).data.cpu().numpy())
                im_list = np.array(im_list)
                # print('minimal weight sum is {} size {}'.format(im_list.min(), im_list.shape[0]))

                for i, im in enumerate(im_list):
                    prob = 1 - stats.norm(0, 2).cdf(im)
                    if np.random.rand() < prob:
                        random_sender = np.random.randint(0, len(rec_weights))
                        new_param = rec_weights[random_sender].clone()
                        # random pic
                        random_filter = np.random.randint(
                            0, new_param.shape[0])
                        # TODO give larger weights more chance
                        weights[name][i] = new_param[random_filter]
                        self.log_dict['transfer_count'] += 1
            # self.network.state_dict()[name].copy_(param.clone())
            # TODO: maybe modify the optimizer
        self.network.load_state_dict(weights)
        hvd.allreduce(torch.zeros(1), name='Barrier')
def distributed_matmul_tn(left: Tensor, right: Tensor) -> Tensor:
    """
    Multiply two sequence tensors to obtain the result of :math:`A^{T} B`.

    Left and right inputs can be N-dimensional tensors, where the first one
    must be of size :math:`* \times \frac{T}{N} \times T` and the second one of
    size , where :math:`* \times \frac{T}{N} \times D`, where :math:`T` is the
    total length,  :math:`N` is the total number of processes available and
    :math:`D`, the dimension of the sequence. The result of this function is a
    tensor of size :math:`* \times \frac{T}{N} \times D`, that contain the
    result chunk for each process of the resulting operation.

    Inputs
    ------
    left: Tensor
        :math:`A` in :math:`A^T B`, must be of size
        :math:`* \times \frac{T}{N} \times T`
    right: Tensor
        :math:`B` in :math:`A^T B`, must be of size
        :math:`* \times \frac{T}{N} \times D`

    Returns
    -------
    result: Tensor
        For each process, this function computes the corresponding segment
        of the operation :math:`A^T B`, of size
        :math:`* \times \frac{T}{N} \times D`
    """
    cols = left.size(-1)
    world_size = get_world_size()
    rank = get_rank()

    split_size = cols // world_size
    splits = left.split(split_size, -1)
    rank_block = None

    synchronize()
    for r in range(world_size):
        rank_split = splits[r]
        rank_multiplication = torch.matmul(rank_split.transpose(-1, -2), right)
        handle = hvd.allreduce_async(rank_multiplication,
                                     name=f'matmul_tn_{r}',
                                     op=hvd.Sum)
        if r == rank:
            rank_block = hvd.synchronize(handle)
    return rank_block.contiguous()
Esempio n. 26
0
    def wait_receive(self, handles, ctx):
        tensors_compressed = []
        for h in handles:
            tensor_compressed = synchronize(h)
            tensors_compressed.append(tensor_compressed.chunk(self.world_size))

        tensors_decompressed = []
        if len(tensors_compressed) == 1:
            for tensor in tensors_compressed[0]:
                tensors_decompressed.append(
                    self.compressor.decompress([tensor], ctx))
        elif len(tensors_compressed) == 2:
            for tensor, meta in zip(tensors_compressed[0],
                                    tensors_compressed[1]):
                tensors_decompressed.append(
                    self.compressor.decompress((tensor, meta), ctx))

        tensors_decompressed = self.memory.aggregate(tensors_decompressed)
        return tensors_decompressed
Esempio n. 27
0
    def test_horovod_allreduce_async_fused(self):
        """Test that the allreduce correctly sums 1D, 2D, 3D tensors
        with Tensor Fusion."""
        hvd.init()
        size = hvd.size()
        dtypes = [torch.IntTensor, torch.LongTensor,
                  torch.FloatTensor, torch.DoubleTensor]
        if torch.cuda.is_available():
            dtypes += [torch.cuda.IntTensor, torch.cuda.LongTensor,
                       torch.cuda.FloatTensor, torch.cuda.DoubleTensor]
        dims = [1, 2, 3]
        tests = []
        is_hvd_poll_false_once = False
        for dtype, dim in itertools.product(dtypes, dims):
            torch.manual_seed(1234)
            tensor = torch.FloatTensor(*([17] * dim)).random_(-100, 100)
            tensor = tensor.type(dtype)
            handle = hvd.allreduce_async(tensor, average=False)
            if not hvd.poll(handle):
                is_hvd_poll_false_once = True
            multiplied = tensor * size
            tests.append((dtype, multiplied, handle))

        # Make sure it's an asynchronous operation.
        assert is_hvd_poll_false_once, 'hvd.poll() always returns True, not an async op?'

        for dtype, multiplied, handle in tests:
            summed = hvd.synchronize(handle)
            max_difference = summed.sub(multiplied).max()

            # Threshold for floating point equality depends on number of
            # ranks, since we're comparing against precise multiplication.
            if size <= 3 or dtype in [torch.IntTensor, torch.LongTensor,
                                      torch.cuda.IntTensor, torch.cuda.LongTensor]:
                threshold = 0
            elif size < 10:
                threshold = 1e-4
            elif size < 15:
                threshold = 5e-4
            else:
                break

            assert max_difference <= threshold, 'hvd.allreduce produces incorrect results'
    def _allgather_factors(self):
        """Allgather the factors for all layers"""
        handles = []

        def _get_value_and_idx(sparse_tensor):
            tensor = sparse_tensor.data.view(-1)
            one_indexes = tensor != 0.0
            indexes = one_indexes.nonzero().data.squeeze().view(-1)
            values = tensor.data[indexes]
            return values, indexes.int()

        for i, m in enumerate(self.modules):
            module_name = self.module_names[i]

            A_values, A_indexes = self.m_sparseA[
                m]  #_get_value_and_idx(self.m_A[m].data)
            if A_values.numel() == 0:
                continue
            A_value_name = module_name + '_A_value'
            A_idx_name = module_name + '_A_idx'
            #h_value = hvd.allgather_async(A_values, A_value_name)
            #h_idx = hvd.allgather_async(A_indexes, A_idx_name)
            h_value = hvd.allgather_async(A_values)
            h_idx = hvd.allgather_async(A_indexes)

            G_values, G_indexes = self.m_sparseG[
                m]  #_get_value_and_idx(self.m_G[m].data)
            G_value_name = module_name + '_G_value'
            G_idx_name = module_name + '_G_idx'
            #h_value_G = hvd.allgather_async(G_values, G_value_name)
            #h_idx_G = hvd.allgather_async(G_indexes, G_idx_name)
            if G_values is not None and G_values.numel() > 0:
                h_value_G = hvd.allgather_async(G_values)
                h_idx_G = hvd.allgather_async(G_indexes)
                handles.append((h_value, h_idx, h_value_G, h_idx_G))

        num_of_workers = hvd.size()

        def _decompress(values, indices, output):
            numel = indices.numel()
            real_num_values = numel // num_of_workers
            for i in range(num_of_workers):
                tmp_values = values.data[i * real_num_values:(i + 1) *
                                         real_num_values]
                tmp_indices = indices.data[i * real_num_values:(i + 1) *
                                           real_num_values]
                output[tmp_indices] += tmp_values

        for i, handle in enumerate(handles):
            module_name = self.module_names[i]
            module = self.modules[i]
            m_A = self.m_A[module].view(-1)
            m_A.fill_(0.0)
            m_G = self.m_G[module].view(-1)
            m_G.fill_(0.0)

            h_value_A, h_idx_A, h_value_G, h_idx_G = handle
            A_values = hvd.synchronize(h_value_A)
            A_indexes = hvd.synchronize(h_idx_A).long()
            _decompress(A_values, A_indexes, m_A)
            #print(A_indexes[0])
            #print(A_values[0])
            #m_A.scatter_add_(0, A_indexes, A_values)
            m_A.div_(hvd.size())

            G_values = hvd.synchronize(h_value_G)
            G_indexes = hvd.synchronize(h_idx_G).long()
            #print('G_I: ', G_indexes[0])
            #print('G_V: ', G_values[0])
            #m_G.scatter_add_(0, G_indexes, G_values)
            _decompress(G_values, G_indexes, m_G)
            m_G.div_(hvd.size())
def upsnet_train():

    if is_master:
        logger.info('training config:{}\n'.format(pprint.pformat(config)))
    gpus = [torch.device('cuda', int(_)) for _ in config.gpus.split(',')]
    num_replica = hvd.size() if config.train.use_horovod else len(gpus)
    num_gpus = 1 if config.train.use_horovod else len(gpus)

    # create models
    train_model = eval(config.symbol)().cuda()
        
    # create optimizer
    params_lr = train_model.get_params_lr()
    # we use custom optimizer and pass lr=1 to support different lr for different weights
    optimizer = SGD(params_lr, lr=1, momentum=config.train.momentum, weight_decay=config.train.wd)
    if config.train.use_horovod:
        optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=train_model.named_parameters())
    optimizer.zero_grad()

    # create data loader
    train_dataset = eval(config.dataset.dataset)(image_sets=config.dataset.image_set.split('+'), flip=config.train.flip, result_path=final_output_path)
    val_dataset = eval(config.dataset.dataset)(image_sets=config.dataset.test_image_set.split('+'), flip=False, result_path=final_output_path, phase='val')
    if config.train.use_horovod:
        train_sampler = distributed.DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
        val_sampler = distributed.DistributedSampler(val_dataset, num_replicas=hvd.size(), rank=hvd.rank())
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train.batch_size, sampler=train_sampler, num_workers=num_gpus * 4, drop_last=False, collate_fn=train_dataset.collate)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.train.batch_size, sampler=val_sampler, num_workers=num_gpus * 4, drop_last=False, collate_fn=val_dataset.collate)
    else:
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train.batch_size, shuffle=config.train.shuffle, num_workers=num_gpus * 4 if not config.debug_mode else num_gpus * 4, drop_last=False, collate_fn=train_dataset.collate)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.train.batch_size, shuffle=False, num_workers=num_gpus * 4 if not config.debug_mode else num_gpus * 4, drop_last=False, collate_fn=val_dataset.collate)

    # preparing
    curr_iter = config.train.begin_iteration
    batch_end_callback = [Speedometer(num_replica * config.train.batch_size, config.train.display_iter)]
    metrics = []
    metrics_name = []
    if config.network.has_rpn:
        metrics.extend([AvgMetric(name='rpn_cls_loss'), AvgMetric(name='rpn_bbox_loss'),])
        metrics_name.extend(['rpn_cls_loss', 'rpn_bbox_loss'])
    if config.network.has_rcnn:
        metrics.extend([AvgMetric(name='rcnn_accuracy'), AvgMetric(name='cls_loss'), AvgMetric(name='bbox_loss'),])
        metrics_name.extend(['rcnn_accuracy', 'cls_loss', 'bbox_loss'])
    if config.network.has_mask_head:
        metrics.extend([AvgMetric(name='mask_loss'), ])
        metrics_name.extend(['mask_loss'])
    if config.network.has_fcn_head:
        metrics.extend([AvgMetric(name='fcn_loss'), ])
        metrics_name.extend(['fcn_loss'])
        if config.train.fcn_with_roi_loss:
            metrics.extend([AvgMetric(name='fcn_roi_loss'), ])
            metrics_name.extend(['fcn_roi_loss'])
    if config.network.has_panoptic_head:
        metrics.extend([AvgMetric(name='panoptic_accuracy'), AvgMetric(name='panoptic_loss'), ])
        metrics_name.extend(['panoptic_accuracy', 'panoptic_loss'])

    if config.train.resume:
        train_model.load_state_dict(torch.load(os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth')), resume=True)
        optimizer.load_state_dict(torch.load(os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth')))
        if config.train.use_horovod:
            hvd.broadcast_parameters(train_model.state_dict(), root_rank=0)
    else:
        if is_master:
            train_model.load_state_dict(torch.load(config.network.pretrained))

        if config.train.use_horovod:
            hvd.broadcast_parameters(train_model.state_dict(), root_rank=0)

    if not config.train.use_horovod:
        train_model = DataParallel(train_model, device_ids=[int(_) for _ in config.gpus.split(',')]).to(gpus[0])

    if is_master:
        batch_end_callback[0](0, 0)

    train_model.eval()

    # start training
    while curr_iter < config.train.max_iteration:
        if config.train.use_horovod:
            train_sampler.set_epoch(curr_iter)

            if config.network.use_syncbn:
                train_model.train()
                if config.network.backbone_freeze_at > 0:
                    train_model.freeze_backbone(config.network.backbone_freeze_at)
                if config.network.backbone_fix_bn:
                    train_model.resnet_backbone.eval()


            for inner_iter, batch in enumerate(train_loader):
                data, label, _ = batch
                for k, v in data.items():
                    data[k] = v if not torch.is_tensor(v) else v.cuda()
                for k, v in label.items():
                    label[k] = v if not torch.is_tensor(v) else v.cuda()

                lr = adjust_learning_rate(optimizer, curr_iter, config)
                optimizer.zero_grad()
                output = train_model(data, label)
                loss = 0
                if config.network.has_rpn:
                    loss = loss + output['rpn_cls_loss'].mean() + output['rpn_bbox_loss'].mean()
                if config.network.has_rcnn:
                    loss = loss + output['cls_loss'].mean() + output['bbox_loss'].mean() * config.train.bbox_loss_weight
                if config.network.has_mask_head:
                    loss = loss + output['mask_loss'].mean()
                if config.network.has_fcn_head:
                    loss = loss + output['fcn_loss'].mean() * config.train.fcn_loss_weight
                    if config.train.fcn_with_roi_loss:
                        loss = loss + output['fcn_roi_loss'].mean() * config.train.fcn_loss_weight * 0.2
                if config.network.has_panoptic_head:
                    loss = loss + output['panoptic_loss'].mean() * config.train.panoptic_loss_weight
                loss.backward()
                optimizer.step(lr)

                losses = []
                losses.append(allreduce_async(loss, name='train_total_loss'))
                for l in metrics_name:
                    losses.append(allreduce_async(output[l].mean(), name=l))

                loss = hvd.synchronize(losses[0]).item()
                if is_master:
                    writer.add_scalar('train_total_loss', loss, curr_iter)
                for i, (metric, l) in enumerate(zip(metrics, metrics_name)):
                    loss = hvd.synchronize(losses[i + 1]).item()
                    if is_master:
                        writer.add_scalar('train_' + l, loss, curr_iter)
                        metric.update(_, _, loss)
                curr_iter += 1


                if curr_iter in config.train.decay_iteration:
                    if is_master:
                        logger.info('decay momentum buffer')
                    for k in optimizer.state_dict()['state'].keys():
                        if 'momentum_buffer' in optimizer.state_dict()['state'][k]:
                            optimizer.state_dict()['state'][k]['momentum_buffer'].div_(10)

                if is_master:
                    if curr_iter % config.train.display_iter == 0:
                        for callback in batch_end_callback:
                            callback(curr_iter, metrics)

                    if curr_iter % config.train.snapshot_step == 0:
                        logger.info('taking snapshot ...')
                        torch.save(train_model.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.pth'))
                        torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.state.pth'))
        else:
            inner_iter = 0
            train_iterator = train_loader.__iter__()
            while inner_iter + num_gpus <= len(train_loader):
                batch = []
                for gpu_id in gpus:
                    data, label, _ = train_iterator.next()
                    for k, v in data.items():
                        data[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True)
                    for k, v in label.items():
                        label[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True)
                    batch.append((data, label))
                    inner_iter += 1
                lr = adjust_learning_rate(optimizer, curr_iter, config)
                optimizer.zero_grad()
                if config.train.use_horovod:
                    output = train_model(data, label)
                else:
                    output = train_model(*batch)

                loss = 0
                if config.network.has_rpn:
                    loss = loss + output['rpn_cls_loss'].mean() + output['rpn_bbox_loss'].mean()
                if config.network.has_rcnn:
                    loss = loss + output['cls_loss'].mean() + output['bbox_loss'].mean()
                if config.network.has_mask_head:
                    loss = loss + output['mask_loss'].mean()
                if config.network.has_fcn_head:
                    loss = loss + output['fcn_loss'].mean() * config.train.fcn_loss_weight
                    if config.train.fcn_with_roi_loss:
                        loss = loss + output['fcn_roi_loss'].mean() * config.train.fcn_loss_weight * 0.2
                if config.network.has_panoptic_head:
                    loss = loss + output['panoptic_loss'].mean() * config.train.panoptic_loss_weight
                loss.backward()
                optimizer.step(lr)
                
                losses = []
                losses.append(loss.item())
                for l in metrics_name:
                    losses.append(output[l].mean().item())

                loss = losses[0]
                if is_master:
                    writer.add_scalar('train_total_loss', loss, curr_iter)
                for i, (metric, l) in enumerate(zip(metrics, metrics_name)):
                    loss = losses[i + 1]
                    if is_master:
                        writer.add_scalar('train_' + l, loss, curr_iter)
                        metric.update(_, _, loss)
                curr_iter += 1

                if curr_iter in config.train.decay_iteration:
                    if is_master:
                        logger.info('decay momentum buffer')
                    for k in optimizer.state_dict()['state'].keys():
                        optimizer.state_dict()['state'][k]['momentum_buffer'].div_(10)

                if is_master:
                    if curr_iter % config.train.display_iter == 0:
                        for callback in batch_end_callback:
                            callback(curr_iter, metrics)


                    if curr_iter % config.train.snapshot_step == 0:
                        logger.info('taking snapshot ...')
                        torch.save(train_model.module.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.pth'))
                        torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.state.pth'))

            while True:
                try:
                    train_iterator.next()
                except:
                    break

        for metric in metrics:
            metric.reset()

        if config.train.eval_data:
            train_model.eval()

            if config.train.use_horovod:
                for inner_iter, batch in enumerate(val_loader):
                    data, label, _ = batch
                    for k, v in data.items():
                        data[k] = v if not torch.is_tensor(v) else v.cuda(non_blocking=True)
                    for k, v in label.items():
                        label[k] = v if not torch.is_tensor(v) else v.cuda(non_blocking=True)

                    with torch.no_grad():
                        output = train_model(data, label)

                    for metric, l in zip(metrics, metrics_name):
                        loss = hvd.allreduce(output[l].mean()).item()
                        if is_master:
                            metric.update(_, _, loss)

            else:
                inner_iter = 0
                val_iterator = val_loader.__iter__()
                while inner_iter + len(gpus) <= len(val_loader):
                    batch = []
                    for gpu_id in gpus:
                        data, label, _ = val_iterator.next()
                        for k, v in data.items():
                            data[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True)
                        for k, v in label.items():
                            label[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True)
                        batch.append((data, label))
                        inner_iter += 1

                    with torch.no_grad():
                        if config.train.use_horovod:
                            output = train_model(data, label)
                        else:
                            output = train_model(*batch)

                    losses = []
                    for l in metrics_name:
                        losses.append(allreduce_async(output[l].mean(), name=l) if config.train.use_horovod else output[l].mean().item())

                    for metric, loss in zip(metrics, losses):
                        loss = hvd.synchronize(loss).item() if config.train.use_horovod else loss
                        if is_master:
                            metric.update(_, _, loss)

                while True:
                    try:
                        val_iterator.next()
                    except Exception:
                        break

            s = 'Batch [%d]\t Epoch[%d]\t' % (curr_iter, curr_iter // len(train_loader))

            for metric in metrics:
                m, v = metric.get()
                s += 'Val-%s=%f,\t' % (m, v)
                if is_master:
                    writer.add_scalar('val_' + m, v, curr_iter)
                    metric.reset()
            if is_master:
                logger.info(s)

    if is_master and config.train.use_horovod:
        logger.info('taking snapshot ...')
        torch.save(train_model.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth'))
    elif not config.train.use_horovod:
        logger.info('taking snapshot ...')
        torch.save(train_model.module.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth'))
    def step(self, closure=None, epoch=None):
        """Perform one K-FAC step

        Note:
        - this function should always be called before `optimizer.step()`
        - gradients must be averaged across ranks before calling `step()`

        Args:
          closure: for compatibility with the base optimizer class.
              `closure` is ignored by KFAC
          epoch (int, optional): epoch to use for determining when to end
              the `diag_warmup` period. `epoch` is not necessary if not using
              `diag_warmup`
        """

        # Update params, used for compatibilty with `KFACParamScheduler`
        group = self.param_groups[0]
        self.lr = group['lr']
        self.damping = group['damping']
        self.fac_update_freq = group['fac_update_freq']
        self.kfac_update_freq = group['kfac_update_freq']

        updates = {}
        handles = []

        if epoch is None:
            if self.diag_warmup > 0:
                print("WARNING: diag_warmup > 0 but epoch was not passed to "
                      "KFAC.step(). Defaulting to no diag_warmup")
            diag_blocks = self.diag_blocks
        else:
            diag_blocks = self.diag_blocks if epoch >= self.diag_warmup else 1

        if hvd.size() > 1 and self.steps % self.fac_update_freq == 0:
            self.fw_merged_comm.synchronize()
            self.bw_merged_comm.synchronize()

            #for handle in self.fw_factor_handles:
            #    hvd.synchronize(handle)
            #self.fw_factor_handles.clear()
            #for handle in self.bw_factor_handles:
            #    hvd.synchronize(handle)
            #self.bw_factor_handles.clear()

        # if we are switching from no diag approx to approx, we need to clear
        # off-block-diagonal elements
        if not self.have_cleared_Q and \
                epoch == self.diag_warmup and \
                self.steps % self.kfac_update_freq == 0:
            self._clear_eigen()
            self.have_cleared_Q = True

        if self.steps % self.kfac_update_freq == 0:
            # reset rank iter so device get the same layers
            # to compute to take advantage of caching
            self.rank_iter.reset()
            handles = []

            #eigen_ranks = self._generate_eigen_ranks(epoch)
            eigen_ranks = self._generate_eigen_ranks_uniform(epoch)
            #eigen_ranks = self._generate_eigen_ranks_naive(epoch)

            for module in self.modules:
                ranks_a, ranks_g = eigen_ranks[module]
                self.m_dA_ranks[module] = ranks_a[0]
                self.m_dG_ranks[module] = ranks_g[0]
                rank_a = ranks_a[0]
                rank_g = ranks_g[0]

                self._update_eigen_A(module, ranks_a)
                h1 = hvd.broadcast_async_(self.m_QA[module], rank_a)
                h2 = hvd.broadcast_async_(self.m_dA[module], rank_a)
                self._update_eigen_G(module, ranks_g)
                h3 = hvd.broadcast_async_(self.m_QG[module], rank_g)
                h4 = hvd.broadcast_async_(self.m_dG[module], rank_g)
                handles.append((h1, h2, h3, h4))

            if hvd.size() > 1:
                #for handle in handles:
                #    hvd.synchronize(handle)
                #self._allreduce_eigendecomp()
                #self._broadcast_eigendecomp()
                pass

        for i, module in enumerate(self.modules):
            if hvd.size() > 1 and len(handles) > 0:
                h1, h2, h3, h4 = handles[i]
                hvd.synchronize(h1)
                hvd.synchronize(h2)
                hvd.synchronize(h3)
                hvd.synchronize(h4)

            grad = self._get_grad(module)
            precon_grad = self._get_preconditioned_grad(module, grad)
            updates[module] = precon_grad

        #self._update_scale_grad(updates)

        self.steps += 1
Esempio n. 31
0
def barrier():
    torch.cuda.synchronize()
    handle = hvd.broadcast_async_(sync_tensor, root_rank=0)
    hvd.synchronize(handle)