def run_alexnet_ann_recall_test_simulation_trial3():
    # instantiate alexnet from mnist trained
    alex_cnn = AlexNet()
    alex_cnn.load_state_dict(torch.load("trained_models/alexnet.pt", map_location=torch.device("cpu")))
    alex_cnn.eval()
    alex_capture = Intermediate_Capture(alex_cnn.fc1) # for now capture final output
    return run_alexnet_ann_recall_simulation(alex_cnn=alex_cnn, alex_capture=alex_capture, output_name="alexnet_recall_task_trial3.txt", num_nodes=1024)
Example #2
0
def bpda_all():
    model_list = {
    'naive-alexnet':"/media/dsg3/dsgprivate/yuhang/model/alexnet/naive/naive_param.pkl",
    'new-AT-alexnet':"/media/dsg3/dsgprivate/yuhang/model/alexnet/nat/naive_param.pkl",
    'ensemble-AT-alexnet':"/media/dsg3/dsgprivate/yuhang/model/alexnet/eat/naive_param.pkl",
    #'DPLAT-alexnet':"/media/dsg3/dsgprivate/yuhang/model/alexnet/dplat/lat_param.pkl",   
    }
    if args.model == 'alexnet':
        model = AlexNet(enable_lat=args.enable_lat,
                      epsilon=args.lat_epsilon,
                      pro_num=args.lat_pronum,
                      batch_size=args.model_batchsize,
                      num_classes=200,
                      if_dropout=args.dropout
                      ).cuda()
    elif args.model == 'alexnetBN':
        model = AlexNetBN(enable_lat=args.enable_lat,
                      epsilon=args.lat_epsilon,
                      pro_num=args.lat_pronum,
                      batch_size=args.model_batchsize,
                      num_classes=200,
                      if_dropout=args.dropout
                      ).cuda()    
    # if cifar then normalize epsilon from [0,255] to [0,1]
    
    '''
    if args.dataset == 'cifar10':
        eps = args.attack_epsilon / 255.0
    else:
        eps = args.attack_epsilon
    '''
    eps = args.attack_epsilon
    #eps = args.attack_epsilon
    # the last layer of densenet is F.log_softmax, while CrossEntropyLoss have contained Softmax()
    attack = Attack(dataroot = "/media/dsg3/dsgprivate/lat/data/sampled_imagenet/",
                    dataset  = args.dataset,
                    batch_size = args.attack_batchsize,
                    target_model = model,
                    criterion = nn.CrossEntropyLoss(),
                    epsilon = eps,
                    alpha =  args.attack_alpha,
                    iteration = args.attack_iter)
    
    
    for target in model_list:
        print('------Now target model is {} -------'.format(target))
        model_path = model_list[target]
        model.load_state_dict(torch.load((model_path)))
        print('model successfully loaded')
        attack.i_fgsm()
Example #3
0
class TestNetwork():
    def __init__(self, dataset, batch_size, epochs):
        self.dataset = dataset
        self.batch_size = batch_size
        self.epochs = epochs

        # letters contains 27 classes, digits contains 10 classes
        num_classes = 27 if dataset == 'letters' else 10

        # Load mdoel and use cuda if available
        self.model = AlexNet(num_classes)
        if torch.cuda.is_available():
            self.model.cuda()

        # Load testing dataset
        kwargs = {
            'num_workers': 1,
            'pin_memory': True
        } if torch.cuda.is_available() else {}
        self.test_loader = torch.utils.data.DataLoader(EMNIST(
            './data',
            dataset,
            download=True,
            transform=transforms.Compose([
                transforms.Lambda(correct_rotation),
                transforms.Resize((224, 224)),
                transforms.Grayscale(3),
                transforms.ToTensor(),
            ]),
            train=False),
                                                       batch_size=batch_size,
                                                       shuffle=True,
                                                       **kwargs)

        # Optimizer and loss function
        self.loss_fn = nn.CrossEntropyLoss()

    def test(self, epoch):
        """
        Test the model for one epoch with a pre trained network
        :param epoch: Current epoch
        :return: None
        """
        # Load weights from trained model
        state_dict = torch.load(
            './trained_models/{}_{}.pth'.format(self.dataset, epoch),
            map_location=lambda storage, loc: storage)['model']
        self.model.load_state_dict(state_dict)
        self.model.eval()

        test_loss = 0
        test_correct = 0
        progress = None
        for batch_idx, (data, target) in enumerate(self.test_loader):
            # Get data and label
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)

            #
            output = self.model(data)
            loss = self.loss_fn(output, target)
            test_loss += loss.data[0]
            pred = output.data.max(1, keepdim=True)[1]
            test_correct += pred.eq(target.data.view_as(pred)).sum()

            # Print information about current step
            current_progress = int(100 * (batch_idx + 1) * self.batch_size /
                                   len(self.test_loader.dataset))
            if current_progress is not progress and current_progress % 5 == 0:
                progress = current_progress
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, (batch_idx + 1) * len(data),
                    len(self.test_loader.dataset), current_progress,
                    loss.data[0]))

        test_loss /= (len(self.test_loader.dataset) / self.batch_size)
        test_correct /= len(self.test_loader.dataset)
        test_correct *= 100

        # Print information about current epoch
        print(
            'Test Epoch: {} \tCorrect: {:3.2f}%\tAverage loss: {:.6f}'.format(
                epoch, test_correct, test_loss))

    def start(self):
        """
        Start testing the network
        :return: None
        """
        for epoch in range(1, self.epochs + 1):
            self.test(epoch)
Example #4
0
    elif args.model == 'resnet':
        cnn = ResNet18(enable_lat=args.enable_lat,
                       epsilon=args.epsilon,
                       pro_num=args.pro_num,
                       batch_size=args.batchsize,
                       num_classes=200,
                       if_dropout=args.dropout)
        #cnn.apply(conv_init)
    elif args.model == 'alexnetBN':
        cnn = AlexNetBN(enable_lat=args.enable_lat,
                        epsilon=args.epsilon,
                        pro_num=args.pro_num,
                        batch_size=args.batchsize,
                        num_classes=200,
                        if_dropout=args.dropout)
    cnn.cuda()

    if os.path.exists(real_model_path):
        cnn.load_state_dict(torch.load(real_model_path))
        print('model successfully loaded.')
    else:
        print("load model failed.")

    if args.test_flag:
        if args.adv_flag:
            test_all(cnn)
        else:
            test_op(cnn)
    else:
        train_op(cnn)
