Example #1
0
    def forward(self, x):
        def _inner_forward(x):
            x = [s.chunk(2, dim=1) for s in x]
            x1 = [s[0] for s in x]
            x2 = [s[1] for s in x]

            x2 = self.cross_resolution_weighting(x2)
            x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)]
            x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)]

            out = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(x1, x2)]
            out = [channel_shuffle(s, 2) for s in out]

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        return out
Example #2
0
    def forward(self, x):
        def _inner_forward(x):
            x = self.conv1(x)
            x1, x2 = x.chunk(2, dim=1)

            x2 = self.expand_conv(x2)
            x2 = self.depthwise_conv(x2)
            x2 = self.linear_conv(x2)

            out = torch.cat((self.branch1(x1), x2), dim=1)

            out = channel_shuffle(out, 2)

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        return out
    def forward(self, x, edge_index):

        start, end = edge_index

        x = self.input_network(x)

        # Loop over iterations of edge and node networks
        for i in range(self.hparams["n_graph_iters"]):

            x_inital = x

            x = checkpoint(self.run_inner_loop, (x, start, end))

            # Residual connection
            x = x_inital + x

        #         print("5:", torch.cuda.max_memory_allocated() / 1024**3)
        #         torch.cuda.reset_peak_memory_stats()

        edge_inputs = torch.cat([x[start], x[end]], dim=1)
        return self.edge_network(edge_inputs)
 def forward(self, tensor, **kwargs):
     batch_size, input_len, channels = tensor.shape
     assert not (self.decoder_mode and "embeddings" not in kwargs), "Embeddings must be supplied if decoding"
     assert not ("embeddings" in kwargs and (kwargs["embeddings"].shape[0], kwargs["embeddings"].shape[1], kwargs["embeddings"].shape[2]) != (batch_size, input_len, channels)), "Embeddings size must be the same as the input tensor"
     head_outputs = []
     for index, head in enumerate(self.heads):
         Q = self.to_q[index](tensor)
         K = self.to_k[index](tensor) if not self.decoder_mode else self.to_k[index](kwargs["embeddings"])
         V = self.to_v[index](tensor) if not self.decoder_mode else self.to_v[index](kwargs["embeddings"])
         if self.checkpoint_level == "C2":
             head_outputs.append(checkpoint(head,Q,K,V))
         else:
             head_outputs.append(head(Q,K,V,**kwargs))
     out = torch.cat(head_outputs, dim=-1)
     if self.w_o_intermediate_dim is None:
         out = self.w_o(out)
     else:
         out = self.w_o_1(out)
         out = self.w_o_2(out)
     out = self.mh_dropout(out)
     return out
Example #5
0
    def forward(self, x):
        def _func_factory(conv, bn, relu, has_bn, has_relu):
            def func(x):
                x = conv(x)
                if has_bn:
                    x = bn(x)
                if has_relu:
                    x = relu(x)
                return x

            return func

        func = _func_factory(self.conv, self.bn, self.relu, self.has_bn,
                             self.has_relu)

        if self.efficient:
            x = checkpoint(func, x)
        else:
            x = func(x)

        return x
Example #6
0
def checkpoint_sequential(functions, segments, *inputs):
    def run_function(start, end, functions):
        def forward(*inputs):
            for j in range(start, end + 1):
                inputs = functions[j](*inputs)
            return inputs

        return forward

    if isinstance(functions, torch.nn.Sequential):
        functions = list(functions.children())

    segment_size = len(functions) // segments
    # the last chunk has to be non-volatile
    end = -1
    for start in range(0, segment_size * (segments - 1), segment_size):
        end = start + segment_size - 1
        inputs = checkpoint(run_function(start, end, functions), *inputs)
        if not isinstance(inputs, tuple):
            inputs = (inputs, )
    return run_function(end + 1, len(functions) - 1, functions)(*inputs)
