Exemplo n.º 1
0
def train():
    ex = wandb.init(project="PQRST-segmentation")
    ex.config.setdefaults(wandb_config)

    logging.basicConfig(level=logging.INFO,
                        format="%(levelname)s: %(message)s")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    net = UNet(in_ch=1, out_ch=4)
    net.to(device)

    try:
        train_model(net=net,
                    device=device,
                    batch_size=wandb.config.batch_size,
                    lr=wandb.config.lr,
                    epochs=wandb.config.epochs)
    except KeyboardInterrupt:
        try:
            save = input("save?(y/n)")
            if save == "y":
                torch.save(net.state_dict(), 'net_params.pkl')
            sys.exit(0)
        except SystemExit:
            os._exit(0)
Exemplo n.º 2
0
def lr_find(model: UNet,
            data_loader,
            optimizer: Optimizer,
            criterion,
            use_gpu,
            min_lr=0.0001,
            max_lr=0.1):
    # Save model and optimizer states to revert
    model_state = model.state_dict()
    optimizer_state = optimizer.state_dict()

    losses = []
    lrs = []
    scheduler = CyclicExpLR(optimizer,
                            min_lr,
                            max_lr,
                            step_size_up=100,
                            mode='triangular',
                            cycle_momentum=True)
    model.train()
    for i, (data, target, class_ids) in enumerate(data_loader):
        data, target = data, target

        if use_gpu:
            data = data.cuda()
            target = target.cuda()

        optimizer.zero_grad()
        output_raw = model(data)
        # This step is specific for this project
        output = torch.zeros(output_raw.shape[0], 1, output_raw.shape[2],
                             output_raw.shape[3])

        if use_gpu:
            output = output.cuda()

        # This step is specific for this project
        for idx, (raw_o, class_id) in enumerate(zip(output_raw, class_ids)):
            output[idx] = raw_o[class_id - 1]

        loss = criterion(output, target)
        loss.backward()
        current_lr = optimizer.param_groups[0]['lr']
        # Stop if lr stopped increasing
        if len(lrs) > 0 and current_lr < lrs[-1]:
            break
        lrs.append(current_lr)
        losses.append(loss.item())
        optimizer.step()
        scheduler.step()

    # Plot in log scale
    plt.plot(lrs, losses)
    plt.xscale('log')

    plt.show()

    model.load_state_dict(model_state)
    optimizer.load_state_dict(optimizer_state)
