示例#1
0
    def __call__(self, example, random_onset=False):
        copies = flatten(
            {key: example[key] for key in self.copy_keys}
            if self.copy_keys is not None else example
        )

        if random_onset:
            start = np.random.rand()
            for fragment_step in self.fragment_steps.values():
                start = int(int(start*fragment_step) / fragment_step)
        else:
            start = 0.

        def fragment(key, x):
            fragment_step = self.fragment_steps[key]
            fragment_length = self.fragment_lengths[key]
            start_idx = int(start * fragment_step)
            if start_idx > 0:
                slc = [slice(None)] * len(x.shape)
                slc[self.axis] = slice(
                    int(start_idx), x.shape[self.axis]
                )
                x = x[slc]

            end_index = x.shape[self.axis]
            if self.drop_last:
                end_index -= (fragment_length - 1)
            fragments = list()
            for start_idx in np.arange(0, end_index, fragment_step):
                if fragment_length == 1 and self.squeeze:
                    fragments.append(x.take(start_idx, axis=self.axis))
                else:
                    slc = [slice(None)] * len(x.shape)
                    slc[self.axis] = slice(
                        int(start_idx), int(start_idx) + int(fragment_length)
                    )
                    fragments.append(x[tuple(slc)])
            return fragments

        features = flatten({
            key: nested_op(lambda x: fragment(key, x), example[key])
            for key in self.fragment_steps.keys()
        })
        num_fragments = np.array(
            [len(features[key]) for key in list(features.keys())]
        )
        assert all(num_fragments == num_fragments[0]), (list(features.keys()), num_fragments)
        fragments = list()
        for i in range(int(num_fragments[0])):
            fragment = deepcopy(copies)
            for key in features.keys():
                fragment[key] = features[key][i]
            fragment = deflatten(fragment)
            fragments.append(fragment)
        return fragments
示例#2
0
 def pre_step(self, trainer):
     if self.trigger(iteration=trainer.iteration, epoch=trainer.epoch) \
             and trainer.iteration != 0:
         print('SWA')
         module = self.get_module(trainer)
         self.count += 1
         if self.swa_module is None:
             self.swa_module = module.state_dict()
         else:
             r = 1 / self.count
             self.swa_module = nested_op(
                 lambda x, y: (1 - r) * x.to(y.device) + r * y,
                 self.swa_module, module.state_dict())
示例#3
0
 def __call__(self, example):
     example = nested_op(self.collate, *example, sequence_type=())
     return example
示例#4
0
 def __call__(self, example):
     if isinstance(example, dict):
         example = nested_op(self.concatenate, example, sequence_type=())
     elif isinstance(example, (list, tuple)):
         example = self.concatenate(example)
     return example