示例#1
0
def generate_arrays(phase):
    dataset = data.DataBowl3Detector(data_dir, config, phase=phase)
    n_samples = dataset.__len__()
    while True:
        for i in range(n_samples):
            x, y, _ = dataset.__getitem__(i)
            x = np.expand_dims(x, axis=-1)
            y = np.expand_dims(y, axis=0)
            yield (x, y)
def main():
    global args
    args = parser.parse_args()

    torch.manual_seed(0)
    torch.cuda.set_device(0)

    model = import_module(args.model)
    config, net, loss, get_pbb = model.get_model()
    start_epoch = args.start_epoch
    save_dir = args.save_dir

    if args.resume:
        checkpoint = torch.load(args.resume)
        if start_epoch == 0:
            start_epoch = checkpoint['epoch'] + 1
        if not save_dir:
            save_dir = checkpoint['save_dir']
        else:
            save_dir = os.path.join('results', save_dir)
        net.load_state_dict(checkpoint['state_dict'])
    else:
        if start_epoch == 0:
            start_epoch = 1
        if not save_dir:
            exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime())
            save_dir = os.path.join('results', args.model + '-' + exp_id)
        else:
            save_dir = os.path.join('results', save_dir)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    logfile = os.path.join(save_dir, 'log')
    if args.test != 1:
        sys.stdout = Logger(logfile)
        pyfiles = [f for f in os.listdir('./') if f.endswith('.py')]
        for f in pyfiles:
            shutil.copy(f, os.path.join(save_dir, f))
    n_gpu = setgpu(args.gpu)
    args.n_gpu = n_gpu
    print("arg", args.gpu)
    print("num_gpu", n_gpu)

    net = net.cuda()
    loss = loss.cuda()
    cudnn.benchmark = True
    net = DataParallel(net)
    datadir = config_training['preprocess_result_path']
    print("datadir", datadir)
    print("anchor", config['anchors'])
    print("pad_val", config['pad_value'])
    print("th_pos_train", config['th_pos_train'])

    if args.test == 1:
        margin = 32
        sidelen = 144
        print("args.test True")
        split_comber = SplitComb(sidelen, config['max_stride'],
                                 config['stride'], margin, config['pad_value'])
        dataset = data.DataBowl3Detector(datadir,
                                         'val9.npy',
                                         config,
                                         phase='test',
                                         split_comber=split_comber)
        test_loader = DataLoader(dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 collate_fn=data.collate,
                                 pin_memory=False)

        test(test_loader, net, get_pbb, save_dir, config, sidelen)
        return

    # net = DataParallel(net)

    train_dataset = data.DataBowl3Detector(datadir,
                                           'train_luna_9.npy',
                                           config,
                                           phase='train')
    print("len train_dataset", train_dataset.__len__())
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    val_dataset = data.DataBowl3Detector(datadir,
                                         'val9.npy',
                                         config,
                                         phase='val')
    print("len val_dataset", val_dataset.__len__())

    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    margin = 32
    sidelen = 144

    split_comber = SplitComb(sidelen, config['max_stride'], config['stride'],
                             margin, config['pad_value'])
    test_dataset = data.DataBowl3Detector(datadir,
                                          'val9.npy',
                                          config,
                                          phase='test',
                                          split_comber=split_comber)

    test_loader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=args.workers,
                             collate_fn=data.collate,
                             pin_memory=False)

    print("lr", args.lr)
    optimizer = torch.optim.SGD(net.parameters(),
                                args.lr,
                                momentum=0.9,
                                weight_decay=args.weight_decay)

    def get_lr(epoch):
        if epoch <= args.epochs * 0.5:
            lr = args.lr
        elif epoch <= args.epochs * 0.8:
            lr = 0.1 * args.lr
        else:
            lr = 0.01 * args.lr
        return lr

    best_val_loss = 100
    best_test_loss = 0

    for epoch in range(start_epoch, args.epochs + 1):
        print("epoch", epoch)
        train(train_loader, net, loss, epoch, optimizer, get_lr,
              args.save_freq, save_dir)
        best_val_loss = validate(val_loader, net, loss, best_val_loss, epoch,
                                 save_dir)
        if ((epoch > 150) and ((epoch + 1) % 10) == 0):
            best_test_loss = test_training(test_loader, net, get_pbb, save_dir,
                                           config, sidelen, best_test_loss,
                                           epoch, n_gpu)

        if ((epoch > 300) and ((epoch + 1) % 100) == 0):
            num_neg = train_dataset.get_neg_num_neg() + 800
            train_dataset.set_neg_num_neg(num_neg)
示例#3
0
            plt.subplot(xx, xx, i + 1).imshow(array_normalied(array[i, :, :]),
                                              color)
    elif axis == 1:
        for i in range(len):
            plt.subplot(xx, xx, i + 1).imshow(array_normalied(array[:, i, :]),
                                              color)
    elif axis == 2:
        for i in range(len):
            plt.subplot(xx, xx, i + 1).imshow(array_normalied(array[:, :, i]),
                                              color)
    else:
        print("axis=0/1/2")


# load data
dataset_train = data.DataBowl3Detector(config, process='all')
for i in range(20000, 20100):
    x, y, z = dataset_train[i]

# check dataset
n = 100  # 3750
for i in range(10):
    bbox = dataset_train.bboxes[n + 8 * i]
    filename = dataset_train.filenames[int(bbox[0])]
    print(i, filename)
    x, y, z = dataset_train[n + i]
    coord = np.where(y[:, 0, ...] >= 0.5)
    if len(coord[0]) > 0:
        show_image(y[coord[0][0], 0, coord[1][0] - 1:coord[1][0] + 2, ...])
        show_image(x[0, 4 * coord[1][0] - 2:4 * coord[1][0] + 2, ...])
def main():
    global args
    args = parser.parse_args()
    config_training = import_module(args.config)
    config_training = config_training.config
    # from config_training import config as config_training
    torch.manual_seed(0)
    torch.cuda.set_device(0)

    model = import_module(args.model)
    config, net, loss, get_pbb = model.get_model()
    start_epoch = args.start_epoch
    save_dir = args.save_dir
    
