Esempio n. 1
0
 def _test_scatter(self, tensor):
     x = tensor.detach().requires_grad_()
     result = dp.scatter(x, (0, 1))
     self.assertEqual(len(result), 2)
     self.assertEqual(result[0], x[:2])
     self.assertEqual(result[0].get_device(), 0)
     self.assertEqual(result[1], x[2:])
     self.assertEqual(result[1].get_device(), 1)
     grad = result[0].detach().clone().fill_(2)
     result[0].backward(grad)
     self.assertEqual(x.grad[:2], grad)
     self.assertEqual(x.grad[2:], grad.clone().zero_())
     _assertGradAndGradgradChecks(self, lambda y: dp.scatter(y, (0, 1)), (x,))
Esempio n. 2
0
def data_parallel(f,
                  input,
                  params,
                  stats,
                  mode,
                  device_ids,
                  output_device=None):
    assert isinstance(device_ids, list)
    if output_device is None:
        output_device = device_ids[0]

    if len(device_ids) == 1:
        return f(input, params, stats, mode)

    params_all = Broadcast.apply(device_ids, *params.values())
    params_replicas = [{
        k: params_all[i + j * len(params)]
        for i, k in enumerate(params.keys())
    } for j in range(len(device_ids))]
    stats_replicas = [
        dict(zip(stats.keys(), p))
        for p in comm.broadcast_coalesced(list(stats.values()), device_ids)
    ]

    replicas = [
        partial(f, params=p, stats=s, mode=mode)
        for p, s in zip(params_replicas, stats_replicas)
    ]
    inputs = scatter([input], device_ids)
    outputs = parallel_apply(replicas, inputs)
    return gather(outputs, output_device)
Esempio n. 3
0
def data_parallel(f,
                  input,
                  params,
                  stats,
                  mode,
                  device_ids,
                  output_device=None):
    if output_device is None:
        output_device = device_ids[0]

    if len(device_ids) == 1:
        return f(input, params, stats, mode)

    def replicate(param_dict, g):
        replicas = [{} for d in device_ids]
        for k, v in param_dict.iteritems():
            for i, u in enumerate(g(v)):
                replicas[i][k] = u
        return replicas

    params_replicas = replicate(params, lambda x: Broadcast(device_ids)(x))
    stats_replicas = replicate(stats, lambda x: comm.broadcast(x, device_ids))

    replicas = [
        lambda x, p=p, s=s, mode=mode: f(x, p, s, mode)
        for i, (p, s) in enumerate(zip(params_replicas, stats_replicas))
    ]
    inputs = scatter(input, device_ids)
    outputs = parallel_apply(replicas, inputs)
    return gather(outputs, output_device)
Esempio n. 4
0
    def esti_variance_step(self):
        source_data, target_data, src_lens, tgt_lens = self.corpus.get_esti_batches()
        whole_batchs_variance_list = []

        for source, target, src_len, tgt_len in zip(source_data, target_data, src_lens, tgt_lens):
            for i in range(0, self.corpus.num_of_multi_refs):
                all_inputs = scatter(inputs=(source, target[i], src_len, tgt_len[i]), target_gpus=self.device_idxs,
                                     dim=0)
                num_of_device = len(all_inputs)

                args = list((self.replicas[i], all_inputs[i]) for i in range(0, num_of_device))

                all_threads = list(Thread(target=_esti_variance_worker, args=list(args[i]) + [self.queue]) for i in
                                   range(0, num_of_device))
                for t in all_threads:
                    t.start()
                for t in all_threads:
                    t.join()

                all_results = list(self.queue.get() for _ in range(0, num_of_device))
                cur_batch_variance = numpy.mean(all_results)
                whole_batchs_variance_list.append(cur_batch_variance)

        total_variance = torch.FloatTensor(whole_batchs_variance_list)
        total_average_variance = total_variance.mean()

        return total_average_variance
