Exemple #1
0
    def forward(self, input):
        # x = self.DownBlock(input)
        x = checkpoint_sequential(self.DownBlock, 8, input)

        gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
        gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
        gap_weight = list(self.gap_fc.parameters())[0]
        gap = x * gap_weight.unsqueeze(2).unsqueeze(3)

        gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
        gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
        gmp_weight = list(self.gmp_fc.parameters())[0]
        gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)

        cam_logit = torch.cat([gap_logit, gmp_logit], 1)
        x = torch.cat([gap, gmp], 1)
        x = self.relu(self.conv1x1(x))

        heatmap = torch.sum(x, dim=1, keepdim=True)

        if self.light:
            x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1)
            x_ = checkpoint_sequential(self.FC, 2, x_.view(x_.shape[0], -1))
            # x_ = self.FC(x_.view(x_.shape[0], -1))
        else:
            # x_ = self.FC(x.view(x.shape[0], -1))
            x_ = checkpoint_sequential(self.FC, 2, x.view(x.shape[0], -1))
        gamma, beta = self.gamma(x_), self.beta(x_)

        for i in range(self.n_blocks):
            x = getattr(self, 'UpBlock1_' + str(i + 1))(x, gamma, beta)
        # out = self.UpBlock2(x)
        out = checkpoint_sequential(self.UpBlock2, 8, x)

        return out, cam_logit, heatmap
Exemple #2
0
    def forward(self, x, training):
        # x.requires_grad_(True)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        if self.pooling:
            x = self.maxpool(x)

        x2 = x.clone()
        if self.checkpoint and training:
            x = checkpoint_sequential(self.layer1, 3, x)
            x = checkpoint_sequential(self.layer2, 3, x)
            x1 = x.clone()

            x = checkpoint_sequential(self.layer3, 3, x)
            x = checkpoint_sequential(self.layer4, 3, x)
        else:
            x = self.layer1(x)
            x = self.layer2(x)
            x1 = x.clone()

            x = self.layer3(x)
            x = self.layer4(x)

        return x, x1, x2
    def forward(self, x):
        layer_out = []
        for layer in self.encoder.features[:3]:
            x = checkpoint_sequential(layer, len(list(layer)), x)
            # x = layer(x)
            layer_out.append(x)

        layer_out[0] = self.feature_avg_pool(layer_out[0])
        layer_out[1] = self.feature_avg_pool(layer_out[1])
        layer_out = torch.cat(layer_out, dim=1)

        pooled_features = []
        for layer in self.encoder.features[3:]:
            x = checkpoint_sequential(layer, len(list(layer)), x)
            # x = layer(x)
            pooled_features.append(x)

        x = self.feature_pooling(torch.cat(pooled_features, dim=1))
        x = F.upsample(x, scale_factor=2, mode='bilinear')

        x = self.up_4x_conv(x)
        # concatenate inpute features

        layer_out = self.feature_4x_conv(layer_out)
        x = torch.cat([layer_out, x], dim=1)
        x = self.smooth_feature_4x_conv(x)

        x = self.out_conv(x)
        return x
Exemple #4
0
    def shared_step(self, batch, batch_idx):
        img1, img2 = batch

        # ENCODE
        # encode -> representations
        # (b, 3, 32, 32) -> (b, 2048, 2, 2)
        if self.hparams.save_mem:
            img1 = torch.autograd.Variable(img1, requires_grad=True)
            img2 = torch.autograd.Variable(img2, requires_grad=True)
            h1 = checkpoint_sequential(self.encoder, 9, img1)
            h2 = checkpoint_sequential(self.encoder, 9, img2)
        else:
            h1 = self.encoder(img1)
            h2 = self.encoder(img2)

        # the bolts resnets return a list of feature maps
        if isinstance(h1, list):
            h1 = h1[-1]
            h2 = h2[-1]

        # PROJECT
        # img -> E -> h -> || -> z
        # (b, 2048, 2, 2) -> (b, 128)
        z1 = self.projection(h1)
        z2 = self.projection(h2)

        loss, sim = self.nt_xent_loss(z1, z2, self.hparams.loss_temperature)
        acc = self.cale_acc(sim)

        return loss, acc
