class TrainNetwork():
    def __init__(self, dataset, batch_size, epochs, lr, lr_decay_epoch,
                 momentum):
        assert (dataset == 'letters' or dataset == 'mnist')

        self.dataset = dataset
        self.batch_size = batch_size
        self.epochs = epochs
        self.lr = lr
        self.lr_decay_epoch = lr_decay_epoch
        self.momentum = momentum

        # letters contains 27 classes, digits contains 10 classes
        num_classes = 27 if dataset == 'letters' else 10

        # Load pre learned AlexNet with changed number of output classes
        state_dict = torch.load('./trained_models/alexnet.pth')
        state_dict['classifier.6.weight'] = torch.zeros(num_classes, 4096)
        state_dict['classifier.6.bias'] = torch.zeros(num_classes)
        self.model = AlexNet(num_classes)
        self.model.load_state_dict(state_dict)

        # Use cuda if available
        if torch.cuda.is_available():
            self.model.cuda()

        # Load training dataset
        kwargs = {
            'num_workers': 1,
            'pin_memory': True
        } if torch.cuda.is_available() else {}
        self.train_loader = torch.utils.data.DataLoader(
            EMNIST('./data',
                   dataset,
                   download=True,
                   transform=transforms.Compose([
                       transforms.Lambda(correct_rotation),
                       transforms.Lambda(random_transform),
                       transforms.Resize((224, 224)),
                       transforms.RandomResizedCrop(224, (0.9, 1.1),
                                                    ratio=(0.9, 1.1)),
                       transforms.Grayscale(3),
                       transforms.ToTensor(),
                   ])),
            batch_size=batch_size,
            shuffle=True,
            **kwargs)

        # Optimizer and loss function
        self.optimizer = optim.SGD(self.model.parameters(),
                                   lr=self.lr,
                                   momentum=self.momentum)
        self.loss_fn = nn.CrossEntropyLoss()

    def reduce_learning_rate(self, epoch):
        """
        Reduce the learning rate by factor 0.1 every lr_decay_epoch
        :param optimizer: Optimizer containing the learning rate
        :param epoch: Current epoch
        :param init_lr: Initial learning rate
        :param lr_decay_epoch: Number of epochs until learning rate gets reduced
        :return: None
        """
        lr = self.lr * (0.1**(epoch // self.lr_decay_epoch))

        if epoch % self.lr_decay_epoch == 0:
            print('LR is set to {}'.format(lr))

        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def train(self, epoch):
        """
        Train the model for one epoch and save the result as a .pth file
        :param epoch: Current epoch
        :return: None
        """
        self.model.train()

        train_loss = 0
        train_correct = 0
        progress = None
        for batch_idx, (data, target) in enumerate(self.train_loader):
            # Get data and label
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)

            # Optimize using backpropagation
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.loss_fn(output, target)
            train_loss += loss.data[0]
            pred = output.data.max(1, keepdim=True)[1]
            train_correct += pred.eq(target.data.view_as(pred)).sum()
            loss.backward()
            self.optimizer.step()

            # Print information about current step
            current_progress = int(100 * (batch_idx + 1) * self.batch_size /
                                   len(self.train_loader.dataset))
            if current_progress is not progress and current_progress % 5 == 0:
                progress = current_progress
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, (batch_idx + 1) * len(data),
                    len(self.train_loader.dataset), current_progress,
                    loss.data[0]))

        train_loss /= (len(self.train_loader.dataset) / self.batch_size)
        train_correct /= len(self.train_loader.dataset)
        train_correct *= 100

        # Print information about current epoch
        print(
            'Train Epoch: {} \tCorrect: {:3.2f}%\tAverage loss: {:.6f}'.format(
                epoch, train_correct, train_loss))

        # Save snapshot
        torch.save(
            {
                'model': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict()
            }, './trained_models/{}_{}.pth'.format(self.dataset, epoch))

    def start(self):
        """
        Start training the network
        :return: None
        """
        for epoch in range(1, self.epochs + 1):
            self.reduce_learning_rate(epoch)
            self.train(epoch)
