import sys sys.path.insert(0, ".") from adversarial_net.models import LanguageModel from adversarial_net import arguments as flags from adversarial_net.preprocessing import WordCounter from adversarial_net import osp flags.add_argument( name="save_model_dir", argtype=str, default= "E:/kaggle/avito/imdb_testset/adversarial_net/model/lm_model/lm_model.ckpt" ) if __name__ == "__main__": vocab_freqs = WordCounter().load( osp.join(flags["lm_inputs"]["datapath"], "imdb_word_freqs.pickle")).most_common_freqs( flags["lm_sequence"]["vocab_size"]) flags.add_variable(name="vocab_freqs", value=vocab_freqs) lm_model = LanguageModel() lm_model.build() lm_model.fit(save_model_path=flags["save_model_dir"])
def configure(): flags.register_variable(name="vocab_freqs") flags.add_argument(scope="inputs", name="datapath", argtype=str) flags.add_argument(scope="inputs", name="dataset", argtype=str) flags.add_argument(scope="inputs", name="eval_count_examples", argtype=int, default=-1) flags.add_argument(scope="inputs", name="vocab_size", argtype=int, default=50000) flags.add_association(scope="inputs", name="eval_max_words", assoc_scope="inputs", assoc_name="vocab_size") flags.add_argument(scope="inputs", name="batch_size", argtype=int, default=256) flags.add_argument(scope="inputs", name="unroll_steps", argtype=int, default=200) flags.add_association(scope="inputs", name="lstm_num_layers", assoc_scope="lm_sequence", assoc_name="rnn_num_layers") flags.add_association(scope="inputs", name="state_size", assoc_scope="lm_sequence", assoc_name="rnn_cell_size") flags.add_argument(scope="inputs", name="bidrec", argtype=bool, default=False) flags.add_association(scope="lm_sequence", name="vocab_size", assoc_scope="inputs", assoc_name="vocab_size") flags.add_argument(scope="lm_sequence", name="embedding_dim", argtype=int, default=256) # should be same with gan rnn_cell_size flags.add_argument(scope="lm_sequence", name="rnn_cell_size", argtype=int, default=1024) flags.add_argument(scope="lm_sequence", name="normalize", argtype="bool", default=True) flags.add_argument(scope="lm_sequence", name="keep_embed_prob", argtype=float, default=1.0) flags.add_argument(scope="lm_sequence", name="lstm_keep_pro_out", argtype=float, default=1.0) flags.add_argument(scope="lm_sequence", name="rnn_num_layers", argtype=int, default=1) flags.add_association(scope="lm_sequence", name="vocab_freqs", assoc_name="vocab_freqs") flags.add_scope_association(scope="lm_inputs", assoc_scope="inputs") flags.add_association(scope="lm_loss", name="vocab_size", assoc_scope="inputs", assoc_name="vocab_size") flags.add_argument(scope="lm_loss", name="num_candidate_samples", argtype=int, default=1024) flags.add_association(scope="lm_loss", name="vocab_freqs", assoc_name="vocab_freqs") flags.add_scope_association(scope="ae_sequence", assoc_scope="lm_sequence") flags.add_scope_association(scope="ae_inputs", assoc_scope="inputs") flags.add_scope_association(scope="ae_loss", assoc_scope="lm_loss") flags.add_scope_association(scope="adv_cl_inputs", assoc_scope="inputs") flags.add_association(scope="adv_cl_inputs", name="phase", assoc_name="phase") flags.add_argument(scope="adv_cl_sequence", name="hidden_size", argtype=int, default=30) flags.add_argument(scope="adv_cl_sequence", name="num_layers", argtype=int, default=1) flags.add_argument(scope="adv_cl_sequence", name="num_classes", argtype=int, default=2) flags.add_argument(scope="adv_cl_sequence", name="keep_prob", argtype=int, default=1.0) flags.add_association(scope="adv_cl_sequence", name="input_size", assoc_scope="lm_sequence", assoc_name="rnn_cell_size") flags.add_argument(scope="adv_cl_loss", name="adv_reg_coeff", argtype=float, default=1.0) flags.add_argument(scope="adv_cl_loss", name="perturb_norm_length", argtype=float, default=5.0) flags.add_argument(scope="vir_adv_loss", name="num_power_iteration", argtype=int, default=1) flags.add_argument(scope="vir_adv_loss", name="small_constant_for_finite_diff", argtype=float, default=1e-1) flags.add_argument(scope="gan", name="critic_iters", argtype=int, default=5) flags.add_argument(scope="gan", name="rnn_cell_size", argtype=int, default=1024) flags.add_argument(scope="summary", name="rnn_cell_size", argtype=int, default=512) flags.add_argument(scope="summary", name="rnn_keep_prob_out", argtype=float, default=1.0) flags.add_argument(scope="summary", name="maximum_iterations", argtype=int, default=50) flags.add_argument(scope="summary", name="beam_width", argtype=int, default=3) flags.add_argument(name="phase", argtype=str, default="train") flags.add_argument(name="max_grad_norm", argtype=float, default=1.0) flags.add_argument(name="lr", argtype=float, default=1e-3) flags.add_argument(name="lr_decay", argtype=float, default=0.9999) flags.add_argument(name="max_steps", argtype=int, default=100000) flags.add_argument(name="save_steps", argtype=int, default=100) flags.add_argument(name="save_best", argtype="bool", default=True) flags.add_argument(name="save_best_check_steps", argtype=int, default=100) flags.add_argument(name="eval_acc", argtype=bool, default=False) flags.add_argument(name="eval_steps", argtype=int, default=100) flags.add_argument(name="should_restore_if_could", argtype="bool", default=True) flags.add_argument(name="tf_debug_trace", argtype=bool, default=False) flags.add_argument(name="tf_timeline_dir", argtype=str, default=None) flags.add_argument(name="no_need_clip_grads", argtype=bool, default=False) flags.add_argument(name="with_lr_decay", argtype="bool", default=True) flags.add_argument(name="lr_decay_step", argtype=int, default=1) # continue from break point flags.add_argument(name="best_loss_val", argtype=float, default=99999999.0) flags.add_argument(name="extra_save_dir", argtype=str, default=None) flags.add_argument(name="tfdebug_root", argtype=str, default=None) flags.add_argument(name="use_exp_mov_avg_loss", argtype=bool, default=False) flags.add_argument(name="use_exp_mov_avg_loss_decay", argtype=float, default=0.9) flags.add_argument(name="logging_file", argtype=str, default=None)
def eval_from(value): eval_from_vals = [ "generator", "topic_generator", "pretrain_cl", "final_cl" ] assert value in eval_from_vals, "step is one of %s" % eval_from_vals return value def adv_type(value): adv_types = ["adv", "vir_adv"] assert value in adv_types, "adv_type is one of %s" % adv_types return value flags.add_argument(name="step", argtype=training_step) flags.add_argument(name="save_model_dir", argtype=str) flags.add_argument(name="pretrain_model_dir", argtype=str, default=None) flags.add_argument(name="eval_from", argtype=eval_from, default="generator") flags.add_argument(name="eval_batch_size", argtype=int, default=2) flags.add_argument(name="eval_topic_count", argtype=int, default=2) flags.add_argument(name="eval_seq_length", argtype=int, default=200) # lm/ae model args flags.add_argument(name="no_loss_sampler", argtype=bool, default=False) flags.add_argument(name="hard_mode", argtype=bool, default=False) flags.add_argument(name="forget_bias", argtype=float, default=0.0) # model prefix flags.add_argument(name="model_prefix", argtype=str, default=None) # adversarial training type flags.add_argument(name="adv_type", argtype=adv_type, default="adv")