Exemple #5
0
    def test_checkpoint(self):
        model = nn.Sequential(
            nn.Linear(100, 50),
            nn.ReLU(),
            nn.Linear(50, 20),
            nn.ReLU(),
            nn.Linear(20, 5),
            nn.ReLU()
        )

        x = torch.randn(1, 100, requires_grad=True)

        # not checkpointed
        out = model(x)
        out_not_checkpointed = out.data.clone()
        model.zero_grad()
        out.sum().backward()
        grad_not_checkpointed = {}
        for name, param in model.named_parameters():
            grad_not_checkpointed[name] = param.grad.data.clone()
        input_grad = x.grad.data.clone()

        # checkpointed model by passing list of modules
        chunks = 2
        modules = list(model.children())
        input_var = x.detach()
        input_var.requires_grad = True
        # pass list of modules to checkpoint
        out = checkpoint_sequential(modules, chunks, input_var)
        out_checkpointed = out.data.clone()
        model.zero_grad()
        out.sum().backward()
        grad_checkpointed = {}
        for name, param in model.named_parameters():
            grad_checkpointed[name] = param.grad.data.clone()
        checkpoint_input_grad = input_var.grad.data.clone()
        # compare the output, input and parameters gradients
        self.assertEqual(out_checkpointed, out_not_checkpointed)
        self.assertEqual(input_grad, checkpoint_input_grad)
        for name in grad_checkpointed:
            self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])

        # checkpointed by passing sequential directly
        input_var1 = x.detach()
        input_var1.requires_grad = True
        # pass the sequential itself
        out = checkpoint_sequential(model, 2, input_var1)
        out_checkpointed = out.data.clone()
        model.zero_grad()
        out.sum().backward()
        grad_checkpointed = {}
        for name, param in model.named_parameters():
            grad_checkpointed[name] = param.grad.data.clone()
        checkpoint_input_grad = input_var1.grad.data.clone()
        # compare the output, input and parameters gradients
        self.assertEqual(out_checkpointed, out_not_checkpointed)
        self.assertEqual(input_grad, checkpoint_input_grad)
        for name in grad_checkpointed:
            self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])
Exemple #6
0
    def test_checkpoint(self):
        model = nn.Sequential(
            nn.Linear(100, 50),
            nn.ReLU(),
            nn.Linear(50, 20),
            nn.ReLU(),
            nn.Linear(20, 5),
            nn.ReLU()
        )

        x = torch.randn(1, 100, requires_grad=True)

        # not checkpointed
        out = model(x)
        out_not_checkpointed = out.data.clone()
        model.zero_grad()
        out.sum().backward()
        grad_not_checkpointed = {}
        for name, param in model.named_parameters():
            grad_not_checkpointed[name] = param.grad.data.clone()
        input_grad = x.grad.data.clone()

        # checkpointed model by passing list of modules
        chunks = 2
        modules = list(model.children())
        input_var = x.detach()
        input_var.requires_grad = True
        # pass list of modules to checkpoint
        out = checkpoint_sequential(modules, chunks, input_var)
        out_checkpointed = out.data.clone()
        model.zero_grad()
        out.sum().backward()
        grad_checkpointed = {}
        for name, param in model.named_parameters():
            grad_checkpointed[name] = param.grad.data.clone()
        checkpoint_input_grad = input_var.grad.data.clone()
        # compare the output, input and parameters gradients
        self.assertEqual(out_checkpointed, out_not_checkpointed)
        self.assertEqual(input_grad, checkpoint_input_grad)
        for name in grad_checkpointed:
            self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])

        # checkpointed by passing sequential directly
        input_var1 = x.detach()
        input_var1.requires_grad = True
        # pass the sequential itself
        out = checkpoint_sequential(model, 2, input_var1)
        out_checkpointed = out.data.clone()
        model.zero_grad()
        out.sum().backward()
        grad_checkpointed = {}
        for name, param in model.named_parameters():
            grad_checkpointed[name] = param.grad.data.clone()
        checkpoint_input_grad = input_var1.grad.data.clone()
        # compare the output, input and parameters gradients
        self.assertEqual(out_checkpointed, out_not_checkpointed)
        self.assertEqual(input_grad, checkpoint_input_grad)
        for name in grad_checkpointed:
            self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])
