Exemple #1
0
def main(_):
    # FILESYSTEM SETUP ------------------------------------------------------------
    assert FLAGS.data_dir, "Must specify data location!"
    assert FLAGS.log_dir, "Must specify experiment to log to!"
    assert FLAGS.exp_tag, "Must specify experiment tag subfolder to log_dir %s" % FLAGS.log_dir
    assert FLAGS.cnn_ckpt, "Must specify where to load CNN checkpoint from!"
    assert FLAGS.variant, "Must specific shapeworld variant"

    # Build saving folders
    save_root = FLAGS.log_dir + os.sep + FLAGS.exp_tag
    train_path = save_root + os.sep + "train"
    eval_path = save_root + os.sep + "eval"
    test_path = save_root + os.sep + "test"

    if not tf.gfile.IsDirectory(train_path):
        tf.gfile.MakeDirs(train_path)
        tf.gfile.MakeDirs(eval_path)
        tf.gfile.MakeDirs(test_path)

        tf.logging.info("Creating training directory: %s", train_path)
        tf.logging.info("Creating eval directory: %s", eval_path)
        tf.logging.info("Creating eval directory: %s", test_path)
    else:
        tf.logging.info("Using training directory: %s", train_path)
        tf.logging.info("Using eval directory: %s", eval_path)

    # Sanity check
    tf.reset_default_graph()
    tf.logging.info("Clean graph reset...")

    try:
        dataset = Dataset.create(dtype=FLAGS.dtype,
                                 name=FLAGS.name,
                                 variant=FLAGS.variant,
                                 config=FLAGS.data_dir)
        dataset.pixel_noise_stddev = 0.1
    except Exception:
        raise ValueError(
            "variant=%s did not point to a valid Shapeworld dataset" %
            FLAGS.variant)

    # Get parsing and parameter feats
    params = Config(mode="train", sw_specification=dataset.specification())
    params.cnn_checkpoint = FLAGS.cnn_ckpt
    params.batch_size = FLAGS.batch_size

    # MODEL SETUP ------------------------------------------------------------
    g = tf.Graph()
    with g.as_default():
        parser = FullSequenceBatchParser(
            src_vocab=dataset.vocabularies['language'])
        params.vocab_size = len(parser.tgt_vocab)

        batch = tf_util.batch_records(dataset,
                                      mode="train",
                                      batch_size=params.batch_size)
        model = CaptioningModel(config=params, batch_parser=parser)

        if FLAGS.glove_dir:
            tf.logging.info("Loading GloVe Embeddings...")
            gl = GloveLoader(vocab=parser.tgt_vocab,
                             glove_dir=FLAGS.glove_dir,
                             dims=FLAGS.glove_dim,
                             load_new=False)
            glove_initials = gl.get_embeddings_matrix()
            tf.logging.info("Building model with GloVe initialisation...")
            model.build_model(batch, embedding_init=glove_initials)
        else:
            tf.logging.info("Building model without GloVe initialisation...")
            model.build_model(batch)
        tf.logging.info("Network built...")

        # TRAINING OPERATION SETUP ------------------------------------------------------------
        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            train_op = tf.contrib.layers.optimize_loss(
                loss=model.batch_loss,
                global_step=model.global_step,
                learning_rate=params.initial_learning_rate,
                optimizer=params.optimizer,
                clip_gradients=params.clip_gradients,
            )

        logging_saver = tf.train.Saver(
            max_to_keep=params.max_checkpoints_to_keep)
        summary_op = tf.summary.merge_all()

    train_writer = tf.summary.FileWriter(logdir=train_path, graph=g)

    tf.logging.info('###' * 20)
    tf.logging.info("Beginning shape2seq network training for %d steps" %
                    params.num_total_steps)

    with tf.Session(graph=g,
                    config=tf.ConfigProto(allow_soft_placement=True)) as sess:

        tf.logging.info("### Trainable Variables")
        for var in tf.trainable_variables():
            print("-> %s" % var.op.name)

        coordinator = tf.train.Coordinator()
        queue_threads = tf.train.start_queue_runners(sess=sess,
                                                     coord=coordinator)

        # Initialise everything
        sess.run([tf.global_variables_initializer(), tf.tables_initializer()])

        tf.logging.info("Restoring CNN...")
        model.init_fn(sess)

        start_train_time = time.time()

        # Loss accumulator and logging interval generator at [25%, 50%, 75%, 100%] * epoch
        logging_loss = []
        logging_points = np.linspace(0,
                                     params.num_steps_per_epoch,
                                     4,
                                     endpoint=False,
                                     dtype=np.int32)
        logging_points = np.fliplr(
            [params.num_steps_per_epoch - logging_points])[0]

        for c_epoch in range(0, params.num_epochs):
            tf.logging.info("Running epoch %d" % c_epoch)
            for c_step in trange(params.num_steps_per_epoch * c_epoch,
                                 params.num_steps_per_epoch * (c_epoch + 1)):
                if c_step in logging_points:
                    _, loss_, summaries = sess.run(
                        fetches=[train_op, model.batch_loss, summary_op])

                    loss_ = logging_loss + [loss_]
                    logging_loss = []

                    avg_loss = np.mean(loss_).squeeze()
                    new_summ = tf.Summary()
                    new_summ.value.add(tag="train/avg_loss",
                                       simple_value=avg_loss)
                    train_writer.add_summary(
                        new_summ, tf.train.global_step(sess,
                                                       model.global_step))
                    train_writer.add_summary(
                        summaries,
                        tf.train.global_step(sess, model.global_step))
                    train_writer.flush()

                    tf.logging.info(
                        " -> Average loss step %d, for last %d steps: %.5f" %
                        (c_step, len(loss_), avg_loss))

                # Run without summaries
                else:
                    _, loss_, = sess.run(fetches=[train_op, model.batch_loss])
                    logging_loss.append(loss_)

            logging_saver.save(sess=sess,
                               save_path=train_path + os.sep + "model",
                               global_step=tf.train.global_step(
                                   sess, model.global_step))

        coordinator.request_stop()
        coordinator.join(threads=queue_threads)

        end_time = time.time() - start_train_time
        tf.logging.info('Training complete in %.2f-secs/%.2f-mins/%.2f-hours',
                        end_time, end_time / 60, end_time / (60 * 60))
