def forward(cls, ctx, x, weight, bias, running_mean, running_var, extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): # Save context cls._parse_extra(ctx, extra) ctx.training = training ctx.momentum = momentum ctx.eps = eps ctx.activation = activation ctx.slope = slope ctx.affine = weight is not None and bias is not None # Prepare inputs count = _count_samples(x) * (ctx.master_queue.maxsize + 1) x = x.contiguous() weight = weight.contiguous() if ctx.affine else x.new_empty(0) bias = bias.contiguous() if ctx.affine else x.new_empty(0) if ctx.training: mean, var = _backend.mean_var(x) if ctx.is_master: means, vars = [mean.unsqueeze(0)], [var.unsqueeze(0)] for _ in range(ctx.master_queue.maxsize): mean_w, var_w = ctx.master_queue.get() ctx.master_queue.task_done() means.append(mean_w.unsqueeze(0)) vars.append(var_w.unsqueeze(0)) means = comm.gather(means) vars = comm.gather(vars) mean = means.mean(0) var = (vars + (mean - means) ** 2).mean(0) tensors = comm.broadcast_coalesced((mean, var), [mean.get_device()] + ctx.worker_ids) for ts, queue in zip(tensors[1:], ctx.worker_queues): queue.put(ts) else: ctx.master_queue.put((mean, var)) mean, var = ctx.worker_queue.get() ctx.worker_queue.task_done() # Update running stats running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1)) # Mark in-place modified tensors ctx.mark_dirty(x, running_mean, running_var) else: mean, var = running_mean.contiguous(), running_var.contiguous() ctx.mark_dirty(x) # BN forward + activation _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) _act_forward(ctx, x) # Output ctx.var = var ctx.save_for_backward(x, var, weight, bias) return x
def forward(self, *inputs): if not all(input.is_cuda for input in inputs): raise TypeError('Broadcast function not implemented for CPU tensors') if len(inputs) == 0: return tuple() self.num_inputs = len(inputs) self.input_device = inputs[0].get_device() outputs = comm.broadcast_coalesced(inputs, self.target_gpus) return tuple([t for tensors in outputs for t in tensors])
def _sync_params(self): params = [p.data for p in self.module.parameters()] result = broadcast_coalesced(params, self.device_ids, self.broadcast_bucket_size) for tensors, module in zip(result[1:], self._module_copies[1:]): for tensor, param in zip(tensors, module.parameters()): param.data.set_(tensor) # cross-node buffer sync buffers = list(self.module._all_buffers()) flat_buffers = _flatten_tensors(buffers) dist.broadcast(flat_buffers, 0) for buf, synced in zip(buffers, _unflatten_tensors(flat_buffers, buffers)): buf.copy_(synced) # intra-node buffer sync result = broadcast_coalesced(buffers, self.device_ids, self.broadcast_bucket_size) for tensors, module in zip(result[1:], self._module_copies[1:]): for tensor, buf in zip(tensors, module._all_buffers()): buf.set_(tensor)
def _sync_params(self): if len(self.device_ids) > 1: # intra-node parameter sync params = [p.data for p in self.module.parameters()] result = broadcast_coalesced(params, self.device_ids, self.broadcast_bucket_size) for tensors, module in zip(result[1:], self._module_copies[1:]): for tensor, param in zip(tensors, module.parameters()): param.data.set_(tensor) # module buffer sync if self.broadcast_buffers: buffers = list(self.module._all_buffers()) if len(buffers) > 0: # cross-node buffer sync self._dist_broadcast_coalesced(buffers, self.broadcast_bucket_size) if len(self.device_ids) > 1: # intra-node buffer sync result = broadcast_coalesced(buffers, self.device_ids, self.broadcast_bucket_size) for tensors, module in zip(result[1:], self._module_copies[1:]): for tensor, buf in zip(tensors, module._all_buffers()): buf.set_(tensor)
def forward(ctx, target_gpus, *inputs): if not all(input.is_cuda for input in inputs): raise TypeError('Broadcast function not implemented for CPU tensors') ctx.target_gpus = target_gpus if len(inputs) == 0: return tuple() ctx.num_inputs = len(inputs) ctx.input_device = inputs[0].get_device() outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus) non_differentiables = [] for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]): if not input_requires_grad: for output in outputs: non_differentiables.append(output[idx]) ctx.mark_non_differentiable(*non_differentiables) return tuple([t for tensors in outputs for t in tensors])
def data_parallel(f, input, params, stats, mode, device_ids, output_device=None): assert isinstance(device_ids, list) if output_device is None: output_device = device_ids[0] if len(device_ids) == 1: return f(input, params, stats, mode) params_all = Broadcast.apply(device_ids, *params.values()) params_replicas = [{k: params_all[i + j*len(params)] for i, k in enumerate(params.keys())} for j in range(len(device_ids))] stats_replicas = [dict(zip(stats.keys(), p)) for p in comm.broadcast_coalesced(list(stats.values()), device_ids)] replicas = [partial(f, params=p, stats=s, mode=mode) for p, s in zip(params_replicas, stats_replicas)] inputs = scatter([input], device_ids) outputs = parallel_apply(replicas, inputs) return gather(outputs, output_device)
def backward(ctx, dz): x, _ex, _exs, gamma, beta = ctx.saved_tensors dz = dz.contiguous() # BN backward if dz.is_cuda: dx, _dex, _dexs, dgamma, dbeta = lib.gpu.batchnorm_backward(dz, x, _ex, _exs, gamma, beta, ctx.eps) else: raise NotImplemented if ctx.training: if ctx.sync: if ctx.is_master: _dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)] for _ in range(ctx.master_queue.maxsize): _dex_w, _dexs_w = ctx.master_queue.get() ctx.master_queue.task_done() _dex.append(_dex_w.unsqueeze(0)) _dexs.append(_dexs_w.unsqueeze(0)) _dex = comm.gather(_dex).mean(0) _dexs = comm.gather(_dexs).mean(0) tensors = comm.broadcast_coalesced((_dex, _dexs), [_dex.get_device()] + ctx.worker_ids) for ts, queue in zip(tensors[1:], ctx.worker_queues): queue.put(ts) else: ctx.master_queue.put((_dex, _dexs)) _dex, _dexs = ctx.worker_queue.get() ctx.worker_queue.task_done() if x.is_cuda: dx_ = lib.gpu.expectation_backward(x, _dex, _dexs) else: raise NotImplemented dx = dx + dx_ return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None
def update_replicas(self, parallel=False): params = [ para.data for para in filter_para_grad(self.module.parameters()) ] if len(params) > 0: param_copies = tuple([ t for tensors in comm.broadcast_coalesced( params, self.device_ids) for t in tensors ]) # currently, pytorch broadcast binds parameters between self.nets[0] and self.module, so the following line ensures correctness but less efficient #for module, param_copy in zip(self.nets, [param_copies[i:i + len(params)] for i in range(0, len(param_copies), len(params))]): for module, param_copy in zip(self.nets[1:], [ param_copies[i:i + len(params)] for i in range(len(params), len(param_copies), len(params)) ]): for mp, para in zip(filter_para_grad(module.parameters()), param_copy): mp.data, mp.grad = para, None self.ngradev = 0
def forward_nonsparse(ctx, target_gpus, *inputs, context_update=True): if not all(input.is_cuda for input in inputs): raise TypeError( 'Broadcast function not implemented for CPU tensors') target_gpus = list( map(lambda x: _get_device_index(x, True), target_gpus)) if context_update: ctx.target_gpus = target_gpus if len(inputs) == 0: return tuple() if context_update: ctx.num_inputs = len(inputs) ctx.input_device = inputs[0].get_device() outputs = comm.broadcast_coalesced(inputs, target_gpus) non_differentiables = [] if context_update: for idx, input_requires_grad in enumerate( ctx.needs_input_grad[1:]): if not input_requires_grad: for output in outputs: non_differentiables.append(output[idx]) ctx.mark_non_differentiable(*non_differentiables) return tuple([t for tensors in outputs for t in tensors])
def backward(ctx, dz): z, _ex, _exs, gamma, beta = ctx.saved_tensors dz = dz.contiguous() # Undo activation _act_backward(ctx, z, dz) # BN backward dx, _dex, _dexs, dgamma, dbeta = \ encoding_ext.batchnorm_inp_backward(dz, z, _ex, _exs, gamma, beta, ctx.eps) if ctx.training: if ctx.sync: if ctx.is_master: _dex, _dexs = [_dex.unsqueeze(0)], [_dexs.unsqueeze(0)] for _ in range(ctx.master_queue.maxsize): _dex_w, _dexs_w = ctx.master_queue.get() ctx.master_queue.task_done() _dex.append(_dex_w.unsqueeze(0)) _dexs.append(_dexs_w.unsqueeze(0)) _dex = comm.gather(_dex).mean(0) _dexs = comm.gather(_dexs).mean(0) tensors = comm.broadcast_coalesced( (_dex, _dexs), [_dex.get_device()] + ctx.worker_ids) for ts, queue in zip(tensors[1:], ctx.worker_queues): queue.put(ts) else: ctx.master_queue.put((_dex, _dexs)) _dex, _dexs = ctx.worker_queue.get() ctx.worker_queue.task_done() encoding_ext.expectation_inp_backward(dx, z, _dex, _dexs, _ex, _exs, gamma, beta, ctx.eps) return dx, dgamma, dbeta, None, None, None, None, None, None, None, None, None
def backward(ctx, dz): z, var, weight, bias = ctx.saved_tensors dz = dz.contiguous() # Undo activation _act_backward(ctx, z, dz) if ctx.training: edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) if ctx.is_master: edzs, eydzs = [edz], [eydz] for _ in range(len(ctx.worker_queues)): edz_w, eydz_w = ctx.master_queue.get() ctx.master_queue.task_done() edzs.append(edz_w) eydzs.append(eydz_w) edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1) eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1) tensors = comm.broadcast_coalesced((edz, eydz), [edz.get_device()] + ctx.worker_ids) for ts, queue in zip(tensors[1:], ctx.worker_queues): queue.put(ts) else: ctx.master_queue.put((edz, eydz)) edz, eydz = ctx.worker_queue.get() ctx.worker_queue.task_done() else: edz = dz.new_zeros(dz.size(1)) eydz = dz.new_zeros(dz.size(1)) dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) dweight = dweight if ctx.affine else None dbias = dbias if ctx.affine else None return dx, dweight, dbias, None, None, None, None, None, None, None, None
def test_broadcast_coalesced(self): numel = 5 num_bytes = numel * 8 tensors = [ torch.randn(numel).long().cuda(), torch.randn(numel).cuda(), torch.randn(numel).long().cuda(), torch.randn(numel).long().cuda(), torch.randn(numel * 2).int().cuda(), # int is 2x shorter torch.randn(numel).cuda(), ] b_tensors = [comm.broadcast(t, (0, 1)) for t in tensors] for (_, bt), t in zip(b_tensors, tensors): self.assertEqual(bt.get_device(), 1) self.assertEqual(bt, t) self.assertIsInstance(bt, type(t)) bc_tensors = comm.broadcast_coalesced(tensors, (0, 1), buffer_size=num_bytes * 5 // 2) bc_tensors_t = list(zip(*bc_tensors)) self.assertEqual(b_tensors, bc_tensors_t) for (_, bt), (_, bct) in zip(b_tensors, bc_tensors_t): self.assertEqual(bt.get_device(), bct.get_device()) self.assertIsInstance(bct, type(bt))
def update_replicas(self): if self.optm_splt is None: if self.is_contiguous_parameters: params = [ para.data for para in get_all_contiguous_parameters_m(self.module) ] param_copies = comm.broadcast_coalesced( params, self.device_ids) with torch.no_grad(): for module, param_copy in zip(self.nets[1:], param_copies[1:]): for mp, para in zip( get_all_contiguous_parameters_m(module), param_copy): mp.data.copy_(para) mp.grad.zero_() else: params = [ para.data for para in filter_para_grad(self.module.parameters()) ] param_copies = comm.broadcast_coalesced( params, self.device_ids) # currently, pytorch broadcast binds parameters between self.nets[0] and self.module, so the following line ensures correctness but less efficient #for module, param_copy in zip(self.nets, param_copies): for module, param_copy in zip(self.nets[1:], param_copies[1:]): for mp, para in zip(filter_para_grad(module.parameters()), param_copy): mp.data, mp.grad = para, None else: for i, ( net, ( lind, rind, ), ) in enumerate(zip(self.nets, self.optm_splt)): _dev_params = [ para.data for para in get_contiguous_parameters_m(net, index=i) ] if self.is_contiguous_parameters else [ para.data for para in range_parameter_iter( net, lind, rind, func=filter_para_grad_iter) ] if i > 0: _devices = self.device_ids[:] _devices.insert(0, _devices.pop(i)) else: _devices = self.device_ids _dev_param_copies = comm.broadcast_coalesced( _dev_params, _devices) if i > 0: _dev_param_copies.insert(i, _dev_param_copies.pop(0)) for pc, _dpc in zip(param_copies, _dev_param_copies): pc.extend(_dpc) else: param_copies = _dev_param_copies if self.is_contiguous_parameters: with torch.no_grad(): for module, param_copy in zip(self.nets, param_copies): for mp, para in zip( get_all_contiguous_parameters_m(module), param_copy): mp.data.copy_(para) mp.grad.zero_() else: for module, param_copy in zip(self.nets, param_copies): for mp, para in zip(filter_para_grad(module.parameters()), param_copy): mp.data, mp.grad = para, None self.ngradev = 0
def backward(ctx, dz): z, weight, bias, running_mean, running_var = ctx.saved_tensors dz = dz.contiguous() # Undo activation _act_backward(ctx, z, dz) if ctx.needs_input_grad[0]: dx = dz.new().resize_as_(dz) else: dx = None if ctx.needs_input_grad[1]: dweight = dz.new().resize_as_(running_mean).zero_() else: dweight = None if ctx.needs_input_grad[2]: dbias = dz.new().resize_as_(running_mean).zero_() else: dbias = None if ctx.training: edz = dz.new().resize_as_(running_mean) eydz = dz.new().resize_as_(running_mean) _check_contiguous(z, dz, weight, bias, edz, eydz) _check(_ext.bn_edz_eydz_cuda, z, dz, weight if weight is not None else dz.new(), bias if bias is not None else dz.new(), edz, eydz, ctx.eps) if ctx.is_master: edzs, eydzs = [edz], [eydz] for _ in range(len(ctx.worker_queues)): edz_w, eydz_w = ctx.master_queue.get() ctx.master_queue.task_done() edzs.append(edz_w) eydzs.append(eydz_w) edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1) eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1) tensors = comm.broadcast_coalesced( (edz, eydz), [edz.get_device()] + ctx.worker_ids) for ts, queue in zip(tensors[1:], ctx.worker_queues): queue.put(ts) else: ctx.master_queue.put((edz, eydz)) edz, eydz = ctx.worker_queue.get() ctx.worker_queue.task_done() else: edz = dz.new().resize_as_(running_mean).zero_() eydz = dz.new().resize_as_(running_mean).zero_() _check_contiguous(dz, z, ctx.var, weight, bias, edz, eydz, dx, dweight, dbias) _check(_ext.bn_backard_cuda, dz, z, ctx.var, weight if weight is not None else dz.new(), bias if bias is not None else dz.new(), edz, eydz, dx if dx is not None else dz.new(), dweight if dweight is not None else dz.new(), dbias if dbias is not None else dz.new(), ctx.eps) del ctx.var return dx, dweight, dbias, None, None, None, None, None, None, None, None
def forward(ctx, x, weight, bias, running_mean, running_var, extra, compute_stats=True, momentum=0.1, eps=1e-05): def _parse_extra(ctx, extra): ctx.is_master = extra["is_master"] if ctx.is_master: ctx.master_queue = extra["master_queue"] ctx.worker_queues = extra["worker_queues"] ctx.worker_ids = extra["worker_ids"] else: ctx.master_queue = extra["master_queue"] ctx.worker_queue = extra["worker_queue"] # Save context if extra is not None: _parse_extra(ctx, extra) ctx.compute_stats = compute_stats ctx.momentum = momentum ctx.eps = eps ctx.affine = weight is not None and bias is not None if ctx.compute_stats: N = _count_samples(x) * (ctx.master_queue.maxsize + 1) assert N > 1 # 1. compute sum(x) and sum(x^2) xsum, xsqsum = _backend.syncbn_sum_sqsum(x.detach()) if ctx.is_master: xsums, xsqsums = [xsum], [xsqsum] # master : gatther all sum(x) and sum(x^2) from slaves for _ in range(ctx.master_queue.maxsize): xsum_w, xsqsum_w = ctx.master_queue.get() ctx.master_queue.task_done() xsums.append(xsum_w) xsqsums.append(xsqsum_w) xsum = comm.reduce_add(xsums) xsqsum = comm.reduce_add(xsqsums) mean = xsum / N sumvar = xsqsum - xsum * mean var = sumvar / N uvar = sumvar / (N - 1) # master : broadcast global mean, variance to all slaves tensors = comm.broadcast_coalesced( (mean, uvar, var), [mean.get_device()] + ctx.worker_ids) for ts, queue in zip(tensors[1:], ctx.worker_queues): queue.put(ts) else: # slave : send sum(x) and sum(x^2) to master ctx.master_queue.put((xsum, xsqsum)) # slave : get global mean and variance mean, uvar, var = ctx.worker_queue.get() ctx.worker_queue.task_done() # Update running stats running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar) ctx.N = N ctx.save_for_backward(x, weight, bias, mean, var) else: mean, var = running_mean, running_var # do batch norm forward z = _backend.syncbn_forward(x, weight, bias, mean, var, ctx.affine, ctx.eps) return z
def backward(ctx, *inputs): inputs = [i.data for i in inputs] inputs = [inputs[i:i + ctx.num_inputs] for i in range(0, len(inputs), ctx.num_inputs)] results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) outputs = comm.broadcast_coalesced(results, ctx.target_gpus) return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])
def forward(cls, ctx, x, weight, bias, running_mean, running_var, extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): # Save context cls._parse_extra(ctx, extra) ctx.training = training ctx.momentum = momentum ctx.eps = eps ctx.activation = activation ctx.slope = slope n = _count_samples(x) * (ctx.master_queue.maxsize + 1) if ctx.training: mean = x.new().resize_(1, running_mean.size(0)) var = x.new().resize_(1, running_var.size(0)) _check_contiguous(x, mean, var) _check(_ext.bn_mean_var_cuda, x, mean, var) if ctx.is_master: means, vars = [mean], [var] for _ in range(ctx.master_queue.maxsize): mean_w, var_w = ctx.master_queue.get() ctx.master_queue.task_done() means.append(mean_w) vars.append(var_w) means = comm.gather(means) vars = comm.gather(vars) mean = means.mean(0) var = (vars + (mean - means)**2).mean(0) tensors = comm.broadcast_coalesced( (mean, var), [mean.get_device()] + ctx.worker_ids) for ts, queue in zip(tensors[1:], ctx.worker_queues): queue.put(ts) else: ctx.master_queue.put((mean, var)) mean, var = ctx.worker_queue.get() ctx.worker_queue.task_done() # Update running stats running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) running_var.mul_( (1 - ctx.momentum)).add_(ctx.momentum * var * n / (n - 1)) else: mean, var = running_mean, running_var _check_contiguous(x, mean, var, weight, bias) _check(_ext.bn_forward_cuda, x, mean, var, weight if weight is not None else x.new(), bias if bias is not None else x.new(), x, x, ctx.eps) # Activation _act_forward(ctx, x) # Output ctx.var = var ctx.save_for_backward(x, weight, bias, running_mean, running_var) ctx.mark_dirty(x) return x
def backward(ctx, dz): z, weight, bias, running_mean, running_var = ctx.saved_tensors dz = dz.contiguous() # Undo activation _act_backward(ctx, z, dz) if ctx.needs_input_grad[0]: dx = dz.new().resize_as_(dz) else: dx = None if ctx.needs_input_grad[1]: dweight = dz.new().resize_as_(running_mean).zero_() else: dweight = None if ctx.needs_input_grad[2]: dbias = dz.new().resize_as_(running_mean).zero_() else: dbias = None if ctx.training: edz = dz.new().resize_as_(running_mean) eydz = dz.new().resize_as_(running_mean) _check_contiguous(z, dz, weight, bias, edz, eydz) _check(_ext.bn_edz_eydz_cuda, z, dz, weight if weight is not None else dz.new(), bias if bias is not None else dz.new(), edz, eydz, ctx.eps) if ctx.is_master: edzs, eydzs = [edz], [eydz] for _ in range(len(ctx.worker_queues)): edz_w, eydz_w = ctx.master_queue.get() ctx.master_queue.task_done() edzs.append(edz_w) eydzs.append(eydz_w) edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1) eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1) tensors = comm.broadcast_coalesced((edz, eydz), [edz.get_device()] + ctx.worker_ids) for ts, queue in zip(tensors[1:], ctx.worker_queues): queue.put(ts) else: ctx.master_queue.put((edz, eydz)) edz, eydz = ctx.worker_queue.get() ctx.worker_queue.task_done() else: edz = dz.new().resize_as_(running_mean).zero_() eydz = dz.new().resize_as_(running_mean).zero_() _check_contiguous(dz, z, ctx.var, weight, bias, edz, eydz, dx, dweight, dbias) _check(_ext.bn_backard_cuda, dz, z, ctx.var, weight if weight is not None else dz.new(), bias if bias is not None else dz.new(), edz, eydz, dx if dx is not None else dz.new(), dweight if dweight is not None else dz.new(), dbias if dbias is not None else dz.new(), ctx.eps) del ctx.var return dx, dweight, dbias, None, None, None, None, None, None, None, None
def forward(cls, ctx, x, weight, bias, running_mean, running_var, extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): # Save context cls._parse_extra(ctx, extra) ctx.training = training ctx.momentum = momentum ctx.eps = eps ctx.activation = activation ctx.slope = slope n = _count_samples(x) * (ctx.master_queue.maxsize + 1) if ctx.training: mean = x.new().resize_(1, running_mean.size(0)) var = x.new().resize_(1, running_var.size(0)) _check_contiguous(x, mean, var) _check(_ext.bn_mean_var_cuda, x, mean, var) if ctx.is_master: means, vars = [mean], [var] for _ in range(ctx.master_queue.maxsize): mean_w, var_w = ctx.master_queue.get() ctx.master_queue.task_done() means.append(mean_w) vars.append(var_w) means = comm.gather(means) vars = comm.gather(vars) mean = means.mean(0) var = (vars + (mean - means) ** 2).mean(0) tensors = comm.broadcast_coalesced((mean, var), [mean.get_device()] + ctx.worker_ids) for ts, queue in zip(tensors[1:], ctx.worker_queues): queue.put(ts) else: ctx.master_queue.put((mean, var)) mean, var = ctx.worker_queue.get() ctx.worker_queue.task_done() # Update running stats running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * n / (n - 1)) else: mean, var = running_mean, running_var _check_contiguous(x, mean, var, weight, bias) _check(_ext.bn_forward_cuda, x, mean, var, weight if weight is not None else x.new(), bias if bias is not None else x.new(), x, x, ctx.eps) # Activation _act_forward(ctx, x) # Output ctx.var = var ctx.save_for_backward(x, weight, bias, running_mean, running_var) ctx.mark_dirty(x) return x
def replicate(network, devices, no_gradient=False): def clear_gradient(para): para.grad = None return para num_replicas = len(devices) params = [clear_gradient(para) for para in network.parameters() ] if no_gradient else list(network.parameters()) param_indices = {param: idx for idx, param in enumerate(params)} param_copies = comm.broadcast_coalesced(params, devices) buffers = list(network.buffers()) buffer_indices = {buf: idx for idx, buf in enumerate(buffers)} buffer_copies = comm.broadcast_coalesced(buffers, devices) modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} scriptmodule_skip_attr = { "_parameters", "_buffers", "_modules", "forward", "_c" } for i, module in enumerate(modules): module_indices[module] = i for j in range(num_replicas): if isinstance(module, ScriptModule): # we have to initialize ScriptModule properly so that # it works with pybind11 replica = ScriptModule() attribute_names = set(entry[0] for entry in module._c._get_attributes()) keys = set(module.__dict__.keys() ) - scriptmodule_skip_attr - attribute_names for key in keys: if not isinstance(module.__dict__[key], ScriptMethod): replica.__dict__[key] = module.__dict__[key] for name, the_type, value in module._c._get_attributes(): if not name in module._buffers.keys(): replica._c._register_attribute(name, the_type, value) else: replica = module.__new__(type(module)) replica.__dict__ = module.__dict__.copy() replica._parameters = replica._parameters.copy() replica._buffers = replica._buffers.copy() replica._modules = replica._modules.copy() module_copies[j].append(replica) for i, module in enumerate(modules): for key, child in module._modules.items(): if child is None: for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = None else: module_idx = module_indices[child] for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = module_copies[j][module_idx] for key, param in module._parameters.items(): if param is None: for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = None else: param_idx = param_indices[param] for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = param_copies[j][ param_idx].requires_grad_(param.requires_grad) for key, buf in module._buffers.items(): if buf is None: for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = None else: buffer_idx = buffer_indices[buf] for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = buffer_copies[j][buffer_idx] for j in range(num_replicas): for i, module in enumerate(modules): if isinstance(module, ScriptModule): replica = module_copies[j][i] for method_name in module._c._method_names(): replica._c.clone_method(module._c, method_name) return [module_copies[j][0] for j in range(num_replicas)]
def replicate(network, devices, detach=False): from ._functions import Broadcast if not _replicatable_module(network): raise RuntimeError("Cannot replicate network where python modules are " "childrens of ScriptModule") devices = list(map(lambda x: _get_device_index(x, True), devices)) num_replicas = len(devices) params = list(network.parameters()) param_indices = {param: idx for idx, param in enumerate(params)} param_copies = Broadcast.apply(devices, *params) if len(params) > 0: param_copies = [ param_copies[i:i + len(params)] for i in range(0, len(param_copies), len(params)) ] buffers = list(network.buffers()) buffer_indices = {buf: idx for idx, buf in enumerate(buffers)} buffer_copies = comm.broadcast_coalesced(buffers, devices) modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} scriptmodule_skip_attr = {"_parameters", "_buffers", "_modules"} for i, module in enumerate(modules): module_indices[module] = i for j in range(num_replicas): if _is_script_module(module): # we have to initialize ScriptModule properly so that # it works with pybind11 replica = _init_script_module() keys = set(module.__dict__.keys()) - scriptmodule_skip_attr for key in keys: replica.__dict__[key] = module.__dict__[key] else: replica = module.__new__(type(module)) replica.__dict__ = module.__dict__.copy() replica._parameters = replica._parameters.copy() replica._buffers = replica._buffers.copy() replica._modules = replica._modules.copy() module_copies[j].append(replica) for i, module in enumerate(modules): for key, child in module._modules.items(): if child is None: for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = None else: module_idx = module_indices[child] for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = module_copies[j][module_idx] for key, param in module._parameters.items(): if param is None: for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = None else: param_idx = param_indices[param] for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = param_copies[j][param_idx].detach() \ if detach else param_copies[j][param_idx] for key, buf in module._buffers.items(): if buf is None: for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = None else: buffer_idx = buffer_indices[buf] for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = buffer_copies[j][buffer_idx] for j in range(num_replicas): _copy_scriptmodule_methods(modules, module_copies[j], module_indices) return [module_copies[j][0] for j in range(num_replicas)]
def replicate_model(network, devices, no_gradient=False): def clear_gradient(para): if no_gradient: para.grad = None return para num_replicas = len(devices) params = [clear_gradient(para) for para in network.parameters() ] if no_gradient else list(network.parameters()) param_indices = {param: idx for idx, param in enumerate(params)} param_copies = tuple([ t for tensors in comm.broadcast_coalesced(params, devices) for t in tensors ]) if len(params) > 0: param_copies = [ param_copies[i:i + len(params)] for i in range(0, len(param_copies), len(params)) ] buffers = list(network.buffers()) buffer_indices = {buf: idx for idx, buf in enumerate(buffers)} buffer_copies = comm.broadcast_coalesced(buffers, devices) modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} for i, module in enumerate(modules): module_indices[module] = i for j in range(num_replicas): replica = module.__new__(type(module)) replica.__dict__ = module.__dict__.copy() replica._parameters = replica._parameters.copy() replica._buffers = replica._buffers.copy() replica._modules = replica._modules.copy() module_copies[j].append(replica) for i, module in enumerate(modules): for key, child in module._modules.items(): if child is None: for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = None else: module_idx = module_indices[child] for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = module_copies[j][module_idx] for key, param in module._parameters.items(): if param is None: for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = None else: param_idx = param_indices[param] for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = param_copies[j][ param_idx].requires_grad_(param.requires_grad) for key, buf in module._buffers.items(): if buf is None: for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = None else: buffer_idx = buffer_indices[buf] for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = buffer_copies[j][buffer_idx] return [module_copies[j][0] for j in range(num_replicas)]
def replicate(network, devices): from ._functions import Broadcast devices = tuple(devices) num_replicas = len(devices) params = list(network.parameters()) param_indices = {param: idx for idx, param in enumerate(params)} param_copies = Broadcast(devices)(*params) if len(params) > 0: param_copies = [param_copies[i:i + len(params)] for i in range(0, len(param_copies), len(params))] buffers = list(network._all_buffers()) buffer_indices = {buf: idx for idx, buf in enumerate(buffers)} buffer_copies = comm.broadcast_coalesced(buffers, devices) modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} for i, module in enumerate(modules): module_indices[module] = i for j in range(num_replicas): replica = module.__new__(type(module)) replica.__dict__ = module.__dict__.copy() replica._parameters = replica._parameters.copy() replica._buffers = replica._buffers.copy() replica._modules = replica._modules.copy() module_copies[j].append(replica) for i, module in enumerate(modules): for key, child in module._modules.items(): if child is None: for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = None else: module_idx = module_indices[child] for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = module_copies[j][module_idx] for key, param in module._parameters.items(): if param is None: for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = None else: param_idx = param_indices[param] for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = param_copies[j][param_idx] for key, buf in module._buffers.items(): if buf is None: for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = None else: buffer_idx = buffer_indices[buf] for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = buffer_copies[j][buffer_idx] return [module_copies[j][0] for j in range(num_replicas)]
def forward(ctx, x, gamma, beta, running_mean, running_var, extra, training=True, momentum=0.1, eps=1e-5, sync=True): def parse_extra(ctx, extra): ctx.is_master = extra["is_master"] if ctx.is_master: ctx.master_queue = extra["master_queue"] ctx.worker_queues = extra["worker_queues"] ctx.worker_ids = extra["worker_ids"] else: ctx.master_queue = extra["master_queue"] ctx.worker_queue = extra["worker_queue"] parse_extra(ctx, extra) ctx.training = training ctx.momentum = momentum ctx.eps = eps ctx.sync = sync if ctx.training: ex, exs = syncbn_gpu.batch_norm_collect_statistics(x) if ctx.sync: if ctx.is_master: ex, exs = [ex.unsqueeze(0)], [exs.unsqueeze(0)] for _ in range(ctx.master_queue.maxsize): ex_w, exs_w = ctx.master_queue.get() ctx.master_queue.task_done() ex.append(ex_w.unsqueeze(0)) exs.append(exs_w.unsqueeze(0)) ex = comm.gather(ex).mean(0) exs = comm.gather(exs).mean(0) tensors = comm.broadcast_coalesced( (ex, exs), [ex.get_device()] + ctx.worker_ids) for ts, queue in zip(tensors[1:], ctx.worker_queues): queue.put(ts) else: ctx.master_queue.put((ex, exs)) ex, exs = ctx.worker_queue.get() ctx.worker_queue.task_done() var = exs - ex**2 running_mean.mul_(1 - ctx.momentum).add_(ctx.momentum * ex) running_var.mul_(1 - ctx.momentum).add_(ctx.momentum * var) ctx.mark_dirty(running_mean, running_var) y = syncbn_gpu.batch_norm_transform_input(x, gamma, beta, ex, exs, ctx.eps) ctx.save_for_backward(x, ex, exs, gamma, beta) return y
def replicate(network, devices, detach=False): from ._functions import Broadcast devices = tuple(devices) num_replicas = len(devices) params = list(network.parameters()) param_indices = {param: idx for idx, param in enumerate(params)} param_copies = Broadcast.apply(devices, *params) if len(params) > 0: param_copies = [ param_copies[i:i + len(params)] for i in range(0, len(param_copies), len(params)) ] buffers = list(network.buffers()) buffer_indices = {buf: idx for idx, buf in enumerate(buffers)} buffer_copies = comm.broadcast_coalesced(buffers, devices) modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} for i, module in enumerate(modules): module_indices[module] = i for j in range(num_replicas): replica = module.__new__(type(module)) replica.__dict__ = module.__dict__.copy() replica._parameters = replica._parameters.copy() replica._buffers = replica._buffers.copy() replica._modules = replica._modules.copy() module_copies[j].append(replica) for i, module in enumerate(modules): for key, child in module._modules.items(): if child is None: for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = None else: module_idx = module_indices[child] for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = module_copies[j][module_idx] for key, param in module._parameters.items(): if param is None: for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = None else: param_idx = param_indices[param] for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = param_copies[j][param_idx].detach() \ if detach else param_copies[j][param_idx] for key, buf in module._buffers.items(): if buf is None: for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = None else: buffer_idx = buffer_indices[buf] for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = buffer_copies[j][buffer_idx] return [module_copies[j][0] for j in range(num_replicas)]
def forward(cls, ctx, x, weight, bias, running_mean, running_var, extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): # Save context cls._parse_extra(ctx, extra) ctx.training = training ctx.momentum = momentum ctx.eps = eps ctx.activation = activation ctx.slope = slope ctx.affine = weight is not None and bias is not None # Prepare inputs count = _count_samples(x) * (ctx.master_queue.maxsize + 1) x = x.contiguous() weight = weight.contiguous() if ctx.affine else x.new_empty(0) bias = bias.contiguous() if ctx.affine else x.new_empty(0) if ctx.training: mean, var = _backend.mean_var(x) if ctx.is_master: means, vars = [mean.unsqueeze(0)], [var.unsqueeze(0)] for _ in range(ctx.master_queue.maxsize): mean_w, var_w = ctx.master_queue.get() ctx.master_queue.task_done() means.append(mean_w.unsqueeze(0)) vars.append(var_w.unsqueeze(0)) means = comm.gather(means) vars = comm.gather(vars) mean = means.mean(0) var = (vars + (mean - means)**2).mean(0) tensors = comm.broadcast_coalesced( (mean, var), [mean.get_device()] + ctx.worker_ids) for ts, queue in zip(tensors[1:], ctx.worker_queues): queue.put(ts) else: ctx.master_queue.put((mean, var)) mean, var = ctx.worker_queue.get() ctx.worker_queue.task_done() # Update running stats running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1)) # Mark in-place modified tensors ctx.mark_dirty(x, running_mean, running_var) else: mean, var = running_mean.contiguous(), running_var.contiguous() ctx.mark_dirty(x) # BN forward + activation _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) _act_forward(ctx, x) # Output ctx.var = var ctx.save_for_backward(x, var, weight, bias) return x
def update_replicas_para(self): if self.optm_splt is None: if self.is_contiguous_parameters: params = [ para.data for para in get_all_contiguous_parameters_m(self.module) ] param_copies = comm.broadcast_coalesced( params, self.device_ids) with torch.no_grad(): for module, param_copy in zip(self.nets[1:], param_copies[1:]): for mp, para in zip( get_all_contiguous_parameters_m(module), param_copy): mp.data.copy_(para) else: params = [ para.data for para in filter_para_grad(self.module.parameters()) ] param_copies = comm.broadcast_coalesced( params, self.device_ids) for module, param_copy in zip(self.nets[1:], param_copies[1:]): for mp, para in zip(filter_para_grad(module.parameters()), param_copy): mp.data = para else: for i, ( net, ( lind, rind, ), ) in enumerate(zip(self.nets, self.optm_splt)): _dev_params = [ para.data for para in get_contiguous_parameters_m(net, index=i) ] if self.is_contiguous_parameters else [ para.data for para in range_parameter_iter( net, lind, rind, func=filter_para_grad_iter) ] if i > 0: _devices = self.device_ids[:] _devices.insert(0, _devices.pop(i)) else: _devices = self.device_ids _dev_param_copies = comm.broadcast_coalesced( _dev_params, _devices) if i > 0: _dev_param_copies.insert(i, _dev_param_copies.pop(0)) for pc, _dpc in zip(param_copies, _dev_param_copies): pc.extend(_dpc) else: param_copies = _dev_param_copies if self.is_contiguous_parameters: with torch.no_grad(): for module, param_copy in zip(self.nets, param_copies): for mp, para in zip( get_all_contiguous_parameters_m(module), param_copy): mp.data.copy_(para) else: for module, param_copy in zip(self.nets, param_copies): for mp, para in zip( filter_para_grad(module.parameters()), param_copy): mp.data = para self.ngradev = 0
def replicate(network, devices): devices = tuple(devices) num_replicas = len(devices) params = list(network.parameters()) param_indices = {param: idx for idx, param in enumerate(params)} param_copies = broadcast_coalesced(params, devices) buffers = list(network._all_buffers()) buffer_indices = {buf: idx for idx, buf in enumerate(buffers)} buffer_copies = broadcast_coalesced(buffers, devices) modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} for i, module in enumerate(modules): module_indices[module] = i for j in range(num_replicas): replica = module.__new__(type(module)) replica.__dict__ = module.__dict__.copy() replica._parameters = replica._parameters.copy() replica._buffers = replica._buffers.copy() replica._modules = replica._modules.copy() module_copies[j].append(replica) for i, module in enumerate(modules): for key, child in module._modules.items(): if child is None: for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = None else: module_idx = module_indices[child] for j in range(num_replicas): replica = module_copies[j][i] replica._modules[key] = module_copies[j][module_idx] for key, param in module._parameters.items(): if param is None: for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = None else: param_idx = param_indices[param] for j in range(num_replicas): replica = module_copies[j][i] replica._parameters[key] = param_copies[j][param_idx] replica._parameters[ key].requires_grad = param.requires_grad for key, buf in module._buffers.items(): if buf is None: for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = None else: buffer_idx = buffer_indices[buf] for j in range(num_replicas): replica = module_copies[j][i] replica._buffers[key] = buffer_copies[j][buffer_idx] return [module_copies[j][0] for j in range(num_replicas)]
def forward(cls, ctx, x, gamma, beta, running_mean, running_var, extra, sync=True, training=True, momentum=0.1, eps=1e-05, activation="none", slope=0.01): # save context cls._parse_extra(ctx, extra) ctx.sync = sync ctx.training = training ctx.momentum = momentum ctx.eps = eps ctx.activation = activation ctx.slope = slope assert activation == 'none' # continous inputs x = x.contiguous() gamma = gamma.contiguous() beta = beta.contiguous() if ctx.training: _ex, _exs = _C.expectation_forward(x) if ctx.sync: if ctx.is_master: _ex, _exs = [_ex.unsqueeze(0)], [_exs.unsqueeze(0)] for _ in range(ctx.master_queue.maxsize): _ex_w, _exs_w = ctx.master_queue.get() ctx.master_queue.task_done() _ex.append(_ex_w.unsqueeze(0)) _exs.append(_exs_w.unsqueeze(0)) _ex = comm.gather(_ex).mean(0) _exs = comm.gather(_exs).mean(0) tensors = comm.broadcast_coalesced( (_ex, _exs), [_ex.get_device()] + ctx.worker_ids) for ts, queue in zip(tensors[1:], ctx.worker_queues): queue.put(ts) else: ctx.master_queue.put((_ex, _exs)) _ex, _exs = ctx.worker_queue.get() ctx.worker_queue.task_done() # Update running stats _var = _exs - _ex**2 running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex) running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var) # Mark in-place modified tensors ctx.mark_dirty(running_mean, running_var) else: _ex, _var = running_mean.contiguous(), running_var.contiguous() _exs = _var + _ex**2 # BN forward y = _C.batchnorm_forward(x, _ex, _exs, gamma, beta, ctx.eps) # Output ctx.save_for_backward(x, _ex, _exs, gamma, beta) return y
def forward(cls, ctx, x, weight, bias, running_mean, running_var, extra, compute_stats=True, momentum=0.1, eps=1e-05): # Save context if extra is not None: cls._parse_extra(ctx, extra) ctx.compute_stats = compute_stats ctx.momentum = momentum ctx.eps = eps if ctx.compute_stats: N = _count_samples(x) * (ctx.master_queue.maxsize + 1) assert N > 1 num_features = running_mean.size(0) # 1. compute sum(x) and sum(x^2) xsum = x.new().resize_(num_features) xsqsum = x.new().resize_(num_features) _check_contiguous(x, xsum, xsqsum) _lib_bn.syncbn_sum_sqsum_cuda(x.detach(), xsum, xsqsum) if ctx.is_master: xsums, xsqsums = [xsum], [xsqsum] # master : gatther all sum(x) and sum(x^2) from slaves for _ in range(ctx.master_queue.maxsize): xsum_w, xsqsum_w = ctx.master_queue.get() ctx.master_queue.task_done() xsums.append(xsum_w) xsqsums.append(xsqsum_w) xsum = comm.reduce_add(xsums) xsqsum = comm.reduce_add(xsqsums) mean = xsum / N sumvar = xsqsum - xsum * mean var = sumvar / N uvar = sumvar / (N - 1) # master : broadcast global mean, variance to all slaves tensors = comm.broadcast_coalesced( (mean, uvar, var), [mean.get_device()] + ctx.worker_ids) for ts, queue in zip(tensors[1:], ctx.worker_queues): queue.put(ts) else: # slave : send sum(x) and sum(x^2) to master ctx.master_queue.put((xsum, xsqsum)) # slave : get global mean and variance mean, uvar, var = ctx.worker_queue.get() ctx.worker_queue.task_done() # Update running stats running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar) ctx.N = N ctx.save_for_backward(x, weight, bias, mean, var) else: mean, var = running_mean, running_var output = x.new().resize_as_(x) _check_contiguous(output, x, mean, var, weight, bias) # do batch norm forward _lib_bn.syncbn_forward_cuda(output, x, weight if weight is not None else x.new(), bias if bias is not None else x.new(), mean, var, ctx.eps) return output