Esempio n. 5
0
    def train_step(self, batch):
        report_idx = self.processed_steps % self.report_every_steps

        self.src_num_pad_tokens[report_idx] = int(batch[0].numel())
        self.tgt_num_pad_tokens[report_idx] = int(batch[1].numel())
        self.src_tokens[report_idx] = int(batch[2].sum())
        self.tgt_tokens[report_idx] = int(batch[3].sum())
        self.num_examples[report_idx] = int(batch[0].size(0))

        all_inputs = scatter(inputs=batch, target_gpus=self.device_idxs, dim=0)
        num_of_device = len(all_inputs)
        factor = 1.0 / self.num_examples[report_idx]
        args = list((self.replicas[i], all_inputs[i], factor, self.queue) for i in range(0, num_of_device))

        all_threads = list(Thread(target=_train_worker, args=list(args[i])) for i in range(0, num_of_device))
        for t in all_threads:
            t.start()
        for t in all_threads:
            t.join()

        all_results = list(self.queue.get() for _ in range(0, num_of_device))

        self.acc_report[report_idx] += sum(x[0] for x in all_results) / self.tgt_tokens[report_idx]
        self.loss_report[report_idx] += sum(x[1] for x in all_results) / self.num_examples[report_idx]
        self.update_decay_steps[report_idx] += 1

        return
Esempio n. 6
0
def data_parallel(f, input, params, stats, mode, device_ids, output_device=None):
    if output_device is None:
        output_device = device_ids[0]

    if len(device_ids) == 1: # only 1 device 
        return f(input, params, stats, mode)
    
    # function inside data_parallel 
    def replicate(param_dict, g):
        replicas = [{} for d in device_ids]  # replicas, list of n_devices dict
        for k,v in param_dict.iteritems():  # v is parameter
            for i,u in enumerate(g(v)):
                replicas[i][k] = u
        return replicas
    
    # broadcast parameters 
    params_replicas = replicate(params, lambda x: Broadcast(device_ids)(x))
    # broadcast stats 
    stats_replicas = replicate(stats, lambda x: comm.broadcast(x, device_ids))

    replicas = [lambda x,p=p,s=s,mode=mode: f(x,p,s,mode)
            for i,(p,s) in enumerate(zip(params_replicas, stats_replicas))]

    inputs = scatter(input, device_ids)

    outputs = parallel_apply(replicas, inputs)

    return gather(outputs, output_device)
Esempio n. 7
0
	def forward(self, *inputs, **kwargs):
		inputs = scatter(inputs, self.device_ids, dim=0)
		kwargs = scatter(kwargs, self.device_ids, dim=0)
		replicas = replicate(self.network, self.device_ids[:len(inputs)])
		outputs = parallel_apply(replicas, inputs, kwargs)
		outputs = list(zip(*outputs))

		res = []
		for i in range(len(outputs)):
			buf = []
			for j in range(len(outputs[i])):
				if isinstance(outputs[i][j], int):
					if outputs[i][j]<0:
						buf.append(outputs[i][j])
				else:
					buf.append(outputs[i][j].to(self.device_ids[0]))
			res.append(buf)
		return res
Esempio n. 8
0
 def _test_scatter(self, x):
     if not TEST_MULTIGPU:
         raise unittest.SkipTest("Only one GPU detected")
     x = Variable(x)
     result = dp.scatter(x, (0, 1))
     self.assertEqual(len(result), 2)
     self.assertEqual(result[0], x[:2])
     self.assertEqual(result[0].get_device(), 0)
     self.assertEqual(result[1], x[2:])
     self.assertEqual(result[1].get_device(), 1)
Esempio n. 9
0
    def replicate_module(self, module: torch.nn.Module,
                         devices: List[int]) -> List[torch.nn.Module]:
        assert self.n_mask_samples % len(devices) == 0
        copies = replicate(module, devices)

        def walk(module: torch.nn.Module, copy: torch.nn.Module):
            module_map = {id(module): copy}
            for name, ref in module._modules.items():
                module_map.update(walk(ref, getattr(copy, name)))

            return module_map

        devices = [_get_device_index(d) for d in devices]

        # Copy the custom parameters
        all_params = [p.get() for p in self.pointer_values]

        if (not self.masking_enabled) or (not self.training):
            scattered = _broadcast_coalesced_reshape(all_params, devices)
        else:
            # Here is more complicated, because there might be non-masked parameters which has to be handled in the
            # usual way
            masked_indices = [
                i for i, n in enumerate(self.param_names) if self.is_masked(n)
            ]
            simple_indices = [
                i for i, n in enumerate(self.param_names)
                if not self.is_masked(n)
            ]

            masked_params = scatter([all_params[i] for i in masked_indices],
                                    devices)
            simple_params = _broadcast_coalesced_reshape(
                [all_params[i] for i in simple_indices], devices)

            scattered = [[None] * len(all_params) for _ in devices]
            for d in range(len(devices)):
                for mi, mp in zip(masked_indices, masked_params[d]):
                    scattered[d][mi] = mp

                for si, sp in zip(simple_indices, simple_params[d]):
                    scattered[d][si] = sp

        for i, c in enumerate(copies):
            device_map = walk(module, c)
            for j, p in enumerate(self.pointer_values):
                setattr(device_map[id(p.parent)], p.name, scattered[i][j])

            self.update_rnn_params(c)

        return copies
