def forward(self, x): def _inner_forward(x): x = [s.chunk(2, dim=1) for s in x] x1 = [s[0] for s in x] x2 = [s[1] for s in x] x2 = self.cross_resolution_weighting(x2) x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)] x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)] out = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(x1, x2)] out = [channel_shuffle(s, 2) for s in out] return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) return out
def forward(self, x): def _inner_forward(x): x = self.conv1(x) x1, x2 = x.chunk(2, dim=1) x2 = self.expand_conv(x2) x2 = self.depthwise_conv(x2) x2 = self.linear_conv(x2) out = torch.cat((self.branch1(x1), x2), dim=1) out = channel_shuffle(out, 2) return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) return out
def forward(self, x, edge_index): start, end = edge_index x = self.input_network(x) # Loop over iterations of edge and node networks for i in range(self.hparams["n_graph_iters"]): x_inital = x x = checkpoint(self.run_inner_loop, (x, start, end)) # Residual connection x = x_inital + x # print("5:", torch.cuda.max_memory_allocated() / 1024**3) # torch.cuda.reset_peak_memory_stats() edge_inputs = torch.cat([x[start], x[end]], dim=1) return self.edge_network(edge_inputs)
def forward(self, tensor, **kwargs): batch_size, input_len, channels = tensor.shape assert not (self.decoder_mode and "embeddings" not in kwargs), "Embeddings must be supplied if decoding" assert not ("embeddings" in kwargs and (kwargs["embeddings"].shape[0], kwargs["embeddings"].shape[1], kwargs["embeddings"].shape[2]) != (batch_size, input_len, channels)), "Embeddings size must be the same as the input tensor" head_outputs = [] for index, head in enumerate(self.heads): Q = self.to_q[index](tensor) K = self.to_k[index](tensor) if not self.decoder_mode else self.to_k[index](kwargs["embeddings"]) V = self.to_v[index](tensor) if not self.decoder_mode else self.to_v[index](kwargs["embeddings"]) if self.checkpoint_level == "C2": head_outputs.append(checkpoint(head,Q,K,V)) else: head_outputs.append(head(Q,K,V,**kwargs)) out = torch.cat(head_outputs, dim=-1) if self.w_o_intermediate_dim is None: out = self.w_o(out) else: out = self.w_o_1(out) out = self.w_o_2(out) out = self.mh_dropout(out) return out
def forward(self, x): def _func_factory(conv, bn, relu, has_bn, has_relu): def func(x): x = conv(x) if has_bn: x = bn(x) if has_relu: x = relu(x) return x return func func = _func_factory(self.conv, self.bn, self.relu, self.has_bn, self.has_relu) if self.efficient: x = checkpoint(func, x) else: x = func(x) return x
def checkpoint_sequential(functions, segments, *inputs): def run_function(start, end, functions): def forward(*inputs): for j in range(start, end + 1): inputs = functions[j](*inputs) return inputs return forward if isinstance(functions, torch.nn.Sequential): functions = list(functions.children()) segment_size = len(functions) // segments # the last chunk has to be non-volatile end = -1 for start in range(0, segment_size * (segments - 1), segment_size): end = start + segment_size - 1 inputs = checkpoint(run_function(start, end, functions), *inputs) if not isinstance(inputs, tuple): inputs = (inputs, ) return run_function(end + 1, len(functions) - 1, functions)(*inputs)
def forward(self, *prev_features): bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features): bottleneck_output = cp.checkpoint(bn_function, *prev_features) else: bottleneck_output = bn_function(*prev_features) if hasattr(self, 'concrete_dropout'): new_features = self.concrete_dropout( self.relu2(self.norm2(bottleneck_output))) else: new_features = self.conv2(self.relu2( self.norm2(bottleneck_output))) if self.drop_rate > 0: new_features = F.dropout(new_features, p=self.drop_rate, training=(self.training)) return new_features
def forward(self, *prev_features): bn_function = _bn_function_factory( self.norm1, self.relu1, self.conv1, self.index, self.mode ) if any(prev_feature.requires_grad for prev_feature in prev_features): # Does not compute intermediate values # but recompute them in the backward pass # tradeoff btw memory & computation bottleneck_output = cp.checkpoint(bn_function, *prev_features) else: bottleneck_output = bn_function(*prev_features) new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) # new_features has g channels if self.drop_rate > 0: new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) return new_features
def CTMRGTSVD(T, chi, tsvd_extra, max_iter, thresh=None, use_checkpoint=False): # T(up, left, down, right) threshold = 1E-8 if T.dtype is torch.float64 else 1E-6 # ctmrg convergence threshold if thresh is not None: threshold = thresh # C(down, right), E(up,right,down) C = T.sum((0, 1)) # E = T.sum(1).permute(0, 2, 1) truncation_error = 0.0 sold = torch.zeros(chi, dtype=T.dtype, device=T.device) diff = 1E1 bvar = 1E1 bvar3 = 1E1 for n in range(max_iter): tensors = C, E, T, torch.tensor(chi), torch.tensor(tsvd_extra) if use_checkpoint: # use checkpoint to save memory C, E, s, error = checkpoint(renormalize, *tensors) else: C, E, s, error = renormalize(*tensors) Enorm = E.norm() E = E / Enorm truncation_error += error.item() if (s.numel() == sold.numel()): diff = (s - sold).norm().item() bvar = boundaryVariance(C, E) #bvar3 = boundaryVariance3(T, C, E) #print( s, sold ) #print( 'n: %d, Enorm: %g, error: %e, diff: %e' % (n, Enorm, error.item(), diff) ) #if (diff < threshold): if (bvar < threshold): break sold = s print('ctmrg converged at iterations %d to %.5e, bvar2 %.5e, bvar3 %.5e, \ truncation error: %.5f' % (n, diff, bvar, bvar3, truncation_error)) return C, E
def forward(self, x): def _inner_forward(x): identity = x out = self.conv1(x) out = self.norm1(out) out = self.relu(out) out = self.dropblock1(out) out = self.conv2(out) out = self.norm2(out) out = self.relu(out) out = self.dropblock2(out) if self.with_gen_attention: out = self.gen_attention_block(out) out = self.conv3(out) out = self.norm3(out) out = self.dropblock3(out) if self.with_gcb: out = self.context_block(out) if self.downsample is not None: identity = self.downsample(x) out += self.dropblockskip(identity) return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) out = self.relu(out) return out
def forward(self, prev_output_tokens, encoder_out=None): # embed positions positions = self.embed_positions( prev_output_tokens, ) if self.embed_positions is not None else None # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if positions is not None: x += positions x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C # The tensor needs to copy transposed because # fused dropout is not capable of handing strided data x = x.transpose(0, 1) # decoder layers for layer in self.layers: x = checkpoint( layer.forward, x, encoder_out['encoder_out'] if encoder_out is not None else None, encoder_out['encoder_padding_mask'] if encoder_out is not None else None, ) x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.adaptive_softmax is None: # project back to size of vocabulary if self.share_input_output_embed: x = F.linear(x, self.embed_tokens.weight) else: x = F.linear(x, self.embed_out) return x
def forward(self, x): def _inner_forward(x): identity = x out = self.conv1(x) out = self.norm1(out) out = self.relu(out) if not self.with_dcn: out = self.conv2(out) elif self.with_modulated_dcn: offset_mask = self.conv2_offset(x) offset = offset_mask[:, :18, :, :] mask = offset_mask[:, -9:, :, :].sigmoid() out = self.conv2(out, offset, mask) else: offset = self.conv2_offset(x) out = self.conv2(out, offset) out = self.norm2(out) out = self.relu(out) out = self.conv3(out) out = self.norm3(out) if self.downsample is not None: identity = self.downsample(x) out += identity return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) out = self.relu(out) return out
def forward(self, x, H, W): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ # calculate attention mask for SW-MSA Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, attn_mask) else: x = blk(x, attn_mask) if self.downsample is not None: x_down = self.downsample(x, H, W) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W
def forward(self, hidden_states, attention_mask, chunks=None): all_hidden_states = () all_attentions = () if chunks is not None: assert isinstance(chunks, int) chunk_size = (len(self.layer) + chunks - 1) // chunks for start in range(0, len(self.layer), chunk_size): outputs = checkpoint(self.run_function(start, chunk_size), hidden_states, attention_mask) if self.output_hidden_states: all_hidden_states = all_hidden_states + outputs[1] if self.output_attentions: all_attentions = all_attentions + outputs[-1] hidden_states = outputs[0] else: for i, layer_module in enumerate(self.layer): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) layer_outputs = layer_module(hidden_states, attention_mask) hidden_states = layer_outputs[0] if self.output_attentions: all_attentions = all_attentions + (layer_outputs[1], ) # Add last layer if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) # ADD trailing layer norm outputs = self.final_ln(hidden_states) outputs = (hidden_states, ) if self.output_hidden_states: outputs = outputs + (all_hidden_states, ) if self.output_attentions: outputs = outputs + (all_attentions, ) return outputs # outputs, (hidden states), (attentions)
def forward(self, x): """forward""" def _inner_forward(x): identity = x out = self.conv1(x) out = self.norm1(out) out = self.relu(out) if self.avd and self.avd_first: out = self.avd_layer(out) out = self.conv2(out) out = self.norm2(out) out = self.relu(out) if self.avd and not self.avd_first: out = self.avd_layer(out) out = self.conv3(out) out = self.norm3(out) if self.downsample is not None: identity = self.downsample(x) out += identity return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) out = self.relu(out) if self.nonlocal_block is not None: out = self.nonlocal_block(out) return out
def forward(self, x): def _inner_forward(xx): x, traj_src = xx identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) if self.if_inflate: if self.with_trajectory: assert traj_src is not None out = self.conv2_t(out, traj_src[0]) else: out = self.conv2_t(out) out = self.bn2_t(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity return out, traj_src[1:] if self.with_cp and x.requires_grad: out, traj_remains = cp.checkpoint(_inner_forward, x) else: out, traj_remains = _inner_forward(x) out = self.relu(out) return out, traj_remains
def forward(self, x): """Forward function""" def _inner_forward(x): identity = x out = self.conv1(x) out = self.norm1(out) out = self.relu(out) if self.with_plugins: out = self.forward_plugin(out, self.after_conv1_plugin_names) out = self.conv2(out) out = self.norm2(out) out = self.relu(out) if self.with_plugins: out = self.forward_plugin(out, self.after_conv2_plugin_names) out = self.conv3(out) out = self.norm3(out) if self.with_plugins: out = self.forward_plugin(out, self.after_conv3_plugin_names) if self.downsample is not None: identity = self.downsample(x) out += identity return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) out = self.relu(out) return out
def test_checkpoint_non_tensor_inputs_outputs(self): def foo(t1, t2, scale, t3): t4 = t1 + t2 * t3 t5 = t1 * t2 + t3 t4 *= scale t5 *= scale return scale, t4, None, True, t5, "bar", t1 t1 = torch.rand(10, requires_grad=True) t2 = torch.rand(10, requires_grad=True) t3 = torch.rand(10) scale = random.randint(0, 10) res = checkpoint(foo, t1, t2, scale, t3) self.assertEqual(scale, res[0]) self.assertEqual((t1 + t2 * t3) * scale, res[1]) self.assertEqual(None, res[2]) self.assertEqual(True, res[3]) self.assertEqual((t1 * t2 + t3) * scale, res[4]) self.assertEqual("bar", res[5]) self.assertEqual(t1, res[6]) # Validate running backward. res[1].sum().backward(retain_graph=True) res[4].sum().backward(retain_graph=True) res[6].sum().backward() with self.assertRaisesRegex( RuntimeError, "Trying to backward through the graph a second time"): res[6].sum().backward() t1_grad = t1.grad t2_grad = t2.grad # Reset grads, run without checkpoint and validate we receive same grads. t1.grad = None t2.grad = None res = foo(t1, t2, scale, t3) torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()]) self.assertEqual(t1.grad, t1_grad) self.assertEqual(t2.grad, t2_grad)
def forward_features(self, x): B = x.shape[0] pixel_embed = self.pixel_embed(x, self.pixel_pos) patch_embed = self.norm2_proj( self.proj( self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1)))) patch_embed = torch.cat( (self.cls_token.expand(B, -1, -1), patch_embed), dim=1) patch_embed = patch_embed + self.patch_pos patch_embed = self.pos_drop(patch_embed) if self.grad_checkpointing and not torch.jit.is_scripting(): for blk in self.blocks: pixel_embed, patch_embed = checkpoint(blk, pixel_embed, patch_embed) else: for blk in self.blocks: pixel_embed, patch_embed = blk(pixel_embed, patch_embed) patch_embed = self.norm(patch_embed) return patch_embed
def forward(self, atom_features_list, bond_info, cond_features=None): """ Args: atom_features_list (list[torch.Tensor]): Input features from previous dense layers for each node size=[num_nodes, num_node_features] bond_info (torch.Tensor): Bond type information packed into a single matrix type: torch.long, shape: [-1, 3], where 3 = begin_ids + end_ids + bond_type cond_features (torch.Tensor or None): Input conditional features should be None if self.conditional is False Returns: torch.Tensor: Output feature for each node, size=[num_nodes, hidden_sizes[-1]] """ bn_fn = _bn_function_factory(self.bottlenec) if self.efficient and all([ atom_features_i.requires_grad for atom_features_i in atom_features_list ]): atom_features = cp.checkpoint(bn_fn, *atom_features_list) else: atom_features = bn_fn(*atom_features_list) return self.conv(atom_features, bond_info, cond_features)
def forward(self, x): """Defines the computation performed at every call.""" def _inner_forward(x): """Forward wrapper for utilizing checkpoint.""" identity = x out = self.conv1(x) out = self.conv2(out) if self.downsample is not None: identity = self.downsample(x) out = out + identity return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) out = self.relu(out) return out
def forward(self, x): """ Args: x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. Returns: dict[str->Tensor]: names and the corresponding features """ assert x.dim( ) == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!" outputs = {} # import pdb # pdb.set_trace() # cnt = 0 if self.checkpoint_grad_num > 0: # modules = [module for k, module in self.stem._modules.items()] # x = checkpoint.checkpoint_sequential(modules,1,x) x = checkpoint.checkpoint(self.custom(self.stem), x) else: x = self.stem(x) # cnt += 1 if "stem" in self._out_features: outputs["stem"] = x for stage, name in self.stages_and_names: if self.checkpoint_grad_num > 0: # and cnt>2: modules = [module for k, module in stage._modules.items()] x = checkpoint.checkpoint_sequential(modules, 1, x) else: x = stage(x) if name in self._out_features: outputs[name] = x # cnt += 1 if self.num_classes is not None: x = self.avgpool(x) x = torch.flatten(x, 1) x = self.linear(x) if "linear" in self._out_features: outputs["linear"] = x return outputs
def forward(self, X, mask): if self.attn_type.startswith( "longformer") or self.attn_type.startswith("reformer"): with torch.cuda.amp.autocast(enabled=False): attn_out = self.attn(X.float(), mask.float()) else: Q = self.split_heads(self.W_q(X)) K = self.split_heads(self.W_k(X)) V = self.split_heads(self.W_v(X)) with torch.cuda.amp.autocast(enabled=False): if self.grad_checkpointing: attn_out = checkpoint(self.attn, Q.float(), K.float(), V.float(), mask.float()) else: attn_out = self.attn(Q.float(), K.float(), V.float(), mask.float()) attn_out = self.combine_heads(attn_out) out = self.ff(attn_out) return out
def forward(self, *prev_features): # concatenate magic bottleneck_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) if any(prev_feature.requires_grad for prev_feature in prev_features): bottleneck_output = cp.checkpoint(bottleneck_function, *prev_features) else: bottleneck_output = bottleneck_function(*prev_features) # if self.drop_rate > 0: # new_features = F.dropout(bottleneck_output, p=self.drop_rate, training=self.training) # else: # new_features = bottleneck_output new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) if self.drop_rate > 0: new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) return new_features
def forward(self, x, cond, decode_step, decode_idx): h = self.pre_attn_norm(x, cond) if self.training: h = checkpoint(self.attn, h, h, h, decode_step, decode_idx) else: h = self.attn(h, h, h, decode_step, decode_idx) h = self.post_attn_dp(h) x = x + h if self.use_frame_cond: h = self.pre_enc_norm(x, cond) h = self.enc_attn(h, cond['frame_cond'], cond['frame_cond'], decode_step, decode_idx) h = self.post_enc_dp(h) x = x + h h = self.pre_fc_norm(x, cond) h = self.fc_block(h) h = self.post_fc_dp(h) x = x + h return x
def forward(self, x): def _inner_forward(x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if (self.downsample is not None): residual = self.downsample(x) out += residual return out if (self.with_cp and x.requires_grad): out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) out = self.relu(out) return out
def forward(self, x, H, W): """Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W) x_down = self.downsample(x_reshaped) x_down = x_down.flatten(2).transpose(1, 2) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W
def forward(self, x): def _inner_forward(_x): _identity = _x _out = self.conv1(_x) _out = self.bn1(_out) _out = self.relu(_out) _out = self.conv2(_out) _out = self.bn2(_out) _out = self.relu(_out) _out = self.conv3(_out) _out = self.bn3(_out) if self.downsample is not None: _identity = self.downsample(_x) if self.dropout is not None: _out = self.dropout(_out) _out += _identity return _out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) if self.se_block is not None: out = self.se_block(out) if self.nonlocal_block is not None: out = self.nonlocal_block(out) out = self.relu(out) return out
def forward(self, x): def _inner_forward(x): identity = x out = self.conv1(x) out = self.conv2(out) out = self.conv3(out) if self.downsample is not None: identity = self.downsample(x) out += identity return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) out = self.relu(out) return out
def multi_step_forward(self, moving, target, compute_loss=True): """ mutli-step forward, A_t is composed of A_update and A_last :param moving: the moving image :param target: the target image :param compute_loss: if true, compute the loss :return: warped image (intensity[-1,1]), transformation map (coord [-1,1]), affine param """ output = None moving_cp = moving affine_param = None affine_param_last = None affine_map = None bilinear = [Bilinear(self.zero_boundary) for i in range(self.step)] self.loss = 0. for i in range(self.step): #affine_param = self.affine_gen(moving, target) if i == 0: affine_param = self.affine_gen(moving, target) else: affine_param = checkpoint(self.affine_gen, moving, target) if i > 0: affine_param = self.update_affine_param( affine_param, affine_param_last) affine_param_last = affine_param affine_map = self.gen_affine_map(affine_param) output = bilinear[i](moving_cp, affine_map) moving = output self.affine_param = affine_param if compute_loss and (i == self.step - 1 or self.acc_multi_step_loss): self.loss += self.compute_overall_loss(self.extern_loss, output, target) if compute_loss and self.acc_multi_step_loss: self.loss = self.loss / self.step return output, affine_map, affine_param