def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 2 and isinstance(args[1], str)
    dataset_name = args[1]
    coord_add = get_coord_add(dataset_name)
    dataset_size_train = get_dataset_size_train(dataset_name)
    dataset_size_test = get_dataset_size_test(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=False,
                                      epochs=cfg.epoch)
    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    with tf.Graph().as_default():
        num_batches_per_epoch_train = int(dataset_size_train / cfg.batch_size)
        num_batches_test = int(dataset_size_test / cfg.batch_size)

        batch_x, batch_labels = create_inputs()
        output = net.build_arch(batch_x,
                                coord_add,
                                is_train=False,
                                num_classes=num_classes)
        batch_acc = net.test_accuracy(output, batch_labels)
        saver = tf.train.Saver()

        step = 0

        summaries = []
        summaries.append(tf.summary.scalar('accuracy', batch_acc))
        summary_op = tf.summary.merge(summaries)

        with tf.Session() as sess:
            tf.train.start_queue_runners(sess=sess)
            summary_writer = tf.summary.FileWriter(
                cfg.test_logdir, graph=None)  # graph=sess.graph, huge!

            for epoch in range(cfg.epoch):
                # requires a regex to adapt the loss value in the file name here
                ckpt_re = re.compile()
                ckpt = os.path.join(
                    cfg.logdir,
                    'model.ckpt-%d' % (num_batches_per_epoch_train * epoch))
                saver.restore(sess, ckpt)

                accuracy_sum = 0
                for i in range(num_batches_test):
                    batch_acc_v, summary_str = sess.run(
                        [batch_acc, summary_op])
                    print('%d batches are tested.' % step)
                    summary_writer.add_summary(summary_str, step)

                    accuracy_sum += batch_acc_v

                    step += 1

                ave_acc = accuracy_sum / num_batches_test
                print('the average accuracy is %f' % ave_acc)
Beispiel #2
0
def main(args):
    tf.set_random_seed(1234)
    coord_add = get_coord_add(dataset_name)
    dataset_size_train = get_dataset_size_train(dataset_name)
    dataset_size_test = get_dataset_size_test(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=False,
                                      epochs=cfg.epoch)

    with tf.Graph().as_default():
        num_batches_test = int(dataset_size_test / cfg.batch_size * 0.5)
        batch_x, batch_labels = create_inputs()
        output, pose_out = net.build_arch(batch_x,
                                          coord_add,
                                          is_train=False,
                                          num_classes=num_classes)
        tf.logging.debug(pose_out.get_shape())

        batch_acc = net.test_accuracy(output, batch_labels)
        saver = tf.train.Saver()
        session_config = tf.ConfigProto(
            device_count={'GPU': 0},
            gpu_options={
                'allow_growth': 1,
                # 'per_process_gpu_memory_fraction': 0.1,
                'visible_device_list': '0'
            },
            allow_soft_placement=True)
        with tf.Session(config=session_config) as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            mode_file = tf.train.latest_checkpoint(ckpt)
            saver.restore(sess, mode_file)

            accuracy_sum = 0
            for i in range(num_batches_test):
                batch_acc_v = sess.run([batch_acc])
                accuracy_sum += batch_acc_v[0]
                print(accuracy_sum)

            ave_acc = accuracy_sum / num_batches_test
            print('the average accuracy is %f' % ave_acc)
def main(args):

    # Set reproduciable random seed
    tf.set_random_seed(1234)

    # Directories
    # Get name
    split = FLAGS.load_dir.split('/')
    if split[-1]:
        name = split[-1]
    else:
        name = split[-2]

    # Get parent directory
    split = FLAGS.load_dir.split("/" + name)
    parent_dir = split[0]

    test_dir = '{}/{}/test'.format(parent_dir, name)
    test_summary_dir = test_dir + '/summary'

    # Clear the test log directory
    if (FLAGS.reset is True) and os.path.exists(test_dir):
        shutil.rmtree(test_dir)
    if not os.path.exists(test_summary_dir):
        os.makedirs(test_summary_dir)

    # Logger
    conf.setup_logger(logger_dir=test_dir, name="logger_test.txt")
    logger.info("name: " + name)
    logger.info("parent_dir: " + parent_dir)
    logger.info("test_dir: " + test_dir)
    if FLAGS.patch_path:
        logger.info("patch_path: " + FLAGS.patch_path)

    # Load hyperparameters from train run
    conf.load_or_save_hyperparams()

    # Get dataset hyperparameters
    logger.info('Using dataset: {}'.format(FLAGS.dataset))

    # Dataset
    dataset_size_test = conf.get_dataset_size_test(
        FLAGS.dataset
    ) if FLAGS.partition == "test" else conf.get_dataset_size_train(
        FLAGS.dataset)
    num_classes = conf.get_num_classes(FLAGS.dataset)
    create_inputs_test = conf.get_create_inputs(FLAGS.dataset,
                                                mode=FLAGS.partition)

    #----------------------------------------------------------------------------
    # GRAPH - TEST
    #----------------------------------------------------------------------------
    logger.info('BUILD TEST GRAPH')
    g_test = tf.Graph()
    with g_test.as_default():
        # Get global_step
        global_step = tf.train.get_or_create_global_step()

        num_batches_test = int(dataset_size_test / FLAGS.batch_size)

        # Get data
        input_dict = create_inputs_test()
        batch_x = input_dict['image']
        batch_labels = input_dict['label']

        # AG 10/12/2018: Split batch for multi gpu implementation
        # Each split is of size FLAGS.batch_size / FLAGS.num_gpus
        # See: https://github.com/naturomics/CapsNet-
        # Tensorflow/blob/master/dist_version/distributed_train.py
        splits_x = tf.split(axis=0,
                            num_or_size_splits=FLAGS.num_gpus,
                            value=batch_x)
        splits_labels = tf.split(axis=0,
                                 num_or_size_splits=FLAGS.num_gpus,
                                 value=batch_labels)

        # Build architecture
        build_arch = conf.get_dataset_architecture(FLAGS.dataset)
        # for baseline
        #build_arch = conf.get_dataset_architecture('baseline')

        #--------------------------------------------------------------------------
        # MULTI GPU - TEST
        #--------------------------------------------------------------------------
        # Calculate the logits for each model tower
        tower_logits = []
        tower_recon_losses = []
        reuse_variables = None
        with tf.device("/cpu:0"):
            scale_min_feed = tf.placeholder(tf.float32,
                                            shape=[],
                                            name="scale_min_feed")
            scale_max_feed = tf.placeholder(tf.float32,
                                            shape=[],
                                            name="scale_max_feed")
        patch_feed = None
        if FLAGS.patch_path:
            patch_feed = tf.placeholder(
                tf.float32,
                shape=batch_x.get_shape().as_list()[-3:],
                name="patch_feed")
        for i in range(FLAGS.num_gpus):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('tower_%d' % i) as scope:
                    with slim.arg_scope([slim.variable], device='/cpu:0'):
                        logits, recon_losses, patch_node = tower_fn(
                            build_arch,
                            splits_x[i],
                            scale_min_feed,
                            scale_max_feed,
                            patch_feed,
                            scope,
                            num_classes,
                            reuse_variables=reuse_variables,
                            is_train=False)

                    # Don't reuse variable for first GPU, but do reuse for others
                    reuse_variables = True
                    # Keep track of losses and logits across for each tower
                    tower_logits.append(logits)
                    tower_recon_losses.append(recon_losses)
        # Combine logits from all towers
        test_metrics = {}
        if not FLAGS.save_patch:
            test_logits = tf.concat(tower_logits, axis=0)
            test_preds = tf.argmax(test_logits, axis=-1)
            test_recon_losses = tf.concat(tower_recon_losses, axis=0)
            test_metrics = {
                'preds': test_preds,
                'labels': batch_labels,
                'recon_losses': test_recon_losses
            }
        if FLAGS.adv_patch:
            test_metrics['patch'] = patch_node

        # Reset and read operations for streaming metrics go here
        test_reset = {}
        test_read = {}

        # Saver
        saver = tf.train.Saver(max_to_keep=None)

        # Set summary op

        #--------------------------------------------------------------------------
        # SESSION - TEST
        #--------------------------------------------------------------------------
        #sess_test = tf.Session(
        #    config=tf.ConfigProto(allow_soft_placement=True,
        #                          log_device_placement=False),
        #    graph=g_test)
        # Perry: added in for RTX 2070 incompatibility workaround
        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        config.gpu_options.allow_growth = True
        sess_test = tf.Session(config=config, graph=g_test)

        #sess_test.run(tf.local_variables_initializer())
        #sess_test.run(tf.global_variables_initializer())

        summary_writer = tf.summary.FileWriter(test_summary_dir,
                                               graph=sess_test.graph)

        ckpts_to_test = []
        load_dir_chechpoint = os.path.join(FLAGS.load_dir, "train",
                                           "checkpoint")

        # Evaluate the latest ckpt in dir
        if FLAGS.ckpt_name is None:
            latest_ckpt = tf.train.latest_checkpoint(load_dir_chechpoint)
            ckpts_to_test.append(latest_ckpt)

        # Evaluate all ckpts in dir
        elif FLAGS.ckpt_name == "all":
            # Get list of files in firectory and sort by date created
            filenames = os.listdir(load_dir_chechpoint)
            regex = re.compile(r'.*.index')
            filenames = filter(regex.search, filenames)
            data_ckpts = (os.path.join(load_dir_chechpoint, fn)
                          for fn in filenames)
            data_ckpts = ((os.stat(path), path) for path in data_ckpts)

            # regular files, insert creation date
            data_ckpts = ((stat[ST_CTIME], path) for stat, path in data_ckpts
                          if S_ISREG(stat[ST_MODE]))
            data_ckpts = sorted(data_ckpts)
            # remove ".index"
            ckpts_to_test = [path[:-6] for ctime, path in data_ckpts]

        # Evaluate ckpt specified by name
        else:
            ckpt_name = os.path.join(load_dir_chechpoint, FLAGS.ckpt_name)
            ckpts_to_test.append(ckpt_name)

        #--------------------------------------------------------------------------
        # MAIN LOOP
        #--------------------------------------------------------------------------
        # Run testing on checkpoints
        for ckpt in ckpts_to_test:
            saver.restore(sess_test, ckpt)

            if FLAGS.save_patch:
                out = sess_test.run(test_metrics['patch'])
                patch = out
                if patch.shape[-1] == 1:
                    patch = np.squeeze(patch, axis=-1)
                formatted = (patch * 255).astype('uint8')
                img = Image.fromarray(formatted)
                save_dir = os.path.join(FLAGS.storage, 'logs/', FLAGS.dataset,
                                        FLAGS.logdir)
                img.save(
                    os.path.join(FLAGS.load_dir, "test", "saved_patch.png"))
                return

            # Reset accumulators
            sess_test.run(test_reset)
            test_preds_vals = []
            test_labels_vals = []
            test_recon_losses_vals = []
            test_scales = []

            interval = 0.1 if FLAGS.adv_patch else 1
            for scale in np.arange(0, 1, interval):
                for i in range(num_batches_test):
                    feed_dict = {scale_min_feed: scale, scale_max_feed: scale}
                    if FLAGS.patch_path:
                        patch_dims = patch_feed.get_shape()
                        patch = np.asarray(Image.open(FLAGS.patch_path),
                                           dtype=np.float32)
                        if len(patch.shape) < 3:
                            patch = np.expand_dims(patch, axis=-1)
                        if patch_dims[-1] == 1:
                            patch = np.mean(patch, axis=-1, keepdims=True)
                        patch = patch / 255
                        feed_dict[patch_feed] = patch
                    out = sess_test.run([test_metrics], feed_dict=feed_dict)
                    test_metrics_v = out[0]
                    #ckpt_num = re.split('-', ckpt)[-1]
                    #logger.info('TEST ckpt-{}'.format(ckpt_num)
                    #    + ' bch-{:d}'.format(i)
                    #    )
                    test_preds_vals.append(test_metrics_v['preds'])
                    test_labels_vals.append(test_metrics_v['labels'])
                    test_recon_losses_vals.append(
                        test_metrics_v['recon_losses'])
                    test_scales.append(
                        np.full(test_metrics_v['preds'].shape,
                                fill_value=scale))

        logger.info('writing to csv')
        test_preds_vals = np.concatenate(test_preds_vals)
        test_labels_vals = np.concatenate(test_labels_vals)
        test_recon_losses_vals = np.concatenate(test_recon_losses_vals)
        test_scales = np.concatenate(test_scales)

        data = {
            'predictions': test_preds_vals,
            'labels': test_labels_vals,
            'reconstruction_losses': test_recon_losses_vals,
            'scales': test_scales
        }
        filename = "recon_losses.csv"
        if FLAGS.patch_path:
            filename = re.sub(
                '[^\w\-_]', '_',
                FLAGS.patch_path) + "_" + FLAGS.partition + ".csv"
        csv_save_path = os.path.join(FLAGS.load_dir, FLAGS.partition, filename)
        pd.DataFrame(data).to_csv(csv_save_path, index=False)
        logger.info('csv saved at ' + csv_save_path)
