for i in range(start, split_param_size): masks[split_metrics[i][0]][split_metrics[i][1]] = 1 num += 1 print('NUM: ', num) else: threshold = np.mean(metrics_list) # + np.std(metrics_list) size = 0 for i in range(len(split_metrics)): if split_metrics[i][2] < threshold: break else: masks[split_metrics[i][0]][split_metrics[i][1]] = 1 size += 1 print('split size: ', size) network.set_param_values(init_param_value) split_param_size += 1 if split_param_size != 0: split_network = MLP_MaskedSplit( input_shape=(in_dim + len(difficulties) + 1, ), output_dim=out_dim, hidden_sizes=hidden_size, hidden_nonlinearity=NL.tanh, output_nonlinearity=None, split_num=len(difficulties) + 1, split_masks=masks, init_net=network, ) else: split_network = copy.deepcopy(network) split_network.set_param_values(init_param_value)
init_param_value = np.copy(network.get_param_values()) #Xs, Ys = synthesize_data(dim, 2000, tasks) task_grads = [] for i in range(len(trainingXs)): task_grads.append([]) if not load_split_data: net_weight_values = [] for i in range(epochs): net_weight_values.append(network.get_param_values()) train(train_fn, np.concatenate(trainingXs), np.concatenate(trainingYs), 1) joblib.dump(net_weight_values, 'data/trained/gradient_temp/supervised_split_' + append + '/net_weight_values.pkl', compress=True) else: net_weight_values = joblib.load('data/trained/gradient_temp/supervised_split_' + append + '/net_weight_values.pkl') for i in range(epochs): network.set_param_values(net_weight_values[i]) for j in range(len(trainingXs)): grad = grad_fn(trainingXs[j], trainingYs[j]) task_grads[j].append(grad) print('------- collected gradient info -------------') pred = out(np.concatenate(testingXs)) split_counts = [] for i in range(len(task_grads[0][0])): split_counts.append(np.zeros(task_grads[0][0][i].shape)) for i in range(len(task_grads[0])): for k in range(len(task_grads[0][i])): region_gradients = [] for region in range(len(task_grads)):