예제 #1
0
def trainval(exp_dict,
             savedir_base,
             reset,
             metrics_flag=True,
             datadir=None,
             cuda=False):
    # bookkeeping
    # ---------------

    # get experiment directory
    exp_id = hu.hash_dict(exp_dict)
    savedir = os.path.join(savedir_base, exp_id)

    if reset:
        # delete and backup experiment
        hc.delete_experiment(savedir, backup_flag=True)

    # create folder and save the experiment dictionary
    os.makedirs(savedir, exist_ok=True)
    hu.save_json(os.path.join(savedir, 'exp_dict.json'), exp_dict)
    print(pprint.pprint(exp_dict))
    print('Experiment saved in %s' % savedir)

    # set seed
    # ==================
    seed = 42 + exp_dict['runs']
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        device = 'cuda'
        torch.cuda.manual_seed_all(seed)
    else:
        device = 'cpu'

    print('Running on device: %s' % device)

    # Dataset
    # ==================
    train_set = datasets.get_dataset(dataset_name=exp_dict["dataset"],
                                     train_flag=True,
                                     datadir=datadir,
                                     exp_dict=exp_dict)

    train_loader = DataLoader(train_set,
                              drop_last=True,
                              shuffle=True,
                              sampler=None,
                              batch_size=exp_dict["batch_size"])

    # Load Val Dataset
    val_set = datasets.get_dataset(dataset_name=exp_dict["dataset"],
                                   train_flag=False,
                                   datadir=datadir,
                                   exp_dict=exp_dict)

    # Model
    # ==================
    use_backpack = exp_dict['opt'].get("backpack", False)

    model = models.get_model(exp_dict["model"],
                             train_set=train_set,
                             backpack=use_backpack).to(device=device)
    if use_backpack:
        assert exp_dict['opt']['name'] in ['nus_wrapper', 'adaptive_second']
        from backpack import extend
        model = extend(model)

    # Choose loss and metric function
    loss_function = metrics.get_metric_function(exp_dict["loss_func"])

    # Load Optimizer
    # ==============
    n_batches_per_epoch = len(train_set) / float(exp_dict["batch_size"])
    opt = optimizers.get_optimizer(opt=exp_dict["opt"],
                                   params=model.parameters(),
                                   n_batches_per_epoch=n_batches_per_epoch,
                                   n_train=len(train_set),
                                   train_loader=train_loader,
                                   model=model,
                                   loss_function=loss_function,
                                   exp_dict=exp_dict,
                                   batch_size=exp_dict["batch_size"])

    # Checkpointing
    # =============
    score_list_path = os.path.join(savedir, "score_list.pkl")
    model_path = os.path.join(savedir, "model_state_dict.pth")
    opt_path = os.path.join(savedir, "opt_state_dict.pth")

    if os.path.exists(score_list_path):
        # resume experiment
        score_list = ut.load_pkl(score_list_path)
        if use_backpack:
            model.load_state_dict(torch.load(model_path), strict=False)
        else:
            model.load_state_dict(torch.load(model_path))
        opt.load_state_dict(torch.load(opt_path))
        s_epoch = score_list[-1]["epoch"] + 1
    else:
        # restart experiment
        score_list = []
        s_epoch = 0

    # Start Training
    # ==============
    n_train = len(train_loader.dataset)
    n_batches = len(train_loader)
    batch_size = train_loader.batch_size

    for epoch in range(s_epoch, exp_dict["max_epoch"]):
        # Set seed
        seed = epoch + exp_dict['runs']
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        score_dict = {"epoch": epoch}

        # Validate
        # --------
        if metrics_flag:
            # 1. Compute train loss over train set
            score_dict["train_loss"] = metrics.compute_metric_on_dataset(
                model,
                train_set,
                metric_name=exp_dict["loss_func"],
                batch_size=exp_dict['batch_size'])

            # 2. Compute val acc over val set
            score_dict["val_acc"] = metrics.compute_metric_on_dataset(
                model,
                val_set,
                metric_name=exp_dict["acc_func"],
                batch_size=exp_dict['batch_size'])

        # Train
        # -----
        model.train()
        print("%d - Training model with %s..." %
              (epoch, exp_dict["loss_func"]))

        s_time = time.time()

        train_on_loader(model, train_set, train_loader, opt, loss_function,
                        epoch, use_backpack)

        e_time = time.time()

        # Record step size and batch size
        score_dict["step"] = opt.state.get("step",
                                           0) / int(n_batches_per_epoch)
        score_dict["step_size"] = opt.state.get("step_size", {})
        score_dict["step_size_avg"] = opt.state.get("step_size_avg", {})
        score_dict["n_forwards"] = opt.state.get("n_forwards", {})
        score_dict["n_backwards"] = opt.state.get("n_backwards", {})
        score_dict["grad_norm"] = opt.state.get("grad_norm", {})
        score_dict["batch_size"] = batch_size
        score_dict["train_epoch_time"] = e_time - s_time
        score_dict.update(opt.state["gv_stats"])

        # Add score_dict to score_list
        score_list += [score_dict]

        # Report and save
        print(pd.DataFrame(score_list).tail())
        ut.save_pkl(score_list_path, score_list)
        ut.torch_save(model_path, model.state_dict())
        ut.torch_save(opt_path, opt.state_dict())
        print("Saved: %s" % savedir)

    return score_list
