コード例 #1
0
ファイル: inference_cvc.py プロジェクト: lswzjuer/NAS-WDAN
def main(args):
    # 0.762
    # args.model='unet'
    # model1=get_models(args)
    # model1.load_state_dict(torch.load(r'E:\segmentation\Image_Segmentation\logs\cvc_logs\unet_ep1600\cvc\20200312-143050\model_best.pth.tar',map_location='cpu')['state_dict'])

    # # 0.766/0.773
    # args.model='unet++'
    # model2=get_models(args)
    # model2.load_state_dict(torch.load(r'E:\segmentation\Image_Segmentation\logs\cvc_logs\unet++_nodeep_ep800\cvc\no_deep\model_best.pth.tar',map_location='cpu')['state_dict'])
    #
    # # mutilres 0.695
    # args.model='multires_unet'
    # model3=get_models(args)
    # model3.load_state_dict(torch.load(r'E:\segmentation\Image_Segmentation\logs\cvc_logs\multires_unet_800\cvc\20200310-172036\checkpoint.pth.tar',map_location='cpu')['state_dict'])
    #
    #
    # attention_unet 0.778
    args.model = 'attention_unet_v1'
    model4 = get_models(args)
    model4.load_state_dict(
        torch.load(
            r'E:\segmentation\Image_Segmentation\logs\cvc_logs\attention_unet_v1_ep1600\cvc\20200312-143413\model_best.pth.tar',
            map_location='cpu')['state_dict'])

    genotype = eval('genotypes.%s' % 'layer7_double_deep')
    #BuildNasUnetPrune
    model5 = BuildNasUnetPrune(
        genotype=genotype,
        input_c=3,
        c=16,
        num_classes=1,
        meta_node_num=4,
        layers=9,
        dp=0,
        use_sharing=True,
        double_down_channel=True,
        aux=True,
    )
    model5.load_state_dict(
        torch.load(
            r'E:\segmentation\Image_Segmentation\nas_search_unet\logs\cvc\layer7_double_deep_ep1600_20200320-200539\model_best.pth.tar',
            map_location='cpu')['state_dict'])
    models_list = [model4, model5]
    inference_isic(models_list, args.image, args.mask)
コード例 #2
0
def main(args):
    model1 = get_models(args)
    ckpt1 = torch.load(args.model_weight1, map_location='cpu')
    model1.load_state_dict(ckpt1['state_dict'])
    # inference_isic(model,args.image,args.mask)

    ckpt2 = torch.load(args.model_weight2, map_location='cpu')
    genotype = eval('genotypes.%s' %
                    'stage1_layer9_110epoch_double_deep_final')
    #BuildNasUnetPrune
    model2 = BuildNasUnetPrune(
        genotype=genotype,
        input_c=3,
        c=16,
        num_classes=1,
        meta_node_num=4,
        layers=9,
        dp=0,
        use_sharing=True,
        double_down_channel=True,
        aux=True,
    )
    model2.load_state_dict(ckpt2['state_dict'])
    inference_isic(model1, model2, args.image, args.mask)
