def dummy_loss_model(VAR, sea): AE_path = temp_dir+'DAE_{}_{}.hdf'.format(VAR, sea) # model checkpoint AE = keras.models.load_model(AE_path) W = AE.get_weights() DAE = mu.DAE(AE_N, AE_input_size, AE_latent_size) DAE.set_weights(W) # Encoder layer selection encoder = DAE.layers[1] # freeze enoder layers encoder.trainable = False for layer in encoder.layers: layer.trainable = False f_sproj = [encoder.layers[i].output for i in layer_id] # Loss models loss_models = [] for single_proj in f_sproj: loss_models.append(keras.models.Model(encoder.inputs, single_proj)) return loss_models
def loss_model(VAR, sea, layer_id): model_path = temp_dir + 'DAE_{}_{}_self.hdf'.format( VAR, sea) # model checkpoint AE = keras.models.load_model(model_path) W = AE.get_weights() N = [48, 96, 192, 384] input_size = (None, None, 1) # DAE DAE = mu.DAE(N, input_size) DAE.set_weights(W) # freeze DAE.trainable = False for layer in DAE.layers: layer.trainable = False f_sproj = [DAE.layers[i].output for i in layer_id] loss_models = [] for single_proj in f_sproj: loss_models.append(keras.models.Model(DAE.inputs, single_proj)) return loss_models
l = [1e-4, 1e-5] # learning rate epochs = 200 # sapce for early stopping # DAE N = [48, 96, 192, 384] input_size = (None, None, 1) # training file location file_path = BATCH_dir trainfiles = glob(file_path+'{}_BATCH_*_TORI*_{}*.npy'.format(VAR, sea)) validfiles = glob(file_path+'{}_BATCH_*_VORI*_{}*.npy'.format(VAR, sea)) # model_path = temp_dir+'DAE_{}_{}_elev.hdf'.format(VAR, sea) train_path = temp_dir+'DAE_{}_{}_elev.npy'.format(VAR, sea) DAE = mu.DAE(N, input_size) # optimizer & callback & compile opt_ae = keras.optimizers.Adam(lr=l[0]) callbacks = [keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0.00001, patience=2, verbose=True), keras.callbacks.ModelCheckpoint(filepath=model_path, verbose=True, monitor='val_loss', save_best_only=True)] DAE.compile(loss=keras.losses.mean_absolute_error, optimizer=opt_ae, metrics=[keras.losses.mean_absolute_error]) # Data generator gen_train = tu.grid_grid_gen(trainfiles, labels, input_flag, output_flag) gen_valid = tu.grid_grid_gen(validfiles, labels, input_flag, output_flag) # train temp_hist = DAE.fit_generator(generator=gen_train, validation_data=gen_valid, callbacks=callbacks, initial_epoch=0, epochs=epochs, verbose=1, shuffle=True, max_queue_size=8, workers=8)