Exemplo n.º 1
0
def train():
    cube_len = FLAGS.cube_len
    output_dir = os.path.join(FLAGS.output_dir, FLAGS.category)
    checkpoint_dir = os.path.join(output_dir, 'checkpoints')
    synthesis_dir = os.path.join(output_dir, 'recovery')
    log_dir = os.path.join(output_dir, 'log')

    obs = tf.placeholder(tf.float32, [None, cube_len, cube_len, cube_len, 1],
                         name='obs_data')
    syn = tf.placeholder(tf.float32, [None, cube_len, cube_len, cube_len, 1],
                         name='syn_data')

    obs_res = descriptor(obs, reuse=False)
    syn_res = descriptor(syn, reuse=True)

    recon_err = tf.square(
        tf.reduce_mean(syn, axis=0) - tf.reduce_mean(obs, axis=0))
    des_loss = tf.subtract(tf.reduce_mean(syn_res, axis=0),
                           tf.reduce_mean(obs_res, axis=0))

    syn_langevin = langevin_dynamics(syn)

    train_data = data_io.getObj(FLAGS.data_path,
                                FLAGS.category,
                                train=True,
                                cube_len=cube_len,
                                num_voxels=FLAGS.train_size,
                                low_bound=0,
                                up_bound=1)
    num_voxels = len(train_data)

    incomplete_data = np.zeros(train_data.shape)
    masks = np.zeros(train_data.shape)
    for i in range(len(incomplete_data)):
        incomplete_data[i], masks[i] = get_incomplete_data(train_data[i])

    train_data = train_data[..., np.newaxis]
    incomplete_data = incomplete_data[..., np.newaxis]
    masks = masks[..., np.newaxis]

    data_io.saveVoxelsToMat(train_data,
                            "%s/observed_data.mat" % output_dir,
                            cmin=0,
                            cmax=1)
    data_io.saveVoxelsToMat(incomplete_data,
                            "%s/incomplete_data.mat" % output_dir,
                            cmin=0,
                            cmax=1)

    voxel_mean = train_data.mean()
    train_data = train_data - voxel_mean
    incomplete_data = incomplete_data - voxel_mean

    num_batches = int(math.ceil(num_voxels / FLAGS.batch_size))
    # descriptor variables
    des_vars = [
        var for var in tf.trainable_variables() if var.name.startswith('des')
    ]

    des_optim = tf.train.AdamOptimizer(FLAGS.d_lr, beta1=FLAGS.beta1)
    des_grads_vars = des_optim.compute_gradients(des_loss, var_list=des_vars)
    des_grads = [
        tf.reduce_mean(tf.abs(grad)) for (grad, var) in des_grads_vars
        if '/w' in var.name
    ]
    # update by mean of gradients
    apply_d_grads = des_optim.apply_gradients(des_grads_vars)

    with tf.Session() as sess:
        # initialize training
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(max_to_keep=50)

        recover_voxels = np.random.randn(num_voxels, cube_len, cube_len,
                                         cube_len, 1)

        des_loss_epoch = []
        recon_err_epoch = []
        plt.ion()

        for epoch in range(FLAGS.num_epochs):
            d_grad_vec = []
            des_loss_vec = []
            recon_err_vec = []

            init_data = incomplete_data.copy()
            start_time = time.time()
            for i in range(num_batches):
                indices = slice(i * FLAGS.batch_size,
                                min(num_voxels, (i + 1) * FLAGS.batch_size))
                obs_data = train_data[indices]
                syn_data = init_data[indices]
                data_mask = masks[indices]

                # Langevin Sampling
                sample = sess.run(syn_langevin, feed_dict={syn: syn_data})
                syn_data = sample * (1 - data_mask) + syn_data * data_mask

                # learn D net
                d_grad, d_loss = \
                    sess.run([des_grads, des_loss, apply_d_grads], feed_dict={obs: obs_data, syn: syn_data})[:2]

                d_grad_vec.append(d_grad)
                des_loss_vec.append(d_loss)
                # Compute MSE
                mse = sess.run(recon_err,
                               feed_dict={
                                   obs: obs_data,
                                   syn: syn_data
                               })
                recon_err_vec.append(mse)
                recover_voxels[indices] = syn_data

            end_time = time.time()
            d_grad_mean, des_loss_mean, recon_err_mean = float(np.mean(d_grad_vec)), float(np.mean(des_loss_vec)), \
                                                         float(np.mean(recon_err_vec))
            des_loss_epoch.append(des_loss_mean)
            recon_err_epoch.append(recon_err_mean)
            print(
                'Epoch #%d, descriptor loss: %.4f, descriptor SSD weight: %.4f, Avg MSE: %4.4f, time: %.2fs'
                % (epoch, des_loss_mean, d_grad_mean, recon_err_mean,
                   end_time - start_time))

            if epoch % FLAGS.log_step == 0:
                if not os.path.exists(synthesis_dir):
                    os.makedirs(synthesis_dir)
                data_io.saveVoxelsToMat(
                    recover_voxels + voxel_mean,
                    "%s/sample%04d.mat" % (synthesis_dir, epoch))

                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                saver.save(sess,
                           "%s/%s" % (checkpoint_dir, 'model.ckpt'),
                           global_step=epoch)

                if not os.path.exists(log_dir):
                    os.makedirs(log_dir)
                plt.figure(1)
                data_io.draw_graph(plt, des_loss_epoch, 'des_loss', log_dir,
                                   'r')
                plt.figure(2)
                data_io.draw_graph(plt, recon_err_epoch, 'recon_error',
                                   log_dir, 'b')
