Beispiel #1
0
def main():
    # load the config
    config = parse_train_config()
    # load the model
    model = Unet3d(in_channels=config.in_channels,
                   out_channels=config.out_channels,
                   interpolate=config.interpolate,
                   concatenate=config.concatenate,
                   norm_type=config.norm_type,
                   init_channels=config.init_channels,
                   scale_factor=(2, 2, 2))

    if config.init_weight:
        model.apply(init_weight)

    # get the device to train on
    gpu_all = tuple(config.gpu_index)
    gpu_main = gpu_all[0]

    device = torch.device(
        'cuda:' + str(gpu_main) if torch.cuda.is_available() else 'cpu')
    model = nn.DataParallel(model, device_ids=gpu_all)
    model.to(device)

    # load data
    phase = 'train'
    train_dataset = Hdf5Dataset(config.data_path, phase,
                                config.train_sub_index)
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=config.train_batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True,
                              drop_last=True)
    val_dataset = Hdf5Dataset(config.data_path, phase, config.val_sub_index)
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=config.val_batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True)

    # define accuracy
    accuracy_criterion = DiceAccuracy()

    # define loss
    if config.loss_weight is None:
        loss_criterion = DiceLoss()
    else:
        loss_criterion = DiceLoss(weight=config.loss_weight)

    # define optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=config.learning_rate,
                                 weight_decay=config.weight_decay)

    trainer = Trainer(config, model, device, train_loader, val_loader,
                      accuracy_criterion, loss_criterion, optimizer)
    trainer.main()
Beispiel #2
0
def train_net(args):
    cropsize = [cfgs.crop_height, cfgs.crop_width]
    # dataset_train = CityScapes(cfgs.data_dir, cropsize=cropsize, mode='train')
    dataset_train = ContextVoc(cfgs.train_file,cropsize=cropsize, mode='train')
    dataloader_train = DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        drop_last=True
    )
    # dataset_val = CityScapes(cfgs.data_dir,  mode='val')
    dataset_val = ContextVoc(cfgs.val_file, mode='val')
    dataloader_val = DataLoader(
        dataset_val,
        batch_size=1,
        shuffle=True,
        num_workers=args.num_workers,
        drop_last=True
    )
    # build net
    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
    if torch.cuda.is_available() and args.use_gpu:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    net = BiSeNet(cfgs.num_classes, cfgs.netname).to(device)
    # net = BiSeNet(cfgs.num_classes).to(device)
    if args.mulgpu:
        net = torch.nn.DataParallel(net)
    if args.pretrained_model_path is not None:
        print('load model from %s ...' % args.pretrained_model_path)
        load_dict = torch.load(args.pretrained_model_path,map_location=device)
        dict_new = renamedict(net.module.state_dict(),load_dict)
        net.module.load_state_dict(dict_new,strict=False)
        # net.load_state_dict(torch.load(args.pretrained_model_path))
        print('Done!')
    net.train()
    # build optimizer
    if args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(net.parameters(), args.learning_rate)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(), args.learning_rate, momentum=0.9, weight_decay=1e-4)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(net.parameters(), args.learning_rate)
    else:  
        print('not supported optimizer \n')
        optimizer = None
    #build loss
    if args.losstype == 'dice':
        criterion = DiceLoss()
    elif args.losstype == 'crossentropy':
        criterion = torch.nn.CrossEntropyLoss()
    elif args.losstype == 'ohem':
        score_thres = 0.7
        n_min = args.batch_size * cfgs.crop_height * cfgs.crop_width //16
        criterion = OhemCELoss(thresh=score_thres, n_min=n_min)
    elif args.losstype == 'focal':
        criterion = SoftmaxFocalLoss()
    return net,optimizer,criterion,dataloader_train,dataloader_val
Beispiel #3
0
def train(model, train_loader, device, optimizer):
    model.train()
    steps = len(train_loader)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, steps)
    train_loss = 0.0
    dsc_loss = DiceLoss()

    progress_bar = st.sidebar.progress(0)

    for i, data in enumerate(train_loader):
        x,y = data

        optimizer.zero_grad()
        y_pred = model(x.to(device))
        loss = dsc_loss(y_pred, y.to(device))
        #train_loss_list.append(loss.item())
        #train_loss_detail.line_chart(np.array(train_loss_list))
        progress_bar.progress((i+1)/len(train_loader))

        loss.backward()
        optimizer.step()
        scheduler.step()

        train_loss += loss.item()
    return model, train_loss/len(train_loader), optimizer
Beispiel #4
0
def train(args, model, optimizer, dataloader_train, dataloader_val):
    writer = SummaryWriter(
        comment=''.format(args.optimizer, args.context_path))
    if args.loss == 'dice':
        loss_func = DiceLoss()
    elif args.loss == 'crossentropy':
        loss_func = torch.nn.CrossEntropyLoss(ignore_index=255)
    max_miou = 0
    step = 0
    scaler = torch.cuda.amp.GradScaler()
    for epoch in range(args.num_epochs):
        lr = poly_lr_scheduler(optimizer,
                               args.learning_rate,
                               iter=epoch,
                               max_iter=args.num_epochs)
        model.train()
        tq = tqdm(total=len(dataloader_train) * args.batch_size)
        tq.set_description('epoch %d, lr %f' % (epoch, lr))
        loss_record = []
        for i, (data, label) in enumerate(dataloader_train):
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda().long()
            with torch.cuda.amp.autocast():
                output, output_sup1, output_sup2 = model(data)
                loss1 = loss_func(output, label)
                loss2 = loss_func(output_sup1, label)
                loss3 = loss_func(output_sup2, label)
                loss = loss1 + loss2 + loss3
                tq.update(args.batch_size)
                tq.set_postfix(loss='%.6f' % loss)
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            step += 1
            writer.add_scalar('loss_step', loss, step)
            loss_record.append(loss.item())
            scaler.update()
        tq.close()
        loss_train_mean = np.mean(loss_record)
        writer.add_scalar('epoch/loss_epoch_train', float(loss_train_mean),
                          epoch)
        print('loss for train : %f' % (loss_train_mean))
        if epoch % args.checkpoint_step == 0 and epoch != 0:
            if not os.path.isdir(args.save_model_path):
                os.mkdir(args.save_model_path)
            torch.save(model.state_dict(),
                       os.path.join(args.save_model_path, 'model.pth'))

        if epoch % args.validation_step == 0 and epoch != 0:
            precision, miou = val(args, model, dataloader_val)
            if miou > max_miou:
                max_miou = miou
                #import os
                os.makedirs(args.save_model_path, exist_ok=True)
                torch.save(
                    model.state_dict(),
                    os.path.join(args.save_model_path, 'best_dice_loss.pth'))
            writer.add_scalar('epoch/precision_val', precision, epoch)
            writer.add_scalar('epoch/miou val', miou, epoch)
