Пример #1
0
    def model_session(self):
        """
        load a model with a tf session
        :return: a tf session
        """
        saver = tf.train.Saver(tf.all_variables())
        init = tf.initialize_all_variables()

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self.
                                    per_process_gpu_memory_fraction)
        config = tf.ConfigProto(gpu_options=gpu_options,
                                log_device_placement=True,
                                allow_soft_placement=True)
        config.gpu_options.allow_growth = True

        sess = tf.Session(config=config)
        sess.run(init)

        # resume
        latest = str(util.latest_checkpoint(self.model_weights_dir))
        if not latest:
            print "No checkpoint to continue from in", latest
            sys.exit(1)
        saver.restore(sess, latest)
        print "model loaded", latest

        return sess
Пример #2
0
        # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
        checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
        checkpoint_prefix = os.path.join(checkpoint_dir, "model")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver = tf.train.Saver(tf.global_variables(),
                               max_to_keep=FLAGS.num_checkpoints)

        # Write vocabulary and index2label
        vocab_processor.save(os.path.join(out_dir, "vocab"))
        pickle.dump(index2label,
                    open(os.path.join(out_dir, 'index2label.pk'), 'wb'))

        # resume or Initialize all variables to train from scratch
        if FLAGS.resume:
            latest = str(util.latest_checkpoint(FLAGS.CHECKPOINT_DIR))
            if not latest:
                print("No checkpoint to continue from in", latest)
                sys.exit(1)
            print("resume training", latest)
            saver.restore(sess, latest)
        else:
            sess.run(tf.global_variables_initializer())

        def train_step(x_batch, y_batch):
            """
            A single training step
            """
            feed_dict = {
                vdcnn.input_x: x_batch,
                vdcnn.input_y: y_batch,
Пример #3
0
def evalute():
    # Prepare graph and configuration
    data_files = list(filter(lambda x: not gfile.IsDirectory(x), [os.path.join(FLAGS.data_dir, x) for x in gfile.ListDirectory(FLAGS.data_dir)]))
    filename_queue = tf.train.string_input_producer(data_files, num_epochs=1)
    model, saver = create_graph(filename_queue, FLAGS.ps_count)
    init_op = tf.global_variables_initializer()
    local_variables_init_op = tf.local_variables_initializer()
    sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement)
    output_path = os.path.join(FLAGS.output_dir, FLAGS.output_name)

    with tf.Session(config=sess_config) as sess, open(output_path, "w") as filew:
        sess.run([init_op, local_variables_init_op])

        # Restore checkpoint
        checkpoint_path = latest_checkpoint(FLAGS.model_dir)
        if checkpoint_path:
            print("Restoring model from", checkpoint_path)
            try:
                saver.restore(sess, checkpoint_path)
            except tf.errors.NotFoundError as e:
                print(datetime.now(), "Not all variables found in checkpoint, ingore") 
            print_model(sess)
        else:
            raise IOError("No model files found")

        # Populate the queues
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # Start to validate
        local_step = 0
        graph_comptime = 0.0
        evaluation_start = time.time()
        try:
            while not coord.should_stop():
                # Exectuion OPs
                begin = time.time()
                # _loss, _pred, _labels, _weights, final, term_embeddings, layer1_output, layer3_input1, layer3_input2, pos, layer2_output, layer2_input, embedding_out = sess.run(
                #     [
                #         model.loss_fn,
                #         model.prediction,
                #         model.labels,
                #         model.weights,
                #         model.final,
                #         model.term_embeddings,
                #         model.layer1_output,
                #         model.layer3_input1,
                #         model.layer3_input2,
                #         model.pos,
                #         model.layer2_output,
                #         model.layer2_input,
                #         model.embedding_out,
                #     ])
                _pred, _labels, _weights = sess.run([model.prediction, model.labels, model.weights])
                end = time.time()

                # Update counters
                local_step += 1
                graph_comptime += end - begin
                if local_step % FLAGS.report_progress_step == 0:
                    print("local_step: %d, mean_step_time: %.3f, validation_time: %.2f" % (local_step, graph_comptime/local_step, end-evaluation_start))
                    #print("Pred:", _pred, "Loss:", _loss)
                    #print("layer3_input1:", layer3_input1)
                    #print("embedding_out:", embedding_out)
                    #print("term embedding:", term_embeddings)
                    #print("layer2_input:", layer2_input)
                    #print("layer1_output:", layer1_output)
                    #print("layer2_output:", layer2_output)
                    #print("layer3_input2:", layer3_input2)
                    #print("final:", final)
                    #print("pos:", pos)
                    #print(_loss)
                    #print(_pred)

                # Output result
                for i in range(len(_pred)):
                    filew.write(str(_pred[i][0]) + '\t' + str(_labels[i][0]) + '\t' + str(_weights[i][0]) + '\n')
        except tf.errors.OutOfRangeError:
            print("Reach EOF of data")

        # Stop all threads
        coord.request_stop()
        coord.join(threads)