Example #5
0
import torch
from torch.nn.functional import softmax

from alexnet import AlexNet
from utils import cifar10_loader, device, cifar10_classes

torch.random.manual_seed(128)
batch_size = 1
testloader = cifar10_loader(train=False, batch_size=batch_size)

net = AlexNet()
net.load_state_dict(torch.load("model/model.h5"))
net.eval()

correct = 0
total = 0


def run():
    global correct, total
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            inputs, labels = images.to(device), labels.to(device)
            outputs = net(inputs)
            _, predicted = torch.topk(outputs.data, 5)
            #print(predicted)
            indexes = predicted.numpy()[0].tolist()
            #print(indexes)
            #print(softmax(outputs).numpy()[0][indexes])
            #print([cifar10_classes[i] for i in indexes])
Example #6
0
def test_task2(root_path):
    '''
    :param root_path: root path of test data, e.g. ./dataset/task2/test/0/
    :return results: a dict of classification results
    results = {'audio_0000.pkl': 23, ‘audio_0001’: 11, ...}
    This means audio 'audio_0000.pkl' is matched to video 'video_0023' and ‘audio_0001’ is matched to 'video_0011'.
    '''
    results = dict()

    os.chdir(os.path.split(os.path.realpath(__file__))[0])
    audio_transforms = transforms.Compose([transforms.ToTensor()])
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    label_list = [9, 2, 3, 8, 7, 6, 5, 0, 4, 1]
    audio_model = ResNet.resnet18(num_classes=10)
    audio_model = audio_model.to(device)
    audio_model = torch.nn.DataParallel(audio_model)
    audio_model.load_state_dict(torch.load('./model/resnet18.pth'))
    audio_model.eval()

    audio_class_features = []
    video_class_features = []
    audio_motion_features = []
    video_motion_features = []
    Files = list(filter(lambda x: x.endswith(".pkl"), os.listdir(root_path)))
    Files.sort()
    for sample in Files:
        audio_path = root_path + '/' + sample
        data = np.load(audio_path, allow_pickle=True)['audio']
        for i in range(4):
            S = librosa.resample(data[:, i], orig_sr=44100, target_sr=11000)
            S = np.abs(librosa.stft(S[5650:-5650], n_fft=510, hop_length=128))
            S = np.log10(S + 0.0000001)
            S = np.clip(S, -5, 5)
            S -= np.min(S)
            S = 255 * (S / np.max(S))
            if S.shape[-1] < 256:
                S = np.pad(S, ((0, 0), (int(np.ceil((256 - S.shape[-1]) / 2)),
                                        int(np.floor(
                                            (256 - S.shape[-1]) / 2)))))
            if S.shape[-1] > 256:
                S = S[:,
                      int(np.ceil((S.shape[-1] - 256) /
                                  2)):-int(np.floor((S.shape[-1] - 256) / 2))]
            if i == 0:
                feature = np.uint8(S)[:, :, np.newaxis]
            else:
                feature = np.concatenate(
                    (np.uint8(S)[:, :, np.newaxis], feature), axis=-1)
        X = audio_transforms(feature)
        X = X.to(device)
        class_feature_t = torch.softmax(audio_model(X.unsqueeze(0)),
                                        dim=-1).squeeze(0)
        class_feature = np.zeros(10)
        for i in range(10):
            class_feature[label_list[i]] = class_feature_t[i]
        threshold = 0.35
        label_list2 = [1, 2, 0, 3]
        label = [0] * 4
        for i in range(4):
            if np.max(data[:, i]) > threshold:
                label[label_list2[i]] = 1
        audio_class_features.append(class_feature)
        audio_motion_features.append(label)

    net = AlexNet()
    net.load_state_dict(torch.load('./model/alexnet.pt'))

    k = 0
    while os.path.exists(root_path + '/video_' + str('%04d' % k)):
        video_class_feature, video_move_feature = get_video_feature(
            net, root_path + '/video_' + str('%04d' % k))
        video_class_features.append(video_class_feature)
        video_motion_features.append(video_move_feature)
        k = k + 1

    indices = pairing(audio_class_features, audio_motion_features,
                      video_class_features, video_motion_features, -100)
    j = 0
    for sample in Files:
        results[sample] = indices[j][1]
        j = j + 1
    return results
Example #7
0
targets = torch.tensor(testset.targets)
targets[targets >= 5] = 5
testset.targets = targets.tolist()
testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=4)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('==> Building model..')

net = AlexNet(5)

checkpoint_path = 'cifar_10_alexnet.t7'
checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu'))
try:
    net.load_state_dict(checkpoint['net'])
except:
    new_check_point = OrderedDict()
    for k, v in checkpoint['net'].items():
        name = k[7:]  # remove `module.`
        # name = k[9:]  # remove `module.1.`
        new_check_point[name] = v
    net.load_state_dict(new_check_point)


net = net.to(device)

if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True
Example #8
0
    if args.model == 'vgg':
        model = VGG16(enable_lat=False,
                      epsilon=0.6,
                      pro_num=5,
                      batch_size=args.batch_size,
                      if_dropout=True)
    elif args.model == 'resnet':
        model = ResNet50(enable_lat=False,
                         epsilon=0.6,
                         pro_num=5,
                         batch_size=args.batch_size,
                         if_dropout=True)
    elif args.model == 'alexnet':
        model = AlexNet(enable_lat=False,
                        epsilon=0.6,
                        pro_num=5,
                        batch_size=args.batch_size,
                        num_classes=200,
                        if_dropout=True)
    model.cuda()

    if os.path.exists(args.model_path):
        model.load_state_dict(torch.load(args.model_path))
        print('load model successfully.')
    else:
        print("load failed.")
    if args.test_flag:
        test_all(model)
    else:
        test_op(model)
