def finetune_q_model(): X_train, X_valid, Y_train, Y_valid = Datasets.make_stacked_frame_data( data_folder_path) # Only need steering targets Y_train = Y_train[0] Y_valid = Y_valid[0] model = q_categorical(input_dimension=(120, 160, 4)) # Load pretrained Q model for finetuning model.load_weights('saved_models/robin_track_v2_highres.h5') model.layers[-1].activation = softmax adam = Adam(lr=1e-4) # Use a smaller learning rate for fine-tuning? model.compile(loss='categorical_crossentropy', optimizer=adam) print("weights Load Successfully!") callbacks = [EarlyStopping(monitor='val_loss', patience=3)] #ModelCheckpoint(filepath='best_model.h5', monitor='val_loss', save_best_only=True)] model.fit(X_train, Y_train, epochs=20, batch_size=64, validation_data=(X_valid, Y_valid)) timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S') model.save_weights('saved_models/finetune_q_' + timestamp + '.h5', overwrite=True)
def train_stacked_frame_simple_model(): X_train, X_valid, Y_train, Y_valid = Datasets.make_stacked_frame_data( data_folder_path) model = simple_categorical(input_dimension=(120, 160, 4)) callbacks = [EarlyStopping(monitor='val_loss', patience=3)] #ModelCheckpoint(filepath='best_model.h5', monitor='val_loss', save_best_only=True)] model.fit(X_train, Y_train, epochs=3, batch_size=64, validation_data=(X_valid, Y_valid)) timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S') model.save_weights('saved_models/stackedframe_simple_categorical_' + timestamp + '.h5', overwrite=True)