Ejemplo n.º 1
0
def vae_estimator(hparams):

    # Get a session
    sess = tf.Session()

    # Set up palceholders
    #A = tf.placeholder(tf.float32, shape=(hparams.batch_size, 100), name='A')
    y_batch = tf.placeholder(tf.float32, shape=(hparams.batch_size, hparams.n_input), name='y_batch')

    # Create the generator
    # TODO: Move z_batch definition here
    z_batch,x_hat_batch, restore_path, restore_dict = mnist_model_def.vae_gen(hparams)

    # measure the estimate

    y_hat_batch = tf.identity(x_hat_batch,name='y2_batch')

    # define all losses
    m_loss1_batch = tf.reduce_mean(tf.abs(y_batch - y_hat_batch), 1)
    m_loss2_batch = tf.reduce_mean((y_batch - y_hat_batch)**2, 1)
    zp_loss_batch = tf.reduce_sum(z_batch**2, 1)

    # define total loss
    total_loss_batch = hparams.mloss1_weight * m_loss1_batch \
                     + hparams.mloss2_weight * m_loss2_batch \
                     + hparams.zprior_weight * zp_loss_batch
    total_loss = tf.reduce_mean(total_loss_batch)

    # Compute means for logging
    m_loss1 = tf.reduce_mean(m_loss1_batch)
    m_loss2 = tf.reduce_mean(m_loss2_batch)
    zp_loss = tf.reduce_mean(zp_loss_batch)

    # Set up gradient descent
    var_list = [z_batch]
    global_step = tf.Variable(0, trainable=False, name='global_step')
    learning_rate = utils.get_learning_rate(global_step, hparams)
    opt = utils.get_optimizer(learning_rate, hparams)
    update_op = opt.minimize(total_loss, var_list=var_list, global_step=global_step, name='update_op')
    opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step)

    # Intialize and restore model parameters
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    restorer = tf.train.Saver(var_list=restore_dict)
    restorer.restore(sess, restore_path)

    def estimator(y_batch_val,z_batch_val,hparams):
        """Function that returns the estimated image"""
        best_keeper = utils.BestKeeper(hparams)
        assign_z_opt_op = z_batch.assign(z_batch_val)

        feed_dict = {y_batch: y_batch_val}

        for i in range(hparams.num_random_restarts):
            sess.run(opt_reinit_op)
            sess.run(assign_z_opt_op)
            for j in range(hparams.max_update_iter):
                _, lr_val, total_loss_val, \
                m_loss1_val, \
                m_loss2_val, \
                zp_loss_val = sess.run([update_op, learning_rate, total_loss,
                                        m_loss1,
                                        m_loss2,
                                        zp_loss], feed_dict=feed_dict)
                logging_format = 'rr {} iter {} lr {} total_loss {} m_loss1 {} m_loss2 {} zp_loss {}'
                print logging_format.format(i, j, lr_val, total_loss_val,
                                            m_loss1_val,
                                            m_loss2_val,
                                            zp_loss_val)

            x_hat_batch_val,z_batch_val, total_loss_batch_val = sess.run([x_hat_batch,z_batch, total_loss_batch], feed_dict=feed_dict)
            best_keeper.report(x_hat_batch_val,z_batch_val, total_loss_batch_val)
        return best_keeper.get_best()

    return estimator
Ejemplo n.º 2
0
def vae_estimator(hparams):

    # Get a session
    tf.reset_default_graph()
    g1 = tf.Graph()
    with g1.as_default() as g:
        gpu_options = tf.GPUOptions(allow_growth=True)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options\
          , allow_soft_placement=True))

        # Set up palceholders
        A = tf.placeholder(tf.float32,
                           shape=(hparams.n_input, hparams.num_measurements),
                           name='A')
        y_batch = tf.placeholder(tf.float32,
                                 shape=(hparams.batch_size,
                                        hparams.num_measurements),
                                 name='y_batch')

        # Create the generator
        z_batch, x_hat_batch, _, restore_path, restore_dict = construct_gen(
            hparams, vae_model_def, 'gen')

        # measure the estimate
        if hparams.measurement_type == 'project':
            y_hat_batch = tf.identity(x_hat_batch, name='y_hat_batch')
        else:
            y_hat_batch = tf.matmul(x_hat_batch, A, name='y_hat_batch')

        # define all losses
        m_loss1_batch = tf.reduce_mean(tf.abs(y_batch - y_hat_batch), 1)
        m_loss2_batch = tf.reduce_mean((y_batch - y_hat_batch)**2, 1)
        zp_loss_batch = tf.reduce_sum(z_batch**2, 1)

        # define total loss
        total_loss_batch = hparams.mloss1_weight * m_loss1_batch \
                         + hparams.mloss2_weight * m_loss2_batch \
                         + hparams.zprior_weight * zp_loss_batch
        total_loss = tf.reduce_mean(total_loss_batch)

        # Compute means for logging
        m_loss1 = tf.reduce_mean(m_loss1_batch)
        m_loss2 = tf.reduce_mean(m_loss2_batch)
        zp_loss = tf.reduce_mean(zp_loss_batch)

        # Set up gradient descent
        var_list = [z_batch]
        global_step = tf.Variable(0, trainable=False, name='global_step')
        learning_rate = utils.get_learning_rate(global_step, hparams)
        opt = utils.get_optimizer(learning_rate, hparams)
        update_op = opt.minimize(total_loss,
                                 var_list=var_list,
                                 global_step=global_step,
                                 name='update_op')
        opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step)

        # Intialize and restore model parameters
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        restorer = tf.train.Saver(var_list=restore_dict)
        restorer.restore(sess, restore_path)

        def estimator(A_val, y_batch_val, hparams):
            """Function that returns the estimated image"""
            best_keeper = utils.BestKeeper(hparams)
            if hparams.measurement_type == 'project':
                feed_dict = {y_batch: y_batch_val}
            else:
                feed_dict = {A: A_val, y_batch: y_batch_val}
            for i in range(hparams.num_random_restarts):
                sess.run(opt_reinit_op)
                for j in range(hparams.max_update_iter):
                    _, lr_val, total_loss_val, \
                    m_loss1_val, \
                    m_loss2_val, \
                    zp_loss_val = sess.run([update_op, learning_rate, total_loss,
                                            m_loss1,
                                            m_loss2,
                                            zp_loss], feed_dict=feed_dict)

                x_hat_batch_val, total_loss_batch_val = sess.run(
                    [x_hat_batch, total_loss_batch], feed_dict=feed_dict)
                best_keeper.report(x_hat_batch_val, total_loss_batch_val)
            return best_keeper.get_best()

        return estimator