Beispiel #5
0
def main():
    args = parser.parse_args()
    
    dataset = SyntheticCellDataset(arg.img_dir, arg.mask_dir)
    
    indices = torch.randperm(len(dataset)).tolist()
    sr = int(args.split_ratio * len(dataset))
    train_set = torch.utils.data.Subset(dataset, indices[:-sr])
    val_set = torch.utils.data.Subset(dataset, indices[-sr:])
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, pin_memory=True)
    
    device = torch.device("cpu" if not args.use_cuda else "cuda:0")
    
    model = UNet()
    model.to(device)
    
    dsc_loss = DiceLoss()
    
    optimizer = torch.optim.Adam(model.parameters(), args.lr)
    
    val_overall = 1000
    for epoch in args.N_epoch:
        model, train_loss, optimizer = train(model, train_loader, device, optimizer)
        val_loss = validate(model, val_loader, device)
        
        if val_loss < val_overall:
            save_checkpoint(args.model_save_dir + '/epoch_'+str(epoch+1), model, train_loss, val_loss, epoch)
            val_overall = val_loss
            
        print('[{}/{}] train loss :{} val loss : {}'.format(epoch+1, num_epoch, train_loss, val_loss))
    print('Training completed)
Beispiel #6
0
 def __init__(self, *args, lr=1e-3, **kwargs):
     super().__init__(*args, **kwargs)
     self.crit = DiceLoss()
     # self.crit = nn.BCEWithLogitsLoss()
     self.accuracy = pl.metrics.Accuracy()
     self.f1 = pl.metrics.F1()
     self.lr = lr
Beispiel #7
0
 def get_crit(loss_type):
     if loss_type == 'dice':
         return DiceLoss()
     elif loss_type == 'bce':
         return nn.BCEWithLogitsLoss()
     else:
         raise ValueError(
             f'Unsupported loss: {loss_type}. Choose one of [dice, bce].')
Beispiel #8
0
def test_plot(path="data", num_epochs=1, start=0, end=0):
    criterion = DiceLoss()
    liver_dataset = LiverDataset1(path,
                                  transform=x_transforms,
                                  target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=0)
    test_model(criterion, dataloaders, num_epochs)
Beispiel #9
0
    def compute_loss(self, start_logits, end_logits, start_labels, end_labels, label_mask):
        """compute loss on squad task."""
        if len(start_labels.size()) > 1:
            start_labels = start_labels.squeeze(-1)
        if len(end_labels.size()) > 1:
            end_labels = end_labels.squeeze(-1)

        # sometimes the start/end positions are outside our model inputs, we ignore these terms
        batch_size, ignored_index = start_logits.shape # ignored_index: seq_len
        start_labels.clamp_(0, ignored_index)
        end_labels.clamp_(0, ignored_index)

        if self.loss_type != "ce":
            # start_labels/end_labels: position index of answer starts/ends among the document.
            # F.one_hot will map the postion index to a sequence of 0, 1 labels.
            start_labels = F.one_hot(start_labels, num_classes=ignored_index)
            end_labels = F.one_hot(end_labels, num_classes=ignored_index)

        if self.loss_type == "ce":
            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_labels)
            end_loss = loss_fct(end_logits, end_labels)
        elif self.loss_type == "bce":
            start_loss = F.binary_cross_entropy_with_logits(start_logits.view(-1), start_labels.view(-1).float(), reduction="none")
            end_loss = F.binary_cross_entropy_with_logits(end_logits.view(-1), end_labels.view(-1).float(), reduction="none")

            start_loss = (start_loss * label_mask.view(-1)).sum() / label_mask.sum()
            end_loss = (end_loss * label_mask.view(-1)).sum() / label_mask.sum()
        elif self.loss_type == "focal":
            loss_fct = FocalLoss(gamma=self.args.focal_gamma, reduction="none")
            start_loss = loss_fct(FocalLoss.convert_binary_pred_to_two_dimension(start_logits.view(-1)),
                                         start_labels.view(-1))
            end_loss = loss_fct(FocalLoss.convert_binary_pred_to_two_dimension(end_logits.view(-1)),
                                       end_labels.view(-1))
            start_loss = (start_loss * label_mask.view(-1)).sum() / label_mask.sum()
            end_loss = (end_loss * label_mask.view(-1)).sum() / label_mask.sum()

        elif self.loss_type in ["dice", "adaptive_dice"]:
            loss_fct = DiceLoss(with_logits=True, smooth=self.args.dice_smooth, ohem_ratio=self.args.dice_ohem,
                                      alpha=self.args.dice_alpha, square_denominator=self.args.dice_square)
            # add to test
            # start_logits, end_logits = start_logits.view(batch_size, -1), end_logits.view(batch_size, -1)
            # start_labels, end_labels = start_labels.view(batch_size, -1), end_labels.view(batch_size, -1)
            start_logits, end_logits = start_logits.view(-1, 1), end_logits.view(-1, 1)
            start_labels, end_labels = start_labels.view(-1, 1), end_labels.view(-1, 1)
            # label_mask = label_mask.view(batch_size, -1)
            label_mask = label_mask.view(-1, 1)
            start_loss = loss_fct(start_logits, start_labels, mask=label_mask)
            end_loss = loss_fct(end_logits, end_labels, mask=label_mask)
        else:
            raise ValueError("This type of loss func donot exists.")

        total_loss = (start_loss + end_loss) / 2

        return total_loss, start_loss, end_loss
    def train(self):

        self.model.train()
        tbar = tqdm(self.train_queue)
        for step, (input, target) in enumerate(tbar):

            input = input.to(device=self.device, dtype=torch.float32)
            target = target.to(device=self.device, dtype=torch.float32)

            predicts = self.model(input)
            predicts_prob = torch.sigmoid(predicts)
            self.dice = DiceLoss()
            self.loss = (.75 * self.criterion(predicts_prob, target) +
                         .25 * self.dice(
                             (predicts_prob > 0.5).float(), target))

            self.train_loss_meter.update(self.loss.item(), input.size(0))

            self.model_optimizer.zero_grad()
            self.loss.backward()
            self.model_optimizer.step()

            ###########CAL METRIC############
            SE, SPE, ACC, DICE = metrics(predicts_prob, target)

            self.train_accuracy.update(ACC, input.size(0))
            self.train_sensitivity.update(SE, input.size(0))
            self.train_specificity.update(SPE, input.size(0))
            self.tr_dice.update(DICE, input.size(0))
            #################################

            tbar.set_description(
                'loss: %.4f; dice: %.4f' %
                (self.train_loss_meter.mloss, self.tr_dice.mloss))

            self.writer.add_images('Train/Images', input, self.epoch)
            self.writer.add_images('Train/Masks/True', target, self.epoch)
            self.writer.add_images('Train/Masks/pred',
                                   (predicts_prob > .5).float(), self.epoch)

        self.writer.add_scalar('Train/loss', self.train_loss_meter.mloss,
                               self.epoch)
        self.writer.add_scalar('Train/Acc', self.train_accuracy.mloss,
                               self.epoch)
        self.writer.add_scalar('Train/Sen', self.train_sensitivity.mloss,
                               self.epoch)
        self.writer.add_scalar('Train/Spe', self.train_specificity.mloss,
                               self.epoch)
        self.writer.add_scalar('Train/Dice', self.tr_dice.mloss, self.epoch)
Beispiel #11
0
def train():
    print('start training ...........')
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
    model = Model().to(device)
    batch_size = 2
    num_epochs = 100
    learning_rate = 0.1

    train_loader, val_loader = get_loader(batch_size=batch_size, shuffle=True)

    optimizer = optim.SGD(model.parameters(),
                          lr=learning_rate,
                          momentum=0.9,
                          nesterov=True)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
    criterion = DiceLoss(smooth=1.)

    train_losses, val_losses = [], []
    for epoch in range(num_epochs):
        train_epoch_loss = fit(epoch,
                               model,
                               optimizer,
                               criterion,
                               device,
                               train_loader,
                               phase='training')
        val_epoch_loss = fit(epoch,
                             model,
                             optimizer,
                             criterion,
                             device,
                             val_loader,
                             phase='validation')
        print('-----------------------------------------')

        if epoch == 0 or val_epoch_loss <= np.min(val_losses):
            torch.save(model.state_dict(), 'output/weight.pth')

        train_losses.append(train_epoch_loss)
        val_losses.append(val_epoch_loss)

        write_figures('output', train_losses, val_losses)
        write_log('output', epoch, train_epoch_loss, val_epoch_loss)

        scheduler.step(val_epoch_loss)
