예제 #1
0
    def test_first_weight_update_ipllr(self):
        n_warmup_steps = int(10e5)
        n_samples = 200
        batch_size = 1
        n_batches = math.ceil(n_samples / batch_size)
        output_size = 1

        xs = torch.rand(size=(n_samples, self.input_size))
        ys = torch.rand(size=(n_samples, output_size))
        # ys = torch.randint(high=10, size=(n_samples,))
        batches = [(xs[i * batch_size: (i + 1) * batch_size, :], ys[i * batch_size: (i + 1) * batch_size, :])
                   for i in range(n_batches)]

        Ls = [2, 3, 4]
        base_lrs = [1., 0.1, 0.01]
        # Ls = [3]
        # base_lrs = [0.01]
        width = 1024
        activation = 'relu'
        bias = False

        for L in Ls:
            for base_lr in base_lrs:
                model_config = deepcopy(self.base_model_config)
                model_config.scheduler.params['n_warmup_steps'] = n_warmup_steps
                model_config.architecture['n_layers'] = L + 1
                model_config.architecture['output_size'] = output_size
                model_config.architecture['width'] = width
                model_config.architecture['bias'] = bias
                model_config.activation.name = activation
                model_config.loss.name = 'mse'
                model_config.optimizer.params['lr'] = base_lr

                ipllr = FcIPLLR(model_config)
                model_config.scheduler = None
                muP = FCmuP(model_config)
                ntk = FCNTK(model_config)

                ipllr.copy_initial_params_from_model(ntk, check_model=True)
                ipllr.initialize_params()

                muP.copy_initial_params_from_model(ntk, check_model=True)
                muP.initialize_params()

                self.assertEqual(ipllr.base_lr, base_lr)
                self.assertEqual(ipllr.scheduler.base_lr, base_lr)
                self.assertEqual(muP.base_lr, base_lr)
                self.assertEqual(ntk.base_lr, base_lr)

                # set all input biases to 0
                with torch.no_grad():
                    ipllr.input_layer.bias.data.fill_(0.)
                    muP.input_layer.bias.data.fill_(0.)
                    ntk.input_layer.bias.data.fill_(0.)

                ipllr_init = deepcopy(ipllr)
                muP_init = deepcopy(muP)
                ntk_init = deepcopy(ntk)

                ipllr.train()
                muP.train()
                ntk.train()

                for idx, batch in enumerate(batches):
                    if idx == 0:
                        loss_derivative_ratio = self._test_first_weight_updates(idx, batch, ipllr, muP, ntk, ipllr_init,
                                                                                muP_init, ntk_init, width, L, base_lr,
                                                                                loss_derivative_ratio=0)
                    else:
                        _ = self._test_first_weight_updates(idx, batch, ipllr, muP, ntk, ipllr_init, muP_init, ntk_init,
                                                            width, L, base_lr,
                                                            loss_derivative_ratio=loss_derivative_ratio)