Example #7
0
    def forward(self, *prev_features):
        bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
        if self.memory_efficient and any(prev_feature.requires_grad
                                         for prev_feature in prev_features):
            bottleneck_output = cp.checkpoint(bn_function, *prev_features)
        else:
            bottleneck_output = bn_function(*prev_features)

        if hasattr(self, 'concrete_dropout'):
            new_features = self.concrete_dropout(
                self.relu2(self.norm2(bottleneck_output)))
        else:
            new_features = self.conv2(self.relu2(
                self.norm2(bottleneck_output)))

            if self.drop_rate > 0:
                new_features = F.dropout(new_features,
                                         p=self.drop_rate,
                                         training=(self.training))

        return new_features
Example #8
0
 def forward(self, *prev_features):
     bn_function = _bn_function_factory(
         self.norm1,
         self.relu1,
         self.conv1,
         self.index,
         self.mode
         )
     if any(prev_feature.requires_grad for prev_feature in prev_features):
         # Does not compute intermediate values
         # but recompute them in the backward pass
         # tradeoff btw memory & computation
         bottleneck_output = cp.checkpoint(bn_function, *prev_features)
     else:
         bottleneck_output = bn_function(*prev_features)
     new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
     # new_features has g channels
     if self.drop_rate > 0:
         new_features = F.dropout(new_features, p=self.drop_rate,
                                  training=self.training)
     return new_features
Example #9
0
def CTMRGTSVD(T, chi, tsvd_extra, max_iter, thresh=None, use_checkpoint=False):
    # T(up, left, down, right)

    threshold = 1E-8 if T.dtype is torch.float64 else 1E-6  # ctmrg convergence threshold
    if thresh is not None:
        threshold = thresh

    # C(down, right), E(up,right,down)
    C = T.sum((0, 1))  #
    E = T.sum(1).permute(0, 2, 1)

    truncation_error = 0.0
    sold = torch.zeros(chi, dtype=T.dtype, device=T.device)
    diff = 1E1
    bvar = 1E1
    bvar3 = 1E1
    for n in range(max_iter):
        tensors = C, E, T, torch.tensor(chi), torch.tensor(tsvd_extra)
        if use_checkpoint:  # use checkpoint to save memory
            C, E, s, error = checkpoint(renormalize, *tensors)
        else:
            C, E, s, error = renormalize(*tensors)

        Enorm = E.norm()
        E = E / Enorm
        truncation_error += error.item()
        if (s.numel() == sold.numel()):
            diff = (s - sold).norm().item()
            bvar = boundaryVariance(C, E)
            #bvar3 = boundaryVariance3(T, C, E)
            #print( s, sold )
        #print( 'n: %d, Enorm: %g, error: %e, diff: %e' % (n, Enorm, error.item(), diff) )
        #if (diff < threshold):
        if (bvar < threshold):
            break
        sold = s
    print('ctmrg converged at iterations %d to %.5e, bvar2 %.5e, bvar3 %.5e, \
        truncation error: %.5f' % (n, diff, bvar, bvar3, truncation_error))

    return C, E