Beispiel #12
0
 def __init__(self, config, ntoken, ntag, vectors):
     super(BiLSTM_CRF_DAE, self).__init__()
     self.config = config
     self.vocab_size = ntoken
     self.batch_size = config.batch_size
     self.dropout = config.dropout
     self.drop = nn.Dropout(self.dropout)
     self.embedding = nn.Embedding(ntoken, config.embedding_size)
     if config.is_vector:
         self.embedding = nn.Embedding.from_pretrained(vectors,
                                                       freeze=False)
     self.lstm = nn.LSTM(input_size=config.embedding_size,
                         hidden_size=config.bi_lstm_hidden // 2,
                         num_layers=config.num_layers,
                         bidirectional=True)
     self.linner = nn.Linear(config.bi_lstm_hidden, ntag)
     self.lm_decoder = nn.Linear(config.bi_lstm_hidden, self.vocab_size)
     self.dice_loss = DiceLoss()
     self.criterion = nn.CrossEntropyLoss()
     self.crflayer = CRF(ntag)
Beispiel #13
0
 def __init__(self, config, ntoken, ntag, vectors):
     super(TransformerEncoderModel, self).__init__()
     self.ntoken = ntoken
     self.config = config
     self.src_mask = None
     self.vectors = vectors
     self.embedding_size = config.embedding_size
     self.embedding = nn.Embedding(ntoken, config.embedding_size)
     self.pos_encoder = PositionalEncoding(config.embedding_size,
                                           config.dropout)
     encoder_layers = TransformerEncoderLayer(config.embedding_size,
                                              config.nhead, config.nhid,
                                              config.dropout)
     self.lstm = nn.LSTM(input_size=config.embedding_size,
                         hidden_size=config.bi_lstm_hidden // 2,
                         num_layers=1,
                         bidirectional=True)
     self.att_weight = nn.Parameter(
         torch.randn(config.bi_lstm_hidden, config.batch_size,
                     config.bi_lstm_hidden))
     self.transformer_encoder = TransformerEncoder(encoder_layers,
                                                   config.nlayers)
     if config.is_pretrained_model:
         # with torch.no_grad():
         config_bert = BertConfig.from_pretrained(config.pretrained_config)
         model = BertModel.from_pretrained(config.pretrained_model,
                                           config=config_bert)
         self.embedding = model
         for name, param in model.named_parameters():
             param.requires_grad = True
     elif config.is_vector:
         self.embedding = nn.Embedding.from_pretrained(vectors,
                                                       freeze=False)
     self.embedding.weight.requires_grad = True
     self.emsize = config.embedding_size
     self.linner = nn.Linear(config.bi_lstm_hidden, ntag)
     self.init_weights()
     self.crflayer = CRF(ntag)
     self.dice_loss = DiceLoss()
     self.criterion = nn.CrossEntropyLoss()
     self.lm_decoder = nn.Linear(self.config.bi_lstm_hidden, ntoken)
Beispiel #14
0
def main():

    dsc_loss = DiceLoss()

    model = UNet(in_channels=21, out_channels=4)
    model.cuda()
    print(summary(model, input_size=(21, 256, 256)))
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    scheduler = optim.lr_scheduler.CyclicLR(optimizer,
                                            base_lr=0.0001,
                                            max_lr=0.01)
    loader_train, loader_eval = datasets()

    save_model_path = '../checkpoints/baseline.pk'
    load_model_path = '../checkpoints/baseline-056.pk'
    model.load_state_dict(torch.load(load_model_path))

    for e in range(1000):
        model = train_epoch(model, loader_train, optimizer, dsc_loss)
        model = eval_epoch(model, loader_eval, dsc_loss)
        scheduler.step()
        torch.save(model.state_dict(), save_model_path)
        print('begin epoch', e, 'saving to', save_model_path)
Beispiel #15
0
def train():

    model = UNet_2d(8, 1, 1).to(
        device)  # conv_channels=8, input_channels, classes, slices
    print(model)

    criterion = [DiceLoss(), torch.nn.BCELoss()]

    optimizer = optim.Adam(model.parameters(), lr=args.lr, eps=1e-8)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               mode='min',
                                               factor=0.1,
                                               patience=5,
                                               cooldown=0,
                                               min_lr=1e-8)

    dataset = SpleenDataset(transform=x_transform,
                            target_transform=y_transform)
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=4)
    train_model(model, criterion, optimizer, dataloader, args.num_epochs)
    def __init__(self, filePathTrain):
        # Hyperparameters
        self.batchSize = 1
        self.numEpochs = 10
        self.learningRate = 0.001
        self.validPercent = 0.1
        self.trainShuffle = True
        self.testShuffle = False
        self.momentum = 0.99
        self.imageDim = 128

        # Variables
        self.imageDirectory = filePathTrain
        self.labelDirectory = filePathTrain
        self.numChannels = 3
        self.numClasses = 1

        # Device configuration
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        # Load dataset
        self.trainLoader = self.getTrainingLoader()
        #self.testLoader  = self.getTestLoader()

        # Setup model
        self.model = UNet(n_channels=self.numChannels,
                          n_classes=self.numClasses,
                          bilinear=True).to(self.device)
        #self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=self.learningRate, weight_decay=self.weightDecay, momentum=self.momentum)
        #self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learningRate)
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=self.learningRate,
                                         momentum=self.momentum)
        #self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min' if self.numClasses > 1 else 'max', patience=2)
        self.criterion = DiceLoss()
Beispiel #17
0
def main(train_args, model):
    print(train_args)

    net = model.cuda()
    net.train()

    criterion = DiceLoss()

    optimizer_adam = optim.Adam(model.parameters(),
                                lr=train_args['lr'],
                                weight_decay=train_args['weight_decay'],
                                )

    if len(train_args['snapshot']) == 0:
        curr_epoch = 0
        train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(torch.load(opj(savedir_nets2, train_args['snapshot'] + '.pth')))
        optimizer_adam.load_state_dict(torch.load(opj(savedir_nets2, train_args['snapshot'] + '_opt.pth')))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                                     'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                                     'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    scheduler_adam = StepLR(optimizer=optimizer_adam, step_size=1, gamma=train_args['lr_adapt'])

    inputs, labels = preproc_train2.get_trainset()

    for epoch in range(curr_epoch, train_args['epochs']):
        train(inputs[:], labels[:], net, criterion=criterion, optimizer=optimizer_adam, epoch=epoch, train_args=train_args)
        validate(inputs[:], labels[:], net, criterion, optimizer_adam, epoch, train_args)
        scheduler_adam.step()

    return 0