예제 #2
0
    def test_second_forward_scales(self):
        n_warmup_steps = int(10e5)
        n_samples = 200
        batch_size = 1
        n_batches = math.ceil(n_samples / batch_size)
        output_size = 1

        xs = torch.rand(size=(n_samples, self.input_size))
        ys = torch.rand(size=(n_samples, output_size))
        # ys = torch.randint(high=10, size=(n_samples,))
        batches = [(xs[i * batch_size: (i + 1) * batch_size, :], ys[i * batch_size: (i + 1) * batch_size, :])
                   for i in range(n_batches)]

        Ls = [3, 4, 5]
        base_lrs = [0.01, 0.1, 1.0, 10.0]
        width = 4000
        activation = 'relu'
        bias = False

        for L in Ls:
            for base_lr in base_lrs:
                model_config = deepcopy(self.base_model_config)
                model_config.scheduler.params['n_warmup_steps'] = n_warmup_steps
                model_config.architecture['n_layers'] = L + 1
                model_config.architecture['output_size'] = output_size
                model_config.architecture['width'] = width
                model_config.architecture['bias'] = bias
                model_config.activation.name = activation
                model_config.loss.name = 'mse'
                model_config.optimizer.params['lr'] = base_lr

                ipllr = FcIPLLR(model_config)
                model_config.scheduler = None
                muP = FCmuP(model_config)
                ntk = FCNTK(model_config)

                ipllr.copy_initial_params_from_model(ntk, check_model=True)
                ipllr.initialize_params()

                muP.copy_initial_params_from_model(ntk, check_model=True)
                muP.initialize_params()

                self.assertEqual(ipllr.base_lr, base_lr)
                self.assertEqual(ipllr.scheduler.base_lr, base_lr)
                self.assertEqual(muP.base_lr, base_lr)
                self.assertEqual(ntk.base_lr, base_lr)

                # set all input biases to 0
                with torch.no_grad():
                    ipllr.input_layer.bias.data.fill_(0.)
                    muP.input_layer.bias.data.fill_(0.)
                    ntk.input_layer.bias.data.fill_(0.)

                ipllr_init = deepcopy(ipllr)
                muP_init = deepcopy(muP)
                ntk_init = deepcopy(ntk)

                ipllr.train()
                muP.train()
                ntk.train()
                
                # first train the models for one step
                x, y = batches[0]
                self._train_models_one_step(ipllr, muP, x, y)

                ipllr.eval()
                muP.eval()
                ntk.eval()

                contributions_df = pd.DataFrame(columns=['model', 'layer', 'init', 'update', 'total', 'id'])
                contributions_df.loc[:, ['init', 'update', 'total', 'id']] = \
                    contributions_df.loc[:, ['init', 'update', 'total', 'id']].astype(float)

                idx = self._compute_contributions(contributions_df, 'ipllr', ipllr, ipllr_init, batches, idx=0)
                _ = self._compute_contributions(contributions_df, 'muP', muP, muP_init, batches, idx)

                ipllr_contributions = contributions_df.loc[contributions_df.model == 'ipllr', :]
                muP_contributions = contributions_df.loc[contributions_df.model == 'muP', :]

                print('---- For L = {:,} and base_lr = {} ----'.format(L, base_lr))

                print('ipllr contributions per layer : ')
                print(ipllr_contributions.groupby(by='layer')[['init', 'update', 'total']].mean())
                print('')

                print('muP contributions per layer : ')
                print(muP_contributions.groupby(by='layer')[['init', 'update', 'total']].mean())
                print('\n\n')
예제 #3
0
    def test_ipllr_vs_muP(self):
        n_warmup_steps = int(10e5)
        n_samples = 200
        batch_size = 1
        n_batches = math.ceil(n_samples / batch_size)
        output_size = 1

        xs = torch.rand(size=(n_samples, self.input_size))
        ys = torch.rand(size=(n_samples, output_size))
        # ys = torch.randint(high=10, size=(n_samples,))
        batches = [(xs[i * batch_size: (i+1) * batch_size, :], ys[i * batch_size: (i+1) * batch_size, :])
                   for i in range(n_batches)]

        Ls = [2, 3, 4]
        base_lrs = [1., 0.1, 0.01]
        width = 1024
        activation = 'relu'
        bias = False

        for L in Ls:
            for base_lr in base_lrs:
                model_config = deepcopy(self.base_model_config)
                model_config.scheduler.params['n_warmup_steps'] = n_warmup_steps
                model_config.architecture['n_layers'] = L + 1
                model_config.architecture['output_size'] = output_size
                model_config.architecture['width'] = width
                model_config.architecture['bias'] = bias
                model_config.activation.name = activation
                model_config.loss.name = 'mse'
                model_config.optimizer.params['lr'] = base_lr

                ipllr = FcIPLLR(model_config)
                model_config.scheduler = None
                muP = FCmuP(model_config)
                ntk = FCNTK(model_config)

                ipllr.copy_initial_params_from_model(ntk, check_model=True)
                ipllr.initialize_params()

                muP.copy_initial_params_from_model(ntk, check_model=True)
                muP.initialize_params()

                self.assertEqual(ipllr.base_lr, base_lr)
                self.assertEqual(ipllr.scheduler.base_lr, base_lr)
                self.assertEqual(muP.base_lr, base_lr)
                self.assertEqual(ntk.base_lr, base_lr)

                # set all input biases to 0
                with torch.no_grad():
                    ipllr.input_layer.bias.data.fill_(0.)
                    muP.input_layer.bias.data.fill_(0.)
                    ntk.input_layer.bias.data.fill_(0.)

                ipllr.train()
                muP.train()
                ntk.train()

                chis_ipllr = []
                chis_muP = []
                xs_L_ntk = []
                for i, batch in enumerate(batches):
                    chi_ipllr, chi_muP, x_L_ntk = \
                        self.test_first_forward_backward_ip_muP(batch, ipllr, muP, ntk, width, L)
                    chis_ipllr.append(chi_ipllr)
                    chis_muP.append(chi_muP)
                    xs_L_ntk.append(x_L_ntk.mean().item())

                with torch.no_grad():
                    print('input means :', xs.mean().item())
                    print('mean x_L_ntk :', np.mean(xs_L_ntk))
                    print('mean chi_muP :', np.mean(chis_muP))
                    print('mean chi_ipllr :', np.mean(chis_ipllr))
                    print('')
예제 #4
0
def main(activation="relu", base_lr=0.01, batch_size=512, dataset="mnist"):
    config_path = os.path.join(CONFIG_PATH, 'fc_ipllr_{}.yaml'.format(dataset))
    figures_dir = os.path.join(FIGURES_DIR, dataset)
    create_dir(figures_dir)
    log_path = os.path.join(figures_dir, 'log_muP_{}.txt'.format(activation))
    logger = set_up_logger(log_path)

    logger.info('Parameters of the run:')
    logger.info('activation = {}'.format(activation))
    logger.info('base_lr = {}'.format(base_lr))
    logger.info('batch_size = {:,}'.format(batch_size))
    logger.info('Random SEED : {:,}'.format(SEED))
    logger.info(
        'Number of random trials for each model : {:,}'.format(N_TRIALS))

    try:
        set_random_seeds(SEED)  # set random seed for reproducibility
        config_dict = read_yaml(config_path)

        version = 'L={}_m={}_act={}_lr={}_bs={}'.format(
            L, width, activation, base_lr, batch_size)
        template_name = 'muP_{}_ranks_{}_' + version

        config_dict['architecture']['width'] = width
        config_dict['architecture']['n_layers'] = L + 1
        config_dict['optimizer']['params']['lr'] = base_lr
        config_dict['activation']['name'] = activation

        base_model_config = ModelConfig(config_dict)

        # Load data & define models
        logger.info('Loading data ...')
        if dataset == 'mnist':
            from utils.dataset.mnist import load_data
        elif dataset == 'cifar10':
            from utils.dataset.cifar10 import load_data
        elif dataset == 'cifar100':
            # TODO : add cifar100 to utils.dataset
            pass
        else:
            error = ValueError(
                "dataset must be one of ['mnist', 'cifar10', 'cifar100'] but was {}"
                .format(dataset))
            logger.error(error)
            raise error

        training_dataset, test_dataset = load_data(download=False,
                                                   flatten=True)
        train_data_loader = DataLoader(training_dataset,
                                       shuffle=True,
                                       batch_size=batch_size)
        batches = list(train_data_loader)

        full_x = torch.cat([a for a, _ in batches], dim=0)
        full_y = torch.cat([b for _, b in batches], dim=0)

        logger.info('Defining models')
        base_model_config.scheduler = None
        muPs = [FCmuP(base_model_config) for _ in range(N_TRIALS)]

        for muP in muPs:
            for i, param_group in enumerate(muP.optimizer.param_groups):
                if i == 0:
                    param_group['lr'] = param_group['lr'] * (muP.d + 1)

        # save initial models
        muPs_0 = [deepcopy(muP) for muP in muPs]

        # train model one step
        logger.info('Training model a first step (t=1)')
        x, y = batches[0]
        muPs_1 = []
        for muP in muPs:
            train_model_one_step(muP, x, y, normalize_first=True)
            muPs_1.append(deepcopy(muP))

        # train models for a second step
        logger.info('Training model a second step (t=2)')
        x, y = batches[1]
        muPs_2 = []
        for muP in muPs:
            train_model_one_step(muP, x, y, normalize_first=True)
            muPs_2.append(deepcopy(muP))

        # set eval mode for all models
        for i in range(N_TRIALS):
            muPs[i].eval()
            muPs_0[i].eval()
            muPs_1[i].eval()
            muPs_2[i].eval()

        logger.info('Storing initial and update matrices')
        # define W0 and b0
        W0s = []
        b0s = []
        for muP_0 in muPs_0:
            W0, b0 = get_W0_dict(muP_0, normalize_first=True)
            W0s.append(W0)
            b0s.append(b0)

        # define Delta_W_1 and Delta_b_1
        Delta_W_1s = []
        Delta_b_1s = []
        for i in range(N_TRIALS):
            Delta_W_1, Delta_b_1 = get_Delta_W1_dict(muPs_0[i],
                                                     muPs_1[i],
                                                     normalize_first=True)
            Delta_W_1s.append(Delta_W_1)
            Delta_b_1s.append(Delta_b_1)

        # define Delta_W_2 and Delta_b_2
        Delta_W_2s = []
        Delta_b_2s = []
        for i in range(N_TRIALS):
            Delta_W_2, Delta_b_2 = get_Delta_W2_dict(muPs_1[i],
                                                     muPs_2[i],
                                                     normalize_first=True)
            Delta_W_2s.append(Delta_W_2)
            Delta_b_2s.append(Delta_b_2)

        x, y = full_x, full_y  # compute pre-activations on full batch

        # contributions after first step
        h0s = []
        delta_h_1s = []
        h1s = []
        x1s = []
        for i in range(N_TRIALS):
            h0, delta_h_1, h1, x1 = get_contributions_1(x,
                                                        muPs_1[i],
                                                        W0s[i],
                                                        b0s[i],
                                                        Delta_W_1s[i],
                                                        Delta_b_1s[i],
                                                        normalize_first=True)
            h0s.append(h0)
            delta_h_1s.append(delta_h_1)
            h1s.append(h0)
            x1s.append(x1)

        # ranks of initial weight matrices and first two updates
        logger.info('Computing ranks of weight matrices ...')
        weight_ranks_dfs_dict = dict()

        tol = None
        weight_ranks_dfs_dict['svd_default'] = [
            get_svd_ranks_weights(W0s[i],
                                  Delta_W_1s[i],
                                  Delta_W_2s[i],
                                  L,
                                  tol=tol) for i in range(N_TRIALS)
        ]

        tol = 1e-7
        weight_ranks_dfs_dict['svd_tol'] = [
            get_svd_ranks_weights(W0s[i],
                                  Delta_W_1s[i],
                                  Delta_W_2s[i],
                                  L,
                                  tol=tol) for i in range(N_TRIALS)
        ]

        weight_ranks_dfs_dict['squared_tr'] = [
            get_square_trace_ranks_weights(W0s[i], Delta_W_1s[i],
                                           Delta_W_2s[i], L)
            for i in range(N_TRIALS)
        ]

        weight_ranks_df_dict = {
            key: get_concatenated_ranks_df(weight_ranks_dfs_dict[key])
            for key in weight_ranks_dfs_dict.keys()
        }
        avg_ranks_df_dict = {
            key: get_avg_ranks_dfs(weight_ranks_df_dict[key])
            for key in weight_ranks_df_dict.keys()
        }

        logger.info('Saving weights ranks data frames to csv ...')
        for key in weight_ranks_df_dict.keys():
            logger.info(key)
            logger.info('\n' + str(avg_ranks_df_dict[key]) + '\n\n')
            avg_ranks_df_dict[key].to_csv(os.path.join(
                figures_dir,
                template_name.format(key, 'weights') + '.csv'),
                                          header=True,
                                          index=True)

        ranks_dfs = [
            weight_ranks_df_dict['svd_default'],
            weight_ranks_df_dict['svd_tol'], weight_ranks_df_dict['squared_tr']
        ]

        # plot weights ranks
        logger.info('Plotting weights ranks')
        plt.figure(figsize=(12, 6))
        plot_weights_ranks_vs_layer('W0',
                                    ranks_dfs,
                                    tol,
                                    L,
                                    width,
                                    base_lr,
                                    batch_size,
                                    y_scale='log')
        plt.savefig(
            os.path.join(figures_dir,
                         template_name.format('W0', 'weights') + '.png'))

        plt.figure(figsize=(12, 6))
        plot_weights_ranks_vs_layer('Delta_W_1',
                                    ranks_dfs,
                                    tol,
                                    L,
                                    width,
                                    base_lr,
                                    batch_size,
                                    y_scale='log')
        plt.savefig(
            os.path.join(figures_dir,
                         template_name.format('Delta_W_1', 'weights') +
                         '.png'))

        plt.figure(figsize=(12, 6))
        plot_weights_ranks_vs_layer('Delta_W_2',
                                    ranks_dfs,
                                    tol,
                                    L,
                                    width,
                                    base_lr,
                                    batch_size,
                                    y_scale='log')
        plt.savefig(
            os.path.join(figures_dir,
                         template_name.format('Delta_W_2', 'weights') +
                         '.png'))

        # ranks of the pre-activations
        logger.info('Computing ranks of (pre-)activations ...')
        act_ranks_dfs_dict = dict()

        tol = None
        act_ranks_dfs_dict['svd_default'] = [
            get_svd_ranks_acts(h0s[i],
                               delta_h_1s[i],
                               h1s[i],
                               x1s[i],
                               L,
                               tol=tol) for i in range(N_TRIALS)
        ]

        tol = 1e-7
        act_ranks_dfs_dict['svd_tol'] = [
            get_svd_ranks_acts(h0s[i],
                               delta_h_1s[i],
                               h1s[i],
                               x1s[i],
                               L,
                               tol=tol) for i in range(N_TRIALS)
        ]

        act_ranks_dfs_dict['squared_tr'] = [
            get_square_trace_ranks_acts(h0s[i], delta_h_1s[i], h1s[i], x1s[i],
                                        L) for i in range(N_TRIALS)
        ]

        act_ranks_df_dict = {
            key: get_concatenated_ranks_df(act_ranks_dfs_dict[key])
            for key in act_ranks_dfs_dict.keys()
        }
        avg_ranks_df_dict = {
            key: get_avg_ranks_dfs(act_ranks_df_dict[key])
            for key in act_ranks_df_dict.keys()
        }

        logger.info('Saving (pre-)activation ranks data frames to csv ...')
        for key in avg_ranks_df_dict.keys():
            logger.info(key)
            logger.info('\n' + str(avg_ranks_df_dict[key]) + '\n\n')
            avg_ranks_df_dict[key].to_csv(os.path.join(
                figures_dir,
                template_name.format(key, 'acts') + '.csv'),
                                          header=True,
                                          index=True)

        ranks_dfs = [
            act_ranks_df_dict['svd_default'], act_ranks_df_dict['svd_tol'],
            act_ranks_df_dict['squared_tr']
        ]

        logger.info('Plotting (pre-)activation ranks')
        plt.figure(figsize=(12, 6))
        plot_acts_ranks_vs_layer('h0',
                                 ranks_dfs,
                                 tol,
                                 L,
                                 width,
                                 base_lr,
                                 batch_size,
                                 y_scale='log')
        plt.savefig(
            os.path.join(figures_dir,
                         template_name.format('h0', 'acts') + '.png'))

        plt.figure(figsize=(12, 6))
        plot_acts_ranks_vs_layer('h1',
                                 ranks_dfs,
                                 tol,
                                 L,
                                 width,
                                 base_lr,
                                 batch_size,
                                 y_scale='log')
        plt.savefig(
            os.path.join(figures_dir,
                         template_name.format('h1', 'acts') + '.png'))

        plt.figure(figsize=(12, 6))
        plot_acts_ranks_vs_layer('x1',
                                 ranks_dfs,
                                 tol,
                                 L,
                                 width,
                                 base_lr,
                                 batch_size,
                                 y_scale='log')
        plt.savefig(
            os.path.join(figures_dir,
                         template_name.format('x1', 'acts') + '.png'))

        plt.figure(figsize=(12, 6))
        plot_acts_ranks_vs_layer('delta_h_1',
                                 ranks_dfs,
                                 tol,
                                 L,
                                 width,
                                 base_lr,
                                 batch_size,
                                 y_scale='log')
        plt.savefig(
            os.path.join(figures_dir,
                         template_name.format('delta_h_1', 'acts') + '.png'))

        # diversity in terms of the index of the maximum entry
        logger.info(
            'Computing diversity of the maximum entry of pre-activations...')
        max_acts_diversity_dfs = [
            get_max_acts_diversity(h0s[i], delta_h_1s[i], h1s[i], L)
            for i in range(N_TRIALS)
        ]
        max_acts_diversity_df = get_concatenated_ranks_df(
            max_acts_diversity_dfs)
        avg_max_acts_diversity_df = get_avg_ranks_dfs(max_acts_diversity_df)
        logger.info('Diversity of the maximum activation index df:')
        logger.info(str(avg_max_acts_diversity_df))
        avg_max_acts_diversity_df.to_csv(os.path.join(
            figures_dir, 'muP_max_acts_' + version + '.csv'),
                                         header=True,
                                         index=True)

    except Exception as e:
        logger.exception("Exception when running the script : {}".format(e))
예제 #5
0
def main(activation="relu",
         n_steps=300,
         base_lr=0.01,
         batch_size=512,
         dataset="mnist"):
    config_path = os.path.join(CONFIG_PATH, 'fc_ipllr_{}.yaml'.format(dataset))
    figures_dir = os.path.join(FIGURES_DIR, dataset)
    create_dir(figures_dir)
    log_path = os.path.join(figures_dir, 'log_muP_{}.txt'.format(activation))
    logger = set_up_logger(log_path)

    logger.info('Parameters of the run:')
    logger.info('activation = {}'.format(activation))
    logger.info('n_steps = {:,}'.format(n_steps))
    logger.info('base_lr = {}'.format(base_lr))
    logger.info('batch_size = {:,}'.format(batch_size))
    logger.info('Random SEED : {:,}'.format(SEED))
    logger.info(
        'Number of random trials for each model : {:,}'.format(N_TRIALS))

    try:
        set_random_seeds(SEED)  # set random seed for reproducibility
        config_dict = read_yaml(config_path)

        fig_name_template = 'muP_{}_{}_L={}_m={}_act={}_lr={}_bs={}.png'

        config_dict['architecture']['width'] = width
        config_dict['architecture']['n_layers'] = L + 1
        config_dict['optimizer']['params']['lr'] = base_lr
        config_dict['activation']['name'] = activation

        base_model_config = ModelConfig(config_dict)

        # Load data & define models
        logger.info('Loading data ...')
        if dataset == 'mnist':
            from utils.dataset.mnist import load_data
        elif dataset == 'cifar10':
            from utils.dataset.cifar10 import load_data
        elif dataset == 'cifar100':
            # TODO : add cifar100 to utils.dataset
            config_dict['architecture']['output_size'] = 100
            pass
        else:
            error = ValueError(
                "dataset must be one of ['mnist', 'cifar10', 'cifar100'] but was {}"
                .format(dataset))
            logger.error(error)
            raise error

        training_dataset, test_dataset = load_data(download=False,
                                                   flatten=True)
        train_data_loader = DataLoader(training_dataset,
                                       shuffle=True,
                                       batch_size=batch_size)
        batches = list(train_data_loader)

        logger.info('Defining models')
        base_model_config.scheduler = None
        muPs = [FCmuP(base_model_config) for _ in range(N_TRIALS)]
        muPs_renorm = [FCmuP(base_model_config) for _ in range(N_TRIALS)]
        muPs_renorm_scale_lr = [
            FCmuP(base_model_config) for _ in range(N_TRIALS)
        ]

        for muP in muPs_renorm_scale_lr:
            for i, param_group in enumerate(muP.optimizer.param_groups):
                if i == 0:
                    param_group['lr'] = param_group['lr'] * (muP.d + 1)

        logger.info('Copying parameters of base muP')
        for i in range(N_TRIALS):
            muPs_renorm[i].copy_initial_params_from_model(muPs[i])
            muPs_renorm_scale_lr[i].copy_initial_params_from_model(muPs[i])

            muPs_renorm[i].initialize_params()
            muPs_renorm_scale_lr[i].initialize_params()

        results = dict()
        logger.info('Generating training results ...')
        results['muP'] = [
            collect_training_losses(muPs[i],
                                    batches,
                                    n_steps,
                                    normalize_first=False)
            for i in range(N_TRIALS)
        ]

        results['muP_renorm'] = [
            collect_training_losses(muPs_renorm[i],
                                    batches,
                                    n_steps,
                                    normalize_first=True)
            for i in range(N_TRIALS)
        ]

        results['muP_renorm_scale_lr'] = [
            collect_training_losses(muPs_renorm_scale_lr[i],
                                    batches,
                                    n_steps,
                                    normalize_first=True)
            for i in range(N_TRIALS)
        ]

        mode = 'training'
        losses = dict()
        for key, res in results.items():
            losses[key] = [r[0] for r in res]

        chis = dict()
        for key, res in results.items():
            chis[key] = [r[1] for r in res]

        # Plot losses and derivatives
        logger.info('Saving figures at {}'.format(figures_dir))
        key = 'loss'
        plt.figure(figsize=(12, 8))
        plot_losses_models(losses,
                           key=key,
                           L=L,
                           width=width,
                           activation=activation,
                           lr=base_lr,
                           batch_size=batch_size,
                           mode=mode,
                           normalize_first=renorm_first,
                           marker=None,
                           name='muP')
        plt.ylim(0, 2.5)
        plt.savefig(
            os.path.join(
                figures_dir,
                fig_name_template.format(mode, key, L, width, activation,
                                         base_lr, batch_size)))

        key = 'chi'
        plt.figure(figsize=(12, 8))
        plot_losses_models(chis,
                           key=key,
                           L=L,
                           width=width,
                           activation=activation,
                           lr=base_lr,
                           batch_size=batch_size,
                           mode=mode,
                           marker=None,
                           name='muP')
        plt.savefig(
            os.path.join(
                figures_dir,
                fig_name_template.format(mode, key, L, width, activation,
                                         base_lr, batch_size)))

    except Exception as e:
        logger.exception("Exception when running the script : {}".format(e))
    def test_scales_with_previous_multiple_steps_muP_without_renorm(self):
        n_steps = 150
        widths = [1024]
        Ls = [6]
        n_batches = 10
        base_lrs = [0.1, 0.01]
        batch_size = 512
        config = deepcopy(self.base_model_config)

        batches = list(
            DataLoader(self.training_dataset,
                       shuffle=True,
                       batch_size=batch_size))
        print('len(batches) :', len(batches))

        for L in Ls:
            config.architecture['n_layers'] = L + 1
            for width in widths:
                config.architecture['width'] = width

                config.scheduler = None
                base_muP = FCmuP(config)

                for base_lr in base_lrs:
                    config.optimizer.params['lr'] = base_lr

                    config.scheduler = None
                    muP = FCmuP(config)

                    scheduler_config = {
                        'calibrate_base_lr': True,
                        'default_calibration': False
                    }
                    config.scheduler = BaseConfig(scheduler_config)
                    # config.scheduler.params['calibrate_base_lr'] = True
                    # config.scheduler.params['default_calibration'] = False
                    ipllr_calib = FcIPLLR(config,
                                          n_warmup_steps=12,
                                          lr_calibration_batches=batches)

                    # set init from same model
                    muP.copy_initial_params_from_model(base_muP)
                    muP.initialize_params()

                    ipllr_calib.copy_initial_params_from_model(muP)
                    ipllr_calib.initialize_params()

                    muP_init = deepcopy(muP)
                    ipllr_calib_init = deepcopy(ipllr_calib)

                    for step in range(n_steps):
                        print('##### step {} ####'.format(step))
                        # first train the models for one step
                        x, y = batches[step % len(batches)]

                        # copy models at current step
                        muP_previous = deepcopy(muP)
                        ipllr_calib_previous = deepcopy(ipllr_calib)

                        # train for oone step
                        self._train_models_one_step(muP,
                                                    ipllr_calib,
                                                    x,
                                                    y,
                                                    renorm_first=False)

                        # batch_nb = 1 + step * n_batches % len(batches)
                        # print('batch_nb:', batch_nb)
                        next_batch = batches[(step + 1) % len(batches)]
                        # print('len(reduced_batches) at step {} : {}'.format(step, len(reduced_batches)))

                        print('---- For L = {:,} and base_lr = {} ----'.format(
                            L, base_lr))
                        muP_contribs = \
                            self._compute_contributions_with_previous('muP', muP, muP_init, muP_previous, [next_batch],
                                                                      renorm_first=False)
                        print('muP contributions per layer : ')
                        print(
                            muP_contribs.groupby(by='layer')[[
                                'init', 'previous_h', 'previous_Delta_h',
                                'delta_h', 'Delta_h', 'total'
                            ]].mean())
                        print('')

                        ipllr_calib_contribs = \
                            self._compute_contributions_with_previous('ipllr_calib', ipllr_calib, ipllr_calib_init,
                                                                      ipllr_calib_previous, [next_batch],
                                                                      renorm_first=False)

                        print('calibrated ipllr contributions per layer : ')
                        print(
                            ipllr_calib_contribs.groupby(by='layer')[[
                                'init', 'previous_h', 'previous_Delta_h',
                                'delta_h', 'Delta_h', 'total'
                            ]].mean())
                        print('\n\n')
    def test_scales_multiple_steps_muP(self):
        n_steps = 10
        widths = [1024]
        Ls = [6]
        n_batches = 10
        base_lrs = [0.01, 0.1]
        config = deepcopy(self.base_model_config)

        batches = list(self.train_data_loader)
        print('len(batches) :', len(batches))
        for L in Ls:
            config.architecture['n_layers'] = L + 1
            for width in widths:
                config.architecture['width'] = width

                config.scheduler = None
                base_muP = FCmuP(config)

                for base_lr in base_lrs:
                    config.optimizer.params['lr'] = base_lr

                    config.scheduler = None
                    muP = FCmuP(config)

                    scheduler_config = {
                        'calibrate_base_lr': True,
                        'default_calibration': False
                    }
                    config.scheduler = BaseConfig(scheduler_config)
                    # config.scheduler.params['calibrate_base_lr'] = True
                    # config.scheduler.params['default_calibration'] = False
                    ipllr_calib = FcIPLLR(config,
                                          n_warmup_steps=12,
                                          lr_calibration_batches=batches)

                    # set init from same model
                    muP.copy_initial_params_from_model(base_muP)
                    muP.initialize_params()

                    ipllr_calib.copy_initial_params_from_model(base_muP)
                    ipllr_calib.initialize_params()

                    muP_init = deepcopy(muP)
                    ipllr_calib_init = deepcopy(ipllr_calib)

                    for step in range(n_steps):
                        print('##### step {} ####'.format(step))
                        # first train the models for one step
                        x, y = batches[step]
                        self._train_models_one_step(muP, ipllr_calib, x, y)

                        # batch_nb = 1 + step * n_batches % len(batches)
                        # print('batch_nb:', batch_nb)
                        # reduced_batches = batches[batch_nb: batch_nb + n_batches]
                        reduced_batches = batches[-n_batches:]
                        # print('len(reduced_batches) at step {} : {}'.format(step, len(reduced_batches)))
                        muP_contribs = self._compute_contributions(
                            'muP', muP, muP_init, reduced_batches)
                        ipllr_calib_contribs = self._compute_contributions(
                            'ipllr_calib', ipllr_calib, ipllr_calib_init,
                            reduced_batches)
                        print('---- For L = {:,} and base_lr = {} ----'.format(
                            L, base_lr))

                        print('muP contributions per layer : ')
                        print(
                            muP_contribs.groupby(by='layer')[[
                                'init', 'update', 'total'
                            ]].mean())
                        print('')

                        print('calibrated ipllr contributions per layer : ')
                        print(
                            ipllr_calib_contribs.groupby(by='layer')[[
                                'init', 'update', 'total'
                            ]].mean())
                        print('\n\n')