class nt_transfer_model(): def __init__(self, dataset_str, prof_model_str, model_str, instance_input_shape, prof_input_shape, NN_DENSITY_LEVEL_LIST, OPTIMIZER_STR, NUM_RUNS, NUM_EPOCHS, BATCH_SIZE, STEP_SIZE, MASK_UPDATE_FREQ, LAMBDA_KER_DIST, LAMBDA_L2_REG, MASK_UPDATE_BOOL = True, VALIDATION_FRACTION = 0.002, PRUNE_METHOD = 'magnitude', GLOBAL_PRUNE_BOOL = False, INIT_RUN_INDEX = 1, SAVE_BOOL = True, save_dir = '/content/neural-tangent-transfer/saved/ntt_results/' ): """ Args: # Model options dataset_str: a string that describes the dataset model_str: a string that describes the model instance_input_shape: the shape of a single input data sample NN_DENSITY_LEVEL_LIST: a list of desired weight density levels # Optimization options OPTIMIZER_STR: a string of optimizer NUM_RUNS: number of independent runs NUM_EPOCHS: number of epochs BATCH_SIZE: size of a minibatch STEP_SIZE: the learning rate MASK_UPDATE_FREQ: number of mini-batch iterations, after which mask will be updated according to the magnitudes of weights. LAMBDA_KER_DIST: scaling parameter for kernel distance LAMBDA_L2_REG: scaling parameter for l2 penalty VALIDATION_FRACTION: the fraction of training data held-out as validation during NTT PRUNE_METHOD: the pruning method GLOBAL_PRUNE_BOOL: whether to use global (net-wise) pruning or not # Data & model saving options INIT_RUN_INDEX: the index of the initial run SAVE_BOOL: a boolean variable which decides whether the learned sparse parameters are saved. save_dir: the data saving directory. """ self.model_str = model_str self.prof_model_str = prof_model_str self.NN_DENSITY_LEVEL_LIST = NN_DENSITY_LEVEL_LIST self.DATASET = Dataset(datasource = dataset_str, VALIDATION_FRACTION = VALIDATION_FRACTION ) self.NUM_RUNS = NUM_RUNS self.BATCH_SIZE = BATCH_SIZE self.NUM_EPOCHS = NUM_EPOCHS self.OPTIMIZER_WITH_PARAMS = optimizer_dict[OPTIMIZER_STR](step_size = STEP_SIZE) self.LAMBDA_KER_DIST = LAMBDA_KER_DIST self.LAMBDA_L2_REG = LAMBDA_L2_REG self.SAVE_BOOL = SAVE_BOOL self.INIT_RUN_INDEX = INIT_RUN_INDEX self.GLOBAL_PRUNE_BOOL = GLOBAL_PRUNE_BOOL self.PRUNE_METHOD = PRUNE_METHOD if MASK_UPDATE_BOOL == False: logger.info("Mask update during NTT has been turned off") self.MASK_UPDATE_FREQ = np.inf # not updating masks else: self.MASK_UPDATE_FREQ = MASK_UPDATE_FREQ # updating masks every a MASK_UPDATE_FREQ number of iterations # time.sleep(orig_random.uniform(1,20)) # now = datetime.now() # now_str = str(now.strftime("%D:%H:%M:%S")).replace('/', ':') if GLOBAL_PRUNE_BOOL: global_layerwise_str = 'global_prune' else: global_layerwise_str = 'layerwise_prune' # self.unique_model_dir = save_dir + dataset_str + '_' + global_layerwise_str + '_' + self.model_str + '__' + now_str self.unique_model_dir = save_dir + dataset_str + '_' + global_layerwise_str + '_' + self.model_str self.param_dict = dict(model_str = model_str, prof_model_str = prof_model_str, dataset_str = dataset_str, instance_input_shape = instance_input_shape, prof_input_shape = prof_input_shape, NN_DENSITY_LEVEL_LIST = NN_DENSITY_LEVEL_LIST, OPTIMIZER_STR = OPTIMIZER_STR, NUM_RUNS = NUM_RUNS, NUM_EPOCHS = NUM_EPOCHS, BATCH_SIZE = BATCH_SIZE, STEP_SIZE = STEP_SIZE, MASK_UPDATE_FREQ = self.MASK_UPDATE_FREQ, LAMBDA_KER_DIST = LAMBDA_KER_DIST, LAMBDA_L2_REG = LAMBDA_L2_REG, SAVE_BOOL = SAVE_BOOL, VALIDATION_FRACTION = VALIDATION_FRACTION, GLOBAL_PRUNE_BOOL = GLOBAL_PRUNE_BOOL, PRUNE_METHOD = PRUNE_METHOD) # unpack the neural net architecture init_fun, apply_fn = model_dict[model_str](W_initializers_str = 'glorot_normal()', b_initializers_str = 'normal()') self.init_fun = init_fun self.apply_fn = apply_fn self.emp_ntk_fn = empirical_ntk_fn(apply_fn) # unpack the professor neural net architecture prof_init_fun, prof_apply_fn = model_dict[prof_model_str](W_initializers_str='glorot_normal()', b_initializers_str='normal()') self.prof_init_fun = prof_init_fun self.prof_apply_fn = prof_apply_fn self.prof_emp_ntk_fn = empirical_ntk_fn(prof_apply_fn) self.batch_input_shape = [-1] + instance_input_shape self.prof_batch_input_shape = [-1] + prof_input_shape self.vali_samples = self.DATASET.dataset['val']['input'][:self.BATCH_SIZE, :].reshape(self.batch_input_shape) self.prof_vali_samples = self.DATASET.dataset['val']['input'][:self.BATCH_SIZE, :].reshape(self.prof_batch_input_shape) # split validation inputs into two collections, vali_inputs_1 and vali_inputs_2. half_vali_size = int(len(self.vali_samples)/2) self.vali_inputs_1 = self.vali_samples[:half_vali_size] self.vali_inputs_2 = self.vali_samples[half_vali_size:] # split validation inputs into two collections, vali_inputs_1 and vali_inputs_2 for professor net with different input prof_half_vali_size = int(len(self.prof_vali_samples) / 2) self.prof_vali_inputs_1 = self.prof_vali_samples[:prof_half_vali_size] self.prof_vali_inputs_2 = self.prof_vali_samples[prof_half_vali_size:] def kernel_dist_target_dist_l2_loss(self, student_ker_mat, student_pred, prof_ker_mat, prof_pred, masked_params): """ Compute kernel distance, target distance, and parameter l2 loss. Args: student_ker_mat: a student-network ntk matrix, student_pred: a student-network prediction matrix. teacher_ker_mat: a teacher-network ntk matrix. teacher_pred: a teacher-network prediction matrix. Returns: ker_dist: squared l2 difference between two kernel matrices, normalized by the size of the matrix. target_dist: squared l2 norm difference between two prediction matrices, normalized by the size of the matrix. params_norm_squared: squared l2 norm of parameters. """ # the normalized squared difference between teacher and student NTK matrices ker_dist = np.sum(np.square(student_ker_mat - prof_ker_mat)) / prof_ker_mat.size #ker_dist = np.sum(np.square(student_ker_mat - teacher_ker_mat)) / teacher_ker_mat.size #t-s learning # the normalized squared difference between teacher and student network predictions target_dist = np.sum(np.square(student_pred - prof_pred)) / student_pred.size #target_dist = np.sum(np.square(student_pred - teacher_pred)) / student_pred.size #t-s learning # squared norm of parameters params_norm_squared = stax_params_l2_square(masked_params) return ker_dist, target_dist, params_norm_squared def eval_nt_transfer_loss_on_vali_data(self, masked_student_net_params, vali_prof_prediction, vali_prof_ntk_mat, density_level, key): """ Evaluate the ntk transfer loss using validation data. Args: masked_student_net_params: the masked student network's parameters. vali_teacher_prediction: the teacher network's prediction evaluated using validation inputs vali_teacher_ntk_mat: the teacher network's ntk matrix evaluated using validation inputs Returns: transfer_loss: the transfer loss ker_dist: the kernel distance target_dist: the target distance param_squared_norm: the l2 square of the parameters. """ # evaluate the student ntk matrix using validation data. vali_student_ntk_mat = self.emp_ntk_fn(self.vali_inputs_1, self.vali_inputs_2, masked_student_net_params, keys=random.PRNGKey(key)) # evaluate the student prediction using validation data. vali_student_prediction = self.apply_fn(masked_student_net_params, self.vali_samples, rng=random.PRNGKey(key)) # calculate the kernel distance, target distance, and parameter l2 loss ker_dist, target_dist, param_squared_norm = self.kernel_dist_target_dist_l2_loss(vali_student_ntk_mat, vali_student_prediction, vali_prof_ntk_mat, vali_prof_prediction, masked_student_net_params ) # weight these distances and sum them up. weighted_ker_dist = self.LAMBDA_KER_DIST * ker_dist weighted_parameters_squared_norm = (self.LAMBDA_L2_REG / density_level) * param_squared_norm transfer_loss = weighted_ker_dist + target_dist + weighted_parameters_squared_norm return transfer_loss, ker_dist, target_dist, param_squared_norm def nt_transfer_loss(self, student_net_params, masks, teacher_net_params, x, prof_x, prof_net_params, density_level, key): """ The loss function of NTK transfer. Args: student_net_params: network parameters. masks: a collection of binary masks. apply_fn: a function that maps a tupe of network-parameters and a set of network-input to network-output. analytic_ntk_fn: an analytic (parameter independent) ntk function. It is the ntk of the teacher neural network. emp_ntk_fn: an empirical (parameter dependent) ntk function. It is the ntk of the student neural network. teacher_net_params: a set of parameters used in dense, teacher network. x: the network inputs. Returns: transfer_loss: the loss function output. """ masked_student_net_params = get_sparse_params_filtered_by_masks(student_net_params, masks) # split inputs into two collections, x1 and x2. x1 = x[:int(len(x)/2)] x2 = x[int(len(x)/2):] # split input into two collection for professor batch prof_x1 = prof_x[:int(len(prof_x) / 2)] prof_x2 = prof_x[int(len(prof_x) / 2):] # student network prediction student_prediction = self.apply_fn(masked_student_net_params, x, rng=random.PRNGKey(key)) # teacher network prediction #teacher_prediction = self.apply_fn(teacher_net_params, x, rng=random.PRNGKey(key)) # professor network prediction prof_prediction = self.prof_apply_fn(prof_net_params, prof_x, rng=random.PRNGKey(key)) # student network's NTK evaluated on x1 and x2 student_ntk_mat = self.emp_ntk_fn(x1, x2, masked_student_net_params, keys=random.PRNGKey(key)) # teacher network's NTK evaluated on x1 and x2 #teacher_ntk_mat = self.emp_ntk_fn(x1, x2, teacher_net_params, keys=random.PRNGKey(key)) # professor network's NTK evaluated on x1 and x2 prof_ntk_mat = self.prof_emp_ntk_fn(prof_x1, prof_x2, prof_net_params, keys=random.PRNGKey(key)) # compute kernel, target, and paramter l2 loss ker_dist, target_dist, param_squared_norm = self.kernel_dist_target_dist_l2_loss(student_ntk_mat, student_prediction, prof_ntk_mat, prof_prediction, masked_student_net_params) # weight these losses to get the transfer loss transfer_loss = self.LAMBDA_KER_DIST * ker_dist + target_dist + (self.LAMBDA_L2_REG / density_level) * param_squared_norm return transfer_loss def optimize(self, return_teacher_params_bool = False): """ Carry out the optimization task. Arg: run_index: the index of independent run of the optimization. save_dir: the directory used to save the transferred results. Returns: nt_trans_params_all_runs: the transferred parameters nt_trans_masks_all_runs: the transferred masks nt_trans_vali_all_runs: a collection o fvalidation loss during training. """ gen_batches = self.DATASET.data_stream(self.BATCH_SIZE) num_complete_batches, leftover = divmod(self.DATASET.num_example['train'], self.BATCH_SIZE) # number of minibatches per epoch num_mini_batches_per_epochs = num_complete_batches + bool(leftover) # number of total iterations num_total_iters = self.NUM_EPOCHS * num_mini_batches_per_epochs # number of time that the sparisty levels get updated num_sparsity_updates = num_total_iters // self.MASK_UPDATE_FREQ mask_update_limit = num_total_iters - self.MASK_UPDATE_FREQ if self.SAVE_BOOL == True: # save the transferred results in the desinated directory. trans_model_dir = self.unique_model_dir # while os.path.exists(trans_model_dir): # trans_model_dir = trans_model_dir + '_0' os.makedirs(trans_model_dir) np.save(trans_model_dir + '/param_dict.npy', self.param_dict) nt_trans_params_all_sparsities_all_runs = [] nt_trans_masks_all_sparsities_all_runs = [] nt_trans_vali_all_sparsities_all_runs = [] teacher_params_all_sparsities_all_runs = [] num_sparisty_levels = len(self.NN_DENSITY_LEVEL_LIST) num_runs = len(range(self.INIT_RUN_INDEX, self.INIT_RUN_INDEX + self.NUM_RUNS )) all_density_all_run_num_total_iters = num_sparisty_levels * num_runs * num_total_iters for nn_density_level in self.NN_DENSITY_LEVEL_LIST: nt_trans_params_all_runs = [] nt_trans_masks_all_runs = [] nt_trans_vali_all_runs = [] teacher_params_all_runs = [] for run_index in range(self.INIT_RUN_INDEX, self.INIT_RUN_INDEX + self.NUM_RUNS ): # do logging for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) # a string that summarizes the current ntt experiment model_summary_str = self.model_str + '_density_' + str(round(nn_density_level, 2) ) + '_run_' + str(run_index) prof_model_summary_str = self.prof_model_str + '_density_' + str(round(nn_density_level, 2)) + '_run_' + str(run_index) if self.SAVE_BOOL == True: model_dir_density_run = trans_model_dir + '/' + 'density_' + str(round(nn_density_level, 2) ) + '/' + 'run_' + str(run_index) + '/' os.makedirs(model_dir_density_run) logging.basicConfig(filename = model_dir_density_run + "/" + model_summary_str + "_log.log", format='%(asctime)s %(message)s', filemode='w', level=logging.DEBUG) else: logging.basicConfig(filename = model_summary_str + "_log.log" , format='%(asctime)s %(message)s', filemode='w', level=logging.DEBUG) # for different run indices, randomly draw teacher net's parameters _, teacher_net_params = self.init_fun(random.PRNGKey(run_index), tuple(self.batch_input_shape)) _, prof_net_params = self.prof_init_fun(random.PRNGKey(run_index), tuple(self.prof_batch_input_shape)) # the prediction of the teacher net evaluated on validation samples #vali_teacher_prediction = self.apply_fn(teacher_net_params, self.vali_samples, rng=random.PRNGKey(run_index)) vali_prof_prediction = self.prof_apply_fn(prof_net_params, self.prof_vali_samples, rng=random.PRNGKey(run_index)) #vali_teacher_ntk_mat = self.emp_ntk_fn(self.vali_inputs_1, self.vali_inputs_2, teacher_net_params, keys=random.PRNGKey(run_index)) vali_prof_ntk_mat = self.prof_emp_ntk_fn(self.prof_vali_inputs_1, self.prof_vali_inputs_2, prof_net_params, keys=random.PRNGKey(run_index)) # the initial binary mask if self.PRUNE_METHOD == 'magnitude': masks = get_masks_from_jax_params(teacher_net_params, nn_density_level, global_bool = self.GLOBAL_PRUNE_BOOL) elif self.PRUNE_METHOD == 'logit_snip': logger.info("Use logit snip method to get the initial mask") num_examples_snip = 128 # gen_batches_logit_snip = self.DATASET.data_stream(num_examples_snip) snip_input = self.DATASET.dataset['train']['input'][:num_examples_snip, :] if self.GLOBAL_PRUNE_BOOL == False: logger.warning("layerwise sparse net initialized with logit_snip") masks = get_logit_snip_masks(teacher_net_params, nn_density_level, self.apply_fn, snip_input, self.batch_input_shape, GlOBAL_PRUNE_BOOL = self.GLOBAL_PRUNE_BOOL) else: raise NotImplementedError("not implemented") # the initial student parameters masked_student_net_params = get_sparse_params_filtered_by_masks(teacher_net_params, masks) # instantiate the optimizer triple opt_init, opt_update, get_params = self.OPTIMIZER_WITH_PARAMS opt_state = opt_init(teacher_net_params) #optimize toward teacher #opt_state = opt_init(prof_net_params) #optimize toward professor # one step of NTK transfer @jit def nt_transfer_step(i, opt_state, x, prof_x, masks): # parameters in the current optimizer state student_net_params = get_params(opt_state) # gradients that flow through the binary masks masked_g = grad(self.nt_transfer_loss)(student_net_params, masks, teacher_net_params, x, prof_x, prof_net_params, nn_density_level, key=run_index) return opt_update(i, masked_g, opt_state) # a list of validation loss vali_loss_list = [] # calculate the initial validation loss. vali_loss = self.eval_nt_transfer_loss_on_vali_data(masked_student_net_params, vali_prof_prediction, vali_prof_ntk_mat, nn_density_level, key=run_index) vali_loss_list.append(vali_loss) logger.info("Before transfer: trans dist %.3f | ntk dist %.3f | targ dist %.3f | l2 pentalty %.3f | nn density %.2f", vali_loss[0], vali_loss[1], vali_loss[2], vali_loss[3], nn_density_level) itercount = itertools.count() t = time.time() # loop through iterations for num_iter in range(1, num_total_iters + 1): # a batch of input data batch_xs, _ = next(gen_batches) # reshape the input to a proper format (2d array for MLP and 3d for CNN) batch_xs = batch_xs.reshape(self.batch_input_shape) prof_batch_xs = batch_xs.reshape(self.prof_batch_input_shape) # update the optimizer state opt_state = nt_transfer_step(next(itercount), opt_state, batch_xs, prof_batch_xs, masks ) if num_iter % 100 == 0: elapsed_time = time.time() - t if (num_iter <= 500) and (run_index == self.INIT_RUN_INDEX) and (nn_density_level == self.NN_DENSITY_LEVEL_LIST[0]): # estimate the program end time. remaining_iter_num = all_density_all_run_num_total_iters - num_iter remaining_seconds = elapsed_time * ( remaining_iter_num / 100 ) expected_end_time = str(datetime.now() + timedelta(seconds = remaining_seconds)) # get parameters from the current optimizer state student_net_params = get_params(opt_state) # filter the paramters by masks masked_student_net_params = get_sparse_params_filtered_by_masks(student_net_params , masks) # validation loss vali_loss = self.eval_nt_transfer_loss_on_vali_data(masked_student_net_params, vali_prof_prediction, vali_prof_ntk_mat, nn_density_level, key=run_index) vali_loss_list.append(vali_loss) logger.info('run: %02d/%02d | iter %04d/%04d | trans. dist %.3f | ntk dist %.3f | targ. dist %.3f | l2 %.3f | nn density %.2f | time %.2f [s] | expected finish time %s', run_index, self.NUM_RUNS + self.INIT_RUN_INDEX - 1, num_iter, num_total_iters, vali_loss[0], vali_loss[1], vali_loss[2], vali_loss[3], nn_density_level, elapsed_time, expected_end_time) t = time.time() if (num_iter % self.MASK_UPDATE_FREQ == 0) and (num_iter < mask_update_limit): # get parameters from the current optimizer state student_net_params = get_params(opt_state) # update masks masks = get_masks_from_jax_params(student_net_params, nn_density_level, global_bool = self.GLOBAL_PRUNE_BOOL) # if self.PRUNE_METHOD == 'logit_snip': # logit_snip_batch_xs, _ = next(gen_batches_logit_snip) # masks = get_logit_snip_masks(student_net_params, nn_density_level, self.apply_fn, snip_input, self.batch_input_shape, GlOBAL_PRUNE_BOOL = self.GLOBAL_PRUNE_BOOL) # else: # masks = get_masks_from_jax_params(student_net_params, nn_density_level, global_bool = self.GLOBAL_PRUNE_BOOL) elapsed_time = time.time() - t student_net_params = get_params(opt_state) # filter the paramters by masks masked_student_net_params = get_sparse_params_filtered_by_masks(student_net_params , masks) vali_loss = self.eval_nt_transfer_loss_on_vali_data(masked_student_net_params, vali_prof_prediction, vali_prof_ntk_mat, nn_density_level, key=run_index) vali_loss_list.append(vali_loss) logger.info('run: %02d/%02d | iter %04d/%04d | trans. dist %.3f | ntk dist %.3f | targ. dist %.3f | l2 %.3f | nn density %.2f | time %.2f [s]', run_index, self.NUM_RUNS + self.INIT_RUN_INDEX - 1, num_iter, num_total_iters, vali_loss[0], vali_loss[1], vali_loss[2], vali_loss[3], nn_density_level, elapsed_time ) vali_loss_array = np.array(vali_loss_list) nt_trans_params_all_runs.append(masked_student_net_params) nt_trans_masks_all_runs.append(masks) nt_trans_vali_all_runs.append(vali_loss_array) teacher_params_all_runs.append(teacher_net_params ) if self.SAVE_BOOL == True: model_summary_str = self.model_str + '_density_' + str(round(nn_density_level, 2) ) + '_run_' + str(run_index) prof_model_summary_str = self.prof_model_str + '_density_' + str(round(nn_density_level, 2)) + '_run_' + str(run_index) prof_param_fileName = model_dir_density_run + 'prof_params_' + prof_model_summary_str np.save(prof_param_fileName, prof_net_params) teacher_param_fileName = model_dir_density_run + 'teacher_params_' + model_summary_str np.save(teacher_param_fileName, teacher_net_params) student_param_fileName = model_dir_density_run + 'transferred_params_' + model_summary_str np.save(student_param_fileName, masked_student_net_params) mask_fileName = model_dir_density_run + 'transferred_masks_' + model_summary_str np.save(mask_fileName, masks) loss_array_fileName = model_dir_density_run + 'loss_array_' + model_summary_str np.save(loss_array_fileName, vali_loss_array) nt_trans_params_all_sparsities_all_runs.append( nt_trans_params_all_runs ) nt_trans_masks_all_sparsities_all_runs.append( nt_trans_masks_all_runs ) nt_trans_vali_all_sparsities_all_runs.append( nt_trans_vali_all_runs ) teacher_params_all_sparsities_all_runs.append( teacher_params_all_runs ) if return_teacher_params_bool: return nt_trans_params_all_sparsities_all_runs, nt_trans_masks_all_sparsities_all_runs, nt_trans_vali_all_sparsities_all_runs, teacher_params_all_sparsities_all_runs else: return nt_trans_params_all_sparsities_all_runs, nt_trans_masks_all_sparsities_all_runs, nt_trans_vali_all_sparsities_all_runs
def supervised_optimization(self, sup_density_list, wiring_str, save_supervised_result_bool, dataset_str, EXPLOITATION_NUM_EPOCHS, EXPLOITATION_BATCH_SIZE, OPTIMIZER_STR, STEP_SIZE, REG, W_initializers_str='glorot_normal()', b_initializers_str='normal()', init_weight_rescale_bool=False, EXPLOITATION_VALIDATION_FRACTION=0.1, EXPLOIT_TRAIN_DATASET_FRACTION=1.0, RECORD_ACC_FREQ=100, DROPOUT_LAYER_POS=[], **kwargs): """ Train a neural network with loaded wiring from scratch. Args: sup_density_list: a list of network density levels wiring_str: a string that represents the network wiring, e.g., trans, rand, snip dataset_str: a string used to retreive the dataset EXPLOITATION_NUM_EPOCHS: the number of epochs used in supervsied training EXPLOITATION_BATCH_SIZE: the batch size used in supervsied training OPTIMIZER_STR: a string used to retreive the optimzier STEP_SIZE: step size of the optimizer REG: l2 regularization constant EXPLOITATION_VALIDATION_FRACTION: the fraction of training data held out for validation purpose EXPLOIT_TRAIN_DATASET_FRACTION: the fraction of training data used in evaluation. RECORD_ACC_FREQ: the frequency for recording train and test results Returns: train_acc_list_runs: a list of training accuracy test_acc_list_runs: a list of testing accuracy """ for density in sup_density_list: if density not in self.ntt_setup_dict['NN_DENSITY_LEVEL_LIST']: raise ValueError( 'The desired density level for supervised training is not used in NTT.' ) dataset_info = Dataset( datasource=dataset_str, VALIDATION_FRACTION=EXPLOITATION_VALIDATION_FRACTION) dataset = dataset_info.dataset # configure the dataset gen_batches = dataset_info.data_stream(EXPLOITATION_BATCH_SIZE) batch_input_shape = [-1] + self.ntt_setup_dict['instance_input_shape'] nr_training_samples = len(dataset['train']['input']) nr_training_samples_subset = int(nr_training_samples * EXPLOIT_TRAIN_DATASET_FRACTION) train_input = dataset['train'][ 'input'][:nr_training_samples_subset].reshape(batch_input_shape) train_label = dataset['train']['label'][:nr_training_samples_subset] test_input = dataset['test']['input'].reshape(batch_input_shape) test_label = dataset['test']['label'] num_complete_batches, leftover = divmod(nr_training_samples, EXPLOITATION_BATCH_SIZE) num_mini_batches_per_epochs = num_complete_batches + bool(leftover) total_batch = EXPLOITATION_NUM_EPOCHS * num_mini_batches_per_epochs if len(DROPOUT_LAYER_POS) == 0: # in this case, dropout is NOT used init_fun_no_dropout, f_train = model_dict[self.model_str]( W_initializers_str=W_initializers_str, b_initializers_str=b_initializers_str) f_test = f_train f_no_dropout = f_train key_dropout = None subkey_dropout = None else: # in this case, dropout is used _, f_train = model_dict[self.model_str + '_dropout']( mode='train', W_initializers_str=W_initializers_str, b_initializers_str=b_initializers_str) _, f_test = model_dict[self.model_str + '_dropout']( mode='test', W_initializers_str=W_initializers_str, b_initializers_str=b_initializers_str) init_fun_no_dropout, f_no_dropout = model_dict[self.model_str]( W_initializers_str=W_initializers_str, b_initializers_str=b_initializers_str) key_dropout = random.PRNGKey(0) @jit def step(i, opt_state, x, y, masks, key): this_step_params = get_params(opt_state) masked_g = grad(softmax_cross_entropy_with_logits_l2_reg)( this_step_params, f_train, x, y, masks, L2_REG_COEFF=REG, key=key) return opt_update(i, masked_g, opt_state) train_results_dict = {} test_results_dict = {} trained_masked_dict = {} for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) time.sleep(orig_random.uniform(1, 5)) now_str = '__' + str(datetime.now().strftime("%D:%H:%M:%S")).replace( '/', ':') supervised_model_info = '[u]' + self.ntt_file_name + '_[s]' + dataset_str supervised_model_wiring_info = supervised_model_info + '_' + wiring_str supervised_model_wiring_dir = self.supervised_result_path + supervised_model_info + '/' + supervised_model_wiring_info + now_str if save_supervised_result_bool: while os.path.exists(supervised_model_wiring_dir): temp = supervised_model_wiring_dir + '_0' supervised_model_wiring_dir = temp # print(supervised_model_wiring_dir) os.makedirs(supervised_model_wiring_dir) logging.basicConfig(filename=supervised_model_wiring_dir + "/supervised_learning_log.log", format='%(asctime)s %(message)s', filemode='w', level=logging.DEBUG) else: logging.basicConfig(filename="supervised_learning_log.log", format='%(asctime)s %(message)s', filemode='w', level=logging.DEBUG) for nn_density_level in sup_density_list: nn_density_level = onp.round(nn_density_level, 2) train_acc_list_runs = [] test_acc_list_runs = [] trained_masked_params_runs = [] for run_index in range(1, self.ntt_setup_dict['NUM_RUNS'] + 1): if wiring_str == 'trans': # load ntt masks and parameters density_run_dir = '/' + 'density_' + str( nn_density_level) + '/' + 'run_' + str(run_index) transferred_masks_fileName = '/transferred_masks_' + self.model_str + density_run_dir.replace( '/', '_') + '.npy' transferred_param_fileName = '/transferred_params_' + self.model_str + density_run_dir.replace( '/', '_') + '.npy' masks = list( np.load(self.ntt_result_path + density_run_dir + transferred_masks_fileName, allow_pickle=True)) masked_params = list( np.load(self.ntt_result_path + density_run_dir + transferred_param_fileName, allow_pickle=True)) elif wiring_str == 'rand': # randomly initialize masks and parameters _, params = init_fun_no_dropout(random.PRNGKey(run_index), tuple(batch_input_shape)) masks = get_masks_from_jax_params( params, nn_density_level, global_bool=self.ntt_setup_dict['GLOBAL_PRUNE_BOOL'], magnitude_base_bool=False, reshuffle_seed=run_index) masked_params = get_sparse_params_filtered_by_masks( params, masks) elif wiring_str == 'dense': # randomly initialize masks and parameters _, params = init_fun_no_dropout(random.PRNGKey(run_index), tuple(batch_input_shape)) # masks = get_masks_from_jax_params(params, nn_density_level, global_bool = self.ntt_setup_dict['GLOBAL_PRUNE_BOOL'], magnitude_base_bool = False, reshuffle_seed = run_index) logger.info("Dense net!!") masks = None masked_params = params elif wiring_str == 'snip': # randomly initialize masks and parameters if dataset_str == 'cifar-10': num_examples_snip = 128 else: num_examples_snip = 100 snip_input = dataset['train']['input'][:num_examples_snip] snip_label = dataset['train']['label'][:num_examples_snip] snip_batch = (snip_input, snip_label) _, params = init_fun_no_dropout(random.PRNGKey(run_index), tuple(batch_input_shape)) if not self.ntt_setup_dict['GLOBAL_PRUNE_BOOL']: logger.info("Use layerwise snip") masks = get_snip_masks( params, nn_density_level, f_no_dropout, snip_batch, batch_input_shape, self.ntt_setup_dict['GLOBAL_PRUNE_BOOL']) masked_params = get_sparse_params_filtered_by_masks( params, masks) elif wiring_str == 'logit_snip': # randomly initialize masks and parameters if dataset_str == 'cifar-10': num_examples_snip = 128 else: num_examples_snip = 100 snip_input = dataset['train']['input'][:num_examples_snip] _, params = init_fun_no_dropout(random.PRNGKey(run_index), tuple(batch_input_shape)) masks = get_logit_snip_masks( params, nn_density_level, f_no_dropout, snip_input, batch_input_shape, self.ntt_setup_dict['GLOBAL_PRUNE_BOOL']) # get_snip_masks(params, nn_density_level, f_no_dropout, snip_batch, batch_input_shape) masked_params = get_sparse_params_filtered_by_masks( params, masks) else: raise ValueError('The wiring string is undefined.') # optionally, add dropout layers #Test without dropout masks if len(DROPOUT_LAYER_POS) > 100: dropout_masked_params = [ () ] * (len(masked_params) + len(DROPOUT_LAYER_POS)) dropout_masks = [[]] * (len(masked_params) + len(DROPOUT_LAYER_POS)) print(len(masked_params)) #check dropout position #pprint(masked_params) # check num_inserted = 0 for i in range(len(dropout_masked_params)): if i in DROPOUT_LAYER_POS: num_inserted += 1 else: dropout_masked_params[i] = masked_params[ i - num_inserted] dropout_masks[i] = masks[i - num_inserted] masks = dropout_masks masked_params = dropout_masked_params if init_weight_rescale_bool == True: logger.info( "Init weight rescaled: W_scaled = W/sqrt(nn_density_level)" ) scaled_params = [] for i in range(len(masked_params)): if len(masked_params[i]) == 2: scaled_params.append( (masked_params[i][0] * np.sqrt(1 / nn_density_level), masked_params[i][1])) else: scaled_params.append(masked_params[i]) masked_params = scaled_params optimizer_with_params = optimizer_dict[OPTIMIZER_STR]( step_size=STEP_SIZE) opt_init, opt_update, get_params = optimizer_with_params opt_state = opt_init(masked_params) train_acc_list = [] test_acc_list = [] itercount = itertools.count() for iteration in range(total_batch): batch_xs, batch_ys = next(gen_batches) batch_xs = batch_xs.reshape(batch_input_shape) if key_dropout is not None: key_dropout, subkey_dropout = random.split(key_dropout) opt_state = step(next(itercount), opt_state, batch_xs, batch_ys, masks=masks, key=subkey_dropout) if iteration % RECORD_ACC_FREQ == 0: masked_trans_params = get_params(opt_state) train_acc = accuracy(masked_trans_params, f_test, train_input, train_label, key_dropout) test_acc = accuracy(masked_trans_params, f_test, test_input, test_label, key_dropout) train_acc_list.append(train_acc) test_acc_list.append(test_acc) logger.info( "NN density %.2f | Run %03d/%03d | Iteration %03d/%03d | Train acc %.2f%% | Test acc %.2f%%", nn_density_level, run_index, self.ntt_setup_dict['NUM_RUNS'], iteration + 1, total_batch, train_acc * 100, test_acc * 100) trained_masked_trans_params = get_params(opt_state) train_acc_list_runs.append(train_acc_list) test_acc_list_runs.append(test_acc_list) trained_masked_params_runs.append(trained_masked_trans_params) train_acc_list_runs = np.array(train_acc_list_runs) test_acc_list_runs = np.array(test_acc_list_runs) train_results_dict[str(nn_density_level)] = train_acc_list_runs test_results_dict[str(nn_density_level)] = test_acc_list_runs trained_masked_dict[str( nn_density_level)] = trained_masked_params_runs if save_supervised_result_bool: supervised_model_wiring_dir_run = supervised_model_wiring_dir + '/density_' + str( round(nn_density_level, 2)) + '/' while os.path.exists(supervised_model_wiring_dir_run): temp = supervised_model_wiring_dir_run + '_0' supervised_model_wiring_dir_run = temp os.makedirs(supervised_model_wiring_dir_run) model_summary_str = '[u]' + self.ntt_file_name + '_[s]' + dataset_str + '_density_' + str( round(nn_density_level, 2)) np.save( supervised_model_wiring_dir_run + '/' + 'supervised_trained_' + model_summary_str, [ nn_density_level, train_acc_list_runs, test_acc_list_runs, trained_masked_params_runs ]) output = dict(train_results=train_results_dict, test_results=test_results_dict, trained_params=trained_masked_dict) return output