コード例 #1
0
def main(not_parsed_args):
    # we use a margin loss
    model = CapsNet()
    last_epoch = load_weights(model)
    model.compile(loss=margin_loss, optimizer=optimizers.Adam(FLAGS.lr), metrics=['accuracy'])
    model.summary()

    dataset = Dataset(FLAGS.dataset, FLAGS.batch_size)
    tensorboard = TensorBoard(log_dir='./tf_logs', batch_size=FLAGS.batch_size, write_graph=False, write_grads=True, write_images=True, update_freq='batch')
    tensorboard.set_model(model)

    for epoch in range(last_epoch, FLAGS.epochs):
        logging.info('Epoch %d' % epoch)
        model.fit_generator(generator=dataset,
            epochs=1,
            steps_per_epoch=len(dataset),
            verbose=1,
            validation_data=dataset.eval_dataset,
            validation_steps=len(dataset.eval_dataset))

        logging.info('Saving model')
        filename = 'model_%d.h5' % (epoch)
        path = os.path.join(FLAGS.model_dir, filename)
        path_info = os.path.join(FLAGS.model_dir, 'info')
        model.save_weights(path)
        f = open(path_info, 'w')
        f.write(filename)
        f.close()
コード例 #2
0
def main():
    """CapsNet run as module.

    Run full cycle when CapsNet is run as a module.
    """
    people = fetch_lfw_people(
        color=True,
        min_faces_per_person=25,
        # resize=1.,
        # slice_=(slice(48, 202), slice(48, 202))
    )

    data = preprocess(people)

    (x_train, y_train), (x_test, y_test) = data  # noqa

    model = CapsNet(x_train.shape[1:], len(np.unique(y_train, axis=0)))

    model.summary()

    # Start TensorBoard
    tensorboard = callbacks.TensorBoard('model/tensorboard_logs',
                                        batch_size=10,
                                        histogram_freq=1,
                                        write_graph=True,
                                        write_grads=True,
                                        write_images=True)
    model.train(data, batch_size=10, extra_callbacks=[tensorboard])
    model.save('/tmp')

    metrics = model.test(x_test, y_test)
    pprint(metrics)
コード例 #3
0
def main():
    # Device configuration, check cuda availability
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Hyper parameters
    BATCH_SIZE = 128
    EPOCHS_NUM = 50
    LR = 0.001
    GAMMA = 0.96

    transform = transforms.Compose([
        # shift by 2 pixels in either direction with zero padding.
        transforms.RandomCrop(28, padding=2),
        transforms.ToTensor(),
        transforms.Normalize((0.1307, ), (0.3081, ))
    ])

    train_dataset = torchvision.datasets.MNIST(root=op.join(
        sys.path[0], 'data/'),
                                               train=True,
                                               transform=transform,
                                               download=True)

    test_dataset = torchvision.datasets.MNIST(root=op.join(
        sys.path[0], 'data/'),
                                              train=False,
                                              transform=transform)

    # Data loader
    train_loader = Data.DataLoader(dataset=train_dataset,
                                   batch_size=BATCH_SIZE,
                                   num_workers=4,
                                   shuffle=True)

    test_loader = Data.DataLoader(dataset=test_dataset,
                                  batch_size=BATCH_SIZE,
                                  num_workers=4,
                                  shuffle=True)

    # Load model
    model = CapsNet().to(device)
    # Loss and optimizer
    criterion = CapsuleLoss()
    optimizer = optim.Adam(model.parameters(), lr=LR)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=GAMMA)

    # Train and test model
    train(model, device, train_loader, criterion, optimizer, scheduler,
          EPOCHS_NUM)
    test(model, device, test_loader)

    # Save model
    torch.save(model, op.join(sys.path[0], 'model/mnist_capsnet.pt'))
コード例 #4
0
def main(_):
    dataset = cfg.dataset
    input_shape, num_classes, use_test_queue = get_dataset_values(
        dataset, cfg.test_batch_size, is_training=False)

    tf.logging.info("Initializing CNN for {}...".format(dataset))
    model = CapsNet(input_shape,
                    num_classes,
                    is_training=False,
                    use_test_queue=use_test_queue)
    tf.logging.info("Finished initialization.")

    if not os.path.exists(cfg.logdir):
        os.mkdir(cfg.logdir)
    logdir = os.path.join(cfg.logdir, model.name)
    if not os.path.exists(logdir):
        os.mkdir(logdir)
    logdir = os.path.join(logdir, dataset)
    if not os.path.exists(logdir):
        os.mkdir(logdir)

    sv = tf.train.Supervisor(graph=model.graph,
                             logdir=logdir,
                             save_model_secs=0)

    tf.logging.info("Initialize evaluation...")
    evaluate(model, sv, dataset)
    tf.logging.info("Finished evaluation.")
