예제 #1
0
 def concat_normal(distrib_list):
     means = [d.mean for d in distrib_list]
     stds = [d.stddev for d in distrib_list]
     if unsqueeze:
         means = [m.unsqueeze(dim) for m in means]
         stds = [s.unsqueeze(dim) for s in stds]
     means = Gather.apply(target_device, dim, *tuple(means))
     stds = Gather.apply(target_device, dim, *tuple(stds))
     return type(distrib_list[0])(means, stds)
예제 #2
0
 def gather_map(outputs):
     out = outputs[0]
     if isinstance(out, Variable):
         return Gather.apply(target_device, dim, *outputs)
     if isinstance(out, dict):
         return dict([(k,
                       Gather.apply(target_device, dim,
                                    *[each[k] for each in outputs]))
                      for k in out.keys()])
     if out is None:
         return None
     return type(out)(map(gather_map, zip(*outputs)))
예제 #3
0
        def gather_map(outputs):
            out = outputs[0]

            if isinstance(out, ModelOutput):
                names = out.to_dict().keys()
                none_names = [
                    n for n in out.__dataclass_fields__.keys()
                    if n not in names
                ]
                attrs = {
                    k: gather_map([getattr(o, k) for o in outputs])
                    for k in names
                }
                attrs = {**attrs, **{k: None for k in none_names}}
                return type(out)(**attrs)

            if isinstance(out, torch.Tensor):
                return Gather.apply(target_device, dim, *outputs)
            if out is None:
                return None

            if isinstance(out, dict):
                if not all((len(out) == len(d) for d in outputs)):
                    raise ValueError(
                        "All dicts must have the same number of keys")
                return type(out)(
                    ((k, gather_map([d[k] for d in outputs])) for k in out))

            return type(out)(map(gather_map, zip(*outputs)))
예제 #4
0
 def _stack_raw(self, values, out, maybe_cuda):
     if self.mode is VarLengthCollateV3Mode.GATHER and maybe_cuda:
         if values[0].dim() == 0:
             values = [o.unsqueeze(0) for o in values]
         return Gather.apply(self.gather_device, self.gather_dim, *values)
     else:
         return torch.stack(values, 0, out=out)
예제 #5
0
 def gather_map(outputs_):
     out = outputs_[0]
     if isinstance(out, torch.Tensor):
         # if all(t.dim() == 0 for t in outputs_) and dim == 0:
         #     # unsqueeze warnings will trigger
         #     import xdev
         #     xdev.embed()
         return OrigGather.apply(target_device, dim, *outputs_)
     if isinstance(out, BatchContainer):
         newdata = [d for dc in outputs_ for d in dc.data]
         if not out.cpu_only:
             import netharn as nh
             target_xpu = nh.XPU(target_device)
             newdata = target_xpu.move(newdata)
         return newdata
     if out is None:
         return None
     if isinstance(out, dict):
         out0_keys = set(out.keys())
         output_keys = [set(d.keys()) for d in outputs_]
         if not all(out0_keys == k for k in output_keys):
             problem_keys = (set.union(*output_keys) -
                             set.intersection(*output_keys))
             raise ValueError('All dicts must have the same keys. '
                              'problem_keys={}'.format(problem_keys))
         return type(out)(
             ((k, gather_map([d[k] for d in outputs_])) for k in out))
     return type(out)(map(gather_map, zip(*outputs_)))
예제 #6
0
    def gather_map(outputs):
        # original + modifications of pytorch `gather_map` function
        out = outputs[0]
        if isinstance(out, torch.Tensor):
            return Gather.apply(target_device, dim, *outputs)
        if out is None:
            return None

        #
        # Torch 1.3 modification
        #
        if isinstance(out, Output):
            # TODO need to be extensible! create a trw.train.output.gather function
            # TODO merge metrics too!
            outputs_t = [o.output for o in outputs]
            out.output = gather_map(outputs_t)

            if hasattr(out, 'output_truth'):
                output_truth = [o.output_truth for o in outputs]
                out.output_truth = gather_map(output_truth)
            return out
        #
        # end modification
        #

        if isinstance(out, dict):
            if not all((len(out) == len(d) for d in outputs)):
                raise ValueError('All dicts must have the same number of keys')
            return type(out)(
                ((k, gather_map([d[k] for d in outputs])) for k in out))
        return type(out)(map(gather_map, zip(*outputs)))
예제 #7
0
 def gather_map(outputs):
     out = outputs[0]
     if isinstance(out, torch.Tensor):
         return Gather.apply(target_device, dim, *outputs)
     if isinstance(out, ComplexTensor):
         return ComplexTensor(
             Gather.apply(target_device, dim, *[o.real for o in outputs]),
             Gather.apply(target_device, dim, *[o.imag for o in outputs]))
     if out is None:
         return None
     if isinstance(out, dict):
         if not all((len(out) == len(d) for d in outputs)):
             raise ValueError('All dicts must have the same number of keys')
         return type(out)(
             ((k, gather_map([d[k] for d in outputs])) for k in out))
     return type(out)(map(gather_map, zip(*outputs)))
