def run_check_net(train_dl, val_dl, multi_gpu=[0, 1]):
    set_logger(LOG_PATH)
    logging.info('\n\n')
    #---
    if MODEL == 'UNetResNet34':
        net = UNetResNet34(debug=False).cuda(device=device)
    #elif MODEL == 'RESNET18':
    #    net = AtlasResNet18(debug=False).cuda(device=device)

#     for param in net.named_parameters():
#         if param[0][:8] in ['decoder5']:#'decoder5', 'decoder4', 'decoder3', 'decoder2'
#             param[1].requires_grad = False

# dummy sgd to see if it can converge ...
#optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),
#                  lr=LearningRate, momentum=0.9, weight_decay=0.0001)
#optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.045)#LearningRate
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max',
#                                                       factor=0.5, patience=4,#4 resnet34
#                                                       verbose=False, threshold=0.0001,
#                                                       threshold_mode='rel', cooldown=0,
#                                                       min_lr=0, eps=1e-08)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.9, last_epoch=-1)
    train_params = filter(lambda p: p.requires_grad, net.parameters())
    optimizer = torch.optim.SGD(train_params,
                                momentum=0.9,
                                weight_decay=0.0001,
                                lr=LearningRate)
    scheduler = LR_Scheduler(
        'poly', LearningRate, NUM_EPOCHS,
        len(train_dl))  #lr_scheduler=['poly', 'step', 'cos']

    if warm_start:
        logging.info('warm_start: ' + last_checkpoint_path)
        net, _ = load_checkpoint(last_checkpoint_path, net)

    # using multi GPU
    if multi_gpu is not None:
        net = nn.DataParallel(net, device_ids=multi_gpu)

    #use sync_batchnorm
    #net = convert_model(net)

    diff = 0
    best_val_metric = -0.1
    optimizer.zero_grad()

    #seed = get_seed()
    #seed = SEED
    #logging.info('aug seed: '+str(seed))
    #ia.imgaug.seed(seed)
    #np.random.seed(seed)

    for i_epoch in range(NUM_EPOCHS):
        t0 = time.time()
        # iterate through trainset
        if multi_gpu is not None:
            net.module.set_mode('train')
        else:
            net.set_mode('train')
        train_loss_list = []
        #train_metric_list
        #logit_list, truth_list = [], []
        for i, (images, masks) in enumerate(train_dl):
            ## adjust learning rate
            scheduler(optimizer, i, i_epoch, best_val_metric)

            input_data = images.to(device=device, dtype=torch.float)
            #1 for non-zero-mask
            truth = (torch.sum(masks.reshape(masks.size()[0],
                                             masks.size()[1], -1),
                               dim=2,
                               keepdim=False) != 0).to(device=device,
                                                       dtype=torch.float)
            logit = net(input_data)

            #logit_list.append(logit)
            #truth_list.append(truth)

            if multi_gpu is not None:
                _train_loss = net.module.criterion(logit, truth)
                #_train_metric = net.module.metric(logit, truth)#device='gpu'
            else:
                _train_loss = net.criterion(logit, truth)
                #_train_metric = net.metric(logit, truth)#device='gpu'
            train_loss_list.append(_train_loss.item())
            #train_metric_list.append(_train_metric.item())#.detach()

            #grandient accumulation step=2
            acc_step = 1
            _train_loss = _train_loss / acc_step
            _train_loss.backward()
            if (i + 1) % acc_step == 0:
                optimizer.step()
                optimizer.zero_grad()

        train_loss = np.mean(train_loss_list)
        #train_metric = np.mean(train_metric_list)
        #         if multi_gpu is not None:
        #             train_metric, train_tn, train_fp, train_fn, train_tp, train_auc, train_pos_percent = net.module.metric(torch.cat(logit_list, dim=0), torch.cat(truth_list, dim=0))
        #         else:
        #             train_metric, train_tn, train_fp, train_fn, train_tp, train_auc, train_pos_percent = net.metric(torch.cat(logit_list, dim=0), torch.cat(truth_list, dim=0))

        # compute valid loss & metrics (concatenate valid set in cpu, then compute loss, metrics on full valid set)
        net.module.set_mode('valid')
        with torch.no_grad():
            #             val_loss_list, val_metric_list = [], []
            #             for i, (image, masks) in enumerate(val_dl):
            #                 input_data = image.to(device=device, dtype=torch.float)
            #                 truth = masks.to(device=device, dtype=torch.float)
            #                 logit = net(input_data)

            #                 if multi_gpu is not None:
            #                     _val_loss  = net.module.criterion(logit, truth)
            #                     _val_metric  = net.module.metric(logit, truth)#device='gpu'
            #                 else:
            #                     _val_loss  = net.criterion(logit, truth)
            #                     _val_metric  = net.metric(logit, truth)#device='gpu'
            #                 val_loss_list.append(_val_loss.item())
            #                 val_metric_list.append(_val_metric.item())#.detach()

            #             val_loss = np.mean(val_loss_list)
            #             val_metric = np.mean(val_metric_list)

            logit_valid, truth_valid = None, None
            for j, (images, masks) in enumerate(val_dl):
                input_data = images.to(device=device, dtype=torch.float)
                #1 for non-zero-mask
                truth = (torch.sum(masks.reshape(masks.size()[0],
                                                 masks.size()[1], -1),
                                   dim=2,
                                   keepdim=False) != 0).to(device=device,
                                                           dtype=torch.float)
                logit = net(input_data)

                if logit_valid is None:
                    logit_valid = logit
                    truth_valid = truth
                else:
                    logit_valid = torch.cat((logit_valid, logit), dim=0)
                    truth_valid = torch.cat((truth_valid, truth), dim=0)
            if multi_gpu is not None:
                val_loss = net.module.criterion(logit_valid, truth_valid)
                _, val_metric, val_tn, val_fp, val_fn, val_tp, val_pos_percent = net.module.metric(
                    logit_valid, truth_valid)
            else:
                val_loss = net.criterion(logit_valid, truth_valid)
                _, val_metric, val_tn, val_fp, val_fn, val_tp, val_pos_percent = net.metric(
                    logit_valid, truth_valid)

        # Adjust learning_rate
        #scheduler.step(val_metric)
        #
        if i_epoch >= 30:
            if val_metric > best_val_metric:
                best_val_metric = val_metric
                is_best = True
                diff = 0
            else:
                is_best = False
                diff += 1
                if diff > early_stopping_round:
                    logging.info(
                        'Early Stopping: val_metric does not increase %d rounds'
                        % early_stopping_round)
                    #print('Early Stopping: val_iou does not increase %d rounds'%early_stopping_round)
                    break
        else:
            is_best = False

        #save checkpoint
        checkpoint_dict = \
        {
            'epoch': i_epoch,
            'state_dict': net.module.state_dict() if multi_gpu is not None else net.state_dict(),
            'optim_dict' : optimizer.state_dict(),
            'metrics': {'train_loss': train_loss, 'val_loss': val_loss,
                        'val_metric': val_metric}
        }
        save_checkpoint(checkpoint_dict,
                        is_best=is_best,
                        checkpoint=checkpoint_path)

        #if i_epoch%20==0:
        if i_epoch > -1:
            logging.info(
                '[EPOCH %05d]train_loss: %0.5f; val_loss, val_metric: %0.5f, %0.5f'
                % (i_epoch, train_loss.item(), val_loss.item(), val_metric))
            logging.info('val_pos_percent: %.3f' % (val_pos_percent))
            logging.info('val (tn, fp, fn, tp): %d, %d, %d, %d' %
                         (val_tn, val_fp, val_fn, val_tp))
            logging.info('time elapsed: %0.1f min' % ((time.time() - t0) / 60))