Esempio n. 10
0
def data_parallel(f, input, params, mode, device_ids, output_device=None):
    assert isinstance(device_ids, list)
    if output_device is None:
        output_device = device_ids[0]

    if len(device_ids) == 1:
        return f(input, params, mode)

    params_all = Broadcast.apply(device_ids, *params.values())
    params_replicas = [{k: params_all[i + j*len(params)] for i, k in enumerate(params.keys())}
                       for j in range(len(device_ids))]

    replicas = [partial(f, params=p, mode=mode)
                for p in params_replicas]
    inputs = scatter([input], device_ids)
    outputs = parallel_apply(replicas, inputs)
    return gather(outputs, output_device)
Esempio n. 11
0
def data_parallel(f, input, params, mode, device_ids, output_device=None):
    device_ids = list(device_ids)
    if output_device is None:
        output_device = device_ids[0]

    if len(device_ids) == 1:
        return f(input, params, mode)

    params_all = Broadcast.apply(device_ids, *params.values())
    params_replicas = [{
        k: params_all[i + j * len(params)]
        for i, k in enumerate(params.keys())
    } for j in range(len(device_ids))]

    replicas = [partial(f, params=p, mode=mode) for p in params_replicas]
    inputs = scatter([input], device_ids)
    outputs = parallel_apply(replicas, inputs)
    return gather(outputs, output_device)
Esempio n. 12
0
    def scatter(self, inputs, kwargs, device_ids):
        from torch.nn.parallel import scatter

        def chunk_it(seq, num, devices):
            assert isinstance(num, (int, list, tuple))
            if isinstance(num, int):
                chunk_sizes = [
                    len(seq) / float(num),
                ] * num
            else:
                chunk_sizes = map(int, num)
            out = []
            last = 0.0
            for size, device in zip(chunk_sizes, devices):
                out.append(
                    torch.tensor(seq[int(last):int(last + size)],
                                 device=device))
                last += size
            return out

        devices = [torch.device('cuda:' + str(i)) for i in device_ids]
        nums_atoms_ckd = chunk_it(inputs[3], len(device_ids), devices)
        chunk_sizes = [sum(num_atoms) for num_atoms in nums_atoms_ckd]
        gs_charge_ckd, atom_type_ckd, pos_ckd = [
            chunk_it(i, chunk_sizes, devices) for i in inputs[:3]
        ]
        inputs = list(
            zip(gs_charge_ckd, atom_type_ckd, pos_ckd, nums_atoms_ckd))
        kwargs = scatter(kwargs, device_ids, self.dim) if kwargs else []

        if len(inputs) < len(kwargs):
            inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
        elif len(kwargs) < len(inputs):
            kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
        inputs = tuple(inputs)
        kwargs = tuple(kwargs)

        return inputs, kwargs