Exemplo n.º 3
0
def main():
    train_transform = A.Compose([
        A.Resize(height=config.IMAGE_HEIGHT, width=config.IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.normalize(mean=[0.0, 0.0, 0.0],
                    std=[1.0, 1.0, 1.0],
                    max_pixel_value=255.0),
        ToTensorV2,
    ])

    val_transform = A.Compose([
        A.Resize(height=config.IMAGE_HEIGHT, width=config.IMAGE_WIDTH),
        A.normalize(mean=[0.0, 0.0, 0.0],
                    std=[1.0, 1.0, 1.0],
                    max_pixel_value=255.0),
        ToTensorV2,
    ])

    model = UNet(in_channels=3, out_channels=1).to(config.DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        config.TRAIN_IMAGE_DIR,
        config.TRAIN_MASK_DIR,
        config.VAL_IMG_DIR,
        config.VAL_MASK_DIR,
        config.BATCH_SIZE,
        train_transform,
        val_transform,
    )

    if config.LOAD_MODEL:
        load_checkpoint(torch.load('my_checkpoint.pth.tar'), model)

    scaler = torch.cuda.amp.GradScaler()
    for epoch in range(config.NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # save model
        checkpoint = {
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        # check acc
        check_accuracy(val_loader, model, device=config.DEVICE)

        # print some examples to a folder
        save_predictions_as_imgs(val_loader,
                                 model,
                                 folder='saved_images',
                                 device=config.DEVICE)
Exemplo n.º 4
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--file_paths', default="data/files.txt")
    parser.add_argument('--landmark_paths', default="data/landmarks.txt")
    parser.add_argument('--landmark', type=int, default=0)
    parser.add_argument('--save_path')
    parser.add_argument('--num_epochs', type=int, default=int(1e9))
    parser.add_argument('--log_freq', type=int, default=100)
    parser.add_argument('--separator', default=",")
    parser.add_argument('--batch_size', type=int, default=8)
    args = parser.parse_args()

    file_paths = args.file_paths
    landmark_paths = args.landmark_paths
    landmark_wanted = args.landmark
    num_epochs = args.num_epochs
    log_freq = args.log_freq
    save_path = args.save_path

    x, y = get_data(file_paths,
                    landmark_paths,
                    landmark_wanted,
                    separator=args.separator)
    print(f"Got {len(x)} images with {len(y)} landmarks")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("device", device)

    dataset = TensorDataset(torch.Tensor(x), torch.Tensor(y))
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    unet = UNet(in_dim=1, out_dim=6, num_filters=4)
    criterion = torch.nn.CrossEntropyLoss(weight=get_weigths(y))
    optimizer = optim.SGD(unet.parameters(), lr=0.001, momentum=0.9)

    unet.to(device)

    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(dataloader):
            inputs, labels = data
            optimizer.zero_grad()

            outputs = unet(inputs)
            loss = criterion(outputs, labels.long())
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        print(f"[{epoch+1}/{num_epochs}] loss: {running_loss}")
        if epoch % log_freq == log_freq - 1:
            if save_path is not None:
                torch.save(unet.state_dict(),
                           os.path.join(save_path, f"unet-{epoch}.pt"))
Exemplo n.º 5
0
Arquivo: train.py Projeto: Onojimi/try
    net = UNet(input_channels=3, nclasses=1)
    writer = SummaryWriter(log_dir='../../log/sn1', comment='unet')
    #     net.cuda()
    #     import pdb
    #     from torchsummary import summary
    #     summary(net, (3,1000,1000))
    #     pdb.set_trace()

    if args.gpu:
        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)
        net.cuda()

    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  gpu=args.gpu,
                  writer=writer,
                  load=args.load)

        torch.save(net.state_dict(), 'model_fin.pth')

    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'interrupt.pth')
        print('saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
Exemplo n.º 6
0
        if LOSS_NUM == 3 or LOSS_NUM == 4:
            wts = torch.from_numpy(wts).float().cuda()

        data = torch.from_numpy(data).float()
        target = torch.from_numpy(target).float()

        if train_on_gpu:
            data, target = data.cuda(), target.cuda()

        optimizer.zero_grad()
        if LOSS_NUM in [5, 6]:
            output, _ = model(data)
        else:
            output = model(data)

        if LOSS_NUM == 3 or LOSS_NUM == 4:
            loss = criterion(output, target, wts)
        else:
            loss = criterion(output, target)

        loss.backward()

        optimizer.step()
        train_loss += loss.item() * data.size(0)

    train_loss = train_loss / numpts
    print("Epoch " + str(epoch) + " loss = " + str(train_loss))
    print("-------------")

torch.save(model.state_dict(), "checkpoint/model-" + model_string + ".ckpt")
class Instructor:
    ''' Model training and evaluation '''
    def __init__(self, opt):
        self.opt = opt
        if opt.inference:
            self.testset = TestImageDataset(fdir=opt.impaths['test'],
                                            imsize=opt.imsize)
        else:
            self.trainset = ImageDataset(fdir=opt.impaths['train'],
                                         bdir=opt.impaths['btrain'],
                                         imsize=opt.imsize,
                                         mode='train',
                                         aug_prob=opt.aug_prob,
                                         prefetch=opt.prefetch)
            self.valset = ImageDataset(fdir=opt.impaths['val'],
                                       bdir=opt.impaths['bval'],
                                       imsize=opt.imsize,
                                       mode='val',
                                       aug_prob=opt.aug_prob,
                                       prefetch=opt.prefetch)
        self.model = UNet(n_channels=3,
                          n_classes=1,
                          bilinear=self.opt.use_bilinear)
        if opt.checkpoint:
            self.model.load_state_dict(
                torch.load('./state_dict/{:s}'.format(opt.checkpoint),
                           map_location=self.opt.device))
            print('checkpoint {:s} has been loaded'.format(opt.checkpoint))
        if opt.multi_gpu == 'on':
            self.model = torch.nn.DataParallel(self.model)
        self.model = self.model.to(opt.device)
        self._print_args()

    def _print_args(self):
        n_trainable_params, n_nontrainable_params = 0, 0
        for p in self.model.parameters():
            n_params = torch.prod(torch.tensor(p.shape))
            if p.requires_grad:
                n_trainable_params += n_params
            else:
                n_nontrainable_params += n_params
        self.info = 'n_trainable_params: {0}, n_nontrainable_params: {1}\n'.format(
            n_trainable_params, n_nontrainable_params)
        self.info += 'training arguments:\n' + '\n'.join([
            '>>> {0}: {1}'.format(arg, getattr(self.opt, arg))
            for arg in vars(self.opt)
        ])
        if self.opt.device.type == 'cuda':
            print('cuda memory allocated:',
                  torch.cuda.memory_allocated(opt.device.index))
        print(self.info)

    def _reset_records(self):
        self.records = {
            'best_epoch': 0,
            'best_dice': 0,
            'train_loss': list(),
            'val_loss': list(),
            'val_dice': list(),
            'checkpoints': list()
        }

    def _update_records(self, epoch, train_loss, val_loss, val_dice):
        if val_dice > self.records['best_dice']:
            path = './state_dict/{:s}_dice{:.4f}_temp{:s}.pt'.format(
                self.opt.model_name, val_dice,
                str(time.time())[-6:])
            if self.opt.multi_gpu == 'on':
                torch.save(self.model.module.state_dict(), path)
            else:
                torch.save(self.model.state_dict(), path)
            self.records['best_epoch'] = epoch
            self.records['best_dice'] = val_dice
            self.records['checkpoints'].append(path)
        self.records['train_loss'].append(train_loss)
        self.records['val_loss'].append(val_loss)
        self.records['val_dice'].append(val_dice)

    def _draw_records(self):
        timestamp = str(int(time.time()))
        print('best epoch: {:d}'.format(self.records['best_epoch']))
        print('best train loss: {:.4f}, best val loss: {:.4f}'.format(
            min(self.records['train_loss']), min(self.records['val_loss'])))
        print('best val dice {:.4f}'.format(self.records['best_dice']))
        os.rename(
            self.records['checkpoints'][-1],
            './state_dict/{:s}_dice{:.4f}_save{:s}.pt'.format(
                self.opt.model_name, self.records['best_dice'], timestamp))
        for path in self.records['checkpoints'][0:-1]:
            os.remove(path)
        # Draw figures
        plt.figure()
        trainloss, = plt.plot(self.records['train_loss'])
        valloss, = plt.plot(self.records['val_loss'])
        plt.legend([trainloss, valloss], ['train', 'val'], loc='upper right')
        plt.title('{:s} loss curve'.format(timestamp))
        plt.savefig('./figs/{:s}_loss.png'.format(timestamp),
                    format='png',
                    transparent=True,
                    dpi=300)
        plt.figure()
        valdice, = plt.plot(self.records['val_dice'])
        plt.title('{:s} dice curve'.format(timestamp))
        plt.savefig('./figs/{:s}_dice.png'.format(timestamp),
                    format='png',
                    transparent=True,
                    dpi=300)
        # Save report
        report = '\t'.join(
            ['val_dice', 'train_loss', 'val_loss', 'best_epoch', 'timestamp'])
        report += "\n{:.4f}\t{:.4f}\t{:.4f}\t{:d}\t{:s}\n{:s}".format(
            self.records['best_dice'], min(self.records['train_loss']),
            min(self.records['val_loss']), self.records['best_epoch'],
            timestamp, self.info)
        with open('./logs/{:s}_log.txt'.format(timestamp), 'w') as f:
            f.write(report)
        print('report saved:', './logs/{:s}_log.txt'.format(timestamp))

    def _train(self, train_dataloader, criterion, optimizer):
        self.model.train()
        train_loss, n_total, n_batch = 0, 0, len(train_dataloader)
        for i_batch, sample_batched in enumerate(train_dataloader):
            inputs, target = sample_batched[0].to(
                self.opt.device), sample_batched[1].to(self.opt.device)
            predict = self.model(inputs)

            optimizer.zero_grad()
            loss = criterion(predict, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * len(sample_batched)
            n_total += len(sample_batched)

            ratio = int((i_batch + 1) * 50 / n_batch)
            sys.stdout.write("\r[" + ">" * ratio + " " * (50 - ratio) +
                             "] {}/{} {:.2f}%".format(i_batch + 1, n_batch,
                                                      (i_batch + 1) * 100 /
                                                      n_batch))
            sys.stdout.flush()
        print()
        return train_loss / n_total

    def _evaluation(self, val_dataloader, criterion):
        self.model.eval()
        val_loss, val_dice, n_total = 0, 0, 0
        with torch.no_grad():
            for sample_batched in val_dataloader:
                inputs, target = sample_batched[0].to(
                    self.opt.device), sample_batched[1].to(self.opt.device)
                predict = self.model(inputs)
                loss = criterion(predict, target)
                dice = dice_coeff(predict, target)
                val_loss += loss.item() * len(sample_batched)
                val_dice += dice.item() * len(sample_batched)
                n_total += len(sample_batched)
        return val_loss / n_total, val_dice / n_total

    def run(self):
        _params = filter(lambda p: p.requires_grad, self.model.parameters())
        optimizer = torch.optim.Adam(_params,
                                     lr=self.opt.lr,
                                     weight_decay=self.opt.l2reg)
        criterion = BCELoss2d()
        train_dataloader = DataLoader(dataset=self.trainset,
                                      batch_size=self.opt.batch_size,
                                      shuffle=True)
        val_dataloader = DataLoader(dataset=self.valset,
                                    batch_size=self.opt.batch_size,
                                    shuffle=False)
        self._reset_records()
        for epoch in range(self.opt.num_epoch):
            train_loss = self._train(train_dataloader, criterion, optimizer)
            val_loss, val_dice = self._evaluation(val_dataloader, criterion)
            self._update_records(epoch, train_loss, val_loss, val_dice)
            print(
                '{:d}/{:d} > train loss: {:.4f}, val loss: {:.4f}, val dice: {:.4f}'
                .format(epoch + 1, self.opt.num_epoch, train_loss, val_loss,
                        val_dice))
        self._draw_records()

    def inference(self):
        test_dataloader = DataLoader(dataset=self.testset,
                                     batch_size=1,
                                     shuffle=False)
        n_batch = len(test_dataloader)
        with torch.no_grad():
            for i_batch, sample_batched in enumerate(test_dataloader):
                index, inputs = sample_batched[0], sample_batched[1].to(
                    self.opt.device)
                predict = self.model(inputs)
                self.testset.save_img(index.item(), predict, self.opt.use_crf)
                ratio = int((i_batch + 1) * 50 / n_batch)
                sys.stdout.write(
                    "\r[" + ">" * ratio + " " * (50 - ratio) +
                    "] {}/{} {:.2f}%".format(i_batch + 1, n_batch,
                                             (i_batch + 1) * 100 / n_batch))
                sys.stdout.flush()
        print()
Exemplo n.º 8
0
def train(cfg_path, device='cuda'):
    if cfg_path is not None:
        cfg.merge_from_file(cfg_path)
    cfg.freeze()

    if not os.path.isdir(cfg.LOG_DIR):
        os.makedirs(cfg.LOG_DIR)
    if not os.path.isdir(cfg.SAVE_DIR):
        os.makedirs(cfg.SAVE_DIR)

    model = UNet(cfg.NUM_CHANNELS, cfg.NUM_CLASSES)
    model.to(device)

    train_data_loader = build_data_loader(cfg, 'train')
    if cfg.VAL:
        val_data_loader = build_data_loader(cfg, 'val')
    else:
        val_data_loader = None

    optimizer = build_optimizer(cfg, model)
    lr_scheduler = build_lr_scheduler(cfg, optimizer)
    criterion = get_loss_func(cfg)
    writer = SummaryWriter(cfg.LOG_DIR)

    iter_counter = 0
    loss_meter = AverageMeter()
    val_loss_meter = AverageMeter()
    min_val_loss = 1e10

    print('Training Start')
    for epoch in range(cfg.SOLVER.MAX_EPOCH):
        print('Epoch {}/{}'.format(epoch + 1, cfg.SOLVER.MAX_EPOCH))
        if lr_scheduler is not None:
            lr_scheduler.step(epoch)
        for data in train_data_loader:
            iter_counter += 1

            imgs, annots = data
            imgs = imgs.to(device)
            annots = annots.to(device)

            y = model(imgs)
            optimizer.zero_grad()
            loss = criterion(y, annots)
            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())

            if iter_counter % 10 == 0:
                writer.add_scalars('loss', {'train': loss_meter.avg},
                                   iter_counter)
                loss_meter.reset()
            if lr_scheduler is not None:
                writer.add_scalar('learning rate',
                                  optimizer.param_groups[0]['lr'],
                                  iter_counter)
            save_as_checkpoint(model, optimizer,
                               os.path.join(cfg.SAVE_DIR, 'checkpoint.pth'),
                               epoch, iter_counter)

        # Skip validation when cfg.VAL is False
        if val_data_loader is None:
            continue

        for data in val_data_loader:
            val_loss_meter.reset()
            with torch.no_grad():
                imgs, annots = data
                imgs = imgs.to(device)
                annots = annots.to(device)

                y = model(imgs)
                optimizer.zero_grad()
                loss = criterion(y, annots)
                val_loss_meter.update(loss.item())
        if val_loss_meter.avg < min_val_loss:
            min_val_loss = val_loss_meter.avg
            writer.add_scalars('loss', {'val': val_loss_meter.avg},
                               iter_counter)
            # save model if validation loss is minimum
            torch.save(model.state_dict(),
                       os.path.join(cfg.SAVE_DIR, 'min_val_loss.pth'))
Exemplo n.º 9
0
            loss = criterion(output_v, label_v)

            val_loss += loss.item() / valSize

    loss_track.append((train_loss, loss_G, loss_D, val_loss))
    torch.save(loss_track, 'checkpoint_GAN/loss.pth')

    print(
        '[{:4d}/{}], tr_ls: {:.5f}, G_ls: {:.5f}, D_ls: {:.5f}, te_ls: {:.5f}'.
        format(epoch + 1, epoch_num, train_loss, loss_G, loss_D, val_loss))

    if epoch % 50 == 0:
        torch.save(
            {
                'epoch': epoch + 1,
                'state_dict_G': G.state_dict(),
                'state_dict_D': D.state_dict(),
                'optimizer_G': optimizer_G.state_dict(),
                'optimizer_D': optimizer_D.state_dict(),
            }, 'checkpoint_GAN/model.pth')

    if epoch % 5 == 0:
        img_out, img_gt = [], []
        sampleNum = 4
        for i in np.random.choice(32, sampleNum, replace=False):
            img_out.append(
                TF.to_tensor(onehot2img(output[i].detach().to('cpu').numpy())))
            img_gt.append(
                TF.to_tensor(label2img(label[i].detach().to('cpu').numpy())))
        for i in np.random.choice(17, sampleNum, replace=False):
            img_out.append(
Exemplo n.º 10
0
            img = (data[0].transpose(0, 1).transpose(1,
                                                     2).detach().cpu().numpy())
            output_mask = np.abs(
                create_boolean_mask(output[0][0].cpu().detach().numpy()) *
                (-1))
            target_mask = target[0][0].cpu().numpy().astype(np.bool)

            img[output_mask == 1, 0] = 1
            img[target_mask == 1, 1] = 1
            # overlapping regions look yellow
            plt.imshow(img)
            plt.savefig(f'training_process/training_{epoch}.png')
            print('Training Epoch: {} - Loss: {:.6f}'.format(
                epoch + 1, total_training_loss / len(df_train)))
            torch.save(model.state_dict(), 'model.pth')

        # Validation Loop
        model.eval()
        total_validation_loss = 0.0
        for i, (data, target, class_ids) in enumerate(validation_loader):
            data, target = data, target
            if use_gpu:
                data = data.cuda()
                target = target.cuda()

            optimizer.zero_grad()
            output = model.predict(data, use_gpu, class_ids)

            loss = criterion(output, target)
            total_validation_loss += loss.item()
Exemplo n.º 11
0
            val_loss += loss.item() / valSize

    scheduler.step()
    loss_track.append((train_loss, val_loss))
    torch.save(loss_track, 'checkpoint/loss.pth')

    print(
        '[{:4d}/{}] lr: {:.5f}, train_loss: {:.5f}, test_loss: {:.5f}'.format(
            epoch + 1, epoch_num, optimizer.param_groups[0]['lr'], train_loss,
            val_loss))

    if epoch % 50 == 0:
        torch.save(
            {
                'epoch': epoch + 1,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
            }, 'checkpoint/model.pth')

    if epoch % 5 == 0:
        img_out, img_gt = [], []
        sampleNum = 4
        for i in np.random.choice(32, sampleNum, replace=False):
            img_out.append(
                TF.to_tensor(onehot2img(output[i].detach().to('cpu').numpy())))
            img_gt.append(
                TF.to_tensor(label2img(label[i].detach().to('cpu').numpy())))
        for i in np.random.choice(17, sampleNum, replace=False):
            img_out.append(
                TF.to_tensor(onehot2img(
Exemplo n.º 12
0
class TrainModel():
    def __init__(self,
                 cur_dir,
                 suffix='.tif',
                 cuda=True,
                 testBatchSize=4,
                 batchSize=4,
                 nEpochs=200,
                 lr=0.01,
                 threads=4,
                 seed=123,
                 size=256,
                 input_transform=True,
                 target_transform=True):
        #        super(TrainModel, self).__init__()

        self.data_dir = cur_dir + '/data/'
        self.suffix = suffix
        """
        training parameters are set here

        """
        self.colordim = 1
        self.cuda = cuda
        if self.cuda and not torch.cuda.is_available():
            raise Exception("No GPU found, please run without --cuda")
        self.testBatchSize = testBatchSize
        self.batchSize = batchSize
        self.nEpochs = nEpochs
        self.lr = lr
        self.threads = threads
        self.seed = seed
        self.size = size

        self.input_transform = input_transform
        self.target_transform = target_transform
        self.__check_dir = cur_dir + '/checkpoint'
        if not exists(self.__check_dir):
            os.mkdir(self.__check_dir)
        self.__epoch_dir = cur_dir + '/epoch'
        if not exists(self.__epoch_dir):
            os.mkdir(self.__epoch_dir)
        """
        initialize the model
        """

        if self.cuda:
            self.unet = UNet(self.colordim).cuda()
            self.criterion = nn.MSELoss().cuda()
        else:
            self.unet = UNet(self.colordim)
            self.criterion = nn.MSELoss()

        self.optimizer = optim.SGD(self.unet.parameters(),
                                   lr=self.lr,
                                   momentum=0.9,
                                   weight_decay=0.0001)

    def __dir_exist(self, cur_dir):
        if not exists(cur_dir):
            sys.exit(cur_dir + ' does not exist...')
        return cur_dir
        """
        need to be completed later
        add some other functions such as function for checking if there are
        train and test subfolders in the directory
        """

    def __get_training_set(self):
        root_dir = self.__dir_exist(self.data_dir)
        train_dir = self.__dir_exist(join(root_dir, "train"))
        return DatasetFromFolder(train_dir,
                                 colordim=self.colordim,
                                 size=self.size,
                                 _input_transform=self.input_transform,
                                 _target_transform=self.target_transform,
                                 suffix=self.suffix)

    def __get_test_set(self):
        root_dir = self.__dir_exist(self.data_dir)
        test_dir = self.__dir_exist(join(root_dir, "test"))
        return DatasetFromFolder(test_dir,
                                 colordim=self.colordim,
                                 size=self.size,
                                 _input_transform=self.input_transform,
                                 _target_transform=self.target_transform,
                                 suffix=self.suffix)

    def __get_ready_for_data(self):

        self.train_set = self.__get_training_set()
        self.test_set = self.__get_test_set()
        self.training_data_loader = DataLoader(dataset=self.train_set,
                                               num_workers=self.threads,
                                               batch_size=self.batchSize,
                                               shuffle=True)
        self.testing_data_loader = DataLoader(dataset=self.test_set,
                                              num_workers=self.threads,
                                              batch_size=self.testBatchSize,
                                              shuffle=False)

    def __train(self, epoch):

        epoch_loss = 0

        for iteration, (batch_x,
                        batch_y) in enumerate(self.training_data_loader):

            input = Variable(batch_x)
            target = Variable(batch_y)
            if self.cuda:
                input = input.cuda()
                target = target.cuda()

            self.optimizer.zero_grad()
            input = self.unet(input)
            loss = self.criterion(input, target)
            epoch_loss += (loss.data[0])
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if iteration % 50 is 0:
                print("===> Epoch[{}]({}/{}) : Loss: {:.4f}".format(
                    epoch, iteration, len(self.training_data_loader),
                    loss.data[0]))

        result1 = input.cuda()
        imgout = torch.cat([target, result1], 2)
        torchvision.utils.save_image(
            imgout.data, self.__epoch_dir + '/' + str(epoch) + self.suffix)
        print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(
            epoch, epoch_loss / len(self.training_data_loader)))

        return epoch_loss / len(self.training_data_loader)

    def __test(self):

        totalloss = 0
        for batch in self.testing_data_loader:

            input = Variable(batch[0], volatile=True)
            target = Variable(batch[1][:, :, :, :], volatile=True)

            if self.cuda:
                input = input.cuda()
                target = target.cuda()
            self.optimizer.zero_grad()
            prediction = self.unet(input)
            loss = self.criterion(prediction, target)
            totalloss += loss.data[0]

        print("===> Avg. test loss: {:,.4f} dB".format(
            totalloss / len(self.testing_data_loader)))

    def __checkpoint(self, epoch):

        model_out_path = (self.__check_dir +
                          "/model_epoch_{}.pth").format(epoch)
        torch.save(self.unet.state_dict(), model_out_path)
#        self.__print("Checkpoint saved to {}.".format(model_out_path))

    def run(self):

        if self.cuda:
            torch.cuda.manual_seed(self.seed)
        else:
            torch.manual_seed(self.seed)

        print("===> Loading data")
        self.__get_ready_for_data()

        print("===> Building unet")
        print("===> Training unet")

        for epoch in range(1, self.nEpochs + 1):

            avg_loss = self.__train(epoch)

            if epoch % 20 is 0:
                self.__checkpoint(epoch)
                self.__test()

        if not self.__exit:
            self.__checkpoint(epoch)

    def use_model_on_one_image(self, image_path, model_path, save_path):
        """
        use an existed model on one image
        """
        if self.cuda:
            self.unet.load_state_dict(torch.load(model_path))
        else:
            self.unet.load_state_dict(
                torch.load(model_path,
                           map_location=lambda storage, loc: storage))

        ori_image = Image.open(image_path).convert('L')
        transform = ToTensor()

        input = transform(ori_image)
        if self.cuda:
            input = Variable(input.cuda())
        else:
            input = Variable(input)
        input = torch.squeeze(input, 0)

        output = unet(input)

        if self.cuda:
            output = output.cuda()

        result = torch.cat([input.data, output.data], 0)

        torchvision.utils.save_image(result, save_path)
Exemplo n.º 13
0
Arquivo: train.py Projeto: krsrv/UNet
def train(epochs=10, lr=0.001, n_class=1, in_channel=1, loss_fn='BCE', display=False, save=False, \
  load=False, directory='../Data/train/', img_size=None, data_size=None, load_file=None, save_file=None):
    #if torch.cuda.is_available():
    #  torch.cuda.set_device(1)
    # Dataset
    dataset = get_dataset(directory, img_size, data_size)

    #optimizer = torch.optim.SGD(model.parameters(), lr = lr, momentum = momentum, weight_decay = decay)
    print("Epochs:\t{}\nLearning Rate:\t{}\nOutput classes:\t{}\nInput channels:\t{}\n\
Loss function:\t{}\nImage cropping size:\t{}\nDataset size:\t{}\n"                                                                    .format(epochs, lr, n_class, \
  in_channel, loss_fn, img_size, data_size))

    # Neural network model
    model = UNet(n_class,
                 in_channel).cuda() if torch.cuda.is_available() else UNet(
                     n_class, in_channel)

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_log = []

    if load:
        get_checkpoint(model, optimizer, loss_log, load_file)

    criterion = torch.nn.BCELoss()
    if loss_fn == 'CE':
        weights = torch.Tensor([10, 90])
        if torch.cuda.is_available():
            weights = weights.cuda()
        criterion = torch.nn.CrossEntropyLoss(weight=weights)

    for epoch in range(epochs):
        #print("Starting Epoch #{}".format(epoch))

        train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True)
        epoch_loss = 0

        for i, images in enumerate(train_loader):
            # get the inputs
            image, label = images['image'], images['label']

            # zero the parameter gradients
            optimizer.zero_grad()

            ## Run the forward pass
            outputs = model.forward(image).cuda() if torch.cuda.is_available(
            ) else model.forward(image)

            if display:
                T.ToPILImage()(outputs[0].float()).show()

            if loss_fn == 'CE':
                label = label.squeeze(1).long()
            elif loss_fn == 'BCE':
                label = label.float()

            loss = criterion(outputs, label)
            loss.backward()

            epoch_loss = epoch_loss + loss.item()

            optimizer.step()

            #if i % 10 == 0 :
            #  print("Epoch #{} Batch #{} Loss: {}".format(epoch,i,loss.item()))
        loss_log.append(epoch_loss)

        #print("Epoch",epoch," finished. Loss :",loss.item())
        print(epoch, loss.item())
        epoch_loss = 0
    if save:
        save_checkpoint(
            {
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'loss_log': loss_log,
            }, save_file)
    print(loss_log)
    #T.ToPILImage()(outputs[0].float()).show()

    if display:
        testloader = DataLoader(dataset=dataset, batch_size=1, shuffle=True)
        dataiter = iter(testloader)

        testimg = dataiter.next()
        img, lbl = testimg['image'], testimg['label']
        trained = model(img)
        thresholded = (trained > torch.tensor([0.5]))
        T.ToPILImage()(img[0]).show()
        T.ToPILImage()(lbl.float()).show()
        T.ToPILImage()((trained[0]).float()).show()
        T.ToPILImage()((thresholded[0]).float()).show()

        matching = (thresholded[0].long() == lbl.long()).sum()
        accuracy = float(matching) / lbl.numel()
        print("matching {}, total {}, accuracy {}".format(matching, lbl.numel(),\
        accuracy))
Exemplo n.º 14
0
def train():
    """Main training process."""
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    if opt.data_type == 'bacteria':
        train_h5pth = opt.h5_path + 'train.h5'
        val_h5pth = opt.h5_path + 'valid.h5'
    elif opt.data_type == 'cell':
        train_h5pth = opt.cell_h5_path + 'train.h5'
        val_h5pth = opt.cell_h5_path + 'valid.h5'

    train_data = FluoData(train_h5pth,
                          opt.data_type,
                          color=opt.color,
                          horizontal_flip=1.0 * opt.h_flip,
                          vertical_flip=1.0 * opt.v_flip)
    train_dataloader = DataLoader(train_data, batch_size=opt.batch_size)
    val_data = FluoData(val_h5pth,
                        opt.data_type,
                        color=opt.color,
                        horizontal_flip=0,
                        vertical_flip=0)
    val_dataloader = DataLoader(val_data, batch_size=opt.batch_size)

    if opt.model.find("UNet") != -1:
        model = UNet(input_filters=3, filters=opt.unet_filters,
                     N=opt.conv).to(device)
    elif opt.model == "FCRN_A":
        model = FCRN_A(input_filters=3, filters=opt.unet_filters,
                       N=opt.conv).to(device)
    model = torch.nn.DataParallel(model)

    # if os.path.exists('{}.pth'.format(opt.model)):
    #     model.load_state_dict(torch.load('{}.pth'.format(opt.model)))

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=opt.learning_rate,
                                momentum=0.9,
                                weight_decay=1e-5)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=20,
                                                   gamma=0.1)

    # if plot flag is on, create a live plot (to be updated by Looper)
    if opt.plot:
        plt.ion()
        fig, plots = plt.subplots(nrows=2, ncols=2)
    else:
        plots = [None] * 2

    best_result = float('inf')
    best_epoch = 0

    train_looper = Looper(model, device, criterion, optimizer,
                          train_dataloader, len(train_data), plots[0])
    valid_looper = Looper(model,
                          device,
                          criterion,
                          optimizer,
                          val_dataloader,
                          len(val_data),
                          plots[1],
                          validation=True)

    for epoch in range(opt.epochs):

        print("======= epoch {} =======".format(epoch))

        ###############################################
        ########         Training Phase        ########
        ###############################################
        train_looper.run()
        lr_scheduler.step()

        ###############################################
        ########       Validation Phase        ########
        ###############################################
        with torch.no_grad():
            result = valid_looper.run()

        if result < best_result:
            best_result = result
            best_epoch = epoch
            torch.save(model.state_dict(), '{}.pth'.format(opt.model))

            print(f"\nNew best result: {best_result}")

    print("[Training done] Best epoch: {}".format(best_epoch))
    print("[Training done] Best result: {}".format(best_result))
Exemplo n.º 15
0
                if phase == 'valid':
                    loss_valid.append(loss.item())  # validation loss값 리스트에 저장.
                    y_pred_np = y_pred.detach().cpu().numpy()  # y prediction numpy
                    # extend하는게 어떤 의미인지는 잘 모루게따..
                    validation_pred.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])])
                    y_true_np = y_true.detach().cpu().numpy()
                    validation_true.extend([y_true_np[s] for s in range(y_true_np.shape[0])])
                if phase == 'train':
                    loss_train.append(loss.item())
                    loss.backward()
                    optimizer.step()
        if phase == 'train':
            log_loss_summary(loss_train, epoch)
            loss_train = []

        if phase == 'valid':
            log_loss_summary(loss_valid, epoch, prefix='val_')
            mean_dsc = np.mean(
                dsc_per_volume(
                    validation_pred,
                    validation_true,
                    valid_loader.dataset.patient_slice_index,
                )
            )
            log_scalar_summary("val_dsc", mean_dsc, epoch)
            if mean_dsc > best_validation_dsc:
                best_validation_dsc = mean_dsc
                torch.save(unet.state_dict(), os.path.join('./', "unet.pt"))
            loss_valid = []
