Esempio n. 1
0
 def check_class(line):
     """Check if class name is valid in Data Base"""
     if line == "":
         print("** class name missing **")
         return False
     line = line.split()
     if line[0] not in models_dict.keys():
         print("** class doesn't exist **")
         return False
     return True
Esempio n. 2
0
def main(args):
    ############    init config ################
    model_name = args.model
    assert model_name in models_dict.keys(),"The Usage model is not exist !"
    print('Usage model :{}'.format(model_name))

    #################### init logger ###################################
    log_dir = './logs/'+ args.model+'_'+args.note + '/{}'.format(time.strftime('%Y%m%d-%H%M%S'))
    logger = get_logger(log_dir)
    print('RUNDIR: {}'.format(log_dir))
    logger.info('{}-Train'.format(args.model))
    # setting
    setting={k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    args.save_path = log_dir
    args.save_tbx_log = args.save_path + '/tbx_log'
    writer = SummaryWriter(args.save_tbx_log)
    ##################### init device #################################
    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    np.random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    args.use_cuda= args.gpus>0 and torch.cuda.is_available()
    args.device = torch.device('cuda' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.manualSeed)
        cudnn.benchmark = True
    ####################### init dataset ###########################################
    train_loader=get_dataloder(args,split_flag="train")
    val_loader=get_dataloder(args,split_flag="valid")
    ######################## init model ############################################
    # model
    logger.info("Model Dict has keys: \n {}".format(models_dict.keys()))
    model=get_models(args)
    if torch.cuda.device_count() > 1 and args.use_cuda:
        logger.info('use: %d gpus', torch.cuda.device_count())
        model = nn.DataParallel(model)
    logger.info('param size = %fMB', calc_parameters_count(model))
    # init loss
    if args.loss=='bce':
        criterion=nn.BCELoss()
    elif args.loss=='bcelog':
        criterion=nn.BCEWithLogitsLoss()
    elif args.loss=="dice":
        criterion=DiceLoss()
    elif args.loss=="softdice":
        criterion=SoftDiceLoss()
    elif args.loss=='bcedice':
        criterion=BCEDiceLoss()
    else:
        criterion=nn.CrossEntropyLoss()
    if args.use_cuda:
        logger.info("load model and criterion to gpu !")
        model=model.to(args.device)
        criterion=criterion.to(args.device)
    # init optimizer
    if args.model_optimizer=="sgd":
        #torch.optim.SGD(parametetrs,lr=args.lr,weight_decay=args.weight_decay,momentum=args.momentum)
        optimizer=torch.optim.SGD(model.parameters(),lr=args.lr,weight_decay=args.weight_decay,momentum=args.momentum)
    else:
        optimizer=torch.optim.Adam(model.parameters(),args.lr,[args.beta1, args.beta2],
                                   weight_decay=args.weight_decay)

    # init schedulers  Steplr
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,args.epoch)
    # scheduler=torch.optim.lr_scheduler.StepLR(optimizer=optimizer,step_size=30,gamma=0.1,last_epoch=-1)
    ############################### check resume #########################
    start_epoch=0
    if args.resume is not None:
        if os.path.isfile(args.resume):
            logger.info("Loading model and optimizer from checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=args.device)
            start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            model.load_state_dict(checkpoint['state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler'])
        else:
            raise FileNotFoundError("No checkpoint found at '{}'".format(args.resume))

    #################################### train and val ########################
    max_value=0
    for epoch in range(start_epoch,args.epoch):
        # lr=adjust_learning_rate(args,optimizer,epoch)
        scheduler.step()
        logger.info('Epoch: %d lr %e', epoch, scheduler.get_lr()[0])
        # train
        mr, ms, mp, mf, mjc, md, macc, mean_loss=train(args, model, criterion, train_loader,
                                                       optimizer, epoch, logger)
        # write
        writer.add_scalar('Train/Loss', mean_loss, epoch)
        writer.add_scalar('Train/mAcc', macc, epoch)
        writer.add_scalar('Train/Recall', mr, epoch)
        writer.add_scalar('Train/Specifi', ms, epoch)
        writer.add_scalar('Train/Precision', mp, epoch)
        writer.add_scalar('Train/F1', mf, epoch)
        writer.add_scalar('Train/Jc', mjc, epoch)
        writer.add_scalar('Train/Dice', md, epoch)

        # val
        vmr, vms, vmp, vmf, vmjc, vmd, vmacc, vmean_loss=val(args, model, criterion, val_loader, epoch, logger)

        writer.add_scalar('Val/Loss', vmean_loss, epoch)
        writer.add_scalar('Val/mAcc', vmacc, epoch)
        writer.add_scalar('Val/Recall', vmr, epoch)
        writer.add_scalar('Val/Specifi', vms, epoch)
        writer.add_scalar('Val/Precision', vmp, epoch)
        writer.add_scalar('Val/F1', vmf, epoch)
        writer.add_scalar('Val/Jc', vmjc, epoch)
        writer.add_scalar('Val/Dice', vmd, epoch)

        is_best=True if (vmjc>=max_value) else False
        max_value=max(max_value,vmjc)
        state={
                'epoch': epoch,
                'optimizer': optimizer.state_dict(),
                'state_dict': model.state_dict(),
                'scheduler': model.state_dict(),
            }
        logger.info("epoch:{} best:{} max_value:{}".format(epoch,is_best,max_value))
        if not is_best:
            torch.save(state,os.path.join(args.save_path,"checkpoint.pth.tar"))
        else:
            torch.save(state,os.path.join(args.save_path,"checkpoint.pth.tar"))
            torch.save(state,os.path.join(args.save_path,"model_best.pth.tar"))

    writer.close()
Esempio n. 3
0
def adjust_learning_rate(args,optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr
    assert len(args.gammas) == len(args.schedule), "length of gammas and schedule should be equal"
    for (gamma, step) in zip(args.gammas, args.schedule):
        if (epoch >= step):
            lr = lr * gamma
        else:
            break
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr


if __name__ == '__main__':
    models_name=models_dict.keys()
    datasets_name=datasets_dict.keys()
    parser = argparse.ArgumentParser(description='Unet serieas baseline')
    # Add default argument
    parser.add_argument('--model',  type=str, default='unet',choices=models_name,
                        help='Model to train and evaluation')
    parser.add_argument('--note' ,type=str, default='_',
                        help='model note ')
    parser.add_argument('--dataset',type=str, default='cvc',choices=datasets_name,
                        help='Model to train and evaluation')
    parser.add_argument('--base_size', type=int, default=256, help="resize base size")
    parser.add_argument('--crop_size', type=int, default=256, help="crop  size")
    parser.add_argument('--im_channel', type=int, default=3, help="input image channel ")
    parser.add_argument('--class_num', type=int, default=1, help="output feature channel")
    parser.add_argument('--epoch', type=int, default=1600, help="epochs")
    parser.add_argument('--train_batch', type=int, default=8, help="train_batch")