def main(hparams):

    # Set up some stuff accoring to hparams
    hparams.n_input = np.prod(hparams.image_shape)
    utils.set_num_measurements(hparams)
    utils.print_hparams(hparams)

    # get inputs
    data_dict = model_input(hparams)

    estimator = utils.get_estimator(hparams, hparams.model_types[0])
    print(estimator)
    hparams.checkpoint_dir = utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses = utils.load_checkpoints(hparams)

    h_hats_dict = {model_type: {} for model_type in hparams.model_types}
    for key, x in data_dict.iteritems():
        if not hparams.not_lazy:
            # If lazy, first check if the image has already been
            # saved before by *all* estimators. If yes, then skip this image.
            save_paths = utils.get_save_paths(hparams, key)
            is_saved = all([
                os.path.isfile(save_path) for save_path in save_paths.values()
            ])
            if is_saved:
                continue

        # Get Rx data
        Rx = data_dict[key]['Rx_data']
        Tx = data_dict[key]['Tx_data']
        H = data_dict[key]['H_data']
        Pilot_Rx = utils.get_pilot(Rx)
        print('Pilot_shape', Pilot_Rx.shape)
        Pilot_Rx = Pilot_Rx[0::2] + Pilot_Rx[1::2] * 1j
        Pilot_Tx = utils.get_pilot(Tx)
        Pilot_Tx = Pilot_Tx[0::2] + Pilot_Tx[1::2] * 1j
        Pilot_complex = Pilot_Rx / Pilot_Tx
        Pilot = np.empty((Pilot_complex.size * 2, ), dtype=Pilot_Rx.dtype)
        Pilot[0::2] = np.real(Pilot_complex)
        Pilot[1::2] = np.imag(Pilot_complex)

        Pilot = np.reshape(Pilot, [1, -1]) / 2.5
        # Construct estimates using each estimator
        h_hat = estimator(Tx, Rx, Pilot, hparams)

        # Compute and store measurement and l2 loss
        #        measurement_losses['dcgan'][key] = utils.get_measurement_loss(h_hat, Tx, Rx)
        #        l2_losses['dcgan'][key] = utils.get_l2_loss(h_hat, H)

        print "Processed upto image {0} / {1}".format(key + 1, len(data_dict))

        # Checkpointing
        if (hparams.save_images) and ((key + 1) % hparams.checkpoint_iter
                                      == 0):
            # utils.checkpoint(key,h_hat, measurement_losses, l2_losses, save_image, hparams)
            utils.save_channel_image(key + 1, h_hat, hparams)
            utils.save_channel_mat(key + 1, h_hat, hparams)

            print '\nProcessed and saved first ', key + 1, 'channels\n'
Ejemplo n.º 2
0
def main(hparams):

    # Set up some stuff accoring to hparams
    hparams.n_input = np.prod(hparams.image_shape)
    utils.set_num_measurements(hparams)
    utils.print_hparams(hparams)

    # get inputs
    data_dict = model_input(hparams)

    estimator = utils.get_estimator(hparams, 'vae')
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses = utils.load_checkpoints(hparams)

    h_hats_dict = {model_type: {} for model_type in hparams.model_types}
    for key, x in data_dict.iteritems():
        if not hparams.not_lazy:
            # If lazy, first check if the image has already been
            # saved before by *all* estimators. If yes, then skip this image.
            save_paths = utils.get_save_paths(hparams, key)
            is_saved = all([
                os.path.isfile(save_path) for save_path in save_paths.values()
            ])
            if is_saved:
                continue

        # Get Rx data
        Rx = data_dict[key]['Rx_data']
        Tx = data_dict[key]['Rx_data']
        H = data_dict[key]['H_data']

        # Construct estimates using each estimator
        h_hat = estimator(Tx, Rx, hparams)

        # Save the estimate
        h_hats_dict['vae'][key] = h_hat

        # Compute and store measurement and l2 loss
        measurement_losses['vae'][key] = utils.get_measurement_loss(
            h_hat, Tx, Rx)
        l2_losses['vae'][key] = utils.get_l2_loss(h_hat, H)

        print 'Processed upto image {0} / {1}'.format(key + 1, len(data_dict))

        # Checkpointing
        if (hparams.save_images) and ((key + 1) % hparams.checkpoint_iter
                                      == 0):
            utils.checkpoint(key, h_hat, measurement_losses, l2_losses,
                             save_image, hparams)
            print '\nProcessed and saved first ', key + 1, 'channels\n'
Ejemplo n.º 3
0
def process_or_load_hparams(out_dir, default_hparams, hparams_path):
    hparams = default_hparams
    # if a Hparams path is given as argument, override the default_hparams.
    hparams = utils.maybe_parse_standard_hparams(hparams, hparams_path)
    # extend HParams to add some parameters necessary for the training.
    hparams = process_hparams(hparams)
    # check compatibility of HParams
    check_hparams(hparams)
    # Save HParams
    utils.save_hparams(out_dir, hparams)

    # Print HParams
    print("Print hyperparameters:")
    utils.print_hparams(hparams)
    return hparams
Ejemplo n.º 4
0
Archivo: main.py Proyecto: efikarra/vae
def create_or_load_hparams(out_dir, default_hparams, flags):
    """Create hparams or load hparams from out_dir."""
    hparams = utils.load_hparams(out_dir)
    if not hparams:
        hparams = default_hparams
        hparams = utils.maybe_parse_standard_hparams(hparams,
                                                     flags.hparams_path)
        hparams.add_hparam("x_dim", hparams.img_width * hparams.img_height)
    else:
        hparams = utils.ensure_compatible_hparams(hparams, default_hparams,
                                                  flags)

    # Save HParams
    utils.save_hparams(out_dir, hparams)

    # Print HParams
    utils.print_hparams(hparams)
    return hparams
Ejemplo n.º 5
0
def create_or_load_hparams(out_dir, default_hparams, flags):
    # if the out_dir already contains hparams file, load these hparams.
    hparams = utils.load_hparams(out_dir)
    if not hparams:
        hparams = default_hparams
        hparams = utils.maybe_parse_standard_hparams(hparams,
                                                     flags.hparams_path)
        hparams = extend_hparams(hparams)
    else:
        #ensure that the loaded hparams and the command line hparams are compatible. If not, the command line hparams are overwritten!
        hparams = utils.ensure_compatible_hparams(hparams, default_hparams,
                                                  flags)

    # Save HParams
    utils.save_hparams(out_dir, hparams)

    # Print HParams
    print("Print hyperparameters:")
    utils.print_hparams(hparams)
    return hparams
Ejemplo n.º 6
0
def create_or_load_hparams(
    out_dir, default_hparams, hparams_path, save_hparams=True):
  """Create hparams or load hparams from out_dir."""
  hparams = utils.load_hparams(out_dir)
  if not hparams:
    hparams = default_hparams
    hparams = utils.maybe_parse_standard_hparams(
        hparams, hparams_path)
  else:
    hparams = ensure_compatible_hparams(hparams, default_hparams, hparams_path)
  hparams = extend_hparams(hparams)

  # Save HParams
  if save_hparams:
    utils.save_hparams(out_dir, hparams)
    for metric in hparams.metrics:
      utils.save_hparams(getattr(hparams, "best_" + metric + "_dir"), hparams)

  # Print HParams
  utils.print_hparams(hparams)
  return hparams