Ejemplo n.º 3
0
def dcgan_l1_estimator(hparams, model_type):
    # pylint: disable = C0326

    tf.reset_default_graph()
    g1 = tf.Graph()
    with g1.as_default() as g:
        # Set up palceholders
        A = tf.placeholder(tf.float32,
                           shape=(hparams.n_input, hparams.num_measurements),
                           name='A')
        y_batch = tf.placeholder(tf.float32,
                                 shape=(hparams.batch_size,
                                        hparams.num_measurements),
                                 name='y_batch')

        # Create the generator
        z_batch = tf.Variable(tf.random_normal([hparams.batch_size, 100]),
                              name='z_batch')
        x_hat_batch, restore_dict_gen, restore_path_gen = dcgan_gen(
            z_batch, hparams)

        # Create the discriminator
        prob, restore_dict_discrim, restore_path_discrim = dcgan_discrim(
            x_hat_batch, hparams)
        nu_estim = tf.get_variable("x_estim",
                                   dtype=tf.float32,
                                   shape=x_hat_batch.get_shape(),
                                   initializer=tf.constant_initializer(0))
        x_estim = nu_estim + x_hat_batch

        # measure the estimate
        if hparams.measurement_type == 'project':
            y_hat_batch = tf.identity(x_estim, name='y2_batch')
        else:
            y_hat_batch = tf.matmul(x_estim, A, name='y2_batch')

        # define all losses
        m_loss1_batch = tf.reduce_mean(tf.abs(y_batch - y_hat_batch), 1)
        m_loss2_batch = tf.reduce_mean((y_batch - y_hat_batch)**2, 1)
        zp_loss_batch = tf.reduce_sum(z_batch**2, 1)
        d_loss1_batch = -tf.log(prob)
        d_loss2_batch = tf.log(1 - prob)

        if model_type == 'dcgan_l1':
            l1_loss = tf.reduce_sum(tf.abs(nu_estim), 1)
        elif model_type == 'dcgan_l1_wavelet':
            W = wavelet_basis()
            Winv = np.linalg.inv(W)
            l1_loss = tf.reduce_sum(
                tf.abs(tf.matmul(nu_estim, tf.constant(Winv,
                                                       dtype=tf.float32))), 1)
        elif model_type == 'dcgan_l1_dct':
            dct_proj = np.reshape(
                np.array([
                    dct2(np.eye(64)) for itr in range(hparams.batch_size * 3)
                ]), [hparams.batch_size, 3, 64, 64])
            nu_re = tf.transpose(tf.reshape(nu_estim, (-1, 64, 64, 3)),
                                 [0, 3, 1, 2])
            l1_loss = tf.reduce_sum(
                tf.abs(
                    tf.matmul(nu_re, tf.constant(dct_proj, dtype=tf.float32))),
                [1, 2, 3])

        # define total loss
        total_loss_batch = hparams.mloss1_weight * m_loss1_batch \
                         + hparams.mloss2_weight * m_loss2_batch \
                         + hparams.zprior_weight * zp_loss_batch \
                         + hparams.dloss1_weight * d_loss1_batch \
                         + hparams.dloss2_weight * d_loss2_batch \
                         + hparams.sparse_gen_weight * l1_loss
        total_loss = tf.reduce_mean(total_loss_batch)

        # Compute means for logging
        m_loss1 = tf.reduce_mean(m_loss1_batch)
        m_loss2 = tf.reduce_mean(m_loss2_batch)
        zp_loss = tf.reduce_mean(zp_loss_batch)
        d_loss1 = tf.reduce_mean(d_loss1_batch)
        d_loss2 = tf.reduce_mean(d_loss2_batch)

        # Set up gradient descent z_batch,
        var_list = [nu_estim, z_batch]
        global_step = tf.Variable(0, trainable=False, name='global_step')
        learning_rate = utils.get_learning_rate(global_step, hparams)
        with tf.variable_scope(tf.get_variable_scope(), reuse=False):
            opt = utils.get_optimizer(learning_rate, hparams)
            update_op = opt.minimize(total_loss,
                                     var_list=var_list,
                                     global_step=global_step,
                                     name='update_op')
            update_init_op = opt.minimize(total_loss,
                                          var_list=[z_batch],
                                          name='update_init_op')
            nu_estim_clip = nu_estim.assign(
                tf.maximum(tf.minimum(1.0 - x_hat_batch, nu_estim),
                           -1.0 - x_hat_batch))

        opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step)

        # Intialize and restore model parameters
        init_op = tf.global_variables_initializer()

    # Get a session
    gpu_options = tf.GPUOptions(allow_growth=True)
    sess = tf.Session(graph=g1, config=tf.ConfigProto(gpu_options=gpu_options\
      , allow_soft_placement=True))

    sess.run(init_op)
    restorer_gen = tf.train.Saver(var_list=restore_dict_gen)
    restorer_discrim = tf.train.Saver(var_list=restore_dict_discrim)
    restorer_gen.restore(sess, restore_path_gen)
    restorer_discrim.restore(sess, restore_path_discrim)

    def estimator(A_val, y_batch_val, hparams):
        """Function that returns the estimated image"""
        best_keeper = utils.BestKeeper(hparams)

        if hparams.measurement_type == 'project':
            feed_dict = {y_batch: y_batch_val}
        else:
            feed_dict = {A: A_val, y_batch: y_batch_val}

        for i in range(hparams.num_random_restarts):
            sess.run(opt_reinit_op)
            if hparams.max_update_iter > 250:
                init_itr_no = 250
            else:
                init_itr_no = 0

            for j in range(init_itr_no):
                sess.run([update_init_op], feed_dict=feed_dict)
                x_estim_val, total_loss_batch_val = sess.run(
                    [x_estim, total_loss_batch], feed_dict=feed_dict)
                best_keeper.report(x_estim_val, total_loss_batch_val)

            for j in range(int(hparams.max_update_iter - init_itr_no)):
                _, lr_val, total_loss_val, \
                m_loss1_val, \
                m_loss2_val, \
                zp_loss_val, \
                d_loss1_val, \
                d_loss2_val = sess.run([update_op, learning_rate, total_loss,
                                        m_loss1,
                                        m_loss2,
                                        zp_loss,
                                        d_loss1,
                                        d_loss2], feed_dict=feed_dict)
                sess.run(nu_estim_clip)

            x_estim_val, total_loss_batch_val = sess.run(
                [x_estim, total_loss_batch], feed_dict=feed_dict)
            best_keeper.report(x_estim_val, total_loss_batch_val)
        return best_keeper.get_best()

    return estimator