コード例 #3
0
def main(args):
    #################### init logger ###################################
    log_dir = './logs/' + '{}'.format(args.dataset) + '/{}_{}_{}'.format(args.model,args.note,time.strftime('%Y%m%d-%H%M%S'))


    logger = get_logger(log_dir)
    print('RUNDIR: {}'.format(log_dir))
    logger.info('{}-Train'.format(args.model))
    # setting
    args.save_path = log_dir
    args.save_tbx_log = args.save_path + '/tbx_log'
    writer = SummaryWriter(args.save_tbx_log)
    ##################### init device #################################
    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    np.random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.device = torch.device('cuda' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.manualSeed)
        cudnn.benchmark = True
    ####################### init dataset ###########################################
    train_loader = get_dataloder(args, split_flag="train")
    val_loader = get_dataloder(args, split_flag="valid")

    ############init model ###########################
    if  args.model == "layer7_double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'layer7_double_deep'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=7,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )



    elif args.model == "stage1_double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage1_double_deep'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    elif args.model == "stage1_nodouble_deep":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'stage1_deep'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    elif args.model == "stage1_nodouble_deep_slim":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'stage1_deep'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPruneSlim(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )


    elif args.model == "alpha1_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha1_stage1_double_deep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    elif args.model == "alpha0_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_stage1_double_deep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    #isic trans
    elif args.model == "stage1_layer9_110epoch_double_deep_final":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage1_layer9_110epoch_double_deep_final'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )


    # just normaL cell keep
    elif args.model == "dd_normal":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPruneNormal(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    # normal+down
    elif args.model == "dd_normaldown":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPruneNormalDown(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    # normal+up 
    elif args.model == "dd_normalup":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPruneNormalUp(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    # normal+up+down
    elif args.model == "alpha0_5_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    # abliation study of channel doubling and deepsupervision
    elif args.model == "alpha0_5_stage1_double_nodeep_ep80":
        args.deepsupervision = False
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_nodeep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    elif args.model == "alpha0_5_stage1_nodouble_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'alpha0_5_stage1_nodouble_deep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    elif args.model == "alpha0_5_stage1_nodouble_nodeep_ep80":
        args.deepsupervision = False
        args.double_down_channel = False
        args.genotype_name = 'alpha0_5_stage1_nodouble_nodeep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    if torch.cuda.device_count() > 1 and args.use_cuda:
        logger.info('use: %d gpus', torch.cuda.device_count())
        model = nn.DataParallel(model)

    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    logger.info(genotype)
    logger.info('param size = %fMB', calc_parameters_count(model))
    # init loss
    if args.loss == 'bce':
        criterion = nn.BCELoss()
    elif args.loss == 'bcelog':
        criterion = nn.BCEWithLogitsLoss()
    elif args.loss == "dice":
        criterion = DiceLoss()
    elif args.loss == "softdice":
        criterion = SoftDiceLoss()
    elif args.loss == 'bcedice':
        criterion = BCEDiceLoss()
    else:
        criterion = nn.CrossEntropyLoss()
    if args.use_cuda:
        logger.info("load model and criterion to gpu !")
    model = model.to(args.device)
    criterion = criterion.to(args.device)
    # init optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
    # init schedulers  Steplr
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epoch)
    # scheduler=torch.optim.lr_scheduler.StepLR(optimizer=optimizer,step_size=30,gamma=0.1,last_epoch=-1)

    ############################### check resume #########################
    start_epoch = 0
    if args.resume is not None:
        if os.path.isfile(args.resume):
            logger.info("Loading model and optimizer from checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=args.device)
            start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            model.load_state_dict(checkpoint['state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler'])
        else:
            raise FileNotFoundError("No checkpoint found at '{}'".format(args.resume))

    #################################### train and val ########################
    max_value = 0
    for epoch in range(start_epoch, args.epoch):
        # lr=adjust_learning_rate(args,optimizer,epoch)
        scheduler.step()
        logger.info('Epoch: %d lr %e', epoch, scheduler.get_lr()[0])
        # train
        if args.deepsupervision:
            mean_loss, value1, value2 = train(args, model, criterion, train_loader, optimizer)
            mr, ms, mp, mf, mjc, md, macc = value1
            mmr, mms, mmp, mmf, mmjc, mmd, mmacc = value2
            logger.info(
                "Epoch:{} Train_Loss:{:.3f} Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".format(epoch, mean_loss, macc, md, mjc))
            logger.info("                        dmAcc:{:.3f} dmDice:{:.3f} dmJc:{:.3f}".format(mmacc, mmd, mmjc))
            writer.add_scalar('Train/dmAcc', mmacc, epoch)
            writer.add_scalar('Train/dRecall', mmr, epoch)
            writer.add_scalar('Train/dSpecifi', mms, epoch)
            writer.add_scalar('Train/dPrecision', mmp, epoch)
            writer.add_scalar('Train/dF1', mmf, epoch)
            writer.add_scalar('Train/dJc', mmjc, epoch)
            writer.add_scalar('Train/dDice', mmd, epoch)
        else:
            mean_loss, value1 = train(args, model, criterion, train_loader,
                                      optimizer)
            mr, ms, mp, mf, mjc, md, macc = value1
            logger.info(
                "Epoch:{} Train_Loss:{:.3f} Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".format(epoch, mean_loss, macc, md, mjc))
        # write
        writer.add_scalar('Train/Loss', mean_loss, epoch)
        writer.add_scalar('Train/mAcc', macc, epoch)
        writer.add_scalar('Train/Recall', mr, epoch)
        writer.add_scalar('Train/Specifi', ms, epoch)
        writer.add_scalar('Train/Precision', mp, epoch)
        writer.add_scalar('Train/F1', mf, epoch)
        writer.add_scalar('Train/Jc', mjc, epoch)
        writer.add_scalar('Train/Dice', md, epoch)

        # val
        if args.deepsupervision:
            vmean_loss, valuev1, valuev2 = infer(args, model, criterion, val_loader)
            vmr, vms, vmp, vmf, vmjc, vmd, vmacc = valuev1
            mvmr, mvms, mvmp, mvmf, mvmjc, mvmd, mvmacc = valuev2
            logger.info(
                "Epoch:{} Val_Loss:{:.3f} Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".format(epoch, vmean_loss, vmacc, vmd, vmjc))
            logger.info("                        dmAcc:{:.3f} dmDice:{:.3f} dmJc:{:.3f}".format(mvmacc, mvmd, mvmjc))
            writer.add_scalar('Val/mAcc', mvmacc, epoch)
            writer.add_scalar('Val/Recall', mvmr, epoch)
            writer.add_scalar('Val/Specifi', mvms, epoch)
            writer.add_scalar('Val/Precision', mvmp, epoch)
            writer.add_scalar('Val/F1', mvmf, epoch)
            writer.add_scalar('Val/Jc', mvmjc, epoch)
            writer.add_scalar('Val/Dice', mvmd, epoch)
        else:
            vmean_loss, valuev1 = infer(args, model, criterion, val_loader)
            vmr, vms, vmp, vmf, vmjc, vmd, vmacc = valuev1
            logger.info(
                "Epoch:{} Val_Loss:{:.3f} Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".format(epoch, vmean_loss, vmacc, vmd, vmjc))

        is_best = True if (vmjc >=max_value) else False
        max_value = max(max_value, vmjc)
        writer.add_scalar('Val/Loss', vmean_loss, epoch)
        writer.add_scalar('Val/mAcc', vmacc, epoch)
        writer.add_scalar('Val/Recall', vmr, epoch)
        writer.add_scalar('Val/Specifi', vms, epoch)
        writer.add_scalar('Val/Precision', vmp, epoch)
        writer.add_scalar('Val/F1', vmf, epoch)
        writer.add_scalar('Val/Jc', vmjc, epoch)
        writer.add_scalar('Val/Dice', vmd, epoch)

        state={
                'epoch': epoch,
                'optimizer': optimizer.state_dict(),
                'state_dict': model.state_dict(),
                'scheduler': model.state_dict(),
            }
        logger.info("epoch:{} best:{} max_value:{}".format(epoch,is_best,max_value))
        if not is_best:
            torch.save(state,os.path.join(args.save_path,"checkpoint.pth.tar"))
        else:
            torch.save(state,os.path.join(args.save_path,"checkpoint.pth.tar"))
            torch.save(state,os.path.join(args.save_path,"model_best.pth.tar"))

    writer.close()
コード例 #4
0
def main(args):
    #################### init logger ###################################
    log_dir = './eval' + '/{}'.format(args.dataset) + '/{}'.format(args.model)
    logger = get_logger(log_dir)
    print('RUNDIR: {}'.format(log_dir))
    logger.info('{}-Eval'.format(args.model))
    # setting
    args.save_path = log_dir
    args.save_images = os.path.join(args.save_path, "images")
    if not os.path.exists(args.save_images):
        os.mkdir(args.save_images)
    ##################### init device #################################
    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    np.random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.device = torch.device('cuda' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.manualSeed)
        cudnn.benchmark = True
    ####################### init dataset ###########################################
    val_loader = get_dataloder(args, split_flag="valid")

    ######################## init model ############################################
    if args.model == "layer7_double_deep_ep1600_8lr4e-3":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'layer7_double_deep'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=9,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)
        args.model_path = './logs/cvc/layer7_double_deep_ep1600_8lr4e-3/model_best.pth.tar'
        model.load_state_dict(
            torch.load(args.model_path, map_location='cpu')['state_dict'])

    elif args.model == "alpha0_double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_stage1_double_deep_ep200'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)
        args.model_path = './logs/cvc/alpha0_8lr4e-3/model_best.pth.tar'
        state_dict = torch.load(args.model_path,
                                map_location='cpu')['state_dict']
        state_dict = remove_module(state_dict)
        model.load_state_dict(state_dict)

    elif args.model == "alpha0_5_double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)
        args.model_path = './logs/cvc/alpha0_5_8lr4e-3/model_best.pth.tar'
        state_dict = torch.load(args.model_path,
                                map_location='cpu')['state_dict']
        state_dict = remove_module(state_dict)
        model.load_state_dict(state_dict)

    elif args.model == "alpha1_double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)
        args.model_path = './logs/cvc/alpha1_8lr4e-3/model_best.pth.tar'
        state_dict = torch.load(args.model_path,
                                map_location='cpu')['state_dict']
        state_dict = remove_module(state_dict)
        model.load_state_dict(state_dict)

    else:
        raise NotImplementedError()

    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    logger.info(genotype)
    logger.info('param size = %fMB', calc_parameters_count(model))
    # init loss
    if args.loss == 'bce':
        criterion = nn.BCELoss()
    elif args.loss == 'bcelog':
        criterion = nn.BCEWithLogitsLoss()
    elif args.loss == "dice":
        criterion = DiceLoss()
    elif args.loss == "softdice":
        criterion = SoftDiceLoss()
    elif args.loss == 'bcedice':
        criterion = BCEDiceLoss()
    else:
        criterion = nn.CrossEntropyLoss()
    if args.use_cuda:
        logger.info("load model and criterion to gpu !")
    model = model.to(args.device)
    criterion = criterion.to(args.device)

    infer(args, model, criterion, val_loader, logger, args.save_images)
コード例 #5
0
def main(args):

    args.model_list = [
        'double_deep', 'double', 'nodouble', 'nodouble_deep', 'slim_dd',
        'slim_double', 'slim_nodouble', 'slim_nodouble_deep'
    ]
    #args.model_list=["slim_nodouble_deep_init32"]
    for model_name in args.model_list:
        print(model_name)
        if model_name == "double_deep":
            args.deepsupervision = True
            args.double_down_channel = True
            args.genotype_name = 'stage1_layer9_110epoch_double_deep_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = BuildNasUnetPrune(
                genotype=genotype,
                input_c=args.in_channels,
                c=args.init_channels,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/prune_20200313-063406_32_32_ep300_double_deep/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        elif model_name == 'double':
            args.deepsupervision = False
            args.double_down_channel = True
            args.genotype_name = 'stage1_layer9_110epoch_double_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = BuildNasUnetPrune(
                genotype=genotype,
                input_c=args.in_channels,
                c=args.init_channels,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/prune_20200313-063428_32_32_ep300_double/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        elif model_name == 'nodouble':
            args.deepsupervision = False
            args.double_down_channel = False
            args.genotype_name = 'stage1_layer9_110epoch_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = BuildNasUnetPrune(
                genotype=genotype,
                input_c=args.in_channels,
                c=args.init_channels,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/prune_20200316-141125_nodouble_32_ep300/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        elif model_name == 'nodouble_deep':
            args.deepsupervision = True
            args.double_down_channel = False
            args.genotype_name = 'stage1_layer9_110epoch_deep_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = BuildNasUnetPrune(
                genotype=genotype,
                input_c=args.in_channels,
                c=args.init_channels,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/prune_20200316-141242_nodouble_32_ep300_deep/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        if model_name == "slim_dd":
            args.deepsupervision = True
            args.double_down_channel = True
            args.genotype_name = 'stage1_layer9_110epoch_double_deep_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = net_dd(genotype=genotype,
                           input_c=args.in_channels,
                           c=args.init_channels,
                           num_classes=args.nclass,
                           meta_node_num=args.middle_nodes,
                           layers=args.layers,
                           dp=args.dropout_prob,
                           use_sharing=args.use_sharing,
                           double_down_channel=args.double_down_channel,
                           aux=args.aux)
            args.model_path = './logs/isic2018/dd_20200319-170442_ep300/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        elif model_name == 'slim_double':
            args.deepsupervision = False
            args.double_down_channel = True
            args.genotype_name = 'stage1_layer9_110epoch_double_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = net_double(genotype=genotype,
                               input_c=args.in_channels,
                               c=args.init_channels,
                               num_classes=args.nclass,
                               meta_node_num=args.middle_nodes,
                               layers=args.layers,
                               dp=args.dropout_prob,
                               use_sharing=args.use_sharing,
                               double_down_channel=args.double_down_channel,
                               aux=args.aux)
            args.model_path = './logs/isic2018/double_20200319-170621_ep300/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        elif model_name == 'slim_nodouble':
            args.deepsupervision = False
            args.double_down_channel = False
            args.genotype_name = 'stage1_layer9_110epoch_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = net_nodouble(genotype=genotype,
                                 input_c=args.in_channels,
                                 c=args.init_channels,
                                 num_classes=args.nclass,
                                 meta_node_num=args.middle_nodes,
                                 layers=args.layers,
                                 dp=args.dropout_prob,
                                 use_sharing=args.use_sharing,
                                 double_down_channel=args.double_down_channel,
                                 aux=args.aux)
            args.model_path = './logs/isic2018/nodouble_20200319-210910_ep300/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        elif model_name == 'slim_nodouble_deep':
            args.deepsupervision = True
            args.double_down_channel = False
            args.genotype_name = 'stage1_layer9_110epoch_deep_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = net_nodouble_deep(
                genotype=genotype,
                input_c=args.in_channels,
                c=args.init_channels,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/nodouble_deep_20200319-210600_ep300/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        elif model_name == 'slim_nodouble_deep_init32':
            args.deepsupervision = True
            args.double_down_channel = False
            args.genotype_name = 'stage1_layer9_110epoch_deep_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = net_nodouble_deep(
                genotype=genotype,
                input_c=args.in_channels,
                c=32,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/nodouble_deep_ep300_init32/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        #################### init logger ###################################
        log_dir = './eval' + '/{}'.format(
            args.dataset) + '/{}'.format(model_name)
        ##################### init model ########################################
        logger = get_logger(log_dir)
        print('RUNDIR: {}'.format(log_dir))
        logger.info('{}-Eval'.format(model_name))
        # setting
        args.save_path = log_dir
        args.save_images = os.path.join(args.save_path, "images")
        if not os.path.exists(args.save_images):
            os.mkdir(args.save_images)
        ##################### init device #################################
        if args.manualSeed is None:
            args.manualSeed = random.randint(1, 10000)
        np.random.seed(args.manualSeed)
        torch.manual_seed(args.manualSeed)
        args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
        args.device = torch.device('cuda' if args.use_cuda else 'cpu')
        if args.use_cuda:
            torch.cuda.manual_seed(args.manualSeed)
            cudnn.benchmark = True
        ####################### init dataset ###########################################
        # sorted vaild datasets
        val_loader = get_dataloder(args, split_flag="valid")
        setting = {k: v for k, v in args._get_kwargs()}
        logger.info(setting)
        logger.info(genotype)
        logger.info('param size = %fMB', calc_parameters_count(model))

        # init loss
        if args.loss == 'bce':
            criterion = nn.BCELoss()
        elif args.loss == 'bcelog':
            criterion = nn.BCEWithLogitsLoss()
        elif args.loss == "dice":
            criterion = DiceLoss()
        elif args.loss == "softdice":
            criterion = SoftDiceLoss()
        elif args.loss == 'bcedice':
            criterion = BCEDiceLoss()
        else:
            criterion = nn.CrossEntropyLoss()
        if args.use_cuda:
            logger.info("load model and criterion to gpu !")
        model = model.to(args.device)
        criterion = criterion.to(args.device)