Exemple #7
0
 def forward(self, input1, input2):
     segments = 2
     if self.upsampling is True:
         output2 = checkpoint_sequential(self.up_modules, segments, input2)
     else:
         output2 = self.up(input2)
     output1 = nn.functional.interpolate(input1, output2.size()[2:], mode='bilinear', align_corners=True)
     return checkpoint_sequential(self.conv_modules, segments, torch.cat([output1, output2], 1))
Exemple #8
0
    def test_checkpoint_sequential_deprecated_no_args(self):
        class Noop(nn.Module):
            def forward(self):
                pass

        model = nn.Sequential(Noop())

        with self.assertRaises(TypeError):
            checkpoint_sequential(model, 1)
Exemple #9
0
    def test_checkpoint_sequential_deprecated_multiple_args(self):
        class Two(nn.Module):
            def forward(self, a, b):
                return a, b

        model = nn.Sequential(Two())
        a = torch.randn(1, 100, requires_grad=True)
        b = torch.randn(1, 100, requires_grad=True)

        with self.assertRaises(TypeError):
            checkpoint_sequential(model, 1, a, b)
Exemple #10
0
 def forward(self, x):
     with torch.cuda.amp.autocast(self.fp16):
         x = self.conv1(x)
         x = self.bn1(x)
         x = self.prelu(x)
         x = checkpoint_sequential(self.layer1, 10, x)
         x = checkpoint_sequential(self.layer2, 10, x)
         x = checkpoint_sequential(self.layer3, 10, x)
         x = checkpoint_sequential(self.layer4, 10, x)
         x = self.bn2(x)
         x = torch.flatten(x, 1)
         x = self.dropout(x)
     x = self.fc(x.float() if self.fp16 else x)
     x = self.features(x)
     return x