Exemplo n.º 2
0
def test():
    assert FLAGS.ckpt != None, 'no model provided.'
    cube_len = FLAGS.cube_len
    incomp_dir = os.path.join(FLAGS.incomp_data_path, FLAGS.category)
    test_dir = os.path.join(FLAGS.output_dir, FLAGS.category, 'test')

    syn = tf.placeholder(tf.float32, [None, cube_len, cube_len, cube_len, 1],
                         name='syn_data')
    syn_res = descriptor(syn, reuse=False)
    syn_langevin = langevin_dynamics(syn)

    train_data = data_io.getObj(FLAGS.data_path,
                                FLAGS.category,
                                train=True,
                                cube_len=cube_len,
                                num_voxels=FLAGS.train_size,
                                low_bound=0,
                                up_bound=1)

    incomplete_data = data_io.getVoxelsFromMat('%s/incomplete_test.mat' %
                                               incomp_dir,
                                               data_name='voxels')
    masks = np.array(io.loadmat(('%s/masks.mat' % incomp_dir))['masks'],
                     dtype=np.float32)

    sample_size = len(incomplete_data)

    masks = masks[..., np.newaxis]
    incomplete_data = incomplete_data[..., np.newaxis]
    voxel_mean = train_data.mean()
    incomplete_data = incomplete_data - voxel_mean
    num_batches = int(math.ceil(sample_size / FLAGS.batch_size))

    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        print('Loading checkpoint {}.'.format(FLAGS.ckpt))
        saver.restore(sess, FLAGS.ckpt)

        init_data = incomplete_data.copy()
        sample_voxels = np.random.randn(sample_size, cube_len, cube_len,
                                        cube_len, 1)

        for i in range(num_batches):
            indices = slice(i * FLAGS.batch_size,
                            min(sample_size, (i + 1) * FLAGS.batch_size))
            syn_data = init_data[indices]
            data_mask = masks[indices]

            # Langevin Sampling
            sample = sess.run(syn_langevin, feed_dict={syn: syn_data})

            sample_voxels[indices] = sample * (
                1 - data_mask) + syn_data * data_mask

        if not os.path.exists(test_dir):
            os.makedirs(test_dir)
        data_io.saveVoxelsToMat(sample_voxels + voxel_mean,
                                "%s/recovery.mat" % test_dir,
                                cmin=0,
                                cmax=1)