class TrainNetwork():
    def __init__(self, dataset, batch_size, epochs, lr, lr_decay_epoch,
                 momentum):
        assert (dataset == 'letters' or dataset == 'mnist')

        self.dataset = dataset
        self.batch_size = batch_size
        self.epochs = epochs
        self.lr = lr
        self.lr_decay_epoch = lr_decay_epoch
        self.momentum = momentum

        # letters contains 27 classes, digits contains 10 classes
        num_classes = 27 if dataset == 'letters' else 10

        # Load pre learned AlexNet with changed number of output classes
        state_dict = torch.load('./trained_models/alexnet.pth')
        state_dict['classifier.6.weight'] = torch.zeros(num_classes, 4096)
        state_dict['classifier.6.bias'] = torch.zeros(num_classes)
        self.model = AlexNet(num_classes)
        self.model.load_state_dict(state_dict)

        # Use cuda if available
        if torch.cuda.is_available():
            self.model.cuda()

        # Load training dataset
        kwargs = {
            'num_workers': 1,
            'pin_memory': True
        } if torch.cuda.is_available() else {}
        self.train_loader = torch.utils.data.DataLoader(
            EMNIST('./data',
                   dataset,
                   download=True,
                   transform=transforms.Compose([
                       transforms.Lambda(correct_rotation),
                       transforms.Lambda(random_transform),
                       transforms.Resize((224, 224)),
                       transforms.RandomResizedCrop(224, (0.9, 1.1),
                                                    ratio=(0.9, 1.1)),
                       transforms.Grayscale(3),
                       transforms.ToTensor(),
                   ])),
            batch_size=batch_size,
            shuffle=True,
            **kwargs)

        # Optimizer and loss function
        self.optimizer = optim.SGD(self.model.parameters(),
                                   lr=self.lr,
                                   momentum=self.momentum)
        self.loss_fn = nn.CrossEntropyLoss()

    def reduce_learning_rate(self, epoch):
        """
        Reduce the learning rate by factor 0.1 every lr_decay_epoch
        :param optimizer: Optimizer containing the learning rate
        :param epoch: Current epoch
        :param init_lr: Initial learning rate
        :param lr_decay_epoch: Number of epochs until learning rate gets reduced
        :return: None
        """
        lr = self.lr * (0.1**(epoch // self.lr_decay_epoch))

        if epoch % self.lr_decay_epoch == 0:
            print('LR is set to {}'.format(lr))

        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def train(self, epoch):
        """
        Train the model for one epoch and save the result as a .pth file
        :param epoch: Current epoch
        :return: None
        """
        self.model.train()

        train_loss = 0
        train_correct = 0
        progress = None
        for batch_idx, (data, target) in enumerate(self.train_loader):
            # Get data and label
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)

            # Optimize using backpropagation
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.loss_fn(output, target)
            train_loss += loss.data[0]
            pred = output.data.max(1, keepdim=True)[1]
            train_correct += pred.eq(target.data.view_as(pred)).sum()
            loss.backward()
            self.optimizer.step()

            # Print information about current step
            current_progress = int(100 * (batch_idx + 1) * self.batch_size /
                                   len(self.train_loader.dataset))
            if current_progress is not progress and current_progress % 5 == 0:
                progress = current_progress
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, (batch_idx + 1) * len(data),
                    len(self.train_loader.dataset), current_progress,
                    loss.data[0]))

        train_loss /= (len(self.train_loader.dataset) / self.batch_size)
        train_correct /= len(self.train_loader.dataset)
        train_correct *= 100

        # Print information about current epoch
        print(
            'Train Epoch: {} \tCorrect: {:3.2f}%\tAverage loss: {:.6f}'.format(
                epoch, train_correct, train_loss))

        # Save snapshot
        torch.save(
            {
                'model': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict()
            }, './trained_models/{}_{}.pth'.format(self.dataset, epoch))

    def start(self):
        """
        Start training the network
        :return: None
        """
        for epoch in range(1, self.epochs + 1):
            self.reduce_learning_rate(epoch)
            self.train(epoch)
Example #10
0
# Model
print('==> Building model..')

net = AlexNet()
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.pth')
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),
                      lr=args.lr,
                      momentum=0.9,
                      weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