예제 #2
0
def trainval_svrg(exp_dict, savedir, datadir, metrics_flag=True):
    '''
        SVRG-specific training and validation loop.
    '''
    pprint.pprint(exp_dict)

    # Load Train Dataset
    train_set = datasets.get_dataset(dataset_name=exp_dict["dataset"],
                                     train_flag=True,
                                     datadir=datadir,
                                     exp_dict=exp_dict)

    train_loader = DataLoader(train_set,
                              drop_last=False,
                              shuffle=True,
                              batch_size=exp_dict["batch_size"])

    # Load Val Dataset
    val_set = datasets.get_dataset(dataset_name=exp_dict["dataset"],
                                   train_flag=False,
                                   datadir=datadir,
                                   exp_dict=exp_dict)

    # Load model
    model = models.get_model(exp_dict["model"], train_set=train_set).cuda()

    # Choose loss and metric function
    loss_function = metrics.get_metric_function(exp_dict["loss_func"])

    # lookup the learning rate
    lr = get_svrg_step_size(exp_dict)

    # Load Optimizer
    opt = get_svrg_optimizer(model,
                             loss_function,
                             train_loader=train_loader,
                             lr=lr)

    # Resume from last saved state_dict
    if (not os.path.exists(savedir + "/run_dict.pkl")
            or not os.path.exists(savedir + "/score_list.pkl")):
        ut.save_pkl(savedir + "/run_dict.pkl", {"running": 1})
        score_list = []
        s_epoch = 0
    else:
        score_list = ut.load_pkl(savedir + "/score_list.pkl")
        model.load_state_dict(torch.load(savedir + "/model_state_dict.pth"))
        opt.load_state_dict(torch.load(savedir + "/opt_state_dict.pth"))
        s_epoch = score_list[-1]["epoch"] + 1

    for epoch in range(s_epoch, exp_dict["max_epoch"]):
        score_dict = {"epoch": epoch}

        if metrics_flag:
            # 1. Compute train loss over train set
            score_dict["train_loss"] = metrics.compute_metric_on_dataset(
                model, train_set, metric_name=exp_dict["loss_func"])

            # 2. Compute val acc over val set
            score_dict["val_acc"] = metrics.compute_metric_on_dataset(
                model, val_set, metric_name=exp_dict["acc_func"])

        # 3. Train over train loader
        model.train()
        print("%d - Training model with %s..." %
              (epoch, exp_dict["loss_func"]))

        s_time = time.time()
        for images, labels in tqdm.tqdm(train_loader):
            images, labels = images.cuda(), labels.cuda()

            opt.zero_grad()
            closure = lambda svrg_model: loss_function(
                svrg_model, images, labels, backwards=True)
            opt.step(closure)

        e_time = time.time()

        # Record step size and batch size
        score_dict["step_size"] = opt.state["step_size"]
        score_dict["batch_size"] = train_loader.batch_size
        score_dict["train_epoch_time"] = e_time - s_time

        # Add score_dict to score_list
        score_list += [score_dict]

        # Report and save
        print(pd.DataFrame(score_list).tail())
        ut.save_pkl(savedir + "/score_list.pkl", score_list)
        ut.torch_save(savedir + "/model_state_dict.pth", model.state_dict())
        ut.torch_save(savedir + "/opt_state_dict.pth", opt.state_dict())
        print("Saved: %s" % savedir)

    return score_list
