예제 #1
0
def get_validation_data():
    dev_cm = read_data('dev_cm.txt')
    dev_nl = read_data('dev_nl.txt')

    dev_dict = {}
    for x, y in zip(dev_cm, dev_nl):
        if y not in dev_dict.keys():
            dev_dict[y] = [x]
        else:
            dev_dict[y].append(x)

    dev_dataset = list(zip(*dev_dict.items()))
    #dev_dataset = [x[:cfg('val_n')] for x in dev_dataset]
    return dev_dataset
예제 #2
0
def preprocess():
    for NAME in ('dev', 'train', 'dirty'):
        try:
            cm = read_data(NAME + '_cm.txt')
            nl = read_data(NAME + '_nl.txt')
        except:
            print(f"[WARNING]: {NAME} data not found")
            continue

        al = [context(x) for x in zip(nl, cm)]
        al = al = "".join(al)
        if not al.endswith(cfg('eos')):
            al += cfg('eos')

        save_data(NAME + ".txt", al)
def main(_):
  print("-" * 80)
  if not os.path.isdir(FLAGS.output_dir):
    print("Path {0} does not exist. Creating.".format(FLAGS.output_dir))
    os.makedirs(FLAGS.output_dir)
  elif FLAGS.reset_output_dir:
    print("Path {0} exists. Remove and remake.".format(FLAGS.output_dir))
    shutil.rmtree(FLAGS.output_dir)
    os.makedirs(FLAGS.output_dir)

  print_user_flags()

  hparams = Hparams()
  images, labels = read_data(FLAGS.data_path)

  g = tf.Graph()
  with g.as_default():
    ops = get_ops(images, labels)

    # count model variables
    tf_variables = tf.trainable_variables()
    num_params = count_model_params(tf_variables)

    print("-" * 80)
    print("Starting session")
    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.train.SingularMonitoredSession(
      config=config, checkpoint_dir=FLAGS.output_dir) as sess:

        # training loop
        print("-" * 80)
        print("Starting training")
        for step in range(1, hparams.train_steps + 1):
          sess.run(ops["train_op"])
          if step % FLAGS.log_every == 0:
            global_step, train_loss, valid_acc = sess.run([
              ops["global_step"],
              ops["train_loss"],
              ops["valid_acc"],
            ])
            log_string = ""
            log_string += "step={0:<6d}".format(step)
            log_string += " loss={0:<5.2f}".format(train_loss)
            log_string += " val_acc={0:<3d}/{1:<3d}".format(
              valid_acc, hparams.eval_batch_size)
            print(log_string)
            sys.stdout.flush()

        # final test
        print("-" * 80)
        print("Training done. Eval on TEST set")
        num_corrects = 0
        for _ in range(10000 // hparams.eval_batch_size):
          num_corrects += sess.run(ops["test_acc"])
        print("test_accuracy: {0:>5d}/10000".format(num_corrects))
예제 #4
0
def main():
    data_path = os.path.join("data", "u.data")

    x, omega = data_utils.read_data(data_path, M=M, N=N, offset=OFFSET)

    try:
        p = np.load("p.npy")
        q = np.load("q.npy")
        trained_pq = (p, q)
    except FileNotFoundError:
        trained_pq = None

    ## changing k would require starting with different p and q
    ## delete those files before continuing
    p, q = sgd.stochastic_gd(x=x,
                             omega=omega,
                             lam=0.,
                             k=10,
                             trained_pq=trained_pq,
                             batch_size=943,
                             learning_rate=1.)

    p.dump("p.npy")
    q.dump("q.npy")
예제 #5
0
파일: main.py 프로젝트: p328188467/tfmnist
def main(_):
    print("-" * 80)
    images, labels = read_data("./fashion-minst/")
    epoches = []
    popu = create_popu(None, images, labels)
    train_child(popu)