예제 #1
0
        csm = tf.expand_dims(csm, 2)

        k0 = tf.expand_dims(k0, 0)  # batch
        csm = tf.expand_dims(csm, 0)  # batch

        k0 = tf.transpose(
            k0, [0, 4, 3, 1, 2])  # nb, nx, ny, nt, nc -> nb, nc, nt, nx, ny
        csm = tf.transpose(csm, [0, 4, 3, 1, 2])

        # k0 = k0[:,:,0:18,:,:]
        # csm = csm[:,:,0:18,:,:]

        mask = tf.cast(tf.abs(k0) > 0, tf.complex64)

        # initialize network
        net = SLR_Net(mask, niter, learnedSVT)
        net.load_weights(weight_file)

        # Iterate over epochs.
        # forward

        # with tf.GradientTape() as tape:

        t0 = time.time()
        recon, X_SYM = net(k0, csm)
        t1 = time.time()

        recon_abs = tf.abs(recon)

        # loss_total = mse(LSrecon, LplusS_label)
예제 #2
0
    mask = np.transpose(mask, [1, 0])
    mask = np.reshape(mask, [1, 1, mask.shape[0], 1, mask.shape[1]])
    mask = tf.cast(tf.constant(mask), tf.complex64)

    # prepare dataset
    dataset = get_dataset(mode,
                          dataset_name,
                          batch_size,
                          shuffle=True,
                          full=True)
    #dataset = get_dataset('test', dataset_name, batch_size, shuffle=True, full=True)
    tf.print('dataset loaded.')

    # initialize network
    if net_name == 'SLRNET':
        net = SLR_Net(mask, niter, learnedSVT)

    weight_file = 'models/stable/2020-10-23T12-09-22SLRNET_DYNAMIC_V2_MULTICOIL8/epoch-50/ckpt'
    net.load_weights(weight_file)

    tf.print('network initialized.')

    learning_rate_org = learning_rate
    learning_rate_decay = 0.95

    optimizer = tf.optimizers.Adam(learning_rate_org)

    # Iterate over epochs.
    total_step = 0
    param_num = 0
    loss = 0