Beispiel #1
0
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import os
import network
from utils import soft_cross_entropy, kldiv
from utils.visualizer import VisdomPlotter
from utils.misc import pack_images, denormalize
from dataloader import get_dataloader
import torchvision
import random
import numpy as np

vp = VisdomPlotter('15550', env='DFAD-caltech101')


def train(args, teacher, student, generator, device, train_loader, optimizer,
          epoch):
    teacher.eval()
    student.train()
    generator.train()
    optimizer_S, optimizer_G = optimizer

    for i in range(args.epoch_itrs):
        for k in range(5):
            z = torch.randn((args.batch_size, args.nz, 1, 1)).to(device)
            optimizer_S.zero_grad()
            fake = generator(z).detach()
            t_logit = teacher(fake)
Beispiel #2
0
import numpy as np
import argparse
import torch
import torchvision
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim

import network
from utils.visualizer import VisdomPlotter
from utils.loss import *
from dataloader import get_dataloader
from quantization import quantize_model


vp = VisdomPlotter('8097', env='ZAQ-main')

def train(args, p_model, q_model, generator, optimizer, epoch):
    p_model.eval()
    q_model.train()
    generator.train()
    optimizer_Q, optimizer_G = optimizer

    inter_loss = SCRM().to(args.device)

    for i in range(args.epoch_itrs):
        for k in range(5):
            z = torch.randn((args.batch_size, args.nz, 1, 1)).to(args.device)
            optimizer_Q.zero_grad()
            fake = generator(z).detach()
            g_p, p_logit = p_model(fake, True)
Beispiel #3
0
import torch.nn as nn
import torch.optim as optim

from torchvision import datasets, transforms
import torchvision
import network
from utils import soft_cross_entropy, kldiv
from utils.visualizer import VisdomPlotter
from utils.misc import pack_images, denormalize
from dataloader import get_dataloader
from utils.stream_metrics import StreamSegMetrics
import random, os
import numpy as np
from PIL import Image

vp = VisdomPlotter('15550', env='DFAD-nyuv2')


def train(args, teacher, student, generator, device, train_loader, optimizer,
          epoch):
    teacher.eval()
    student.train()
    generator.train()
    optimizer_S, optimizer_G = optimizer
    for i in range(args.epoch_itrs):

        for k in range(5):
            z = torch.randn((args.batch_size, args.nz, 1, 1)).to(device)
            optimizer_S.zero_grad()
            fake = generator(z).detach()
            t_logit = teacher(fake)
Beispiel #4
0
from __future__ import print_function
import argparse
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import network
from utils.visualizer import VisdomPlotter
from utils.misc import pack_images, denormalize
from dataloader import get_dataloader
import os, random
import numpy as np
import torchvision

vp = VisdomPlotter('15550', env='DFAD-cifar')


def train(args, teacher, student, generator, device, optimizer, epoch):
    teacher.eval()
    student.train()
    generator.train()
    optimizer_S, optimizer_G = optimizer

    for i in range(args.epoch_itrs):
        for k in range(5):
            z = torch.randn((args.batch_size, args.nz, 1, 1)).to(device)
            optimizer_S.zero_grad()
            fake = generator(z).detach()
            t_logit = teacher(fake)
            s_logit = student(fake)
            loss_S = F.l1_loss(s_logit, t_logit.detach())
Beispiel #5
0
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim

from torchvision import datasets, transforms
import torchvision
import network
from utils import soft_cross_entropy, kldiv
from utils.visualizer import VisdomPlotter
from utils.misc import pack_images, denormalize
from dataloader import get_dataloader
from utils.stream_metrics import StreamSegMetrics
import random, os
import numpy as np

vp = VisdomPlotter('15550', env='DFAD-camvid')

def train(args, teacher, student, generator, device, train_loader, optimizer, epoch):
    teacher.eval()
    student.train()
    generator.train()
    optimizer_S, optimizer_G = optimizer

    for i in range( args.epoch_itrs ):
        for k in range(5):
            z = torch.randn( (args.batch_size, args.nz, 1, 1) ).to(device)
            optimizer_S.zero_grad()
            fake = generator(z).detach()
            t_logit = teacher(fake)
            s_logit = student(fake)
            loss_S = F.l1_loss(s_logit, t_logit.detach())
Beispiel #6
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--num_classes', type=int, default=11)
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=16,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=200,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        metavar='LR',
                        help='learning rate (default: 0.1)')
    parser.add_argument('--data_root', type=str, default='data')
    parser.add_argument('--dataset',
                        type=str,
                        default='camvid',
                        choices=['camvid', 'nyuv2'],
                        help='dataset name (default: camvid)')
    parser.add_argument(
        '--model',
        type=str,
        default='deeplabv3_resnet50',
        choices=['deeplabv3_resnet50', 'segnet_vgg19', 'segnet_vgg16'],
        help='model name (default: deeplabv3_resnet50)')
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--gamma', type=float, default=0.1)
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--no_cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--step_size',
                        type=int,
                        default=100,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--ckpt', type=str, default=None)
    parser.add_argument(
        '--log_interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--test_only', action='store_true', default=False)
    parser.add_argument('--download', action='store_true', default=False)
    parser.add_argument('--scheduler', action='store_true', default=False)
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    print(args)

    global vp
    vp = VisdomPlotter('15550', 'teacher-seg-%s' % args.dataset)

    train_loader, test_loader = get_dataloader(args)
    model = get_model(args)

    if args.ckpt is not None:
        model.load_state_dict(torch.load(args.ckpt))
    model = model.to(device)

    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          momentum=args.momentum)

    best_result = 0
    if args.scheduler:
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              args.step_size,
                                              gamma=args.gamma)

    if args.test_only:
        results = test(args, model, device, test_loader)
        return

    for epoch in range(1, args.epochs + 1):
        if args.scheduler:
            scheduler.step()
        print("Lr = %.6f" % (optimizer.param_groups[0]['lr']))
        train(args, model, device, train_loader, optimizer, epoch)
        results = test(args, model, device, test_loader)
        vp.add_scalar('mIoU', epoch, results['Mean IoU'])
        if results['Mean IoU'] > best_result:
            best_result = results['Mean IoU']
            torch.save(
                model.state_dict(),
                "checkpoint/teacher/%s-%s.pt" % (args.dataset, args.model))
    print("Best mIoU=%.6f" % best_result)