Esempio n. 13
0
    def forward(self, x, label, **kwargs):
        if self.gpus is None:
            # cpu mode, normal fc layer
            x = classify(x, self.weight, label, simple_output=True, **kwargs)
            with torch.no_grad():
                acc = accuracy(x, label)
            x = F.log_softmax(x, dim=1)
            label = label.unsqueeze(-1)
            loss = torch.gather(x, 1, label)
            loss = -loss.mean()
            return loss, acc
        else:
            weight_scattered = (w.to(i)
                                for w, i in zip(self.weights, self.gpus))
            feat_copies = [x.to(i) for i in self.gpus]
            labels_scattered = []
            for i in range(len(self.weights)):
                labels_new = label.clone()
                labels_new[(labels_new >= self.weight_idx[i + 1]) |
                           (labels_new < self.weight_idx[i])] = -1
                labels_new = labels_new - self.weight_idx[i]
                labels_scattered.append(labels_new)
            kwargs_scattered = scatter(kwargs, self.gpus)
            input_scattered = list(
                zip(feat_copies, weight_scattered, labels_scattered))
            modules = [classify] * len(self.weights)
            results_scattered = parallel_apply(modules, input_scattered,
                                               kwargs_scattered, self.gpus)

            logits = [i[0] for i in results_scattered]
            xexps = [i[1] for i in results_scattered]
            sums = [i[2] for i in results_scattered]
            argmaxs = [i[3] for i in results_scattered]
            maxs = [i[4] for i in results_scattered]

            sums = gather(sums, 0, dim=1)
            sums = sums.sum(dim=1, keepdim=True)
            sums_scattered = [sums.to(i) for i in self.gpus]
            loss_input_scattered = list(
                zip(logits, xexps, labels_scattered, sums_scattered))
            loss_results_scattered = parallel_apply(
                [nllDistributed] * len(self.gpus), loss_input_scattered, None,
                self.gpus)
            loss_results_scattered = [i.sum() for i in loss_results_scattered]

            loss_results_scattered = [i.to(0) for i in loss_results_scattered]
            loss = sum(loss_results_scattered)
            loss = loss / x.shape[0]

            for i in range(len(argmaxs)):
                argmaxs[i] = argmaxs[i] + self.weight_idx[i]
            maxs = [i.to(0) for i in maxs]
            maxs = torch.stack(maxs, dim=1)

            _, max_split = torch.max(maxs, dim=1)
            idx = torch.arange(0, maxs.size(0), dtype=torch.long)
            argmaxs = [i.to(0) for i in argmaxs]
            argmaxs = torch.stack(argmaxs, dim=1)
            predicted = argmaxs[idx, max_split]

            total = label.size(0)
            predicted = predicted.cpu()
            label = label.cpu()
            correct = (predicted == label).sum().item()
            acc = correct / total

            return loss, acc
