def trainval(exp_dict, savedir_base, reset, metrics_flag=True, datadir=None, cuda=False): # bookkeeping # --------------- # get experiment directory exp_id = hu.hash_dict(exp_dict) savedir = os.path.join(savedir_base, exp_id) if reset: # delete and backup experiment hc.delete_experiment(savedir, backup_flag=True) # create folder and save the experiment dictionary os.makedirs(savedir, exist_ok=True) hu.save_json(os.path.join(savedir, 'exp_dict.json'), exp_dict) print(pprint.pprint(exp_dict)) print('Experiment saved in %s' % savedir) # set seed # ================== seed = 42 + exp_dict['runs'] np.random.seed(seed) torch.manual_seed(seed) if cuda: device = 'cuda' torch.cuda.manual_seed_all(seed) else: device = 'cpu' print('Running on device: %s' % device) # Dataset # ================== train_set = datasets.get_dataset(dataset_name=exp_dict["dataset"], train_flag=True, datadir=datadir, exp_dict=exp_dict) train_loader = DataLoader(train_set, drop_last=True, shuffle=True, sampler=None, batch_size=exp_dict["batch_size"]) # Load Val Dataset val_set = datasets.get_dataset(dataset_name=exp_dict["dataset"], train_flag=False, datadir=datadir, exp_dict=exp_dict) # Model # ================== use_backpack = exp_dict['opt'].get("backpack", False) model = models.get_model(exp_dict["model"], train_set=train_set, backpack=use_backpack).to(device=device) if use_backpack: assert exp_dict['opt']['name'] in ['nus_wrapper', 'adaptive_second'] from backpack import extend model = extend(model) # Choose loss and metric function loss_function = metrics.get_metric_function(exp_dict["loss_func"]) # Load Optimizer # ============== n_batches_per_epoch = len(train_set) / float(exp_dict["batch_size"]) opt = optimizers.get_optimizer(opt=exp_dict["opt"], params=model.parameters(), n_batches_per_epoch=n_batches_per_epoch, n_train=len(train_set), train_loader=train_loader, model=model, loss_function=loss_function, exp_dict=exp_dict, batch_size=exp_dict["batch_size"]) # Checkpointing # ============= score_list_path = os.path.join(savedir, "score_list.pkl") model_path = os.path.join(savedir, "model_state_dict.pth") opt_path = os.path.join(savedir, "opt_state_dict.pth") if os.path.exists(score_list_path): # resume experiment score_list = ut.load_pkl(score_list_path) if use_backpack: model.load_state_dict(torch.load(model_path), strict=False) else: model.load_state_dict(torch.load(model_path)) opt.load_state_dict(torch.load(opt_path)) s_epoch = score_list[-1]["epoch"] + 1 else: # restart experiment score_list = [] s_epoch = 0 # Start Training # ============== n_train = len(train_loader.dataset) n_batches = len(train_loader) batch_size = train_loader.batch_size for epoch in range(s_epoch, exp_dict["max_epoch"]): # Set seed seed = epoch + exp_dict['runs'] np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) score_dict = {"epoch": epoch} # Validate # -------- if metrics_flag: # 1. Compute train loss over train set score_dict["train_loss"] = metrics.compute_metric_on_dataset( model, train_set, metric_name=exp_dict["loss_func"], batch_size=exp_dict['batch_size']) # 2. Compute val acc over val set score_dict["val_acc"] = metrics.compute_metric_on_dataset( model, val_set, metric_name=exp_dict["acc_func"], batch_size=exp_dict['batch_size']) # Train # ----- model.train() print("%d - Training model with %s..." % (epoch, exp_dict["loss_func"])) s_time = time.time() train_on_loader(model, train_set, train_loader, opt, loss_function, epoch, use_backpack) e_time = time.time() # Record step size and batch size score_dict["step"] = opt.state.get("step", 0) / int(n_batches_per_epoch) score_dict["step_size"] = opt.state.get("step_size", {}) score_dict["step_size_avg"] = opt.state.get("step_size_avg", {}) score_dict["n_forwards"] = opt.state.get("n_forwards", {}) score_dict["n_backwards"] = opt.state.get("n_backwards", {}) score_dict["grad_norm"] = opt.state.get("grad_norm", {}) score_dict["batch_size"] = batch_size score_dict["train_epoch_time"] = e_time - s_time score_dict.update(opt.state["gv_stats"]) # Add score_dict to score_list score_list += [score_dict] # Report and save print(pd.DataFrame(score_list).tail()) ut.save_pkl(score_list_path, score_list) ut.torch_save(model_path, model.state_dict()) ut.torch_save(opt_path, opt.state_dict()) print("Saved: %s" % savedir) return score_list
def trainval_svrg(exp_dict, savedir, datadir, metrics_flag=True): ''' SVRG-specific training and validation loop. ''' pprint.pprint(exp_dict) # Load Train Dataset train_set = datasets.get_dataset(dataset_name=exp_dict["dataset"], train_flag=True, datadir=datadir, exp_dict=exp_dict) train_loader = DataLoader(train_set, drop_last=False, shuffle=True, batch_size=exp_dict["batch_size"]) # Load Val Dataset val_set = datasets.get_dataset(dataset_name=exp_dict["dataset"], train_flag=False, datadir=datadir, exp_dict=exp_dict) # Load model model = models.get_model(exp_dict["model"], train_set=train_set).cuda() # Choose loss and metric function loss_function = metrics.get_metric_function(exp_dict["loss_func"]) # lookup the learning rate lr = get_svrg_step_size(exp_dict) # Load Optimizer opt = get_svrg_optimizer(model, loss_function, train_loader=train_loader, lr=lr) # Resume from last saved state_dict if (not os.path.exists(savedir + "/run_dict.pkl") or not os.path.exists(savedir + "/score_list.pkl")): ut.save_pkl(savedir + "/run_dict.pkl", {"running": 1}) score_list = [] s_epoch = 0 else: score_list = ut.load_pkl(savedir + "/score_list.pkl") model.load_state_dict(torch.load(savedir + "/model_state_dict.pth")) opt.load_state_dict(torch.load(savedir + "/opt_state_dict.pth")) s_epoch = score_list[-1]["epoch"] + 1 for epoch in range(s_epoch, exp_dict["max_epoch"]): score_dict = {"epoch": epoch} if metrics_flag: # 1. Compute train loss over train set score_dict["train_loss"] = metrics.compute_metric_on_dataset( model, train_set, metric_name=exp_dict["loss_func"]) # 2. Compute val acc over val set score_dict["val_acc"] = metrics.compute_metric_on_dataset( model, val_set, metric_name=exp_dict["acc_func"]) # 3. Train over train loader model.train() print("%d - Training model with %s..." % (epoch, exp_dict["loss_func"])) s_time = time.time() for images, labels in tqdm.tqdm(train_loader): images, labels = images.cuda(), labels.cuda() opt.zero_grad() closure = lambda svrg_model: loss_function( svrg_model, images, labels, backwards=True) opt.step(closure) e_time = time.time() # Record step size and batch size score_dict["step_size"] = opt.state["step_size"] score_dict["batch_size"] = train_loader.batch_size score_dict["train_epoch_time"] = e_time - s_time # Add score_dict to score_list score_list += [score_dict] # Report and save print(pd.DataFrame(score_list).tail()) ut.save_pkl(savedir + "/score_list.pkl", score_list) ut.torch_save(savedir + "/model_state_dict.pth", model.state_dict()) ut.torch_save(savedir + "/opt_state_dict.pth", opt.state_dict()) print("Saved: %s" % savedir) return score_list
def get_dataset(dataset_name, train_flag, datadir, exp_dict): if dataset_name == "mnist": dataset = torchvision.datasets.MNIST( datadir, train=train_flag, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, ), (0.5, )) ])) if dataset_name == "cifar10": transform_function = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) dataset = torchvision.datasets.CIFAR10(root=datadir, train=train_flag, download=True, transform=transform_function) if dataset_name == "cifar100": transform_function = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) dataset = torchvision.datasets.CIFAR100(root=datadir, train=train_flag, download=True, transform=transform_function) if dataset_name in ["mushrooms", "w8a", "rcv1", "ijcnn"]: sigma_dict = { "mushrooms": 0.5, "w8a": 20.0, "rcv1": 0.25, "ijcnn": 0.05 } X, y = load_libsvm(dataset_name, data_dir=datadir) labels = np.unique(y) y[y == labels[0]] = 0 y[y == labels[1]] = 1 # TODO: (amishkin) splits = train_test_split(X, y, test_size=0.2, shuffle=True, random_state=9513451) splits = train_test_split(X, y, test_size=0.2, shuffle=False, random_state=42) X_train, X_test, Y_train, Y_test = splits if train_flag: # fname_rbf = "%s/rbf_%s_train.pkl" % (datadir, dataset_name) # if os.path.exists(fname_rbf): # k_train_X = ut.load_pkl(fname_rbf) # else: k_train_X = rbf_kernel(X_train, X_train, sigma_dict[dataset_name]) # ut.save_pkl(fname_rbf, k_train_X) X_train = k_train_X X_train = torch.FloatTensor(X_train) Y_train = torch.FloatTensor(Y_train) dataset = torch.utils.data.TensorDataset(X_train, Y_train) else: # fname_rbf = "%s/rbf_%s_test.pkl" % (datadir, dataset_name) # if os.path.exists(fname_rbf): # k_test_X = ut.load_pkl(fname_rbf) # else: k_test_X = rbf_kernel(X_test, X_train, sigma_dict[dataset_name]) # ut.save_pkl(fname_rbf, k_test_X) X_test = k_test_X X_test = torch.FloatTensor(X_test) Y_test = torch.FloatTensor(Y_test) dataset = torch.utils.data.TensorDataset(X_test, Y_test) return dataset if dataset_name == "synthetic": margin = exp_dict["margin"] X, y, _, _ = make_binary_linear(n=exp_dict["n_samples"], d=exp_dict["d"], margin=margin, y01=True, bias=True, separable=True, seed=42) # No shuffling to keep the support vectors inside the training set splits = train_test_split(X, y, test_size=0.2, shuffle=False, random_state=42) X_train, X_test, Y_train, Y_test = splits X_train = torch.FloatTensor(X_train) X_test = torch.FloatTensor(X_test) Y_train = torch.FloatTensor(Y_train) Y_test = torch.FloatTensor(Y_test) if train_flag: dataset = torch.utils.data.TensorDataset(X_train, Y_train) else: dataset = torch.utils.data.TensorDataset(X_test, Y_test) return dataset if dataset_name == "matrix_fac": fname = datadir + 'matrix_fac.pkl' if not os.path.exists(fname): data = generate_synthetic_matrix_factorization_data() ut.save_pkl(fname, data) A, y = ut.load_pkl(fname) X_train, X_test, y_train, y_test = train_test_split( A, y, test_size=0.2, random_state=9513451) training_set = torch.utils.data.TensorDataset( torch.tensor(X_train, dtype=torch.float), torch.tensor(y_train, dtype=torch.float)) test_set = torch.utils.data.TensorDataset( torch.tensor(X_test, dtype=torch.float), torch.tensor(y_test, dtype=torch.float)) if train_flag: dataset = training_set else: dataset = test_set return dataset
def get_dataset(dataset_name, train_flag, datadir, exp_dict): if dataset_name in ['B', 'C']: bias = 1 scaling = 10 sparsity = 10 solutionSparsity = 0.1 n = 1000 if dataset_name == 'C': p = 100 if dataset_name == 'B': p = 10000 A = np.random.randn(n, p) + bias A = A.dot(np.diag(scaling * np.random.randn(p))) A = A * (np.random.rand(n, p) < (sparsity * np.log(n) / n)) w = np.random.randn(p) * (np.random.rand(p) < solutionSparsity) b = np.sign(A.dot(w)) b = b * np.sign(np.random.rand(n) - 0.1) labels = np.unique(b) A = A / np.linalg.norm(A, axis=1)[:, None].clip(min=1e-6) A = A * 2 b[b == labels[0]] = 0 b[b == labels[1]] = 1 dataset = torch.utils.data.TensorDataset(torch.FloatTensor(A), torch.FloatTensor(b)) return DatasetWrapper(dataset) if dataset_name == 'tiny_imagenet': if train_flag: transform_train = transforms.Compose([ # transforms.Resize(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) # define dataloader dataset = torchvision.datasets.ImageFolder( root=datadir, transform=transform_train) else: transform_test = transforms.Compose([ # transforms.Resize(32), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) dataset = torchvision.datasets.ImageFolder( root=datadir, transform=transform_test) if dataset_name == 'imagenette2-160': if train_flag: transform_train = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), # transforms.RandomResizedCrop(224), # transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) # define dataloader dataset = torchvision.datasets.ImageFolder( root=datadir, transform=transform_train) else: transform_test = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) dataset = torchvision.datasets.ImageFolder( root=datadir, transform=transform_test) if dataset_name == 'imagewoof2-160': if train_flag: transform_train = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), # transforms.RandomResizedCrop(224), # transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) # define dataloader dataset = torchvision.datasets.ImageFolder( root=datadir, transform=transform_train) else: transform_test = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) dataset = torchvision.datasets.ImageFolder( root=datadir, transform=transform_test) if dataset_name == "mnist": view = torchvision.transforms.Lambda(lambda x: x.view(-1).view(784)) dataset = torchvision.datasets.MNIST( datadir, train=train_flag, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, ), (0.5, )), view ])) if dataset_name == "cifar10": transform_function = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) dataset = torchvision.datasets.CIFAR10(root=datadir, train=train_flag, download=True, transform=transform_function) if dataset_name == "cifar100": transform_function = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) dataset = torchvision.datasets.CIFAR100(root=datadir, train=train_flag, download=True, transform=transform_function) if dataset_name in ['B', 'C']: bias = 1 scaling = 10 sparsity = 10 solutionSparsity = 0.1 n = 1000 if dataset_name == 'C': p = 100 if dataset_name == 'B': p = 10000 A = np.random.randn(n, p) + bias A = A.dot(np.diag(scaling * np.random.randn(p))) A = A * (np.random.rand(n, p) < (sparsity * np.log(n) / n)) w = np.random.randn(p) * (np.random.rand(p) < solutionSparsity) b = np.sign(A.dot(w)) b = b * np.sign(np.random.rand(n) - 0.1) labels = np.unique(b) A = A / np.linalg.norm(A, axis=1)[:, None].clip(min=1e-6) A = A * 2 b[b == labels[0]] = 0 b[b == labels[1]] = 1 # squared_max, squared_min = compute_max_eta_squared_loss(A) # logistic_max, logistic_min = compute_max_eta_logistic_loss(A) dataset = torch.utils.data.TensorDataset(torch.FloatTensor(A), torch.FloatTensor(b)) return DatasetWrapper(dataset) if dataset_name in [ "mushrooms", "w8a", "rcv1", "ijcnn", 'a1a', 'a2a', "mushrooms_convex", "w8a_convex", "rcv1_convex", "ijcnn_convex", 'a1a_convex', 'a2a_convex' ]: sigma_dict = { "mushrooms": 0.5, "w8a": 20.0, "rcv1": 0.25, "ijcnn": 0.05 } X, y = load_libsvm(dataset_name.replace('_convex', ''), data_dir=datadir) labels = np.unique(y) y[y == labels[0]] = 0 y[y == labels[1]] = 1 # splits used in experiments splits = train_test_split(X, y, test_size=0.2, shuffle=True, random_state=9513451) X_train, X_test, Y_train, Y_test = splits if "_convex" in dataset_name: if train_flag: # training set X_train = torch.FloatTensor(X_train.toarray()) Y_train = torch.FloatTensor(Y_train) dataset = torch.utils.data.TensorDataset(X_train, Y_train) else: # test set X_test = torch.FloatTensor(X_test.toarray()) Y_test = torch.FloatTensor(Y_test) dataset = torch.utils.data.TensorDataset(X_test, Y_test) return DatasetWrapper(dataset) if train_flag: # fname_rbf = "%s/rbf_%s_%s_train.pkl" % (datadir, dataset_name, sigma_dict[dataset_name]) fname_rbf = "%s/rbf_%s_%s_train.npy" % (datadir, dataset_name, sigma_dict[dataset_name]) if os.path.exists(fname_rbf): k_train_X = np.load(fname_rbf) else: k_train_X = rbf_kernel(X_train, X_train, sigma_dict[dataset_name]) np.save(fname_rbf, k_train_X) print('%s saved' % fname_rbf) X_train = k_train_X X_train = torch.FloatTensor(X_train) Y_train = torch.LongTensor(Y_train) dataset = torch.utils.data.TensorDataset(X_train, Y_train) else: fname_rbf = "%s/rbf_%s_%s_test.npy" % (datadir, dataset_name, sigma_dict[dataset_name]) if os.path.exists(fname_rbf): k_test_X = np.load(fname_rbf) else: k_test_X = rbf_kernel(X_test, X_train, sigma_dict[dataset_name]) np.save(fname_rbf, k_test_X) print('%s saved' % fname_rbf) X_test = k_test_X X_test = torch.FloatTensor(X_test) Y_test = torch.LongTensor(Y_test) dataset = torch.utils.data.TensorDataset(X_test, Y_test) if dataset_name == "synthetic": margin = exp_dict["margin"] X, y, _, _ = make_binary_linear(n=exp_dict["n_samples"], d=exp_dict["d"], margin=margin, y01=True, bias=True, separable=exp_dict.get( "separable", True), seed=42) # No shuffling to keep the support vectors inside the training set splits = train_test_split(X, y, test_size=0.2, shuffle=False, random_state=42) X_train, X_test, Y_train, Y_test = splits X_train = torch.FloatTensor(X_train) X_test = torch.FloatTensor(X_test) Y_train = torch.LongTensor(Y_train) Y_test = torch.LongTensor(Y_test) if train_flag: dataset = torch.utils.data.TensorDataset(X_train, Y_train) else: dataset = torch.utils.data.TensorDataset(X_test, Y_test) if dataset_name == "matrix_fac": fname = datadir + 'matrix_fac.pkl' if not os.path.exists(fname): data = generate_synthetic_matrix_factorization_data() ut.save_pkl(fname, data) A, y = ut.load_pkl(fname) X_train, X_test, y_train, y_test = train_test_split( A, y, test_size=0.2, random_state=9513451) training_set = torch.utils.data.TensorDataset( torch.tensor(X_train, dtype=torch.float), torch.tensor(y_train, dtype=torch.float)) test_set = torch.utils.data.TensorDataset( torch.tensor(X_test, dtype=torch.float), torch.tensor(y_test, dtype=torch.float)) if train_flag: dataset = training_set else: dataset = test_set return DatasetWrapper(dataset)
def trainval(exp_dict, savedir, datadir, metrics_flag=True): # TODO: Do we get similar results with different seeds? # Set seed np.random.seed(42) torch.manual_seed(42) torch.cuda.manual_seed_all(42) pprint.pprint(exp_dict) # Load Train Dataset train_set = datasets.get_dataset(dataset_name=exp_dict["dataset"], train_flag=True, datadir=datadir, exp_dict=exp_dict) train_loader = DataLoader(train_set, drop_last=True, shuffle=True, batch_size=exp_dict["batch_size"]) # Load Val Dataset val_set = datasets.get_dataset(dataset_name=exp_dict["dataset"], train_flag=False, datadir=datadir, exp_dict=exp_dict) # Load model model = models.get_model(exp_dict["model"], train_set=train_set).cuda() # Choose loss and metric function loss_function = metrics.get_metric_function(exp_dict["loss_func"]) # Load Optimizer n_batches_per_epoch = len(train_set) / float(exp_dict["batch_size"]) opt = optimizers.get_optimizer(opt=exp_dict["opt"], params=model.parameters(), n_batches_per_epoch=n_batches_per_epoch) # Resume from last saved state_dict if (not os.path.exists(savedir + "/run_dict.pkl") or not os.path.exists(savedir + "/score_list.pkl")): ut.save_pkl(savedir + "/run_dict.pkl", {"running": 1}) score_list = [] s_epoch = 0 else: score_list = ut.load_pkl(savedir + "/score_list.pkl") model.load_state_dict(torch.load(savedir + "/model_state_dict.pth")) opt.load_state_dict(torch.load(savedir + "/opt_state_dict.pth")) s_epoch = score_list[-1]["epoch"] + 1 for epoch in range(s_epoch, exp_dict["max_epoch"]): # Set seed np.random.seed(epoch) torch.manual_seed(epoch) torch.cuda.manual_seed_all(epoch) score_dict = {"epoch": epoch} if metrics_flag: # 1. Compute train loss over train set score_dict["train_loss"] = metrics.compute_metric_on_dataset( model, train_set, metric_name=exp_dict["loss_func"]) # 2. Compute val acc over val set score_dict["val_acc"] = metrics.compute_metric_on_dataset( model, val_set, metric_name=exp_dict["acc_func"]) # 3. Train over train loader model.train() print("%d - Training model with %s..." % (epoch, exp_dict["loss_func"])) s_time = time.time() for images, labels in tqdm.tqdm(train_loader): images, labels = images.cuda(), labels.cuda() opt.zero_grad() if exp_dict["opt"]["name"] in exp_configs.ours_opt_list + ["l4"]: closure = lambda: loss_function( model, images, labels, backwards=False) opt.step(closure) else: loss = loss_function(model, images, labels) loss.backward() opt.step() e_time = time.time() # Record step size and batch size score_dict["step_size"] = opt.state["step_size"] score_dict["n_forwards"] = opt.state["n_forwards"] score_dict["n_backwards"] = opt.state["n_backwards"] score_dict["batch_size"] = train_loader.batch_size score_dict["train_epoch_time"] = e_time - s_time # Add score_dict to score_list score_list += [score_dict] # Report and save print(pd.DataFrame(score_list).tail()) ut.save_pkl(savedir + "/score_list.pkl", score_list) ut.torch_save(savedir + "/model_state_dict.pth", model.state_dict()) ut.torch_save(savedir + "/opt_state_dict.pth", opt.state_dict()) print("Saved: %s" % savedir) return score_list
def trainval(exp_dict, savedir, args): # Set seed and device # =================== seed = 42 + exp_dict['runs'] np.random.seed(seed) torch.manual_seed(seed) if args.cuda: device = 'cuda' torch.cuda.manual_seed_all(seed) assert torch.cuda.is_available( ), 'cuda is not, available please run with "-c 0"' else: device = 'cpu' print('Running on device: %s' % device) # Load Datasets # ================== train_set = datasets.get_dataset(dataset_name=exp_dict["dataset"], split='train', datadir=args.datadir, exp_dict=exp_dict) train_loader = DataLoader(train_set, drop_last=True, shuffle=True, sampler=None, batch_size=exp_dict["batch_size"]) val_set = datasets.get_dataset(dataset_name=exp_dict["dataset"], split='val', datadir=args.datadir, exp_dict=exp_dict) # Load Model # ================== model = models.get_model(train_loader, exp_dict, device=device) model_path = os.path.join(savedir, "model.pth") score_list_path = os.path.join(savedir, "score_list.pkl") if os.path.exists(score_list_path): # resume experiment score_list = ut.load_pkl(score_list_path) model.set_state_dict(torch.load(model_path)) s_epoch = score_list[-1]["epoch"] + 1 else: # restart experiment score_list = [] s_epoch = 0 # Train and Val # ============== for epoch in range(s_epoch, exp_dict["max_epoch"]): # Set seed seed = epoch + exp_dict.get('runs', 0) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # Train one epoch s_time = time.time() model.train_on_loader(train_loader) e_time = time.time() # Validate one epoch train_loss_dict = model.val_on_dataset(train_set, metric=exp_dict["loss_func"], name='loss') val_acc_dict = model.val_on_dataset(val_set, metric=exp_dict["acc_func"], name='score') # Record metrics score_dict = {"epoch": epoch} score_dict.update(train_loss_dict) score_dict.update(val_acc_dict) score_dict["step_size"] = model.opt.state.get("step_size", {}) score_dict["step_size_avg"] = model.opt.state.get("step_size_avg", {}) score_dict["n_forwards"] = model.opt.state.get("n_forwards", {}) score_dict["n_backwards"] = model.opt.state.get("n_backwards", {}) score_dict["grad_norm"] = model.opt.state.get("grad_norm", {}) score_dict["train_epoch_time"] = e_time - s_time score_dict.update(model.opt.state["gv_stats"]) # Add score_dict to score_list score_list += [score_dict] # Report and save print(pd.DataFrame(score_list).tail()) ut.save_pkl(score_list_path, score_list) ut.torch_save(model_path, model.get_state_dict()) print("Saved: %s" % savedir)
def trainval_pls(exp_dict, savedir, datadir, metrics_flag=True): ''' PLS-specific training and validation loop. ''' pprint.pprint(exp_dict) # Load Train Dataset train_set = datasets.get_dataset(dataset_name=exp_dict["dataset"], train_flag=True, datadir=datadir, exp_dict=exp_dict) train_loader = DataLoader(train_set, drop_last=False, shuffle=True, batch_size=exp_dict["batch_size"]) # Load Val Dataset val_set = datasets.get_dataset(dataset_name=exp_dict["dataset"], train_flag=False, datadir=datadir, exp_dict=exp_dict) # Load model model = models.get_model(exp_dict["model"], train_set=train_set).cuda() # Choose loss and metric function if exp_dict["loss_func"] == 'logistic_loss': loss_function = logistic_loss_grad_moments else: raise ValueError("PLS only supports the logistic loss.") # Load Optimizer opt = pls.PLS(model, exp_dict["max_epoch"], exp_dict["batch_size"], expl_policy='exponential') # Resume from last saved state_dict if (not os.path.exists(savedir + "/run_dict.pkl") or not os.path.exists(savedir + "/score_list.pkl")): ut.save_pkl(savedir + "/run_dict.pkl", {"running": 1}) score_list = [] s_epoch = 0 else: score_list = ut.load_pkl(savedir + "/score_list.pkl") model.load_state_dict(torch.load(savedir + "/model_state_dict.pth")) opt.load_state_dict(torch.load(savedir + "/opt_state_dict.pth")) s_epoch = score_list[-1]["epoch"] + 1 # PLS-specific tracking for iterations and epochs: epoch = s_epoch iter_num = 0 iters_per_epoch = math.ceil( len(train_loader.dataset) / exp_dict['batch_size']) new_epoch = True while epoch < exp_dict["max_epoch"]: for images, labels in tqdm.tqdm(train_loader): # record metrics at the start of a new epoch if metrics_flag and new_epoch: new_epoch = False score_dict = {"epoch": epoch} # 1. Compute train loss over train set score_dict["train_loss"] = metrics.compute_metric_on_dataset( model, train_set, metric_name=exp_dict["loss_func"]) # 2. Compute val acc over val set score_dict["val_acc"] = metrics.compute_metric_on_dataset( model, val_set, metric_name=exp_dict["acc_func"]) # 3. Train over train loader model.train() print("%d - Training model with %s..." % (epoch, exp_dict["loss_func"])) s_time = time.time() images, labels = images.cuda(), labels.cuda() closure = grad_moment_closure_factory(model, images, labels, loss_function) # For PLS, calls to optimizer.step() do not correspond to a single optimizer step. # Instead, they correspond to one evaluation in the line-search, which may or many # not be accepted. opt.step(closure) # Epoch and iteration tracking. if opt.state['complete']: iter_num = iter_num + 1 # potentially increment the epoch counter. if iter_num % iters_per_epoch == 0: epoch = epoch + 1 new_epoch = True # compute metrics at end of previous epoch if new_epoch: e_time = time.time() # Record step size and batch size score_dict["step_size"] = opt.state["step_size"] score_dict["batch_size"] = train_loader.batch_size score_dict["train_epoch_time"] = e_time - s_time # Add score_dict to score_list score_list += [score_dict] # Report and save print(pd.DataFrame(score_list).tail()) ut.save_pkl(savedir + "/score_list.pkl", score_list) ut.torch_save(savedir + "/model_state_dict.pth", model.state_dict()) ut.torch_save(savedir + "/opt_state_dict.pth", opt.state_dict()) print("Saved: %s" % savedir) return score_list