示例#1
0
                 ylim=(-0.5, 0.5))

    # Train
    model = getattr(models, args.model)(noise_dim, 2, lambda0, lambda1)
    if args.mode == 'train':
        timer = ElapsedTimer()
        model.train(data_obj,
                    func_obj,
                    val_scale,
                    batch_size=batch_size,
                    train_steps=train_steps,
                    disc_lr=disc_lr,
                    gen_lr=gen_lr,
                    save_interval=save_interval,
                    save_dir=save_dir)
        elapsed_time = timer.elapsed_time()
        runtime_mesg = 'Wall clock time for training: %s' % elapsed_time
        print(runtime_mesg)
    else:
        model.restore(save_dir=save_dir)

    print(
        '##########################################################################'
    )
    print('Plotting generated samples ...')

    # Plot generated samples
    n = 1000
    gen_data = model.synthesize(n)
    visualize_2d(data[:n],
                 func=func,
示例#2
0
文件: main.py 项目: KUR-creative/unet
def main(experiment_yml_path):
    with open(experiment_yml_path, 'r') as f:
        settings = yaml.load(f)
    experiment_name, _ = os.path.splitext(
        os.path.basename(experiment_yml_path))
    print('->', experiment_name)
    for k, v in settings.items():
        print(k, '=', v)
    #----------------------- experiment settings ------------------------
    IMG_SIZE = settings['IMG_SIZE']
    BATCH_SIZE = settings['BATCH_SIZE']
    NUM_EPOCHS = settings['NUM_EPOCHS']

    data_augmentation = settings['data_augmentation']  # string

    dataset_dir = settings['dataset_dir']
    save_model_path = settings['save_model_path']  ## NOTE
    history_path = settings['history_path']  ## NOTE

    eval_result_dirpath = os.path.join(settings['eval_result_parent_dir'],
                                       experiment_name)
    # optional settings
    sqr_crop_dataset = settings.get('sqr_crop_dataset')
    kernel_init = settings.get('kernel_init')
    num_maxpool = settings.get('num_maxpool')
    num_filters = settings.get('num_filters')
    overlap_factor = settings.get('overlap_factor')
    #loaded_model = save_model_path ## NOTE
    loaded_model = None
    #--------------------------------------------------------------------

    #--------------------------------------------------------------------
    train_dir = os.path.join(dataset_dir, 'train')
    valid_dir = os.path.join(dataset_dir, 'valid')
    test_dir = os.path.join(dataset_dir, 'test')

    output_dir = os.path.join(dataset_dir, 'output')
    origin_dir = os.path.join(output_dir, 'image')
    answer_dir = os.path.join(output_dir, 'label')
    result_dir = os.path.join(output_dir, 'result')
    #--------------------------------------------------------------------

    #-------------------- ready to generate batch -----------------------
    train_imgs = list(load_imgs(os.path.join(train_dir, 'image')))
    train_masks = list(load_imgs(os.path.join(train_dir, 'label')))
    valid_imgs = list(load_imgs(os.path.join(valid_dir, 'image')))
    valid_masks = list(load_imgs(os.path.join(valid_dir, 'label')))
    test_imgs = list(load_imgs(os.path.join(test_dir, 'image')))
    test_masks = list(load_imgs(os.path.join(test_dir, 'label')))

    if overlap_factor is None: overlap_factor = 2
    #calc mean h,w of dataset
    tr_h, tr_w = sum(map(lambda img: np.array(img.shape[:2]),
                         train_imgs)) / len(train_imgs)
    vl_h, vl_w = sum(map(lambda img: np.array(img.shape[:2]),
                         valid_imgs)) / len(valid_imgs)
    te_h, te_w = sum(map(lambda img: np.array(img.shape[:2]),
                         test_imgs)) / len(test_imgs)
    #print(tr_h,tr_w, '|', vl_h,vl_w, '|', te_h,te_w)
    train_num_sample = int(
        (tr_h / IMG_SIZE) * (tr_w / IMG_SIZE) * overlap_factor)
    valid_num_sample = int(
        (vl_h / IMG_SIZE) * (vl_w / IMG_SIZE) * overlap_factor)
    test_num_sample = int(
        (te_h / IMG_SIZE) * (te_w / IMG_SIZE) * overlap_factor)
    #print(train_num_sample,valid_num_sample,test_num_sample)
    train_steps_per_epoch = modulo_ceil(
        len(train_imgs), BATCH_SIZE) // BATCH_SIZE * train_num_sample
    valid_steps_per_epoch = modulo_ceil(
        len(valid_imgs), BATCH_SIZE) // BATCH_SIZE * valid_num_sample
    test_steps_per_epoch = modulo_ceil(
        len(test_imgs), BATCH_SIZE) // BATCH_SIZE * test_num_sample
    print('# train images =', len(train_imgs), '| train steps/epoch =',
          train_steps_per_epoch)
    print('# valid images =', len(valid_imgs), '| valid steps/epoch =',
          valid_steps_per_epoch)
    print('#  test images =', len(test_imgs), '|  test steps/epoch =',
          test_steps_per_epoch)

    if data_augmentation == 'bioseg':
        aug = augmenter(BATCH_SIZE,
                        IMG_SIZE,
                        1,
                        crop_before_augs=[
                            iaa.Fliplr(0.5),
                            iaa.Flipud(0.5),
                            iaa.Affine(rotate=(-180, 180), mode='reflect'),
                        ],
                        crop_after_augs=[
                            iaa.ElasticTransformation(alpha=(100, 200),
                                                      sigma=14,
                                                      mode='reflect'),
                        ])
    elif data_augmentation == 'manga_gb':
        aug = augmenter(BATCH_SIZE,
                        IMG_SIZE,
                        1,
                        crop_before_augs=[
                            iaa.Affine(rotate=(-3, 3),
                                       shear=(-3, 3),
                                       scale={
                                           'x': (0.8, 1.5),
                                           'y': (0.8, 1.5)
                                       },
                                       mode='reflect'),
                        ])
    elif data_augmentation == 'no_aug':
        aug = augmenter(BATCH_SIZE, IMG_SIZE, 1)

    if sqr_crop_dataset:
        aug = None

    my_gen = batch_gen(train_imgs, train_masks, BATCH_SIZE, aug)
    valid_gen = batch_gen(valid_imgs, valid_masks, BATCH_SIZE, aug)
    test_gen = batch_gen(test_imgs, test_masks, BATCH_SIZE, aug)
    #--------------------------------------------------------------------
    '''
    # DEBUG
    for ims,mas in my_gen:
        for im,ma in zip(ims,mas):
            cv2.imshow('i',im)
            cv2.imshow('m',ma); cv2.waitKey(0)
    '''
    #---------------------------- train model ---------------------------
    if kernel_init is None: kernel_init = 'he_normal'
    if num_maxpool is None: num_maxpool = 4
    if num_filters is None: num_filters = 64

    LEARNING_RATE = 1.0
    model = unet(pretrained_weights=loaded_model,
                 input_size=(IMG_SIZE, IMG_SIZE, 1),
                 kernel_init=kernel_init,
                 num_filters=num_filters,
                 num_maxpool=num_maxpool,
                 lr=LEARNING_RATE)

    model_checkpoint = ModelCheckpoint(save_model_path,
                                       monitor='val_loss',
                                       verbose=1,
                                       save_best_only=True)
    train_timer = ElapsedTimer(experiment_yml_path + ' training')
    history = model.fit_generator(my_gen,
                                  epochs=NUM_EPOCHS,
                                  steps_per_epoch=train_steps_per_epoch,
                                  validation_steps=valid_steps_per_epoch,
                                  validation_data=valid_gen,
                                  callbacks=[model_checkpoint])
    train_time_str = train_timer.elapsed_time()
    #--------------------------------------------------------------------

    #--------------------------- save results ---------------------------
    origins = list(load_imgs(origin_dir))
    answers = list(load_imgs(answer_dir))
    assert len(origins) == len(answers)

    num_imgs = len(origins)

    if not sqr_crop_dataset:
        aug_det = augmenter(num_imgs, IMG_SIZE,
                            1).to_deterministic()  # no augmentation!
        origins = aug_det.augment_images(origins)
        answers = aug_det.augment_images(answers)

    predictions = model.predict_generator(
        (img.reshape(1, IMG_SIZE, IMG_SIZE, 1) for img in origins),
        num_imgs,
        verbose=1)
    evaluator.save_img_tuples(zip(origins, answers, predictions), result_dir)

    test_metrics = model.evaluate_generator(test_gen,
                                            steps=test_steps_per_epoch)
    K.clear_session()
    #print(model.metrics_names)
    #print(test_metrics)
    print('test set: loss =', test_metrics[0], '| IoU =', test_metrics[1])
    #--------------------------------------------------------------------

    #------------------- evaluation and save results --------------------
    with open(history_path, 'w') as f:
        f.write(
            yaml.dump(
                dict(
                    loss=list(map(np.asscalar, history.history['loss'])),
                    acc=list(map(np.asscalar, history.history['mean_iou'])),
                    val_loss=list(map(np.asscalar,
                                      history.history['val_loss'])),
                    val_acc=list(
                        map(np.asscalar, history.history['val_mean_iou'])),
                    test_loss=np.asscalar(test_metrics[0]),
                    test_acc=np.asscalar(test_metrics[1]),
                    train_time=train_time_str,
                )))

    modulo = 2**num_maxpool
    evaluator.eval_and_save_result(
        dataset_dir,
        save_model_path,
        eval_result_dirpath,
        files_2b_copied=[history_path, experiment_yml_path],
        num_filters=num_filters,
        num_maxpool=num_maxpool,
        modulo=modulo)
示例#3
0
if is_trainable:
    # Read the pickle file
    Data_A = read_pickle('./Data/Data_Train/Data_Left_train.pkl')
    Data_B = read_pickle('./Data/Data_Train/Data_Right_train.pkl')
    print("Data A/B: ", Data_A.shape, Data_B.shape)
    # Initialize the model
    assert Data_A.shape == Data_B.shape
    if len(Data_A.shape) == 4 and len(Data_B.shape) == 4:
        img_shape = (Data_A.shape[1], Data_A.shape[2], Data_A.shape[3])
        banis = BANIS(img_shape)
    else:
        print("The shape of input dataset don't match!!!")
    # Train the model and record the runtime
    timer = ElapsedTimer()
    banis.train(Data_A,
                Data_B,
                EPOCHS=n_epochs,
                BATCH_SIZE=128,
                WARMUP_STEP=n_step,
                NUM_IMG=5)
    timer.elapsed_time()
else:
    # Plotting the sampling images
    A_gen_list = np.load("./A_gen_baait.npy")
    plot_samples(A_gen_list, name='Agen')
    B_gen_list = np.load("./B_gen_baait.npy")
    plot_samples(B_gen_list, name='Bgen')
    AB_rec_list = np.load("./AB_rec_baait.npy")
    plot_samples(AB_rec_list, name='ABrec')
示例#4
0
文件: main.py 项目: KUR-creative/unet
                                      history.history['val_loss'])),
                    val_acc=list(
                        map(np.asscalar, history.history['val_mean_iou'])),
                    test_loss=np.asscalar(test_metrics[0]),
                    test_acc=np.asscalar(test_metrics[1]),
                    train_time=train_time_str,
                )))

    modulo = 2**num_maxpool
    evaluator.eval_and_save_result(
        dataset_dir,
        save_model_path,
        eval_result_dirpath,
        files_2b_copied=[history_path, experiment_yml_path],
        num_filters=num_filters,
        num_maxpool=num_maxpool,
        modulo=modulo)
    #--------------------------------------------------------------------


