Ejemplo n.º 1
0
 def __init__(self, args, condition, c2b, nor):
     self.data = DataLM(os.path.join(args.root_path, args.data_path),
                        args.batch_size,
                        args.eval_batch_size,
                        condition=condition)
     self.c2b = c2b
     if nor:
         args.model_run = args.model_nor
     else:
         args.model_run = args.model_vmf
     self.args = load_args(args.exp_path, args.model_run)
     self.model = load_model(self.args, len(self.data.dictionary),
                             args.exp_path, args.model_run)
     self.learner = Code2Code(self.model.lat_dim, self.model.ninp)
     self.learner.cuda()
     self.optim = torch.optim.Adam(self.learner.parameters(), lr=0.001)
Ejemplo n.º 2
0
    def __init__(self, args):
        self.data = DataLM(os.path.join(args.root_path, args.data_path),
                           args.batch_size,
                           args.eval_batch_size,
                           condition=True)
        word_list = sorted(self.data.dictionary.word2idx.items(),
                           key=itemgetter(1))

        vmf_args = load_args(args.exp_path, args.model_vmf)
        vmf_model = load_model(vmf_args, len(self.data.dictionary),
                               args.exp_path, args.model_vmf)
        vmf_emb = vmf_model.emb.weight
        self.write_word_embedding(args.exp_path, args.model_vmf + '_emb',
                                  word_list, vmf_emb)
        nor_args = load_args(args.exp_path, args.model_nor)
        nor_model = load_model(nor_args, len(self.data.dictionary),
                               args.exp_path, args.model_nor)
        nor_emb = nor_model.emb.weight
        self.write_word_embedding(args.exp_path, args.model_nor + '_emb',
                                  word_list, nor_emb)
 def load_data(data_path, eval_batch_siez, condition):
     data = DataLM(data_path, eval_batch_siez, eval_batch_siez, condition)
     return data
Ejemplo n.º 4
0
class CodeLearner():
    def __init__(self, args, condition, c2b, nor):
        self.data = DataLM(os.path.join(args.root_path, args.data_path),
                           args.batch_size,
                           args.eval_batch_size,
                           condition=condition)
        self.c2b = c2b
        if nor:
            args.model_run = args.model_nor
        else:
            args.model_run = args.model_vmf
        self.args = load_args(args.exp_path, args.model_run)
        self.model = load_model(self.args, len(self.data.dictionary),
                                args.exp_path, args.model_run)
        self.learner = Code2Code(self.model.lat_dim, self.model.ninp)
        self.learner.cuda()
        self.optim = torch.optim.Adam(self.learner.parameters(), lr=0.001)

    def run_train(self):
        valid_acc = []
        for e in range(10):
            print("EPO: {}".format(e))
            self.train_epo(self.data.train)
            acc = self.evaluate(self.data.test)
            valid_acc.append(acc)
        return min(valid_acc)

    def train_epo(self, train_batches):
        self.learner.train()
        print("Epo start")
        acc_loss = 0
        cnt = 0

        random.shuffle(train_batches)
        for idx, batch in enumerate(train_batches):
            self.optim.zero_grad()
            seq_len, batch_sz = batch.size()
            if self.data.condition:
                seq_len -= 1

                if self.model.input_cd_bit > 1:
                    bit = batch[0, :]
                    bit = GVar(bit)
                else:
                    bit = None
                batch = batch[1:, :]
            else:
                bit = None
            feed = self.data.get_feed(batch)

            seq_len, batch_sz = feed.size()
            emb = self.model.drop(self.model.emb(feed))

            if self.model.input_cd_bit > 1:
                bit = self.model.enc_bit(bit)
            else:
                bit = None

            h = self.model.forward_enc(emb, bit)
            tup, kld, vecs = self.model.forward_build_lat(
                h)  # batchsz, lat dim
            if self.model.dist_type == 'vmf':
                code = tup['mu']
            elif self.model.dist_type == 'nor':
                code = tup['mean']
            else:
                raise NotImplementedError
            emb = torch.mean(emb, dim=0)
            if self.c2b:
                loss = self.learner(code, emb)
            else:
                loss = self.learner(code, emb)
            loss.backward()
            self.optim.step()
            acc_loss += loss.data[0]
            cnt += 1
            if idx % 400 == 0 and (idx > 0):
                print("Training {}".format(acc_loss / cnt))
                acc_loss = 0
                cnt = 0

    def evaluate(self, dev_batches):
        self.learner.eval()
        print("Test start")
        acc_loss = 0
        cnt = 0
        random.shuffle(dev_batches)
        for idx, batch in enumerate(dev_batches):
            self.optim.zero_grad()
            seq_len, batch_sz = batch.size()
            if self.data.condition:
                seq_len -= 1

                if self.model.input_cd_bit > 1:
                    bit = batch[0, :]
                    bit = GVar(bit)
                else:
                    bit = None
                batch = batch[1:, :]
            else:
                bit = None
            feed = self.data.get_feed(batch)

            seq_len, batch_sz = feed.size()
            emb = self.model.drop(self.model.emb(feed))

            if self.model.input_cd_bit > 1:
                bit = self.model.enc_bit(bit)
            else:
                bit = None

            h = self.model.forward_enc(emb, bit)
            tup, kld, vecs = self.model.forward_build_lat(
                h)  # batchsz, lat dim
            if self.model.dist_type == 'vmf':
                code = tup['mu']
            elif self.model.dist_type == 'nor':
                code = tup['mean']
            else:
                raise NotImplementedError
            emb = torch.mean(emb, dim=0)
            if self.c2b:
                loss = self.learner(code, emb)
            else:
                loss = self.learner(code, emb)
            acc_loss += loss.data[0]
            cnt += 1
            if idx % 400 == 0:
                acc_loss = 0
                cnt = 0
        # print("===============test===============")
        # print(acc_loss / cnt)
        print(acc_loss / cnt)
        return float(acc_loss / cnt)
