def dist_collect(x): """ collect all tensor from all GPUs args: x: shape (mini_batch, ...) returns: shape (mini_batch * num_gpu, ...) """ x = x.contiguous() out_list = [ torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size()) ] out_list = functional.all_gather(out_list, x) return torch.cat(out_list, dim=0).contiguous()
def test_all_gather(): if dist.get_rank() == 0: print("ALL GATHER TEST\n") dist.barrier() x = torch.tensor(3., requires_grad=True) y = (dist.get_rank() + 1) * x print(dist.get_rank(), "Sending y:", y) z = distops.all_gather(list(torch.zeros(dist.get_world_size())), y, next_backprop=None, inplace=True) print(dist.get_rank(), "Received tensor:", z) l = torch.sum(torch.stack(z)) l = l * (dist.get_rank() + 1) l.backward() print(dist.get_rank(), "Gradient with MPI:", x.grad) dist.barrier() if dist.get_rank() == 0: print() x = [ torch.tensor(3., requires_grad=True) for i in range(dist.get_world_size()) ] res = [] for i in range(1, dist.get_world_size() + 1): res.append(i * x[i - 1]) res2 = [] for i in range(dist.get_world_size()): temp = [] for j in range(dist.get_world_size()): temp.append(torch.clone(res[j])) res2.append(temp) l_s = [torch.sum(torch.stack(i)) for i in res2] final = [(i + 1) * k for i, k in enumerate(l_s)] for i in range(dist.get_world_size() - 1): final[i].backward(retain_graph=True) final[-1].backward() for i, x_i in enumerate(x): print(i, "Gradient in single process:", x_i.grad) print('-' * 50)
def get_similarity_matrix(outputs, chunk=2, multi_gpu=False): ''' Compute similarity matrix - outputs: (B', d) tensor for B' = B * chunk - sim_matrix: (B', B') tensor ''' if multi_gpu: outputs_gathered = [] for out in outputs.chunk(chunk): gather_t = [ torch.empty_like(out) for _ in range(dist.get_world_size()) ] gather_t = torch.cat(distops.all_gather(gather_t, out)) outputs_gathered.append(gather_t) outputs = torch.cat(outputs_gathered) sim_matrix = torch.mm(outputs, outputs.t()) # (B', d), (d, B') -> (B', B') return sim_matrix
def Supervised_NT_xent(sim_matrix, labels, temperature=0.5, chunk=2, eps=1e-8, multi_gpu=False): ''' Compute NT_xent loss - sim_matrix: (B', B') tensor for B' = B * chunk (first 2B are pos samples) ''' device = sim_matrix.device if multi_gpu: gather_t = [ torch.empty_like(labels) for _ in range(dist.get_world_size()) ] labels = torch.cat(distops.all_gather(gather_t, labels)) labels = labels.repeat(2) logits_max, _ = torch.max(sim_matrix, dim=1, keepdim=True) sim_matrix = sim_matrix - logits_max.detach() B = sim_matrix.size(0) // chunk # B = B' / chunk eye = torch.eye(B * chunk).to(device) # (B', B') sim_matrix = torch.exp(sim_matrix / temperature) * (1 - eye ) # remove diagonal denom = torch.sum(sim_matrix, dim=1, keepdim=True) sim_matrix = -torch.log(sim_matrix / (denom + eps) + eps) # loss matrix labels = labels.contiguous().view(-1, 1) Mask = torch.eq(labels, labels.t()).float().to(device) #Mask = eye * torch.stack([labels == labels[i] for i in range(labels.size(0))]).float().to(device) Mask = Mask / (Mask.sum(dim=1, keepdim=True) + eps) loss = torch.sum(Mask * sim_matrix) / (2 * B) return loss
def pairwise_similarity(outputs, temperature=0.5, multi_gpu=False, adv_type='None'): ''' Compute pairwise similarity and return the matrix input: aggregated outputs & temperature for scaling return: pairwise cosine similarity ''' if multi_gpu and adv_type == 'None': B = int(outputs.shape[0] / 2) outputs_1 = outputs[0:B] outputs_2 = outputs[B:] gather_t_1 = [ torch.empty_like(outputs_1) for i in range(dist.get_world_size()) ] gather_t_1 = distops.all_gather(gather_t_1, outputs_1) gather_t_2 = [ torch.empty_like(outputs_2) for i in range(dist.get_world_size()) ] gather_t_2 = distops.all_gather(gather_t_2, outputs_2) outputs_1 = torch.cat(gather_t_1) outputs_2 = torch.cat(gather_t_2) outputs = torch.cat((outputs_1, outputs_2)) elif multi_gpu and 'Rep' in adv_type: if adv_type == 'Rep': N = 3 B = int(outputs.shape[0] / N) outputs_1 = outputs[0:B] outputs_2 = outputs[B:2 * B] outputs_3 = outputs[2 * B:3 * B] gather_t_1 = [ torch.empty_like(outputs_1) for i in range(dist.get_world_size()) ] gather_t_1 = distops.all_gather(gather_t_1, outputs_1) gather_t_2 = [ torch.empty_like(outputs_2) for i in range(dist.get_world_size()) ] gather_t_2 = distops.all_gather(gather_t_2, outputs_2) gather_t_3 = [ torch.empty_like(outputs_3) for i in range(dist.get_world_size()) ] gather_t_3 = distops.all_gather(gather_t_3, outputs_3) outputs_1 = torch.cat(gather_t_1) outputs_2 = torch.cat(gather_t_2) outputs_3 = torch.cat(gather_t_3) if N == 3: outputs = torch.cat((outputs_1, outputs_2, outputs_3)) B = outputs.shape[0] outputs_norm = outputs / (outputs.norm(dim=1).view(B, 1) + 1e-8) similarity_matrix = (1. / temperature) * torch.mm( outputs_norm, outputs_norm.transpose(0, 1).detach()) return similarity_matrix, outputs