def get_vgg_net(model_folder, out_keys=['r11', 'r21', 'r31', 'r41', 'r51']):

    vgg_net = VGG(pool='avg', out_keys=out_keys)
    vgg_net.load_state_dict(torch.load(model_folder + 'vgg_conv.pth'))
    vgg_net.cuda()
    for param in vgg_net.parameters():
        param.requires_grad = False
    return vgg_net
        checkpoint['net'][k] = weights
        totwts += tot_wts
        totpru += totpru_wts
        totmcount += tot_mcnt

print("Total weights", totwts)
print("Total Pruned", totpru)
print("Total Mask Count", totmcount)
checkpoint['address'] = addressbook
checkpoint['mask'] = maskbook
net.load_state_dict(checkpoint['net'])

# Training

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),
                      lr=args.lr,
                      momentum=0.9,
                      weight_decay=5e-4)


def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
    teacher.load_state_dict(torch.load(args.teacher_path))
    teacher.cuda()
    if args.multi_gpu:
        teacher = torch.nn.DataParallel(teacher)

#############

criterion = CrossEntropyLossMaybeSmooth(smooth_eps=args.smooth_eps).cuda()
# args.smooth = args.smooth_eps > 0.0
# args.mixup = config.alpha > 0.0

optimizer_init_lr = args.warmup_lr if args.warmup else args.lr

optimizer = None
if (args.optmzr == 'sgd'):
    optimizer = torch.optim.SGD(model.parameters(),
                                optimizer_init_lr,
                                momentum=0.9,
                                weight_decay=1e-4)
elif (args.optmzr == 'adam'):
    optimizer = torch.optim.Adam(model.parameters(), optimizer_init_lr)

scheduler = None
if args.lr_scheduler == 'cosine':
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     T_max=args.epochs *
                                                     len(train_loader),
                                                     eta_min=4e-08)
elif args.lr_scheduler == 'default':
    # my learning rate scheduler for cifar, following https://github.com/kuangliu/pytorch-cifar
    epoch_milestones = [150, 250, 350]
Example #4
0
                running_loss += loss.item() * inputs.size(0)
                corrects_class += torch.sum(preds_class == labels_class)

            epoch_loss = running_loss / len(data_loaders[phase].dataset)
            Loss_list[phase].append(epoch_loss)
            epoch_acc_class = corrects_class.double() / len(data_loaders[phase].dataset)

            Accuracy_list_class[phase].append(100 * epoch_acc_class)
            print('{} Loss: {:.4f}  Acc_class: {:.2%}'.format(phase, epoch_loss,epoch_acc_class))

            if phase == 'val' and epoch_acc_class > best_acc:
                best_acc = epoch_acc_class
                best_model_wts = copy.deepcopy(model.state_dict())
                print('Best val class Acc: {:.2%}'.format(best_acc))

    model.load_state_dict(best_model_wts)
    #torch.save(model.state_dict(), 'best_model.pt')
    print('Best val class Acc: {:.2%}'.format(best_acc))
    return model,Loss_list,Accuracy_list_class

network = VGG('VGG11').to(device)
optimizer = optim.SGD(network.parameters(), lr=0.01, momentum=0.9) #lr改成0.05会收敛得快一点
criterion = nn.CrossEntropyLoss()
#exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) # Decay LR by a factor of 0.1 every 1 epochs
model, Loss_list, Accuracy_list_class = train_model(network, criterion, optimizer, num_epochs=10)




Example #5
0
testset = torchvision.datasets.CIFAR10(root='~/data',
                                       train=False,
                                       download=True,
                                       transform=transform)
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=128,
                                         shuffle=False,
                                         num_workers=4)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
           'ship', 'truck')

net = VGG('VGG16')
net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

print('Training Start...')
for epoch in range(200):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in tqdm(enumerate(trainloader)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
Example #6
0
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.t7')
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
Example #7
0
        ]),
    }

    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in [train, val]}

    dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                 batch_size=8,
                                                 shuffle=True,
                                                 num_workers=4) for x in [train, val]}

    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', val]}

    use_gpu = torch.cuda.is_available()

    model = VGG(2)

    if os.path.exists(save_path):
        model.load_state_dict(torch.load(save_path))

    if use_gpu:
        model = model.cuda()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.95)

    scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    print ('*' * 10)
    print ('start training')
    trainval(dataloders, model, optimizer, scheduler, criterion, dataset_sizes, phase='train')