Example #11
0
class LiveShowcase:
    def __init__(self, path_to_model):
        num_classes = 27

        # Member variables
        self.status = 'Ready'
        self.last_words = None
        self.dictionary_set = set(nltk.corpus.words.words())

        # Load pre learned AlexNet
        state_dict = torch.load(path_to_model, map_location=lambda storage, loc: storage)['model']
        self.model = AlexNet(num_classes)
        self.model.load_state_dict(state_dict)
        self.model.eval()

    def process_image(self, image, bounding_boxes):
        """
        Process image to find and classify characters and build 5 most probable words
        :param image: rgb image
        :param bounding_boxes: list of bounding boxes containing characters (min_x, min_y, width, height)
        :return: None
        """
        self.status = 'Processing'

        # Find 5 most probable words
        subimages = extract_characters(image, bounding_boxes)
        words = classify_characters(self.model, subimages)
        self.last_words = words[:5]

        self.status = 'Ready'

    def start(self, max_bounding_boxes=10):
        """
        Start the live showcase using a camera
        :return: None
        """
        # Try to open a connection to the camera
        cap = cv2.VideoCapture(0)
        if not cap.isOpened():
            print('Error: No camera found')
            return
        cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
        cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 960)

        print('Press q to stop the live showcase')
        while True:
            # Capture frame-by-frame
            ret, image = cap.read()
            output = image

            # Find bounding boxes for each character
            image = preprocess_image(image)
            bounding_boxes = find_bounding_boxes(image)
            bounding_boxes = filter_bounding_boxes(image, bounding_boxes)
            for box in bounding_boxes:
                cv2.rectangle(output, (box[0], box[1]), (box[0] + box[2], box[1] + box[3]), (0, 0, 255), 2)

            # Process image if no other image is processed
            if self.status.__contains__('Ready'):
                if len(bounding_boxes) > max_bounding_boxes:
                    self.status = 'Ready [Warning: too many bounding boxes]'
                    self.last_words = None
                else:
                    thread = threading.Thread(target=self.process_image, args=(image, bounding_boxes), daemon=True)
                    thread.start()

            # Draw status bar with last recognized words
            cv2.putText(output, 'Status: {}'.format(self.status), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                        (0, 0, 0), 1, cv2.LINE_AA)
            if self.last_words:
                for offset, word in zip(range(len(self.last_words)), self.last_words):
                    color = (0, 0, 0)
                    # Use green color if word is in dictionary
                    if word[0].lower() in self.dictionary_set:
                        color = (0, 255, 0)
                    cv2.putText(output, '{} ({:5.2f}%)'.format(word[0], 100 * word[1]),
                                (10, 20 + (offset + 1) * 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                                color, 1, cv2.LINE_AA)

            # Draw bounding box around detected word
            if self.last_words and len(bounding_boxes) > 1:
                word = self.last_words[0]
                color = (0, 0, 255)
                # Use green color if word is in dictionary
                if word[0].lower() in self.dictionary_set:
                    color = (0, 255, 0)
                text = '{} ({:5.2f}%)'.format(word[0], 100 * word[1])
                padding = 10
                top_left = (np.min([b[0] for b in bounding_boxes]) - padding,
                            np.min([b[1] for b in bounding_boxes]) - padding)
                bottom_right = (np.max([b[0]+b[2] for b in bounding_boxes]) + padding,
                                np.max([b[1]+b[3] for b in bounding_boxes]) + padding)
                text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
                cv2.rectangle(output, (top_left[0] - 1, top_left[1] - text_size[1] - 2 * padding),
                                          (top_left[0] + text_size[0] + 2 * padding, top_left[1]),
                              color, thickness=cv2.FILLED)
                cv2.putText(output, text, (top_left[0] + padding, top_left[1] - padding),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
                cv2.rectangle(output, top_left, bottom_right, color, 2)

            # Display the resulting frame
            cv2.imshow('Image internal', image)
            cv2.imshow('Showcase', output)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

        # When everything done, release the capture
        cap.release()
        cv2.destroyAllWindows()
Example #12
0
def main():
    progress = default_progress()
    experiment_dir = 'experiment/miniplaces'
    # Here's our data
    train_loader = torch.utils.data.DataLoader(CachedImageFolder(
        'dataset/miniplaces/simple/train',
        transform=transforms.Compose([
            transforms.Resize(128),
            transforms.RandomCrop(119),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)
        ])),
                                               batch_size=64,
                                               shuffle=True,
                                               num_workers=6,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(CachedImageFolder(
        'dataset/miniplaces/simple/val',
        transform=transforms.Compose([
            transforms.Resize(128),
            transforms.CenterCrop(119),
            transforms.ToTensor(),
            transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)
        ])),
                                             batch_size=512,
                                             shuffle=False,
                                             num_workers=6,
                                             pin_memory=True)
    # Create a simplified AlexNet with half resolution.
    model = AlexNet(first_layer='conv1',
                    last_layer='fc8',
                    layer_sizes=dict(fc6=2048, fc7=2048),
                    output_channels=100,
                    half_resolution=True,
                    include_lrn=False,
                    split_groups=False).cuda()
    # Use Kaiming initialization for the weights
    for name, val in model.named_parameters():
        if 'weight' in name:
            init.kaiming_uniform_(val)
        else:
            # Init positive bias in many layers to avoid dead neurons.
            assert 'bias' in name
            init.constant_(
                val, 0 if any(
                    name.startswith(layer)
                    for layer in ['conv1', 'conv3', 'fc8']) else 1)
    # An abbreviated training schedule: 40000 batches.
    # TODO: tune these hyperparameters.
    # init_lr = 0.002
    init_lr = 0.002
    # max_iter = 40000 - 34.5% @1
    # max_iter = 50000 - 37% @1
    # max_iter = 80000 - 39.7% @1
    # max_iter = 100000 - 40.1% @1
    max_iter = 100000
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=init_lr,
        momentum=0.9,  # 0.9,
        # weight_decay=0.001)
        weight_decay=0.001)
    iter_num = 0
    best = dict(val_accuracy=0.0)
    model.train()
    # Oh, hold on.  Let's actually resume training if we already have a model.
    checkpoint_filename = 'miniplaces.pth.tar'
    best_filename = 'best_%s' % checkpoint_filename
    best_checkpoint = os.path.join(experiment_dir, best_filename)
    try_to_resume_training = False
    if try_to_resume_training and os.path.exists(best_checkpoint):
        checkpoint = torch.load(os.path.join(experiment_dir, best_filename))
        iter_num = checkpoint['iter']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best['val_accuracy'] = checkpoint['accuracy']

    def save_checkpoint(state, is_best):
        filename = os.path.join(experiment_dir, checkpoint_filename)
        ensure_dir_for(filename)
        torch.save(state, filename)
        if is_best:
            shutil.copyfile(filename,
                            os.path.join(experiment_dir, best_filename))

    def validate_and_checkpoint():
        model.eval()
        val_loss, val_acc = AverageMeter(), AverageMeter()
        for input, target in progress(val_loader):
            # Load data
            input_var, target_var = [
                Variable(d.cuda(non_blocking=True)) for d in [input, target]
            ]
            # Evaluate model
            with torch.no_grad():
                output = model(input_var)
                loss = criterion(output, target_var)
                _, pred = output.max(1)
                accuracy = (target_var.eq(pred)
                            ).data.float().sum().item() / input.size(0)
            val_loss.update(loss.data.item(), input.size(0))
            val_acc.update(accuracy, input.size(0))
            # Check accuracy
            post_progress(l=val_loss.avg, a=val_acc.avg)
        # Save checkpoint
        save_checkpoint(
            {
                'iter': iter_num,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'accuracy': val_acc.avg,
                'loss': val_loss.avg,
            }, val_acc.avg > best['val_accuracy'])
        best['val_accuracy'] = max(val_acc.avg, best['val_accuracy'])
        post_progress(v=val_acc.avg)

    # Here is our training loop.
    while iter_num < max_iter:
        for input, target in progress(train_loader):
            # Track the average training loss/accuracy for each epoch.
            train_loss, train_acc = AverageMeter(), AverageMeter()
            # Load data
            input_var, target_var = [
                Variable(d.cuda(non_blocking=True)) for d in [input, target]
            ]
            # Evaluate model
            output = model(input_var)
            loss = criterion(output, target_var)
            train_loss.update(loss.data.item(), input.size(0))
            # Perform one step of SGD
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # Also check training set accuracy
            _, pred = output.max(1)
            accuracy = (target_var.eq(pred)).data.float().sum().item() / (
                input.size(0))
            train_acc.update(accuracy)
            remaining = 1 - iter_num / float(max_iter)
            post_progress(l=train_loss.avg,
                          a=train_acc.avg,
                          v=best['val_accuracy'])
            # Advance
            iter_num += 1
            if iter_num >= max_iter:
                break
            # Linear learning rate decay
            lr = init_lr * remaining
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            # Ocassionally check validation set accuracy and checkpoint
            if iter_num % 1000 == 0:
                validate_and_checkpoint()
                model.train()