예제 #3
0
def get_dataset(dataset_name, train_flag, datadir, exp_dict):
    if dataset_name == "mnist":
        dataset = torchvision.datasets.MNIST(
            datadir,
            train=train_flag,
            download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.5, ), (0.5, ))
            ]))

    if dataset_name == "cifar10":
        transform_function = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        dataset = torchvision.datasets.CIFAR10(root=datadir,
                                               train=train_flag,
                                               download=True,
                                               transform=transform_function)

    if dataset_name == "cifar100":
        transform_function = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        dataset = torchvision.datasets.CIFAR100(root=datadir,
                                                train=train_flag,
                                                download=True,
                                                transform=transform_function)

    if dataset_name in ["mushrooms", "w8a", "rcv1", "ijcnn"]:

        sigma_dict = {
            "mushrooms": 0.5,
            "w8a": 20.0,
            "rcv1": 0.25,
            "ijcnn": 0.05
        }

        X, y = load_libsvm(dataset_name, data_dir=datadir)

        labels = np.unique(y)

        y[y == labels[0]] = 0
        y[y == labels[1]] = 1
        # TODO: (amishkin) splits = train_test_split(X, y, test_size=0.2, shuffle=True, random_state=9513451)
        splits = train_test_split(X,
                                  y,
                                  test_size=0.2,
                                  shuffle=False,
                                  random_state=42)
        X_train, X_test, Y_train, Y_test = splits

        if train_flag:
            # fname_rbf = "%s/rbf_%s_train.pkl" % (datadir, dataset_name)

            # if os.path.exists(fname_rbf):
            #     k_train_X = ut.load_pkl(fname_rbf)
            # else:
            k_train_X = rbf_kernel(X_train, X_train, sigma_dict[dataset_name])
            # ut.save_pkl(fname_rbf, k_train_X)

            X_train = k_train_X
            X_train = torch.FloatTensor(X_train)
            Y_train = torch.FloatTensor(Y_train)

            dataset = torch.utils.data.TensorDataset(X_train, Y_train)

        else:
            # fname_rbf = "%s/rbf_%s_test.pkl" % (datadir, dataset_name)
            # if os.path.exists(fname_rbf):
            #     k_test_X = ut.load_pkl(fname_rbf)
            # else:
            k_test_X = rbf_kernel(X_test, X_train, sigma_dict[dataset_name])
            # ut.save_pkl(fname_rbf, k_test_X)

            X_test = k_test_X
            X_test = torch.FloatTensor(X_test)
            Y_test = torch.FloatTensor(Y_test)

            dataset = torch.utils.data.TensorDataset(X_test, Y_test)

        return dataset

    if dataset_name == "synthetic":
        margin = exp_dict["margin"]

        X, y, _, _ = make_binary_linear(n=exp_dict["n_samples"],
                                        d=exp_dict["d"],
                                        margin=margin,
                                        y01=True,
                                        bias=True,
                                        separable=True,
                                        seed=42)
        # No shuffling to keep the support vectors inside the training set
        splits = train_test_split(X,
                                  y,
                                  test_size=0.2,
                                  shuffle=False,
                                  random_state=42)
        X_train, X_test, Y_train, Y_test = splits

        X_train = torch.FloatTensor(X_train)
        X_test = torch.FloatTensor(X_test)

        Y_train = torch.FloatTensor(Y_train)
        Y_test = torch.FloatTensor(Y_test)

        if train_flag:
            dataset = torch.utils.data.TensorDataset(X_train, Y_train)
        else:
            dataset = torch.utils.data.TensorDataset(X_test, Y_test)

        return dataset

    if dataset_name == "matrix_fac":
        fname = datadir + 'matrix_fac.pkl'
        if not os.path.exists(fname):
            data = generate_synthetic_matrix_factorization_data()
            ut.save_pkl(fname, data)

        A, y = ut.load_pkl(fname)

        X_train, X_test, y_train, y_test = train_test_split(
            A, y, test_size=0.2, random_state=9513451)

        training_set = torch.utils.data.TensorDataset(
            torch.tensor(X_train, dtype=torch.float),
            torch.tensor(y_train, dtype=torch.float))
        test_set = torch.utils.data.TensorDataset(
            torch.tensor(X_test, dtype=torch.float),
            torch.tensor(y_test, dtype=torch.float))

        if train_flag:
            dataset = training_set
        else:
            dataset = test_set

    return dataset
