def get_flops(net, input_shape=(1, 3, 300, 300)):
    from flops_benchmark import add_flops_counting_methods, start_flops_count
    input = torch.ones(input_shape)
    input = torch.autograd.Variable(input)

    net = add_flops_counting_methods(net)
    net = net.train()
    net.start_flops_count()

    _ = net(input)

    return net.compute_average_flops_cost() / 1e9 / 2
Example #2
0
def get_flops(net, input_size=(300, 300)):
    input_size = (1, 3, input_size[0], input_size[1])
    input = torch.randn(input_size)
    input = torch.autograd.Variable(input.cuda())

    net = add_flops_counting_methods(net)
    net = net.cuda().eval()
    net.start_flops_count()

    _ = net(input)

    return net.compute_average_flops_cost()/1e9/2
Example #3
0
def get_flops(net, input_size=(1, 3, 224, 224), method='benchmark'):
    if method == 'profile':
        flops, param = profile(net, input_size)
        flops = flops / 1e6
    else:
        inputs = torch.randn(input_size)

        net = add_flops_counting_methods(net)
        net = net.eval()
        net.start_flops_count()

        _ = net(Variable(inputs,volatile=True))
        flops = net.compute_average_flops_cost() / 1e6 / 2
    return flops
Example #4
0
def main():
    print(time.ctime())
    global args, best_prec1
    args = parser.parse_args()
    torch.manual_seed(args.seed)

    torch.cuda.manual_seed(args.seed)

    if args.visdom:
        global plotter
        plotter = VisdomLinePlotter(env_name=args.expname)

    # set the target rates for each layer
    # the default is to use the same target rate for each layer
    target_rates_list = [1.0] * 33
    for i in range(7, 30):
        target_rates_list[i] = 0.5
    target_rates = {
        i: target_rates_list[i]
        for i in range(len(target_rates_list))
    }

    model = ResNet101_ImageNet()

    # optionally initialize from pretrained
    if args.pretrained:
        latest_checkpoint = args.pretrained
        if os.path.isfile(latest_checkpoint):
            print("=> loading checkpoint '{}'".format(latest_checkpoint))
            # TODO: clean this part up
            checkpoint = torch.load(latest_checkpoint)
            state = model.state_dict()
            loaded_state_dict = checkpoint
            for k in loaded_state_dict:
                if k in state:
                    state[k] = loaded_state_dict[k]
                else:
                    if 'fc' in k:
                        state[k.replace('fc', 'linear')] = loaded_state_dict[k]
                    if 'downsample' in k:
                        state[k.replace('downsample',
                                        'shortcut')] = loaded_state_dict[k]
            model.load_state_dict(state)
            print("=> loaded checkpoint '{}' (epoch {})".format(
                latest_checkpoint, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(latest_checkpoint))

    model = torch.nn.DataParallel(model).cuda()
    model = add_flops_counting_methods(model)
    model.start_flops_count()

    # ImageNet Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_loader = FacesDataset(
        "../images/train",
        "../images/train_labels.csv",
        transforms.Compose([
            transforms.ColorJitter(),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15,
                                      resample=False,
                                      expand=False,
                                      center=None),
            #transforms.Scale(224),
            transforms.ToTensor(),
            #normalize,
        ]))

    val_loader = FacesDataset("../images/val", "../images/val_labels.csv",
                              transforms.Compose([transforms.ToTensor()]))

    train_loader = torch.utils.data.DataLoader(train_loader,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=2)
    val_loader = torch.utils.data.DataLoader(val_loader,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=2)

    # optionally resume from a checkpoint
    if args.resume:
        latest_checkpoint = os.path.join(args.resume, 'checkpoint.pth.tar')
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD([{
        'params':
        [param for name, param in model.named_parameters() if 'fc' in name],
        'lr':
        args.lrfact * args.lr,
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            param
            for name, param in model.named_parameters() if 'fc' not in name
        ],
        'lr':
        args.lr,
        'weight_decay':
        args.weight_decay
    }],
                          momentum=args.momentum)

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    if args.test:
        test_acc = validate(val_loader, model, criterion, 60, target_rates)
        sys.exit()

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, target_rates)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch, target_rates)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
    print('Best accuracy: ', best_prec1)
Example #5
0
def validate(val_loader, model, criterion, epoch, target_rates):
    """Perform validation on the validation set"""
    print(time.ctime())
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    accumulator = ActivationAccum_img(epoch)
    activations = AverageMeter()

    # Temperature of Gumble Softmax
    # We simply keep it fixed
    temp = 1

    # switch to evaluate mode
    model.eval()

    model = add_flops_counting_methods(model)
    model.start_flops_count()

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda(async=True)
            input = input.cuda()
            input_var = torch.autograd.Variable(input)
            target_var = torch.autograd.Variable(target)

            # compute output
            #model.reset_flops_count()
            output, activation_rates = model(input_var, temperature=temp)
            #print(model.compute_average_flops_cost()/ 1e9 / 2)
            #print(target.data.cpu().numpy()) #.data[0].cpu().numpy()) #for i in range(args.batch_size))
            '''
            op = output.data.cpu().numpy()
            #print(output.size(), target_var.size())
            for i in range(args.batch_size):
                print(target.data.cpu().numpy()[i],end=' ')
                print(np.argmax(op[i]),end=' ')
                for j in range(33):
                    print(activation_rates[j].cpu().numpy()[i][0][0],end=' ')
                print()
            '''

            # classification loss
            loss = criterion(output, target_var)
            #print("Here")
            acts = 0
            for j, act in enumerate(activation_rates):
                if target_rates[j] <= 1:
                    acts += torch.mean(act)
                else:
                    acts += 1
            # this is important when using data DataParallel
            acts = torch.mean(acts / len(activation_rates))

            # see above
            if math.isnan(acts.item()):
                continue

            #print("Here")

            # accumulate statistics over eval set
            #accumulator.accumulate(activation_rates, target_var, target_rates)
            #print("Here")

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))
            activations.update(acts.item(), 1)
            #print("Here")

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            #print("Here")

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'
                      'Activations: {act.val:.3f} ({act.avg:.3f})'.format(
                          i,
                          len(val_loader),
                          batch_time=batch_time,
                          loss=losses,
                          top1=top1,
                          top5=top5,
                          act=activations))

    activ_output = accumulator.getoutput()

    print('gate activation rates:')
    print(activ_output[0])

    print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))

    if args.visdom:
        plotter.plot('act', 'test', epoch, activations.avg)
        plotter.plot('top1', 'test', epoch, top1.avg)
        plotter.plot('top5', 'test', epoch, top5.avg)
        plotter.plot('loss', 'test', epoch, losses.avg)
        for gate in activ_output[0]:
            plotter.plot('gates', '{}'.format(gate), epoch,
                         activ_output[0][gate])

        # Plot more detailed stats like activation heatmaps for key epochs
        if epoch in [30, 60, 99]:
            for category in activ_output[1]:
                plotter.plot('classes', '{}'.format(category), epoch,
                             activ_output[1][category])

            heatmap = activ_output[2]
            means = np.mean(heatmap, axis=0)
            stds = np.std(heatmap, axis=0)
            normalized_stds = np.array(stds / (means + 1e-10)).squeeze()

            plotter.plot_heatmap(activ_output[2], epoch)
            for counter in range(len(normalized_stds)):
                plotter.plot('activations{}'.format(epoch), 'activations',
                             counter, normalized_stds[counter])
            for counter in range(len(means)):
                plotter.plot('opening_rate{}'.format(epoch), 'opening_rate',
                             counter, means[counter])
    return top1.avg