if __name__ == '__main__':
    with open('experiment_log', 'w') as log:
        for experiment_path in human_sorted(file_paths(sys.argv[1])):
            try:
                timer = ElapsedTimer(experiment_path)
                main(experiment_path)
                log.write(timer.elapsed_time())
            except AssertionError as error:
                print(str(error))
                log.write(str(error))
示例#5
0
def train(DATASET_NAME,
          NUM_EPOCH,
          Tc,
          Td,
          SAVE_INTERVAL,
          MAILING_ENABLED,
          learned_data_ratio,
          now_epoch=0,
          Cmodel=None,
          Dmodel=None,
          CDmodel=None):
    if (Cmodel is None) and (Dmodel is None) and (CDmodel is None):
        Cmodel, Dmodel, CDmodel = init_models()
    ''' model sanity checking
    from keras.utils import plot_model
    Cmodel.summary(); Dmodel.summary(); CDmodel.summary();
    plot_model(Cmodel, to_file='Cmodel.png', show_shapes=True)
    plot_model(Dmodel, to_file='Dmodel.png', show_shapes=True)
    plot_model(CDmodel, to_file='CDmodel.png', show_shapes=True)
    '''
    data_file = h5py.File(DATASET_NAME, 'r')
    #-------------------------------------------------------------------------------
    data_arr = data_file['images'][:]  # already preprocessed, float32.
    mean_pixel_value = data_file['mean_pixel_value'][()]  # value is float
    learned_arr_len = int(data_arr.shape[0] * learned_data_ratio)
    learned_arr_len = learned_arr_len - (learned_arr_len % BATCH_SIZE
                                         )  #never use remainders..
    print('data_arr shape: ', data_arr.shape)
    print('length of data to learn: ', learned_arr_len)

    timer = ElapsedTimer('Total Training')
    #-------------------------------------------------------------------------------
    for epoch in range(now_epoch, NUM_EPOCH):
        #epoch_timer = ElapsedTimer()
        #--------------------------------------------------------------------------
        for batch in gen_batch(data_arr, BATCH_SIZE, IMG_SHAPE, LD_CROP_SIZE,
                               HOLE_MIN_LEN, HOLE_MAX_LEN, mean_pixel_value,
                               learned_arr_len):
            if epoch < Tc:
                mse_loss = trainC(Cmodel, batch, epoch)
            else:
                bce_d_loss = trainD(Cmodel, Dmodel, batch, epoch)
                if epoch >= Tc + Td:
                    joint_loss, mse, gan = trainC_in(CDmodel, batch, epoch)
        #--------------------------------------------------------------------------
        if epoch < Tc:
            print('epoch {}: [C mse loss: {}]'.format(epoch, mse_loss),
                  flush=True)  #, end='')
        else:
            if epoch >= Tc + Td:
                print('epoch {}: [joint loss: {} | mse loss: {}, gan loss: {}]'\
                       .format(epoch, joint_loss, mse, gan), flush=True)#, end='')
            else:
                print('epoch {}: [D bce loss: {}]'.format(epoch, bce_d_loss),
                      flush=True)  #, end='')
        #epoch_timer.elapsed_time()
        save(Cmodel, Dmodel, batch, SAVE_INTERVAL, epoch, NUM_EPOCH, 'output')
    #-------------------------------------------------------------------------------
    time_str = timer.elapsed_time()
    data_file.close()

    if MAILING_ENABLED:
        import mailing
        mailing.send_mail_to_kur(time_str)