def train(hparams, num_epoch, tuning): log_dir = './results/' test_batch_size = 8 # Load dataset training_set, valid_set = make_dataset(BATCH_SIZE=hparams['HP_BS'], file_name='train_tf_record', split=True) test_set = make_dataset(BATCH_SIZE=test_batch_size, file_name='test_tf_record', split=False) class_names = ['NRDR', 'RDR'] # Model model = ResNet() # set optimizer optimizer = tf.keras.optimizers.Adam(learning_rate=hparams['HP_LR']) # set metrics train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() valid_accuracy = tf.keras.metrics.Accuracy() valid_con_mat = ConfusionMatrix(num_class=2) test_accuracy = tf.keras.metrics.Accuracy() test_con_mat = ConfusionMatrix(num_class=2) # Save Checkpoint if not tuning: ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=model) manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=5) # Set up summary writers current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") tb_log_dir = log_dir + current_time + '/train' summary_writer = tf.summary.create_file_writer(tb_log_dir) # Restore Checkpoint if not tuning: ckpt.restore(manager.latest_checkpoint) if manager.latest_checkpoint: logging.info('Restored from {}'.format(manager.latest_checkpoint)) else: logging.info('Initializing from scratch.') @tf.function def train_step(train_img, train_label): # Optimize the model loss_value, grads = grad(model, train_img, train_label) optimizer.apply_gradients(zip(grads, model.trainable_variables)) train_pred, _ = model(train_img) train_label = tf.expand_dims(train_label, axis=1) train_accuracy.update_state(train_label, train_pred) for epoch in range(num_epoch): begin = time() # Training loop for train_img, train_label, train_name in training_set: train_img = data_augmentation(train_img) train_step(train_img, train_label) with summary_writer.as_default(): tf.summary.scalar('Train Accuracy', train_accuracy.result(), step=epoch) for valid_img, valid_label, _ in valid_set: valid_img = tf.cast(valid_img, tf.float32) valid_img = valid_img / 255.0 valid_pred, _ = model(valid_img, training=False) valid_pred = tf.cast(tf.argmax(valid_pred, axis=1), dtype=tf.int64) valid_con_mat.update_state(valid_label, valid_pred) valid_accuracy.update_state(valid_label, valid_pred) # Log the confusion matrix as an image summary cm_valid = valid_con_mat.result() figure = plot_confusion_matrix(cm_valid, class_names=class_names) cm_valid_image = plot_to_image(figure) with summary_writer.as_default(): tf.summary.scalar('Valid Accuracy', valid_accuracy.result(), step=epoch) tf.summary.image('Valid ConfusionMatrix', cm_valid_image, step=epoch) end = time() logging.info( "Epoch {:d} Training Accuracy: {:.3%} Validation Accuracy: {:.3%} Time:{:.5}s" .format(epoch + 1, train_accuracy.result(), valid_accuracy.result(), (end - begin))) train_accuracy.reset_states() valid_accuracy.reset_states() valid_con_mat.reset_states() if not tuning: if int(ckpt.step) % 5 == 0: save_path = manager.save() logging.info('Saved checkpoint for epoch {}: {}'.format( int(ckpt.step), save_path)) ckpt.step.assign_add(1) for test_img, test_label, _ in test_set: test_img = tf.cast(test_img, tf.float32) test_img = test_img / 255.0 test_pred, _ = model(test_img, training=False) test_pred = tf.cast(tf.argmax(test_pred, axis=1), dtype=tf.int64) test_accuracy.update_state(test_label, test_pred) test_con_mat.update_state(test_label, test_pred) cm_test = test_con_mat.result() # Log the confusion matrix as an image summary figure = plot_confusion_matrix(cm_test, class_names=class_names) cm_test_image = plot_to_image(figure) with summary_writer.as_default(): tf.summary.scalar('Test Accuracy', test_accuracy.result(), step=epoch) tf.summary.image('Test ConfusionMatrix', cm_test_image, step=epoch) logging.info("Trained finished. Final Accuracy in test set: {:.3%}".format( test_accuracy.result())) # Visualization if not tuning: for vis_img, vis_label, vis_name in test_set: vis_label = vis_label[0] vis_name = vis_name[0] vis_img = tf.cast(vis_img[0], tf.float32) vis_img = tf.expand_dims(vis_img, axis=0) vis_img = vis_img / 255.0 with tf.GradientTape() as tape: vis_pred, conv_output = model(vis_img, training=False) pred_label = tf.argmax(vis_pred, axis=-1) vis_pred = tf.reduce_max(vis_pred, axis=-1) grad_1 = tape.gradient(vis_pred, conv_output) weight = tf.reduce_mean(grad_1, axis=[1, 2]) / grad_1.shape[1] act_map0 = tf.nn.relu( tf.reduce_sum(weight * conv_output, axis=-1)) act_map0 = tf.squeeze(tf.image.resize(tf.expand_dims(act_map0, axis=-1), (256, 256), antialias=True), axis=-1) plot_map(vis_img, act_map0, vis_pred, pred_label, vis_label, vis_name) break return test_accuracy.result()
def train(unit, dropout, learning_rate, num_epoch, tuning=True): num_epoch = int(num_epoch) log_dir = './results/' # Load dataset path = os.getcwd() train_file = path + '/hapt_tfrecords/hapt_train.tfrecords' val_file = path + '/hapt_tfrecords/hapt_val.tfrecords' test_file = path + '/hapt_tfrecords/hapt_test.tfrecords' train_dataset = make_dataset(train_file, overlap=True) val_dataset = make_dataset(val_file) test_dataset = make_dataset(test_file) class_names = [ 'WALKING', 'WALKING_UPSTAIRS', 'WALKING_DOWNSTAIRS', 'SITTING', 'STANDING', 'LAYING', 'STAND_TO_SIT', 'SIT_TO_STAND', ' SIT_TO_LIE', 'LIE_TO_SIT', 'STAND_TO_LIE', 'LIE_TO_STAND' ] # set a random batch number to visualize the result in test dataset. len_test = len(list(test_dataset)) show_index = random.randint(10, len_test) # Model model = Lstm(unit=unit, drop_out=dropout) # set optimizer optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) # set Metrics train_accuracy = tf.keras.metrics.CategoricalAccuracy() val_accuracy = tf.keras.metrics.Accuracy() val_con_mat = ConfusionMatrix(num_class=13) test_accuracy = tf.keras.metrics.Accuracy() test_con_mat = ConfusionMatrix(num_class=13) # Save Checkpoint if not tuning: ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=model) manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=5) # Set up summary writers current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") tb_log_dir = log_dir + current_time summary_writer = tf.summary.create_file_writer(tb_log_dir) # Restore Checkpoint if not tuning: ckpt.restore(manager.latest_checkpoint) if manager.latest_checkpoint: logging.info('Restored from {}'.format(manager.latest_checkpoint)) else: logging.info('Initializing from scratch.') # calculate losses, update network and metrics. @tf.function def train_step(inputs, label): # Optimize the model loss_value, grads = grad(model, inputs, label) optimizer.apply_gradients(zip(grads, model.trainable_variables)) train_pred = model(inputs, training=True) train_pred = tf.squeeze(train_pred) label = tf.squeeze(label) train_accuracy.update_state(label, train_pred, sample_weight=sample_weight) for epoch in range(num_epoch): begin = time() # Training loop for exp_num, index, label, train_inputs in train_dataset: train_inputs = tf.expand_dims(train_inputs, axis=0) # One-hot coding is applied. label = label - 1 sample_weight = tf.cast(tf.math.not_equal(label, -1), tf.int64) label = tf.expand_dims(tf.one_hot(label, depth=12), axis=0) train_step(train_inputs, label) for exp_num, index, label, val_inputs in val_dataset: val_inputs = tf.expand_dims(val_inputs, axis=0) sample_weight = tf.cast( tf.math.not_equal(label, tf.constant(0, dtype=tf.int64)), tf.int64) val_pred = model(val_inputs, training=False) val_pred = tf.squeeze(val_pred) val_pred = tf.cast(tf.argmax(val_pred, axis=1), dtype=tf.int64) + 1 val_con_mat.update_state(label, val_pred, sample_weight=sample_weight) val_accuracy.update_state(label, val_pred, sample_weight=sample_weight) # Log the confusion matrix as an image summary cm_valid = val_con_mat.result() figure = plot_confusion_matrix(cm_valid, class_names=class_names) cm_valid_image = plot_to_image(figure) with summary_writer.as_default(): tf.summary.scalar('Train Accuracy', train_accuracy.result(), step=epoch) tf.summary.scalar('Valid Accuracy', val_accuracy.result(), step=epoch) tf.summary.image('Valid ConfusionMatrix', cm_valid_image, step=epoch) end = time() logging.info( "Epoch {:d} Training Accuracy: {:.3%} Validation Accuracy: {:.3%} Time:{:.5}s" .format(epoch + 1, train_accuracy.result(), val_accuracy.result(), (end - begin))) train_accuracy.reset_states() val_accuracy.reset_states() val_con_mat.reset_states() if not tuning: if int(ckpt.step) % 5 == 0: save_path = manager.save() logging.info('Saved checkpoint for epoch {}: {}'.format( int(ckpt.step), save_path)) ckpt.step.assign_add(1) i = 0 for exp_num, index, label, test_inputs in test_dataset: test_inputs = tf.expand_dims(test_inputs, axis=0) sample_weight = tf.cast( tf.math.not_equal(label, tf.constant(0, dtype=tf.int64)), tf.int64) test_pred = model(test_inputs, training=False) test_pred = tf.cast(tf.argmax(test_pred, axis=2), dtype=tf.int64) test_pred = tf.squeeze(test_pred, axis=0) + 1 test_accuracy.update_state(label, test_pred, sample_weight=sample_weight) test_con_mat.update_state(label, test_pred, sample_weight=sample_weight) i += 1 # visualize the result if i == show_index: if not tuning: visualization_path = path + '/visualization/' image_path = visualization_path + current_time + '.png' inputs = tf.squeeze(test_inputs) show(index, label, inputs, test_pred, image_path) # Log the confusion matrix as an image summary cm_test = test_con_mat.result() figure = plot_confusion_matrix(cm_test, class_names=class_names) cm_test_image = plot_to_image(figure) with summary_writer.as_default(): tf.summary.scalar('Test Accuracy', test_accuracy.result(), step=epoch) tf.summary.image('Test ConfusionMatrix', cm_test_image, step=epoch) logging.info("Trained finished. Final Accuracy in test set: {:.3%}".format( test_accuracy.result())) return test_accuracy.result()