print("\nBest validation mean DSC: {:4f}\n".format(best_validation_dsc))
Exemplo n.º 16
0
    val_loader = DataLoader(val_ds,
                            batch_size=batchsize,
                            num_workers=2,
                            pin_memory=pin_memory,
                            shuffle=True)

    #Get model
    model = UNet(in_channels=3, out_channels=1)

    #Loss function and optimizer
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        train_fn(model, train_loader, loss_fn, optimizer, device="cuda")
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        #test
        train_loss, train_acc, train_dice_score = test_fn(
            model, train_loader, loss_fn, device)
        val_loss, val_acc, val_dice_score = test_fn(model, val_loader, loss_fn,
                                                    device)
        print(
            "Epoch : {}/{}, Train Loss :{:.2f}, Val Loss :{:.2f}, Train acc :{:.2f}, Val acc :{:.2f}, Train dice score :{:.2f}, Val dice score :{:.2f}"
            .format(epoch + 1, epochs, train_loss, val_loss, train_accuracy,
                    val_accuracy, train_dice_score, val_dice_score))
        true_masks = batch['mask']
        # print(true_masks.size())
        imgs = imgs.to(device=device, dtype=torch.float32)
        mask_type = torch.float32 if net.n_classes == 1 else torch.long
        true_masks = true_masks.to(device=device, dtype=mask_type)

        masks_pred = net(imgs)
        # masks_pred = masks_pred.to("cpu", torch.double)
        # print(masks_pred.size())

        loss = criterion(masks_pred, true_masks)
        # loss = criterion_(masks_pred, true_masks)
        # print("###########")
        # print(loss.item())
        epoch_loss += loss.item()
        # writer.add_scalar('Loss/train', loss.item(), global_step)

        print("epoch : %d, batch : %5d, loss : %.5f" %
              (epoch, (global_step / batch_size), loss.item()))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        global_step += 1
        if global_step // 1000 == 0 and global_step > 1000:
            val_pre = eval_net(net, val_loader, device)
            print("val loss : %.5f" % val_pre)

    if epoch % 10 == 0 and epoch > 0:
        torch.save(net.state_dict(), dir_checkpoint + f'epoch_%d.pth' % epoch)
