예제 #1
0
        print('step_now: {}, step loss: {}'.format(step_now,loss_cpu))
        if step_now % save_step == 0:
            val_loss = 0
            torch.save(model.state_dict(), './model-step%d.pth' % (step_now))
'''

## FP16 !
for i in range(num_epoches):
    for id, sample in enumerate(train_dataloader):
        step_now += 1
        img = sample["image"]
        mask = sample["mask"]
        optimizer.zero_grad()
        with autocast():
            pred = model(img.to(device))
            loss = criterion(pred, mask.to(device))
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        loss_cpu = loss.cpu().item()
        print('step_now: {}, step loss: {}'.format(step_now, loss_cpu))
        total_loss += loss_cpu
        if step_now % save_step == 0:
            val_loss = 0
            torch.save(model.state_dict(), './model-step%d.pth' % (step_now))

torch.cuda.synchronize()
end = time.time()
print("training time per step(ms): ", (end - start) * 1000 / (num_epoches * 2))
torch.save(model.state_dict(), './model-all.pth')
예제 #2
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)

    if args.saveTest == 'True':
        args.saveTest = True
    elif args.saveTest == 'False':
        args.saveTest = False

    # Check if the save directory exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    cudnn.benchmark = True

    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST),
            transforms.TenCrop(args.resizedImageSize),
            transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
            #transforms.Lambda(lambda normalized: torch.stack([transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182])(crop) for crop in normalized]))
            #transforms.RandomResizedCrop(224, interpolation=Image.NEAREST),
            #transforms.RandomHorizontalFlip(),
            #transforms.RandomVerticalFlip(),
            #transforms.ToTensor(),
        ]),
        'test': transforms.Compose([
            transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST),
            transforms.ToTensor(),
            #transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182])
        ]),
    }

    # Data Loading
    data_dir = 'datasets/miccaiSegRefined'
    # json path for class definitions
    json_path = 'datasets/miccaiSegClasses.json'

    image_datasets = {x: miccaiSegDataset(os.path.join(data_dir, x), data_transforms[x],
                        json_path) for x in ['train', 'test']}

    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                  batch_size=args.batchSize,
                                                  shuffle=True,
                                                  num_workers=args.workers)
                  for x in ['train', 'test']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}

    # Get the dictionary for the id and RGB value pairs for the dataset
    classes = image_datasets['train'].classes
    key = utils.disentangleKey(classes)
    num_classes = len(key)

    # Initialize the model
    model = UNet(num_classes)

    # # Optionally resume from a checkpoint
    # if args.resume:
    #     if os.path.isfile(args.resume):
    #         print("=> loading checkpoint '{}'".format(args.resume))
    #         checkpoint = torch.load(args.resume)
    #         #args.start_epoch = checkpoint['epoch']
    #         pretrained_dict = checkpoint['state_dict']
    #         pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model.state_dict()}
    #         model.state_dict().update(pretrained_dict)
    #         model.load_state_dict(model.state_dict())
    #         print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    #     else:
    #         print("=> no checkpoint found at '{}'".format(args.resume))
    #
    #     # # Freeze the encoder weights
    #     # for param in model.encoder.parameters():
    #     #     param.requires_grad = False
    #
    #     optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd)
    # else:
    optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd)

    # Load the saved model
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

    print(model)

    # Define loss function (criterion)
    criterion = nn.CrossEntropyLoss()

    # Use a learning rate scheduler
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    if use_gpu:
        model.cuda()
        criterion.cuda()

    # Initialize an evaluation Object
    evaluator = utils.Evaluate(key, use_gpu)

    for epoch in range(args.start_epoch, args.epochs):
        #adjust_learning_rate(optimizer, epoch)

        # Train for one epoch
        print('>>>>>>>>>>>>>>>>>>>>>>>Training<<<<<<<<<<<<<<<<<<<<<<<')
        train(dataloaders['train'], model, criterion, optimizer, scheduler, epoch, key)

        # Evaulate on validation set

        print('>>>>>>>>>>>>>>>>>>>>>>>Testing<<<<<<<<<<<<<<<<<<<<<<<')
        validate(dataloaders['test'], model, criterion, epoch, key, evaluator)

        # Calculate the metrics
        print('>>>>>>>>>>>>>>>>>> Evaluating the Metrics <<<<<<<<<<<<<<<<<')
        IoU = evaluator.getIoU()
        print('Mean IoU: {}, Class-wise IoU: {}'.format(torch.mean(IoU), IoU))
        PRF1 = evaluator.getPRF1()
        precision, recall, F1 = PRF1[0], PRF1[1], PRF1[2]
        print('Mean Precision: {}, Class-wise Precision: {}'.format(torch.mean(precision), precision))
        print('Mean Recall: {}, Class-wise Recall: {}'.format(torch.mean(recall), recall))
        print('Mean F1: {}, Class-wise F1: {}'.format(torch.mean(F1), F1))
        evaluator.reset()

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, filename=os.path.join(args.save_dir, 'checkpoint_{}.tar'.format(epoch)))
예제 #3
0
                pass
            torch.save(net.state_dict(),
                       dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')
                

    writer.close()


if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    net = UNet(n_channels=3, n_classes=1, bilinear=True)
    net.to(device=device)
    # faster convolutions, but more memory
    # cudnn.benchmark = True

    try:
        train_net(net=net,
                  epochs=5,
                  batch_size=1,
                  lr=0.001,
                  device=device,
                  val_percent=10.0 / 100)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        logging.info('Saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
예제 #4
0
# resume
if (osp.isfile(resume_path) and resume_flag):
    checkpoint = torch.load(resume_path)
    model.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"])
    best_iou = checkpoint['best_iou']
    # scheduler.load_state_dict(checkpoint["scheduler_state"])
    # start_epoch = checkpoint["epoch"]
    print(
        "=====>",
        "Loaded checkpoint '{}' (iter {})".format(resume_path,
                                                  checkpoint["epoch"]))
else:
    print("=====>", "No checkpoint found at '{}'".format(resume_path))
    print("load unet weight and bias")
    model_dict = model.state_dict()
    pretrained_dict = torch.load("/home/cv_xfwang/Pytorch-UNet/MODEL.pth")
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)


# Training
def train(epoch):
    print("Epoch: ", epoch)
    model.train()
    total_loss = 0
    # for index, (img, mask) in tqdm(enumerate(train_loader)):
예제 #5
0
class BaseModel:
    losses = {'train': [], 'val': []}
    acces = {'train': [], 'val': []}
    scores = {'train': [], 'val': []}
    pred = {'train': [], 'val': []}
    true = {'train': [], 'val': []}

    def __init__(self, args):
        self.args = args
        self.net = None
        print(args.model_name)
        if args.model_name == 'UNet':
            self.net = UNet(args.in_channels, args.num_classes)
            self.net.apply(weights_init)
        elif args.model_name == 'UNetResNet34':
            self.net = UNetResNet34(args.num_classes, dropout_2d=0.2)
        elif args.model_name == 'UNetResNet152':
            self.net = UNetResNet152(args.num_classes, dropout_2d=0.2)
        elif args.model_name == 'UNet11':
            self.net = UNet11(args.num_classes, pretrained=True)
        elif args.model_name == 'UNetVGG16':
            self.net = UNetVGG16(args.num_classes,
                                 pretrained=True,
                                 dropout_2d=0.0,
                                 is_deconv=True)
        elif args.model_name == 'deeplab50_v2':
            if args.ms:
                raise NotImplemented
            else:
                self.net = deeplab50_v2(args.num_classes,
                                        pretrained=args.pretrained)
        elif args.model_name == 'deeplab_v2':
            if args.ms:
                self.net = ms_deeplab_v2(args.num_classes,
                                         pretrained=args.pretrained,
                                         scales=args.ms_scales)
            else:
                self.net = deeplab_v2(args.num_classes,
                                      pretrained=args.pretrained)
        elif args.model_name == 'deeplab_v3':
            if args.ms:
                self.net = ms_deeplab_v3(args.num_classes,
                                         out_stride=args.out_stride,
                                         pretrained=args.pretrained,
                                         scales=args.ms_scales)
            else:
                self.net = deeplab_v3(args.num_classes,
                                      out_stride=args.out_stride,
                                      pretrained=args.pretrained)
        elif args.model_name == 'deeplab_v3_plus':
            if args.ms:
                self.net = ms_deeplab_v3_plus(args.num_classes,
                                              out_stride=args.out_stride,
                                              pretrained=args.pretrained,
                                              scales=args.ms_scales)
            else:
                self.net = deeplab_v3_plus(args.num_classes,
                                           out_stride=args.out_stride,
                                           pretrained=args.pretrained)

        self.interp = nn.Upsample(size=args.size, mode='bilinear')

        self.iterations = args.epochs
        self.lr_current = args.lr
        self.cuda = args.cuda
        self.phase = args.phase
        self.lr_policy = args.lr_policy
        self.cyclic_m = args.cyclic_m
        if self.lr_policy == 'cyclic':
            print('using cyclic')
            assert self.iterations % self.cyclic_m == 0
        if args.loss == 'CELoss':
            self.criterion = nn.CrossEntropyLoss(size_average=True)
        elif args.loss == 'DiceLoss':
            self.criterion = DiceLoss(num_classes=args.num_classes)
        elif args.loss == 'MixLoss':
            self.criterion = MixLoss(args.num_classes,
                                     weights=args.loss_weights)
        elif args.loss == 'LovaszLoss':
            self.criterion = LovaszSoftmax(per_image=args.loss_per_img)
        elif args.loss == 'FocalLoss':
            self.criterion = FocalLoss(args.num_classes, alpha=None, gamma=2)
        else:
            raise RuntimeError('must define loss')

        if 'deeplab' in args.model_name:
            self.optimizer = optim.SGD(
                [{
                    'params': get_1x_lr_params_NOscale(self.net),
                    'lr': args.lr
                }, {
                    'params': get_10x_lr_params(self.net),
                    'lr': 10 * args.lr
                }],
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay)
        else:
            self.optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                              self.net.parameters()),
                                       lr=args.lr,
                                       momentum=args.momentum,
                                       weight_decay=args.weight_decay)
        self.iters = 0
        self.best_val = 0.0
        self.count = 0

    def init_model(self):
        if self.args.resume_model:
            saved_state_dict = torch.load(
                self.args.resume_model,
                map_location=lambda storage, loc: storage)
            if self.args.ms:
                new_params = self.net.Scale.state_dict().copy()
                for i in saved_state_dict:
                    # Scale.layer5.conv2d_list.3.weight
                    i_parts = i.split('.')
                    # print i_parts
                    if not (not i_parts[0] == 'layer5') and (not i_parts[0]
                                                             == 'decoder'):
                        new_params[i] = saved_state_dict[i]
                self.net.Scale.load_state_dict(new_params)
            else:
                new_params = self.net.state_dict().copy()
                for i in saved_state_dict:
                    # Scale.layer5.conv2d_list.3.weight
                    i_parts = i.split('.')
                    # print i_parts
                    if (not i_parts[0] == 'layer5') and (not i_parts[0]
                                                         == 'decoder'):
                        # if not i_parts[0] == 'layer5':
                        new_params[i] = saved_state_dict[i]
                self.net.load_state_dict(new_params)

            print('Resuming training, image net loading {}...'.format(
                self.args.resume_model))
            # self.load_weights(self.net, self.args.resume_model)

        if self.args.mGPUs:
            self.net = nn.DataParallel(self.net)

        if self.args.cuda:
            self.net = self.net.cuda()
            cudnn.benchmark = True

    def _adjust_learning_rate(self, epoch):
        """Sets the learning rate to the initial LR decayed by 10 at every specified step
        # Adapted from PyTorch Imagenet example:
        # https://github.com/pytorch/examples/blob/master/imagenet/main.py
        """
        if epoch < int(self.iterations * 0.5):
            self.lr_current = max(self.lr_current * self.args.gamma, 1e-4)
        elif epoch < int(self.iterations * 0.85):
            self.lr_current = max(self.lr_current * self.args.gamma, 1e-5)
        else:
            self.lr_current = max(self.lr_current * self.args.gamma, 1e-6)
        self.optimizer.param_groups[0]['lr'] = self.lr_current
        self.optimizer.param_groups[1]['lr'] = self.lr_current * 10

    def save_network(self, net, net_name, epoch, label=''):
        save_fname = '%s_%s_%s.pth' % (epoch, net_name, label)
        save_path = os.path.join(self.args.save_folder, self.args.exp_name,
                                 save_fname)
        torch.save(net.state_dict(), save_path)

    def load_weights(self, net, base_file):
        other, ext = os.path.splitext(base_file)
        if ext == '.pkl' or '.pth':
            print('Loading weights into state dict...')
            net.load_state_dict(
                torch.load(base_file,
                           map_location=lambda storage, loc: storage))
            print('Finished!')
        else:
            print('Sorry only .pth and .pkl files supported.')

    def load_trained_model(self):
        path = os.path.join(self.args.save_folder, self.args.exp_name,
                            self.args.trained_model)
        print('eval cls, image net loading {}...'.format(path))
        if self.args.ms:
            self.load_weights(self.net.Scale, path)
        else:
            self.load_weights(self.net, path)

    def eval(self, dataloader):
        assert self.phase == 'test', "Command arg phase should be 'test'. "
        from tqdm import tqdm
        self.net.eval()
        output = []

        for i, image in tqdm(enumerate(dataloader)):
            if self.cuda:
                image = Variable(image.cuda(), volatile=True)
            else:
                image = Variable(image, volatile=True)

            # cls forward
            out = self.net(image)
            if isinstance(out, list):
                out_max = out[-1]
                if out_max.size(2) != image.size(2):
                    out = self.interp(out_max)
            else:
                if out.size(2) != image.size(2):
                    out = self.interp(out)
            # out [bs * num_tta, c, h, w]
            if self.args.use_tta:
                num_tta = len(tta_config)
                # out = F.softmax(out, dim=1)
                out = detta_score(
                    out.view(num_tta, -1, self.args.num_classes, out.size(2),
                             out.size(3)))  # [num_tta, bs, nclass, H, W]
                out = out.mean(dim=0)  # [bs, nclass, H, W]
            out = F.softmax(out)
            output.extend([
                resize(pred[1].data.cpu().numpy(), (101, 101)) for pred in out
            ])
        return np.array(output)

    def tta(self, dataloaders):
        results = np.zeros(shape=(len(dataloaders[0].dataset),
                                  self.args.num_classes))
        for dataloader in dataloaders:
            output = self.eval(dataloader)
            results += output
        return np.argmax(results, 1)

    def tta_output(self, dataloaders):
        results = np.zeros(shape=(len(dataloaders[0].dataset),
                                  self.args.num_classes))
        for dataloader in dataloaders:
            output = self.eval(dataloader)
            results += output
        return results

    def test_val(self, dataloader):
        assert self.phase == 'test', "Command arg phase should be 'test'. "
        from tqdm import tqdm
        self.net.eval()
        predict = []
        true = []
        t1 = time.time()

        for i, (image, mask) in tqdm(enumerate(dataloader)):
            if self.cuda:
                image = Variable(image.cuda(), volatile=True)
                label_image = Variable(mask.cuda(), volatile=True)
            else:
                image = Variable(image, volatile=True)
                label_image = Variable(mask, volatile=True)

            # cls forward
            out = self.net(image)
            if isinstance(out, list):
                out_max = out[-1]
                if out_max.size(2) != label_image.size(2):
                    out = self.interp(out_max)
            else:
                if out.size(2) != image.size(2):
                    out = self.interp(out)
            # out [bs * num_tta, c, h, w]
            if self.args.use_tta:
                num_tta = len(tta_config)
                # out = F.softmax(out, dim=1)
                out = detta_score(
                    out.view(num_tta, -1, self.args.num_classes, out.size(2),
                             out.size(3)))  # [num_tta, bs, nclass, H, W]
                out = out.mean(dim=0)  # [bs, nclass, H, W]
            out = F.softmax(out)
            if self.args.aug == 'heng':
                out = out[:, :, 11:11 + 202, 11:11 + 202]
            predict.extend([
                resize(pred[1].data.cpu().numpy(), (101, 101)) for pred in out
            ])
            # predict.extend([pred[1, :101, :101].data.cpu().numpy() for pred in out])
            # pred.extend(out.data.cpu().numpy())
            true.extend(label_image.data.cpu().numpy())
        # pred_all = np.argmax(np.array(pred), 1)
        for t in np.arange(0.05, 0.51, 0.01):
            pred_all = np.array(predict) > t
            true_all = np.array(true).astype(np.int)
            # new_iou = intersection_over_union(true_all, pred_all)
            # new_iou_t = intersection_over_union_thresholds(true_all, pred_all)
            mean_iou, iou_t = mIoU(true_all, pred_all)
            print('threshold : {:.4f}'.format(t))
            print('mean IoU : {:.4f}, IoU threshold : {:.4f}'.format(
                mean_iou, iou_t))

        return predict, true

    def run_epoch(self, dataloader, writer, epoch, train=True, metrics=True):
        if train:
            self.net.train()
            flag = 'train'
        else:
            self.net.eval()
            flag = 'val'
        t2 = time.time()
        for image, mask in dataloader:
            if train and self.lr_policy != 'step':
                adjust_learning_rate(self.args.lr, self.optimizer, self.iters,
                                     self.iterations * len(dataloader), 0.9,
                                     self.cyclic_m, self.lr_policy)
                self.iters += 1

            if self.cuda:
                image = Variable(image.cuda(), volatile=(not train))
                label_image = Variable(mask.cuda(), volatile=(not train))
            else:
                image = Variable(image, volatile=(not train))
                label_image = Variable(mask, volatile=(not train))
            # cls forward
            out = self.net(image)

            if isinstance(out, list):
                out_max = None
                loss = 0.0
                for i, out_scale in enumerate(out):
                    if out_scale.size(2) != label_image.size(2):
                        out_scale = self.interp(out_scale)
                    if i == (len(out) - 1):
                        out_max = out_scale
                    loss += self.criterion(out_scale, label_image)
                label_image_np = label_image.data.cpu().numpy()
                sig_out_np = out_max.data.cpu().numpy()
                acc = accuracy(label_image_np, np.argmax(sig_out_np, 1))

                self.pred[flag].extend(sig_out_np)
                self.true[flag].extend(label_image_np)

                self.losses[flag].append(loss.data[0])
                self.acces[flag].append(acc)

            else:
                if out.size(-1) != label_image.size(-1):
                    out = self.interp(out)

                loss = self.criterion(out, label_image)
                label_image_np = label_image.data.cpu().numpy()
                sig_out_np = out.data.cpu().numpy()
                acc = accuracy(label_image_np, np.argmax(sig_out_np, 1))

                self.pred[flag].extend(sig_out_np)
                self.true[flag].extend(label_image_np)

                self.losses[flag].append(loss.data[0])
                self.acces[flag].append(acc)

            if train:
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        if metrics:
            n = len(self.losses[flag])
            loss = sum(self.losses[flag]) / n
            scalars = [
                loss,
            ]
            names = [
                'loss',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_loss')

            all_acc = sum(self.acces[flag]) / n
            scalars = [
                all_acc,
            ]
            names = [
                'all_acc',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_acc')

            # all_score = sum(self.scores[flag]) / n
            # scalars = [all_score, ]
            # names = ['all_score', ]
            # write_scalars(writer, scalars, names, epoch, tag=flag + '_score')

            pred_all = np.argmax(np.array(self.pred[flag]), 1)
            true_all = np.array(self.true[flag]).astype(np.int)
            mean_iou, iou_t = mIoU(true_all, pred_all)

            # new_iou = intersection_over_union(true_all, pred_all)
            # new_iou_t = intersection_over_union_thresholds(true_all, pred_all)

            scalars = [
                mean_iou,
                iou_t,
            ]
            names = [
                'mIoU',
                'mIoU_threshold',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_IoU')

            scalars = [
                self.optimizer.param_groups[0]['lr'],
            ]
            names = [
                'learning_rate',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_lr')

            print(
                '{} loss: {:.4f} | acc: {:.4f} | mIoU: {:.4f} | mIoU_threshold: {:.4f} |  n_iter: {} |  learning_rate: {} | time: {:.2f}'
                .format(flag, loss, all_acc, mean_iou, iou_t, epoch,
                        self.optimizer.param_groups[0]['lr'],
                        time.time() - t2))

            self.losses[flag] = []
            self.pred[flag] = []
            self.true[flag] = []
            self.acces[flag] = []
            self.scores[flag] = []

            if (not train) and (iou_t >= self.best_val):
                if self.args.ms:
                    if self.args.mGPUs:
                        self.save_network(self.net.module.Scale,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                    else:
                        self.save_network(self.net.Scale,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                else:
                    if self.args.mGPUs:
                        self.save_network(self.net.module,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                    else:
                        self.save_network(self.net,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                print(
                    'val improve from {:.4f} to {:.4f} saving in best val_iteration {}'
                    .format(self.best_val, iou_t, epoch))
                self.best_val = iou_t
                self.count = 0

            if (not train) and (self.best_val - iou_t > 0.003) and (
                    self.count < 10) and (self.lr_policy == 'step'):
                self.count += 1
            if (not train) and (self.count >= 10) and (self.lr_policy
                                                       == 'step'):
                self._adjust_learning_rate(epoch)
                self.count = 0

    def train_val(self, dataloader_train, dataloader_val, writer):
        val_epoch = 0
        for epoch in range(self.iterations):
            if (self.lr_policy == 'cyclic') and (
                    epoch % int(self.iterations / self.cyclic_m) == 0):
                print('-------start cycle {}------------'.format(
                    epoch // int(self.iterations / self.cyclic_m)))
                self.best_val = 0.0
            self.run_epoch(dataloader_train,
                           writer,
                           epoch,
                           train=True,
                           metrics=True)
            self.run_epoch(dataloader_val,
                           writer,
                           val_epoch,
                           train=False,
                           metrics=True)
            val_epoch += 1
            if (epoch + 1) % self.args.save_freq == 0:
                if self.args.ms:
                    if self.args.mGPUs:
                        self.save_network(
                            self.net.module.Scale,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                    else:
                        self.save_network(
                            self.net.Scale,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                else:
                    if self.args.mGPUs:
                        self.save_network(
                            self.net.module,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                    else:
                        self.save_network(
                            self.net,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                print('saving in val_iteration {}'.format(val_epoch))
예제 #6
0
def train(train_sources, eval_source):
    path = sys.argv[1]
    dr = DataReader(path, train_sources)
    dr.read()
    print(len(dr.train.x))

    batch_size = 8
    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')

    dataset_s_train = MultiDomainDataset(dr.train.x, dr.train.y, dr.train.vendor, device, DomainAugmentation())
    dataset_s_dev = MultiDomainDataset(dr.dev.x, dr.dev.y, dr.dev.vendor, device)
    dataset_s_test = MultiDomainDataset(dr.test.x, dr.test.y, dr.test.vendor, device)
    loader_s_train = DataLoader(dataset_s_train, batch_size, shuffle=True)

    dr_eval = DataReader(path, [eval_source])
    dr_eval.read()

    dataset_eval_dev = MultiDomainDataset(dr_eval.dev.x, dr_eval.dev.y, dr_eval.dev.vendor, device)
    dataset_eval_test = MultiDomainDataset(dr_eval.test.x, dr_eval.test.y, dr_eval.test.vendor, device)

    dataset_da_train = MultiDomainDataset(dr.train.x+dr_eval.train.x, dr.train.y+dr_eval.train.y, dr.train.vendor+dr_eval.train.vendor, device, DomainAugmentation())
    loader_da_train = DataLoader(dataset_da_train, batch_size, shuffle=True)

    segmentator = UNet()
    discriminator = Discriminator(n_domains=len(train_sources))
    discriminator.to(device)
    segmentator.to(device)

    sigmoid = nn.Sigmoid()
    selector = Selector()

    s_criterion = nn.BCELoss()
    d_criterion = nn.CrossEntropyLoss()
    s_optimizer = optim.AdamW(segmentator.parameters(), lr=0.0001, weight_decay=0.01)
    d_optimizer = optim.AdamW(discriminator.parameters(), lr=0.001, weight_decay=0.01)
    a_optimizer = optim.AdamW(segmentator.encoder.parameters(), lr=0.001, weight_decay=0.01)
    lmbd = 1/150
    s_train_losses = []
    s_dev_losses = []
    d_train_losses = []
    eval_domain_losses = []
    train_dices = []
    dev_dices = []
    eval_dices = []
    epochs = 3
    da_loader_iter = iter(loader_da_train)
    for epoch in tqdm(range(epochs)):
        s_train_loss = 0.0
        d_train_loss = 0.0
        for index, sample in enumerate(loader_s_train):
            img = sample['image']
            target_mask = sample['target']

            da_sample = next(da_loader_iter, None)
            if epoch == 100:
                s_optimizer.defaults['lr'] = 0.001
                d_optimizer.defaults['lr'] = 0.0001
            if da_sample is None:
                da_loader_iter = iter(loader_da_train)
                da_sample = next(da_loader_iter, None)
            if epoch < 50 or epoch >= 100:
                # Training step of segmentator
                predicted_activations, inner_repr = segmentator(img)
                predicted_mask = sigmoid(predicted_activations)
                s_loss = s_criterion(predicted_mask, target_mask)
                s_optimizer.zero_grad()
                s_loss.backward()
                s_optimizer.step()
                s_train_loss += s_loss.cpu().detach().numpy()

            if epoch >= 50:
                # Training step of discriminator
                predicted_activations, inner_repr = segmentator(da_sample['image'])
                predicted_activations = predicted_activations.clone().detach()
                inner_repr = inner_repr.clone().detach()
                predicted_vendor = discriminator(predicted_activations, inner_repr)
                d_loss = d_criterion(predicted_vendor, da_sample['vendor'])
                d_optimizer.zero_grad()
                d_loss.backward()
                d_optimizer.step()
                d_train_loss += d_loss.cpu().detach().numpy()

            if epoch >= 100:
                # adversarial training step
                predicted_mask, inner_repr = segmentator(da_sample['image'])
                predicted_vendor = discriminator(predicted_mask, inner_repr)
                a_loss = -1 * lmbd * d_criterion(predicted_vendor, da_sample['vendor'])
                a_optimizer.zero_grad()
                a_loss.backward()
                a_optimizer.step()
                lmbd += 1/150
        inference_model = nn.Sequential(segmentator, selector, sigmoid)
        inference_model.to(device)
        inference_model.eval()
        d_train_losses.append(d_train_loss / len(loader_s_train))
        s_train_losses.append(s_train_loss / len(loader_s_train))
        s_dev_losses.append(calculate_loss(dataset_s_dev, inference_model, s_criterion, batch_size))
        eval_domain_losses.append(calculate_loss(dataset_eval_dev, inference_model, s_criterion, batch_size))

        train_dices.append(calculate_dice(inference_model, dataset_s_train))
        dev_dices.append(calculate_dice(inference_model, dataset_s_dev))
        eval_dices.append(calculate_dice(inference_model, dataset_eval_dev))

        segmentator.train()

    date_time = datetime.now().strftime("%m%d%Y_%H%M%S")
    model_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "model", "weights", "segmentator"+str(date_time)+".pth")
    torch.save(segmentator.state_dict(), model_path)

    util.plot_data([(s_train_losses, 'train_losses'), (s_dev_losses, 'dev_losses'), (d_train_losses, 'discriminator_losses'),
               (eval_domain_losses, 'eval_domain_losses')],
              'losses.png')
    util.plot_dice([(train_dices, 'train_dice'), (dev_dices, 'dev_dice'), (eval_dices, 'eval_dice')],
              'dices.png')

    inference_model = nn.Sequential(segmentator, selector, sigmoid)
    inference_model.to(device)
    inference_model.eval()

    print('Dice on annotated: ', calculate_dice(inference_model, dataset_s_test))
    print('Dice on unannotated: ', calculate_dice(inference_model, dataset_eval_test))
예제 #7
0
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    n_classes = utils.params.n_classes
    n_channels = utils.params.n_channels
    print(
        "Number of input channels = {}. Number of classes = {}.".format(
            n_channels, n_classes
        )
    )
    if n_classes == 3:
        class_weights = np.array([0.3, 0.5, 1]).astype(np.float)
    elif n_classes == 2:
        class_weights = np.array([0.047619, 1]).astype(np.float)
    print("Current class weights = {}".format(class_weights))
    assert (
        len(class_weights) == n_classes
    ), "Should be a 1D Tensor assigning weight to each of the classes. Lenght of the weights-vector should be equal to the number of classes"

    net = UNet(n_channels=n_channels, n_classes=n_classes)
    net.to(device=device)

    try:
        print("Training starting...")
        train_net(net, n_channels, n_classes, class_weights)
        print("Training done")
    except KeyboardInterrupt:
        torch.save(net.state_dict(), model_path + "INTERRUPTED.pth")
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)