def main(args):
    ############    init config ################
    #################### init logger ###################################
    log_dir = './search_exp/' + '/{}'.format(args.model) + \
              '/{}'.format(args.dataset) + '/{}_{}'.format(time.strftime('%Y%m%d-%H%M%S'),args.note)

    logger = get_logger(log_dir)
    print('RUNDIR: {}'.format(log_dir))
    logger.info('{}-Search'.format(args.model))
    args.save_path = log_dir
    args.save_tbx_log = args.save_path + '/tbx_log'
    writer = SummaryWriter(args.save_tbx_log)
    ##################### init device #################################
    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    np.random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.multi_gpu = args.gpus > 1 and torch.cuda.is_available()
    args.device = torch.device('cuda:0' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.manualSeed)
        cudnn.enabled = True
        cudnn.benchmark = True
    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)

    ####################### init dataset ###########################################
    logger.info("Dataset for search is {}".format(args.dataset))
    train_dataset = datasets_dict[args.dataset](args,
                                                args.dataset_root,
                                                split='train')
    val_dataset = datasets_dict[args.dataset](args,
                                              args.dataset_root,
                                              split='valid')
    # train_dataset=datasets_dict[args.dataset](args,split='train')
    # val_dataset=datasets_dict[args.dataset](args,split='valid')
    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))
    # init loss
    if args.loss == 'bce':
        criterion = nn.BCELoss()
    elif args.loss == 'bcelog':
        criterion = nn.BCEWithLogitsLoss()
    elif args.loss == "dice":
        criterion = DiceLoss()
    elif args.loss == "softdice":
        criterion = SoftDiceLoss()
    elif args.loss == 'bcedice':
        criterion = BCEDiceLoss()
    else:
        criterion = nn.CrossEntropyLoss()
    if args.use_cuda:
        logger.info("load criterion to gpu !")
    criterion = criterion.to(args.device)
    ######################## init model ############################################
    switches_normal = []
    switches_down = []
    switches_up = []
    nums_mixop = sum([2 + i for i in range(args.meta_node_num)])
    for i in range(nums_mixop):
        switches_normal.append([True for j in range(len(CellPos))])
    for i in range(nums_mixop):
        switches_down.append([True for j in range(len(CellLinkDownPos))])
    for i in range(nums_mixop):
        switches_up.append([True for j in range(len(CellLinkUpPos))])
    # 6-->3-->1
    drop_op_down = [2, 3]
    # 4-->2-->1
    drop_op_up = [2, 1]
    # 7-->4-->1
    drop_op_normal = [3, 3]
    # stage0 pruning  stage 1 pruning, stage 2 (training)
    original_train_batch = args.train_batch
    original_val_batch = args.val_batch
    for sp in range(2):
        # build dataloader
        # model ,numclass=1,im_ch=3,init_channel=16,intermediate_nodes=4,layers=9
        if sp == 0:
            args.model = "UnetLayer7"
            args.layers = 7
            sp_train_batch = original_train_batch
            sp_val_batch = original_val_batch
            sp_epoch = args.epochs
            sp_lr = args.lr
        else:
            #args.model = "UnetLayer9"
            # 在算力平台上面UnetLayer9就是UnetLayer9_v2
            args.model = "UnetLayer9"
            args.layers = 9
            sp_train_batch = original_train_batch
            sp_val_batch = original_val_batch
            sp_lr = args.lr
            sp_epoch = args.epochs

        train_queue = data.DataLoader(
            train_dataset,
            batch_size=sp_train_batch,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                indices[:split]),
            pin_memory=True,
            num_workers=args.num_workers)
        val_queue = data.DataLoader(
            train_dataset,
            batch_size=sp_train_batch,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                indices[split:num_train]),
            pin_memory=True,
            num_workers=args.num_workers)
        test_dataloader = data.DataLoader(val_dataset,
                                          batch_size=sp_val_batch,
                                          pin_memory=True,
                                          num_workers=args.num_workers)
        logger.info(
            "stage:{} model:{} epoch:{} lr:{} train_batch:{} val_batch:{}".
            format(sp, args.model, sp_epoch, sp_lr, sp_train_batch,
                   sp_val_batch))

        model = get_models(args, switches_normal, switches_down, switches_up)
        save_model_path = os.path.join(args.save_path,
                                       "stage_{}_model".format(sp))
        if not os.path.exists(save_model_path):
            os.mkdir(save_model_path)
        if args.multi_gpu:
            logger.info('use: %d gpus', args.gpus)
            model = nn.DataParallel(model)
        model = model.to(args.device)
        logger.info('param size = %fMB', calc_parameters_count(model))
        # init optimizer for arch parameters and weight parameters
        # final stage, just train the network parameters
        optimizer_arch = torch.optim.Adam(model.arch_parameters(),
                                          lr=args.arch_lr,
                                          betas=(0.5, 0.999),
                                          weight_decay=args.arch_weight_decay)
        optimizer_weight = torch.optim.SGD(model.weight_parameters(),
                                           lr=sp_lr,
                                           weight_decay=args.weight_decay,
                                           momentum=args.momentum)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer_weight, sp_epoch, eta_min=args.lr_min)
        #################################### train and val ########################
        max_value = 0
        for epoch in range(0, sp_epoch):
            # lr=adjust_learning_rate(args,optimizer,epoch)
            scheduler.step()
            logger.info('Epoch: %d lr %e', epoch, scheduler.get_lr()[0])
            # train
            if epoch < args.arch_after:
                weight_loss_avg, arch_loss_avg, mr, ms, mp, mf, mjc, md, macc = train(
                    args,
                    train_queue,
                    val_queue,
                    model,
                    criterion,
                    optimizer_weight,
                    optimizer_arch,
                    train_arch=False)
            else:
                weight_loss_avg, arch_loss_avg, mr, ms, mp, mf, mjc, md, macc = train(
                    args,
                    train_queue,
                    val_queue,
                    model,
                    criterion,
                    optimizer_weight,
                    optimizer_arch,
                    train_arch=True)
            logger.info("Epoch:{} WeightLoss:{:.3f}  ArchLoss:{:.3f}".format(
                epoch, weight_loss_avg, arch_loss_avg))
            logger.info("         Acc:{:.3f}   Dice:{:.3f}  Jc:{:.3f}".format(
                macc, md, mjc))
            # write
            writer.add_scalar('Train/W_loss', weight_loss_avg, epoch)
            writer.add_scalar('Train/A_loss', arch_loss_avg, epoch)
            writer.add_scalar('Train/Dice', md, epoch)
            # infer
            if (epoch + 1) % args.infer_epoch == 0:
                genotype = model.genotype()
                logger.info('genotype = %s', genotype)
                val_loss, (vmr, vms, vmp, vmf, vmjc, vmd,
                           vmacc) = infer(args, model, val_queue, criterion)
                logger.info(
                    "ValLoss:{:.3f} ValAcc:{:.3f}  ValDice:{:.3f} ValJc:{:.3f}"
                    .format(val_loss, vmacc, vmd, vmjc))
                writer.add_scalar('Val/loss', val_loss, epoch)

                is_best = True if (vmjc >= max_value) else False
                max_value = max(max_value, vmjc)
                state = {
                    'epoch': epoch,
                    'optimizer_arch': optimizer_arch.state_dict(),
                    'optimizer_weight': optimizer_weight.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'state_dict': model.state_dict(),
                    'alphas_dict': model.alphas_dict(),
                }
                logger.info("epoch:{} best:{} max_value:{}".format(
                    epoch, is_best, max_value))
                if not is_best:
                    torch.save(
                        state,
                        os.path.join(save_model_path, "checkpoint.pth.tar"))
                else:
                    torch.save(
                        state,
                        os.path.join(save_model_path, "checkpoint.pth.tar"))
                    torch.save(
                        state,
                        os.path.join(save_model_path, "model_best.pth.tar"))

        # one stage end, we should change the operations num (divided 2)
        weight_down = F.softmax(model.arch_parameters()[0],
                                dim=-1).data.cpu().numpy()
        weight_up = F.softmax(model.arch_parameters()[1],
                              dim=-1).data.cpu().numpy()
        weight_normal = F.softmax(model.arch_parameters()[2],
                                  dim=-1).data.cpu().numpy()
        weight_network = F.softmax(model.arch_parameters()[3],
                                   dim=-1).data.cpu().numpy()
        logger.info("alphas_down: \n{}".format(weight_down))
        logger.info("alphas_up: \n{}".format(weight_up))
        logger.info("alphas_normal: \n{}".format(weight_normal))
        logger.info("alphas_network: \n{}".format(weight_network))

        genotype = model.genotype()
        logger.info('Stage:{} \n  Genotype: {}'.format(sp, genotype))
        logger.info(
            '------Stage {} end ! Then  Dropping Paths------'.format(sp))
        # 6                4              7
        # CellLinkDownPos CellLinkUpPos CellPos
        # # 6-->3-->1
        # drop_op_down = [3, 2]
        # # 4-->2-->1
        # drop_op_up = [2, 1]
        # # 7-->4-->1
        # drop_op_normal = [3, 3]
        # update switches in 0 stage end
        if sp == 0:
            switches_down = update_switches(weight_down.copy(),
                                            switches_down.copy(),
                                            CellLinkDownPos, drop_op_down[sp])
            switches_up = update_switches(weight_up.copy(), switches_up.copy(),
                                          CellLinkUpPos, drop_op_up[sp])
            switches_normal = update_switches(weight_normal.copy(),
                                              switches_normal.copy(), CellPos,
                                              drop_op_normal[sp])
            logger.info('switches_down = %s', switches_down)
            logger.info('switches_up = %s', switches_up)
            logger.info('switches_normal = %s', switches_normal)
            logging_switches(logger, switches_down, CellLinkDownPos)
            logging_switches(logger, switches_up, CellLinkUpPos)
            logging_switches(logger, switches_normal, CellPos)
        else:
            # sp==1 is the final stage, we don`t need the keep operations
            # because we has the model.genotype
            # show the final one op in 14 mixop
            switches_down = update_switches(weight_down.copy(),
                                            switches_down.copy(),
                                            CellLinkDownPos, drop_op_down[sp])
            switches_up = update_switches(weight_up.copy(), switches_up.copy(),
                                          CellLinkUpPos, drop_op_up[sp])
            switches_normal = update_switches_nozero(weight_normal.copy(),
                                                     switches_normal.copy(),
                                                     CellPos,
                                                     drop_op_normal[sp])
            logger.info('switches_down = %s', switches_down)
            logger.info('switches_up = %s', switches_up)
            logger.info('switches_normal = %s', switches_normal)
            logging_switches(logger, switches_down, CellLinkDownPos)
            logging_switches(logger, switches_up, CellLinkUpPos)
            logging_switches(logger, switches_normal, CellPos)
    writer.close()