Beispiel #4
0
def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 3 and isinstance(args[1], str) and isinstance(
        args[2], str)
    dataset_name = args[1]
    model_name = args[2]
    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    coord_add = get_coord_add(dataset_name)
    dataset_size_train = get_dataset_size_train(dataset_name)
    dataset_size_test = get_dataset_size_test(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=False,
                                      epochs=cfg.epoch)

    with tf.Graph().as_default():
        num_batches_per_epoch_train = int(dataset_size_train / cfg.batch_size)
        num_batches_test = 2  # int(dataset_size_test / cfg.batch_size * 0.1)

        batch_x, batch_labels = create_inputs()
        batch_squash = tf.divide(batch_x, 255.)
        batch_x_norm = slim.batch_norm(batch_x,
                                       center=False,
                                       is_training=False,
                                       trainable=False)
        output, pose_out = net.build_arch(batch_x_norm,
                                          coord_add,
                                          is_train=False,
                                          num_classes=num_classes)
        tf.logging.debug(pose_out.get_shape())

        batch_acc = net.test_accuracy(output, batch_labels)
        m_op = tf.constant(0.9)
        loss, spread_loss, mse, recon_img_squash = net.spread_loss(
            output, pose_out, batch_squash, batch_labels, m_op)
        tf.summary.scalar('spread_loss', spread_loss)
        tf.summary.scalar('reconstruction_loss', mse)
        tf.summary.scalar('all_loss', loss)
        data_size = int(batch_x.get_shape()[1])
        recon_img = tf.multiply(
            tf.reshape(recon_img_squash,
                       shape=[cfg.batch_size, data_size, data_size, 1]), 255.)
        orig_img = tf.reshape(batch_x,
                              shape=[cfg.batch_size, data_size, data_size, 1])
        tf.summary.image('orig_image', orig_img)
        tf.summary.image('recon_image', recon_img)
        saver = tf.train.Saver()

        step = 0

        tf.summary.scalar('accuracy', batch_acc)
        summary_op = tf.summary.merge_all()

        with tf.Session(
                config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False)) as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            if not os.path.exists(cfg.test_logdir +
                                  '/{}/{}/'.format(model_name, dataset_name)):
                os.makedirs(cfg.test_logdir +
                            '/{}/{}/'.format(model_name, dataset_name))
            summary_writer = tf.summary.FileWriter(
                cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name),
                graph=sess.graph)  # graph=sess.graph, huge!

            files = os.listdir(cfg.logdir +
                               '/{}/{}/'.format(model_name, dataset_name))
            for epoch in range(45, 46):
                # requires a regex to adapt the loss value in the file name here
                ckpt_re = ".ckpt-%d" % (num_batches_per_epoch_train * epoch)
                for __file in files:
                    if __file.endswith(ckpt_re + ".index"):
                        ckpt = os.path.join(
                            cfg.logdir +
                            '/{}/{}/'.format(model_name, dataset_name),
                            __file[:-6])
                #ckpt = os.path.join(cfg.logdir, "model.ckpt-%d" % (num_batches_per_epoch_train * epoch))
                ############Comentar linea de abajo
                #ckpt = os.path.join(cfg.logdir, "caps/mnist/model-0.3764.ckpt-1718")
                saver.restore(sess, ckpt)

                accuracy_sum = 0
                for i in range(num_batches_test):
                    batch_acc_v, summary_str, orig_image, recon_image = sess.run(
                        [batch_acc, summary_op, orig_img, recon_img])
                    print('%d batches are tested.' % step)
                    summary_writer.add_summary(summary_str, step)

                    accuracy_sum += batch_acc_v

                    step += 1
                    # display original/reconstructed images in matplotlib
                    plot_imgs(orig_image, i, 'ori')
                    plot_imgs(recon_image, i, 'rec')

                ave_acc = accuracy_sum / num_batches_test
                print('the average accuracy is %f' % ave_acc)
Beispiel #5
0
def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 3 and isinstance(args[1], str) and isinstance(args[2], str)
    dataset_name = args[1]
    model_name = args[2]
    coord_add = get_coord_add(dataset_name)
    dataset_size_train = get_dataset_size_train(dataset_name)
    dataset_size_test = get_dataset_size_test(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(
        dataset_name, is_train=False, epochs=cfg.epoch)

    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    with tf.Graph().as_default():
        num_batches_per_epoch_train = int(dataset_size_train / cfg.batch_size)
        num_batches_test = int(dataset_size_test / cfg.batch_size * 0.1)

        batch_x, batch_labels = create_inputs()
        batch_x = slim.batch_norm(batch_x, center=False, is_training=False, trainable=False)
        if model_name == "caps":
            output, _ = net.build_arch(batch_x, coord_add,
                                       is_train=False, num_classes=num_classes)
        elif model_name == "cnn_baseline":
            output = net.build_arch_baseline(batch_x,
                                             is_train=False, num_classes=num_classes)
        else:
            raise "Please select model from 'caps' or 'cnn_baseline' as the secondary argument of eval.py!"
        batch_acc = net.test_accuracy(output, batch_labels)
        saver = tf.train.Saver()

        step = 0

        summaries = []
        summaries.append(tf.summary.scalar('accuracy', batch_acc))
        summary_op = tf.summary.merge(summaries)

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True, log_device_placement=False)) as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            if not os.path.exists(cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name)):
                os.makedirs(cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name))
            summary_writer = tf.summary.FileWriter(
                cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name), graph=sess.graph)  # graph=sess.graph, huge!

            files = os.listdir(cfg.logdir + '/{}/{}/'.format(model_name, dataset_name))
            for epoch in range(1, cfg.epoch):
                # requires a regex to adapt the loss value in the file name here
                ckpt_re = ".ckpt-%d" % (num_batches_per_epoch_train * epoch)
                for __file in files:
                    if __file.endswith(ckpt_re + ".index"):
                        ckpt = os.path.join(cfg.logdir + '/{}/{}/'.format(model_name, dataset_name), __file[:-6])
                # ckpt = os.path.join(cfg.logdir, "model.ckpt-%d" % (num_batches_per_epoch_train * epoch))
                saver.restore(sess, ckpt)

                accuracy_sum = 0
                for i in range(num_batches_test):
                    batch_acc_v, summary_str = sess.run([batch_acc, summary_op])
                    print('%d batches are tested.' % step)
                    summary_writer.add_summary(summary_str, step)

                    accuracy_sum += batch_acc_v

                    step += 1

                ave_acc = accuracy_sum / num_batches_test
                print('the average accuracy is %f' % ave_acc)

            coord.join(threads)
