Пример #1
0
import torch
import torch.nn as nn
import numpy as np
from torch import optim
from model import UNet
from torchvision import models

model = UNet(3, 1)
model = nn.DataParallel(model)
model_dict = model.state_dict()
pretrained_dict = torch.load('CP50.pth')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

model_dict.update(pretrained_dict)

model.load_state_dict(model_dict)
train_params = []
for k, v in model.named_parameters():
    train_params.append(k)
    pref = k[:12]
    print(pref)
    if pref == 'module.conv1' or pref == 'module.conv2':
        v.requires_grad = False
        train_params.remove(k)

for k, v in model.named_parameters():
    if k not in train_params:
        print(v.requires_grad)
Пример #2
0
    def UNet_train(self):
                
        model = UNet(in_ch = args.in_ch, out_ch = args.out_ch, kernel_size = args.kernel_size).to(device)
        optimizer = torch.optim.SGD(model.parameters(), lr = args.lr, momentum=0.99)
        criterion = nn.CrossEntropyLoss()
        
        iters = np.ceil(self.train_imgs.size(0)/args.batch_size).astype(int)
        print("\nSteps per epoch =  {}\n".format(iters))
        best_acc = 0
        test_imgs = self.test_imgs
        test_labels = self.test_labels
        
        print("="*70 +"\n\t\t\t Training Network\n"+ "="*70)
        start = time.time()
        for epoch in range(args.epochs):
            print(epoch)
            train_loss = []
            
            # Shuffling the data
            permute_idxs = np.random.permutation(len(self.train_labels))
            train_imgs = self.train_imgs[permute_idxs]
            train_labels = self.train_labels[permute_idxs]
            for step in range(iters):
                start = step*args.batch_size
                stop = (step+1)*args.batch_size
                
                # Get batches
                train_batch_imgs = train_imgs[start:stop].float()
                train_batch_labels = train_labels[start: stop].long()

                # Get predictions
                optimizer.zero_grad()
                out = model(train_batch_imgs)
                
                # Calculate Loss
                # out = out.permute(0, 2, 3, 1)
                # out = out.resize(args.batch_size * args.out_height * args.out_breadth, 2)
                # train_batch_labels = train_batch_labels.resize(args.batch_size * args.out_height * args.out_breadth)
                out = out.resize(train_batch_imgs.size(0)*args.out_height*args.out_breadth, args.out_ch)
                # print(train_batch_labels.size())
                train_batch_labels = train_batch_labels.resize(train_batch_labels.size(0)*args.out_height*args.out_breadth)
                loss = criterion(out, train_batch_labels)

                # Backprop
                loss.backward()
                optimizer.step()

                train_loss.append(loss.item())
                avg_train_loss = round(np.mean(train_loss),4)
                preds = torch.max(out.data,1)[1]
                correct = preds.long().eq(train_batch_labels.long()).cpu().sum().item()
                train_acc = correct/(iters*args.out_height*args.out_breadth)

                writer.add_scalar('Train/Loss', avg_train_loss, epoch+1)
                writer.add_scalar('Train/Accuracy', train_acc, epoch+1)
                for name, param in model.named_parameters():
                    if not param.requires_grad:
                        continue
                    writer.add_histogram('epochs/'+name, param.data.view(-1), global_step = epoch+1)
                
            if epoch % args.eval_every == 0:
                avg_test_loss, test_acc = self.get_val_results(test_imgs, test_labels, model)
                writer.add_scalar('Test/Loss', avg_test_loss, epoch+1)
                writer.add_scalar('Test/Accuracy', test_acc, epoch+1)
                if test_acc > best_acc:
                    best_acc = test_acc
                    print("\nNew High Score! Saving model...\n")
                    torch.save(model.state_dict(), self.model_path+"/model.pickle")

                end = time.time()
                h,m,s = calc_elapsed_time(start, end)
                print("\nEpoch: {}/{},  Train_loss = {:.4f},  Train_acc = {:.4f},   Val_loss = {:.4f},    Val_acc = {:.4f}"
                      .format(epoch+1, args.epochs, avg_train_loss, train_acc, avg_test_loss, test_acc))

        print("\n"+"="*50 + "\n\t Training Done \n")
        print("\nBest Val accuracy = ", best_acc)
Пример #3
0
         print("=>  load checkpoint : {}".format(args.resume))
         checkpoint = torch.load(args.resume)
         model.load_state_dict(checkpoint["state_dict"])
         print("start epoch:{}, loss:{}".format(checkpoint["epoch"],
                                                checkpoint["loss"]))
         start_epoch += checkpoint["epoch"]
 if args.cuda:
     model = model.cuda()
 if args.mGPU:
     model = DataParallel(model)
     #cudnn.benchmark = True
 #optimizer = optim.Adagrad(model.parameters(), args.lr)
 optimizer = optim.SGD([{
     'params': [
         param
         for name, param in model.named_parameters() if name[-4:] == 'bias'
     ],
     'lr':
     2 * args.lr
 }, {
     'params': [
         param
         for name, param in model.named_parameters() if name[-4:] != 'bias'
     ],
     'lr':
     args.lr,
     'weight_decay':
     5e-4
 }],
                       momentum=0.9)
 model.train()