예제 #1
0
def test_one_set(loader, device, net, categories, num_classes, output_size,
                 is_mixed_precision):
    # Evaluate on 1 data_loader (cudnn impact < 0.003%)
    net.eval()
    conf_mat = ConfusionMatrix(num_classes)
    with torch.no_grad():
        for image, target in tqdm(loader):
            image, target = image.to(device), target.to(device)
            with autocast(is_mixed_precision):
                output = net(image)['out']
                output = torch.nn.functional.interpolate(output,
                                                         size=output_size,
                                                         mode='bilinear',
                                                         align_corners=True)
                conf_mat.update(target.flatten(), output.argmax(1).flatten())

    acc_global, acc, iu = conf_mat.compute()
    print(categories)
    print(('global correct: {:.2f}\n'
           'average row correct: {}\n'
           'IoU: {}\n'
           'mean IoU: {:.2f}').format(
               acc_global.item() * 100,
               ['{:.2f}'.format(i) for i in (acc * 100).tolist()],
               ['{:.2f}'.format(i) for i in (iu * 100).tolist()],
               iu.mean().item() * 100))

    return acc_global.item() * 100, iu.mean().item() * 100
예제 #2
0
def train(writer,
          loader_c,
          loader_sup,
          validation_loader,
          device,
          criterion,
          net,
          optimizer,
          lr_scheduler,
          num_epochs,
          is_mixed_precision,
          with_sup,
          num_classes,
          categories,
          input_sizes,
          val_num_steps=1000,
          loss_freq=10,
          tensorboard_prefix='',
          best_mIoU=0):
    #######
    # c for carry (pseudo labeled), sup for support (labeled with ground truth) -_-
    # Don't ask me why
    #######
    # Poly training schedule
    # Epoch length measured by "carry" (c) loader
    # Batch ratio is determined by loaders' own batch size
    # Validate and find the best snapshot per val_num_steps
    loss_num_steps = int(len(loader_c) / loss_freq)
    net.train()
    epoch = 0
    if with_sup:
        iter_sup = iter(loader_sup)

    if is_mixed_precision:
        scaler = GradScaler()

    # Training
    running_stats = {
        'disagree': -1,
        'current_win': -1,
        'avg_weights': 1.0,
        'loss': 0.0
    }
    while epoch < num_epochs:
        conf_mat = ConfusionMatrix(num_classes)
        time_now = time.time()
        for i, data in enumerate(loader_c, 0):
            # Combine loaders (maybe just alternate training will work)
            if with_sup:
                inputs_c, labels_c = data
                inputs_sup, labels_sup = next(iter_sup, (0, 0))
                if type(inputs_sup) == type(labels_sup) == int:
                    iter_sup = iter(loader_sup)
                    inputs_sup, labels_sup = next(iter_sup, (0, 0))

                # Formatting (prob: label + max confidence, label: just label)
                float_labels_sup = labels_sup.clone().float().unsqueeze(1)
                probs_sup = torch.cat(
                    [float_labels_sup,
                     torch.ones_like(float_labels_sup)],
                    dim=1)
                probs_c = labels_c.clone()
                labels_c = labels_c[:, 0, :, :].long()

                # Concatenating
                inputs = torch.cat([inputs_c, inputs_sup])
                labels = torch.cat([labels_c, labels_sup])
                probs = torch.cat([probs_c, probs_sup])

                probs = probs.to(device)
            else:
                inputs, labels = data

            # Normal training
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            with autocast(is_mixed_precision):
                outputs = net(inputs)['out']
                outputs = torch.nn.functional.interpolate(outputs,
                                                          size=input_sizes[0],
                                                          mode='bilinear',
                                                          align_corners=True)
                conf_mat.update(labels.flatten(), outputs.argmax(1).flatten())

                if with_sup:
                    loss, stats = criterion(outputs, probs, inputs_c.shape[0])
                else:
                    loss, stats = criterion(outputs, labels)

            if is_mixed_precision:
                accelerator.backward(scaler.scale(loss))
                scaler.step(optimizer)
                scaler.update()
            else:
                accelerator.backward(loss)
                optimizer.step()

            lr_scheduler.step()

            # Logging
            for key in stats.keys():
                running_stats[key] += stats[key]
            current_step_num = int(epoch * len(loader_c) + i + 1)
            if current_step_num % loss_num_steps == (loss_num_steps - 1):
                for key in running_stats.keys():
                    print('[%d, %d] ' % (epoch + 1, i + 1) + key + ' : %.4f' %
                          (running_stats[key] / loss_num_steps))
                    writer.add_scalar(tensorboard_prefix + key,
                                      running_stats[key] / loss_num_steps,
                                      current_step_num)
                    running_stats[key] = 0.0

            # Validate and find the best snapshot
            if current_step_num % val_num_steps == (val_num_steps - 1) or \
                current_step_num == num_epochs * len(loader_c) - 1:
                # Apex bug https://github.com/NVIDIA/apex/issues/706, fixed in PyTorch1.6, kept here for BC
                test_pixel_accuracy, test_mIoU = test_one_set(
                    loader=validation_loader,
                    device=device,
                    net=net,
                    num_classes=num_classes,
                    categories=categories,
                    output_size=input_sizes[2],
                    is_mixed_precision=is_mixed_precision)
                writer.add_scalar(tensorboard_prefix + 'test pixel accuracy',
                                  test_pixel_accuracy, current_step_num)
                writer.add_scalar(tensorboard_prefix + 'test mIoU', test_mIoU,
                                  current_step_num)
                net.train()

                # Record best model(Straight to disk)
                if test_mIoU > best_mIoU:
                    best_mIoU = test_mIoU
                    save_checkpoint(net=net,
                                    optimizer=optimizer,
                                    lr_scheduler=lr_scheduler,
                                    is_mixed_precision=is_mixed_precision)

        # Evaluate training accuracies(same metric as validation, but must be on-the-fly to save time)
        acc_global, acc, iu = conf_mat.compute()
        print(categories)
        print(('global correct: {:.2f}\n'
               'average row correct: {}\n'
               'IoU: {}\n'
               'mean IoU: {:.2f}').format(
                   acc_global.item() * 100,
                   ['{:.2f}'.format(i) for i in (acc * 100).tolist()],
                   ['{:.2f}'.format(i) for i in (iu * 100).tolist()],
                   iu.mean().item() * 100))

        train_pixel_acc = acc_global.item() * 100
        train_mIoU = iu.mean().item() * 100
        writer.add_scalar(tensorboard_prefix + 'train pixel accuracy',
                          train_pixel_acc, epoch + 1)
        writer.add_scalar(tensorboard_prefix + 'train mIoU', train_mIoU,
                          epoch + 1)

        epoch += 1
        print('Epoch time: %.2fs' % (time.time() - time_now))

    return best_mIoU