コード例 #1
0
def inference_stack(x_fn, in_depth, generator, output_folder, experiment_name):
    X = h5py.File(x_fn, 'r')['images']
    stack_z_len = X.shape[0]

    for s_idx in range(stack_z_len - in_depth):
        batch_X = np.expand_dims(np.transpose(X[s_idx:(s_idx + in_depth)],
                                              (1, 2, 0)),
                                 axis=0)
        batch_X.astype(np.float32)

        pred_img = generator.predict(batch_X)

        output = output_folder + "/" + experiment_name

        if not os.path.exists(output):
            os.makedirs(output)

        save2img(pred_img[0, :, :, 0], output + "/" + '%s.png' % (s_idx))
コード例 #2
0
                                      step=current_it)
                    tf.summary.scalar('discriminator_loss_real',
                                      tf.reduce_mean(disc_real_o),
                                      step=current_it)
                    tf.summary.scalar('discriminator_loss_fake',
                                      tf.reduce_mean(disc_fake_o),
                                      step=current_it)

                if (current_it) % (save_every // gene_iters) == 0:
                    tf.print('current_iteration:',
                             current_it,
                             output_stream=sys.stdout)
                    tf.print('dataset_path:',
                             dataset_path,
                             output_stream=sys.stdout)
                    tf.print('current_slice:', idx, output_stream=sys.stdout)

                    pred_img, current_sinogram = generator.predict(
                        inpainted_sinogram_masked)
                    save2img(pred_img[0, :, :, 0],
                             '%s/it%05d.png' % (itr_out_dir, current_it))

                    generator.save("%s/gen-it%05d.h5" % (itr_out_dir, current_it), \
                                include_optimizer=True)

                    discriminator.save("%s/disc-it%05d.h5" % (itr_out_dir, current_it), \
                                include_optimizer=True)

                sys.stdout.flush()
                current_it += 1
コード例 #3
0
def main(args):
    model = unet()
    # model = DnCNN(1, num_of_layers = 8)
    _ = model.apply(model_init)  # init model weights and bias

    masker = Masker(width=4, mode='zero')

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
        model = model.to(torch_devs)

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)

    for epoch in range(args.maxep + 1):
        time_it_st = time.time()
        X_mb, y_mb = mb_data_iter.next()
        time_data = 1000 * (time.time() - time_it_st)

        mdl_input, mask = masker.mask(X_mb, epoch)

        model.train()  # sets the module in training mode.
        optimizer.zero_grad()
        pred = model.forward(mdl_input)
        loss = criterion(pred * mask, X_mb * mask)
        loss.backward()
        optimizer.step()

        time_e2e = 1000 * (time.time() - time_it_st)
        itr_prints = '[Info] @ %.1f Epoch: %05d, loss: %.4f, elapse: %.2fs/itr' % (\
                    time.time(), epoch, loss.cpu().detach().numpy(), (time.time() - time_it_st), )
        print(itr_prints)

        if epoch % 100 == 0:
            if epoch == 0:
                val_ns, val_gt = get1batch4test(sidx=range(1),
                                                psz=2048,
                                                nvar=args.nvar)
                save2img(val_ns[0, 0].cpu().numpy(),
                         '%s/ns.png' % (itr_out_dir))
                save2img(val_gt[0, 0].cpu().numpy(),
                         '%s/gt.png' % (itr_out_dir))

            model.eval()  # sets the module in inference mode.
            with torch.no_grad():
                mdn = masker.infer_full_image(val_ns, model)
                ddn = model.forward(val_ns)

            save2img(mdn[0, 0].cpu().numpy(),
                     '%s/mdn-it%05d.png' % (itr_out_dir, epoch))
            save2img(ddn[0, 0].cpu().numpy(),
                     '%s/ddn-it%05d.png' % (itr_out_dir, epoch))

            if torch.cuda.device_count() > 1:
                torch.save(model.module.state_dict(),
                           "%s/mdl-it%05d.pth" % (itr_out_dir, epoch))
            else:
                torch.save(model.state_dict(),
                           "%s/mdl-it%05d.pth" % (itr_out_dir, epoch))
        sys.stdout.flush()