Example #2
0
def run_experiment(args):
    torch.manual_seed(args.seed)
    if not args.no_cuda:
        torch.cuda.manual_seed(args.seed)

    # Dataset
    if args.dataset == 'mnist':
        train_loader, test_loader, _, val_data = prepare_mnist(args)
    else:
        create_val_img_folder(args)
        train_loader, test_loader, _, val_data = prepare_imagenet(args)
    idx_to_class = {i: c for c, i in val_data.class_to_idx.items()}

    # Model & Criterion
    if args.model == 'AlexNet':
        if args.pretrained:
            model = models.__dict__['alexnet'](pretrained=True)
            # Change the last layer
            in_f = model.classifier[-1].in_features
            model.classifier[-1] = nn.Linear(in_f, args.classes)
        else:
            model = AlexNet(args.classes)
        criterion = nn.CrossEntropyLoss(size_average=False)
    else:
        model = SVM(args.features, args.classes)
        criterion = MultiClassHingeLoss(margin=args.margin, size_average=False)
    if not args.no_cuda:
        model.cuda()

    # Load saved model and test on it
    if args.load:
        model.load_state_dict(torch.load(args.model_path))
        val_acc = test(model, criterion, test_loader, 0, [], [], idx_to_class,
                       args)

    # Optimizer
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters())
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum)

    total_minibatch_count = 0
    val_acc = 0
    train_losses, train_accs = [], []
    val_losses, val_accs = [], []

    # Train and test
    for epoch in range(1, args.epochs + 1):
        total_minibatch_count = train(model, criterion, optimizer,
                                      train_loader, epoch,
                                      total_minibatch_count, train_losses,
                                      train_accs, args)

        val_acc = test(model, criterion, test_loader, epoch, val_losses,
                       val_accs, idx_to_class, args)

    # Save model
    if args.save:
        if not os.path.exists(args.models_dir):
            os.makedirs(args.models_dir)
        filename = '_'.join(
            [args.prefix, args.dataset, args.model, 'model.pt'])
        torch.save(model.state_dict(), os.path.join(args.models_dir, filename))

    # Plot graphs
    fig, axes = plt.subplots(1, 4, figsize=(13, 4))
    axes[0].plot(train_losses)
    axes[0].set_title('Loss')
    axes[1].plot(train_accs)
    axes[1].set_title('Acc')
    axes[1].set_ylim([0, 1])
    axes[2].plot(val_losses)
    axes[2].set_title('Val loss')
    axes[3].plot(val_accs)
    axes[3].set_title('Val Acc')
    axes[3].set_ylim([0, 1])
    # Images don't show on Ubuntu
    # plt.tight_layout()

    # Save results
    if not os.path.exists(args.results_dir):
        os.makedirs(args.results_dir)
    filename = '_'.join([args.prefix, args.dataset, args.model, 'plot.png'])
    fig.suptitle(filename)
    fig.savefig(os.path.join(args.results_dir, filename))
Example #3
0
def main():

    #gpus=[4,5,6,7]
    gpus = [0]
    print("GPUs :", gpus)
    print("prepare data")
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_tfs = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])

    val_tfs = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(), normalize
    ])

    train_ds = datasets.ImageFolder('/home/gw/data/imagenet_10/train',
                                    train_tfs)
    val_ds = datasets.ImageFolder('/home/gw/data/imagenet_10/val', val_tfs)

    train_ld = torch.utils.data.DataLoader(train_ds,
                                           batch_size=256,
                                           shuffle=True,
                                           num_workers=4,
                                           pin_memory=True)

    val_ld = torch.utils.data.DataLoader(val_ds,
                                         batch_size=64,
                                         shuffle=False,
                                         num_workers=4,
                                         pin_memory=True)

    print("construct model")
    #model = ResNet50()
    #model=torchvision.models.AlexNet()
    model = AlexNet()
    #model = torch.nn.DataParallel(model, device_ids=gpus).cuda(gpus[0])
    model.cuda()

    criterion = nn.CrossEntropyLoss().cuda(gpus[0])
    optimizer = torch.optim.SGD(model.parameters(),
                                0.01,
                                momentum=0.875,
                                weight_decay=3.0517578125e-05)

    model.train()
    print("begin trainning")
    for epoch in range(0, 50):
        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')
        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')

        progress = ProgressMeter(len(train_ld),
                                 [batch_time, data_time, losses, top1, top5],
                                 prefix="Epoch: [{}]".format(epoch))

        end = time.time()
        for i, (images, labels) in enumerate(train_ld):
            data_time.update(time.time() - end)
            print('image shape: ', images.shape)
            print('labels shape: ', labels.shape)
            images = images.cuda(gpus[0], non_blocking=True)
            labels = labels.cuda(gpus[0], non_blocking=True)

            outputs = model(images)
            loss = criterion(outputs, labels)

            # measure accuracy
            acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_time.update(time.time() - end)
            end = time.time()

            if i % 10 == 0:
                progress.display(i)