コード例 #6
0
ファイル: retrain_chao.py プロジェクト: lswzjuer/NAS-WDAN
def main(args):

    #################### init logger ###################################
    log_dir = './logs/' + '{}'.format(args.dataset) + '/{}_{}_{}'.format(
        args.model, args.note, time.strftime('%Y%m%d-%H%M%S'))

    logger = get_logger(log_dir)
    print('RUNDIR: {}'.format(log_dir))
    logger.info('{}-Train'.format(args.model))
    # setting
    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    args.save_path = log_dir
    args.save_tbx_log = args.save_path + '/tbx_log'
    writer = SummaryWriter(args.save_tbx_log)
    ##################### init device #################################
    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    np.random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.device = torch.device('cuda' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.manualSeed)
        cudnn.benchmark = True
    ####################### init dataset ###########################################
    train_loader = get_dataloder(args, split_flag="train")
    val_loader = get_dataloder(args, split_flag="valid")
    ######################## init model ############################################
    # model
    ############init model ###########################
    if args.model == "nodouble_deep_init32_ep100":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'nodouble_deep_init32_ep100'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=32,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=9,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "nodouble_deep_isic":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'stage1_layer9_110epoch_deep_final'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "nodouble_deep_drop02_layer7end":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'nodouble_deep_drop02_layer7end'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "stage1_nodouble_deep_ep36":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'stage1_nodouble_deep_ep36'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "stage1_nodouble_deep_ep63":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'stage1_nodouble_deep_ep63'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)
    elif args.model == "stage1_nodouble_deep_ep83":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'stage1_nodouble_deep_ep83'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha1_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha1_stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha0_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha0_5_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    # isic trans
    elif args.model == "stage1_layer9_110epoch_double_deep_final":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage1_layer9_110epoch_double_deep_final'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    #chaos
    elif args.model == "stage0_double_deep_ep80_newim":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage0_double_deep_ep80_newim'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "stage1_double_deep_ep80_ts":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage1_double_deep_ep80_ts'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    # cvc trans
    elif args.model == "layer7_double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'layer7_double_deep'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    if torch.cuda.device_count() > 1 and args.use_cuda:
        logger.info('use: %d gpus', torch.cuda.device_count())
        model = nn.DataParallel(model)
    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    logger.info(genotype)
    logger.info('param size = %fMB', calc_parameters_count(model))

    # init loss
    if args.loss == 'bce':
        criterion = nn.BCELoss()
    elif args.loss == 'bcelog':
        criterion = nn.BCEWithLogitsLoss()
    elif args.loss == "dice":
        criterion = DiceLoss()
    elif args.loss == "softdice":
        criterion = SoftDiceLoss()
    elif args.loss == 'bcedice':
        criterion = BCEDiceLoss()
    elif args.loss == 'multibcedice':
        criterion = MultiClassEntropyDiceLoss()
    else:
        criterion = nn.CrossEntropyLoss()
    if args.use_cuda:
        logger.info("load model and criterion to gpu !")
        model = model.to(args.device)
        criterion = criterion.to(args.device)
    # init optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                weight_decay=args.weight_decay,
                                momentum=args.momentum)
    # init schedulers  Steplr
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.epoch)
    # scheduler=torch.optim.lr_scheduler.StepLR(optimizer=optimizer,step_size=30,gamma=0.1,last_epoch=-1)
    ############################### check resume #########################
    start_epoch = 0
    if args.resume is not None:
        if os.path.isfile(args.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume, map_location=args.device)
            start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            model.load_state_dict(checkpoint['state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler'])
        else:
            raise FileNotFoundError("No checkpoint found at '{}'".format(
                args.resume))

    #################################### train and val ########################

    max_value = 0
    for epoch in range(start_epoch, args.epoch):
        # lr=adjust_learning_rate(args,optimizer,epoch)
        scheduler.step()
        # logger.info('Epoch: %d lr %e', epoch, scheduler.get_lr()[0])
        # train
        total_loss = train(args, model, criterion, train_loader, optimizer,
                           epoch, logger)
        # write
        writer.add_scalar('Train/total_loss', total_loss, epoch)
        # val
        tloss, md = val(args, model, criterion, val_loader, epoch, logger)
        writer.add_scalar('Val/total_loss', tloss, epoch)

        is_best = True if (md >= max_value) else False
        max_value = max(max_value, md)
        state = {
            'epoch': epoch,
            'optimizer': optimizer.state_dict(),
            'state_dict': model.state_dict(),
            'scheduler': model.state_dict(),
        }
        logger.info("epoch:{} best:{} max_value:{}".format(
            epoch, is_best, max_value))
        if not is_best:
            torch.save(state, os.path.join(args.save_path,
                                           "checkpoint.pth.tar"))
        else:
            torch.save(state, os.path.join(args.save_path,
                                           "checkpoint.pth.tar"))
            torch.save(state, os.path.join(args.save_path,
                                           "model_best.pth.tar"))
    writer.close()