Пример #2
0
def run_check_net(train_dl,
                  val_dl,
                  multi_gpu=[0, 1],
                  nonempty_only_loss=False):
    set_logger(LOG_PATH)
    logging.info('\n\n')
    #---
    enc, dec = MODEL.split('_')[0], MODEL.split('_')[1]
    net = SegmentationModule(net_enc=enc, net_dec=dec).cuda(device=device)

    #     for param in net.named_parameters():
    #         if param[0][:8] in ['decoder5']:#'decoder5', 'decoder4', 'decoder3', 'decoder2'
    #             param[1].requires_grad = False

    #     train_params = [{'params': net.get_1x_lr_params(), 'lr': LearningRate},
    #                     {'params': net.get_10x_lr_params(), 'lr': LearningRate * 10}]#for resnet backbone
    train_params = filter(lambda p: p.requires_grad, net.parameters())
    # dummy sgd to see if it can converge ...
    #optimizer = torch.optim.SGD(train_params,
    #                  lr=LearningRate, momentum=0.9, weight_decay=0.0001)
    #optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.045)#LearningRate
    #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max',
    #                                                       factor=0.5, patience=4,#4 resnet34
    #                                                       verbose=False, threshold=0.0001,
    #                                                       threshold_mode='rel', cooldown=0,
    #                                                       min_lr=0, eps=1e-08)
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.9, last_epoch=-1)

    #for deeplabv3plus customized
    optimizer = torch.optim.SGD(train_params,
                                momentum=0.9,
                                weight_decay=0.0001,
                                lr=LearningRate)
    scheduler = LR_Scheduler(
        'poly', LearningRate, NUM_EPOCHS,
        len(train_dl))  #lr_scheduler=['poly', 'step', 'cos']

    if warm_start:
        logging.info('warm_start: ' + last_checkpoint_path)
        net, _ = load_checkpoint(last_checkpoint_path, net)

    # using multi GPU
    if multi_gpu is not None:
        net = nn.DataParallel(net, device_ids=multi_gpu)

    #use sync_batchnorm
    #net = convert_model(net)

    diff = 0
    best_val_metric = -0.1
    optimizer.zero_grad()

    #seed = get_seed()
    #seed = SEED
    #logging.info('aug seed: '+str(seed))
    #ia.imgaug.seed(seed)
    #np.random.seed(seed)

    for i_epoch in range(NUM_EPOCHS):
        ### adjust learning rate
        #scheduler.step(epoch=i_epoch)
        #print('lr: %f'%scheduler.get_lr()[0])

        t0 = time.time()
        # iterate through trainset
        if multi_gpu is not None:
            net.module.set_mode('train')
        else:
            net.set_mode('train')
        train_loss_list, train_metric_list = [], []
        #for seed in [1]:#[1, SEED]:#augment raw data with a duplicate one (augmented)
        #seed = get_seed()
        #np.random.seed(seed)
        #ia.imgaug.seed(i//10)
        for i, (image, masks) in enumerate(train_dl):
            ## adjust learning rate
            scheduler(optimizer, i, i_epoch, best_val_metric)

            input_data = image.to(device=device, dtype=torch.float)
            truth = masks.to(device=device, dtype=torch.float)
            #set_trace()
            logit, logit_clf = net(input_data)  #[:, :3, :, :]

            if multi_gpu is not None:
                _train_loss = net.module.criterion(logit, truth,
                                                   nonempty_only_loss,
                                                   logit_clf)
                _train_metric = net.module.metric(logit, truth,
                                                  nonempty_only_loss,
                                                  logit_clf)  #device='gpu'
            else:
                _train_loss = net.criterion(logit, truth, nonempty_only_loss,
                                            logit_clf)
                _train_metric = net.metric(logit, truth, nonempty_only_loss,
                                           logit_clf)  #device='gpu'
            train_loss_list.append(_train_loss.item())
            train_metric_list.append(_train_metric.item())  #.detach()

            #grandient accumulation step=2
            acc_step = 1
            _train_loss = _train_loss / acc_step
            _train_loss.backward()
            if (i + 1) % acc_step == 0:
                optimizer.step()
                optimizer.zero_grad()

        train_loss = np.mean(train_loss_list)
        train_metric = np.mean(train_metric_list)

        # compute valid loss & metrics (concatenate valid set in cpu, then compute loss, metrics on full valid set)
        net.module.set_mode('valid')
        with torch.no_grad():
            val_loss_list, val_metric_list = [], []
            for i, (image, masks) in enumerate(val_dl):
                input_data = image.to(device=device, dtype=torch.float)
                truth = masks.to(device=device, dtype=torch.float)
                logit, logit_clf = net(input_data)

                if multi_gpu is not None:
                    _val_loss = net.module.criterion(logit, truth,
                                                     nonempty_only_loss,
                                                     logit_clf)
                    _val_metric = net.module.metric(logit, truth,
                                                    nonempty_only_loss,
                                                    logit_clf)  #device='gpu'
                else:
                    _val_loss = net.criterion(logit, truth, nonempty_only_loss,
                                              logit_clf)
                    _val_metric = net.metric(logit, truth, nonempty_only_loss,
                                             logit_clf)  #device='gpu'
                val_loss_list.append(_val_loss.item())
                val_metric_list.append(_val_metric.item())  #.detach()

            val_loss = np.mean(val_loss_list)
            val_metric = np.mean(val_metric_list)

