コード例 #1
0
def main():
    logging.debug("Loading configuration file.")
    config = json.load(open(args.learning_config, "r"))

    logging.debug("Loading Initial/Generative models.")
    dbm = DBM(config["layers"], **config["initial_model_args"])
    gen_dbm = DBM(config["layers"], **config["generative_model_args"])

    logging.debug("Loading learning data.")
    learning_data = Data(gen_dbm.sampling(sampling_num=config["data_size"]))

    logging.info("Optimizer: %s" % config["optimizer"])
    optimizer = getattr(mltools.optimizer,
                        config["optimizer"])(**config["optimizer_args"])

    setting_log = {
        "learning_epoch": args.learning_epoch,
        "learning_configfile": args.learning_config,
        "traindata_size": len(learning_data),
    }
    setting_log.update(config)
    learning_log = LearningLog(setting_log)

    logging.info("Train started.")
    dbm.train(learning_data,
              args.learning_epoch,
              gen_dbm,
              learning_log,
              optimizer,
              minibatch_size=config["minibatch_size"],
              test_interval=config["test_interval"])
    logging.info("Train ended.")

    timestamp = time.strftime("%Y-%m-%d_%H-%M-%S")
    suffix = "_" + args.filename_suffix if args.filename_suffix is not None else ""
    model_filepath = os.path.join(args.output_directory,
                                  "{}_model.json".format(timestamp + suffix))
    log_filepath = os.path.join(args.output_directory,
                                "{}_log.json".format(timestamp + suffix))

    dbm.save(model_filepath)
    logging.info("Model parameters were dumped to: {}".format(model_filepath))
    learning_log.save(log_filepath)
    logging.info("Learning log was dumped to: {}".format(log_filepath))
コード例 #2
0
import tensorflow as tf
import datetime
import os
from DRBM import DRBM
from tensorflow.keras.utils import to_categorical
from mltools import LearningLog

parser = argparse.ArgumentParser("DRBM learning script.", add_help=False)
parser.add_argument("learning_config", action="store", type=str, help="path of learning configuration file.")
parser.add_argument("learning_epoch", action="store", type=int, help="numbers of epochs.")
parser.add_argument("-d", "--output_directory", action="store", type=str, default="./results/", help="directory to output parameter & log")
parser.add_argument("-s", "--filename_suffix", action="store", type=str, default=None, help="filename suffix")
args = parser.parse_args()

config = json.load(open(args.learning_config, "r"))
ll = LearningLog(config)

dtype = config["dtype"]

gen_drbm = DRBM(*config["generative-layers"], **config["generative-args"], dtype=dtype, random_bias=True)
x_train, y_train = gen_drbm.stick_break(config["datasize"])
y_train = to_categorical(y_train, dtype=dtype)

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
optimizer = tf.keras.optimizers.Adamax(learning_rate=0.002, epsilon=1e-8)

drbm = DRBM(*config["training-layers"], **config["training-args"], dtype=dtype)
drbm.fit_generative(args.learning_epoch, config["datasize"], config["minibatch-size"], optimizer, train_ds, gen_drbm, ll)

now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename = [
コード例 #3
0
parser.add_argument("-d",
                    "--output_directory",
                    action="store",
                    type=str,
                    default="./results/",
                    help="directory to output parameter & log")
parser.add_argument("-s",
                    "--filename_suffix",
                    action="store",
                    type=str,
                    default=None,
                    help="filename suffix")
args = parser.parse_args()

config = json.load(open(args.learning_config, "r"))
ll = LearningLog(config)

y, x = np.split(np.loadtxt("./urban.txt", delimiter=","), [1], 1)

x_mean = np.mean(x, 0)
x_std = np.std(x, 0)
x = (x - x_mean) / x_std

y = to_categorical(y)

dtype = config["dtype"]
x = x.astype(dtype)
y = y.astype(dtype)

x_train, x_test, y_train, y_test = train_test_split(x,
                                                    y,