Ejemplo n.º 1
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    FLAGS = settings()

    np.random.seed(FLAGS.seed)
    tf.set_random_seed(FLAGS.seed)

    # Slim dataset contains data sources, decoder, reader and other meta-information
    dataset = mnist.get_split('train', FLAGS.dataset_dir)
    iterations_per_epoch = dataset.num_samples // FLAGS.batch_size  # 60,000/24 = 2500

    # images: Tensor (?, 28, 28, 1)
    # labels: Tensor (?)
    images, labels = load_batch(dataset, FLAGS.batch_size)

    # Tensor(?, 10)
    one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)

    # poses: Tensor(?, 10, 4, 4) activations: (?, 10)
    poses, activations = m_capsules.nets.capsules_net(
        images,
        num_classes=10,
        iterations=3,
        batch_size=FLAGS.batch_size,
        name='capsules_em')

    global_step = tf.train.get_or_create_global_step()
    loss = m_capsules.nets.spread_loss(one_hot_labels,
                                       activations,
                                       iterations_per_epoch,
                                       global_step,
                                       name='spread_loss')
    tf.summary.scalar('losses/spread_loss', loss)

    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    train_tensor = slim.learning.create_train_op(loss,
                                                 optimizer,
                                                 global_step=global_step,
                                                 clip_gradient_norm=4.0)

    slim.learning.train(
        train_tensor,
        logdir=FLAGS.log_dir,
        log_every_n_steps=10,
        save_summaries_secs=60,
        saver=tf.train.Saver(max_to_keep=2),
        save_interval_secs=600,
    )
Ejemplo n.º 2
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    FLAGS = settings()

    np.random.seed(FLAGS.seed)
    tf.set_random_seed(FLAGS.seed)

    # Slim dataset contains data sources, decoder, reader and other meta-information
    dataset = mnist.get_split('train', FLAGS.dataset_dir)
    iterations_per_epoch = dataset.num_samples // FLAGS.batch_size # 60,000/24 = 2500

    # images: Tensor (?, 28, 28, 1)
    # labels: Tensor (?)
    images, labels = load_batch(
        dataset,
        FLAGS.batch_size)

    # Tensor(?, 10)
    one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)

    # poses: Tensor(?, 10, 4, 4) activations: (?, 10)
    poses, activations = m_capsules.nets.capsules_net(images, num_classes=10, iterations=3, batch_size=FLAGS.batch_size, name='capsules_em')

    global_step = tf.train.get_or_create_global_step()
    loss = m_capsules.nets.spread_loss(
        one_hot_labels, activations, iterations_per_epoch, global_step, name='spread_loss'
    )
    tf.summary.scalar('losses/spread_loss', loss)

    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    train_tensor = slim.learning.create_train_op(
        loss, optimizer, global_step=global_step, clip_gradient_norm=4.0
    )

    slim.learning.train(
        train_tensor,
        logdir=FLAGS.log_dir,
        log_every_n_steps=10,
        save_summaries_secs=60,
        saver=tf.train.Saver(max_to_keep=2),
        save_interval_secs=600,
    )
Ejemplo n.º 3
0
warnings.filterwarnings("ignore")
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
from cap_config import settings
from cap_datasets import mnist
from sklearn.metrics import classification_report,confusion_matrix,roc_auc_score,accuracy_score
import m_capsules
import time
FLAGS = None
tf.logging.set_verbosity(tf.logging.INFO)

# Importing datasets
X_train, Y_train, valX, valY, testX, testY = mnist.class_8()

FLAGS = settings()
X = tf.placeholder(shape=[None, 28, 28, 1], dtype=tf.float32, name="X")
y = tf.placeholder(shape=[None], dtype=tf.int64, name="y")
np.random.seed(FLAGS.seed)
tf.set_random_seed(FLAGS.seed)
"""iterations_per_epoch = dataset.num_samples // FLAGS.batch_size # 60,000/24 = 2500"""
iterations_per_epoch = len(X_train) // FLAGS.batch_size
# images: Tensor (?, 28, 28, 1)
# labels: Tensor (?)
one_hot_labels = slim.one_hot_encoding(y, 2)
# poses: Tensor(?, 10, 4, 4) activations: (?, 10)
poses, activations = m_capsules.nets.capsules_net(X, num_classes=2, iterations=3, batch_size=FLAGS.batch_size, name='capsules_em')
y_proba_argmax = tf.argmax(activations, axis=1, name="y_proba")
y_pred =  y_proba_argmax    #tf.squeeze(y_proba_argmax, axis=[1,2], name="y_pred")
correct = tf.equal(y, y_pred, name="correct")
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy")