def decoder():
    # 配置GPU
    use_gpu = True
    config_gpu(use_gpu=use_gpu)

    # 读取字典 和 词向量矩阵
    vocab = Vocab(config.path_vocab, config.vocab_size)

    wvtool = WVTool(ndim=config.emb_dim)
    embedding_matrix = wvtool.load_embedding_matrix(path_embedding_matrix=config.path_embedding_matrixt)

    # 构建模型
    logger.info('构建Seq2Seq模型 ...')
    model=Seq2Seq(config.beam_size,embedding_matrix=embedding_matrix)


    # 存档点管理
    ckpt = tf.train.Checkpoint(Seq2Seq=model)
    ckpt_manager = tf.train.CheckpointManager(checkpoint=ckpt, directory=config.dir_ckpt, max_to_keep=10)
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        logger.info('decoder模型存档点加载自: {}'.format(ckpt_manager.latest_checkpoint))
    else:
        logger.info('无可加载的存档点')

    # 获取训练数据
    batcher = Batcher(config.path_seg_test, vocab, mode='decode',
                      batch_size=config.beam_size, single_pass=True)

    time.sleep(20)
    # 训练模型
    # 输入:训练数据barcher,模型,词表,存档点,词向量矩阵
    batch_decode(batcher, model=model,vocab=vocab)
    def __init__(self, args, model_name=None):
        self.args = args
        vocab = args.vocab_path if args.vocab_path is not None else config.vocab_path
        self.vocab = Vocab(vocab, config.vocab_size, config.embeddings_file,
                           args)
        self.train_batcher = Batcher(args.train_data_path,
                                     self.vocab,
                                     mode='train',
                                     batch_size=args.batch_size,
                                     single_pass=False,
                                     args=args)
        self.eval_batcher = Batcher(args.eval_data_path,
                                    self.vocab,
                                    mode='eval',
                                    batch_size=args.batch_size,
                                    single_pass=True,
                                    args=args)
        time.sleep(30)

        if model_name is None:
            self.train_dir = os.path.join(config.log_root,
                                          'train_%d' % (int(time.time())))
        else:
            self.train_dir = os.path.join(config.log_root, model_name)

        if not os.path.exists(self.train_dir):
            os.mkdir(self.train_dir)
        self.model_dir = os.path.join(self.train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)
Beispiel #3
0
    def __init__(self, par):
        file = par.cfg
        self.par = par
        self.cuda = False
        self.reuse_words = False
        self.cell = None
        self.num_layers = None
        self.bidirectional = None
        self.hidden_size = None
        self.emb_src_size = None
        self.emb_tgt_size = None
        self.attention = 'dot'
        self.coverage = None
        self.pointer = None
        self.opt_method = None         
        self.max_grad_norm = None
        self.n_iters_sofar = None

        with open(file, 'r') as stream: opts = yaml.load(stream)
        for o,v in opts.items():
            if   o=="cuda":          self.cuda = bool(v) and torch.cuda.is_available()
            elif o=="cell":          self.cell = v.lower()
            elif o=="reuse_words":   self.reuse_words = bool(v)
            elif o=="num_layers":    self.num_layers = int(v)
            elif o=="bidirectional": self.bidirectional = bool(v)
            elif o=="hidden_size":   self.hidden_size = int(v)
            elif o=="emb_src_size":  self.emb_src_size = int(v)
            elif o=="emb_tgt_size":  self.emb_tgt_size = int(v)
            elif o=="attention":     self.attention = v
            elif o=="coverage":      self.coverage = bool(v)
            elif o=="pointer":       self.pointer = bool(v)
            elif o=="opt_method":    self.opt_method = v
            elif o=="max_grad_norm": self.max_grad_norm = float(v)
            else: sys.exit("error: unparsed {} config option.".format(o))

        if self.par.voc_src is None: sys.exit('error: missing -voc_src option')
        if self.coverage and self.attention != 'concat': sys.exit('error: option coverage must be used with attention: \'concat\'')
        self.svoc = Vocab(self.par.voc_src)
        if self.reuse_words:
            self.tvoc = self.svoc
            self.emb_tgt_size = self.emb_src_size
        else:
            if self.par.voc_tgt is None: sys.exit('error: missing -voc_tgt option\n')
            self.tvoc = Vocab(self.par.voc_tgt)
        self.out()
Beispiel #4
0
def main():

    vocab = Vocab(config.vocab_path, config.vocab_size)
    train_batcher = Batcher(config.train_data_path, vocab, mode='train', batch_size=config.batch_size, single_pass=False)
    eval_batcher = Batcher(config.eval_data_path , vocab, mode='train', batch_size=config.batch_size, single_pass=False)
    model = build_model(config)
    criterion = LabelSmoothing(config.vocab_size, train_batcher.pad_id, smoothing=.1)
    
    if args.mode=='train':
        train(config.max_iters, train_batcher, eval_batcher, model, criterion, config, args.save_path)
    elif args.mode=='eval':
        eval(config, args.model)