コード例 #7
0
def main(args):

    args.device = torch.device('cuda')
    args.dataset = 'isic2018'
    args.train_batch = 2
    args.val_batch = 2
    args.num_workers = 2
    args.crop_size = 256
    args.base_size = 256
    train_loader = get_dataloder(args, split_flag="train")
    #args.model = 'nodouble_deep'
    ######################## init model ############################################
    # model
    # get the network parameters
    if args.model == 'dd':
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage1_layer9_110epoch_double_deep_final'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )
        print('param size = %fMB', calc_parameters_count(model))

    elif args.model == 'double':
        args.deepsupervision = False
        args.double_down_channel = True
        args.genotype_name = 'stage1_layer9_110epoch_double_final'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )
        print('param size = %fMB', calc_parameters_count(model))

    elif args.model == 'nodouble':
        args.deepsupervision = False
        args.double_down_channel = False
        args.genotype_name = 'stage1_layer9_110epoch_final'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )
        print('param size = %fMB', calc_parameters_count(model))

    elif args.model == 'nodouble_deep':
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'stage1_layer9_110epoch_deep_final'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )
        print('param size = %fMB', calc_parameters_count(model))

    else:
        raise  NotImplementedError()
    time=test_time(args,model,train_loader)
    print("Infrence time:{}".format(time))