#    args.resume = True
    if args.resume:
        checkpoint = torch.load(args.resume)
        # if start_epoch == 0:
        #     start_epoch = checkpoint['epoch'] + 1
        # if not save_dir:
        #     save_dir = checkpoint['save_dir']
        # else:
        #     save_dir = os.path.join('results',save_dir)
        net.load_state_dict(checkpoint['state_dict'])
    # else:
    if start_epoch == 0:
        start_epoch = 1
    if not save_dir:
        exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime())
        save_dir = os.path.join('results', args.model + '-' + exp_id)
    else:
        save_dir = os.path.join('results',save_dir)
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    logfile = os.path.join(save_dir,'log')
    if args.test!=1:
        sys.stdout = Logger(logfile)
        pyfiles = [f for f in os.listdir('./') if f.endswith('.py')]
        for f in pyfiles:
            shutil.copy(f,os.path.join(save_dir,f))
    n_gpu = setgpu(args.gpu)
    args.n_gpu = n_gpu
    net = net.cuda()
    loss = loss.cuda()
    cudnn.benchmark = False                     #True
    net = DataParallel(net)
    traindatadir = config_training['train_preprocess_result_path']
    valdatadir = config_training['val_preprocess_result_path']
    testdatadir = config_training['test_preprocess_result_path']
    trainfilelist = []
   # with open("/home/mpadmana/anaconda3/envs/DeepLung_original/luna_patient_names/luna_train_list.pkl",'rb') as f:
    #    trainfilelist=pickle.load(f)
    with open("/home/mpadmana/anaconda3/envs/DeepLung_original/methodist_patient_names/methodist_train.pkl",'rb') as f:

        trainfilelist=pickle.load(f)
        
    valfilelist = []
    #with open("/home/mpadmana/anaconda3/envs/DeepLung_original/luna_patient_names/luna_val_list.pkl",'rb') as f:
     #   valfilelist=pickle.load(f)
    with open ("/home/mpadmana/anaconda3/envs/DeepLung_original/methodist_patient_names/methodist_val.pkl",'rb') as f:
        valfilelist=pickle.load(f)
    testfilelist = []
    #with open("/home/mpadmana/anaconda3/envs/DeepLung_original/luna_patient_names/luna_test_list.pkl",'rb') as f:
     #   testfilelist=pickle.load(f)
    with open("/home/mpadmana/anaconda3/envs/DeepLung_original/methodist_patient_names/methodist_test.pkl",'rb') as f:
        testfilelist=pickle.load(f)
    testfilelist=['download20180608140526download20180608140500001_1_3_12_30000018060618494775800001943']
    if args.test == 1:

        margin = 32
        sidelen = 144
        import data
        split_comber = SplitComb(sidelen,config['max_stride'],config['stride'],margin,config['pad_value'])
        dataset = data.DataBowl3Detector(
            testdatadir,
            testfilelist,
            config,
            phase='test',
            split_comber=split_comber)
        test_loader = DataLoader(
            dataset,
            batch_size = 1,
            shuffle = False,
            num_workers = 0,
            collate_fn = data.collate,
            pin_memory=False)

        for i, (data, target, coord, nzhw) in enumerate(test_loader): # check data consistency
            if i >= len(testfilelist)/args.batch_size:
                break
        
        test(test_loader, net, get_pbb, save_dir,config)

        return
    #net = DataParallel(net)
    from detector import data
    print(len(trainfilelist))
    dataset = data.DataBowl3Detector(
        traindatadir,
        trainfilelist,
        config,
        phase = 'train')
    train_loader = DataLoader(
        dataset,
        batch_size = args.batch_size,
        shuffle = True,
        num_workers = 0,
        pin_memory=True)

    dataset = data.DataBowl3Detector(
        valdatadir,
        valfilelist,
        config,
        phase = 'val')
    val_loader = DataLoader(
        dataset,
        batch_size = args.batch_size,
        shuffle = False,
        num_workers = 0,
        pin_memory=True)

    for i, (data, target, coord) in enumerate(train_loader): # check data consistency
        if i >= len(trainfilelist)/args.batch_size:
            break

    for i, (data, target, coord) in enumerate(val_loader): # check data consistency
        if i >= len(valfilelist)/args.batch_size:
            break

    optimizer = torch.optim.SGD(
        net.parameters(),
        args.lr,
        momentum = 0.9,
        weight_decay = args.weight_decay)
    
    def get_lr(epoch):
        if epoch <= args.epochs * 1/3: #0.5:
            lr = args.lr
        elif epoch <= args.epochs * 2/3: #0.8:
            lr = 0.1 * args.lr
        elif epoch <= args.epochs * 0.8:
            lr = 0.05 * args.lr
        else:
            lr = 0.01 * args.lr
        return lr
    

    for epoch in range(start_epoch, args.epochs + 1):
        train(train_loader, net, loss, epoch, optimizer, get_lr, args.save_freq, save_dir)
        validate(val_loader, net, loss)
