Beispiel #1
0
def make_network(model_name, class_num):
    if model_name == 'resnet':
        net = resnet.resnet34(True)
        net.fc = nn.Linear(512, class_num)
    elif model_name == 'mobilenet':
        net = mobilenet.mobilenet_v2(True)
        net.classifier[1] = nn.Linear(1280, class_num)
    else:
        net = raspnet.raspnet(name=model_name, class_num=class_num)
    return net
Beispiel #2
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    print("=> creating model '{}'".format(args.arch))
    model = resnet.resnet50(g=args.groups, r=args.compression_rate, progressive=args.progressive) if args.arch == 'resnet50' else \
                            mobilenet.mobilenet_v2(g=args.groups, r=args.compression_rate, progressive=args.progressive)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(
        [p for p in model.parameters() if p.requires_grad],
        args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay)

    # resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if 'epoch' in checkpoint.keys():
                args.start_epoch = checkpoint['epoch']
            chck_state_dict = checkpoint[
                'state_dict'] if 'state_dict' in checkpoint.keys(
                ) else checkpoint
            model_state_dict = model.state_dict()
            loaded_dict = {
                k: v
                for k, v in chck_state_dict.items() if k in model_state_dict
            }
            if not bool(loaded_dict
                        ):  # empty dictionary if model was trained in parallel
                loaded_dict = {
                    'module.' + k: v
                    for k, v in chck_state_dict.items()
                    if 'module.' + k in model_state_dict
                }
            model_state_dict.update(loaded_dict)
            model.load_state_dict(model_state_dict)
            if 'optimizer' in checkpoint.keys():
                optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            name=args.arch,
            g=args.groups,
            r=args.compression_rate,
            mode=args.progressive)
Beispiel #3
0
import torch

from mobilenet import mobilenet_v2

# create model
model = mobilenet_v2(pretrained=True)

# create input
x = torch.randn((1, 3, 224, 224))

# output
out = model(x)
print('end')
Beispiel #4
0
def mobilenetv2():
    """Load mobilenet"""
    return mobilenet.mobilenet_v2(pretrained=True, progress=True)
Beispiel #5
0
def mobilenet_v2(pretrained=False, progress=True):
    return mobilenet.mobilenet_v2(pretrained, progress)
Beispiel #6
0
# Settings
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
network = 'mobilenet'
model_dir = './model/{}/epoch-5.pth.tar'.format(network)
age_thresh = 6
batch_size = 500
data_dir = r'./data'