예제 #8
0
def gather_res(outputs, target_device, dim=0):
    """
    Assuming the signatures are the same accross results!
    """
    out = outputs[0]
    args = {field: Gather.apply(target_device, dim, *[getattr(o, field) for o in outputs])
            for field, v in out.__dict__.items() if v is not None}
    return type(out)(**args)
 def gather_map(outputs):
     out = outputs[0]
     if isinstance(out, Variable):
         return Gather.apply(target_device, dim, *outputs)
     if out is None:
         return None
     if isinstance(out, dict):
         # Patch to support dictionaries
         value_iter = (item.values() for item in outputs)
         return dict(zip(out, map(gather_map, zip(*value_iter))))
     return type(out)(map(gather_map, zip(*outputs)))
예제 #10
0
 def gather_map(outputs):
     out = outputs[0]
     if torch.is_tensor(out):
         # MJY(20180330) HACK:: force nr_dims > 0
         if out.dim() == 0:
             outputs = [o.unsqueeze(0) for o in outputs]
         return Gather.apply(target_device, dim, *outputs)
     elif out is None:
         return None
     elif isinstance(out, collections.Mapping):
         return {k: gather_map([o[k] for o in outputs]) for k in out}
     elif isinstance(out, collections.Sequence):
         return type(out)(map(gather_map, zip(*outputs)))
예제 #11
0
 def gather_map(outputs):
     out = outputs[0]
     if isinstance(out, Variable):
         # MJY(20180330) HACK:: force nr_dims > 0
         if out.dim() == 0:
             outputs = [o.unsqueeze(0) for o in outputs]
         return Gather.apply(target_device, dim, *outputs)
     elif out is None:
         return None
     elif isinstance(out, collections.Mapping):
         return {k: gather_map([o[k] for o in outputs]) for k in out}
     elif isinstance(out, collections.Sequence):
         return type(out)(map(gather_map, zip(*outputs)))
예제 #12
0
    def gather_map(outputs):
        # An error in any GPU is an error for the entire batch-- throw it all out
        if any(['error' in x and x['error'] for x in outputs]):
            return {'error': True}

        out = outputs[0]
        loss = Gather.apply(target_device, dim,
                            *[x['losses']['loss'] for x in outputs])

        return_obj = {'losses': {'loss': loss}}

        if 'other_metrics' in out:
            other_metrics = {}

            for metric in list(out['other_metrics'].keys()):
                other_metrics[metric] = Gather.apply(
                    target_device, dim,
                    *[x['other_metrics'][metric] for x in outputs])

            return_obj['other_metrics'] = other_metrics

        return return_obj
 def gather_map(outputs):
     out = outputs[0]
     if isinstance(out, torch.Tensor):
         return Gather.apply(target_device, dim, *outputs)
     if out is None:
         return None
     if isinstance(out, dict):
         if not all((len(out) == len(d) for d in outputs)):
             raise ValueError('All dicts must have the same number of keys')
         return type(out)(
             ((k, gather_map([d[k] for d in outputs])) for k in out))
     if isinstance(out, PackedSequence):
         return packed_sequence_gather(outputs, target_device)
     return type(out)(map(gather_map, zip(*outputs)))
예제 #14
0
 def gather_map(outputs):
     out = outputs[0]
     if isinstance(out, Variable) or torch.is_tensor(out):
         if out.dim() == 0:
             outputs = [o.unsqueeze(0) for o in outputs]
         return Gather.apply(target_device, dim, *outputs)
     elif out is None:
         return None
     elif isinstance(out, collections.Mapping):
         return {k: gather_map([o[k] for o in outputs]) for k in out}
     elif isinstance(out, six.string_types):
         return outputs
     elif isinstance(out, collections.Sequence):
         return type(out)(map(gather_map, zip(*outputs)))
     return outputs
예제 #15
0
    def gather_map(outputs):
        if isinstance(outputs, Variable):
            if target_device == -1:
                return outputs.cpu()
            return outputs.cuda(target_device)

        out = outputs[0]
        if isinstance(out, Variable):
            return Gather.apply(target_device, dim, *outputs)
        if out is None:
            return None

        if isinstance(out, ScatterList):
            return tuple(map(gather_map, itertools.chain(*outputs)))

        return type(out)(map(gather_map, zip(*outputs)))
예제 #16
0
 def gather_map(outputs):
     out = outputs[0]
     if isinstance(out, torch.Tensor):
         return Gather.apply(target_device, dim, *outputs)
     if out is None:
         return None
     if isinstance(out, dict):
         if not all((len(out) == len(d) for d in outputs)):
             raise ValueError('All dicts must have the same number of keys')
         return type(out)(
             ((k, gather_map([d[k] for d in outputs])) for k in out))
     if isinstance(out, torch.distributions.Distribution):
         return concat_distrib(outputs,
                               target_device,
                               dim=0,
                               unsqueeze=False)
     return type(out)(map(gather_map, zip(*outputs)))
예제 #17
0
        def gather_map(outputs):
            elem = outputs[0]
            elem_type = type(elem)

            if isinstance(elem, torch.Tensor):
                return Gather.apply(self.output_device, self.dim, *outputs)

            if elem is None:
                return None

            if isinstance(elem, Mapping):
                if not all((len(elem) == len(d) for d in outputs)):
                    raise ValueError('All dicts must have the same number of keys')
                return elem_type(((k, gather_map([d[k] for d in outputs]))
                                  for k in elem))

            if isinstance(elem, Iterable) and not isinstance(elem, str):
                return elem_type(map(gather_map, zip(*outputs)))

            return outputs