import torch import torch.nn as nn import numpy as np from torch import optim from model import UNet from torchvision import models model = UNet(3, 1) model = nn.DataParallel(model) model_dict = model.state_dict() pretrained_dict = torch.load('CP50.pth') pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) model.load_state_dict(model_dict) train_params = [] for k, v in model.named_parameters(): train_params.append(k) pref = k[:12] print(pref) if pref == 'module.conv1' or pref == 'module.conv2': v.requires_grad = False train_params.remove(k) for k, v in model.named_parameters(): if k not in train_params: print(v.requires_grad)
def UNet_train(self): model = UNet(in_ch = args.in_ch, out_ch = args.out_ch, kernel_size = args.kernel_size).to(device) optimizer = torch.optim.SGD(model.parameters(), lr = args.lr, momentum=0.99) criterion = nn.CrossEntropyLoss() iters = np.ceil(self.train_imgs.size(0)/args.batch_size).astype(int) print("\nSteps per epoch = {}\n".format(iters)) best_acc = 0 test_imgs = self.test_imgs test_labels = self.test_labels print("="*70 +"\n\t\t\t Training Network\n"+ "="*70) start = time.time() for epoch in range(args.epochs): print(epoch) train_loss = [] # Shuffling the data permute_idxs = np.random.permutation(len(self.train_labels)) train_imgs = self.train_imgs[permute_idxs] train_labels = self.train_labels[permute_idxs] for step in range(iters): start = step*args.batch_size stop = (step+1)*args.batch_size # Get batches train_batch_imgs = train_imgs[start:stop].float() train_batch_labels = train_labels[start: stop].long() # Get predictions optimizer.zero_grad() out = model(train_batch_imgs) # Calculate Loss # out = out.permute(0, 2, 3, 1) # out = out.resize(args.batch_size * args.out_height * args.out_breadth, 2) # train_batch_labels = train_batch_labels.resize(args.batch_size * args.out_height * args.out_breadth) out = out.resize(train_batch_imgs.size(0)*args.out_height*args.out_breadth, args.out_ch) # print(train_batch_labels.size()) train_batch_labels = train_batch_labels.resize(train_batch_labels.size(0)*args.out_height*args.out_breadth) loss = criterion(out, train_batch_labels) # Backprop loss.backward() optimizer.step() train_loss.append(loss.item()) avg_train_loss = round(np.mean(train_loss),4) preds = torch.max(out.data,1)[1] correct = preds.long().eq(train_batch_labels.long()).cpu().sum().item() train_acc = correct/(iters*args.out_height*args.out_breadth) writer.add_scalar('Train/Loss', avg_train_loss, epoch+1) writer.add_scalar('Train/Accuracy', train_acc, epoch+1) for name, param in model.named_parameters(): if not param.requires_grad: continue writer.add_histogram('epochs/'+name, param.data.view(-1), global_step = epoch+1) if epoch % args.eval_every == 0: avg_test_loss, test_acc = self.get_val_results(test_imgs, test_labels, model) writer.add_scalar('Test/Loss', avg_test_loss, epoch+1) writer.add_scalar('Test/Accuracy', test_acc, epoch+1) if test_acc > best_acc: best_acc = test_acc print("\nNew High Score! Saving model...\n") torch.save(model.state_dict(), self.model_path+"/model.pickle") end = time.time() h,m,s = calc_elapsed_time(start, end) print("\nEpoch: {}/{}, Train_loss = {:.4f}, Train_acc = {:.4f}, Val_loss = {:.4f}, Val_acc = {:.4f}" .format(epoch+1, args.epochs, avg_train_loss, train_acc, avg_test_loss, test_acc)) print("\n"+"="*50 + "\n\t Training Done \n") print("\nBest Val accuracy = ", best_acc)
print("=> load checkpoint : {}".format(args.resume)) checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint["state_dict"]) print("start epoch:{}, loss:{}".format(checkpoint["epoch"], checkpoint["loss"])) start_epoch += checkpoint["epoch"] if args.cuda: model = model.cuda() if args.mGPU: model = DataParallel(model) #cudnn.benchmark = True #optimizer = optim.Adagrad(model.parameters(), args.lr) optimizer = optim.SGD([{ 'params': [ param for name, param in model.named_parameters() if name[-4:] == 'bias' ], 'lr': 2 * args.lr }, { 'params': [ param for name, param in model.named_parameters() if name[-4:] != 'bias' ], 'lr': args.lr, 'weight_decay': 5e-4 }], momentum=0.9) model.train()