def build_model(config, train_dataset, train_labels, set_non_trainable=False, train_stats=None): """ Build the NN model based on the input """ ModelArchitect = config['MODEL']['ModelArchitect'] NodesList = mrnn_utility.getlist_int(config['MODEL']['NodesList']) Activation = mrnn_utility.getlist_str(config['MODEL']['Activation']) if (len(NodesList) != len(Activation)): raise ValueError( 'In the config file, number of NodesList != Activation list with NodesList = ', NodesList, ' and Activation = ', Activation) if (ModelArchitect.lower() == "dnn_kregl1l2_gauss".lower()): model = DNN_kregl1l2_gauss(config, train_dataset, train_labels, NodesList, Activation) elif (ModelArchitect.lower() == "user_dnn_kregl1l2_gauss_grad".lower()): model = user_DNN_kregl1l2_gauss_grad_setup(config, train_dataset, train_labels, NodesList, Activation, train_stats) elif (ModelArchitect.lower().find('cnn') >= 0): LayerName = mrnn_utility.getlist_str(config['MODEL']['LayerName']) Padding = mrnn_utility.getlist_str(config['MODEL']['Padding']) if (len(NodesList) != len(LayerName) or len(NodesList) != len(Padding)): raise ValueError( 'In the config file, number of NodesList != LayerName with NodesList = ', NodesList, len(NodesList), ' and LayerName = ', LayerName, len(LayerName), 'and Padding = ', Padding, len(Padding)) elif (ModelArchitect.lower() == "CNN_user_supervise".lower()): model = CNN_user_supervise_setup(config, train_dataset, train_labels, NodesList, Activation, LayerName, Padding) elif (ModelArchitect.lower() == "CNN_supervise".lower()): model = CNN_supervise(config, train_dataset, train_labels, NodesList, Activation, LayerName, Padding) else: raise ValueError('Model architect = ', ModelArchitect, ' is chosen, but is not implemented!') else: raise ValueError('Model architect = ', ModelArchitect, ' is chosen, but is not implemented!') if set_non_trainable: model.trainable = False return model
def build_learningrate(config): """ Build different learning rates based on the input """ LR = mrnn_utility.getlist_str(config['MODEL']['LearningRate']) # print('LR str: ', LR) if (len(LR) == 1): LearningRate = float(LR[0]) # print('--Decay in mono rate: rate = ', LearningRate) elif (len(LR) > 1): LR_type = LR[0] if (LR_type == 'mono'): LearningRate = float(LR[1]) elif (LR_type == 'decay_exp'): initial_learning_rate = 0.001 decay_steps = 1000 decay_rate = 0.96 initial_learning_rate = float(LR[1]) if len(LR) > 2: decay_steps = int(LR[2]) if len(LR) > 3: decay_rate = float(LR[3]) # print('--Decay in exponential rate: initial rate = ', initial_learning_rate, ', decay steps = ', decay_steps, ', decay_rate = ', decay_rate) if (mrnn_utility.get_package_version(tf.__version__)[0] == 1 and mrnn_utility.get_package_version(tf.__version__)[1] <= 13): global_step = tf.Variable(0, name='global_step', trainable=False) global_step = tf.train.get_global_step() LearningRate = tf.train.exponential_decay( initial_learning_rate, global_step, decay_steps=decay_steps, decay_rate=decay_rate, staircase=True) else: # print("!!!! Caution: use Learning rate with care, there were occasions that tf1.13 should better performance on training. !!!") global_step = tf.Variable(0, name='global_step', trainable=False) LearningRate = tf.compat.v1.train.exponential_decay( initial_learning_rate, global_step, decay_steps=decay_steps, decay_rate=decay_rate, staircase=True) else: raise ValueError( 'unknown choice for learning rate (mono, decay_exp): ', LR) return LearningRate
def build_callbacks(config): """ Build different call back functions based on the input """ callbacks = [] callback_names = mrnn_utility.getlist_str(config['MODEL']['CallBacks']) if 'checkpoint' in callback_names: callbacks.append(check_point_callback(config)) if 'tensorboard' in callback_names: callbacks.append(tensor_board_callback(config)) if 'printdot' in callback_names: callbacks.append(callback_PrintDot()) return callbacks
train_dataset, train_labels, train_stats=train_stats) if (config['RESTART']['RestartWeight'].lower() == 'y'): print('checkpoint_dir for restart: ', checkpoint_dir) latest = tf.train.latest_checkpoint(checkpoint_dir) print("latest checkpoint: ", latest) if (latest != None): model.load_weights(latest) print("Successfully load weight: ", latest) else: print("No saved weights, start to train the model from the beginning!") pass metrics = mrnn_utility.getlist_str(config['MODEL']['Metrics']) optimizer = mrnn_models.build_optimizer(config) loss = mrnn_models.my_mse_loss_with_grad(BetaP=0.0) model.compile(loss=loss, optimizer=optimizer, metrics=metrics) label_scale = float(config['TEST']['LabelScale']) callbacks = mrnn_models.build_callbacks(config) train_dataset = train_dataset.to_numpy() train_labels = train_labels.to_numpy() val_dataset = val_dataset.to_numpy() val_labels = val_labels.to_numpy() test_dataset = test_dataset.to_numpy() test_labels = test_labels.to_numpy() # make sure that the derivative data is scaled correctly
def shift_labels(config, dataset, dataset_index, dataset_frame, data_file): """ Shift label based on the trained-NN predictions """ # print("---!!!!--- reach shift_labels!!!!") # print("---!!!!--- Remember to modify old 'std', old 'mean' for DNN based KBNN") trained_model_lists = mrnn_utility.getlist_str( config['KBNN']['LabelShiftingModels']) if len(trained_model_lists) > 0: # all_fields = mrnn_utility.getlist_str(config['TEST']['AllFields']) label_fields = mrnn_utility.getlist_str(config['TEST']['LabelFields']) if len(label_fields) > 1: # raise ValueError( # 'Shift labels is not working for two labels shifting yet!') print('Shift labels is not working for two labels shifting yet!') # if label_fields[0] != all_fields[-1]: # raise ValueError('the single label for KBNN should put at the end of all label fields!') # print("---!!!!--- load trained model!!!!") old_models = load_trained_model(trained_model_lists) # print("---!!!!--- after load trained model!!!!") key0 = label_fields[0] old_label_scale = mrnn_utility.getlist_float( config['KBNN']['OldShiftLabelScale']) # print('old shift features: ', config['KBNN']['OldShiftFeatures']) # to switch between vtk and other features if (config['KBNN']['OldShiftFeatures'].find('.vtk') >= 0): # """ """ # print("--- here: vtk for label shift") # index should not be used anymore. # use base frame info to do the base free energy shifting dataset_old = mrnn_utility.load_data_from_vtk_for_label_shift_frame( config, dataset_frame, normalization_flag=True, verbose=0) elif (config['KBNN']['OldShiftFeatures'].find('.npy') >= 0): # """ """ # print("--- here: npy for label shift") # index should not be used anymore. # use base frame info to do the base free energy shifting dataset_old = mrnn_utility.load_data_from_npy_for_label_shift_frame( config, dataset_frame, normalization_flag=True, verbose=0) else: old_feature_fields = mrnn_utility.getlist_str( config['KBNN']['OldShiftFeatures']) raw_dataset_old = mrnn_utility.read_csv_fields( data_file, old_feature_fields) dataset_old = raw_dataset_old.copy() if len(old_feature_fields) > 0: old_mean = mrnn_utility.getlist_float( config['KBNN']['OldShiftMean']) old_std = mrnn_utility.getlist_float( config['KBNN']['OldShiftStd']) old_data_norm = int(config['KBNN']['OldShiftDataNormOption']) if (old_data_norm != 2): dataset_old = (dataset_old - old_mean) / old_std else: dataset_old = (dataset_old - old_mean) / old_std + 0.5 raise ValueError( "This part is not carefully checked. Please check it before you disable it." ) if (len(old_models) > 0): # convert dataset_old to numpy data array in case it is not try: dataset_old = dataset_old.to_numpy() except: try: dataset_old = dataset_old.numpy() except: pass pass for model0 in old_models: label_shift_amount = [] batch_size = int(config['MODEL']['BatchSize']) print('run...', model0) # use model.predict() will run it in the eager mode and evaluate the tensor properly. for i0 in range(0, len(dataset_old), batch_size): tmp_shift = model0.predict( mrnn_utility.special_input_case( dataset_old[i0:i0 + batch_size]) ) / old_label_scale # numpy type label_shift_amount.extend(tmp_shift) for i0 in range(0, len(dataset[key0])): a = dataset[key0][i0] - label_shift_amount[i0] # tf1.13 if (i0 % 200 == 0): print('--i0--', i0, 'DNS:', dataset[key0][i0], '\t', 'NN:', label_shift_amount[i0], ' key0 = ', key0, ' dataset size = ', len(dataset[key0]), ' label shift size = ', len(label_shift_amount)) # for tf2.0 # print('--i0--',i0, 'DNS:', dataset[key0][i0],'\t', 'NN:',label_shift_amount[i0].numpy()[0], '\t', a.numpy()[0], '\t', abs(a.numpy()[0]/dataset[key0][i0])*100, 'new label', new_label[key0][i0]) dataset[key0][i0] = dataset[key0][i0] - label_shift_amount[i0]