Beispiel #6
0
def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 2 and isinstance(args[1], str)
    dataset_name = args[1]
    logger.info('Using dataset: {}'.format(dataset_name))
    coord_add = get_coord_add(dataset_name)
    num_classes = get_num_classes(dataset_name)

    dataset_size = get_dataset_size_train(dataset_name)
    dataset_size_test = get_dataset_size_test(dataset_name)
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=True,
                                      epochs=cfg.epoch)
    test_inputs = get_create_inputs(dataset_name,
                                    is_train=False,
                                    epochs=cfg.epoch)
    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    with tf.Graph().as_default(), tf.device('/cpu:0'):
        """Get global_step."""
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)
        """Get batches per epoch."""
        num_batches_per_epoch = int(dataset_size / cfg.batch_size)
        num_batches_test = int(dataset_size_test / cfg.batch_size)
        """Set tf summaries."""
        summaries = []
        valid_sum = []
        """Use exponential decay leanring rate?"""
        # lrn_rate = tf.maximum(tf.train.exponential_decay(
        #     1e-2, global_step, num_batches_per_epoch, 0.8), 1e-5)
        # summaries.append(tf.summary.scalar('learning_rate', lrn_rate))
        opt = tf.train.AdamOptimizer(learning_rate=0.001)
        """Get batch from data queue."""
        train_q = create_inputs()
        test_q = test_inputs()
        use_train_data = tf.placeholder(dtype=tf.bool, shape=())
        batch_x, batch_labels = tf.cond(use_train_data,
                                        true_fn=lambda: train_q,
                                        false_fn=lambda: test_q)
        # batch_y = tf.one_hot(batch_labels, depth=10, axis=1, dtype=tf.float32)
        """Define the dataflow graph."""
        m_op = tf.placeholder(dtype=tf.float32, shape=())
        with tf.device('/gpu:0'):
            with slim.arg_scope([slim.variable], device='/cpu:0'):
                norm_batch_x = tf.contrib.layers.batch_norm(batch_x,
                                                            is_training=True)

                # Select network architecture.
                if cfg.network == 'conv':
                    import capsnet_em as net
                    output = net.build_arch(norm_batch_x,
                                            coord_add,
                                            is_train=True,
                                            num_classes=num_classes)
                elif cfg.network == 'fc':
                    import capsnet_fc as net
                    output = net.build_arch(norm_batch_x,
                                            is_train=True,
                                            num_classes=num_classes)
                else:
                    raise ValueError('Invalid network architecture: ' %
                                     cfg.network)

                # Select loss function.
                if cfg.loss_fn == 'spread':
                    loss = net.spread_loss(output, batch_labels, m_op)
                elif cfg.loss_fn == 'margin':
                    loss = net.margin_loss(output, batch_labels)
                elif cfg.loss_fn == 'cross_en':
                    loss = net.cross_entropy_loss(output, batch_labels)
                else:
                    raise ValueError('Invalid loss function: ' % cfg.loss_fn)

                acc = net.accuracy(output, batch_labels)
            """Compute gradient."""
            grad = opt.compute_gradients(loss)
            # See: https://stackoverflow.com/questions/40701712/how-to-check-nan-in-gradients-in-tensorflow-when-updating
            grad_check = [tf.check_numerics(g, message='Gradient NaN Found!')
                          for g, _ in grad] + \
                         [tf.check_numerics(loss, message='Loss NaN Found')]
        """Add to summary."""
        summaries.append(tf.summary.scalar('loss', loss))
        summaries.append(tf.summary.scalar('acc', acc))
        valid_sum.append(tf.summary.scalar('val_loss', loss))
        valid_sum.append(tf.summary.scalar('val_acc', acc))
        """Apply graident."""
        with tf.control_dependencies(grad_check):
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = opt.apply_gradients(grad, global_step=global_step)
        """Set Session settings."""
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=cfg.gpu_frac)
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False,
                                                gpu_options=gpu_options))
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())
        """Set Saver."""
        var_to_save = [
            v for v in tf.global_variables() if 'Adam' not in v.name
        ]  # Don't save redundant Adam beta/gamma
        saver = tf.train.Saver(var_list=var_to_save, max_to_keep=cfg.epoch)
        """Display parameters"""
        total_p = np.sum([
            np.prod(v.get_shape().as_list()) for v in var_to_save
        ]).astype(np.int32)
        train_p = np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]).astype(np.int32)
        logger.info('Total Parameters: {}'.format(total_p))
        logger.info('Trainable Parameters: {}'.format(train_p))

        # read snapshot
        # latest = os.path.join(cfg.logdir, 'model.ckpt-4680')
        # saver.restore(sess, latest)
        """Set summary op."""
        summary_op = tf.summary.merge(summaries)
        valid_sum_op = tf.summary.merge(valid_sum)
        """Start coord & queue."""
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        """Set summary writer"""
        summary_writer = tf.summary.FileWriter(
            cfg.logdir, graph=None)  # graph = sess.graph, huge!
        """Main loop."""
        m_min = 0.2
        m_max = 0.9
        m = m_min
        for step in range(cfg.epoch * num_batches_per_epoch):
            if (step % num_batches_per_epoch) == 0:
                tic = time.time()
                progbar = tf.keras.utils.Progbar(
                    num_batches_per_epoch, verbose=(1 if cfg.progbar else 0))
            """"TF queue would pop batch until no file"""
            try:
                _, loss_value, acc_value = sess.run([train_op, loss, acc],
                                                    feed_dict={
                                                        use_train_data: True,
                                                        m_op: m
                                                    })
                progbar.update((step % num_batches_per_epoch),
                               values=[('loss', loss_value),
                                       ('acc', acc_value)])
            except KeyboardInterrupt:
                sess.close()
                sys.exit()
            except tf.errors.InvalidArgumentError:
                logger.warning(
                    '%d iteration contains NaN gradients. Discard.' % step)
                continue
            """Write to summary."""
            if step % 10 == 0:
                summary_str = sess.run(summary_op,
                                       feed_dict={
                                           use_train_data: True,
                                           m_op: m
                                       })
                summary_writer.add_summary(summary_str, step)
            """Epoch wise linear annealling."""
            if (step % num_batches_per_epoch) == 0:
                if step > 0:
                    m += (m_max - m_min) / (cfg.epoch * 0.6)
                    if m > m_max:
                        m = m_max
                """Save model periodically"""
                ckpt_path = os.path.join(
                    cfg.logdir, 'model-{0:.4f}.ckpt'.format(loss_value))
                saver.save(sess, ckpt_path, global_step=step)

            # Add a new progress bar
            if ((step + 1) % num_batches_per_epoch) == 0:
                toc = time.time()
                val_loss_value, val_acc_value = (0.0, 0.0)
                for i in range(num_batches_test):
                    val_batch = sess.run([loss, acc],
                                         feed_dict={
                                             use_train_data: False,
                                             m_op: m
                                         })
                    val_loss_batch, val_acc_batch = val_batch
                    val_loss_value += val_loss_batch / num_batches_test
                    val_acc_value += val_acc_batch / num_batches_test
                valid_sum_str = sess.run(valid_sum_op,
                                         feed_dict={
                                             use_train_data: False,
                                             m_op: m
                                         })
                summary_writer.add_summary(valid_sum_str, step)
                print('\nEpoch %d/%d in ' %
                      (step // num_batches_per_epoch + 1, cfg.epoch) +
                      '%.1fs' % (toc - tic) + ' - loss: %f' % val_loss_value +
                      ' - acc: %f' % val_acc_value)
        """Join threads"""
        coord.join(threads)
def main(args):
  
  # Set reproduciable random seed
  tf.set_random_seed(1234)
  
  # Directories
  # Get name
  split = FLAGS.load_dir.split('/')
  if split[-1]:
    name = split[-1]
  else:
    name = split[-2]
    
  # Get parent directory
  split = FLAGS.load_dir.split("/" + name)
  parent_dir = split[0]

  test_dir = '{}/{}/test'.format(parent_dir, name)
  test_summary_dir = test_dir + '/summary'

  # Clear the test log directory
  if (FLAGS.reset is True) and os.path.exists(test_dir):
    shutil.rmtree(test_dir) 
  if not os.path.exists(test_summary_dir):
    os.makedirs(test_summary_dir)
  
  # Logger
  conf.setup_logger(logger_dir=test_dir, name="logger_test.txt")
  logger.info("name: " + name)
  logger.info("parent_dir: " + parent_dir)
  logger.info("test_dir: " + test_dir)
  
  # Load hyperparameters from train run
  conf.load_or_save_hyperparams()
  
  # Get dataset hyperparameters
  logger.info('Using dataset: {}'.format(FLAGS.dataset))
  
  # Dataset
  dataset_size_test  = conf.get_dataset_size_test(FLAGS.dataset)
  num_classes        = conf.get_num_classes(FLAGS.dataset)
  create_inputs_test = conf.get_create_inputs(FLAGS.dataset, mode="test")

  
  #----------------------------------------------------------------------------
  # GRAPH - TEST
  #----------------------------------------------------------------------------
  logger.info('BUILD TEST GRAPH')
  g_test = tf.Graph()
  with g_test.as_default():
    # Get global_step
    global_step = tf.train.get_or_create_global_step()

    num_batches_test = int(dataset_size_test / FLAGS.batch_size)

    # Get data
    input_dict = create_inputs_test()
    batch_x = input_dict['image']
    batch_labels = input_dict['label']
    
    # AG 10/12/2018: Split batch for multi gpu implementation
    # Each split is of size FLAGS.batch_size / FLAGS.num_gpus
    # See: https://github.com/naturomics/CapsNet-
    # Tensorflow/blob/master/dist_version/distributed_train.py
    splits_x = tf.split(
        axis=0, 
        num_or_size_splits=FLAGS.num_gpus, 
        value=batch_x)
    splits_labels = tf.split(
        axis=0, 
        num_or_size_splits=FLAGS.num_gpus, 
        value=batch_labels)
    
    # Build architecture
    build_arch = conf.get_dataset_architecture(FLAGS.dataset)
    # for baseline
    #build_arch = conf.get_dataset_architecture('baseline')
    
    #--------------------------------------------------------------------------
    # MULTI GPU - TEST
    #--------------------------------------------------------------------------
    # Calculate the logits for each model tower
    tower_logits = []
    reuse_variables = None
    for i in range(FLAGS.num_gpus):
      with tf.device('/gpu:%d' % i):
        with tf.name_scope('tower_%d' % i) as scope:
          with slim.arg_scope([slim.variable], device='/cpu:0'):
            loss, logits = tower_fn(
                build_arch, 
                splits_x[i], 
                splits_labels[i], 
                scope, 
                num_classes, 
                reuse_variables=reuse_variables, 
                is_train=False)

          # Don't reuse variable for first GPU, but do reuse for others
          reuse_variables = True
          
          # Keep track of losses and logits across for each tower
          tower_logits.append(logits)
          
          # Loss for each tower
          tf.summary.histogram("test_logits", logits)
    
    # Combine logits from all towers
    logits = tf.concat(tower_logits, axis=0)
    
    # Calculate metrics
    test_loss = mod.spread_loss(logits, batch_labels)
    test_acc = met.accuracy(logits, batch_labels)
    
    # Prepare predictions and one-hot labels
    test_probs = tf.nn.softmax(logits=logits)
    test_labels_oh = tf.one_hot(batch_labels, num_classes)
    
    # Group metrics together
    # See: https://cs230-stanford.github.io/tensorflow-model.html
    test_metrics = {'loss' : test_loss,
                   'labels' : batch_labels, 
                   'labels_oh' : test_labels_oh,
                   'logits' : logits,
                   'probs' : test_probs,
                   'acc' : test_acc,
                   }
    
    # Reset and read operations for streaming metrics go here
    test_reset = {}
    test_read = {}
    
    tf.summary.scalar("test_loss", test_loss)
    tf.summary.scalar("test_acc", test_acc)
      
    # Saver
    saver = tf.train.Saver(max_to_keep=None)
    
    # Set summary op
    test_summary = tf.summary.merge_all()
    
  
    #--------------------------------------------------------------------------
    # SESSION - TEST
    #--------------------------------------------------------------------------
    #sess_test = tf.Session(
    #    config=tf.ConfigProto(allow_soft_placement=True, 
    #                          log_device_placement=False), 
    #    graph=g_test)
    # Perry: added in for RTX 2070 incompatibility workaround
    config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
    config.gpu_options.allow_growth = True
    sess_test = tf.Session(config=config, graph=g_test)

   
    
    #sess_test.run(tf.local_variables_initializer())
    #sess_test.run(tf.global_variables_initializer())

    summary_writer = tf.summary.FileWriter(
        test_summary_dir, 
        graph=sess_test.graph)


    ckpts_to_test = []
    load_dir_chechpoint = os.path.join(FLAGS.load_dir, "train", "checkpoint")
    
    # Evaluate the latest ckpt in dir
    if FLAGS.ckpt_name is None:
      latest_ckpt = tf.train.latest_checkpoint(load_dir_chechpoint)
      ckpts_to_test.append(latest_ckpt)

    # Evaluate all ckpts in dir  
    elif FLAGS.ckpt_name == "all":
      # Get list of files in firectory and sort by date created
      filenames = os.listdir(load_dir_chechpoint)
      regex = re.compile(r'.*.index')
      filenames = filter(regex.search, filenames)
      data_ckpts = (os.path.join(load_dir_chechpoint, fn) for fn in filenames)
      data_ckpts = ((os.stat(path), path) for path in data_ckpts)

      # regular files, insert creation date
      data_ckpts = ((stat[ST_CTIME], path) for stat, path in data_ckpts 
                    if S_ISREG(stat[ST_MODE]))
      data_ckpts= sorted(data_ckpts)
      # remove ".index"
      ckpts_to_test = [path[:-6] for ctime, path in data_ckpts]
        
    # Evaluate ckpt specified by name
    else:
      ckpt_name = os.path.join(load_dir_chechpoint, FLAGS.ckpt_name)
      ckpts_to_test.append(ckpt_name)    
      
      
    #--------------------------------------------------------------------------
    # MAIN LOOP
    #--------------------------------------------------------------------------
    # Run testing on checkpoints
    for ckpt in ckpts_to_test:
      saver.restore(sess_test, ckpt)
          
      # Reset accumulators
      accuracy_sum = 0
      loss_sum = 0
      sess_test.run(test_reset)

      for i in range(num_batches_test):
        
        test_metrics_v, test_summary_str_v = sess_test.run(
            [test_metrics, test_summary])
        
        # Update
        accuracy_sum += test_metrics_v['acc']
        loss_sum += test_metrics_v['loss']

        ckpt_num = re.split('-', ckpt)[-1]
        logger.info('TEST ckpt-{}'.format(ckpt_num) 
              + ' bch-{:d}'.format(i) 
              + ' cum_acc: {:.2f}%'.format(accuracy_sum/(i+1)*100) 
              + ' cum_loss: {:.4f}'.format(loss_sum/(i+1)) 
               )

      ave_acc = accuracy_sum / num_batches_test
      ave_loss = loss_sum / num_batches_test
  
      logger.info('TEST ckpt-{}'.format(ckpt_num) 
            + ' avg_acc: {:.2f}%'.format(ave_acc*100) 
            + ' avg_loss: {:.4f}'.format(ave_loss))

      logger.info("Write Test Summary")
      summary_test = tf.Summary()
      summary_test.value.add(tag="test_acc", simple_value=ave_acc)
      summary_test.value.add(tag="test_loss", simple_value=ave_loss)
      summary_writer.add_summary(summary_test, ckpt_num)
Beispiel #8
0
def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 2 and isinstance(args[1], str)
    dataset_name = args[1]
    coord_add = get_coord_add(dataset_name)
    dataset_size_train = get_dataset_size_train(dataset_name)
    dataset_size_test = get_dataset_size_test(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=False,
                                      epochs=cfg.epoch)
    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    with tf.Graph().as_default():
        num_batches_per_epoch_train = int(dataset_size_train / cfg.batch_size)
        num_batches_test = int(dataset_size_test / cfg.batch_size * 0.1)

        batch_x, batch_labels = create_inputs()
        batch_x = slim.batch_norm(batch_x,
                                  center=False,
                                  is_training=False,
                                  trainable=False)
        output, _ = net.build_arch(batch_x,
                                   coord_add,
                                   is_train=False,
                                   num_classes=num_classes)
        batch_acc = net.test_accuracy(output, batch_labels)
        saver = tf.train.Saver()

        step = 0

        summaries = []
        summaries.append(tf.summary.scalar('accuracy', batch_acc))
        summary_op = tf.summary.merge(summaries)

        with tf.Session(
                config=tf.ConfigProto(allow_soft_placement=True,
                                      log_device_placement=False)) as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            summary_writer = tf.summary.FileWriter(
                cfg.test_logdir, graph=sess.graph)  # graph=sess.graph, huge!

            files = os.listdir(cfg.logdir)
            for epoch in range(1, cfg.epoch):
                # requires a regex to adapt the loss value in the file name here
                ckpt_re = ".ckpt-%d" % (num_batches_per_epoch_train * epoch)
                for __file in files:
                    if __file.endswith(ckpt_re + ".index"):
                        ckpt = os.path.join(cfg.logdir, __file[:-6])
                # ckpt = os.path.join(cfg.logdir, "model.ckpt-%d" % (num_batches_per_epoch_train * epoch))
                saver.restore(sess, ckpt)

                accuracy_sum = 0
                for i in range(num_batches_test):
                    batch_acc_v, summary_str = sess.run(
                        [batch_acc, summary_op])
                    print('%d batches are tested.' % step)
                    summary_writer.add_summary(summary_str, step)
                    print('%d batch accuracy.' % batch_acc_v)

                    accuracy_sum += batch_acc_v

                    step += 1

                ave_acc = accuracy_sum / num_batches_test
                print('the average accuracy is %f' % ave_acc)

            coord.join(threads)
Beispiel #9
0
def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 3 and isinstance(args[1], str) and isinstance(
        args[2], str)
    dataset_name = args[1]
    model_name = args[2]
    coord_add = get_coord_add(dataset_name)
    dataset_size_train = get_dataset_size_train(dataset_name)
    dataset_size_test = get_dataset_size_test(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(dataset_name,
                                      is_train=False,
                                      epochs=cfg.epoch)
    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    with tf.Graph().as_default():
        num_batches_per_epoch_train = int(dataset_size_train / cfg.batch_size)
        num_batches_test = int(dataset_size_test / cfg.batch_size * 0.1)

        batch_x, batch_labels = create_inputs()
        batch_x = slim.batch_norm(batch_x,
                                  center=False,
                                  is_training=False,
                                  trainable=False)
        if model_name == "caps":
            output, _ = net.build_arch(batch_x,
                                       coord_add,
                                       is_train=False,
                                       num_classes=num_classes)
        elif model_name == "cnn_baseline":
            output = net.build_arch_baseline(batch_x,
                                             is_train=False,
                                             num_classes=num_classes)
        else:
            raise "Please select model from 'caps' or 'cnn_baseline' as the secondary argument of eval.py!"
        batch_acc = net.test_accuracy(output, batch_labels)
        saver = tf.train.Saver()

        step = 0

        summaries = []
        summaries.append(tf.summary.scalar('accuracy', batch_acc))
        summary_op = tf.summary.merge(summaries)

        session_config = tf.ConfigProto(
            device_count={'GPU': 0},
            gpu_options={
                'allow_growth': 1,
                # 'per_process_gpu_memory_fraction': 0.1,
                'visible_device_list': '0'
            },
            allow_soft_placement=True)
        with tf.Session(config=session_config) as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            if not os.path.exists(cfg.test_logdir +
                                  '/{}/{}/'.format(model_name, dataset_name)):
                os.makedirs(cfg.test_logdir +
                            '/{}/{}/'.format(model_name, dataset_name))
            summary_writer = tf.summary.FileWriter(
                cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name),
                graph=sess.graph)  # graph=sess.graph, huge!

            files = os.listdir(cfg.logdir +
                               '/{}/{}/'.format(model_name, dataset_name))
            for epoch in range(1, cfg.epoch):
                # requires a regex to adapt the loss value in the file name here
                ckpt_re = ".ckpt-%d" % (num_batches_per_epoch_train * epoch)
                for __file in files:
                    if __file.endswith(ckpt_re + ".index"):
                        ckpt = os.path.join(
                            cfg.logdir +
                            '/{}/{}/'.format(model_name, dataset_name),
                            __file[:-6])
                # ckpt = os.path.join(cfg.logdir, "model.ckpt-%d" % (num_batches_per_epoch_train * epoch))
                saver.restore(sess, ckpt)

                accuracy_sum = 0
                for i in range(num_batches_test):
                    batch_acc_v, summary_str = sess.run(
                        [batch_acc, summary_op])
                    print('%d batches are tested.' % step)
                    summary_writer.add_summary(summary_str, step)

                    accuracy_sum += batch_acc_v

                    step += 1

                ave_acc = accuracy_sum / num_batches_test
                print('the average accuracy is %f' % ave_acc)

            coord.join(threads)
def main(args):
    """Get dataset hyperparameters."""
    assert len(args) == 3 and isinstance(args[1], str) and isinstance(args[2], str)
    dataset_name = args[1]
    model_name = args[2]

    """Set reproduciable random seed"""
    tf.set_random_seed(1234)

    coord_add = get_coord_add(dataset_name)
    dataset_size_train = get_dataset_size_train(dataset_name)
    dataset_size_test = get_dataset_size_test(dataset_name)
    num_classes = get_num_classes(dataset_name)
    create_inputs = get_create_inputs(
        dataset_name, is_train=False, epochs=cfg.epoch)

    with tf.Graph().as_default():
        num_batches_per_epoch_train = int(dataset_size_train / cfg.batch_size)
        num_batches_test = 2  # int(dataset_size_test / cfg.batch_size * 0.1)

        batch_x, batch_labels = create_inputs()
        batch_squash = tf.divide(batch_x, 255.)
        batch_x_norm = slim.batch_norm(batch_x, center=False, is_training=False, trainable=False)
        output, pose_out = net.build_arch(batch_x_norm, coord_add,
                                          is_train=False, num_classes=num_classes)
        tf.logging.debug(pose_out.get_shape())

        batch_acc = net.test_accuracy(output, batch_labels)
        m_op = tf.constant(0.9)
        loss, spread_loss, mse, recon_img_squash = net.spread_loss(
            output, pose_out, batch_squash, batch_labels, m_op)
        tf.summary.scalar('spread_loss', spread_loss)
        tf.summary.scalar('reconstruction_loss', mse)
        tf.summary.scalar('all_loss', loss)
        data_size = int(batch_x.get_shape()[1])
        recon_img = tf.multiply(tf.reshape(recon_img_squash, shape=[
                                cfg.batch_size, data_size, data_size, 1]), 255.)
        orig_img = tf.reshape(batch_x, shape=[
            cfg.batch_size, data_size, data_size, 1])
        tf.summary.image('orig_image', orig_img)
        tf.summary.image('recon_image', recon_img)
        saver = tf.train.Saver()

        step = 0

        tf.summary.scalar('accuracy', batch_acc)
        summary_op = tf.summary.merge_all()

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True, log_device_placement=False)) as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            if not os.path.exists(cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name)):
                os.makedirs(cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name))
            summary_writer = tf.summary.FileWriter(
                cfg.test_logdir + '/{}/{}/'.format(model_name, dataset_name), graph=sess.graph)  # graph=sess.graph, huge!

            files = os.listdir(cfg.logdir + '/{}/{}/'.format(model_name, dataset_name))
            for epoch in range(14, 15):
                # requires a regex to adapt the loss value in the file name here
                ckpt_re = ".ckpt-%d" % (num_batches_per_epoch_train * epoch)
                for __file in files:
                    if __file.endswith(ckpt_re + ".index"):
                        ckpt = os.path.join(
                            cfg.logdir + '/{}/{}/'.format(model_name, dataset_name), __file[:-6])
                # ckpt = os.path.join(cfg.logdir, "model.ckpt-%d" % (num_batches_per_epoch_train * epoch))
                saver.restore(sess, ckpt)

                accuracy_sum = 0
                for i in range(num_batches_test):
                    batch_acc_v, summary_str, orig_image, recon_image = sess.run(
                        [batch_acc, summary_op, orig_img, recon_img])
                    print('%d batches are tested.' % step)
                    summary_writer.add_summary(summary_str, step)

                    accuracy_sum += batch_acc_v

                    step += 1
                    # display original/reconstructed images in matplotlib
                    plot_imgs(orig_image, i, 'ori')
                    plot_imgs(recon_image, i, 'rec')

                ave_acc = accuracy_sum / num_batches_test
                print('the average accuracy is %f' % ave_acc)