コード例 #5
0
ファイル: main.py プロジェクト: longmoc/sample_models
def eval_single_image(image_path):
    image = read_image(image_path)
    model = CapsNet(train_path, image_size, classes, 1)
    sv = tf.train.Supervisor(graph=model.graph,
                             logdir=logdir,
                             save_model_secs=0)
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
    with sv.managed_session(config=tf.ConfigProto(
            gpu_options=gpu_options)) as sess:
        sv.saver.restore(sess, tf.train.latest_checkpoint(logdir))
        tf.logging.info('Model restored!')
        y_pred, y_pred_cls = sess.run((model.softmax_v, model.argmax_idx),
                                      {model.X: [image]})
        print(y_pred[0])
        print(y_pred_cls[0], y_pred[0][y_pred_cls[0]][0][0])
コード例 #6
0
    def init_engine(self, is_training=True):
        """
        This function initialize the engine from the config by extracting some parameters
        It also create a saver and the config proto for training configuration

        :param is_training:
        :return:
        """
        self.tf_log_dir = get_from_config('log')
        self.checkpoint_path = get_from_config('checkpoint_path')
        self.batch_size = get_from_config('batch_size')
        self.num_epochs = get_from_config('epochs')
        self.model = CapsNet(is_training=is_training)
        self.saver = None

        self.create_config_proto()
コード例 #7
0
ファイル: main.py プロジェクト: longmoc/sample_models
def evaluate(valid_data, batch_size):
    model = CapsNet(train_path, image_size, classes, batch_size)
    num_valid_batch = int(valid_data.num_examples / batch_size) + 1
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
    supervisor = tf.train.Supervisor(graph=model.graph,
                                     logdir=logdir,
                                     save_model_secs=0)
    with supervisor.managed_session(config=tf.ConfigProto(
            gpu_options=gpu_options)) as sess:
        supervisor.saver.restore(sess, tf.train.latest_checkpoint(logdir))
        tf.logging.info('Model restored!')
        test_acc = 0
        for step in range(num_valid_batch):
            x, y, _, _ = valid_data.next_batch(batch_size)
            acc = sess.run(model.accuracy, {model.X: x, model.labels: y})
            test_acc += acc
        test_acc = test_acc / num_valid_batch
        print(test_acc)
コード例 #8
0
ファイル: main.py プロジェクト: longmoc/sample_models
def train():
    data = dataset.read_train_sets(train_path,
                                   image_size,
                                   classes,
                                   validation_size=validation_size)
    num_tr_batch = int(data.train.num_examples / batch_size) + 1
    # trX, trY, num_tr_batch, valX, valY, num_val_batch = load_data('mnist', batch_size, is_training=True)
    model = CapsNet(train_path, image_size, classes, batch_size)
    sv = tf.train.Supervisor(graph=model.graph,
                             logdir=logdir,
                             save_model_secs=0)
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
    with sv.managed_session(config=tf.ConfigProto(
            gpu_options=gpu_options)) as sess:
        global_step = 0
        for epoch in range(num_epochs):
            print('Ep %d/%d' % (epoch, num_epochs - 1))
            if sv.should_stop():
                print('Should stop.')
                break
            for step in range(num_tr_batch):
                global_step = epoch * num_tr_batch + step
                img_data, label_data, _, _ = data.train.next_batch(batch_size)
                if global_step % 10 == 0:
                    _, loss, acc, summary, margin_loss = sess.run([
                        model.train_op, model.total_loss, model.accuracy,
                        model.train_summary, model.margin_loss
                    ])
                    print('Step %d: loss %f \t acc %f \t margin loss %f' %
                          (global_step, loss, acc, margin_loss))
                    sv.summary_writer.add_summary(summary, global_step)
                else:
                    sess.run(model.train_op)

            if (epoch + 1) % 2 == 0:
                sv.saver.save(
                    sess, logdir + '/model_epoch_%04d_step_%02d' %
                    (epoch, global_step))

        evaluate(data.valid, batch_size=10)