예제 #4
0
파일: datasets.py 프로젝트: kiminh/ada_sls
def get_dataset(dataset_name, train_flag, datadir, exp_dict):
    if dataset_name in ['B', 'C']:
        bias = 1
        scaling = 10
        sparsity = 10
        solutionSparsity = 0.1
        n = 1000

        if dataset_name == 'C':
            p = 100
        if dataset_name == 'B':
            p = 10000

        A = np.random.randn(n, p) + bias
        A = A.dot(np.diag(scaling * np.random.randn(p)))
        A = A * (np.random.rand(n, p) < (sparsity * np.log(n) / n))
        w = np.random.randn(p) * (np.random.rand(p) < solutionSparsity)

        b = np.sign(A.dot(w))
        b = b * np.sign(np.random.rand(n) - 0.1)
        labels = np.unique(b)
        A = A / np.linalg.norm(A, axis=1)[:, None].clip(min=1e-6)
        A = A * 2
        b[b == labels[0]] = 0
        b[b == labels[1]] = 1

        dataset = torch.utils.data.TensorDataset(torch.FloatTensor(A),
                                                 torch.FloatTensor(b))

        return DatasetWrapper(dataset)

    if dataset_name == 'tiny_imagenet':
        if train_flag:
            transform_train = transforms.Compose([
                # transforms.Resize(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225)),
            ])
            # define dataloader
            dataset = torchvision.datasets.ImageFolder(
                root=datadir, transform=transform_train)

        else:
            transform_test = transforms.Compose([
                # transforms.Resize(32),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225)),
            ])
            dataset = torchvision.datasets.ImageFolder(
                root=datadir, transform=transform_test)

    if dataset_name == 'imagenette2-160':
        if train_flag:
            transform_train = transforms.Compose([
                transforms.Resize(224),
                transforms.CenterCrop(224),
                # transforms.RandomResizedCrop(224),
                # transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225)),
            ])
            # define dataloader
            dataset = torchvision.datasets.ImageFolder(
                root=datadir, transform=transform_train)

        else:
            transform_test = transforms.Compose([
                transforms.Resize(224),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225)),
            ])
            dataset = torchvision.datasets.ImageFolder(
                root=datadir, transform=transform_test)

    if dataset_name == 'imagewoof2-160':
        if train_flag:
            transform_train = transforms.Compose([
                transforms.Resize(224),
                transforms.CenterCrop(224),
                # transforms.RandomResizedCrop(224),
                # transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225)),
            ])
            # define dataloader
            dataset = torchvision.datasets.ImageFolder(
                root=datadir, transform=transform_train)

        else:
            transform_test = transforms.Compose([
                transforms.Resize(224),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225)),
            ])
            dataset = torchvision.datasets.ImageFolder(
                root=datadir, transform=transform_test)

    if dataset_name == "mnist":
        view = torchvision.transforms.Lambda(lambda x: x.view(-1).view(784))
        dataset = torchvision.datasets.MNIST(
            datadir,
            train=train_flag,
            download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.5, ), (0.5, )), view
            ]))

    if dataset_name == "cifar10":
        transform_function = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        dataset = torchvision.datasets.CIFAR10(root=datadir,
                                               train=train_flag,
                                               download=True,
                                               transform=transform_function)

    if dataset_name == "cifar100":
        transform_function = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        dataset = torchvision.datasets.CIFAR100(root=datadir,
                                                train=train_flag,
                                                download=True,
                                                transform=transform_function)

    if dataset_name in ['B', 'C']:
        bias = 1
        scaling = 10
        sparsity = 10
        solutionSparsity = 0.1
        n = 1000

        if dataset_name == 'C':
            p = 100
        if dataset_name == 'B':
            p = 10000

        A = np.random.randn(n, p) + bias
        A = A.dot(np.diag(scaling * np.random.randn(p)))
        A = A * (np.random.rand(n, p) < (sparsity * np.log(n) / n))
        w = np.random.randn(p) * (np.random.rand(p) < solutionSparsity)

        b = np.sign(A.dot(w))
        b = b * np.sign(np.random.rand(n) - 0.1)
        labels = np.unique(b)
        A = A / np.linalg.norm(A, axis=1)[:, None].clip(min=1e-6)
        A = A * 2
        b[b == labels[0]] = 0
        b[b == labels[1]] = 1
        # squared_max, squared_min = compute_max_eta_squared_loss(A)
        # logistic_max, logistic_min = compute_max_eta_logistic_loss(A)
        dataset = torch.utils.data.TensorDataset(torch.FloatTensor(A),
                                                 torch.FloatTensor(b))

        return DatasetWrapper(dataset)

    if dataset_name in [
            "mushrooms", "w8a", "rcv1", "ijcnn", 'a1a', 'a2a',
            "mushrooms_convex", "w8a_convex", "rcv1_convex", "ijcnn_convex",
            'a1a_convex', 'a2a_convex'
    ]:

        sigma_dict = {
            "mushrooms": 0.5,
            "w8a": 20.0,
            "rcv1": 0.25,
            "ijcnn": 0.05
        }

        X, y = load_libsvm(dataset_name.replace('_convex', ''),
                           data_dir=datadir)

        labels = np.unique(y)

        y[y == labels[0]] = 0
        y[y == labels[1]] = 1
        # splits used in experiments
        splits = train_test_split(X,
                                  y,
                                  test_size=0.2,
                                  shuffle=True,
                                  random_state=9513451)
        X_train, X_test, Y_train, Y_test = splits

        if "_convex" in dataset_name:
            if train_flag:
                # training set
                X_train = torch.FloatTensor(X_train.toarray())
                Y_train = torch.FloatTensor(Y_train)
                dataset = torch.utils.data.TensorDataset(X_train, Y_train)
            else:
                # test set
                X_test = torch.FloatTensor(X_test.toarray())
                Y_test = torch.FloatTensor(Y_test)
                dataset = torch.utils.data.TensorDataset(X_test, Y_test)

            return DatasetWrapper(dataset)

        if train_flag:
            # fname_rbf = "%s/rbf_%s_%s_train.pkl" % (datadir, dataset_name, sigma_dict[dataset_name])
            fname_rbf = "%s/rbf_%s_%s_train.npy" % (datadir, dataset_name,
                                                    sigma_dict[dataset_name])
            if os.path.exists(fname_rbf):
                k_train_X = np.load(fname_rbf)
            else:
                k_train_X = rbf_kernel(X_train, X_train,
                                       sigma_dict[dataset_name])
                np.save(fname_rbf, k_train_X)
                print('%s saved' % fname_rbf)

            X_train = k_train_X
            X_train = torch.FloatTensor(X_train)
            Y_train = torch.LongTensor(Y_train)

            dataset = torch.utils.data.TensorDataset(X_train, Y_train)

        else:
            fname_rbf = "%s/rbf_%s_%s_test.npy" % (datadir, dataset_name,
                                                   sigma_dict[dataset_name])
            if os.path.exists(fname_rbf):
                k_test_X = np.load(fname_rbf)
            else:
                k_test_X = rbf_kernel(X_test, X_train,
                                      sigma_dict[dataset_name])
                np.save(fname_rbf, k_test_X)
                print('%s saved' % fname_rbf)

            X_test = k_test_X
            X_test = torch.FloatTensor(X_test)
            Y_test = torch.LongTensor(Y_test)

            dataset = torch.utils.data.TensorDataset(X_test, Y_test)

    if dataset_name == "synthetic":
        margin = exp_dict["margin"]

        X, y, _, _ = make_binary_linear(n=exp_dict["n_samples"],
                                        d=exp_dict["d"],
                                        margin=margin,
                                        y01=True,
                                        bias=True,
                                        separable=exp_dict.get(
                                            "separable", True),
                                        seed=42)
        # No shuffling to keep the support vectors inside the training set
        splits = train_test_split(X,
                                  y,
                                  test_size=0.2,
                                  shuffle=False,
                                  random_state=42)
        X_train, X_test, Y_train, Y_test = splits

        X_train = torch.FloatTensor(X_train)
        X_test = torch.FloatTensor(X_test)

        Y_train = torch.LongTensor(Y_train)
        Y_test = torch.LongTensor(Y_test)

        if train_flag:
            dataset = torch.utils.data.TensorDataset(X_train, Y_train)
        else:
            dataset = torch.utils.data.TensorDataset(X_test, Y_test)

    if dataset_name == "matrix_fac":
        fname = datadir + 'matrix_fac.pkl'
        if not os.path.exists(fname):
            data = generate_synthetic_matrix_factorization_data()
            ut.save_pkl(fname, data)

        A, y = ut.load_pkl(fname)

        X_train, X_test, y_train, y_test = train_test_split(
            A, y, test_size=0.2, random_state=9513451)

        training_set = torch.utils.data.TensorDataset(
            torch.tensor(X_train, dtype=torch.float),
            torch.tensor(y_train, dtype=torch.float))
        test_set = torch.utils.data.TensorDataset(
            torch.tensor(X_test, dtype=torch.float),
            torch.tensor(y_test, dtype=torch.float))

        if train_flag:
            dataset = training_set
        else:
            dataset = test_set

    return DatasetWrapper(dataset)