Example #13
0
        magnitude_loss = 0.0001 * F.mse_loss(target=real_eig_vals,
                                             input=fake_eig_vals)
        structure_loss = -torch.sum(torch.mul(fake_eig_vecs, real_eig_vecs), 0)
        normalized_real_eig_vals = normalize_min_max(real_eig_vals)
        weighted_structure_loss = torch.sum(
            torch.mul(normalized_real_eig_vals, structure_loss))
        return magnitude_loss + weighted_structure_loss

    netG = Generator(ngpu).to(device)
    netG.apply(weights_init)
    if opt.netG != '':
        netG.load_state_dict(torch.load(opt.netG))
    print(netG)

    netC = AlexNet(ngpu).to(device)
    netC.load_state_dict(torch.load('./best_model.pth'))
    print(netC)
    netC.eval()

    netD = Discriminator(ngpu).to(device)
    netD.apply(weights_init)
    if opt.netD != '':
        netD.load_state_dict(torch.load(opt.netD))
    print(netD)

    criterion = nn.BCELoss()
    criterion_sum = nn.BCELoss(reduction='sum')

    fixed_noise = torch.randn(opt.batchSize, 100, 1, 1, device=device)

    real_label = 1
Example #14
0
    [transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


test_dataset = datasets.ImageFolder(root=image_path + "/test",
                                        transform=data_transform)
test_num = len(test_dataset)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16, shuffle=True,
                                              num_workers=0)


model=AlexNet(num_classes=2)
model_weight_path = "./models/AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()


correct = 0
total = 0
acc = 0.0

confusion_matrix = meter.ConfusionMeter(2)
for i, sample in enumerate(test_loader):
    inputs, labels = sample[0], sample[1]
    outputs = model(inputs)

    _, prediction = torch.max(outputs, 1)
    correct += (labels == prediction).sum().item()
    total += labels.size(0)
def run_alexnet_ann_recall_test_simulation_trial4():
    num_nodes = 10
    alex_cnn1 = AlexNet()
    alex_cnn1.load_state_dict(torch.load("trained_models/alexnet.pt", map_location=torch.device("cpu")))
    alex_cnn1.eval()
    alex_capture1 = Intermediate_Capture(alex_cnn1.layer3) # for now capture final output
    output_name = "alexnet_recall_task_trial4.txt"

    alex_cnn2 = AlexNet()
    alex_cnn2.load_state_dict(torch.load("trained_models/alexnet.pt", map_location=torch.device("cpu")))
    alex_cnn2.eval()
    alex_capture2 = Intermediate_Capture(alex_cnn2.layer4) # for now capture final output

    alex_cnn3 = AlexNet()
    alex_cnn3.load_state_dict(torch.load("trained_models/alexnet.pt", map_location=torch.device("cpu")))
    alex_cnn3.eval()
    alex_capture3 = Intermediate_Capture(alex_cnn3.layer5) # for now capture final output

    transform = transforms.ToTensor()
    data_raw = MNIST(
    root='./data/mnist',
    train=True,
    download=True,
    transform=transform)

    # creating a toy dataset for simple probing
    mnist_subset = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]}
    per_class_sizes = {0:10, 1:10, 2:10, 3:10, 4:10, 5:10, 6:10, 7:10, 8:10, 9:10}
    for i in range(len(data_raw)):
        image, label = data_raw[i]
        if len(mnist_subset[label]) < per_class_sizes[label]:
            mnist_subset[label].append(torch.reshape(image, (1,1, 28,28)))
        done = True
        for k in mnist_subset:
            if len(mnist_subset[k]) < per_class_sizes[k]:
                done=False
        if done:
            break


    # converts mnist_subset into table that is usable for model input
    full_pattern_set = []
    full_label_set = []
    for k in mnist_subset:
        for v in mnist_subset[k]:
            full_pattern_set.append(v)
            full_label_set.append(k)

    # given list of a desired labels, randomly choose an example of each label from the mnist dataset to store
    stored_size_vs_performance = [] # list will store tuples of (hopfield perf, popularity perf, ortho perf)
    for desired_label_size in range(10):
        desired_labels = list(range(desired_label_size+1))
        full_stored_set, full_stored_labels = create_storage_set(desired_labels, mnist_subset, reshape=False, make_numpy=False)
        print("Num Stored: ", len(desired_labels))

        # evaluate hopnet performance
        ann_model = hopnet(6272) 
        model = CNN_ANN(alex_cnn1, ann_model, alex_capture1, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, full_stored_set, full_stored_labels, verbose=False)
        print("Alexnet Layer3:", num_succ, ":", num_fail)
        layer3_perf = int(num_succ)

        ann_model = hopnet(12544) 
        model = CNN_ANN(alex_cnn2, ann_model, alex_capture2, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, full_stored_set, full_stored_labels, verbose=False)
        print("Alexnet Layer3:", num_succ, ":", num_fail)
        layer4_perf = int(num_succ)

        ann_model = hopnet(2304) 
        model = CNN_ANN(alex_cnn3, ann_model, alex_capture3, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, full_stored_set, full_stored_labels, verbose=False)
        print("Alexnet Layer3:", num_succ, ":", num_fail)
        layer5_perf = int(num_succ)

        stored_size_vs_performance.append((layer3_perf, layer4_perf, layer5_perf))

    # write performance to file
    fh = open("data/graph_sources/" + output_name, "w")
    for perf in stored_size_vs_performance:
        fh.write(str(perf[0]) + "," + str(perf[1]) + "," + str(perf[2]) + "\n")
    fh.close()
    return stored_size_vs_performance
