Exemple #1
0
def main():
    ## load data
    # CIFAR-100: 500 training images and 100 testing images per class
    print('\nloading the dataset ...\n')
    num_aug = 3
    im_size = 32
    transform_train = transforms.Compose([
        transforms.RandomCrop(im_size, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                             (0.2675, 0.2565, 0.2761))
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                             (0.2675, 0.2565, 0.2761))
    ])

    def _init_fn(worker_id):
        random.seed(base_seed + worker_id)

    # trainset = torchvision.datasets.CIFAR100(root='CIFAR100_data', train=True, download=True, transform=transform_train)
    trainset = DrawDataset(transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=opt.batch_size,
                                              shuffle=True,
                                              num_workers=8,
                                              worker_init_fn=_init_fn)
    # testset = torchvision.datasets.CIFAR100(root='CIFAR100_data', train=False, download=True, transform=transform_test)
    testset = DrawDataset(transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=5)
    print('done')

    ## load network
    print('\nloading the network ...\n')
    # use attention module?
    if not opt.no_attention:
        print('\nturn on attention ...\n')
    else:
        print('\nturn off attention ...\n')
    # (linear attn) insert attention befroe or after maxpooling?
    # (grid attn only supports "before" mode)
    if opt.attn_mode == 'before':
        print('\npay attention before maxpooling layers...\n')
        net = AttnVGG_before(im_size=im_size,
                             num_classes=100,
                             attention=not opt.no_attention,
                             normalize_attn=opt.normalize_attn,
                             init='xavierUniform')
    elif opt.attn_mode == 'after':
        print('\npay attention after maxpooling layers...\n')
        net = AttnVGG_after(im_size=im_size,
                            num_classes=100,
                            attention=not opt.no_attention,
                            normalize_attn=opt.normalize_attn,
                            init='xavierUniform')
    else:
        raise NotImplementedError("Invalid attention mode!")
    criterion = nn.CrossEntropyLoss()
    print('done')

    ## move to GPU
    print('\nmoving to GPU ...\n')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device_ids = [0, 1]
    model = nn.DataParallel(net, device_ids=device_ids).to(device)
    criterion.to(device)
    print('done')

    ### optimizer
    optimizer = optim.SGD(model.parameters(),
                          lr=opt.lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    lr_lambda = lambda epoch: np.power(0.5, int(epoch / 25))
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    # training
    print('\nstart training ...\n')
    step = 0
    running_avg_accuracy = 0
    writer = SummaryWriter(opt.outf)
    for epoch in range(opt.epochs):
        images_disp = []
        # adjust learning rate
        scheduler.step()
        writer.add_scalar('train/learning_rate',
                          optimizer.param_groups[0]['lr'], epoch)
        print("\nepoch %d learning rate %f\n" %
              (epoch, optimizer.param_groups[0]['lr']))
        # run for one epoch
        for aug in range(num_aug):
            for i, data in enumerate(trainloader, 0):
                # warm up
                model.train()
                model.zero_grad()
                optimizer.zero_grad()
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                if (aug == 0) and (
                        i == 0):  # archive images in order to save to logs
                    images_disp.append(inputs[0:36, :, :, :])
                # forward
                pred, __, __, __ = model(inputs)
                # backward
                loss = criterion(pred, labels)
                loss.backward()
                optimizer.step()
                # display results
                if i % 10 == 0:
                    model.eval()
                    pred, __, __, __ = model(inputs)
                    predict = torch.argmax(pred, 1)
                    total = labels.size(0)
                    correct = torch.eq(predict, labels).sum().double().item()
                    accuracy = correct / total
                    running_avg_accuracy = 0.9 * running_avg_accuracy + 0.1 * accuracy
                    writer.add_scalar('train/loss', loss.item(), step)
                    writer.add_scalar('train/accuracy', accuracy, step)
                    writer.add_scalar('train/running_avg_accuracy',
                                      running_avg_accuracy, step)
                    print(
                        "[epoch %d][aug %d/%d][%d/%d] loss %.4f accuracy %.2f%% running avg accuracy %.2f%%"
                        % (epoch, aug, num_aug - 1, i, len(trainloader) - 1,
                           loss.item(), (100 * accuracy),
                           (100 * running_avg_accuracy)))
                step += 1
        # the end of each epoch: test & log
        print('\none epoch done, saving records ...\n')
        torch.save(model.state_dict(), os.path.join(opt.outf, 'net.pth'))
        if epoch == opt.epochs / 2:
            torch.save(model.state_dict(),
                       os.path.join(opt.outf, 'net%d.pth' % epoch))
        model.eval()
        total = 0
        correct = 0
        with torch.no_grad():
            # log scalars
            for i, data in enumerate(testloader, 0):
                images_test, labels_test = data
                images_test, labels_test = images_test.to(
                    device), labels_test.to(device)
                if i == 0:  # archive images in order to save to logs
                    images_disp.append(inputs[0:36, :, :, :])
                pred_test, __, __, __ = model(images_test)
                predict = torch.argmax(pred_test, 1)
                total += labels_test.size(0)
                correct += torch.eq(predict, labels_test).sum().double().item()
            writer.add_scalar('test/accuracy', correct / total, epoch)
            print("\n[epoch %d] accuracy on test data: %.2f%%\n" %
                  (epoch, 100 * correct / total))
            # log images
            if opt.log_images:
                print('\nlog images ...\n')
                I_train = utils.make_grid(images_disp[0],
                                          nrow=6,
                                          normalize=True,
                                          scale_each=True)
                writer.add_image('train/image', I_train, epoch)
                if epoch == 0:
                    I_test = utils.make_grid(images_disp[1],
                                             nrow=6,
                                             normalize=True,
                                             scale_each=True)
                    writer.add_image('test/image', I_test, epoch)
            if opt.log_images and (not opt.no_attention):
                print('\nlog attention maps ...\n')
                # base factor
                if opt.attn_mode == 'before':
                    min_up_factor = 1
                else:
                    min_up_factor = 2
                # sigmoid or softmax
                if opt.normalize_attn:
                    vis_fun = visualize_attn_softmax
                else:
                    vis_fun = visualize_attn_sigmoid
                # training data
                __, c1, c2, c3 = model(images_disp[0])
                if c1 is not None:
                    attn1 = vis_fun(I_train,
                                    c1,
                                    up_factor=min_up_factor,
                                    nrow=6)
                    writer.add_image('train/attention_map_1', attn1, epoch)
                if c2 is not None:
                    attn2 = vis_fun(I_train,
                                    c2,
                                    up_factor=min_up_factor * 2,
                                    nrow=6)
                    writer.add_image('train/attention_map_2', attn2, epoch)
                if c3 is not None:
                    attn3 = vis_fun(I_train,
                                    c3,
                                    up_factor=min_up_factor * 4,
                                    nrow=6)
                    writer.add_image('train/attention_map_3', attn3, epoch)
                # test data
                __, c1, c2, c3 = model(images_disp[1])
                if c1 is not None:
                    attn1 = vis_fun(I_test,
                                    c1,
                                    up_factor=min_up_factor,
                                    nrow=6)
                    writer.add_image('test/attention_map_1', attn1, epoch)
                if c2 is not None:
                    attn2 = vis_fun(I_test,
                                    c2,
                                    up_factor=min_up_factor * 2,
                                    nrow=6)
                    writer.add_image('test/attention_map_2', attn2, epoch)
                if c3 is not None:
                    attn3 = vis_fun(I_test,
                                    c3,
                                    up_factor=min_up_factor * 4,
                                    nrow=6)
                    writer.add_image('test/attention_map_3', attn3, epoch)
def main():
    ## load data
    print('\nloading the dataset ...\n')
    if False:  # TODO debug section, remove
        pass

    else:
        opt = argparser()
        print(opt)

        num_aug = 1
        raw_size = 1024
        im_size = opt.image_size

        transform_train = transforms.Compose([
            transforms.Resize(im_size),
            transforms.RandomVerticalFlip(),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # 0.498, std: 0.185
            transforms.Normalize((0.5, ), (0.185, ))
        ])

        transform_test = transforms.Compose([
            transforms.Resize(im_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.185, ))
        ])

        if opt.global_param != 'CIFAR':
            xray = XRAY(transform_train,
                        transform_test,
                        force_pre_process=False,
                        csv_file=opt.csv_path)
            trainset = xray.train_set
            trainloader = torch.utils.data.DataLoader(
                trainset,
                batch_size=opt.batch_size,
                shuffle=True,
                num_workers=8,
                worker_init_fn=_worker_init_fn_)
            testset = xray.test_set
            testloader = torch.utils.data.DataLoader(testset,
                                                     batch_size=opt.batch_size,
                                                     shuffle=False,
                                                     num_workers=5)
            class_to_index = xray.class_to_index()
        else:
            trainset = PacemakerDataset(transform=transform_train,
                                        is_train=True)
            trainloader = torch.utils.data.DataLoader(
                trainset,
                batch_size=opt.batch_size,
                shuffle=True,
                num_workers=8,
                worker_init_fn=_worker_init_fn_)

            testset = PacemakerDataset(transform=transform_test,
                                       is_train=False)
            testloader = torch.utils.data.DataLoader(testset,
                                                     batch_size=opt.batch_size,
                                                     shuffle=False,
                                                     num_workers=5)
            class_to_index = trainset.class_to_index()

        num_of_class = len(class_to_index.keys()) - 1
        device_ids = [0, 1]

        if opt.loss is not None and isinstance(opt.loss, str):
            criterion = xray_loss.Loss(opt.loss)

        else:
            criterion = nn.BCELoss()
            #criterion = nn.CrossEntropyLoss()
        print("criterion = %s" % type(criterion))

    print('done num_of_classes: %s [%s] , post crop size: %s' %
          (num_of_class, class_to_index, im_size))

    ## load network
    print('\nloading the network ...\n')
    # use attention module?
    if not opt.no_attention:
        print('\nturn on attention ...\n')
    else:
        print('\nturn off attention ...\n')
    # (linear attn) insert attention befroe or after maxpooling?
    # (grid attn only supports "before" mode)
    if opt.attn_mode == 'before':
        print('\npay attention before maxpooling layers...\n')
        net = AttnVGG_before(im_size=im_size,
                             num_classes=num_of_class,
                             attention=not opt.no_attention,
                             normalize_attn=opt.normalize_attn,
                             init='xavierUniform',
                             _base_features=opt.base_feature_size,
                             dropout=opt.dropout)
    elif opt.attn_mode == 'after':
        print('\npay attention after maxpooling layers...\n')
        net = AttnVGG_after(im_size=im_size,
                            num_classes=num_of_class,
                            attention=not opt.no_attention,
                            normalize_attn=opt.normalize_attn,
                            init='xavierUniform')
    else:
        raise NotImplementedError("Invalid attention mode!")

    print('done')

    ## move to GPU
    print('\nmoving to GPU ...\n')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_pre_trained = None
    record = None

    if opt.pre_train:
        try:
            if os.path.exists(opt.chest_xray_pretrain_path):
                model_pre_trained = torch.load(opt.chest_xray_pretrain_path)
                record = None
            else:
                assert opt.test_only is False, "cannot run test only mode without pre train data"

        except AttributeError:
            record_path = os.path.join(opt.outf, 'record')
            if os.path.exists(record_path):
                with open(os.path.join(opt.outf, 'record'), 'r') as frecord:
                    # contents = record.read()
                    record = json.load(frecord)
                # record = ast.literal_eval(contents)

                if os.path.exists(record['model']):
                    model_pre_trained = torch.load(record['model'])

                print("found pre trained data: %s", record)

    if model_pre_trained is not None:
        model = nn.DataParallel(net, device_ids=device_ids)
        model.load_state_dict(model_pre_trained)
        model = model.to(device)
    else:
        model = nn.DataParallel(net, device_ids=device_ids).to(device)

    criterion.to(device)
    print('done')

    if opt.test_only:
        visual_test_image_softmax(model, opt.test_image, transform_test)
        return

    if record is None:
        lr = opt.lr
        first_epoch = 0
        step = 0
    else:
        lr = record['lr']
        first_epoch = record['epoch']
        step = record['step']

    slow_lr = opt.slow_lr if hasattr(opt, 'slow_lr') else False

    ### optimizer
    if opt.global_param == 'CIFAR':
        optimizer = optim.SGD(model.parameters(),
                              lr=lr,
                              momentum=0.9,
                              weight_decay=5e-4)
        lr_lambda = lambda epoch: np.power(0.5, int(epoch / 25))
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
    else:
        if opt.global_param == 'BCE':
            optimizer = optim.SGD(model.parameters(),
                                  lr=lr,
                                  momentum=opt.momentum,
                                  weight_decay=opt.weight_decay)
        else:
            optimizer = optim.Adam(model.parameters(),
                                   lr=lr,
                                   weight_decay=opt.weight_decay)
        rate = 25 if not slow_lr else 50
        lr_lambda = lambda epoch: max(np.power(0.5, int(epoch / rate)), 1e-4)
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    # training
    start = time.time()
    print('\nstart training [%s]...\n' % start)

    running_avg_accuracy = 0
    writer = SummaryWriter(opt.outf)
    for epoch in range(first_epoch, first_epoch + opt.epochs):
        images_disp = []
        # adjust learning rate
        scheduler.step()
        writer.add_scalar('train/learning_rate',
                          optimizer.param_groups[0]['lr'], epoch)
        print("\nepoch %d learning rate %f\n" %
              (epoch, optimizer.param_groups[0]['lr']))
        # run for one epoch

        for aug in range(num_aug):
            for i, data in enumerate(trainloader, 0):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                # warm up
                model.train()
                model.zero_grad()
                optimizer.zero_grad()
                if (aug == 0) and (
                        i == 0):  # archive images in order to save to logs
                    print("input:", inputs.shape, "inputs[0:36, :, :, :] ->",
                          inputs[0:36, :, :, :].shape)
                    images_disp.append(inputs[0:36, :, :, :])

                # forward
                pred, __, __, __ = model(inputs)

                # backward
                if isinstance(criterion, nn.BCELoss):
                    pred = torch.sigmoid(pred)

                loss = criterion(pred, labels)

                #print("loss: %s, pred: %s, labels: %s" % (loss, pred, labels))
                loss.backward()
                #print("post loss backward")

                optimizer.step()
                # display results
                if i % 10 == 0:
                    model.eval()
                    pred, __, __, __ = model(inputs)
                    if isinstance(criterion, nn.BCELoss):
                        predict = pred
                        predict[predict > 0.5] = 1
                        predict[predict <= 0.5] = 0
                    elif isinstance(criterion, nn.CrossEntropyLoss):
                        predict = torch.argmax(pred, 1)
                    elif isinstance(criterion, xray_loss.Loss):
                        predict = torch.sigmoid(pred)
                        predict[predict > 0.5] = 1
                        predict[predict <= 0.5] = 0
                    else:
                        raise Exception("{} what is this?".format(criterion))

                    #print("predict: ", predict.shape, "pred: ", pred.shape, "label: ", labels.shape, "input: ", inputs.shape)
                    #print("Train: predict: ", predict, "label: ", labels)
                    total = labels.size(0) * labels.size(1)
                    correct = torch.eq(predict, labels).sum().double().item()
                    accuracy = correct / total
                    # print("accuracy:%s = correct:%s [pred:%s, predict:%s, labels:%s] / total:%s" % (accuracy, correct, pred, predict, labels, total))
                    running_avg_accuracy = 0.9 * running_avg_accuracy + 0.1 * accuracy
                    writer.add_scalar('train/loss', loss.item(), step)
                    writer.add_scalar('train/accuracy', accuracy, step)
                    writer.add_scalar('train/running_avg_accuracy',
                                      running_avg_accuracy, step)
                    print(
                        "[epoch %d][aug %d/%d][%d/%d] loss %.4f accuracy %.2f%% running avg accuracy %.2f%%"
                        % (epoch, aug, num_aug - 1, i, len(trainloader) - 1,
                           loss.item(), (100 * accuracy),
                           (100 * running_avg_accuracy)))
                step += 1
        # the end of each epoch: test & log
        print('\none epoch done [took: %s], saving records ...\n' %
              (time.time() - start))
        state = os.path.join(opt.outf, 'net.pth')
        torch.save(model.state_dict(), state)
        with open(os.path.join(opt.outf, 'record'), 'w') as record:
            srecord = {
                "lr": optimizer.param_groups[0]['lr'],
                "epoch": epoch,
                "model": state,
                "step": step,
                'global_arg': str(opt)
            }
            json.dump(srecord, record)

        if epoch == opt.epochs / 2:
            torch.save(model.state_dict(),
                       os.path.join(opt.outf, 'net%d.pth' % epoch))
        model.eval()
        total = 0
        correct = 0
        with torch.no_grad():
            # log scalars
            images_disp.append(inputs[0:36, :, :, :])
            if opt.global_param == 'PACEMAKER':  # TODO not needed, remove it
                print("\n[epoch %d] log images for pacemaker" % epoch)
            else:
                for i, data in enumerate(testloader, 0):
                    images_test, labels_test = data
                    images_test, labels_test = images_test.to(
                        device), labels_test.to(device)

                    pred_test, __, __, __ = model(images_test)

                    pred_test = torch.sigmoid(pred_test)

                    #print("Test prediction: %s" % pred_test)
                    #assert not (isinstance(criterion, nn.BCELoss) or isinstance(criterion, nn.CrossEntropyLoss))
                    if isinstance(criterion, nn.BCELoss):
                        predict = pred_test
                        predict[predict > 0.5] = 1
                        predict[predict <= 0.5] = 0
                    elif isinstance(criterion, nn.CrossEntropyLoss):
                        predict = torch.argmax(pred_test, 1)
                    elif isinstance(criterion, xray_loss.Loss):
                        predict = pred_test
                        predict[predict > 0.5] = 1
                        predict[predict <= 0.5] = 0
                    else:
                        raise Exception("not sure how we reached here")

                    total += labels_test.size(0) * labels_test.size(1)
                    correct += torch.eq(predict,
                                        labels_test).sum().double().item()

                writer.add_scalar('test/accuracy', correct / total, epoch)
                print("\n[epoch %d] accuracy on test data: %.2f%%\n" %
                      (epoch, 100 * correct / total))

            # log images
            if opt.log_images:
                print('\nlog images ...\n')
                I_train = utils.make_grid(images_disp[0],
                                          nrow=6,
                                          normalize=True,
                                          scale_each=True)
                writer.add_image('train/image', I_train, epoch)
                #if epoch == 0:
                if epoch == first_epoch:
                    I_test = utils.make_grid(images_disp[1],
                                             nrow=6,
                                             normalize=True,
                                             scale_each=True)
                    writer.add_image('test/image', I_test, epoch)
            if opt.log_images and (not opt.no_attention):
                print('\nlog attention maps ...\n')
                # base factor
                if opt.attn_mode == 'before':
                    min_up_factor = 1
                else:
                    min_up_factor = 2
                # sigmoid or softmax
                if opt.normalize_attn:
                    vis_fun = visualize_attn_softmax
                else:
                    vis_fun = visualize_attn_sigmoid
                # training data
                __, c1, c2, c3 = model(images_disp[0])
                if c1 is not None:
                    attn1 = vis_fun(I_train,
                                    c1,
                                    up_factor=min_up_factor,
                                    nrow=6)
                    writer.add_image('train/attention_map_1', attn1, epoch)
                if c2 is not None:
                    attn2 = vis_fun(I_train,
                                    c2,
                                    up_factor=min_up_factor * 2,
                                    nrow=6)
                    writer.add_image('train/attention_map_2', attn2, epoch)
                if c3 is not None:
                    attn3 = vis_fun(I_train,
                                    c3,
                                    up_factor=min_up_factor * 4,
                                    nrow=6)
                    writer.add_image('train/attention_map_3', attn3, epoch)
                # test data
                __, c1, c2, c3 = model(images_disp[1])
                if c1 is not None:
                    attn1 = vis_fun(I_test,
                                    c1,
                                    up_factor=min_up_factor,
                                    nrow=6)
                    writer.add_image('test/attention_map_1', attn1, epoch)
                if c2 is not None:
                    attn2 = vis_fun(I_test,
                                    c2,
                                    up_factor=min_up_factor * 2,
                                    nrow=6)
                    writer.add_image('test/attention_map_2', attn2, epoch)
                if c3 is not None:
                    attn3 = vis_fun(I_test,
                                    c3,
                                    up_factor=min_up_factor * 4,
                                    nrow=6)
                    writer.add_image('test/attention_map_3', attn3, epoch)

        start = time.time()
