コード例 #1
0
    def __init__(self, dataset_str, prof_model_str, model_str, instance_input_shape, prof_input_shape, NN_DENSITY_LEVEL_LIST, OPTIMIZER_STR, NUM_RUNS, NUM_EPOCHS, BATCH_SIZE, STEP_SIZE, MASK_UPDATE_FREQ, LAMBDA_KER_DIST, LAMBDA_L2_REG, MASK_UPDATE_BOOL = True, VALIDATION_FRACTION = 0.002, PRUNE_METHOD = 'magnitude', GLOBAL_PRUNE_BOOL = False, INIT_RUN_INDEX = 1, SAVE_BOOL = True, save_dir = '/content/neural-tangent-transfer/saved/ntt_results/' ):
        """ 
        Args: 
            # Model options    
            dataset_str: a string that describes the dataset
            model_str: a string that describes the model
            instance_input_shape: the shape of a single input data sample
            NN_DENSITY_LEVEL_LIST: a list of desired weight density levels
            
            
            # Optimization options
            OPTIMIZER_STR: a string of optimizer
            NUM_RUNS: number of independent runs
            NUM_EPOCHS: number of epochs
            BATCH_SIZE: size of a minibatch
            STEP_SIZE: the learning rate
            MASK_UPDATE_FREQ: number of mini-batch iterations, after which mask will be updated according to the magnitudes of weights. 
            LAMBDA_KER_DIST: scaling parameter for kernel distance
            LAMBDA_L2_REG: scaling parameter for l2 penalty
            VALIDATION_FRACTION: the fraction of training data held-out as validation during NTT
            PRUNE_METHOD: the pruning method
            GLOBAL_PRUNE_BOOL: whether to use global (net-wise) pruning or not
            
            # Data & model saving options
            INIT_RUN_INDEX: the index of the initial run
            SAVE_BOOL: a boolean variable which decides whether the learned sparse parameters are saved.
            save_dir: the data saving directory.
        """
        
        self.model_str = model_str
        self.prof_model_str = prof_model_str
        self.NN_DENSITY_LEVEL_LIST = NN_DENSITY_LEVEL_LIST
        self.DATASET =  Dataset(datasource = dataset_str, VALIDATION_FRACTION = VALIDATION_FRACTION )        
        self.NUM_RUNS = NUM_RUNS
        self.BATCH_SIZE = BATCH_SIZE
        self.NUM_EPOCHS = NUM_EPOCHS
        self.OPTIMIZER_WITH_PARAMS = optimizer_dict[OPTIMIZER_STR](step_size = STEP_SIZE)
        self.LAMBDA_KER_DIST = LAMBDA_KER_DIST
        self.LAMBDA_L2_REG = LAMBDA_L2_REG
        self.SAVE_BOOL = SAVE_BOOL   
        self.INIT_RUN_INDEX = INIT_RUN_INDEX
        self.GLOBAL_PRUNE_BOOL = GLOBAL_PRUNE_BOOL
        self.PRUNE_METHOD = PRUNE_METHOD
        
        if MASK_UPDATE_BOOL == False:
            logger.info("Mask update during NTT has been turned off")
            self.MASK_UPDATE_FREQ = np.inf # not updating masks
        else:
            self.MASK_UPDATE_FREQ = MASK_UPDATE_FREQ # updating masks every a MASK_UPDATE_FREQ number of iterations        
        
#         time.sleep(orig_random.uniform(1,20))            
#         now = datetime.now()
#         now_str = str(now.strftime("%D:%H:%M:%S")).replace('/', ':')
        

        if GLOBAL_PRUNE_BOOL:
            global_layerwise_str = 'global_prune'
        else:
            global_layerwise_str = 'layerwise_prune'
                                            
