Exemplo n.º 1
0
def test(epoch, analyzer):
  analyzer.start_test(epoch)

  net.eval()
  test_loss = 0
  correct = 0
  total = 0
  with torch.no_grad():
      for batch_idx, (inputs, targets) in enumerate(testloader):
          inputs, targets = inputs.to(device), targets.to(device)
          outputs = net(inputs)
          loss = criterion(outputs, targets)

          test_loss += loss.item()
          _, predicted = outputs.max(1)
          total += targets.size(0)
          correct += predicted.eq(targets).sum().item()

          if device == 'cuda':
              predicted = predicted.cpu()
              targets = targets.cpu()

          predicted_nbdt = analyzer.update_batch(outputs, targets)
          print(predicted.eq(targets).sum())
          print(predicted_nbdt.eq(targets).sum())
          progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d) '
              % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))


  acc = 100.*correct/total
  print("Accuracy: {}, {}/{}".format(acc, correct, total))

  analyzer.end_test(epoch)
Exemplo n.º 2
0
    def train(epoch):
        if hasattr(criterion, "set_epoch"):
            criterion.set_epoch(epoch, args.epochs)

        print("\nEpoch: %d / LR: %.04f" % (epoch, scheduler.get_last_lr()[0]))
        net.train()
        train_loss = 0
        metric.clear()
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            metric.forward(outputs, targets)
            transform = trainset.transform_val_inverse().to(device)
            stat = analyzer.update_batch(outputs, targets, transform(inputs))

            progress_bar(
                batch_idx,
                len(trainloader),
                "Loss: %.3f | Acc: %.3f%% (%d/%d) %s" % (
                    train_loss / (batch_idx + 1),
                    100.0 * metric.report(),
                    metric.correct,
                    metric.total,
                    f"| {analyzer.name}: {stat}" if stat else "",
                ),
            )
        scheduler.step()
def train(epoch, analyzer):
    analyzer.start_train(epoch)
    lr = adjust_learning_rate(epoch, args.lr)
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)

    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        stat = analyzer.update_batch(outputs, targets)
        extra = f'| {stat}' if stat else ''

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d) %s'
            % (train_loss/(batch_idx+1), 100.*correct/total, correct, total, extra))

    analyzer.end_train(epoch)
def test(epoch, analyzer, checkpoint=True):
    analyzer.start_test(epoch)

    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            if device == 'cuda':
                predicted = predicted.cpu()
                targets = targets.cpu()

            stat = analyzer.update_batch(outputs, targets)
            extra = f'| {stat}' if stat else ''

            progress_bar(
                batch_idx, len(testloader),
                'Loss: %.3f | Acc: %.3f%% (%d/%d) %s' %
                (test_loss / (batch_idx + 1), 100. * correct / total, correct,
                 total, extra))

    # Save checkpoint.
    acc = 100. * correct / total
    print("Accuracy: {}, {}/{}".format(acc, correct, total))
    if acc > best_acc and checkpoint:
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')

        print(f'Saving to {checkpoint_fname} ({acc})..')
        torch.save(state, f'./checkpoint/{checkpoint_fname}.pth')
        best_acc = acc

    analyzer.end_test(epoch)
Exemplo n.º 5
0
    def test(epoch, checkpoint=True):
        nonlocal best_acc
        net.eval()
        test_loss = 0
        metric.clear()
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)

                if not args.disable_test_eval:
                    loss = criterion(outputs, targets)
                    test_loss += loss.item()
                    metric.forward(outputs, targets)
                transform = testset.transform_val_inverse().to(device)
                stat = analyzer.update_batch(outputs, targets,
                                             transform(inputs))

                progress_bar(
                    batch_idx,
                    len(testloader),
                    "Loss: %.3f | Acc: %.3f%% (%d/%d) %s" % (
                        test_loss / (batch_idx + 1),
                        100.0 * metric.report(),
                        metric.correct,
                        metric.total,
                        f"| {analyzer.name}: {stat}" if stat else "",
                    ),
                )

        # Save checkpoint.
        acc = 100.0 * metric.report()
        print("Accuracy: {}, {}/{} | Best Accurracy: {}".format(
            acc, metric.correct, metric.total, best_acc))
        if acc > best_acc and checkpoint:
            Colors.green(f"Saving to {checkpoint_fname} ({acc})..")
            state = {
                "net": net.state_dict(),
                "acc": acc,
                "epoch": epoch,
            }
            os.makedirs("checkpoint", exist_ok=True)
            torch.save(state, f"./checkpoint/{checkpoint_fname}.pth")
            best_acc = acc