示例#1
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    device = torch.device("cuda" if args.cuda else "cpu")

    # datasets
    train_loader, test_loader = get_setting(args)

    # model
    A, B, C, D = 64, 8, 16, 16
    # A, B, C, D = 32, 32, 32, 32
    model = capsules(A=A, B=B, C=C, D=D, E=num_class,
                     iters=args.em_iters).to(device)

    criterion = SpreadLoss(num_class=num_class, m_min=0.2, m_max=0.9)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=1)

    best_acc = test(test_loader, model, criterion, device)
    for epoch in range(1, args.epochs + 1):
        acc = train(train_loader, model, criterion, optimizer, epoch, device)
        acc /= len(train_loader)
        scheduler.step(acc)
        if epoch % args.test_intvl == 0:
            best_acc = max(best_acc, test(test_loader, model, criterion, device))
    best_acc = max(best_acc, test(test_loader, model, criterion, device))
    print('best test accuracy: {:.6f}'.format(best_acc))

    snapshot(model, args.snapshot_folder, args.epochs)
def main():
    global args, train_writer, test_writer
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    device = torch.device("cuda" if args.cuda else "cpu")

    # tensorboard logging
    train_writer = SummaryWriter(comment='train')
    test_writer = SummaryWriter(comment='test')

    # dataset
    num_class, img_dim, train_loader, test_loader = get_setting(args)

    # model
    #     A, B, C, D = 64, 8, 16, 16
    A, B, C, D = 32, 32, 32, 32
    model = capsules(A=A,
                     B=B,
                     C=C,
                     D=D,
                     E=num_class,
                     iters=args.em_iters,
                     add_decoder=args.add_decoder,
                     img_dim=img_dim).to(device)

    print("Number of trainable parameters: {}".format(
        sum(param.numel() for param in model.parameters())))
    criterion = CapsuleLoss(alpha=args.alpha,
                            mode='bce',
                            num_class=num_class,
                            add_decoder=args.add_decoder)
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)

    best_loss, best_score = test(test_loader, model, criterion, 0, device)
    for epoch in range(1, args.epochs + 1):
        scores = train(train_loader, model, criterion, optimizer, epoch,
                       device)

        if epoch % args.test_intvl == 0:
            test_loss, test_score = test(test_loader, model, criterion,
                                         epoch * len(train_loader), device)
            if test_loss < best_loss or test_score > best_score:
                snapshot(model, args.snapshot_folder, epoch)
            best_loss = min(best_loss, test_loss)
            best_score = max(best_score, test_score)
    print('best test score: {:.6f}'.format(best_score))

    train_writer.close()
    test_writer.close()

    # save end model
    snapshot(model, args.snapshot_folder, 'end_{}'.format(args.epochs))
示例#3
0
def main():
    global args, best_prec1

    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    device = torch.device("cuda" if args.cuda else "cpu")

    # datasets
    num_class, train_loader, val_loader = get_setting(args)

    # model
    A, B, C, D = 64, 8, 16, 16
    # A, B, C, D = 32, 32, 32, 32
    model = capsules(A=A, B=B, C=C, D=D, E=num_class,
                     iters=args.em_iters).to(device)

    print(model)
    if args.load_weights:
        model.load_state_dict(
            torch.load(os.path.join(args.snapshot_folder, args.load_weights)))

    print_number_parameters(model)
    criterion = SpreadLoss(num_class=num_class, m_min=0.2, m_max=0.9)
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=3)

    best_acc = 0.0
    for epoch in range(1, args.epochs + 1):
        acc, loss = train(train_loader, model, criterion, optimizer, epoch,
                          device)
        # scheduler.step(acc)
        if epoch % args.test_intvl == 0:
            val_acc, val_loss = test(val_loader, model, criterion, device)

            print("Train - Average loss: {:.6f} Average acc: {:.6f}".format(
                loss, acc))

            if val_acc > best_acc:
                best_acc = val_acc
                snapshot(model, args.snapshot_folder, epoch)

            print("Current Best: {:.6f}".format(best_acc))
            print('[EPOCH {}] TRAIN_LOSS: {:.6f} TRAIN_ACC: {:.6f}'.format(
                epoch, loss, acc))
            print('[EPOCH {}] VAL_LOSS: {:.6f} VAL_ACC: {:.6f}'.format(
                epoch, val_loss, val_acc))

    print('Best val accuracy: {:.6f}'.format(best_acc))