Example #16
0
def test_sample_images(path_to_model, path_to_images, save_path):
    num_classes = 27

    # Load pre learned AlexNet
    state_dict = torch.load(path_to_model,
                            map_location=lambda storage, loc: storage)['model']
    model = AlexNet(num_classes)
    model.load_state_dict(state_dict)
    model.eval()

    # Process every image
    dictionary = set(nltk.corpus.words.words())
    distances = defaultdict(lambda: defaultdict(lambda: 0))
    size_distances = defaultdict(lambda: defaultdict(lambda: 0))
    corrected_words = defaultdict(lambda: defaultdict(lambda: 0))
    with open('{}labels.txt'.format(path_to_images)) as f:
        for line in f:
            sections = line.split('; ')
            if len(sections) < 2:
                continue
            fname = sections[0]
            correct_word = sections[1]

            # Open image
            image = cv2.imread('{}{}'.format(path_to_images, fname))
            output = image

            # Find bounding boxes for each character
            image = preprocess_image(image)
            _, image = cv2.threshold(image, 90, 255, cv2.THRESH_BINARY_INV)
            bounding_boxes = find_bounding_boxes(image)
            bounding_boxes = filter_bounding_boxes(image, bounding_boxes)

            # Find 5 most probable results
            subimages = extract_characters(image, bounding_boxes)
            results = classify_characters(model, subimages)
            results = results[:5]

            # Check if word can be corrected
            corrected_word = ''
            for word in results:
                if word[0].lower() in dictionary and corrected_word is '':
                    corrected_word = word[0]

            # Append to evaluation dicts for evaluation
            most_probable_word = results[0][0]
            distance = Levenshtein.distance(most_probable_word, correct_word)
            distances[len(correct_word)][distance] += 1
            size_distances[len(correct_word)][len(most_probable_word)] += 1

            corrected_words[len(correct_word)][0] += 1
            if corrected_word == correct_word:
                corrected_words[len(correct_word)][1] += 1

            # Print information about current progress
            print(
                'Correct: {:12s}  Most probable: {:12s}  Corrected: {:12s}  Distance: {:1d}  Success: {}'
                .format(correct_word, most_probable_word, corrected_word,
                        distance, corrected_word == correct_word))

    #  Save results
    with open('{}/test_results_distance.txt'.format(save_path), 'w') as f:
        for size in sorted(distances):
            for distance in sorted(distances[size]):
                f.write('{};{};{}\n'.format(size, distance,
                                            distances[size][distance]))

    with open('{}/test_results_size.txt'.format(save_path), 'w') as f:
        for size in sorted(size_distances):
            for size_distance in sorted(size_distances[size]):
                f.write('{};{};{}\n'.format(
                    size, size_distance, size_distances[size][size_distance]))

    with open('{}/test_results_corrected.txt'.format(save_path), 'w') as f:
        for key in sorted(corrected_words):
            for count in sorted(corrected_words[key]):
                f.write('{};{};{}\n'.format(key, count,
                                            corrected_words[key][count]))
def run_alexnet_ann_recall_test_simulation_trial7():
    output_name="alexnet_recall_task_trial7.txt"
    num_nodes=10
    full_connection_mat = np.ones(shape=(num_nodes,num_nodes)) - np.eye(num_nodes)
    alex_cnn = AlexNet()
    alex_cnn.load_state_dict(torch.load("trained_models/alexnet.pt", map_location=torch.device("cpu")))
    alex_cnn.eval()
    alex_capture = Intermediate_Capture(alex_cnn.fc3) # for now capture final output

    transform = transforms.ToTensor()
    data_raw = MNIST(
    root='./data/mnist',
    train=True,
    download=True,
    transform=transform)

    # creating a toy dataset for simple probing
    mnist_subset = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]}
    per_class_sizes = {0:10, 1:10, 2:10, 3:10, 4:10, 5:10, 6:10, 7:10, 8:10, 9:10}
    for i in range(len(data_raw)):
        image, label = data_raw[i]
        if len(mnist_subset[label]) < per_class_sizes[label]:
            mnist_subset[label].append(torch.reshape(image, (1,1, 28,28)))
        done = True
        for k in mnist_subset:
            if len(mnist_subset[k]) < per_class_sizes[k]:
                done=False
        if done:
            break


    # converts mnist_subset into table that is usable for model input
    full_pattern_set = []
    full_label_set = []
    for k in mnist_subset:
        for v in mnist_subset[k]:
            full_pattern_set.append(v)
            full_label_set.append(k)

    # given list of a desired labels, randomly choose an example of each label from the mnist dataset to store
    stored_size_vs_performance = [] # list will store tuples of (hopfield perf, popularity perf, ortho perf)
    for desired_label_size in range(10):

        # need to generate probe set each time
        # when desired label size is k:
        # probe set is 10 instances each of labels 0 to k-1
        desired_labels = list(range(desired_label_size+1))
        sub_probe_set = []
        sub_probe_labels = []
        for des in desired_labels:
            # add 10 instances of des
            for inst in mnist_subset[des]:
                sub_probe_set.append(inst)
                sub_probe_labels.append(des)
        full_stored_set, full_stored_labels = create_storage_set(desired_labels, mnist_subset, reshape=False, make_numpy=False)
        print("Num Stored: ", len(desired_labels))

        # evaluate hopnet performance
        ann_model = hopnet(num_nodes) 
        model = CNN_ANN(alex_cnn, ann_model, alex_capture, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, sub_probe_set, sub_probe_labels, verbose=False)
        print("Hopfield:", num_succ, ":", num_fail)
        hopfield_perf = int(num_succ)

        # evaluate popularity ANN performance
        # hyperparams: set c = N-1, with randomly generated connectivity matrix
        ann_model = PopularityANN(N=num_nodes, c=num_nodes-1, connectivity_matrix=full_connection_mat)
        model = CNN_ANN(alex_cnn, ann_model, alex_capture, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, sub_probe_set, sub_probe_labels, verbose=False)
        print("PopularityANN:", num_succ, ":", num_fail)
        popularity_perf = int(num_succ)

        # evaluate orthogonal hebbs ANN performance
        ann_model = OrthogonalHebbsANN(N=num_nodes)
        model = CNN_ANN(alex_cnn, ann_model, alex_capture, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, sub_probe_set, sub_probe_labels, verbose=False)
        print("OrthogonalHebbsANN:", num_succ, ":", num_fail)
        ortho_perf = int(num_succ)

        stored_size_vs_performance.append((hopfield_perf, popularity_perf, ortho_perf))

    # write performance to file
    fh = open("data/graph_sources/" + output_name, "w")
    for perf in stored_size_vs_performance:
        fh.write(str(perf[0]) + "," + str(perf[1]) + "," + str(perf[2]) + "\n")
    fh.close()
    return stored_size_vs_performance