Example #10
0
    def forward(self, x):

        def _inner_forward(x):
            identity = x

            out = self.conv1(x)
            out = self.norm1(out)
            out = self.relu(out)
            out = self.dropblock1(out)

            out = self.conv2(out)
            out = self.norm2(out)
            out = self.relu(out)
            out = self.dropblock2(out)

            if self.with_gen_attention:
                out = self.gen_attention_block(out)

            out = self.conv3(out)
            out = self.norm3(out)
            out = self.dropblock3(out)

            if self.with_gcb:
                out = self.context_block(out)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += self.dropblockskip(identity)

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out
Example #11
0
    def forward(self, prev_output_tokens, encoder_out=None):
        # embed positions
        positions = self.embed_positions(
            prev_output_tokens, ) if self.embed_positions is not None else None

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)
        if positions is not None:
            x += positions
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        # The tensor needs to copy transposed because
        # fused dropout is not capable of handing strided data
        x = x.transpose(0, 1)

        # decoder layers
        for layer in self.layers:
            x = checkpoint(
                layer.forward,
                x,
                encoder_out['encoder_out']
                if encoder_out is not None else None,
                encoder_out['encoder_padding_mask']
                if encoder_out is not None else None,
            )

        x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.adaptive_softmax is None:
            # project back to size of vocabulary
            if self.share_input_output_embed:
                x = F.linear(x, self.embed_tokens.weight)
            else:
                x = F.linear(x, self.embed_out)

        return x
    def forward(self, x):

        def _inner_forward(x):
            identity = x

            out = self.conv1(x)
            out = self.norm1(out)
            out = self.relu(out)

            if not self.with_dcn:
                out = self.conv2(out)
            elif self.with_modulated_dcn:
                offset_mask = self.conv2_offset(x)
                offset = offset_mask[:, :18, :, :]
                mask = offset_mask[:, -9:, :, :].sigmoid()
                out = self.conv2(out, offset, mask)
            else:
                offset = self.conv2_offset(x)
                out = self.conv2(out, offset)
            out = self.norm2(out)
            out = self.relu(out)

            out = self.conv3(out)
            out = self.norm3(out)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += identity

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out
Example #13
0
    def forward(self, x, H, W):
        """ Forward function.
        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
        """

        # calculate attention mask for SW-MSA
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

        for blk in self.blocks:
            blk.H, blk.W = H, W
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, attn_mask)
            else:
                x = blk(x, attn_mask)
        if self.downsample is not None:
            x_down = self.downsample(x, H, W)
            Wh, Ww = (H + 1) // 2, (W + 1) // 2
            return x, H, W, x_down, Wh, Ww
        else:
            return x, H, W, x, H, W
Example #14
0
    def forward(self, hidden_states, attention_mask, chunks=None):
        all_hidden_states = ()
        all_attentions = ()

        if chunks is not None:
            assert isinstance(chunks, int)
            chunk_size = (len(self.layer) + chunks - 1) // chunks
            for start in range(0, len(self.layer), chunk_size):
                outputs = checkpoint(self.run_function(start, chunk_size),
                                     hidden_states, attention_mask)
                if self.output_hidden_states:
                    all_hidden_states = all_hidden_states + outputs[1]
                if self.output_attentions:
                    all_attentions = all_attentions + outputs[-1]
                hidden_states = outputs[0]
        else:
            for i, layer_module in enumerate(self.layer):
                if self.output_hidden_states:
                    all_hidden_states = all_hidden_states + (hidden_states, )

                layer_outputs = layer_module(hidden_states, attention_mask)
                hidden_states = layer_outputs[0]
                if self.output_attentions:
                    all_attentions = all_attentions + (layer_outputs[1], )

            # Add last layer
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states, )

            # ADD trailing layer norm
            outputs = self.final_ln(hidden_states)

            outputs = (hidden_states, )

            if self.output_hidden_states:
                outputs = outputs + (all_hidden_states, )
            if self.output_attentions:
                outputs = outputs + (all_attentions, )

        return outputs  # outputs, (hidden states), (attentions)
Example #15
0
    def forward(self, x):
        """forward"""
        def _inner_forward(x):
            identity = x

            out = self.conv1(x)
            out = self.norm1(out)
            out = self.relu(out)

            if self.avd and self.avd_first:
                out = self.avd_layer(out)

            out = self.conv2(out)
            out = self.norm2(out)
            out = self.relu(out)

            if self.avd and not self.avd_first:
                out = self.avd_layer(out)

            out = self.conv3(out)
            out = self.norm3(out)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += identity

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        if self.nonlocal_block is not None:
            out = self.nonlocal_block(out)

        return out