Exemple #11
0
    def test_checkpoint_trigger(self):

        class Net(nn.Module):

            def __init__(self):
                super(Net, self).__init__()
                self.counter = 0

            def forward(self, input_var):
                self.counter += 1
                return input_var

        # checkpointed
        modules = [Net() for _ in range(10)]
        for m in modules:
            self.assertEqual(m.counter, 0)
        input_var = torch.randn(3, 4, requires_grad=True)
        out = checkpoint_sequential(modules, 2, input_var)
        for m in modules:
            self.assertEqual(m.counter, 1)
        out.sum().backward()
        for m in modules[:(len(modules) // 2)]:
            self.assertEqual(m.counter, 2)
        for m in modules[(len(modules) // 2):]:
            self.assertEqual(m.counter, 1)
 def forward(self, x):
     # out = self.features(x)
     out = checkpoint_sequential(self.features, 2, x)
     out = out.view(out.size(0), -1)
     out = F.dropout(out, p=0.5, training=self.training)
     out = self.classifier(out)
     return out
Exemple #13
0
    def test_checkpoint_trigger(self):

        class Net(nn.Module):

            def __init__(self):
                super(Net, self).__init__()
                self.counter = 0

            def forward(self, input_var):
                self.counter += 1
                return input_var

        # checkpointed
        modules = [Net() for _ in range(10)]
        for m in modules:
            self.assertEqual(m.counter, 0)
        input_var = torch.randn(3, 4, requires_grad=True)
        out = checkpoint_sequential(modules, 2, input_var)
        for m in modules:
            self.assertEqual(m.counter, 1)
        out.sum().backward()
        for m in modules[:(len(modules) // 2)]:
            self.assertEqual(m.counter, 2)
        for m in modules[(len(modules) // 2):]:
            self.assertEqual(m.counter, 1)
 def forward(self, input_features):
     if self.memory_efficient:
         output = cp.checkpoint_sequential(self.main_net, 4, input_features)
     else:
         output = self.main_net(input_features)
     output = F.adaptive_avg_pool2d(output, (1, 1)).view(output.size(0), -1)
     output = self.classifier(output)
     return output
 def forward(self, input_var, chunks=3):
     modules = [module for k, module in self._modules.items()][0]
     #print(modules)
     input_var = checkpoint_sequential(modules, chunks, input_var)
     #print(input_var.shape)
     input_var = input_var.view(input_var.size(0) // 5, 5, -1).mean(1)
     input_var = self.fc(input_var)
     return input_var
 def forward(self, x):
     if self.is_memory_efficient:
         x = cp.checkpoint_sequential(self.classifier, segments=1, input=x)
     else:
         x = self.classifier(x)
     logits = x.view(x.size(0), self.num_classes)
     probas = F.softmax(logits, dim=1)
     return logits, probas
Exemple #17
0
    def train_one_epoch(self):
        """
        Train network in one epoch
        """
        print('Training......')

        # set mode train
        self.network.train()

        # prepare data
        train_loss = 0
        train_loader = DataLoader(self.datasets['train'],
                                  batch_size=self.params.train_batch,
                                  shuffle=self.params.shuffle,
                                  drop_last=True,
                                  num_workers=self.params.dataloader_workers)
        train_size = len(self.datasets['train'])
        if train_size % self.params.train_batch != 0:
            total_batch = train_size // self.params.train_batch + 1
        else:
            total_batch = train_size // self.params.train_batch

        # train through dataset
        for batch_idx, batch in enumerate(train_loader):
            self.pb.click(batch_idx, total_batch)
            image, label = batch['image'], batch['label']
            image_cuda, label_cuda = image.cuda(), label.cuda()

            # checkpoint split
            if self.params.should_split:
                image_cuda.requires_grad_()
                out = checkpoint_sequential(self.network, self.params.split,
                                            image_cuda)
            else:
                out = self.network(image_cuda)
            loss = self.loss_fn(out, label_cuda)

            # optimize
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()

            # accumulate
            train_loss += loss.item()

            # record first loss
            if self.train_loss == []:
                self.train_loss.append(train_loss)
                self.summary_writer.add_scalar('loss/train_loss', train_loss,
                                               0)

        self.pb.close()
        train_loss /= total_batch
        self.train_loss.append(train_loss)

        # add to summary
        self.summary_writer.add_scalar('loss/train_loss', train_loss,
                                       self.epoch)
Exemple #18
0
 def _head_to_tail(self, pool5):
     if cfg.MIX_LOCATION != 0:
         cfg.layer4 = True
     num_segments = 3
     fc7 = checkpoint_sequential(self.resnet.layer4, num_segments, pool5)
     fc7 = fc7.mean(3).mean(2)
     #    fc7 = self.resnet.layer4(pool5).mean(3).mean(2) # average pooling after layer4
     cfg.layer4 = False
     return fc7
    def forward(self, x):
        x1 = self.conv_1(x)
        x2 = self.conv_2(x)
        x3 = self.conv_3(x)
        x4 = self.conv_4(x)

        y = torch.cat([x1, x2, x3, x4], dim=1)
        y = checkpoint_sequential(self.conv_1x1, 1, y)
        return y
Exemple #20
0
 def forward(self, x, chunks=None):
     modules = [module for k, module in self._modules.items()][0]
     input_var = x.detach()
     input_var.requires_grad = True
     input_var = checkpoint_sequential(modules, chunks, input_var)
     input_var = F.relu(input_var, inplace=True)
     input_var = F.avg_pool2d(input_var, kernel_size=7, stride=1).view(input_var.size(0), -1)
     input_var = self.classifier(input_var)
     return input_var
    def forward(self, x):
        if self.save_grad:
            x = checkpoint_sequential(self.model, 9, x)
        else:
            x = self.model(x)

        x = x.view(-1, 1024)
        x = self.fc(x)
        return x
Exemple #22
0
 def forward(self,
             x: torch.Tensor,
             use_checkpoint: bool = False) -> torch.Tensor:
     if use_checkpoint:
         if not x.requires_grad:
             x = x.clone().requires_grad_()
         return checkpoint_sequential(self.d_blocks, len(self.d_blocks), x)
     else:
         return self.d_blocks(x)
Exemple #23
0
    def Test(self):
        """
        Test network on test set
        """
        print('Testing:')
        # set mode eval
        torch.cuda.empty_cache()
        self.network.eval()

        # prepare test data
        test_loader = DataLoader(self.datasets['test'],
                                 batch_size=self.params.test_batch,
                                 shuffle=False,
                                 num_workers=self.params.dataloader_workers)
        test_size = len(self.datasets['test'])
        if test_size % self.params.test_batch != 0:
            total_batch = test_size // self.params.test_batch + 1
        else:
            total_batch = test_size // self.params.test_batch

        # test for one epoch
        for batch_idx, batch in enumerate(test_loader):
            self.pb.click(batch_idx, total_batch)
            image, label, name = batch['image'], batch['label'], batch[
                'label_name']
            image_cuda, label_cuda = image.cuda(), label.cuda()
            if self.params.should_split:
                image_cuda.requires_grad_()
                out = checkpoint_sequential(self.network, self.params.split,
                                            image_cuda)
            else:
                out = self.network(image_cuda)

            for i in range(self.params.test_batch):
                idx = batch_idx * self.params.test_batch + i
                id_map = logits2trainId(out[i, ...])
                color_map = trainId2color(self.params.logdir,
                                          id_map,
                                          name=name[i],
                                          save=False)
                #trainId2LabelId(self.params.logdir, id_map, name=name[i])
                image_orig = image[i].numpy().transpose(1, 2, 0)
                image_orig = image_orig * 255
                image_orig = image_orig.astype(np.uint8)

                image_orig = cv2.cvtColor(image_orig,
                                          cv2.COLOR_BGR2RGB).transpose(
                                              (2, 0, 1))
                color_map = cv2.cvtColor(color_map,
                                         cv2.COLOR_BGR2RGB).transpose(
                                             (2, 0, 1))

                self.summary_writer.add_image('test/img_%d/orig' % idx,
                                              image_orig, idx)
                self.summary_writer.add_image('test/img_%d/seg' % idx,
                                              color_map, idx)
 def forward(self, x):
     # 使用 checkpoint
     segments = 10
     out = x.detach()
     out.requires_grad = True
     out = checkpoint.checkpoint_sequential(
         self.cnn, segments, out
     )  # segments 的值不能太大(實際值視情況而定). for start in range(0, segment_size * (segments - 1), segment_size): ValueError: range() arg 3 must not be zero.
     out = out.view(out.size()[0], -1)  # flatten
     return self.fc(out)
Exemple #25
0
 def forward_features(self, x):
     x = self.effnet_model.conv_stem(x)
     x = self.effnet_model.bn1(x)
     x = self.effnet_model.act1(x)
     x = checkpoint_sequential(self.effnet_model.blocks,
                               self.checkpoint_nchunks, x)
     x = self.effnet_model.conv_head(x)
     x = self.effnet_model.bn2(x)
     x = self.effnet_model.act2(x)
     return x
Exemple #26
0
def auto_grad_checkpoint(layer, x, chunks=3):
    use_grad_checkpoint = getattr(auto_grad_checkpoint, 'use_grad_checkpoint',
                                  False)
    chunks = getattr(auto_grad_checkpoint, 'chunks',
                     chunks)  # for globally set chunks
    chunks = min(len(layer), chunks)
    need_grad = next(layer.parameters()).requires_grad
    if use_grad_checkpoint and need_grad and chunks > 0:
        return checkpoint_sequential(layer, chunks, x)
    return layer(x)
Exemple #27
0
    def val_one_epoch(self):
        """
        Validate network in one epoch every m training epochs,
            m is defined in params.val_every
        """
        # TODO: add IoU compute function
        print('Validating:')

        # set mode eval
        self.network.eval()

        # prepare data
        val_loss = 0
        val_loader = DataLoader(self.datasets['val'],
                                batch_size=self.params.val_batch,
                                shuffle=self.params.shuffle,
                                drop_last=True,
                                num_workers=self.params.dataloader_workers)
        val_size = len(self.datasets['val'])
        if val_size % self.params.val_batch != 0:
            total_batch = val_size // self.params.val_batch + 1
        else:
            total_batch = val_size // self.params.val_batch

        # validate through dataset
        for batch_idx, batch in enumerate(val_loader):
            self.pb.click(batch_idx, total_batch)
            image, label = batch['image'], batch['label']
            image_cuda, label_cuda = image.cuda(), label.cuda()

            # checkpoint split
            if self.params.should_split:
                image_cuda.requires_grad_()
                out = checkpoint_sequential(self.network, self.params.split,
                                            image_cuda)
            else:
                out = self.network(image_cuda)

            loss = self.loss_fn(out, label_cuda)

            val_loss += loss.item()

            # record first loss
            if self.val_loss == []:
                self.val_loss.append(val_loss)
                self.summary_writer.add_scalar('loss/val_loss', val_loss, 0)

        self.pb.close()
        val_loss /= total_batch
        self.val_loss.append(val_loss)

        # add to summary
        self.summary_writer.add_scalar('loss/val_loss', val_loss, self.epoch)
Exemple #28
0
 def forward(self, x):
     x = self.input_layer(x)
     if self.use_checkpoint:
         # x = checkpoint_sequential(self.input_layer,2,x)
         x=checkpoint_sequential(self.body,self.chunks,x)
         # x = checkpoint_sequential(self.output_layer,2,x)
     else:
         # x = self.input_layer(x)
         x = self.body(x)
         # x = self.output_layer(x)
     x = self.output_layer(x)
     return l2_norm(x)
Exemple #29
0
    def test_checkpoint_sequential_deprecated_no_args(self):
        class Noop(nn.Module):
            def forward(self):
                pass

        model = nn.Sequential(Noop())

        self.assertWarnsRegex(
            lambda: checkpoint_sequential(model, 1),
            'deprecated',
            'checkpoint_sequential with no args should be deprecated',
        )
Exemple #30
0
 def forward(self, input):
     embed = self.embedding(input)
     input_var = embed.detach()
     input_var.requires_grad = True
     out = checkpoint(self.bilstm, input_var)
     out1 = out[0].permute(1, 2, 0)
     pooling = kmax_pooling(out1, 2,
                            self.kmax_pooling)  # batch * hidden *kmax
     flatten = pooling.view(pooling.size(0), -1)
     out2 = checkpoint_sequential(self.fc, 4, flatten)
     logits = self.linear(out2)
     return logits
Exemple #31
0
 def forward(self, input_t):
     """
     Input has shape (batch_size, 1, audio_length,) 
     Output has shape (batch_size, audio_length,) 
     """
     batch_size = input_t.shape[0]
     assert input_t.shape[1] == 1
     audio_length = input_t.shape[2]
     # Use gradient checkpointing to save GPU memory.
     modules = [m for m in self.convs._modules.values()]
     conv_t = checkpoint_sequential(modules, GRAD_CHECKPOINT_SEGMENTS, input_t)
     conv_t = conv_t.squeeze(dim=1)
     return self.tanh(conv_t)
Exemple #32
0
    def forward(self, x):
        if self.checkpoint is True:
            # modules = [module for k, module in self._modules.items()][0]
            input = x.detach()
            input.requires_grad = True
            # input = self.first_layer(input)
            input = checkpoint_sequential(self.feature_layers, 2, input)
        else:
            input = self.feature_layers(input)

        input = input.view(input.size(0), -1)
        input = self.classifier(input)
        return input
Exemple #33
0
    def test_checkpoint_valid(self):
        model = nn.Sequential(
            nn.Linear(100, 50),
            nn.ReLU(),
            nn.Linear(50, 20),
            nn.ReLU(),
            nn.Linear(20, 5),
            nn.ReLU()
        )

        input_var = torch.randn(1, 100, requires_grad=True)

        # checkpointed
        chunks = 2
        modules = list(model.children())
        out = checkpoint_sequential(modules, chunks, input_var)
        with self.assertRaisesRegex(RuntimeError, "Checkpointing is not compatible"):
            torch.autograd.grad(
                outputs=[out], grad_outputs=[torch.ones(1, 5)], inputs=[input_var], create_graph=True
            )