Exemplo n.º 3
0
def test():
    assert FLAGS.ckpt != None, 'no model provided.'
    cube_len = FLAGS.cube_len
    scale = FLAGS.scale
    batch_size = FLAGS.batch_size

    test_dir = os.path.join(FLAGS.output_dir, FLAGS.category, 'test')

    lr_size = cube_len // scale
    obs = tf.placeholder(tf.float32, [None, cube_len, cube_len, cube_len, 1],
                         name='obs_data')
    syn = tf.placeholder(tf.float32, [None, cube_len, cube_len, cube_len, 1],
                         name='syn_data')
    low_res = tf.placeholder(tf.float32, [None, lr_size, lr_size, lr_size, 1],
                             name='low_res')

    down_syn = downsample(obs, scale)
    up_syn = upsample(low_res, scale)

    syn_res = descriptor(syn, reuse=False)
    syn_langevin = langevin_dynamics(syn)
    sr_res = obs + syn - avg_pool(syn, scale)

    train_data = data_io.getObj(FLAGS.data_path,
                                FLAGS.category,
                                train=True,
                                cube_len=cube_len,
                                num_voxels=FLAGS.train_size)
    test_data = data_io.getObj(FLAGS.data_path,
                               FLAGS.category,
                               train=False,
                               cube_len=cube_len,
                               num_voxels=FLAGS.test_size)

    if not os.path.exists(test_dir):
        os.makedirs(test_dir)
    data_io.saveVoxelsToMat(test_data,
                            "%s/observed_data.mat" % test_dir,
                            cmin=0,
                            cmax=1)
    sample_size = len(test_data)
    num_batches = int(math.ceil(sample_size / batch_size))

    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        print('Generating low resolution data')
        test_data = test_data[..., np.newaxis]
        up_samples = np.zeros(test_data.shape)
        down_samples = np.zeros(
            shape=[sample_size, lr_size, lr_size, lr_size, 1])
        for i in range(num_batches):
            indices = slice(i * batch_size,
                            min(sample_size, (i + 1) * batch_size))
            obs_data = test_data[indices]
            ds = sess.run(down_syn, feed_dict={obs: obs_data})
            us = sess.run(up_syn, feed_dict={low_res: ds})
            down_samples[indices] = ds
            up_samples[indices] = us

        data_io.saveVoxelsToMat(down_samples,
                                "%s/down_sample.mat" % test_dir,
                                cmin=0,
                                cmax=1)
        data_io.saveVoxelsToMat(up_samples,
                                "%s/up_sample.mat" % test_dir,
                                cmin=0,
                                cmax=1)

        voxel_mean = train_data.mean()
        up_samples = up_samples - voxel_mean

        print 'Loading checkpoint {}.'.format(FLAGS.ckpt)
        saver.restore(sess, FLAGS.ckpt)

        init_data = up_samples.copy()
        sample_voxels = np.random.randn(sample_size, cube_len, cube_len,
                                        cube_len, 1)

        for i in range(num_batches):
            indices = slice(i * batch_size,
                            min(sample_size, (i + 1) * batch_size))
            us_data = init_data[indices]

            # Langevin Sampling
            y1 = sess.run(syn_langevin, feed_dict={syn: us_data})

            sample_voxels[indices] = y1

        data_io.saveVoxelsToMat(sample_voxels + voxel_mean,
                                "%s/samples.mat" % test_dir,
                                cmin=0,
                                cmax=1)