コード例 #9
0
ファイル: main.py プロジェクト: jarchv/capsnet-tensorflow
        plt.imshow(samples_imgs[img_idx], cmap='gray')
        plt.title("Input: " + str(mnist.test.labels[img_idx]))
        plt.axis("off")

    #plt.show()
    for img_idx in range(num_samples):
        plt.subplot(2, num_samples, num_samples + img_idx + 1)
        plt.imshow(reconstructions_imgs[img_idx], cmap='gray')
        plt.title("Output: " + str(y_pred_value[img_idx]))
        plt.axis("off")

    plt.show()


def count_params():
    size = lambda v: functools.reduce(lambda x, y: x * y,
                                      v.get_shape().as_list())
    n_trainable = sum(size(v) for v in tf.trainable_variables())
    #n_total = sum(size(v) for v in tf.all_variables())

    print("Model size (Trainable): {:.1f}M\n".format(n_trainable / 1000000.0))
    #print("Model size (Total): {}".format(n_total))


if __name__ == '__main__':
    tf.reset_default_graph()
    model = CapsNet(rounds=3)
    #train(model, False, 50)
    test(model)
#reconstruction(model, 5)
コード例 #10
0
        loss = capsule_net.loss(data, output, target, reconstructions)

        test_loss += loss.data.item()
        correct += sum(
            np.argmax(masked.data.cpu().numpy(), 1) == np.argmax(
                target.data.cpu().numpy(), 1))

    tqdm.write("Epoch: [{}/{}], test accuracy: {:.6f}, loss: {:.6f}".format(
        epoch, N_EPOCHS, correct / len(test_loader.dataset),
        test_loss / len(test_loader)))


if __name__ == '__main__':
    torch.manual_seed(1)
    dataset = 'mnist'

    config = Config(dataset)
    mnist = Dataset(dataset, BATCH_SIZE)

    capsule_net = CapsNet(config)
    capsule_net = torch.nn.DataParallel(capsule_net)
    if USE_CUDA:
        capsule_net = capsule_net.cuda()
    capsule_net = capsule_net.module

    optimizer = torch.optim.Adam(capsule_net.parameters(), lr=LEARNING_RATE)

    for e in range(1, N_EPOCHS + 1):
        train(capsule_net, optimizer, mnist.train_loader, e)
        test(capsule_net, mnist.test_loader, e)
コード例 #11
0
ファイル: main.py プロジェクト: JensOverby/MatrixCaps
                                              shuffle=False,
                                              drop_last=False)

    sup_iterator = train_loader.__iter__()
    test_iterator = test_loader.__iter__()
    _, imgs, labels = sup_iterator.next()
    sup_iterator = train_loader.__iter__()
    """
    Setup model, load it to CUDA and make JIT compilation
    """
    normalize = 0
    if args.normalize:
        normalize = args.normalize
    model = CapsNet(labels.shape[1],
                    img_shape=imgs[0].shape,
                    dataset=args.dataset,
                    normalize=normalize,
                    device=device)

    if not args.disable_cuda:
        model.cuda()
    """
    if args.jit:
        dummy1 = torch.rand(args.batch_size,4,100,100).float()
        dummy2 = torch.rand(args.batch_size,12).float()
        if not args.disable_cuda:
            dummy1 = dummy1.cuda()
            dummy2 = dummy2.cuda()
        model(lambda_, dummy1, dummy2)
        model = torch.jit.trace(model, (lambda_, dummy1, dummy2), check_inputs=[(lambda_, dummy1, dummy2)])
    """
コード例 #12
0
checkpoint_path = os.path.join(args.directory, args.name)
n_samples = 5

idx = np.random.choice(x_test.shape[0], size=n_samples, replace=False)
sample_images = x_test[idx, :]
sample_images = sample_images.reshape(-1, 28, 28, 1)

# Placeholders
X = tf.placeholder(shape=[None, 28, 28, 1], dtype=tf.float32, name="X")
y = tf.placeholder(shape=[None], dtype=tf.int64, name="y")
mask_with_labels = tf.placeholder_with_default(False,
                                               shape=(),
                                               name="mask_with_labels")

# Rebuild the models.
caps2_output = CapsNet(X)
y_prob = safe_norm(caps2_output, axis=-2, name="y_prob")
# Choose the predicted one.
y_prob_argmax = tf.argmax(y_prob, axis=2, name="y_predicted_argmax")
y_pred = tf.squeeze(y_prob_argmax, axis=[1, 2], name="y_pred")

if args.reconstruct:
    reconstruction_loss, decoder_output = reconstruct(caps2_output,
                                                      mask_with_labels, X, y,
                                                      y_pred, Labels,
                                                      outputDimension)

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, checkpoint_path)
    if args.reconstruct:
