def main(config): # Load CIFAR data data = DataLoader(config) train_loader, test_loader = data.prepare_data() model = DenseNet(config) model.build((config["trainer"]["batch_size"], 224, 224, 3)) print(model.summary()) optimizer = tf.keras.optimizers.Adam(lr=0.001) loss_object = tf.keras.losses.CategoricalCrossentropy() train_loss = tf.keras.metrics.Mean(name="loss", dtype=tf.float32) train_accuracy = tf.keras.metrics.CategoricalAccuracy( name='train_accuracy') def train_step(images, labels): with tf.GradientTape() as tape: predictions = model(images) loss = loss_object(labels, predictions) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) train_loss(loss) train_accuracy(labels, predictions) for epoch in range(config["trainer"]["epochs"]): for step, (images, labels) in tqdm( enumerate(train_loader), total=int(len(data) / config["trainer"]["batch_size"])): train_step(images, labels) template = 'Epoch {}, Loss: {:.4f}, Accuracy: {:.4f}' print( template.format(epoch + 1, train_loss.result(), train_accuracy.result() * 100))
if 'model-last_epoch.h5' in os.listdir('outputs/'): print ('last model loaded') model= load_model('outputs/model-last_epoch.h5') else: print('created a new model instead') model = DenseNet(input_shape= (r,c,1), dense_blocks=5, dense_layers=-1, growth_rate=8, dropout_rate=0.2, bottleneck=True, compression=1.0, weight_decay=1e-4, depth=40) # training parameters adamOpt = Adam(lr=0.0001) model.compile(loss='mean_squared_error', optimizer=adamOpt, metrics=['mae', 'mse']) model.summary(line_length=200) model.save("test") # log_filename = 'outputs/' + 'landmarks' +'_results.csv' # # csv_log = callbacks.CSVLogger(log_filename, separator=',', append=True) # # # # checkpoint = callbacks.ModelCheckpoint(checkpoint_filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min') # # callbacks_list = [csv_log, checkpoint] # callbacks_list.append(ReduceLROnPlateau(factor=reduceLearningRate, patience=200, # verbose=True)) # callbacks_list.append(EarlyStopping(verbose=True, patience=200)) #