Exemplo n.º 1
0
from torch.utils.data import DataLoader
from optparse import OptionParser
from torchvision.datasets import ImageFolder
from torchvision import transforms, utils,datasets
from utils import CenterLoss, AverageMeter, TopKAccuracyMetric, ModelCheckpoint, batch_augment
from models import WSDAN,inception_v3
from dataset import *
device = torch.device("cuda")

# General loss functions
cross_entropy_loss = nn.CrossEntropyLoss()
center_loss = CenterLoss()

# loss and metric
loss_container = AverageMeter(name='loss')
raw_metric = TopKAccuracyMetric(topk=(1, 3))
crop_metric = TopKAccuracyMetric(topk=(1, 3))
drop_metric = TopKAccuracyMetric(topk=(1, 3))

def main():
    parser = OptionParser()
    parser.add_option('-j', '--workers', dest='workers', default=16, type='int',
                      help='number of data loading workers (default: 16)')
    parser.add_option('-e', '--epochs', dest='epochs', default=80, type='int',
                      help='number of epochs (default: 80)')
    parser.add_option('-b', '--batch-size', dest='batch_size', default=16, type='int',
                      help='batch size (default: 16)')
    parser.add_option('-c', '--ckpt', dest='ckpt', default=False,
                      help='load checkpoint model (default: False)')
    parser.add_option('-v', '--verbose', dest='verbose', default=100, type='int',
                      help='show information for each <verbose> iterations (default: 100)')
Exemplo n.º 2
0
def main():
    logging.basicConfig(
        format=
        '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
        level=logging.INFO)
    warnings.filterwarnings("ignore")

    try:
        ckpt = sys.argv[1]
    except:
        logging.info('Usage: python3 eval.py <model.ckpt>')
        return

    ##################################
    # Dataset for testing
    ##################################
    test_dataset = CarDataset(phase='test', resize=448)
    test_loader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=1,
                             pin_memory=True)

    ##################################
    # Initialize model
    ##################################
    net = WSDAN(num_classes=test_dataset.num_classes,
                M=32,
                net='inception_mixed_6e')

    # Load ckpt and get state_dict
    checkpoint = torch.load(ckpt)
    state_dict = checkpoint['state_dict']

    # Load weights
    net.load_state_dict(state_dict)
    logging.info('Network loaded from {}'.format(ckpt))

    ##################################
    # use cuda
    ##################################
    cudnn.benchmark = True
    net.to(device)
    net = nn.DataParallel(net)
    net.eval()

    ##################################
    # Prediction
    ##################################
    accuracy = TopKAccuracyMetric(topk=(1, 5))
    accuracy.reset()

    with torch.no_grad():
        pbar = tqdm(total=len(test_loader), unit=' batches')
        pbar.set_description('Validation')
        for i, (X, y) in enumerate(test_loader):
            X = X.to(device)
            y = y.to(device)

            # WS-DAN
            y_pred_raw, feature_matrix, attention_maps = net(X)

            # Augmentation with crop_mask
            crop_image = batch_augment(X,
                                       attention_maps,
                                       mode='crop',
                                       theta=0.1)

            y_pred_crop, _, _ = net(crop_image)
            pred = (y_pred_raw + y_pred_crop) / 2.

            if visualize:
                # reshape attention maps
                attention_maps = F.upsample_bilinear(attention_maps,
                                                     size=(X.size(2),
                                                           X.size(3)))
                attention_maps = torch.sqrt(attention_maps.cpu() /
                                            attention_maps.max().item())

                # get heat attention maps
                heat_attention_maps = generate_heatmap(attention_maps)

                # raw_image, heat_attention, raw_attention
                raw_image = X.cpu() * STD + MEAN
                heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5
                raw_attention_image = raw_image * attention_maps

                for batch_idx in range(X.size(0)):
                    rimg = ToPILImage(raw_image[batch_idx])
                    raimg = ToPILImage(raw_attention_image[batch_idx])
                    haimg = ToPILImage(heat_attention_image[batch_idx])
                    rimg.save(
                        os.path.join(savepath,
                                     '%03d_raw.jpg' % (i + batch_idx)))
                    raimg.save(
                        os.path.join(savepath,
                                     '%03d_raw_atten.jpg' % (i + batch_idx)))
                    haimg.save(
                        os.path.join(savepath,
                                     '%03d_heat_atten.jpg' % (i + batch_idx)))

            # Top K
            epoch_acc = accuracy(pred, y)

            # end of this batch
            batch_info = 'Val Acc ({:.2f}, {:.2f})'.format(
                epoch_acc[0], epoch_acc[1])
            pbar.update()
            pbar.set_postfix_str(batch_info)

        pbar.close()

    # show information for this epoch
    logging.info('Accuracy: %.2f, %.2f' % (epoch_acc[0], epoch_acc[1]))