Ejemplo n.º 7
0
def main(hparams):
    # Set up some stuff according to hparams
    hparams.n_input = np.prod(hparams.image_shape)
    maxiter = hparams.max_outer_iter
    utils.print_hparams(hparams)

    # get inputs
    xs_dict = model_input(hparams)

    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses = utils.load_checkpoints(hparams)

    x_hats_dict = {'dcgan' : {}}
    x_batch_dict = {}
    for key, x in xs_dict.iteritems():
        if hparams.lazy:
            # If lazy, first check if the image has already been
            # saved before by *all* estimators. If yes, then skip this image.
            save_paths = utils.get_save_paths(hparams, key)
            is_saved = all([os.path.isfile(save_path) for save_path in save_paths.values()])
            if is_saved:
                continue

        x_batch_dict[key] = x
        if len(x_batch_dict) < hparams.batch_size:
            continue

        # Reshape input
        x_batch_list = [x.reshape(1, hparams.n_input) for _, x in x_batch_dict.iteritems()]
        x_batch = np.concatenate(x_batch_list)

        # Construct measurements
        A_outer = utils.get_outer_A(hparams)

        y_batch_outer=np.matmul(x_batch, A_outer)


        x_main_batch = 0.0 * x_batch
        z_opt_batch = np.random.randn(hparams.batch_size, 100)
        for k in range(maxiter):

            x_est_batch=x_main_batch + hparams.outer_learning_rate*(np.matmul((y_batch_outer-np.matmul(x_main_batch,A_outer)),A_outer.T))



            estimator = estimators['dcgan']
            x_hat_batch,z_opt_batch = estimator(x_est_batch,z_opt_batch, hparams)
            x_main_batch=x_hat_batch


        for i, key in enumerate(x_batch_dict.keys()):
            x = xs_dict[key]
            y = y_batch_outer[i]
            x_hat = x_hat_batch[i]

            # Save the estimate
            x_hats_dict['dcgan'][key] = x_hat

            # Compute and store measurement and l2 loss
            measurement_losses['dcgan'][key] = utils.get_measurement_loss(x_hat, A_outer, y)
            l2_losses['dcgan'][key] = utils.get_l2_loss(x_hat, x)
        print 'Processed upto image {0} / {1}'.format(key+1, len(xs_dict))

        # Checkpointing
        if (hparams.save_images) and ((key+1) % hparams.checkpoint_iter == 0):
            utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, save_image, hparams)
            #x_hats_dict = {'dcgan' : {}}
            print '\nProcessed and saved first ', key+1, 'images\n'

        x_batch_dict = {}

    # Final checkpoint
    if hparams.save_images:
        utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, save_image, hparams)
        print '\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict))

    if hparams.print_stats:
        for model_type in hparams.model_types:
            print model_type
            mean_m_loss = np.mean(measurement_losses[model_type].values())
            mean_l2_loss = np.mean(l2_losses[model_type].values())
            print 'mean measurement loss = {0}'.format(mean_m_loss)
            print 'mean l2 loss = {0}'.format(mean_l2_loss)

    if hparams.image_matrix > 0:
        utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams)

    # Warn the user that some things were not processsed
    if len(x_batch_dict) > 0:
        print '\nDid NOT process last {} images because they did not fill up the last batch.'.format(len(x_batch_dict))
        print 'Consider rerunning lazily with a smaller batch size.'
Ejemplo n.º 8
0
    def train(self,
              train_writer_path,
              log_f,
              restore_folder=None,
              restore_name=None):
        if not self.is_training: return
        sess = self.sess
        saver = self.saver
        if train_writer_path:
            train_writer = tf.summary.FileWriter(train_writer_path, sess.graph)
        time.clock()
        sess.run(self.init)
        if restore_folder:
            restore_path = os.path.join(restore_folder, restore_name)
            saver.restore(sess, restore_path)
            utils.print_out('load from %s' % restore_path, log_f)
            # with self.graph.as_default():
            #     gru_att_cov_variables = []
            #     gru_att_cov_variables.extend(get_collections_from_scope('tgt_embedding'))
            #     gru_att_cov_variables.extend(get_collections_from_scope('gru'))
            #     gru_att_cov_variables.extend(get_collections_from_scope('att'))
            #     gru_att_cov_variables.extend(get_collections_from_scope('coverage'))
            #     gru_att_cov_variables.extend(get_collections_from_scope('project'))
            #     init_gru_att_cov_variables = tf.variables_initializer(var_list=gru_att_cov_variables)
            # sess.run(init_gru_att_cov_variables)
            # utils.print_out('random init gru_att_cov variables', log_f)
        global_step = 0
        epoch = 0
        utils.print_hparams(cfg, f=log_f)
        sess.run(self.init_iter_train)  # 初始化训练输入
        utils.print_out('init_iter_train', log_f)
        sess.run(self.init_iter_vaild)  # 初始化验证输入
        utils.print_out('init_iter_vaild', log_f)
        learning_rate = 0
        loss = 0
        try:
            utils.print_out("start at %s" % get_local_time(), log_f)
            while epoch < cfg.epochs:
                epoch += 1
                i_steps = 0

                while i_steps < cfg.each_steps:
                    try:
                        _, loss, global_step, learning_rate, summary = \
                            sess.run([self.train_op, self.loss_train, self.global_step,
                                      self.learning_rate, self.train_summary])
                        i_steps += 1
                    except tf.errors.OutOfRangeError:
                        sess.run(self.init_iter_train)
                    if global_step % cfg.print_frq == 0:
                        utils.print_out(
                            'epoch %d, step %d, gloSp %d, lr %.4f, loss %.4f' %
                            (epoch, i_steps, global_step, learning_rate, loss),
                            log_f)
                    if global_step % cfg.summary_frq == 0:
                        summary = sess.run(self.train_summary)
                        if train_writer_path:
                            train_writer.add_summary(summary,
                                                     global_step=global_step)

                if epoch % cfg.val_frq == 0:

                    val_count = 0
                    val_loss = 0
                    val_acc = 0
                    val_edit_dist = 0
                    true_sample_words = ['']
                    pred_sample_words = ['']
                    i_loss = 0
                    i_acc = 0
                    while val_count < cfg.val_steps:
                        try:
                            i_loss, i_acc, true_sample_words, pred_sample_words, summary = \
                                sess.run([self.loss_vaild, self.accuracy_vaild, self.train_lookUpTgt_vaild,
                                          self.infer_lookUpTgt_vaild, self.vaild_summary])
                            val_count += 1
                        except tf.errors.OutOfRangeError:
                            sess.run(self.init_iter_vaild)

                        val_loss += i_loss
                        val_acc += i_acc
                        c_val_edit_dist = []
                        for t, p in zip(true_sample_words, pred_sample_words):
                            edit_dist = utils.normal_leven(t, p)
                            c_val_edit_dist.append(edit_dist)
                        c_val_edit_dist = sum(c_val_edit_dist) / float(
                            len(c_val_edit_dist))
                        val_edit_dist += c_val_edit_dist

                    val_acc /= val_count
                    val_loss /= val_count
                    val_edit_dist /= val_count

                    timeStamp = int(time.time())
                    timeArray = time.localtime(timeStamp)
                    styleTime = time.strftime("%Y_%m_%d_%H_%M_%S", timeArray)

                    utils.print_out(
                        '%s ### val loss %.4f, acc %.4f, edit_dist %.4f' %
                        (styleTime, val_loss, val_acc, val_edit_dist), log_f)
                    if train_writer_path:
                        train_writer.add_summary(summary,
                                                 global_step=global_step)
                    test_show_size = min(cfg.test_show_size,
                                         len(true_sample_words))
                    for i in range(test_show_size):
                        str_tr = ''.join(true_sample_words[i])
                        str_pd = ''.join(pred_sample_words[i])
                        utils.print_out("   ## true: %s" % (str_tr), log_f)
                        utils.print_out("      pred: %s" % (str_pd), log_f)
                if epoch % cfg.save_frq == 0 and train_writer_path:
                    checkPoint_path = os.path.join(train_writer_path,
                                                   "checkPoint.model")
                    saver.save(sess, checkPoint_path, global_step=global_step)
                    utils.print_out(
                        "   global step %d, check point save to %s-%d" %
                        (global_step, checkPoint_path, global_step), log_f)

        except Exception as e:
            utils.print_out(
                "!!!!  Interrupt ## end training, global step %d" %
                (global_step), log_f)
            if len(e.args) > 0:
                utils.print_out("An error occurred. {}".format(e.args[-1]),
                                log_f)
            traceback.print_exc()

        finally:
            if train_writer_path:
                checkPoint_path = os.path.join(train_writer_path,
                                               "end_checkPoint.model")
                saver.save(sess, checkPoint_path, global_step=global_step)
                utils.print_out(
                    "   end training, global step %d, check point save to %s-%d"
                    % (global_step, checkPoint_path, global_step), log_f)
            utils.print_out("end at %s" % get_local_time(), log_f)
            return epoch