Example #8
0
def main(style_img_path: str,
         content_img_path: str, 
         img_dim: int,
         num_iter: int,
         style_weight: int,
         content_weight: int,
         variation_weight: int,
         print_every: int,
         save_every: int):

    assert style_img_path is not None
    assert content_img_path is not None

    # define the device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # read the images
    style_img = Image.open(style_img_path)
    cont_img = Image.open(content_img_path)
    
    # define the transform
    transform = transforms.Compose([transforms.Resize((img_dim, img_dim)),
                                    transforms.ToTensor(), 
                                    transforms.Normalize([0.485, 0.456, 0.406],
                                                         [0.229, 0.224, 0.225])])
    
    # get the tensor of the image
    content_image = transform(cont_img).unsqueeze(0).to(device)
    style_image = transform(style_img).unsqueeze(0).to(device)
    
    # init the network
    vgg = VGG().to(device).eval()
    
    # replace the MaxPool with the AvgPool layers
    for name, child in vgg.vgg.named_children():
        if isinstance(child, nn.MaxPool2d):
            vgg.vgg[int(name)] = nn.AvgPool2d(kernel_size=2, stride=2)
            
    # lock the gradients
    for param in vgg.parameters():
        param.requires_grad = False
    
    # get the content activations of the content image and detach them from the graph
    content_activations = vgg.get_content_activations(content_image).detach()
    
    # unroll the content activations
    content_activations = content_activations.view(512, -1)
    
    # get the style activations of the style image
    style_activations = vgg.get_style_activations(style_image)
    
    # for every layer in the style activations
    for i in range(len(style_activations)):

        # unroll the activations and detach them from the graph
        style_activations[i] = style_activations[i].squeeze().view(style_activations[i].shape[1], -1).detach()

    # calculate the gram matrices of the style image
    style_grams = [gram(style_activations[i]) for i in range(len(style_activations))]
    
    # generate the Gaussian noise
    noise = torch.randn(1, 3, img_dim, img_dim, device=device, requires_grad=True)
    
    # define the adam optimizer
    # pass the feature map pixels to the optimizer as parameters
    adam = optim.Adam(params=[noise], lr=0.01, betas=(0.9, 0.999))

    # run the iteration
    for iteration in range(num_iter):

        # zero the gradient
        adam.zero_grad()

        # get the content activations of the Gaussian noise
        noise_content_activations = vgg.get_content_activations(noise)

        # unroll the feature maps of the noise
        noise_content_activations = noise_content_activations.view(512, -1)

        # calculate the content loss
        content_loss_ = content_loss(noise_content_activations, content_activations)

        # get the style activations of the noise image
        noise_style_activations = vgg.get_style_activations(noise)

        # for every layer
        for i in range(len(noise_style_activations)):

            # unroll the the noise style activations
            noise_style_activations[i] = noise_style_activations[i].squeeze().view(noise_style_activations[i].shape[1], -1)

        # calculate the noise gram matrices
        noise_grams = [gram(noise_style_activations[i]) for i in range(len(noise_style_activations))]

        # calculate the total weighted style loss
        style_loss = 0
        for i in range(len(style_activations)):
            N, M = noise_style_activations[i].shape[0], noise_style_activations[i].shape[1]
            style_loss += (gram_loss(noise_grams[i], style_grams[i], N, M) / 5.)

        # put the style loss on device
        style_loss = style_loss.to(device)
            
        # calculate the total variation loss
        variation_loss = total_variation_loss(noise).to(device)

        # weight the final losses and add them together
        total_loss = content_weight * content_loss_ + style_weight * style_loss + variation_weight * variation_loss

        if iteration % print_every == 0:
            print("Iteration: {}, Content Loss: {:.3f}, Style Loss: {:.3f}, Var Loss: {:.3f}".format(iteration, 
                                                                                                     content_weight * content_loss_.item(),
                                                                                                     style_weight * style_loss.item(), 
                                                                                                     variation_weight * variation_loss.item()))

        # create the folder for the generated images
        if not os.path.exists('./generated/'):
            os.mkdir('./generated/')
        
        # generate the image
        if iteration % save_every == 0:
            save_image(noise.cpu().detach(), filename='./generated/iter_{}.png'.format(iteration))

        # backprop
        total_loss.backward()
        
        # update parameters
        adam.step()