示例#2
0
def main(args):
    ############    init config ################
    model_name = args.model
    assert model_name in models_dict.keys(),"The Usage model is not exist !"
    print('Usage model :{}'.format(model_name))

    #################### init logger ###################################
    log_dir = './logs/'+ args.model+'_'+args.note + '/{}'.format(time.strftime('%Y%m%d-%H%M%S'))
    logger = get_logger(log_dir)
    print('RUNDIR: {}'.format(log_dir))
    logger.info('{}-Train'.format(args.model))
    # setting
    setting={k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    args.save_path = log_dir
    args.save_tbx_log = args.save_path + '/tbx_log'
    writer = SummaryWriter(args.save_tbx_log)
    ##################### init device #################################
    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    np.random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    args.use_cuda= args.gpus>0 and torch.cuda.is_available()
    args.device = torch.device('cuda' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.manualSeed)
        cudnn.benchmark = True
    ####################### init dataset ###########################################
    train_loader=get_dataloder(args,split_flag="train")
    val_loader=get_dataloder(args,split_flag="valid")
    ######################## init model ############################################
    # model
    logger.info("Model Dict has keys: \n {}".format(models_dict.keys()))
    model=get_models(args)
    if torch.cuda.device_count() > 1 and args.use_cuda:
        logger.info('use: %d gpus', torch.cuda.device_count())
        model = nn.DataParallel(model)
    logger.info('param size = %fMB', calc_parameters_count(model))
    # init loss
    if args.loss=='bce':
        criterion=nn.BCELoss()
    elif args.loss=='bcelog':
        criterion=nn.BCEWithLogitsLoss()
    elif args.loss=="dice":
        criterion=DiceLoss()
    elif args.loss=="softdice":
        criterion=SoftDiceLoss()
    elif args.loss=='bcedice':
        criterion=BCEDiceLoss()
    else:
        criterion=nn.CrossEntropyLoss()
    if args.use_cuda:
        logger.info("load model and criterion to gpu !")
        model=model.to(args.device)
        criterion=criterion.to(args.device)
    # init optimizer
    if args.model_optimizer=="sgd":
        #torch.optim.SGD(parametetrs,lr=args.lr,weight_decay=args.weight_decay,momentum=args.momentum)
        optimizer=torch.optim.SGD(model.parameters(),lr=args.lr,weight_decay=args.weight_decay,momentum=args.momentum)
    else:
        optimizer=torch.optim.Adam(model.parameters(),args.lr,[args.beta1, args.beta2],
                                   weight_decay=args.weight_decay)

    # init schedulers  Steplr
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,args.epoch)
    # scheduler=torch.optim.lr_scheduler.StepLR(optimizer=optimizer,step_size=30,gamma=0.1,last_epoch=-1)
    ############################### check resume #########################
    start_epoch=0
    if args.resume is not None:
        if os.path.isfile(args.resume):
            logger.info("Loading model and optimizer from checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=args.device)
            start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            model.load_state_dict(checkpoint['state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler'])
        else:
            raise FileNotFoundError("No checkpoint found at '{}'".format(args.resume))

    #################################### train and val ########################
    max_value=0
    for epoch in range(start_epoch,args.epoch):
        # lr=adjust_learning_rate(args,optimizer,epoch)
        scheduler.step()
        logger.info('Epoch: %d lr %e', epoch, scheduler.get_lr()[0])
        # train
        mr, ms, mp, mf, mjc, md, macc, mean_loss=train(args, model, criterion, train_loader,
                                                       optimizer, epoch, logger)
        # write
        writer.add_scalar('Train/Loss', mean_loss, epoch)
        writer.add_scalar('Train/mAcc', macc, epoch)
        writer.add_scalar('Train/Recall', mr, epoch)
        writer.add_scalar('Train/Specifi', ms, epoch)
        writer.add_scalar('Train/Precision', mp, epoch)
        writer.add_scalar('Train/F1', mf, epoch)
        writer.add_scalar('Train/Jc', mjc, epoch)
        writer.add_scalar('Train/Dice', md, epoch)

        # val
        vmr, vms, vmp, vmf, vmjc, vmd, vmacc, vmean_loss=val(args, model, criterion, val_loader, epoch, logger)

        writer.add_scalar('Val/Loss', vmean_loss, epoch)
        writer.add_scalar('Val/mAcc', vmacc, epoch)
        writer.add_scalar('Val/Recall', vmr, epoch)
        writer.add_scalar('Val/Specifi', vms, epoch)
        writer.add_scalar('Val/Precision', vmp, epoch)
        writer.add_scalar('Val/F1', vmf, epoch)
        writer.add_scalar('Val/Jc', vmjc, epoch)
        writer.add_scalar('Val/Dice', vmd, epoch)

        is_best=True if (vmjc>=max_value) else False
        max_value=max(max_value,vmjc)
        state={
                'epoch': epoch,
                'optimizer': optimizer.state_dict(),
                'state_dict': model.state_dict(),
                'scheduler': model.state_dict(),
            }
        logger.info("epoch:{} best:{} max_value:{}".format(epoch,is_best,max_value))
        if not is_best:
            torch.save(state,os.path.join(args.save_path,"checkpoint.pth.tar"))
        else:
            torch.save(state,os.path.join(args.save_path,"checkpoint.pth.tar"))
            torch.save(state,os.path.join(args.save_path,"model_best.pth.tar"))

    writer.close()
示例#3
0
def main(args):
    #################### init logger ###################################
    log_dir = './logs/' + '{}'.format(args.dataset) + '/{}_{}_{}'.format(args.model,args.note,time.strftime('%Y%m%d-%H%M%S'))


    logger = get_logger(log_dir)
    print('RUNDIR: {}'.format(log_dir))
    logger.info('{}-Train'.format(args.model))
    # setting
    args.save_path = log_dir
    args.save_tbx_log = args.save_path + '/tbx_log'
    writer = SummaryWriter(args.save_tbx_log)
    ##################### init device #################################
    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    np.random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.device = torch.device('cuda' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.manualSeed)
        cudnn.benchmark = True
    ####################### init dataset ###########################################
    train_loader = get_dataloder(args, split_flag="train")
    val_loader = get_dataloder(args, split_flag="valid")

    ############init model ###########################
    if  args.model == "layer7_double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'layer7_double_deep'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=7,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )



    elif args.model == "stage1_double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage1_double_deep'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    elif args.model == "stage1_nodouble_deep":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'stage1_deep'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    elif args.model == "stage1_nodouble_deep_slim":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'stage1_deep'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPruneSlim(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )


    elif args.model == "alpha1_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha1_stage1_double_deep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    elif args.model == "alpha0_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_stage1_double_deep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    #isic trans
    elif args.model == "stage1_layer9_110epoch_double_deep_final":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage1_layer9_110epoch_double_deep_final'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )


    # just normaL cell keep
    elif args.model == "dd_normal":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPruneNormal(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    # normal+down
    elif args.model == "dd_normaldown":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPruneNormalDown(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    # normal+up 
    elif args.model == "dd_normalup":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPruneNormalUp(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    # normal+up+down
    elif args.model == "alpha0_5_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    # abliation study of channel doubling and deepsupervision
    elif args.model == "alpha0_5_stage1_double_nodeep_ep80":
        args.deepsupervision = False
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_nodeep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    elif args.model == "alpha0_5_stage1_nodouble_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'alpha0_5_stage1_nodouble_deep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    elif args.model == "alpha0_5_stage1_nodouble_nodeep_ep80":
        args.deepsupervision = False
        args.double_down_channel = False
        args.genotype_name = 'alpha0_5_stage1_nodouble_nodeep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(
            genotype=genotype,
            input_c=args.in_channels,
            c=args.init_channels,
            num_classes=args.nclass,
            meta_node_num=args.middle_nodes,
            layers=args.layers,
            dp=args.dropout_prob,
            use_sharing=args.use_sharing,
            double_down_channel=args.double_down_channel,
            aux=args.aux
        )

    if torch.cuda.device_count() > 1 and args.use_cuda:
        logger.info('use: %d gpus', torch.cuda.device_count())
        model = nn.DataParallel(model)

    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    logger.info(genotype)
    logger.info('param size = %fMB', calc_parameters_count(model))
    # init loss
    if args.loss == 'bce':
        criterion = nn.BCELoss()
    elif args.loss == 'bcelog':
        criterion = nn.BCEWithLogitsLoss()
    elif args.loss == "dice":
        criterion = DiceLoss()
    elif args.loss == "softdice":
        criterion = SoftDiceLoss()
    elif args.loss == 'bcedice':
        criterion = BCEDiceLoss()
    else:
        criterion = nn.CrossEntropyLoss()
    if args.use_cuda:
        logger.info("load model and criterion to gpu !")
    model = model.to(args.device)
    criterion = criterion.to(args.device)
    # init optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
    # init schedulers  Steplr
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epoch)
    # scheduler=torch.optim.lr_scheduler.StepLR(optimizer=optimizer,step_size=30,gamma=0.1,last_epoch=-1)

    ############################### check resume #########################
    start_epoch = 0
    if args.resume is not None:
        if os.path.isfile(args.resume):
            logger.info("Loading model and optimizer from checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=args.device)
            start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            model.load_state_dict(checkpoint['state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler'])
        else:
            raise FileNotFoundError("No checkpoint found at '{}'".format(args.resume))

    #################################### train and val ########################
    max_value = 0
    for epoch in range(start_epoch, args.epoch):
        # lr=adjust_learning_rate(args,optimizer,epoch)
        scheduler.step()
        logger.info('Epoch: %d lr %e', epoch, scheduler.get_lr()[0])
        # train
        if args.deepsupervision:
            mean_loss, value1, value2 = train(args, model, criterion, train_loader, optimizer)
            mr, ms, mp, mf, mjc, md, macc = value1
            mmr, mms, mmp, mmf, mmjc, mmd, mmacc = value2
            logger.info(
                "Epoch:{} Train_Loss:{:.3f} Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".format(epoch, mean_loss, macc, md, mjc))
            logger.info("                        dmAcc:{:.3f} dmDice:{:.3f} dmJc:{:.3f}".format(mmacc, mmd, mmjc))
            writer.add_scalar('Train/dmAcc', mmacc, epoch)
            writer.add_scalar('Train/dRecall', mmr, epoch)
            writer.add_scalar('Train/dSpecifi', mms, epoch)
            writer.add_scalar('Train/dPrecision', mmp, epoch)
            writer.add_scalar('Train/dF1', mmf, epoch)
            writer.add_scalar('Train/dJc', mmjc, epoch)
            writer.add_scalar('Train/dDice', mmd, epoch)
        else:
            mean_loss, value1 = train(args, model, criterion, train_loader,
                                      optimizer)
            mr, ms, mp, mf, mjc, md, macc = value1
            logger.info(
                "Epoch:{} Train_Loss:{:.3f} Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".format(epoch, mean_loss, macc, md, mjc))
        # write
        writer.add_scalar('Train/Loss', mean_loss, epoch)
        writer.add_scalar('Train/mAcc', macc, epoch)
        writer.add_scalar('Train/Recall', mr, epoch)
        writer.add_scalar('Train/Specifi', ms, epoch)
        writer.add_scalar('Train/Precision', mp, epoch)
        writer.add_scalar('Train/F1', mf, epoch)
        writer.add_scalar('Train/Jc', mjc, epoch)
        writer.add_scalar('Train/Dice', md, epoch)

        # val
        if args.deepsupervision:
            vmean_loss, valuev1, valuev2 = infer(args, model, criterion, val_loader)
            vmr, vms, vmp, vmf, vmjc, vmd, vmacc = valuev1
            mvmr, mvms, mvmp, mvmf, mvmjc, mvmd, mvmacc = valuev2
            logger.info(
                "Epoch:{} Val_Loss:{:.3f} Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".format(epoch, vmean_loss, vmacc, vmd, vmjc))
            logger.info("                        dmAcc:{:.3f} dmDice:{:.3f} dmJc:{:.3f}".format(mvmacc, mvmd, mvmjc))
            writer.add_scalar('Val/mAcc', mvmacc, epoch)
            writer.add_scalar('Val/Recall', mvmr, epoch)
            writer.add_scalar('Val/Specifi', mvms, epoch)
            writer.add_scalar('Val/Precision', mvmp, epoch)
            writer.add_scalar('Val/F1', mvmf, epoch)
            writer.add_scalar('Val/Jc', mvmjc, epoch)
            writer.add_scalar('Val/Dice', mvmd, epoch)
        else:
            vmean_loss, valuev1 = infer(args, model, criterion, val_loader)
            vmr, vms, vmp, vmf, vmjc, vmd, vmacc = valuev1
            logger.info(
                "Epoch:{} Val_Loss:{:.3f} Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".format(epoch, vmean_loss, vmacc, vmd, vmjc))

        is_best = True if (vmjc >=max_value) else False
        max_value = max(max_value, vmjc)
        writer.add_scalar('Val/Loss', vmean_loss, epoch)
        writer.add_scalar('Val/mAcc', vmacc, epoch)
        writer.add_scalar('Val/Recall', vmr, epoch)
        writer.add_scalar('Val/Specifi', vms, epoch)
        writer.add_scalar('Val/Precision', vmp, epoch)
        writer.add_scalar('Val/F1', vmf, epoch)
        writer.add_scalar('Val/Jc', vmjc, epoch)
        writer.add_scalar('Val/Dice', vmd, epoch)

        state={
                'epoch': epoch,
                'optimizer': optimizer.state_dict(),
                'state_dict': model.state_dict(),
                'scheduler': model.state_dict(),
            }
        logger.info("epoch:{} best:{} max_value:{}".format(epoch,is_best,max_value))
        if not is_best:
            torch.save(state,os.path.join(args.save_path,"checkpoint.pth.tar"))
        else:
            torch.save(state,os.path.join(args.save_path,"checkpoint.pth.tar"))
            torch.save(state,os.path.join(args.save_path,"model_best.pth.tar"))

    writer.close()
示例#4
0
def main(args):
    #################### init logger ###################################
    args.model='unet'
    model_weight_path='../logs/isic2018/unet_ep300/20200402-135108/model_best.pth.tar'
    model=get_models(args)
    model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])

    log_dir = './models/' + args.model+'_prune_'+args.note
    logger = get_logger(log_dir)
    print('RUNDIR: {}'.format(log_dir))
    logger.info('{}-L1Prune'.format(args.model))
    # setting
    args.save_path = log_dir
    args.save_tbx_log = args.save_path + '/tbx_log'
    writer = SummaryWriter(args.save_tbx_log)

    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    np.random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)

    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.device = torch.device('cuda' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.manualSeed)
        cudnn.benchmark = True

    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)

    train_loader=get_dataloder(args,split_flag="train")
    val_loader=get_dataloder(args,split_flag="valid")


    # init loss
    if args.loss == 'bce':
        criterion = nn.BCELoss()
    elif args.loss == 'bcelog':
        criterion = nn.BCEWithLogitsLoss()
    elif args.loss == "dice":
        criterion = DiceLoss()
    elif args.loss == "softdice":
        criterion = SoftDiceLoss()
    elif args.loss == 'bcedice':
        criterion = BCEDiceLoss()
    else:
        criterion = nn.CrossEntropyLoss()
    if args.use_cuda:
        logger.info("load model and criterion to gpu !")
    model = model.to(args.device)
    criterion = criterion.to(args.device)

    logger.info("Original trained model performance test: ")
    infer(args, model, criterion, val_loader,logger)

    # Pruning
    # Pruning Configuration, in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS',
    configure_list = [{
        'sparsity': 0.5,
        'op_types': ['Conv2d'],
        'op_names': ['Conv1.conv.0','Conv1.conv.3','Conv2.conv.0','Conv2.conv.3','Conv3.conv.0','Conv3.conv.3',
                     'Conv4.conv.0','Conv4.conv.3','Conv5.conv.0','Conv5.conv.3',
                     'Up5.up.1','Up_conv5.conv.0','Up_conv5.conv.3',
                     'Up4.up.1','Up_conv4.conv.0','Up_conv4.conv.3',
                     'Up3.up.1','Up_conv3.conv.0','Up_conv3.conv.3',
                     'Up2.up.1','Up_conv2.conv.0','Up_conv2.conv.3',
                     ]}
    ]
    # Prune model and test accuracy without fine tuning.
    logger.info('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
    pruner = L1FilterPruner(model, configure_list)

    # change the forward func (mul pruning mask )
    model = pruner.compress()

    # test performance without finetuning
    logger.info("Pruning trained model performance test: ")
    infer(args, model, criterion, val_loader,logger)

    # Fine tune the pruned model for 40 epochs and test accuracy
    logger.info('=' * 10 + 'Fine tuning' + '=' * 10)
    #torch.optim.SGD(parametetrs,lr=args.lr,weight_decay=args.weight_decay,momentum=args.momentum)
    optimizer=torch.optim.SGD(model.parameters(),lr=args.lr,weight_decay=args.weight_decay,momentum=args.momentum)
    # init schedulers  Steplr
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,args.epoch)

    max_value = 0
    for epoch in range(0, args.epoch):
        # lr=adjust_learning_rate(args,optimizer,epoch)
        scheduler.step()
        logger.info('Epoch: %d lr %e', epoch, scheduler.get_lr()[0])
        # update mask
        pruner.update_epoch(epoch)
        # train
        train(args, model, criterion, train_loader,optimizer, epoch, logger)
        # val
        vmr, vms, vmp, vmf, vmjc, vmd, vmacc,vloss = infer(args, model, criterion, val_loader,logger)

        writer.add_scalar('Val/Loss', vloss, epoch)
        writer.add_scalar('Val/mAcc', vmacc, epoch)
        writer.add_scalar('Val/Recall', vmr, epoch)
        writer.add_scalar('Val/Specifi', vms, epoch)
        writer.add_scalar('Val/Precision', vmp, epoch)
        writer.add_scalar('Val/F1', vmf, epoch)
        writer.add_scalar('Val/Jc', vmjc, epoch)
        writer.add_scalar('Val/Dice', vmd, epoch)

        is_best = True if (vmjc >= max_value) else False
        max_value = max(max_value, vmjc)
        if is_best:
            pruner.export_model(model_path=os.path.join(args.save_path,"best_prune_unet.pth"), mask_path=os.path.join(args.save_path,'mask_prune_indexs.pth'))
        state = {
            'epoch': epoch,
            'optimizer': optimizer.state_dict(),
            'state_dict': model.state_dict(),
            'scheduler': model.state_dict(),
        }
        logger.info("epoch:{} best:{} max_value:{}".format(epoch, is_best, max_value))
        torch.save(state, os.path.join(args.save_path, "checkpoint.pth.tar"))
    writer.close()

    # test the best_prune_unet.pth
    args.model='unet'
    model_weight_path=os.path.join(args.save_path,"best_prune_unet.pth")
    model=get_models(args)
    model.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    model = model.to(args.device)
    logger.info("Final saved pruned  model performance test: ")
    infer(args, model, criterion, val_loader,logger)