Ejemplo n.º 9
0
def main(hparams):
    # Set up some stuff according to hparams
    utils.set_up_dir(hparams.ckpt_dir)
    utils.set_up_dir(hparams.sample_dir)
    utils.print_hparams(hparams)

    # encode
    x_ph = tf.placeholder(tf.float32, [None, hparams.n_input], name='x_ph')
    z_mean, z_log_sigma_sq = model_def.encoder(hparams,
                                               x_ph,
                                               'enc',
                                               reuse=False)

    # sample
    eps = tf.random_normal((hparams.batch_size, hparams.n_z),
                           0,
                           1,
                           dtype=tf.float32)
    z_sigma = tf.sqrt(tf.exp(z_log_sigma_sq))
    z = z_mean + z_sigma * eps

    # reconstruct
    logits, x_reconstr_mean = model_def.generator(hparams,
                                                  z,
                                                  'gen',
                                                  reuse=False)

    # generator sampler
    z_ph = tf.placeholder(tf.float32, [None, hparams.n_z], name='x_ph')
    _, x_sample = model_def.generator(hparams, z_ph, 'gen', reuse=True)

    # define loss and update op
    total_loss = model_def.get_loss(x_ph, logits, z_mean, z_log_sigma_sq)
    opt = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
    update_op = opt.minimize(total_loss)

    # Sanity checks
    for var in tf.global_variables():
        print var.op.name
    print ''

    # Get a new session
    sess = tf.Session()

    # Model checkpointing setup
    model_saver = tf.train.Saver()

    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    # Attempt to restore variables from checkpoint
    start_epoch = utils.try_restore(hparams, sess, model_saver)

    # Get data iterator
    iterator = data_input.omniglot_data_iterator()

    # Training
    for epoch in range(start_epoch + 1, hparams.training_epochs):
        avg_loss = 0.0
        num_batches = hparams.num_samples // hparams.batch_size
        batch_num = 0
        for (x_batch_val, _) in iterator(hparams, num_batches):
            batch_num += 1
            feed_dict = {x_ph: x_batch_val}
            _, loss_val = sess.run([update_op, total_loss],
                                   feed_dict=feed_dict)
            avg_loss += loss_val / hparams.num_samples * hparams.batch_size

            if batch_num % 100 == 0:
                x_reconstr_mean_val = sess.run(x_reconstr_mean,
                                               feed_dict={x_ph: x_batch_val})

                z_val = np.random.randn(hparams.batch_size, hparams.n_z)
                x_sample_val = sess.run(x_sample, feed_dict={z_ph: z_val})

                utils.save_images(
                    np.reshape(x_reconstr_mean_val, [-1, 28, 28]), [10, 10],
                    '{}/reconstr_{:02d}_{:04d}.png'.format(
                        hparams.sample_dir, epoch, batch_num))
                utils.save_images(
                    np.reshape(x_batch_val, [-1, 28, 28]), [10, 10],
                    '{}/orig_{:02d}_{:04d}.png'.format(hparams.sample_dir,
                                                       epoch, batch_num))
                utils.save_images(
                    np.reshape(x_sample_val, [-1, 28, 28]), [10, 10],
                    '{}/sampled_{:02d}_{:04d}.png'.format(
                        hparams.sample_dir, epoch, batch_num))

        if epoch % hparams.summary_epoch == 0:
            print "Epoch:", '%04d' % (epoch), 'Avg loss = {:.9f}'.format(
                avg_loss)

        if epoch % hparams.ckpt_epoch == 0:
            save_path = os.path.join(hparams.ckpt_dir, 'omniglot_vae_model')
            model_saver.save(sess, save_path, global_step=epoch)

    save_path = os.path.join(hparams.ckpt_dir, 'omniglot_vae_model')
    model_saver.save(sess, save_path, global_step=hparams.training_epochs - 1)
def main(hparams):
    # set up perceptual loss
    device = 'cuda:0'
    percept = PerceptualLoss(
            model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
    )

    utils.print_hparams(hparams)

    # get inputs
    xs_dict = model_input(hparams)

    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses, lpips_scores, z_hats = utils.load_checkpoints(hparams)

    x_hats_dict = {model_type : {} for model_type in hparams.model_types}
    x_batch_dict = {}

    A = utils.get_A(hparams)
    noise_batch = hparams.noise_std * np.random.standard_t(2, size=(hparams.batch_size, hparams.num_measurements))



    for key, x in xs_dict.items():
        if not hparams.not_lazy:
            # If lazy, first check if the image has already been
            # saved before by *all* estimators. If yes, then skip this image.
            save_paths = utils.get_save_paths(hparams, key)
            is_saved = all([os.path.isfile(save_path) for save_path in save_paths.values()])
            if is_saved:
                continue

        x_batch_dict[key] = x
        if len(x_batch_dict) < hparams.batch_size:
            continue

        # Reshape input
        x_batch_list = [x.reshape(1, hparams.n_input) for _, x in x_batch_dict.items()]
        x_batch = np.concatenate(x_batch_list)

        # Construct noise and measurements


        y_batch = utils.get_measurements(x_batch, A, noise_batch, hparams)

        # Construct estimates using each estimator
        for model_type in hparams.model_types:
            estimator = estimators[model_type]
            x_hat_batch, z_hat_batch, m_loss_batch = estimator(A, y_batch, hparams)

            for i, key in enumerate(x_batch_dict.keys()):
                x = xs_dict[key]
                y_train = y_batch[i]
                x_hat = x_hat_batch[i]

                # Save the estimate
                x_hats_dict[model_type][key] = x_hat

                # Compute and store measurement and l2 loss
                measurement_losses[model_type][key] = m_loss_batch[key]
                l2_losses[model_type][key] = utils.get_l2_loss(x_hat, x)
                lpips_scores[model_type][key] = utils.get_lpips_score(percept, x_hat, x, hparams.image_shape)
                z_hats[model_type][key] = z_hat_batch[i]

        print('Processed upto image {0} / {1}'.format(key+1, len(xs_dict)))

        # Checkpointing
        if (hparams.save_images) and ((key+1) % hparams.checkpoint_iter == 0):
            utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, lpips_scores, z_hats, save_image, hparams)
            x_hats_dict = {model_type : {} for model_type in hparams.model_types}
            print('\nProcessed and saved first ', key+1, 'images\n')

        x_batch_dict = {}

    # Final checkpoint
    if hparams.save_images:
        utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, lpips_scores, z_hats, save_image, hparams)
        print('\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict)))

    if hparams.print_stats:
        for model_type in hparams.model_types:
            print(model_type)
            measurement_loss_list = list(measurement_losses[model_type].values())
            l2_loss_list = list(l2_losses[model_type].values())
            mean_m_loss = np.mean(measurement_loss_list)
            mean_l2_loss = np.mean(l2_loss_list)
            print('mean measurement loss = {0}'.format(mean_m_loss))
            print('mean l2 loss = {0}'.format(mean_l2_loss))

    if hparams.image_matrix > 0:
        utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams)

    # Warn the user that some things were not processsed
    if len(x_batch_dict) > 0:
        print('\nDid NOT process last {} images because they did not fill up the last batch.'.format(len(x_batch_dict)))
        print('Consider rerunning lazily with a smaller batch size.')