Ejemplo n.º 4
0
def dcgan_estimator(hparams):
    sess = tf.Session()
    y_batch = tf.placeholder(tf.float32,
                             shape=(hparams.batch_size, hparams.n_input),
                             name='y_batch')
    z_batch = tf.Variable(tf.random_normal([hparams.batch_size, 100]),
                          name='z_batch')
    x_hat_batch, restore_dict_gen, restore_path_gen = celebA_model_def.dcgan_gen(
        z_batch, hparams)
    prob, restore_dict_discrim, restore_path_discrim = celebA_model_def.dcgan_discrim(
        x_hat_batch, hparams)
    y_hat_batch = tf.zeros(x_hat_batch, name='y2_batch')
    m_loss1_batch = tf.abs(tf.reduce_mean(y_batch - y_hat_batch), 1)
    m_loss2_batch = tf.reduce_mean((y_batch - y_hat_batch)**2, 1)
    zp_loss_batch = tf.reduce_sum(z_batch**2, 1)
    d_loss1_batch = -tf.log(prob)
    d_loss2_batch = tf.log(1 - prob)
    m_loss1 = tf.reduce_mean(m_loss1_batch)
    m_loss2 = tf.reduce_mean(m_loss2_batch)
    zp_loss = tf.reduce_mean(zp_loss_batch)
    d_loss1 = tf.reduce_mean(d_loss1_batch)
    d_loss2 = tf.reduce_mean(d_loss2_batch)
    total_loss_batch = hparams.mloss1_weight * m_loss1_batch \
        + hparams.mloss2_weight * m_loss2_batch \
        + hparams.zprior_weight * zp_loss_batch \
        + hparams.dloss1_weight * d_loss1_batch \
        + hparams.dloss2_weight * d_loss2_batch

    var_list = [z_batch]
    global_step = tf.Variable(0, trainable=False, name='global_step')
    learning_rate = utils.get_learning_rate(global_step, hparams)
    with tf.variable_scope(tf.get_variable_scope(), reuse=False):
        opt = utils.get_optimizer(learning_rate, hparams)
        update_op = opt.minimize(total_loss,
                                 var_list=var_list,
                                 global_step=global_step,
                                 name='update_op')
    opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step)
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    restorer_gen = tf.train.Saver(var_list=restore_dict_gen)
    restorer_discrim = tf.train.Saver(var_list=restore_dict_discrim)
    restorer_gen.restore(sess, restore_path_gen)
    restorer_discrim.restore(sess, restore_path_discrim)

    def estimator(y_batch_val, z_batch_val, hparams):
        """Function that returns the estimated image"""
        best_keeper = utils.BestKeeper(hparams)
        assign_z_opt_op = z_batch.assign(z_batch_val)

        feed_dict = {y_batch: y_batch_val}

        for i in range(hparams.num_random_restarts):
            sess.run(opt_reinit_op)
            sess.run(assign_z_opt_op)
            for j in range(hparams.max_update_iter):

                _, lr_val, total_loss_val, \
                    m_loss1_val, \
                    m_loss2_val, \
                    zp_loss_val, \
                    d_loss1_val, \
                    d_loss2_val = sess.run([update_op, learning_rate, total_loss,
                                            m_loss1,
                                            m_loss2,
                                            zp_loss,
                                            d_loss1,
                                            d_loss2], feed_dict=feed_dict)
                logging_format = 'rr {} iter {} lr {} total_loss {} m_loss1 {} m_loss2 {} zp_loss {} d_loss1 {} d_loss2 {}'
                print logging_format.format(i, j, lr_val, total_loss_val,
                                            m_loss1_val, m_loss2_val,
                                            zp_loss_val, d_loss1_val,
                                            d_loss2_val)

            x_hat_batch_val, z_batch_val, total_loss_batch_val = sess.run(
                [x_hat_batch, z_batch, total_loss_batch], feed_dict=feed_dict)
            best_keeper.report(x_hat_batch_val, z_batch_val,
                               total_loss_batch_val)
        return best_keeper.get_best()

    return estimator