Example #18
0
def run_experiment(args):
    torch.manual_seed(args.seed)
    if not args.no_cuda:
        torch.cuda.manual_seed(args.seed)

    # Dataset
    if args.dataset == 'mnist':
        train_loader, test_loader, _, val_data = prepare_mnist(args)
    else:
        create_val_img_folder(args)
        train_loader, test_loader, _, val_data = prepare_imagenet(args)
    idx_to_class = {i: c for c, i in val_data.class_to_idx.items()}

    # Model & Criterion
    if args.model == 'AlexNet':
        if args.pretrained:
            model = models.__dict__['alexnet'](pretrained=True)
            # Change the last layer
            in_f = model.classifier[-1].in_features
            model.classifier[-1] = nn.Linear(in_f, args.classes)
        else:
            model = AlexNet(args.classes)
        criterion = nn.CrossEntropyLoss(size_average=False)
    else:
        model = SVM(args.features, args.classes)
        criterion = MultiClassHingeLoss(margin=args.margin, size_average=False)
    if not args.no_cuda:
        model.cuda()

    # Load saved model and test on it
    if args.load:
        model.load_state_dict(torch.load(args.model_path))
        val_acc = test(model, criterion, test_loader, 0, [], [], idx_to_class,
                       args)

    # Optimizer
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters())
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum)

    total_minibatch_count = 0
    val_acc = 0
    train_losses, train_accs = [], []
    val_losses, val_accs = [], []

    # Train and test
    for epoch in range(1, args.epochs + 1):
        total_minibatch_count = train(model, criterion, optimizer,
                                      train_loader, epoch,
                                      total_minibatch_count, train_losses,
                                      train_accs, args)

        val_acc = test(model, criterion, test_loader, epoch, val_losses,
                       val_accs, idx_to_class, args)

    # Save model
    if args.save:
        if not os.path.exists(args.models_dir):
            os.makedirs(args.models_dir)
        filename = '_'.join(
            [args.prefix, args.dataset, args.model, 'model.pt'])
        torch.save(model.state_dict(), os.path.join(args.models_dir, filename))

    # Plot graphs
    fig, axes = plt.subplots(1, 4, figsize=(13, 4))
    axes[0].plot(train_losses)
    axes[0].set_title('Loss')
    axes[1].plot(train_accs)
    axes[1].set_title('Acc')
    axes[1].set_ylim([0, 1])
    axes[2].plot(val_losses)
    axes[2].set_title('Val loss')
    axes[3].plot(val_accs)
    axes[3].set_title('Val Acc')
    axes[3].set_ylim([0, 1])
    # Images don't show on Ubuntu
    # plt.tight_layout()

    # Save results
    if not os.path.exists(args.results_dir):
        os.makedirs(args.results_dir)
    filename = '_'.join([args.prefix, args.dataset, args.model, 'plot.png'])
    fig.suptitle(filename)
    fig.savefig(os.path.join(args.results_dir, filename))