示例#5
0
def main():
    global args
    args = parser.parse_args()

    torch.manual_seed(0)
    torch.cuda.set_device(0)

    print("import module ")
    model = import_module(args.model)
    print("get module")
    config, net, loss, get_pbb = model.get_model()
    start_epoch = args.start_epoch
    save_dir = args.save_dir
    if args.resume:
        checkpoint = torch.load(args.resume)
        if start_epoch == 0:
            start_epoch = checkpoint['epoch'] + 1
        if not save_dir:
            save_dir = checkpoint['save_dir']
        else:
            save_dir = os.path.join('results', save_dir)
        net.load_state_dict(checkpoint['state_dict'])
        print("save dir ", save_dir)
    else:
        if start_epoch == 0:
            start_epoch = 1
        if not save_dir:
            exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime())
            save_dir = os.path.join('results', args.model + '-' + exp_id)
        else:
            save_dir = os.path.join('results', save_dir)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    logfile = os.path.join(save_dir, 'log')
    if args.test != 1:
        sys.stdout = Logger(logfile)
        pyfiles = [f for f in os.listdir('./') if f.endswith('.py')]
        for f in pyfiles:
            shutil.copy(f, os.path.join(save_dir, f))
    print("num of gpu", args.gpu)
    n_gpu = setgpu(args.gpu)
    args.n_gpu = n_gpu
    print("get net")
    net = net.cuda()
    print("get loss")
    loss = loss.cuda()
    cudnn.benchmark = True
    print("data parallel")
    net = DataParallel(net)
    datadir = config_training['preprocess_result_path']

    if args.test == 1:
        print("testing")
        margin = 32
        sidelen = 112

        split_comber = SplitComb(sidelen, config['max_stride'],
                                 config['stride'], margin, config['pad_value'])

        print("load data")
        dataset = data.DataBowl3Detector(datadir,
                                         'full.npy',
                                         config,
                                         phase='test',
                                         split_comber=split_comber)
        test_loader = DataLoader(dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 collate_fn=data.collate,
                                 pin_memory=False)

        test(test_loader, net, get_pbb, save_dir, config)
        return

    #net = DataParallel(net)
    dataset = data.DataBowl3Detector(datadir,
                                     'kaggleluna_full.npy',
                                     config,
                                     phase='train')
    train_loader = DataLoader(dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    dataset = data.DataBowl3Detector(datadir,
                                     'valsplit.npy',
                                     config,
                                     phase='val')
    val_loader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    optimizer = torch.optim.SGD(net.parameters(),
                                args.lr,
                                momentum=0.9,
                                weight_decay=args.weight_decay)

    def get_lr(epoch):
        if epoch <= args.epochs * 0.5:
            lr = args.lr
        elif epoch <= args.epochs * 0.8:
            lr = 0.1 * args.lr
        else:
            lr = 0.01 * args.lr
        return lr

    for epoch in range(start_epoch, args.epochs + 1):
        train(train_loader, net, loss, epoch, optimizer, get_lr,
              args.save_freq, save_dir)
        validate(val_loader, net, loss)
示例#6
0
def main():
    global args
    args = parser.parse_args()

    torch.manual_seed(0)
    torch.cuda.set_device(0)

    model = import_module(args.model)
    config, net, loss, get_pbb = model.get_model()
    start_epoch = args.start_epoch
    save_dir = args.save_dir

    if args.resume:
        checkpoint = torch.load(args.resume)
        if start_epoch == 0:
            start_epoch = checkpoint['epoch'] + 1
        if not save_dir:
            save_dir = checkpoint['save_dir']
        else:
            save_dir = os.path.join('results', save_dir)
        net.load_state_dict(checkpoint['state_dict'])
    else:
        if start_epoch == 0:
            start_epoch = 1
        if not save_dir:
            exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime())
            save_dir = os.path.join('results', args.model + '-' + exp_id)
        else:
            save_dir = os.path.join('results', save_dir)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    logfile = os.path.join(save_dir, 'log')
    if args.test != 1:
        sys.stdout = Logger(logfile)
        pyfiles = [f for f in os.listdir('./') if f.endswith('.py')]
        for f in pyfiles:
            shutil.copy(f, os.path.join(save_dir, f))
    n_gpu = setgpu(args.gpu)
    args.n_gpu = n_gpu
    net = net.cuda()
    loss = loss.cuda()
    cudnn.benchmark = False  #True
    net = DataParallel(net)
    traindatadir = config_training['train_preprocess_result_path']
    valdatadir = config_training['val_preprocess_result_path']
    testdatadir = config_training['test_preprocess_result_path']
    trainfilelist = []
    for f in os.listdir(config_training['train_data_path']):
        if f.endswith('.mhd') and f[:-4] not in config_training['black_list']:
            trainfilelist.append(f[:-4])
    valfilelist = []
    for f in os.listdir(config_training['val_data_path']):
        if f.endswith('.mhd') and f[:-4] not in config_training['black_list']:
            valfilelist.append(f[:-4])
    testfilelist = []
    for f in os.listdir(config_training['test_data_path']):
        if f.endswith('.mhd') and f[:-4] not in config_training['black_list']:
            testfilelist.append(f[:-4])

    if args.test == 1:
        margin = 32
        sidelen = 144
        import data
        split_comber = SplitComb(sidelen, config['max_stride'],
                                 config['stride'], margin, config['pad_value'])
        dataset = data.DataBowl3Detector(testdatadir,
                                         testfilelist,
                                         config,
                                         phase='test',
                                         split_comber=split_comber)
        test_loader = DataLoader(dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 collate_fn=data.collate,
                                 pin_memory=False)

        for i, (data, target, coord,
                nzhw) in enumerate(test_loader):  # check data consistency
            if i >= len(testfilelist) / args.batch_size:
                break

        test(test_loader, net, get_pbb, save_dir, config)
        return
    #net = DataParallel(net)
    import data
    dataset = data.DataBowl3Detector(traindatadir,
                                     trainfilelist,
                                     config,
                                     phase='train')
    train_loader = DataLoader(dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    dataset = data.DataBowl3Detector(valdatadir,
                                     valfilelist,
                                     config,
                                     phase='val')
    val_loader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    for i, (data, target,
            coord) in enumerate(train_loader):  # check data consistency
        if i >= len(trainfilelist) / args.batch_size:
            break

    for i, (data, target,
            coord) in enumerate(val_loader):  # check data consistency
        if i >= len(valfilelist) / args.batch_size:
            break

    optimizer = torch.optim.SGD(net.parameters(),
                                args.lr,
                                momentum=0.9,
                                weight_decay=args.weight_decay)

    def get_lr(epoch):
        if epoch <= args.epochs * 1 / 3:  #0.5:
            lr = args.lr
        elif epoch <= args.epochs * 2 / 3:  #0.8:
            lr = 0.1 * args.lr
        elif epoch <= args.epochs * 0.8:
            lr = 0.05 * args.lr
        else:
            lr = 0.01 * args.lr
        return lr

    for epoch in range(start_epoch, args.epochs + 1):
        train(train_loader, net, loss, epoch, optimizer, get_lr,
              args.save_freq, save_dir)
        validate(val_loader, net, loss)
示例#7
0
def main():
    global args
    start = time.time()
    print('start!')
    args = parser.parse_args()
    config_training = import_module(args.config)
    config_training = config_training.config
    # from config_training import config as config_training
    torch.manual_seed(0)
    # torch.cuda.set_device(0)
    model = import_module(args.model)
    config, net, loss, get_pbb = model.get_model()
    start_epoch = args.start_epoch  #0

    save_dir = args.save_dir  #res18/retrft960
    # print('args.resume!',args.resume, time() - start)
    if args.resume:
        checkpoint = torch.load(args.resume)
        print('args.resume', args.resume)
        # if start_epoch == 0:
        #     start_epoch = checkpoint['epoch'] + 1
        # if not save_dir:
        #     save_dir = checkpoint['save_dir']
        # else:
        #     save_dir = os.path.join('results',save_dir)
        net.load_state_dict(checkpoint['state_dict'])
    # else:
    # print('start_epoch',start_epoch, time() - start)
    if start_epoch == 0:
        start_epoch = 1
    if not save_dir:
        exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime())
        save_dir = os.path.join('results', args.model + '-' + exp_id)
    else:
        save_dir = os.path.join('results', save_dir)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    logfile = os.path.join(save_dir, 'log')
    # print('args.test',args.test, time() - start)
    if args.test != 1:
        sys.stdout = Logger(logfile)
        pyfiles = [f for f in os.listdir('./') if f.endswith('.py')]
        # print('pyfiles', time() - start)
        for f in pyfiles:
            shutil.copy(f, os.path.join(save_dir, f))
            # print('pyfiles1', time() - start)
    n_gpu = setgpu(args.gpu)
    args.n_gpu = n_gpu

    # net = net.cuda()

    loss = loss.cuda()
    device = 'cuda'
    net = net.to(device)
    cudnn.benchmark = True  #False
    net = DataParallel(net).cuda()
    # print('net0', time.time() - start)
    traindatadir = config_training[
        'train_preprocess_result_path']  #'/home/zhaojie/zhaojie/Lung/data/luna16/LUNA16PROPOCESSPATH/'
    valdatadir = config_training[
        'val_preprocess_result_path']  #'/home/zhaojie/zhaojie/Lung/data/luna16/LUNA16PROPOCESSPATH/'
    testdatadir = config_training[
        'test_preprocess_result_path']  #'/home/zhaojie/zhaojie/Lung/data/luna16/LUNA16PROPOCESSPATH/'
    trainfilelist = []
    # print('data_path',config_training['train_data_path'])
    for folder in config_training['train_data_path']:
        print('folder', folder)
        for f in os.listdir(folder):
            if f.endswith(
                    '.mhd') and f[:-4] not in config_training['black_list']:
                trainfilelist.append(folder.split('/')[-2] + '/' + f[:-4])
    valfilelist = []
    for folder in config_training['val_data_path']:
        for f in os.listdir(folder):
            if f.endswith(
                    '.mhd') and f[:-4] not in config_training['black_list']:
                valfilelist.append(folder.split('/')[-2] + '/' + f[:-4])
    testfilelist = []
    for folder in config_training['test_data_path']:
        for f in os.listdir(folder):
            if f.endswith(
                    '.mhd') and f[:-4] not in config_training['black_list']:
                testfilelist.append(folder.split('/')[-2] + '/' + f[:-4])

    if args.test == 1:
        print('--------test-------------')
        print('len(testfilelist)', len(testfilelist))

        print('batch_size', args.batch_size)
        # margin = 32
        margin = 16
        # sidelen = 144
        sidelen = 128
        # sidelen = 208
        import data
        split_comber = SplitComb(sidelen, config['max_stride'],
                                 config['stride'], margin, config['pad_value'])
        dataset = data.DataBowl3Detector(testdatadir,
                                         testfilelist,
                                         config,
                                         phase='test',
                                         split_comber=split_comber)
        test_loader = DataLoader(dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 collate_fn=data.collate,
                                 pin_memory=False)
        iter1, iter2, iter3, iter4 = next(iter(test_loader))
        # print("sample: ", len(iter1))
        # print("lable: ", iter2.size())
        # print("coord: ", iter3.size())
        for i, (data, target, coord,
                nzhw) in enumerate(test_loader):  # check data consistency
            if i >= len(testfilelist) // args.batch_size:
                break

        test(test_loader, net, get_pbb, save_dir, config)
        return

    import data
    print('len(trainfilelist)', len(trainfilelist))
    # print('trainfilelist',trainfilelist)
    print('batch_size', args.batch_size)
    dataset = data.DataBowl3Detector(traindatadir,
                                     trainfilelist,
                                     config,
                                     phase='train')
    # print('train_loader')
    train_loader = DataLoader(dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    dataset = data.DataBowl3Detector(valdatadir,
                                     valfilelist,
                                     config,
                                     phase='val')
    val_loader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)
    iter1, iter2, iter3 = next(iter(train_loader))
    # print("sample: ", iter1.size())
    # print("lable: ", iter2.size())
    # print("coord: ", iter3.size())
    for i, (data, target,
            coord) in enumerate(train_loader):  # check data consistency
        if i >= len(trainfilelist) / args.batch_size:
            break

    for i, (data, target,
            coord) in enumerate(val_loader):  # check data consistency
        if i >= len(valfilelist) / args.batch_size:
            break

    optimizer = torch.optim.SGD(net.parameters(),
                                args.lr,
                                momentum=0.9,
                                weight_decay=args.weight_decay)

    def get_lr(epoch):
        if epoch <= args.epochs * 1 / 3:  #0.5:
            lr = args.lr
        elif epoch <= args.epochs * 2 / 3:  #0.8:
            lr = 0.1 * args.lr
        elif epoch <= args.epochs * 0.8:
            lr = 0.05 * args.lr
        else:
            lr = 0.01 * args.lr
        return lr

    for epoch in range(start_epoch, args.epochs + 1):
        train(train_loader, net, loss, epoch, optimizer, get_lr,
              args.save_freq, save_dir)
        validate(val_loader, net, loss)
