def forward(self, input): # x = self.DownBlock(input) x = checkpoint_sequential(self.DownBlock, 8, input) gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) gap_logit = self.gap_fc(gap.view(x.shape[0], -1)) gap_weight = list(self.gap_fc.parameters())[0] gap = x * gap_weight.unsqueeze(2).unsqueeze(3) gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1)) gmp_weight = list(self.gmp_fc.parameters())[0] gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) cam_logit = torch.cat([gap_logit, gmp_logit], 1) x = torch.cat([gap, gmp], 1) x = self.relu(self.conv1x1(x)) heatmap = torch.sum(x, dim=1, keepdim=True) if self.light: x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1) x_ = checkpoint_sequential(self.FC, 2, x_.view(x_.shape[0], -1)) # x_ = self.FC(x_.view(x_.shape[0], -1)) else: # x_ = self.FC(x.view(x.shape[0], -1)) x_ = checkpoint_sequential(self.FC, 2, x.view(x.shape[0], -1)) gamma, beta = self.gamma(x_), self.beta(x_) for i in range(self.n_blocks): x = getattr(self, 'UpBlock1_' + str(i + 1))(x, gamma, beta) # out = self.UpBlock2(x) out = checkpoint_sequential(self.UpBlock2, 8, x) return out, cam_logit, heatmap
def forward(self, x, training): # x.requires_grad_(True) x = self.conv1(x) x = self.bn1(x) x = self.relu(x) if self.pooling: x = self.maxpool(x) x2 = x.clone() if self.checkpoint and training: x = checkpoint_sequential(self.layer1, 3, x) x = checkpoint_sequential(self.layer2, 3, x) x1 = x.clone() x = checkpoint_sequential(self.layer3, 3, x) x = checkpoint_sequential(self.layer4, 3, x) else: x = self.layer1(x) x = self.layer2(x) x1 = x.clone() x = self.layer3(x) x = self.layer4(x) return x, x1, x2
def forward(self, x): layer_out = [] for layer in self.encoder.features[:3]: x = checkpoint_sequential(layer, len(list(layer)), x) # x = layer(x) layer_out.append(x) layer_out[0] = self.feature_avg_pool(layer_out[0]) layer_out[1] = self.feature_avg_pool(layer_out[1]) layer_out = torch.cat(layer_out, dim=1) pooled_features = [] for layer in self.encoder.features[3:]: x = checkpoint_sequential(layer, len(list(layer)), x) # x = layer(x) pooled_features.append(x) x = self.feature_pooling(torch.cat(pooled_features, dim=1)) x = F.upsample(x, scale_factor=2, mode='bilinear') x = self.up_4x_conv(x) # concatenate inpute features layer_out = self.feature_4x_conv(layer_out) x = torch.cat([layer_out, x], dim=1) x = self.smooth_feature_4x_conv(x) x = self.out_conv(x) return x
def shared_step(self, batch, batch_idx): img1, img2 = batch # ENCODE # encode -> representations # (b, 3, 32, 32) -> (b, 2048, 2, 2) if self.hparams.save_mem: img1 = torch.autograd.Variable(img1, requires_grad=True) img2 = torch.autograd.Variable(img2, requires_grad=True) h1 = checkpoint_sequential(self.encoder, 9, img1) h2 = checkpoint_sequential(self.encoder, 9, img2) else: h1 = self.encoder(img1) h2 = self.encoder(img2) # the bolts resnets return a list of feature maps if isinstance(h1, list): h1 = h1[-1] h2 = h2[-1] # PROJECT # img -> E -> h -> || -> z # (b, 2048, 2, 2) -> (b, 128) z1 = self.projection(h1) z2 = self.projection(h2) loss, sim = self.nt_xent_loss(z1, z2, self.hparams.loss_temperature) acc = self.cale_acc(sim) return loss, acc
def test_checkpoint(self): model = nn.Sequential( nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 5), nn.ReLU() ) x = torch.randn(1, 100, requires_grad=True) # not checkpointed out = model(x) out_not_checkpointed = out.data.clone() model.zero_grad() out.sum().backward() grad_not_checkpointed = {} for name, param in model.named_parameters(): grad_not_checkpointed[name] = param.grad.data.clone() input_grad = x.grad.data.clone() # checkpointed model by passing list of modules chunks = 2 modules = list(model.children()) input_var = x.detach() input_var.requires_grad = True # pass list of modules to checkpoint out = checkpoint_sequential(modules, chunks, input_var) out_checkpointed = out.data.clone() model.zero_grad() out.sum().backward() grad_checkpointed = {} for name, param in model.named_parameters(): grad_checkpointed[name] = param.grad.data.clone() checkpoint_input_grad = input_var.grad.data.clone() # compare the output, input and parameters gradients self.assertEqual(out_checkpointed, out_not_checkpointed) self.assertEqual(input_grad, checkpoint_input_grad) for name in grad_checkpointed: self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name]) # checkpointed by passing sequential directly input_var1 = x.detach() input_var1.requires_grad = True # pass the sequential itself out = checkpoint_sequential(model, 2, input_var1) out_checkpointed = out.data.clone() model.zero_grad() out.sum().backward() grad_checkpointed = {} for name, param in model.named_parameters(): grad_checkpointed[name] = param.grad.data.clone() checkpoint_input_grad = input_var1.grad.data.clone() # compare the output, input and parameters gradients self.assertEqual(out_checkpointed, out_not_checkpointed) self.assertEqual(input_grad, checkpoint_input_grad) for name in grad_checkpointed: self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])
def test_checkpoint(self): model = nn.Sequential( nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 5), nn.ReLU() ) x = torch.randn(1, 100, requires_grad=True) # not checkpointed out = model(x) out_not_checkpointed = out.data.clone() model.zero_grad() out.sum().backward() grad_not_checkpointed = {} for name, param in model.named_parameters(): grad_not_checkpointed[name] = param.grad.data.clone() input_grad = x.grad.data.clone() # checkpointed model by passing list of modules chunks = 2 modules = list(model.children()) input_var = x.detach() input_var.requires_grad = True # pass list of modules to checkpoint out = checkpoint_sequential(modules, chunks, input_var) out_checkpointed = out.data.clone() model.zero_grad() out.sum().backward() grad_checkpointed = {} for name, param in model.named_parameters(): grad_checkpointed[name] = param.grad.data.clone() checkpoint_input_grad = input_var.grad.data.clone() # compare the output, input and parameters gradients self.assertEqual(out_checkpointed, out_not_checkpointed) self.assertEqual(input_grad, checkpoint_input_grad) for name in grad_checkpointed: self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name]) # checkpointed by passing sequential directly input_var1 = x.detach() input_var1.requires_grad = True # pass the sequential itself out = checkpoint_sequential(model, 2, input_var1) out_checkpointed = out.data.clone() model.zero_grad() out.sum().backward() grad_checkpointed = {} for name, param in model.named_parameters(): grad_checkpointed[name] = param.grad.data.clone() checkpoint_input_grad = input_var1.grad.data.clone() # compare the output, input and parameters gradients self.assertEqual(out_checkpointed, out_not_checkpointed) self.assertEqual(input_grad, checkpoint_input_grad) for name in grad_checkpointed: self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])
def forward(self, input1, input2): segments = 2 if self.upsampling is True: output2 = checkpoint_sequential(self.up_modules, segments, input2) else: output2 = self.up(input2) output1 = nn.functional.interpolate(input1, output2.size()[2:], mode='bilinear', align_corners=True) return checkpoint_sequential(self.conv_modules, segments, torch.cat([output1, output2], 1))
def test_checkpoint_sequential_deprecated_no_args(self): class Noop(nn.Module): def forward(self): pass model = nn.Sequential(Noop()) with self.assertRaises(TypeError): checkpoint_sequential(model, 1)
def test_checkpoint_sequential_deprecated_multiple_args(self): class Two(nn.Module): def forward(self, a, b): return a, b model = nn.Sequential(Two()) a = torch.randn(1, 100, requires_grad=True) b = torch.randn(1, 100, requires_grad=True) with self.assertRaises(TypeError): checkpoint_sequential(model, 1, a, b)
def forward(self, x): with torch.cuda.amp.autocast(self.fp16): x = self.conv1(x) x = self.bn1(x) x = self.prelu(x) x = checkpoint_sequential(self.layer1, 10, x) x = checkpoint_sequential(self.layer2, 10, x) x = checkpoint_sequential(self.layer3, 10, x) x = checkpoint_sequential(self.layer4, 10, x) x = self.bn2(x) x = torch.flatten(x, 1) x = self.dropout(x) x = self.fc(x.float() if self.fp16 else x) x = self.features(x) return x
def test_checkpoint_trigger(self): class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.counter = 0 def forward(self, input_var): self.counter += 1 return input_var # checkpointed modules = [Net() for _ in range(10)] for m in modules: self.assertEqual(m.counter, 0) input_var = torch.randn(3, 4, requires_grad=True) out = checkpoint_sequential(modules, 2, input_var) for m in modules: self.assertEqual(m.counter, 1) out.sum().backward() for m in modules[:(len(modules) // 2)]: self.assertEqual(m.counter, 2) for m in modules[(len(modules) // 2):]: self.assertEqual(m.counter, 1)
def forward(self, x): # out = self.features(x) out = checkpoint_sequential(self.features, 2, x) out = out.view(out.size(0), -1) out = F.dropout(out, p=0.5, training=self.training) out = self.classifier(out) return out
def test_checkpoint_trigger(self): class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.counter = 0 def forward(self, input_var): self.counter += 1 return input_var # checkpointed modules = [Net() for _ in range(10)] for m in modules: self.assertEqual(m.counter, 0) input_var = torch.randn(3, 4, requires_grad=True) out = checkpoint_sequential(modules, 2, input_var) for m in modules: self.assertEqual(m.counter, 1) out.sum().backward() for m in modules[:(len(modules) // 2)]: self.assertEqual(m.counter, 2) for m in modules[(len(modules) // 2):]: self.assertEqual(m.counter, 1)
def forward(self, input_features): if self.memory_efficient: output = cp.checkpoint_sequential(self.main_net, 4, input_features) else: output = self.main_net(input_features) output = F.adaptive_avg_pool2d(output, (1, 1)).view(output.size(0), -1) output = self.classifier(output) return output
def forward(self, input_var, chunks=3): modules = [module for k, module in self._modules.items()][0] #print(modules) input_var = checkpoint_sequential(modules, chunks, input_var) #print(input_var.shape) input_var = input_var.view(input_var.size(0) // 5, 5, -1).mean(1) input_var = self.fc(input_var) return input_var
def forward(self, x): if self.is_memory_efficient: x = cp.checkpoint_sequential(self.classifier, segments=1, input=x) else: x = self.classifier(x) logits = x.view(x.size(0), self.num_classes) probas = F.softmax(logits, dim=1) return logits, probas
def train_one_epoch(self): """ Train network in one epoch """ print('Training......') # set mode train self.network.train() # prepare data train_loss = 0 train_loader = DataLoader(self.datasets['train'], batch_size=self.params.train_batch, shuffle=self.params.shuffle, drop_last=True, num_workers=self.params.dataloader_workers) train_size = len(self.datasets['train']) if train_size % self.params.train_batch != 0: total_batch = train_size // self.params.train_batch + 1 else: total_batch = train_size // self.params.train_batch # train through dataset for batch_idx, batch in enumerate(train_loader): self.pb.click(batch_idx, total_batch) image, label = batch['image'], batch['label'] image_cuda, label_cuda = image.cuda(), label.cuda() # checkpoint split if self.params.should_split: image_cuda.requires_grad_() out = checkpoint_sequential(self.network, self.params.split, image_cuda) else: out = self.network(image_cuda) loss = self.loss_fn(out, label_cuda) # optimize self.opt.zero_grad() loss.backward() self.opt.step() # accumulate train_loss += loss.item() # record first loss if self.train_loss == []: self.train_loss.append(train_loss) self.summary_writer.add_scalar('loss/train_loss', train_loss, 0) self.pb.close() train_loss /= total_batch self.train_loss.append(train_loss) # add to summary self.summary_writer.add_scalar('loss/train_loss', train_loss, self.epoch)
def _head_to_tail(self, pool5): if cfg.MIX_LOCATION != 0: cfg.layer4 = True num_segments = 3 fc7 = checkpoint_sequential(self.resnet.layer4, num_segments, pool5) fc7 = fc7.mean(3).mean(2) # fc7 = self.resnet.layer4(pool5).mean(3).mean(2) # average pooling after layer4 cfg.layer4 = False return fc7
def forward(self, x): x1 = self.conv_1(x) x2 = self.conv_2(x) x3 = self.conv_3(x) x4 = self.conv_4(x) y = torch.cat([x1, x2, x3, x4], dim=1) y = checkpoint_sequential(self.conv_1x1, 1, y) return y
def forward(self, x, chunks=None): modules = [module for k, module in self._modules.items()][0] input_var = x.detach() input_var.requires_grad = True input_var = checkpoint_sequential(modules, chunks, input_var) input_var = F.relu(input_var, inplace=True) input_var = F.avg_pool2d(input_var, kernel_size=7, stride=1).view(input_var.size(0), -1) input_var = self.classifier(input_var) return input_var
def forward(self, x): if self.save_grad: x = checkpoint_sequential(self.model, 9, x) else: x = self.model(x) x = x.view(-1, 1024) x = self.fc(x) return x
def forward(self, x: torch.Tensor, use_checkpoint: bool = False) -> torch.Tensor: if use_checkpoint: if not x.requires_grad: x = x.clone().requires_grad_() return checkpoint_sequential(self.d_blocks, len(self.d_blocks), x) else: return self.d_blocks(x)
def Test(self): """ Test network on test set """ print('Testing:') # set mode eval torch.cuda.empty_cache() self.network.eval() # prepare test data test_loader = DataLoader(self.datasets['test'], batch_size=self.params.test_batch, shuffle=False, num_workers=self.params.dataloader_workers) test_size = len(self.datasets['test']) if test_size % self.params.test_batch != 0: total_batch = test_size // self.params.test_batch + 1 else: total_batch = test_size // self.params.test_batch # test for one epoch for batch_idx, batch in enumerate(test_loader): self.pb.click(batch_idx, total_batch) image, label, name = batch['image'], batch['label'], batch[ 'label_name'] image_cuda, label_cuda = image.cuda(), label.cuda() if self.params.should_split: image_cuda.requires_grad_() out = checkpoint_sequential(self.network, self.params.split, image_cuda) else: out = self.network(image_cuda) for i in range(self.params.test_batch): idx = batch_idx * self.params.test_batch + i id_map = logits2trainId(out[i, ...]) color_map = trainId2color(self.params.logdir, id_map, name=name[i], save=False) #trainId2LabelId(self.params.logdir, id_map, name=name[i]) image_orig = image[i].numpy().transpose(1, 2, 0) image_orig = image_orig * 255 image_orig = image_orig.astype(np.uint8) image_orig = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB).transpose( (2, 0, 1)) color_map = cv2.cvtColor(color_map, cv2.COLOR_BGR2RGB).transpose( (2, 0, 1)) self.summary_writer.add_image('test/img_%d/orig' % idx, image_orig, idx) self.summary_writer.add_image('test/img_%d/seg' % idx, color_map, idx)
def forward(self, x): # 使用 checkpoint segments = 10 out = x.detach() out.requires_grad = True out = checkpoint.checkpoint_sequential( self.cnn, segments, out ) # segments 的值不能太大(實際值視情況而定). for start in range(0, segment_size * (segments - 1), segment_size): ValueError: range() arg 3 must not be zero. out = out.view(out.size()[0], -1) # flatten return self.fc(out)
def forward_features(self, x): x = self.effnet_model.conv_stem(x) x = self.effnet_model.bn1(x) x = self.effnet_model.act1(x) x = checkpoint_sequential(self.effnet_model.blocks, self.checkpoint_nchunks, x) x = self.effnet_model.conv_head(x) x = self.effnet_model.bn2(x) x = self.effnet_model.act2(x) return x
def auto_grad_checkpoint(layer, x, chunks=3): use_grad_checkpoint = getattr(auto_grad_checkpoint, 'use_grad_checkpoint', False) chunks = getattr(auto_grad_checkpoint, 'chunks', chunks) # for globally set chunks chunks = min(len(layer), chunks) need_grad = next(layer.parameters()).requires_grad if use_grad_checkpoint and need_grad and chunks > 0: return checkpoint_sequential(layer, chunks, x) return layer(x)
def val_one_epoch(self): """ Validate network in one epoch every m training epochs, m is defined in params.val_every """ # TODO: add IoU compute function print('Validating:') # set mode eval self.network.eval() # prepare data val_loss = 0 val_loader = DataLoader(self.datasets['val'], batch_size=self.params.val_batch, shuffle=self.params.shuffle, drop_last=True, num_workers=self.params.dataloader_workers) val_size = len(self.datasets['val']) if val_size % self.params.val_batch != 0: total_batch = val_size // self.params.val_batch + 1 else: total_batch = val_size // self.params.val_batch # validate through dataset for batch_idx, batch in enumerate(val_loader): self.pb.click(batch_idx, total_batch) image, label = batch['image'], batch['label'] image_cuda, label_cuda = image.cuda(), label.cuda() # checkpoint split if self.params.should_split: image_cuda.requires_grad_() out = checkpoint_sequential(self.network, self.params.split, image_cuda) else: out = self.network(image_cuda) loss = self.loss_fn(out, label_cuda) val_loss += loss.item() # record first loss if self.val_loss == []: self.val_loss.append(val_loss) self.summary_writer.add_scalar('loss/val_loss', val_loss, 0) self.pb.close() val_loss /= total_batch self.val_loss.append(val_loss) # add to summary self.summary_writer.add_scalar('loss/val_loss', val_loss, self.epoch)
def forward(self, x): x = self.input_layer(x) if self.use_checkpoint: # x = checkpoint_sequential(self.input_layer,2,x) x=checkpoint_sequential(self.body,self.chunks,x) # x = checkpoint_sequential(self.output_layer,2,x) else: # x = self.input_layer(x) x = self.body(x) # x = self.output_layer(x) x = self.output_layer(x) return l2_norm(x)
def test_checkpoint_sequential_deprecated_no_args(self): class Noop(nn.Module): def forward(self): pass model = nn.Sequential(Noop()) self.assertWarnsRegex( lambda: checkpoint_sequential(model, 1), 'deprecated', 'checkpoint_sequential with no args should be deprecated', )
def forward(self, input): embed = self.embedding(input) input_var = embed.detach() input_var.requires_grad = True out = checkpoint(self.bilstm, input_var) out1 = out[0].permute(1, 2, 0) pooling = kmax_pooling(out1, 2, self.kmax_pooling) # batch * hidden *kmax flatten = pooling.view(pooling.size(0), -1) out2 = checkpoint_sequential(self.fc, 4, flatten) logits = self.linear(out2) return logits
def forward(self, input_t): """ Input has shape (batch_size, 1, audio_length,) Output has shape (batch_size, audio_length,) """ batch_size = input_t.shape[0] assert input_t.shape[1] == 1 audio_length = input_t.shape[2] # Use gradient checkpointing to save GPU memory. modules = [m for m in self.convs._modules.values()] conv_t = checkpoint_sequential(modules, GRAD_CHECKPOINT_SEGMENTS, input_t) conv_t = conv_t.squeeze(dim=1) return self.tanh(conv_t)
def forward(self, x): if self.checkpoint is True: # modules = [module for k, module in self._modules.items()][0] input = x.detach() input.requires_grad = True # input = self.first_layer(input) input = checkpoint_sequential(self.feature_layers, 2, input) else: input = self.feature_layers(input) input = input.view(input.size(0), -1) input = self.classifier(input) return input
def test_checkpoint_valid(self): model = nn.Sequential( nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), nn.Linear(20, 5), nn.ReLU() ) input_var = torch.randn(1, 100, requires_grad=True) # checkpointed chunks = 2 modules = list(model.children()) out = checkpoint_sequential(modules, chunks, input_var) with self.assertRaisesRegex(RuntimeError, "Checkpointing is not compatible"): torch.autograd.grad( outputs=[out], grad_outputs=[torch.ones(1, 5)], inputs=[input_var], create_graph=True )