#         self.unique_model_dir =  save_dir + dataset_str + '_' + global_layerwise_str + '_' + self.model_str + '__' + now_str
        self.unique_model_dir =  save_dir + dataset_str + '_' + global_layerwise_str + '_' + self.model_str 
        
        self.param_dict = dict(model_str = model_str,
                               prof_model_str = prof_model_str,
                               dataset_str = dataset_str, 
                               instance_input_shape = instance_input_shape,
                               prof_input_shape = prof_input_shape,
                               NN_DENSITY_LEVEL_LIST = NN_DENSITY_LEVEL_LIST, 
                               OPTIMIZER_STR = OPTIMIZER_STR,
                               NUM_RUNS = NUM_RUNS, 
                               NUM_EPOCHS = NUM_EPOCHS, 
                               BATCH_SIZE = BATCH_SIZE,  
                               STEP_SIZE = STEP_SIZE, 
                               MASK_UPDATE_FREQ = self.MASK_UPDATE_FREQ, 
                               LAMBDA_KER_DIST =  LAMBDA_KER_DIST, 
                               LAMBDA_L2_REG = LAMBDA_L2_REG,
                               SAVE_BOOL = SAVE_BOOL,
                               VALIDATION_FRACTION = VALIDATION_FRACTION,
                               GLOBAL_PRUNE_BOOL = GLOBAL_PRUNE_BOOL,
                               PRUNE_METHOD = PRUNE_METHOD)  
        
        # unpack the neural net architecture
        init_fun, apply_fn = model_dict[model_str](W_initializers_str = 'glorot_normal()', b_initializers_str = 'normal()')
        
        self.init_fun = init_fun
        self.apply_fn = apply_fn
        self.emp_ntk_fn = empirical_ntk_fn(apply_fn)

        # unpack the professor neural net architecture
        prof_init_fun, prof_apply_fn = model_dict[prof_model_str](W_initializers_str='glorot_normal()', b_initializers_str='normal()')

        self.prof_init_fun = prof_init_fun
        self.prof_apply_fn = prof_apply_fn
        self.prof_emp_ntk_fn = empirical_ntk_fn(prof_apply_fn)
        
        self.batch_input_shape = [-1] + instance_input_shape
        self.prof_batch_input_shape = [-1] + prof_input_shape
    

        self.vali_samples = self.DATASET.dataset['val']['input'][:self.BATCH_SIZE, :].reshape(self.batch_input_shape)
        self.prof_vali_samples = self.DATASET.dataset['val']['input'][:self.BATCH_SIZE, :].reshape(self.prof_batch_input_shape)

        # split validation inputs into two collections, vali_inputs_1 and vali_inputs_2.
        half_vali_size = int(len(self.vali_samples)/2)
        self.vali_inputs_1 = self.vali_samples[:half_vali_size]
        self.vali_inputs_2 = self.vali_samples[half_vali_size:]

        # split validation inputs into two collections, vali_inputs_1 and vali_inputs_2 for professor net with different input
        prof_half_vali_size = int(len(self.prof_vali_samples) / 2)
        self.prof_vali_inputs_1 = self.prof_vali_samples[:prof_half_vali_size]
        self.prof_vali_inputs_2 = self.prof_vali_samples[prof_half_vali_size:]