Exemplo n.º 3
0
def main():
    logging.basicConfig(
        format=
        '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
        level=logging.INFO)
    warnings.filterwarnings("ignore")

    try:
        ckpt = config.eval_ckpt
    except:
        logging.info('Set ckpt for evaluation in config.py')
        return

    ##################################
    # Dataset for testing
    ##################################
    # _, test_dataset = get_trainval_datasets(config.tag, resize=config.image_size)
    test_dataset = CarDataset('test')
    test_loader = DataLoader(test_dataset,
                             batch_size=config.batch_size,
                             shuffle=False,
                             num_workers=2,
                             pin_memory=True)
    name2label, label2name = mapping('../training_labels.csv')
    ##################################
    # Initialize model
    ##################################
    net = WSDAN(num_classes=test_dataset.num_classes,
                M=config.num_attentions,
                net=config.net)

    # Load ckpt and get state_dict
    checkpoint = torch.load(ckpt)
    state_dict = checkpoint['state_dict']

    # Load weights
    net.load_state_dict(state_dict)
    logging.info('Network loaded from {}'.format(ckpt))

    ##################################
    # use cuda
    ##################################
    net.to(device)
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    ##################################
    # Prediction
    ##################################
    raw_accuracy = TopKAccuracyMetric(topk=(1, 5))
    ref_accuracy = TopKAccuracyMetric(topk=(1, 5))
    raw_accuracy.reset()
    ref_accuracy.reset()

    net.eval()
    logits = []
    ids = []
    with torch.no_grad():
        pbar = tqdm(total=len(test_loader), unit=' batches')
        pbar.set_description('Validation')
        for i, (X, y, id) in enumerate(test_loader):
            X = X.to(device)
            y = y.to(device)
            ids.extend(id)

            # WS-DAN
            y_pred_raw, _, attention_maps = net(X)

            # Augmentation with crop_mask
            crop_image = batch_augment(X,
                                       attention_maps,
                                       mode='crop',
                                       theta=0.1,
                                       padding_ratio=0.05)

            y_pred_crop, _, _ = net(crop_image)
            y_pred = (y_pred_raw + y_pred_crop) / 2.

            # Save the predictions
            logits.append(y_pred.cpu())
            prediction = torch.argmax(torch.cat(logits, dim=0), dim=1)

            submission = pd.DataFrame(
                [ids,
                 [label2name[x] for x in prediction.numpy()]]).transpose()
            submission.columns = ['id', 'label']
            submission.to_csv(savepath + 'predictions.csv', index=False)

            if visualize:
                # reshape attention maps
                attention_maps = F.upsample_bilinear(attention_maps,
                                                     size=(X.size(2),
                                                           X.size(3)))
                attention_maps = torch.sqrt(attention_maps.cpu() /
                                            attention_maps.max().item())

                # get heat attention maps
                heat_attention_maps = generate_heatmap(attention_maps)

                # raw_image, heat_attention, raw_attention
                raw_image = X.cpu() * STD + MEAN
                heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5
                raw_attention_image = raw_image * attention_maps

                for batch_idx in range(X.size(0)):
                    rimg = ToPILImage(raw_image[batch_idx])
                    raimg = ToPILImage(raw_attention_image[batch_idx])
                    haimg = ToPILImage(heat_attention_image[batch_idx])
                    rimg.save(
                        os.path.join(
                            savepath, '%03d_raw.jpg' %
                            (i * config.batch_size + batch_idx)))
                    raimg.save(
                        os.path.join(
                            savepath, '%03d_raw_atten.jpg' %
                            (i * config.batch_size + batch_idx)))
                    haimg.save(
                        os.path.join(
                            savepath, '%03d_heat_atten.jpg' %
                            (i * config.batch_size + batch_idx)))

            # Top K
            epoch_raw_acc = raw_accuracy(y_pred_raw, y)
            epoch_ref_acc = ref_accuracy(y_pred, y)

            # end of this batch
            batch_info = 'Val Acc: Raw ({:.2f}, {:.2f}), Refine ({:.2f}, {:.2f})'.format(
                epoch_raw_acc[0], epoch_raw_acc[1], epoch_ref_acc[0],
                epoch_ref_acc[1])
            pbar.update()
            pbar.set_postfix_str(batch_info)

        pbar.close()