transform_train = transforms.Compose([
    transforms.RandomCrop(crop_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

transform_test = transforms.Compose([
    transforms.TenCrop(crop_size),
    transforms.Lambda(lambda crops: torch.stack(
        [transforms.ToTensor()(crop) for crop in crops]))
])

if args.checkpoint is None:
    start_epoch = 0
    model = VGG(args.model_name)
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=5e-4)
else:
    checkpoint = torch.load(args.checkpoint)
    start_epoch = checkpoint['epoch'] + 1
    print('\nLoaded checkpoint from epoch %d.\n' % start_epoch)
    model = VGG(args.model_name)
    model.load_state_dict(checkpoint["model_weights"])
    optimizer = checkpoint["optimizer"]
    if args.adjust_optim is not None:
        print("Adjust optimizer....")
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=0.9,
Example #10
0
# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))
# Make a grid from batch
out = torchvision.utils.make_grid(inputs)
#DI.imshow(out, train_mean, train_std)#, title=[class_names[x] for x in classes])

#num_train = 100
#train_loader = torch.utils.data.DataLoader(train_data, batch_size=50, shuffle=False, num_workers=4,
#                                           sampler=OverfitSampler(num_train))
#val_loader = torch.utils.data.DataLoader(val_data, batch_size=50, shuffle=False, num_workers=4)

model = VGG()
#solver = Solver(optim_args={"lr": 1e-2})
#solver.train(model, dataloaders['train'], dataloaders['val'], log_nth=1, num_epochs=5)
if model.is_cuda:
    model = model.cuda()

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
Solver.train_model(dataloaders,
                   dataset_sizes,
                   model,
                   criterion,
                   optimizer_ft,
                   exp_lr_scheduler,
                   num_epochs=25)
Example #11
0
    def train(self):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = VGG(self.model_type, True).to(device)

        loss_function = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=self.learning_rate,
                                    weight_decay=5e-4)

        start_epoch = 0
        end_epoch = self.epochs
        average_loss_list = []
        writer = SummaryWriter('logs')

        if self.load_model_dir is not None:
            checkpoint = torch.load(self.load_model_dir)
            model.load_state_dict(checkpoint['model_state_dict'])
            start_epoch += checkpoint['epoch']
            end_epoch += checkpoint['epoch']
            average_loss_list = checkpoint['average_loss_list']
            for idx, loss in enumerate(average_loss_list):
                writer.add_scalar("Training Loss Average", loss, idx + 1)

        mean, std = self.compute_mean_std(
            datasets.CIFAR100(self.dataset_dir, train=True, download=True))

        milestone = [60, 120, 160, 180, 200, 220]

        transform_ops = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomResizedCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        train_dataset = datasets.CIFAR100(self.dataset_dir,
                                          train=True,
                                          download=True,
                                          transform=transform_ops)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=self.batch_size,
                                                   shuffle=True,
                                                   num_workers=self.num_worker)
        train_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=milestone, gamma=0.2)

        model.train()

        running_loss = 0
        running_idx = 0
        running_temp_idx = 0

        for epoch in range(start_epoch, end_epoch):
            total_loss = 0
            for batch_idx, (x, y) in enumerate(train_loader):
                x = Variable(x.cuda())
                y = Variable(y.cuda())
                optimizer.zero_grad()

                predicted = model(x)
                loss = loss_function(predicted, y)
                loss.backward()
                optimizer.step()
                total_loss += loss.data.item()

                running_loss += loss.data.item()
                running_temp_idx += 1
                if batch_idx % 100 == 0:
                    print(
                        'Train Epoch: {}  {:.2f}% Percent Finished. Current Loss: {:.6f}'
                        .format(epoch + 1, 100 * batch_idx / len(train_loader),
                                total_loss))

                if running_temp_idx % 50 == 0:
                    running_idx += 1
                    writer.add_scalar("Running Loss", running_loss / 100,
                                      running_idx)
                    runninng_temp_idx = 0
                    running_loss = 0

            writer.add_scalar("Training Loss Average",
                              total_loss / len(train_loader), epoch + 1)
            print('Epoch {} Finished! Total Loss: {:.2f}'.format(
                epoch + 1, total_loss))

            print("---------------Test Initalized!------------------")
            accuracy = test("", 128, model=model)
            writer.add_scalar("Test Accuracy", accuracy, epoch + 1)

            train_scheduler.step()

            average_loss_list.append(total_loss / len(train_loader))
            if (epoch + 1) % 50 == 0:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'model_state_dict': model.state_dict(),
                        'average_loss_list': average_loss_list,
                        'model_type': self.model_type
                    }, self.model_save_dir +
                    "vgg-checkpoint-{}.pth".format(epoch + 1))

        torch.save(
            {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'average_loss_list': average_loss_list,
                'model_type': self.model_type
            }, self.model_save_dir + "vgg.pth".format(epoch + 1))