def vae_estimator(hparams):
    # Get a session
    sess = tf.Session()

    # Set up palceholders
    A = tf.placeholder(tf.float32,
                       shape=(hparams.n_input, hparams.num_measurements),
                       name='A')
    y_batch = tf.placeholder(tf.float32,
                             shape=(hparams.batch_size,
                                    hparams.num_measurements),
                             name='y_batch')

    # Create the generator
    # TODO: Move z_batch definition here
    z_batch, x_hat_batch, restore_path, restore_dict, _ = mnist_model_def.vae_gen(
        hparams)

    # measure the estimate
    if hparams.measurement_type == 'project':
        y_hat_batch = tf.identity(x_hat_batch, name='y_hat_batch')
    else:
        y_hat_batch = tf.matmul(x_hat_batch, A, name='y_hat_batch')

    # define all losses
    m_loss1_batch = tf.reduce_mean(tf.abs(y_batch - y_hat_batch), 1)
    m_loss2_batch = tf.reduce_mean((y_batch - y_hat_batch)**2, 1)

    #zp_loss_batch = tf.reduce_sum(z_batch**2, 1)
    if hparams.stdv > 0:
        norm_val = 1 / (hparams.stdv**2)
    else:
        norm_val = 1e+20

    zp_loss_batch = tf.reduce_sum(
        (z_batch - tf.ones(tf.shape(z_batch)) * hparams.mean)**2 * norm_val,
        1)  #added normalization

    # define total loss

    total_loss_batch = hparams.mloss1_weight * m_loss1_batch \
                     + hparams.mloss2_weight * m_loss2_batch \
                     + hparams.zprior_weight * zp_loss_batch
    total_loss = tf.reduce_mean(total_loss_batch)

    # Compute means for logging
    m_loss1 = tf.reduce_mean(m_loss1_batch)
    m_loss2 = tf.reduce_mean(m_loss2_batch)
    zp_loss = tf.reduce_mean(zp_loss_batch)

    # Set up gradient descent
    var_list = [z_batch]
    global_step = tf.Variable(0, trainable=False, name='global_step')
    learning_rate = utils.get_learning_rate(global_step, hparams)
    opt = utils.get_optimizer(learning_rate, hparams)
    update_op = opt.minimize(total_loss,
                             var_list=var_list,
                             global_step=global_step,
                             name='update_op')
    opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step)

    # Intialize and restore model parameters
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    restorer = tf.train.Saver(var_list=restore_dict)
    restorer.restore(sess, restore_path)

    def estimator(A_val, y_batch_val, hparams):
        """Function that returns the estimated image"""
        best_keeper = utils.BestKeeper(hparams)
        if hparams.measurement_type == 'project':
            #            if y_batch_val.shape[0]!=hparams.batch_size:
            #                y_batch_val_tmp = np.zeros((hparams.batch_size,hparams.num_measurements))
            #                y_batch_val_tmp[:y_batch_val.shape[0],:] = y_batch_val
            #                y_batch_val = y_batch_val_tmp

            #                print('Smaller INPUT NUMBER')#Or change hparams on the fly
            feed_dict = {y_batch: y_batch_val}
        else:
            feed_dict = {A: A_val, y_batch: y_batch_val}
        for i in range(hparams.num_random_restarts):
            sess.run(opt_reinit_op)
            for j in range(hparams.max_update_iter):
                _, lr_val, total_loss_val, \
                m_loss1_val, \
                m_loss2_val, \
                zp_loss_val = sess.run([update_op, learning_rate, total_loss,
                                        m_loss1,
                                        m_loss2,
                                        zp_loss], feed_dict=feed_dict)
                logging_format = 'rr {} iter {} lr {} total_loss {} m_loss1 {} m_loss2 {} zp_loss {}'
                print(
                    logging_format.format(i, j, lr_val, total_loss_val,
                                          m_loss1_val, m_loss2_val,
                                          zp_loss_val))
                #print('n_z is {}'.format(hparams.n_z))
                if total_loss_val == m_loss2_val and zp_loss_val > 0 and hparams.zprior_weight > 0:
                    raise ValueError('NONONO')

                if hparams.gif and ((j % hparams.gif_iter) == 0):
                    images = sess.run(x_hat_batch, feed_dict=feed_dict)
                    for im_num, image in enumerate(images):
                        save_dir = '{0}/{1}/'.format(hparams.gif_dir, im_num)
                        utils.set_up_dir(save_dir)
                        save_path = save_dir + '{0}.png'.format(j)
                        image = image.reshape(hparams.image_shape)
                        save_image(image, save_path)

            x_hat_batch_val, total_loss_batch_val = sess.run(
                [x_hat_batch, total_loss_batch], feed_dict=feed_dict)
            best_keeper.report(x_hat_batch_val, total_loss_batch_val)
        return best_keeper.get_best()

    return estimator
