def main(): test_dataset = ChestXrayDataSet(data_dir=DATA_DIR, image_list_file=IMAGE_LIST_TEST) length = test_dataset.__len__() print("The length of test data is ", length) # (image_name, label, image) = test_dataset.__getitem__(0) # print ("The path of the first image is ", image_name, ", the lable of it is ", label) # (image, label) = test_dataset.__getitem__(0) # print ("The lable of the first image is ", label) dataDir = DATA_DIR imageListFileTrain = IMAGE_LIST_TRAIN imageListFileVal = IMAGE_LIST_VAL timestampTime = time.strftime("%H%M%S") timestampDate = time.strftime("%d%m%Y") timestampLaunch = timestampDate + '-' + timestampTime transResize = 256 transCrop = 224 isTrained = True classCount = 156 batchSize = 16 epochSize = 100 ChexnetTrainer.train(dataDir, imageListFileTrain, imageListFileVal, transResize, transCrop, isTrained, classCount, batchSize, epochSize, timestampLaunch, None)
def train(dataDir, imageListFileTrain, imageListFileVal, transResize, transCrop, isTrained, classCount, batchSize, epochSize, launchTimestamp, checkpoint): """Train the network. Args: dataDir - path to the data dir imageListFileTrain - path to the iamge list file to train imageListFileVal - path to the iamge list file to train transResize - size of the image to scale down to transCrop - size of the cropped image isTrained - if True, uses pre-trained version of the network (pre-trained on imagenet) classCount - number of output classes batchSize - batch size epochSize - number of epochs launchTimestamp - date/time, used to assign unique name for the checkpoint file checkpoint - if not None loads the model and continues training """ # SETTINGS # ^^^^^^^^ # initialize and load the model # --------------- print("train begins!=======================") model = CheXNet(classCount, isTrained).cuda() model = torch.nn.DataParallel(model).cuda() # data transforms # --------------- normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) transformList = [] transformList.append(transforms.Resize(transResize)) transformList.append(transforms.RandomResizedCrop(transCrop)) transformList.append(transforms.RandomHorizontalFlip()) transformList.append(transforms.ToTensor()) transformList.append(normalize) transform = transforms.Compose(transformList) # datasets # --------------- datasetTrain = ChestXrayDataSet(data_dir=dataDir, image_list_file=imageListFileTrain, transform=transform) datasetVal = ChestXrayDataSet(data_dir=dataDir, image_list_file=imageListFileVal, transform=transform) print(datasetTrain.__len__()) dataLoaderTrain = DataLoader(dataset=datasetTrain, batch_size=batchSize, shuffle=False, num_workers=8, pin_memory=True) dataLoaderVal = DataLoader(dataset=datasetVal, batch_size=batchSize, shuffle=False, num_workers=8, pin_memory=True) # optimizer and scheduler # --------------- optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5) scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=5, mode='min') # loss # --------------- loss = torch.nn.BCELoss(size_average=True) # Load checkpoint # --------------- if checkpoint != None: modelCheckpoint = torch.load(checkpoint) model.load_state_dict(modelCheckpoint['state_dict']) optimizer.load_state_dict(modelCheckpoint['optimizer']) # Train # ^^^^^ # TODO: train, epochTrain and epochVal lossMin = 100000 for epochIdx in range(0, epochSize): print("EpochIdx: ################ ", epochIdx) timestampTime = time.strftime("%H%M%S") timestampDate = time.strftime("%d%m%Y") timestampSTART = timestampDate + '-' + timestampTime ChexnetTrainer.epochTrain(model, dataLoaderTrain, optimizer, scheduler, classCount, loss) lossVal = ChexnetTrainer.epochVal(model, dataLoaderVal, optimizer, scheduler, classCount, loss) timestampTime = time.strftime("%H%M%S") timestampDate = time.strftime("%d%m%Y") timestampEND = timestampDate + '-' + timestampTime scheduler.step(lossVal) if lossVal < lossMin: lossMin = lossVal torch.save( { 'epoch': epochIdx + 1, 'state_dict': model.state_dict(), 'best_loss': lossMin, 'optimizer': optimizer.state_dict() }, 'm-' + launchTimestamp + '.pth.tar') print('Epoch [' + str(epochIdx + 1) + '] [save] [' + timestampEND + '] loss= ' + str(lossVal)) else: print('Epoch [' + str(epochIdx + 1) + '] [----] [' + timestampEND + '] loss= ' + str(lossVal))