Example #12
0
    contentImg = contentImg.unsqueeze(0)
    styleImg,contentImg,content_iq = util.luminance_transfer(styleImg.numpy(),contentImg.numpy())
    styleImg = Variable(torch.from_numpy(styleImg))
    contentImg = Variable(torch.from_numpy(contentImg))
else:
    styleImg = load_image(opt.style_image) # 1x3x512x512
    contentImg = load_image(opt.content_image) # 1x3x512x512

if(opt.cuda):
    styleImg = styleImg.cuda()
    contentImg = contentImg.cuda()

###############   MODEL   ####################
vgg = VGG()
vgg.load_state_dict(torch.load(opt.vgg_dir))
for param in vgg.parameters():
    param.requires_grad = False
if(opt.cuda):
    vgg.cuda()
###########   LOSS & OPTIMIZER   ##########
class GramMatrix(nn.Module):
    def forward(self,input):
        b, c, h, w = input.size()
        f = input.view(b,c,h*w) # bxcx(hxw)
        # torch.bmm(batch1, batch2, out=None)   #
        # batch1: bxmxp, batch2: bxpxn -> bxmxn #
        G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
        return G.div_(h*w)

class styleLoss(nn.Module):
    def forward(self,input,target):
Example #13
0
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)

testset = torch.utils.data.TensorDataset(imgs[-10000:], labels[-10000:])
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# Model
print('==> Building model..')
net = VGG('VGG11')

if use_cuda:
    net.cuda()
    net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
    cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(net.parameters(), lr=0.01)

# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda: inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
def train():
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=32),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data',
                                            train=True,
                                            download=False,
                                            transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data',
                                           train=False,
                                           download=False,
                                           transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=2)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    model = VGG(vars(args))
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lrate,
                                momentum=0.9,
                                weight_decay=5e-4)

    if args.use_cuda:
        model = model.cuda()

    if args.eval:
        model.load_state_dict(torch.load(args.model_dir))
        model.eval()
        accuracy = model.evaluate(testloader)
        exit()

    total_size = len(trainloader)
    lrate = args.lrate
    best_score = 0.0
    scores = []
    for epoch in range(1, args.epochs + 1):
        model.train()
        for i, (image, label) in enumerate(trainloader):

            loss = model(image, label)
            model.zero_grad()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if i % 100 == 0:
                print('Epoch = %d, step = %d / %d, loss = %.5f lrate = %.5f' %
                      (epoch, i, total_size, loss, lrate))

        model.eval()
        accuracy = model.evaluate(testloader)
        scores.append(accuracy)

        with open(args.model_dir + "_scores.pkl", "wb") as f:
            pkl.dump(scores, f)

        if best_score < accuracy:
            best_score = accuracy
            print('saving %s ...' % args.model_dir)
            torch.save(model.state_dict(), args.model_dir)

        if epoch % args.decay_period == 0:
            lrate *= args.decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lrate
Example #15
0
            epoch_loss = running_loss / len(data_loaders[phase].dataset)
            Loss_list[phase].append(epoch_loss)
            epoch_acc_species = corrects_species.double() / len(
                data_loaders[phase].dataset)

            Accuracy_list_species[phase].append(100 * epoch_acc_species)
            print('{} Loss: {:.4f}  Acc_class: {:.2%}'.format(
                phase, epoch_loss, epoch_acc_species))

            if phase == 'val' and epoch_acc_species > best_acc:
                best_acc = epoch_acc_species
                best_model_wts = copy.deepcopy(model.state_dict())
                print('Best val species Acc: {:.2%}'.format(best_acc))

    model.load_state_dict(best_model_wts)
    #torch.save(model.state_dict(), 'best_model.pt')
    print('Best val class Acc: {:.2%}'.format(best_acc))
    return model, Loss_list, Accuracy_list_species


