def main(flags): # initialize data reader # optional for what type of layer the network ends with if len(flags.tconv_dims) == 0: output_size = flags.fc_filters[-1] else: output_size = flags.tconv_dims[-1] features, labels, train_init_op, valid_init_op = data_reader.read_data( input_size=flags.input_size, output_size=output_size - 2 * flags.clip, x_range=flags.x_range, y_range=flags.y_range, cross_val=flags.cross_val, val_fold=flags.val_fold, batch_size=flags.batch_size, shuffle_size=flags.shuffle_size) # make network ntwk = network_maker.CnnNetwork(features, labels, utils.my_model_fn_tens, flags.batch_size, clip=flags.clip, fc_filters=flags.fc_filters, tconv_Fnums=flags.tconv_Fnums, tconv_dims=flags.tconv_dims, tconv_filters=flags.tconv_filters, n_filter=flags.n_filter, n_branch=flags.n_branch, reg_scale=flags.reg_scale, learn_rate=flags.learn_rate, decay_step=flags.decay_step, decay_rate=flags.decay_rate) # define hooks for monitoring training train_hook = network_helper.TrainValueHook(flags.verb_step, ntwk.loss, ckpt_dir=ntwk.ckpt_dir, write_summary=True) lr_hook = network_helper.TrainValueHook(flags.verb_step, ntwk.learn_rate, ckpt_dir=ntwk.ckpt_dir, write_summary=True, value_name='learning_rate') valid_hook = network_helper.ValidationHook(flags.eval_step, valid_init_op, ntwk.labels, ntwk.logits, ntwk.loss, ntwk.preconv, ntwk.preTconv, ckpt_dir=ntwk.ckpt_dir, write_summary=True) # train the network ntwk.train(train_init_op, flags.train_step, [train_hook, valid_hook, lr_hook], write_summary=True)
def main(flags): # initialize data reader if len(flags.tconv_dims) == 0: output_size = flags.fc_filters[-1] else: output_size = flags.tconv_dims[-1] reader = data_reader.DataReader(input_size=flags.input_size, output_size=output_size, x_range=flags.x_range, y_range=flags.y_range, cross_val=flags.cross_val, val_fold=flags.val_fold, batch_size=flags.batch_size, shuffle_size=flags.shuffle_size) features, labels, train_init_op, valid_init_op = reader.get_data_holder_and_init_op( (flags.train_file, flags.valid_file)) # make network ntwk = network_maker.CnnNetwork(features, labels, utils.my_model_fn, flags.batch_size, fc_filters=flags.fc_filters, tconv_dims=flags.tconv_dims, tconv_filters=flags.tconv_filters, learn_rate=flags.learn_rate, decay_step=flags.decay_step, decay_rate=flags.decay_rate) # define hooks for monitoring training train_hook = network_helper.TrainValueHook(flags.verb_step, ntwk.loss, ckpt_dir=ntwk.ckpt_dir, write_summary=True) lr_hook = network_helper.TrainValueHook(flags.verb_step, ntwk.learn_rate, ckpt_dir=ntwk.ckpt_dir, write_summary=True, value_name='learning_rate') valid_hook = network_helper.ValidationHook(flags.eval_step, valid_init_op, ntwk.labels, ntwk.logits, ntwk.loss, ckpt_dir=ntwk.ckpt_dir, write_summary=True) # train the network ntwk.train(train_init_op, flags.train_step, [train_hook, valid_hook, lr_hook], write_summary=True)
def get_hook_list(flags, ntwk, valid_init_op, losses, loss_names, forward_or_backward_str, detail_train_loss=True, summary_op=None): hook_list = [] if (detail_train_loss): print("Losses:", losses) print("loss_name", loss_names) for cnt, (loss, name) in enumerate(zip(losses, loss_names)): print("forward_or_backward_str:", forward_or_backward_str) print("name:", name) print("loss:", loss) hook_list.append( network_helper.TrainValueHook( flags.verb_step, loss, value_name=forward_or_backward_str + name, ckpt_dir=ntwk.ckpt_dir, write_summary=True)) #add a summary op hook for histograms print("Merged Summary op:", ntwk.merged_summary_op) summary_op_hook = network_helper.SummaryWritingHook( ntwk.merged_summary_op, flags.write_weight_step) hook_list.append(summary_op_hook) #Add a validation hook at the END!! (THE end controls the stopping of the training valid_hook = network_helper.ValidationHook( flags.eval_step, valid_init_op, ntwk.labels, ntwk.logits, ntwk.mse_loss, stop_threshold=flags.stop_threshold, value_name=forward_or_backward_str + "test_loss", ckpt_dir=ntwk.ckpt_dir, write_summary=True) hook_list.append(valid_hook) #The validation hook is always in the list return hook_list