Ejemplo n.º 1
0
def Train():
    print('********************load data********************')
    dataloader_train, dataloader_val = get_train_val_dataloader(
        batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8)
    print('********************load data succeed!********************')

    print('********************load model********************')
    # initialize and load the model
    if args.model == 'CXRNet':
        model = CXRNet(num_classes=N_CLASSES,
                       is_pre_trained=True).cuda()  #initialize model
        optimizer_model = optim.Adam(model.parameters(),
                                     lr=1e-3,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=1e-5)
        lr_scheduler_model = lr_scheduler.StepLR(optimizer_model,
                                                 step_size=10,
                                                 gamma=1)
        torch.backends.cudnn.benchmark = True  # improve train speed slightly
        bce_criterion = nn.BCELoss()  #define binary cross-entropy loss
        #mse_criterion = nn.MSELoss() #define regression loss

        model_unet = UNet(n_channels=3, n_classes=1).cuda()  #initialize model
        CKPT_PATH = config['CKPT_PATH'] + 'best_unet.pkl'
        if os.path.exists(CKPT_PATH):
            checkpoint = torch.load(CKPT_PATH)
            model_unet.load_state_dict(checkpoint)  #strict=False
            print("=> loaded well-trained unet model checkpoint: " + CKPT_PATH)
        model_unet.eval()
    else:
        print('No required model')
        return  #over
    print('********************load model succeed!********************')

    print('********************begin training!********************')
    AUROC_best = 0.50
    for epoch in range(config['MAX_EPOCHS']):
        since = time.time()
        print('Epoch {}/{}'.format(epoch + 1, config['MAX_EPOCHS']))
        print('-' * 10)
        train_loss = []
        model.train()  #set model to training mode
        with torch.autograd.enable_grad():
            for batch_idx, (image, label) in enumerate(dataloader_train):
                optimizer_model.zero_grad()
                var_image = torch.autograd.Variable(image).cuda()
                var_label = torch.autograd.Variable(label).cuda()
                var_mask = model_unet(var_image)
                var_output = model(var_image, var_mask)  #forward
                loss_tensor = bce_criterion(var_output, var_label)
                loss_tensor.backward()
                optimizer_model.step()
                train_loss.append(loss_tensor.item())
                sys.stdout.write(
                    '\r Epoch: {} / Step: {} : train BCE loss = {}'.format(
                        epoch + 1, batch_idx + 1,
                        float('%0.6f' % loss_tensor.item())))
                sys.stdout.flush()
        lr_scheduler_model.step()  #about lr and gamma
        print("\r Eopch: %5d train loss = %.6f" %
              (epoch + 1, np.mean(train_loss)))

        model.eval()  #turn to test mode
        val_loss = []
        gt = torch.FloatTensor().cuda()
        pred = torch.FloatTensor().cuda()
        with torch.autograd.no_grad():
            for batch_idx, (image, label) in enumerate(dataloader_val):
                gt = torch.cat((gt, label.cuda()), 0)
                var_image = torch.autograd.Variable(image).cuda()
                var_label = torch.autograd.Variable(label).cuda()
                var_mask = model_unet(var_image)
                var_output = model(var_image, var_mask)  #forward
                loss_tensor = bce_criterion(var_output, var_label)
                pred = torch.cat((pred, var_output.data), 0)
                val_loss.append(loss_tensor.item())
                sys.stdout.write(
                    '\r Epoch: {} / Step: {} : validation loss = {}'.format(
                        epoch + 1, batch_idx + 1,
                        float('%0.6f' % loss_tensor.item())))
                sys.stdout.flush()
        #evaluation
        AUROCs_avg = np.array(compute_AUCs(gt, pred)).mean()
        logger.info(
            "\r Eopch: %5d validation loss = %.6f, Validataion AUROC image=%.4f"
            % (epoch + 1, np.mean(val_loss), AUROCs_avg))

        #save checkpoint
        if AUROC_best < AUROCs_avg:
            AUROC_best = AUROCs_avg
            torch.save(
                model.state_dict(), config['CKPT_PATH'] +
                'best_model_CXRNet.pkl')  #Saving torch.nn.DataParallel Models
            print(' Epoch: {} model has been already save!'.format(epoch + 1))

        time_elapsed = time.time() - since
        print('Training epoch: {} completed in {:.0f}m {:.0f}s'.format(
            epoch + 1, time_elapsed // 60, time_elapsed % 60))
Ejemplo n.º 2
0
def Train():
    print('********************load data********************')
    if args.dataset == 'NIHCXR':
        dataloader_train = get_train_dataloader_NIH(
            batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8)
        dataloader_val = get_test_dataloader_NIH(
            batch_size=config['BATCH_SIZE'], shuffle=False, num_workers=8)
    elif args.dataset == 'VinCXR':
        dataloader_train = get_train_dataloader_VIN(
            batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8)
        dataloader_val = get_val_dataloader_VIN(
            batch_size=config['BATCH_SIZE'], shuffle=False, num_workers=8)
    else:
        print('No required dataset')
        return
    print('********************load data succeed!********************')

    print('********************load model********************')
    if args.model == 'CXRNet' and args.dataset == 'NIHCXR':
        N_CLASSES = len(CLASS_NAMES_NIH)
        model = CXRNet(num_classes=N_CLASSES,
                       is_pre_trained=True)  #initialize model
        CKPT_PATH = config[
            'CKPT_PATH'] + args.model + '_' + args.dataset + '_best.pkl'
        if os.path.exists(CKPT_PATH):
            checkpoint = torch.load(CKPT_PATH)
            model.load_state_dict(checkpoint)  #strict=False
            print(
                "=> Loaded well-trained CXRNet model checkpoint of NIH-CXR dataset: "
                + CKPT_PATH)
    elif args.model == 'CXRNet' and args.dataset == 'VinCXR':
        N_CLASSES = len(CLASS_NAMES_Vin)
        model = CXRNet(num_classes=N_CLASSES,
                       is_pre_trained=True)  #initialize model
        #model = se_densenet121(t_num_classes=N_CLASSES, pretrained=True)#initialize model
        CKPT_PATH = config[
            'CKPT_PATH'] + args.model + '_' + args.dataset + '_best.pkl'
        if os.path.exists(CKPT_PATH):
            checkpoint = torch.load(CKPT_PATH)
            model.load_state_dict(checkpoint)  #strict=False
            print(
                "=> Loaded well-trained CXRNet model checkpoint of NIH-CXR dataset: "
                + CKPT_PATH)
    else:
        print('No required model')
        return  #over
    model = nn.DataParallel(
        model).cuda()  # make model available multi GPU cores training
    optimizer_model = optim.Adam(model.parameters(),
                                 lr=1e-3,
                                 betas=(0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=1e-5)
    lr_scheduler_model = lr_scheduler.StepLR(optimizer_model,
                                             step_size=10,
                                             gamma=1)
    torch.backends.cudnn.benchmark = True  # improve train speed slightly
    bce_criterion = nn.BCELoss()  #define binary cross-entropy loss
    print('********************load model succeed!********************')

    print('********************begin training!********************')
    AUROC_best = 0.50
    for epoch in range(config['MAX_EPOCHS']):
        since = time.time()
        print('Epoch {}/{}'.format(epoch + 1, config['MAX_EPOCHS']))
        print('-' * 10)
        model.train()  #set model to training mode
        train_loss = []
        with torch.autograd.enable_grad():
            for batch_idx, (image, label, box) in enumerate(dataloader_train):
                var_image = torch.autograd.Variable(image).cuda()
                var_label = torch.autograd.Variable(label).cuda()

                optimizer_model.zero_grad()
                _, var_output = model(var_image)
                loss_tensor = bce_criterion(var_output, var_label)  #backward
                loss_tensor.backward()
                optimizer_model.step()  ##update parameters

                sys.stdout.write(
                    '\r Epoch: {} / Step: {} : train loss = {}'.format(
                        epoch + 1, batch_idx + 1,
                        float('%0.6f' % loss_tensor.item())))
                sys.stdout.flush()
                train_loss.append(loss_tensor.item())
        lr_scheduler_model.step()  #about lr and gamma
        print("\r Eopch: %5d train loss = %.6f" %
              (epoch + 1, np.mean(train_loss)))

        model.eval()  #turn to test mode
        val_loss = []
        gt = torch.FloatTensor().cuda()
        pred = torch.FloatTensor().cuda()
        with torch.autograd.no_grad():
            for batch_idx, (image, label, box) in enumerate(dataloader_val):
                var_image = torch.autograd.Variable(image).cuda()
                var_label = torch.autograd.Variable(label).cuda()
                _, var_output = model(var_image)  #forward
                loss_tensor = bce_criterion(var_output, var_label)  #backward
                sys.stdout.write(
                    '\r Epoch: {} / Step: {} : validation loss = {}'.format(
                        epoch + 1, batch_idx + 1,
                        float('%0.6f' % loss_tensor.item())))
                sys.stdout.flush()
                val_loss.append(loss_tensor.item())
                gt = torch.cat((gt, label.cuda()), 0)
                pred = torch.cat((pred, var_output.data), 0)
        AUROCs = compute_AUCs(gt, pred, N_CLASSES)
        AUROC_avg = np.array(AUROCs).mean()
        logger.info(
            "\r Eopch: %5d validation loss = %.6f, Validataion AUROC = %.4f" %
            (epoch + 1, np.mean(val_loss), AUROC_avg))

        AUROC_avg = Test()
        if AUROC_best < AUROC_avg:
            AUROC_best = AUROC_avg
            CKPT_PATH = config[
                'CKPT_PATH'] + args.model + '_' + args.dataset + '_best.pkl'
            torch.save(model.module.state_dict(),
                       CKPT_PATH)  #Saving torch.nn.DataParallel Models
            print(' Epoch: {} model has been already save!'.format(epoch + 1))

        time_elapsed = time.time() - since
        print('Training epoch: {} completed in {:.0f}m {:.0f}s'.format(
            epoch + 1, time_elapsed // 60, time_elapsed % 60))