Exemplo n.º 4
0
def train():
    cube_len = FLAGS.cube_len
    scale = FLAGS.scale
    batch_size = FLAGS.batch_size

    output_dir = os.path.join(FLAGS.output_dir, FLAGS.category)
    checkpoint_dir = os.path.join(output_dir, 'checkpoints')
    synthesis_dir = os.path.join(output_dir, 'sr_results')
    log_dir = os.path.join(output_dir, 'log')

    lr_size = cube_len // scale
    obs = tf.placeholder(tf.float32, [None, cube_len, cube_len, cube_len, 1],
                         name='obs_data')
    syn = tf.placeholder(tf.float32, [None, cube_len, cube_len, cube_len, 1],
                         name='syn_data')
    low_res = tf.placeholder(tf.float32, [None, lr_size, lr_size, lr_size, 1],
                             name='low_res')

    down_syn = downsample(obs, scale)
    up_syn = upsample(low_res, scale)
    obs_res = descriptor(obs, reuse=False)
    syn_res = descriptor(syn, reuse=True)
    sr_res = obs + syn - avg_pool(syn, scale)

    recon_err = tf.square(
        tf.reduce_mean(syn, axis=0) - tf.reduce_mean(obs, axis=0))

    des_loss = tf.subtract(tf.reduce_mean(syn_res, axis=0),
                           tf.reduce_mean(obs_res, axis=0))
    dLdI = tf.gradients(syn_res, syn)[0]

    syn_langevin = langevin_dynamics(syn)

    train_data = data_io.getObj(FLAGS.data_path,
                                FLAGS.category,
                                cube_len=cube_len,
                                num_voxels=FLAGS.train_size)
    num_voxels = len(train_data)

    train_data = train_data[..., np.newaxis]

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    data_io.saveVoxelsToMat(train_data,
                            "%s/observed_data.mat" % output_dir,
                            cmin=0,
                            cmax=1)

    num_batches = int(math.ceil(num_voxels / batch_size))
    # descriptor variables
    des_vars = [
        var for var in tf.trainable_variables() if var.name.startswith('des')
    ]

    des_optim = tf.train.AdamOptimizer(FLAGS.d_lr, beta1=FLAGS.beta1)
    des_grads_vars = des_optim.compute_gradients(des_loss, var_list=des_vars)
    des_grads = [
        tf.reduce_mean(tf.abs(grad)) for (grad, var) in des_grads_vars
        if '/w' in var.name
    ]
    apply_d_grads = des_optim.apply_gradients(des_grads_vars)

    saver = tf.train.Saver(max_to_keep=50)

    with tf.Session() as sess:
        # initialize training
        sess.run(tf.global_variables_initializer())

        print('Generating low resolution data')
        up_samples = np.zeros(train_data.shape)
        down_samples = np.zeros(
            shape=[num_voxels, lr_size, lr_size, lr_size, 1])
        for i in range(num_batches):
            indices = slice(i * batch_size,
                            min(num_voxels, (i + 1) * batch_size))
            obs_data = train_data[indices]
            ds = sess.run(down_syn, feed_dict={obs: obs_data})
            us = sess.run(up_syn, feed_dict={low_res: ds})
            down_samples[indices] = ds
            up_samples[indices] = us

        data_io.saveVoxelsToMat(down_samples,
                                "%s/down_sample.mat" % output_dir,
                                cmin=0,
                                cmax=1)
        data_io.saveVoxelsToMat(up_samples,
                                "%s/up_sample.mat" % output_dir,
                                cmin=0,
                                cmax=1)

        voxel_mean = train_data.mean()
        train_data = train_data - voxel_mean
        up_samples = up_samples - voxel_mean

        print('start training')
        sample_voxels = np.random.randn(num_voxels, cube_len, cube_len,
                                        cube_len, 1)

        des_loss_epoch = []
        recon_err_epoch = []
        plt.ion()

        for epoch in range(FLAGS.num_epochs):
            des_loss_vec = []
            recon_err_vec = []

            start_time = time.time()
            for i in range(num_batches):
                indices = slice(i * batch_size,
                                min(num_voxels, (i + 1) * batch_size))
                obs_data = train_data[indices]
                us_data = up_samples[indices]

                sr = sess.run(syn_langevin, feed_dict={syn: us_data})

                # learn D net
                d_loss = sess.run([des_loss, apply_d_grads],
                                  feed_dict={
                                      obs: obs_data,
                                      syn: sr
                                  })[0]
                # Compute MSE
                mse = sess.run(recon_err, feed_dict={obs: obs_data, syn: sr})
                recon_err_vec.append(mse)
                des_loss_vec.append(d_loss)

                sample_voxels[indices] = sr

            end_time = time.time()
            des_loss_mean, recon_err_mean = float(
                np.mean(des_loss_vec)), float(np.mean(recon_err_vec))
            des_loss_epoch.append(des_loss_mean)
            recon_err_epoch.append(recon_err_mean)

            print(
                'Epoch #%d, descriptor loss: %.4f, avg MSE: %4.4f, time:%.2fs'
                %
                (epoch, des_loss_mean, recon_err_mean, end_time - start_time))

            if mse > 2 or np.isnan(mse):
                break

            if epoch % FLAGS.log_step == 0:
                if not os.path.exists(synthesis_dir):
                    os.makedirs(synthesis_dir)
                data_io.saveVoxelsToMat(sample_voxels + voxel_mean,
                                        "%s/sample%04d.mat" %
                                        (synthesis_dir, epoch),
                                        cmin=0,
                                        cmax=1)

                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                saver.save(sess,
                           "%s/%s" % (checkpoint_dir, 'model.ckpt'),
                           global_step=epoch)

                if not os.path.exists(log_dir):
                    os.makedirs(log_dir)
                plt.figure(1)
                data_io.draw_graph(plt, des_loss_epoch, 'des_loss', log_dir,
                                   'r')
                plt.figure(2)
                data_io.draw_graph(plt, recon_err_epoch, 'recon_error',
                                   log_dir, 'b')