Example #4
0
def train(train_loader, eval_loader, opt):
    print('==> Start training...')

    summary_writer = SummaryWriter('./runs/' + str(int(time.time())))

    is_cuda = torch.cuda.is_available()
    model = AlexNet()
    if is_cuda:
        model = model.cuda()

    optimizer = optim.SGD(
        params=model.parameters(),
        lr=opt.base_lr,
        momentum=0.9,
    )
    criterion = nn.CrossEntropyLoss()

    best_eval_acc = -0.1
    losses = AverageMeter()
    accuracies = AverageMeter()
    global_step = 0
    for epoch in range(1, opt.epochs + 1):
        # train
        model.train()
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            global_step += 1
            if is_cuda:
                inputs = inputs.cuda()
                targets = targets.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            losses.update(loss.item(), outputs.shape[0])
            summary_writer.add_scalar('train/loss', loss, global_step)

            _, preds = torch.max(outputs, dim=1)
            acc = preds.eq(targets).sum().item() / len(targets)
            accuracies.update(acc)
            summary_writer.add_scalar('train/acc', acc, global_step)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        summary_writer.add_scalar('lr', optimizer.param_groups[0]['lr'],
                                  global_step)
        print(
            '==> Epoch: %d; Average Train Loss: %.4f; Average Train Acc: %.4f'
            % (epoch, losses.avg, accuracies.avg))

        # eval
        model.eval()
        losses.reset()
        accuracies.reset()
        for batch_idx, (inputs, targets) in enumerate(eval_loader):
            if is_cuda:
                inputs = inputs.cuda()
                targets = targets.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            losses.update(loss.item(), outputs.shape[0])

            _, preds = torch.max(outputs, dim=1)
            acc = preds.eq(targets).sum().item() / len(targets)
            accuracies.update(acc)

        summary_writer.add_scalar('eval/loss', losses.avg, global_step)
        summary_writer.add_scalar('eval/acc', accuracies.avg, global_step)
        if accuracies.avg > best_eval_acc:
            best_eval_acc = accuracies.avg
            torch.save(model, './weights/best.pt')
        print(
            '==> Epoch: %d; Average Eval Loss: %.4f; Average/Best Eval Acc: %.4f / %.4f'
            % (epoch, losses.avg, accuracies.avg, best_eval_acc))
Example #5
0
    elif args.model == 'resnet':
        cnn = ResNet18(enable_lat=args.enable_lat,
                       epsilon=args.epsilon,
                       pro_num=args.pro_num,
                       batch_size=args.batchsize,
                       num_classes=200,
                       if_dropout=args.dropout)
        #cnn.apply(conv_init)
    elif args.model == 'alexnetBN':
        cnn = AlexNetBN(enable_lat=args.enable_lat,
                        epsilon=args.epsilon,
                        pro_num=args.pro_num,
                        batch_size=args.batchsize,
                        num_classes=200,
                        if_dropout=args.dropout)
    cnn.cuda()

    if os.path.exists(real_model_path):
        cnn.load_state_dict(torch.load(real_model_path))
        print('model successfully loaded.')
    else:
        print("load model failed.")

    if args.test_flag:
        if args.adv_flag:
            test_all(cnn)
        else:
            test_op(cnn)
    else:
        train_op(cnn)