Beispiel #18
0
    def __init__(self, args, model: torch.nn.Module, train_dataset: Dataset, test_dataset: Dataset, utils):
        self.utils = utils
        self.args = args
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

        self.batch_size = self.args.batch_size
        self.img_size = self.args.img_size

        self.model = model.to(self.device)

        os.makedirs(os.path.join(self.args.ckpt_dir, self.model.name), exist_ok=True)
        os.makedirs(self.args.save_gen_images_dir, exist_ok=True)

        ''' optimizer '''

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)

        '''dataset and dataloader'''
        self.train_dataset = train_dataset
        weights = self.utils.make_weights_for_balanced_classes(self.train_dataset.imgs, len(self.train_dataset.classes))
        weights = torch.DoubleTensor(weights)
        sampler = WeightedRandomSampler(weights, len(weights))

        self.train_dataloader = DataLoader(self.train_dataset, self.batch_size,
                                           num_workers=args.num_worker, sampler=sampler,
                                           pin_memory=True)

        self.test_dataset = test_dataset
        self.test_dataloader = DataLoader(self.test_dataset, self.batch_size, num_workers=args.num_worker,
                                          pin_memory=True)

        '''loss function'''
        self.criterion = DiceLoss().to(self.device)

        '''scheduler'''
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.5, patience=3)
def train_main(data_folder, in_channels, out_channels, learning_rate, no_epochs):
    """
    Train module
    :param data_folder: data folder
    :param in_channels: the input channel of input images
    :param out_channels: the final output channel
    :param learning_rate: set learning rate for training
    :param no_epochs: number of epochs to train model
    :return: None
    """
    #print("Entro a train_main")
    model = UnetModel(in_channels=in_channels, out_channels=out_channels)
    #print("Acabo el modelo")
    optim = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
    criterion = DiceLoss()
    #print("Entrando a trainer")
    trainer = Trainer(data_dir=data_folder, net=model, optimizer=optim, criterion=criterion, no_epochs=no_epochs)
    
    trainer.train(data_paths_loader=get_data_paths, dataset_loader=data_gen, batch_data_loader=batch_data_gen)
    #model_json = model.to_json()
    #with open("modelu.json", "w") as json_file:
    #     json_file.write(model_json)
    #model.save_weights("model_inicial")
    print("NOT Saved model to disk")
Beispiel #20
0
def train_val(config):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    train_loader = get_dataloader(img_dir=config.train_img_dir,
                                  mask_dir=config.train_mask_dir,
                                  mode="train",
                                  batch_size=config.batch_size,
                                  num_workers=config.num_workers)
    val_loader = get_dataloader(img_dir=config.val_img_dir,
                                mask_dir=config.val_mask_dir,
                                mode="val",
                                batch_size=config.batch_size,
                                num_workers=config.num_workers)

    writer = SummaryWriter(
        comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" %
        (config.lr, config.batch_size, config.model_type, config.data_type))

    if config.model_type not in [
            'UNet', 'R2UNet', 'AUNet', 'R2AUNet', 'SEUNet', 'SEUNet++',
            'UNet++', 'DAUNet', 'DANet', 'AUNetR', 'RendDANet', "BASNet"
    ]:
        print('ERROR!! model_type should be selected in supported models')
        print('Choose model %s' % config.model_type)
        return
    if config.model_type == "UNet":
        model = UNet()
    elif config.model_type == "AUNet":
        model = AUNet()
    elif config.model_type == "R2UNet":
        model = R2UNet()
    elif config.model_type == "SEUNet":
        model = SEUNet(useCSE=False, useSSE=False, useCSSE=True)
    elif config.model_type == "UNet++":
        model = UNetPP()
    elif config.model_type == "DANet":
        model = DANet(backbone='resnet101', nclass=1)
    elif config.model_type == "AUNetR":
        model = AUNet_R16(n_classes=1, learned_bilinear=True)
    elif config.model_type == "RendDANet":
        model = RendDANet(backbone='resnet101', nclass=1)
    elif config.model_type == "BASNet":
        model = BASNet(n_channels=3, n_classes=1)
    else:
        model = UNet()

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.to(device, dtype=torch.float)

    if config.optimizer == "sgd":
        optimizer = SGD(model.parameters(),
                        lr=config.lr,
                        weight_decay=1e-6,
                        momentum=0.9)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    if config.loss == "dice":
        criterion = DiceLoss()
    elif config.loss == "bce":
        criterion = nn.BCELoss()
    elif config.loss == "bas":
        criterion = BasLoss()
    else:
        criterion = MixLoss()

    scheduler = lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
    global_step = 0
    best_dice = 0.0
    for epoch in range(config.num_epochs):
        epoch_loss = 0.0
        with tqdm(total=config.num_train,
                  desc="Epoch %d / %d" % (epoch + 1, config.num_epochs),
                  unit='img') as train_pbar:
            model.train()
            for image, mask in train_loader:
                image = image.to(device, dtype=torch.float)
                mask = mask.to(device, dtype=torch.float)
                d0, d1, d2, d3, d4, d5, d6, d7 = model(image)
                loss = criterion(d0, d1, d2, d3, d4, d5, d6, d7, mask)
                epoch_loss += loss.item()

                writer.add_scalar('Loss/train', loss.item(), global_step)
                train_pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_pbar.update(image.shape[0])
                global_step += 1

                # if global_step % 100 == 0:
                #     writer.add_images('masks/true', mask, global_step)
                #     writer.add_images('masks/pred', d0 > 0.5, global_step)
            scheduler.step()
        epoch_dice = 0.0
        epoch_acc = 0.0
        epoch_sen = 0.0
        epoch_spe = 0.0
        epoch_pre = 0.0
        current_num = 0
        with tqdm(total=config.num_val,
                  desc="Epoch %d / %d validation round" %
                  (epoch + 1, config.num_epochs),
                  unit='img') as val_pbar:
            model.eval()
            locker = 0
            for image, mask in val_loader:
                current_num += image.shape[0]
                image = image.to(device, dtype=torch.float)
                mask = mask.to(device, dtype=torch.float)
                d0, d1, d2, d3, d4, d5, d6, d7 = model(image)
                batch_dice = dice_coeff(mask, d0).item()
                epoch_dice += batch_dice * image.shape[0]
                epoch_acc += get_accuracy(pred=d0, true=mask) * image.shape[0]
                epoch_sen += get_sensitivity(pred=d0,
                                             true=mask) * image.shape[0]
                epoch_spe += get_specificity(pred=d0,
                                             true=mask) * image.shape[0]
                epoch_pre += get_precision(pred=d0, true=mask) * image.shape[0]
                if locker == 200:
                    writer.add_images('masks/true', mask, epoch + 1)
                    writer.add_images('masks/pred', d0 > 0.5, epoch + 1)
                val_pbar.set_postfix(**{'dice (batch)': batch_dice})
                val_pbar.update(image.shape[0])
                locker += 1
            epoch_dice /= float(current_num)
            epoch_acc /= float(current_num)
            epoch_sen /= float(current_num)
            epoch_spe /= float(current_num)
            epoch_pre /= float(current_num)
            epoch_f1 = get_F1(SE=epoch_sen, PR=epoch_pre)
            if epoch_dice > best_dice:
                best_dice = epoch_dice
                writer.add_scalar('Best Dice/test', best_dice, epoch + 1)
                torch.save(
                    model, config.result_path + "/%s_%s_%d.pth" %
                    (config.model_type, str(epoch_dice), epoch + 1))
            logging.info('Validation Dice Coeff: {}'.format(epoch_dice))
            print("epoch dice: " + str(epoch_dice))
            writer.add_scalar('Dice/test', epoch_dice, epoch + 1)
            writer.add_scalar('Acc/test', epoch_acc, epoch + 1)
            writer.add_scalar('Sen/test', epoch_sen, epoch + 1)
            writer.add_scalar('Spe/test', epoch_spe, epoch + 1)
            writer.add_scalar('Pre/test', epoch_pre, epoch + 1)
            writer.add_scalar('F1/test', epoch_f1, epoch + 1)

    writer.close()
    print("Training finished")