Ejemplo n.º 11
0
        'measurement-type',
        'noise-std',
        'num-measurements',
        'model-types',
        'zprior_weight',
        'dloss1_weight',
        'lmbd',
        'max-update-iter',
        'num-random-restarts',
    ]

    for hparam in HPARAM_NAMES_FOR_GRID_SEARCH:
        PARSER.add_argument('--' + hparam,
                            metavar=hparam + '-val',
                            type=str,
                            nargs='+',
                            default=['0'],
                            help='Values of ' + hparam)

    PARSER.add_argument(
        '--scripts-base-dir',
        type=str,
        default='../scripts/',
        help='Base directory to save scripts: Absolute path or relative to src'
    )

    HPARAMS = PARSER.parse_args()
    utils.print_hparams(HPARAMS)

    create_scripts(HPARAMS, HPARAM_NAMES_FOR_GRID_SEARCH)
Ejemplo n.º 12
0
    if (model_name == "NFFM"):
        hparam = NFFM_params
    elif (model_name == "DeepFM"):
        hparam = DeepFM_params
    elif (model_name == "FFM"):
        hparam = FFM_params
    elif (model_name == "FM"):
        hparam = FM_params
    elif (model_name == "DCN"):
        hparam = DCN_params
    elif (model_name == "XDeepFM"):
        hparam = XDeepFM_params
    elif (model_name == "AFM"):
        hparam = AFM_params

    utils.print_hparams(hparam)  # 打印模型参数

    model = ctrNet.build_model(hparam)  # 建立模型

    print("Start " + model_name)
    model.train(train_data=(train_df[features], train_df['label']),
                dev_data=(dev_df[features], dev_df['label']))

    preds = model.infer(dev_data=(test_df[features], test_df['label']))

    fpr, tpr, thresholds = metrics.roc_curve(test_df['label'] + 1,
                                             preds,
                                             pos_label=2)
    auc = metrics.auc(fpr, tpr)
    print("last model auc {}".format(auc))
    print(model_name + " Done")
Ejemplo n.º 13
0
def main(hparams):
#    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
    # Set up some stuff according to hparams
    utils.set_up_dir(hparams.ckpt_dir)
    utils.set_up_dir(hparams.sample_dir)
    utils.print_hparams(hparams)
    # encode
    x_ph = tf.placeholder(tf.float32, [None, hparams.n_input], name='x_ph')
    
    _,x_reconstr_mean,_, loss_list = model_def.model(hparams,x_ph,['enc','gen'],[False,False])
    total_loss = tf.add_n(loss_list, name='total_loss')
    
    
#    z_mean, z_log_sigma_sq = model_def.encoder(hparams, x_ph, 'enc', reuse=False)
#
#    # sample
#    eps = tf.random_normal((hparams.batch_size, hparams.n_z), 0, 1, dtype=tf.float32)
#    z_sigma = tf.sqrt(tf.exp(z_log_sigma_sq))
#    z = z_mean + z_sigma * eps
#
#    # reconstruct
#    logits, x_reconstr_mean, _ = model_def.generator(hparams, z, 'gen', reuse=False)
#
    # generator sampler
    z_ph = tf.placeholder(tf.float32, [None, hparams.grid[-1]], name='x_ph')
    x_sample = []
    for i in range(len(hparams.grid)):
        _, x_sample_tmp, _ = model_def.generator_i(hparams, model_def.slicer_dec(hparams,i,None,z_ph), 'gen', True,i) 
        x_sample.append(x_sample_tmp)
#    _, x_sample, _ = model_def.generator(hparams, z_ph, 'gen', reuse=True)

#    # define loss and update op
#    total_loss = model_def.get_loss(x_ph, logits, z_mean, z_log_sigma_sq)
    opt = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
    update_op = opt.minimize(total_loss)
#    print([el.name for el in tf.trainable_variables()])

    # Sanity checks
    for var in tf.global_variables():
        print(var.op.name)
    print('')
#    print([o.name for o in tf.trainable_variables()])

    # Get a new session
    sess = tf.Session()

    # Model checkpointing setup
    model_saver = tf.train.Saver()

    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    # Attempt to restore variables from checkpoint
    start_epoch = utils.try_restore(hparams, sess, model_saver)

    # Get data iterator
    iterator = data_input.mnist_data_iteratior(dataset=hparams.dataset)

    # Training
    for epoch in range(start_epoch+1, hparams.training_epochs):
        avg_loss = 0.0
        num_batches = hparams.num_samples // hparams.batch_size
        batch_num = 0
        for (x_batch_val, _) in iterator(hparams, num_batches):
            batch_num += 1
            feed_dict = {x_ph: x_batch_val}
            _, loss_val = sess.run([update_op, total_loss], feed_dict=feed_dict)
            avg_loss += loss_val / hparams.num_samples * hparams.batch_size

            if batch_num % 100 == 0:
                x_reconstr_mean_val = sess.run(x_reconstr_mean, feed_dict={x_ph: x_batch_val})

                z_val = np.random.randn(hparams.batch_size, hparams.grid[-1])
                x_sample_val = sess.run(x_sample, feed_dict={z_ph: z_val})
#                print(sess.run(hparams.track[0], feed_dict={z_ph: z_val}))
#                print(sess.run(hparams.track[1], feed_dict={z_ph: z_val}))
#                s1 = sess.run(hparams.x_ph_ref[0], feed_dict={x_ph: x_batch_val})
#                s2 = sess.run(hparams.logits_ref[0], feed_dict={x_ph: x_batch_val})
#                s3 = sess.run(loss_list, feed_dict={x_ph: x_batch_val})
#                print(s1.shape)
#                print(s2.shape)
#                print(s3)
                utils.save_images(np.reshape(x_batch_val, [-1, 28, 28]), \
                                      [10, 10], \
                                      '{}/orig_{:02d}_{:04d}.png'.format(hparams.sample_dir, epoch, batch_num))
                for i in range(len(hparams.grid)):
                    utils.save_images(np.reshape(x_reconstr_mean_val[i], [-1, 28, 28]),
                                      [10, 10],
                                      '{}/reconstr_{}_{:02d}_{:04d}.png'.format(hparams.sample_dir, hparams.grid[i], epoch, batch_num))
                    
                    utils.save_images(np.reshape(x_sample_val[i], [-1, 28, 28]),
                                      [10, 10],
                                      '{}/sampled_{}_{:02d}_{:04d}.png'.format(hparams.sample_dir, hparams.grid[i], epoch, batch_num))


        if epoch % hparams.summary_epoch == 0:
            print("Epoch:", '%04d' % (epoch), 'Avg loss = {:.9f}'.format(avg_loss))

        if epoch % hparams.ckpt_epoch == 0:
            save_path = os.path.join(hparams.ckpt_dir, '{}_vae_model_flex_hid'.format(hparams.dataset)+str('_'.join(map(str,hparams.grid))))
            model_saver.save(sess, save_path, global_step=epoch)

    save_path = os.path.join(hparams.ckpt_dir, '{}_vae_model_flex_hid'.format(hparams.dataset)+str('_'.join(map(str,hparams.grid))))
    model_saver.save(sess, save_path, global_step=hparams.training_epochs-1)
Ejemplo n.º 14
0
    log_path = os.path.join(log_dir, 'train.log')

    logging.basicConfig(level=logging.DEBUG,  # 控制台打印的日志级别
                        format='%(asctime)s - %(levelname)s: %(message)s',  # 日志格式
                        handlers=[
                            logging.FileHandler(log_path),
                            logging.StreamHandler(sys.stdout)]
                        )

    training_data = get_data_loader()
    evaluation_data = get_data_loader()

    logging.info(str(model))

    logging.info(str(print_hparams(hp)))

    logging.info('Data loaded!')
    logging.info('Data size: ' + str(len(training_data)))

    logging.info('Total Model parameters: ' + f'{sum(p.numel() for p in model.parameters() if p.requires_grad):,}')

    epoch = int(os.path.basename(hp.restore_path).replace('.pt', '')) if hp.restore_path else 1

    if hp.mode == 'train':
        while epoch < hp.training_epochs + 1:
            epoch_start_time = time.time()
            train(training_data)
            scheduler.step()
            eval(evaluation_data)
            if epoch % hp.checkpoint_save_interval == 0:
Ejemplo n.º 15
0
def main(hparams):
    hparams.n_input = np.prod(hparams.image_shape)
    maxiter = hparams.max_outer_iter
    utils.print_hparams(hparams)
    xs_dict = model_input(hparams)
    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses = utils.load_checkpoints(hparams)
    x_hats_dict = {'dcgan': {}}
    x_batch_dict = {}
    for key, x in xs_dict.iteritems():
        x_batch_dict[key] = x
        if len(x_batch_dict) < hparams.batch_size:
            continue
        x_coll = [
            x.reshape(1, hparams.n_input) for _, x in x_batch_dict.iteritems()
        ]
        x_batch = np.concatenate(x_coll)
        A_outer = utils.get_outer_A(hparams)
        # 1bitify
        y_batch_outer = np.sign(np.matmul(x_batch, A_outer))

        x_main_batch = 0.0 * x_batch
        z_opt_batch = np.random.randn(hparams.batch_size, 100)
        for k in range(maxiter):
            x_est_batch = x_main_batch + hparams.outer_learning_rate * (
                np.matmul(
                    (y_batch_outer -
                     np.sign(np.matmul(x_main_batch, A_outer))), A_outer.T))
            estimator = estimators['dcgan']
            x_hat_batch, z_opt_batch = estimator(x_est_batch, z_opt_batch,
                                                 hparams)
            x_main_batch = x_hat_batch

        for i, key in enumerate(x_batch_dict.keys()):
            x = xs_dict[key]
            y = y_batch_outer[i]
            x_hat = x_hat_batch[i]
            x_hats_dict['dcgan'][key] = x_hat
            measurement_losses['dcgan'][key] = utils.get_measurement_loss(
                x_hat, A_outer, y)
            l2_losses['dcgan'][key] = utils.get_l2_loss(x_hat, x)
        print 'Processed upto image {0} / {1}'.format(key + 1, len(xs_dict))
        if (hparams.save_images) and ((key + 1) % hparams.checkpoint_iter
                                      == 0):
            utils.checkpoint(x_hats_dict, measurement_losses, l2_losses,
                             save_image, hparams)
            print '\nProcessed and saved first ', key + 1, 'images\n'

        x_batch_dict = {}

    if hparams.save_images:
        utils.checkpoint(x_hats_dict, measurement_losses, l2_losses,
                         save_image, hparams)
        print '\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict))

    if hparams.print_stats:
        for model_type in hparams.model_types:
            print model_type
            mean_m_loss = np.mean(measurement_losses[model_type].values())
            mean_l2_loss = np.mean(l2_losses[model_type].values())
            print 'mean measurement loss = {0}'.format(mean_m_loss)
            print 'mean l2 loss = {0}'.format(mean_l2_loss)

    if hparams.image_matrix > 0:
        utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams)

    # Warn the user that some things were not processsed
    if len(x_batch_dict) > 0:
        print '\nDid NOT process last {} images because they did not fill up the last batch.'.format(
            len(x_batch_dict))
        print 'Consider rerunning lazily with a smaller batch size.'
def main(hparams):
#    if not hparams.use_gpu:
#        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
    # Set up some stuff accoring to hparams
    hparams.n_input = np.prod(hparams.image_shape)
    #hparams.stdv = 10 #adjust to HPARAM in model_def.py
    #hparams.mean = 0 #adjust to HPARAM in model_def.py
    utils.set_num_measurements(hparams)
    utils.print_hparams(hparams)

    hparams.bol = False
 #   hparams.dict_flag = False
    # get inputs
    if hparams.input_type == 'dict-input':# or hparams.dict_flag:
        hparams_load_key = copy.copy(hparams)
        hparams_load_key.input_type = 'full-input'
        hparams_load_key.measurement_type = 'project'
        hparams_load_key.zprior_weight = 0.0
        hparams.key_field = np.load(utils.get_checkpoint_dir(hparams_load_key, hparams.model_types[0])+'candidates.npy').item()
        print(hparams.measurement_type)
    xs_dict, label_dict = model_input(hparams)    

    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    sh = utils.SaveHandler()
    sh.load_or_init_all(hparams.save_images,hparams.model_types,sh.get_pkl_filepaths(hparams,use_all=True))
    if label_dict is None:
        print('No labels exist.')
        del sh.class_loss
#    measurement_losses, l2_losses, emd_losses, x_orig, x_rec, noise_batch = utils.load_checkpoints(hparams)
    
    if hparams.input_type == 'gen-span':
        np.save(utils.get_checkpoint_dir(hparams, hparams.model_types[0])+'z.npy',hparams.z_from_gen)
        np.save(utils.get_checkpoint_dir(hparams, hparams.model_types[0])+'images.npy',hparams.images_mat)
    
    

    x_hats_dict = {model_type : {} for model_type in hparams.model_types}
    x_batch_dict = {}
    x_batch=[]
    x_hat_batch=[]
#    l2_losses2=np.zeros((len(xs_dict),1))
#    distances_arr=[]
    image_distance =np.zeros((len(xs_dict),1))
    hparams.x = [] # TO REMOVE
    for key, x in xs_dict.iteritems(): #//each batch once (x_batch_dict emptied at end)
        if not hparams.not_lazy:
            # If lazy, first check if the image has already been
            # saved before by *all* estimators. If yes, then skip this image.
            save_paths = utils.get_save_paths(hparams, key)
            is_saved = all([os.path.isfile(save_path) for save_path in save_paths.values()])
            if is_saved:
                continue

        x_batch_dict[key] = x       
        hparams.x.append(x)#To REMOVE
        if len(x_batch_dict) < hparams.batch_size:
            continue
        
        # Reshape input
        x_batch_list = [x.reshape(1, hparams.n_input) for _, x in x_batch_dict.iteritems()]
        x_batch = np.concatenate(x_batch_list)
#        x_batch, known_distortion, distances = get_random_distortion(x_batch)
#        distances_arr[(key-1)*hparams.batch_size:key*hparams.batch_size] = distances
#        xs_dict[(key-1)*hparams.batch_size:key*hparams.batch_size] =x_batch
        
        # Construct noise and measurements
        recovered, optim = utils.load_if_optimized(hparams)
        if recovered and np.linalg.norm(optim.x_orig-x_batch) < 1e-10:
            hparams.optim = optim
            hparams.recovered = True
        else:
            hparams.recovered=False
            optim.x_orig = x_batch
            
            hparams.optim = optim
            
        A, noise_batch, y_batch, c_val = utils.load_meas(hparams,sh,x_batch,xs_dict)
        hparams.optim.noise_batch = noise_batch
        if c_val:
            continue
        
        if hparams.measurement_type == 'sample_distribution':
            plot_distribution(hparams,x_batch)
            
#            for i in range(z.shape[1]):#range(1):
#                plt.hist(z[i,:], facecolor='blue', alpha=0.5)
#                directory_distr = 
#                pl.savefig("abc.png")            
        elif hparams.measurement_type == 'autoencoder':
            plot_reconstruction(hparams,x_batch) 
        else:
            # Construct estimates using each estimator
            for model_type in hparams.model_types:
                estimator = estimators[model_type]
                start = time.time()

                tmp = estimator(A, y_batch, hparams)
                if isinstance(tmp,tuple):
                    x_hat_batch = tmp[0]
                    sh.z_rec = tmp[1]                    
                else:
                    x_hat_batch = tmp
                    del sh.z_rec
                end = time.time()
                duration = end-start
                print('The calculation needed {} time'.format(datetime.timedelta(seconds=duration)))
                np.save(utils.get_checkpoint_dir(hparams, model_type)+'elapsed_time',duration)
#                DEBUGGING = []
                for i, key in enumerate(x_batch_dict.keys()):
    #                x = xs_dict[key]+known_distortion[i]
                    x = xs_dict[key]
                    y = y_batch[i]
                    x_hat = x_hat_batch[i]
