def train(hparams, num_epoch, tuning):

    log_dir = './results/'
    test_batch_size = 8
    # Load dataset
    training_set, valid_set = make_dataset(BATCH_SIZE=hparams['HP_BS'],
                                           file_name='train_tf_record',
                                           split=True)
    test_set = make_dataset(BATCH_SIZE=test_batch_size,
                            file_name='test_tf_record',
                            split=False)
    class_names = ['NRDR', 'RDR']

    # Model
    model = ResNet()

    # set optimizer
    optimizer = tf.keras.optimizers.Adam(learning_rate=hparams['HP_LR'])
    # set metrics
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    valid_accuracy = tf.keras.metrics.Accuracy()
    valid_con_mat = ConfusionMatrix(num_class=2)
    test_accuracy = tf.keras.metrics.Accuracy()
    test_con_mat = ConfusionMatrix(num_class=2)

    # Save Checkpoint
    if not tuning:
        ckpt = tf.train.Checkpoint(step=tf.Variable(1),
                                   optimizer=optimizer,
                                   net=model)
        manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=5)

    # Set up summary writers
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tb_log_dir = log_dir + current_time + '/train'
    summary_writer = tf.summary.create_file_writer(tb_log_dir)

    # Restore Checkpoint
    if not tuning:
        ckpt.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            logging.info('Restored from {}'.format(manager.latest_checkpoint))
        else:
            logging.info('Initializing from scratch.')

    @tf.function
    def train_step(train_img, train_label):
        # Optimize the model
        loss_value, grads = grad(model, train_img, train_label)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        train_pred, _ = model(train_img)
        train_label = tf.expand_dims(train_label, axis=1)
        train_accuracy.update_state(train_label, train_pred)

    for epoch in range(num_epoch):

        begin = time()

        # Training loop
        for train_img, train_label, train_name in training_set:
            train_img = data_augmentation(train_img)
            train_step(train_img, train_label)

        with summary_writer.as_default():
            tf.summary.scalar('Train Accuracy',
                              train_accuracy.result(),
                              step=epoch)

        for valid_img, valid_label, _ in valid_set:
            valid_img = tf.cast(valid_img, tf.float32)
            valid_img = valid_img / 255.0
            valid_pred, _ = model(valid_img, training=False)
            valid_pred = tf.cast(tf.argmax(valid_pred, axis=1), dtype=tf.int64)
            valid_con_mat.update_state(valid_label, valid_pred)
            valid_accuracy.update_state(valid_label, valid_pred)

        # Log the confusion matrix as an image summary
        cm_valid = valid_con_mat.result()
        figure = plot_confusion_matrix(cm_valid, class_names=class_names)
        cm_valid_image = plot_to_image(figure)

        with summary_writer.as_default():
            tf.summary.scalar('Valid Accuracy',
                              valid_accuracy.result(),
                              step=epoch)
            tf.summary.image('Valid ConfusionMatrix',
                             cm_valid_image,
                             step=epoch)

        end = time()
        logging.info(
            "Epoch {:d} Training Accuracy: {:.3%} Validation Accuracy: {:.3%} Time:{:.5}s"
            .format(epoch + 1, train_accuracy.result(),
                    valid_accuracy.result(), (end - begin)))
        train_accuracy.reset_states()
        valid_accuracy.reset_states()
        valid_con_mat.reset_states()
        if not tuning:
            if int(ckpt.step) % 5 == 0:
                save_path = manager.save()
                logging.info('Saved checkpoint for epoch {}: {}'.format(
                    int(ckpt.step), save_path))
            ckpt.step.assign_add(1)

    for test_img, test_label, _ in test_set:
        test_img = tf.cast(test_img, tf.float32)
        test_img = test_img / 255.0
        test_pred, _ = model(test_img, training=False)
        test_pred = tf.cast(tf.argmax(test_pred, axis=1), dtype=tf.int64)
        test_accuracy.update_state(test_label, test_pred)
        test_con_mat.update_state(test_label, test_pred)

    cm_test = test_con_mat.result()
    # Log the confusion matrix as an image summary
    figure = plot_confusion_matrix(cm_test, class_names=class_names)
    cm_test_image = plot_to_image(figure)
    with summary_writer.as_default():
        tf.summary.scalar('Test Accuracy', test_accuracy.result(), step=epoch)
        tf.summary.image('Test ConfusionMatrix', cm_test_image, step=epoch)

    logging.info("Trained finished. Final Accuracy in test set: {:.3%}".format(
        test_accuracy.result()))

    # Visualization
    if not tuning:
        for vis_img, vis_label, vis_name in test_set:
            vis_label = vis_label[0]
            vis_name = vis_name[0]
            vis_img = tf.cast(vis_img[0], tf.float32)
            vis_img = tf.expand_dims(vis_img, axis=0)
            vis_img = vis_img / 255.0
            with tf.GradientTape() as tape:
                vis_pred, conv_output = model(vis_img, training=False)
                pred_label = tf.argmax(vis_pred, axis=-1)
                vis_pred = tf.reduce_max(vis_pred, axis=-1)
                grad_1 = tape.gradient(vis_pred, conv_output)
                weight = tf.reduce_mean(grad_1, axis=[1, 2]) / grad_1.shape[1]
                act_map0 = tf.nn.relu(
                    tf.reduce_sum(weight * conv_output, axis=-1))
                act_map0 = tf.squeeze(tf.image.resize(tf.expand_dims(act_map0,
                                                                     axis=-1),
                                                      (256, 256),
                                                      antialias=True),
                                      axis=-1)
                plot_map(vis_img, act_map0, vis_pred, pred_label, vis_label,
                         vis_name)
            break

    return test_accuracy.result()