Ejemplo n.º 6
0
def stage_i(A_val,y_batch_val,hparams,hid_i,init_obj,early_stop,bs,optim,recovered=False):
    model_def = globals()['model_def']
    m_loss1_batch_dict = {}
    m_loss2_batch_dict = {}
    zp_loss_batch_dict = {}
    total_loss_dict = {}
    x_hat_batch_dict = {}
    model_selection = ModelSelect(hparams) 
    hid_i=int(hid_i)
#        print('Matrix norm is {}'.format(np.linalg.norm(A_val)))
#        hparams.eps = hparams.eps * np.linalg.norm(A_val)
   
    # Get a session
    sess = tf.Session()

    # Set up palceholders
    A = tf.placeholder(tf.float32, shape=(hparams.n_input, hparams.num_measurements), name='A')
   
    y_batch = tf.placeholder(tf.float32, shape=(hparams.batch_size, hparams.num_measurements), name='y_batch')
    # Create the generator
    model_hparams = model_def.Hparams()
    model_hparams.n_z = hparams.n_z
    model_hparams.stdv = hparams.stdv
    model_hparams.mean = hparams.mean
    model_hparams.grid = copy.deepcopy(hparams.grid)
    model_selection.setup_dim(hid_i,model_hparams)
    
    if not hparams.model_types[0] == 'vae-flex-alt' and 'alt' in hparams.model_types[0]:
        model_def.ignore_grid = next((j for  j in model_selection.dim_list if j >= hid_i), None)
    
    #set up the initialization            
    print('The initialization is: {}'.format(init_obj.mode))
    if init_obj.mode=='random':
        z_batch = model_def.get_z_var(model_hparams,hparams.batch_size,hid_i)
    elif init_obj.mode in ['previous-and-random','only-previous']:
        z_batch = model_def.get_z_var(model_hparams,hparams.batch_size,hid_i)
        init_op_par = tf.assign(z_batch, truncate_val(model_hparams,hparams,hid_i,init_obj,stdv=0))
    else:
        z_batch = truncate_val(model_hparams,hparams,hid_i,init_obj,stdv=0.1)
    _, x_hat_batch, _ = model_def.generator_i(model_hparams, z_batch, 'gen', hparams.bol,hid_i,relative=False)
    x_hat_batch_dict[hid_i] = x_hat_batch


    # measure the estimate
    if hparams.measurement_type == 'project':
        y_hat_batch = tf.identity(x_hat_batch, name='y_hat_batch')
    else:
        y_hat_batch = tf.matmul(x_hat_batch, A, name='y_hat_batch')

    # define all losses
    m_loss1_batch = tf.reduce_mean(tf.abs(y_batch - y_hat_batch), 1)
    m_loss2_batch = tf.reduce_mean((y_batch - y_hat_batch)**2, 1)
    
    if hparams.stdv>0:
        norm_val = 1/(hparams.stdv**2)
    else:
        norm_val = 1e+20
    
    zp_loss_batch = tf.reduce_sum((z_batch-tf.ones(tf.shape(z_batch))*hparams.mean)**2*norm_val, 1) #added normalization       
    
    # define total loss    
    total_loss_batch = hparams.mloss1_weight * m_loss1_batch \
                     + hparams.mloss2_weight * m_loss2_batch \
                     + hparams.zprior_weight * zp_loss_batch
    total_loss = tf.reduce_mean(total_loss_batch)
    total_loss_dict[hid_i] = total_loss
    
    # Compute means for logging
    m_loss1 = tf.reduce_mean(m_loss1_batch)
    m_loss2 = tf.reduce_mean(m_loss2_batch)
    zp_loss = tf.reduce_mean(zp_loss_batch)
    
    m_loss1_batch_dict[hid_i] = m_loss1
    m_loss2_batch_dict[hid_i] = m_loss2
    zp_loss_batch_dict[hid_i] = zp_loss

    # Set up gradient descent
    var_list = [z_batch]
    if recovered:
        global_step = tf.Variable(hparams.optim.global_step, trainable=False, name='global_step')
    else:
        global_step = tf.Variable(0, trainable=False, name='global_step')
    learning_rate = utils.get_learning_rate(global_step, hparams)
    opt = utils.get_optimizer(learning_rate, hparams)
    update_op = opt.minimize(total_loss, var_list=var_list, global_step=global_step, name='update_op')
    opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step)

    # Intialize and restore model parameters
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    #restore the setting
    if 'alt' in hparams.model_types[0]:
        factor = 1
    else:
        factor = len(hparams.grid)
    model_def.batch_size = hparams.batch_size*factor #changes object (call by reference), necessary, since call of generator_i might change batch size.
    model_selection.restore(sess,hid_i)        

    if recovered:
        best_keeper = hparams.optim.best_keeper
    else:
        best_keeper = utils.BestKeeper(hparams,logg_z=True)
    if hparams.measurement_type == 'project':
        feed_dict = {y_batch: y_batch_val}
    else:
        feed_dict = {A: A_val, y_batch: y_batch_val}
    flag = False
    for i in range(init_obj.num_random_restarts):
        if recovered and i <= hparams.optim.i: #Loosing optimizer's state, keras implementation maybe better
            if i < hparams.optim.i:
                continue
            else:
                sess.run(utils.get_opt_reinit_op(opt, [], global_step))
                sess.run(tf.assign(z_batch,hparams.optim.z_batch))              
        else:            
            sess.run(opt_reinit_op)
            if i<1 and init_obj.mode in ['previous-and-random','only-previous']:
                print('Using previous outcome as starting point')
                sess.run(init_op_par)            
        for j in range(hparams.max_update_iter):
            if recovered and j < hparams.optim.j:
                continue
            _, lr_val, total_loss_val, \
            m_loss1_val, \
            m_loss2_val, \
            zp_loss_val = sess.run([update_op, learning_rate, total_loss,
                                    m_loss1,
                                    m_loss2,
                                    zp_loss], feed_dict=feed_dict)         

            if hparams.gif and ((j % hparams.gif_iter) == 0):
                images = sess.run(x_hat_batch, feed_dict=feed_dict)
                for im_num, image in enumerate(images):
                    save_dir = '{0}/{1}/{2}/'.format(hparams.gif_dir, hid_i,im_num)
                    utils.set_up_dir(save_dir)
                    save_path = save_dir + '{0}.png'.format(j)
                    image = image.reshape(hparams.image_shape)
                    save_image(image, save_path)
            if j%100==0 and early_stop:
                x_hat_batch_val = sess.run(x_hat_batch, feed_dict=feed_dict)
                if check_tolerance(hparams,A_val,x_hat_batch_val,y_batch_val)[1]:
                    flag = True
                    print('Early stopping')
                    break
            if j%25==0:#Now not every turn                
                logging_format = 'hid {} rr {} iter {} lr {} total_loss {} m_loss1 {} m_loss2 {} zp_loss {}'
                print( logging_format.format(hid_i, i, j, lr_val, total_loss_val,
                                            m_loss1_val,
                                            m_loss2_val,
                                            zp_loss_val)) 
            if j%100==0:
                x_hat_batch_val, total_loss_batch_val, z_batch_val = sess.run([x_hat_batch, total_loss_batch,z_batch], feed_dict=feed_dict)
                best_keeper.report(x_hat_batch_val, total_loss_batch_val,z_val=z_batch_val)
                optim.global_step = sess.run(global_step)
                optim.A = A_val
                optim.y_batch = y_batch_val
                optim.i=i
                optim.j=j
                optim.z_batch= z_batch_val
                optim.best_keeper=best_keeper
                optim.bs=bs
                optim.init_obj = init_obj
                utils.save_to_pickle(optim,utils.get_checkpoint_dir(hparams, hparams.model_types[0])+'tmp/optim.pkl')
                print('Checkpoint of optimization created')

        hparams.optim.j = 0                
        x_hat_batch_val, total_loss_batch_val, z_batch_val = sess.run([x_hat_batch, total_loss_batch,z_batch], feed_dict=feed_dict)
        best_keeper.report(x_hat_batch_val, total_loss_batch_val,z_val=z_batch_val)
        if flag:
            break
    tf.reset_default_graph()
    return best_keeper.get_best()
