csm = tf.expand_dims(csm, 2) k0 = tf.expand_dims(k0, 0) # batch csm = tf.expand_dims(csm, 0) # batch k0 = tf.transpose( k0, [0, 4, 3, 1, 2]) # nb, nx, ny, nt, nc -> nb, nc, nt, nx, ny csm = tf.transpose(csm, [0, 4, 3, 1, 2]) # k0 = k0[:,:,0:18,:,:] # csm = csm[:,:,0:18,:,:] mask = tf.cast(tf.abs(k0) > 0, tf.complex64) # initialize network net = SLR_Net(mask, niter, learnedSVT) net.load_weights(weight_file) # Iterate over epochs. # forward # with tf.GradientTape() as tape: t0 = time.time() recon, X_SYM = net(k0, csm) t1 = time.time() recon_abs = tf.abs(recon) # loss_total = mse(LSrecon, LplusS_label)
mask = np.transpose(mask, [1, 0]) mask = np.reshape(mask, [1, 1, mask.shape[0], 1, mask.shape[1]]) mask = tf.cast(tf.constant(mask), tf.complex64) # prepare dataset dataset = get_dataset(mode, dataset_name, batch_size, shuffle=True, full=True) #dataset = get_dataset('test', dataset_name, batch_size, shuffle=True, full=True) tf.print('dataset loaded.') # initialize network if net_name == 'SLRNET': net = SLR_Net(mask, niter, learnedSVT) weight_file = 'models/stable/2020-10-23T12-09-22SLRNET_DYNAMIC_V2_MULTICOIL8/epoch-50/ckpt' net.load_weights(weight_file) tf.print('network initialized.') learning_rate_org = learning_rate learning_rate_decay = 0.95 optimizer = tf.optimizers.Adam(learning_rate_org) # Iterate over epochs. total_step = 0 param_num = 0 loss = 0