def train(net, epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    mask_channel = torch.load('mask_null.dat')
    mask_channel = utils.setMask(mask_channel, 4, 1)
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()

        net = utils.netMaskMul(net, mask_channel)
        if args.fixed == 1:
            net = utils.quantize(net, args.pprec)
        optimizer.step()

        train_loss += loss.data.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        #correct += predicted.eq(targets.data).cpu().sum().item()
        correct += float(predicted.eq(targets.data).cpu().sum())

        progress_bar(
            batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (train_loss /
             (batch_idx + 1), 100. * correct / total, correct, total))
Пример #2
0
def test(net):
    global glob_gau
    global glob_blur
    global best_acc
    glob_blur = 1
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    mask_channel = torch.load('mask_null.dat')
    mask_channel = utils.setMask(utils.setMask(mask_channel, 3, 1), 4, 0)
    if args.mode > 0:
        net = utils.netMaskMul(net, mask_channel)
        net = utils.addNetwork(net, net2)
    if args.fixed == 1:
        net = utils.quantize(net, args.pprec)
    for batch_idx, (inputs, targets) in enumerate(test_loader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = Variable(inputs), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)

        test_loss += loss.data.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += float(predicted.eq(targets.data).cpu().sum())

        progress_bar(
            batch_idx, len(test_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (test_loss /
             (batch_idx + 1), 100. * correct / total, correct, total))

    # Save checkpoint.
    acc = 100. * correct / total
    if acc > best_acc:

        state = {
            'net': net.module if use_cuda else net,
            'acc': acc,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        if args.mode == 0:
            pass
        else:
            print('Saving..')
            torch.save(state, './checkpoint/ckpt_20190802_half_clean_B1.t0')
        best_acc = acc

    return acc