Beispiel #21
0
def main(args):
    makedirs(args)
    snapshotargs(args)
    device = torch.device("cpu" if not torch.cuda.is_available() else args.device)

    loader_train, loader_valid = data_loaders(args)
    loaders = {"train": loader_train, "valid": loader_valid}

    unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
    # unet.apply(weights_init)
    unet.to(device)

    dsc_loss = DiceLoss()
    best_validation_dsc = 0.0

    optimizer = optim.Adam(unet.parameters(), lr=args.lr, weight_decay=1e-3)
    # optimizer = optim.Adam(unet.parameters(), lr=args.lr)

    logger = Logger(args.logs)
    loss_train = []
    loss_valid = []

    log_train = []
    log_valid = []

    validation_pred = []
    validation_true = []
    step = 0

    for epoch in tqdm(range(args.epochs), total=args.epochs):
        for phase in ["train", "valid"]:
            if phase == "train":
                unet.train()
            else:
                unet.eval()


            for i, data in enumerate(loaders[phase]):
                if phase == "train":
                    step += 1
                x, y_true = data
                x, y_true = x.to(device), y_true.to(device)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == "train"):
                    y_pred = unet(x)
                    loss = dsc_loss(y_pred, y_true)
                    print(loss)

                    # if phase == "valid":
                    if phase == "train":
                        loss_valid.append(loss.item())
                        y_pred_np = y_pred.detach().cpu().numpy()
                        validation_pred.extend(
                            [y_pred_np[s] for s in range(y_pred_np.shape[0])]
                        )
                        y_true_np = y_true.detach().cpu().numpy()
                        validation_true.extend(
                            [y_true_np[s] for s in range(y_true_np.shape[0])]
                        )

                    if phase == "train":
                        loss_train.append(loss.item())
                        loss.backward()
                        optimizer.step()


            if phase == "valid":
                dsc, label_dsc = dsc_per_volume(
                          validation_pred,
                          validation_true,
                          # loader_valid.dataset.patient_slice_index,
                          loader_train.dataset.patient_slice_index,
                          )
                mean_dsc = np.mean(dsc)
                print(mean_dsc)
                print(np.array(label_dsc).mean(axis=0))

                if mean_dsc > best_validation_dsc:
                    best_validation_dsc = mean_dsc
                    best_label_dsc = label_dsc
                    torch.save(unet.state_dict(), os.path.join(args.weights, "unet.pt"))
                    opt = epoch
                log_valid.append(np.mean(loss_valid))
                loss_valid = []
                validation_pred = []
                validation_true = []
        log_train.append(np.mean(loss_train))
        loss_train=[]

    plt.plot(log_valid)
    plt.plot(log_train)
    plt.savefig("Test")
    print("Best validation mean DSC: {:4f}".format(best_validation_dsc))
    print(opt)
