Пример #1
0
 def test_vgg16_bn(self):
     self.run_model_test(vgg16_bn(), train=False,
                         batch_size=BATCH_SIZE)
Пример #2
0
 def test_vgg16_bn(self):
     # VGG 16-layer model (configuration "D") with batch normalization
     x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
     self.exportTest(toC(vgg16_bn()), toC(x))
Пример #3
0
 def test_vgg16_bn(self):
     self.run_model_test(vgg16_bn(), train=False, batch_size=BATCH_SIZE)
Пример #4
0
def construct_unet(n_cls):  # no weights inited
    model = vgg16_bn(pretrained=False)
    encoder_blocks = _get_encoder_blocks(model)
    encoder_channels = [64, 128, 256, 512, 1024]  # vgg16 channels

    return UNet(encoder_blocks, encoder_channels, n_cls)
Пример #5
0
 def test_vgg16_bn(self):
     # VGG 16-layer model (configuration "D") with batch normalization
     x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
     self.exportTest(toC(vgg16_bn()), toC(x))
Пример #6
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    # create model
    if args.network_name == 'vgg19':
        model = vgg19_bn(pretrained=False)
    elif args.network_name == 'vgg16':
        model = vgg16_bn(pretrained=False)
    elif 'resnet' in args.network_name:
        model = models.__dict__[args.arch](pretrained=False)
    else:
        raise NotImplementedError

    # Initialize network
    for layer in model.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            if args.init == 'normal_kaiming':
                nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
            elif args.init == 'normal_kaiming_fout':
                nn.init.kaiming_normal_(layer.weight,
                                        nonlinearity='relu',
                                        mode='fan_out')
            elif args.init == 'normal_xavier':
                nn.init.xavier_normal_(layer.weight)
            elif args.init == 'orthogonal':
                nn.init.orthogonal_(layer.weight)
            else:
                raise ValueError(
                    f"Unrecognised initialisation parameter {args.init}")

    model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    #############################
    ####    Pruning code     ####
    #############################

    pruning_factor = args.pruning_factor
    keep_masks = []
    filename = ''
    if pruning_factor != 1:
        print(f'Pruning network iteratively for {args.num_steps} steps')
        keep_masks = iterative_pruning(model,
                                       train_loader,
                                       device,
                                       pruning_factor,
                                       prune_method=args.prune_method,
                                       num_steps=args.num_steps,
                                       mode=args.mode,
                                       num_batches=args.num_batches)

        apply_prune_mask(model, keep_masks)

    # File where to save training history
    run_name = (args.network_name + '_IMAGENET' + '_spars' +
                str(1 - pruning_factor) + '_variant' + str(args.prune_method) +
                '_train-frac' + str(args.frac_data_for_train) +
                f'_steps{args.num_steps}_{args.mode}' + f'_{args.init}' +
                f'_batch{args.num_batches}' + f'_rseed_{seed}')
    writer_name = 'runs/' + run_name
    writer = SummaryWriter(writer_name)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    iterations = 0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # Train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args, writer)

        # Evaluate on validation set
        iterations = epoch * len(train_loader)
        acc1 = validate(val_loader, model, criterion, args, writer, iterations)

        # Save checkpoint
        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            if (epoch + 1) % 5 == 0:
                if not os.path.exists('saved_models/'):
                    os.makedirs('saved_models/')
                save_name = 'saved_models/' + run_name + '_cross_entropy_' + str(
                    epoch + 1) + '.model'
                torch.save(model.state_dict(), save_name)
            elif (epoch + 1) == args.epochs:
                if not os.path.exists('saved_models/'):
                    os.makedirs('saved_models/')
                save_name = 'saved_models/' + run_name + '_cross_entropy_' + str(
                    epoch + 1) + '.model'
                torch.save(model.state_dict(), save_name)