示例#5
0
        print('Performing data augmentation...')
        augmentation(args.root_path, mode='train')
        augmentation(args.root_path, mode='val')
        augmentation(args.root_path, mode='test')
        exit()

    # Configure loss function type and distance option
    if args.loss == 'bce' and args.option == 'binary':
        loss_fn = ShapeBCELoss()
    elif args.loss == 'bce' and args.option == 'multi':
        # loss_fn = ShapeBCELoss() if args.option == 'multi' else FPNLoss()
        loss_fn = ShapeBCELoss()
    elif args.loss == 'jaccard':
        loss_fn = IoULoss()
    elif args.loss == 'dice':
        loss_fn = SoftDiceLoss()
    elif args.loss == 'boundary':
        alpha = 1.0  # a: (a * Region-based loss + (1-a) * boundary loss)
        loss_fn = SurfaceLoss(alpha=alpha, dice=args.region_option)
    else:
        raise NotImplementedError(
            'Loss function {0} not recognized'.format(loss))

    # Configure network architecture option
    c_in = 1
    if args.option == 'binary':
        c_out = 1
    elif args.option == 'multi':
        c_out = 3
    else:
        raise NotImplementedError(
示例#6
0
def main(args):
    #################### init logger ###################################
    log_dir = './eval' + '/{}'.format(args.dataset) + '/{}'.format(args.model)
    logger = get_logger(log_dir)
    print('RUNDIR: {}'.format(log_dir))
    logger.info('{}-Eval'.format(args.model))
    # setting
    args.save_path = log_dir
    args.save_images = os.path.join(args.save_path, "images")
    if not os.path.exists(args.save_images):
        os.mkdir(args.save_images)
    ##################### init device #################################
    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    np.random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.device = torch.device('cuda' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.manualSeed)
        cudnn.benchmark = True
    ####################### init dataset ###########################################
    val_loader = get_dataloder(args, split_flag="valid")

    ######################## init model ############################################
    if args.model == "layer7_double_deep_ep1600_8lr4e-3":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'layer7_double_deep'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=9,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)
        args.model_path = './logs/cvc/layer7_double_deep_ep1600_8lr4e-3/model_best.pth.tar'
        model.load_state_dict(
            torch.load(args.model_path, map_location='cpu')['state_dict'])

    elif args.model == "alpha0_double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_stage1_double_deep_ep200'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)
        args.model_path = './logs/cvc/alpha0_8lr4e-3/model_best.pth.tar'
        state_dict = torch.load(args.model_path,
                                map_location='cpu')['state_dict']
        state_dict = remove_module(state_dict)
        model.load_state_dict(state_dict)

    elif args.model == "alpha0_5_double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)
        args.model_path = './logs/cvc/alpha0_5_8lr4e-3/model_best.pth.tar'
        state_dict = torch.load(args.model_path,
                                map_location='cpu')['state_dict']
        state_dict = remove_module(state_dict)
        model.load_state_dict(state_dict)

    elif args.model == "alpha1_double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)
        args.model_path = './logs/cvc/alpha1_8lr4e-3/model_best.pth.tar'
        state_dict = torch.load(args.model_path,
                                map_location='cpu')['state_dict']
        state_dict = remove_module(state_dict)
        model.load_state_dict(state_dict)

    else:
        raise NotImplementedError()

    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    logger.info(genotype)
    logger.info('param size = %fMB', calc_parameters_count(model))
    # init loss
    if args.loss == 'bce':
        criterion = nn.BCELoss()
    elif args.loss == 'bcelog':
        criterion = nn.BCEWithLogitsLoss()
    elif args.loss == "dice":
        criterion = DiceLoss()
    elif args.loss == "softdice":
        criterion = SoftDiceLoss()
    elif args.loss == 'bcedice':
        criterion = BCEDiceLoss()
    else:
        criterion = nn.CrossEntropyLoss()
    if args.use_cuda:
        logger.info("load model and criterion to gpu !")
    model = model.to(args.device)
    criterion = criterion.to(args.device)

    infer(args, model, criterion, val_loader, logger, args.save_images)