# print(C)

import logging
import daiquiri
daiquiri.setup(level=logging.DEBUG)
logger = daiquiri.getLogger(__name__)
import capsnet_em as net
from confusion_matrix_API import plot_confusion
import matplotlib.pyplot as plt

####################   改这里  ##########################################
import capsnet_em as net
ckpt = 'logdir/caps/asl/'
dataset_name = 'asl'
#想要测试多少个batch
test_dataset_size = cfg.get_dataset_size_test('asl')
batch_size =cfg.batch_size
num_batches_test = 10 * (test_dataset_size // batch_size)
# num_batches_test = 10  #数量太少会报错,因为画图有的行是NaN值
recognize_data_dir ='../data/asl_tf'
recognize_labels_txt_keywords = 'labels.txt'
####################   end       ########################################

def conver_number_to_label_name(the_list,labels_maps):
    for idx,each in enumerate(the_list):
        the_list[idx] = labels_maps[str(each)]
    return the_list

def main(args):
    # 1、设置GPU模式
    session_config = cfg.set_gpu()
def main(args):
  """Run training and validation.
  
  1. Build graphs
      1.1 Training graph to run on multiple GPUs
      1.2 Validation graph to run on multiple GPUs
  2. Configure sessions
      2.1 Train
      2.2 Validate
  3. Main loop
      3.1 Train
      3.2 Write summary
      3.3 Save model
      3.4 Validate model
      
  Author:
    Perry Deng
  """
  
  # Set reproduciable random seed
  tf.set_random_seed(1234)
    
  # Directories
  train_dir, train_summary_dir = conf.setup_train_directories()
  
  # Logger
  conf.setup_logger(logger_dir=train_dir, name="logger_train.txt")
  
  # Hyperparameters
  conf.load_or_save_hyperparams(train_dir)
  
  # Get dataset hyperparameters
  logger.info('Using dataset: {}'.format(FLAGS.dataset))
  dataset_size_train = conf.get_dataset_size_train(FLAGS.dataset)\
      if not FLAGS.train_on_test else conf.get_dataset_size_test(FLAGS.dataset)
  dataset_size_val = conf.get_dataset_size_validate(FLAGS.dataset)
  build_arch = conf.get_dataset_architecture(FLAGS.dataset)
  num_classes = conf.get_num_classes(FLAGS.dataset)
  create_inputs_train = conf.get_create_inputs(FLAGS.dataset, mode="train_whole")\
      if not FLAGS.train_on_test else conf.get_create_inputs(FLAGS.dataset, mode="train_on_test")
  create_inputs_train_wholeset = conf.get_create_inputs(FLAGS.dataset, mode="train_whole")
  if dataset_size_val > 0:
    create_inputs_val   = conf.get_create_inputs(FLAGS.dataset, mode="validate")

  
 #*****************************************************************************
 # 1. BUILD GRAPHS
 #*****************************************************************************

  #----------------------------------------------------------------------------
  # GRAPH - TRAIN
  #----------------------------------------------------------------------------
  logger.info('BUILD TRAIN GRAPH')
  g_train = tf.Graph()
  with g_train.as_default(), tf.device('/cpu:0'):
    
    # Get global_step
    global_step = tf.train.get_or_create_global_step()

    # Get batches per epoch
    num_batches_per_epoch = int(dataset_size_train / FLAGS.batch_size)

    # In response to a question on OpenReview, Hinton et al. wrote the 
    # following:
    # "We use an exponential decay with learning rate: 3e-3, decay_steps: 20000,     # decay rate: 0.96."
    # https://openreview.net/forum?id=HJWLfGWRb&noteId=ryxTPFDe2X
    lrn_rate = tf.train.exponential_decay(learning_rate = FLAGS.lrn_rate, 
                        global_step = global_step,
                        decay_steps = 20000,
                        decay_rate = 0.96)
    tf.summary.scalar('learning_rate', lrn_rate)
    opt = tf.train.AdamOptimizer(learning_rate=lrn_rate)

    # Get batch from data queue. Batch size is FLAGS.batch_size, which is then 
    # divided across multiple GPUs
    input_dict = create_inputs_train()
    batch_x = input_dict['image']
    batch_labels = input_dict['label']
    
    # AG 03/10/2018: Split batch for multi gpu implementation
    # Each split is of size FLAGS.batch_size / FLAGS.num_gpus
    # See: https://github.com/naturomics/CapsNet-Tensorflow/blob/master/
    # dist_version/distributed_train.py
    splits_x = tf.split(
        axis=0, 
        num_or_size_splits=FLAGS.num_gpus, 
        value=batch_x)
    splits_labels = tf.split(
        axis=0, 
        num_or_size_splits=FLAGS.num_gpus, 
        value=batch_labels)

    
    #--------------------------------------------------------------------------
    # MULTI GPU - TRAIN
    #--------------------------------------------------------------------------
    # Calculate the gradients for each model tower
    tower_grads = []
    tower_losses = []
    tower_logits = []
    tower_target_labels = []
    reuse_variables = None
    for i in range(FLAGS.num_gpus):
      with tf.device('/gpu:%d' % i):
        with tf.name_scope('tower_%d' % i) as scope:
          logger.info('TOWER %d' % i)
          #with slim.arg_scope([slim.model_variable, slim.variable],
          # device='/cpu:0'):
          with slim.arg_scope([slim.variable], device='/cpu:0'):
            loss, logits, x, patch, target_labels = tower_fn(
                build_arch,
                splits_x[i],
                splits_labels[i],
                scope,
                num_classes,
                reuse_variables=reuse_variables,
                is_train=True)
          
          # Don't reuse variable for first GPU, but do reuse for others
          reuse_variables = True
          
          # Compute gradients for one GPU
          patch_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                           "patch_params")
          grads = opt.compute_gradients(loss, var_list=patch_params)
          
          # Keep track of the gradients across all towers.
          tower_grads.append(grads)
          tower_target_labels.append(target_labels)          

          # Keep track of losses and logits across for each tower
          tower_logits.append(logits)
          tower_losses.append(loss)
          
          # Loss for each tower
          tf.summary.scalar("loss", loss)
    
    # We must calculate the mean of each gradient. Note that this is the
    # synchronization point across all towers.
    grads = average_gradients(tower_grads)
    
    # See: https://stackoverflow.com/questions/40701712/how-to-check-nan-in-
    # gradients-in-tensorflow-when-updating
    grad_check = ([tf.check_numerics(g, message='Gradient NaN Found!') 
                      for g, _ in grads if g is not None]
                  + [tf.check_numerics(loss, message='Loss NaN Found')])
    
    # Apply the gradients to adjust the shared variables
    with tf.control_dependencies(grad_check):
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      with tf.control_dependencies(update_ops):
        train_op = opt.apply_gradients(grads, global_step=global_step)
    
    # Calculate mean loss     
    loss = tf.reduce_mean(tower_losses)
    
    # Calculate accuracy
    logits = tf.concat(tower_logits, axis=0)
    target_labels = tf.concat(tower_target_labels, axis=0)
    acc = met.accuracy(logits, target_labels)
    
    # Prepare predictions and one-hot labels
    probs = tf.nn.softmax(logits=logits)
    labels_oh = tf.one_hot(batch_labels, num_classes)
    
    # Group metrics together
    # See: https://cs230-stanford.github.io/tensorflow-model.html
    trn_metrics = {'loss' : loss,
             'labels' : batch_labels, 
             'labels_oh' : labels_oh,
             'logits' : logits,
             'probs' : probs,
             'acc' : acc,
             }
    
    # Reset and read operations for streaming metrics go here
    trn_reset = {}
    trn_read = {}
    
    # Logging
    tf.summary.scalar('batch_loss', loss)
    tf.summary.scalar('batch_success_rate', acc)

    # Set Saver
    # AG 26/09/2018: Save all variables including Adam so that we can continue 
    # training from where we left off
    # max_to_keep=None should keep all checkpoints
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)
    
    # Display number of parameters
    train_params = np.sum([np.prod(v.get_shape().as_list())
              for v in tf.trainable_variables()]).astype(np.int32)
    logger.info('Trainable Parameters: {}'.format(train_params))
        
    # Set summary op
    trn_summary = tf.summary.merge_all()

  #----------------------------------------------------------------------------
  # GRAPH - TRAINING SET ACCURACY
  #----------------------------------------------------------------------------
  logger.info('BUILD TRAINING SET ACCURACY GRAPH')
  g_trn_acc = tf.Graph()
  with g_trn_acc.as_default():
    # Get global_step
    global_step = tf.train.get_or_create_global_step()

    
    # Get data
    input_dict = create_inputs_train_wholeset()
    batch_x = input_dict['image']
    batch_labels = input_dict['label']
    
    # AG 10/12/2018: Split batch for multi gpu implementation
    # Each split is of size FLAGS.batch_size / FLAGS.num_gpus
    # See: https://github.com/naturomics/CapsNet-
    # Tensorflow/blob/master/dist_version/distributed_train.py
    splits_x = tf.split(
        axis=0, 
        num_or_size_splits=FLAGS.num_gpus, 
        value=batch_x)
    splits_labels = tf.split(
        axis=0, 
        num_or_size_splits=FLAGS.num_gpus, 
        value=batch_labels)
    
    
    #--------------------------------------------------------------------------
    # MULTI GPU - TRAINING SET ACCURACY
    #--------------------------------------------------------------------------
    # Calculate the logits for each model tower
    tower_logits = []
    tower_target_labels = []
    reuse_variables = None
    for i in range(FLAGS.num_gpus):
      with tf.device('/gpu:%d' % i):
        with tf.name_scope('tower_%d' % i) as scope:
          with slim.arg_scope([slim.variable], device='/cpu:0'):
            loss, logits, x, patch, target_labels = tower_fn(
                build_arch, 
                splits_x[i], 
                splits_labels[i], 
                scope, 
                num_classes, 
                reuse_variables=reuse_variables, 
                is_train=False)

          # Don't reuse variable for first GPU, but do reuse for others
          reuse_variables = True
          
          # Keep track of losses and logits across for each tower
          tower_logits.append(logits)
          tower_target_labels.append(target_labels)
          # Loss for each tower
          tf.summary.histogram("train_set_logits", logits)
    
    # Combine logits from all towers
    logits = tf.concat(tower_logits, axis=0)
    target_labels = tf.concat(tower_target_labels, axis=0)
    # Calculate metrics
    train_set_loss = mod.spread_loss(logits, target_labels)
    train_set_acc = met.accuracy(logits, target_labels)
    
    # Prepare predictions and one-hot labels
    train_set_probs = tf.nn.softmax(logits=logits)
    train_set_labels_oh = tf.one_hot(batch_labels, num_classes)
    
    # Group metrics together
    # See: https://cs230-stanford.github.io/tensorflow-model.html
    train_set_metrics = {'loss' : train_set_loss,
                   'labels' : batch_labels, 
                   'labels_oh' : train_set_labels_oh,
                   'logits' : logits,
                   'probs' : train_set_probs,
                   'acc' : train_set_acc,
                   }
    
    # Reset and read operations for streaming metrics go here
    train_set_reset = {}
    train_set_read = {}
    saver = tf.train.Saver(max_to_keep=None)
    
    tf.summary.scalar("train_set_loss", train_set_loss)
    tf.summary.scalar("train_set_success_rate", train_set_acc)
    trn_acc_summary = tf.summary.merge_all()
  
  if dataset_size_val > 0: 
    #----------------------------------------------------------------------------
    # GRAPH - VALIDATION
    #----------------------------------------------------------------------------
    logger.info('BUILD VALIDATION GRAPH')
    g_val = tf.Graph()
    with g_val.as_default():
      # Get global_step
      global_step = tf.train.get_or_create_global_step()

      num_batches_val = int(dataset_size_val / FLAGS.batch_size)
      
      # Get data
      input_dict = create_inputs_val()
      batch_x = input_dict['image']
      batch_labels = input_dict['label']
      
      # AG 10/12/2018: Split batch for multi gpu implementation
      # Each split is of size FLAGS.batch_size / FLAGS.num_gpus
      # See: https://github.com/naturomics/CapsNet-
      # Tensorflow/blob/master/dist_version/distributed_train.py
      splits_x = tf.split(
          axis=0, 
          num_or_size_splits=FLAGS.num_gpus, 
          value=batch_x)
      splits_labels = tf.split(
          axis=0, 
          num_or_size_splits=FLAGS.num_gpus, 
          value=batch_labels)
      
      
      #--------------------------------------------------------------------------
      # MULTI GPU - VALIDATE
      #--------------------------------------------------------------------------
      # Calculate the logits for each model tower
      tower_logits = []
      tower_target_labels = []
      reuse_variables = None
      for i in range(FLAGS.num_gpus):
        with tf.device('/gpu:%d' % i):
          with tf.name_scope('tower_%d' % i) as scope:
            with slim.arg_scope([slim.variable], device='/cpu:0'):
              loss, logits, x, patch, target_labels = tower_fn(
                  build_arch, 
                  splits_x[i], 
                  splits_labels[i], 
                  scope, 
                  num_classes, 
                  reuse_variables=reuse_variables, 
                  is_train=False)

            # Don't reuse variable for first GPU, but do reuse for others
            reuse_variables = True
            
            # Keep track of losses and logits across for each tower
            tower_logits.append(logits)
            tower_target_labels.append(target_labels)
            # Loss for each tower
            tf.summary.histogram("val_logits", logits)

      # take patch and patched images from last tower
      val_patch = patch
      val_x = x

      # Combine logits from all towers
      logits = tf.concat(tower_logits, axis=0)
      target_labels = tf.concat(tower_target_labels, axis=0)
      # Calculate metrics
      val_loss = mod.spread_loss(logits, target_labels)
      val_acc = met.accuracy(logits, target_labels)
      
      # Prepare predictions and one-hot labels
      val_probs = tf.nn.softmax(logits=logits)
      val_labels_oh = tf.one_hot(batch_labels, num_classes)
      
      # Group metrics together
      # See: https://cs230-stanford.github.io/tensorflow-model.html
      val_metrics = {'loss' : val_loss,
                     'labels' : batch_labels, 
                     'labels_oh' : val_labels_oh,
                     'logits' : logits,
                     'probs' : val_probs,
                     'acc' : val_acc,
                     }
      val_images = {'patch' : val_patch,
                    'x' : val_x} 
      # Reset and read operations for streaming metrics go here
      val_reset = {}
      val_read = {}
      
      tf.summary.scalar("val_loss", val_loss)
      tf.summary.scalar("val_success_rate", val_acc)
        
      # Saver
      saver = tf.train.Saver(max_to_keep=1)
      
      # Set summary op
      val_summary = tf.summary.merge_all()
       
        
  #****************************************************************************
  # 2. SESSIONS
  #****************************************************************************
          
  #----- SESSION TRAIN -----#
  # Session settings
  #sess_train = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
  #                                              log_device_placement=False),
  #                        graph=g_train)

  # Perry: added in for RTX 2070 incompatibility workaround
  config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
  config.gpu_options.allow_growth = True
  sess_train = tf.Session(config=config, graph=g_train)

  # Debugger
  # AG 05/06/2018: Debugging using either command line or TensorBoard
  if FLAGS.debugger is not None:
    # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
    sess_train = tf_debug.TensorBoardDebugWrapperSession(sess_train, 
                                                         FLAGS.debugger)
    
  with g_train.as_default():
    sess_train.run([tf.global_variables_initializer(),
                    tf.local_variables_initializer()])
    
    # Restore previous checkpoint
    # AG 26/09/2018: where should this go???
    if FLAGS.load_dir is not None:
      prev_step = load_training(saver, sess_train, FLAGS.load_dir, opt)
    else:
      prev_step = 0

  # Create summary writer, and write the train graph
  summary_writer = tf.summary.FileWriter(train_summary_dir, 
                                         graph=sess_train.graph)


  #----- SESSION TRAIN SET ACCURACY -----#
  #sess_val = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
  #                                            log_device_placement=False),
  #                      graph=g_val)

  # Perry: added in for RTX 2070 incompatibility workaround
  config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
  config.gpu_options.allow_growth = True
  sess_train_acc = tf.Session(config=config, graph=g_trn_acc)

  with g_trn_acc.as_default():
    sess_train_acc.run([tf.local_variables_initializer(), 
                        tf.global_variables_initializer()])


  if dataset_size_val > 0:
    #----- SESSION VALIDATION -----#
    #sess_val = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
    #                                            log_device_placement=False),
    #                      graph=g_val)
 
    # Perry: added in for RTX 2070 incompatibility workaround
    config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
    config.gpu_options.allow_growth = True
    sess_val = tf.Session(config=config, graph=g_val)


    with g_val.as_default():
      sess_val.run([tf.local_variables_initializer(), 
                    tf.global_variables_initializer()])


  #****************************************************************************
  # 3. MAIN LOOP
  #****************************************************************************
  SUMMARY_FREQ = 100
  SAVE_MODEL_FREQ = num_batches_per_epoch # 500
  VAL_FREQ = num_batches_per_epoch # 500
  PROFILE_FREQ = 5
  #print("starting main loop") 
  for step in range(prev_step, FLAGS.epoch * num_batches_per_epoch + 1): 
    #print("looping")
  #for step in range(0,3):
    # AG 23/05/2018: limit number of iterations for testing
    # for step in range(100):
    epoch_decimal = step/num_batches_per_epoch
    epoch = int(np.floor(epoch_decimal))
    

    # TF queue would pop batch until no file
    try: 
      # TRAIN
      with g_train.as_default():
    
          # With profiling
          if (FLAGS.profile is True) and ((step % PROFILE_FREQ) == 0): 
            logger.info("Train with Profiling")
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
          # Without profiling
          else:
            run_options = None
            run_metadata = None
          
          # Reset streaming metrics
          if step % (num_batches_per_epoch/4) == 1:
            logger.info("Reset streaming metrics")
            sess_train.run([trn_reset])
          
          # MAIN RUN
          tic = time.time()
          train_op_v, trn_metrics_v, trn_summary_v = sess_train.run(
              [train_op, trn_metrics, trn_summary], 
              options=run_options, 
              run_metadata=run_metadata)
          toc = time.time()
          
          # Read streaming metrics
          trn_read_v = sess_train.run(trn_read)
          
          # Write summary for profiling
          if run_options is not None: 
            summary_writer.add_run_metadata(
                run_metadata, 'epoch{:f}'.format(epoch_decimal))
          
          # Logging
          #logger.info('TRN'
          #      + ' e-{:d}'.format(epoch)
          #      + ' stp-{:d}'.format(step) 
          #        )
          #      + ' {:.2f}s'.format(toc - tic) 
          #      + ' loss: {:.4f}'.format(trn_metrics_v['loss'])
          #      + ' acc: {:.2f}%'.format(trn_metrics_v['acc']*100)
          #       )

    except KeyboardInterrupt:
      sess_train.close()
      sess_val.close()
      sys.exit()
      
    except tf.errors.InvalidArgumentError as e:
      logger.warning('%d iteration contains NaN gradients. Discard.' % step)
      logger.error(str(e))
      continue
      
    else:
      # WRITE SUMMARY
      if (step % SUMMARY_FREQ) == 0:
        logger.info("Write Train Summary")
        with g_train.as_default():
          # Summaries from graph
          summary_writer.add_summary(trn_summary_v, step)
          
      # SAVE MODEL
      if (step % SAVE_MODEL_FREQ) == 0:
        logger.info("Save Model")
        with g_train.as_default():
          train_checkpoint_dir = train_dir + '/checkpoint'
          if not os.path.exists(train_checkpoint_dir):
            os.makedirs(train_checkpoint_dir)

          # Save ckpt from train session
          ckpt_path = os.path.join(train_checkpoint_dir, 'model.ckpt' + str(epoch))
          saver.save(sess_train, ckpt_path, global_step=step)
      if (step % VAL_FREQ) == 0:
        # calculate metrics every epoch
        with g_trn_acc.as_default():
          logger.info("Start Train Set Accuracy")
          # Restore ckpt to val session
          latest_ckpt = tf.train.latest_checkpoint(train_checkpoint_dir)
          saver.restore(sess_train_acc, latest_ckpt)
          
          # Reset accumulators
          accuracy_sum = 0
          loss_sum = 0
          sess_train_acc.run(train_set_reset)
          
          for i in range(num_batches_per_epoch):
            train_set_metrics_v, train_set_summary_str_v = sess_train_acc.run(
                [train_set_metrics, trn_acc_summary])
            
            # Update
            accuracy_sum += train_set_metrics_v['acc']
            loss_sum += train_set_metrics_v['loss']
            
            # Read
            trn_read_v = sess_train_acc.run(val_read)
            
            # Get checkpoint number
            ckpt_num = re.split('-', latest_ckpt)[-1]

          # Average across batches
          ave_acc = accuracy_sum / num_batches_per_epoch
          ave_loss = loss_sum / num_batches_per_epoch
           
          logger.info('TRN ckpt-{}'.format(ckpt_num) 
                      + ' avg_success: {:.2f}%'.format(ave_acc*100) 
                      + ' avg_loss: {:.4f}'.format(ave_loss)
                     )
          
          logger.info("Write Train Summary")
          summary_train = tf.Summary()
          summary_train.value.add(tag="trn_success", simple_value=ave_acc)
          summary_train.value.add(tag="trn_loss", simple_value=ave_loss)
          summary_writer.add_summary(summary_train, epoch)
          

        if dataset_size_val > 0: 
          #----- Validation -----#
          with g_val.as_default():
            logger.info("Start Validation")
            
            # Restore ckpt to val session
            latest_ckpt = tf.train.latest_checkpoint(train_checkpoint_dir)
            saver.restore(sess_val, latest_ckpt)
            
            # Reset accumulators
            accuracy_sum = 0
            loss_sum = 0
            sess_val.run(val_reset)
            
            for i in range(num_batches_val):
              if i == num_batches_val - 1:
                # take a sample of patched images on the last validation batch
                val_metrics_v, val_summary_str_v, val_images_v = sess_val.run(
                    [val_metrics, val_summary, val_images])
                x = val_images_v['x']
                patch = val_images_v['patch']
              else:
                val_metrics_v, val_summary_str_v = sess_val.run(
                    [val_metrics, val_summary])
              # Update
              accuracy_sum += val_metrics_v['acc']
              loss_sum += val_metrics_v['loss']
              
              # Read
              val_read_v = sess_val.run(val_read)
              
              # Get checkpoint number
              ckpt_num = re.split('-', latest_ckpt)[-1]

              # Logging
              #logger.info('VAL ckpt-{}'.format(ckpt_num) 
              #            + ' bch-{:d}'.format(i) 
              #            + ' cum_acc: {:.2f}%'.format(accuracy_sum/(i+1)*100) 
              #            + ' cum_loss: {:.4f}'.format(loss_sum/(i+1))
              #           )
            
            # Average across batches
            ave_acc = accuracy_sum / num_batches_val
            ave_loss = loss_sum / num_batches_val
             
            logger.info('VAL ckpt-{}'.format(ckpt_num) 
                        + ' avg_success: {:.2f}%'.format(ave_acc*100) 
                        + ' avg_loss: {:.4f}'.format(ave_loss)
                       )
            
            logger.info("Write Val Summary")
            summary_val = tf.Summary()
            summary_val.value.add(tag="val_success", simple_value=ave_acc)
            summary_val.value.add(tag="val_loss", simple_value=ave_loss)
            summary_writer.add_summary(summary_val, epoch)
            log_images(summary_writer, "patch", [patch], epoch)
            log_images(summary_writer, "patched_input", x, epoch)
            if patch.shape[-1] == 1:
              patch = np.squeeze(patch, axis=-1)
            formatted = (patch * 255).astype('uint8')
            img = Image.fromarray(formatted)
            img.save(os.path.join(train_dir, "saved_patch.png"))
 
  # Close (main loop)
  sess_train.close()
  sess_val.close()
  sys.exit()