Exemplo n.º 18
0
def train():
    if not os.path.exists('train_model/'):
        os.makedirs('train_model/')
    if not os.path.exists('result/'):
        os.makedirs('result/')

    train_data, dev_data, word2id, id2word, char2id, opts = load_data(
        vars(args))
    model = UNet(opts)

    if args.use_cuda:
        model = model.cuda()

    dev_batches = get_batches(dev_data, args.batch_size, evaluation=True)

    if args.eval:
        print('load model...')
        model.load_state_dict(torch.load(args.model_dir))
        model.eval()
        model.Evaluate(dev_batches,
                       args.data_path + 'dev_eval.json',
                       answer_file='result/' + args.model_dir.split('/')[-1] +
                       '.answers',
                       drop_file=args.data_path + 'drop.json',
                       dev=args.data_path + 'dev-v2.0.json')
        exit()

    if args.load_model:
        print('load model...')
        model.load_state_dict(torch.load(args.model_dir))
        model.eval()
        _, F1 = model.Evaluate(dev_batches,
                               args.data_path + 'dev_eval.json',
                               answer_file='result/' +
                               args.model_dir.split('/')[-1] + '.answers',
                               drop_file=args.data_path + 'drop.json',
                               dev=args.data_path + 'dev-v2.0.json')
        best_score = F1
        with open(args.model_dir + '_f1_scores.pkl', 'rb') as f:
            f1_scores = pkl.load(f)
        with open(args.model_dir + '_em_scores.pkl', 'rb') as f:
            em_scores = pkl.load(f)
    else:
        best_score = 0.0
        f1_scores = []
        em_scores = []

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.Adamax(parameters, lr=args.lrate)

    lrate = args.lrate

    for epoch in range(1, args.epochs + 1):
        train_batches = get_batches(train_data, args.batch_size)
        dev_batches = get_batches(dev_data, args.batch_size, evaluation=True)
        total_size = len(train_data) // args.batch_size

        model.train()
        for i, train_batch in enumerate(train_batches):
            loss = model(train_batch)
            model.zero_grad()
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(parameters, opts['grad_clipping'])
            optimizer.step()
            model.reset_parameters()

            if i % 100 == 0:
                print(
                    'Epoch = %d, step = %d / %d, loss = %.5f, lrate = %.5f best_score = %.3f'
                    % (epoch, i, total_size, model.train_loss.value, lrate,
                       best_score))
                sys.stdout.flush()

        model.eval()
        exact_match_score, F1 = model.Evaluate(
            dev_batches,
            args.data_path + 'dev_eval.json',
            answer_file='result/' + args.model_dir.split('/')[-1] + '.answers',
            drop_file=args.data_path + 'drop.json',
            dev=args.data_path + 'dev-v2.0.json')
        f1_scores.append(F1)
        em_scores.append(exact_match_score)
        with open(args.model_dir + '_f1_scores.pkl', 'wb') as f:
            pkl.dump(f1_scores, f)
        with open(args.model_dir + '_em_scores.pkl', 'wb') as f:
            pkl.dump(em_scores, f)

        if best_score < F1:
            best_score = F1
            print('saving %s ...' % args.model_dir)
            torch.save(model.state_dict(), args.model_dir)
        if epoch > 0 and epoch % args.decay_period == 0:
            lrate *= args.decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lrate