Beispiel #22
0
def main():
    # load data
    print('\nloading the dataset ...')
    assert opt.dataset == "ISIC2016" or opt.dataset == "ISIC2017"
    if opt.dataset == "ISIC2016":
        num_aug = 5
        normalize = Normalize((0.7012, 0.5517, 0.4875),
                              (0.0942, 0.1331, 0.1521))
    elif opt.dataset == "ISIC2017":
        num_aug = 2
        normalize = Normalize((0.6820, 0.5312, 0.4736),
                              (0.0840, 0.1140, 0.1282))
    if opt.over_sample:
        print('data is offline oversampled ...')
        train_file = 'train_oversample.csv'
    else:
        print('no offline oversampling ...')
        train_file = 'train.csv'
    im_size = 224
    transform_train = torch_transforms.Compose([
        RatioCenterCrop(0.8),
        Resize((256, 256)),
        RandomCrop((224, 224)),
        RandomRotate(),
        RandomHorizontalFlip(),
        RandomVerticalFlip(),
        ToTensor(), normalize
    ])
    transform_val = torch_transforms.Compose([
        RatioCenterCrop(0.8),
        Resize((256, 256)),
        CenterCrop((224, 224)),
        ToTensor(), normalize
    ])
    trainset = ISIC(csv_file=train_file, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=8,
        worker_init_fn=_worker_init_fn_(),
        drop_last=True)
    valset = ISIC(csv_file='val.csv', transform=transform_val)
    valloader = torch.utils.data.DataLoader(valset,
                                            batch_size=64,
                                            shuffle=False,
                                            num_workers=8)
    print('done\n')

    # load models
    print('\nloading the model ...')

    if not opt.no_attention:
        print('turn on attention ...')
        if opt.normalize_attn:
            print('use softmax for attention map ...')
        else:
            print('use sigmoid for attention map ...')
    else:
        print('turn off attention ...')

    net = AttnVGG(num_classes=2,
                  attention=not opt.no_attention,
                  normalize_attn=opt.normalize_attn)
    dice = DiceLoss()
    if opt.focal_loss:
        print('use focal loss ...')
        criterion = FocalLoss(gama=2., size_average=True, weight=None)
    else:
        print('use cross entropy loss ...')
        criterion = nn.CrossEntropyLoss()
    print('done\n')

    # move to GPU
    print('\nmoving models to GPU ...')
    model = nn.DataParallel(net, device_ids=device_ids).to(device)
    criterion.to(device)
    dice.to(device)
    print('done\n')

    # optimizer
    optimizer = optim.SGD(model.parameters(),
                          lr=opt.lr,
                          momentum=0.9,
                          weight_decay=1e-4,
                          nesterov=True)
    lr_lambda = lambda epoch: np.power(0.1, epoch // 10)
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    # training
    print('\nstart training ...\n')
    step = 0
    EMA_accuracy = 0
    AUC_val = 0
    writer = SummaryWriter(opt.outf)
    if opt.log_images:
        data_iter = iter(valloader)
        fixed_batch = next(data_iter)
        fixed_batch = fixed_batch['image'][0:16, :, :, :].to(device)
    for epoch in range(opt.epochs):
        torch.cuda.empty_cache()
        # adjust learning rate
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        writer.add_scalar('train/learning_rate', current_lr, epoch)
        print("\nepoch %d learning rate %f\n" % (epoch + 1, current_lr))
        # run for one epoch
        for aug in range(num_aug):
            for i, data in enumerate(trainloader, 0):
                # warm up
                model.train()
                model.zero_grad()
                optimizer.zero_grad()
                inputs, seg, labels = data['image'], data['image_seg'], data[
                    'label']
                seg = seg[:, -1:, :, :]
                seg_1 = F.adaptive_avg_pool2d(seg,
                                              im_size // opt.base_up_factor)
                seg_2 = F.adaptive_avg_pool2d(
                    seg, im_size // opt.base_up_factor // 2)
                inputs, seg_1, seg_2, labels = inputs.to(device), seg_1.to(
                    device), seg_2.to(device), labels.to(device)
                # forward
                pred, a1, a2 = model(inputs)
                # backward
                loss_c = criterion(pred, labels)
                loss_seg1 = dice(a1, seg_1)
                loss_seg2 = dice(a2, seg_2)
                loss = loss_c + 0.001 * loss_seg1 + 0.01 * loss_seg2
                loss.backward()
                optimizer.step()
                # display results
                if i % 10 == 0:
                    model.eval()
                    pred, __, __ = model(inputs)
                    predict = torch.argmax(pred, 1)
                    total = labels.size(0)
                    correct = torch.eq(predict, labels).sum().double().item()
                    accuracy = correct / total
                    EMA_accuracy = 0.9 * EMA_accuracy + 0.1 * accuracy
                    writer.add_scalar('train/loss_c', loss_c.item(), step)
                    writer.add_scalar('train/loss_seg1', loss_seg1.item(),
                                      step)
                    writer.add_scalar('train/loss_seg2', loss_seg2.item(),
                                      step)
                    writer.add_scalar('train/accuracy', accuracy, step)
                    writer.add_scalar('train/EMA_accuracy', EMA_accuracy, step)
                    print(
                        "[epoch %d][aug %d/%d][iter %d/%d] loss_c %.4f loss_seg1 %.4f loss_seg2 %.4f accuracy %.2f%% EMA %.2f%%"
                        %
                        (epoch + 1, aug + 1, num_aug, i + 1, len(trainloader),
                         loss.item(), loss_seg1.item(), loss_seg2.item(),
                         (100 * accuracy), (100 * EMA_accuracy)))
                step += 1
        # the end of each epoch - validation results
        model.eval()
        total = 0
        correct = 0
        with torch.no_grad():
            with open('val_results.csv', 'wt', newline='') as csv_file:
                csv_writer = csv.writer(csv_file, delimiter=',')
                for i, data in enumerate(valloader, 0):
                    images_val, labels_val = data['image'], data['label']
                    images_val, labels_val = images_val.to(
                        device), labels_val.to(device)
                    pred_val, __, __ = model(images_val)
                    predict = torch.argmax(pred_val, 1)
                    total += labels_val.size(0)
                    correct += torch.eq(predict,
                                        labels_val).sum().double().item()
                    # record predictions
                    responses = F.softmax(pred_val,
                                          dim=1).squeeze().cpu().numpy()
                    responses = [
                        responses[i] for i in range(responses.shape[0])
                    ]
                    csv_writer.writerows(responses)
            AP, AUC, precision_mean, precision_mel, recall_mean, recall_mel = compute_metrics(
                'val_results.csv', 'val.csv')
            # save checkpoints
            print('\nsaving checkpoints ...\n')
            checkpoint = {
                'state_dict': model.module.state_dict(),
                'opt_state_dict': optimizer.state_dict(),
            }
            torch.save(checkpoint,
                       os.path.join(opt.outf, 'checkpoint_latest.pth'))
            if AUC > AUC_val:  # save optimal validation model
                torch.save(checkpoint, os.path.join(opt.outf,
                                                    'checkpoint.pth'))
                AUC_val = AUC
            # log scalars
            writer.add_scalar('val/accuracy', correct / total, epoch)
            writer.add_scalar('val/mean_precision', precision_mean, epoch)
            writer.add_scalar('val/mean_recall', recall_mean, epoch)
            writer.add_scalar('val/precision_mel', precision_mel, epoch)
            writer.add_scalar('val/recall_mel', recall_mel, epoch)
            writer.add_scalar('val/AP', AP, epoch)
            writer.add_scalar('val/AUC', AUC, epoch)
            print("\n[epoch %d] val result: accuracy %.2f%%" %
                  (epoch + 1, 100 * correct / total))
            print(
                "\nmean precision %.2f%% mean recall %.2f%% \nprecision for mel %.2f%% recall for mel %.2f%%"
                % (100 * precision_mean, 100 * recall_mean,
                   100 * precision_mel, 100 * recall_mel))
            print("\nAP %.4f AUC %.4f\n optimal AUC: %.4f" %
                  (AP, AUC, AUC_val))
            # log images
            if opt.log_images:
                print('\nlog images ...\n')
                I_train = utils.make_grid(inputs[0:16, :, :, :],
                                          nrow=4,
                                          normalize=True,
                                          scale_each=True)
                I_seg_1 = utils.make_grid(seg_1[0:16, :, :, :],
                                          nrow=4,
                                          normalize=True,
                                          scale_each=True)
                I_seg_2 = utils.make_grid(seg_2[0:16, :, :, :],
                                          nrow=4,
                                          normalize=True,
                                          scale_each=True)
                writer.add_image('train/image', I_train, epoch)
                writer.add_image('train/seg1', I_seg_1, epoch)
                writer.add_image('train/seg2', I_seg_2, epoch)
                if epoch == 0:
                    I_val = utils.make_grid(fixed_batch,
                                            nrow=4,
                                            normalize=True,
                                            scale_each=True)
                    writer.add_image('val/image', I_val, epoch)
            if opt.log_images and (not opt.no_attention):
                print('\nlog attention maps ...\n')
                # training data
                __, a1, a2 = model(inputs[0:16, :, :, :])
                if a1 is not None:
                    attn1 = visualize_attn(I_train,
                                           a1,
                                           up_factor=opt.base_up_factor,
                                           nrow=4)
                    writer.add_image('train/attention_map_1', attn1, epoch)
                if a2 is not None:
                    attn2 = visualize_attn(I_train,
                                           a2,
                                           up_factor=2 * opt.base_up_factor,
                                           nrow=4)
                    writer.add_image('train/attention_map_2', attn2, epoch)
                # val data
                __, a1, a2 = model(fixed_batch)
                if a1 is not None:
                    attn1 = visualize_attn(I_val,
                                           a1,
                                           up_factor=opt.base_up_factor,
                                           nrow=4)
                    writer.add_image('val/attention_map_1', attn1, epoch)
                if a2 is not None:
                    attn2 = visualize_attn(I_val,
                                           a2,
                                           up_factor=2 * opt.base_up_factor,
                                           nrow=4)
                    writer.add_image('val/attention_map_2', attn2, epoch)
Beispiel #23
0
# optimizer = torch.optim.RMSprop(params, lr=config.lr, alpha = 0.95)
# optimizer = RAdam(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0)
# optimizer = PlainRAdam(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0)
if os.path.exists(config.init_optimizer):
    ckpt = torch.load(config.init_optimizer)
    optimizer.load_state_dict(ckpt['optimizer'])

# lr_scheduler
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.3)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.num_epochs*len(data_loader))
lr_scheduler = GradualWarmupScheduler(optimizer, multiplier=100, 
                                      total_epoch=min(1000, len(data_loader)-1), 
                                      after_scheduler=scheduler_cosine)

# loss function
criterion = DiceLoss()
# criterion = Weight_Soft_Dice_Loss(weight=[0.1, 0.9])
# criterion = BCELoss()
# criterion = MixedLoss(10.0, 2.0)
# criterion = Weight_BCELoss(weight_pos=0.25, weight_neg=0.75)
# criterion = Lovasz_Loss(margin=[1, 5]

print('start training...')
train_start = time.time()
for epoch in range(config.num_epochs):
    epoch_start = time.time()
    model_ft, optimizer = train_one_epoch(model_ft, data_loader, criterion, 
                                          optimizer, lr_scheduler=lr_scheduler, device=device, 
                                          epoch=epoch, vis=vis)
    do_valid(model_ft, dataloader_val, criterion, epoch, device, vis=vis)
    print('Epoch time: {:.3f}min\n'.format((time.time()-epoch_start)/60/60))
Beispiel #24
0
def main():
    args = parser.parse_args()
    save_path = 'Trainid_' + args.id
    writer = SummaryWriter(log_dir='runs/' + args.tag + str(time.time()))
    if not os.path.isdir(save_path):
        os.mkdir(save_path)
        os.mkdir(save_path + '/Checkpoint')

    train_dataset_path = 'data/train'
    val_dataset_path = 'data/valid'
    train_transform = transforms.Compose([ToTensor()])
    val_transform = transforms.Compose([ToTensor()])
    train_dataset = TrainDataset(path=train_dataset_path,
                                 transform=train_transform)
    val_dataset = TrainDataset(path=val_dataset_path, transform=val_transform)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=4)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                pin_memory=True,
                                num_workers=4)

    size_train = len(train_dataloader)
    size_val = len(val_dataloader)
    print('Number of Training Images: {}'.format(size_train))
    print('Number of Validation Images: {}'.format(size_val))
    start_epoch = 0
    model = Res(n_ch=4, n_classes=9)
    class_weights = torch.Tensor([1, 1, 1, 1, 1, 1, 1, 1, 0]).cuda()
    criterion = DiceLoss()
    criterion1 = torch.nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    if args.gpu:
        model = model.cuda()
        criterion = criterion.cuda()
        criterion1 = criterion1.cuda()

    if args.resume is not None:
        weight_path = sorted(os.listdir(save_path + '/Checkpoint/'),
                             key=lambda x: float(x[:-8]))[0]
        checkpoint = torch.load(save_path + '/Checkpoint/' + weight_path)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print('Loaded Checkpoint of Epoch: {}'.format(args.resume))

    for epoch in range(start_epoch, int(args.epoch) + start_epoch):
        adjust_learning_rate(optimizer, epoch)
        train(model, train_dataloader, criterion, criterion1, optimizer, epoch,
              writer, size_train)
        print('')
        val_loss = val(model, val_dataloader, criterion, criterion1, epoch,
                       writer, size_val)
        print('')
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            filename=save_path + '/Checkpoint/' + str(val_loss) + '.pth.tar')
    writer.export_scalars_to_json(save_path + '/log.json')
