def __init__(
        self, lanmt_model, hidden_size=256, latent_size=8,
        noise=0.1, targets="logpy", decoder="fixed", training_mode="energy", imitation=False, imit_rand_steps=1, enable_valid_grad=True):
        """
        Args:
            lanmt_model(LANMTModel)
        """
        self.imitation = imitation
        self.imit_rand_steps = imit_rand_steps
        self.training_mode = training_mode
        self._hidden_size = latent_size * 4
        self._latent_size = latent_size
        self.set_stepwise_training(False)
        lanmt_model.train(False)
        self._lanmt = [lanmt_model]
        super(LatentScoreNetwork5, self).__init__(src_vocab_size=1, tgt_vocab_size=1)
        self.enable_valid_grad = True
        self.train()
        self._mycnt = 0

        self.noise = noise
        self.targets = targets

        self.tb_str = "{}_{}_{}_{}_{}_{}".format(targets, decoder, training_mode, noise, imitation, imit_rand_steps)
        if envswitch.who() == "shu":
            main_dir = "{}/data/wmt14_ende_fair/tensorboard".format(os.getenv("HOME"))
        else:
            main_dir = "/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/tensorboard/"

        self._tb= SummaryWriter(
          log_dir="{}/{}".format(main_dir, self.tb_str), flush_secs=10)
Ejemplo n.º 2
0
    def __init__(self,
                 lanmt_model,
                 hidden_size=256,
                 latent_size=8,
                 noise=0.1,
                 targets="logpy",
                 decoder="fixed",
                 imitation=False,
                 line_search_c=0.1,
                 imit_rand_steps=1,
                 enable_valid_grad=True):
        """
        Args:
            lanmt_model(LANMTModel)
        """
        self.imitation = imitation
        self.imit_rand_steps = imit_rand_steps
        self._hidden_size = hidden_size
        self._latent_size = latent_size
        self.set_stepwise_training(False)
        super(LatentScoreNetwork5, self).__init__(src_vocab_size=1,
                                                  tgt_vocab_size=1)
        lanmt_model.train(False)
        self._lanmt = [lanmt_model]
        self.enable_valid_grad = True
        self.train()
        self._mycnt = 0

        self.noise = noise
        self.targets = targets
        self.line_search_c = line_search_c

        if envswitch.who() == "shu":
            main_dir = "{}/data/wmt14_ende_fair/tensorboard".format(
                os.getenv("HOME"))
        else:
            main_dir = "/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/tensorboard/"
Ejemplo n.º 3
0
ap.add_argument("--opt_disentangle", action="store_true")

# Paths
ap.add_argument(
    "--model_path",
    default="/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/checkpoints/ebm.pt")
ap.add_argument(
    "--result_path",
    default="/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/checkpoints/ebm.result"
)
OPTS.parse(ap)

OPTS.model_path = OPTS.model_path.replace(DATA_ROOT, OPTS.root)
OPTS.result_path = OPTS.result_path.replace(DATA_ROOT, OPTS.root)

if envswitch.who() == "shu":
    OPTS.model_path = os.path.join(DATA_ROOT,
                                   os.path.basename(OPTS.model_path))
    OPTS.result_path = os.path.join(DATA_ROOT,
                                    os.path.basename(OPTS.result_path))
    OPTS.fixbug1 = True
    OPTS.fixbug2 = True

if envswitch.who() == "jason_prince":
    OPTS.model_path = os.path.join(HOME_DIR, "checkpoints", "ebm", OPTS.dtok,
                                   os.path.basename(OPTS.model_path))
    OPTS.result_path = os.path.join(HOME_DIR, "checkpoints", "ebm", OPTS.dtok,
                                    os.path.basename(OPTS.result_path))
    os.makedirs(os.path.dirname(OPTS.model_path), exist_ok=True)

# Determine the number of GPUs to use
Ejemplo n.º 4
0
    default=
    "/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/checkpoints/lanmt.result")
OPTS.parse(ap)

if OPTS.hidden_size == 512:
    envswitch.register(
        "shu", "lanmt_path",
        os.path.join(
            DATA_ROOT,
            "lanmt_annealbudget_batchtokens-8192_distill_dtok-wmt14_fair_ende_embedsz-512_fastanneal_heads-8_hiddensz-512_x5longertrain.pt.bak"
        ))

OPTS.model_path = OPTS.model_path.replace(DATA_ROOT, OPTS.root)
OPTS.result_path = OPTS.result_path.replace(DATA_ROOT, OPTS.root)

if envswitch.who() == "shu":
    OPTS.model_path = os.path.join(DATA_ROOT,
                                   os.path.basename(OPTS.model_path))
    OPTS.result_path = os.path.join(DATA_ROOT,
                                    os.path.basename(OPTS.result_path))
    OPTS.fixbug1 = True
    OPTS.fixbug2 = True

if envswitch.who() == "jason_prince":
    OPTS.model_path = os.path.join(HOME_DIR, "checkpoints/lanmt.pt")
    OPTS.result_path = os.path.join(HOME_DIR, "checkpoints/lanmt.result")

# Determine the number of GPUs to use
horovod_installed = importlib.util.find_spec("horovod") is not None
if envswitch.who() != "shu":
    horovod_installed = False