示例#8
0
文件: layers.py 项目: leurekay/myLung
    #loss_cls=tf.cast(loss_cls,tf.float32)
    loss = tf.add(loss_cls, loss_reg)
    return loss


if __name__ == '__main__':
    model = n_net()

    model.summary()

    plot_model(model, to_file='images/model3d.png', show_shapes=True)

    import data

    data_dir = '/data/lungCT/luna/temp/luna_npy'
    dataset = data.DataBowl3Detector(data_dir, data.config)
    patch, label, coord = dataset.__getitem__(22)

    y_true = tf.constant(label)

    a = myloss(y_true, y_true)

    #    hard=hard_mining(a,a,4)

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    aa = sess.run(a)

#    hh=sess.run(hard)
示例#9
0
def main():
    global args
    args = parser.parse_args()

    seed = 0
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.set_device(0)

    model = import_module(args.model)
    config, net, loss, get_pbb = model.get_model()
    start_epoch = args.start_epoch
    save_dir = args.save_dir

    if args.resume:
        checkpoint = torch.load(args.resume)
        if start_epoch == 0:
            start_epoch = checkpoint['epoch'] + 1
        if not save_dir:
            save_dir = checkpoint['save_dir']
        else:
            save_dir = os.path.join('results', save_dir)
        net.load_state_dict(checkpoint['state_dict'])
    else:
        if start_epoch == 0:
            start_epoch = 1
        if not save_dir:
            exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime())
            save_dir = os.path.join('results', args.model + '-' + exp_id)
        else:
            save_dir = os.path.join('results', save_dir)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    logfile = os.path.join(save_dir, 'log')

    # if training, save files to know how training was done
    if args.test != 1:
        sys.stdout = Logger(logfile)
        # sys.stdout = logging.getLogger(logfile)
        print sys.argv
        pyfiles = [f for f in os.listdir('./') if f.endswith('.py')]
        for f in pyfiles:
            shutil.copy(f, os.path.join(save_dir, f))
        shutil.copy('config_training.py', os.path.join(save_dir))

    n_gpu = setgpu(args.gpu)
    args.n_gpu = n_gpu
    net = net.cuda()
    loss = loss.cuda()
    cudnn.benchmark = True
    net = DataParallel(net)
    datadir = config_training[
        'preprocess_result_path'] if args.data is None else args.data

    if args.test == 1:
        margin = 32
        sidelen = 144

        split_comber = SplitComb(sidelen, config['max_stride'],
                                 config['stride'], margin, config['pad_value'])

        test_set_file = args.test_filename

        dataset = data.DataBowl3Detector(datadir,
                                         test_set_file,
                                         config,
                                         phase='test',
                                         split_comber=split_comber)
        test_loader = DataLoader(dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 collate_fn=data.collate,
                                 pin_memory=False)

        test(test_loader, net, get_pbb, save_dir, config, args.test_set)
        return

    #net = DataParallel(net)

    dataset = data.DataBowl3Detector(datadir,
                                     args.train_filename,
                                     config,
                                     phase='train')
    train_loader = DataLoader(dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    dataset = data.DataBowl3Detector(datadir,
                                     args.val_filename,
                                     config,
                                     phase='val')
    val_loader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    if args.optim == 'adam':
        optimizer = torch.optim.Adam(net.parameters())
    elif args.optim == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(),
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)

    def get_lr(epoch):
        if epoch <= args.epochs * 0.5:
            lr = args.lr
        elif epoch <= args.epochs * 0.8:
            lr = 0.1 * args.lr
        else:
            lr = 0.01 * args.lr
        return lr

    for epoch in range(start_epoch, args.epochs + 1):
        train(train_loader, net, loss, epoch, optimizer, get_lr,
              args.save_freq, save_dir)
        validate(val_loader, net, loss)
