def __init__(self, config_dict: dict, base_experiment_path: str, model: BaseABCParam, train_dataset: Dataset, test_dataset: Dataset, val_dataset: Dataset = None, train_ratio: float = 0.8, n_trials: int = 10, early_stopping=False): self.width = config_dict['architecture']['width'] self.batch_size = config_dict['training']['batch_size'] self._set_base_lr(config_dict) super().__init__(config_dict, base_experiment_path) if 'n_epochs' in config_dict['training'].keys(): self.max_epochs = config_dict['training']['n_epochs'] else: self.max_epochs = self.MAX_EPOCHS if 'n_steps' in config_dict['training'].keys(): self.max_steps = config_dict['training']['n_steps'] else: self.max_steps = self.MAX_STEPS if val_dataset is None: self._set_train_val_data_from_train(train_dataset, train_ratio) self.test_dataset = test_dataset self._set_data_loaders() self.model = model self.n_trials = n_trials if 'early_stopping' in config_dict['training'].keys(): self.early_stopping = config_dict['training']['early_stopping'] else: self.early_stopping = early_stopping self.early_stopping_callback = False # this is modified in _set_tb_logger_and_callbacks in early_stopping=True set_random_seeds(self.SEED) # set random seed for reproducibility self.trial_seeds = np.random.randint( 0, 100, size=n_trials) # define random seeds to use for each trial
def main(activation="relu", base_lr=0.01, batch_size=512, dataset="mnist"): config_path = os.path.join(CONFIG_PATH, 'fc_ipllr_{}.yaml'.format(dataset)) figures_dir = os.path.join(FIGURES_DIR, dataset) create_dir(figures_dir) log_path = os.path.join(figures_dir, 'log_muP_{}.txt'.format(activation)) logger = set_up_logger(log_path) logger.info('Parameters of the run:') logger.info('activation = {}'.format(activation)) logger.info('base_lr = {}'.format(base_lr)) logger.info('batch_size = {:,}'.format(batch_size)) logger.info('Random SEED : {:,}'.format(SEED)) logger.info( 'Number of random trials for each model : {:,}'.format(N_TRIALS)) try: set_random_seeds(SEED) # set random seed for reproducibility config_dict = read_yaml(config_path) version = 'L={}_m={}_act={}_lr={}_bs={}'.format( L, width, activation, base_lr, batch_size) template_name = 'muP_{}_ranks_{}_' + version config_dict['architecture']['width'] = width config_dict['architecture']['n_layers'] = L + 1 config_dict['optimizer']['params']['lr'] = base_lr config_dict['activation']['name'] = activation base_model_config = ModelConfig(config_dict) # Load data & define models logger.info('Loading data ...') if dataset == 'mnist': from utils.dataset.mnist import load_data elif dataset == 'cifar10': from utils.dataset.cifar10 import load_data elif dataset == 'cifar100': # TODO : add cifar100 to utils.dataset pass else: error = ValueError( "dataset must be one of ['mnist', 'cifar10', 'cifar100'] but was {}" .format(dataset)) logger.error(error) raise error training_dataset, test_dataset = load_data(download=False, flatten=True) train_data_loader = DataLoader(training_dataset, shuffle=True, batch_size=batch_size) batches = list(train_data_loader) full_x = torch.cat([a for a, _ in batches], dim=0) full_y = torch.cat([b for _, b in batches], dim=0) logger.info('Defining models') base_model_config.scheduler = None muPs = [FCmuP(base_model_config) for _ in range(N_TRIALS)] for muP in muPs: for i, param_group in enumerate(muP.optimizer.param_groups): if i == 0: param_group['lr'] = param_group['lr'] * (muP.d + 1) # save initial models muPs_0 = [deepcopy(muP) for muP in muPs] # train model one step logger.info('Training model a first step (t=1)') x, y = batches[0] muPs_1 = [] for muP in muPs: train_model_one_step(muP, x, y, normalize_first=True) muPs_1.append(deepcopy(muP)) # train models for a second step logger.info('Training model a second step (t=2)') x, y = batches[1] muPs_2 = [] for muP in muPs: train_model_one_step(muP, x, y, normalize_first=True) muPs_2.append(deepcopy(muP)) # set eval mode for all models for i in range(N_TRIALS): muPs[i].eval() muPs_0[i].eval() muPs_1[i].eval() muPs_2[i].eval() logger.info('Storing initial and update matrices') # define W0 and b0 W0s = [] b0s = [] for muP_0 in muPs_0: W0, b0 = get_W0_dict(muP_0, normalize_first=True) W0s.append(W0) b0s.append(b0) # define Delta_W_1 and Delta_b_1 Delta_W_1s = [] Delta_b_1s = [] for i in range(N_TRIALS): Delta_W_1, Delta_b_1 = get_Delta_W1_dict(muPs_0[i], muPs_1[i], normalize_first=True) Delta_W_1s.append(Delta_W_1) Delta_b_1s.append(Delta_b_1) # define Delta_W_2 and Delta_b_2 Delta_W_2s = [] Delta_b_2s = [] for i in range(N_TRIALS): Delta_W_2, Delta_b_2 = get_Delta_W2_dict(muPs_1[i], muPs_2[i], normalize_first=True) Delta_W_2s.append(Delta_W_2) Delta_b_2s.append(Delta_b_2) x, y = full_x, full_y # compute pre-activations on full batch # contributions after first step h0s = [] delta_h_1s = [] h1s = [] x1s = [] for i in range(N_TRIALS): h0, delta_h_1, h1, x1 = get_contributions_1(x, muPs_1[i], W0s[i], b0s[i], Delta_W_1s[i], Delta_b_1s[i], normalize_first=True) h0s.append(h0) delta_h_1s.append(delta_h_1) h1s.append(h0) x1s.append(x1) # ranks of initial weight matrices and first two updates logger.info('Computing ranks of weight matrices ...') weight_ranks_dfs_dict = dict() tol = None weight_ranks_dfs_dict['svd_default'] = [ get_svd_ranks_weights(W0s[i], Delta_W_1s[i], Delta_W_2s[i], L, tol=tol) for i in range(N_TRIALS) ] tol = 1e-7 weight_ranks_dfs_dict['svd_tol'] = [ get_svd_ranks_weights(W0s[i], Delta_W_1s[i], Delta_W_2s[i], L, tol=tol) for i in range(N_TRIALS) ] weight_ranks_dfs_dict['squared_tr'] = [ get_square_trace_ranks_weights(W0s[i], Delta_W_1s[i], Delta_W_2s[i], L) for i in range(N_TRIALS) ] weight_ranks_df_dict = { key: get_concatenated_ranks_df(weight_ranks_dfs_dict[key]) for key in weight_ranks_dfs_dict.keys() } avg_ranks_df_dict = { key: get_avg_ranks_dfs(weight_ranks_df_dict[key]) for key in weight_ranks_df_dict.keys() } logger.info('Saving weights ranks data frames to csv ...') for key in weight_ranks_df_dict.keys(): logger.info(key) logger.info('\n' + str(avg_ranks_df_dict[key]) + '\n\n') avg_ranks_df_dict[key].to_csv(os.path.join( figures_dir, template_name.format(key, 'weights') + '.csv'), header=True, index=True) ranks_dfs = [ weight_ranks_df_dict['svd_default'], weight_ranks_df_dict['svd_tol'], weight_ranks_df_dict['squared_tr'] ] # plot weights ranks logger.info('Plotting weights ranks') plt.figure(figsize=(12, 6)) plot_weights_ranks_vs_layer('W0', ranks_dfs, tol, L, width, base_lr, batch_size, y_scale='log') plt.savefig( os.path.join(figures_dir, template_name.format('W0', 'weights') + '.png')) plt.figure(figsize=(12, 6)) plot_weights_ranks_vs_layer('Delta_W_1', ranks_dfs, tol, L, width, base_lr, batch_size, y_scale='log') plt.savefig( os.path.join(figures_dir, template_name.format('Delta_W_1', 'weights') + '.png')) plt.figure(figsize=(12, 6)) plot_weights_ranks_vs_layer('Delta_W_2', ranks_dfs, tol, L, width, base_lr, batch_size, y_scale='log') plt.savefig( os.path.join(figures_dir, template_name.format('Delta_W_2', 'weights') + '.png')) # ranks of the pre-activations logger.info('Computing ranks of (pre-)activations ...') act_ranks_dfs_dict = dict() tol = None act_ranks_dfs_dict['svd_default'] = [ get_svd_ranks_acts(h0s[i], delta_h_1s[i], h1s[i], x1s[i], L, tol=tol) for i in range(N_TRIALS) ] tol = 1e-7 act_ranks_dfs_dict['svd_tol'] = [ get_svd_ranks_acts(h0s[i], delta_h_1s[i], h1s[i], x1s[i], L, tol=tol) for i in range(N_TRIALS) ] act_ranks_dfs_dict['squared_tr'] = [ get_square_trace_ranks_acts(h0s[i], delta_h_1s[i], h1s[i], x1s[i], L) for i in range(N_TRIALS) ] act_ranks_df_dict = { key: get_concatenated_ranks_df(act_ranks_dfs_dict[key]) for key in act_ranks_dfs_dict.keys() } avg_ranks_df_dict = { key: get_avg_ranks_dfs(act_ranks_df_dict[key]) for key in act_ranks_df_dict.keys() } logger.info('Saving (pre-)activation ranks data frames to csv ...') for key in avg_ranks_df_dict.keys(): logger.info(key) logger.info('\n' + str(avg_ranks_df_dict[key]) + '\n\n') avg_ranks_df_dict[key].to_csv(os.path.join( figures_dir, template_name.format(key, 'acts') + '.csv'), header=True, index=True) ranks_dfs = [ act_ranks_df_dict['svd_default'], act_ranks_df_dict['svd_tol'], act_ranks_df_dict['squared_tr'] ] logger.info('Plotting (pre-)activation ranks') plt.figure(figsize=(12, 6)) plot_acts_ranks_vs_layer('h0', ranks_dfs, tol, L, width, base_lr, batch_size, y_scale='log') plt.savefig( os.path.join(figures_dir, template_name.format('h0', 'acts') + '.png')) plt.figure(figsize=(12, 6)) plot_acts_ranks_vs_layer('h1', ranks_dfs, tol, L, width, base_lr, batch_size, y_scale='log') plt.savefig( os.path.join(figures_dir, template_name.format('h1', 'acts') + '.png')) plt.figure(figsize=(12, 6)) plot_acts_ranks_vs_layer('x1', ranks_dfs, tol, L, width, base_lr, batch_size, y_scale='log') plt.savefig( os.path.join(figures_dir, template_name.format('x1', 'acts') + '.png')) plt.figure(figsize=(12, 6)) plot_acts_ranks_vs_layer('delta_h_1', ranks_dfs, tol, L, width, base_lr, batch_size, y_scale='log') plt.savefig( os.path.join(figures_dir, template_name.format('delta_h_1', 'acts') + '.png')) # diversity in terms of the index of the maximum entry logger.info( 'Computing diversity of the maximum entry of pre-activations...') max_acts_diversity_dfs = [ get_max_acts_diversity(h0s[i], delta_h_1s[i], h1s[i], L) for i in range(N_TRIALS) ] max_acts_diversity_df = get_concatenated_ranks_df( max_acts_diversity_dfs) avg_max_acts_diversity_df = get_avg_ranks_dfs(max_acts_diversity_df) logger.info('Diversity of the maximum activation index df:') logger.info(str(avg_max_acts_diversity_df)) avg_max_acts_diversity_df.to_csv(os.path.join( figures_dir, 'muP_max_acts_' + version + '.csv'), header=True, index=True) except Exception as e: logger.exception("Exception when running the script : {}".format(e))
def main(activation="relu", n_steps=300, base_lr=0.01, batch_size=512, dataset="mnist"): config_path = os.path.join(CONFIG_PATH, 'fc_ipllr_{}.yaml'.format(dataset)) figures_dir = os.path.join(FIGURES_DIR, dataset) create_dir(figures_dir) log_path = os.path.join(figures_dir, 'log_muP_{}.txt'.format(activation)) logger = set_up_logger(log_path) logger.info('Parameters of the run:') logger.info('activation = {}'.format(activation)) logger.info('n_steps = {:,}'.format(n_steps)) logger.info('base_lr = {}'.format(base_lr)) logger.info('batch_size = {:,}'.format(batch_size)) logger.info('Random SEED : {:,}'.format(SEED)) logger.info( 'Number of random trials for each model : {:,}'.format(N_TRIALS)) try: set_random_seeds(SEED) # set random seed for reproducibility config_dict = read_yaml(config_path) fig_name_template = 'muP_{}_{}_L={}_m={}_act={}_lr={}_bs={}.png' config_dict['architecture']['width'] = width config_dict['architecture']['n_layers'] = L + 1 config_dict['optimizer']['params']['lr'] = base_lr config_dict['activation']['name'] = activation base_model_config = ModelConfig(config_dict) # Load data & define models logger.info('Loading data ...') if dataset == 'mnist': from utils.dataset.mnist import load_data elif dataset == 'cifar10': from utils.dataset.cifar10 import load_data elif dataset == 'cifar100': # TODO : add cifar100 to utils.dataset config_dict['architecture']['output_size'] = 100 pass else: error = ValueError( "dataset must be one of ['mnist', 'cifar10', 'cifar100'] but was {}" .format(dataset)) logger.error(error) raise error training_dataset, test_dataset = load_data(download=False, flatten=True) train_data_loader = DataLoader(training_dataset, shuffle=True, batch_size=batch_size) batches = list(train_data_loader) logger.info('Defining models') base_model_config.scheduler = None muPs = [FCmuP(base_model_config) for _ in range(N_TRIALS)] muPs_renorm = [FCmuP(base_model_config) for _ in range(N_TRIALS)] muPs_renorm_scale_lr = [ FCmuP(base_model_config) for _ in range(N_TRIALS) ] for muP in muPs_renorm_scale_lr: for i, param_group in enumerate(muP.optimizer.param_groups): if i == 0: param_group['lr'] = param_group['lr'] * (muP.d + 1) logger.info('Copying parameters of base muP') for i in range(N_TRIALS): muPs_renorm[i].copy_initial_params_from_model(muPs[i]) muPs_renorm_scale_lr[i].copy_initial_params_from_model(muPs[i]) muPs_renorm[i].initialize_params() muPs_renorm_scale_lr[i].initialize_params() results = dict() logger.info('Generating training results ...') results['muP'] = [ collect_training_losses(muPs[i], batches, n_steps, normalize_first=False) for i in range(N_TRIALS) ] results['muP_renorm'] = [ collect_training_losses(muPs_renorm[i], batches, n_steps, normalize_first=True) for i in range(N_TRIALS) ] results['muP_renorm_scale_lr'] = [ collect_training_losses(muPs_renorm_scale_lr[i], batches, n_steps, normalize_first=True) for i in range(N_TRIALS) ] mode = 'training' losses = dict() for key, res in results.items(): losses[key] = [r[0] for r in res] chis = dict() for key, res in results.items(): chis[key] = [r[1] for r in res] # Plot losses and derivatives logger.info('Saving figures at {}'.format(figures_dir)) key = 'loss' plt.figure(figsize=(12, 8)) plot_losses_models(losses, key=key, L=L, width=width, activation=activation, lr=base_lr, batch_size=batch_size, mode=mode, normalize_first=renorm_first, marker=None, name='muP') plt.ylim(0, 2.5) plt.savefig( os.path.join( figures_dir, fig_name_template.format(mode, key, L, width, activation, base_lr, batch_size))) key = 'chi' plt.figure(figsize=(12, 8)) plot_losses_models(chis, key=key, L=L, width=width, activation=activation, lr=base_lr, batch_size=batch_size, mode=mode, marker=None, name='muP') plt.savefig( os.path.join( figures_dir, fig_name_template.format(mode, key, L, width, activation, base_lr, batch_size))) except Exception as e: logger.exception("Exception when running the script : {}".format(e))
def _run_trial(self, idx): trial_name = 'trial_{}'.format(idx + 1) self.trial_dir = os.path.join( self.base_experiment_path, trial_name) # folder to hold trial results if not os.path.exists( self.trial_dir): # run trial only if it doesn't already exist create_dir(self.trial_dir) # directory to save the trial set_random_seeds( self.trial_seeds[idx]) # set random seed for the trial self._set_tb_logger_and_callbacks( trial_name) # tb logger, checkpoints and early stopping log_dir = os.path.join( self.trial_dir, self.LOG_NAME) # define path to save the logs of the trial logger = set_up_logger(log_dir) config = ModelConfig( config_dict=self.config_dict ) # define the config as a class to pass to the model model = self.model(config) # define the model logger.info('----- Trial {:,} ----- with model config {}\n'.format( idx + 1, self.model_config)) self._log_experiment_info(len(self.train_dataset), len(self.val_dataset), len(self.test_dataset), model.std) logger.info('Random seed used for the script : {:,}'.format( self.SEED)) logger.info('Number of model parameters : {:,}'.format( model.count_parameters())) logger.info('Model architecture :\n{}\n'.format(model)) try: # training and validation pipeline trainer = pl.Trainer( max_epochs=self.max_epochs, max_steps=self.max_steps, logger=self.tb_logger, checkpoint_callback=self.checkpoint_callback, num_sanity_val_steps=0, early_stop_callback=self.early_stopping_callback) trainer.fit(model=model, train_dataloader=self.train_data_loader, val_dataloaders=self.val_data_loader) # test pipeline test_results = trainer.test( model=model, test_dataloaders=self.test_data_loader) logger.info('Test results :\n{}\n'.format(test_results)) # save all training, val and test results to pickle file with open(os.path.join(self.trial_dir, self.RESULTS_FILE), 'wb') as file: pickle.dump(model.results, file) except Exception as e: # dump and save results before exiting with open(os.path.join(self.trial_dir, self.RESULTS_FILE), 'wb') as file: pickle.dump(model.results, file) logger.warning('model results dumped before interruption') logger.exception( "Exception while running the train-val-test pipeline : {}". format(e)) raise Exception(e) else: logging.warning( "Directory for trial {:,} of experiment {} already exists". format(idx, self.model_config))
def main(activation="relu", n_steps=300, base_lr=0.01, batch_size=512, dataset="mnist"): config_path = os.path.join(CONFIG_PATH, 'fc_ipllr_{}.yaml'.format(dataset)) figures_dir = os.path.join(FIGURES_DIR, dataset) create_dir(figures_dir) log_path = os.path.join(figures_dir, 'log_ipllr_{}.txt'.format(activation)) logger = set_up_logger(log_path) logger.info('Parameters of the run:') logger.info('activation = {}'.format(activation)) logger.info('n_steps = {:,}'.format(n_steps)) logger.info('base_lr = {}'.format(base_lr)) logger.info('batch_size = {:,}'.format(batch_size)) logger.info('dataset = {}'.format(dataset)) logger.info('Random SEED : {:,}'.format(SEED)) logger.info( 'Number of random trials for each model : {:,}'.format(N_TRIALS)) try: set_random_seeds(SEED) # set random seed for reproducibility config_dict = read_yaml(config_path) fig_name_template = 'IPLLRs_1_last_small_{}_{}_L={}_m={}_act={}_lr={}_bs={}.png' config_dict['architecture']['width'] = width config_dict['architecture']['n_layers'] = L + 1 config_dict['optimizer']['params']['lr'] = base_lr config_dict['activation']['name'] = activation config_dict['scheduler'] = { 'name': 'warmup_switch', 'params': { 'n_warmup_steps': n_warmup_steps, 'calibrate_base_lr': True, 'default_calibration': False } } # Load data & define models logger.info('Loading data ...') if dataset == 'mnist': from utils.dataset.mnist import load_data elif dataset == 'cifar10': from utils.dataset.cifar10 import load_data elif dataset == 'cifar100': # TODO : add cifar100 to utils.dataset pass else: error = ValueError( "dataset must be one of ['mnist', 'cifar10', 'cifar100'] but was {}" .format(dataset)) logger.error(error) raise error training_dataset, test_dataset = load_data(download=False, flatten=True) train_data_loader = DataLoader(training_dataset, shuffle=True, batch_size=batch_size) batches = list(train_data_loader) logger.info('Number of batches (steps) per epoch : {:,}'.format( len(batches))) logger.info('Number of epochs : {:,}'.format(n_steps // len(batches))) config_dict['scheduler']['params']['calibrate_base_lr'] = False config = ModelConfig(config_dict) logger.info('Defining models') ipllrs = [FcIPLLR(config) for _ in range(N_TRIALS)] config_dict['scheduler']['params']['calibrate_base_lr'] = True config = ModelConfig(config_dict) ipllrs_calib = [ FcIPLLR(config, lr_calibration_batches=batches) for _ in range(N_TRIALS) ] ipllrs_calib_renorm = [ FcIPLLR(config, lr_calibration_batches=batches) for _ in range(N_TRIALS) ] ipllrs_calib_renorm_scale_lr = [ FcIPLLR(config, lr_calibration_batches=batches) for _ in range(N_TRIALS) ] logger.info('Copying parameters of base ipllr') for i in range(N_TRIALS): ipllrs_calib[i].copy_initial_params_from_model(ipllrs[i]) ipllrs_calib_renorm[i].copy_initial_params_from_model(ipllrs[i]) ipllrs_calib_renorm_scale_lr[i].copy_initial_params_from_model( ipllrs[i]) ipllrs_calib[i].initialize_params() ipllrs_calib_renorm[i].initialize_params() ipllrs_calib_renorm_scale_lr[i].initialize_params() # Make sure calibration takes into account normalization logger.info('Recalibrating lrs with new initialisation') for ipllr in ipllrs_calib: initial_base_lrs = ipllr.scheduler.calibrate_base_lr( ipllr, batches=batches, normalize_first=False) ipllr.scheduler._set_param_group_lrs(initial_base_lrs) for ipllr in ipllrs_calib_renorm: initial_base_lrs = ipllr.scheduler.calibrate_base_lr( ipllr, batches=batches, normalize_first=True) ipllr.scheduler._set_param_group_lrs(initial_base_lrs) for ipllr in ipllrs_calib_renorm_scale_lr: initial_base_lrs = ipllr.scheduler.calibrate_base_lr( ipllr, batches=batches, normalize_first=True) ipllr.scheduler._set_param_group_lrs(initial_base_lrs) # scale lr of first layer if needed for ipllr in ipllrs_calib_renorm_scale_lr: ipllr.scheduler.warm_lrs[0] = ipllr.scheduler.warm_lrs[0] * ( ipllr.d + 1) # with calibration results = dict() logger.info('Generating training results ...') results['ipllr_calib'] = [ collect_training_losses(ipllrs_calib[i], batches, n_steps, normalize_first=False) for i in range(N_TRIALS) ] results['ipllr_calib_renorm'] = [ collect_training_losses(ipllrs_calib_renorm[i], batches, n_steps, normalize_first=True) for i in range(N_TRIALS) ] results['ipllr_calib_renorm_scale_lr'] = [ collect_training_losses(ipllrs_calib_renorm_scale_lr[i], batches, n_steps, normalize_first=True) for i in range(N_TRIALS) ] mode = 'training' losses = dict() for key, res in results.items(): losses[key] = [r[0] for r in res] chis = dict() for key, res in results.items(): chis[key] = [r[1] for r in res] # Plot losses and derivatives logger.info('Saving figures at {}'.format(figures_dir)) key = 'loss' plt.figure(figsize=(12, 8)) plot_losses_models(losses, key=key, L=L, width=width, activation=activation, lr=base_lr, batch_size=batch_size, mode=mode, normalize_first=renorm_first, marker=None, name='IPLLR') plt.savefig( os.path.join( figures_dir, fig_name_template.format(mode, key, L, width, activation, base_lr, batch_size))) key = 'chi' plt.figure(figsize=(12, 8)) plot_losses_models(chis, key=key, L=L, width=width, activation=activation, lr=base_lr, batch_size=batch_size, mode=mode, marker=None, name='IPLLR') plt.savefig( os.path.join( figures_dir, fig_name_template.format(mode, key, L, width, activation, base_lr, batch_size))) except Exception as e: logger.exception("Exception when running the script : {}".format(e))
def main(): print('ROOT :', ROOT) print('CONFIG_PATH :', CONFIG_PATH) # constants SEED = 30 L = 6 width = 1024 n_warmup_steps = 1 batch_size = 512 base_lr = 0.1 set_random_seeds(SEED) # set random seed for reproducibility config_dict = read_yaml(CONFIG_PATH) config_dict['architecture']['width'] = width config_dict['architecture']['n_layers'] = L + 1 config_dict['optimizer']['params']['lr'] = base_lr config_dict['scheduler'] = { 'name': 'warmup_switch', 'params': { 'n_warmup_steps': n_warmup_steps, 'calibrate_base_lr': True, 'default_calibration': False } } base_model_config = ModelConfig(config_dict) # Load data & define model training_dataset, test_dataset = load_data(download=False, flatten=True) train_data_loader = DataLoader(training_dataset, shuffle=True, batch_size=batch_size) batches = list(train_data_loader) full_x = torch.cat([a for a, _ in batches], dim=0) full_y = torch.cat([b for _, b in batches], dim=0) # Define model ipllr = FcIPLLR(base_model_config, n_warmup_steps=12, lr_calibration_batches=batches) ipllr.scheduler.warm_lrs[0] = ipllr.scheduler.warm_lrs[0] * (ipllr.d + 1) # Save initial model : t=0 ipllr_0 = deepcopy(ipllr) # Train model one step : t=1 x, y = batches[0] train_model_one_step(ipllr, x, y, normalize_first=True) ipllr_1 = deepcopy(ipllr) # Train model for a second step : t=2 x, y = batches[1] train_model_one_step(ipllr, x, y, normalize_first=True) ipllr_2 = deepcopy(ipllr) ipllr.eval() ipllr_0.eval() ipllr_1.eval() ipllr_2.eval() layer_scales = ipllr.layer_scales intermediate_layer_keys = [ "layer_{:,}_intermediate".format(l) for l in range(2, L + 1) ] # Define W0 and b0 with torch.no_grad(): W0 = { 1: layer_scales[0] * ipllr_0.input_layer.weight.data.detach() / math.sqrt(ipllr_0.d + 1) } for i, l in enumerate(range(2, L + 1)): layer = getattr(ipllr_0.intermediate_layers, intermediate_layer_keys[i]) W0[l] = layer_scales[l - 1] * layer.weight.data.detach() W0[L + 1] = layer_scales[L] * ipllr_0.output_layer.weight.data.detach() with torch.no_grad(): b0 = layer_scales[0] * ipllr_0.input_layer.bias.data.detach( ) / math.sqrt(ipllr_0.d + 1) # Define Delta_W_1 and Delta_b_1 with torch.no_grad(): Delta_W_1 = { 1: layer_scales[0] * (ipllr_1.input_layer.weight.data.detach() - ipllr_0.input_layer.weight.data.detach()) / math.sqrt(ipllr_1.d + 1) } for i, l in enumerate(range(2, L + 1)): layer_1 = getattr(ipllr_1.intermediate_layers, intermediate_layer_keys[i]) layer_0 = getattr(ipllr_0.intermediate_layers, intermediate_layer_keys[i]) Delta_W_1[l] = layer_scales[l - 1] * ( layer_1.weight.data.detach() - layer_0.weight.data.detach()) Delta_W_1[ L + 1] = layer_scales[L] * (ipllr_1.output_layer.weight.data.detach() - ipllr_0.output_layer.weight.data.detach()) with torch.no_grad(): Delta_b_1 = layer_scales[0] * ( ipllr_1.input_layer.bias.data.detach() - ipllr_0.input_layer.bias.data.detach()) / math.sqrt(ipllr_1.d + 1) # Define Delta_W_2 with torch.no_grad(): Delta_W_2 = { 1: layer_scales[0] * (ipllr_2.input_layer.weight.data.detach() - ipllr_1.input_layer.weight.data.detach()) / math.sqrt(ipllr_2.d + 1) } for i, l in enumerate(range(2, L + 1)): layer_2 = getattr(ipllr_2.intermediate_layers, intermediate_layer_keys[i]) layer_1 = getattr(ipllr_1.intermediate_layers, intermediate_layer_keys[i]) Delta_W_2[l] = layer_scales[l - 1] * ( layer_2.weight.data.detach() - layer_1.weight.data.detach()) Delta_W_2[ L + 1] = layer_scales[L] * (ipllr_2.output_layer.weight.data.detach() - ipllr_1.output_layer.weight.data.detach()) with torch.no_grad(): Delta_b_2 = layer_scales[0] * ( ipllr_2.input_layer.bias.data.detach() - ipllr_1.input_layer.bias.data.detach()) / math.sqrt(ipllr_1.d + 1) # Ranks print('computing sympy Matrix ...') M = sympy.Matrix(Delta_W_1[1].numpy().tolist()) print('Computing row echelon form ...') start = time() row_echelon = M.rref() end = time() print('Time for computing row echelon form : {:.3f} minutes'.format( (end - start) / 60)) print(row_echelon) print(row_echelon[1]) print(len(row_echelon[1]))