Ejemplo n.º 7
0
def dcgan_estimator(hparams):
    # pylint: disable = C0326

    # Get a session
    sess = tf.Session()

    # Set up palceholders
    A = tf.placeholder(tf.float32,
                       shape=(hparams.n_input, hparams.num_measurements),
                       name='A')
    y_batch = tf.placeholder(tf.float32,
                             shape=(hparams.batch_size,
                                    hparams.num_measurements),
                             name='y_batch')

    # Create the generator
    z_batch = tf.Variable(tf.random_normal([hparams.batch_size, 100]),
                          name='z_batch')
    x_hat_batch, restore_dict_gen, restore_path_gen = celebA_model_def.dcgan_gen(
        z_batch, hparams)

    # Create the discriminator
    prob, restore_dict_discrim, restore_path_discrim = celebA_model_def.dcgan_discrim(
        x_hat_batch, hparams)

    # measure the estimate
    if hparams.measurement_type == 'project':
        y_hat_batch = tf.identity(x_hat_batch, name='y2_batch')
    else:
        measurement_is_sparse = (hparams.measurement_type
                                 in ['inpaint', 'superres'])
        y_hat_batch = tf.matmul(x_hat_batch,
                                A,
                                b_is_sparse=measurement_is_sparse,
                                name='y2_batch')

    # define all losses
    m_loss1_batch = tf.reduce_mean(tf.abs(y_batch - y_hat_batch), 1)
    m_loss2_batch = tf.reduce_mean((y_batch - y_hat_batch)**2, 1)
    zp_loss_batch = tf.reduce_sum(z_batch**2, 1)
    d_loss1_batch = -tf.log(prob)
    d_loss2_batch = tf.log(1 - prob)

    # define total loss
    total_loss_batch = hparams.mloss1_weight * m_loss1_batch \
                     + hparams.mloss2_weight * m_loss2_batch \
                     + hparams.zprior_weight * zp_loss_batch \
                     + hparams.dloss1_weight * d_loss1_batch \
                     + hparams.dloss2_weight * d_loss2_batch
    total_loss = tf.reduce_mean(total_loss_batch)

    # Compute means for logging
    m_loss1 = tf.reduce_mean(m_loss1_batch)
    m_loss2 = tf.reduce_mean(m_loss2_batch)
    zp_loss = tf.reduce_mean(zp_loss_batch)
    d_loss1 = tf.reduce_mean(d_loss1_batch)
    d_loss2 = tf.reduce_mean(d_loss2_batch)

    # Set up gradient descent
    var_list = [z_batch]
    global_step = tf.Variable(0, trainable=False, name='global_step')
    learning_rate = utils.get_learning_rate(global_step, hparams)
    with tf.variable_scope(tf.get_variable_scope(), reuse=False):
        opt = utils.get_optimizer(learning_rate, hparams)
        update_op = opt.minimize(total_loss,
                                 var_list=var_list,
                                 global_step=global_step,
                                 name='update_op')
    opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step)

    # Intialize and restore model parameters
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    restorer_gen = tf.train.Saver(var_list=restore_dict_gen)
    restorer_discrim = tf.train.Saver(var_list=restore_dict_discrim)
    restorer_gen.restore(sess, restore_path_gen)
    restorer_discrim.restore(sess, restore_path_discrim)

    def estimator(A_val, y_batch_val, hparams):
        """Function that returns the estimated image"""
        best_keeper = utils.BestKeeper(hparams)

        if hparams.measurement_type == 'project':
            feed_dict = {y_batch: y_batch_val}
        else:
            feed_dict = {A: A_val, y_batch: y_batch_val}

        for i in range(hparams.num_random_restarts):
            sess.run(opt_reinit_op)
            for j in range(hparams.max_update_iter):
                if hparams.gif and ((j % hparams.gif_iter) == 0):
                    images = sess.run(x_hat_batch, feed_dict=feed_dict)
                    for im_num, image in enumerate(images):
                        save_dir = '{0}/{1}/'.format(hparams.gif_dir, im_num)
                        utils.set_up_dir(save_dir)
                        save_path = save_dir + '{0}.png'.format(j)
                        image = image.reshape(hparams.image_shape)
                        save_image(image, save_path)

                _, lr_val, total_loss_val, \
                m_loss1_val, \
                m_loss2_val, \
                zp_loss_val, \
                d_loss1_val, \
                d_loss2_val = sess.run([update_op, learning_rate, total_loss,
                                        m_loss1,
                                        m_loss2,
                                        zp_loss,
                                        d_loss1,
                                        d_loss2], feed_dict=feed_dict)
                logging_format = 'rr {} iter {} lr {} total_loss {} m_loss1 {} m_loss2 {} zp_loss {} d_loss1 {} d_loss2 {}'
                print logging_format.format(i, j, lr_val, total_loss_val,
                                            m_loss1_val, m_loss2_val,
                                            zp_loss_val, d_loss1_val,
                                            d_loss2_val)

            x_hat_batch_val, total_loss_batch_val = sess.run(
                [x_hat_batch, total_loss_batch], feed_dict=feed_dict)
            best_keeper.report(x_hat_batch_val, total_loss_batch_val)
        return best_keeper.get_best()

    return estimator
