def _test_scatter(self, tensor): x = tensor.detach().requires_grad_() result = dp.scatter(x, (0, 1)) self.assertEqual(len(result), 2) self.assertEqual(result[0], x[:2]) self.assertEqual(result[0].get_device(), 0) self.assertEqual(result[1], x[2:]) self.assertEqual(result[1].get_device(), 1) grad = result[0].detach().clone().fill_(2) result[0].backward(grad) self.assertEqual(x.grad[:2], grad) self.assertEqual(x.grad[2:], grad.clone().zero_()) _assertGradAndGradgradChecks(self, lambda y: dp.scatter(y, (0, 1)), (x,))
def data_parallel(f, input, params, stats, mode, device_ids, output_device=None): assert isinstance(device_ids, list) if output_device is None: output_device = device_ids[0] if len(device_ids) == 1: return f(input, params, stats, mode) params_all = Broadcast.apply(device_ids, *params.values()) params_replicas = [{ k: params_all[i + j * len(params)] for i, k in enumerate(params.keys()) } for j in range(len(device_ids))] stats_replicas = [ dict(zip(stats.keys(), p)) for p in comm.broadcast_coalesced(list(stats.values()), device_ids) ] replicas = [ partial(f, params=p, stats=s, mode=mode) for p, s in zip(params_replicas, stats_replicas) ] inputs = scatter([input], device_ids) outputs = parallel_apply(replicas, inputs) return gather(outputs, output_device)
def data_parallel(f, input, params, stats, mode, device_ids, output_device=None): if output_device is None: output_device = device_ids[0] if len(device_ids) == 1: return f(input, params, stats, mode) def replicate(param_dict, g): replicas = [{} for d in device_ids] for k, v in param_dict.iteritems(): for i, u in enumerate(g(v)): replicas[i][k] = u return replicas params_replicas = replicate(params, lambda x: Broadcast(device_ids)(x)) stats_replicas = replicate(stats, lambda x: comm.broadcast(x, device_ids)) replicas = [ lambda x, p=p, s=s, mode=mode: f(x, p, s, mode) for i, (p, s) in enumerate(zip(params_replicas, stats_replicas)) ] inputs = scatter(input, device_ids) outputs = parallel_apply(replicas, inputs) return gather(outputs, output_device)
def esti_variance_step(self): source_data, target_data, src_lens, tgt_lens = self.corpus.get_esti_batches() whole_batchs_variance_list = [] for source, target, src_len, tgt_len in zip(source_data, target_data, src_lens, tgt_lens): for i in range(0, self.corpus.num_of_multi_refs): all_inputs = scatter(inputs=(source, target[i], src_len, tgt_len[i]), target_gpus=self.device_idxs, dim=0) num_of_device = len(all_inputs) args = list((self.replicas[i], all_inputs[i]) for i in range(0, num_of_device)) all_threads = list(Thread(target=_esti_variance_worker, args=list(args[i]) + [self.queue]) for i in range(0, num_of_device)) for t in all_threads: t.start() for t in all_threads: t.join() all_results = list(self.queue.get() for _ in range(0, num_of_device)) cur_batch_variance = numpy.mean(all_results) whole_batchs_variance_list.append(cur_batch_variance) total_variance = torch.FloatTensor(whole_batchs_variance_list) total_average_variance = total_variance.mean() return total_average_variance
def train_step(self, batch): report_idx = self.processed_steps % self.report_every_steps self.src_num_pad_tokens[report_idx] = int(batch[0].numel()) self.tgt_num_pad_tokens[report_idx] = int(batch[1].numel()) self.src_tokens[report_idx] = int(batch[2].sum()) self.tgt_tokens[report_idx] = int(batch[3].sum()) self.num_examples[report_idx] = int(batch[0].size(0)) all_inputs = scatter(inputs=batch, target_gpus=self.device_idxs, dim=0) num_of_device = len(all_inputs) factor = 1.0 / self.num_examples[report_idx] args = list((self.replicas[i], all_inputs[i], factor, self.queue) for i in range(0, num_of_device)) all_threads = list(Thread(target=_train_worker, args=list(args[i])) for i in range(0, num_of_device)) for t in all_threads: t.start() for t in all_threads: t.join() all_results = list(self.queue.get() for _ in range(0, num_of_device)) self.acc_report[report_idx] += sum(x[0] for x in all_results) / self.tgt_tokens[report_idx] self.loss_report[report_idx] += sum(x[1] for x in all_results) / self.num_examples[report_idx] self.update_decay_steps[report_idx] += 1 return
def data_parallel(f, input, params, stats, mode, device_ids, output_device=None): if output_device is None: output_device = device_ids[0] if len(device_ids) == 1: # only 1 device return f(input, params, stats, mode) # function inside data_parallel def replicate(param_dict, g): replicas = [{} for d in device_ids] # replicas, list of n_devices dict for k,v in param_dict.iteritems(): # v is parameter for i,u in enumerate(g(v)): replicas[i][k] = u return replicas # broadcast parameters params_replicas = replicate(params, lambda x: Broadcast(device_ids)(x)) # broadcast stats stats_replicas = replicate(stats, lambda x: comm.broadcast(x, device_ids)) replicas = [lambda x,p=p,s=s,mode=mode: f(x,p,s,mode) for i,(p,s) in enumerate(zip(params_replicas, stats_replicas))] inputs = scatter(input, device_ids) outputs = parallel_apply(replicas, inputs) return gather(outputs, output_device)
def forward(self, *inputs, **kwargs): inputs = scatter(inputs, self.device_ids, dim=0) kwargs = scatter(kwargs, self.device_ids, dim=0) replicas = replicate(self.network, self.device_ids[:len(inputs)]) outputs = parallel_apply(replicas, inputs, kwargs) outputs = list(zip(*outputs)) res = [] for i in range(len(outputs)): buf = [] for j in range(len(outputs[i])): if isinstance(outputs[i][j], int): if outputs[i][j]<0: buf.append(outputs[i][j]) else: buf.append(outputs[i][j].to(self.device_ids[0])) res.append(buf) return res
def _test_scatter(self, x): if not TEST_MULTIGPU: raise unittest.SkipTest("Only one GPU detected") x = Variable(x) result = dp.scatter(x, (0, 1)) self.assertEqual(len(result), 2) self.assertEqual(result[0], x[:2]) self.assertEqual(result[0].get_device(), 0) self.assertEqual(result[1], x[2:]) self.assertEqual(result[1].get_device(), 1)
def replicate_module(self, module: torch.nn.Module, devices: List[int]) -> List[torch.nn.Module]: assert self.n_mask_samples % len(devices) == 0 copies = replicate(module, devices) def walk(module: torch.nn.Module, copy: torch.nn.Module): module_map = {id(module): copy} for name, ref in module._modules.items(): module_map.update(walk(ref, getattr(copy, name))) return module_map devices = [_get_device_index(d) for d in devices] # Copy the custom parameters all_params = [p.get() for p in self.pointer_values] if (not self.masking_enabled) or (not self.training): scattered = _broadcast_coalesced_reshape(all_params, devices) else: # Here is more complicated, because there might be non-masked parameters which has to be handled in the # usual way masked_indices = [ i for i, n in enumerate(self.param_names) if self.is_masked(n) ] simple_indices = [ i for i, n in enumerate(self.param_names) if not self.is_masked(n) ] masked_params = scatter([all_params[i] for i in masked_indices], devices) simple_params = _broadcast_coalesced_reshape( [all_params[i] for i in simple_indices], devices) scattered = [[None] * len(all_params) for _ in devices] for d in range(len(devices)): for mi, mp in zip(masked_indices, masked_params[d]): scattered[d][mi] = mp for si, sp in zip(simple_indices, simple_params[d]): scattered[d][si] = sp for i, c in enumerate(copies): device_map = walk(module, c) for j, p in enumerate(self.pointer_values): setattr(device_map[id(p.parent)], p.name, scattered[i][j]) self.update_rnn_params(c) return copies
def data_parallel(f, input, params, mode, device_ids, output_device=None): assert isinstance(device_ids, list) if output_device is None: output_device = device_ids[0] if len(device_ids) == 1: return f(input, params, mode) params_all = Broadcast.apply(device_ids, *params.values()) params_replicas = [{k: params_all[i + j*len(params)] for i, k in enumerate(params.keys())} for j in range(len(device_ids))] replicas = [partial(f, params=p, mode=mode) for p in params_replicas] inputs = scatter([input], device_ids) outputs = parallel_apply(replicas, inputs) return gather(outputs, output_device)
def data_parallel(f, input, params, mode, device_ids, output_device=None): device_ids = list(device_ids) if output_device is None: output_device = device_ids[0] if len(device_ids) == 1: return f(input, params, mode) params_all = Broadcast.apply(device_ids, *params.values()) params_replicas = [{ k: params_all[i + j * len(params)] for i, k in enumerate(params.keys()) } for j in range(len(device_ids))] replicas = [partial(f, params=p, mode=mode) for p in params_replicas] inputs = scatter([input], device_ids) outputs = parallel_apply(replicas, inputs) return gather(outputs, output_device)
def scatter(self, inputs, kwargs, device_ids): from torch.nn.parallel import scatter def chunk_it(seq, num, devices): assert isinstance(num, (int, list, tuple)) if isinstance(num, int): chunk_sizes = [ len(seq) / float(num), ] * num else: chunk_sizes = map(int, num) out = [] last = 0.0 for size, device in zip(chunk_sizes, devices): out.append( torch.tensor(seq[int(last):int(last + size)], device=device)) last += size return out devices = [torch.device('cuda:' + str(i)) for i in device_ids] nums_atoms_ckd = chunk_it(inputs[3], len(device_ids), devices) chunk_sizes = [sum(num_atoms) for num_atoms in nums_atoms_ckd] gs_charge_ckd, atom_type_ckd, pos_ckd = [ chunk_it(i, chunk_sizes, devices) for i in inputs[:3] ] inputs = list( zip(gs_charge_ckd, atom_type_ckd, pos_ckd, nums_atoms_ckd)) kwargs = scatter(kwargs, device_ids, self.dim) if kwargs else [] if len(inputs) < len(kwargs): inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) elif len(kwargs) < len(inputs): kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) inputs = tuple(inputs) kwargs = tuple(kwargs) return inputs, kwargs
def forward(self, x, label, **kwargs): if self.gpus is None: # cpu mode, normal fc layer x = classify(x, self.weight, label, simple_output=True, **kwargs) with torch.no_grad(): acc = accuracy(x, label) x = F.log_softmax(x, dim=1) label = label.unsqueeze(-1) loss = torch.gather(x, 1, label) loss = -loss.mean() return loss, acc else: weight_scattered = (w.to(i) for w, i in zip(self.weights, self.gpus)) feat_copies = [x.to(i) for i in self.gpus] labels_scattered = [] for i in range(len(self.weights)): labels_new = label.clone() labels_new[(labels_new >= self.weight_idx[i + 1]) | (labels_new < self.weight_idx[i])] = -1 labels_new = labels_new - self.weight_idx[i] labels_scattered.append(labels_new) kwargs_scattered = scatter(kwargs, self.gpus) input_scattered = list( zip(feat_copies, weight_scattered, labels_scattered)) modules = [classify] * len(self.weights) results_scattered = parallel_apply(modules, input_scattered, kwargs_scattered, self.gpus) logits = [i[0] for i in results_scattered] xexps = [i[1] for i in results_scattered] sums = [i[2] for i in results_scattered] argmaxs = [i[3] for i in results_scattered] maxs = [i[4] for i in results_scattered] sums = gather(sums, 0, dim=1) sums = sums.sum(dim=1, keepdim=True) sums_scattered = [sums.to(i) for i in self.gpus] loss_input_scattered = list( zip(logits, xexps, labels_scattered, sums_scattered)) loss_results_scattered = parallel_apply( [nllDistributed] * len(self.gpus), loss_input_scattered, None, self.gpus) loss_results_scattered = [i.sum() for i in loss_results_scattered] loss_results_scattered = [i.to(0) for i in loss_results_scattered] loss = sum(loss_results_scattered) loss = loss / x.shape[0] for i in range(len(argmaxs)): argmaxs[i] = argmaxs[i] + self.weight_idx[i] maxs = [i.to(0) for i in maxs] maxs = torch.stack(maxs, dim=1) _, max_split = torch.max(maxs, dim=1) idx = torch.arange(0, maxs.size(0), dtype=torch.long) argmaxs = [i.to(0) for i in argmaxs] argmaxs = torch.stack(argmaxs, dim=1) predicted = argmaxs[idx, max_split] total = label.size(0) predicted = predicted.cpu() label = label.cpu() correct = (predicted == label).sum().item() acc = correct / total return loss, acc
def eval_step(self): acc = numpy.zeros(self.corpus.num_of_multi_refs) loss = numpy.zeros(self.corpus.num_of_multi_refs) source_data, target_data, src_lens, tgt_lens = self.corpus.get_valid_batches() source_data_translation, target_data_translation = \ self.corpus.get_valid_batches_for_translation() num_of_batches = len(source_data_translation) num_of_examples = len(self.corpus.corpus_source_valid_numerate) for source, target, src_len, tgt_len in zip(source_data, target_data, src_lens, tgt_lens): for i in range(0, self.corpus.num_of_multi_refs): all_inputs = scatter(inputs=(source, target[i], src_len, tgt_len[i]), target_gpus=self.device_idxs, dim=0) num_of_device = len(all_inputs) args = list((self.replicas[i], all_inputs[i]) for i in range(0, num_of_device)) all_threads = list(Thread(target=_eval_worker, args=list(args[i]) + [self.queue]) for i in range(0, num_of_device)) for t in all_threads: t.start() for t in all_threads: t.join() all_results = list(self.queue.get() for _ in range(0, num_of_device)) acc[i] += sum(x[0] for x in all_results) loss[i] += sum(x[1] for x in all_results) acc /= num_of_examples loss /= num_of_examples translation_results = [] device_idx = self.device_idxs[self.processed_steps // self.eval_every_steps % self.num_of_devices] model = self.replicas[device_idx] for idx, source in enumerate(source_data_translation): print('\rTranslating batch %d/%d ... ' % (idx + 1, num_of_batches), sep=' ', end='') translated = model.infer_step(source.to(device_idx)) translation_results += translated print('done.') translation_results = list(list(self.corpus.tgt_idx2word[x] for x in line) for line in translation_results) target_data_translation = list(list(list(self.corpus.tgt_idx2word[x] for x in line) for line in gt_ref) for gt_ref in target_data_translation) if self.corpus.bpe_tgt: translation_results = list(self.corpus.byte_pair_handler_tgt.subwords2words(line) for line in translation_results) target_data_translation = list(list(self.corpus.byte_pair_handler_tgt.subwords2words(line) for line in gt_ref) for gt_ref in target_data_translation) bleu_score = self.bleu.bleu(translation_results, target_data_translation) print('BLEU score: %5.2f' % (bleu_score * 100)) if self.tgt_character_level: r = re.compile(r'((?:(?:[a-zA-Z0-9])+[\-\+\=!@#\$%\^&\*\(\);\:\'\"\[\]{},\.<>\/\?\|`~]*)+|[^a-zA-Z0-9])') print('') print('For character-level:') translation_results = list(' '.join(sum(list(r.findall(x) for x in line), list())).split() for line in translation_results) target_data_translation = list(list(' '.join(sum(list(r.findall(x) for x in line), list())).split() for line in gt_ref) for gt_ref in target_data_translation) bleu_score = self.bleu.bleu(translation_results, target_data_translation) print('BLEU score: %5.2f' % (bleu_score * 100)) self.stats.valid_record(acc, loss, bleu_score) del source_data, target_data, src_lens, tgt_lens, source_data_translation, target_data_translation for i in range(0, self.corpus.num_of_multi_refs): output = str.format('Step %6d valid, ref%1d acc: %5.2f loss: %5.2f bleu: %f' % (self.processed_steps, i, acc[i] * 100, loss[i], bleu_score)) print(output) self.stats.log_to_file(output) print('*' * 80) self.stats.log_to_file('*' * 80) print('Model performances (%s): ' % self.eval_type) if self.eval_type == 'acc': sorted_results = sorted(self.stats.valid_acc.items(), key=lambda d: d[1], reverse=True) temp_acc = float(acc.mean()) if self.best_acc < temp_acc: print('Best acc: %f -> %f at step %d -> %d' % (self.best_acc, temp_acc, self.best_step, self.processed_steps)) self.best_acc = temp_acc self.best_step = self.processed_steps else: print('Best acc: %f at step %d' % (self.best_acc, self.best_step)) elif self.eval_type == 'xent': sorted_results = sorted(self.stats.valid_loss.items(), key=lambda d: d[1]) temp_loss = float(loss.mean()) if self.best_loss > temp_loss: print('Best loss: %f -> %f at step %d -> %d' % (self.best_loss, temp_loss, self.best_step, self.processed_steps)) self.best_loss = temp_loss self.best_step = self.processed_steps else: print('Best loss: %f at step %d' % (self.best_loss, self.best_step)) else: sorted_results = sorted(self.stats.valid_bleu.items(), key=lambda d: d[1], reverse=True) temp_bleu = float(bleu_score) if self.best_bleu < temp_bleu: print('Best bleu: %f -> %f at step %d -> %d' % (self.best_bleu, temp_bleu, self.best_step, self.processed_steps)) self.best_bleu = temp_bleu self.best_step = self.processed_steps else: print('Best bleu: %f at step %d' % (self.best_bleu, self.best_step)) if self.max_save_models > 0: for (step_temp, value_temp) in sorted_results[:self.max_save_models]: print('%6d\t%8f' % (step_temp, value_temp)) for (step_temp, _) in sorted_results[self.max_save_models:]: path = self.stats.fold_name + '/' + str(step_temp) + '.pt' if os.path.isfile(self.stats.fold_name + '/' + str(step_temp) + '.pt'): os.remove(path) print('Remove %d.pt' % step_temp) print('*' * 80) return