Beispiel #25
0
    def __init__(self, configs):
        self.batch_size = configs.get("batch_size", "16")
        self.epochs = configs.get("epochs", "100")
        self.lr = configs.get("lr", "0.0001")

        device_args = configs.get("device", "cuda")
        self.device = torch.device(
            "cpu" if not torch.cuda.is_available() else device_args)

        self.workers = configs.get("workers", "4")

        self.vis_images = configs.get("vis_images", "200")
        self.vis_freq = configs.get("vis_freq", "10")

        self.weights = configs.get("weights", "./weights")
        if not os.path.exists(self.weights):
            os.mkdir(self.weights)

        self.logs = configs.get("logs", "./logs")
        if not os.path.exists(self.weights):
            os.mkdir(self.weights)

        self.images_path = configs.get("images_path", "./data")

        self.is_resize = config.get("is_resize", False)
        self.image_short_side = config.get("image_short_side", 256)

        self.is_padding = config.get("is_padding", False)

        is_multi_gpu = config.get("DateParallel", False)

        pre_train = config.get("pre_train", False)
        model_path = config.get("model_path", './weights/unet_idcard_adam.pth')

        # self.image_size = configs.get("image_size", "256")
        # self.aug_scale = configs.get("aug_scale", "0.05")
        # self.aug_angle = configs.get("aug_angle", "15")

        self.step = 0

        self.dsc_loss = DiceLoss()
        self.model = UNet(in_channels=Dataset.in_channels,
                          out_channels=Dataset.out_channels)
        if pre_train:
            self.model.load_state_dict(torch.load(model_path,
                                                  map_location=self.device),
                                       strict=False)

        if is_multi_gpu:
            self.model = nn.DataParallel(self.model)

        self.model.to(self.device)

        self.best_validation_dsc = 0.0

        self.loader_train, self.loader_valid = self.data_loaders()

        self.params = [p for p in self.model.parameters() if p.requires_grad]

        self.optimizer = optim.Adam(self.params,
                                    lr=self.lr,
                                    weight_decay=0.0005)
        # self.optimizer = torch.optim.SGD(self.params, lr=self.lr, momentum=0.9, weight_decay=0.0005)
        self.scheduler = lr_scheduler.LR_Scheduler_Head(
            'poly', self.lr, self.epochs, len(self.loader_train))