#                    plt.figure()
#                    plt.imshow(np.reshape(x_hat, [64, 64, 3])*255)#, interpolation="nearest", cmap=plt.cm.gray)
#                    plt.show()
    
                    # Save the estimate
                    x_hats_dict[model_type][key] = x_hat
    
                    # Compute and store measurement and l2 loss
                    sh.measurement_losses[model_type][key] = utils.get_measurement_loss(x_hat, A, y)
#                    DEBUGGING.append(np.sum((x_hat.dot(A)-y)**2)/A.shape[1])
                    sh.l2_losses[model_type][key] = utils.get_l2_loss(x_hat, x)
                    if hparams.class_bol and label_dict is not None:
                        try:
                            sh.class_losses[model_type][key] = utils.get_classifier_loss(hparams,x_hat,label_dict[key])
                        except:
                            sh.class_losses[model_type][key] = NaN
                            warnings.warn('Class loss unsuccessfull, most likely due to corrupted memory. Simply retry.')
                    if hparams.emd_bol:
                        try:
                            _,sh.emd_losses[model_type][key] = utils.get_emd_loss(x_hat, x)
                            if 'nonneg' not in hparams.tv_or_lasso_mode and 'pca'  in model_type:
                                warnings.warn('EMD requires nonnegative images, for safety insert nonneg into tv_or_lasso_mode')
                        except ValueError:
                            warnings.warn('EMD calculation unsuccesfull (most likely due to negative images)')
                            pass
    #                    if l2_losses[model_type][key]-measurement_losses[model_type][key]!=0:
    #                        print('NO')
    #                        print(y)
    #                        print(x)
    #                        print(np.mean((x-y)**2))
                    image_distance[i] = np.linalg.norm(x_hat-x)
    #                l2_losses2[key] = np.mean((x_hat-x)**2)
    #                print('holla')
    #                print(l2_losses2[key])
    #                print(np.linalg.norm(x_hat-x)**2/len(xs_dict[0]))
    #                print(np.linalg.norm(x_hat-x)/len(xs_dict[0]))
    #                print(np.linalg.norm(x_hat-x))
            print('Processed upto image {0} / {1}'.format(key+1, len(xs_dict)))
            sh.x_orig = x_batch
            sh.x_rec = x_hat_batch
            sh.noise = noise_batch
    
            #ACTIVATE ON DEMAND
            #plot_bad_reconstruction(measurement_losses,x_batch)
            # Checkpointing
            if (hparams.save_images) and ((key+1) % hparams.checkpoint_iter == 0):           
                utils.checkpoint(x_hats_dict, save_image, sh, hparams)
                x_hats_dict = {model_type : {} for model_type in hparams.model_types}
                print('\nProcessed and saved first ', key+1, 'images\n')    
            x_batch_dict = {}
                   

    if 'wavelet' in hparams.model_types[0]:
        print np.abs(sh.x_rec)
        print('The average sparsity is {}'.format(np.sum(np.abs(sh.x_rec)>=0.0001)/float(hparams.batch_size)))

    # Final checkpoint
    if hparams.save_images:
        utils.checkpoint(x_hats_dict, save_image, sh, hparams)
        print('\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict)))
        if hparams.dataset in ['mnist', 'fashion-mnist']:
            if np.array(x_batch).size:
                utilsM.save_images(np.reshape(x_batch, [-1, 28, 28]),
                                          [8, 8],utils.get_checkpoint_dir(hparams, hparams.model_types[0])+'original.png')
            if np.array(x_hat_batch).size:
                utilsM.save_images(np.reshape(x_hat_batch, [-1, 28, 28]),
                                          [8, 8],utils.get_checkpoint_dir(hparams, hparams.model_types[0])+'reconstruction.png')

        for model_type in hparams.model_types:
#            print(model_type)
            mean_m_loss = np.mean(sh.measurement_losses[model_type].values())
            mean_l2_loss = np.mean(sh.l2_losses[model_type].values()) #\|XHUT-X\|**2/784/64
            if hparams.emd_bol:
                mean_emd_loss = np.mean(sh.emd_losses[model_type].values())
            if label_dict is not None:
                mean_class_loss = np.mean(sh.class_losses[model_type].values())
                print('mean class loss = {0}'.format(mean_class_loss))
#            print(image_distance)
            mean_norm_loss = np.mean(image_distance)#sum_i(\|xhut_i-x_i\|)/64
#            mean_rep_error = np.mean(distances_arr)
#            mean_opt_meas_error_pixel = np.mean(np.array(l2_losses[model_type].values())-np.array(distances_arr)/xs_dict[0].shape)
#            mean_opt_meas_error = np.mean(image_distance-distances_arr)
            print('mean measurement loss = {0}'.format(mean_m_loss))
#            print np.sum(np.asarray(DEBUGGING))/64
            print('mean l2 loss = {0}'.format(mean_l2_loss))
            if hparams.emd_bol:
                print('mean emd loss = {0}'.format(mean_emd_loss))            
            print('mean distance = {0}'.format(mean_norm_loss))
            print('mean distance pixelwise = {0}'.format(mean_norm_loss/len(xs_dict[xs_dict.keys()[0]])))
#            print('mean representation error = {0}'.format(mean_rep_error))
#            print('mean optimization plus measurement error = {0}'.format(mean_opt_meas_error))
#            print('mean optimization plus measurement error per pixel = {0}'.format(mean_opt_meas_error_pixel))

    if hparams.image_matrix > 0:
        utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams)

    # Warn the user that some things were not processsed
    if len(x_batch_dict) > 0:
        print('\nDid NOT process last {} images because they did not fill up the last batch.'.format(len(x_batch_dict)))
        print('Consider rerunning lazily with a smaller batch size.')
Ejemplo n.º 17
0
def main(hparams):
    hparams.n_input = np.prod(hparams.image_shape)
    hparams.model_type = 'vae'
    maxiter = hparams.max_outer_iter
    utils.print_hparams(hparams)
    xs_dict = model_input(hparams)  # returns the images
    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses = utils.load_checkpoints(hparams)

    x_hats_dict = {'vae': {}}
    x_batch_dict = {}

    for key, x in xs_dict.iteritems():
        print key
        x_batch_dict[key] = x  #placing images in dictionary
        if len(x_batch_dict) < hparams.batch_size:
            continue
        x_coll = [
            x.reshape(1, hparams.n_input) for _, x in x_batch_dict.iteritems()
        ]  #Generates the columns of input x
        x_batch = np.concatenate(x_coll)  # Generates entire X

        A_outer = utils.get_outer_A(hparams)  # Created the random matric A

        noise_batch = hparams.noise_std * np.random.randn(
            hparams.batch_size, 100)

        y_batch_outer = np.sign(
            np.matmul(x_batch, A_outer)
        )  # Multiplication of A and X followed by quantization on 4 levels

        #y_batch_outer = np.matmul(x_batch, A_outer)

        x_main_batch = 0.0 * x_batch
        z_opt_batch = np.random.randn(hparams.batch_size,
                                      20)  #Input to the generator of the GAN

        for k in range(maxiter):

            x_est_batch = x_main_batch + hparams.outer_learning_rate * (
                np.matmul(
                    (y_batch_outer -
                     np.sign(np.matmul(x_main_batch, A_outer))), A_outer.T))
            #x_est_batch = x_main_batch + hparams.outer_learning_rate * (np.matmul((y_batch_outer - np.matmul(x_main_batch, A_outer)), A_outer.T))
            # Gradient decent in x is done
            estimator = estimators['vae']
            x_hat_batch, z_opt_batch = estimator(
                x_est_batch, z_opt_batch, hparams)  # Projectin on the GAN
            x_main_batch = x_hat_batch

        dist = np.linalg.norm(x_batch - x_main_batch) / 784
        print 'cool'
        print dist

        for i, key in enumerate(x_batch_dict.keys()):
            x = xs_dict[key]
            y = y_batch_outer[i]
            x_hat = x_hat_batch[i]

            # Save the estimate
            x_hats_dict['vae'][key] = x_hat

            # Compute and store measurement and l2 loss
            measurement_losses['vae'][key] = utils.get_measurement_loss(
                x_hat, A_outer, y)
            l2_losses['vae'][key] = utils.get_l2_loss(x_hat, x)
        print 'Processed upto image {0} / {1}'.format(key + 1, len(xs_dict))

        # Checkpointing
        if (hparams.save_images) and ((key + 1) % hparams.checkpoint_iter
                                      == 0):
            utils.checkpoint(x_hats_dict, measurement_losses, l2_losses,
                             save_image, hparams)
            #x_hats_dict = {'dcgan' : {}}
            print '\nProcessed and saved first ', key + 1, 'images\n'

        x_batch_dict = {}

    # Final checkpoint
    if hparams.save_images:
        utils.checkpoint(x_hats_dict, measurement_losses, l2_losses,
                         save_image, hparams)
        print '\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict))

    if hparams.print_stats:
        for model_type in hparams.model_types:
            print model_type
            mean_m_loss = np.mean(measurement_losses[model_type].values())
            mean_l2_loss = np.mean(l2_losses[model_type].values())
            print 'mean measurement loss = {0}'.format(mean_m_loss)
            print 'mean l2 loss = {0}'.format(mean_l2_loss)

    if hparams.image_matrix > 0:
        utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams)

    # Warn the user that some things were not processsed
    if len(x_batch_dict) > 0:
        print '\nDid NOT process last {} images because they did not fill up the last batch.'.format(
            len(x_batch_dict))
        print 'Consider rerunning lazily with a smaller batch size.'
Ejemplo n.º 18
0
        cfg.iter_routing=iter_routing

        if not debug:
            t = str(int(time.time()))
            out_dir = os.path.join(cfg.out_dir, "log_%s" % t)

            if not os.path.exists(out_dir):
                os.makedirs(out_dir)
            log_file = os.path.join(out_dir, "log_%s" % t)
            log_f = tf.gfile.GFile(log_file, mode="a")
            utils.print_out("# log_file=%s" % log_file, log_f)
        else:
            log_f = None
            out_dir = None

        utils.print_hparams(cfg, f=log_f)
        train(do_k_fold, out_dir, log_f)
        utils.print_out('END grid %d' % i)

else:
    if not debug:
        t = str(int(time.time()))
        out_dir = os.path.join(cfg.out_dir, "log_%s" % t)

        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        log_file = os.path.join(out_dir, "log_%s" % t)
        log_f = tf.gfile.GFile(log_file, mode="a")
        utils.print_out("# log_file=%s" % log_file, log_f)
    else:
        log_f = None
Ejemplo n.º 19
0
            'ratio_click_of_consumptionAbility_in_aid',
            'ratio_click_of_age_in_advertiserId',
            'ratio_click_of_productType_in_uid',
            'ratio_click_of_productType_in_consumptionAbility'
        ],
        mutil_features=[
            'interest1', 'interest2', 'interest3', 'interest4', 'interest5',
            'kw1', 'kw2', 'kw3', 'topic1', 'topic2', 'topic3', 'appIdAction',
            'appIdInstall', 'marriageStatus', 'ct', 'os'
        ],
    )


hparams = create_hparams()
hparams.path = './model/'
utils.print_hparams(hparams)

