Esempio n. 1
0
def check_ldp(model_name, dataset, perc) :
    if dataset == 'imagenette' : 
        path = untar_data(URLs.IMAGENETTE)
    elif dataset == 'cifar10' : 
        path = untar_data(URLs.CIFAR)
    elif dataset == 'imagewoof' : 
        path = untar_data(URLs.IMAGEWOOF)
    
    new_path = path/('new' + str(perc))
    val = 'val'
    sz = 224
    stats = imagenet_stats

    tfms = get_transforms(do_flip=False)
    load_name = dataset
    if dataset == 'cifar10' : 
        val = 'test'
        sz = 32
        stats = cifar_stats
        load_name = dataset[ : -2]

    data = ImageDataBunch.from_folder(new_path, train = 'train', valid = 'val', test = 'test', bs = 64, size = sz, ds_tfms = tfms).normalize(stats)
    
    if model_name == 'resnet10' :
        net = resnet10(pretrained = False, progress = False)
    elif model_name == 'resnet14' : 
        net = resnet14(pretrained = False, progress = False)
    elif model_name == 'resnet18' :
        net = resnet18(pretrained = False, progress = False)
    elif model_name == 'resnet20' :
        net = resnet20(pretrained = False, progress = False)
    elif model_name == 'resnet26' :
        net = resnet26(pretrained = False, progress = False)
    savename = '../saved_models/' + dataset + '/less_data' + str(perc) + '/' + model_name + '_classifier/model0.pt'
    net.load_state_dict(torch.load(savename, map_location = 'cpu'))
    net.cuda()

    ld_stagewise_acc = _get_accuracy(data.valid_dl, net)
        
    return ld_stagewise_acc
            for i, (images, labels) in enumerate(data.valid_dl):
                if torch.cuda.is_available():
                    images = torch.autograd.Variable(images).cuda().float()
                    labels = torch.autograd.Variable(labels).cuda()
                else:
                    images = torch.autograd.Variable(images).float()
                    labels = torch.autograd.Variable(labels)

                # Forward pass
                y_pred = net(images)

                loss = F.cross_entropy(y_pred, labels)
                val.append(loss.item())

        val_loss = sum(val) / len(val)
        val_loss_list.append(val_loss)
        val_acc = _get_accuracy(data.valid_dl, net)
        experiment.log_metric("train_loss", train_loss)
        experiment.log_metric("val_loss", val_loss)
        experiment.log_metric("val_acc", val_acc * 100)

        print('epoch : ', epoch + 1,
              ' / ', hyper_params["num_epochs"], ' | TL : ',
              round(train_loss, 6), ' | VL : ', round(val_loss, 6), ' | VA : ',
              round(val_acc * 100, 6))

        if (val_acc * 100) > min_val:
            print('saving model')
            min_val = val_acc * 100
            torch.save(net.state_dict(), savename)
Esempio n. 3
0
            for i, (images, labels) in enumerate(data.valid_dl) :
                if torch.cuda.is_available():
                    images = torch.autograd.Variable(images).cuda().float()
                    labels = torch.autograd.Variable(labels).cuda()
                else : 
                    images = torch.autograd.Variable(images).float()
                    labels = torch.autograd.Variable(labels)

                # Forward pass
                outputs = net(images)
                loss = F.cross_entropy(outputs, labels)
                val.append(loss)

        val_loss = (sum(val) / len(val)).item()
        val_loss_list.append(val_loss)
        val_acc = _get_accuracy(data.valid_dl, net)
        val_acc_list.append(val_acc)
        print('epoch : ', epoch + 1, ' / ', hyper_params['num_epochs'], ' | TL : ', round(train_loss, 4), ' | VL : ', round(val_loss, 4), ' | VA : ', round(val_acc * 100, 6))
        
        if (val_acc * 100) > min_val :
            print('saving model')
            min_val = val_acc * 100
            torch.save(net.state_dict(), savename)

#     plt.plot(range(hyper_params['num_epochs']), train_loss_list, 'r', label = 'training_loss')
#     plt.plot(range(hyper_params['num_epochs']), val_loss_list, 'b', label = 'validation_loss')
#     plt.legend()
#     plt.savefig('../figures/' + str(hyper_params['dataset']) + '/resnet14_no_teacher/training_losses' + str(repeated) + '.jpg')
#     plt.close()

#     plt.plot(range(hyper_params['num_epochs']), val_acc_list, 'r', label = 'validation_accuracy')