コード例 #1
0
for epoch in range(args.epochs * args.switch_interval):
    if epoch + 1 % args.switch_interval:
        if training_flag == "a":
            training_model = model_b
            gutout_model = model_a
            optimizer = optimizer_b
            training_flag = "b"
        else:
            training_model = model_a
            gutout_model = model_b
            optimizer = optimizer_a
            training_flag = "a"

    grad_cam = BatchGradCam(
        model=gutout_model,
        feature_module=gutout_model.layer4,
        target_layer_names=["1"],
        use_cuda=args.use_cuda,
    )
    train_accuracy = train(training_model, grad_cam, criterion, optimizer,
                           train_loader, max_num_batches)
    test_acc = test(training_model, test_loader, max_num_batches)

    tqdm.write(training_flag + " test_acc: %.3f" % (test_acc))
    row = {
        "epoch": str(epoch),
        "train_acc": str(train_accuracy),
        "test_acc": str(test_acc),
    }

    if training_flag == "a":
        scheduler_a.step()
コード例 #2
0
experiment_id = args.dataset + "_" + args.model
current_time = time.localtime()
current_time = time.strftime("%H-%M-%S", current_time)
experiment_dir = current_time + "-experiment_" + experiment_id

os.makedirs(experiment_dir)
os.makedirs(os.path.join(experiment_dir, "checkpoints/"), exist_ok=True)
csv_filename = os.path.join(experiment_dir, experiment_id + ".csv")
csv_logger = CSVLogger(
    args=args, fieldnames=["epoch", "train_acc", "test_acc"], filename=csv_filename
)

best_acc = -1
grad_cam = BatchGradCam(
    model=model,
    feature_module=model.layer1,
    target_layer_names=["0"],
    use_cuda=args.use_cuda,
)

# run training loop
for epoch in range(args.epochs):
    train_accuracy = train(
        model, grad_cam, criterion, optimizer, train_loader, max_num_batches
    )
    test_acc = test(model, test_loader, max_num_batches)
    is_best = test_acc > best_acc
    tqdm.write("test_acc: %.3f" % (test_acc))

    # scheduler.step(epoch)  # Use this line for PyTorch <1.4
    scheduler.step()  # Use this line for PyTorch >=1.4
コード例 #3
0
            else:
                # switch to training model a
                training_model = model_a
                optimizer = optimizer_a
                scheduler = scheduler_a
                csv_logger = csv_logger_a
                best_acc = copy.copy(best_acc_a)
                training_flag = "a"

                # if model a is training, the model b is the gutout model
                gutout_model = model_b

        # create the gradCAM model
        grad_cam = BatchGradCam(
            model=gutout_model,
            feature_module=getattr(gutout_model, args.feature_module),
            target_layer_names=[args.target_layer_names],
            use_cuda=args.use_cuda,
        )

        # run the training loop on a single model
        print(f"running epoch with model: {training_flag}")
        best_acc = run_epoch(training_model,
                             grad_cam,
                             criterion,
                             optimizer,
                             scheduler,
                             csv_logger,
                             train_loader,
                             test_loader,
                             epoch,
                             best_acc,
コード例 #4
0
import numpy as np

sys.path.append(os.path.join(os.path.dirname(__file__)))
sys.path.append(os.path.join(os.path.dirname(__file__), "../../"))

from src.models.resnet import resnet18
from src.gutout.gutout_utils import BatchGradCam, get_gutout_samples, gutout_images

num_classes = 10
model = resnet18(num_classes=num_classes)

path = "run/checkpoints/cifar10_resnet18_Epoch_45acc0.7484_.pth"
model.load_state_dict(torch.load(path, map_location="cpu"))

grad_cam = BatchGradCam(model=model,
                        feature_module=model.layer3,
                        target_layer_names=["0"],
                        use_cuda=False)

img_path = "sample_imgs_cifar10/plane.png"
img = cv2.imread(img_path, 1)
img = np.float32(cv2.resize(img, (32, 32)))
img = np.expand_dims(img, 0)
img = torch.from_numpy(img).permute(0, 3, 1, 2)

mask = grad_cam(img).numpy()

print(mask)
gutout_mask = mask <= 0.9
img = img.numpy()
img = np.squeeze(img, axis=0)
gutout_mask = np.squeeze(gutout_mask, axis=0)