コード例 #8
0
ファイル: retrain_isic.py プロジェクト: lswzjuer/NAS-WDAN
def main(args):

    #################### init logger ###################################
    log_dir = './logs/' + '{}'.format(args.dataset) + '/{}_{}_{}'.format(
        args.model, time.strftime('%Y%m%d-%H%M%S'), args.note)
    logger = get_logger(log_dir)
    print('RUNDIR: {}'.format(log_dir))
    logger.info('{}-Train'.format(args.model))
    # setting
    args.save_path = log_dir
    args.save_tbx_log = args.save_path + '/tbx_log'
    writer = SummaryWriter(args.save_tbx_log)
    ##################### init device #################################
    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    np.random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.device = torch.device('cuda' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.manualSeed)
        cudnn.benchmark = True
    ####################### init dataset ###########################################
    train_loader = get_dataloder(args, split_flag="train")
    val_loader = get_dataloder(args, split_flag="valid")
    ######################## init model ############################################
    # model
    # get the network parameters
    if args.model == "alpha_double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage1_layer9_110epoch_double_deep_final'
        args.alphas_model = './search_exp/Nas_Search_Unet/isic2018/deepsupervision/stage_1_model/checkpoint.pth.tar'
        model_alphas = torch.load(
            args.alphas_model,
            map_location=args.device)['alphas_dict']['alphas_network']
        model_alphas.requires_grad = False
        model_alphas = F.softmax(model_alphas, dim=-1)
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnet(genotype=genotype,
                             input_c=args.in_channels,
                             c=args.init_channels,
                             num_classes=args.nclass,
                             meta_node_num=args.middle_nodes,
                             layers=args.layers,
                             dp=args.dropout_prob,
                             use_sharing=args.use_sharing,
                             double_down_channel=args.double_down_channel,
                             aux=args.aux)

    elif args.model == "alpha_double":
        args.deepsupervision = False
        args.double_down_channel = True
        args.genotype_name = 'stage1_layer9_110epoch_double_final'
        args.alphas_model = './search_exp/Nas_Search_Unet/isic2018/nodeepsupervision/stage_1_model/checkpoint.pth.tar'
        model_alphas = torch.load(
            args.alphas_model,
            map_location=args.device)['alphas_dict']['alphas_network']
        model_alphas.requires_grad = False
        model_alphas = F.softmax(model_alphas, dim=-1)
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnet(genotype=genotype,
                             input_c=args.in_channels,
                             c=args.init_channels,
                             num_classes=args.nclass,
                             meta_node_num=args.middle_nodes,
                             layers=args.layers,
                             dp=args.dropout_prob,
                             use_sharing=args.use_sharing,
                             double_down_channel=args.double_down_channel,
                             aux=args.aux)

    elif args.model == "alpha_nodouble":
        args.deepsupervision = False
        args.double_down_channel = False
        args.genotype_name = 'stage1_layer9_110epoch_final'
        args.alphas_model = './search_exp/Nas_Search_Unet/isic2018/nodouble/stage_1_model/checkpoint.pth.tar'
        model_alphas = torch.load(
            args.alphas_model,
            map_location=args.device)['alphas_dict']['alphas_network']
        model_alphas.requires_grad = False
        model_alphas = F.softmax(model_alphas, dim=-1)
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnet(genotype=genotype,
                             input_c=args.in_channels,
                             c=args.init_channels,
                             num_classes=args.nclass,
                             meta_node_num=args.middle_nodes,
                             layers=args.layers,
                             dp=args.dropout_prob,
                             use_sharing=args.use_sharing,
                             double_down_channel=args.double_down_channel,
                             aux=args.aux)

    elif args.model == "alpha_nodouble_deep":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'stage1_layer9_110epoch_deep_final'
        args.alphas_model = './search_exp/Nas_Search_Unet/isic2018/nodouble_deep/stage_1_model/checkpoint.pth.tar'
        model_alphas = torch.load(
            args.alphas_model,
            map_location=args.device)['alphas_dict']['alphas_network']
        model_alphas.requires_grad = False
        model_alphas = F.softmax(model_alphas, dim=-1)
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnet(genotype=genotype,
                             input_c=args.in_channels,
                             c=args.init_channels,
                             num_classes=args.nclass,
                             meta_node_num=args.middle_nodes,
                             layers=args.layers,
                             dp=args.dropout_prob,
                             use_sharing=args.use_sharing,
                             double_down_channel=args.double_down_channel,
                             aux=args.aux)

    elif args.model == "double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage1_layer9_110epoch_double_deep_final'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "double":
        args.deepsupervision = False
        args.double_down_channel = True
        args.genotype_name = 'stage1_layer9_110epoch_double_final'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "nodouble":
        args.deepsupervision = False
        args.double_down_channel = False
        args.genotype_name = 'stage1_layer9_110epoch_final'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "nodouble_deep":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'stage1_layer9_110epoch_deep_final'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha1_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha1_stage1_double_deep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha0_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_stage1_double_deep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha0_5_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha0_5_stage1_double_nodeep_ep80":
        args.deepsupervision = False
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_nodeep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha0_5_stage1_nodouble_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'alpha0_5_stage1_nodouble_deep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha0_5_stage1_nodouble_nodeep_ep80":
        args.deepsupervision = False
        args.double_down_channel = False
        args.genotype_name = 'alpha0_5_stage1_nodouble_nodeep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    # cvc trans
    elif args.model == "layer7_double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'layer7_double_deep'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    # chaos trans
    elif args.model == "stage0_double_deep_ep80_newim":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage0_double_deep_ep80_newim'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    if torch.cuda.device_count() > 1 and args.use_cuda:
        logger.info('use: %d gpus', torch.cuda.device_count())
        model = nn.DataParallel(model)

    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    logger.info(genotype)
    logger.info(model_alphas)
    flop, param = get_model_complexity_info(model, (3, 256, 256),
                                            as_strings=True,
                                            print_per_layer_stat=False)
    print("GFLOPs: {}".format(flop))
    print("Params: {}".format(param))
    # init loss
    if args.loss == 'bce':
        criterion = nn.BCELoss()
    elif args.loss == 'bcelog':
        criterion = nn.BCEWithLogitsLoss()
    elif args.loss == "dice":
        criterion = DiceLoss()
    elif args.loss == "softdice":
        criterion = SoftDiceLoss()
    elif args.loss == 'bcedice':
        criterion = BCEDiceLoss()
    else:
        criterion = nn.CrossEntropyLoss()
    if args.use_cuda:
        logger.info("load model and criterion to gpu !")
    model = model.to(args.device)
    criterion = criterion.to(args.device)
    # init optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                weight_decay=args.weight_decay,
                                momentum=args.momentum)
    # init schedulers  Steplr
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.epoch)
    #scheduler=torch.optim.lr_scheduler.StepLR(optimizer=optimizer,step_size=30,gamma=0.1,last_epoch=-1)

    ############################### check resume #########################
    start_epoch = 0
    if args.resume is not None:
        if os.path.isfile(args.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume, map_location=args.device)
            start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            model.load_state_dict(checkpoint['state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler'])
        else:
            raise FileNotFoundError("No checkpoint found at '{}'".format(
                args.resume))

    #################################### train and val ########################
    max_value = 0
    for epoch in range(start_epoch, args.epoch):
        # lr=adjust_learning_rate(args,optimizer,epoch)
        scheduler.step()
        # train
        if args.deepsupervision:
            mean_loss, value1, value2 = train(args, model_alphas, model,
                                              criterion, train_loader,
                                              optimizer)
            mr, ms, mp, mf, mjc, md, macc = value1
            logger.info(
                "Epoch:{} Train_Loss:{:.3f} Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".
                format(epoch, mean_loss, macc, md, mjc))
            writer.add_scalar('Train/dDice', mmd, epoch)
        else:
            mean_loss, value1 = train(args, model_alphas, model, criterion,
                                      train_loader, optimizer)
            mr, ms, mp, mf, mjc, md, macc = value1
            logger.info(
                "Epoch:{} Train_Loss:{:.3f} Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".
                format(epoch, mean_loss, macc, md, mjc))
        # write
        writer.add_scalar('Train/Loss', mean_loss, epoch)

        # val
        if args.deepsupervision:
            vmean_loss, valuev1, valuev2 = infer(args, model_alphas, model,
                                                 criterion, val_loader)
            vmr, vms, vmp, vmf, vmjc, vmd, vmacc = valuev1
            logger.info(
                "Epoch:{} Val_Loss:{:.3f} Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".
                format(epoch, vmean_loss, vmacc, vmd, vmjc))

        else:
            vmean_loss, valuev1 = infer(args, model_alphas, model, criterion,
                                        val_loader)
            vmr, vms, vmp, vmf, vmjc, vmd, vmacc = valuev1
            logger.info(
                "Epoch:{} Val_Loss:{:.3f} Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".
                format(epoch, vmean_loss, vmacc, vmd, vmjc))

        is_best = True if vmjc >= max_value else False
        max_value = max(max_value, vmjc)
        writer.add_scalar('Val/Loss', vmean_loss, epoch)

        state = {
            'epoch': epoch,
            'optimizer': optimizer.state_dict(),
            'state_dict': model.state_dict(),
            'scheduler': model.state_dict(),
        }
        logger.info("epoch:{} best:{} max_value:{}".format(
            epoch, is_best, max_value))
        if not is_best:
            torch.save(state, os.path.join(args.save_path,
                                           "checkpoint.pth.tar"))
        else:
            torch.save(state, os.path.join(args.save_path,
                                           "checkpoint.pth.tar"))
            torch.save(state, os.path.join(args.save_path,
                                           "model_best.pth.tar"))

    writer.close()