コード例 #13
0
(x_train, y_train), (x_test, y_test) = mnist.load_data(
)  # separates data set into testing and training sets

# normalize the bois
x_train = (x_train / 255).astype("float32").reshape([-1, 28, 28, 1])
x_test = (x_test / 255).astype("float32").reshape([-1, 28, 28, 1])
y_train = y_train.astype("int32")
y_test = y_test.astype("int32")

reconstructor = tf.keras.models.Sequential()
reconstructor.add(tf.keras.layers.Flatten())
reconstructor.add(tf.keras.layers.Dense(512, activation='relu'))
reconstructor.add(tf.keras.layers.Dense(1024, activation='relu'))
reconstructor.add(tf.keras.layers.Dense(784, activation='sigmoid'))

capsNet = CapsNet()
optimizer = tf.keras.optimizers.Adam()
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
    name='train_accuracy')
train_reconstructor_mae = tf.keras.metrics.Mean(name='train_reconstructor_mae')
checkpoint_path = "./MNIST_checkpoints"
ckpt = tf.train.Checkpoint(capsNet=capsNet, reconstructor=reconstructor)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print('Latest checkpoint restored!!')

train_step_signature = [
    tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32),
コード例 #14
0
ファイル: train.py プロジェクト: erikreppel/capsulenet
        acc = utils.categorical_accuracy(y.float(), preds.cpu().data)
        logger.log(epoch,
                   i,
                   len(dataloader.dataset),
                   start + '_TEST',
                   loss=loss.data[0],
                   acc=acc)


trainloader, testloader = utils.mnist_dataloaders(args.data_path,
                                                  args.batch_size,
                                                  args.use_gpu)

model = CapsNet(n_conv_channel=256,
                n_primary_caps=8,
                primary_cap_size=1152,
                output_unit_size=16,
                n_routing_iter=3)

# load state from past runs
if args.load_checkpoint != '':
    model.load_state_dict(torch.load(args.load_checkpoint))

# move to GPU
model = model.cuda() if args.use_gpu else model
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

# setup decoder for training
decoder = Decoder()
decoder = decoder.cuda() if args.use_gpu else decoder
decoder_optim = torch.optim.Adam(decoder.parameters(), lr=0.001)
コード例 #15
0
from torch.autograd import Variable
from torch.utils.data.dataloader import DataLoader
from torchvision import datasets

from capsnet import CapsNet
from functions import DigitMarginLoss
from functions import accuracy

train_loader = DataLoader(datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([
    # transforms.RandomShift(2),
    transforms.ToTensor()])), batch_size=1, shuffle=True)

test_loader = DataLoader(datasets.MNIST('data', train=False, transform=transforms.Compose([
    transforms.ToTensor()])), batch_size=1)

model = CapsNet()
optimizer = optim.Adam(model.parameters())
margin_loss = DigitMarginLoss()
reconstruction_loss = torch.nn.MSELoss(size_average=False)
model.train()

for epoch in range(1, 11):
    epoch_tot_loss = 0
    epoch_tot_acc = 0
    for batch, (data, target) in enumerate(train_loader, 1):
        data = Variable(data)
        target = Variable(target)

        digit_caps, reconstruction = model(data, target)
        loss = margin_loss(digit_caps, target) + 0.0005 * reconstruction_loss(reconstruction, data.view(-1))
        epoch_tot_loss += loss
コード例 #16
0
        loss = capsule_net.loss(data, output, target, reconstructions)

        test_loss += loss.data[0]
        correct += sum(
            np.argmax(masked.data.cpu().numpy(), 1) == np.argmax(
                target.data.cpu().numpy(), 1))

    tqdm.write("Epoch: [{}/{}], test accuracy: {:.6f}, loss: {:.6f}".format(
        epoch, N_EPOCHS, correct / len(test_loader.dataset),
        test_loss / len(test_loader)))


if __name__ == '__main__':
    torch.manual_seed(1)
    dataset = 'cifar10'
    # dataset = 'mnist'
    config = Config(dataset)
    mnist = Dataset(dataset, BATCH_SIZE)

    capsule_net = CapsNet(config)
    capsule_net = torch.nn.DataParallel(capsule_net)
    if USE_CUDA:
        capsule_net = capsule_net.cuda()
    capsule_net = capsule_net.module

    optimizer = torch.optim.Adam(capsule_net.parameters())

    for e in range(1, N_EPOCHS + 1):
        train(capsule_net, optimizer, mnist.train_loader, e)
        test(capsule_net, mnist.test_loader, e)