예제 #5
0
def trainval(exp_dict, savedir, datadir, metrics_flag=True):
    # TODO: Do we get similar results with different seeds?
    # Set seed
    np.random.seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)

    pprint.pprint(exp_dict)

    # Load Train Dataset
    train_set = datasets.get_dataset(dataset_name=exp_dict["dataset"],
                                     train_flag=True,
                                     datadir=datadir,
                                     exp_dict=exp_dict)

    train_loader = DataLoader(train_set,
                              drop_last=True,
                              shuffle=True,
                              batch_size=exp_dict["batch_size"])

    # Load Val Dataset
    val_set = datasets.get_dataset(dataset_name=exp_dict["dataset"],
                                   train_flag=False,
                                   datadir=datadir,
                                   exp_dict=exp_dict)

    # Load model
    model = models.get_model(exp_dict["model"], train_set=train_set).cuda()

    # Choose loss and metric function
    loss_function = metrics.get_metric_function(exp_dict["loss_func"])

    # Load Optimizer
    n_batches_per_epoch = len(train_set) / float(exp_dict["batch_size"])
    opt = optimizers.get_optimizer(opt=exp_dict["opt"],
                                   params=model.parameters(),
                                   n_batches_per_epoch=n_batches_per_epoch)

    # Resume from last saved state_dict
    if (not os.path.exists(savedir + "/run_dict.pkl")
            or not os.path.exists(savedir + "/score_list.pkl")):
        ut.save_pkl(savedir + "/run_dict.pkl", {"running": 1})
        score_list = []
        s_epoch = 0
    else:
        score_list = ut.load_pkl(savedir + "/score_list.pkl")
        model.load_state_dict(torch.load(savedir + "/model_state_dict.pth"))
        opt.load_state_dict(torch.load(savedir + "/opt_state_dict.pth"))
        s_epoch = score_list[-1]["epoch"] + 1

    for epoch in range(s_epoch, exp_dict["max_epoch"]):
        # Set seed
        np.random.seed(epoch)
        torch.manual_seed(epoch)
        torch.cuda.manual_seed_all(epoch)

        score_dict = {"epoch": epoch}

        if metrics_flag:
            # 1. Compute train loss over train set
            score_dict["train_loss"] = metrics.compute_metric_on_dataset(
                model, train_set, metric_name=exp_dict["loss_func"])

            # 2. Compute val acc over val set
            score_dict["val_acc"] = metrics.compute_metric_on_dataset(
                model, val_set, metric_name=exp_dict["acc_func"])

        # 3. Train over train loader
        model.train()
        print("%d - Training model with %s..." %
              (epoch, exp_dict["loss_func"]))

        s_time = time.time()
        for images, labels in tqdm.tqdm(train_loader):
            images, labels = images.cuda(), labels.cuda()

            opt.zero_grad()

            if exp_dict["opt"]["name"] in exp_configs.ours_opt_list + ["l4"]:
                closure = lambda: loss_function(
                    model, images, labels, backwards=False)
                opt.step(closure)

            else:
                loss = loss_function(model, images, labels)
                loss.backward()
                opt.step()

        e_time = time.time()

        # Record step size and batch size
        score_dict["step_size"] = opt.state["step_size"]
        score_dict["n_forwards"] = opt.state["n_forwards"]
        score_dict["n_backwards"] = opt.state["n_backwards"]
        score_dict["batch_size"] = train_loader.batch_size
        score_dict["train_epoch_time"] = e_time - s_time

        # Add score_dict to score_list
        score_list += [score_dict]

        # Report and save
        print(pd.DataFrame(score_list).tail())
        ut.save_pkl(savedir + "/score_list.pkl", score_list)
        ut.torch_save(savedir + "/model_state_dict.pth", model.state_dict())
        ut.torch_save(savedir + "/opt_state_dict.pth", opt.state_dict())
        print("Saved: %s" % savedir)

    return score_list