def pggan_estimator(hparams):
    # pylint: disable = C0326

    # Get a session
    sess = tf.Session()

    # Set up palceholders
    Tx = tf.placeholder(tf.float32, shape=hparams.modSignal_shape, name='Tx')
    Rx = tf.placeholder(tf.float32, shape=hparams.modSignal_shape, name='Rx')
    Pilot = tf.placeholder(tf.float32,
                           shape=[hparams.batch_size, hparams.pilot_dim],
                           name='Pilot')

    # Create the generator
    z_batch = tf.Variable(tf.random.normal([hparams.batch_size,
                                            hparams.z_dim]),
                          name='z_batch')
    H_hat, restore_dict_gen, restore_path_gen = channel_model_def.pggan_gen(
        z_batch, Pilot, hparams)

    # measure the estimate
    print('H_hat:', H_hat.shape)
    print('Tx:', Tx.shape)
    Rx_hat = utils.calRx(H_hat, Tx, hparams)
    '''
    if hparams.measurement_type == 'project':
        y_hat_batch = tf.identity(x_hat_batch, name='y2_batch')
    elif hparams.measurement_type == 'pilot':
        Rx_hat = utils.calRx(H_hat,Tx,hparams)
        # Rx_hat = utils.multiComplex(H_hat,Tx);
        # Rx_hat = tf.multiply(H_hat, Tx, name='y_hat')  # TODO complex mult
    else:
        measurement_is_sparse = (hparams.measurement_type in ['inpaint', 'superres'])
        y_hat_batch = tf.matmul(x_hat_batch, A, b_is_sparse=measurement_is_sparse, name='y2_batch')
    '''

    # define all losses
    if hparams.measurement_type == 'pilot':
        # only polit Loss
        m_loss1_batch = tf.abs(
            utils.get_tf_pilot(Rx) - utils.get_tf_pilot(Rx_hat))
        m_loss2_batch = (utils.get_tf_pilot(Rx) -
                         utils.get_tf_pilot(Rx_hat))**2
        zp_loss_batch = tf.reduce_sum(z_batch**2, 1)
    else:
        m_loss1_batch = tf.reduce_mean(tf.abs(Rx - Rx_hat), 1)
        m_loss2_batch = tf.reduce_mean((Rx - Rx_hat)**2, 1)
        zp_loss_batch = tf.reduce_sum(z_batch**2, 1)

    # define total loss
    total_loss_batch = hparams.mloss1_weight * m_loss1_batch \
                     + hparams.mloss2_weight * m_loss2_batch \
                     + hparams.zprior_weight * zp_loss_batch
    total_loss = tf.reduce_mean(total_loss_batch)

    # Compute means for logging
    m_loss1 = tf.reduce_mean(m_loss1_batch)
    m_loss2 = tf.reduce_mean(m_loss2_batch)
    zp_loss = tf.reduce_mean(zp_loss_batch)

    # Set up gradient descent
    var_list = [z_batch]
    global_step = tf.Variable(0, trainable=False, name='global_step')
    learning_rate = utils.get_learning_rate(global_step, hparams)
    with tf.variable_scope(tf.get_variable_scope(), reuse=False):
        opt = utils.get_optimizer(learning_rate, hparams)
        update_op = opt.minimize(total_loss,
                                 var_list=var_list,
                                 global_step=global_step,
                                 name='update_op')
    opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step)

    # Intialize and restore model parameters
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    restorer_gen = tf.train.Saver(var_list=restore_dict_gen)
    restorer_gen.restore(sess, restore_path_gen)

    def estimator(Tx_val, Rx_val, Pilot_val, hparams):
        """Function that returns the estimated image"""
        best_keeper = utils.BestKeeper(hparams)

        if hparams.measurement_type == 'project':
            feed_dict = {y_batch: y_batch_val}
        else:
            feed_dict = {Tx: Tx_val, Rx: Rx_val, Pilot: Pilot_val}

        for i in range(hparams.num_random_restarts):
            sess.run(opt_reinit_op)
            for j in range(hparams.max_update_iter):
                if hparams.gif and ((j % hparams.gif_iter) == 0):
                    images = sess.run(x_hat_batch, feed_dict=feed_dict)
                    for im_num, image in enumerate(images):
                        save_dir = '{0}/{1}/'.format(hparams.gif_dir, im_num)
                        utils.set_up_dir(save_dir)
                        save_path = save_dir + '{0}.png'.format(j)
                        image = image.reshape(hparams.image_shape)
                        save_image(image, save_path)

                _, lr_val, total_loss_val, \
                m_loss1_val, \
                m_loss2_val, \
                zp_loss_val = sess.run([update_op, learning_rate, total_loss,
                                        m_loss1,
                                        m_loss2,
                                        zp_loss], feed_dict=feed_dict)
                logging_format = 'rr {} iter {} lr {} total_loss {} m_loss1 {} m_loss2 {} zp_loss {}'
                print logging_format.format(i, j, lr_val, total_loss_val,
                                            m_loss1_val, m_loss2_val,
                                            zp_loss_val)

            H_hat_val, total_loss_val = sess.run([H_hat, total_loss],
                                                 feed_dict=feed_dict)
            best_keeper.report(H_hat_val, total_loss_val)
        return best_keeper.get_best()

    return estimator
