def __init__( self, lanmt_model, hidden_size=256, latent_size=8, noise=0.1, targets="logpy", decoder="fixed", training_mode="energy", imitation=False, imit_rand_steps=1, enable_valid_grad=True): """ Args: lanmt_model(LANMTModel) """ self.imitation = imitation self.imit_rand_steps = imit_rand_steps self.training_mode = training_mode self._hidden_size = latent_size * 4 self._latent_size = latent_size self.set_stepwise_training(False) lanmt_model.train(False) self._lanmt = [lanmt_model] super(LatentScoreNetwork5, self).__init__(src_vocab_size=1, tgt_vocab_size=1) self.enable_valid_grad = True self.train() self._mycnt = 0 self.noise = noise self.targets = targets self.tb_str = "{}_{}_{}_{}_{}_{}".format(targets, decoder, training_mode, noise, imitation, imit_rand_steps) if envswitch.who() == "shu": main_dir = "{}/data/wmt14_ende_fair/tensorboard".format(os.getenv("HOME")) else: main_dir = "/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/tensorboard/" self._tb= SummaryWriter( log_dir="{}/{}".format(main_dir, self.tb_str), flush_secs=10)
def __init__(self, lanmt_model, hidden_size=256, latent_size=8, noise=0.1, targets="logpy", decoder="fixed", imitation=False, line_search_c=0.1, imit_rand_steps=1, enable_valid_grad=True): """ Args: lanmt_model(LANMTModel) """ self.imitation = imitation self.imit_rand_steps = imit_rand_steps self._hidden_size = hidden_size self._latent_size = latent_size self.set_stepwise_training(False) super(LatentScoreNetwork5, self).__init__(src_vocab_size=1, tgt_vocab_size=1) lanmt_model.train(False) self._lanmt = [lanmt_model] self.enable_valid_grad = True self.train() self._mycnt = 0 self.noise = noise self.targets = targets self.line_search_c = line_search_c if envswitch.who() == "shu": main_dir = "{}/data/wmt14_ende_fair/tensorboard".format( os.getenv("HOME")) else: main_dir = "/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/tensorboard/"
ap.add_argument("--opt_disentangle", action="store_true") # Paths ap.add_argument( "--model_path", default="/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/checkpoints/ebm.pt") ap.add_argument( "--result_path", default="/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/checkpoints/ebm.result" ) OPTS.parse(ap) OPTS.model_path = OPTS.model_path.replace(DATA_ROOT, OPTS.root) OPTS.result_path = OPTS.result_path.replace(DATA_ROOT, OPTS.root) if envswitch.who() == "shu": OPTS.model_path = os.path.join(DATA_ROOT, os.path.basename(OPTS.model_path)) OPTS.result_path = os.path.join(DATA_ROOT, os.path.basename(OPTS.result_path)) OPTS.fixbug1 = True OPTS.fixbug2 = True if envswitch.who() == "jason_prince": OPTS.model_path = os.path.join(HOME_DIR, "checkpoints", "ebm", OPTS.dtok, os.path.basename(OPTS.model_path)) OPTS.result_path = os.path.join(HOME_DIR, "checkpoints", "ebm", OPTS.dtok, os.path.basename(OPTS.result_path)) os.makedirs(os.path.dirname(OPTS.model_path), exist_ok=True) # Determine the number of GPUs to use
default= "/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/checkpoints/lanmt.result") OPTS.parse(ap) if OPTS.hidden_size == 512: envswitch.register( "shu", "lanmt_path", os.path.join( DATA_ROOT, "lanmt_annealbudget_batchtokens-8192_distill_dtok-wmt14_fair_ende_embedsz-512_fastanneal_heads-8_hiddensz-512_x5longertrain.pt.bak" )) OPTS.model_path = OPTS.model_path.replace(DATA_ROOT, OPTS.root) OPTS.result_path = OPTS.result_path.replace(DATA_ROOT, OPTS.root) if envswitch.who() == "shu": OPTS.model_path = os.path.join(DATA_ROOT, os.path.basename(OPTS.model_path)) OPTS.result_path = os.path.join(DATA_ROOT, os.path.basename(OPTS.result_path)) OPTS.fixbug1 = True OPTS.fixbug2 = True if envswitch.who() == "jason_prince": OPTS.model_path = os.path.join(HOME_DIR, "checkpoints/lanmt.pt") OPTS.result_path = os.path.join(HOME_DIR, "checkpoints/lanmt.result") # Determine the number of GPUs to use horovod_installed = importlib.util.find_spec("horovod") is not None if envswitch.who() != "shu": horovod_installed = False
def prepare(self): """Define the modules """ # Embedding layers self.embed_layer = TransformerEmbedding(self._tgt_vocab_size, self.embed_size) embed_layer = self.embed_layer self.pos_embed_layer = PositionalEmbedding(self.hidden_size) self.x_encoder = TransformerEncoder( None, self.hidden_size, 5 if OPTS.fix_layers else self.encoder_layers) # Prior p(z|x) self.prior_encoder = TransformerCrossEncoder( None, self.hidden_size, 3 if OPTS.fix_layers else self.prior_layers) self.p_hid2lat = nn.Linear(self.hidden_size, self.latent_dim * 2) # Approximate Posterior q(z|y,x) self.q_encoder_xy = TransformerCrossEncoder( None, self.hidden_size, 3 if OPTS.fix_layers else self.q_layers) self.q_hid2lat = nn.Linear(self.hidden_size, self.latent_dim * 2) # Decoder p(y|x,z) self.lat2hid = nn.Linear(self.latent_dim, self.hidden_size) self.decoder = TransformerCrossEncoder( None, self.hidden_size, 3 if OPTS.fix_layers else self.decoder_layers, skip_connect=True) # Length prediction #self.length_predictor = nn.Linear(self.hidden_size, 100) self.length_predictor = nn.Sequential(nn.Linear(self.hidden_size, 200), nn.ELU(), nn.Linear(200, 200), nn.ELU(), nn.Linear(200, 100)) # Word probability estimator self.final_bias = nn.Parameter(torch.randn(self._tgt_vocab_size)) self.final_bias.requires_grad = True final_bias = self.final_bias class FinalLinear(object): def __call__(self, x): return x @ torch.transpose(embed_layer.weight, 0, 1) + final_bias if envswitch.who() == "shu": if OPTS.dtok == "iwslt16_deen": self.expander_nn = FinalLinear() else: self.expander_nn = nn.Linear(self.hidden_size, self._tgt_vocab_size) else: # NOTE : FinalLinear for IWSLT, otherwise nn.Linear #self.expander_nn = FinalLinear() self.expander_nn = nn.Linear(self.hidden_size, self._tgt_vocab_size) self.label_smooth = LabelSmoothingKLDivLoss(0.1, self._tgt_vocab_size, 0) self.set_stepwise_training(False)
default="/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/checkpoints/ebm.pt") ap.add_argument( "--result_path", default="/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/checkpoints/ebm.result" ) OPTS.parse(ap) OPTS.model_path = OPTS.model_path.replace(DATA_ROOT, OPTS.root) OPTS.result_path = OPTS.result_path.replace(DATA_ROOT, OPTS.root) result_dir = os.path.join(DATA_ROOT, "results") if not os.path.exists(result_dir): os.mkdir(result_dir) OPTS.result_path = "{}/{}.result".format(result_dir, short_tag(OPTS.result_tag)) if envswitch.who() == "shu": OPTS.model_path = os.path.join(DATA_ROOT, os.path.basename(OPTS.model_path)) # OPTS.result_path = os.path.join(DATA_ROOT, os.path.basename(OPTS.result_path)) OPTS.fixbug1 = True if OPTS.dtok != "iwslt16_deen": OPTS.fixbug2 = True else: OPTS.model_path = os.path.join(HOME_DIR, "checkpoints", "ebm", OPTS.dtok, os.path.basename(OPTS.model_path)) OPTS.result_path = os.path.join(HOME_DIR, "checkpoints", "ebm", OPTS.dtok, os.path.basename(OPTS.result_path)) os.makedirs(os.path.dirname(OPTS.model_path), exist_ok=True) # Determine the number of GPUs to use horovod_installed = importlib.util.find_spec("horovod") is not None
# Paths ap.add_argument( "--model_path", default= "/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/checkpoints_lvm/lanmt.pt") ap.add_argument( "--result_path", default= "/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/checkpoints_lvm/lanmt.result") OPTS.parse(ap) OPTS.model_path = OPTS.model_path.replace(DATA_ROOT, OPTS.root) OPTS.result_path = OPTS.result_path.replace(DATA_ROOT, OPTS.root) if envswitch.who() == "shu": OPTS.model_path = os.path.join(DATA_ROOT, os.path.basename(OPTS.model_path)) OPTS.result_path = os.path.join(DATA_ROOT, os.path.basename(OPTS.result_path)) OPTS.fixbug1 = True OPTS.fixbug2 = True else: OPTS.model_path = os.path.join(HOME_DIR, "checkpoints", "lvm", OPTS.dtok, os.path.basename(OPTS.model_path)) OPTS.result_path = os.path.join(HOME_DIR, "checkpoints", "lvm", OPTS.dtok, os.path.basename(OPTS.result_path)) os.makedirs(os.path.dirname(OPTS.model_path), exist_ok=True) #OPTS.fixbug2 = True # Determine the number of GPUs to use