def forward(self, x, size): """ Args: x: input tensor of shape b,c,(f,)t size: size to trim at dims (f,)t Returns: x shortened by size in the last (two) dimension(s) """ assert x.dim() in [3, 4], x.shape sides = to_list(self.side, x.dim() - 2) sizes = to_list(size, x.dim() - 2) slc = [slice(None)] * x.dim() for i, (side, size) in enumerate(zip(sides, sizes)): idx = 2 + i if side is None or size < 1: assert size == 0, sizes continue elif side == 'front': slc[idx] = slice(size, x.shape[idx]) elif side == 'both': # if size is odd: end is trimmed more than front slc[idx] = slice(size // 2, -math.ceil(size / 2)) elif side == 'end': slc[idx] = slice(0, -size) else: raise ValueError x = x[tuple(slc)] return x
def compute_conv_output_shape(input_shape, out_channels, kernel_size, dilation, stride, pad_type, transpose=False): input_shape = np.array(input_shape) output_shape = np.zeros_like(input_shape) output_shape[0] = input_shape[0] output_shape[1] = out_channels kernel_size = to_list(kernel_size, len(input_shape) - 2) dilation = to_list(dilation, len(input_shape) - 2) stride = to_list(stride, len(input_shape) - 2) pad_type = to_list(pad_type, len(input_shape) - 2) for d in range(len(kernel_size)): if transpose: output_shape[2 + d] = _compute_transpose_out_size( input_shape[2 + d], kernel_size[d], dilation[d], stride[d], pad_type[d]) else: output_shape[2 + d] = _compute_conv_out_size( input_shape[2 + d], kernel_size[d], dilation[d], stride[d], pad_type[d]) assert all(output_shape > 0), output_shape return output_shape.astype(np.int64)
def forward(self, x, size): """ Args: x: input tensor of shape b,c,(f,)t size: size to pad to dims (f,)t Returns: x padded by size in the last (two) dimension(s) ? """ assert x.dim() in [3, 4], x.shape sides = to_list(self.side, x.dim() - 2) sizes = to_list(size, x.dim() - 2) if not any(np.array(sizes)): return x pad = [] for side, size in reversed(list(zip(sides, sizes))): if side is None or size < 1: assert size == 0, sizes pad.extend([0, 0]) elif side == 'front': pad.extend([size, 0]) elif side == 'both': # if size is odd: end is padded more than front pad.extend([size // 2, math.ceil(size / 2)]) elif side == 'end': pad.extend([0, size]) else: raise ValueError(f'pad side {side} unknown') x = F.pad(x, pad, mode=self.mode) return x
def trim_padding(self, x): """ counter part to pad_or_trim used in transposed convolutions. Only implemented if out_shape is not None! Args: x: input tensor of shape b,c,(f,)t Returns: """ assert self.is_transpose() front_pad, end_pad = list(zip(*[ compute_pad_size(k, d, s, t) for k, d, s, t in zip( to_list(self.kernel_size, 1+self.is_2d()), to_list(self.dilation, 1+self.is_2d()), to_list(self.stride, 1+self.is_2d()), to_list(self.pad_type, 1+self.is_2d()), ) ])) end_pad = np.maximum(np.array(end_pad)-np.array(self.stride)+1, 0) if any(front_pad): x = Trim(side='front')(x, size=front_pad) if any(end_pad): x = Trim(side='end')(x, size=end_pad) return x
def __init__(self, length: int = -1, shift: int = None, include_keys: Union[str, list, tuple] = None, exclude_keys: Union[str, list, tuple] = None, copy_keys: Union[str, bool, list, tuple] = True, axis: Union[int, list, tuple, dict] = -1, anchor: Union[int, str] = 'left', mode: 'str' = 'constant', padding: bool = False, flatten_separator: str = '.'): self.include = None if include_keys is None else to_list(include_keys) self.exclude = [] if exclude_keys is None else to_list(exclude_keys) self.length = length if isinstance(axis, (dict, int)): self.axis = axis if isinstance(axis, dict): assert self.include is not None, (self.axis, self.include) assert set(axis.keys()) == set(self.include), ( axis.keys(), self.include ) elif isinstance(axis, (tuple, list)): self.axis = to_list(axis) assert self.include is not None, (self.axis, self.include) assert len(axis) == len(include_keys), ( 'If axis are specified as list it has to have the same length' 'as include_keys', axis, include_keys ) else: raise TypeError('Unknown type for axis', axis) if shift is None: shift = length # If there is a use case for shift > length, open a pull request and # remove this assert. assert shift <= length, (shift, length) self.shift = shift assert isinstance(anchor, (str, int)), anchor self.anchor = anchor self.copy_keys = to_list(copy_keys) assert all([isinstance(key, (bool, str)) for key in self.copy_keys]), ( 'All keys in copy_keys have to be str, or copy key has to be one' 'boolean', copy_keys ) assert mode in possible_segment_modes, ( 'length_mode has to be one of', possible_segment_modes, 'but is', mode ) self.mode = mode if padding: # No padding is implemented for the begging of a signal assert anchor in [0, 'left'], (padding, anchor) self.padding = padding self.flatten_separator = flatten_separator
def get_out_lengths(self, in_lengths): out_lengths = in_lengths for i, conv in enumerate(self.convs): out_lengths = conv.get_out_lengths(out_lengths) if self.pool_types[i] is not None: if self.is_transpose(): raise NotImplementedError else: out_lengths = out_lengths / to_list(self.pool_sizes[i])[-1] if to_list(self.pad_sides[i])[-1] is None: out_lengths = np.floor(out_lengths) else: out_lengths = np.ceil(out_lengths) return out_lengths
def forward(self, *tensors, seq_len=None, seq_axes=-1): """ Args: tensors: features (BxFxT) seq_len: Returns: """ if self.training: seq_axes = to_list(seq_axes, len(tensors)) T = tensors[0].shape[seq_axes[0]] max_cutoff = int(self.max_cutoff_rate * min(seq_len)) cutoff_front = int(np.random.rand() * (max_cutoff + 1)) cutoff_end = int(np.random.rand() * (max_cutoff + 1)) seq_len = np.minimum( np.array(seq_len) - cutoff_front, T - (cutoff_front + cutoff_end) ).astype(np.int) tensors = list(tensors) for i, tensor in enumerate(tensors): tensors[i] = tensor.narrow( seq_axes[i], cutoff_front, T - cutoff_end ) return (*tensors, seq_len)
def review(self, inputs, outputs): # compute loss targets = inputs[self.label_key] if outputs.dim() == 3: # (B, T, K) if targets.dim() == 2: # (B, K) targets = targets.unsqueeze(1).expand(outputs.shape) outputs = outputs.contiguous().view((-1, outputs.shape[-1])) targets = targets.contiguous().view((-1, targets.shape[-1])) bce = nn.BCELoss(reduction='none')(outputs, targets).sum(-1) # create review including metrics and visualizations labels, label_ranked_precisions = positive_class_precisions( targets.cpu().data.numpy(), outputs.cpu().data.numpy()) review = dict(loss=bce.mean(), scalars=dict( labels=labels, label_ranked_precisions=label_ranked_precisions), images=dict(features=inputs[self.feature_key][:3])) for boundary in to_list(self.decision_boundary): decision = (outputs.detach() > boundary).float() true_pos = (decision * targets).sum() false_pos = (decision * (1. - targets)).sum() false_neg = ((1. - decision) * targets).sum() review['scalars'].update({ f'true_pos_{boundary}': true_pos, f'false_pos_{boundary}': false_pos, f'false_neg_{boundary}': false_neg }) return review
def initialize_labels( self, labels=None, dataset=None, dataset_name=None, verbose=False ): filename = f"{self.label_key}.json" if dataset_name is None \ else f"{self.label_key}_{dataset_name}.json" filepath = None if self.storage_dir is None \ else (self.storage_dir / filename).expanduser().absolute() if filepath and Path(filepath).exists(): with filepath.open() as fid: labels_ = json.load(fid) if verbose: print(f'Restored labels from {filepath}') if labels is not None: assert labels_ == labels labels = labels_ else: if labels is None: labels = set() for example in dataset: labels.update(to_list(example[self.label_key])) labels = sorted(labels) if filepath: with filepath.open('w') as fid: json.dump(labels, fid, indent=4) if verbose: print(f'Saved labels to {filepath}') self.label_mapping = { label: i for i, label in enumerate(labels) } self.inverse_label_mapping = { i: label for label, i in self.label_mapping.items() }
def _count_labels(self, raw_datasets, label_key, label_counts=None, reps=1): if label_counts is None: label_counts = defaultdict(lambda: 0) if isinstance(raw_datasets, list): labels = [] for ds, ds_reps in raw_datasets: label_counts, cur_labels = self._count_labels( ds, label_key, label_counts=label_counts, reps=ds_reps * reps) labels.append(cur_labels) return label_counts, labels labels = [] for example in raw_datasets: cur_labels = sorted(set(to_list(example[label_key]))) labels.append(cur_labels) for label in cur_labels: label_counts[label] += reps # print(label_counts) return label_counts, labels
def trim_padded_or_pad_trimmed(self, y, out_shape=None): assert self.is_transpose() if out_shape is not None: assert y.shape[:2] == out_shape[:2], (y.shape, out_shape) size = np.array(y.shape[2:]) - np.array(out_shape[2:]) pad_side = [ 'both' if side is None else side # if no padding has been used both sides have been trimmed for side in to_list(self.pad_side) ] if any(size > 0): y = Trim(side=pad_side)(y, size=size) if any(size < 0): y = Pad(side=pad_side, mode='constant')(y, size=-size) elif any([side is not None for side in to_list(self.pad_side)]): raise NotImplementedError return y
def get_out_lengths(self, in_lengths): """ L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{kernel\_size} + \text{output\_padding} Returns: """ out_lengths = np.array(in_lengths) assert out_lengths.ndim == 1, out_lengths.ndim if self.is_transpose(): raise NotImplementedError else: if to_list(self.pad_side)[-1] is None: out_lengths = out_lengths - ( to_list(self.dilation)[-1] * (to_list(self.kernel_size)[-1] - 1)) out_lengths = np.ceil(out_lengths / to_list(self.stride)[-1]) return out_lengths.astype(np.int64)
def forward(self, x, size): sides = to_list(self.side, x.dim() - 2) sizes = to_list(size, x.dim() - 2) pad = [] for side, size in list(zip(sides, sizes))[::-1]: if side is None or size < 1: pad.extend([0, 0]) elif side == 'front': pad.extend([size, 0]) elif side == 'both': pad.extend([size // 2, math.ceil(size / 2)]) elif side == 'end': pad.extend([0, size]) else: raise ValueError(f'pad side {side} unknown') x = F.pad(x, tuple(pad), mode=self.mode) return x
def forward(self, x, size): sides = to_list(self.side, x.dim() - 2) sizes = to_list(size, x.dim() - 2) slc = [slice(None)] * x.dim() for i, (side, size) in enumerate(zip(sides, sizes)): idx = 2 + i if side is None or size < 1: continue elif side == 'front': slc[idx] = slice(size, x.shape[idx]) elif side == 'both': slc[idx] = slice(size // 2, -math.ceil(size / 2)) elif side == 'end': slc[idx] = slice(0, -size) else: raise ValueError x = x[tuple(slc)] return x
def trim_padded_or_pad_trimmed(self, y, out_shape=None): assert self.is_transpose() if out_shape is not None: assert y.shape[:2] == tuple(out_shape)[:2], (y.shape, out_shape) size = np.array(y.shape[2:]) - np.array(out_shape[2:]) pad_side = [ 'both' if side is None else side # if no padding has been used both sides have been trimmed for side in to_list(self.pad_side) ] if any(size > 0): y = Trim(side=pad_side)(y, size=size) if any(size < 0): y = Pad(side=pad_side, mode='constant')(y, size=-size) elif any([side is not None for side in to_list(self.pad_side)]): # # trim the minimal padding that could have occurred # trim_size = np.array(self.dilation) * (np.array(self.kernel_size) - 1) - (np.array(self.stride)-1) # y = Trim(side=self.pad_side)(y, size=trim_size) raise NotImplementedError return y
def _get_labels(self, example): labels = example[self.label_key] if self.multi_hot_encoded_labels: assert labels.ndim >= 1, labels.shape if labels.ndim > 1: assert labels.ndim == 2, labels.shape labels = labels.sum(-1) labels = np.argwhere(labels > 0).flatten() if isinstance(labels, np.ndarray): labels = labels.tolist() return to_list(labels)
def compute_conv_output_sequence_lengths(input_sequence_lengths, kernel_size, dilation, pad_type, stride, transpose=False): kernel_size = to_list(kernel_size) dilation = to_list(dilation) stride = to_list(stride) pad_type = to_list(pad_type) if transpose: seq_len_out = _compute_transpose_out_size(input_sequence_lengths, kernel_size[-1], dilation[-1], stride[-1], pad_type[-1]) else: seq_len_out = _compute_conv_out_size(input_sequence_lengths, kernel_size[-1], dilation[-1], stride[-1], pad_type[-1]) assert all(seq_len_out > 0), seq_len_out return seq_len_out.astype(np.int64)
def pad(self, x): """ adds padding Args: x: input tensor of shape b,c,(f,)t Returns: """ front_pad, end_pad = list( zip(*[ compute_pad_size(k, d, s, t) for k, d, s, t in zip( to_list(self.kernel_size, 1), to_list(self.dilation, 1), to_list(self.stride, 1), to_list(self.pad_type, 1), ) ])) if any(np.array(front_pad) > 0): x = Pad(side='front')(x, size=front_pad) if any(np.array(end_pad) > 0): x = Pad(side='end')(x, size=end_pad) return x
def pad_or_trim(self, x): pad_dims = [side is not None for side in to_list(self.pad_side)] if any(pad_dims): size = (np.array(self.dilation) * (np.array(self.kernel_size) - 1) - ((np.array(x.shape[2:]) - 1) % np.array(self.stride))).tolist() x = Pad(side=self.pad_side)(x, size=size) if not all(pad_dims): size = ((np.array(x.shape[2:]) - np.array(self.kernel_size)) % np.array(self.stride)).tolist() x = Trim(side=('both' if not pad_dim else None for pad_dim in pad_dims))(x, size) return x
def get_out_shape(self, in_shape): assert in_shape[1] == self.in_channels, (in_shape[1], self.in_channels) out_shape = in_shape for i, conv in enumerate(self.convs): out_shape = conv.get_out_shape(out_shape) out_shape[1] = self.layer_in_channels[i + 1] if self.pool_types[i] is not None: if self.is_transpose(): raise NotImplementedError else: out_shape_ = out_shape[2:] / np.array(self.pool_sizes[i]) out_shape[2:] = np.where( [pad is None for pad in to_list(self.pad_sides[i])], np.floor(out_shape_), np.ceil(out_shape_)) return out_shape
def get_out_shape(self, in_shape): out_shape = np.array(in_shape) assert len(out_shape) == 3 + self.is_2d(), (len(out_shape), self.is_2d()) assert in_shape[1] == self.in_channels, (in_shape[1], self.in_channels) out_shape[1] = self.out_channels if self.is_transpose(): raise NotImplementedError else: out_shape_ = out_shape[2:] - (np.array(self.dilation) * (np.array(self.kernel_size) - 1)) out_shape[2:] = np.where( [pad is None for pad in to_list(self.pad_side)], out_shape_, out_shape[2:]) out_shape[2:] = np.ceil(out_shape[2:] / np.array(self.stride)) return out_shape.astype(np.int64)
def get_shapes(self, in_shape): assert in_shape[1] == self.in_channels, (in_shape[1], self.in_channels) out_shape = in_shape shapes = [in_shape] for i, conv in enumerate(self.convs): out_shape = conv.get_out_shape(out_shape) out_shape[1] = self.layer_in_channels[ i + 1] # has to be adjusted with dense skip connections if self.pool_types[i] is not None: if self.is_transpose(): raise NotImplementedError else: out_shape_ = out_shape[2:] / np.array(self.pool_sizes[i]) out_shape[2:] = np.where( [pad is None for pad in to_list(self.pad_sides[i])], np.floor(out_shape_), np.ceil(out_shape_)) shapes.append(out_shape) return shapes
def modify_summary(self, summary): # compute lwlrap if all([ key in summary['scalars'] for key in ['labels', 'label_ranked_precisions'] ]): labels = summary['scalars'].pop('labels') label_ranked_precisions = summary['scalars'].pop( 'label_ranked_precisions') summary['scalars']['lwlrap'] = lwlrap_from_precisions( label_ranked_precisions, labels)[0] # compute precision, recall and fscore for each decision boundary for boundary in to_list(self.decision_boundary): true_pos_key = f'true_pos_{boundary}' false_pos_key = f'false_pos_{boundary}' false_neg_key = f'false_neg_{boundary}' if all([ key in summary['scalars'] for key in [true_pos_key, false_pos_key, false_neg_key] ]): tp = np.sum(summary['scalars'].pop(true_pos_key)) fp = np.sum(summary['scalars'].pop(false_pos_key)) fn = np.sum(summary['scalars'].pop(false_neg_key)) p = tp / (tp + fp) r = tp / (tp + fn) summary['scalars'][f'precision_{boundary}'] = p summary['scalars'][f'recall_{boundary}'] = r summary['scalars'][f'f1_{boundary}'] = 2 * (p * r) / (p + r) summary = super().modify_summary(summary) for key, image in summary['images'].items(): if image.dim() == 4 and image.shape[1] > 1: image = image[:, 0] if image.dim() == 3: image = image.unsqueeze(1) summary['images'][key] = make_grid(image.flip(2), normalize=True, scale_each=False, nrow=1) return summary
def finalize_dogmatic_config(cls, config): config['cnn_2d'] = {'factory': CNN2d} config['cnn_1d'] = {'factory': CNN1d} config['enc'] = {'factory': nn.GRU} config['fcn'] = {'factory': fully_connected_stack} input_size = config['input_size'] if config['cnn_2d'] is not None and input_size is not None: cnn_2d = config['cnn_2d']['factory'].from_config(config['cnn_2d']) output_size = cnn_2d.get_out_shape((input_size, 1000))[0] out_channels = cnn_2d.out_channels \ if cnn_2d.out_channels is not None \ else cnn_2d.hidden_channels[-1] input_size = out_channels * output_size if config['cnn_1d'] is not None and input_size is not None: config['cnn_1d']['in_channels'] = input_size input_size = config['cnn_1d']['out_channels'] \ if config['cnn_1d']['out_channels'] is not None \ else to_list(config['cnn_1d']['hidden_channels'])[-1] if config['enc'] is not None: if config['enc']['factory'] == nn.GRU: config['enc'].update({ 'num_layers': 1, 'bias': True, 'batch_first': True, 'dropout': 0., 'bidirectional': False }) elif config['enc']['factory'] == Transformer: config['enc']['norm'] = config['cnn']['norm'] if input_size is not None: config['enc']['input_size'] = input_size if config['fcn'] is not None: config['fcn']['input_size'] = config['enc']['hidden_size'] config['pool'] = None
def __call__(self, example): return not any( [name in to_list(example[self.key]) for name in self.names])
def __init__(self, key, names): self.key = key self.names = to_list(names)
def fragment_signal(*signals, axis, step, fragment_length, onset_mode='center'): """ Args: signals: axis: step: fragment_length: Returns: >>> signals = [np.arange(20).reshape((2, 10)), np.arange(10).reshape((2, 5))] >>> from pprint import pprint >>> pprint(fragment_signal(*signals, axis=1, step=[4, 2], fragment_length=[4, 2])) ([array([[ 0, 1, 2, 3], [10, 11, 12, 13]]), array([[ 4, 5, 6, 7], [14, 15, 16, 17]])], [array([[0, 1], [5, 6]]), array([[2, 3], [7, 8]])]) >>> signal = np.arange(20).reshape((2, 10)) >>> pprint(fragment_signal(signal, axis=1, step=4, fragment_length=4, onset_mode='front')) [array([[ 0, 1, 2, 3], [10, 11, 12, 13]]), array([[ 4, 5, 6, 7], [14, 15, 16, 17]])] >>> pprint(fragment_signal(signal, axis=1, step=4, fragment_length=4, onset_mode='center')) [array([[ 1, 2, 3, 4], [11, 12, 13, 14]]), array([[ 5, 6, 7, 8], [15, 16, 17, 18]])] >>> pprint(fragment_signal(signal, axis=1, step=4, fragment_length=4, onset_mode='end')) [array([[ 2, 3, 4, 5], [12, 13, 14, 15]]), array([[ 6, 7, 8, 9], [16, 17, 18, 19]])] >>> pprint(fragment_signal(signal, axis=1, step=4, fragment_length=4, onset_mode='random')) [array([[ 0, 1, 2, 3], [10, 11, 12, 13]]), array([[ 4, 5, 6, 7], [14, 15, 16, 17]])] """ axis = to_list(axis, len(signals)) step = to_list(step, len(signals)) fragment_length = to_list(fragment_length, len(signals)) # get random start if onset_mode == 'front': start = 0. elif onset_mode == 'random': # find max start such that at least one segment is obtained max_start = 1. for i in range(len(signals)): max_start = max( min( max_start, (signals[i].shape[axis[i]] - fragment_length[i]) / step[i] ), 0. ) start = np.random.rand() start *= max_start elif onset_mode in ['center', 'end']: start = 1. for i in range(len(signals)): tail = (signals[i].shape[axis[i]] - fragment_length[i]) % step[i] start = min(start, tail / step[i]) if onset_mode == 'center': start = start / 2 # adjust start to match an integer index for all keys for i in range(len(signals)): start = int(start*step[i]) / step[i] fragmented_signals = [] for i in range(len(signals)): x = signals[i] ax = axis[i] assert ax < x.ndim, (ax, x.ndim) frag_len = fragment_length[i] def get_slice(start, stop): slc = [slice(None)] * x.ndim slc[ax] = slice(int(start), int(stop)) return tuple(slc) start_idx = round(start * step[i]) assert abs(start_idx - start * step[i]) < 1e-6, (start_idx, start*step[i]) fragments = [ x[get_slice(idx, idx + frag_len)] for idx in np.arange( start_idx, x.shape[ax] - frag_len + 1, step[i] ) ] fragmented_signals.append(fragments) if len(signals) == 1: return fragmented_signals[0] assert len(set([len(sig) for sig in fragmented_signals])) == 1, ( [sig.shape for sig in signals], step, fragment_length, [len(sig) for sig in fragmented_signals] ) return (*fragmented_signals, )
def fragment_signal(*signals, axis, step, max_length, min_length=1, random_start=False): """ Args: signals: axis: step: max_length: min_length: random_start: Returns: >>> signals = [np.arange(20).reshape((2, 10)), np.arange(10).reshape((2, 5))] >>> from pprint import pprint >>> pprint(fragment_signal(signals, axis=1, step=[4, 2], max_length=[4, 2])) [[array([[ 0, 1, 2, 3], [10, 11, 12, 13]]), array([[ 4, 5, 6, 7], [14, 15, 16, 17]]), array([[ 8, 9], [18, 19]])], [array([[0, 1], [5, 6]]), array([[2, 3], [7, 8]]), array([[4], [9]])]] >>> pprint(fragment_signal(\ signals, axis=1, step=[4, 2], max_length=[4, 2], min_length=[4, 2]\ )) [[array([[ 0, 1, 2, 3], [10, 11, 12, 13]]), array([[ 4, 5, 6, 7], [14, 15, 16, 17]])], [array([[0, 1], [5, 6]]), array([[2, 3], [7, 8]])]] """ axis = to_list(axis, len(signals)) step = to_list(step, len(signals)) max_length = to_list(max_length, len(signals)) min_length = to_list(min_length, len(signals)) # get random start if random_start: start = np.random.rand() # find max start such that at least one segment is obtained max_start = 1. for i in range(len(signals)): # get nested structure and cast to dict max_start = max( min(max_start, (signals[i].shape[axis[i]] - max_length[i]) / step[i]), 0.) start *= max_start # adjust start to match an integer index for all keys for i in range(len(signals)): start = int(start * step[i]) / step[i] else: start = 0. fragmented_signals = [] for i in range(len(signals)): x = signals[i] ax = axis[i] assert ax < x.ndim, (ax, x.ndim) min_len = min_length[i] max_len = max_length[i] assert max_len >= min_len def get_slice(start, stop): slc = [slice(None)] * x.ndim slc[ax] = slice(int(start), int(stop)) return tuple(slc) start_idx = round(start * step[i]) assert abs(start_idx - start * step[i]) < 1e-6, (start_idx, start * step[i]) fragments = [ x[get_slice(idx, idx + max_len)] for idx in np.arange(start_idx, x.shape[ax] - min_len + 1, step[i]) ] fragmented_signals.append(fragments) if len(signals) == 1: return signals[0] assert len(set([len(sig) for sig in fragmented_signals ])) == 1, ([sig.shape for sig in signals], [len(sig) for sig in fragmented_signals]) return (*fragmented_signals, )
def get_transpose_config(cls, config, transpose_config=None): assert config['factory'] == cls if transpose_config is None: transpose_config = dict() if config['factory'] == CNN1d: transpose_config['factory'] = CNNTranspose1d if config['factory'] == CNNTranspose1d: transpose_config['factory'] = CNN1d if config['factory'] == CNN2d: transpose_config['factory'] = CNNTranspose2d if config['factory'] == CNNTranspose2d: transpose_config['factory'] = CNN2d channels = [config['in_channels']] + config['out_channels'] num_layers = len(config['out_channels']) if 'residual_connections' in config.keys() \ and config['residual_connections'] is not None: skip_connections = defaultdict(list) for src_idx, dst_indices in enumerate( to_list(config['residual_connections'], num_layers)): for dst_idx in to_list(dst_indices): if dst_idx is not None: skip_connections[num_layers - dst_idx].append(num_layers - src_idx) transpose_config['residual_connections'] = [ None if i not in skip_connections else skip_connections[i][0] if len(skip_connections) == 1 else skip_connections[i] for i in range(num_layers) ] if 'dense_connections' in config.keys() \ and config['dense_connections'] is not None: skip_connections = defaultdict(list) for src_idx, dst_indices in enumerate( to_list(config['dense_connections'], num_layers)): for dst_idx in to_list(dst_indices): if dst_idx is not None: skip_connections[num_layers - dst_idx].append(num_layers - src_idx) if cls.is_transpose(): channels[src_idx] -= channels[dst_idx] else: channels[dst_idx] += channels[src_idx] transpose_config['dense_connections'] = [ None if i not in skip_connections else skip_connections[i][0] if len(skip_connections) == 1 else skip_connections[i] for i in range(num_layers) ] transpose_config['in_channels'] = channels[-1] transpose_config['out_channels'] = channels[:-1][::-1] for kw in [ 'kernel_size', 'pad_sides', 'dilation', 'stride', 'pool_types', 'pool_size' ]: if kw not in config.keys(): continue if isinstance(config[kw], list): transpose_config[kw] = config[kw][::-1] else: transpose_config[kw] = config[kw] return transpose_config
def forward(self, x, seq_len=None, out_shapes=None, out_lengths=None, pool_indices=None): if not self.is_transpose(): assert out_shapes is None, out_shapes assert out_lengths is None, out_lengths assert pool_indices is None, pool_indices.shape shapes = to_list(copy(out_shapes), self.num_layers)[::-1] lengths = to_list(copy(out_lengths), self.num_layers)[::-1] pool_indices = to_list(copy(pool_indices), self.num_layers)[::-1] residual_skip_signals = defaultdict(list) dense_skip_signals = defaultdict(list) for i, conv in enumerate(self.convs): x, seq_len = self.maybe_unpool( x, pool_type=self.pool_types[i], pool_size=self.pool_sizes[i], seq_len=seq_len, pool_indices=pool_indices[i], ) if self.residual_connections[i] is not None: for dst_idx in self.residual_connections[i]: residual_skip_signals[dst_idx].append((i, x)) if self.dense_connections[i] is not None: for dst_idx in sorted(self.dense_connections[i]): if self.is_transpose(): x, x_skip = torch.split(x, [ self.layer_in_channels[i], self.out_channels[dst_idx - 1] ], dim=1) dense_skip_signals[dst_idx].append((i, x_skip)) else: dense_skip_signals[dst_idx].append((i, x)) in_shape = x.shape in_lengths = seq_len x, seq_len = conv(x, seq_len=seq_len, out_shape=shapes[i], out_lengths=lengths[i]) shapes[i] = in_shape lengths[i] = in_lengths for src_idx, x_ in dense_skip_signals[i + 1]: x_ = F.interpolate(x_, size=x.shape[2:]) if self.is_transpose(): x = x + x_ else: x = torch.cat((x, x_), dim=1) for src_idx, x_ in residual_skip_signals[i + 1]: x_ = F.interpolate(x_, size=x.shape[2:]) if f'{src_idx}->{i+1}' in self.residual_convs: x_, _ = self.residual_convs[f'{src_idx}->{i + 1}'](x_) x = x + x_ x, seq_len, pool_indices[i] = self.maybe_pool( x, pool_type=self.pool_types[i], pool_size=self.pool_sizes[i], pad_side=self.pad_sides[i], seq_len=seq_len) if self.return_pool_data: return x, seq_len, shapes, lengths, pool_indices return x, seq_len