def vae_estimator(hparams):

    # Get a session
    sess = tf.Session()

    # Set up palceholders
    Tx = tf.placeholder(tf.float32, shape=hparams.image_shape, name='Tx')
    Rx = tf.placeholder(tf.float32, shape=hparams.image_shape, name='Rx')

    # Create the generator
    # TODO: Move z_batch definition here
    z_batch, H_hat, restore_path, restore_dict = channel_model_def.vae_gen(
        hparams)

    # measure the estimate
    if hparams.measurement_type == 'project':
        Rx_hat = tf.identity(x_hat_batch, name='y_hat_batch')
    elif hparams.measurement_type == 'pilot':
        Rx_hat = utils.multiComplex(H_hat, Tx)
        # Rx_hat = tf.multiply(H_hat, Tx, name='y_hat')  # TODO complex mult
    else:
        Rx_hat = tf.multiply(H_hat, Tx, name='y_hat')

    # define all losses
    m_loss1_batch = tf.reduce_mean(tf.reduce_mean(tf.abs(Rx - Rx_hat), 1), 0)
    m_loss2_batch = tf.reduce_mean(tf.reduce_mean((Rx - Rx_hat)**2, 1), 0)
    zp_loss_batch = tf.reduce_sum(z_batch**2, 1)

    # define total loss
    total_loss_batch = hparams.mloss1_weight * m_loss1_batch \
                     + hparams.mloss2_weight * m_loss2_batch \
                     + hparams.zprior_weight * zp_loss_batch
    total_loss = tf.reduce_mean(total_loss_batch)

    # Compute means for logging
    m_loss1 = tf.reduce_mean(m_loss1_batch)
    m_loss2 = tf.reduce_mean(m_loss2_batch)
    zp_loss = tf.reduce_mean(zp_loss_batch)

    # Set up gradient descent
    var_list = [z_batch]
    global_step = tf.Variable(0, trainable=False, name='global_step')
    learning_rate = utils.get_learning_rate(global_step, hparams)
    opt = utils.get_optimizer(learning_rate, hparams)
    update_op = opt.minimize(total_loss,
                             var_list=var_list,
                             global_step=global_step,
                             name='update_op')
    opt_reinit_op = utils.get_opt_reinit_op(opt, var_list, global_step)

    # Intialize and restore model parameters
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    restorer = tf.train.Saver(var_list=restore_dict)
    restorer.restore(sess, restore_path)

    def estimator(Tx_val, Rx_val, hparams):
        """Function that returns the estimated image"""
        best_keeper = utils.BestKeeper(hparams)
        if hparams.measurement_type == 'project':
            feed_dict = {Rx: Rx_val}
        else:
            feed_dict = {Tx: Tx_val, Rx: Rx_val}
        for i in range(hparams.num_random_restarts):
            sess.run(opt_reinit_op)
            for j in range(hparams.max_update_iter):
                _, lr_val, total_loss_val, \
                m_loss1_val, \
                m_loss2_val, \
                zp_loss_val = sess.run([update_op, learning_rate, total_loss,
                                        m_loss1,
                                        m_loss2,
                                        zp_loss], feed_dict=feed_dict)
                logging_format = 'rr {} iter {} lr {} total_loss {} m_loss1 {} m_loss2 {} zp_loss {}'
                print logging_format.format(i, j, lr_val, total_loss_val,
                                            m_loss1_val, m_loss2_val,
                                            zp_loss_val)

            H_hat_val, total_loss_batch_val = sess.run(
                [H_hat, total_loss_batch], feed_dict=feed_dict)
            best_keeper.report(H_hat_val, total_loss_batch_val)
        return best_keeper.get_best()

    return estimator