Exemplo n.º 19
0
def train(args):
    """
    Train UNet from datasets
    """

    # dataset
    print('Reading dataset from {}...'.format(args.dataset_path))
    train_dataset = SSDataset(dataset_path=args.dataset_path, is_train=True)
    val_dataset = SSDataset(dataset_path=args.dataset_path, is_train=False)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True)
    val_dataloader = DataLoader(dataset=val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False)

    # mask
    with open(args.mask_json_path, 'w', encoding='utf-8') as mask:
        colors = SSDataset.all_colors
        mask.write(json.dumps(colors))
        print('Mask colors list has been saved in {}'.format(
            args.mask_json_path))

    # model
    net = UNet(in_channels=3, out_channels=5)
    if args.cuda:
        net = net.cuda()

    # setting
    lr = args.lr  # 1e-3
    optimizer = optim.Adam(net.parameters(), lr=lr)
    criterion = loss_fn

    # run
    train_losses = []
    val_losses = []
    print('Start training...')
    for epoch_idx in range(args.epochs):
        # train
        net.train()
        train_loss = 0
        for batch_idx, batch_data in enumerate(train_dataloader):
            xs, ys = batch_data
            if args.cuda:
                xs = xs.cuda()
                ys = ys.cuda()
            ys_pred = net(xs)
            loss = criterion(ys_pred, ys)
            train_loss += loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # val
        net.eval()
        val_loss = 0
        for batch_idx, batch_data in enumerate(val_dataloader):
            xs, ys = batch_data
            if args.cuda:
                xs = xs.cuda()
                ys = ys.cuda()
            ys_pred = net(xs)
            loss = loss_fn(ys_pred, ys)
            val_loss += loss

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        print('Epoch: {}, Train total loss: {}, Val total loss: {}'.format(
            epoch_idx + 1, train_loss.item(), val_loss.item()))

        # save
        if (epoch_idx + 1) % args.save_epoch == 0:
            checkpoint_path = os.path.join(
                args.checkpoint_path,
                'checkpoint_{}.pth'.format(epoch_idx + 1))
            torch.save(net.state_dict(), checkpoint_path)
            print('Saved Checkpoint at Epoch {} to {}'.format(
                epoch_idx + 1, checkpoint_path))

    # summary
    if args.do_save_summary:
        epoch_range = list(range(1, args.epochs + 1))
        plt.plot(epoch_range, train_losses, 'r', label='Train loss')
        plt.plot(epoch_range, val_loss, 'g', label='Val loss')
        plt.imsave(args.summary_image)
        print('Summary images have been saved in {}'.format(
            args.summary_image))

    # save
    net.eval()
    torch.save(net.state_dict(), args.model_state_dict)
    print('Saved state_dict in {}'.format(args.model_state_dict))
Exemplo n.º 20
0
        net.eval()
        epoch_loss = []
        for img, mask in val_data_loader:
            inputs = img.to(device)
            labels = mask.to(device)

            outputs = net(inputs)
            loss = criterion(outputs, labels)

            epoch_loss += [loss.item()]

        mean_loss = np.mean(epoch_loss)
        print(f"val epoch: {epoch}/{num_epoch}, loss: {mean_loss:.3f}")

        # Save best model
        if mean_loss < best_loss:
            best_loss = mean_loss

            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': best_loss
                }, 'best_unet.pt')

            print(f"Save ckpt | loss: {best_loss}")

# Save model
# torch.save(net.state_dict(), 'unet.pt')
Exemplo n.º 21
0
        loss = criterion(pred, target)

        # the 2nd pass
        input_t = warp(input, flow)
        input_t_pred = model(input_t)
        pred_t = warp(pred, flow)

        loss_t = criterion(input_t_pred, pred_t)
        total_loss = loss + loss_t * args.weight

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        logger.add_scalar('Train/Loss', loss.item(), iters)
        logger.add_scalar('Train/Loss_t', loss_t.item(), iters)
        iters += 1

        if (i + 1) % 10 == 0:
            print('Train Epoch: {0} [{1}/{2}]\t'
                  'l1Loss={Loss1:.8f} '
                  'conLoss={Loss2:.8f} '.format(epoch,
                                                i + 1,
                                                len(train_loader),
                                                Loss1=loss.item(),
                                                Loss2=loss_t.item()))

    save_checkpoint(model.state_dict(), epoch, log_dir)
    print()