示例#7
0
def main(args):
    #################### init logger ###################################
    # args.model_list=["unet","unet++_deep","unet++_nodeep",'attention_unet_v1','multires_unet','r2unet_t3',
    #                  'unet_ep800dice','unet++_deep_ep800dice','unet++_nodeep_ep800dice','attention_unet_v1_ep800dice','multires_unet_ep800dice'
    #                  ]
    args.model_list = ['unet', 'unet++', "attention_unet", "multires_unet"]

    for model_name in args.model_list:
        if model_name == 'unet':
            args.model = 'unet'
            model_weight_path = './logs/chaos/unet_ep150_v2/20200403-134703/checkpoint.pth.tar'
            model = get_models(args)
            model.load_state_dict(
                torch.load(model_weight_path,
                           map_location='cpu')['state_dict'])
        elif model_name == 'unet++':
            args.model = 'unet++'
            args.deepsupervision = False
            model_weight_path = './logs/chaos/unet++_ep150_v2/20200403-135401/checkpoint.pth.tar'
            model = get_models(args)
            model.load_state_dict(
                torch.load(model_weight_path,
                           map_location='cpu')['state_dict'])

        # elif model_name == 'unet++_deep':
        #     args.model = 'unet++'
        #     args.deepsupervision = True
        #     model_weight_path = './logs/unet++_deep_ep1600/cvc/20200312-143345/model_best.pth.tar'
        #     model = get_models(args)
        #     model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])

        elif model_name == 'attention_unet':
            args.model = 'attention_unet_v1'
            args.deepsupervision = False
            model_weight_path = './logs/chaos/attention_unet_v1_ep150_v2/20200403-135445/checkpoint.pth.tar'
            model = get_models(args)
            model.load_state_dict(
                torch.load(model_weight_path,
                           map_location='cpu')['state_dict'])

        elif model_name == 'multires_unet':
            args.model = 'multires_unet'
            args.deepsupervision = False
            model_weight_path = './logs/chaos/multires_unet_ep150_v2/20200403-135549/checkpoint.pth.tar'
            model = get_models(args)
            model.load_state_dict(
                torch.load(model_weight_path,
                           map_location='cpu')['state_dict'])

        else:
            raise NotImplementedError()

        assert os.path.exists(args.save)
        args.model_save_path = os.path.join(args.save, model_name)
        logger = get_logger(args.model_save_path)
        args.save_images = os.path.join(args.model_save_path, "images")
        if not os.path.exists(args.save_images):
            os.mkdir(args.save_images)
        if args.manualSeed is None:
            args.manualSeed = random.randint(1, 10000)
        np.random.seed(args.manualSeed)
        torch.manual_seed(args.manualSeed)
        args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
        args.device = torch.device('cuda' if args.use_cuda else 'cpu')
        if args.use_cuda:
            torch.cuda.manual_seed(args.manualSeed)
            cudnn.benchmark = True
        val_loader = get_dataloder(args, split_flag="valid")
        setting = {k: v for k, v in args._get_kwargs()}
        logger.info(setting)
        logger.info('param size = %fMB', calc_parameters_count(model))

        # init loss
        if args.loss == 'bce':
            criterion = nn.BCELoss()
        elif args.loss == 'bcelog':
            criterion = nn.BCEWithLogitsLoss()
        elif args.loss == "dice":
            criterion = DiceLoss()
        elif args.loss == "softdice":
            criterion = SoftDiceLoss()
        elif args.loss == 'bcedice':
            criterion = BCEDiceLoss()
        else:
            criterion = nn.CrossEntropyLoss()
        if args.use_cuda:
            logger.info("load model and criterion to gpu !")

        model = model.to(args.device)
        criterion = criterion.to(args.device)
        infer(args, model, criterion, val_loader, logger, args.save_images)