Beispiel #13
0
def main(args):
    # Set reproduciable random seed
    tf.set_random_seed(1234)

    # Directories
    # Get name
    split = FLAGS.load_dir.split('/')
    if split[-1]:
        name = split[-1]
    else:
        name = split[-2]

    # Get parent directory
    split = FLAGS.load_dir.split("/" + name)
    parent_dir = split[0]

    test_dir = '{}/{}/reconstructions'.format(parent_dir, name)
    test_summary_dir = test_dir + '/summary'

    # Clear the test log directory
    if (FLAGS.reset is True) and os.path.exists(test_dir):
        shutil.rmtree(test_dir)
    if not os.path.exists(test_summary_dir):
        os.makedirs(test_summary_dir)

    # Logger
    conf.setup_logger(logger_dir=test_dir, name="logger_test.txt")
    logger.info("name: " + name)
    logger.info("parent_dir: " + parent_dir)
    logger.info("test_dir: " + test_dir)

    # Load hyperparameters from train run
    conf.load_or_save_hyperparams()

    # Get dataset hyperparameters
    logger.info('Using dataset: {}'.format(FLAGS.dataset))

    # Dataset
    dataset_size_test = conf.get_dataset_size_test(FLAGS.dataset)
    num_classes = conf.get_num_classes(FLAGS.dataset)
    # train mode for random sampling
    create_inputs_test = conf.get_create_inputs(FLAGS.dataset, mode="train")

    # ----------------------------------------------------------------------------
    # GRAPH - TEST
    # ----------------------------------------------------------------------------
    logger.info('BUILD TEST GRAPH')
    g_test = tf.Graph()
    with g_test.as_default():
        tf.train.get_or_create_global_step()
        # Get data
        input_dict = create_inputs_test()
        batch_x = input_dict['image']
        batch_labels = input_dict['label']

        # Build architecture
        build_arch = conf.get_dataset_architecture(FLAGS.dataset)
        # for baseline
        # build_arch = conf.get_dataset_architecture('baseline')

        # --------------------------------------------------------------------------
        # MULTI GPU - TEST
        # --------------------------------------------------------------------------
        # Calculate the logits for each model tower
        with tf.device('/gpu:0'):
            with tf.name_scope('tower_0') as scope:
                with slim.arg_scope([slim.variable], device='/cpu:0'):
                    loss, logits, recon, cf_recon = tower_fn(
                        build_arch,
                        batch_x,
                        batch_labels,
                        scope,
                        num_classes,
                        reuse_variables=tf.AUTO_REUSE,
                        is_train=False)

                # Keep track of losses and logits across for each tower
                recon_images = tf.reshape(recon, batch_x.get_shape())
                cf_recon_images = tf.reshape(cf_recon, batch_x.get_shape())
                images = {
                    "reconstructed_images": recon_images,
                    "reconstructed_images_zeroed_background": cf_recon_images,
                    "input": batch_x
                }
        saver = tf.train.Saver(max_to_keep=None)

        # --------------------------------------------------------------------------
        # SESSION - TEST
        # --------------------------------------------------------------------------
        # sess_test = tf.Session(
        #    config=tf.ConfigProto(allow_soft_placement=True,
        #                          log_device_placement=False),
        #    graph=g_test)
        # Perry: added in for RTX 2070 incompatibility workaround
        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        config.gpu_options.allow_growth = True
        sess_test = tf.Session(config=config, graph=g_test)

        # sess_test.run(tf.local_variables_initializer())
        # sess_test.run(tf.global_variables_initializer())

        summary_writer = tf.summary.FileWriter(test_summary_dir,
                                               graph=sess_test.graph)

        ckpts_to_test = []
        load_dir_chechpoint = os.path.join(FLAGS.load_dir, "train",
                                           "checkpoint")

        # Evaluate the latest ckpt in dir
        if FLAGS.ckpt_name is None:
            latest_ckpt = tf.train.latest_checkpoint(load_dir_chechpoint)
            ckpts_to_test.append(latest_ckpt)
        # Evaluate all ckpts in dir
        else:
            ckpt_name = os.path.join(load_dir_chechpoint, FLAGS.ckpt_name)
            ckpts_to_test.append(ckpt_name)

            # --------------------------------------------------------------------------
        # MAIN LOOP
        # --------------------------------------------------------------------------
        # Run testing on checkpoints
        for ckpt in ckpts_to_test:
            saver.restore(sess_test, ckpt)

            for i in range(dataset_size_test):
                out = sess_test.run([images])
                reconstructed_image, reconstructed_image_zeroed_background, input_img =\
                    out[0]["reconstructed_images"], out[0]["reconstructed_images_zeroed_background"], out[0]["input"]
                if reconstructed_image.shape[0] == 1:
                    reconstructed_image = np.squeeze(reconstructed_image,
                                                     axis=0)
                    reconstructed_image_zeroed_background = np.squeeze(
                        reconstructed_image_zeroed_background, axis=0)
                    input_img = np.squeeze(input_img, axis=0)
                if reconstructed_image.shape[-1] == 1:
                    reconstructed_image = np.squeeze(reconstructed_image,
                                                     axis=-1)
                    reconstructed_image_zeroed_background = np.squeeze(
                        reconstructed_image_zeroed_background, axis=-1)
                    input_img = np.squeeze(input_img, axis=-1)
                reconstructed_image = Image.fromarray(
                    (reconstructed_image * 255).astype('uint8'))
                reconstructed_image_zeroed_background = Image.fromarray(
                    (reconstructed_image_zeroed_background *
                     255).astype('uint8'))
                input_img = Image.fromarray((input_img * 255).astype('uint8'))
                fig = plt.figure(figsize=(1, 3))
                fig.add_subplot(1, 3, 1)
                plt.imshow(input_img)
                fig.add_subplot(1, 3, 2)
                plt.imshow(reconstructed_image)
                fig.add_subplot(1, 3, 3)
                plt.imshow(reconstructed_image_zeroed_background)
                plt.show()