logger.close()
Exemplo n.º 22
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--bs', metavar='bs', type=int, default=2)
    parser.add_argument('--path', type=str, default='../../data')
    parser.add_argument('--results', type=str, default='../../results/model')
    parser.add_argument('--nw', type=int, default=0)
    parser.add_argument('--max_images', type=int, default=None)
    parser.add_argument('--val_size', type=int, default=None)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--lr', type=float, default=0.003)
    parser.add_argument('--lr_decay', type=float, default=0.99997)
    parser.add_argument('--kernel_lvl', type=float, default=1)
    parser.add_argument('--noise_lvl', type=float, default=1)
    parser.add_argument('--motion_blur', type=bool, default=False)
    parser.add_argument('--homo_align', type=bool, default=False)
    parser.add_argument('--resume', type=bool, default=False)

    args = parser.parse_args()

    print()
    print(args)
    print()

    if not os.path.isdir(args.results): os.makedirs(args.results)

    PATH = args.results
    if not args.resume:
        f = open(PATH + "/param.txt", "a+")
        f.write(str(args))
        f.close()

    writer = SummaryWriter(PATH + '/runs')

    # CUDA for PyTorch
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda:0' if use_cuda else "cpu")

    # Parameters
    params = {'batch_size': args.bs, 'shuffle': True, 'num_workers': args.nw}

    # Generators
    print('Initializing training set')
    training_set = Dataset(args.path + '/train/', args.max_images,
                           args.kernel_lvl, args.noise_lvl, args.motion_blur,
                           args.homo_align)
    training_generator = data.DataLoader(training_set, **params)

    print('Initializing validation set')
    validation_set = Dataset(args.path + '/test/', args.val_size,
                             args.kernel_lvl, args.noise_lvl, args.motion_blur,
                             args.homo_align)

    validation_generator = data.DataLoader(validation_set, **params)

    # Model
    model = UNet(in_channel=3, out_channel=3)
    if args.resume:
        models_path = get_newest_model(PATH)
        print('loading model from ', models_path)
        model.load_state_dict(torch.load(models_path))

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)

    model.to(device)

    # Loss + optimizer
    criterion = BurstLoss()
    optimizer = RAdam(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=8 // args.bs, gamma=args.lr_decay)
    if args.resume:
        n_iter = np.loadtxt(PATH + '/train.txt', delimiter=',')[:, 0][-1]
    else:
        n_iter = 0

    # Loop over epochs
    for epoch in range(args.epochs):
        train_loss = 0.0

        # Training
        model.train()
        for i, (X_batch, y_labels) in enumerate(training_generator):
            # Alter the burst length for each mini batch

            burst_length = np.random.randint(2, 9)
            X_batch = X_batch[:, :burst_length, :, :, :]

            # Transfer to GPU
            X_batch, y_labels = X_batch.to(device).type(
                torch.float), y_labels.to(device).type(torch.float)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            pred = model(X_batch)
            loss = criterion(pred, y_labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            train_loss += loss.detach().cpu().numpy()
            writer.add_scalar('training_loss', loss.item(), n_iter)

            if i % 100 == 0 and i > 0:
                loss_printable = str(np.round(train_loss, 2))

                f = open(PATH + "/train.txt", "a+")
                f.write(str(n_iter) + "," + loss_printable + "\n")
                f.close()

                print("training loss ", loss_printable)

                train_loss = 0.0

            if i % 1000 == 0:
                if torch.cuda.device_count() > 1:
                    torch.save(
                        model.module.state_dict(),
                        os.path.join(PATH,
                                     'model_' + str(int(n_iter)) + '.pt'))
                else:
                    torch.save(
                        model.state_dict(),
                        os.path.join(PATH,
                                     'model_' + str(int(n_iter)) + '.pt'))

            if i % 1000 == 0:
                # Validation
                val_loss = 0.0
                with torch.set_grad_enabled(False):
                    model.eval()
                    for v, (X_batch,
                            y_labels) in enumerate(validation_generator):
                        # Alter the burst length for each mini batch

                        burst_length = np.random.randint(2, 9)
                        X_batch = X_batch[:, :burst_length, :, :, :]

                        # Transfer to GPU
                        X_batch, y_labels = X_batch.to(device).type(
                            torch.float), y_labels.to(device).type(torch.float)

                        # forward + backward + optimize
                        pred = model(X_batch)
                        loss = criterion(pred, y_labels)

                        val_loss += loss.detach().cpu().numpy()

                        if v < 5:
                            im = make_im(pred, X_batch, y_labels)
                            writer.add_image('image_' + str(v), im, n_iter)

                    writer.add_scalar('validation_loss', val_loss, n_iter)

                    loss_printable = str(np.round(val_loss, 2))
                    print('validation loss ', loss_printable)

                    f = open(PATH + "/eval.txt", "a+")
                    f.write(str(n_iter) + "," + loss_printable + "\n")
                    f.close()

            n_iter += args.bs
Exemplo n.º 23
0
            plot3 = fig.add_subplot(gs[0, 2])
            # plot4 = fig.add_subplot(gs[1, 1])
            # plot5 = fig.add_subplot(gs[1, 2])

            plot1.axis('off')
            plot2.axis('off')
            plot3.axis('off')
            # plot4.axis('off')
            # plot5.axis('off')
            if n_count % 200 == 0:
                plot1.imshow(batch_x[0].cpu().squeeze(0), cmap='rainbow')
                plot2.imshow(batch_y[0,:1].cpu().squeeze(0), cmap='rainbow')
                # plot3.imshow((batch_x-batch_y)[0:].cpu().squeeze(0), cmap='gray')
                plot3.imshow(output[0].cpu().detach().numpy().squeeze(0), cmap='rainbow')
                # plot5.imshow(batch_y[0:1].cpu().squeeze(0)-output[0].cpu().detach().numpy().squeeze(0), cmap='gray')
                plt.show()


            epoch_loss += loss_chk.item()
            loss.backward()
            optimizer.step()
            if n_count % 10 == 0:
                print('%4d %4d / %4d loss = %2.8f loss_mse = %2.8f' % (
                epoch + 1, n_count, xs.size(0) // batch_size, loss.item() / batch_size, loss_chk.item() / batch_size))
        elapsed_time = time.time() - start_time

        log('epcoh = %4d , loss = %4.4f , time = %4.2f s' % (epoch + 1, epoch_loss / n_count, elapsed_time))
        np.savetxt('train_result.txt', np.hstack((epoch + 1, epoch_loss / n_count, elapsed_time)), fmt='%2.4f')
        torch.save(model.state_dict(), os.path.join(save_dir, 'model_%03d.pth' % (epoch+1)))
        # torch.save(model, os.path.join(save_dir, 'model_%03d.pth' % (epoch + 1)))
Exemplo n.º 24
0
    def UNet_train(self):
                
        model = UNet(in_ch = args.in_ch, out_ch = args.out_ch, kernel_size = args.kernel_size).to(device)
        optimizer = torch.optim.SGD(model.parameters(), lr = args.lr, momentum=0.99)
        criterion = nn.CrossEntropyLoss()
        
        iters = np.ceil(self.train_imgs.size(0)/args.batch_size).astype(int)
        print("\nSteps per epoch =  {}\n".format(iters))
        best_acc = 0
        test_imgs = self.test_imgs
        test_labels = self.test_labels
        
        print("="*70 +"\n\t\t\t Training Network\n"+ "="*70)
        start = time.time()
        for epoch in range(args.epochs):
            print(epoch)
            train_loss = []
            
            # Shuffling the data
            permute_idxs = np.random.permutation(len(self.train_labels))
            train_imgs = self.train_imgs[permute_idxs]
            train_labels = self.train_labels[permute_idxs]
            for step in range(iters):
                start = step*args.batch_size
                stop = (step+1)*args.batch_size
                
                # Get batches
                train_batch_imgs = train_imgs[start:stop].float()
                train_batch_labels = train_labels[start: stop].long()

                # Get predictions
                optimizer.zero_grad()
                out = model(train_batch_imgs)
                
                # Calculate Loss
                # out = out.permute(0, 2, 3, 1)
                # out = out.resize(args.batch_size * args.out_height * args.out_breadth, 2)
                # train_batch_labels = train_batch_labels.resize(args.batch_size * args.out_height * args.out_breadth)
                out = out.resize(train_batch_imgs.size(0)*args.out_height*args.out_breadth, args.out_ch)
                # print(train_batch_labels.size())
                train_batch_labels = train_batch_labels.resize(train_batch_labels.size(0)*args.out_height*args.out_breadth)
                loss = criterion(out, train_batch_labels)

                # Backprop
                loss.backward()
                optimizer.step()

                train_loss.append(loss.item())
                avg_train_loss = round(np.mean(train_loss),4)
                preds = torch.max(out.data,1)[1]
                correct = preds.long().eq(train_batch_labels.long()).cpu().sum().item()
                train_acc = correct/(iters*args.out_height*args.out_breadth)

                writer.add_scalar('Train/Loss', avg_train_loss, epoch+1)
                writer.add_scalar('Train/Accuracy', train_acc, epoch+1)
                for name, param in model.named_parameters():
                    if not param.requires_grad:
                        continue
                    writer.add_histogram('epochs/'+name, param.data.view(-1), global_step = epoch+1)
                
            if epoch % args.eval_every == 0:
                avg_test_loss, test_acc = self.get_val_results(test_imgs, test_labels, model)
                writer.add_scalar('Test/Loss', avg_test_loss, epoch+1)
                writer.add_scalar('Test/Accuracy', test_acc, epoch+1)
                if test_acc > best_acc:
                    best_acc = test_acc
                    print("\nNew High Score! Saving model...\n")
                    torch.save(model.state_dict(), self.model_path+"/model.pickle")

                end = time.time()
                h,m,s = calc_elapsed_time(start, end)
                print("\nEpoch: {}/{},  Train_loss = {:.4f},  Train_acc = {:.4f},   Val_loss = {:.4f},    Val_acc = {:.4f}"
                      .format(epoch+1, args.epochs, avg_train_loss, train_acc, avg_test_loss, test_acc))

        print("\n"+"="*50 + "\n\t Training Done \n")
        print("\nBest Val accuracy = ", best_acc)
Exemplo n.º 25
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=100,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=67,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=80,
                        metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.3,
                        metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=13,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=5,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--output_nc',
                        type=int,
                        default=1,
                        metavar='N',
                        help='output channels')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=True,
                        help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}

    #Dataset making

    img_dir = '/n/holyscratch01/wadduwage_lab/temp20200620/20-Jun-2020/beads_tr_data_5sls_20-Jun-2020.h5'
    train_dataset = HDF5Dataset(img_dir=img_dir, isTrain=True)
    test_dataset = HDF5Dataset(img_dir=img_dir, isTrain=False)

    # Data Loading #
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=0,
                              drop_last=True)

    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=0,
                             drop_last=True)

    model = UNet(n_classes=args.output_nc).cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=1e-6)
    scheduler = StepLR(optimizer, step_size=20, gamma=args.gamma)
    criterion = torch.nn.SmoothL1Loss()

    Best_ACC = 0
    Best_Epoch = 1

    for epoch in range(1, args.epochs + 1):
        tloss = train(args, model, device, train_loader, optimizer, epoch,
                      criterion)
        vloss = test(args, model, device, test_loader, criterion)
        print("epoch:%.1f" % epoch, "Train_loss:%.4f" % tloss,
              "Val_loss:%.4f" % vloss)
        scheduler.step()
        try:
            os.makedirs(model_path)
        except OSError:
            pass
        torch.save(model.state_dict(),
                   model_path + "/fcn_deep_" + str(epoch) + ".pth")