def train_generic_model(model_name="alexnet",
                        dataset="custom",
                        num_classes=-1,
                        batch_size=8,
                        is_transform=1,
                        num_workers=2,
                        lr_decay=1,
                        l2_reg=0,
                        hdf5_path="dataset-bosch-224x224.hdf5",
                        trainset_dir="./TRAIN_data_224_v8",
                        testset_dir="./TEST_data_224_v8",
                        convert_grey=False):
    CHKPT_PATH = "./checkpoint_{}.PTH".format(model_name)
    print("CUDA:")
    print(torch.cuda.is_available())
    if is_transform:

        trans_ls = []
        if convert_grey:
            trans_ls.append(transforms.Grayscale(num_output_channels=1))
        trans_ls.extend([
            transforms.Resize((224, 224)),
            # transforms.RandomCrop((224, 224)),
            # transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        transform = transforms.Compose(trans_ls)
    else:
        transform = None

    print("DATASET FORMAT: {}".format(dataset))
    print("TRAINSET PATH: {}".format(trainset_dir))
    print("TESTSET PATH: {}".format(testset_dir))
    print("HDF5 PATH: {}".format(hdf5_path))
    if dataset == "custom":
        trainset = torchvision.datasets.ImageFolder(root=trainset_dir,
                                                    transform=transform)
        train_size = len(trainset)
        testset = torchvision.datasets.ImageFolder(root=testset_dir,
                                                   transform=transform)
        test_size = len(testset)
    elif dataset == "cifar":
        trainset = torchvision.datasets.CIFAR10(root="CIFAR_TRAIN_data",
                                                train=True,
                                                download=True,
                                                transform=transform)
        train_size = len(trainset)
        testset = torchvision.datasets.CIFAR10(root="CIFAR_TEST_data",
                                               train=False,
                                               download=True,
                                               transform=transform)
        test_size = len(testset)
    elif dataset == "hdf5":
        if num_workers == 1:
            trainset = Hdf5Dataset(hdf5_path,
                                   transform=transform,
                                   is_test=False)
        else:
            trainset = Hdf5DatasetMPI(hdf5_path,
                                      transform=transform,
                                      is_test=False)
        train_size = len(trainset)
        if num_workers == 1:
            testset = Hdf5Dataset(hdf5_path, transform=transform, is_test=True)
        else:
            testset = Hdf5DatasetMPI(hdf5_path,
                                     transform=transform,
                                     is_test=True)
        test_size = len(testset)

    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers)

    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=num_workers)
    if model_name == "alexnet":
        net = AlexNet(num_classes=num_classes)
    elif model_name == "lenet5":
        net = LeNet5(num_classes=num_classes)
    elif model_name == "stn-alexnet":
        net = STNAlexNet(num_classes=num_classes)
    elif model_name == "stn-lenet5":
        net = LeNet5STN(num_classes=num_classes)
    elif model_name == "capsnet":
        net = CapsuleNet(num_classes=num_classes)
    elif model_name == "convneta":
        net = ConvNetA(num_classes=num_classes)
    elif model_name == "convnetb":
        net = ConvNetB(num_classes=num_classes)
    elif model_name == "convnetc":
        net = ConvNetC(num_classes=num_classes)
    elif model_name == "convnetd":
        net = ConvNetD(num_classes=num_classes)
    elif model_name == "convnete":
        net = ConvNetE(num_classes=num_classes)
    elif model_name == "convnetf":
        net = ConvNetF(num_classes=num_classes)
    elif model_name == "convnetg":
        net = ConvNetG(num_classes=num_classes)
    elif model_name == "convneth":
        net = ConvNetH(num_classes=num_classes)
    elif model_name == "convneti":
        net = ConvNetI(num_classes=num_classes)
    elif model_name == "convnetj":
        net = ConvNetJ(num_classes=num_classes)
    elif model_name == "convnetk":
        net = ConvNetK(num_classes=num_classes)
    elif model_name == "convnetl":
        net = ConvNetL(num_classes=num_classes)
    elif model_name == "convnetm":
        net = ConvNetM(num_classes=num_classes)
    elif model_name == "convnetn":
        net = ConvNetN(num_classes=num_classes)
    elif model_name == "resnet18":
        net = models.resnet18(pretrained=False, num_classes=num_classes)

    print(net)

    if torch.cuda.is_available():
        net = net.cuda()

    if model_name == "capsnet":
        criterion = CapsuleLoss()
    else:
        criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(net.parameters(),
                          lr=LEARNING_RATE,
                          momentum=0.9,
                          weight_decay=l2_reg)

    if lr_decay:
        scheduler = ReduceLROnPlateau(optimizer, 'min')

    best_acc = 0
    from_epoch = 0

    if os.path.exists(CHKPT_PATH):
        print("Checkpoint Found: {}".format(CHKPT_PATH))
        state = torch.load(CHKPT_PATH)
        net.load_state_dict(state['state_dict'])
        optimizer.load_state_dict(state['optimizer'])
        best_acc = state['best_accuracy']
        from_epoch = state['epoch']

    for epoch in range(from_epoch, NUM_EPOCHS):
        #print("Epoch: {}/{}".format(epoch + 1, NUM_EPOCHS))
        epoch_loss = 0
        correct = 0
        for i, data in enumerate(train_loader, 0):
            #print("Train \t Epoch: {}/{} \t Batch: {}/{}".format(epoch + 1,
            #                                            NUM_EPOCHS,
            #                                            i + 1,
            #                                            ceil(train_size / BATCH_SIZE)))
            inputs, labels = data
            inputs, labels = Variable(inputs).type(torch.FloatTensor),\
                             Variable(labels).type(torch.LongTensor)

            if model_name == "capsnet":
                inputs = augmentation(inputs)
                ground_truth = torch.eye(num_classes).index_select(
                    dim=0, index=labels)

            if torch.cuda.is_available():
                inputs = inputs.cuda()
                labels = labels.cuda()

            optimizer.zero_grad()

            if model_name == "capsnet":
                classes, reconstructions = net(inputs, ground_truth)
                loss = criterion(inputs, ground_truth, classes,
                                 reconstructions)
            else:
                outputs = net(inputs)
                loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()
            epoch_loss += loss.data[0]
            if model_name != "capsnet":
                log_outputs = F.softmax(outputs, dim=1)
            else:
                log_outputs = classes
            pred = log_outputs.data.max(1, keepdim=True)[1]
            correct += pred.eq(labels.data.view_as(pred)).sum()

        print(
            "Epoch: {} \t Training Loss: {:.4f} \t Training Accuracy: {:.2f} \t {}/{}"
            .format(epoch + 1, epoch_loss / train_size,
                    100 * correct / train_size, correct, train_size))

        correct = 0
        test_loss = 0
        for i, data in enumerate(test_loader, 0):
            # print("Test \t Epoch: {}/{} \t Batch: {}/{}".format(epoch + 1,
            #                                             NUM_EPOCHS,
            #                                             i + 1,
            #                                             ceil(test_size / BATCH_SIZE)))
            inputs, labels = data
            inputs, labels = Variable(inputs).type(
                torch.FloatTensor), Variable(labels).type(torch.LongTensor)

            if model_name == "capsnet":
                inputs = augmentation(inputs)
                ground_truth = torch.eye(num_classes).index_select(
                    dim=0, index=labels)

            if torch.cuda.is_available():
                inputs = inputs.cuda()
                labels = labels.cuda()

            if model_name == "capsnet":
                classes, reconstructions = net(inputs)
                loss = criterion(inputs, ground_truth, classes,
                                 reconstructions)
            else:
                outputs = net(inputs)
                loss = criterion(outputs, labels)

            test_loss += loss.data[0]

            if model_name != "capsnet":
                log_outputs = F.softmax(outputs, dim=1)
            else:
                log_outputs = classes

            pred = log_outputs.data.max(1, keepdim=True)[1]
            correct += pred.eq(labels.data.view_as(pred)).sum()
        print(
            "Epoch: {} \t Testing Loss: {:.4f} \t Testing Accuracy: {:.2f} \t {}/{}"
            .format(epoch + 1, test_loss / test_size,
                    100 * correct / test_size, correct, test_size))
        if correct >= best_acc:
            if not os.path.exists("./models"):
                os.mkdir("./models")
            torch.save(
                net.state_dict(),
                "./models/model-{}-{}-{}-{}-val-acc-{:.2f}-train-{}-test-{}-epoch-{}.pb"
                .format(model_name, dataset, hdf5_path, str(datetime.now()),
                        100 * correct / test_size,
                        trainset_dir.replace(" ", "_").replace("/", "_"),
                        testset_dir.replace(" ", "_").replace("/",
                                                              "_"), epoch + 1))
        best_acc = max(best_acc, correct)

        # save checkpoint path
        state = {
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_accuracy': best_acc
        }
        torch.save(state, CHKPT_PATH)

        if lr_decay:
            # Note that step should be called after validate()
            scheduler.step(test_loss)

    print('Finished Training')

    print("")
    print("")