コード例 #2
0
    def supervised_optimization(self,
                                sup_density_list,
                                wiring_str,
                                save_supervised_result_bool,
                                dataset_str,
                                EXPLOITATION_NUM_EPOCHS,
                                EXPLOITATION_BATCH_SIZE,
                                OPTIMIZER_STR,
                                STEP_SIZE,
                                REG,
                                W_initializers_str='glorot_normal()',
                                b_initializers_str='normal()',
                                init_weight_rescale_bool=False,
                                EXPLOITATION_VALIDATION_FRACTION=0.1,
                                EXPLOIT_TRAIN_DATASET_FRACTION=1.0,
                                RECORD_ACC_FREQ=100,
                                DROPOUT_LAYER_POS=[],
                                **kwargs):
        """ 
        Train a neural network with loaded wiring from scratch.

        Args: 
            sup_density_list: a list of network density levels
            wiring_str: a string that represents the network wiring, e.g., trans, rand, snip
            dataset_str: a string used to retreive the dataset
            EXPLOITATION_NUM_EPOCHS: the number of epochs used in supervsied training
            EXPLOITATION_BATCH_SIZE: the batch size used in supervsied training
            OPTIMIZER_STR: a string used to retreive the optimzier
            STEP_SIZE: step size of the optimizer
            REG: l2 regularization constant
            EXPLOITATION_VALIDATION_FRACTION: the fraction of training data held out for validation purpose
            EXPLOIT_TRAIN_DATASET_FRACTION: the fraction of training data used in evaluation. 
            RECORD_ACC_FREQ: the frequency for recording train and test results

        Returns:
            train_acc_list_runs: a list of training accuracy
            test_acc_list_runs: a list of testing accuracy
        """

        for density in sup_density_list:
            if density not in self.ntt_setup_dict['NN_DENSITY_LEVEL_LIST']:
                raise ValueError(
                    'The desired density level for supervised training is not used in NTT.'
                )

        dataset_info = Dataset(
            datasource=dataset_str,
            VALIDATION_FRACTION=EXPLOITATION_VALIDATION_FRACTION)

        dataset = dataset_info.dataset

        # configure the dataset
        gen_batches = dataset_info.data_stream(EXPLOITATION_BATCH_SIZE)

        batch_input_shape = [-1] + self.ntt_setup_dict['instance_input_shape']

        nr_training_samples = len(dataset['train']['input'])

        nr_training_samples_subset = int(nr_training_samples *
                                         EXPLOIT_TRAIN_DATASET_FRACTION)

        train_input = dataset['train'][
            'input'][:nr_training_samples_subset].reshape(batch_input_shape)
        train_label = dataset['train']['label'][:nr_training_samples_subset]

        test_input = dataset['test']['input'].reshape(batch_input_shape)
        test_label = dataset['test']['label']

        num_complete_batches, leftover = divmod(nr_training_samples,
                                                EXPLOITATION_BATCH_SIZE)

        num_mini_batches_per_epochs = num_complete_batches + bool(leftover)

        total_batch = EXPLOITATION_NUM_EPOCHS * num_mini_batches_per_epochs

        if len(DROPOUT_LAYER_POS) == 0:
            # in this case, dropout is NOT used
            init_fun_no_dropout, f_train = model_dict[self.model_str](
                W_initializers_str=W_initializers_str,
                b_initializers_str=b_initializers_str)
            f_test = f_train
            f_no_dropout = f_train
            key_dropout = None
            subkey_dropout = None

        else:
            # in this case, dropout is used
            _, f_train = model_dict[self.model_str + '_dropout'](
                mode='train',
                W_initializers_str=W_initializers_str,
                b_initializers_str=b_initializers_str)
            _, f_test = model_dict[self.model_str + '_dropout'](
                mode='test',
                W_initializers_str=W_initializers_str,
                b_initializers_str=b_initializers_str)

            init_fun_no_dropout, f_no_dropout = model_dict[self.model_str](
                W_initializers_str=W_initializers_str,
                b_initializers_str=b_initializers_str)

            key_dropout = random.PRNGKey(0)

        @jit
        def step(i, opt_state, x, y, masks, key):
            this_step_params = get_params(opt_state)
            masked_g = grad(softmax_cross_entropy_with_logits_l2_reg)(
                this_step_params,
                f_train,
                x,
                y,
                masks,
                L2_REG_COEFF=REG,
                key=key)
            return opt_update(i, masked_g, opt_state)

        train_results_dict = {}
        test_results_dict = {}
        trained_masked_dict = {}

        for handler in logging.root.handlers[:]:
            logging.root.removeHandler(handler)

        time.sleep(orig_random.uniform(1, 5))
        now_str = '__' + str(datetime.now().strftime("%D:%H:%M:%S")).replace(
            '/', ':')

        supervised_model_info = '[u]' + self.ntt_file_name + '_[s]' + dataset_str

        supervised_model_wiring_info = supervised_model_info + '_' + wiring_str

        supervised_model_wiring_dir = self.supervised_result_path + supervised_model_info + '/' + supervised_model_wiring_info + now_str

        if save_supervised_result_bool:

            while os.path.exists(supervised_model_wiring_dir):
                temp = supervised_model_wiring_dir + '_0'
                supervised_model_wiring_dir = temp
            # print(supervised_model_wiring_dir)
            os.makedirs(supervised_model_wiring_dir)

            logging.basicConfig(filename=supervised_model_wiring_dir +
                                "/supervised_learning_log.log",
                                format='%(asctime)s %(message)s',
                                filemode='w',
                                level=logging.DEBUG)
        else:
            logging.basicConfig(filename="supervised_learning_log.log",
                                format='%(asctime)s %(message)s',
                                filemode='w',
                                level=logging.DEBUG)

        for nn_density_level in sup_density_list:

            nn_density_level = onp.round(nn_density_level, 2)
            train_acc_list_runs = []
            test_acc_list_runs = []
            trained_masked_params_runs = []

            for run_index in range(1, self.ntt_setup_dict['NUM_RUNS'] + 1):

                if wiring_str == 'trans':
                    # load ntt masks and parameters
                    density_run_dir = '/' + 'density_' + str(
                        nn_density_level) + '/' + 'run_' + str(run_index)

                    transferred_masks_fileName = '/transferred_masks_' + self.model_str + density_run_dir.replace(
                        '/', '_') + '.npy'

                    transferred_param_fileName = '/transferred_params_' + self.model_str + density_run_dir.replace(
                        '/', '_') + '.npy'

                    masks = list(
                        np.load(self.ntt_result_path + density_run_dir +
                                transferred_masks_fileName,
                                allow_pickle=True))

                    masked_params = list(
                        np.load(self.ntt_result_path + density_run_dir +
                                transferred_param_fileName,
                                allow_pickle=True))

                elif wiring_str == 'rand':
                    # randomly initialize masks and parameters

                    _, params = init_fun_no_dropout(random.PRNGKey(run_index),
                                                    tuple(batch_input_shape))

                    masks = get_masks_from_jax_params(
                        params,
                        nn_density_level,
                        global_bool=self.ntt_setup_dict['GLOBAL_PRUNE_BOOL'],
                        magnitude_base_bool=False,
                        reshuffle_seed=run_index)

                    masked_params = get_sparse_params_filtered_by_masks(
                        params, masks)

                elif wiring_str == 'dense':
                    # randomly initialize masks and parameters

                    _, params = init_fun_no_dropout(random.PRNGKey(run_index),
                                                    tuple(batch_input_shape))

                    #                     masks = get_masks_from_jax_params(params, nn_density_level, global_bool = self.ntt_setup_dict['GLOBAL_PRUNE_BOOL'], magnitude_base_bool = False, reshuffle_seed = run_index)
                    logger.info("Dense net!!")

                    masks = None
                    masked_params = params

                elif wiring_str == 'snip':
                    # randomly initialize masks and parameters
                    if dataset_str == 'cifar-10':
                        num_examples_snip = 128
                    else:
                        num_examples_snip = 100

                    snip_input = dataset['train']['input'][:num_examples_snip]

                    snip_label = dataset['train']['label'][:num_examples_snip]

                    snip_batch = (snip_input, snip_label)

                    _, params = init_fun_no_dropout(random.PRNGKey(run_index),
                                                    tuple(batch_input_shape))

                    if not self.ntt_setup_dict['GLOBAL_PRUNE_BOOL']:
                        logger.info("Use layerwise snip")

                    masks = get_snip_masks(
                        params, nn_density_level, f_no_dropout, snip_batch,
                        batch_input_shape,
                        self.ntt_setup_dict['GLOBAL_PRUNE_BOOL'])

                    masked_params = get_sparse_params_filtered_by_masks(
                        params, masks)

                elif wiring_str == 'logit_snip':
                    # randomly initialize masks and parameters
                    if dataset_str == 'cifar-10':
                        num_examples_snip = 128
                    else:
                        num_examples_snip = 100

                    snip_input = dataset['train']['input'][:num_examples_snip]

                    _, params = init_fun_no_dropout(random.PRNGKey(run_index),
                                                    tuple(batch_input_shape))

                    masks = get_logit_snip_masks(
                        params, nn_density_level, f_no_dropout, snip_input,
                        batch_input_shape,
                        self.ntt_setup_dict['GLOBAL_PRUNE_BOOL'])
                    #                     get_snip_masks(params, nn_density_level, f_no_dropout, snip_batch, batch_input_shape)

                    masked_params = get_sparse_params_filtered_by_masks(
                        params, masks)

                else:
                    raise ValueError('The wiring string is undefined.')

            # optionally, add dropout layers #Test without dropout masks
                if len(DROPOUT_LAYER_POS) > 100:
                    dropout_masked_params = [
                        ()
                    ] * (len(masked_params) + len(DROPOUT_LAYER_POS))

                    dropout_masks = [[]] * (len(masked_params) +
                                            len(DROPOUT_LAYER_POS))

                    print(len(masked_params))  #check dropout position
                    #pprint(masked_params) # check

                    num_inserted = 0
                    for i in range(len(dropout_masked_params)):
                        if i in DROPOUT_LAYER_POS:
                            num_inserted += 1
                        else:
                            dropout_masked_params[i] = masked_params[
                                i - num_inserted]
                            dropout_masks[i] = masks[i - num_inserted]

                    masks = dropout_masks
                    masked_params = dropout_masked_params

                if init_weight_rescale_bool == True:
                    logger.info(
                        "Init weight rescaled: W_scaled = W/sqrt(nn_density_level)"
                    )
                    scaled_params = []

                    for i in range(len(masked_params)):
                        if len(masked_params[i]) == 2:
                            scaled_params.append(
                                (masked_params[i][0] *
                                 np.sqrt(1 / nn_density_level),
                                 masked_params[i][1]))
                        else:
                            scaled_params.append(masked_params[i])

                    masked_params = scaled_params

                optimizer_with_params = optimizer_dict[OPTIMIZER_STR](
                    step_size=STEP_SIZE)

                opt_init, opt_update, get_params = optimizer_with_params

                opt_state = opt_init(masked_params)

                train_acc_list = []

                test_acc_list = []

                itercount = itertools.count()

                for iteration in range(total_batch):

                    batch_xs, batch_ys = next(gen_batches)

                    batch_xs = batch_xs.reshape(batch_input_shape)

                    if key_dropout is not None:
                        key_dropout, subkey_dropout = random.split(key_dropout)

                    opt_state = step(next(itercount),
                                     opt_state,
                                     batch_xs,
                                     batch_ys,
                                     masks=masks,
                                     key=subkey_dropout)

                    if iteration % RECORD_ACC_FREQ == 0:

                        masked_trans_params = get_params(opt_state)

                        train_acc = accuracy(masked_trans_params, f_test,
                                             train_input, train_label,
                                             key_dropout)
                        test_acc = accuracy(masked_trans_params, f_test,
                                            test_input, test_label,
                                            key_dropout)

                        train_acc_list.append(train_acc)
                        test_acc_list.append(test_acc)

                        logger.info(
                            "NN density %.2f | Run %03d/%03d | Iteration %03d/%03d | Train acc %.2f%% | Test acc %.2f%%",
                            nn_density_level, run_index,
                            self.ntt_setup_dict['NUM_RUNS'], iteration + 1,
                            total_batch, train_acc * 100, test_acc * 100)

                trained_masked_trans_params = get_params(opt_state)

                train_acc_list_runs.append(train_acc_list)
                test_acc_list_runs.append(test_acc_list)
                trained_masked_params_runs.append(trained_masked_trans_params)

            train_acc_list_runs = np.array(train_acc_list_runs)
            test_acc_list_runs = np.array(test_acc_list_runs)

            train_results_dict[str(nn_density_level)] = train_acc_list_runs
            test_results_dict[str(nn_density_level)] = test_acc_list_runs
            trained_masked_dict[str(
                nn_density_level)] = trained_masked_params_runs

            if save_supervised_result_bool:

                supervised_model_wiring_dir_run = supervised_model_wiring_dir + '/density_' + str(
                    round(nn_density_level, 2)) + '/'

                while os.path.exists(supervised_model_wiring_dir_run):
                    temp = supervised_model_wiring_dir_run + '_0'
                    supervised_model_wiring_dir_run = temp

                os.makedirs(supervised_model_wiring_dir_run)

                model_summary_str = '[u]' + self.ntt_file_name + '_[s]' + dataset_str + '_density_' + str(
                    round(nn_density_level, 2))

                np.save(
                    supervised_model_wiring_dir_run + '/' +
                    'supervised_trained_' + model_summary_str, [
                        nn_density_level, train_acc_list_runs,
                        test_acc_list_runs, trained_masked_params_runs
                    ])

        output = dict(train_results=train_results_dict,
                      test_results=test_results_dict,
                      trained_params=trained_masked_dict)

        return output