if __name__ == '__main__':
    # define network
    if network == 'resnet':
        net = resnet.resnet18(True)
        net.fc = nn.Linear(512, 2)
    elif network == 'mobilenet':
        net = mobilenet.mobilenet_v2(True)
        net.classifier[1] = nn.Linear(1280, 2)
    else:
        raise NotImplementedError
    net.load_state_dict(torch.load(model_dir)['state_dict'])
    net.to(device)

    transform_valid = transforms.Compose([
        transforms.ToPILImage(),
        transforms.CenterCrop((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    valid_reader = DataLoader(reader.UTKDataLoader(data_dir,
                                                   'valid.hdf5',
Beispiel #7
0
def main():
	# get data (remember to substitute path to UTKFace dataset)
	x_train, y_train, x_valid, y_valid = utils.get_images(r'../UTKFace', age_thresh=custom_at, resize_shape=resize_shape)

	# define reader
	transform_train = transforms.Compose([
		transforms.ToPILImage(),
		transforms.RandomHorizontalFlip(),
		transforms.ToTensor(),
		transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
	])
	transform_valid = transforms.Compose([
		transforms.ToPILImage(),
		transforms.ToTensor(),
		transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
	])
	train_reader = DataLoader(utils.UTKDataLoader(x_train, y_train, tsfm=transform_train), batch_size=batch_size,
							  num_workers=4, shuffle=True)
	valid_reader = DataLoader(utils.UTKDataLoader(x_valid, y_valid, tsfm=transform_valid), batch_size=batch_size,
							  num_workers=4, shuffle=False)

	# network
	if model_type == 'large':
		if model_name == 'resnet':
			net = resnet.resnet34(True)
			net.fc = nn.Linear(512, prev_class_num)
		elif model_name == 'mobilenet':
			net = mobilenet.mobilenet_v2(True)
			net.classifier[1] = nn.Linear(1280, prev_class_num)
		else:
			raise NotImplementedError
	else:
		net = raspnet.raspnet(name=model_name, class_num=prev_class_num)
	writer.add_graph(net, torch.rand(1, 3, *resize_shape))

	if load_prev == True:
		# NOTE load previously trained model here
		checkpoint = torch.load(prev_model)
		net.load_state_dict(checkpoint["state_dict"])
		for param in net.parameters():	# freeze parameters
			param.requires_grad = False

	net.dense6_1 = nn.Linear(256, curr_class_num)
	#optimizer.load_state_dict(checkpoint["opt_dict"])

	net.to(device)
	net.train()	# set training mode

	# define loss function
	criterion = nn.CrossEntropyLoss().to(device)
	# set up optimizer
	if opt == 'Adam':
		optimizer = optim.Adam(net.parameters(), lr=learning_rate, amsgrad=True)
	else:
		optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
	scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=decay_step, gamma=decay_rate)

	"""
	# NOTE manually reset learning rate to desired value
	for g in optimizer.param_groups:
		g['lr'] = learning_rate
	"""

	# train
	for epoch in range(1, epochs+1):
		
		if load_prev:	# only do this when in curriculum training mode
			if epoch == unfreeze_epoch:
				print("### epoch: "+str(epoch)+" unfreezed parameters!")
				for param in net.parameters():
					param.requires_grad = True

		running_loss = 0.0
		pbar = tqdm(train_reader)
		for i, data in enumerate(pbar):
			inputs, labels = data
			inputs, labels = inputs.float().to(device), labels.long().to(device)
			optimizer.zero_grad()
			outputs = net(inputs)
			loss = criterion(outputs, labels)
			loss.backward()
			optimizer.step()
			scheduler.step()

			running_loss += loss.item()
			if i % verb_step == verb_step - 1:
				pbar.set_description('Epoch {} Step {}: train cross entropy loss: {:.4f}'.
									 format(epoch, i + 1, running_loss / verb_step))
				running_loss = 0.0

		# validation
		correct = 0
		total = 0
		truth = []
		pred = []
		with torch.no_grad():
			for data in valid_reader:
				inputs, labels = data
				inputs, labels = inputs.float().to(device), labels.long().to(device)
				outputs = net(inputs)
				_, predicted = torch.max(outputs.data, 1)
				total += labels.size(0)
				correct += (predicted == labels).sum().item()
				truth.extend(labels.cpu().numpy())
				pred.extend(predicted.cpu().numpy())

		p, r, f1 = utils.f1_score(truth, pred, 0)
		print('Epoch {}: valid accuracy: {:.2f}, precision: {:.2f}, recall: {:.2f}, f1: {:.2f}'.format(
			epoch, 100 * correct / total, p, r, f1))
		writer.add_scalar('valid/acc', correct / total, epoch)
		writer.add_scalar('valid/precision', p, epoch)
		writer.add_scalar('valid/recall', r, epoch)
		writer.add_scalar('valid/f1', f1, epoch)
		
		#if load_prev == False:
		if False:
			if epoch % save_epoch == 0:
				save_name = os.path.join(save_dir, 'epoch-{}.pth.tar'.format(epoch))
				torch.save({
					'epoch': epochs,
					'state_dict': net.state_dict(),
					'opt_dict': optimizer.state_dict(),
				}, save_name)
				print('Saved model at {}'.format(save_name))

	print('Finished training')
	save_name = str(curr_class_num)+'.pt'
	torch.save({
		'epoch': epochs,
		'state_dict': net.state_dict(),
		'opt_dict': optimizer.state_dict(),
	}, save_name)
	print('Saved model at {}'.format(save_name))
Beispiel #8
0
def init():
    global model
    model = mobilenet_v2(pretrained=True)
    return ('OK', 200)
Beispiel #9
0
def main():
    # get data
    x_train, y_train, x_valid, y_valid = utils.get_images(
        r'./data/UTKFace', resize_shape=resize_shape)

    # define reader
    transform_train = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    transform_valid = transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    train_reader = DataLoader(utils.UTKDataLoader(x_train,
                                                  y_train,
                                                  tsfm=transform_train),
                              batch_size=batch_size,
                              num_workers=4,
                              shuffle=True)
    valid_reader = DataLoader(utils.UTKDataLoader(x_valid,
                                                  y_valid,
                                                  tsfm=transform_valid),
                              batch_size=batch_size,
                              num_workers=4,
                              shuffle=False)

    # network
    if model_type == 'large':
        if model_name == 'resnet':
            net = resnet.resnet34(True)
            net.fc = nn.Linear(512, class_num)
        elif model_name == 'mobilenet':
            net = mobilenet.mobilenet_v2(True)
            net.classifier[1] = nn.Linear(1280, class_num)
        else:
            raise NotImplementedError
    else:
        net = raspnet.raspnet(name=model_name, class_num=class_num)
    net.to(device)

    # define loss
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(net.parameters(), lr=learning_rate, amsgrad=True)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=[10, 15],
                                               gamma=0.1)

    # train
    for epoch in range(epochs):
        running_loss = 0.0
        pbar = tqdm(train_reader)
        for i, data in enumerate(pbar):
            inputs, labels = data
            inputs, labels = inputs.float().to(device), labels.long().to(
                device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            running_loss += loss.item()
            if i % verb_step == verb_step - 1:
                pbar.set_description(
                    'Epoch {} Step {}: train cross entropy loss: {:.4f}'.
                    format(epoch + 1, i + 1, running_loss / verb_step))
                running_loss = 0.0

        # validation
        correct = 0
        total = 0
        truth = []
        pred = []
        with torch.no_grad():
            for data in valid_reader:
                inputs, labels = data
                inputs, labels = inputs.float().to(device), labels.long().to(
                    device)
                outputs = net(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                truth.extend(labels.cpu().numpy())
                pred.extend(predicted.cpu().numpy())
        p, r, f1 = utils.f1_score(truth, pred, 0)
        print(
            'Epoch {}: valid accuracy: {:.2f}, precision: {:.2f}, recall: {:.2f}, f1: {:.2f}'
            .format(epoch + 1, 100 * correct / total, p, r, f1))

        if epoch % save_epoch == 0 and epoch != 0:
            save_name = os.path.join(save_dir,
                                     'epoch-{}.pth.tar'.format(epoch))
            torch.save(
                {
                    'epoch': epochs,
                    'state_dict': net.state_dict(),
                    'opt_dict': optimizer.state_dict(),
                }, save_name)
            print('Saved model at {}'.format(save_name))

    print('Finished training')
                    metavar='ARCH',
                    default='resnet50',
                    choices=model_names,
                    help='model architecture: ' + ' | '.join(model_names) +
                    ' (default: resnet50)')
parser.add_argument('-g', '--groups', default=4, type=int)
parser.add_argument('-r', '--compression-rate', default=2.0, type=float)
parser.add_argument('--progressive',
                    action='store_true',
                    help='compression mode: ' + ' | '.join(modes) +
                    ' (default: uniform)')

args = parser.parse_args()

model = resnet50(pretrained=False, g=args.groups, r=args.compression_rate, progressive=args.progressive) if args.arch == 'resnet50' else \
                 mobilenet_v2(pretrained=False, g=args.groups, r=args.compression_rate, progressive=args.progressive)

checkpoint = torch.load(args.checkpoint)
checkpoint = checkpoint['state_dict'] if 'state_dict' in checkpoint.keys(
) else checkpoint
state_dict = model.state_dict()
loaded_dict = {k: v for k, v in checkpoint.items() if k in state_dict}
if not bool(loaded_dict):  # empty dictionary if model was trained in parallel
    loaded_dict = {
        k[7:]: v
        for k, v in checkpoint.items() if k[7:] in state_dict
    }
state_dict.update(loaded_dict)
model.load_state_dict(state_dict)
model.cuda()
model = nn.DataParallel(model)