コード例 #17
0
from capsnet import CapsNet
from getdata import get_train_data,get_test_data
import torch
from torch.autograd import Variable
import torch.utils.data as Data

# train_data:[17023,6000,1],train_tag:[17023] numpy类型
train_data,train_tag=get_train_data()
# test_data,test_tag=get_test_data()
net=CapsNet()
# 数据集加载
# 数据转换成tensor 并[17023,6000,1]=>[17023,1,6000]
train_data=torch.FloatTensor(train_data).permute(0,2,1)
train_tag=torch.LongTensor(train_tag)

train_set=Data.TensorDataset(train_data,train_tag)
train_loader=Data.DataLoader(dataset=train_set,batch_size=32,shuffle=True)
# 优化器设置
optimizer=torch.optim.Adam(net.parameters(),lr=0.001)
loss_func=torch.nn.CrossEntropyLoss()

# train...
print('开始训练:')
accc=[]
epoch=50
maxnum=0
for i in range(epoch):
    for j,(x,y) in enumerate(train_loader):
        
        x,y=Variable(x),Variable(y)
        out=net(x)
コード例 #18
0
ファイル: main.py プロジェクト: Xiangs18/CapsNet
def main():
    # Load model
    model = CapsNet().to(device)
    criterion = CapsuleLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96)

    # Load data
    transform = transforms.Compose([
        # shift by 2 pixels in either direction with zero padding.
        transforms.RandomCrop(28, padding=2),
        transforms.ToTensor(),
        transforms.Normalize((0.1307, ), (0.3081, )),
    ])
    DATA_PATH = "./data"
    BATCH_SIZE = 128
    train_loader = DataLoader(
        dataset=MNIST(root=DATA_PATH,
                      download=True,
                      train=True,
                      transform=transform),
        batch_size=BATCH_SIZE,
        num_workers=4,
        shuffle=True,
    )
    test_loader = DataLoader(
        dataset=MNIST(root=DATA_PATH,
                      download=True,
                      train=False,
                      transform=transform),
        batch_size=BATCH_SIZE,
        num_workers=4,
        shuffle=True,
    )

    # Train
    EPOCHES = 50
    model.train()
    for ep in range(EPOCHES):
        batch_id = 1
        correct, total, total_loss = 0, 0, 0.0
        for images, labels in train_loader:
            optimizer.zero_grad()
            images = images.to(device)
            labels = torch.eye(10).index_select(dim=0, index=labels).to(device)
            logits, reconstruction = model(images)

            # Compute loss & accuracy
            loss = criterion(images, labels, logits, reconstruction)
            correct += torch.sum(
                torch.argmax(logits, dim=1) == torch.argmax(labels,
                                                            dim=1)).item()
            total += len(labels)
            accuracy = correct / total
            total_loss += loss
            loss.backward()
            optimizer.step()
            print("Epoch {}, batch {}, loss: {}, accuracy: {}".format(
                ep + 1, batch_id, total_loss / batch_id, accuracy))
            batch_id += 1
        scheduler.step(ep)
        print("Total loss for epoch {}: {}".format(ep + 1, total_loss))

    # Eval
    model.eval()
    correct, total = 0, 0
    for images, labels in test_loader:
        # Add channels = 1
        images = images.to(device)
        # Categogrical encoding
        labels = torch.eye(10).index_select(dim=0, index=labels).to(device)
        logits, reconstructions = model(images)
        pred_labels = torch.argmax(logits, dim=1)
        correct += torch.sum(pred_labels == torch.argmax(labels, dim=1)).item()
        total += len(labels)
    print("Accuracy: {}".format(correct / total))

    # Save model
    torch.save(
        model.state_dict(),
        "./model/capsnet_ep{}_acc{}.pt".format(EPOCHES, correct / total),
    )
コード例 #19
0
                                     download=True,
                                     transform=transform)
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=BATCH_SIZE,
                                         shuffle=False,
                                         num_workers=2)


def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))


# Initialize capsnet
capsnet = CapsNet()
# Initialize optimizer and loss function
optimizer = torch.optim.Adam(capsnet.get_params())
# Using spread loss for capsnet
# loss_function = lambda x, y: spread_loss(x, y, 1)
loss_function = nn.CrossEntropyLoss()
loss_margin = 0.2

# Start training
start_time = time.time()
print('Training...')
for epoch in range(EPOCHS):
    running_loss = 0.0
    for i, data in enumerate(trainloader):
        # Zero gradients
        optimizer.zero_grad()