hparams.aid = [
    'aid', 'advertiserId', 'campaignId', 'creativeId', 'creativeSize',
    'adCategoryId', 'productId', 'productType',
    'cvr_of_creativeId_and_onehot2', 'cvr_of_creativeId_and_onehot9',
    'cvr_of_creativeId_and_onehot16', 'cvr_of_creativeId_and_onehot10',
    'cvr_of_creativeId_and_onehot15', 'cvr_of_creativeId_and_onehot14',
    'cvr_of_creativeId_and_onehot13', 'cvr_of_creativeId_and_onehot18',
    'cvr_of_aid_and_age', 'cvr_of_creativeSize', 'cvr_of_uid_and_adCategoryId',
    'cvr_of_uid_and_creativeSize', 'ratio_click_of_productType_in_uid'
]
hparams.user = [
    'age', 'gender', 'education', 'consumptionAbility', 'LBS', 'carrier',
    'house', 'interest1', 'interest2', 'interest3', 'interest4', 'interest5',
    'kw1', 'kw2', 'kw3', 'topic1', 'topic2', 'topic3', 'appIdAction',
Ejemplo n.º 20
0
def main(hparams):

    # Set up some stuff accoring to hparams
    hparams.n_input = np.prod(hparams.image_shape)
    utils.set_num_measurements(hparams)
    utils.print_hparams(hparams)

    # get inputs
    xs_dict = model_input(hparams)

    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses = utils.load_checkpoints(hparams)

    x_hats_dict = {model_type: {} for model_type in hparams.model_types}
    x_batch_dict = {}
    for key, x in xs_dict.iteritems():
        if not hparams.not_lazy:
            # If lazy, first check if the image has already been
            # saved before by *all* estimators. If yes, then skip this image.
            save_paths = utils.get_save_paths(hparams, key)
            is_saved = all([
                os.path.isfile(save_path) for save_path in save_paths.values()
            ])
            if is_saved:
                continue

        x_batch_dict[key] = x
        if len(x_batch_dict) < hparams.batch_size:
            continue

        # Reshape input
        x_batch_list = [
            x.reshape(1, hparams.n_input) for _, x in x_batch_dict.iteritems()
        ]
        x_batch = np.concatenate(x_batch_list)

        # Construct noise and measurements
        A = utils.get_A(hparams)
        noise_batch = hparams.noise_std * np.random.randn(
            hparams.batch_size, hparams.num_measurements)
        if hparams.measurement_type == 'project':
            y_batch = x_batch + noise_batch
        else:
            y_batch = np.matmul(x_batch, A) + noise_batch

        # Construct estimates using each estimator
        for model_type in hparams.model_types:
            estimator = estimators[model_type]
            x_hat_batch = estimator(A, y_batch, hparams)

            for i, key in enumerate(x_batch_dict.keys()):
                x = xs_dict[key]
                y = y_batch[i]
                x_hat = x_hat_batch[i]

                # Save the estimate
                x_hats_dict[model_type][key] = x_hat

                # Compute and store measurement and l2 loss
                measurement_losses[model_type][
                    key] = utils.get_measurement_loss(x_hat, A, y)
                l2_losses[model_type][key] = utils.get_l2_loss(x_hat, x)

        print('Processed upto image {0} / {1}'.format(key + 1, len(xs_dict)))

        # Checkpointing
        if (hparams.save_images) and ((key + 1) % hparams.checkpoint_iter
                                      == 0):
            utils.checkpoint(x_hats_dict, measurement_losses, l2_losses,
                             save_image, hparams)
            x_hats_dict = {
                model_type: {}
                for model_type in hparams.model_types
            }
            print('\nProcessed and saved first ', key + 1, 'images\n')

        x_batch_dict = {}

    # Final checkpoint
    if hparams.save_images:
        utils.checkpoint(x_hats_dict, measurement_losses, l2_losses,
                         save_image, hparams)
        print('\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict)))

    if hparams.print_stats:
        for model_type in hparams.model_types:
            print(model_type)
            mean_m_loss = np.mean(measurement_losses[model_type].values())
            mean_l2_loss = np.mean(l2_losses[model_type].values())
            print('mean measurement loss = {0}'.format(mean_m_loss))
            print('mean l2 loss = {0}'.format(mean_l2_loss))

    if hparams.image_matrix > 0:
        utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams)

    # Warn the user that some things were not processsed
    if len(x_batch_dict) > 0:
        print(
            '\nDid NOT process last {} images because they did not fill up the last batch.'
            .format(len(x_batch_dict)))
        print('Consider rerunning lazily with a smaller batch size.')
Ejemplo n.º 21
0
def main(hparams):
    # Set up some stuff according to hparams
    utils.set_up_dir(hparams.ckpt_dir)
    utils.set_up_dir(hparams.sample_dir)
    utils.print_hparams(hparams)

    # encode
    x_ph = tf.placeholder(tf.float32, [None, hparams.n_input], name='x_ph')
    z_mean, z_log_sigma_sq = model_def.encoder(hparams,
                                               x_ph,
                                               'enc',
                                               reuse=False)

    # sample
    eps = tf.random_normal((hparams.batch_size, hparams.n_z),
                           0,
                           1,
                           dtype=tf.float32)
    z_sigma = tf.sqrt(tf.exp(z_log_sigma_sq))
    z = z_mean + z_sigma * eps

    # reconstruct
    logits, x_reconstr_mean = model_def.generator(hparams,
                                                  z,
                                                  'gen',
                                                  reuse=False)

    # generator sampler
    z_ph = tf.placeholder(tf.float32, [None, hparams.n_z], name='x_ph')
    _, x_sample = model_def.generator(hparams, z_ph, 'gen', reuse=True)

    # define loss and update op
    total_loss = model_def.get_loss(x_ph, logits, z_mean, z_log_sigma_sq)
    opt = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
    update_op = opt.minimize(total_loss)

    # Sanity checks
    for var in tf.global_variables():
        print var.op.name
    print ''

    # Get a new session
    sess = tf.Session()

    # Model checkpointing setup
    model_saver = tf.train.Saver()

    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    # Attempt to restore variables from checkpoint
    start_epoch = utils.try_restore(hparams, sess, model_saver)

    # Get data iterator
    iterator = data_input.channel_data_iteratior(hparams)
    next_element = iterator.get_next()

    # Training
    for epoch in range(start_epoch + 1, hparams.training_epochs):
        avg_loss = 0.0
        num_batches = hparams.num_samples // hparams.batch_size
        batch_num = 0
        for i in range(num_batches):
            try:
                x_batch_val = sess.run(next_element['H_data'])
                x_batch_val = np.squeeze(x_batch_val)
                batch_num += 1
                feed_dict = {x_ph: x_batch_val}
                _, loss_val = sess.run([update_op, total_loss],
                                       feed_dict=feed_dict)
                #print(loss_val)
                avg_loss += loss_val / hparams.num_samples * hparams.batch_size
            except tf.errors.OutOfRangeError:
                print("End of dataset")
                break

        if epoch % hparams.summary_epoch == 0:
            print "Epoch:", '%04d' % (epoch), 'Avg loss = {:.9f}'.format(
                avg_loss)

        if epoch % hparams.ckpt_epoch == 0:
            save_path = os.path.join(hparams.ckpt_dir, 'channel_vae_model')
            model_saver.save(sess, save_path, global_step=epoch)

    save_path = os.path.join(hparams.ckpt_dir, 'channel_vae_model')
    model_saver.save(sess, save_path, global_step=hparams.training_epochs - 1)
Ejemplo n.º 22
0
def main(hparams):

    # Set up some stuff accoring to hparams
    hparams.n_input = np.prod(hparams.image_shape)
    utils.set_num_measurements(hparams)
    utils.print_hparams(hparams)
    
    if hparams.dataset == 'mnist':
        hparams.n_z = latent_dim
    elif hparams.dataset == 'celebA':
        hparams.z_dim = latent_dim 
    
    # get inputs
    xs_dict = model_input(hparams)

    estimators = utils.get_estimators(hparams)
    utils.setup_checkpointing(hparams)
    measurement_losses, l2_losses = utils.load_checkpoints(hparams)
    
    image_loss_mnist = []
    meas_loss_mnist = []
    x_hat_mnist = []
    x_hats_dict = {model_type : {} for model_type in hparams.model_types}
    x_batch_dict = {}
    for key, x in xs_dict.iteritems():
        if not hparams.not_lazy:
            # If lazy, first check if the image has already been
            # saved before by *all* estimators. If yes, then skip this image.
            save_paths = utils.get_save_paths(hparams, key)
            is_saved = all([os.path.isfile(save_path) for save_path in save_paths.values()])
            if is_saved:
                continue

        x_batch_dict[key] = x
        if len(x_batch_dict) < hparams.batch_size:
            continue

        # Reshape input
        x_batch_list = [x.reshape(1, hparams.n_input) for _, x in x_batch_dict.iteritems()]
        x_batch = np.concatenate(x_batch_list)

        # Construct noise and measurements
        A = utils.get_A(hparams)
        noise_batch = hparams.noise_std * np.random.randn(hparams.batch_size, hparams.num_measurements)
        if hparams.measurement_type == 'project':
            y_batch = x_batch + noise_batch
        else:
            measure = np.matmul(x_batch, A)
            y_batch = np.absolute(measure) + noise_batch

        # Construct estimates using each estimator
        for model_type in hparams.model_types:
            x_main_batch = 10000*np.ones_like(x_batch)
            for k in range(num_restarts):
                print "Restart #", str(k+1)

                # Solve deep pr problem with random initial iterate
                init_iter = np.random.randn(hparams.batch_size, latent_dim)

                # First gradient descent
                z_opt_batch = init_iter                
                estimator = estimators[model_type]
                items = estimator(A, y_batch, z_opt_batch, hparams)
                x_hat_batch1 = items[0]
                z_opt_batch1 = items[1]
                losses_val1  = items[2]
                x_hat_batch = x_hat_batch1
                x_hat_batch = utils.resolve_ambiguity(x_hat_batch, x_batch, hparams.batch_size)
       
                # Use reflection of initial iterate
                z_opt_batch2 = -1*init_iter
                items = estimator(A, y_batch, z_opt_batch2, hparams)
                x_hat_batch2 = items[0]
                z_opt_batch2 = items[1]
                losses_val2  = items[2]           
                x_hat_batch2 = utils.resolve_ambiguity(x_hat_batch2, x_batch, hparams.batch_size)

                x_hat_batchnew = utils.get_optimal_x_batch(x_hat_batch, x_hat_batch2, x_batch, hparams.batch_size)                
                x_main_batch = utils.get_optimal_x_batch(x_hat_batchnew, x_main_batch, x_batch, hparams.batch_size)

            x_hat_batch = x_main_batch
            if hparams.dataset == 'mnist':
                utils.print_stats(x_hat_batch, x_batch, hparams.batch_size)

            for i, key in enumerate(x_batch_dict.keys()):
                x = xs_dict[key]
                y = y_batch[i]
                x_hat = x_hat_batch[i]

                # Save the estimate
                x_hats_dict[model_type][key] = x_hat

                # Compute and store measurement and l2 loss
                measurement_losses[model_type][key] = utils.get_measurement_loss(x_hat, A, y)
                meas_loss_mnist.append(utils.get_measurement_loss(x_hat, A, y))
                l2_losses[model_type][key] = utils.get_l2_loss(x_hat, x)
                image_loss_mnist.append(utils.get_l2_loss(x_hat,x))
        print 'Processed upto image {0} / {1}'.format(key+1, len(xs_dict))

        # Checkpointing
        if (hparams.save_images) and ((key+1) % hparams.checkpoint_iter == 0):
            utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, save_image, hparams)
            x_hats_dict = {model_type : {} for model_type in hparams.model_types}
            print '\nProcessed and saved first ', key+1, 'images\n'

        x_batch_dict = {}

    # Final checkpoint
    if hparams.save_images:
        utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, save_image, hparams)
        print '\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict))

    if hparams.print_stats:
        for model_type in hparams.model_types:
            mean_m_loss = np.mean(measurement_losses[model_type].values())
            mean_l2_loss = np.mean(l2_losses[model_type].values())
            print 'mean measurement loss = {0}'.format(mean_m_loss)
            print 'mean l2 loss = {0}'.format(mean_l2_loss)

    if hparams.image_matrix > 0:
        utils.image_matrix(xs_dict, x_hats_dict, view_image, hparams)

    # Warn the user that some things were not processsed
    if len(x_batch_dict) > 0:
        print '\nDid NOT process last {} images because they did not fill up the last batch.'.format(len(x_batch_dict))
        print 'Consider rerunning lazily with a smaller batch size.'