Example #16
0
    def forward(self, x):

        def _inner_forward(xx):
            x, traj_src = xx
            identity = x

            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)

            out = self.conv2(out)
            out = self.bn2(out)
            out = self.relu(out)

            if self.if_inflate:
                if self.with_trajectory:
                    assert traj_src is not None
                    out = self.conv2_t(out, traj_src[0])
                else:
                    out = self.conv2_t(out)
                out = self.bn2_t(out)

            out = self.conv3(out)
            out = self.bn3(out)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += identity

            return out, traj_src[1:]

        if self.with_cp and x.requires_grad:
            out, traj_remains = cp.checkpoint(_inner_forward, x)
        else:
            out, traj_remains = _inner_forward(x)

        out = self.relu(out)

        return out, traj_remains
Example #17
0
    def forward(self, x):
        """Forward function"""
        def _inner_forward(x):
            identity = x

            out = self.conv1(x)
            out = self.norm1(out)
            out = self.relu(out)

            if self.with_plugins:
                out = self.forward_plugin(out, self.after_conv1_plugin_names)

            out = self.conv2(out)
            out = self.norm2(out)
            out = self.relu(out)

            if self.with_plugins:
                out = self.forward_plugin(out, self.after_conv2_plugin_names)

            out = self.conv3(out)
            out = self.norm3(out)

            if self.with_plugins:
                out = self.forward_plugin(out, self.after_conv3_plugin_names)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += identity

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out
Example #18
0
    def test_checkpoint_non_tensor_inputs_outputs(self):
        def foo(t1, t2, scale, t3):
            t4 = t1 + t2 * t3
            t5 = t1 * t2 + t3
            t4 *= scale
            t5 *= scale
            return scale, t4, None, True, t5, "bar", t1

        t1 = torch.rand(10, requires_grad=True)
        t2 = torch.rand(10, requires_grad=True)
        t3 = torch.rand(10)
        scale = random.randint(0, 10)
        res = checkpoint(foo, t1, t2, scale, t3)
        self.assertEqual(scale, res[0])
        self.assertEqual((t1 + t2 * t3) * scale, res[1])
        self.assertEqual(None, res[2])
        self.assertEqual(True, res[3])
        self.assertEqual((t1 * t2 + t3) * scale, res[4])
        self.assertEqual("bar", res[5])
        self.assertEqual(t1, res[6])

        # Validate running backward.
        res[1].sum().backward(retain_graph=True)
        res[4].sum().backward(retain_graph=True)
        res[6].sum().backward()
        with self.assertRaisesRegex(
                RuntimeError,
                "Trying to backward through the graph a second time"):
            res[6].sum().backward()
        t1_grad = t1.grad
        t2_grad = t2.grad

        # Reset grads, run without checkpoint and validate we receive same grads.
        t1.grad = None
        t2.grad = None
        res = foo(t1, t2, scale, t3)
        torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()])
        self.assertEqual(t1.grad, t1_grad)
        self.assertEqual(t2.grad, t2_grad)
Example #19
0
    def forward_features(self, x):
        B = x.shape[0]
        pixel_embed = self.pixel_embed(x, self.pixel_pos)

        patch_embed = self.norm2_proj(
            self.proj(
                self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1))))
        patch_embed = torch.cat(
            (self.cls_token.expand(B, -1, -1), patch_embed), dim=1)
        patch_embed = patch_embed + self.patch_pos
        patch_embed = self.pos_drop(patch_embed)

        if self.grad_checkpointing and not torch.jit.is_scripting():
            for blk in self.blocks:
                pixel_embed, patch_embed = checkpoint(blk, pixel_embed,
                                                      patch_embed)
        else:
            for blk in self.blocks:
                pixel_embed, patch_embed = blk(pixel_embed, patch_embed)

        patch_embed = self.norm(patch_embed)
        return patch_embed