コード例 #20
0
    DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    DATASET_CONFIG = MNIST

    dataset_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST('./data',
                                   train=True,
                                   download=True,
                                   transform=dataset_transform)
    test_dataset = datasets.MNIST('./data',
                                  train=False,
                                  download=True,
                                  transform=dataset_transform)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    torch.manual_seed(1)

    capsule_net = CapsNet(**DATASET_CONFIG)
    capsule_net.to(DEVICE)

    optimizer = torch.optim.Adam(capsule_net.parameters(), lr=LEARNING_RATE)

    for e in range(1, 1 + EPOCHS):
        train(capsule_net, optimizer, train_loader, e, device=DEVICE)
        evaluate(capsule_net, test_loader, e, device=DEVICE)
コード例 #21
0
 def __init__(self):
     self.capsnet = CapsNet()
     pass
コード例 #22
0
ファイル: train.py プロジェクト: rremani/capsnet.pytorch
from loss import DigitMarginLoss
from utils import accuracy


train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([
        # transforms.RandomShift(2),
        transforms.ToTensor()
    ])), shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
    ])))

model = CapsNet()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
margin_loss = DigitMarginLoss()
reconstruction_loss = torch.nn.MSELoss(size_average=False)
model.train()

for epoch in range(1, 11):
    epoch_tot_loss = 0
    epoch_tot_acc = 0
    for batch, (input, target) in enumerate(train_loader, 1):
        input = Variable(input)
        target = Variable(target)
    
        digit_caps, reconstruction = model(input, target)
        loss = margin_loss(digit_caps, target) + 0.0005 * reconstruction_loss(reconstruction, input.view(-1))
        epoch_tot_loss += loss
