Esempio n. 1
0
    def __init__(self,
                 config_dict: dict,
                 base_experiment_path: str,
                 model: BaseABCParam,
                 train_dataset: Dataset,
                 test_dataset: Dataset,
                 val_dataset: Dataset = None,
                 train_ratio: float = 0.8,
                 n_trials: int = 10,
                 early_stopping=False):
        self.width = config_dict['architecture']['width']
        self.batch_size = config_dict['training']['batch_size']
        self._set_base_lr(config_dict)

        super().__init__(config_dict, base_experiment_path)

        if 'n_epochs' in config_dict['training'].keys():
            self.max_epochs = config_dict['training']['n_epochs']
        else:
            self.max_epochs = self.MAX_EPOCHS

        if 'n_steps' in config_dict['training'].keys():
            self.max_steps = config_dict['training']['n_steps']
        else:
            self.max_steps = self.MAX_STEPS

        if val_dataset is None:
            self._set_train_val_data_from_train(train_dataset, train_ratio)
        self.test_dataset = test_dataset
        self._set_data_loaders()

        self.model = model
        self.n_trials = n_trials

        if 'early_stopping' in config_dict['training'].keys():
            self.early_stopping = config_dict['training']['early_stopping']
        else:
            self.early_stopping = early_stopping

        self.early_stopping_callback = False  # this is modified in _set_tb_logger_and_callbacks in early_stopping=True

        set_random_seeds(self.SEED)  # set random seed for reproducibility
        self.trial_seeds = np.random.randint(
            0, 100, size=n_trials)  # define random seeds to use for each trial
Esempio n. 2
0
def main(activation="relu", base_lr=0.01, batch_size=512, dataset="mnist"):
    config_path = os.path.join(CONFIG_PATH, 'fc_ipllr_{}.yaml'.format(dataset))
    figures_dir = os.path.join(FIGURES_DIR, dataset)
    create_dir(figures_dir)
    log_path = os.path.join(figures_dir, 'log_muP_{}.txt'.format(activation))
    logger = set_up_logger(log_path)

    logger.info('Parameters of the run:')
    logger.info('activation = {}'.format(activation))
    logger.info('base_lr = {}'.format(base_lr))
    logger.info('batch_size = {:,}'.format(batch_size))
    logger.info('Random SEED : {:,}'.format(SEED))
    logger.info(
        'Number of random trials for each model : {:,}'.format(N_TRIALS))

    try:
        set_random_seeds(SEED)  # set random seed for reproducibility
        config_dict = read_yaml(config_path)

        version = 'L={}_m={}_act={}_lr={}_bs={}'.format(
            L, width, activation, base_lr, batch_size)
        template_name = 'muP_{}_ranks_{}_' + version

        config_dict['architecture']['width'] = width
        config_dict['architecture']['n_layers'] = L + 1
        config_dict['optimizer']['params']['lr'] = base_lr
        config_dict['activation']['name'] = activation

        base_model_config = ModelConfig(config_dict)

        # Load data & define models
        logger.info('Loading data ...')
        if dataset == 'mnist':
            from utils.dataset.mnist import load_data
        elif dataset == 'cifar10':
            from utils.dataset.cifar10 import load_data
        elif dataset == 'cifar100':
            # TODO : add cifar100 to utils.dataset
            pass
        else:
            error = ValueError(
                "dataset must be one of ['mnist', 'cifar10', 'cifar100'] but was {}"
                .format(dataset))
            logger.error(error)
            raise error

        training_dataset, test_dataset = load_data(download=False,
                                                   flatten=True)
        train_data_loader = DataLoader(training_dataset,
                                       shuffle=True,
                                       batch_size=batch_size)
        batches = list(train_data_loader)

        full_x = torch.cat([a for a, _ in batches], dim=0)
        full_y = torch.cat([b for _, b in batches], dim=0)

        logger.info('Defining models')
        base_model_config.scheduler = None
        muPs = [FCmuP(base_model_config) for _ in range(N_TRIALS)]

        for muP in muPs:
            for i, param_group in enumerate(muP.optimizer.param_groups):
                if i == 0:
                    param_group['lr'] = param_group['lr'] * (muP.d + 1)

        # save initial models
        muPs_0 = [deepcopy(muP) for muP in muPs]

        # train model one step
        logger.info('Training model a first step (t=1)')
        x, y = batches[0]
        muPs_1 = []
        for muP in muPs:
            train_model_one_step(muP, x, y, normalize_first=True)
            muPs_1.append(deepcopy(muP))

        # train models for a second step
        logger.info('Training model a second step (t=2)')
        x, y = batches[1]
        muPs_2 = []
        for muP in muPs:
            train_model_one_step(muP, x, y, normalize_first=True)
            muPs_2.append(deepcopy(muP))

        # set eval mode for all models
        for i in range(N_TRIALS):
            muPs[i].eval()
            muPs_0[i].eval()
            muPs_1[i].eval()
            muPs_2[i].eval()

        logger.info('Storing initial and update matrices')
        # define W0 and b0
        W0s = []
        b0s = []
        for muP_0 in muPs_0:
            W0, b0 = get_W0_dict(muP_0, normalize_first=True)
            W0s.append(W0)
            b0s.append(b0)

        # define Delta_W_1 and Delta_b_1
        Delta_W_1s = []
        Delta_b_1s = []
        for i in range(N_TRIALS):
            Delta_W_1, Delta_b_1 = get_Delta_W1_dict(muPs_0[i],
                                                     muPs_1[i],
                                                     normalize_first=True)
            Delta_W_1s.append(Delta_W_1)
            Delta_b_1s.append(Delta_b_1)

        # define Delta_W_2 and Delta_b_2
        Delta_W_2s = []
        Delta_b_2s = []
        for i in range(N_TRIALS):
            Delta_W_2, Delta_b_2 = get_Delta_W2_dict(muPs_1[i],
                                                     muPs_2[i],
                                                     normalize_first=True)
            Delta_W_2s.append(Delta_W_2)
            Delta_b_2s.append(Delta_b_2)

        x, y = full_x, full_y  # compute pre-activations on full batch

        # contributions after first step
        h0s = []
        delta_h_1s = []
        h1s = []
        x1s = []
        for i in range(N_TRIALS):
            h0, delta_h_1, h1, x1 = get_contributions_1(x,
                                                        muPs_1[i],
                                                        W0s[i],
                                                        b0s[i],
                                                        Delta_W_1s[i],
                                                        Delta_b_1s[i],
                                                        normalize_first=True)
            h0s.append(h0)
            delta_h_1s.append(delta_h_1)
            h1s.append(h0)
            x1s.append(x1)

        # ranks of initial weight matrices and first two updates
        logger.info('Computing ranks of weight matrices ...')
        weight_ranks_dfs_dict = dict()

        tol = None
        weight_ranks_dfs_dict['svd_default'] = [
            get_svd_ranks_weights(W0s[i],
                                  Delta_W_1s[i],
                                  Delta_W_2s[i],
                                  L,
                                  tol=tol) for i in range(N_TRIALS)
        ]

        tol = 1e-7
        weight_ranks_dfs_dict['svd_tol'] = [
            get_svd_ranks_weights(W0s[i],
                                  Delta_W_1s[i],
                                  Delta_W_2s[i],
                                  L,
                                  tol=tol) for i in range(N_TRIALS)
        ]

        weight_ranks_dfs_dict['squared_tr'] = [
            get_square_trace_ranks_weights(W0s[i], Delta_W_1s[i],
                                           Delta_W_2s[i], L)
            for i in range(N_TRIALS)
        ]

        weight_ranks_df_dict = {
            key: get_concatenated_ranks_df(weight_ranks_dfs_dict[key])
            for key in weight_ranks_dfs_dict.keys()
        }
        avg_ranks_df_dict = {
            key: get_avg_ranks_dfs(weight_ranks_df_dict[key])
            for key in weight_ranks_df_dict.keys()
        }

        logger.info('Saving weights ranks data frames to csv ...')
        for key in weight_ranks_df_dict.keys():
            logger.info(key)
            logger.info('\n' + str(avg_ranks_df_dict[key]) + '\n\n')
            avg_ranks_df_dict[key].to_csv(os.path.join(
                figures_dir,
                template_name.format(key, 'weights') + '.csv'),
                                          header=True,
                                          index=True)

        ranks_dfs = [
            weight_ranks_df_dict['svd_default'],
            weight_ranks_df_dict['svd_tol'], weight_ranks_df_dict['squared_tr']
        ]

        # plot weights ranks
        logger.info('Plotting weights ranks')
        plt.figure(figsize=(12, 6))
        plot_weights_ranks_vs_layer('W0',
                                    ranks_dfs,
                                    tol,
                                    L,
                                    width,
                                    base_lr,
                                    batch_size,
                                    y_scale='log')
        plt.savefig(
            os.path.join(figures_dir,
                         template_name.format('W0', 'weights') + '.png'))

        plt.figure(figsize=(12, 6))
        plot_weights_ranks_vs_layer('Delta_W_1',
                                    ranks_dfs,
                                    tol,
                                    L,
                                    width,
                                    base_lr,
                                    batch_size,
                                    y_scale='log')
        plt.savefig(
            os.path.join(figures_dir,
                         template_name.format('Delta_W_1', 'weights') +
                         '.png'))

        plt.figure(figsize=(12, 6))
        plot_weights_ranks_vs_layer('Delta_W_2',
                                    ranks_dfs,
                                    tol,
                                    L,
                                    width,
                                    base_lr,
                                    batch_size,
                                    y_scale='log')
        plt.savefig(
            os.path.join(figures_dir,
                         template_name.format('Delta_W_2', 'weights') +
                         '.png'))

        # ranks of the pre-activations
        logger.info('Computing ranks of (pre-)activations ...')
        act_ranks_dfs_dict = dict()

        tol = None
        act_ranks_dfs_dict['svd_default'] = [
            get_svd_ranks_acts(h0s[i],
                               delta_h_1s[i],
                               h1s[i],
                               x1s[i],
                               L,
                               tol=tol) for i in range(N_TRIALS)
        ]

        tol = 1e-7
        act_ranks_dfs_dict['svd_tol'] = [
            get_svd_ranks_acts(h0s[i],
                               delta_h_1s[i],
                               h1s[i],
                               x1s[i],
                               L,
                               tol=tol) for i in range(N_TRIALS)
        ]

        act_ranks_dfs_dict['squared_tr'] = [
            get_square_trace_ranks_acts(h0s[i], delta_h_1s[i], h1s[i], x1s[i],
                                        L) for i in range(N_TRIALS)
        ]

        act_ranks_df_dict = {
            key: get_concatenated_ranks_df(act_ranks_dfs_dict[key])
            for key in act_ranks_dfs_dict.keys()
        }
        avg_ranks_df_dict = {
            key: get_avg_ranks_dfs(act_ranks_df_dict[key])
            for key in act_ranks_df_dict.keys()
        }

        logger.info('Saving (pre-)activation ranks data frames to csv ...')
        for key in avg_ranks_df_dict.keys():
            logger.info(key)
            logger.info('\n' + str(avg_ranks_df_dict[key]) + '\n\n')
            avg_ranks_df_dict[key].to_csv(os.path.join(
                figures_dir,
                template_name.format(key, 'acts') + '.csv'),
                                          header=True,
                                          index=True)

        ranks_dfs = [
            act_ranks_df_dict['svd_default'], act_ranks_df_dict['svd_tol'],
            act_ranks_df_dict['squared_tr']
        ]

        logger.info('Plotting (pre-)activation ranks')
        plt.figure(figsize=(12, 6))
        plot_acts_ranks_vs_layer('h0',
                                 ranks_dfs,
                                 tol,
                                 L,
                                 width,
                                 base_lr,
                                 batch_size,
                                 y_scale='log')
        plt.savefig(
            os.path.join(figures_dir,
                         template_name.format('h0', 'acts') + '.png'))

        plt.figure(figsize=(12, 6))
        plot_acts_ranks_vs_layer('h1',
                                 ranks_dfs,
                                 tol,
                                 L,
                                 width,
                                 base_lr,
                                 batch_size,
                                 y_scale='log')
        plt.savefig(
            os.path.join(figures_dir,
                         template_name.format('h1', 'acts') + '.png'))

        plt.figure(figsize=(12, 6))
        plot_acts_ranks_vs_layer('x1',
                                 ranks_dfs,
                                 tol,
                                 L,
                                 width,
                                 base_lr,
                                 batch_size,
                                 y_scale='log')
        plt.savefig(
            os.path.join(figures_dir,
                         template_name.format('x1', 'acts') + '.png'))

        plt.figure(figsize=(12, 6))
        plot_acts_ranks_vs_layer('delta_h_1',
                                 ranks_dfs,
                                 tol,
                                 L,
                                 width,
                                 base_lr,
                                 batch_size,
                                 y_scale='log')
        plt.savefig(
            os.path.join(figures_dir,
                         template_name.format('delta_h_1', 'acts') + '.png'))

        # diversity in terms of the index of the maximum entry
        logger.info(
            'Computing diversity of the maximum entry of pre-activations...')
        max_acts_diversity_dfs = [
            get_max_acts_diversity(h0s[i], delta_h_1s[i], h1s[i], L)
            for i in range(N_TRIALS)
        ]
        max_acts_diversity_df = get_concatenated_ranks_df(
            max_acts_diversity_dfs)
        avg_max_acts_diversity_df = get_avg_ranks_dfs(max_acts_diversity_df)
        logger.info('Diversity of the maximum activation index df:')
        logger.info(str(avg_max_acts_diversity_df))
        avg_max_acts_diversity_df.to_csv(os.path.join(
            figures_dir, 'muP_max_acts_' + version + '.csv'),
                                         header=True,
                                         index=True)

    except Exception as e:
        logger.exception("Exception when running the script : {}".format(e))
Esempio n. 3
0
def main(activation="relu",
         n_steps=300,
         base_lr=0.01,
         batch_size=512,
         dataset="mnist"):
    config_path = os.path.join(CONFIG_PATH, 'fc_ipllr_{}.yaml'.format(dataset))
    figures_dir = os.path.join(FIGURES_DIR, dataset)
    create_dir(figures_dir)
    log_path = os.path.join(figures_dir, 'log_muP_{}.txt'.format(activation))
    logger = set_up_logger(log_path)

    logger.info('Parameters of the run:')
    logger.info('activation = {}'.format(activation))
    logger.info('n_steps = {:,}'.format(n_steps))
    logger.info('base_lr = {}'.format(base_lr))
    logger.info('batch_size = {:,}'.format(batch_size))
    logger.info('Random SEED : {:,}'.format(SEED))
    logger.info(
        'Number of random trials for each model : {:,}'.format(N_TRIALS))

    try:
        set_random_seeds(SEED)  # set random seed for reproducibility
        config_dict = read_yaml(config_path)

        fig_name_template = 'muP_{}_{}_L={}_m={}_act={}_lr={}_bs={}.png'

        config_dict['architecture']['width'] = width
        config_dict['architecture']['n_layers'] = L + 1
        config_dict['optimizer']['params']['lr'] = base_lr
        config_dict['activation']['name'] = activation

        base_model_config = ModelConfig(config_dict)

        # Load data & define models
        logger.info('Loading data ...')
        if dataset == 'mnist':
            from utils.dataset.mnist import load_data
        elif dataset == 'cifar10':
            from utils.dataset.cifar10 import load_data
        elif dataset == 'cifar100':
            # TODO : add cifar100 to utils.dataset
            config_dict['architecture']['output_size'] = 100
            pass
        else:
            error = ValueError(
                "dataset must be one of ['mnist', 'cifar10', 'cifar100'] but was {}"
                .format(dataset))
            logger.error(error)
            raise error

        training_dataset, test_dataset = load_data(download=False,
                                                   flatten=True)
        train_data_loader = DataLoader(training_dataset,
                                       shuffle=True,
                                       batch_size=batch_size)
        batches = list(train_data_loader)

        logger.info('Defining models')
        base_model_config.scheduler = None
        muPs = [FCmuP(base_model_config) for _ in range(N_TRIALS)]
        muPs_renorm = [FCmuP(base_model_config) for _ in range(N_TRIALS)]
        muPs_renorm_scale_lr = [
            FCmuP(base_model_config) for _ in range(N_TRIALS)
        ]

        for muP in muPs_renorm_scale_lr:
            for i, param_group in enumerate(muP.optimizer.param_groups):
                if i == 0:
                    param_group['lr'] = param_group['lr'] * (muP.d + 1)

        logger.info('Copying parameters of base muP')
        for i in range(N_TRIALS):
            muPs_renorm[i].copy_initial_params_from_model(muPs[i])
            muPs_renorm_scale_lr[i].copy_initial_params_from_model(muPs[i])

            muPs_renorm[i].initialize_params()
            muPs_renorm_scale_lr[i].initialize_params()

        results = dict()
        logger.info('Generating training results ...')
        results['muP'] = [
            collect_training_losses(muPs[i],
                                    batches,
                                    n_steps,
                                    normalize_first=False)
            for i in range(N_TRIALS)
        ]

        results['muP_renorm'] = [
            collect_training_losses(muPs_renorm[i],
                                    batches,
                                    n_steps,
                                    normalize_first=True)
            for i in range(N_TRIALS)
        ]

        results['muP_renorm_scale_lr'] = [
            collect_training_losses(muPs_renorm_scale_lr[i],
                                    batches,
                                    n_steps,
                                    normalize_first=True)
            for i in range(N_TRIALS)
        ]

        mode = 'training'
        losses = dict()
        for key, res in results.items():
            losses[key] = [r[0] for r in res]

        chis = dict()
        for key, res in results.items():
            chis[key] = [r[1] for r in res]

        # Plot losses and derivatives
        logger.info('Saving figures at {}'.format(figures_dir))
        key = 'loss'
        plt.figure(figsize=(12, 8))
        plot_losses_models(losses,
                           key=key,
                           L=L,
                           width=width,
                           activation=activation,
                           lr=base_lr,
                           batch_size=batch_size,
                           mode=mode,
                           normalize_first=renorm_first,
                           marker=None,
                           name='muP')
        plt.ylim(0, 2.5)
        plt.savefig(
            os.path.join(
                figures_dir,
                fig_name_template.format(mode, key, L, width, activation,
                                         base_lr, batch_size)))

        key = 'chi'
        plt.figure(figsize=(12, 8))
        plot_losses_models(chis,
                           key=key,
                           L=L,
                           width=width,
                           activation=activation,
                           lr=base_lr,
                           batch_size=batch_size,
                           mode=mode,
                           marker=None,
                           name='muP')
        plt.savefig(
            os.path.join(
                figures_dir,
                fig_name_template.format(mode, key, L, width, activation,
                                         base_lr, batch_size)))

    except Exception as e:
        logger.exception("Exception when running the script : {}".format(e))
Esempio n. 4
0
    def _run_trial(self, idx):
        trial_name = 'trial_{}'.format(idx + 1)
        self.trial_dir = os.path.join(
            self.base_experiment_path,
            trial_name)  # folder to hold trial results

        if not os.path.exists(
                self.trial_dir):  # run trial only if it doesn't already exist
            create_dir(self.trial_dir)  # directory to save the trial
            set_random_seeds(
                self.trial_seeds[idx])  # set random seed for the trial

            self._set_tb_logger_and_callbacks(
                trial_name)  # tb logger, checkpoints and early stopping

            log_dir = os.path.join(
                self.trial_dir,
                self.LOG_NAME)  # define path to save the logs of the trial
            logger = set_up_logger(log_dir)

            config = ModelConfig(
                config_dict=self.config_dict
            )  # define the config as a class to pass to the model
            model = self.model(config)  # define the model

            logger.info('----- Trial {:,} ----- with model config {}\n'.format(
                idx + 1, self.model_config))
            self._log_experiment_info(len(self.train_dataset),
                                      len(self.val_dataset),
                                      len(self.test_dataset), model.std)
            logger.info('Random seed used for the script : {:,}'.format(
                self.SEED))
            logger.info('Number of model parameters : {:,}'.format(
                model.count_parameters()))
            logger.info('Model architecture :\n{}\n'.format(model))

            try:
                # training and validation pipeline
                trainer = pl.Trainer(
                    max_epochs=self.max_epochs,
                    max_steps=self.max_steps,
                    logger=self.tb_logger,
                    checkpoint_callback=self.checkpoint_callback,
                    num_sanity_val_steps=0,
                    early_stop_callback=self.early_stopping_callback)
                trainer.fit(model=model,
                            train_dataloader=self.train_data_loader,
                            val_dataloaders=self.val_data_loader)

                # test pipeline
                test_results = trainer.test(
                    model=model, test_dataloaders=self.test_data_loader)
                logger.info('Test results :\n{}\n'.format(test_results))

                # save all training, val and test results to pickle file
                with open(os.path.join(self.trial_dir, self.RESULTS_FILE),
                          'wb') as file:
                    pickle.dump(model.results, file)

            except Exception as e:
                # dump and save results before exiting
                with open(os.path.join(self.trial_dir, self.RESULTS_FILE),
                          'wb') as file:
                    pickle.dump(model.results, file)
                logger.warning('model results dumped before interruption')
                logger.exception(
                    "Exception while running the train-val-test pipeline : {}".
                    format(e))
                raise Exception(e)

        else:
            logging.warning(
                "Directory for trial {:,} of experiment {} already exists".
                format(idx, self.model_config))
Esempio n. 5
0
def main(activation="relu",
         n_steps=300,
         base_lr=0.01,
         batch_size=512,
         dataset="mnist"):
    config_path = os.path.join(CONFIG_PATH, 'fc_ipllr_{}.yaml'.format(dataset))
    figures_dir = os.path.join(FIGURES_DIR, dataset)
    create_dir(figures_dir)
    log_path = os.path.join(figures_dir, 'log_ipllr_{}.txt'.format(activation))
    logger = set_up_logger(log_path)

    logger.info('Parameters of the run:')
    logger.info('activation = {}'.format(activation))
    logger.info('n_steps = {:,}'.format(n_steps))
    logger.info('base_lr = {}'.format(base_lr))
    logger.info('batch_size = {:,}'.format(batch_size))
    logger.info('dataset = {}'.format(dataset))
    logger.info('Random SEED : {:,}'.format(SEED))
    logger.info(
        'Number of random trials for each model : {:,}'.format(N_TRIALS))

    try:
        set_random_seeds(SEED)  # set random seed for reproducibility
        config_dict = read_yaml(config_path)

        fig_name_template = 'IPLLRs_1_last_small_{}_{}_L={}_m={}_act={}_lr={}_bs={}.png'

        config_dict['architecture']['width'] = width
        config_dict['architecture']['n_layers'] = L + 1
        config_dict['optimizer']['params']['lr'] = base_lr
        config_dict['activation']['name'] = activation
        config_dict['scheduler'] = {
            'name': 'warmup_switch',
            'params': {
                'n_warmup_steps': n_warmup_steps,
                'calibrate_base_lr': True,
                'default_calibration': False
            }
        }

        # Load data & define models
        logger.info('Loading data ...')
        if dataset == 'mnist':
            from utils.dataset.mnist import load_data
        elif dataset == 'cifar10':
            from utils.dataset.cifar10 import load_data
        elif dataset == 'cifar100':
            # TODO : add cifar100 to utils.dataset
            pass
        else:
            error = ValueError(
                "dataset must be one of ['mnist', 'cifar10', 'cifar100'] but was {}"
                .format(dataset))
            logger.error(error)
            raise error

        training_dataset, test_dataset = load_data(download=False,
                                                   flatten=True)
        train_data_loader = DataLoader(training_dataset,
                                       shuffle=True,
                                       batch_size=batch_size)
        batches = list(train_data_loader)
        logger.info('Number of batches (steps) per epoch : {:,}'.format(
            len(batches)))
        logger.info('Number of epochs : {:,}'.format(n_steps // len(batches)))

        config_dict['scheduler']['params']['calibrate_base_lr'] = False
        config = ModelConfig(config_dict)

        logger.info('Defining models')
        ipllrs = [FcIPLLR(config) for _ in range(N_TRIALS)]

        config_dict['scheduler']['params']['calibrate_base_lr'] = True
        config = ModelConfig(config_dict)
        ipllrs_calib = [
            FcIPLLR(config, lr_calibration_batches=batches)
            for _ in range(N_TRIALS)
        ]
        ipllrs_calib_renorm = [
            FcIPLLR(config, lr_calibration_batches=batches)
            for _ in range(N_TRIALS)
        ]
        ipllrs_calib_renorm_scale_lr = [
            FcIPLLR(config, lr_calibration_batches=batches)
            for _ in range(N_TRIALS)
        ]

        logger.info('Copying parameters of base ipllr')
        for i in range(N_TRIALS):
            ipllrs_calib[i].copy_initial_params_from_model(ipllrs[i])
            ipllrs_calib_renorm[i].copy_initial_params_from_model(ipllrs[i])
            ipllrs_calib_renorm_scale_lr[i].copy_initial_params_from_model(
                ipllrs[i])

            ipllrs_calib[i].initialize_params()
            ipllrs_calib_renorm[i].initialize_params()
            ipllrs_calib_renorm_scale_lr[i].initialize_params()

        # Make sure calibration takes into account normalization
        logger.info('Recalibrating lrs with new initialisation')
        for ipllr in ipllrs_calib:
            initial_base_lrs = ipllr.scheduler.calibrate_base_lr(
                ipllr, batches=batches, normalize_first=False)
            ipllr.scheduler._set_param_group_lrs(initial_base_lrs)

        for ipllr in ipllrs_calib_renorm:
            initial_base_lrs = ipllr.scheduler.calibrate_base_lr(
                ipllr, batches=batches, normalize_first=True)
            ipllr.scheduler._set_param_group_lrs(initial_base_lrs)

        for ipllr in ipllrs_calib_renorm_scale_lr:
            initial_base_lrs = ipllr.scheduler.calibrate_base_lr(
                ipllr, batches=batches, normalize_first=True)
            ipllr.scheduler._set_param_group_lrs(initial_base_lrs)

        # scale lr of first layer if needed
        for ipllr in ipllrs_calib_renorm_scale_lr:
            ipllr.scheduler.warm_lrs[0] = ipllr.scheduler.warm_lrs[0] * (
                ipllr.d + 1)

        # with calibration
        results = dict()
        logger.info('Generating training results ...')
        results['ipllr_calib'] = [
            collect_training_losses(ipllrs_calib[i],
                                    batches,
                                    n_steps,
                                    normalize_first=False)
            for i in range(N_TRIALS)
        ]

        results['ipllr_calib_renorm'] = [
            collect_training_losses(ipllrs_calib_renorm[i],
                                    batches,
                                    n_steps,
                                    normalize_first=True)
            for i in range(N_TRIALS)
        ]

        results['ipllr_calib_renorm_scale_lr'] = [
            collect_training_losses(ipllrs_calib_renorm_scale_lr[i],
                                    batches,
                                    n_steps,
                                    normalize_first=True)
            for i in range(N_TRIALS)
        ]

        mode = 'training'
        losses = dict()
        for key, res in results.items():
            losses[key] = [r[0] for r in res]

        chis = dict()
        for key, res in results.items():
            chis[key] = [r[1] for r in res]

        # Plot losses and derivatives
        logger.info('Saving figures at {}'.format(figures_dir))
        key = 'loss'
        plt.figure(figsize=(12, 8))
        plot_losses_models(losses,
                           key=key,
                           L=L,
                           width=width,
                           activation=activation,
                           lr=base_lr,
                           batch_size=batch_size,
                           mode=mode,
                           normalize_first=renorm_first,
                           marker=None,
                           name='IPLLR')

        plt.savefig(
            os.path.join(
                figures_dir,
                fig_name_template.format(mode, key, L, width, activation,
                                         base_lr, batch_size)))

        key = 'chi'
        plt.figure(figsize=(12, 8))
        plot_losses_models(chis,
                           key=key,
                           L=L,
                           width=width,
                           activation=activation,
                           lr=base_lr,
                           batch_size=batch_size,
                           mode=mode,
                           marker=None,
                           name='IPLLR')
        plt.savefig(
            os.path.join(
                figures_dir,
                fig_name_template.format(mode, key, L, width, activation,
                                         base_lr, batch_size)))

    except Exception as e:
        logger.exception("Exception when running the script : {}".format(e))
Esempio n. 6
0
def main():
    print('ROOT :', ROOT)
    print('CONFIG_PATH :', CONFIG_PATH)

    # constants
    SEED = 30
    L = 6
    width = 1024
    n_warmup_steps = 1
    batch_size = 512
    base_lr = 0.1

    set_random_seeds(SEED)  # set random seed for reproducibility
    config_dict = read_yaml(CONFIG_PATH)

    config_dict['architecture']['width'] = width
    config_dict['architecture']['n_layers'] = L + 1
    config_dict['optimizer']['params']['lr'] = base_lr
    config_dict['scheduler'] = {
        'name': 'warmup_switch',
        'params': {
            'n_warmup_steps': n_warmup_steps,
            'calibrate_base_lr': True,
            'default_calibration': False
        }
    }

    base_model_config = ModelConfig(config_dict)

    # Load data & define model

    training_dataset, test_dataset = load_data(download=False, flatten=True)
    train_data_loader = DataLoader(training_dataset,
                                   shuffle=True,
                                   batch_size=batch_size)
    batches = list(train_data_loader)

    full_x = torch.cat([a for a, _ in batches], dim=0)
    full_y = torch.cat([b for _, b in batches], dim=0)

    # Define model

    ipllr = FcIPLLR(base_model_config,
                    n_warmup_steps=12,
                    lr_calibration_batches=batches)
    ipllr.scheduler.warm_lrs[0] = ipllr.scheduler.warm_lrs[0] * (ipllr.d + 1)

    # Save initial model : t=0
    ipllr_0 = deepcopy(ipllr)

    # Train model one step : t=1
    x, y = batches[0]
    train_model_one_step(ipllr, x, y, normalize_first=True)
    ipllr_1 = deepcopy(ipllr)

    # Train model for a second step : t=2
    x, y = batches[1]
    train_model_one_step(ipllr, x, y, normalize_first=True)
    ipllr_2 = deepcopy(ipllr)

    ipllr.eval()
    ipllr_0.eval()
    ipllr_1.eval()
    ipllr_2.eval()

    layer_scales = ipllr.layer_scales
    intermediate_layer_keys = [
        "layer_{:,}_intermediate".format(l) for l in range(2, L + 1)
    ]

    # Define W0 and b0
    with torch.no_grad():
        W0 = {
            1:
            layer_scales[0] * ipllr_0.input_layer.weight.data.detach() /
            math.sqrt(ipllr_0.d + 1)
        }
        for i, l in enumerate(range(2, L + 1)):
            layer = getattr(ipllr_0.intermediate_layers,
                            intermediate_layer_keys[i])
            W0[l] = layer_scales[l - 1] * layer.weight.data.detach()

        W0[L + 1] = layer_scales[L] * ipllr_0.output_layer.weight.data.detach()

    with torch.no_grad():
        b0 = layer_scales[0] * ipllr_0.input_layer.bias.data.detach(
        ) / math.sqrt(ipllr_0.d + 1)

    # Define Delta_W_1 and Delta_b_1
    with torch.no_grad():
        Delta_W_1 = {
            1:
            layer_scales[0] * (ipllr_1.input_layer.weight.data.detach() -
                               ipllr_0.input_layer.weight.data.detach()) /
            math.sqrt(ipllr_1.d + 1)
        }
        for i, l in enumerate(range(2, L + 1)):
            layer_1 = getattr(ipllr_1.intermediate_layers,
                              intermediate_layer_keys[i])
            layer_0 = getattr(ipllr_0.intermediate_layers,
                              intermediate_layer_keys[i])
            Delta_W_1[l] = layer_scales[l - 1] * (
                layer_1.weight.data.detach() - layer_0.weight.data.detach())

        Delta_W_1[
            L +
            1] = layer_scales[L] * (ipllr_1.output_layer.weight.data.detach() -
                                    ipllr_0.output_layer.weight.data.detach())

    with torch.no_grad():
        Delta_b_1 = layer_scales[0] * (
            ipllr_1.input_layer.bias.data.detach() -
            ipllr_0.input_layer.bias.data.detach()) / math.sqrt(ipllr_1.d + 1)

    # Define Delta_W_2
    with torch.no_grad():
        Delta_W_2 = {
            1:
            layer_scales[0] * (ipllr_2.input_layer.weight.data.detach() -
                               ipllr_1.input_layer.weight.data.detach()) /
            math.sqrt(ipllr_2.d + 1)
        }
        for i, l in enumerate(range(2, L + 1)):
            layer_2 = getattr(ipllr_2.intermediate_layers,
                              intermediate_layer_keys[i])
            layer_1 = getattr(ipllr_1.intermediate_layers,
                              intermediate_layer_keys[i])
            Delta_W_2[l] = layer_scales[l - 1] * (
                layer_2.weight.data.detach() - layer_1.weight.data.detach())

        Delta_W_2[
            L +
            1] = layer_scales[L] * (ipllr_2.output_layer.weight.data.detach() -
                                    ipllr_1.output_layer.weight.data.detach())

    with torch.no_grad():
        Delta_b_2 = layer_scales[0] * (
            ipllr_2.input_layer.bias.data.detach() -
            ipllr_1.input_layer.bias.data.detach()) / math.sqrt(ipllr_1.d + 1)

    # Ranks
    print('computing sympy Matrix ...')
    M = sympy.Matrix(Delta_W_1[1].numpy().tolist())

    print('Computing row echelon form ...')
    start = time()
    row_echelon = M.rref()
    end = time()

    print('Time for computing row echelon form : {:.3f} minutes'.format(
        (end - start) / 60))

    print(row_echelon)
    print(row_echelon[1])
    print(len(row_echelon[1]))