def main():
    im_size = 32
    mean, std = [0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]
    transform_test = transforms.Compose([
        #transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    print('done')

    ## load network
    print('\nloading the network ...\n')
    # (linear attn) insert attention befroe or after maxpooling?
    # (grid attn only supports "before" mode)
    if opt.attn_mode == 'before':
        print('\npay attention before maxpooling layers...\n')
        net = AttnVGG_before(im_size=im_size,
                             num_classes=100,
                             attention=True,
                             normalize_attn=opt.normalize_attn,
                             init='xavierUniform')
    elif opt.attn_mode == 'after':
        print('\npay attention after maxpooling layers...\n')
        net = AttnVGG_after(im_size=im_size,
                            num_classes=100,
                            attention=True,
                            normalize_attn=opt.normalize_attn,
                            init='xavierUniform')
    else:
        raise NotImplementedError("Invalid attention mode!")
    print('done')

    ## load model
    print('\nloading the model ...\n')
    state_dict = torch.load(opt.model, map_location=str(device))
    # Remove 'module.' prefix
    state_dict = {k[7:]: v for k, v in state_dict.items()}
    net.load_state_dict(state_dict)
    net = net.to(device)
    net.eval()
    print('done')

    model = net

    # base factor
    if opt.attn_mode == 'before':
        min_up_factor = 1
    else:
        min_up_factor = 2
    # sigmoid or softmax
    if opt.normalize_attn:
        vis_fun = visualize_attn_softmax
    else:
        vis_fun = visualize_attn_sigmoid

    if opt.output_dir:
        print("\nwill save heatmaps\n")

    if opt.img:
        img_dir = ""
        filenames = [opt.img]
    else:
        img_dir = opt.img_dir
        filenames = os.listdir(img_dir)

    display_fig = len(filenames) == 1

    with torch.no_grad():
        for filename in filenames:
            ## load image
            path = os.path.join(img_dir, filename)
            img = imread(path)
            if len(img.shape) == 2:
                img = img[:, :, np.newaxis]
                img = np.concatenate([img, img, img], axis=2)
            img = np.array(Image.fromarray(img).resize((im_size, im_size)))
            orig_img = img.copy()
            img = img.transpose(2, 0, 1)
            img = img / 255.
            img = torch.FloatTensor(img).to(device)
            image = transform_test(img)  # (3, 32, 32)

            if opt.output_dir:
                file_prefix = os.path.join(
                    opt.output_dir,
                    os.path.splitext(os.path.basename(filename))[0])
            else:
                file_prefix = None

            batch = image[np.newaxis, :, :, :]
            __, c1, c2, c3 = model(batch)
            if display_fig:
                fig, axs = plt.subplots(1, 4)
                axs[0].imshow(orig_img)
            if c1 is not None:
                attn1 = vis_fun(
                    img,
                    c1,
                    up_factor=min_up_factor,
                    nrow=1,
                    hm_file=None if file_prefix is None else file_prefix +
                    "_c1.npy")
                if display_fig:
                    axs[1].imshow(attn1.numpy().transpose(1, 2, 0))
            if c2 is not None:
                attn2 = vis_fun(
                    img,
                    c2,
                    up_factor=min_up_factor * 2,
                    nrow=1,
                    hm_file=None if file_prefix is None else file_prefix +
                    "_c2.npy")
                if display_fig:
                    axs[2].imshow(attn2.numpy().transpose(1, 2, 0))
            if c3 is not None:
                attn3 = vis_fun(
                    img,
                    c3,
                    up_factor=min_up_factor * 4,
                    nrow=1,
                    hm_file=None if file_prefix is None else file_prefix +
                    "_c3.npy")
                if display_fig:
                    axs[3].imshow(attn3.numpy().transpose(1, 2, 0))

            if display_fig:
                plt.show()
Exemple #4
0
def main():
    im_size = 32
    mean, std = [0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]
    transform_test = transforms.Compose([
        #transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    print('done')

    ## load network
    print('\nloading the network ...\n')
    # (linear attn) insert attention befroe or after maxpooling?
    # (grid attn only supports "before" mode)
    if opt.attn_mode == 'before':
        print('\npay attention before maxpooling layers...\n')
        net = AttnVGG_before(im_size=im_size,
                             num_classes=100,
                             attention=True,
                             normalize_attn=opt.normalize_attn,
                             init='xavierUniform')
    elif opt.attn_mode == 'after':
        print('\npay attention after maxpooling layers...\n')
        net = AttnVGG_after(im_size=im_size,
                            num_classes=100,
                            attention=True,
                            normalize_attn=opt.normalize_attn,
                            init='xavierUniform')
    else:
        raise NotImplementedError("Invalid attention mode!")
    print('done')

    ## load model
    print('\nloading the model ...\n')
    state_dict = torch.load(opt.model, map_location=str(device))
    # Remove 'module.' prefix
    state_dict = {k[7:]: v for k, v in state_dict.items()}
    net.load_state_dict(state_dict)
    net = net.to(device)
    net.eval()
    print('done')

    model = net

    # base factor
    if opt.attn_mode == 'before':
        min_up_factor = 1
    else:
        min_up_factor = 2
    # sigmoid or softmax
    if opt.normalize_attn:
        vis_fun = visualize_attn_softmax
    else:
        vis_fun = visualize_attn_sigmoid

    results = []
    with torch.no_grad():
        for img_file in os.scandir(opt.image_dir):
            ## load image
            img = imread(img_file.path)
            if len(img.shape) == 2:
                img = img[:, :, np.newaxis]
                img = np.concatenate([img, img, img], axis=2)
            img = np.array(Image.fromarray(img).resize((im_size, im_size)))
            orig_img = img.copy()
            img = img.transpose(2, 0, 1)
            img = img / 255.
            img = torch.FloatTensor(img).to(device)
            image = transform_test(img)  # (3, 256, 256)

            batch = image[np.newaxis, :, :, :]
            pred, __, __, __ = model(batch)
            out, cls = torch.max(F.softmax(pred, dim=1), 1)
            results.append((out.item(), cls.item(), img_file.name))

    sorted_results = sorted(results, reverse=True)
    print("\n".join(f"{result[2]} {result[0]} {result[1]}"
                    for result in sorted_results[:opt.num_images]))