def main():
    global test_csv

    # evaluation mode
    evalute_filepath = '/root/workspace/depth/sparse-to-dense.pytorch/results/uw_nyu.sparsifier=uar.samples=0.modality=rgb.arch=resnet50.decoder=upproj.criterion=l1.lr=0.01.bs=16.pretrained=True(old)'
    best_weights_path = os.path.join(evalute_filepath, 'best_model.pkl')
    assert os.path.isfile(best_weights_path), \
    "=> no best weights found at '{}'".format(evalute_filepath)
    print("=> loading best weights for Model '{}'".format(evalute_filepath))

    val_loader = create_data_loaders()

    decoder = 'upproj'

    model = ResNet(layers=50,
                   decoder=decoder,
                   output_size=val_loader.dataset.output_size,
                   pretrained=False)
    model = model.cuda()
    model.load_state_dict(torch.load(best_weights_path))

    print("=> loaded best weights for Model")

    output_directory = os.path.join('results/uw_test', 'uw_test4')
    validate(val_loader, model, output_directory=output_directory)
def demo_from_best_model(resnet_layer, pretrained, num_classes, path):

    assert resnet_layer == 18 or resnet_layer == 50

    net_best = ResNet(layer_num=resnet_layer, pretrained=pretrained, num_classes=num_classes)
    net_best = net_best.to(device)
    net_best.load_state_dict(torch.load(path))
    net_best.eval()
    best_acc = save_confusion_matrix(net_best, val_loader, 'backup_demo/cm_best.png')
    print('test_best_accuracy = %.2f' % best_acc)
def main():
    global test_csv

    # evaluation mode
    evalute_filepath = '/root/workspace/depth/sparse-to-dense.pytorch/results/uw_nyu.sparsifier=uar.samples=0.modality=rgb.arch=resnet50.decoder=upproj.criterion=l1.lr=0.01.bs=16.pretrained=True(old)'
    best_weights_path = os.path.join(evalute_filepath, 'best_model.pkl')
    assert os.path.isfile(best_weights_path), \
    "=> no best weights found at '{}'".format(evalute_filepath)
    print(
        "=> loading best weights for SphereFCRN '{}'".format(evalute_filepath))

    val_loader = create_data_loaders()

    decoder = 'upproj'

    model = ResNet(layers=50,
                   decoder=decoder,
                   output_size=val_loader.dataset.output_size,
                   pretrained=False)
    model = model.cuda()
    model.load_state_dict(torch.load(best_weights_path))
    # model.decoder.apply(weights_init)

    print("=> loaded best weights for SphereFCRN")

    # print(model)

    # create results folder, if not already exists
    output_directory = os.path.join('results', 'uw_test5')
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    test_csv = os.path.join(output_directory, 'test.csv')
    best_txt = os.path.join(output_directory, 'best.txt')

    result, img_merge = validate(val_loader, model, write_to_file=True)

    # create new csv files
    with open(test_csv, 'w') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
    with open(best_txt, 'w') as txtfile:
        txtfile.write(
            "mse={:.3f}\nrmse={:.3f}\nabsrel={:.3f}\nlg10={:.3f}\nmae={:.3f}\ndelta1={:.3f}\nt_gpu={:.4f}\n"
            .format(result.mse, result.rmse, result.absrel, result.lg10,
                    result.mae, result.delta1, result.gpu_time))
    if img_merge is not None:
        img_filename = output_directory + '/comparison_best.png'
        utils.save_image(img_merge, img_filename)
