コード例 #1
0
def train_on_batch(x, y):
    with tf.GradientTape() as tape:
        output = our_model(x)
        loss = masked_mae_tf(y, output)
        grads = tape.gradient(loss, our_model.trainable_variables)
    optimizer.apply_gradients(zip(grads, our_model.trainable_variables))
    return loss, output, y
コード例 #2
0
def train():
    my_config.is_training = True
    chengdu_initial_loss = 1e9
    porto_initial_loss = 1e9
    print("Loading Weights from checkpoints...")
    checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
    print("Loading Finished.")
    test_epoch = 0
    for epoch in range(epochs):
        lr = optimizer.learning_rate.numpy()
        print("\nepoch {}/{}".format(epoch + 1, epochs))
        weights_before = our_model.get_weights()
        task_dataset = random.sample(dataloader.train_datasets, 1)[0]
        pb_i = Progbar(inner_k, stateful_metrics=metrics_names)
        for k in range(inner_k):
            features, labels = task_dataset[1].batch()
            output, y = None, None
            if task_dataset[0] == 0:
                # loss, output, y = train_on_batch(features)
                loss, output, y = train_on_batch(features, labels)
            elif task_dataset[0] == 1:
                # loss, output, y = train_on_batch(features)
                loss, output, y = train_on_batch(features, labels)
            else:
                # TODO you may add more tasks here
                loss, mae, mape, rmse = -1
            mae = masked_mae_tf(
                dataloader.scaler.inverse_transform(y, task_dataset[0],
                                                    "label"),
                dataloader.scaler.inverse_transform(output, task_dataset[0],
                                                    "label"))
            mape = masked_mape_tf(
                dataloader.scaler.inverse_transform(y, task_dataset[0],
                                                    "label"),
                dataloader.scaler.inverse_transform(output, task_dataset[0],
                                                    "label"))
            rmse = masked_rmse_tf(
                dataloader.scaler.inverse_transform(y, task_dataset[0],
                                                    "label"),
                dataloader.scaler.inverse_transform(output, task_dataset[0],
                                                    "label"))
            loss_metrics.update_state(loss)
            mae_metrics.update_state(mae)
            mape_metrics.update_state(mape)
            rmse_metrics.update_state(rmse)
            pb_i.add(1,
                     values=[("LOSS", loss), ('MAE', mae), ('MAPE', mape),
                             ('RMSE', rmse)])
        epoch_loss = loss_metrics.result()
        epoch_mae = mae_metrics.result()
        epoch_mape = mape_metrics.result()
        epoch_rmse = rmse_metrics.result()
        loss_metrics.reset_states()
        mae_metrics.reset_states()
        mape_metrics.reset_states()
        rmse_metrics.reset_states()
        weights_after = our_model.get_weights()
        outer_step_size_calcu = outer_step_size * (1 - epoch / epochs)
        our_model.set_weights([
            weights_before[i] +
            (weights_after[i] - weights_before[i]) * outer_step_size_calcu
            for i in range(len(our_model.weights))
        ])
        with train_summary_writer.as_default():
            tf.summary.scalar(
                f'{"Chengdu" if task_dataset[0] == 0 else "Porto"} Loss',
                epoch_loss.numpy(),
                step=epoch)
            tf.summary.scalar(
                f'{"Chengdu" if task_dataset[0] == 0 else "Porto"} MAE',
                epoch_mae.numpy(),
                step=epoch)
        print(f'Task {"Chengdu" if task_dataset[0] == 0 else "Porto"}:')
        print(
            f"EPOCH_MAE:{epoch_mae}, EPOCH_MAPE:{epoch_mape}, EPOCH_RMSE:{epoch_rmse}"
        )

        if (epoch + 1) % 1000 == 0:
            test_epoch += 1
            changed_lr = lr_fn(epoch, lr_reduce, lr)
            print('changed_lr:', changed_lr)
            K.set_value(optimizer.lr, changed_lr)
            print('validation begin:')
            chengdu_loss, porto_loss = test(is_testing=False, epoch=test_epoch)
            print('validation end.')
            if chengdu_loss < chengdu_initial_loss:
                checkpoint.save(file_prefix=checkpoint_prefix_chengdu)
                chengdu_initial_loss = chengdu_loss
            if porto_loss < porto_initial_loss:
                checkpoint.save(file_prefix=checkpoint_prefix_porto)
                porto_initial_loss = porto_loss