Esempio n. 14
0
    def eval_step(self):
        acc = numpy.zeros(self.corpus.num_of_multi_refs)
        loss = numpy.zeros(self.corpus.num_of_multi_refs)

        source_data, target_data, src_lens, tgt_lens = self.corpus.get_valid_batches()
        source_data_translation, target_data_translation = \
            self.corpus.get_valid_batches_for_translation()

        num_of_batches = len(source_data_translation)
        num_of_examples = len(self.corpus.corpus_source_valid_numerate)

        for source, target, src_len, tgt_len in zip(source_data, target_data, src_lens, tgt_lens):

            for i in range(0, self.corpus.num_of_multi_refs):
                all_inputs = scatter(inputs=(source, target[i], src_len, tgt_len[i]), target_gpus=self.device_idxs,
                                     dim=0)
                num_of_device = len(all_inputs)
                args = list((self.replicas[i], all_inputs[i]) for i in range(0, num_of_device))

                all_threads = list(Thread(target=_eval_worker, args=list(args[i]) + [self.queue]) for i in
                                   range(0, num_of_device))
                for t in all_threads:
                    t.start()
                for t in all_threads:
                    t.join()

                all_results = list(self.queue.get() for _ in range(0, num_of_device))

                acc[i] += sum(x[0] for x in all_results)
                loss[i] += sum(x[1] for x in all_results)

        acc /= num_of_examples
        loss /= num_of_examples

        translation_results = []
        device_idx = self.device_idxs[self.processed_steps // self.eval_every_steps % self.num_of_devices]
        model = self.replicas[device_idx]

        for idx, source in enumerate(source_data_translation):
            print('\rTranslating batch %d/%d ... ' % (idx + 1, num_of_batches), sep=' ', end='')
            translated = model.infer_step(source.to(device_idx))
            translation_results += translated
        print('done.')
        translation_results = list(list(self.corpus.tgt_idx2word[x] for x in line)
                                   for line in translation_results)
        target_data_translation = list(list(list(self.corpus.tgt_idx2word[x] for x in line)
                                            for line in gt_ref)
                                       for gt_ref in target_data_translation)

        if self.corpus.bpe_tgt:
            translation_results = list(self.corpus.byte_pair_handler_tgt.subwords2words(line)
                                       for line in translation_results)
            target_data_translation = list(list(self.corpus.byte_pair_handler_tgt.subwords2words(line)
                                                for line in gt_ref)
                                           for gt_ref in target_data_translation)

        bleu_score = self.bleu.bleu(translation_results, target_data_translation)

        print('BLEU score: %5.2f' % (bleu_score * 100))

        if self.tgt_character_level:
            r = re.compile(r'((?:(?:[a-zA-Z0-9])+[\-\+\=!@#\$%\^&\*\(\);\:\'\"\[\]{},\.<>\/\?\|`~]*)+|[^a-zA-Z0-9])')
            print('')
            print('For character-level:')

            translation_results = list(' '.join(sum(list(r.findall(x) for x in line), list())).split() for line in translation_results)
            target_data_translation = list(list(' '.join(sum(list(r.findall(x) for x in line), list())).split()
                                           for line in gt_ref) for gt_ref in target_data_translation)
            bleu_score = self.bleu.bleu(translation_results, target_data_translation)
            print('BLEU score: %5.2f' % (bleu_score * 100))

        self.stats.valid_record(acc, loss, bleu_score)

        del source_data, target_data, src_lens, tgt_lens, source_data_translation, target_data_translation

        for i in range(0, self.corpus.num_of_multi_refs):
            output = str.format('Step %6d valid, ref%1d acc: %5.2f loss: %5.2f bleu: %f'
                                % (self.processed_steps, i, acc[i] * 100, loss[i], bleu_score))
            print(output)
            self.stats.log_to_file(output)

        print('*' * 80)
        self.stats.log_to_file('*' * 80)
        print('Model performances (%s): ' % self.eval_type)

        if self.eval_type == 'acc':
            sorted_results = sorted(self.stats.valid_acc.items(), key=lambda d: d[1], reverse=True)
            temp_acc = float(acc.mean())
            if self.best_acc < temp_acc:
                print('Best acc: %f -> %f at step %d -> %d' % (self.best_acc, temp_acc,
                                                               self.best_step, self.processed_steps))
                self.best_acc = temp_acc
                self.best_step = self.processed_steps
            else:
                print('Best acc: %f at step %d' % (self.best_acc, self.best_step))

        elif self.eval_type == 'xent':
            sorted_results = sorted(self.stats.valid_loss.items(), key=lambda d: d[1])
            temp_loss = float(loss.mean())
            if self.best_loss > temp_loss:
                print('Best loss: %f -> %f at step %d -> %d' % (self.best_loss, temp_loss,
                                                                self.best_step, self.processed_steps))
                self.best_loss = temp_loss
                self.best_step = self.processed_steps
            else:
                print('Best loss: %f at step %d' % (self.best_loss, self.best_step))

        else:
            sorted_results = sorted(self.stats.valid_bleu.items(), key=lambda d: d[1], reverse=True)
            temp_bleu = float(bleu_score)
            if self.best_bleu < temp_bleu:
                print('Best bleu: %f -> %f at step %d -> %d' % (self.best_bleu, temp_bleu,
                                                                self.best_step, self.processed_steps))
                self.best_bleu = temp_bleu
                self.best_step = self.processed_steps
            else:
                print('Best bleu: %f at step %d' % (self.best_bleu, self.best_step))

        if self.max_save_models > 0:
            for (step_temp, value_temp) in sorted_results[:self.max_save_models]:
                print('%6d\t%8f' % (step_temp, value_temp))

            for (step_temp, _) in sorted_results[self.max_save_models:]:
                path = self.stats.fold_name + '/' + str(step_temp) + '.pt'
                if os.path.isfile(self.stats.fold_name + '/' + str(step_temp) + '.pt'):
                    os.remove(path)
                    print('Remove %d.pt' % step_temp)

        print('*' * 80)

        return