Example #20
0
    def forward(self, atom_features_list, bond_info, cond_features=None):
        """
        Args:
            atom_features_list (list[torch.Tensor]): Input features from previous dense layers for each node
                                                     size=[num_nodes, num_node_features]
            bond_info (torch.Tensor): Bond type information packed into a single matrix
                                      type: torch.long, shape: [-1, 3], where 3 = begin_ids + end_ids + bond_type
            cond_features (torch.Tensor or None): Input conditional features should be None if self.conditional is False

        Returns:
            torch.Tensor: Output feature for each node, size=[num_nodes, hidden_sizes[-1]]
        """
        bn_fn = _bn_function_factory(self.bottlenec)
        if self.efficient and all([
                atom_features_i.requires_grad
                for atom_features_i in atom_features_list
        ]):
            atom_features = cp.checkpoint(bn_fn, *atom_features_list)
        else:
            atom_features = bn_fn(*atom_features_list)

        return self.conv(atom_features, bond_info, cond_features)
Example #21
0
    def forward(self, x):
        """Defines the computation performed at every call."""

        def _inner_forward(x):
            """Forward wrapper for utilizing checkpoint."""
            identity = x

            out = self.conv1(x)
            out = self.conv2(out)

            if self.downsample is not None:
                identity = self.downsample(x)

            out = out + identity
            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)
        out = self.relu(out)
        return out
Example #22
0
    def forward(self, x):
        """
        Args:
            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.

        Returns:
            dict[str->Tensor]: names and the corresponding features
        """
        assert x.dim(
        ) == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!"
        outputs = {}
        # import pdb
        # pdb.set_trace()
        # cnt = 0
        if self.checkpoint_grad_num > 0:
            # modules = [module for k, module in self.stem._modules.items()]
            # x = checkpoint.checkpoint_sequential(modules,1,x)
            x = checkpoint.checkpoint(self.custom(self.stem), x)
        else:
            x = self.stem(x)
        # cnt += 1
        if "stem" in self._out_features:
            outputs["stem"] = x
        for stage, name in self.stages_and_names:
            if self.checkpoint_grad_num > 0:  # and cnt>2:
                modules = [module for k, module in stage._modules.items()]
                x = checkpoint.checkpoint_sequential(modules, 1, x)
            else:
                x = stage(x)
            if name in self._out_features:
                outputs[name] = x
            # cnt += 1
        if self.num_classes is not None:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.linear(x)
            if "linear" in self._out_features:
                outputs["linear"] = x
        return outputs
Example #23
0
    def forward(self, X, mask):

        if self.attn_type.startswith(
                "longformer") or self.attn_type.startswith("reformer"):
            with torch.cuda.amp.autocast(enabled=False):
                attn_out = self.attn(X.float(), mask.float())
        else:
            Q = self.split_heads(self.W_q(X))
            K = self.split_heads(self.W_k(X))
            V = self.split_heads(self.W_v(X))
            with torch.cuda.amp.autocast(enabled=False):
                if self.grad_checkpointing:
                    attn_out = checkpoint(self.attn, Q.float(), K.float(),
                                          V.float(), mask.float())
                else:
                    attn_out = self.attn(Q.float(), K.float(), V.float(),
                                         mask.float())
            attn_out = self.combine_heads(attn_out)

        out = self.ff(attn_out)

        return out
Example #24
0
    def forward(self, *prev_features):  # concatenate magic
        bottleneck_function = _bn_function_factory(self.norm1, self.relu1,
                                                   self.conv1)
        if any(prev_feature.requires_grad for prev_feature in prev_features):
            bottleneck_output = cp.checkpoint(bottleneck_function,
                                              *prev_features)
        else:
            bottleneck_output = bottleneck_function(*prev_features)

        # if self.drop_rate > 0:
        #     new_features = F.dropout(bottleneck_output, p=self.drop_rate, training=self.training)
        # else:
        #     new_features = bottleneck_output

        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))

        if self.drop_rate > 0:
            new_features = F.dropout(new_features,
                                     p=self.drop_rate,
                                     training=self.training)

        return new_features