Exemplo n.º 26
0
    (options, args) = parser.parse_args()
    return options


if __name__ == '__main__':
    args = get_args()

    net = UNet(n_channels=3, n_classes=3)

    if args.load:
        net.load_state_dict(torch.load(args.load))
        print('Model loaded from {}'.format(args.load))

    if args.gpu:
        net.cuda()
        cudnn.benchmark = True  # faster convolutions, but more memory

    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  gpu=args.gpu)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        print('Saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
Exemplo n.º 27
0
for i, name in enumerate(nms):
    catIds = coco.getCatIds(catNms=name)
    imgIds = coco.getImgIds(catIds=catIds)
    catcount.append(len(imgIds))

indices = np.flip(np.argsort(catcount)[-global_classes:])

topk_catnames = []
for i in range(global_classes):
    topk_catnames.append(nms[indices[i]])

vgg16 = models.vgg16_bn(pretrained=True)
model = UNet()

dctvgg = vgg16.state_dict()
dct = model.state_dict()

dct['inc.conv.conv.0.weight'].data.copy_(dctvgg['features.0.weight'])  #
dct['inc.conv.conv.0.bias'].data.copy_(dctvgg['features.0.bias'])
dct['inc.conv.conv.1.weight'].data.copy_(dctvgg['features.1.weight'])  #
dct['inc.conv.conv.1.bias'].data.copy_(dctvgg['features.1.bias'])
dct['inc.conv.conv.1.running_mean'].data.copy_(
    dctvgg['features.1.running_mean'])  #
dct['inc.conv.conv.1.running_var'].data.copy_(dctvgg['features.1.running_var'])

dct['inc.conv.conv.3.weight'].data.copy_(dctvgg['features.3.weight'])
dct['inc.conv.conv.3.bias'].data.copy_(dctvgg['features.3.bias'])
dct['inc.conv.conv.4.weight'].data.copy_(dctvgg['features.4.weight'])  #
dct['inc.conv.conv.4.bias'].data.copy_(dctvgg['features.4.bias'])
dct['inc.conv.conv.4.running_mean'].data.copy_(
    dctvgg['features.4.running_mean'])  #
            # forward
            output = model.forward(_data)
            pred = output.argmax(1, keepdim=True)

            # convert to numpy
            _target = _target.cpu().numpy()
            pred = pred.cpu().numpy()

            # calculate IoU score
            iou_score += calculate_iou(pred, _target)

    # calculate average IoU score and print the score
    iou_score /= len(test_loader.dataset)
    
    print('{} set: \t{:.4f}'.format(dataset, iou_score))

# test the models
print("\t\tAverage IoU Score")
for (dataset, model) in models.items():
    test(dataset, model, device, TestDatasetLoader)

# test("DRIVE", model_drive, device, TestDatasetLoader)
# test("STARE", model_stare, device, TestDatasetLoader)
# test("FEDERATED LEARNING", model_fr, device, TestDatasetLoader)

# save models
print("Saving models...")
for (dataset, model) in models.items():
    PATH = os.path.sep.join(["models", "model_" + dataset + ".pt"])
    torch.save(model.state_dict(), PATH)
Exemplo n.º 29
0
def main(args):
    writer = SummaryWriter(os.path.join('./logs'))
    # torch.backends.cudnn.benchmark = True
    if not os.path.isdir(args.checkpoint_dir):
        os.mkdir(args.checkpoint_dir)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('[MODEL] CUDA DEVICE : {}'.format(device))

    # TODO DEFINE TRAIN AND TEST TRANSFORMS
    train_tf = None
    test_tf = None

    # Channel wise mean calculated on adobe240-fps training dataset
    mean = [0.429, 0.431, 0.397]
    std = [1, 1, 1]
    normalize = transforms.Normalize(mean=mean, std=std)
    transform = transforms.Compose([transforms.ToTensor(), normalize])

    test_valid = 'validation' if args.valid else 'test'
    train_data = BlurDataset(os.path.join(args.dataset_root, 'train'),
                             seq_len=args.sequence_length,
                             tau=args.num_frame_blur,
                             delta=5,
                             transform=train_tf)
    test_data = BlurDataset(os.path.join(args.dataset_root, test_valid),
                            seq_len=args.sequence_length,
                            tau=args.num_frame_blur,
                            delta=5,
                            transform=train_tf)

    train_loader = DataLoader(train_data,
                              batch_size=args.train_batch_size,
                              shuffle=True)
    test_loader = DataLoader(test_data,
                             batch_size=args.test_batch_size,
                             shuffle=False)

    # TODO IMPORT YOUR CUSTOM MODEL
    model = UNet(3, 3, device, decode_mode=args.decode_mode)

    if args.checkpoint:
        store_dict = torch.load(args.checkpoint)
        try:
            model.load_state_dict(store_dict['state_dict'])
        except KeyError:
            model.load_state_dict(store_dict)

    if args.train_continue:
        store_dict = torch.load(args.checkpoint)
        model.load_state_dict(store_dict['state_dict'])

    else:
        store_dict = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1}

    model.to(device)
    model.train(True)

    # model = nn.DataParallel(model)

    # TODO DEFINE MORE CRITERIA
    # input(True if device == torch.device('cuda:0') else False)
    criterion = {
        'MSE': nn.MSELoss(),
        'L1': nn.L1Loss(),
        # 'Perceptual': PerceptualLoss(model='net-lin', net='vgg', dataparallel=True,
        #                              use_gpu=True if device == torch.device('cuda:0') else False)
    }

    criterion_w = {'MSE': 1.0, 'L1': 10.0, 'Perceptual': 10.0}

    # Define optimizers
    # optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9,weight_decay=5e-4)
    optimizer = optim.Adam(model.parameters(), lr=args.init_learning_rate)

    # Define lr scheduler
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.milestones,
                                               gamma=0.1)

    # best_acc = 0.0
    # start = time.time()
    cLoss = store_dict['loss']
    valLoss = store_dict['valLoss']
    valPSNR = store_dict['valPSNR']
    checkpoint_counter = 0

    loss_tracker = {}
    loss_tracker_test = {}

    psnr_old = 0.0
    dssim_old = 0.0

    for epoch in range(1, 10 *
                       args.epochs):  # loop over the dataset multiple times

        # Append and reset
        cLoss.append([])
        valLoss.append([])
        valPSNR.append([])
        running_loss = 0

        # Increment scheduler count
        scheduler.step()

        tqdm_loader = tqdm(range(len(train_loader)), ncols=150)

        loss = 0.0
        psnr_ = 0.0
        dssim_ = 0.0

        loss_tracker = {}
        for loss_fn in criterion.keys():
            loss_tracker[loss_fn] = 0.0

        # Train
        model.train(True)
        total_steps = 0.01
        total_steps_test = 0.01
        '''for train_idx, data in enumerate(train_loader, 1):
            loss = 0.0
            blur_data, sharpe_data = data
            #import pdb; pdb.set_trace()
            # input(sharpe_data.shape)
            #import pdb; pdb.set_trace()
            interp_idx = int(math.ceil((args.num_frame_blur/2) - 0.49))
            #input(interp_idx)
            if args.decode_mode == 'interp':
                sharpe_data = sharpe_data[:, :, 1::2, :, :]
            elif args.decode_mode == 'deblur':
                sharpe_data = sharpe_data[:, :, 0::2, :, :]
            else:
                #print('\nBoth\n')
                sharpe_data = sharpe_data

            #print(sharpe_data.shape)
            #input(blur_data.shape)
            blur_data = blur_data.to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)
            try:
                sharpe_data = sharpe_data.squeeze().to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)
            except:
                sharpe_data = sharpe_data.squeeze(3).to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)

            # clear gradient
            optimizer.zero_grad()

            # forward pass
            sharpe_out = model(blur_data)
            # import pdb; pdb.set_trace()
            # input(sharpe_out.shape)

            # compute losses
            # import pdb;
            # pdb.set_trace()
            sharpe_out = sharpe_out.permute(0, 2, 1, 3, 4)
            B, C, S, Fx, Fy = sharpe_out.shape
            for loss_fn in criterion.keys():
                loss_tmp = 0.0

                if loss_fn == 'Perceptual':
                    for bidx in range(B):
                        loss_tmp += criterion_w[loss_fn] * \
                                   criterion[loss_fn](sharpe_out[bidx].permute(1, 0, 2, 3),
                                                      sharpe_data[bidx].permute(1, 0, 2, 3)).sum()
                    # loss_tmp /= B
                else:
                    loss_tmp = criterion_w[loss_fn] * \
                               criterion[loss_fn](sharpe_out, sharpe_data)


                # try:
                # import pdb; pdb.set_trace()
                loss += loss_tmp # if
                # except :
                try:
                    loss_tracker[loss_fn] += loss_tmp.item()
                except KeyError:
                    loss_tracker[loss_fn] = loss_tmp.item()

            # Backpropagate
            loss.backward()
            optimizer.step()

            # statistics
            # import pdb; pdb.set_trace()
            sharpe_out = sharpe_out.detach().cpu().numpy()
            sharpe_data = sharpe_data.cpu().numpy()
            for sidx in range(S):
                for bidx in range(B):
                    psnr_ += psnr(sharpe_out[bidx, :, sidx, :, :], sharpe_data[bidx, :, sidx, :, :]) #, peak=1.0)
                    """dssim_ += dssim(np.moveaxis(sharpe_out[bidx, :, sidx, :, :], 0, 2),
                                    np.moveaxis(sharpe_data[bidx, :, sidx, :, :], 0, 2)
                                    )"""

            """sharpe_out = sharpe_out.reshape(-1,3, sx, sy).detach().cpu().numpy()
            sharpe_data = sharpe_data.reshape(-1, 3, sx, sy).cpu().numpy()
            for idx in range(sharpe_out.shape[0]):
                # import pdb; pdb.set_trace()
                psnr_ += psnr(sharpe_data[idx], sharpe_out[idx])
                dssim_ += dssim(np.swapaxes(sharpe_data[idx], 2, 0), np.swapaxes(sharpe_out[idx], 2, 0))"""

            # psnr_ /= sharpe_out.shape[0]
            # dssim_ /= sharpe_out.shape[0]
            running_loss += loss.item()
            loss_str = ''
            total_steps += B*S
            for key in loss_tracker.keys():
               loss_str += ' {0} : {1:6.4f} '.format(key, 1.0*loss_tracker[key] / total_steps)

            # set display info
            if train_idx % 5 == 0:
                tqdm_loader.set_description(('\r[Training] [Ep {0:6d}] loss: {1:6.4f} PSNR: {2:6.4f} SSIM: {3:6.4f} '.format
                                    (epoch, running_loss / total_steps,
                                     psnr_ / total_steps,
                                     dssim_ / total_steps) + loss_str
                                    ))

                tqdm_loader.update(5)
        tqdm_loader.close()'''

        # Validation
        running_loss_test = 0.0
        psnr_test = 0.0
        dssim_test = 0.0
        # print('len', len(test_loader))
        tqdm_loader_test = tqdm(range(len(test_loader)), ncols=150)
        # import pdb; pdb.set_trace()

        loss_tracker_test = {}
        for loss_fn in criterion.keys():
            loss_tracker_test[loss_fn] = 0.0

        with torch.no_grad():
            model.eval()
            total_steps_test = 0.0

            for test_idx, data in enumerate(test_loader, 1):
                loss = 0.0
                blur_data, sharpe_data = data
                interp_idx = int(math.ceil((args.num_frame_blur / 2) - 0.49))
                # input(interp_idx)
                if args.decode_mode == 'interp':
                    sharpe_data = sharpe_data[:, :, 1::2, :, :]
                elif args.decode_mode == 'deblur':
                    sharpe_data = sharpe_data[:, :, 0::2, :, :]
                else:
                    # print('\nBoth\n')
                    sharpe_data = sharpe_data

                # print(sharpe_data.shape)
                # input(blur_data.shape)
                blur_data = blur_data.to(device)[:, :, :, :352, :].permute(
                    0, 1, 2, 4, 3)
                try:
                    sharpe_data = sharpe_data.squeeze().to(
                        device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)
                except:
                    sharpe_data = sharpe_data.squeeze(3).to(
                        device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)

                # clear gradient
                optimizer.zero_grad()

                # forward pass
                sharpe_out = model(blur_data)
                # import pdb; pdb.set_trace()
                # input(sharpe_out.shape)

                # compute losses
                sharpe_out = sharpe_out.permute(0, 2, 1, 3, 4)
                B, C, S, Fx, Fy = sharpe_out.shape
                for loss_fn in criterion.keys():
                    loss_tmp = 0.0
                    if loss_fn == 'Perceptual':
                        for bidx in range(B):
                            loss_tmp += criterion_w[loss_fn] * \
                                        criterion[loss_fn](sharpe_out[bidx].permute(1, 0, 2, 3),
                                                           sharpe_data[bidx].permute(1, 0, 2, 3)).sum()
                        # loss_tmp /= B
                    else:
                        loss_tmp = criterion_w[loss_fn] * \
                                   criterion[loss_fn](sharpe_out, sharpe_data)
                    loss += loss_tmp
                    try:
                        loss_tracker_test[loss_fn] += loss_tmp.item()
                    except KeyError:
                        loss_tracker_test[loss_fn] = loss_tmp.item()

                if ((test_idx % args.progress_iter) == args.progress_iter - 1):
                    itr = test_idx + epoch * len(test_loader)
                    # itr_train
                    writer.add_scalars(
                        'Loss', {
                            'trainLoss': running_loss / total_steps,
                            'validationLoss':
                            running_loss_test / total_steps_test
                        }, itr)
                    writer.add_scalar('Train PSNR', psnr_ / total_steps, itr)
                    writer.add_scalar('Test PSNR',
                                      psnr_test / total_steps_test, itr)
                    # import pdb; pdb.set_trace()
                    # writer.add_image('Validation', sharpe_out.permute(0, 2, 3, 1), itr)

                # statistics
                sharpe_out = sharpe_out.detach().cpu().numpy()
                sharpe_data = sharpe_data.cpu().numpy()
                for sidx in range(S):
                    for bidx in range(B):
                        psnr_test += psnr(
                            sharpe_out[bidx, :, sidx, :, :],
                            sharpe_data[bidx, :, sidx, :, :])  #, peak=1.0)
                        dssim_test += dssim(
                            np.moveaxis(sharpe_out[bidx, :, sidx, :, :], 0, 2),
                            np.moveaxis(sharpe_data[bidx, :, sidx, :, :], 0,
                                        2))  #,range=1.0  )

                running_loss_test += loss.item()
                total_steps_test += B * S
                loss_str = ''
                for key in loss_tracker.keys():
                    loss_str += ' {0} : {1:6.4f} '.format(
                        key, 1.0 * loss_tracker_test[key] / total_steps_test)

                # set display info

                tqdm_loader_test.set_description((
                    '\r[Test    ] [Ep {0:6d}] loss: {1:6.4f} PSNR: {2:6.4f} SSIM: {3:6.4f} '
                    .format(epoch, running_loss_test / total_steps_test,
                            psnr_test / total_steps_test,
                            dssim_test / total_steps_test) + loss_str))
                tqdm_loader_test.update(1)
            tqdm_loader_test.close()

        # save model
        if psnr_old < (psnr_test / total_steps_test):
            if epoch != 1:
                os.remove(
                    os.path.join(
                        args.checkpoint_dir,
                        'epoch-{}-test-psnr-{}-ssim-{}.ckpt'.format(
                            epoch_old,
                            str(round(psnr_old, 4)).replace('.', 'pt'),
                            str(round(dssim_old, 4)).replace('.', 'pt'))))
            epoch_old = epoch
            psnr_old = psnr_test / total_steps_test
            dssim_old = dssim_test / total_steps_test

            checkpoint_dict = {
                'epoch': epoch_old,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'train_psnr': psnr_ / total_steps,
                'train_dssim': dssim_ / total_steps,
                'train_mse': loss_tracker['MSE'] / total_steps,
                'train_l1': loss_tracker['L1'] / total_steps,
                # 'train_percp': loss_tracker['Perceptual'] / total_steps,
                'test_psnr': psnr_old,
                'test_dssim': dssim_old,
                'test_mse': loss_tracker_test['MSE'] / total_steps_test,
                'test_l1': loss_tracker_test['L1'] / total_steps_test,
                # 'test_percp': loss_tracker_test['Perceptual'] / total_steps_test,
            }

            torch.save(
                checkpoint_dict,
                os.path.join(
                    args.checkpoint_dir,
                    'epoch-{}-test-psnr-{}-ssim-{}.ckpt'.format(
                        epoch_old,
                        str(round(psnr_old, 4)).replace('.', 'pt'),
                        str(round(dssim_old, 4)).replace('.', 'pt'))))

        # if epoch % args.checkpoint_epoch == 0:
        #    torch.save(model.state_dict(),args.checkpoint_dir + str(int(epoch/100))+".ckpt")

    return None