示例#8
0
def main(args):
    #################### init logger ###################################
    args.model_list=["unet","unet++",'attention_unet_v1','multires_unet','r2unet_t3']


    for model_name in args.model_list:
        if model_name=='unet':
            args.model='unet'
            model_weight_path='./logs/unet_ep1600/cvc/20200312-143050/model_best.pth.tar'
            model=get_models(args)
            model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])
        elif model_name=='unet++':
            args.model='unet++'
            args.deepsupervision=False
            model_weight_path='./logs/unet++_ep1600/cvc/20200312-143358/model_best.pth.tar'
            model=get_models(args)
            model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])

        elif model_name == 'attention_unet_v1':
            args.model = 'attention_unet_v1'
            model_weight_path = './logs/attention_unet_v1_ep1600/cvc/20200312-143413/model_best.pth.tar'
            model = get_models(args)
            model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])

        elif model_name == 'multires_unet':
            args.model = 'multires_unet'
            model_weight_path = './logs/multires_unet_ep1600_t2/20200322-194117/model_best.pth.tar'
            model = get_models(args)
            model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])

        # change bn relu order
        elif model_name == 'multires_unet_align':
            args.model = 'multires_unet'
            model_weight_path = './logs/multires_unet_ep1600_chbnrelu/20200327-184457/model_best.pth.tar'
            model = get_models(args)
            model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])


        elif model_name == 'r2unet_t3':
            args.model = 'r2unet'
            args.time_step=3
            model_weight_path = './logs/r2unet_ep1600_t2/20200324-032815/model_best.pth.tar'
            model = get_models(args)
            model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])


        elif model_name == 'unet_ep800dice':
            args.model = 'unet'
            model_weight_path = './logs/unet_ep800_bcedice/cvc/20200315-043021/model_best.pth.tar'
            model = get_models(args)
            model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])

        elif model_name=='unet++_nodeep_ep800dice':
            args.model='unet++'
            args.deepsupervision=False
            model_weight_path='./logs/unet++_ep800_bcedice/cvc/20200315-043214/model_best.pth.tar'
            model=get_models(args)
            model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])
        elif model_name == 'unet++_deep_ep800dice':
            args.model = 'unet++'
            args.deepsupervision = True
            model_weight_path = './logs/unet++_deep_ep800_bcedice/cvc/20200315-043134/model_best.pth.tar'
            model = get_models(args)
            model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])

        elif model_name == 'attention_unet_v1_ep800dice':
            args.model = 'attention_unet_v1'
            args.deepsupervision=False
            model_weight_path = './logs/attention_unet_v1_ep800_bcedice/cvc/20200315-043300/model_best.pth.tar'
            model = get_models(args)
            model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])

        elif model_name == 'multires_unet_ep800dice':
            args.model = 'multires_unet'
            args.deepsupervision=False
            model_weight_path = './logs/multires_unet_ep800_bcedice/cvc/20200312-173031/model_best.pth.tar'
            model = get_models(args)
            model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])

        else:
            raise  NotImplementedError()


        assert os.path.exists(args.save)
        args.model_save_path=os.path.join(args.save,model_name)
        logger = get_logger(args.model_save_path)
        args.save_images= os.path.join(args.model_save_path,"images")
        if not os.path.exists(args.save_images):
            os.mkdir(args.save_images)
        if args.manualSeed is None:
            args.manualSeed = random.randint(1, 10000)
        np.random.seed(args.manualSeed)
        torch.manual_seed(args.manualSeed)
        args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
        args.device = torch.device('cuda' if args.use_cuda else 'cpu')
        if args.use_cuda:
            torch.cuda.manual_seed(args.manualSeed)
            cudnn.benchmark = True
        val_loader = get_dataloder(args, split_flag="valid")
        setting = {k: v for k, v in args._get_kwargs()}
        logger.info(setting)
        logger.info('param size = %fMB', calc_parameters_count(model))


        # init loss
        if args.loss == 'bce':
            criterion = nn.BCELoss()
        elif args.loss == 'bcelog':
            criterion = nn.BCEWithLogitsLoss()
        elif args.loss == "dice":
            criterion = DiceLoss()
        elif args.loss == "softdice":
            criterion = SoftDiceLoss()
        elif args.loss == 'bcedice':
            criterion = BCEDiceLoss()
        else:
            criterion = nn.CrossEntropyLoss()
        if args.use_cuda:
            logger.info("load model and criterion to gpu !")

        model = model.to(args.device)
        criterion = criterion.to(args.device)
        infer(args, model, criterion, val_loader,logger,args.save_images)
示例#9
0
def main(args):
    #################### init logger ###################################
    #args.model_list=["unet","unet++_deep",'attention_unet_v1','multires_unet', 'r2unet_t3']

    args.model_list = [
        "unet", "unet++_deep", 'unet++_nodeep', "attention_unet_v1",
        "multires_unet", "r2unet"
    ]

    for model_name in args.model_list:
        # if model_name=='unet':
        #     args.model='unet'
        #     model_weight_path='./logs/isic/logs_coslr/unet/isic2018/20200229-035150/checkpoint.pth.tar'
        #     model=get_models(args)
        #     model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])
        # elif model_name=='unet++_deep':
        #     args.model='unet++'
        #     args.deepsupervision=True
        #     model_weight_path='./logs/isic/logs_coslr/unet++/isic2018/20200229-035514/checkpoint.pth.tar'
        #     model=get_models(args)
        #     model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])
        # elif model_name == 'unet++_nodeep':
        #     args.model = 'unet++'
        #     args.deepsupervision = False
        #     model_weight_path = '/checkpoint.pth.tar'
        #     model = get_models(args)
        #     model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])
        # elif model_name == 'attention_unet_v1':
        #     args.model = 'attention_unet_v1'
        #     model_weight_path = './logs/isic/logs_coslr/attention_unet_v1/isic2018/20200302-190718/checkpoint.pth.tar'
        #     args.deepsupervision=False
        #     model = get_models(args)
        #     model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])
        #
        # elif model_name == 'multires_unet':
        #     args.model = 'multires_unet'
        #     model_weight_path = './logs/isic/logs_coslr/multires_unet/isic2018/20200229-035734/checkpoint.pth.tar'
        #     model = get_models(args)
        #     model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])
        #
        # elif model_name == 'r2unet_t3':
        #     args.model = 'r2unet'
        #     args.time_step=3
        #     model_weight_path = './logs/isic/logs_coslr/r2unet/isic2018/20200302-190808/checkpoint.pth.tar'
        #     model = get_models(args)
        #     model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])

        # ep300   baseline
        if model_name == 'unet':
            args.model = 'unet'
            model_weight_path = './logs/isic2018/unet_ep300/20200402-135108/model_best.pth.tar'
            model = get_models(args)
            model.load_state_dict(
                torch.load(model_weight_path,
                           map_location='cpu')['state_dict'])
        elif model_name == 'unet++_deep':
            args.model = 'unet++'
            args.deepsupervision = True
            model_weight_path = './logs/isic2018/unet++_ep300_deep/20200402-135243/model_best.pth.tar'
            model = get_models(args)
            model.load_state_dict(
                torch.load(model_weight_path,
                           map_location='cpu')['state_dict'])
        elif model_name == 'unet++_nodeep':
            args.model = 'unet++'
            args.deepsupervision = False
            model_weight_path = './logs/isic2018/unet++_ep300/20200402-135317/model_best.pth.tar'
            model = get_models(args)
            model.load_state_dict(
                torch.load(model_weight_path,
                           map_location='cpu')['state_dict'])

        elif model_name == 'attention_unet_v1':
            args.model = 'attention_unet_v1'
            args.deepsupervision = False
            model_weight_path = './logs/isic2018/attention_unet_v1_ep300/20200413-160808//model_best.pth.tar'
            model = get_models(args)
            model.load_state_dict(
                torch.load(model_weight_path,
                           map_location='cpu')['state_dict'])

        elif model_name == 'multires_unet':
            args.model = 'multires_unet'
            args.deepsupervision = False
            model_weight_path = './logs/isic2018/attention_unet_v1_ep300/20200413-160808//model_best.pth.tar'
            model = get_models(args)
            model.load_state_dict(
                torch.load(model_weight_path,
                           map_location='cpu')['state_dict'])
        elif model_name == 'r2unet':
            args.model = 'r2unet'
            args.deepsupervision = False
            model_weight_path = './logs/isic2018/attention_unet_v1_ep300/20200413-160808//model_best.pth.tar'
            model = get_models(args)
            model.load_state_dict(
                torch.load(model_weight_path,
                           map_location='cpu')['state_dict'])

        # elif model_name == 'attention_unet_v1':
        #     args.model = 'attention_unet_v1'
        #     model_weight_path = './logs/isic/logs_coslr/attention_unet_v1/isic2018/20200302-190718/checkpoint.pth.tar'
        #     args.deepsupervision=False
        #     model = get_models(args)
        #     model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])
        #
        # elif model_name == 'multires_unet':
        #     args.model = 'multires_unet'
        #     model_weight_path = './logs/isic/logs_coslr/multires_unet/isic2018/20200229-035734/checkpoint.pth.tar'
        #     model = get_models(args)
        #     model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])
        #
        # elif model_name == 'r2unet_t3':
        #     args.model = 'r2unet'
        #     args.time_step=3
        #     model_weight_path = './logs/isic/logs_coslr/r2unet/isic2018/20200302-190808/checkpoint.pth.tar'
        #     model = get_models(args)
        #     model.load_state_dict(torch.load(model_weight_path, map_location='cpu')['state_dict'])

        else:
            raise NotImplementedError()

        assert os.path.exists(args.save)
        args.model_save_path = os.path.join(args.save, model_name)
        logger = get_logger(args.model_save_path)
        args.save_images = os.path.join(args.model_save_path, "images")
        if not os.path.exists(args.save_images):
            os.mkdir(args.save_images)

        if args.manualSeed is None:
            args.manualSeed = random.randint(1, 10000)
        np.random.seed(args.manualSeed)
        torch.manual_seed(args.manualSeed)
        args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
        args.device = torch.device('cuda' if args.use_cuda else 'cpu')
        if args.use_cuda:
            torch.cuda.manual_seed(args.manualSeed)
            cudnn.benchmark = True

        val_loader = get_dataloder(args, split_flag="valid")

        setting = {k: v for k, v in args._get_kwargs()}
        logger.info(setting)
        logger.info('param size = %fMB', calc_parameters_count(model))

        # init loss
        if args.loss == 'bce':
            criterion = nn.BCELoss()
        elif args.loss == 'bcelog':
            criterion = nn.BCEWithLogitsLoss()
        elif args.loss == "dice":
            criterion = DiceLoss()
        elif args.loss == "softdice":
            criterion = SoftDiceLoss()
        elif args.loss == 'bcedice':
            criterion = BCEDiceLoss()
        else:
            criterion = nn.CrossEntropyLoss()
        if args.use_cuda:
            logger.info("load model and criterion to gpu !")

        model = model.to(args.device)
        criterion = criterion.to(args.device)
        infer(args, model, criterion, val_loader, logger, args.save_images)