Example #6
0
class TestNetwork():
    def __init__(self, dataset, batch_size, epochs):
        self.dataset = dataset
        self.batch_size = batch_size
        self.epochs = epochs

        # letters contains 27 classes, digits contains 10 classes
        num_classes = 27 if dataset == 'letters' else 10

        # Load mdoel and use cuda if available
        self.model = AlexNet(num_classes)
        if torch.cuda.is_available():
            self.model.cuda()

        # Load testing dataset
        kwargs = {
            'num_workers': 1,
            'pin_memory': True
        } if torch.cuda.is_available() else {}
        self.test_loader = torch.utils.data.DataLoader(EMNIST(
            './data',
            dataset,
            download=True,
            transform=transforms.Compose([
                transforms.Lambda(correct_rotation),
                transforms.Resize((224, 224)),
                transforms.Grayscale(3),
                transforms.ToTensor(),
            ]),
            train=False),
                                                       batch_size=batch_size,
                                                       shuffle=True,
                                                       **kwargs)

        # Optimizer and loss function
        self.loss_fn = nn.CrossEntropyLoss()

    def test(self, epoch):
        """
        Test the model for one epoch with a pre trained network
        :param epoch: Current epoch
        :return: None
        """
        # Load weights from trained model
        state_dict = torch.load(
            './trained_models/{}_{}.pth'.format(self.dataset, epoch),
            map_location=lambda storage, loc: storage)['model']
        self.model.load_state_dict(state_dict)
        self.model.eval()

        test_loss = 0
        test_correct = 0
        progress = None
        for batch_idx, (data, target) in enumerate(self.test_loader):
            # Get data and label
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)

            #
            output = self.model(data)
            loss = self.loss_fn(output, target)
            test_loss += loss.data[0]
            pred = output.data.max(1, keepdim=True)[1]
            test_correct += pred.eq(target.data.view_as(pred)).sum()

            # Print information about current step
            current_progress = int(100 * (batch_idx + 1) * self.batch_size /
                                   len(self.test_loader.dataset))
            if current_progress is not progress and current_progress % 5 == 0:
                progress = current_progress
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, (batch_idx + 1) * len(data),
                    len(self.test_loader.dataset), current_progress,
                    loss.data[0]))

        test_loss /= (len(self.test_loader.dataset) / self.batch_size)
        test_correct /= len(self.test_loader.dataset)
        test_correct *= 100

        # Print information about current epoch
        print(
            'Test Epoch: {} \tCorrect: {:3.2f}%\tAverage loss: {:.6f}'.format(
                epoch, test_correct, test_loss))

    def start(self):
        """
        Start testing the network
        :return: None
        """
        for epoch in range(1, self.epochs + 1):
            self.test(epoch)
Example #7
0
    SAVE_PATH = './cp_large.bin'
    logger = Logger('./largecnnlogs')

    lossfunction = nn.MSELoss()

    dataset = Rand_num()
    sampler = RandomSampler(dataset)
    loader = DataLoader(dataset,
                        batch_size=20,
                        sampler=sampler,
                        shuffle=False,
                        num_workers=1,
                        drop_last=True)
    net = AlexNet(3)
    #net.load_state_dict(torch.load(SAVE_PATH))
    net.cuda()
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(10000):
        for i, data in enumerate(loader, 0):
            net.zero_grad()
            video, labels = data
            video = video.view(-1, 3, 227, 227)
            labels = labels.view(-1, 3)
            labels = torch.squeeze(Variable(labels.float().cuda()))
            video = torch.squeeze(Variable((video.float() / 256).cuda()))
            net.train()
            outputs = net.forward(video)
            loss = lossfunction(outputs, labels)
            loss.backward()
            optimizer.step()
            if i == 0:
def train():

    lr = 0.001
    epoch_num = 40
    iteration = 10
    tra_acc_list, val_acc_list, loss_list = [], [], []

    # print(device)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load data
    trainLoader, validLoader, testLoader = prepare_data()

    # init model
    model = AlexNet()

    #optimizer and loss
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.classifier.parameters(), lr, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.1,
        patience=5,
        verbose=True,
        threshold=lr,
        threshold_mode='rel',
        cooldown=0,
        min_lr=0,
        eps=1e-08)

    # gpu
    torch.cuda.empty_cache()
    model = model.cuda()
    criterion = criterion.cuda()
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.cuda()

    # train process
    for epoch in range(epoch_num):  # loop over the dataset multiple times
        epoch_str = f' Epoch {epoch + 1}/{epoch_num} '
        print(f'{epoch_str:-^40s}')
        print(f'Learning rate: {optimizer.param_groups[0]["lr"]}')
        running_loss = 0.0
        for i, data in enumerate(trainLoader, 0):
            # get the inputs
            inputs, labels = data[0].cuda(), data[1].cuda()
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % iteration == iteration - 1:  # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / iteration))
                loss_list.append(running_loss / iteration)
            running_loss = 0.0

        eval_model_train(model, trainLoader, device, tra_acc_list)
        eval_model_validation(model, validLoader, device, val_acc_list)
        scheduler.step(loss.item())

    print('Finished Training')
    eval_model_test(model, testLoader, device)
    save_history(tra_acc_list, val_acc_list, loss_list)
    torch.save(model, './model/save.pt')
Example #9
0
import torchvision.models as models

from alexnet import AlexNet

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True

device = 'cuda' if torch.cuda.is_available(
) else 'cpu'  # device = 'cpu' # CPU ONLY

net = AlexNet(num_classes=10)
if torch.cuda.is_available():
    net = net.cuda()

if device == 'cuda':
    # make it concurent
    # net = torch.nn.DataParallel(net)
    cudnn.benchmark = True


def train_dataset(path, shuffle=True):

    transformation = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224, 244)),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor()
    ])
Example #10
0
    if args.model == 'vgg':
        model = VGG16(enable_lat=False,
                      epsilon=0.6,
                      pro_num=5,
                      batch_size=args.batch_size,
                      if_dropout=True)
    elif args.model == 'resnet':
        model = ResNet50(enable_lat=False,
                         epsilon=0.6,
                         pro_num=5,
                         batch_size=args.batch_size,
                         if_dropout=True)
    elif args.model == 'alexnet':
        model = AlexNet(enable_lat=False,
                        epsilon=0.6,
                        pro_num=5,
                        batch_size=args.batch_size,
                        num_classes=200,
                        if_dropout=True)
    model.cuda()

    if os.path.exists(args.model_path):
        model.load_state_dict(torch.load(args.model_path))
        print('load model successfully.')
    else:
        print("load failed.")
    if args.test_flag:
        test_all(model)
    else:
        test_op(model)
