示例#1
0
文件: train.py 项目: WIEQLI/idlm_Ben
def tandemmain(flags):
    # initialize data reader
    #Set the environment variable for if this is a cpu only script
    if flags.use_cpu_only:
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

    geometry, spectra, train_init_op, valid_init_op = data_reader.read_data(input_size=0,
                                                               output_size=0,
                                                               x_range=flags.x_range,
                                                               y_range=flags.y_range,
								geoboundary = flags.geoboundary,
                                                               cross_val=flags.cross_val,
                                                               val_fold=flags.val_fold,
                                                               batch_size=flags.batch_size,
                                                               shuffle_size=flags.shuffle_size,
							        data_dir = flags.data_dir,
								normalize_input = flags.normalize_input,
                                                                test_ratio = 0.1)
  	#If the input is normalized, then make the boundary useless
    if flags.normalize_input:
        flags.geoboundary = [-1, 1, -1, 1]

    print("making network now")
    # make network
    ntwk = Tandem_network_maker.TandemCnnNetwork(geometry, spectra, model_maker.tandem_model, flags.batch_size,
                            clip=flags.clip, forward_fc_filters=flags.forward_fc_filters,
                            backward_fc_filters=flags.backward_fc_filters,reg_scale=flags.reg_scale,
                            learn_rate=flags.learn_rate,tconv_Fnums=flags.tconv_Fnums,
                            tconv_dims=flags.tconv_dims,n_branch=flags.n_branch,
                            tconv_filters=flags.tconv_filters, n_filter=flags.n_filter,
                            decay_step=flags.decay_step, decay_rate=flags.decay_rate,
                            geoboundary = flags.geoboundary, conv1d_filters = flags.conv1d_filters,
                            conv_channel_list = flags.conv_channel_list)
    
    print("Setting the hooks now")
    # define hooks for monitoring training
    train_loss_hook_list = []
    losses = [ntwk.loss, ntwk.mse_loss, ntwk.reg_loss, ntwk.bdy_loss, ntwk.learn_rate]
    loss_names = ["train_loss", "mse_loss", "regularizaiton_loss", "boundary_loss","Learning_rate"]
    #Forward detailed loss hooks, the training detail depend on input flag
    forward_hooks = get_hook_list(flags, ntwk, valid_init_op, losses, loss_names, "forward_", flags.detail_train_loss_forward) 
    #Assume Tandem one always show the training detailed loss
    tandem_hooks = get_hook_list(flags, ntwk, valid_init_op,  losses, loss_names, "tandem_", detail_train_loss = True)
    
    # train the network
    print("Start the training now")
    #ntwk.train(train_init_op, flags.train_step, [train_hook, valid_hook, lr_hook], write_summary=True)
    ntwk.train(train_init_op, flags.train_step, flags.backward_train_step, forward_hooks, tandem_hooks,
                write_summary=True, load_forward_ckpt = flags.forward_model_ckpt)

    #Write the flag into the current folder and move it to the models/ folder along with the best validation error
    flag_reader.write_flags_and_BVE(flags, ntwk.best_validation_loss)
    #Put the parameter.txt file into the latest folder from model
    put_param_into_folder()
示例#2
0
def training_from_flag(flags):
    """
    Training interface. 1. Read data 2. initialize network 3. train network 4. record flags
    :param flag: The training flags read from command line or parameter.py
    :return: None
    """
    # Get the data
    train_loader, test_loader = data_reader.read_data(flags)
    print("Making network now")

    # Make Network
    ntwk = Network(AutoEncoder, INN, flags, train_loader, test_loader)

    # Training process
    print("Start training now...")
    ntwk.train_autoencoder()

    # Do the house keeping, write the parameters and put into folder, also use pickle to save the flags obejct
    flag_reader.write_flags_and_BVE(flags, ntwk.best_validation_loss)
    put_param_into_folder()