Ejemplo n.º 5
0
    def prepare(self):
        """Define the modules
        """
        # Embedding layers
        self.embed_layer = TransformerEmbedding(self._tgt_vocab_size,
                                                self.embed_size)
        embed_layer = self.embed_layer
        self.pos_embed_layer = PositionalEmbedding(self.hidden_size)
        self.x_encoder = TransformerEncoder(
            None, self.hidden_size,
            5 if OPTS.fix_layers else self.encoder_layers)

        # Prior p(z|x)
        self.prior_encoder = TransformerCrossEncoder(
            None, self.hidden_size,
            3 if OPTS.fix_layers else self.prior_layers)
        self.p_hid2lat = nn.Linear(self.hidden_size, self.latent_dim * 2)

        # Approximate Posterior q(z|y,x)
        self.q_encoder_xy = TransformerCrossEncoder(
            None, self.hidden_size, 3 if OPTS.fix_layers else self.q_layers)
        self.q_hid2lat = nn.Linear(self.hidden_size, self.latent_dim * 2)

        # Decoder p(y|x,z)
        self.lat2hid = nn.Linear(self.latent_dim, self.hidden_size)
        self.decoder = TransformerCrossEncoder(
            None,
            self.hidden_size,
            3 if OPTS.fix_layers else self.decoder_layers,
            skip_connect=True)

        # Length prediction
        #self.length_predictor = nn.Linear(self.hidden_size, 100)
        self.length_predictor = nn.Sequential(nn.Linear(self.hidden_size, 200),
                                              nn.ELU(), nn.Linear(200, 200),
                                              nn.ELU(), nn.Linear(200, 100))

        # Word probability estimator
        self.final_bias = nn.Parameter(torch.randn(self._tgt_vocab_size))
        self.final_bias.requires_grad = True
        final_bias = self.final_bias

        class FinalLinear(object):
            def __call__(self, x):
                return x @ torch.transpose(embed_layer.weight, 0,
                                           1) + final_bias

        if envswitch.who() == "shu":
            if OPTS.dtok == "iwslt16_deen":
                self.expander_nn = FinalLinear()
            else:
                self.expander_nn = nn.Linear(self.hidden_size,
                                             self._tgt_vocab_size)
        else:
            # NOTE : FinalLinear for IWSLT, otherwise nn.Linear
            #self.expander_nn = FinalLinear()
            self.expander_nn = nn.Linear(self.hidden_size,
                                         self._tgt_vocab_size)

        self.label_smooth = LabelSmoothingKLDivLoss(0.1, self._tgt_vocab_size,
                                                    0)
        self.set_stepwise_training(False)
Ejemplo n.º 6
0
    default="/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/checkpoints/ebm.pt")
ap.add_argument(
    "--result_path",
    default="/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/checkpoints/ebm.result"
)
OPTS.parse(ap)

OPTS.model_path = OPTS.model_path.replace(DATA_ROOT, OPTS.root)
OPTS.result_path = OPTS.result_path.replace(DATA_ROOT, OPTS.root)
result_dir = os.path.join(DATA_ROOT, "results")
if not os.path.exists(result_dir):
    os.mkdir(result_dir)
OPTS.result_path = "{}/{}.result".format(result_dir,
                                         short_tag(OPTS.result_tag))

if envswitch.who() == "shu":
    OPTS.model_path = os.path.join(DATA_ROOT,
                                   os.path.basename(OPTS.model_path))
    # OPTS.result_path = os.path.join(DATA_ROOT, os.path.basename(OPTS.result_path))
    OPTS.fixbug1 = True
    if OPTS.dtok != "iwslt16_deen":
        OPTS.fixbug2 = True
else:
    OPTS.model_path = os.path.join(HOME_DIR, "checkpoints", "ebm", OPTS.dtok,
                                   os.path.basename(OPTS.model_path))
    OPTS.result_path = os.path.join(HOME_DIR, "checkpoints", "ebm", OPTS.dtok,
                                    os.path.basename(OPTS.result_path))
    os.makedirs(os.path.dirname(OPTS.model_path), exist_ok=True)

# Determine the number of GPUs to use
horovod_installed = importlib.util.find_spec("horovod") is not None
Ejemplo n.º 7
0
# Paths
ap.add_argument(
    "--model_path",
    default=
    "/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/checkpoints_lvm/lanmt.pt")
ap.add_argument(
    "--result_path",
    default=
    "/misc/vlgscratch4/ChoGroup/jason/lanmt-ebm/checkpoints_lvm/lanmt.result")
OPTS.parse(ap)

OPTS.model_path = OPTS.model_path.replace(DATA_ROOT, OPTS.root)
OPTS.result_path = OPTS.result_path.replace(DATA_ROOT, OPTS.root)

if envswitch.who() == "shu":
    OPTS.model_path = os.path.join(DATA_ROOT,
                                   os.path.basename(OPTS.model_path))
    OPTS.result_path = os.path.join(DATA_ROOT,
                                    os.path.basename(OPTS.result_path))
    OPTS.fixbug1 = True
    OPTS.fixbug2 = True
else:
    OPTS.model_path = os.path.join(HOME_DIR, "checkpoints", "lvm", OPTS.dtok,
                                   os.path.basename(OPTS.model_path))
    OPTS.result_path = os.path.join(HOME_DIR, "checkpoints", "lvm", OPTS.dtok,
                                    os.path.basename(OPTS.result_path))
    os.makedirs(os.path.dirname(OPTS.model_path), exist_ok=True)
    #OPTS.fixbug2 = True

# Determine the number of GPUs to use