Esempio n. 1
0
if sum(j for i, j in skip_steps) + max(skip_steps)[0] > epochs:
    print("Exiting:skip_steps + last skip exceeds total Epochs")
    quit()

FirstSkip = skip_steps[0][0]

TotalSkips = sum([j for i, j in skip_steps])

reg_train_steps = 1
clusters = 5
project_paths = get_project_paths(sys.argv[0], to_tmp=False)

callback_weights_reg = StoreWeights(project_paths["weights"],
                                    reg_train_steps=0,
                                    dtw_clusters=0,
                                    file_prefix="Oweights",
                                    weight_pred_ind=False,
                                    weighs_dtw_cluster_ind=False,
                                    replicate_csvs_at=FirstSkip)

callback_weights_pred = StoreWeights(project_paths["weights"],
                                     reg_train_steps=reg_train_steps,
                                     dtw_clusters=0,
                                     file_prefix="Rweights",
                                     skip_array=skip_steps,
                                     weight_pred_ind=True,
                                     weighs_dtw_cluster_ind=True,
                                     replicate_csvs_at=0)

checkpoint_path = project_paths["checkpoints"] + "/weights_epoch-{epoch}.ckpt"
restore_path = project_paths["checkpoints"] + "/weights_epoch-" + str(
Esempio n. 2
0
    [85, 1],
    [91, 1]
    
]

#check if totalskips + last skip at is within total epochs
if sum(j for i,j in skip_steps) + max(skip_steps)[0] > epochs:
    print("Exiting:skip_steps + last skip exceeds total Epochs")
    quit()
    
TotalSkips = sum([j for i,j in skip_steps]) 

reg_train_steps = 3
clusters = 5
project_paths = get_project_paths(sys.argv[0], to_tmp=False)
callback_weights_reg = StoreWeights(project_paths["weights"], reg_train_steps=0,dtw_clusters=0, file_prefix ="Oweights" , weight_pred_ind=False,weighs_dtw_cluster_ind=False)

callback_weights_pred = StoreWeights(project_paths["weights"], reg_train_steps=2,dtw_clusters=0, file_prefix ="Rweights" ,skip_array=skip_steps, weight_pred_ind=True,weighs_dtw_cluster_ind=True)



fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

train_images = train_images / 255.0
test_images = test_images / 255.0


# trimming inputs for fast processing