Beispiel #5
0
    def __init__(self, model_file_path):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.eval_data_path,
                               self.vocab,
                               mode='eval',
                               batch_size=config.batch_size,
                               single_pass=True)
        time.sleep(15)
        model_name = os.path.basename(model_file_path)

        eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name))
        if not os.path.exists(eval_dir):
            os.mkdir(eval_dir)
        self.summary_writer = SummaryWriter(eval_dir)
        self.model = Model(model_file_path, is_eval=True)
Beispiel #6
0
    def __init__(self, model):

        self._decode_dir = os.path.join(config.log_root,
                                        'decode_%s' % ("model2"))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')

        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path,
                               self.vocab,
                               mode='decode',
                               batch_size=1,
                               single_pass=True)
        self.model = model
    def __init__(self, args, model_file_path, save_path):
        model_name = os.path.basename(model_file_path)
        self.args = args
        self._decode_dir = os.path.join(config.log_root, save_path,
                                        'decode_%s' % (model_name))
        self._structures_dir = os.path.join(self._decode_dir, 'structures')
        self._sent_single_heads_dir = os.path.join(self._decode_dir,
                                                   'sent_heads_preds')
        self._sent_single_heads_ref_dir = os.path.join(self._decode_dir,
                                                       'sent_heads_ref')
        self._contsel_dir = os.path.join(self._decode_dir, 'content_sel_preds')
        self._contsel_ref_dir = os.path.join(self._decode_dir,
                                             'content_sel_ref')
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')

        self._rouge_ref_file = os.path.join(self._decode_dir, 'rouge_ref.json')
        self._rouge_pred_file = os.path.join(self._decode_dir,
                                             'rouge_pred.json')
        self.stat_res_file = os.path.join(self._decode_dir, 'stats.txt')
        self.sent_count_file = os.path.join(self._decode_dir,
                                            'sent_used_counts.txt')
        for p in [
                self._decode_dir, self._structures_dir,
                self._sent_single_heads_ref_dir, self._sent_single_heads_dir,
                self._contsel_ref_dir, self._contsel_dir, self._rouge_ref_dir,
                self._rouge_dec_dir
        ]:
            if not os.path.exists(p):
                os.mkdir(p)
        vocab = args.vocab_path if args.vocab_path is not None else config.vocab_path
        self.vocab = Vocab(vocab, config.vocab_size, config.embeddings_file,
                           args)
        self.batcher = Batcher(args.decode_data_path,
                               self.vocab,
                               mode='decode',
                               batch_size=args.beam_size,
                               single_pass=True,
                               args=args)
        self.batcher.setup_queues()
        time.sleep(30)

        self.model = Model(args, self.vocab).to(device)
        self.model.eval()
Beispiel #8
0
    # load examples
    logging.info("Loading data...")
    if dataset == "A":
        train = load_pickle("./data/SemEval/Task{0}/train.pkl".format(dataset))
        val = load_pickle("./data/SemEval/Task{0}/val.pkl".format(dataset))

    if test_mode:
        test = load_pickle("./data/SemEval/TaskA/test.pkl")
        train = merge_splits(train, val)
        val = test
    logging.info("Number of training examples: {0}".format(len(train)))
    logging.info("Number of validation examples: {0}".format(len(val)))
    for ex in train[0][:3]:
        logging.info("Examples: {0}".format(ex))
    logging.info("Building vocab...")
    vocab = Vocab(train, min_freq, max_vocab_size)
    vocab_size = len(vocab.word2id)
    logging.info("Vocab size: {0}".format(vocab_size))

    # build vocab and data
    # use pretrained word embedding
    logging.info("Loading word embedding from Magnitude...")
    home = os.path.expanduser("~")
    if embedding_size in [50, 100, 200]:
        vectors = Magnitude(
            os.path.join(
                home, "WordEmbedding/glove.twitter.27B.{0}d.magnitude".format(
                    embedding_size)))
    elif embedding_size in [300]:
        # vectors = Magnitude(os.path.join(home, "WordEmbedding/GoogleNews-vectors-negative{0}.magnitude".format(embedding_size)))
        vectors = Magnitude(
Beispiel #9
0
        logging.info("RESUME TRAINING")

    audio_conf = dict(sample_rate=args.sample_rate,
                      window_size=args.window_size,
                      window_stride=args.window_stride,
                      window=args.window,
                      noise_dir=args.noise_dir,
                      noise_prob=args.noise_prob,
                      noise_levels=(args.noise_min, args.noise_max))

    logging.info(audio_conf)

    with open(args.labels_path, encoding="utf-8") as label_file:
        labels = json.load(label_file)

    vocab = Vocab()
    for label in labels:
        vocab.add_token(label)
        vocab.add_label(label)

    train_data_list = []
    for i in range(len(args.train_manifest_list)):
        if args.feat == "spectrogram":
            train_data = SpectrogramDataset(
                vocab,
                args,
                audio_conf,
                manifest_filepath_list=args.train_manifest_list,
                normalize=True,
                augment=args.augment,
                input_type=args.input_type,