示例#10
0
def main():
    global args
    args = parser.parse_args()

    torch.manual_seed(0)
    torch.cuda.set_device(0)

    model = import_module(args.model)
    config, net, loss = model.get_model()
    start_epoch = args.start_epoch
    save_dir = args.save_dir

    if args.resume:
        checkpoint = torch.load(args.resume)
        #if start_epoch == 0:
        #    start_epoch = checkpoint['epoch'] + 1
        #if not save_dir:
        #    save_dir = checkpoint['save_dir']
        #else:
        save_dir = os.path.join('results', save_dir)
        net.load_state_dict(checkpoint['state_dict'])
    else:
        if start_epoch == 0:
            start_epoch = 1
        if not save_dir:
            exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime())
            save_dir = os.path.join('results', args.model + '-' + exp_id)
        else:
            save_dir = os.path.join('results', save_dir)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    logfile = os.path.join(save_dir, 'log')
    if args.test != 1:
        sys.stdout = Logger(logfile)
        pyfiles = [f for f in os.listdir('./') if f.endswith('.py')]
        for f in pyfiles:
            shutil.copy(f, os.path.join(save_dir, f))
    n_gpu = setgpu(args.gpu)
    args.n_gpu = n_gpu
    print("arg", args.gpu)
    print("num_gpu", n_gpu)
    net = net.cuda()
    loss = loss.cuda()
    cudnn.benchmark = True
    net = DataParallel(net)
    datadir = config_training['preprocess_result_path']

    print("datadir", datadir)
    print("pad_val", config['pad_value'])
    print("aug type", config['augtype'])

    dataset = data.DataBowl3Detector(datadir,
                                     'train_luna_9.npy',
                                     config,
                                     phase='train')
    print("len train_dataset", dataset.__len__())
    train_loader = DataLoader(dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    dataset = data.DataBowl3Detector(datadir, 'val9.npy', config, phase='val')
    print("len val_dataset", dataset.__len__())

    val_loader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    optimizer = torch.optim.SGD(net.parameters(),
                                args.lr,
                                momentum=0.9,
                                weight_decay=args.weight_decay)

    def get_lr(epoch):
        if epoch <= args.epochs * 0.5:
            lr = args.lr
        elif epoch <= args.epochs * 0.8:
            lr = 0.1 * args.lr
        else:
            lr = 0.01 * args.lr
        return lr

    best_val_loss = 100
    best_mal_loss = 100
    for epoch in range(start_epoch, args.epochs + 1):
        print("epoch", epoch)
        train(train_loader, net, loss, epoch, optimizer, get_lr,
              args.save_freq, save_dir)
        best_val_loss, best_mal_loss = validate(val_loader, net, loss,
                                                best_val_loss, best_mal_loss,
                                                epoch, save_dir)
示例#11
0
def main():
    global args
    args = parser.parse_args()

    torch.manual_seed(0)
    torch.cuda.set_device(0)

    model = import_module(args.model)
    config, net, loss, get_pbb = model.get_model()
    net.load_state_dict(
        torch.load(
            "/data/wzeng/DSB_3/training/detector/results/res18_extend/048.ckpt"
        )['state_dict'])
    model_path = "/data/wzeng/DSB_3/training/detector/results/res18_extend/048.ckpt"
    print('loading model form ' + model_path)

    #model_dict = net.state_dict()
    #pretrained_dict = torch.load("/data/wzeng/DSB_3/training/detector/results/res18_extend/081.ckpt")['state_dict']
    #pretrained_dict = {k: v for k, v in model_dict.items() if k in pretrained_dict}
    #model_dict.update(pretrained_dict)
    #net.load_state_dict(model_dict)

    #model_path = "/data/wzeng/DSB_3/training/detector/results/res18_extend/081.ckpt"
    #print('loading model form ' + model_path)

    start_epoch = args.start_epoch
    save_dir = args.save_dir

    if args.resume:
        #         print('start resume')
        print('loading model form ' + args.resume)
        checkpoint = torch.load(args.resume)
        if start_epoch == 0:
            start_epoch = checkpoint['epoch'] + 1
        if not save_dir:
            save_dir = checkpoint['save_dir']
        else:
            save_dir = os.path.join('results', save_dir)
        net.load_state_dict(checkpoint['state_dict'])


#         print('resume end')
    else:
        if start_epoch == 0:
            start_epoch = 1
        if not save_dir:
            exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime())
            save_dir = os.path.join('results', args.model + '-' + exp_id)
        else:
            save_dir = os.path.join('results', save_dir)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    logfile = os.path.join(save_dir, 'log')
    if args.test != 1:
        sys.stdout = Logger(logfile)
        pyfiles = [f for f in os.listdir('./') if f.endswith('.py')]
        for f in pyfiles:
            shutil.copy(f, os.path.join(save_dir, f))
    n_gpu = setgpu(args.gpu)
    args.n_gpu = n_gpu
    net = net.cuda()
    loss = loss.cuda()
    cudnn.benchmark = True
    net = DataParallel(net)
    datadir = config_training['preprocess_result_path']

    if args.test == 1:
        margin = 16
        sidelen = 128
        #margin = 32
        #sidelen = 144
        #         print('dataloader....')
        split_comber = SplitComb(sidelen, config['max_stride'],
                                 config['stride'], margin, config['pad_value'])
        dataset = data.DataBowl3Detector(datadir,
                                         'new_test.npy',
                                         config,
                                         phase='test',
                                         split_comber=split_comber)
        test_loader = DataLoader(dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 collate_fn=data.collate,
                                 pin_memory=False)
        #         print('start testing.....')
        test(test_loader, net, get_pbb, save_dir, config)
        return

    #net = DataParallel(net)

    dataset = data.DataBowl3Detector(datadir,
                                     'train_val.npy',
                                     config,
                                     phase='train')
    train_loader = DataLoader(dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    dataset = data.DataBowl3Detector(datadir,
                                     'new_test.npy',
                                     config,
                                     phase='val')
    val_loader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    optimizer = torch.optim.SGD(net.parameters(),
                                args.lr,
                                momentum=0.9,
                                weight_decay=args.weight_decay)

    #optimizer = torch.optim.Adam(
    #    net.parameters(),
    #    args.lr,
    #    weight_decay = args.weight_decay)

    def get_lr(epoch):
        if epoch <= args.epochs * 0.5:
            lr = args.lr
        elif epoch <= args.epochs * 0.8:
            lr = 0.1 * args.lr
        else:
            lr = 0.01 * args.lr
        return lr

    for epoch in range(start_epoch, args.epochs + 1):
        train(train_loader, net, loss, epoch, optimizer, get_lr,
              args.save_freq, save_dir)
        validate(val_loader, net, loss)
示例#12
0
def main():
    global args
    args = parser.parse_args()
    print(args.config)
    config_training = import_module(args.config)
    config_training = config_training.config
    # from config_training import config as config_training
    torch.manual_seed(0)
    torch.cuda.set_device(0)

    model = import_module(args.model)
    config, net, loss, get_pbb = model.get_model()
    start_epoch = args.start_epoch
    save_dir = args.save_dir

    model2 = torch.nn.Sequential()
    model2.add_module('linear', torch.nn.Linear(3, 6, bias=True))
    model2.linear.weight = torch.nn.Parameter(torch.randn(6, 3))
    model2.linear.bias = torch.nn.Parameter(torch.randn(6))
    loss2 = torch.nn.CrossEntropyLoss()
    optimizer2 = optim.SGD(model2.parameters(), lr=args.lr,
                           momentum=0.9)  #, weight_decay=args.weight_decay)

    if args.resume:
        print('resume from ', args.resume)
        checkpoint = torch.load(args.resume)
        # if start_epoch == 0:
        #     start_epoch = checkpoint['epoch'] + 1
        # if not save_dir:
        #     save_dir = checkpoint['save_dir']
        # else:
        #     save_dir = os.path.join('results',save_dir)
        # print(checkpoint.keys())
        net.load_state_dict(checkpoint['state_dict'])
        if start_epoch != 0:
            model2.load_state_dict(checkpoint['state_dict2'])
    # else:
    if start_epoch == 0:
        start_epoch = 1
    if not save_dir:
        exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime())
        save_dir = os.path.join('results', args.model + '-' + exp_id)
    else:
        save_dir = os.path.join('results', save_dir)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    logfile = os.path.join(save_dir, 'log')
    if args.test != 1:
        sys.stdout = Logger(logfile)
        pyfiles = [f for f in os.listdir('./') if f.endswith('.py')]
        for f in pyfiles:
            shutil.copy(f, os.path.join(save_dir, f))
    n_gpu = setgpu(args.gpu)
    args.n_gpu = n_gpu
    net = net.cuda()
    loss = loss.cuda()
    cudnn.benchmark = False  #True
    net = DataParallel(net)
    traindatadir = config_training['train_preprocess_result_path']
    valdatadir = config_training['val_preprocess_result_path']
    testdatadir = config_training['test_preprocess_result_path']
    trainfilelist = []
    print config_training['train_data_path']
    for folder in config_training['train_data_path']:
        print folder
        for f in os.listdir(folder):
            if f.endswith(
                    '.mhd') and f[:-4] not in config_training['black_list']:
                if f[:-4] not in fnamedct:
                    trainfilelist.append(folder.split('/')[-2] + '/' + f[:-4])
                else:
                    trainfilelist.append(
                        folder.split('/')[-2] + '/' + fnamedct[f[:-4]])
    valfilelist = []
    for folder in config_training['val_data_path']:
        for f in os.listdir(folder):
            if f.endswith(
                    '.mhd') and f[:-4] not in config_training['black_list']:
                if f[:-4] not in fnamedct:
                    valfilelist.append(folder.split('/')[-2] + '/' + f[:-4])
                else:
                    valfilelist.append(
                        folder.split('/')[-2] + '/' + fnamedct[f[:-4]])
    testfilelist = []
    for folder in config_training['test_data_path']:
        for f in os.listdir(folder):
            if f.endswith(
                    '.mhd') and f[:-4] not in config_training['black_list']:
                if f[:-4] not in fnamedct:
                    testfilelist.append(folder.split('/')[-2] + '/' + f[:-4])
                else:
                    testfilelist.append(
                        folder.split('/')[-2] + '/' + fnamedct[f[:-4]])
    if args.test == 1:
        margin = 32
        sidelen = 144
        import data
        split_comber = SplitComb(sidelen, config['max_stride'],
                                 config['stride'], margin, config['pad_value'])
        dataset = data.DataBowl3Detector(testdatadir,
                                         testfilelist,
                                         config,
                                         phase='test',
                                         split_comber=split_comber)
        test_loader = DataLoader(dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 collate_fn=data.collate,
                                 pin_memory=False)

        for i, (data, target, coord,
                nzhw) in enumerate(test_loader):  # check data consistency
            if i >= len(testfilelist) / args.batch_size:
                break

        test(test_loader, net, get_pbb, save_dir, config)
        return
    #net = DataParallel(net)
    import data
    print len(trainfilelist)
    dataset = data.DataBowl3Detector(traindatadir,
                                     trainfilelist,
                                     config,
                                     phase='train')
    train_loader = DataLoader(dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    dataset = data.DataBowl3Detector(valdatadir,
                                     valfilelist,
                                     config,
                                     phase='val')
    val_loader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)
    # load weak data
    # weakdata = pd.read_csv(config_training['weaktrain_annos_path'], names=['fname', 'position', 'centerslice'])
    # weakfilename = weakdata['fname'].tolist()[1:]
    # weakfilename = list(set(weakfilename))
    # print('#weakdata', len(weakfilename))
    for i, (data, target,
            coord) in enumerate(train_loader):  # check data consistency
        if i >= len(trainfilelist) / args.batch_size:
            break

    for i, (data, target,
            coord) in enumerate(val_loader):  # check data consistency
        if i >= len(valfilelist) / args.batch_size:
            break
    optimizer = torch.optim.SGD(net.parameters(),
                                args.lr,
                                momentum=0.9,
                                weight_decay=args.weight_decay)
    npars = 0
    for par in net.parameters():
        curnpar = 1
        for s in par.size():
            curnpar *= s
        npars += curnpar
    print('network size', npars)

    def get_lr(epoch):
        if epoch <= args.epochs * 1 / 2:  #0.5:
            lr = args.lr
        elif epoch <= args.epochs * 3 / 4:  #0.8:
            lr = 0.5 * args.lr
        # elif epoch <= args.epochs * 0.8:
        #     lr = 0.05 * args.lr
        else:
            lr = 0.1 * args.lr
        return lr

    for epoch in range(start_epoch, args.epochs + 1):
        # if epoch % 10 == 0:
        import data
        margin = 32
        sidelen = 144
        split_comber = SplitComb(sidelen, config['max_stride'],
                                 config['stride'], margin, config['pad_value'])
        dataset = weakdatav2.DataBowl3Detector(
            config_training['weaktrain_data_path'],
            weakdct.keys(),
            config,
            phase='test',
            split_comber=split_comber)
        weaktest_loader = DataLoader(dataset,
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=args.workers,
                                     collate_fn=data.collate,
                                     pin_memory=False)
        print(len(weaktest_loader))
        for i, (data, target, coord,
                nzhw) in enumerate(weaktest_loader):  # check data consistency
            if i >= len(testfilelist) / args.batch_size:
                break
        srslst, cdxlst, cdylst, cdzlst, dimlst, prblst, poslst, lwzlst, upzlst = weaktest(
            weaktest_loader, model2, net, get_pbb, save_dir, config, epoch)
        config['ep'] = epoch
        config['save_dir'] = save_dir
        with open(
                config['save_dir'] + 'weakinferep' + str(config['ep']) + 'srs',
                'wb') as fp:
            pickle.dump(srslst, fp)
        with open(
                config['save_dir'] + 'weakinferep' + str(config['ep']) + 'cdx',
                'wb') as fp:
            pickle.dump(cdxlst, fp)
        with open(
                config['save_dir'] + 'weakinferep' + str(config['ep']) + 'cdy',
                'wb') as fp:
            pickle.dump(cdylst, fp)
        with open(
                config['save_dir'] + 'weakinferep' + str(config['ep']) + 'cdz',
                'wb') as fp:
            pickle.dump(cdzlst, fp)
        with open(
                config['save_dir'] + 'weakinferep' + str(config['ep']) + 'dim',
                'wb') as fp:
            pickle.dump(dimlst, fp)
        with open(
                config['save_dir'] + 'weakinferep' + str(config['ep']) + 'prb',
                'wb') as fp:
            pickle.dump(prblst, fp)
        with open(
                config['save_dir'] + 'weakinferep' + str(config['ep']) + 'pos',
                'wb') as fp:
            pickle.dump(poslst, fp)
        pdfrm = pd.read_csv(
            config['save_dir'] + 'weakinferep' + str(config['ep']) + '.csv',
            names=[
                'seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter',
                'probability', 'position'
            ])
        if srslst:
            fnmlst = srslst  #pdfrm['seriesuid'].tolist()[1:]
        dataset = weakdatav2.DataBowl3Detector(
            config_training['weaktrain_data_path'],
            list(set(fnmlst)),
            config,
            phase='train',
            fnmlst=srslst,
            cdxlst=cdxlst,
            cdylst=cdylst,
            cdzlst=cdzlst,
            dimlst=dimlst,
            prblst=prblst,
            poslst=poslst,
            lwzlst=lwzlst,
            upzlst=upzlst)
        weaktrain_loader = DataLoader(dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.workers,
                                      pin_memory=True)
        print(len(weaktrain_loader))
        for i, (data, target, coord, prob, pos,
                feat) in enumerate(weaktrain_loader):  # check data consistency
            # print(data.size(), target.size(), coord.size(), prob.size(), pos.size(), feat.size())
            if i >= len(trainfilelist) / args.batch_size:
                break
        weaktrain(weaktrain_loader, model2, loss2, optimizer2, net, loss,
                  epoch, optimizer, get_lr,
                  save_dir)  #, args.save_freq, save_dir)
        train(train_loader, net, loss, epoch, optimizer, get_lr,
              args.save_freq, save_dir)
        validate(val_loader, net, loss)