Exemple #2
0
                        filehandle.write(',validation ' + name)
                filehandle.write('\n')
    iteration_end = iteration_start + args.iterations - 1

    with Model(name=args.model,
               learning_rate=parameters.pop('learning_rate', 1e-3),
               weight_decay=parameters.pop('weight_decay', None),
               clip_gradients=parameters.pop('clip_gradients', None),
               model_directory=args.model_dir,
               summary_directory=args.summary_dir) as model:
        dropout = parameters.pop('dropout_rate', None)

        module = import_module('models.{}.{}'.format(args.type, args.model))
        if args.tf_records:
            inputs = tf_util.batch_records(dataset=dataset,
                                           mode='train',
                                           batch_size=args.batch_size)
            module.model(model=model,
                         inputs=inputs,
                         dataset_parameters=dataset_parameters,
                         **parameters)
        else:
            module.model(
                model=model,
                inputs=dict(),
                dataset_parameters=dataset_parameters,
                **parameters
            )  # no input tensors, hence None for placeholder creation
        model.finalize(restore=args.restore)

        if args.verbosity >= 1:
Exemple #3
0
                        filehandle.write(',validation ' + name)
                filehandle.write('\n')
    iteration_end = iteration_start + args.iterations - 1

    with Model(name=args.model,
               learning_rate=parameters.pop('learning_rate'),
               weight_decay=parameters.pop('weight_decay', 0.0),
               model_directory=args.model_dir,
               summary_directory=args.summary_dir) as model:
        dropout = parameters.pop('dropout_rate', 0.0)

        module = import_module('models.{}.{}'.format(args.type, args.model))
        if args.tf_records:
            module.model(model=model,
                         inputs=tf_util.batch_records(
                             dataset=dataset,
                             batch_size=args.batch_size,
                             noise_range=args.pixel_noise),
                         **parameters)
        else:
            module.model(
                model=model, inputs=dict(), **parameters
            )  # no input tensors, hence None for placeholder creation
        model.finalize(restore=args.restore)

        if args.verbosity >= 1:
            sys.stdout.write('         parameters: {:,}\n'.format(
                model.num_parameters))
            sys.stdout.write('         bytes: {:,}\n'.format(model.num_bytes))
            sys.stdout.write('{} train model...\n'.format(
                datetime.now().strftime('%H:%M:%S')))
            sys.stdout.write('         0%  {}/{}  '.format(
def main(_):
    # FILESYSTEM SETUP ------------------------------------------------------------
    assert FLAGS.data_dir, "Must specify data location!"
    assert FLAGS.log_dir, "Must specify experiment to log to!"
    assert FLAGS.exp_tag, "Must specify experiment tag subfolder to log_dir %s" % FLAGS.log_dir
    assert FLAGS.parse_type

    # Folder setup for saving summaries and loading checkpoints
    save_root = FLAGS.log_dir + os.sep + FLAGS.exp_tag
    test_path = save_root + os.sep + "test"
    if not tf.gfile.IsDirectory(test_path):
        tf.gfile.MakeDirs(test_path)

    train_path = FLAGS.log_dir + os.sep + FLAGS.exp_tag + os.sep + "train"

    model_ckpt = tf.train.latest_checkpoint(
        train_path)  # Get checkpoint to load
    tf.logging.info("Loading checkpoint %s", model_ckpt)
    assert model_ckpt, "Checkpoints could not be loaded, check that train_path %s exists" % train_path

    # Sanity check graph reset
    tf.reset_default_graph()
    tf.logging.info("Clean graph reset...")

    # try:
    dataset = Dataset.create(dtype=FLAGS.dtype,
                             name=FLAGS.name,
                             config=FLAGS.data_dir)
    dataset.pixel_noise_stddev = 0.1
    dataset.random_sampling = False
    # except Exception:
    #     raise ValueError("config=%s did not point to a valid Shapeworld dataset" % FLAGS.data_dir)

    # Get parsing and parameter feats
    params = Config(mode="test", sw_specification=dataset.specification())

    # Parse decoding arg from CLI
    params.decode_type = FLAGS.decode_type
    assert params.decode_type in ['greedy', 'sample', 'beam']

    # MODEL SETUP ------------------------------------------------------------
    g = tf.Graph()
    with g.as_default():
        parser = SimpleBatchParser(src_vocab=dataset.vocabularies['language'],
                                   batch_type=FLAGS.parse_type)
        vocab, rev_vocab = parser.get_vocab()
        params.vocab_size = len(parser.tgt_vocab)

        batch = tf_util.batch_records(dataset,
                                      mode=FLAGS.data_partition,
                                      batch_size=params.batch_size)
        model = CaptioningModel(config=params, batch_parser=parser)
        model.build_model(batch)

        restore_model = tf.train.Saver()

        tf.logging.info("Network built...")

    # TESTING SETUP ------------------------------------------------------------

    if FLAGS.num_imgs < 1:
        num_imgs = params.instances_per_shard * params.num_shards
    else:
        num_imgs = FLAGS.num_imgs
    tf.logging.info("Running test for %d images", num_imgs)

    test_writer = tf.summary.FileWriter(logdir=test_path, graph=g)

    with tf.Session(graph=g,
                    config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        # Launch data loading queues
        coordinator = tf.train.Coordinator()
        queue_threads = tf.train.start_queue_runners(sess=sess,
                                                     coord=coordinator)

        # Model restoration
        restore_model.restore(sess, model_ckpt)
        tf.logging.info("Model restored!")

        # Trained model does not need initialisation. Init the vocab conversation tables
        sess.run([tf.tables_initializer()])

        #  Freeze graph
        sess.graph.finalize()

        # Get global step
        global_step = tf.train.global_step(sess, model.global_step)
        tf.logging.info("Successfully loaded %s at global step = %d.",
                        os.path.basename(model_ckpt), global_step)

        start_test_time = time.time()
        corrects = []
        incorrects = []  # For correctly formed, but wrong captions
        misses = []  # For incorrectly formed captions
        perplexities = []

        for b_idx in range(num_imgs):
            # idx_batch = dataset.generate(n=params.batch_size, mode=FLAGS.data_partition, include_model=True)

            reference_caps, inf_decoder_outputs, batch_perplexity = sess.run(
                fetches=[
                    model.reference_captions, model.inf_decoder_output,
                    model.batch_perplexity
                ],
                feed_dict={model.phase: 0})

            ref_cap = reference_caps.squeeze()
            inf_cap = inf_decoder_outputs.sample_id.squeeze()
            perplexities.append(batch_perplexity)

            if inf_cap.ndim > 0 and inf_cap.ndim > 0:
                print("%d REF -> %s | INF -> %s" % (b_idx, " ".join(
                    rev_vocab[r]
                    for r in ref_cap), " ".join(rev_vocab[r]
                                                for r in inf_cap)))

                # Strip <S>, </S> and any irrelevant tokens and convert to list for order insensitivity
                ref_cap = set([
                    tok for tok in ref_cap
                    if int(tok) not in parser.token_filter
                ])
                inf_cap = set([
                    tok for tok in inf_cap
                    if int(tok) not in parser.token_filter
                ])

                if np.all([i in ref_cap for i in inf_cap]):
                    corrects.append(1)
                else:
                    incorrects.append((ref_cap, inf_cap))
            else:
                print("Skipping %d as inf_cap %s is malformed" %
                      (b_idx, inf_cap))
                misses.append(1)

        # Overall scores for checkpoint
        avg_acc = np.mean(corrects).squeeze()
        std_acc = np.std(corrects).squeeze()
        print("Accuracy: %s -> %.5f ± %.5f | Misses: %d " %
              (FLAGS.parse_type, avg_acc, std_acc, len(misses)))

        avg_perplexity = np.mean(perplexities).squeeze()
        std_perplexity = np.std(perplexities).squeeze()
        print("------------")
        print("PERPLEXITY -> %.5f +- %.5f" % (avg_perplexity, std_perplexity))

        new_summ = tf.Summary()
        new_summ.value.add(tag="%s/avg_acc_%s" %
                           (FLAGS.data_partition, FLAGS.name),
                           simple_value=avg_acc)

        new_summ.value.add(tag="%s/std_acc_%s" %
                           (FLAGS.data_partition, FLAGS.name),
                           simple_value=std_acc)
        new_summ.value.add(tag="%s/perplexity_avg_%s" %
                           (FLAGS.data_partition, FLAGS.name),
                           simple_value=avg_perplexity)
        new_summ.value.add(tag="%s/perplexity_std_%s" %
                           (FLAGS.data_partition, FLAGS.name),
                           simple_value=std_perplexity)

        test_writer.add_summary(new_summ,
                                tf.train.global_step(sess, model.global_step))
        test_writer.flush()

        coordinator.request_stop()
        coordinator.join(threads=queue_threads)

        end_time = time.time() - start_test_time
        tf.logging.info('Testing complete in %.2f-secs/%.2f-mins/%.2f-hours',
                        end_time, end_time / 60, end_time / (60 * 60))