コード例 #1
0
def create_model(model_name,
                 model_state_file=None,
                 gpus=[0, 1, 2, 3],
                 label_refinery=None,
                 label_refinery_state_file=None,
                 coslinear=True,
                 scale=5.0):
    model = _create_single_cpu_model(model_name, model_state_file, coslinear)
    if label_refinery is not None:
        assert label_refinery_state_file is not None, "Refinery state is None."
        label_refinery = _create_single_cpu_model(label_refinery,
                                                  label_refinery_state_file,
                                                  coslinear)
        model = model_refinery_wrapper.ModelRefineryWrapper(
            model, label_refinery, scale)
        loss = refinery_loss.RefineryLoss(cosln=coslinear, scl=scale)
    else:
        if coslinear:
            print('Using other loss')
            loss = Margloss(s=scale)
        else:
            print('Using CrossEntropyLoss')
            # loss = F.cross_entropy
            loss = nn.CrossEntropyLoss()

    if len(gpus) > 0:
        model = model.cuda()
        loss = loss.cuda()
    if len(gpus) > 1:
        model = data_parallel.DataParallel(model, device_ids=gpus)
    return model, loss
コード例 #2
0
def create_model(model_name,
                 model_state_file=None,
                 gpus=[],
                 label_refinery=None,
                 label_refinery_state_file=None):
    model = _create_single_cpu_model(model_name, model_state_file)
    if label_refinery is not None:
        assert label_refinery_state_file is not None, "Refinery state is None."
        label_refinery = _create_single_cpu_model(label_refinery,
                                                  label_refinery_state_file)
        model = model_refinery_wrapper.ModelRefineryWrapper(
            model, label_refinery)
        loss = refinery_loss.RefineryLoss()
    else:
        loss = nn.CrossEntropyLoss()

    if len(gpus) > 0:
        model = model.cuda()
        loss = loss.cuda()
    if len(gpus) > 1:
        model = data_parallel.DataParallel(model, device_ids=gpus)
    return model, loss
コード例 #3
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    ## teacher model
    teacher_model = models.__dict__[args.teacher](
        pretrained=args.pretrained_teacher)
    teacher_model = torch.nn.DataParallel(teacher_model).cuda()

    for param in teacher_model.parameters():
        param.requires_grad = False

    ## student model
    model = models.__dict__[args.arch](pretrained=False,
                                       dropout=args.dropout,
                                       dropconnect=args.dropconnect)

    print("adding quan op ('{}bit')...".format(args.act_bit_width))
    scales = np.load(args.scales)
    idx = 0
    for m in model.modules():
        if isinstance(m, Quantization):
            m.set_quantization_parameters(signed[idx], args.act_bit_width,
                                          scales[idx])
            quan_modules.append(m)
            idx += 1

    model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    criterion_refinery = refinery_loss.RefineryLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True

    if 'alexnet' in args.arch:
        input_size = 227
    else:
        input_size = 224

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.225, 0.225, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            #transforms.RandomResizedCrop(input_size),
            transforms.RandomResizedCrop(input_size,
                                         interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    args.epoch_size = len(train_dataset) // args.batch_size

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(
            valdir,
            transforms.Compose([
                #transforms.Resize(256),
                transforms.Resize(int(input_size / 0.875),
                                  interpolation=Image.BICUBIC),  # == 256
                transforms.CenterCrop(input_size),
                transforms.ToTensor(),
                normalize,
            ])),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True)

    global prune_op
    prune_op = PruneOp(model, target_sparsity)
    # prune_op.init_pruning()

    if args.pretrained:
        state_dict = torch.load(args.pretrained)
        if 'state_dict' in state_dict:
            prune_op.set_masks(state_dict['masks'])
            state_dict = state_dict['state_dict']

        new_state_dict = OrderedDict()
        for key_ori, key_pre in zip(model.state_dict().keys(),
                                    state_dict.keys()):
            new_state_dict[key_ori] = state_dict[key_pre]
        model.load_state_dict(new_state_dict)

        prune_op.init_pruning()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            prune_op.set_masks(checkpoint['masks'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    print(args)
    print('Param sparsit:', prune_op.get_sparsity())

    prune_op.mask_params()

    # enable feature map quantization
    for index, q_module in enumerate(quan_modules):
        if q_module.signed is not None:
            q_module.enable_quantization()

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    print('training...')
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        # adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, teacher_model, model, criterion_refinery,
              optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        print('BEST_ACC1:', best_acc1)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                    'masks': prune_op.get_masks(),
                },
                is_best,
                path=args.save_path)