コード例 #3
0
def test(is_testing=True, epoch=0):
    if test_file_str != "None":
        print("Testing mode on: " + test_file_str)
        test_filename = test_file_str.replace('\r', '')
        test_log_file = open(
            f"./experiments/results/{my_config.general_config['prefix']}/{test_filename}.txt",
            "a+")
    my_config.is_training = False
    mae_loss = 0.0
    mape_loss = 0.0
    rmse_loss = 0.0
    # TODO This part can be refactored to a more reasonable pattern if more datasets are added
    # Chengdu Part
    if is_testing:
        print("Loading Weights from chengdu checkpoints...")
        dataset = dataloader.test_datasets[0]
        checkpoint.restore(
            tf.train.latest_checkpoint(checkpoint_dir + "_chengdu"))
        print("Loading Chengdu Finished.")
    else:
        dataset = dataloader.val_datasets[0]
    print("Chengdu Begin:")
    pb_i = Progbar(dataset[1].batch_num, stateful_metrics=metrics_names)
    for _ in range(dataset[1].batch_num):
        x, y = dataset[1].batch()
        output = test_on_batch(x, y)
        mae = masked_mae_tf(
            dataloader.scaler.inverse_transform(y, dataset[0], "label"),
            dataloader.scaler.inverse_transform(output, dataset[0], "label"),
        )
        mape = masked_mape_tf(
            dataloader.scaler.inverse_transform(y, dataset[0], "label"),
            dataloader.scaler.inverse_transform(output, dataset[0], "label"))
        rmse = masked_rmse_tf(
            dataloader.scaler.inverse_transform(y, dataset[0], "label"),
            dataloader.scaler.inverse_transform(output, dataset[0], "label"))
        mae_loss += mae
        mape_loss += mape
        rmse_loss += rmse
        pb_i.add(1, values=[('MAE', mae), ('MAPE', mape), ('RMSE', rmse)])
    chengdu_epoch_mae = mae_loss / dataset[1].batch_num
    chengdu_epoch_mape = mape_loss / dataset[1].batch_num
    chengdu_epoch_rmse = rmse_loss / dataset[1].batch_num
    if test_file_str == "None":
        with test_summary_writer.as_default():
            tf.summary.scalar(f'Chengdu Test MAE',
                              chengdu_epoch_mae,
                              step=epoch)
            tf.summary.scalar(f'Chengdu Test MAPE',
                              chengdu_epoch_mape,
                              step=epoch)
            tf.summary.scalar(f'Chengdu Test RMSE',
                              chengdu_epoch_rmse,
                              step=epoch)
    else:
        test_log_file.write(
            f"Chengdu: EPOCH_MAE: {chengdu_epoch_mae}, EPOCH_MAPE: {chengdu_epoch_mape}, EPOCH_RMSE: {chengdu_epoch_rmse}\n"
        )
    print(
        f"EPOCH_MAE: {chengdu_epoch_mae}, EPOCH_MAPE: {chengdu_epoch_mape}, EPOCH_RMSE: {chengdu_epoch_rmse}"
    )
    print("Chengdu End.")
    mae_loss = 0.0
    mape_loss = 0.0
    rmse_loss = 0.0
    # Porto Part
    if is_testing:
        print("Loading Weights from porto checkpoints...")
        dataset = dataloader.test_datasets[1]
        checkpoint.restore(
            tf.train.latest_checkpoint(checkpoint_dir + "_porto"))
        print("Loading Chengdu Finished.")
    else:
        dataset = dataloader.val_datasets[1]
    print("Porto Begin:")
    pb_i = Progbar(dataset[1].batch_num, stateful_metrics=metrics_names)
    for _ in range(dataset[1].batch_num):
        x, y = dataset[1].batch()
        output = test_on_batch(x, y)
        mae = masked_mae_tf(
            dataloader.scaler.inverse_transform(y, dataset[0], "label"),
            dataloader.scaler.inverse_transform(output, dataset[0], "label"),
        )
        mape = masked_mape_tf(
            dataloader.scaler.inverse_transform(y, dataset[0], "label"),
            dataloader.scaler.inverse_transform(output, dataset[0], "label"))
        rmse = masked_rmse_tf(
            dataloader.scaler.inverse_transform(y, dataset[0], "label"),
            dataloader.scaler.inverse_transform(output, dataset[0], "label"))
        mae_loss += mae
        mape_loss += mape
        rmse_loss += rmse
        pb_i.add(1, values=[('MAE', mae), ('MAPE', mape), ('RMSE', rmse)])
    porto_epoch_mae = mae_loss / dataset[1].batch_num
    porto_epoch_mape = mape_loss / dataset[1].batch_num
    porto_epoch_rmse = rmse_loss / dataset[1].batch_num
    if test_file_str == "None":
        with test_summary_writer.as_default():
            tf.summary.scalar(f'Porto Test MAE', porto_epoch_mae, step=epoch)
            tf.summary.scalar(f'Porto Test MAPE', porto_epoch_mape, step=epoch)
            tf.summary.scalar(f'Porto Test RMSE', porto_epoch_rmse, step=epoch)
    else:
        test_log_file.write(
            f"Porto: EPOCH_MAE: {porto_epoch_mae}, EPOCH_MAPE: {porto_epoch_mape}, EPOCH_RMSE: {porto_epoch_rmse}\n"
        )
    print(
        f"EPOCH_MAE: {porto_epoch_mae}, EPOCH_MAPE: {porto_epoch_mape}, EPOCH_RMSE: {porto_epoch_rmse}"
    )
    print("Porto End.")
    if test_file_str != "None":
        print("Testing mode end")
        test_log_file.close()
    return chengdu_epoch_mae, porto_epoch_mae