Beispiel #26
0
        weight = Variable(weight.cuda())

    else:
        weight = args.weight  # weight is None

    print("weight: {}".format(weight))

    # criterion
    if args.criterion == 'nll':
        criterion = nn.NLLLoss(weight=weight)
    elif args.criterion == 'ce':
        criterion = nn.CrossEntropyLoss(weight=weight)
    elif args.criterion == 'dice':
        criterion = DiceLoss(weight=weight,
                             ignore_index=None,
                             weight_type=args.weight_type,
                             cal_zerogt=args.cal_zerogt)

    elif args.criterion == 'gdl_inv_square':
        criterion = GeneralizedDiceLoss(weight=weight,
                                        ignore_index=None,
                                        weight_type='inv_square',
                                        alpha=args.alpha)
    elif args.criterion == 'gdl_others_one_gt':
        criterion = GeneralizedDiceLoss(weight=weight,
                                        ignore_index=None,
                                        weight_type='others_one_gt',
                                        alpha=args.alpha)
    elif args.criterion == 'gdl_others_one_pred':
        criterion = GeneralizedDiceLoss(weight=weight,
                                        ignore_index=None,
def main(args):
    makedirs(args)
    snapshotargs(args)
    device = torch.device("cpu" if not torch.cuda.is_available() else args.device)

    loader_train, loader_valid = data_loaders(args)
    loaders = {"train": loader_train, "valid": loader_valid}

    unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
    unet.to(device)

    dsc_loss = DiceLoss()
    best_validation_dsc = 0.0

    optimizer = optim.Adam(unet.parameters(), lr=args.lr)
    print("Learning rate = ", args.lr)                     #AP knowing lr
    print("Batch-size = ", args.batch_size)  # AP knowing batch-size
    print("Number of visualization images to save in log file = ", args.vis_images)  # AP knowing batch-size

    logger = Logger(args.logs)
    loss_train = []
    loss_valid = []

    step = 0

    for epoch in tqdm(range(args.epochs), total=args.epochs):
        for phase in ["train", "valid"]:
            if phase == "train":
                unet.train()
            else:
                unet.eval()

            validation_pred = []
            validation_true = []

            for i, data in enumerate(loaders[phase]):
                if phase == "train":
                    step += 1

                x, y_true = data
                x, y_true = x.to(device), y_true.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    y_pred = unet(x)

                    loss = dsc_loss(y_pred, y_true)

                    if phase == "valid":
                        loss_valid.append(loss.item())
                        y_pred_np = y_pred.detach().cpu().numpy()
                        validation_pred.extend(
                            [y_pred_np[s] for s in range(y_pred_np.shape[0])]
                        )
                        y_true_np = y_true.detach().cpu().numpy()
                        validation_true.extend(
                            [y_true_np[s] for s in range(y_true_np.shape[0])]
                        )
                        if (epoch % args.vis_freq == 0) or (epoch == args.epochs - 1):
                            if i * args.batch_size < args.vis_images:
                                tag = "image/{}".format(i)
                                num_images = args.vis_images - i * args.batch_size
                                logger.image_list_summary(
                                    tag,
                                    log_images(x, y_true, y_pred)[:num_images],
                                    step,
                                )

                    if phase == "train":
                        loss_train.append(loss.item())
                        loss.backward()
                        optimizer.step()

                if phase == "train" and (step + 1) % 10 == 0:
                    log_loss_summary(logger, loss_train, step)
                    loss_train = []

            if phase == "valid":
                log_loss_summary(logger, loss_valid, step, prefix="val_")
                mean_dsc = np.mean(
                    dsc_per_volume(
                        validation_pred,
                        validation_true,
                        loader_valid.dataset.patient_slice_index,
                    )
                )
                logger.scalar_summary("val_dsc", mean_dsc, step)
                if mean_dsc > best_validation_dsc:
                    best_validation_dsc = mean_dsc
                    torch.save(unet.state_dict(), os.path.join(args.weights, "unet.pt"))
                loss_valid = []

    print("Best validation mean DSC: {:4f}".format(best_validation_dsc))
Beispiel #28
0
data_dir = './data'
train_csv_path = os.path.join(data_dir, 'train.csv')
test_csv_path = os.path.join(data_dir, 'test.csv')

train_images_dir = os.path.join(data_dir, 'stage_1_train_images/')
test_images_dir = os.path.join(data_dir, 'stage_1_test_images/')

train_df, train_loader, dev_pids, dev_loader, dev_dataset_for_predict, dev_loader_for_predict, test_loader, test_df, test_pids, boxes_by_pid_dict, min_box_area = load_data(
    train_csv_path, test_csv_path, train_images_dir, test_images_dir,
    batch_size, validation_prop, rescale_factor)
min_box_area = int(round(min_box_area / float(rescale_factor**2)))

# model = torch.nn.DataParallel(LeakyUNET().cuda(), device_ids=[0, 1, 2, 3, 4, 5, 6, 7])
model = torch.nn.DataParallel(LeakyUNET().cuda(), device_ids=[0, 1, 2, 3])

loss_fn = DiceLoss().cuda()

init_learning_rate = 0.5

num_epochs = 1 if debug else 5
num_train_steps = 5 if debug else len(train_loader)
num_dev_steps = 5 if debug else len(dev_loader)

img_dim = int(round(original_dim / rescale_factor))

print("Training for {} epochs".format(num_epochs))
histories, best_models = train_and_evaluate(model,
                                            train_loader,
                                            dev_loader,
                                            init_learning_rate,
                                            loss_fn,
Beispiel #29
0
def train(args, model, optimizer, dataloader_train, dataloader_val):
    # E' l'oggetto che ci stampa a schermo ciò chee acca
    writer = SummaryWriter(
        comment=''.format(args.optimizer, args.context_path))
    # settiamo la loss
    if args.loss == 'dice':
        # classe definita da loro nel file loss.py
        loss_func = DiceLoss()
    elif args.loss == 'crossentropy':
        loss_func = torch.nn.CrossEntropyLoss(ignore_index=255)
    # inizializziamo i contatori
    max_miou = 0
    step = 0
    # iniziamo il training
    for epoch in range(args.num_epochs):
        # inizializziamo il learning rate
        lr = poly_lr_scheduler(optimizer,
                               args.learning_rate,
                               iter=epoch,
                               max_iter=args.num_epochs)
        # iniziamo il train
        model.train()
        # cosa grafica sequenziale
        tq = tqdm(total=len(dataloader_train) * args.batch_size)
        tq.set_description('epoch %d, lr %f' % (epoch, lr))
        # Crediamo che sia la lista delle loss di ogni batch:
        loss_record = []

        # per ogni immagine o per ogni batch??? Ipotizziamo sia su ogni singolo mini-batch
        for i, (data, label) in enumerate(dataloader_train):

            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda().long()

            # Prendiamo:
            # - risultato finale dopo FFM
            # - risultato del 16xdown del contextPath, dopo ARM, modificati (?)
            # - risultato del 32xdown del contextPath, dopo ARM, modificati (?)
            output, output_sup1, output_sup2 = model(data)

            # Calcoliammo la loss
            # Principal loss function (l_p in the paper):
            loss1 = loss_func(output, label)
            # Auxilary loss functions (l_i, for i=2, 3 in the paper):
            loss2 = loss_func(output_sup1, label)
            loss3 = loss_func(output_sup2, label)

            # alfa = 1, compute equation 2:
            loss = loss1 + loss2 + loss3

            # codice grafica
            tq.update(args.batch_size)
            tq.set_postfix(loss='%.6f' % loss)
            '''
            zero_grad clears old gradients from the last step (otherwise you’d just accumulate the gradients from all loss.backward() calls).
            loss.backward() computes the derivative of the loss w.r.t. the parameters (or anything requiring gradients) using backpropagation.
            opt.step() causes the optimizer to take a step based on the gradients of the parameters.
            '''
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # incrementiamo il contatore
            step += 1
            # aggiungiamo i valori per il grafico
            writer.add_scalar('loss_step', loss, step)
            loss_record.append(loss.item())
        tq.close()
        loss_train_mean = np.mean(loss_record)
        writer.add_scalar('epoch/loss_epoch_train', float(loss_train_mean),
                          epoch)
        print('loss for train : %f' % (loss_train_mean))

        # salva il modello fin ora trainato
        if epoch % args.checkpoint_step == 0 and epoch != 0:
            import os
            if not os.path.isdir(args.save_model_path):
                os.mkdir(args.save_model_path)
            torch.save(model.state_dict(),
                       os.path.join(args.save_model_path, 'model.pth'))

        # compute validation every 10 epochs
        if epoch % args.validation_step == 0 and epoch != 0:

            # chaiam la funzione val che da in output le metriche
            precision, miou = val(args, model, dataloader_val)

            # salva miou max e salva il relativo miglior modello
            if miou > max_miou:
                max_miou = miou
                import os
                os.makedirs(args.save_model_path, exist_ok=True)
                torch.save(
                    model.state_dict(),
                    os.path.join(args.save_model_path, 'best_dice_loss.pth'))

            writer.add_scalar('epoch/precision_val', precision, epoch)
            writer.add_scalar('epoch/miou val', miou, epoch)
    # proviamo a terminare il writer per vedere se stampa qualcosa
    writer.close()
Beispiel #30
0
def display_dataset_details(dataset, train_set, val_set):
    x = 'Total Images :'+str(len(dataset))
    y = 'Total Training Images :'+str(len(train_set))
    z = 'Total Validation Images :'+str(len(val_set))

    return x,y,z

if(add_selectbox == 'Training'):
    st.header(add_selectbox)
    st.markdown('Device Detected : '+str(device))
    st.write('Select Training Parameters')
    epochs = st.number_input('Epochs', min_value = 1, value = 2)
    lr = st.number_input('Learning Rate', min_value = 0.0001, max_value = None, value = 0.0010, step = 0.001, format = '%f')

    if st.button('Load Data'):
        dsc_loss = DiceLoss()
        dataset = SyntheticCellDataset('dataset')
        indices = torch.randperm(len(dataset)).tolist()
        sr = int(0.2 * len(dataset))
        train_set = torch.utils.data.Subset(dataset, indices[:-sr])
        val_set = torch.utils.data.Subset(dataset, indices[-sr:])
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=2, shuffle=True, pin_memory=True)
        val_loader = torch.utils.data.DataLoader(val_set, batch_size=2, shuffle=False, pin_memory=True)
        st.write('Data Loaded')
        x,y,z = display_dataset_details(dataset, train_set, val_set)
        st.write(x)
        st.write(y)
        st.write(z)
        #if st.button('Start Training'):
        model = UNet()
        model.to(device)