コード例 #1
0
ファイル: tests.py プロジェクト: IgorDavidyuk/mobilenetv3_FPN
def test_loaded_weights():
    torch.backends.cudnn.deterministic = True
    path = '/home/davidyuk/Projects/backbones/pytorch-mobilenet-v3/mobilenetv3_small_67.4.pth.tar'
    mn3_fpn = MobileNetV3_forFPN()
    mn3_fpn = load_pretrained_fpn(mn3_fpn, path)

    mobNetv3 = MobileNetV3(mode='small')
    state_dict = torch.load(path, map_location='cpu')
    mobNetv3.load_state_dict(state_dict)
    mobNetv3 = mobNetv3.features[:12]
    for param, base_param in zip(mn3_fpn.parameters(), mobNetv3.parameters()):
        assert ((param == base_param).all()), 'params differ'
    #print(len(tuple(mn3_fpn.parameters())),len(tuple(mobNetv3.parameters())))
    # mobNetv3.eval()
    # mn3_fpn.eval()

    image = torch.rand(1, 3, 224, 224)
    with torch.no_grad():
        output = mn3_fpn.forward(image)
        output1 = mobNetv3.forward(image)

    if (output == output1).all():
        print('test passed')
    else:
        print('test failed')
    torch.backends.cudnn.deterministic = False
コード例 #2
0
def main():
    # Parse the JSON arguments
    config_args = parse_args()

    # Create the experiment directories
    #_, config_args.summary_dir, config_args.checkpoint_dir = create_experiment_dirs(
    #    config_args.experiment_dir)

    model = MobileNetV3(n_class=200,
                        input_size=64,
                        classify=config_args.classify)

    if config_args.cuda:
        model.cuda()
        cudnn.enabled = True
        cudnn.benchmark = True

    print("Loading Data...")
    data = TinyImagenet(config_args)
    print("Data loaded successfully\n")

    trainer = Train(model, data.trainloader, data.testloader, config_args)

    if config_args.to_train:
        try:
            print("Training...")
            trainer.train()
            print("Training Finished\n")
        except KeyboardInterrupt:
            pass

    if config_args.to_test:
        print("Testing...")
        trainer.test(data.testloader)
        print("Testing Finished\n")
コード例 #3
0
 def __init__(self, device="cpu", jit=False):
     """ Required """
     self.device = device
     self.jit = jit
     self.model = MobileNetV3()
     if self.jit:
         self.model = torch.jit.script(self.model)
     input_size = (1, 3, 224, 224)
     self.example_inputs = (torch.randn(input_size),)
コード例 #4
0
def load_pretrained_fpn(model, path):
    import torch
    '''
    This function copies weights to a given model from a given checkpoint through vanilla model
    '''
    mobNetv3 = MobileNetV3(mode='small')
    state_dict = torch.load(path, map_location='cpu')
    mobNetv3.load_state_dict(state_dict)
    for param, base_param in zip(model.parameters(), mobNetv3.parameters()):
        if param.size() == base_param.size():
            param.data = base_param.data
        else:
            print('wrong size')
    return model
コード例 #5
0
def convert_mobilenetv3():
    num_classes = 9
    model = MobileNetV3(n_class=num_classes,
                        mode="small",
                        dropout=0.2,
                        width_mult=1.0)
    # model = MobileNetV3(n_class=num_classes, mode="large", dropout=0.2, width_mult=1.0)
    # 加载模型参数
    path = r"checkpoint/mobilenetv3/000/moblienetv3_s_my_acc=65.4676.pth"
    to_path = r"checkpoint/mobilenetv3/000/moblienetv3_s_my_acc=65.4676.onnx"

    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint["net"])
    print("loaded model with acc:{}".format(checkpoint["acc"]))
    model.cuda()
    # dummy_input = torch.randn(10, 3, 224, 224)
    dummy_input = torch.randn(1, 3, 224, 224, device='cuda')
    torch.onnx.export(model, dummy_input, to_path, verbose=True)
コード例 #6
0
def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
    traindir = os.path.join(args.data_path, 'train')
    valdir = os.path.join(args.data_path, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    print("Loading training data")
    st = time.time()
    cache_path = _get_cache_path(traindir)
    if args.cache_dataset and os.path.exists(cache_path):
        # Attention, as the transforms are also cached!
        print("Loading dataset_train from {}".format(cache_path))
        dataset, _ = torch.load(cache_path)
    else:
        dataset = torchvision.datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        if args.cache_dataset:
            print("Saving dataset_train to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)
    print("Took", time.time() - st)

    print("Loading validation data")
    cache_path = _get_cache_path(valdir)
    if args.cache_dataset and os.path.exists(cache_path):
        # Attention, as the transforms are also cached!
        print("Loading dataset_test from {}".format(cache_path))
        dataset_test, _ = torch.load(cache_path)
    else:
        dataset_test = torchvision.datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))
        if args.cache_dataset:
            print("Saving dataset_test to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.batch_size,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

    print("Creating model")
    if args.mode == 'large':
        model = MobileNetV3(mode='large')
    elif args.mode == 'small':
        model = MobileNetV3(mode='small')
    else:
        raise ValueError(
            "Expecting right mode of MobileNetv3: 'small' or 'large'")

    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.data_parallel:
        model = torch.nn.DataParallel(model)
        model_without_ddp = model.module

    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=args.lr_step_size,
                                                   gamma=args.lr_gamma)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device, None)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, criterion, optimizer, data_loader, device,
                        epoch, args.print_freq)
        lr_scheduler.step()
        evaluate(model, criterion, data_loader_test, device, epoch)
        if args.output_dir:
            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, 'checkpoint.pth'))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
コード例 #7
0
from utils import calc_dataset_stats
from dataset import custom

from dataset.tinyimagenet import TinyImagenet
from mobilenetv3 import MobileNetV3
from train import Train
from utils import parse_args, create_experiment_dirs
import torch
import os

if __name__ == '__main__':

    model = MobileNetV3(n_class=200, input_size=64)
    model.cuda()
    model_dict = model.state_dict()
    config_args = parse_args()

    print("Loading Data...")
    data = TinyImagenet(config_args)
    print("Data loaded successfully\n")

    trainer = Train(model, data.trainloader, data.testloader, config_args)
    trainer.train()