network = VGG('VGG11').to(device)
optimizer = optim.SGD(network.parameters(), lr=0.01,
                      momentum=0.9)  #lr改成0.05会收敛得快一点
criterion = nn.CrossEntropyLoss()
#exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) # Decay LR by a factor of 0.1 every 1 epochs
model, Loss_list, Accuracy_list_species = train_model(network,
                                                      criterion,
                                                      optimizer,
                                                      num_epochs=10)
Example #16
0
    styleImg, contentImg, content_iq = util.luminance_transfer(
        styleImg.numpy(), contentImg.numpy())
    styleImg = Variable(torch.from_numpy(styleImg))
    contentImg = Variable(torch.from_numpy(contentImg))
else:
    styleImg = load_image(opt.style_image)  # 1x3x512x512
    contentImg = load_image(opt.content_image)  # 1x3x512x512

if (opt.cuda):
    styleImg = styleImg.cuda()
    contentImg = contentImg.cuda()

###############   MODEL   ####################
vgg = VGG()
vgg.load_state_dict(torch.load(opt.vgg_dir))
for param in vgg.parameters():
    param.requires_grad = False
if (opt.cuda):
    vgg.cuda()


###########   LOSS & OPTIMIZER   ##########
class GramMatrix(nn.Module):
    def forward(self, input):
        b, c, h, w = input.size()
        f = input.view(b, c, h * w)  # bxcx(hxw)
        # torch.bmm(batch1, batch2, out=None)   #
        # batch1: bxmxp, batch2: bxpxn -> bxmxn #
        G = torch.bmm(f, f.transpose(
            1, 2))  # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
        return G.div_(h * w)
Example #17
0
def main():
    best_acc = 0
    start_epoch = args.start_epoch

    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    trainloader = getdata(args, train=True)
    testloader = getdata(args, train=False)

    model = VGG(args.attention, args.nclass)

    if args.gpu:
        if torch.cuda.is_available():
            model = model.cuda()
            cudnn.benchmark = True
        else:
            print(
                'There is no cuda available on this machine use cpu instead.')
            args.gpu = False

    criterion = nn.CrossEntropyLoss()
    optimizer = ''
    if args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
    else:
        print(args.optimizer, 'is not correct')
        return

    title = 'cifar-10-' + args.attention

    if args.evaluate:
        print('\nEvaluation only')
        assert os.path.isfile(
            args.evaluate), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.evaluate)
        model.load_state_dict(checkpoint['state_dict'])
        test_loss, test_acc = test(model, testloader, criterion, args.gpu)
        print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
        return

    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.checkpoint = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint,
                                     state['attention'] + '-' + 'log.txt'),
                        title=title,
                        resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint,
                                     state['attention'] + '-' + 'log.txt'),
                        title=title)
        logger.set_names([
            'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.',
            'Valid Acc.'
        ])

    for epoch in range(start_epoch, args.epochs):
        start_time = time.time()
        adjust_learning_rate(optimizer, epoch)

        train_loss, train_acc = train(model, trainloader, criterion, optimizer,
                                      epoch, args.gpu)
        test_loss, test_acc = test(model, testloader, criterion, args.gpu)
        if sys.version[0] == '3':
            train_acc = train_acc.cpu().numpy().tolist()[0]
            test_acc = test_acc.cpu().numpy().tolist()[0]
        logger.append(
            [state['lr'], train_loss, test_loss, train_acc, test_acc])

        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
                'attention': state['attention'],
            },
            is_best,
            checkpoint=args.checkpoint)
        print(time.time() - start_time)
        print(
            "epoch: {:3d}, lr: {:.8f}, train-loss: {:.3f}, test-loss: {:.3f}, train-acc: {:2.3f}, test_acc:, {:2.3f}"
            .format(epoch, state['lr'], train_loss, test_loss, train_acc,
                    test_acc))

    logger.close()
    logger.plot()
    savefig(os.path.join(args.checkpoint,
                         state['attention'] + '-' + 'log.eps'))

    print('Best acc:', best_acc)