if sum(j for i, j in skip_steps) + max(skip_steps)[0] > epochs: print("Exiting:skip_steps + last skip exceeds total Epochs") quit() FirstSkip = skip_steps[0][0] TotalSkips = sum([j for i, j in skip_steps]) reg_train_steps = 1 clusters = 5 project_paths = get_project_paths(sys.argv[0], to_tmp=False) callback_weights_reg = StoreWeights(project_paths["weights"], reg_train_steps=0, dtw_clusters=0, file_prefix="Oweights", weight_pred_ind=False, weighs_dtw_cluster_ind=False, replicate_csvs_at=FirstSkip) callback_weights_pred = StoreWeights(project_paths["weights"], reg_train_steps=reg_train_steps, dtw_clusters=0, file_prefix="Rweights", skip_array=skip_steps, weight_pred_ind=True, weighs_dtw_cluster_ind=True, replicate_csvs_at=0) checkpoint_path = project_paths["checkpoints"] + "/weights_epoch-{epoch}.ckpt" restore_path = project_paths["checkpoints"] + "/weights_epoch-" + str(
[85, 1], [91, 1] ] #check if totalskips + last skip at is within total epochs if sum(j for i,j in skip_steps) + max(skip_steps)[0] > epochs: print("Exiting:skip_steps + last skip exceeds total Epochs") quit() TotalSkips = sum([j for i,j in skip_steps]) reg_train_steps = 3 clusters = 5 project_paths = get_project_paths(sys.argv[0], to_tmp=False) callback_weights_reg = StoreWeights(project_paths["weights"], reg_train_steps=0,dtw_clusters=0, file_prefix ="Oweights" , weight_pred_ind=False,weighs_dtw_cluster_ind=False) callback_weights_pred = StoreWeights(project_paths["weights"], reg_train_steps=2,dtw_clusters=0, file_prefix ="Rweights" ,skip_array=skip_steps, weight_pred_ind=True,weighs_dtw_cluster_ind=True) fashion_mnist = tf.keras.datasets.fashion_mnist (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] train_images = train_images / 255.0 test_images = test_images / 255.0 # trimming inputs for fast processing