コード例 #23
0
def main(_):
    # Get the batches of training and testing data
    training_X_batch, training_Y_batch, testing_batch = u.create_training_and_testing_batches_for_ld_MMNIST(num_examples=num_examples)
    # training_X_batch, training_Y_batch, testing_batch = u.create_training_and_testing_batches()
    # training_X_batch_full, training_Y_batch_full, testing_batch_full = u.create_training_and_testing_batches()

    # Create a capsule network
    capsnet = CapsNet()

    # Get training errors and reconstructions
    (train_total_error,
     train_margin_error,
     train_reconstruction_error,
     train_reconstructed_combined_image,
     train_reconstructed_first_image,
     train_reconstructed_second_image,
     train_accuracy,
     train_memo_image_reconstructions,
     train_memo_margin_loss,
     train_memo_accuracy) = capsnet.compute_output(training_X_batch, training_Y_batch, keep_prob=0.5)

    # Create operations to minimize training loss
    train_op = tf.train.AdamOptimizer().minimize(train_total_error)
    train_memo_op = tf.train.AdamOptimizer().minimize(train_memo_margin_loss)

    # Get test errors and reconstructions
    (test_0px_sub_total_error,
     test_0px_sub_margin_error,
     test_0px_sub_reconstruction_error,
     _,  # For testing, we don't care about the reconstructed images
     _,  # Reconstructed image 1
     _,  # Reconstructed image 2
     test_0px_sub_accuracy,
     test_0px_sub_memo_image_reconstructions,
     test_0px_sub_memo_margin_loss,
     test_0px_sub_memo_accuracy) = capsnet.compute_output(testing_batch[0], testing_batch[1], keep_prob=1)

    (test_2px_sub_total_error,
     test_2px_sub_margin_error,
     test_2px_sub_reconstruction_error,
     _,  # For testing, we don't care about the reconstructed images
     _,  # Reconstructed image 1
     _,  # Reconstructed image 2
     test_2px_sub_accuracy,
     test_2px_sub_memo_image_reconstructions,
     test_2px_sub_memo_margin_loss,
     test_2px_sub_memo_accuracy) = capsnet.compute_output(testing_batch[2], testing_batch[3], keep_prob=1)

    (test_4px_sub_total_error,
     test_4px_sub_margin_error,
     test_4px_sub_reconstruction_error,
     _,  # For testing, we don't care about the reconstructed images
     _,  # Reconstructed image 1
     _,  # Reconstructed image 2
     test_4px_sub_accuracy,
     test_4px_sub_memo_image_reconstructions,
     test_4px_sub_memo_margin_loss,
     test_4px_sub_memo_accuracy) = capsnet.compute_output(testing_batch[4], testing_batch[5], keep_prob=1)

    (test_6px_sub_total_error,
     test_6px_sub_margin_error,
     test_6px_sub_reconstruction_error,
     _,  # For testing, we don't care about the reconstructed images
     _,  # Reconstructed image 1
     _,  # Reconstructed image 2
     test_6px_sub_accuracy,
     test_6px_sub_memo_image_reconstructions,
     test_6px_sub_memo_margin_loss,
     test_6px_sub_memo_accuracy) = capsnet.compute_output(testing_batch[6], testing_batch[7], keep_prob=1)

    (test_8px_sub_total_error,
     test_8px_sub_margin_error,
     test_8px_sub_reconstruction_error,
     _,  # For testing, we don't care about the reconstructed images
     _,  # Reconstructed image 1
     _,  # Reconstructed image 2
     test_8px_sub_accuracy,
     test_8px_sub_memo_image_reconstructions,
     test_8px_sub_memo_margin_loss,
     test_8px_sub_memo_accuracy) = capsnet.compute_output(testing_batch[8], testing_batch[9], keep_prob=1)

    (test_0px_full_total_error,
     test_0px_full_margin_error,
     test_0px_full_reconstruction_error,
     _,  # For testing, we don't care about the reconstructed images
     _,  # Reconstructed image 1
     _,  # Reconstructed image 2
     test_0px_full_accuracy,
     test_0px_full_memo_image_reconstructions,
     test_0px_full_memo_margin_loss,
     test_0px_full_memo_accuracy) = capsnet.compute_output(testing_batch[10], testing_batch[11], keep_prob=1)

    (test_2px_full_total_error,
     test_2px_full_margin_error,
     test_2px_full_reconstruction_error,
     _,  # For testing, we don't care about the reconstructed images
     _,  # Reconstructed image 1
     _,  # Reconstructed image 2
     test_2px_full_accuracy,
     test_2px_full_memo_image_reconstructions,
     test_2px_full_memo_margin_loss,
     test_2px_full_memo_accuracy) = capsnet.compute_output(testing_batch[12], testing_batch[13], keep_prob=1)

    (test_4px_full_total_error,
     test_4px_full_margin_error,
     test_4px_full_reconstruction_error,
     _,  # For testing, we don't care about the reconstructed images
     _,  # Reconstructed image 1
     _,  # Reconstructed image 2
     test_4px_full_accuracy,
     test_4px_full_memo_image_reconstructions,
     test_4px_full_memo_margin_loss,
     test_4px_full_memo_accuracy) = capsnet.compute_output(testing_batch[14], testing_batch[15], keep_prob=1)

    (test_6px_full_total_error,
     test_6px_full_margin_error,
     test_6px_full_reconstruction_error,
     _,  # For testing, we don't care about the reconstructed images
     _,  # Reconstructed image 1
     _,  # Reconstructed image 2
     test_6px_full_accuracy,
     test_6px_full_memo_image_reconstructions,
     test_6px_full_memo_margin_loss,
     test_6px_full_memo_accuracy) = capsnet.compute_output(testing_batch[16], testing_batch[17], keep_prob=1)

    (test_8px_full_total_error,
     test_8px_full_margin_error,
     test_8px_full_reconstruction_error,
     _,  # For testing, we don't care about the reconstructed images
     _,  # Reconstructed image 1
     _,  # Reconstructed image 2
     test_8px_full_accuracy,
     test_8px_full_memo_image_reconstructions,
     test_8px_full_memo_margin_loss,
     test_8px_full_memo_accuracy) = capsnet.compute_output(testing_batch[18], testing_batch[19], keep_prob=1)

    # For model saving
    saver = tf.train.Saver()

    # For output data
    f1 = open('out_capsgan_test_on_ld' + str(num_examples) + '.csv', 'w+', 0)

    with tf.Session() as sess:
        # Initialize the graph, and receive the queue coordinator and the training monitor
        coord, training_monitor = u.init(sess)

        saver.restore(sess, "/home/urops/andrewg/transfer-learning/generative_capsnet/saved_state/model51209.ckpt")
        print("Model restored.")

        # Pretrain the network on the first part--classifying and splitting
        for i in range(1, 1500):
            sys.stdout.write("Pretraining: " + str(i))
            # sys.stdout.write("Pretraining: (%d/1000)   \r" % (i))
            sys.stdout.flush()
            sess.run([train_op])

        # Now, run the actual training
        for batch_num in range(1, int(600000 / cfg.batch_size) * cfg.num_epochs):

            # Run the training operations, and get the corresponding errors
            (curr_train_total_error,
             curr_train_margin_error,
             curr_train_reconstruction_error,
             curr_train_accuracy,
             curr_memo_margin_loss,
             curr_memo_accuracy,
             _,
             _) = sess.run([train_total_error,
                            train_margin_error,
                            train_reconstruction_error,
                            train_accuracy,
                            train_memo_margin_loss,
                            train_memo_accuracy,
                            train_op,
                            train_memo_op])

            # Add all the losses to the training monitor
            training_monitor.addsix(curr_train_total_error,
                                    curr_train_margin_error,
                                    curr_train_reconstruction_error,
                                    curr_train_accuracy,
                                    curr_memo_margin_loss,
                                    curr_memo_accuracy)

            print(("Step: " + str(batch_num)).ljust(15)[:15]),
            print(("Total loss: " + str(curr_train_total_error)).ljust(25)[:25]),
            print(("Margin loss: " + str(curr_train_margin_error)).ljust(25)[:25]),
            print(("Reconstruct loss: " + str(curr_train_reconstruction_error)).ljust(25)[:25]),
            print(("Train Accuracy: " + str(curr_train_accuracy)).ljust(25)[:25]),
            print(("Memo margin: " + str(curr_memo_margin_loss)).ljust(25)[:25]),
            print(("Memo accuracy: " + str(curr_memo_accuracy)).ljust(25)[:25]),
            print("\n")

            # Every 100 iterations, display the current testing results.
            if batch_num % 100 == 9:
                # Get the errors on the test data
                (curr_total_error,
                 curr_margin_error,
                 curr_reconstruction_error,
                 curr_memo_margin_error,
                 curr_memo_accuracy,
                 curr_accuracy,
                 curr_accuracy_2px,
                 curr_accuracy_4px,
                 curr_accuracy_6px,
                 curr_accuracy_8px,
                 curr_sub_accuracy,
                 curr_sub_accuracy_2px,
                 curr_sub_accuracy_4px,
                 curr_sub_accuracy_6px,
                 curr_sub_accuracy_8px) = sess.run([test_0px_full_total_error,
                                                    test_0px_full_margin_error,
                                                    test_0px_full_reconstruction_error,
                                                    test_0px_full_memo_margin_loss,
                                                    test_0px_full_memo_accuracy,
                                                    test_0px_full_accuracy,
                                                    test_2px_full_accuracy,
                                                    test_4px_full_accuracy,
                                                    test_6px_full_accuracy,
                                                    test_8px_full_accuracy,
                                                    test_0px_sub_accuracy,
                                                    test_2px_sub_accuracy,
                                                    test_4px_sub_accuracy,
                                                    test_6px_sub_accuracy,
                                                    test_8px_sub_accuracy
                                                    ])

                # Add the losses to the training monitor
                training_monitor.addeleventest(curr_total_error,
                                               curr_margin_error,
                                               curr_reconstruction_error,
                                               curr_memo_margin_error,
                                               curr_memo_accuracy,
                                               curr_accuracy,
                                               curr_accuracy_2px,
                                               curr_accuracy_4px,
                                               curr_accuracy_6px,
                                               curr_accuracy_8px,
                                               curr_sub_accuracy,
                                               curr_sub_accuracy_2px,
                                               curr_sub_accuracy_4px,
                                               curr_sub_accuracy_6px,
                                               curr_sub_accuracy_8px
                                               )

                # Display the current training performance
                training_monitor.prints(file=f1, step=batch_num)
                f1.flush()

                # Save the model
                # save_path = saver.save(sess, "saved/model" + str(batch_num) + ".ckpt")
                # print("Model saved in path: %s" % save_path)

    f1.close()
コード例 #24
0
                                               drop_last=False)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=args.batch_size,
                                              num_workers=args.num_workers,
                                              shuffle=True,
                                              drop_last=False)
    sup_iterator = train_loader.__iter__()
    test_iterator = test_loader.__iter__()
    imgs, labels = sup_iterator.next()
    sup_iterator = train_loader.__iter__()
    """
    Setup model, load it to CUDA and make JIT compilation
    """
    #imgs = imgs[:2]
    stat = []
    model = CapsNet(args,
                    len(train_dataset) // (2 * args.batch_size) + 3, stat)

    use_cuda = not args.disable_cuda and torch.cuda.is_available()
    if use_cuda:
        model.cuda()
        imgs = imgs.cuda()
    if args.jit:
        model = torch.jit.trace(model, (imgs), check_inputs=[(imgs)])
    else:
        model(imgs)
    print("# model parameters:",
          sum(param.numel() for param in model.parameters()))
    """
    Construct optimizer, scheduler, and loss functions
    """
    optimizer = torch.optim.Adam(model.parameters(),