Ejemplo n.º 5
0
def synthesis_bow_rep(args):
    data = DataLM(os.path.join(args.root_path, args.data_path),
                  args.batch_size,
                  args.eval_batch_size,
                  condition=True)
Ejemplo n.º 6
0
def main():
    args = NVLL.argparser.parse_arg()

    if args.model == 'nvdm':
        set_seed(args)
        args, writer = set_save_name_log_nvdm(args)
        logging.info("Device Info: {}".format(device))
        print("Current dir {}".format(os.getcwd()))

        from NVLL.data.ng import DataNg
        # from NVLL.model.nvdm import BowVAE
        # For rcv use nvdm_v2?

        if args.data_name == '20ng':
            from NVLL.model.nvdm import BowVAE
        elif args.data_name == 'rcv':
            from NVLL.model.nvdm import BowVAE
        else:
            raise NotImplementedError
        from NVLL.framework.train_eval_nvdm import Runner
        # Datarcv_Distvmf_Modelnvdm_Emb400_Hid400_lat50
        data = DataNg(args)
        model = BowVAE(args,
                       vocab_size=data.vocab_size,
                       n_hidden=args.nhid,
                       n_lat=args.lat_dim,
                       n_sample=args.nsample,
                       dist=args.dist)
        # Automatic matching loading
        if args.load is not None:
            model.load_state_dict(torch.load(args.load), strict=False)
        else:
            print("Auto load temp closed.")
            """
            files = os.listdir(os.path.join(args.exp_path))
            files = [f for f in files if f.endswith(".model")]
            current_name = "Data{}_Dist{}_Model{}_Emb{}_Hid{}_lat{}".format(args.data_name, str(args.dist),
                                                                                      args.model,
            for f in files:
                if current_name in f:
                    try:
                        model.load_state_dict(torch.load(os.path.join(
                            args.exp_path, f)), strict=False)
                        print("Auto Load success! File Name: {}".format(f))
                        break
                    except RuntimeError:
                        print("Automatic Load failed!")
            """
        model.to(device)
        # if torch.cuda.is_available() and GPU_FLAG:
        #     print("Model in GPU")
        #     model = model.cuda()
        runner = Runner(args, model, data, writer)
        runner.start()
        runner.end()

    elif args.model == 'nvrnn':

        set_seed(args)
        args, writer = set_save_name_log_nvrnn(args)
        logging.info("Device Info: {}".format(device))
        print("Current dir {}".format(os.getcwd()))

        from NVLL.data.lm import DataLM
        from NVLL.model.nvrnn import RNNVAE
        from NVLL.framework.train_eval_nvrnn import Runner
        if (args.data_name == 'ptb') or (args.data_name
                                         == 'trec') or (args.data_name
                                                        == 'yelp_sent'):
            data = DataLM(os.path.join(args.root_path, args.data_path),
                          args.batch_size,
                          args.eval_batch_size,
                          condition=False)
        elif args.data_name == 'yelp':
            data = DataLM(os.path.join(args.root_path, args.data_path),
                          args.batch_size,
                          args.eval_batch_size,
                          condition=True)
        else:
            raise NotImplementedError
        model = RNNVAE(args,
                       args.enc_type,
                       len(data.dictionary),
                       args.emsize,
                       args.nhid,
                       args.lat_dim,
                       args.nlayers,
                       dropout=args.dropout,
                       tie_weights=False,
                       input_z=args.input_z,
                       mix_unk=args.mix_unk,
                       condition=(args.cd_bit or args.cd_bow),
                       input_cd_bow=args.cd_bow,
                       input_cd_bit=args.cd_bit)
        # Automatic matching loading
        if args.load is not None:
            model.load_state_dict(torch.load(args.load), strict=False)
        else:
            print("Auto load temp closed.")
            """
            files = os.listdir(os.path.join( args.exp_path))
            files = [f for f in files if f.endswith(".model")]
            current_name = "Data{}_Dist{}_Model{}_Enc{}Bi{}_Emb{}_Hid{}_lat{}".format(args.data_name, str(args.dist),
                                                                                      args.model, args.enc_type,
                                                                                      args.bi,
                                                                                      args.emsize,
                                                                                      args.nhid, args.lat_dim)
            for f in files:
                if current_name in f and ("mixunk0.0" in f):
                    try:
                        model.load_state_dict(torch.load(os.path.join(
                                                                      args.exp_path, f)), strict=False)
                        print("Auto Load success! {}".format(f))
                        break
                    except RuntimeError:
                        print("Automatic Load failed!")
            """
        model.to(device)
        # if torch.cuda.is_available() and GPU_FLAG:
        #     model = model.cuda()
        runner = Runner(args, model, data, writer)
        runner.start()
        runner.end()
Ejemplo n.º 7
0
 def load_data(self, data_path):
     data = DataLM(data_path, self.args.batch_size,
                   self.args.eval_batch_size)
     return data