Ejemplo n.º 1
0
    n_layers_trained = n_DAMP_layers
    theta = [None] * n_layers_trained
    for iter in range(n_layers_trained):
        with tf.variable_scope("Iter" + str(iter)):
            theta_thisIter = LDAMP.init_vars_DnCNN(init_mu, init_sigma)
        theta[iter] = theta_thisIter

## Construct model
y_measured = LDAMP.GenerateNoisyCSData_handles(x_true, A_handle, sigma_w,
                                               A_val_tf)
if alg == 'DAMP':
    (x_hat, MSE_history, NMSE_history, PSNR_history, r, rvar,
     dxdr) = LDAMP.LDAMP(y_measured,
                         A_handle,
                         At_handle,
                         A_val_tf,
                         theta,
                         x_true,
                         tie=tie_weights)
elif alg == 'DIT':
    (x_hat, MSE_history, NMSE_history,
     PSNR_history) = LDAMP.LDIT(y_measured,
                                A_handle,
                                At_handle,
                                A_val_tf,
                                theta,
                                x_true,
                                tie=tie_weights)
else:
    raise ValueError('alg was not a supported option')
Ejemplo n.º 2
0
        n_layers_trained = 1
    else:
        n_layers_trained = n_DAMP_layers
    theta = [None] * n_layers_trained
    for iter in range(n_layers_trained):
        with tf.variable_scope("Iter" + str(iter)):
            theta_thisIter = LDAMP.init_vars_DnCNN(init_mu, init_sigma)
        theta[iter] = theta_thisIter

    ## Construct the measurement model and handles/placeholders
    [A_handle, At_handle, A_val, A_val_tf] = LDAMP.GenerateMeasurementOperators(measurement_mode)
    y_measured = LDAMP.GenerateNoisyCSData_handles(x_true, A_handle, sigma_w, A_val_tf)

    ## Construct the reconstruction model
    if alg=='DAMP':
        (x_hat, MSE_history, NMSE_history, PSNR_history, r_final, rvar_final, div_overN) = LDAMP.LDAMP(y_measured,A_handle,At_handle,A_val_tf,theta,x_true,tie=tie_weights,training=training_tf,LayerbyLayer=LayerbyLayer)
    elif alg=='DIT':
        (x_hat, MSE_history, NMSE_history, PSNR_history) = LDAMP.LDIT(y_measured,A_handle,At_handle,A_val_tf,theta,x_true,tie=tie_weights,training=training_tf,LayerbyLayer=LayerbyLayer)
    else:
        raise ValueError('alg was not a supported option')

    ## Define loss and determine which variables to train
    nfp = np.float32(height_img * width_img)
    if loss_func=='SURE':
        assert alg=='DAMP', "Only LDAMP supports training with SURE"
        cost = LDAMP.MCSURE_loss(x_hat, div_overN, r_final, tf.sqrt(rvar_final))
    elif loss_func=='GSURE':
        assert alg == 'DAMP', "Only LDAMP currently supports training with GSURE"
        temp0=tf.matmul(A_val_tf,A_val_tf,transpose_b=True)
        temp1=tf.matrix_inverse(temp0)
        pinv_A=tf.matmul(A_val_tf,temp1,transpose_a=True)