示例#3
0
文件: train.py 项目: PL187/idlm_Ben
def VAEtrainmain(flags):
    # initialize data reader

    geometry, spectra, train_init_op, valid_init_op = data_reader.read_data(
        input_size=0,
        output_size=0,
        x_range=flags.x_range,
        y_range=flags.y_range,
        geoboundary=flags.geoboundary,
        cross_val=flags.cross_val,
        val_fold=flags.val_fold,
        batch_size=flags.batch_size,
        shuffle_size=flags.shuffle_size,
        data_dir=flags.data_dir,
        normalize_input=flags.normalize_input,
        test_ratio=0.2)
    #If the input is normalized, then make the boundary useless
    if flags.normalize_input:
        flags.geoboundary = [-1, 1, -1, 1]

    # make network
    ntwk = VAE_network_maker.VAENetwork(
        geometry,
        spectra,
        model_maker.VAE,
        flags.batch_size,
        flags.latent_dim,
        spectra_fc_filters=flags.spectra_fc_filters,
        decoder_fc_filters=flags.decoder_fc_filters,
        encoder_fc_filters=flags.encoder_fc_filters,
        reg_scale=flags.reg_scale,
        learn_rate=flags.learn_rate,
        decay_step=flags.decay_step,
        decay_rate=flags.decay_rate,
        geoboundary=flags.geoboundary,
        conv1d_filters=flags.conv1d_filters,
        filter_channel_list=flags.filter_channel_list)

    print("Setting the hooks now")
    # define hooks for monitoring training
    train_loss_hook_list = []
    losses = [
        ntwk.loss, ntwk.mse_loss, ntwk.reg_loss, ntwk.bdy_loss, ntwk.kl_loss,
        ntwk.learn_rate
    ]
    loss_names = [
        "train_loss", "mse_loss", "regularizaiton_loss", "boundary_loss",
        "KL_loss", "Learning_rate"
    ]
    #Forward detailed loss hooks, the training detail depend on input flag
    VAE_hooks = get_hook_list(flags, ntwk, valid_init_op, losses, loss_names,
                              "VAE_")

    print("Starting training now")
    ntwk.train(train_init_op, flags.train_step, VAE_hooks, write_summary=True)

    #Write the flag into the current folder and move it to the models/ folder along with the best validation error
    flag_reader.write_flags_and_BVE(flags, ntwk.best_validation_loss)

    #Put the parameter.txt file into the latest folder from model
    put_param_into_folder()
示例#4
0
def Backpropmain(flags):
    # initialize data reader

    geometry, spectra, train_init_op, valid_init_op = data_reader.read_data(
        input_size=0,
        output_size=0,
        x_range=flags.x_range,
        y_range=flags.y_range,
        geoboundary=flags.geoboundary,
        cross_val=flags.cross_val,
        val_fold=flags.val_fold,
        batch_size=flags.batch_size,
        shuffle_size=flags.shuffle_size,
        data_dir=flags.data_dir,
        normalize_input=flags.normalize_input,
        test_ratio=0.2)
    #If the input is normalized, then make the boundary useless
    if flags.normalize_input:
        flags.geoboundary = [-1, 1, -1, 1]

    print("boundary is set at:", flags.geoboundary)
    print("making network now")
    # make network
    ntwk = Backprop_network_maker.BackPropCnnNetwork(
        geometry,
        spectra,
        model_maker.back_prop_model,
        flags.batch_size,
        clip=flags.clip,
        forward_fc_filters=flags.forward_fc_filters,
        reg_scale=flags.reg_scale,
        learn_rate=flags.learn_rate,
        tconv_Fnums=flags.tconv_Fnums,
        tconv_dims=flags.tconv_dims,
        n_branch=flags.n_branch,
        tconv_filters=flags.tconv_filters,
        n_filter=flags.n_filter,
        decay_step=flags.decay_step,
        decay_rate=flags.decay_rate,
        boundary=flags.geoboundary)

    print("Setting the hooks now")
    # define hooks for monitoring training
    train_loss_hook_list = []
    losses = [
        ntwk.loss, ntwk.mse_loss, ntwk.reg_loss, ntwk.bdy_loss, ntwk.learn_rate
    ]
    loss_names = [
        "train_loss", "mse_loss", "regularizaiton_loss", "boundary_loss",
        "Learning_rate"
    ]
    #Forward detailed loss hooks, the training detail depend on input flag
    forward_hooks = get_hook_list(flags, ntwk, valid_init_op, losses,
                                  loss_names, "forward_")

    # train the network
    print("Start the training now")
    #ntwk.train(train_init_op, flags.train_step, [train_hook, valid_hook, lr_hook], write_summary=True)
    ntwk.train(train_init_op,
               flags.train_step,
               forward_hooks,
               write_summary=True,
               load_forward_ckpt=flags.forward_model_ckpt)

    #Write the flag into the current folder and move it to the models/ folder along with the best validation error
    flag_reader.write_flags_and_BVE(flags, ntwk.best_validation_loss)
    #Put the parameter.txt file into the latest folder from model
    put_param_into_folder()
示例#5
0
文件: train.py 项目: WIEQLI/idlm_Ben
    # Read the parameters to be set
    flags = flag_reader.read_flag()

    # Get the data
    train_loader, test_loader = data_reader.read_data(x_range=flags.x_range,
                                                      y_range=flags.y_range,
                                                      geoboundary=flags.geoboundary,
                                                      batch_size=flags.batch_size,
                                                      normalize_input=flags.normalize_input,
                                                      data_dir=flags.data_dir)
    # Reset the boundary is normalized
    if flags.normalize_input:
        flags.geoboundary = [-1, 1, -1, 1]

    print("Boundary is set at:", flags.geoboundary)
    print("Making network now")

    # Make Network
    ntwk = Network(Forward, flags, train_loader, test_loader)

    # Training process
    print("Start training now...")
    ntwk.train()

    # Do the house keeping, write the parameters and put into folder
    flag_reader.write_flags_and_BVE(flags, ntwk.best_validation_loss)
    put_param_into_folder()