コード例 #9
0
def main(args):

    #args.model_list=['alpha0_double_deep_0.01','alpha0_5_double_deep_0.01','alpha1_double_deep_0.01','nodouble_deep','slim_dd','slim_double','slim_nodouble','slim_nodouble_deep']
    #args.model_list=["double_deep","nodouble_deep","slim_nodouble"]
    #args.model_list=["slim_nodouble_deep_init32"]
    #args.model_list=["slim_nodouble_deep_init48"]
    args.model_list = [
        'alpha0_double_deep_0.01', 'alpha0_5_double_deep_0.01',
        'alpha1_double_deep_0.01'
    ]

    for model_name in args.model_list:
        if model_name == "alpha0_double_deep_0.01":
            args.deepsupervision = True
            args.double_down_channel = True
            args.genotype_name = 'alpha0_stage1_double_deep_ep200'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = BuildNasUnetPrune(
                genotype=genotype,
                input_c=args.in_channels,
                c=args.init_channels,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/alpha0_double_deep_0.01/model_best.pth.tar'
            # kwargs = {'map_location': lambda storage, loc: storage.cuda(0)}
            # state_dict = torch.load(args.model_path, **kwargs)
            # # create new OrderedDict that does not contain `module.`
            # model.load_state_dict(state_dict)

            state_dict = torch.load(args.model_path,
                                    map_location='cpu')['state_dict']
            state_dict = remove_module(state_dict)
            model.load_state_dict(state_dict)

        elif model_name == "alpha0_5_double_deep_0.01":
            args.deepsupervision = True
            args.double_down_channel = True
            args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = BuildNasUnetPrune(
                genotype=genotype,
                input_c=args.in_channels,
                c=args.init_channels,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/alpha0_5_double_deep_0.01/model_best.pth.tar'
            state_dict = torch.load(args.model_path,
                                    map_location='cpu')['state_dict']
            state_dict = remove_module(state_dict)
            model.load_state_dict(state_dict)
            #model.load_state_dict(torch.load(args.model_path, map_location='cpu')['state_dict'])

        elif model_name == "alpha1_double_deep_0.01":
            args.deepsupervision = True
            args.double_down_channel = True
            args.genotype_name = 'alpha1_stage1_double_deep_ep200'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = BuildNasUnetPrune(
                genotype=genotype,
                input_c=args.in_channels,
                c=args.init_channels,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/alpha1_double_deep_0.01/model_best.pth.tar'
            state_dict = torch.load(args.model_path,
                                    map_location='cpu')['state_dict']
            state_dict = remove_module(state_dict)
            model.load_state_dict(state_dict)

            #model.load_state_dict(torch.load(args.model_path, map_location='cpu')['state_dict'])

        #################### init logger ###################################
        log_dir = './eval' + '/{}'.format(
            args.dataset) + '/{}'.format(model_name)
        ##################### init model ########################################
        logger = get_logger(log_dir)
        print('RUNDIR: {}'.format(log_dir))
        logger.info('{}-Eval'.format(model_name))
        # setting
        args.save_path = log_dir
        args.save_images = os.path.join(args.save_path, "images")
        if not os.path.exists(args.save_images):
            os.mkdir(args.save_images)
        ##################### init device #################################
        if args.manualSeed is None:
            args.manualSeed = random.randint(1, 10000)
        np.random.seed(args.manualSeed)
        torch.manual_seed(args.manualSeed)
        args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
        args.device = torch.device('cuda' if args.use_cuda else 'cpu')
        if args.use_cuda:
            torch.cuda.manual_seed(args.manualSeed)
            cudnn.benchmark = True
        ####################### init dataset ###########################################
        # sorted vaild datasets
        val_loader = get_dataloder(args, split_flag="valid")
        setting = {k: v for k, v in args._get_kwargs()}
        logger.info(setting)
        logger.info(genotype)
        logger.info('param size = %fMB', calc_parameters_count(model))

        # init loss
        if args.loss == 'bce':
            criterion = nn.BCELoss()
        elif args.loss == 'bcelog':
            criterion = nn.BCEWithLogitsLoss()
        elif args.loss == "dice":
            criterion = DiceLoss()
        elif args.loss == "softdice":
            criterion = SoftDiceLoss()
        elif args.loss == 'bcedice':
            criterion = BCEDiceLoss()
        else:
            criterion = nn.CrossEntropyLoss()
        if args.use_cuda:
            logger.info("load model and criterion to gpu !")
        model = model.to(args.device)
        criterion = criterion.to(args.device)
        infer(args, model, criterion, val_loader, logger, args.save_images)