Exemplo n.º 30
0
        
        
        # summary.add_scalar('dice', val_dice, epoch_idx)      
   # 
    
        # if (batch_idx+1)%(args.log_interval) == 0 : 
        print("Epoch | Epoch {}/{}  loss {:2.4f}  Dice {:2.4f}  val-loss {:2.4f}  val-Dice {:2.4f}". 
              format(epoch_idx, args.epochs, 
                     losses/(batch_idx+1), dices/(batch_idx+1),
                     val_losses/(vbatch_idx+1), val_dices/(vbatch_idx+1),
                     ))
            
        if max_score <= val_dices/(vbatch_idx+1) :
            max_score = val_dices/(vbatch_idx+1)
            print("  max score : %f" %(max_score))
            torch.save(unet.state_dict(), dir_checkpoint+'unet_2000.pth')              

        
        summary.add_scalars('Loss', {'train_loss' : losses/(batch_idx+1), 'val_loss' : val_losses/(vbatch_idx+1)}, epoch_idx+offset)
        summary.add_scalars('Dice', {'train_dice' : dices/(batch_idx+1), 'val_dice' : val_dices/(vbatch_idx+1)}, epoch_idx+offset)
        
        
        # # if (batch_idx+1)%(args.log_interval) == 0 : 
    
        
        # print("Epoch {}/{}  Batch {}/{}  loss {:2.4f}  Dice {:2.4f} ". 
        #       format(epoch_idx, args.epochs, (batch_idx+1), len(train_dataloader), 
        #              losses/(batch_idx+1), dices/(batch_idx+1)))
        
        # if max_score <= dices/(batch_idx+1) :
        #     max_score = dices/(batch_idx+1)