示例#10
0
def main(args):

    args.model_list = [
        'double_deep', 'double', 'nodouble', 'nodouble_deep', 'slim_dd',
        'slim_double', 'slim_nodouble', 'slim_nodouble_deep'
    ]
    #args.model_list=["slim_nodouble_deep_init32"]
    for model_name in args.model_list:
        print(model_name)
        if model_name == "double_deep":
            args.deepsupervision = True
            args.double_down_channel = True
            args.genotype_name = 'stage1_layer9_110epoch_double_deep_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = BuildNasUnetPrune(
                genotype=genotype,
                input_c=args.in_channels,
                c=args.init_channels,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/prune_20200313-063406_32_32_ep300_double_deep/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        elif model_name == 'double':
            args.deepsupervision = False
            args.double_down_channel = True
            args.genotype_name = 'stage1_layer9_110epoch_double_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = BuildNasUnetPrune(
                genotype=genotype,
                input_c=args.in_channels,
                c=args.init_channels,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/prune_20200313-063428_32_32_ep300_double/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        elif model_name == 'nodouble':
            args.deepsupervision = False
            args.double_down_channel = False
            args.genotype_name = 'stage1_layer9_110epoch_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = BuildNasUnetPrune(
                genotype=genotype,
                input_c=args.in_channels,
                c=args.init_channels,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/prune_20200316-141125_nodouble_32_ep300/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        elif model_name == 'nodouble_deep':
            args.deepsupervision = True
            args.double_down_channel = False
            args.genotype_name = 'stage1_layer9_110epoch_deep_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = BuildNasUnetPrune(
                genotype=genotype,
                input_c=args.in_channels,
                c=args.init_channels,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/prune_20200316-141242_nodouble_32_ep300_deep/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        if model_name == "slim_dd":
            args.deepsupervision = True
            args.double_down_channel = True
            args.genotype_name = 'stage1_layer9_110epoch_double_deep_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = net_dd(genotype=genotype,
                           input_c=args.in_channels,
                           c=args.init_channels,
                           num_classes=args.nclass,
                           meta_node_num=args.middle_nodes,
                           layers=args.layers,
                           dp=args.dropout_prob,
                           use_sharing=args.use_sharing,
                           double_down_channel=args.double_down_channel,
                           aux=args.aux)
            args.model_path = './logs/isic2018/dd_20200319-170442_ep300/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        elif model_name == 'slim_double':
            args.deepsupervision = False
            args.double_down_channel = True
            args.genotype_name = 'stage1_layer9_110epoch_double_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = net_double(genotype=genotype,
                               input_c=args.in_channels,
                               c=args.init_channels,
                               num_classes=args.nclass,
                               meta_node_num=args.middle_nodes,
                               layers=args.layers,
                               dp=args.dropout_prob,
                               use_sharing=args.use_sharing,
                               double_down_channel=args.double_down_channel,
                               aux=args.aux)
            args.model_path = './logs/isic2018/double_20200319-170621_ep300/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        elif model_name == 'slim_nodouble':
            args.deepsupervision = False
            args.double_down_channel = False
            args.genotype_name = 'stage1_layer9_110epoch_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = net_nodouble(genotype=genotype,
                                 input_c=args.in_channels,
                                 c=args.init_channels,
                                 num_classes=args.nclass,
                                 meta_node_num=args.middle_nodes,
                                 layers=args.layers,
                                 dp=args.dropout_prob,
                                 use_sharing=args.use_sharing,
                                 double_down_channel=args.double_down_channel,
                                 aux=args.aux)
            args.model_path = './logs/isic2018/nodouble_20200319-210910_ep300/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        elif model_name == 'slim_nodouble_deep':
            args.deepsupervision = True
            args.double_down_channel = False
            args.genotype_name = 'stage1_layer9_110epoch_deep_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = net_nodouble_deep(
                genotype=genotype,
                input_c=args.in_channels,
                c=args.init_channels,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/nodouble_deep_20200319-210600_ep300/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        elif model_name == 'slim_nodouble_deep_init32':
            args.deepsupervision = True
            args.double_down_channel = False
            args.genotype_name = 'stage1_layer9_110epoch_deep_final'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = net_nodouble_deep(
                genotype=genotype,
                input_c=args.in_channels,
                c=32,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/nodouble_deep_ep300_init32/model_best.pth.tar'
            model.load_state_dict(
                torch.load(args.model_path, map_location='cpu')['state_dict'])

        #################### init logger ###################################
        log_dir = './eval' + '/{}'.format(
            args.dataset) + '/{}'.format(model_name)
        ##################### init model ########################################
        logger = get_logger(log_dir)
        print('RUNDIR: {}'.format(log_dir))
        logger.info('{}-Eval'.format(model_name))
        # setting
        args.save_path = log_dir
        args.save_images = os.path.join(args.save_path, "images")
        if not os.path.exists(args.save_images):
            os.mkdir(args.save_images)
        ##################### init device #################################
        if args.manualSeed is None:
            args.manualSeed = random.randint(1, 10000)
        np.random.seed(args.manualSeed)
        torch.manual_seed(args.manualSeed)
        args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
        args.device = torch.device('cuda' if args.use_cuda else 'cpu')
        if args.use_cuda:
            torch.cuda.manual_seed(args.manualSeed)
            cudnn.benchmark = True
        ####################### init dataset ###########################################
        # sorted vaild datasets
        val_loader = get_dataloder(args, split_flag="valid")
        setting = {k: v for k, v in args._get_kwargs()}
        logger.info(setting)
        logger.info(genotype)
        logger.info('param size = %fMB', calc_parameters_count(model))

        # init loss
        if args.loss == 'bce':
            criterion = nn.BCELoss()
        elif args.loss == 'bcelog':
            criterion = nn.BCEWithLogitsLoss()
        elif args.loss == "dice":
            criterion = DiceLoss()
        elif args.loss == "softdice":
            criterion = SoftDiceLoss()
        elif args.loss == 'bcedice':
            criterion = BCEDiceLoss()
        else:
            criterion = nn.CrossEntropyLoss()
        if args.use_cuda:
            logger.info("load model and criterion to gpu !")
        model = model.to(args.device)
        criterion = criterion.to(args.device)
示例#11
0
def main(args):

    #################### init logger ###################################
    log_dir = './logs/' + '{}'.format(args.dataset) + '/{}_{}_{}'.format(
        args.model, args.note, time.strftime('%Y%m%d-%H%M%S'))

    logger = get_logger(log_dir)
    print('RUNDIR: {}'.format(log_dir))
    logger.info('{}-Train'.format(args.model))
    # setting
    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    args.save_path = log_dir
    args.save_tbx_log = args.save_path + '/tbx_log'
    writer = SummaryWriter(args.save_tbx_log)
    ##################### init device #################################
    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    np.random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.device = torch.device('cuda' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.manualSeed)
        cudnn.benchmark = True
    ####################### init dataset ###########################################
    train_loader = get_dataloder(args, split_flag="train")
    val_loader = get_dataloder(args, split_flag="valid")
    ######################## init model ############################################
    # model
    ############init model ###########################
    if args.model == "nodouble_deep_init32_ep100":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'nodouble_deep_init32_ep100'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=32,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=9,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "nodouble_deep_isic":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'stage1_layer9_110epoch_deep_final'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "nodouble_deep_drop02_layer7end":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'nodouble_deep_drop02_layer7end'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "stage1_nodouble_deep_ep36":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'stage1_nodouble_deep_ep36'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "stage1_nodouble_deep_ep63":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'stage1_nodouble_deep_ep63'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)
    elif args.model == "stage1_nodouble_deep_ep83":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'stage1_nodouble_deep_ep83'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha1_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha1_stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha0_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha0_5_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    # isic trans
    elif args.model == "stage1_layer9_110epoch_double_deep_final":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage1_layer9_110epoch_double_deep_final'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    #chaos
    elif args.model == "stage0_double_deep_ep80_newim":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage0_double_deep_ep80_newim'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage1_double_deep_ep80'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "stage1_double_deep_ep80_ts":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage1_double_deep_ep80_ts'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    # cvc trans
    elif args.model == "layer7_double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'layer7_double_deep'
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    if torch.cuda.device_count() > 1 and args.use_cuda:
        logger.info('use: %d gpus', torch.cuda.device_count())
        model = nn.DataParallel(model)
    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    logger.info(genotype)
    logger.info('param size = %fMB', calc_parameters_count(model))

    # init loss
    if args.loss == 'bce':
        criterion = nn.BCELoss()
    elif args.loss == 'bcelog':
        criterion = nn.BCEWithLogitsLoss()
    elif args.loss == "dice":
        criterion = DiceLoss()
    elif args.loss == "softdice":
        criterion = SoftDiceLoss()
    elif args.loss == 'bcedice':
        criterion = BCEDiceLoss()
    elif args.loss == 'multibcedice':
        criterion = MultiClassEntropyDiceLoss()
    else:
        criterion = nn.CrossEntropyLoss()
    if args.use_cuda:
        logger.info("load model and criterion to gpu !")
        model = model.to(args.device)
        criterion = criterion.to(args.device)
    # init optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                weight_decay=args.weight_decay,
                                momentum=args.momentum)
    # init schedulers  Steplr
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.epoch)
    # scheduler=torch.optim.lr_scheduler.StepLR(optimizer=optimizer,step_size=30,gamma=0.1,last_epoch=-1)
    ############################### check resume #########################
    start_epoch = 0
    if args.resume is not None:
        if os.path.isfile(args.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume, map_location=args.device)
            start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            model.load_state_dict(checkpoint['state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler'])
        else:
            raise FileNotFoundError("No checkpoint found at '{}'".format(
                args.resume))

    #################################### train and val ########################

    max_value = 0
    for epoch in range(start_epoch, args.epoch):
        # lr=adjust_learning_rate(args,optimizer,epoch)
        scheduler.step()
        # logger.info('Epoch: %d lr %e', epoch, scheduler.get_lr()[0])
        # train
        total_loss = train(args, model, criterion, train_loader, optimizer,
                           epoch, logger)
        # write
        writer.add_scalar('Train/total_loss', total_loss, epoch)
        # val
        tloss, md = val(args, model, criterion, val_loader, epoch, logger)
        writer.add_scalar('Val/total_loss', tloss, epoch)

        is_best = True if (md >= max_value) else False
        max_value = max(max_value, md)
        state = {
            'epoch': epoch,
            'optimizer': optimizer.state_dict(),
            'state_dict': model.state_dict(),
            'scheduler': model.state_dict(),
        }
        logger.info("epoch:{} best:{} max_value:{}".format(
            epoch, is_best, max_value))
        if not is_best:
            torch.save(state, os.path.join(args.save_path,
                                           "checkpoint.pth.tar"))
        else:
            torch.save(state, os.path.join(args.save_path,
                                           "checkpoint.pth.tar"))
            torch.save(state, os.path.join(args.save_path,
                                           "model_best.pth.tar"))
    writer.close()