コード例 #4
0
            disc_loss = discriminator_loss(disc_real_o, disc_fake_o)

            disc_gradients = disc_tape.gradient(
                disc_loss, discriminator.trainable_variables)
            disc_optimizer.apply_gradients(
                zip(disc_gradients, discriminator.trainable_variables))

    print('%s; dloss: %.2f (r%.3f, f%.3f), disc_elapse: %.2fs/itr, gan_elapse: %.2fs/itr' % (itr_prints_gen,\
          disc_loss, disc_real_o.numpy().mean(), disc_fake_o.numpy().mean(), \
          (time.time() - time_dit_st)/disc_iters, time.time()-time_git_st))

    if epoch % (200 // gene_iters) == 0:
        X222, y222 = get1batch4test(x_fn=args.xtest,
                                    y_fn=args.ytest,
                                    in_depth=in_depth)
        pred_img = generator.predict(X222[:1])

        save2img(pred_img[0, :, :, 0], '%s/it%05d.png' % (itr_out_dir, epoch))
        if epoch == 0:
            save2img(y222[0, :, :, 0], '%s/gtruth.png' % (itr_out_dir))
            save2img(X222[0, :, :, in_depth // 2],
                     '%s/noisy.png' % (itr_out_dir))

        generator.save("%s/%s-it%05d.h5" % (itr_out_dir, args.expName, epoch), \
                       include_optimizer=False)

        # discriminator.save("%s/disc-it%05d.h5" % (itr_out_dir, epoch), \
        #                include_optimizer=False)

    sys.stdout.flush()
コード例 #5
0
            disc_fake_o = discriminator(gen_imgs, training=True)

            disc_loss = discriminator_loss(disc_real_o, disc_fake_o)

            disc_gradients = disc_tape.gradient(
                disc_loss, discriminator.trainable_variables)
            disc_optimizer.apply_gradients(
                zip(disc_gradients, discriminator.trainable_variables))

    print('%s; dloss: %.2f (r%.3f, f%.3f), disc_elapse: %.2fs/itr, gan_elapse: %.2fs/itr' % (itr_prints_gen,\
          disc_loss, disc_real_o.numpy().mean(), disc_fake_o.numpy().mean(), \
          (time.time() - time_dit_st)/args.itd, time.time()-time_git_st))

    if epoch % (500 // args.itg) == 0:
        X222, y222 = get1batch4test(dsfn=args.dsfn, in_depth=args.depth)
        pred_img = generator.predict(X222[:1])

        save2img(pred_img[0, :, :, 0], '%s/it%05d.png' % (itr_out_dir, epoch))
        if epoch == 0:
            save2img(y222[0, :, :, 0], '%s/gt.png' % (itr_out_dir))
            save2img(X222[0, :, :, args.depth // 2],
                     '%s/ns.png' % (itr_out_dir))

        generator.save("%s/%s-it%05d.h5" % (itr_out_dir, args.expName, epoch), \
                       include_optimizer=False)

        # discriminator.save("%s/disc-it%05d.h5" % (itr_out_dir, epoch), \
        #                include_optimizer=False)

    sys.stdout.flush()
コード例 #6
0
            disc_gradients = disc_tape.gradient(
                disc_loss, discriminator.trainable_variables)
            disc_optimizer.apply_gradients(
                zip(disc_gradients, discriminator.trainable_variables))

    print('%s; dloss: %.8f (r%.8f, f%.8f), disc_elapse: %.8fs/itr, gan_elapse: %.8fs/itr' % (itr_prints_gen,\
          disc_loss, disc_real_o.numpy().mean(), disc_fake_o.numpy().mean(), \
          (time.time() - time_dit_st)/disc_iters, time.time()-time_git_st))

    if epoch % (200 // gene_iters) == 0:
        X222, y222 = get1batch4test(x_fn=args.xtest,
                                    y_fn=args.ytest,
                                    in_depth=in_depth)
        pred_img = generator.predict(X222[:1])
        save2img(pred_img[0, :, :, 0],
                 '%s/it%05d.png' % (itr_out_dir + '/iteration_results', epoch))
        save2img(pred_img[0, :, :, 0], '%s/last_iteration.png' % (itr_out_dir))
        if epoch == 0:
            save2img(y222[0, :, :, 0], '%s/gtruth.png' % (itr_out_dir))
            save2img(X222[0, :, :, in_depth // 2],
                     '%s/noisy.png' % (itr_out_dir))

        generator.save("%s/%s-last-model.h5" % (itr_out_dir, args.expName),
                       include_optimizer=False)
        generator.save(
            "%s/%s-it%05d.h5" %
            (itr_out_dir + '/iteration_models', args.expName, epoch),
            include_optimizer=False)

        # discriminator.save("%s/disc-it%05d.h5" % (itr_out_dir, epoch), \
        #                include_optimizer=False)
コード例 #7
0
                tf.summary.scalar('discriminator_loss_real',
                                  tf.reduce_mean(disc_real_o),
                                  step=epoch)
                tf.summary.scalar('discriminator_loss_fake',
                                  tf.reduce_mean(disc_fake_o),
                                  step=epoch)

            if (epoch % save_every) == 0:

                tf.print('Iteration:', epoch, output_stream=sys.stdout)

                recon_imgs_final = fc_inverse_radon.predict(gen_prediction)

                current_sinogram, _ = generator.predict(
                    inpainted_sinogram_masked)
                save2img(recon_imgs_final[0, :, :, 0],
                         '%s/recon_it_%05d.png' % (itr_out_dir, epoch))
                save2img(current_sinogram[0, :, :, 0],
                         '%s/sinogram_%05d.png' % (itr_out_dir, epoch))

            sys.stdout.flush()
            epoch += 1

        recon_imgs_final = fc_inverse_radon.predict(gen_prediction)
        current_sinogram, _ = generator.predict(inpainted_sinogram_masked)
        save2img(recon_imgs_final[0, :, :, 0],
                 '%s/final_recon_%05d.png' % (final_recon_out_dir, idx))
        save2img(current_sinogram[0, :, :, 0],
                 '%s/sinogram_%05d.png' % (final_recon_out_dir, idx))
        save2img(inpainted_sinogram_masked.numpy()[0, :, :, 0],
                 '%s/original_sinogram_%05d.png' % (final_recon_out_dir, idx))