Example #25
0
    def forward(self, x, cond, decode_step, decode_idx):
        h = self.pre_attn_norm(x, cond)
        if self.training:
            h = checkpoint(self.attn, h, h, h, decode_step, decode_idx)
        else:
            h = self.attn(h, h, h, decode_step, decode_idx)
        h = self.post_attn_dp(h)
        x = x + h

        if self.use_frame_cond:
            h = self.pre_enc_norm(x, cond)
            h = self.enc_attn(h, cond['frame_cond'], cond['frame_cond'],
                              decode_step, decode_idx)
            h = self.post_enc_dp(h)
            x = x + h

        h = self.pre_fc_norm(x, cond)
        h = self.fc_block(h)
        h = self.post_fc_dp(h)
        x = x + h

        return x
Example #26
0
    def forward(self, x):
        def _inner_forward(x):
            residual = x
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
            out = self.conv2(out)
            out = self.bn2(out)
            out = self.relu(out)
            out = self.conv3(out)
            out = self.bn3(out)
            if (self.downsample is not None):
                residual = self.downsample(x)
            out += residual
            return out

        if (self.with_cp and x.requires_grad):
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)
        out = self.relu(out)
        return out
Example #27
0
    def forward(self, x, H, W):
        """Forward function.

        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
        """

        for blk in self.blocks:
            blk.H, blk.W = H, W
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W)
            x_down = self.downsample(x_reshaped)
            x_down = x_down.flatten(2).transpose(1, 2)
            Wh, Ww = (H + 1) // 2, (W + 1) // 2
            return x, H, W, x_down, Wh, Ww
        else:
            return x, H, W, x, H, W
Example #28
0
    def forward(self, x):
        def _inner_forward(_x):
            _identity = _x

            _out = self.conv1(_x)
            _out = self.bn1(_out)
            _out = self.relu(_out)

            _out = self.conv2(_out)
            _out = self.bn2(_out)
            _out = self.relu(_out)

            _out = self.conv3(_out)
            _out = self.bn3(_out)

            if self.downsample is not None:
                _identity = self.downsample(_x)

            if self.dropout is not None:
                _out = self.dropout(_out)

            _out += _identity

            return _out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        if self.se_block is not None:
            out = self.se_block(out)

        if self.nonlocal_block is not None:
            out = self.nonlocal_block(out)

        out = self.relu(out)

        return out
Example #29
0
    def forward(self, x):

        def _inner_forward(x):
            identity = x
            out = self.conv1(x)
            out = self.conv2(out)
            out = self.conv3(out)

            if self.downsample is not None:
                identity = self.downsample(x)
            out += identity

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out
Example #30
0
    def multi_step_forward(self, moving, target, compute_loss=True):
        """
        mutli-step forward, A_t is composed of A_update and A_last

        :param moving: the moving image
        :param target: the target image
        :param compute_loss: if true, compute the loss
        :return: warped image (intensity[-1,1]), transformation map (coord [-1,1]), affine param
        """

        output = None
        moving_cp = moving
        affine_param = None
        affine_param_last = None
        affine_map = None
        bilinear = [Bilinear(self.zero_boundary) for i in range(self.step)]
        self.loss = 0.
        for i in range(self.step):
            #affine_param = self.affine_gen(moving, target)
            if i == 0:
                affine_param = self.affine_gen(moving, target)
            else:
                affine_param = checkpoint(self.affine_gen, moving, target)
            if i > 0:
                affine_param = self.update_affine_param(
                    affine_param, affine_param_last)
            affine_param_last = affine_param
            affine_map = self.gen_affine_map(affine_param)
            output = bilinear[i](moving_cp, affine_map)
            moving = output
            self.affine_param = affine_param
            if compute_loss and (i == self.step - 1
                                 or self.acc_multi_step_loss):
                self.loss += self.compute_overall_loss(self.extern_loss,
                                                       output, target)
        if compute_loss and self.acc_multi_step_loss:
            self.loss = self.loss / self.step
        return output, affine_map, affine_param