예제 #1
0
def predict_patch(img, patch_size, patch_step, nb_filters, nb_conv, batch_size,
                  wpath, spath):
    """
    the cnn model for image transformation


    Parameters
    ----------
    img : array
        The image need to be calculated


    Returns
    -------
    y_img
        Description.

      """
    patch_shape = (patch_size, patch_size)
    img = nor_data(img)
    pn, iy, ix = img.shape
    mdl = model_test(patch_size, patch_size, nb_filters, nb_conv)
    mdl.load_weights(wpath)
    for i in range(pn):
        print('Processing the %s th projection' % i)
        tstart = time.time()
        x_img = img[i]
        x_img = extract_3d(x_img, patch_shape, patch_step)
        x_img = np.reshape(x_img, (len(x_img), patch_size, patch_size, 1))
        y_img = mdl.predict(x_img, batch_size=batch_size)
        y_img = np.reshape(y_img, (len(x_img), patch_size, patch_size))
        y_img = reconstruct_patches(y_img, (iy, ix), patch_step)
        fname = spath + 'prj-' + str(i)
        dxchange.write_tiff(y_img, fname, dtype='float32')
        print('The prediction runs for %s seconds' % (time.time() - tstart))
예제 #2
0
def train_patch(img_x, img_y, patch_size, patch_step, nb_filters, nb_conv,
                batch_size, nb_epoch):
    """
    Function description.

    Parameters
    ----------
    parameter_01 : type
        Description.

    parameter_02 : type
        Description.

    parameter_03 : type
        Description.

    Returns
    -------
    return_01
        Description.
    """
    # pn, iy, ix = img_x.shape

    patch_shape = (patch_size, patch_size)
    img_x = nor_data(img_x)
    img_y = nor_data(img_y)

    train_x = extract_3d(img_x, patch_shape, patch_step)
    print train_x.shape
    train_y = extract_3d(img_y, patch_shape, patch_step)
    train_x = np.reshape(train_x, (len(train_x), patch_size, patch_size, 1))
    train_y = np.reshape(train_y, (len(train_y), patch_size, patch_size, 1))

    mdl = model_test(patch_size, patch_size, nb_filters, nb_conv)
    # mdl.load_weights(ipath)
    print(mdl.summary())
    mdl.fit(train_x,
            train_y,
            batch_size=batch_size,
            epochs=nb_epoch,
            shuffle=True)
    return mdl
예제 #3
0
파일: transform.py 프로젝트: ZichaoDi/Tao
def train(img_x, img_y, patch_size, patch_step, dim_img, nb_filters, nb_conv, batch_size, nb_epoch, x_test, y_test):
    """
    Function description.

    Parameters
    ----------
    parameter_01 : type
        Description.

    parameter_02 : type
        Description.

    parameter_03 : type
        Description.

    Returns
    -------
    return_01
        Description.
    """

    # img_x = nor_data(img_x)
    img_y = nor_data(img_y)
    img_input = extract_patches(img_x, patch_size, 1)
    img_output = extract_patches(img_y, patch_size, 1)
    img_input = np.reshape(img_input, (len(img_input), 1, dim_img, dim_img))
    img_output = np.reshape(img_output, (len(img_input), 1, dim_img, dim_img))

    # test_x = nor_data(x_test)
    test_y = nor_data(y_test)
    test_x = extract_patches(x_test, patch_size, 1)
    test_y = extract_patches(test_y, patch_size, 1)
    test_x = np.reshape(test_x, (len(test_x), 1, dim_img, dim_img))
    test_y = np.reshape(test_y, (len(test_y), 1, dim_img, dim_img))

    mdl = mirror_model(dim_img, nb_filters, nb_conv)
    print(mdl.summary())
    mdl.fit(img_input, img_output, batch_size=batch_size, nb_epoch=nb_epoch, validation_data=(test_x,test_y))
    return mdl
예제 #4
0
def train_full(img_x, img_y, nb_filters, nb_conv, batch_size, nb_epoch):
    """
    Function description.

    Parameters
    ----------
    parameter_01 : type
        Description.

    parameter_02 : type
        Description.

    parameter_03 : type
        Description.

    Returns
    -------
    return_01
        Description.
    """
    pn, iy, ix = img_x.shape

    img_x = nor_data(img_x)
    img_y = nor_data(img_y)

    train_x = np.reshape(img_x, (pn, iy, ix, 1))
    train_y = np.reshape(img_y, (pn, iy, ix, 1))

    mdl = model(iy, ix, nb_filters, nb_conv)
    print(mdl.summary())
    mdl.fit(train_x,
            train_y,
            batch_size=batch_size,
            epochs=nb_epoch,
            shuffle=True)
    return mdl