예제 #1
0
def train(conf_dict):
    """
    train
    """
    training_mode = conf_dict["training_mode"]
    net = utility.import_object(
        conf_dict["net_py"], conf_dict["net_class"])(conf_dict)
    if training_mode == "pointwise":
        datafeed = datafeeds.TFPointwisePaddingData(conf_dict)
        input_l, input_r, label_y = datafeed.ops()
        pred = net.predict(input_l, input_r)
        output_prob = tf.nn.softmax(pred, -1, name="output_prob")
        loss_layer = utility.import_object(
            conf_dict["loss_py"], conf_dict["loss_class"])()
        loss = loss_layer.ops(pred, label_y)
    elif training_mode == "pairwise":
        datafeed = datafeeds.TFPairwisePaddingData(conf_dict)
        input_l, input_r, neg_input = datafeed.ops()
        pos_score = net.predict(input_l, input_r)
        output_prob = tf.identity(pos_score, name="output_preb")
        neg_score = net.predict(input_l, neg_input)
        loss_layer = utility.import_object(
            conf_dict["loss_py"], conf_dict["loss_class"])(conf_dict)
        loss = loss_layer.ops(pos_score, neg_score)
    else:
        print(sys.stderr, "training mode not supported")
        sys.exit(1)
    # define optimizer
    lr = float(conf_dict["learning_rate"])
    optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss)

    # run_trainer
    controler.run_trainer(loss, optimizer, conf_dict)
예제 #2
0
def train(conf_dict):
    tf.compat.v1.reset_default_graph()
    net = utility.import_object(conf_dict["net_py"],
                                conf_dict["net_class"])(conf_dict)
    datafeed = datafeeds.TFPointwisePaddingData(conf_dict)
    input_l, input_r, label_y = datafeed.ops()
    pred = net.predict(input_l, input_r)
    loss_layer = utility.import_object(conf_dict["loss_py"],
                                       conf_dict["loss_class"])()
    loss = loss_layer.ops(pred, label_y)
    # define optimizer
    lr = float(conf_dict["learning_rate"])
    optimizer = tf.compat.v1.train.AdamOptimizer(
        learning_rate=lr).minimize(loss)
    # run_trainer
    controler.run_trainer(loss, optimizer, conf_dict)