예제 #1
0
파일: train.py 프로젝트: chnadell/dlmCN
def main(flags):

    # initialize data reader

    # optional for what type of layer the network ends with
    if len(flags.tconv_dims) == 0:
        output_size = flags.fc_filters[-1]
    else:
        output_size = flags.tconv_dims[-1]

    features, labels, train_init_op, valid_init_op = data_reader.read_data(
        input_size=flags.input_size,
        output_size=output_size - 2 * flags.clip,
        x_range=flags.x_range,
        y_range=flags.y_range,
        cross_val=flags.cross_val,
        val_fold=flags.val_fold,
        batch_size=flags.batch_size,
        shuffle_size=flags.shuffle_size)

    # make network
    ntwk = network_maker.CnnNetwork(features,
                                    labels,
                                    utils.my_model_fn_tens,
                                    flags.batch_size,
                                    clip=flags.clip,
                                    fc_filters=flags.fc_filters,
                                    tconv_Fnums=flags.tconv_Fnums,
                                    tconv_dims=flags.tconv_dims,
                                    tconv_filters=flags.tconv_filters,
                                    n_filter=flags.n_filter,
                                    n_branch=flags.n_branch,
                                    reg_scale=flags.reg_scale,
                                    learn_rate=flags.learn_rate,
                                    decay_step=flags.decay_step,
                                    decay_rate=flags.decay_rate)
    # define hooks for monitoring training
    train_hook = network_helper.TrainValueHook(flags.verb_step,
                                               ntwk.loss,
                                               ckpt_dir=ntwk.ckpt_dir,
                                               write_summary=True)
    lr_hook = network_helper.TrainValueHook(flags.verb_step,
                                            ntwk.learn_rate,
                                            ckpt_dir=ntwk.ckpt_dir,
                                            write_summary=True,
                                            value_name='learning_rate')
    valid_hook = network_helper.ValidationHook(flags.eval_step,
                                               valid_init_op,
                                               ntwk.labels,
                                               ntwk.logits,
                                               ntwk.loss,
                                               ntwk.preconv,
                                               ntwk.preTconv,
                                               ckpt_dir=ntwk.ckpt_dir,
                                               write_summary=True)
    # train the network
    ntwk.train(train_init_op,
               flags.train_step, [train_hook, valid_hook, lr_hook],
               write_summary=True)
예제 #2
0
def main(flags):
    # initialize data reader
    if len(flags.tconv_dims) == 0:
        output_size = flags.fc_filters[-1]
    else:
        output_size = flags.tconv_dims[-1]
    reader = data_reader.DataReader(input_size=flags.input_size,
                                    output_size=output_size,
                                    x_range=flags.x_range,
                                    y_range=flags.y_range,
                                    cross_val=flags.cross_val,
                                    val_fold=flags.val_fold,
                                    batch_size=flags.batch_size,
                                    shuffle_size=flags.shuffle_size)
    features, labels, train_init_op, valid_init_op = reader.get_data_holder_and_init_op(
        (flags.train_file, flags.valid_file))

    # make network
    ntwk = network_maker.CnnNetwork(features,
                                    labels,
                                    utils.my_model_fn,
                                    flags.batch_size,
                                    fc_filters=flags.fc_filters,
                                    tconv_dims=flags.tconv_dims,
                                    tconv_filters=flags.tconv_filters,
                                    learn_rate=flags.learn_rate,
                                    decay_step=flags.decay_step,
                                    decay_rate=flags.decay_rate)
    # define hooks for monitoring training
    train_hook = network_helper.TrainValueHook(flags.verb_step,
                                               ntwk.loss,
                                               ckpt_dir=ntwk.ckpt_dir,
                                               write_summary=True)
    lr_hook = network_helper.TrainValueHook(flags.verb_step,
                                            ntwk.learn_rate,
                                            ckpt_dir=ntwk.ckpt_dir,
                                            write_summary=True,
                                            value_name='learning_rate')
    valid_hook = network_helper.ValidationHook(flags.eval_step,
                                               valid_init_op,
                                               ntwk.labels,
                                               ntwk.logits,
                                               ntwk.loss,
                                               ckpt_dir=ntwk.ckpt_dir,
                                               write_summary=True)
    # train the network
    ntwk.train(train_init_op,
               flags.train_step, [train_hook, valid_hook, lr_hook],
               write_summary=True)
예제 #3
0
파일: train.py 프로젝트: PL187/idlm_Ben
def get_hook_list(flags,
                  ntwk,
                  valid_init_op,
                  losses,
                  loss_names,
                  forward_or_backward_str,
                  detail_train_loss=True,
                  summary_op=None):
    hook_list = []
    if (detail_train_loss):
        print("Losses:", losses)
        print("loss_name", loss_names)
        for cnt, (loss, name) in enumerate(zip(losses, loss_names)):
            print("forward_or_backward_str:", forward_or_backward_str)
            print("name:", name)
            print("loss:", loss)
            hook_list.append(
                network_helper.TrainValueHook(
                    flags.verb_step,
                    loss,
                    value_name=forward_or_backward_str + name,
                    ckpt_dir=ntwk.ckpt_dir,
                    write_summary=True))
    #add a summary op hook for histograms
    print("Merged Summary op:", ntwk.merged_summary_op)
    summary_op_hook = network_helper.SummaryWritingHook(
        ntwk.merged_summary_op, flags.write_weight_step)
    hook_list.append(summary_op_hook)

    #Add a validation hook at the END!! (THE end controls the stopping of the training
    valid_hook = network_helper.ValidationHook(
        flags.eval_step,
        valid_init_op,
        ntwk.labels,
        ntwk.logits,
        ntwk.mse_loss,
        stop_threshold=flags.stop_threshold,
        value_name=forward_or_backward_str + "test_loss",
        ckpt_dir=ntwk.ckpt_dir,
        write_summary=True)
    hook_list.append(valid_hook)  #The validation hook is always in the list
    return hook_list