#             logit_valid, truth_valid = None, None
#             for j, (image, masks) in enumerate(val_dl):
#                 input_data = image.to(device=device, dtype=torch.float)
#                 logit = net(input_data).cpu().float()
#                 truth = masks.cpu().float()
#                 if logit_valid is None:
#                     logit_valid = logit
#                     truth_valid = truth
#                 else:
#                     logit_valid = torch.cat((logit_valid, logit), dim=0)
#                     truth_valid = torch.cat((truth_valid, truth), dim=0)
#             if multi_gpu is not None:
#                 val_loss = net.module.criterion(logit_valid, truth_valid)
#                 val_metric = net.module.metric(logit_valid, truth_valid)
#             else:
#                 val_loss = net.criterion(logit_valid, truth_valid)
#                 val_metric = net.metric(logit_valid, truth_valid)

# Adjust learning_rate
#scheduler.step(val_metric)

#for 1024 trainging is harder, sometimes too early stop, force to at least train 40 epochs
        if i_epoch >= 10:  #-1
            if val_metric > best_val_metric:
                best_val_metric = val_metric
                is_best = True
                diff = 0
            else:
                is_best = False
                diff += 1
                if diff > early_stopping_round:
                    logging.info(
                        'Early Stopping: val_metric does not increase %d rounds'
                        % early_stopping_round)
                    #print('Early Stopping: val_iou does not increase %d rounds'%early_stopping_round)
                    break
        else:
            is_best = False

        #save checkpoint
        checkpoint_dict = \
        {
            'epoch': i,
            'state_dict': net.module.state_dict() if multi_gpu is not None else net.state_dict(),
            'optim_dict' : optimizer.state_dict(),
            'metrics': {'train_loss': train_loss, 'val_loss': val_loss,
                        'train_metric': train_metric, 'val_metric': val_metric}
        }
        save_checkpoint(checkpoint_dict,
                        is_best=is_best,
                        checkpoint=checkpoint_path)

        #if i_epoch%20==0:
        if i_epoch > -1:
            logging.info(
                '[EPOCH %05d]train_loss, train_metric: %0.5f, %0.5f; val_loss, val_metric: %0.5f, %0.5f; time elapsed: %0.1f min'
                % (i_epoch, train_loss.item(), train_metric.item(),
                   val_loss.item(), val_metric.item(),
                   (time.time() - t0) / 60))