示例#4
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    print("Current setting: cycle-"     + str(args.cycle) + \
          "\tintwidth-max-"               + str(args.intwidth_max) + \
          "\tfracwidth-max-"              + str(args.fracwidth_max) + \
          "\tbitwidth-reduce-"            + str(args.bitwidth_reduce) + \
          "\trounding-"                   + args.rounding + "\n")

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    device = torch.device("cuda" if args.cuda else "cpu")
    
    # datasets
    num_class, train_loader, test_loader = get_setting(args)

    # model
    A, B, C, D = 64, 8, 16, 16
    # A, B, C, D = 32, 32, 32, 32

    criterion = SpreadLoss(num_class=num_class, m_min=0.2, m_max=0.9)

    print("Floating-point model:")
    model = capsules(A=A, B=B, C=C, D=D, E=num_class,
                    iters=args.em_iters).to(device)
    model.load_state_dict(torch.load(args.snapshot_folder+"/model_10.pth"))
    model_test_acc = test(test_loader, model, criterion, device)


    print("RAVEN model:")
    modelRAVEN = CapsNetRAVEN(A=A, B=B, C=C, D=D, E=num_class, iters=args.em_iters, 
                        cycle=args.cycle, 
                        intwidth=args.intwidth_max, 
                        fracwidth=args.fracwidth_max, 
                        bitwidth_reduce=args.bitwidth_reduce, 
                        rounding="round").to(device)
    modelRAVEN.eval()
    modelRAVEN_state_dict = modelRAVEN.state_dict()
    modelRAVEN_state_dict.update(torch.load(args.snapshot_folder+"/model_10.pth"))
    modelRAVEN.load_state_dict(modelRAVEN_state_dict)
    modelRAVEN_test_acc = test(test_loader, modelRAVEN, criterion, device)
示例#5
0
    full_dataset, [train_size, test_size])

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=1,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=1,
                                          shuffle=True)

device = torch.device("cuda" if (
    use_cuda and torch.cuda.is_available()) else "cpu")

# Capsule model
num_class = 43
A, B, C, D = 64, 8, 16, 16
model = capsules(A=A, B=B, C=C, D=D, E=num_class, iters=2).to(device)
model.load_state_dict(torch.load(pretrained_model))
model.eval()


def fgsm_attack(image, epsilon, data_grad):
    # Collect the element-wise sign of the data gradient
    sign_data_grad = data_grad.sign()
    # Create the perturbed image by adjusting each pixel of the input image
    perturbed_image = image + epsilon * sign_data_grad
    # Adding clipping to maintain [0,1] range
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    # Return the perturbed image
    return perturbed_image

示例#6
0
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np

from PIL import Image

from model import capsules

transform=transforms.Compose([ transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))])

A=64
B=8
C=16
D=16
E=10
model = capsules(A=A, B=B, C=C, D=D, E=E, 
                     iters=2).cuda()
model.load_state_dict(torch.load('snapshots/model_10.pth'))
model.eval()

img = transform(Image.open('bmp/test_1.bmp')).reshape([1,1,28,28]).cuda()
pos, result = model(img)


print(np.argmax(result.cpu().detach().numpy()))
示例#7
0
def main():

    A, B, C, D = 256, 22, 22, 22
   

    global args
    args = SETTINGS()
    args.cuda = args.cuda and torch.cuda.is_available()

    # torch.manual_seed(args.seed)
    # if args.cuda:
    #     torch.cuda.manual_seed(args.seed)

    device = torch.device("cuda" if args.cuda else "cpu")
    
    num_class, train_loader, test_loader = load_database(args)

    model = capsules(A=A, B=B, C=C, D=D, E=num_class, iters=args.em_iters).to(device)

    criterion = SpreadLoss(num_class=num_class, m_min=0.2, m_max=0.9)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    #optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    #optimizer = optim.Adadelta(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    #optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=args.weight_decay, nesterov=True)
    #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=1, verbose=True)
    scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma= 0.81, last_epoch=-1)

    epoch = 1
    _accs = []
    _accs_test = []
    _loss = []
    _loss_test = []
    model, epoch, optimizer, criterion, scheduler, _accs, _accs_test, _loss, _loss_test = load(model, epoch, optimizer, criterion, scheduler, _accs, _accs_test, _loss, _loss_test, args.save_folder, A, B, C, D)

    #optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=args.weight_decay)

    #plot_graph(_accs, _accs_test, "Accuracy")
    #plot_graph(_loss, _loss_test, "Loss")
    #quit()

    #best_acc = test(test_loader, model, criterion, device)
    #quit()
    print(str(model))
    print(str(optimizer))

    try:
        for epoch in range(epoch, args.epochs + 1):
            _s = time.time()
            acc, loss = train(train_loader, model, criterion, optimizer, epoch, device)
            acc /= len(train_loader)
            _accs.append(acc)
            loss /= len(train_loader)
            _loss.append(loss)
            print("Average acc: " + str(acc))
            print("Average loss: " + str(loss))
            print("Epoch time: " + str(round(time.time() - _s, 3)))
            test_acc, test_loss = test(test_loader, model, criterion, device)
            _accs_test.append(test_acc)
            _loss_test.append(test_loss)
            save(model, epoch+1, optimizer, criterion, scheduler, _accs, _accs_test, _loss, _loss_test, args.save_folder, A, B, C, D)
            #scheduler.step(acc)
            #print(optimizer)
        plot_graph(_accs, _accs_test, "Accuracy")
        plot_graph(_loss, _loss_test, "Loss")
    except KeyboardInterrupt:
        if len(_accs_test) < len(_accs):
            test_acc, test_loss = test(test_loader, model, criterion, device)
            _accs_test.append(test_acc)
            _loss_test.append(test_loss)
        #    epoch += 1
        #save(model, epoch, optimizer, criterion, scheduler, _accs, _accs_test, _loss, _loss_test, args.save_folder, A, B, C, D)
        plot_graph(_accs, _accs_test, "Accuracy")
        plot_graph(_loss, _loss_test, "Loss")
        quit()