示例#12
0
def main(args):

    #################### init logger ###################################
    log_dir = './logs/' + '{}'.format(args.dataset) + '/{}_{}_{}'.format(
        args.model, time.strftime('%Y%m%d-%H%M%S'), args.note)
    logger = get_logger(log_dir)
    print('RUNDIR: {}'.format(log_dir))
    logger.info('{}-Train'.format(args.model))
    # setting
    args.save_path = log_dir
    args.save_tbx_log = args.save_path + '/tbx_log'
    writer = SummaryWriter(args.save_tbx_log)
    ##################### init device #################################
    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    np.random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.device = torch.device('cuda' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.manualSeed)
        cudnn.benchmark = True
    ####################### init dataset ###########################################
    train_loader = get_dataloder(args, split_flag="train")
    val_loader = get_dataloder(args, split_flag="valid")
    ######################## init model ############################################
    # model
    # get the network parameters
    if args.model == "alpha_double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage1_layer9_110epoch_double_deep_final'
        args.alphas_model = './search_exp/Nas_Search_Unet/isic2018/deepsupervision/stage_1_model/checkpoint.pth.tar'
        model_alphas = torch.load(
            args.alphas_model,
            map_location=args.device)['alphas_dict']['alphas_network']
        model_alphas.requires_grad = False
        model_alphas = F.softmax(model_alphas, dim=-1)
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnet(genotype=genotype,
                             input_c=args.in_channels,
                             c=args.init_channels,
                             num_classes=args.nclass,
                             meta_node_num=args.middle_nodes,
                             layers=args.layers,
                             dp=args.dropout_prob,
                             use_sharing=args.use_sharing,
                             double_down_channel=args.double_down_channel,
                             aux=args.aux)

    elif args.model == "alpha_double":
        args.deepsupervision = False
        args.double_down_channel = True
        args.genotype_name = 'stage1_layer9_110epoch_double_final'
        args.alphas_model = './search_exp/Nas_Search_Unet/isic2018/nodeepsupervision/stage_1_model/checkpoint.pth.tar'
        model_alphas = torch.load(
            args.alphas_model,
            map_location=args.device)['alphas_dict']['alphas_network']
        model_alphas.requires_grad = False
        model_alphas = F.softmax(model_alphas, dim=-1)
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnet(genotype=genotype,
                             input_c=args.in_channels,
                             c=args.init_channels,
                             num_classes=args.nclass,
                             meta_node_num=args.middle_nodes,
                             layers=args.layers,
                             dp=args.dropout_prob,
                             use_sharing=args.use_sharing,
                             double_down_channel=args.double_down_channel,
                             aux=args.aux)

    elif args.model == "alpha_nodouble":
        args.deepsupervision = False
        args.double_down_channel = False
        args.genotype_name = 'stage1_layer9_110epoch_final'
        args.alphas_model = './search_exp/Nas_Search_Unet/isic2018/nodouble/stage_1_model/checkpoint.pth.tar'
        model_alphas = torch.load(
            args.alphas_model,
            map_location=args.device)['alphas_dict']['alphas_network']
        model_alphas.requires_grad = False
        model_alphas = F.softmax(model_alphas, dim=-1)
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnet(genotype=genotype,
                             input_c=args.in_channels,
                             c=args.init_channels,
                             num_classes=args.nclass,
                             meta_node_num=args.middle_nodes,
                             layers=args.layers,
                             dp=args.dropout_prob,
                             use_sharing=args.use_sharing,
                             double_down_channel=args.double_down_channel,
                             aux=args.aux)

    elif args.model == "alpha_nodouble_deep":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'stage1_layer9_110epoch_deep_final'
        args.alphas_model = './search_exp/Nas_Search_Unet/isic2018/nodouble_deep/stage_1_model/checkpoint.pth.tar'
        model_alphas = torch.load(
            args.alphas_model,
            map_location=args.device)['alphas_dict']['alphas_network']
        model_alphas.requires_grad = False
        model_alphas = F.softmax(model_alphas, dim=-1)
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnet(genotype=genotype,
                             input_c=args.in_channels,
                             c=args.init_channels,
                             num_classes=args.nclass,
                             meta_node_num=args.middle_nodes,
                             layers=args.layers,
                             dp=args.dropout_prob,
                             use_sharing=args.use_sharing,
                             double_down_channel=args.double_down_channel,
                             aux=args.aux)

    elif args.model == "double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage1_layer9_110epoch_double_deep_final'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "double":
        args.deepsupervision = False
        args.double_down_channel = True
        args.genotype_name = 'stage1_layer9_110epoch_double_final'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "nodouble":
        args.deepsupervision = False
        args.double_down_channel = False
        args.genotype_name = 'stage1_layer9_110epoch_final'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "nodouble_deep":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'stage1_layer9_110epoch_deep_final'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha1_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha1_stage1_double_deep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha0_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_stage1_double_deep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha0_5_stage1_double_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha0_5_stage1_double_nodeep_ep80":
        args.deepsupervision = False
        args.double_down_channel = True
        args.genotype_name = 'alpha0_5_stage1_double_nodeep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha0_5_stage1_nodouble_deep_ep80":
        args.deepsupervision = True
        args.double_down_channel = False
        args.genotype_name = 'alpha0_5_stage1_nodouble_deep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    elif args.model == "alpha0_5_stage1_nodouble_nodeep_ep80":
        args.deepsupervision = False
        args.double_down_channel = False
        args.genotype_name = 'alpha0_5_stage1_nodouble_nodeep_ep80'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    # cvc trans
    elif args.model == "layer7_double_deep":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'layer7_double_deep'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    # chaos trans
    elif args.model == "stage0_double_deep_ep80_newim":
        args.deepsupervision = True
        args.double_down_channel = True
        args.genotype_name = 'stage0_double_deep_ep80_newim'
        model_alphas = None
        genotype = eval('genotypes.%s' % args.genotype_name)
        model = BuildNasUnetPrune(genotype=genotype,
                                  input_c=args.in_channels,
                                  c=args.init_channels,
                                  num_classes=args.nclass,
                                  meta_node_num=args.middle_nodes,
                                  layers=args.layers,
                                  dp=args.dropout_prob,
                                  use_sharing=args.use_sharing,
                                  double_down_channel=args.double_down_channel,
                                  aux=args.aux)

    if torch.cuda.device_count() > 1 and args.use_cuda:
        logger.info('use: %d gpus', torch.cuda.device_count())
        model = nn.DataParallel(model)

    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    logger.info(genotype)
    logger.info(model_alphas)
    flop, param = get_model_complexity_info(model, (3, 256, 256),
                                            as_strings=True,
                                            print_per_layer_stat=False)
    print("GFLOPs: {}".format(flop))
    print("Params: {}".format(param))
    # init loss
    if args.loss == 'bce':
        criterion = nn.BCELoss()
    elif args.loss == 'bcelog':
        criterion = nn.BCEWithLogitsLoss()
    elif args.loss == "dice":
        criterion = DiceLoss()
    elif args.loss == "softdice":
        criterion = SoftDiceLoss()
    elif args.loss == 'bcedice':
        criterion = BCEDiceLoss()
    else:
        criterion = nn.CrossEntropyLoss()
    if args.use_cuda:
        logger.info("load model and criterion to gpu !")
    model = model.to(args.device)
    criterion = criterion.to(args.device)
    # init optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                weight_decay=args.weight_decay,
                                momentum=args.momentum)
    # init schedulers  Steplr
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.epoch)
    #scheduler=torch.optim.lr_scheduler.StepLR(optimizer=optimizer,step_size=30,gamma=0.1,last_epoch=-1)

    ############################### check resume #########################
    start_epoch = 0
    if args.resume is not None:
        if os.path.isfile(args.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume, map_location=args.device)
            start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            model.load_state_dict(checkpoint['state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler'])
        else:
            raise FileNotFoundError("No checkpoint found at '{}'".format(
                args.resume))

    #################################### train and val ########################
    max_value = 0
    for epoch in range(start_epoch, args.epoch):
        # lr=adjust_learning_rate(args,optimizer,epoch)
        scheduler.step()
        # train
        if args.deepsupervision:
            mean_loss, value1, value2 = train(args, model_alphas, model,
                                              criterion, train_loader,
                                              optimizer)
            mr, ms, mp, mf, mjc, md, macc = value1
            logger.info(
                "Epoch:{} Train_Loss:{:.3f} Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".
                format(epoch, mean_loss, macc, md, mjc))
            writer.add_scalar('Train/dDice', mmd, epoch)
        else:
            mean_loss, value1 = train(args, model_alphas, model, criterion,
                                      train_loader, optimizer)
            mr, ms, mp, mf, mjc, md, macc = value1
            logger.info(
                "Epoch:{} Train_Loss:{:.3f} Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".
                format(epoch, mean_loss, macc, md, mjc))
        # write
        writer.add_scalar('Train/Loss', mean_loss, epoch)

        # val
        if args.deepsupervision:
            vmean_loss, valuev1, valuev2 = infer(args, model_alphas, model,
                                                 criterion, val_loader)
            vmr, vms, vmp, vmf, vmjc, vmd, vmacc = valuev1
            logger.info(
                "Epoch:{} Val_Loss:{:.3f} Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".
                format(epoch, vmean_loss, vmacc, vmd, vmjc))

        else:
            vmean_loss, valuev1 = infer(args, model_alphas, model, criterion,
                                        val_loader)
            vmr, vms, vmp, vmf, vmjc, vmd, vmacc = valuev1
            logger.info(
                "Epoch:{} Val_Loss:{:.3f} Acc:{:.3f} Dice:{:.3f} Jc:{:.3f}".
                format(epoch, vmean_loss, vmacc, vmd, vmjc))

        is_best = True if vmjc >= max_value else False
        max_value = max(max_value, vmjc)
        writer.add_scalar('Val/Loss', vmean_loss, epoch)

        state = {
            'epoch': epoch,
            'optimizer': optimizer.state_dict(),
            'state_dict': model.state_dict(),
            'scheduler': model.state_dict(),
        }
        logger.info("epoch:{} best:{} max_value:{}".format(
            epoch, is_best, max_value))
        if not is_best:
            torch.save(state, os.path.join(args.save_path,
                                           "checkpoint.pth.tar"))
        else:
            torch.save(state, os.path.join(args.save_path,
                                           "checkpoint.pth.tar"))
            torch.save(state, os.path.join(args.save_path,
                                           "model_best.pth.tar"))

    writer.close()