コード例 #4
0
ファイル: main.py プロジェクト: anonymous-repo-21/STMTTE
def test(is_testing=True, epoch=0):
    if test_file_str != "None":
        print("Testing mode on: " + test_file_str)
        test_filename = test_file_str.replace('\r', '')
        if not os.path.isdir(
                f"/public/lhy/wms/wms-codebase/tte-single-city/results/{my_config.general_config['prefix']}/"
        ):
            os.mkdir(
                f"/public/lhy/wms/wms-codebase/tte-single-city/results/{my_config.general_config['prefix']}/"
            )
        test_log_file = open(
            f"/public/lhy/wms/wms-codebase/tte-single-city/results/{my_config.general_config['prefix']}/{test_filename}.txt",
            "a+")
    my_config.is_training = False
    mae_loss = 0.0
    mape_loss = 0.0
    rmse_loss = 0.0
    if is_testing:
        print("Loading Weights from chengdu checkpoints...")
        dataset = dataloader.test_datasets[0]
        checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
        print("Loading Finished.")
    else:
        dataset = dataloader.val_datasets[0]
    pb_i = Progbar(dataset[1].batch_num, stateful_metrics=metrics_names)
    for _ in range(dataset[1].batch_num):
        x, y = dataset[1].batch()
        output = test_on_batch(x, y)
        mae = masked_mae_tf(
            dataloader.scaler.inverse_transform(y, which, "label"),
            dataloader.scaler.inverse_transform(output, which, "label"),
        )
        mape = masked_mape_tf(
            dataloader.scaler.inverse_transform(y, which, "label"),
            dataloader.scaler.inverse_transform(output, which, "label"))
        rmse = masked_rmse_tf(
            dataloader.scaler.inverse_transform(y, which, "label"),
            dataloader.scaler.inverse_transform(output, which, "label"))
        mae_loss += mae
        mape_loss += mape
        rmse_loss += rmse
        pb_i.add(1, values=[('MAE', mae), ('MAPE', mape), ('RMSE', rmse)])
    epoch_mae = mae_loss / dataset[1].batch_num
    epoch_mape = mape_loss / dataset[1].batch_num
    epoch_rmse = rmse_loss / dataset[1].batch_num
    if test_file_str == "None":
        with test_summary_writer.as_default():
            tf.summary.scalar(f'Test MAE', epoch_mae, step=epoch)
            tf.summary.scalar(f'Test MAPE', epoch_mape, step=epoch)
            tf.summary.scalar(f'Test RMSE', epoch_rmse, step=epoch)
    else:
        test_log_file.write(
            f"EPOCH_MAE: {epoch_mae}, EPOCH_MAPE: {epoch_mape}, EPOCH_RMSE: {epoch_rmse}\n"
        )
    print(
        f"EPOCH_MAE: {epoch_mae}, EPOCH_MAPE: {epoch_mape}, EPOCH_RMSE: {epoch_rmse}"
    )
    if test_file_str != "None":
        print("Testing mode end")
        test_log_file.close()
    return epoch_mae