def __call__(self, example, random_onset=False): copies = flatten( {key: example[key] for key in self.copy_keys} if self.copy_keys is not None else example ) if random_onset: start = np.random.rand() for fragment_step in self.fragment_steps.values(): start = int(int(start*fragment_step) / fragment_step) else: start = 0. def fragment(key, x): fragment_step = self.fragment_steps[key] fragment_length = self.fragment_lengths[key] start_idx = int(start * fragment_step) if start_idx > 0: slc = [slice(None)] * len(x.shape) slc[self.axis] = slice( int(start_idx), x.shape[self.axis] ) x = x[slc] end_index = x.shape[self.axis] if self.drop_last: end_index -= (fragment_length - 1) fragments = list() for start_idx in np.arange(0, end_index, fragment_step): if fragment_length == 1 and self.squeeze: fragments.append(x.take(start_idx, axis=self.axis)) else: slc = [slice(None)] * len(x.shape) slc[self.axis] = slice( int(start_idx), int(start_idx) + int(fragment_length) ) fragments.append(x[tuple(slc)]) return fragments features = flatten({ key: nested_op(lambda x: fragment(key, x), example[key]) for key in self.fragment_steps.keys() }) num_fragments = np.array( [len(features[key]) for key in list(features.keys())] ) assert all(num_fragments == num_fragments[0]), (list(features.keys()), num_fragments) fragments = list() for i in range(int(num_fragments[0])): fragment = deepcopy(copies) for key in features.keys(): fragment[key] = features[key][i] fragment = deflatten(fragment) fragments.append(fragment) return fragments
def pre_step(self, trainer): if self.trigger(iteration=trainer.iteration, epoch=trainer.epoch) \ and trainer.iteration != 0: print('SWA') module = self.get_module(trainer) self.count += 1 if self.swa_module is None: self.swa_module = module.state_dict() else: r = 1 / self.count self.swa_module = nested_op( lambda x, y: (1 - r) * x.to(y.device) + r * y, self.swa_module, module.state_dict())
def __call__(self, example): example = nested_op(self.collate, *example, sequence_type=()) return example
def __call__(self, example): if isinstance(example, dict): example = nested_op(self.concatenate, example, sequence_type=()) elif isinstance(example, (list, tuple)): example = self.concatenate(example) return example