示例#13
0
def main():
    global args
    args = parser.parse_args()

    torch.manual_seed(0)
    torch.cuda.set_device(0)

    model = import_module(args.model)
    config, net, loss, get_pbb = model.get_model()
    start_epoch = args.start_epoch
    save_dir = args.save_dir

    if args.resume:
        checkpoint = torch.load(args.resume)
        if start_epoch == 0:
            start_epoch = checkpoint['epoch'] + 1
        if not save_dir:
            save_dir = checkpoint['save_dir']
        else:
            save_dir = os.path.join('results', save_dir)
        net.load_state_dict(checkpoint['state_dict'])
    else:
        if start_epoch == 0:
            start_epoch = 1
        if not save_dir:
            exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime())
            save_dir = os.path.join('results', args.model + '-' + exp_id)
        else:
            save_dir = os.path.join('results', save_dir)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    logfile = os.path.join(save_dir, 'log')
    if args.test != 1:
        sys.stdout = Logger(logfile)
        pyfiles = [f for f in os.listdir('./') if f.endswith('.py')]
        for f in pyfiles:
            shutil.copy(f, os.path.join(save_dir, f))
    n_gpu = setgpu(args.gpu)
    args.n_gpu = n_gpu
    net = net.cuda()
    loss = loss.cuda()
    cudnn.benchmark = True
    net = DataParallel(net)
    datadir = config_training['preprocess_path']

    if args.test == 1:
        margin = 32
        sidelen = 144

        split_comber = SplitComb(sidelen, config['max_stride'],
                                 config['stride'], margin, config['pad_value'])
        # test过程先注释掉,因为有validation,而且最初的demo(根目录下的main.py)也是一个预测的过程
        # dataset = data.DataBowl3Detector(
        #     datadir,
        #     'full.npy',
        #     config,
        #     phase='test',
        #     split_comber=split_comber)
        # test_loader = DataLoader(
        #     dataset,
        #     batch_size = 1,
        #     shuffle = False,
        #     num_workers = args.workers,
        #     collate_fn = data.collate,
        #     pin_memory=False)
        #
        # test(test_loader, net, get_pbb, save_dir,config)
        # return

    #net = DataParallel(net)

    dataset = data.DataBowl3Detector(
        datadir,
        config_training['preprocess_path'],  # fix 
        config,
        phase='train')
    train_loader = DataLoader(dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    # dataset = data.DataBowl3Detector(
    #     datadir,
    #     'valsplit.npy',
    #     config,
    #     phase = 'val')
    # val_loader = DataLoader(
    #     dataset,
    #     batch_size = args.batch_size,
    #     shuffle = False,
    #     num_workers = args.workers,
    #     pin_memory=True)

    optimizer = torch.optim.SGD(net.parameters(),
                                args.lr,
                                momentum=0.9,
                                weight_decay=args.weight_decay)

    def get_lr(epoch):
        if epoch <= args.epochs * 0.5:
            lr = args.lr
        elif epoch <= args.epochs * 0.8:
            lr = 0.1 * args.lr
        else:
            lr = 0.01 * args.lr
        return lr

    for epoch in range(start_epoch, args.epochs + 1):
        # TypeError: only size-1 arrays can be converted to Python scalars(传递参数错误)
        # IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number(pytorch版本问题)
        train(train_loader, net, loss, epoch, optimizer, get_lr,
              args.save_freq, save_dir)
示例#14
0
#load model
if os.path.exists(SAVED_MODEL):
    print(
        "*************************\n restore model\n*************************")
    model = load_model(SAVED_MODEL)
else:
    model = n_net()

adam = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
model.compile(
    optimizer=adam,
    loss=myloss,
)

# numbers of sample correspoding train and val
train_dataset = data.DataBowl3Detector(data_dir, config, phase='train')
train_samples = train_dataset.__len__()
val_dataset = data.DataBowl3Detector(data_dir, config, phase='val')
val_samples = val_dataset.__len__()


#call back.   save model named by (time,train_loss,val_loss)
class EpochSave(keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs={}):
        self.losses = []

    def on_epoch_end(self, epoch, logs={}):
        time_now = int(time.time())
        train_loss = logs.get('loss')
        val_loss = logs.get('val_loss')
        self.losses.append([train_loss, val_loss])