for sea in seasons:
    print('===== {} training ====='.format(sea))
    record = 999
    # model name
    model_name = 'UNET{}_{}_{}_clean'.format(N_input, VAR, sea)
    model_path = temp_dir + model_name + '.hdf'  # model checkpoint
    train_path = temp_dir + model_name + '_adam.npy'
    tune_path = temp_dir + model_name + '_sgd.npy'

    # full list of files
    validfiles = glob(file_path +
                      'TMEAN_BATCH_*_VORI-JRA-clean_{}*'.format(sea))
    trainfile = glob(file_path +
                     'TMEAN_BATCH_*_TORI-ERA-clean_{}*'.format(sea))
    gen_valid = tu.grid_grid_gen(validfiles[::2], labels, input_flag,
                                 output_flag)
    # model
    model = mu.UNET(N, (None, None, N_input))
    opt_sgd = keras.optimizers.SGD(lr=l)
    model.compile(loss=keras.losses.mean_absolute_error, optimizer=opt_sgd)
    W = tu.dummy_loader(temp_dir + 'UNET3_TMEAN_{}_tune.hdf'.format(sea))
    model.set_weights(W)

    # loss backup
    LOSS = np.zeros([int(epochs * L_train)]) * np.nan
    LOSS_tune = np.zeros([int(epochs * L_train)]) * np.nan
    VLOSS = np.zeros([epochs]) * np.nan
    VLOSS_tune = np.zeros([epochs]) * np.nan

    tol = 0
    for i in range(epochs):
Ejemplo n.º 2
0
trainfiles = glob(file_path+'{}_BATCH_*_TORI*_{}*.npy'.format(VAR, sea))
validfiles = glob(file_path+'{}_BATCH_*_VORI*_{}*.npy'.format(VAR, sea))
#
model_path = temp_dir+'DAE_{}_{}_elev.hdf'.format(VAR, sea)
train_path = temp_dir+'DAE_{}_{}_elev.npy'.format(VAR, sea)

DAE = mu.DAE(N, input_size)

# optimizer & callback & compile
opt_ae = keras.optimizers.Adam(lr=l[0])
callbacks = [keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0.00001, patience=2, verbose=True),
             keras.callbacks.ModelCheckpoint(filepath=model_path, verbose=True, monitor='val_loss', save_best_only=True)]
DAE.compile(loss=keras.losses.mean_absolute_error, optimizer=opt_ae, metrics=[keras.losses.mean_absolute_error])

# Data generator
gen_train = tu.grid_grid_gen(trainfiles, labels, input_flag, output_flag)
gen_valid = tu.grid_grid_gen(validfiles, labels, input_flag, output_flag)

# train
temp_hist = DAE.fit_generator(generator=gen_train, validation_data=gen_valid, callbacks=callbacks, 
                              initial_epoch=0, epochs=epochs, verbose=1, shuffle=True, max_queue_size=8, workers=8)

W = DAE.get_weights() # backup weights
DAE_tune = mu.DAE(N, input_size)
opt_ae = keras.optimizers.SGD(lr=l[1], decay=1e-2*l[1])

callbacks = [keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0.0000001, patience=2, verbose=True),
             keras.callbacks.ModelCheckpoint(filepath=model_path, verbose=True, monitor='val_loss', save_best_only=True)]

DAE_tune.compile(loss=keras.losses.mean_absolute_error, optimizer=opt_ae, metrics=[keras.losses.mean_absolute_error])
DAE_tune.set_weights(W)
Ejemplo n.º 3
0
    print('Preparing data generators')
    trainfiles_t2 = glob(file_path+'{}_BATCH_*_TORI_*{}*.npy'.format(VAR, sea))+\
                    glob(file_path+'{}_BATCH_*_VORI_*{}*.npy'.format(VAR, sea)) # training domain t2 target
    validfiles_t2 = glob(file_path+'{}_BATCH_*_TSUB_*{}*.npy'.format(VAR, sea))+\
                    glob(file_path+'{}_BATCH_*_VSUB_*{}*.npy'.format(VAR, sea))
    trainfiles_elev = glob(file_path+'{}_BATCH_*_TMIX_*{}*.npy'.format(VAR, sea))+\
                      glob(file_path+'{}_BATCH_*_VMIX_*{}*.npy'.format(VAR, sea)) # transferring & tuning domain elev
    # shuffle filenames
    shuffle(trainfiles_t2)
    shuffle(validfiles_t2)
    shuffle(trainfiles_elev)
    # the data generator for elev tuner and t2 tuner
    gen_elev = tu.grid_grid_gen(trainfiles_elev,
                                labels,
                                input_flag,
                                out_flag_elev,
                                sign_flag=True)
    gen_train = tu.grid_grid_gen_multi(trainfiles_t2, labels, input_flag,
                                       output_flag)
    gen_valid = tu.grid_grid_gen_multi(validfiles_t2, labels, input_flag,
                                       output_flag)

    print('Importing pre-trained weights')
    # import pre-trained model (e.g., 'UNET_TMAX_A3_djf.hdf')
    model_name = 'UAE{}_{}_{}_tune'.format(N_input, VAR, sea)
    model_path = temp_dir + model_name + '.hdf'
    print('\tmodel: {}'.format(model_name))
    backbone = keras.models.load_model(model_path)
    W = backbone.get_weights()
    # tuned model