def train_generic_model(model_name="alexnet",
                        dataset="custom",
                        num_classes=-1,
                        batch_size=8,
                        is_transform=1,
                        num_workers=2,
                        lr_decay=1,
                        l2_reg=0,
                        hdf5_path="dataset-bosch-224x224.hdf5",
                        trainset_dir="./TRAIN_data_224_v8",
                        testset_dir="./TEST_data_224_v8",
                        convert_grey=False):
    CHKPT_PATH = "./checkpoint_{}.PTH".format(model_name)
    print("CUDA:")
    print(torch.cuda.is_available())
    if is_transform:

        trans_ls = []
        if convert_grey:
            trans_ls.append(transforms.Grayscale(num_output_channels=1))
        trans_ls.extend([
            transforms.Resize((224, 224)),
            # transforms.RandomCrop((224, 224)),
            # transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        transform = transforms.Compose(trans_ls)
    else:
        transform = None

    print("DATASET FORMAT: {}".format(dataset))
    print("TRAINSET PATH: {}".format(trainset_dir))
    print("TESTSET PATH: {}".format(testset_dir))
    print("HDF5 PATH: {}".format(hdf5_path))
    if dataset == "custom":
        trainset = torchvision.datasets.ImageFolder(root=trainset_dir,
                                                    transform=transform)
        train_size = len(trainset)
        testset = torchvision.datasets.ImageFolder(root=testset_dir,
                                                   transform=transform)
        test_size = len(testset)
    elif dataset == "cifar":
        trainset = torchvision.datasets.CIFAR10(root="CIFAR_TRAIN_data",
                                                train=True,
                                                download=True,
                                                transform=transform)
        train_size = len(trainset)
        testset = torchvision.datasets.CIFAR10(root="CIFAR_TEST_data",
                                               train=False,
                                               download=True,
                                               transform=transform)
        test_size = len(testset)
    elif dataset == "hdf5":
        if num_workers == 1:
            trainset = Hdf5Dataset(hdf5_path,
                                   transform=transform,
                                   is_test=False)
        else:
            trainset = Hdf5DatasetMPI(hdf5_path,
                                      transform=transform,
                                      is_test=False)
        train_size = len(trainset)
        if num_workers == 1:
            testset = Hdf5Dataset(hdf5_path, transform=transform, is_test=True)
        else:
            testset = Hdf5DatasetMPI(hdf5_path,
                                     transform=transform,
                                     is_test=True)
        test_size = len(testset)

    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers)

    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=num_workers)
    if model_name == "alexnet":
        net = AlexNet(num_classes=num_classes)
    elif model_name == "lenet5":
        net = LeNet5(num_classes=num_classes)
    elif model_name == "stn-alexnet":
        net = STNAlexNet(num_classes=num_classes)
    elif model_name == "stn-lenet5":
        net = LeNet5STN(num_classes=num_classes)
    elif model_name == "capsnet":
        net = CapsuleNet(num_classes=num_classes)
    elif model_name == "convneta":
        net = ConvNetA(num_classes=num_classes)
    elif model_name == "convnetb":
        net = ConvNetB(num_classes=num_classes)
    elif model_name == "convnetc":
        net = ConvNetC(num_classes=num_classes)
    elif model_name == "convnetd":
        net = ConvNetD(num_classes=num_classes)
    elif model_name == "convnete":
        net = ConvNetE(num_classes=num_classes)
    elif model_name == "convnetf":
        net = ConvNetF(num_classes=num_classes)
    elif model_name == "convnetg":
        net = ConvNetG(num_classes=num_classes)
    elif model_name == "convneth":
        net = ConvNetH(num_classes=num_classes)
    elif model_name == "convneti":
        net = ConvNetI(num_classes=num_classes)
    elif model_name == "convnetj":
        net = ConvNetJ(num_classes=num_classes)
    elif model_name == "convnetk":
        net = ConvNetK(num_classes=num_classes)
    elif model_name == "convnetl":
        net = ConvNetL(num_classes=num_classes)
    elif model_name == "convnetm":
        net = ConvNetM(num_classes=num_classes)
    elif model_name == "convnetn":
        net = ConvNetN(num_classes=num_classes)
    elif model_name == "resnet18":
        net = models.resnet18(pretrained=False, num_classes=num_classes)

    print(net)

    if torch.cuda.is_available():
        net = net.cuda()

    if model_name == "capsnet":
        criterion = CapsuleLoss()
    else:
        criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(net.parameters(),
                          lr=LEARNING_RATE,
                          momentum=0.9,
                          weight_decay=l2_reg)

    if lr_decay:
        scheduler = ReduceLROnPlateau(optimizer, 'min')

    best_acc = 0
    from_epoch = 0

    if os.path.exists(CHKPT_PATH):
        print("Checkpoint Found: {}".format(CHKPT_PATH))
        state = torch.load(CHKPT_PATH)
        net.load_state_dict(state['state_dict'])
        optimizer.load_state_dict(state['optimizer'])
        best_acc = state['best_accuracy']
        from_epoch = state['epoch']

    for epoch in range(from_epoch, NUM_EPOCHS):
        #print("Epoch: {}/{}".format(epoch + 1, NUM_EPOCHS))
        epoch_loss = 0
        correct = 0
        for i, data in enumerate(train_loader, 0):
            #print("Train \t Epoch: {}/{} \t Batch: {}/{}".format(epoch + 1,
            #                                            NUM_EPOCHS,
            #                                            i + 1,
            #                                            ceil(train_size / BATCH_SIZE)))
            inputs, labels = data
            inputs, labels = Variable(inputs).type(torch.FloatTensor),\
                             Variable(labels).type(torch.LongTensor)

            if model_name == "capsnet":
                inputs = augmentation(inputs)
                ground_truth = torch.eye(num_classes).index_select(
                    dim=0, index=labels)

            if torch.cuda.is_available():
                inputs = inputs.cuda()
                labels = labels.cuda()

            optimizer.zero_grad()

            if model_name == "capsnet":
                classes, reconstructions = net(inputs, ground_truth)
                loss = criterion(inputs, ground_truth, classes,
                                 reconstructions)
            else:
                outputs = net(inputs)
                loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()
            epoch_loss += loss.data[0]
            if model_name != "capsnet":
                log_outputs = F.softmax(outputs, dim=1)
            else:
                log_outputs = classes
            pred = log_outputs.data.max(1, keepdim=True)[1]
            correct += pred.eq(labels.data.view_as(pred)).sum()

        print(
            "Epoch: {} \t Training Loss: {:.4f} \t Training Accuracy: {:.2f} \t {}/{}"
            .format(epoch + 1, epoch_loss / train_size,
                    100 * correct / train_size, correct, train_size))

        correct = 0
        test_loss = 0
        for i, data in enumerate(test_loader, 0):
            # print("Test \t Epoch: {}/{} \t Batch: {}/{}".format(epoch + 1,
            #                                             NUM_EPOCHS,
            #                                             i + 1,
            #                                             ceil(test_size / BATCH_SIZE)))
            inputs, labels = data
            inputs, labels = Variable(inputs).type(
                torch.FloatTensor), Variable(labels).type(torch.LongTensor)

            if model_name == "capsnet":
                inputs = augmentation(inputs)
                ground_truth = torch.eye(num_classes).index_select(
                    dim=0, index=labels)

            if torch.cuda.is_available():
                inputs = inputs.cuda()
                labels = labels.cuda()

            if model_name == "capsnet":
                classes, reconstructions = net(inputs)
                loss = criterion(inputs, ground_truth, classes,
                                 reconstructions)
            else:
                outputs = net(inputs)
                loss = criterion(outputs, labels)

            test_loss += loss.data[0]

            if model_name != "capsnet":
                log_outputs = F.softmax(outputs, dim=1)
            else:
                log_outputs = classes

            pred = log_outputs.data.max(1, keepdim=True)[1]
            correct += pred.eq(labels.data.view_as(pred)).sum()
        print(
            "Epoch: {} \t Testing Loss: {:.4f} \t Testing Accuracy: {:.2f} \t {}/{}"
            .format(epoch + 1, test_loss / test_size,
                    100 * correct / test_size, correct, test_size))
        if correct >= best_acc:
            if not os.path.exists("./models"):
                os.mkdir("./models")
            torch.save(
                net.state_dict(),
                "./models/model-{}-{}-{}-{}-val-acc-{:.2f}-train-{}-test-{}-epoch-{}.pb"
                .format(model_name, dataset, hdf5_path, str(datetime.now()),
                        100 * correct / test_size,
                        trainset_dir.replace(" ", "_").replace("/", "_"),
                        testset_dir.replace(" ", "_").replace("/",
                                                              "_"), epoch + 1))
        best_acc = max(best_acc, correct)

        # save checkpoint path
        state = {
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_accuracy': best_acc
        }
        torch.save(state, CHKPT_PATH)

        if lr_decay:
            # Note that step should be called after validate()
            scheduler.step(test_loss)

    print('Finished Training')

    print("")
    print("")