예제 #6
0
def trainval(exp_dict, savedir, args):
    # Set seed and device
    # ===================
    seed = 42 + exp_dict['runs']
    np.random.seed(seed)
    torch.manual_seed(seed)
    if args.cuda:
        device = 'cuda'
        torch.cuda.manual_seed_all(seed)
        assert torch.cuda.is_available(
        ), 'cuda is not, available please run with "-c 0"'
    else:
        device = 'cpu'

    print('Running on device: %s' % device)

    # Load Datasets
    # ==================
    train_set = datasets.get_dataset(dataset_name=exp_dict["dataset"],
                                     split='train',
                                     datadir=args.datadir,
                                     exp_dict=exp_dict)

    train_loader = DataLoader(train_set,
                              drop_last=True,
                              shuffle=True,
                              sampler=None,
                              batch_size=exp_dict["batch_size"])

    val_set = datasets.get_dataset(dataset_name=exp_dict["dataset"],
                                   split='val',
                                   datadir=args.datadir,
                                   exp_dict=exp_dict)

    # Load Model
    # ==================
    model = models.get_model(train_loader, exp_dict, device=device)
    model_path = os.path.join(savedir, "model.pth")
    score_list_path = os.path.join(savedir, "score_list.pkl")

    if os.path.exists(score_list_path):
        # resume experiment
        score_list = ut.load_pkl(score_list_path)
        model.set_state_dict(torch.load(model_path))
        s_epoch = score_list[-1]["epoch"] + 1
    else:
        # restart experiment
        score_list = []
        s_epoch = 0

    # Train and Val
    # ==============
    for epoch in range(s_epoch, exp_dict["max_epoch"]):
        # Set seed
        seed = epoch + exp_dict.get('runs', 0)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        # Train one epoch
        s_time = time.time()
        model.train_on_loader(train_loader)
        e_time = time.time()

        # Validate one epoch
        train_loss_dict = model.val_on_dataset(train_set,
                                               metric=exp_dict["loss_func"],
                                               name='loss')
        val_acc_dict = model.val_on_dataset(val_set,
                                            metric=exp_dict["acc_func"],
                                            name='score')

        # Record metrics
        score_dict = {"epoch": epoch}
        score_dict.update(train_loss_dict)
        score_dict.update(val_acc_dict)
        score_dict["step_size"] = model.opt.state.get("step_size", {})
        score_dict["step_size_avg"] = model.opt.state.get("step_size_avg", {})
        score_dict["n_forwards"] = model.opt.state.get("n_forwards", {})
        score_dict["n_backwards"] = model.opt.state.get("n_backwards", {})
        score_dict["grad_norm"] = model.opt.state.get("grad_norm", {})
        score_dict["train_epoch_time"] = e_time - s_time
        score_dict.update(model.opt.state["gv_stats"])

        # Add score_dict to score_list
        score_list += [score_dict]

        # Report and save
        print(pd.DataFrame(score_list).tail())
        ut.save_pkl(score_list_path, score_list)
        ut.torch_save(model_path, model.get_state_dict())
        print("Saved: %s" % savedir)