示例#13
0
def main(args):

    #args.model_list=['alpha0_double_deep_0.01','alpha0_5_double_deep_0.01','alpha1_double_deep_0.01','nodouble_deep','slim_dd','slim_double','slim_nodouble','slim_nodouble_deep']
    #args.model_list=["double_deep","nodouble_deep","slim_nodouble"]
    #args.model_list=["slim_nodouble_deep_init32"]
    #args.model_list=["slim_nodouble_deep_init48"]
    args.model_list = [
        'alpha0_double_deep_0.01', 'alpha0_5_double_deep_0.01',
        'alpha1_double_deep_0.01'
    ]

    for model_name in args.model_list:
        if model_name == "alpha0_double_deep_0.01":
            args.deepsupervision = True
            args.double_down_channel = True
            args.genotype_name = 'alpha0_stage1_double_deep_ep200'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = BuildNasUnetPrune(
                genotype=genotype,
                input_c=args.in_channels,
                c=args.init_channels,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/alpha0_double_deep_0.01/model_best.pth.tar'
            # kwargs = {'map_location': lambda storage, loc: storage.cuda(0)}
            # state_dict = torch.load(args.model_path, **kwargs)
            # # create new OrderedDict that does not contain `module.`
            # model.load_state_dict(state_dict)

            state_dict = torch.load(args.model_path,
                                    map_location='cpu')['state_dict']
            state_dict = remove_module(state_dict)
            model.load_state_dict(state_dict)

        elif model_name == "alpha0_5_double_deep_0.01":
            args.deepsupervision = True
            args.double_down_channel = True
            args.genotype_name = 'alpha0_5_stage1_double_deep_ep80'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = BuildNasUnetPrune(
                genotype=genotype,
                input_c=args.in_channels,
                c=args.init_channels,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/alpha0_5_double_deep_0.01/model_best.pth.tar'
            state_dict = torch.load(args.model_path,
                                    map_location='cpu')['state_dict']
            state_dict = remove_module(state_dict)
            model.load_state_dict(state_dict)
            #model.load_state_dict(torch.load(args.model_path, map_location='cpu')['state_dict'])

        elif model_name == "alpha1_double_deep_0.01":
            args.deepsupervision = True
            args.double_down_channel = True
            args.genotype_name = 'alpha1_stage1_double_deep_ep200'
            genotype = eval('genotypes.%s' % args.genotype_name)
            model = BuildNasUnetPrune(
                genotype=genotype,
                input_c=args.in_channels,
                c=args.init_channels,
                num_classes=args.nclass,
                meta_node_num=args.middle_nodes,
                layers=args.layers,
                dp=args.dropout_prob,
                use_sharing=args.use_sharing,
                double_down_channel=args.double_down_channel,
                aux=args.aux)
            args.model_path = './logs/isic2018/alpha1_double_deep_0.01/model_best.pth.tar'
            state_dict = torch.load(args.model_path,
                                    map_location='cpu')['state_dict']
            state_dict = remove_module(state_dict)
            model.load_state_dict(state_dict)

            #model.load_state_dict(torch.load(args.model_path, map_location='cpu')['state_dict'])

        #################### init logger ###################################
        log_dir = './eval' + '/{}'.format(
            args.dataset) + '/{}'.format(model_name)
        ##################### init model ########################################
        logger = get_logger(log_dir)
        print('RUNDIR: {}'.format(log_dir))
        logger.info('{}-Eval'.format(model_name))
        # setting
        args.save_path = log_dir
        args.save_images = os.path.join(args.save_path, "images")
        if not os.path.exists(args.save_images):
            os.mkdir(args.save_images)
        ##################### init device #################################
        if args.manualSeed is None:
            args.manualSeed = random.randint(1, 10000)
        np.random.seed(args.manualSeed)
        torch.manual_seed(args.manualSeed)
        args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
        args.device = torch.device('cuda' if args.use_cuda else 'cpu')
        if args.use_cuda:
            torch.cuda.manual_seed(args.manualSeed)
            cudnn.benchmark = True
        ####################### init dataset ###########################################
        # sorted vaild datasets
        val_loader = get_dataloder(args, split_flag="valid")
        setting = {k: v for k, v in args._get_kwargs()}
        logger.info(setting)
        logger.info(genotype)
        logger.info('param size = %fMB', calc_parameters_count(model))

        # init loss
        if args.loss == 'bce':
            criterion = nn.BCELoss()
        elif args.loss == 'bcelog':
            criterion = nn.BCEWithLogitsLoss()
        elif args.loss == "dice":
            criterion = DiceLoss()
        elif args.loss == "softdice":
            criterion = SoftDiceLoss()
        elif args.loss == 'bcedice':
            criterion = BCEDiceLoss()
        else:
            criterion = nn.CrossEntropyLoss()
        if args.use_cuda:
            logger.info("load model and criterion to gpu !")
        model = model.to(args.device)
        criterion = criterion.to(args.device)
        infer(args, model, criterion, val_loader, logger, args.save_images)