예제 #1
0
    def forward(self, *inputs, **kwargs):

        if (not self.device_ids) or (len(self.device_ids) == 1):
            return self.module(*inputs, **kwargs) if self.gather_output else [
                self.module(*inputs, **kwargs)
            ]
        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
        inputs = clean_list(inputs)
        ngpu = len(inputs)
        if self.training and ngpu > self.ngradev:
            self.ngradev = ngpu
        if ngpu == 1:
            _fwd_m = self.module if self.nets is None else self.nets[0]
            return _fwd_m(*inputs[0], **kwargs[0]) if self.gather_output else [
                _fwd_m(*inputs[0], **kwargs[0])
            ]
        devices = self.device_ids[:ngpu]
        replicas = self.replicate(
            self.module, devices) if self.nets is None else self.nets[:ngpu]
        outputs = parallel_apply(replicas, inputs, devices, kwargs)
        if self.gather_output:
            return self.gather(outputs, self.output_device)
        else:
            return tuple(zip(
                *outputs)) if isinstance(outputs[0], tuple) else outputs
예제 #2
0
    def train_decode(self, *inputs, **kwargs):

        if not self.device_ids:
            return self.module.train_decode(
                *inputs, **kwargs) if self.gather_output else [
                    self.module.train_decode(*inputs, **kwargs)
                ]
        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
        inputs = clean_list(inputs)
        ngpu = len(inputs)
        if (len(self.device_ids) == 1) or (ngpu == 1):
            return self.module.train_decode(
                *inputs[0], **kwargs[0]) if self.gather_output else [
                    self.module.train_decode(*inputs[0], **kwargs[0])
                ]
        devices = self.device_ids[:ngpu]
        if self.nets is None:
            replicas = self.replicate(self.module, devices)
        else:
            replicas = self.nets[:ngpu]
        outputs = parallel_apply_train_decode(replicas, inputs, devices,
                                              kwargs)
        return self.gather(
            pad_tensors(outputs),
            self.output_device) if self.gather_output else outputs
예제 #3
0
def handle(srcfl, rsf, rslangf, vsize=65532):

    vocab = {}
    lang_vocab = {}

    curid = 0
    for srcf in srcfl:
        if srcf == "--target":
            break
        with open(srcf, "rb") as f:
            for line in f:
                tmp = line.strip()
                if tmp:
                    tokens = clean_list(tmp.decode("utf-8").split())
                    for token in tokens[1:]:
                        vocab[token] = vocab.get(token, 0) + 1
                    token = tokens[0]
                    lang_vocab[token] = lang_vocab.get(token, 0) + 1
        curid += 1

    for srcf in srcfl[curid + 1:]:
        with open(srcf, "rb") as f:
            for line in f:
                tmp = line.strip()
                if tmp:
                    for token in clean_list_iter(tmp.decode("utf-8").split()):
                        vocab[token] = vocab.get(token, 0) + 1

    save_vocab(vocab, rsf, omit_vsize=vsize)
    save_vocab(lang_vocab, rslangf, omit_vsize=False)
예제 #4
0
파일: copy.py 프로젝트: lgstd/transformer
def handle(srcfl, tgtfl):

    nsrc = ntgt = ncopy = 0
    for srcf, tgtf in zip(srcfl, tgtfl):
        with open(srcf, "rb") as fsrc, open(tgtf, "rb") as ftgt:
            for srcl, tgtl in zip(fsrc, ftgt):
                srcl, tgtl = srcl.strip(), tgtl.strip()
                if srcl or tgtl:
                    srcvcb, tgtvcb = clean_list(
                        srcl.decode("utf-8").split()), clean_list(
                            tgtl.decode("utf-8").split())
                    nsrc += len(srcvcb)
                    ntgt += len(tgtvcb)
                    ncopy += len(set(srcvcb) & set(tgtvcb))

    print("src, tgt, copy: %d, %d, %d" % (
        nsrc,
        ntgt,
        ncopy,
    ))
예제 #5
0
def handle(srcfl, tgtfl, r=0.4):

	ens = "\n".encode("utf-8")
	with FileList(srcfl, "rb") as rfl, FileList(tgtfl, "wb") as wfl:
		for lines in zip(*rfl):
			lines = [line.strip() for line in lines]
			if all(lines):
				lines = [clean_list(line.decode("utf-8").split()) for line in lines]
				ratios = [float(len(set(line))) / float(len(line)) for line in lines]
				if all_gt(ratios, r):
					for line, f in zip(lines, wfl):
						f.write(" ".join(line).encode("utf-8"))
						f.write(ens)
예제 #6
0
	def forward(self, inputs, *targets, **kwargs):
		# input should be already scatterd
		# scattering the targets instead
		if not self.device_ids:
			return self.module(inputs[0], *targets, **kwargs)
		targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
		targets = clean_list(targets)
		ngpu = len(targets)
		if ngpu == 1:
			_fwd_m = self.module if self.nets is None else self.nets[0]
			return _fwd_m(inputs[0], *targets[0], **kwargs[0])
		devices = self.device_ids[:ngpu]
		replicas = self.replicate(self.module, devices) if self.nets is None else self.nets[:ngpu]
		outputs = criterion_parallel_apply(replicas, inputs, targets, devices, kwargs)

		return self.gather(outputs, self.output_device)
예제 #7
0
def handle(srcf, rsf, rslangf, vsize=65532):

    vocab = {}
    lang_vocab = {}

    with open(srcf, "rb") as f:
        for line in f:
            tmp = line.strip()
            if tmp:
                tokens = clean_list(tmp.decode("utf-8").split())
                for token in tokens[1:]:
                    vocab[token] = vocab.get(token, 0) + 1
                token = tokens[0]
                lang_vocab[token] = lang_vocab.get(token, 0) + 1

    save_vocab(vocab, rsf, omit_vsize=vsize)
    save_vocab(lang_vocab, rslangf, omit_vsize=False)
예제 #8
0
def doc_reader(fname):

	with open(fname, "rb") as frd:
		cache = []
		max_tok = 0
		for line in frd:
			tmp = line.strip()
			if tmp:
				tmp = clean_list(tmp.decode("utf-8").split())
				_ld = len(tmp)
				if _ld > max_tok:
					max_tok = _ld
				cache.append(tmp)
			else:
				yield cache, max_tok
				cache = []
				max_tok = 0
		if cache:
			yield cache, max_tok
예제 #9
0
    def forward(self, *inputs, **kwargs):

        if (not self.device_ids) or (len(self.device_ids) == 1):
            return self.module(*inputs, **kwargs) if self.gather_output else [
                self.module(*inputs, **kwargs)
            ]
        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
        inputs = clean_list(inputs)
        ngpu = len(inputs)
        if ngpu == 1:
            return self.module(*inputs[0], **
                               kwargs[0]) if self.gather_output else [
                                   self.module(*inputs[0], **kwargs[0])
                               ]
        devices = self.device_ids[:ngpu]
        replicas = self.replicate(
            self.module, devices) if self.nets is None else self.nets[:ngpu]
        outputs = parallel_apply(replicas, inputs, devices, kwargs)
        # uncomment following two lines if your model have multiple outputs
        #if isinstance(outputs[0], tuple):
        #outputs = tuple(zip(*outputs))
        return self.gather(
            outputs, self.output_device) if self.gather_output else outputs
예제 #10
0
	def process(self, input):

		return self.handler.detokenize(clean_list(input.split()))