Exemplo n.º 5
0
def main(_):
    RANDOM_SEED = 66
    np.random.seed(RANDOM_SEED)

    output_dir = os.path.join(FLAGS.output_dir, FLAGS.category)
    sample_dir = os.path.join(output_dir, 'synthesis')
    log_dir = os.path.join(output_dir, 'log')
    model_dir = os.path.join(output_dir, 'checkpoints')

    if tf.gfile.Exists(log_dir):
        tf.gfile.DeleteRecursively(log_dir)
    tf.gfile.MakeDirs(log_dir)

    if tf.gfile.Exists(sample_dir):
        tf.gfile.DeleteRecursively(sample_dir)
    tf.gfile.MakeDirs(sample_dir)

    if tf.gfile.Exists(model_dir):
        tf.gfile.DeleteRecursively(model_dir)
    tf.gfile.MakeDirs(model_dir)

    # Prepare training data
    train_data = data_io.getObj(FLAGS.data_path,
                                FLAGS.category,
                                cube_len=FLAGS.cube_len,
                                num_voxels=FLAGS.train_size,
                                low_bound=0,
                                up_bound=1)

    data_io.saveVoxelsToMat(train_data,
                            "%s/observed_data.mat" % output_dir,
                            cmin=0,
                            cmax=1)

    # Preprocess training data
    voxel_mean = train_data.mean()
    train_data = train_data - voxel_mean
    train_data = train_data[..., np.newaxis]

    FLAGS.num_batches = int(math.ceil(len(train_data) / FLAGS.batch_size))

    print('Reading voxel data {}, shape: {}'.format(FLAGS.category,
                                                    train_data.shape))
    print('min: %.4f\tmax: %.4f\tmean: %.4f' %
          (train_data.min(), train_data.max(), voxel_mean))

    # create and build model
    net = DescriptorNet3D(FLAGS)
    net.build_model()

    with tf.Session() as sess:

        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        sample_size = FLAGS.sample_batch * FLAGS.num_batches
        sample_voxels = np.random.randn(sample_size, FLAGS.cube_len,
                                        FLAGS.cube_len, FLAGS.cube_len, 1)

        saver = tf.train.Saver(max_to_keep=50)

        writer = tf.summary.FileWriter(log_dir, sess.graph)

        for epoch in range(FLAGS.num_epochs):
            d_grad_acc = []

            start_time = time.time()

            sess.run(net.reset_grads)
            for i in range(FLAGS.num_batches):

                obs_data = train_data[i * FLAGS.
                                      batch_size:min(len(train_data), (i + 1) *
                                                     FLAGS.batch_size)]
                syn_data = sample_voxels[i * FLAGS.sample_batch:(i + 1) *
                                         FLAGS.sample_batch]

                # generate synthesized images
                if epoch < 100:
                    syn = sess.run(net.langevin_descriptor_noise,
                                   feed_dict={net.syn: syn_data})

                else:
                    syn = sess.run(net.langevin_descriptor,
                                   feed_dict={net.syn: syn_data})

                # learn D net
                d_grad = sess.run([
                    net.des_grads, net.des_loss_update, net.update_d_grads,
                    net.sample_loss_update
                ],
                                  feed_dict={
                                      net.obs: obs_data,
                                      net.syn: syn
                                  })[0]

                d_grad_acc.append(d_grad)

                # Compute L2 distance
                sess.run(net.recon_err_update,
                         feed_dict={
                             net.obs: obs_data,
                             net.syn: syn
                         })

                sample_voxels[i * FLAGS.sample_batch:(i + 1) *
                              FLAGS.sample_batch] = syn

            sess.run(net.apply_d_grads)
            [des_loss_avg, sample_loss_avg, mse, summary] = sess.run([
                net.des_loss_mean, net.sample_loss_mean, net.recon_err_mean,
                net.summary_op
            ])
            end_time = time.time()

            print(
                'Epoch #%d, descriptor loss: %.4f, descriptor SSD weight: %.4f, sample loss: %.4f, Avg MSE: %4.4f, time: %.2fs'
                % (epoch, des_loss_avg, float(np.mean(d_grad_acc)),
                   sample_loss_avg, mse, end_time - start_time))
            writer.add_summary(summary, epoch)

            if epoch % FLAGS.log_step == 0:
                if not os.path.exists(sample_dir):
                    os.makedirs(sample_dir)
                data_io.saveVoxelsToMat(sample_voxels + voxel_mean,
                                        "%s/sample%04d.mat" %
                                        (sample_dir, epoch),
                                        cmin=0,
                                        cmax=1)

                if not os.path.exists(model_dir):
                    os.makedirs(model_dir)
                saver.save(sess,
                           "%s/%s" % (model_dir, 'net.ckpt'),
                           global_step=epoch)