Example #2
0
def train(unit, dropout, learning_rate, num_epoch, tuning=True):

    num_epoch = int(num_epoch)
    log_dir = './results/'

    # Load dataset
    path = os.getcwd()
    train_file = path + '/hapt_tfrecords/hapt_train.tfrecords'
    val_file = path + '/hapt_tfrecords/hapt_val.tfrecords'
    test_file = path + '/hapt_tfrecords/hapt_test.tfrecords'

    train_dataset = make_dataset(train_file, overlap=True)
    val_dataset = make_dataset(val_file)
    test_dataset = make_dataset(test_file)
    class_names = [
        'WALKING', 'WALKING_UPSTAIRS', 'WALKING_DOWNSTAIRS', 'SITTING',
        'STANDING', 'LAYING', 'STAND_TO_SIT', 'SIT_TO_STAND', ' SIT_TO_LIE',
        'LIE_TO_SIT', 'STAND_TO_LIE', 'LIE_TO_STAND'
    ]

    # set a random batch number to visualize the result in test dataset.
    len_test = len(list(test_dataset))
    show_index = random.randint(10, len_test)

    # Model
    model = Lstm(unit=unit, drop_out=dropout)

    # set optimizer
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    # set Metrics
    train_accuracy = tf.keras.metrics.CategoricalAccuracy()
    val_accuracy = tf.keras.metrics.Accuracy()
    val_con_mat = ConfusionMatrix(num_class=13)
    test_accuracy = tf.keras.metrics.Accuracy()
    test_con_mat = ConfusionMatrix(num_class=13)

    # Save Checkpoint
    if not tuning:
        ckpt = tf.train.Checkpoint(step=tf.Variable(1),
                                   optimizer=optimizer,
                                   net=model)
        manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=5)

    # Set up summary writers
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tb_log_dir = log_dir + current_time
    summary_writer = tf.summary.create_file_writer(tb_log_dir)

    # Restore Checkpoint
    if not tuning:
        ckpt.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            logging.info('Restored from {}'.format(manager.latest_checkpoint))
        else:
            logging.info('Initializing from scratch.')

    # calculate losses, update network and metrics.
    @tf.function
    def train_step(inputs, label):
        # Optimize the model
        loss_value, grads = grad(model, inputs, label)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        train_pred = model(inputs, training=True)
        train_pred = tf.squeeze(train_pred)
        label = tf.squeeze(label)
        train_accuracy.update_state(label,
                                    train_pred,
                                    sample_weight=sample_weight)

    for epoch in range(num_epoch):
        begin = time()

        # Training loop
        for exp_num, index, label, train_inputs in train_dataset:
            train_inputs = tf.expand_dims(train_inputs, axis=0)
            # One-hot coding is applied.
            label = label - 1
            sample_weight = tf.cast(tf.math.not_equal(label, -1), tf.int64)
            label = tf.expand_dims(tf.one_hot(label, depth=12), axis=0)
            train_step(train_inputs, label)

        for exp_num, index, label, val_inputs in val_dataset:
            val_inputs = tf.expand_dims(val_inputs, axis=0)
            sample_weight = tf.cast(
                tf.math.not_equal(label, tf.constant(0, dtype=tf.int64)),
                tf.int64)
            val_pred = model(val_inputs, training=False)
            val_pred = tf.squeeze(val_pred)
            val_pred = tf.cast(tf.argmax(val_pred, axis=1), dtype=tf.int64) + 1
            val_con_mat.update_state(label,
                                     val_pred,
                                     sample_weight=sample_weight)
            val_accuracy.update_state(label,
                                      val_pred,
                                      sample_weight=sample_weight)
        # Log the confusion matrix as an image summary
        cm_valid = val_con_mat.result()
        figure = plot_confusion_matrix(cm_valid, class_names=class_names)
        cm_valid_image = plot_to_image(figure)

        with summary_writer.as_default():
            tf.summary.scalar('Train Accuracy',
                              train_accuracy.result(),
                              step=epoch)
            tf.summary.scalar('Valid Accuracy',
                              val_accuracy.result(),
                              step=epoch)
            tf.summary.image('Valid ConfusionMatrix',
                             cm_valid_image,
                             step=epoch)
        end = time()
        logging.info(
            "Epoch {:d} Training Accuracy: {:.3%} Validation Accuracy: {:.3%} Time:{:.5}s"
            .format(epoch + 1, train_accuracy.result(), val_accuracy.result(),
                    (end - begin)))

        train_accuracy.reset_states()
        val_accuracy.reset_states()
        val_con_mat.reset_states()

        if not tuning:
            if int(ckpt.step) % 5 == 0:
                save_path = manager.save()
                logging.info('Saved checkpoint for epoch {}: {}'.format(
                    int(ckpt.step), save_path))
            ckpt.step.assign_add(1)

    i = 0
    for exp_num, index, label, test_inputs in test_dataset:
        test_inputs = tf.expand_dims(test_inputs, axis=0)
        sample_weight = tf.cast(
            tf.math.not_equal(label, tf.constant(0, dtype=tf.int64)), tf.int64)
        test_pred = model(test_inputs, training=False)
        test_pred = tf.cast(tf.argmax(test_pred, axis=2), dtype=tf.int64)
        test_pred = tf.squeeze(test_pred, axis=0) + 1
        test_accuracy.update_state(label,
                                   test_pred,
                                   sample_weight=sample_weight)
        test_con_mat.update_state(label,
                                  test_pred,
                                  sample_weight=sample_weight)
        i += 1

        # visualize the result
        if i == show_index:
            if not tuning:
                visualization_path = path + '/visualization/'
                image_path = visualization_path + current_time + '.png'
                inputs = tf.squeeze(test_inputs)
                show(index, label, inputs, test_pred, image_path)

    # Log the confusion matrix as an image summary
    cm_test = test_con_mat.result()
    figure = plot_confusion_matrix(cm_test, class_names=class_names)
    cm_test_image = plot_to_image(figure)

    with summary_writer.as_default():
        tf.summary.scalar('Test Accuracy', test_accuracy.result(), step=epoch)
        tf.summary.image('Test ConfusionMatrix', cm_test_image, step=epoch)

    logging.info("Trained finished. Final Accuracy in test set: {:.3%}".format(
        test_accuracy.result()))

    return test_accuracy.result()