예제 #7
0
def trainval_pls(exp_dict, savedir, datadir, metrics_flag=True):
    '''
		PLS-specific training and validation loop.
	'''
    pprint.pprint(exp_dict)

    # Load Train Dataset
    train_set = datasets.get_dataset(dataset_name=exp_dict["dataset"],
                                     train_flag=True,
                                     datadir=datadir,
                                     exp_dict=exp_dict)

    train_loader = DataLoader(train_set,
                              drop_last=False,
                              shuffle=True,
                              batch_size=exp_dict["batch_size"])

    # Load Val Dataset
    val_set = datasets.get_dataset(dataset_name=exp_dict["dataset"],
                                   train_flag=False,
                                   datadir=datadir,
                                   exp_dict=exp_dict)

    # Load model
    model = models.get_model(exp_dict["model"], train_set=train_set).cuda()

    # Choose loss and metric function
    if exp_dict["loss_func"] == 'logistic_loss':
        loss_function = logistic_loss_grad_moments
    else:
        raise ValueError("PLS only supports the logistic loss.")

    # Load Optimizer
    opt = pls.PLS(model,
                  exp_dict["max_epoch"],
                  exp_dict["batch_size"],
                  expl_policy='exponential')

    # Resume from last saved state_dict
    if (not os.path.exists(savedir + "/run_dict.pkl")
            or not os.path.exists(savedir + "/score_list.pkl")):
        ut.save_pkl(savedir + "/run_dict.pkl", {"running": 1})
        score_list = []
        s_epoch = 0
    else:
        score_list = ut.load_pkl(savedir + "/score_list.pkl")
        model.load_state_dict(torch.load(savedir + "/model_state_dict.pth"))
        opt.load_state_dict(torch.load(savedir + "/opt_state_dict.pth"))
        s_epoch = score_list[-1]["epoch"] + 1

    # PLS-specific tracking for iterations and epochs:
    epoch = s_epoch
    iter_num = 0
    iters_per_epoch = math.ceil(
        len(train_loader.dataset) / exp_dict['batch_size'])
    new_epoch = True

    while epoch < exp_dict["max_epoch"]:
        for images, labels in tqdm.tqdm(train_loader):
            # record metrics at the start of a new epoch
            if metrics_flag and new_epoch:
                new_epoch = False
                score_dict = {"epoch": epoch}

                # 1. Compute train loss over train set
                score_dict["train_loss"] = metrics.compute_metric_on_dataset(
                    model, train_set, metric_name=exp_dict["loss_func"])

                # 2. Compute val acc over val set
                score_dict["val_acc"] = metrics.compute_metric_on_dataset(
                    model, val_set, metric_name=exp_dict["acc_func"])

                # 3. Train over train loader
                model.train()
                print("%d - Training model with %s..." %
                      (epoch, exp_dict["loss_func"]))

                s_time = time.time()

            images, labels = images.cuda(), labels.cuda()
            closure = grad_moment_closure_factory(model, images, labels,
                                                  loss_function)

            # For PLS, calls to optimizer.step() do not correspond to a single optimizer step.
            # Instead, they correspond to one evaluation in the line-search, which may or many
            # not be accepted.
            opt.step(closure)

            # Epoch and iteration tracking.
            if opt.state['complete']:
                iter_num = iter_num + 1

                # potentially increment the epoch counter.
                if iter_num % iters_per_epoch == 0:
                    epoch = epoch + 1
                    new_epoch = True

            # compute metrics at end of previous epoch
            if new_epoch:
                e_time = time.time()

                # Record step size and batch size
                score_dict["step_size"] = opt.state["step_size"]
                score_dict["batch_size"] = train_loader.batch_size
                score_dict["train_epoch_time"] = e_time - s_time

                # Add score_dict to score_list
                score_list += [score_dict]

                # Report and save
                print(pd.DataFrame(score_list).tail())
                ut.save_pkl(savedir + "/score_list.pkl", score_list)
                ut.torch_save(savedir + "/model_state_dict.pth",
                              model.state_dict())
                ut.torch_save(savedir + "/opt_state_dict.pth",
                              opt.state_dict())
                print("Saved: %s" % savedir)

    return score_list