Exemplo n.º 4
0
def main():
    logging.basicConfig(
        format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
        level=logging.INFO)
    warnings.filterwarnings("ignore")

    try:
        ckpt = config.eval_ckpt
    except:
        logging.info('Set ckpt for evaluation in config.py')
        return

    ##################################
    # Dataset for testing
    ##################################
    _, test_dataset = get_trainval_datasets(config.tag, resize=config.image_size)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False,
                             num_workers=2, pin_memory=True)

    ##################################
    # Initialize model
    ##################################
    # net = WSDAN(num_classes=test_dataset.num_classes, M=config.num_attentions, net=config.net)
    net = resnet34_plus(num_classes=2)

    # Load ckpt and get state_dict
    checkpoint = torch.load(ckpt)
    #state_dict = checkpoint['state_dict']

    # Load weights
    net.load_state_dict(checkpoint)
    logging.info('Network loaded from {}'.format(ckpt))

    ##################################
    # use cuda
    ##################################
    net.to(device)
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    ##################################
    # Prediction
    ##################################
    raw_accuracy = TopKAccuracyMetric(topk=(1,))
    ref_accuracy = TopKAccuracyMetric(topk=(1,))
    raw_accuracy.reset()
    ref_accuracy.reset()

    net.eval()
    with torch.no_grad():
        pbar = tqdm(total=len(test_loader), unit=' batches')
        pbar.set_description('Validation')
        for i, (X, y) in enumerate(test_loader):
            X = X.to(device)
            y = y.to(device)

            # WS-DAN
            y_pred_raw= net(X)

            # Augmentation with crop_mask

            y_pred = y_pred_raw

            if visualize:
                # reshape attention maps
                attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3)))
                attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item())

                # get heat attention maps
                heat_attention_maps = generate_heatmap(attention_maps)

                # raw_image, heat_attention, raw_attention
                raw_image = X.cpu() * STD + MEAN
                heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5
                raw_attention_image = raw_image * attention_maps

                for batch_idx in range(X.size(0)):
                    rimg = ToPILImage(raw_image[batch_idx])
                    raimg = ToPILImage(raw_attention_image[batch_idx])
                    haimg = ToPILImage(heat_attention_image[batch_idx])
                    rimg.save(os.path.join(savepath, '%03d_raw.jpg' % (i * config.batch_size + batch_idx)))
                    raimg.save(os.path.join(savepath, '%03d_raw_atten.jpg' % (i * config.batch_size + batch_idx)))
                    haimg.save(os.path.join(savepath, '%03d_heat_atten.jpg' % (i * config.batch_size + batch_idx)))

            # Top K
            epoch_raw_acc = raw_accuracy(y_pred_raw, y)
            epoch_ref_acc = ref_accuracy(y_pred, y)

            # end of this batch
            batch_info = 'Val Acc: Raw ({:.2f}, {:.2f}), Refine ({:.2f}, {:.2f})'.format(
                epoch_raw_acc[0], epoch_raw_acc[0], epoch_ref_acc[0], epoch_ref_acc[0])
            pbar.update()
            pbar.set_postfix_str(batch_info)

        pbar.close()