예제 #4
0
def main():
    # load table data
    df_train = pd.read_csv("../input/train_curated.csv")
    df_noisy = pd.read_csv("../input/train_noisy.csv")
    df_test = pd.read_csv("../input/sample_submission.csv")
    labels = df_test.columns[1:].tolist()
    for label in labels:
        df_train[label] = df_train['labels'].apply(lambda x: label in x)
        df_noisy[label] = df_noisy['labels'].apply(lambda x: label in x)

    df_train['path'] = "../input/mel128/train/" + df_train['fname']
    df_test['path'] = "../input/mel128/test/" + df_train['fname']
    df_noisy['path'] = "../input/mel128/noisy/" + df_noisy['fname']

    # fold splitting
    folds = list(
        KFold(n_splits=NUM_FOLD, shuffle=True,
              random_state=SEED).split(np.arange(len(df_train))))

    # build model
    model = ResNet(NUM_CLASS).cuda()

    # set generator
    dataset_noisy = MelDataset(df_noisy['path'], df_noisy[labels].values)
    noisy_loader = DataLoader(
        dataset_noisy,
        batch_size=1,
        shuffle=False,
        num_workers=1,
        pin_memory=True,
    )

    # predict
    preds_noisy = np.zeros(
        [NUM_FOLD, NUM_EPOCH // NUM_CYCLE,
         len(df_noisy), NUM_CLASS], np.float32)
    for fold, (ids_train_split, ids_valid_split) in enumerate(folds):
        for cycle in range(NUM_EPOCH // NUM_CYCLE):
            print("fold: {} cycle: {}, sec: {:.1f}".format(
                fold + 1, cycle + 1,
                time.time() - starttime))
            model.load_state_dict(
                torch.load("{}/weight_fold_{}_epoch_{}.pth".format(
                    LOAD_DIR, fold + 1, NUM_CYCLE * (cycle + 1))))
            preds_noisy[fold, cycle] = predict(noisy_loader, model)

        np.save("{}/preds_noisy.npy".format(OUTPUT_DIR), preds_noisy)
예제 #5
0
class viewpoint_classifier():
    def __init__(self, model,dataset_index=0,video_target = None):

        if args.video == None:
            
            self.video_target = video_target
            customset_train = CustomDataset(path = args.dataset_path,subset_type="training",dataset_index=dataset_index,video_target = video_target)
            customset_test = CustomDataset(path = args.dataset_path,subset_type="testing",dataset_index=dataset_index, video_target = video_target)
        
            self.trainloader = torch.utils.data.DataLoader(dataset=customset_train,batch_size=args.batch_size,shuffle=True,num_workers=args.num_workers)
            self.testloader = torch.utils.data.DataLoader(dataset=customset_test,batch_size=args.batch_size,shuffle=False,num_workers=args.num_workers)    
        else:
            video_dataset = VideoDataset(video=args.video, batch_size=args.batch_size,
                                        frame_skip=int(args.frame_skip),image_folder=args.extract_frames_path, use_existing=args.use_existing_frames)
            
            self.videoloader = torch.utils.data.DataLoader(dataset=video_dataset, batch_size=1,shuffle=False,num_workers=args.num_workers)

   
        if (model == "alex"):
            self.model = AlexNet()
        elif (model == "vgg"):
            self.model = VGG()
        elif (model == "resnet"):
            self.model = ResNet()

        if args.pretrained_model != None:
            if args.pretrained_finetuning == False:
                self.model.load_state_dict(torch.load(args.pretrained_model))
            else:
                print "DEBUG : Make it load only part of the resnet model"
                #print(self.model)
                #self.model.load_state_dict(torch.load(args.pretrained_model))
                #for param in self.model.parameters():
                #    param.requires_grad = False
                self.model.fc = nn.Linear(512, 1000)
                #print(self.model)
                self.model.load_state_dict(torch.load(args.pretrained_model))
                self.model.fc = nn.Linear(512,3)
                #print(self.model)
                
        self.model.cuda()        
        print "Using weight decay: ",args.weight_decay
        self.optimizer = optim.SGD(self.model.parameters(), weight_decay=float(args.weight_decay),lr=0.01, momentum=0.9,nesterov=True)
        self.criterion = nn.CrossEntropyLoss().cuda()
예제 #6
0
def main():
    global test_csv

    # evaluation mode
    evalute_filepath = '/root/workspace/depth/sparse-to-dense.pytorch/results/uw_nyu.sparsifier=uar.samples=0.modality=rgb.arch=resnet50.decoder=upproj.criterion=l1.lr=0.01.bs=16.pretrained=True(old)'
    best_weights_path = os.path.join(evalute_filepath, 'best_model.pkl')
    assert os.path.isfile(best_weights_path), \
    "=> no best weights found at '{}'".format(evalute_filepath)
    print("=> loading best weights for model '{}'".format(evalute_filepath))

    val_loader = create_data_loaders()

    decoder = 'upproj'

    model = ResNet(layers=50,
                   decoder=decoder,
                   output_size=val_loader.dataset.output_size,
                   pretrained=False)
    model = model.cuda()
    model.load_state_dict(torch.load(best_weights_path))
    # model.decoder.apply(weights_init)

    print("=> loaded best weights for model")

    # create results folder, if not already exists
    output_directory = os.path.join('results/uw_test', 'uw_test5')
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    best_txt = os.path.join(output_directory, 'best.txt')

    result = validate(val_loader, model, output_directory=output_directory)

    # create new csv files
    with open(best_txt, 'w') as txtfile:
        txtfile.write("rmse={:.3f}\nabsrel={:.3f}\ndelta1={:.3f}\n".format(
            result[0], result[1], result[2]))
예제 #7
0
class viewpoint_classifier():

    def weighted_sampling(self,dataset_index=0,path=None):

        if not os.path.isfile("./results/intermediate_data/sampling_weights_two_viewpoints.p"):
            customset_preprocess = CustomDatasetViewpoint(path = args.dataset_path,subset_type="training",dataset_index=dataset_index, retrieve_images=False)
            self.processloader = torch.utils.data.DataLoader(dataset=customset_preprocess,batch_size=int(1),shuffle=False,num_workers=int(args.num_workers))

            sample_views = [] # when you start

            for batch_idx, (imgs, label) in enumerate(self.processloader):
                sample_views.append(label.numpy()[0][0])

            class_presence = [0, 0]

            for view in sample_views:
                class_presence[view] += 1

            for i in range(2):
                class_presence[i] /= len(sample_views)*1.0

            class_weights = [0 for i in range(len(sample_views))]
            for i in range(len(sample_views)):
                class_weights[i] = 1.0/class_presence[sample_views[i]]
            m = 2*len(sample_views)
            class_weights = [i/m for i in class_weights]

            # Finished with sampler weighting
            sampler = torch.utils.data.sampler.WeightedRandomSampler(class_weights,len(self.processloader),replacement=True)
            pickle.dump(sampler,open("./results/intermediate_data/sampling_weights_two_viewpoints.p","wb"))
        else:
            sampler = pickle.load(open("./results/intermediate_data/sampling_weights_two_viewpoints.p","rb"))
        return sampler

    def __init__(self, model,dataset_index=0, path = None):

        self.sampler = self.weighted_sampling(dataset_index=dataset_index,path=path)

        customset_train = CustomDatasetViewpoint(path = path,subset_type="training",dataset_index=dataset_index)
        customset_test = CustomDatasetViewpoint(path = path,subset_type="testing",dataset_index=dataset_index)

        self.trainloader = torch.utils.data.DataLoader(pin_memory=True,dataset=customset_train,sampler=self.sampler,batch_size=args.batch_size,shuffle=True,num_workers=args.num_workers)
        self.trainloader_acc = torch.utils.data.DataLoader(dataset=customset_train,batch_size=args.batch_size,shuffle=True,num_workers=args.num_workers)
        self.testloader_acc = torch.utils.data.DataLoader(dataset=customset_test,batch_size=args.batch_size,shuffle=True,num_workers=args.num_workers)

        if (model == "alex"):
            self.model = AlexNet()
        elif (model == "vgg"):
            self.model = VGG(num_classes=2)
        elif (model == "resnet"):
            self.model = ResNet()

        if args.pretrained_model != None:
            if args.pretrained_same_architecture:
                self.model.load_state_dict(torch.load(args.pretrained_model))
            else:
                if args.arch == "vgg":
                    self.model.soft = None
                    classifier = list(self.model.classifier.children())
                    classifier.pop()
                    classifier.append(torch.nn.Linear(4096,1000))
                    new_classifier = torch.nn.Sequential(*classifier)
                    self.model.classifier = new_classifier
                    self.model.load_state_dict(torch.load(args.pretrained_model))
                    classifier = list(self.model.classifier.children())
                    classifier.pop()
                    classifier.append(torch.nn.Linear(4096,2))
                    new_classifier = torch.nn.Sequential(*classifier)
                    self.model.classifier = new_classifier
                    self.model.soft = nn.LogSoftmax()
                else:
                    self.model.fc = nn.Linear(512, 1000)
                    self.model.load_state_dict(torch.load(args.pretrained_model))
                    self.model.fc = nn.Linear(512,2)     
   
        self.optimizer = optim.Adam(self.model.parameters(), weight_decay=float(args.weight_decay), lr=0.0001)
예제 #8
0
    Linear_testloader = DataLoader(Linear_testset,
                                   batch_size=512,
                                   shuffle=False,
                                   num_workers=num_worker)

    # ========== [visualize] ==========
    if batch_size >= 64:
        visualize(Linear_trainloader, dir_log + '/' + 'visual.png')

    # ========== [device] =============
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # ========== [cnn model] ==========
    ckpt = torch.load(dir_ckpt + '/' + 'best.pt')
    model = ResNet(pretrain=False)
    model.load_state_dict(ckpt["cnn"])
    model.to(device)
    linear_clf = Linear_Classifier(classNum=10)
    linear_clf.load_state_dict(ckpt["clf"])
    linear_clf.to(device)
    # opt_clf = optim.SGD(linear_clf.parameters(),
    #                      lr=1e-2,
    #                      momentum=0.9,
    #                      weight_decay=5e-4
    #                      )
    opt_clf = optim.Adam(linear_clf.parameters(), lr=1e-2, weight_decay=5e-4)

    best_acc = 0.0
    for i in range(1, epoch + 1):
        Train_CLF(ep=i, fine_tune=True)
        Test_CLF(path=dir_ckpt + '/' + "best.pt")
예제 #9
0
    model = ResNet(SEBasicBlock, [3, 4, 6, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


def se_resnet50(num_classes=1_000, pretrained=False):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(SEBottleneck, [3, 4, 6, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    if pretrained:
        model.load_state_dict(
            load_state_dict_from_url(
                "https://github.com/moskomule/senet.pytorch/releases/download/archive/seresnet50-60a8950a85b2b.pkl"
            ))
    return model


def se_resnet101(num_classes=1_000):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(SEBottleneck, [3, 4, 23, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model

예제 #10
0
                'errors_fake_D': errors_fake_D,
                'errors_D': errors_D,
                'errors_G': errors_G,
                'losses_triplet': losses_triplet
            }, f)
else:
    print('Loading Existing Model.')
    model = ResNet(depth=50,
                   pretrained=False,
                   cut_at_pooling=False,
                   num_features=num_features,
                   norm=False,
                   dropout=0.5,
                   num_classes=datareader.num_class)
    model.load_state_dict(
        torch.load(osp.join(args.model_dir,
                            'best_triplet_%d.pth' % (hash_bit))))
    model.cuda()
''' ------------------------------- Testing --------------------------------- '''
if args.test:
    batch_size = args.triplet_batch_size
    ''' ============================= Testing ================================ '''
    model.eval()
    n_feat = hash_bit
    ''' Testing Query Features '''
    if args.dataset == 'cuhk03':
        prbX, galX, prbY, galY = datareader.read_pair_images(
            'test', need_augmentation=False)
    elif args.dataset == 'market1501':
        prbX, prbY, _ = datareader.read_images('query',
                                               need_augmentation=False,
예제 #11
0
    net = MobileNetV2(num_classes=3).cuda()
    net._modules.get('features')[-1].register_forward_hook(hook_feature)
else:
    print("BackBone: ResNet18")
    net = ResNet(num_classes=3).cuda()
    net._modules.get('features')[-2].register_forward_hook(hook_feature)

optimizer = torch.optim.SGD(net.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=5e-4)


# load checkpoint
if RESUME:
    # epoch38-acc99.24812316894531-1586176538.pt
    print("===> Resuming from checkpoint.")
    assert os.path.isfile('checkpoint/epoch50-acc99.24812316894531-1586534447.pt'), 'Error: no checkpoint found!'
    net.load_state_dict(torch.load('checkpoint/epoch50-acc99.24812316894531-1586534447.pt'))

criterion = TCLoss(3)

# test and generate CAM video 
if EPOCH == 0:
    test(testloader, net, USE_CUDA, criterion, 0)
for epoch in range (1, EPOCH + 1):
    net = train(trainloader, net, USE_CUDA, epoch, EPOCH + 1, criterion, optimizer, time_consistency)
    test(testloader, net, USE_CUDA, criterion, epoch)
    #calculate avg localization loss
    test_localization(net, features_blobs, classes)
#generate CAM and output videos
if CAM:
    emptyFolder('/userhome/30/yfyang/pytorch-CAM/result/CAM/CPowerSupply/*.jpg')
    emptyFolder('/userhome/30/yfyang/pytorch-CAM/result/CAM/CHardDrive/*.jpg')
def main():
    # load table data
    df_train = pd.read_csv("../input/train_curated.csv")
    df_noisy = pd.read_csv("../input/train_noisy.csv")
    df_test = pd.read_csv("../input/sample_submission.csv")
    labels = df_test.columns[1:].tolist()
    for label in labels:
        df_train[label] = df_train['labels'].apply(lambda x: label in x)
        df_noisy[label] = df_noisy['labels'].apply(lambda x: label in x)

    df_train['path'] = "../input/mel128/train/" + df_train['fname']
    df_test['path'] = "../input/mel128/test/" + df_train['fname']
    df_noisy['path'] = "../input/mel128/noisy/" + df_noisy['fname']

    # calc sampling weight
    df_train['weight'] = 1
    df_noisy['weight'] = len(df_train) / len(df_noisy)

    # generate pseudo label with sharpening
    tmp = np.load("../input/pseudo_label/preds_noisy.npy").mean(axis=(0, 1))
    tmp = tmp**TEMPERATURE
    tmp = tmp / tmp.sum(axis=1)[:, np.newaxis]
    df_noisy_pseudo = df_noisy.copy()
    df_noisy_pseudo[labels] = tmp

    # fold splitting
    folds = list(
        KFold(n_splits=NUM_FOLD, shuffle=True,
              random_state=SEED).split(np.arange(len(df_train))))
    folds_noisy = list(
        KFold(n_splits=NUM_FOLD, shuffle=True,
              random_state=SEED).split(np.arange(len(df_noisy))))

    # Training
    log_columns = [
        'epoch', 'bce', 'lwlrap', 'bce_noisy', 'lwlrap_noisy', 'semi_mse',
        'val_bce', 'val_lwlrap', 'time'
    ]
    for fold, (ids_train_split, ids_valid_split) in enumerate(folds):
        if fold + 1 not in FOLD_LIST: continue
        print("fold: {}".format(fold + 1))
        train_log = pd.DataFrame(columns=log_columns)

        # build model
        model = ResNet(NUM_CLASS).cuda()
        model.load_state_dict(
            torch.load("{}/weight_fold_{}_epoch_512.pth".format(
                LOAD_DIR, fold + 1)))

        # prepare data loaders
        df_train_fold = df_train.iloc[ids_train_split].reset_index(drop=True)
        dataset_train = MelDataset(
            df_train_fold['path'],
            df_train_fold[labels].values,
            crop=CROP_LENGTH,
            crop_mode='additional',
            crop_rate=CROP_RATE,
            mixup=True,
            freqmask=True,
            gain=True,
        )
        train_loader = DataLoader(
            dataset_train,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=1,
            pin_memory=True,
        )

        df_valid = df_train.iloc[ids_valid_split].reset_index(drop=True)
        dataset_valid = MelDataset(
            df_valid['path'],
            df_valid[labels].values,
        )
        valid_loader = DataLoader(
            dataset_valid,
            batch_size=1,
            shuffle=False,
            num_workers=1,
            pin_memory=True,
        )

        dataset_noisy = MelDataset(
            df_noisy['path'],
            df_noisy[labels].values,
            crop=CROP_LENGTH,
            crop_mode='additional',
            crop_rate=CROP_RATE,
            mixup=True,
            freqmask=True,
            gain=True,
        )
        noisy_loader = DataLoader(
            dataset_noisy,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=1,
            pin_memory=True,
        )
        noisy_itr = cycle(noisy_loader)

        df_semi = pd.concat([
            df_train.iloc[ids_train_split],
            df_noisy_pseudo.iloc[folds_noisy[fold][0]]
        ]).reset_index(drop=True)
        semi_sampler = torch.utils.data.sampler.WeightedRandomSampler(
            df_semi['weight'].values, len(df_semi))
        dataset_semi = MelDataset(
            df_semi['path'],
            df_semi[labels].values,
            crop=CROP_LENGTH,
            crop_mode='additional',
            crop_rate=CROP_RATE,
            mixup=True,
            freqmask=True,
            gain=True,
        )
        semi_loader = DataLoader(
            dataset_semi,
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=1,
            pin_memory=True,
            sampler=semi_sampler,
        )
        semi_itr = cycle(semi_loader)

        # set optimizer and loss
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=LR[0])
        scheduler = CosineLR(optimizer,
                             step_size_min=LR[1],
                             t0=len(train_loader) * NUM_CYCLE,
                             tmult=1)

        # training
        for epoch in range(NUM_EPOCH):
            # train for one epoch
            bce, lwlrap, bce_noisy, lwlrap_noisy, mse_semi = train(
                (train_loader, noisy_itr, semi_itr), model, optimizer,
                scheduler, epoch)

            # evaluate on validation set
            val_bce, val_lwlrap = validate(valid_loader, model)

            # print log
            endtime = time.time() - starttime
            print("Epoch: {}/{} ".format(epoch + 1, NUM_EPOCH) +
                  "CE: {:.4f} ".format(bce) +
                  "LwLRAP: {:.4f} ".format(lwlrap) +
                  "Noisy CE: {:.4f} ".format(bce_noisy) +
                  "Noisy LWLRAP: {:.4f} ".format(lwlrap_noisy) +
                  "Semi MSE: {:.4f} ".format(mse_semi) +
                  "Valid CE: {:.4f} ".format(val_bce) +
                  "Valid LWLRAP: {:.4f} ".format(val_lwlrap) +
                  "sec: {:.1f}".format(endtime))

            # save log and weights
            train_log_epoch = pd.DataFrame([[
                epoch + 1, bce, lwlrap, bce_noisy, lwlrap_noisy, mse_semi,
                val_bce, val_lwlrap, endtime
            ]],
                                           columns=log_columns)
            train_log = pd.concat([train_log, train_log_epoch])
            train_log.to_csv("{}/train_log_fold{}.csv".format(
                OUTPUT_DIR, fold + 1),
                             index=False)
            if (epoch + 1) % NUM_CYCLE == 0:
                torch.save(
                    model.state_dict(),
                    "{}/weight_fold_{}_epoch_{}.pth".format(
                        OUTPUT_DIR, fold + 1, epoch + 1))
예제 #13
0
def initiate_cifar10(random_model=False):

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    trainset = datasets.CIFAR10(root='./data',
                                train=True,
                                download=True,
                                transform=transform)
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=128,
                                               shuffle=True,
                                               **kwargs)

    testset = datasets.CIFAR10(root='./data',
                               train=False,
                               download=True,
                               transform=transform)
    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=200,
                                              shuffle=False,
                                              **kwargs)

    classes = [
        'plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship',
        'truck'
    ]

    def show_images(images, labels):
        num_img = len(images)
        np_images = [img.numpy() for img in images]
        fig, axes = plt.subplots(nrows=1, ncols=num_img, figsize=(20, 45))

        for i, ax in enumerate(axes.flat):
            ax.set_axis_off()
            im = ax.imshow(np_images[i], vmin=0., vmax=1.)
            ax.set_title(f'{labels[i]}')
            plt.axis("off")

        fig.subplots_adjust(bottom=0.1,
                            top=0.9,
                            left=0.1,
                            right=0.8,
                            wspace=0.1,
                            hspace=0.25)

        plt.show()

    images, labels = iter(train_loader).next()
    num_img_to_plot = 9
    images = [images[i].permute(1, 2, 0) for i in range(num_img_to_plot)]
    labels = [classes[i] for i in labels[:num_img_to_plot]]
    # show_images(images, labels)

    model = ResNet().to(device)
    model_2 = ResNet().to(device)
    if not random_model:
        if not use_cuda:
            model.load_state_dict(
                torch.load("checkpoints/cifar/resnet_NT_ep_100.pt",
                           map_location='cpu'))
            model_2.load_state_dict(
                torch.load(
                    "checkpoints/cifar/resnet_RFGSM_eps_8_a_10_ep_100.pt",
                    map_location='cpu'))
        else:
            model.load_state_dict(
                torch.load("checkpoints/cifar/resnet_NT_ep_100.pt"))
            model_2.load_state_dict(
                torch.load(
                    "checkpoints/cifar/resnet_RFGSM_eps_8_a_10_ep_100.pt"))
    model.eval()
    model_2.eval()
    test_loss, test_acc = test(model, test_loader)
    print(f'Clean \t loss: {test_loss:.4f} \t acc: {test_acc:.4f}')

    return model, model_2, train_loader, test_loader