def train(config):
  """Train a Variational Autoencoder or deep latent gaussian model on MNIST."""
  cfg = config
  logger = logging.getLogger()
  t0 = time.time()
  logdir_name = util.list_to_str(
      ['dlgm', cfg['p/n_layers'], 'layer', 'w_stddev',
       cfg['p_net/init_w_stddev'], cfg['inference'],
       'q_init_stddev', cfg['q/init_stddev'], 'lr', cfg['optim/learning_rate']
       ])
  if cfg['inference'] == 'proximity':
    logdir_name += '_' + util.list_to_str(
        [cfg['c/proximity_statistic'], 'decay_rate', cfg['c/decay_rate'],
         'decay_steps', cfg['c/decay_steps'], 'lag', cfg['c/lag'],
         cfg['c/decay'], cfg['c/magnitude']])
  cfg['log/dir'] = util.make_logdir(cfg, logdir_name)
  util.log_to_file(os.path.join(cfg['log/dir'], 'train.log'))
  logger.info(cfg)
  np.random.seed(433423)
  tf.set_random_seed(435354)
  sess = tf.InteractiveSession()
  data_iterator, _, _ = util.provide_data(cfg['train_data'])

  def get_feed_iterator():
    while True:
      yield {input_data: next(data_iterator)}
  feed_iterator = get_feed_iterator()
  input_data = tf.placeholder(tf.float32, [cfg['batch_size'], 28, 28, 1])
  tf.summary.image('data', input_data)
  model = models.DeepLatentGaussianModel(cfg)
  variational = models.DeepLatentGaussianVariational(cfg)
  if cfg['inference'] == 'vanilla':
    inference_fn = inferences.VariationalInference
  elif cfg['inference'] == 'proximity':
    inference_fn = inferences.ProximityVariationalInference
  inference = inference_fn(sess, cfg, model, variational, input_data)
  inference.build_train_op()
  # prior_predictive = stats.build_prior_predictive(model)
  posterior_predictive = stats.build_posterior_predictive(
      cfg, model, variational, input_data)
  inference.build_summary_op()

  ckpt = util.latest_checkpoint(cfg)
  if ckpt is not None:
    inference.saver.restore(sess, ckpt)
  else:
    inference.initialize(next(feed_iterator))

  if not cfg['eval_only']:
    for py_step in range(cfg['n_iterations']):
      feed_dict = next(feed_iterator)
      if py_step == 0:
        inference.initialize(feed_dict)
      if cfg['inference'] == 'proximity' and cfg['c/lag'] != 'moving_average':
        feed_dict.update(
            inference.constraint_feed_dict(py_step, feed_iterator))
      if py_step % cfg['print_every'] == 0:
        logger.info(inference.log_stats(feed_dict))
        #util.save_prior_posterior_predictives(
        #    cfg, sess, inference, prior_predictive,
        #    posterior_predictive, feed_dict, feed_dict[input_data])
      sess.run(inference.train_op, feed_dict)
    print(tf.train.latest_checkpoint(cfg['log/dir']))

  # evaluation
  if cfg['eval_only']:
    valid_iterator, np_valid_data_mean, _ = util.provide_data(
        cfg['valid_data'])

    def create_iterator():
      while True:
        yield {input_data: next(valid_iterator)}
    valid_feed_iterator = create_iterator()
    np_l = 0.
    np_log_x = 0.
    for i in range(cfg['valid_data/n_examples'] // cfg['valid_data/batch_size']):
      feed_dict = next(valid_feed_iterator)
      tmp_np_log_x, tmp_np_l = sess.run(
          [inference.log_p_x_hat, inference.elbo], feed_dict)
      np_log_x += np.sum(tmp_np_log_x)
      np_l += np.mean(tmp_np_l)
    logger.info('Time total of: %.3f hours' % ((time.time() - t0) / 60. / 60.))
    valid_elbo = np_l / cfg['valid_data/n_examples']
    valid_log_lik = np_log_x / cfg['valid_data/n_examples']
    txt = ('for %s set -- elbo: %.10f\tlog_likelihood: %.10f' % (
        cfg['valid_data/split'], valid_elbo, valid_log_lik))
    logger.info(txt)
    with open(os.path.join(cfg['log/dir'], 'job.log'), 'w') as f:
      f.write(txt)
    eval_summ = tf.Summary()
    eval_summ.value.add(tag='Valid ELBO', simple_value=valid_elbo)
    eval_summ.value.add(tag='Valid Log Likelihood', simple_value=valid_log_lik)
    inference.summary_writer.add_summary(eval_summ, 0)
    inference.summary_writer.flush()
def train(config):
    """Train sigmoid belief network on MNIST."""
    cfg = config
    logger = logging.getLogger()
    t0 = time.time()
    logdir_name = '_'.join(
        str(s) for s in [
            'sbn_n_layers', cfg['p/n_layers'], 'pi', cfg['p/bernoulli_p'],
            'geom_mean', cfg['optim/geometric_mean'], 'w_eps', cfg['p/w_eps'],
            cfg['optim/learning_rate'], 'learn_prior', cfg['p/learn_prior'],
            cfg['inference']
        ])
    if cfg['inference'] == 'proximity' or cfg['optim/deterministic_annealing']:
        if cfg['optim/deterministic_annealing']:
            logdir_name += '_DA_'
        logdir_name += '_' + '_'.join(
            str(s) for s in [
                cfg['c/proximity_statistic'], 'decay', cfg['c/decay'],
                'decay_rate', cfg['c/decay_rate'], cfg['c/decay_steps'], 'lag',
                cfg['c/lag'], cfg['moving_average/decay'], 'k',
                cfg['c/magnitude']
            ])
    cfg['log/dir'] = util.make_logdir(cfg, logdir_name)
    util.log_to_file(os.path.join(cfg['log/dir'], 'train.log'))
    logger.info(cfg)
    np.random.seed(433423)
    tf.set_random_seed(435354)
    sess = tf.InteractiveSession()
    data_iterator, np_data_mean, _ = util.provide_data(cfg['train_data'])
    input_data = tf.placeholder(cfg['dtype'],
                                [cfg['batch_size']] + cfg['train_data/shape'])
    data_mean = tf.placeholder(cfg['dtype'], cfg['train_data/shape'])
    tf.summary.image('data', input_data)
    data = {'input_data': input_data, 'data_mean': data_mean}

    def create_iterator():
        while True:
            yield {input_data: next(data_iterator), data_mean: np_data_mean}

    feed_iterator = create_iterator()
    model = models.SigmoidBeliefNetworkModel(cfg)
    variational = models.SigmoidBeliefNetworkVariational(cfg)
    if cfg['inference'] == 'vanilla':
        inference_fn = inferences.VariationalInference
    elif cfg['inference'] == 'proximity':
        inference_fn = inferences.ProximityVariationalInference
    inference = inference_fn(sess, cfg, model, variational, data)
    inference.build_train_op()
    prior_predictive = stats.build_prior_predictive(model)
    posterior_predictive = stats.build_posterior_predictive(
        cfg, model, variational, data)
    inference.build_summary_op()
    ckpt = util.latest_checkpoint(cfg)
    if ckpt is not None:
        inference.saver.restore(sess, ckpt)
    else:
        inference.initialize(next(feed_iterator))

    # train
    if not cfg['eval_only']:
        first_feed_dict = next(feed_iterator)
        for py_step in range(cfg['optim/n_iterations']):
            feed_dict = next(feed_iterator)
            if py_step % cfg['print_every'] == 0:
                logger.info(inference.log_stats(feed_dict))
            sess.run(inference.train_op, feed_dict)
        util.save_prior_posterior_predictives(cfg, sess, inference,
                                              prior_predictive,
                                              posterior_predictive,
                                              first_feed_dict,
                                              first_feed_dict[input_data])

    # evaluation
    if cfg['eval_only']:
        valid_data_iterator, np_valid_data_mean, _ = util.provide_data(
            cfg['valid_data'])

        def create_iterator():
            while True:
                yield {
                    input_data: next(valid_data_iterator),
                    data_mean: np_valid_data_mean
                }

        valid_feed_iterator = create_iterator()

        np_l = 0.
        np_log_x = 0.
        assert cfg['valid_data/batch_size'] == 1
        for i in range(cfg['valid_data/n_examples'] //
                       cfg['valid_data/batch_size']):
            feed_dict = next(valid_feed_iterator)
            tmp_np_log_x, tmp_np_l = sess.run(
                [inference.log_p_x_hat, inference.elbo], feed_dict)
            np_log_x += np.sum(tmp_np_log_x)
            np_l += np.mean(tmp_np_l)
        logger.info('Time total of: %.3f hours' %
                    ((time.time() - t0) / 60. / 60.))
        txt = ('for %s set -- elbo: %.10f\tlog_likelihood: %.10f' %
               (cfg['valid_data/split'], np_l / cfg['valid_data/n_examples'],
                np_log_x / cfg['valid_data/n_examples']))
        logger.info(txt)
        with open(os.path.join(cfg['log/dir'], 'job.log'), 'w') as f:
            f.write(txt)
    print(tf.train.latest_checkpoint(cfg['log/dir']))