コード例 #1
0
ファイル: train.py プロジェクト: SigmaQuan/NMT-Coverage
def main():
    args = parse_args()

    state = getattr(experiments.nmt, args.proto)()
    if args.state:
        if args.state.endswith(".py"):
            state.update(eval(open(args.state).read()))
        else:
            with open(args.state) as src:
                state.update(cPickle.load(src))
    for change in args.changes:
        state.update(eval("dict({})".format(change)))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    logger.debug("State:\n{}".format(pprint.pformat(state)))

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, skip_init=args.skip_init, compute_alignment=True)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()

    logger.debug("Load data")
    train_data = get_batch_iterator(state)
    logger.debug("Compile trainer")
    algo = eval(state['algo'])(lm_model, state, train_data)
    logger.debug("Run training")
    main = MainLoop(train_data, None, None, lm_model, algo, state, None,
            reset=state['reset'],
            hooks=[RandomSamplePrinter(state, lm_model, train_data)]
                if state['hookFreq'] >= 0
                else None)
    if state['reload']:
        main.load()
    if state['loopIters'] > 0:
        main.main()
コード例 #2
0
ファイル: __init__.py プロジェクト: krzwolk/GroundHog
def create_loop(state, skip_init=False):
    """TODO: Docstring for create_loop.

    :state: TODO
    :skip_init: TODO
    :returns: TODO

    """
    log.debug("State:\n{}".format(pprint.pformat(state)))

    rng = numpy.random.RandomState(state["seed"])
    enc_dec = RNNEncoderDecoder(state, rng, skip_init)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()

    log.debug("Load data")
    train_data = get_batch_iterator(state)
    log.debug("Compile trainer")
    algo = eval(state["algo"])(lm_model, state, train_data)
    log.debug("Run training")
    return MainLoop(
        train_data,
        None,
        None,
        lm_model,
        algo,
        state,
        None,
        reset=state["reset"],
        hooks=[RandomSamplePrinter(state, lm_model, train_data)] if state["hookFreq"] >= 0 else None,
    )
コード例 #3
0
ファイル: train.py プロジェクト: ronuchit/GroundHog
def main():
    args = parse_args()

    state = getattr(experiments.nmt, args.proto)()
    if args.state:
        if args.state.endswith(".py"):
            state.update(eval(open(args.state).read()))
        else:
            with open(args.state) as src:
                state.update(cPickle.load(src))
    for change in args.changes:
        state.update(eval("dict({})".format(change)))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    logger.debug("State:\n{}".format(pprint.pformat(state)))

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, args.skip_init)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()

    logger.debug("Load data")
    train_data = get_batch_iterator(state)
    logger.debug("Compile trainer")
    algo = eval(state['algo'])(lm_model, state, train_data)
    logger.debug("Run training")
    main = MainLoop(train_data, None, None, lm_model, algo, state, None,
            reset=state['reset'],
            hooks=[RandomSamplePrinter(state, lm_model, train_data)]
                if state['hookFreq'] >= 0
                else None)
    if state['reload']:
        main.load()
    if state['loopIters'] > 0:
        main.main()
コード例 #4
0
ファイル: loaddata_try.py プロジェクト: syscomb/GroundHog
def main():
    args = parse_args()
    print 'syscomb'

    state = getattr(experiments.nmt, args.proto)()
    if args.state:
        if args.state.endswith(".py"):
            state.update(eval(open(args.state).read()))
        else:
            with open(args.state) as src:
                state.update(cPickle.load(src))
    for change in args.changes:
        state.update(eval("dict({})".format(change)))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    logger.debug("State:\n{}".format(pprint.pformat(state)))

    '''
    rng = numpy.random.RandomState(state['seed'])
    if state['syscomb']:
        enc_dec = SystemCombination(state, rng, args.skip_init)
    else:
        enc_dec = RNNEncoderDecoder(state, rng, args.skip_init)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()
    '''

    logger.debug("Load data")
    train_data = get_batch_iterator(state)
    train_data.start(-1)
    '''
コード例 #5
0
ファイル: train.py プロジェクト: syscomb/GroundHog
def main():
    args = parse_args()
    print 'syscomb'

    state = getattr(experiments.nmt, args.proto)()
    if args.state:
        if args.state.endswith(".py"):
            state.update(eval(open(args.state).read()))
        else:
            with open(args.state) as src:
                state.update(cPickle.load(src))
    for change in args.changes:
        state.update(eval("dict({})".format(change)))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    logger.debug("State:\n{}".format(pprint.pformat(state)))

    rng = numpy.random.RandomState(state['seed'])
    if state['syscomb']:
        enc_dec = SystemCombination(state, rng, args.skip_init)
    else:
        enc_dec = RNNEncoderDecoder(state, rng, args.skip_init)
    enc_dec.build()
    if state['algo'] == 'SGD_mrt':
        train_sampler = enc_dec.create_sampler(many_samples=True)
    lm_model = enc_dec.create_lm_model()
    print 'lm model inputs:', lm_model.inputs

    logger.debug("Load data")
    if state['syscomb']:
        train_data = get_batch_iterator_multi(state)
        sampler = RandomSamplePrinter_multi(state, lm_model, train_data, enc_dec)
    else:
        train_data = get_batch_iterator(state)
        sampler = RandomSamplePrinter(state, lm_model, train_data)
    
    logger.debug("Compile trainer")
    if state['algo'] == 'SGD_mrt':
        algo = eval(state['algo'])(lm_model, state, train_data, train_sampler)
    else:
        algo = eval(state['algo'])(lm_model, state, train_data)
    
    logger.debug("Run training")
    main = MainLoop(train_data, None, None, lm_model, algo, state, None,
            reset=state['reset'],
            hooks=[sampler]
                if state['hookFreq'] >= 0
                else None)
    if state['reload']:
        main.load()
    if state['loopIters'] > 0:
        main.main()
コード例 #6
0
ファイル: train.py プロジェクト: kelvinxu/GroundHog
def main():
    args = parse_args()

    # this loads the state specified in the prototype 
    state = getattr(experiments.nmt, args.proto)()
    # this is based on the suggestion in the README.md in this foloder
    if args.state:
        if args.state.endswith(".py"):
            state.update(eval(open(args.state).read()))
        else:
            with open(args.state) as src:
                state.update(cPickle.load(src))
    for change in args.changes:
        state.update(eval("dict({})".format(change)))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    logger.debug("State:\n{}".format(pprint.pformat(state)))

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, args.skip_init)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()

    # If we are going to use validation with the bleu script, we 
    # will need early stopping 
    bleu_validator = None
    if state['bleu_script'] is not None and state['validation_set'] is not None\
        and state['validation_set_grndtruth'] is not None:
        # make beam search       
        beam_search = BeamSearch(enc_dec)
        beam_search.compile()
        bleu_validator = BleuValidator(state, lm_model, beam_search, verbose=state['output_validation_set']) 

    logger.debug("Load data")
    train_data = get_batch_iterator(state)
    logger.debug("Compile trainer")

    algo = eval(state['algo'])(lm_model, state, train_data)
    logger.debug("Run training")
    
    main = MainLoop(train_data, None, None, lm_model, algo, state, None,
            reset=state['reset'],
            bleu_val_fn = bleu_validator, 
            hooks=[RandomSamplePrinter(state, lm_model, train_data)]
                if state['hookFreq'] >= 0 and state['validation_set'] is not None
                else None)

    if state['reload']:
        main.load()
    if state['loopIters'] > 0:
        main.main()
コード例 #7
0
ファイル: train.py プロジェクト: rsennrich/LV_groundhog
def main():
    args = parse_args()

    state = getattr(experiments.nmt, args.proto)()
    if args.state:
        if args.state.endswith(".py"):
            state.update(eval(open(args.state).read()))
        else:
            with open(args.state) as src:
                state.update(cPickle.load(src))
    for change in args.changes:
        state.update(eval("dict({})".format(change)))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    logger.debug("State:\n{}".format(pprint.pformat(state)))

    if 'rolling_vocab' not in state:
        state['rolling_vocab'] = 0
    if 'save_algo' not in state:
        state['save_algo'] = 0
    if 'save_gs' not in state:
        state['save_gs'] = 0
    if 'fixed_embeddings' not in state:
        state['fixed_embeddings'] = False
    if 'save_iter' not in state:
        state['save_iter'] = -1
    if 'var_src_len' not in state:
        state['var_src_len'] = False
    if 'reprocess_each_iteration' not in state:
        state['reprocess_each_iteration'] = False

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, args.skip_init)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()

    logger.debug("Load data")
    train_data = get_batch_iterator(state, rng)
    logger.debug("Compile trainer")
    algo = eval(state['algo'])(lm_model, state, train_data)

    if state['rolling_vocab']:
        logger.debug("Initializing extra parameters")
        init_extra_parameters(lm_model, state)
        if not state['fixed_embeddings']:
            init_adadelta_extra_parameters(algo, state)
        with open(state['rolling_vocab_dict'], 'rb') as f:
            lm_model.rolling_vocab_dict = cPickle.load(f)
        lm_model.total_num_batches = max(lm_model.rolling_vocab_dict)
        lm_model.Dx_shelve = shelve.open(state['Dx_file'])
        lm_model.Dy_shelve = shelve.open(state['Dy_file'])

    hooks = []
    if state['hookFreq'] >= 0:
        hooks.append(RandomSamplePrinter(state, lm_model, train_data))
        if 'external_validation_script' in state and state['external_validation_script']:
            hooks.append(ExternalValidator(state, lm_model))

    logger.debug("Run training")
    main = MainLoop(train_data, None, None, lm_model, algo, state, None,
            reset=state['reset'],
            hooks= hooks)
    if state['reload']:
        main.load()
    if state['loopIters'] > 0:
        main.main()
    if state['rolling_vocab']:
        lm_model.Dx_shelve.close()
        lm_model.Dy_shelve.close()
コード例 #8
0
ファイル: mainLoop.py プロジェクト: rsennrich/LV_groundhog
    def main(self):
        assert self.reset == -1

        print_mem("start")
        self.state["gotNaN"] = 0
        start_time = time.time()
        self.start_time = start_time
        self.batch_start_time = time.time()

        self.step = int(self.timings["step"])
        self.algo.step = self.step

        if self.state["save_iter"] < 0:
            self.save_iter = 0
            self.state["save_iter"] = 0
            self.save()
            if self.channel is not None:
                self.channel.save()
        else:  # Fake saving
            self.save_iter += 1
            self.state["save_iter"] = self.save_iter
        self.save_time = time.time()

        last_cost = 1.0
        self.state["clr"] = self.state["lr"]
        self.train_data.start(self.timings["next_offset"] if "next_offset" in self.timings else -1)

        if self.state["rolling_vocab"]:
            for i in xrange(self.timings["step"] - self.timings["super_step"]):
                self.train_data.next()

        if self.state["rolling_vocab"]:
            # Make sure dictionary is current.
            # If training is interrupted when the vocabularies are exchanged,
            # things may get broken.
            step_modulo = self.step % self.model.total_num_batches
            if step_modulo in self.model.rolling_vocab_dict:  # 0 always in.
                cur_key = step_modulo
            else:
                cur_key = 0
                for key in self.model.rolling_vocab_dict:
                    if (key < step_modulo) and (key > cur_key):  # Find largest key smaller than step_modulo
                        cur_key = key
            new_large2small_src = self.model.Dx_shelve[str(cur_key)]
            new_large2small_trgt = self.model.Dy_shelve[str(cur_key)]
            self.roll_vocab_update_dicts(new_large2small_src, new_large2small_trgt)
            self.zero_or_reload = True

        while (
            self.step < self.state["loopIters"]
            and last_cost > 0.1 * self.state["minerr"]
            and (time.time() - start_time) / 60.0 < self.state["timeStop"]
            and self.state["lr"] > self.state["minlr"]
        ):
            if self.step > 0 and (time.time() - self.save_time) / 60.0 >= self.state["saveFreq"]:
                self.save()
                if self.channel is not None:
                    self.channel.save()
                self.save_time = time.time()
            st = time.time()
            try:
                if self.state["rolling_vocab"]:
                    step_modulo = self.step % self.model.total_num_batches
                    if step_modulo in self.model.rolling_vocab_dict:
                        if not self.zero_or_reload:
                            self.roll_vocab_small2large()  # Not necessary for 0 or when reloading a properly saved model
                            new_large2small_src = self.model.Dx_shelve[str(step_modulo)]
                            new_large2small_trgt = self.model.Dy_shelve[str(step_modulo)]
                            self.roll_vocab_update_dicts(
                                new_large2small_src, new_large2small_trgt
                            )  # Done above for 0 or reloaded model
                        self.roll_vocab_large2small()
                        try:
                            tmp_batch = self.train_data.next(peek=True)
                        except StopIteration:
                            if self.state["reprocess_each_iteration"]:
                                logger.info("Reached end of file; re-preprocessing")
                                subprocess.check_call(self.state["reprocess_each_iteration"], shell=True)
                                if self.state["rolling_vocab"]:
                                    os.remove(self.state["Dx_file"])
                                    os.remove(self.state["Dy_file"])
                                    tmp_state = copy.deepcopy(self.state)
                                    rolling_dicts.main(tmp_state)
                                    with open(self.state["rolling_vocab_dict"], "rb") as f:
                                        self.model.rolling_vocab_dict = cPickle.load(f)
                                    self.model.total_num_batches = max(self.model.rolling_vocab_dict)
                                    self.model.Dx_shelve = shelve.open(self.state["Dx_file"])
                                    self.model.Dy_shelve = shelve.open(self.state["Dy_file"])
                                    # round up/down number of steps so modulo is 0 (hack because total_num_batches can change)
                                    logger.debug("step before restart: {0}".format(self.step))
                                    if self.step % self.model.total_num_batches < self.model.total_num_batches / 2:
                                        self.step -= self.step % self.model.total_num_batches
                                    else:
                                        self.step += self.model.total_num_batches - (
                                            self.step % self.model.total_num_batches
                                        )
                                    logger.debug("step after restart: {0}".format(self.step))

                                logger.debug("Load data")
                                self.train_data = get_batch_iterator(
                                    self.state, numpy.random.RandomState(self.state["seed"])
                                )
                                self.train_data.start(-1)
                                self.timings["next_offset"] = -1

                                step_modulo = self.step % self.model.total_num_batches
                                if step_modulo in self.model.rolling_vocab_dict:
                                    if not self.zero_or_reload:
                                        self.roll_vocab_small2large()  # Not necessary for 0 or when reloading a properly saved model
                                        new_large2small_src = self.model.Dx_shelve[str(step_modulo)]
                                        new_large2small_trgt = self.model.Dy_shelve[str(step_modulo)]
                                        self.roll_vocab_update_dicts(
                                            new_large2small_src, new_large2small_trgt
                                        )  # Done above for 0 or reloaded model
                                    self.roll_vocab_large2small()

                                self.algo.data = self.train_data
                                self.algo.step = self.step
                                tmp_batch = self.train_data.next(peek=True)

                                if self.hooks:
                                    self.hooks[0].train_iter = self.train_data
                            else:
                                self.save()
                                raise
                        if (
                            tmp_batch["x"][:, 0].tolist(),
                            tmp_batch["y"][:, 0].tolist(),
                        ) == self.model.rolling_vocab_dict[step_modulo]:
                            logger.debug("Identical first sentences. OK")
                        else:
                            logger.error("Batches do not correspond.")
                    elif self.state["hookFreq"] > 0 and self.step % self.state["hookFreq"] == 0 and self.hooks:
                        [fn() for fn in self.hooks]
                    # Hook first so that the peeked batch is the same as the one used in algo
                    # Use elif not to peek twice
                try:
                    rvals = self.algo()
                except StopIteration:
                    if self.state["reprocess_each_iteration"]:
                        logger.info("Reached end of file; re-preprocessing")
                        subprocess.check_call(self.state["reprocess_each_iteration"], shell=True)
                        logger.debug("Load data")
                        self.train_data = get_batch_iterator(self.state, numpy.random.RandomState(self.state["seed"]))
                        self.train_data.start(-1)
                        self.timings["next_offset"] = -1
                        self.algo.data = self.train_data
                        self.algo.step = self.step

                        rvals = self.algo()

                        if self.hooks:
                            self.hooks[0].train_iter = self.train_data
                    else:
                        self.save()
                        raise
                self.state["traincost"] = float(rvals["cost"])
                self.state["step"] = self.step
                last_cost = rvals["cost"]
                for name in rvals.keys():
                    self.timings[name][self.step] = float(numpy.array(rvals[name]))
                if self.l2_params:
                    for param in self.model.params:
                        self.timings["l2_" + param.name][self.step] = numpy.mean(param.get_value() ** 2) ** 0.5

                if (numpy.isinf(rvals["cost"]) or numpy.isnan(rvals["cost"])) and self.state["on_nan"] == "raise":
                    self.state["gotNaN"] = 1
                    self.save()
                    if self.channel:
                        self.channel.save()
                    print "Got NaN while training"
                    last_cost = 0
                if self.valid_data is not None and self.step % self.state["validFreq"] == 0 and self.step > 1:
                    valcost = self.validate()
                    if valcost > self.old_cost * self.state["cost_threshold"]:
                        self.patience -= 1
                        if "lr_start" in self.state and self.state["lr_start"] == "on_error":
                            self.state["lr_start"] = self.step
                    elif valcost < self.old_cost:
                        self.patience = self.state["patience"]
                        self.old_cost = valcost

                    if self.state["divide_lr"] and self.patience < 1:
                        # Divide lr by 2
                        self.algo.lr = self.algo.lr / self.state["divide_lr"]
                        bparams = dict(self.model.best_params)
                        self.patience = self.state["patience"]
                        for p in self.model.params:
                            p.set_value(bparams[p.name])

                if not self.state["rolling_vocab"]:  # Standard use of hooks
                    if self.state["hookFreq"] > 0 and self.step % self.state["hookFreq"] == 0 and self.hooks:
                        [fn() for fn in self.hooks]
                if self.reset > 0 and self.step > 1 and self.step % self.reset == 0:
                    print "Resetting the data iterator"
                    self.train_data.reset()

                self.step += 1
                if self.state["rolling_vocab"]:
                    self.zero_or_reload = False
                    self.timings["step"] = self.step  # Step now
                    if (self.step % self.model.total_num_batches) % self.state[
                        "sort_k_batches"
                    ] == 0:  # Start of a super_batch.
                        logger.debug("Set super_step and next_offset")
                        # This log shoud appear just before 'logger.debug("Start of a super batch")' in 'get_homogeneous_batch_iter()'
                        self.timings["super_step"] = self.step
                        # Step at start of superbatch. super_step < step
                        self.timings["next_offset"] = self.train_data.next_offset
                        # Where to start after reload. Will need to call next() a few times
                else:
                    self.timings["step"] = self.step
                    self.timings["next_offset"] = self.train_data.next_offset

            except KeyboardInterrupt:
                break

        if self.state["rolling_vocab"]:
            self.roll_vocab_small2large()
        self.state["wholetime"] = float(time.time() - start_time)
        if self.valid_data is not None:
            self.validate()
        self.save()
        if self.channel:
            self.channel.save()
        print "Took", (time.time() - start_time) / 60.0, "min"
        avg_step = self.timings["time_step"][: self.step].mean()
        avg_cost2expl = self.timings["log2_p_expl"][: self.step].mean()
        print "Average step took {}".format(avg_step)
        print "That amounts to {} sentences in a day".format(1 / avg_step * 86400 * self.state["bs"])
        print "Average log2 per example is {}".format(avg_cost2expl)
コード例 #9
0
ファイル: score.py プロジェクト: DmitryKey/GroundHog
def main():
    args = parse_args()

    state = prototype_state()
    with open(args.state) as src:
        state.update(cPickle.load(src))
    state.update(eval("dict({})".format(args.changes)))

    state['sort_k_batches'] = 1
    state['shuffle'] = False
    state['use_infinite_loop'] = False
    state['force_enc_repr_cpu'] = False

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, skip_init=True, compute_alignment=True)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()
    lm_model.load(args.model_path)

    indx_word_src = cPickle.load(open(state['word_indx'],'rb'))
    indx_word_trgt = cPickle.load(open(state['word_indx_trgt'], 'rb'))

    if args.mode == "batch":
        data_given = args.src or args.trg
        txt = data_given and not (args.src.endswith(".h5") and args.trg.endswith(".h5"))
        if data_given and not txt:
            state['source'] = [args.src]
            state['target'] = [args.trg]
        if not data_given and not txt:
            logger.info("Using the training data")
        if txt:
            data_iter = BatchBiTxtIterator(state,
                    args.src, indx_word_src, args.trg, indx_word_trgt,
                    state['bs'], raise_unk=not args.allow_unk)
            data_iter.start()
        else:
            data_iter = get_batch_iterator(state)
            data_iter.start(0)

        score_file = open(args.scores, "w") if args.scores else sys.stdout

        scorer = enc_dec.create_scorer(batch=True)

        count = 0
        n_samples = 0
        logger.info('Scoring phrases')
        for i, batch in enumerate(data_iter):
            if batch == None:
                continue
            if args.n_batches >= 0 and i == args.n_batches:
                break

            if args.y_noise:
                y = batch['y']
                random_words = numpy.random.randint(0, 100, y.shape).astype("int64")
                change_mask = numpy.random.binomial(1, args.y_noise, y.shape).astype("int64")
                y = change_mask * random_words + (1 - change_mask) * y
                batch['y'] = y

            st = time.time()
            [scores] = scorer(batch['x'], batch['y'],
                    batch['x_mask'], batch['y_mask'])
            if args.print_probs:
                scores = numpy.exp(scores)
            up_time = time.time() - st
            for s in scores:
                print >>score_file, "{:.5e}".format(float(s))

            n_samples += batch['x'].shape[1]
            count += 1

            if count % 100 == 0:
                score_file.flush()
                logger.debug("Scores flushed")
            logger.debug("{} batches, {} samples, {} per sample; example scores: {}".format(
                count, n_samples, up_time/scores.shape[0], scores[:5]))

        logger.info("Done")
        score_file.flush()
    elif args.mode == "interact":
        scorer = enc_dec.create_scorer()
        while True:
            try:
                compute_probs = enc_dec.create_probs_computer()
                src_line = raw_input('Source sequence: ')
                trgt_line = raw_input('Target sequence: ')
                src_seq = parse_input(state, indx_word_src, src_line, raise_unk=not args.allow_unk, 
                                      unk_sym=state['unk_sym_source'], null_sym=state['null_sym_source'])
                trgt_seq = parse_input(state, indx_word_trgt, trgt_line, raise_unk=not args.allow_unk,
                                       unk_sym=state['unk_sym_target'], null_sym=state['null_sym_target'])
                print "Binarized source: ", src_seq
                print "Binarized target: ", trgt_seq
                probs = compute_probs(src_seq, trgt_seq)
                print "Probs: {}, cost: {}".format(probs, -numpy.sum(numpy.log(probs)))
            except Exception:
                traceback.print_exc()
    elif args.mode == "txt":
        assert args.src and args.trg
        scorer = enc_dec.create_scorer()
        src_file = open(args.src, "r")
        trg_file = open(args.trg, "r")
        compute_probs = enc_dec.create_probs_computer(return_alignment=True)
        try:
            numpy.set_printoptions(precision=3, linewidth=150, suppress=True)
            i = 0
            while True:
                src_line = next(src_file).strip()
                trgt_line = next(trg_file).strip()
                src_seq, src_words = parse_input(state,
                        indx_word_src, src_line, raise_unk=not args.allow_unk,
                        unk_sym=state['unk_sym_source'], null_sym=state['null_sym_source'])
                trgt_seq, trgt_words = parse_input(state,
                        indx_word_trgt, trgt_line, raise_unk=not args.allow_unk,
                        unk_sym=state['unk_sym_target'], null_sym=state['null_sym_target'])
                probs, alignment = compute_probs(src_seq, trgt_seq)
                if args.verbose:
                    print "Probs: ", probs.flatten()
                    if alignment.ndim == 3:
                        print "Alignment:".ljust(20), src_line, "<eos>"
                        for i, word in enumerate(trgt_words):
                            print "{}{}".format(word.ljust(20), alignment[i, :, 0])
                        print "Generated by:"
                        for i, word in enumerate(trgt_words):
                            j = numpy.argmax(alignment[i, :, 0])
                            print "{} <--- {}".format(word,
                                    src_words[j] if j < len(src_words) else "<eos>")
                i += 1
                if i % 100 == 0:
                    sys.stdout.flush()
                    logger.debug(i)
                print -numpy.sum(numpy.log(probs))
        except StopIteration:
            pass
    else:
        raise Exception("Unknown mode {}".format(args.mode))
コード例 #10
0
ファイル: score.py プロジェクト: ihsgnef/Groundhog
def main():
    args = parse_args()

    state = prototype_state()
    with open(args.state) as src:
        state.update(cPickle.load(src))
    state.update(eval("dict({})".format(args.changes)))

    state['sort_k_batches'] = 1  # which means don't sort
    state['shuffle'] = False
    state['use_infinite_loop'] = False
    state['force_enc_repr_cpu'] = False

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state,
                                rng,
                                skip_init=True,
                                compute_alignment=True)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()
    lm_model.load(args.model_path)

    indx_word_src = cPickle.load(open(state['word_indx'], 'rb'))
    indx_word_trgt = cPickle.load(open(state['word_indx_trgt'], 'rb'))

    if args.mode == "batch":
        data_given = args.src or args.trg
        txt = data_given and not (args.src.endswith(".h5")
                                  and args.trg.endswith(".h5"))
        if data_given and not txt:
            state['source'] = [args.src]
            state['target'] = [args.trg]
        if not data_given and not txt:
            logger.info("Using the training data")
        if txt:
            data_iter = BatchBiTxtIterator(state,
                                           args.src,
                                           indx_word_src,
                                           args.trg,
                                           indx_word_trgt,
                                           state['bs'],
                                           raise_unk=not args.allow_unk)
            data_iter.start()
        else:
            data_iter = get_batch_iterator(state)
            data_iter.start(0)

        score_file = open(args.scores, "w") if args.scores else sys.stdout

        scorer = enc_dec.create_scorer(batch=True)

        count = 0
        n_samples = 0
        logger.info('Scoring phrases')
        for i, batch in enumerate(data_iter):
            if batch == None:
                continue
            if args.n_batches >= 0 and i == args.n_batches:
                break

            if args.y_noise:
                y = batch['y']
                random_words = numpy.random.randint(0, 100,
                                                    y.shape).astype("int64")
                change_mask = numpy.random.binomial(1, args.y_noise,
                                                    y.shape).astype("int64")
                y = change_mask * random_words + (1 - change_mask) * y
                batch['y'] = y

            st = time.time()
            [scores] = scorer(batch['x'], batch['y'], batch['x_mask'],
                              batch['y_mask'])
            if args.print_probs:
                scores = numpy.exp(scores)
            up_time = time.time() - st
            for s in scores:
                print >> score_file, "{:.5e}".format(float(s))

            n_samples += batch['x'].shape[1]
            count += 1

            if count % 100 == 0:
                score_file.flush()
                logger.debug("Scores flushed")
            logger.debug(
                "{} batches, {} samples, {} per sample; example scores: {}".
                format(count, n_samples, up_time / scores.shape[0],
                       scores[:5]))

        logger.info("Done")
        score_file.flush()
    elif args.mode == "interact":
        scorer = enc_dec.create_scorer()
        while True:
            try:
                compute_probs = enc_dec.create_probs_computer()
                src_line = raw_input('Source sequence: ')
                trgt_line = raw_input('Target sequence: ')
                src_seq = parse_input(state,
                                      indx_word_src,
                                      src_line,
                                      raise_unk=not args.allow_unk,
                                      unk_sym=state['unk_sym_source'],
                                      null_sym=state['null_sym_source'])
                trgt_seq = parse_input(state,
                                       indx_word_trgt,
                                       trgt_line,
                                       raise_unk=not args.allow_unk,
                                       unk_sym=state['unk_sym_target'],
                                       null_sym=state['null_sym_target'])
                print "Binarized source: ", src_seq
                print "Binarized target: ", trgt_seq
                probs = compute_probs(src_seq, trgt_seq)
                print "Probs: {}, cost: {}".format(
                    probs, -numpy.sum(numpy.log(probs)))
            except Exception:
                traceback.print_exc()
    elif args.mode == "txt":
        assert args.src and args.trg
        scorer = enc_dec.create_scorer()
        src_file = open(args.src, "r")
        trg_file = open(args.trg, "r")
        compute_probs = enc_dec.create_probs_computer(return_alignment=True)
        try:
            numpy.set_printoptions(precision=3, linewidth=150, suppress=True)
            i = 0
            while True:
                src_line = next(src_file).strip()
                trgt_line = next(trg_file).strip()
                src_seq, src_words = parse_input(
                    state,
                    indx_word_src,
                    src_line,
                    raise_unk=not args.allow_unk,
                    unk_sym=state['unk_sym_source'],
                    null_sym=state['null_sym_source'])
                trgt_seq, trgt_words = parse_input(
                    state,
                    indx_word_trgt,
                    trgt_line,
                    raise_unk=not args.allow_unk,
                    unk_sym=state['unk_sym_target'],
                    null_sym=state['null_sym_target'])
                probs, alignment = compute_probs(src_seq, trgt_seq)
                if args.verbose:
                    print "Probs: ", probs.flatten()
                    if alignment.ndim == 3:
                        print "Alignment:".ljust(20), src_line, "<eos>"
                        for i, word in enumerate(trgt_words):
                            print "{}{}".format(word.ljust(20), alignment[i, :,
                                                                          0])
                        print "Generated by:"
                        for i, word in enumerate(trgt_words):
                            j = numpy.argmax(alignment[i, :, 0])
                            print "{} <--- {}".format(
                                word, src_words[j]
                                if j < len(src_words) else "<eos>")
                i += 1
                if i % 100 == 0:
                    sys.stdout.flush()
                    logger.debug(i)
                print -numpy.sum(numpy.log(probs))
        except StopIteration:
            pass
    else:
        raise Exception("Unknown mode {}".format(args.mode))
コード例 #11
0
def main():
    args = parse_args()

    state = getattr(experiments.nmt, args.proto)()
    if args.state:
        if args.state.endswith(".py"):
            state.update(eval(open(args.state).read()))
        else:
            with open(args.state) as src:
                state.update(cPickle.load(src))
    for change in args.changes:
        state.update(eval("dict({})".format(change)))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    logger.debug("State:\n{}".format(pprint.pformat(state)))

    if 'var_src_len' not in state:
        state['var_src_len'] = False
    if 'partial_Dxy' not in state:
        state['partial_Dxy'] = False
    state['rolling_vocab'] = True
    state['use_infinite_loop'] = False
    logger.debug("rolling_vocab set to True, 'use_infinite_loop' set to False")
    rng = numpy.random.RandomState(state['seed'])

    logger.debug("Load data")
    train_data = get_batch_iterator(state, rng)
    train_data.start(-1)

    if not state.get('var_src_len'):
        dx = {}
        Dx = {}
        Cx = {}
    else:
        Dx = {}
    dy = {}
    Dy = {}
    Cy = {}

    if not state.get('var_src_len'):
        for i in xrange(state['n_sym_source']):
            Dx[i] = i
            Cx[i] = i
    for i in xrange(state['n_sym_target']):
        Dy[i] = i
        Cy[i] = i

    def update_dicts(arr, d, D, C, full):
        i_range, j_range = numpy.shape(arr)
        for i in xrange(i_range):
            for j in xrange(j_range):
                word = arr[i, j]
                if word not in d:
                    if len(d) == full:
                        return True
                    if word not in D:  # Also not in C
                        key, value = C.popitem()
                        del D[key]
                        d[word] = value
                        D[word] = value
                    else:  # Also in C as (d UNION C) is D. (d INTERSECTION C) is the empty set.
                        d[word] = D[word]
                        del C[word]
        return False

    def unlimited_update_dicts(arr, D, size):
        i_range, j_range = numpy.shape(arr)
        for i in xrange(i_range):
            for j in xrange(j_range):
                word = arr[i, j]
                if word not in D:
                    D[word] = size
                    size += 1

    prev_step = 0
    step = 0
    rolling_vocab_dict = {}
    Dx_dict = {}
    Dy_dict = {}

    output = False
    stop = False

    while not stop:  # Assumes the shuffling in get_homogeneous_batch_iter is always the same (Is this true?)
        try:
            batch = train_data.next()
            if step == 0:
                rolling_vocab_dict[step] = (batch['x'][:, 0].tolist(),
                                            batch['y'][:, 0].tolist())
        except:
            batch = None
            stop = True

        if batch:
            if not state.get('var_src_len'):
                output = update_dicts(batch['x'], dx, Dx, Cx,
                                      state['n_sym_source'])
                output += update_dicts(batch['y'], dy, Dy, Cy,
                                       state['n_sym_target'])
            else:
                unlimited_update_dicts(batch['x'], Dx, len(Dx))
                output = update_dicts(batch['y'], dy, Dy, Cy,
                                      state['n_sym_target'])

            if output:
                Dx_dict[prev_step] = Dx.copy(
                )  # Save dictionaries for the batches preceding this one
                Dy_dict[prev_step] = Dy.copy()
                rolling_vocab_dict[step] = (
                    batch['x'][:, 0].tolist(), batch['y'][:, 0].tolist()
                )  # When we get to this batch, we will need to use a new vocabulary
                if state.get('partial_Dxy') and (step / 100000 -
                                                 prev_step / 100000):
                    print 'Updating Dx, Dy'
                    Dx_file = shelve.open(state['Dx_file'])
                    Dy_file = shelve.open(state['Dy_file'])
                    for key in Dx_dict:
                        Dx_file[str(key)] = Dx_dict[key]
                        Dy_file[str(key)] = Dy_dict[key]
                    Dx_file.close()
                    Dy_file.close()
                    Dx_dict = {}
                    Dy_dict = {}
                # tuple of first sentences of the batch # Uses large vocabulary indices
                prev_step = step
                if not state.get('var_src_len'):
                    print step
                    dx = {}
                    Cx = Dx.copy()
                else:
                    print step, len(Dx)
                    Dx = {}
                dy = {}
                Cy = Dy.copy()
                output = False

                if not state.get('var_src_len'):
                    update_dicts(
                        batch['x'], dx, Dx, Cx, state['n_sym_source']
                    )  # Assumes you cannot fill dx or dy with only 1 batch
                    update_dicts(batch['y'], dy, Dy, Cy, state['n_sym_target'])
                else:
                    unlimited_update_dicts(batch['x'], Dx, len(Dx))
                    update_dicts(batch['y'], dy, Dy, Cy, state['n_sym_target'])

            step += 1

    Dx_dict[prev_step] = Dx.copy()
    Dy_dict[prev_step] = Dy.copy()
    rolling_vocab_dict[
        step] = 0  # Total number of batches # Don't store first sentences here

    with open(state['rolling_vocab_dict'], 'w') as f:
        cPickle.dump(rolling_vocab_dict, f)
    Dx_file = shelve.open(state['Dx_file'])
    Dy_file = shelve.open(state['Dy_file'])
    for key in Dx_dict:
        Dx_file[str(key)] = Dx_dict[key]
        Dy_file[str(key)] = Dy_dict[key]
    Dx_file.close()
    Dy_file.close()
コード例 #12
0
ファイル: test.py プロジェクト: Glaceon31/GroundHog
def main():
    args = parse_args()

    state = getattr(experiments.nmt, args.proto)()

    if args.state:
        if args.state.endswith(".py"):
            state.update(eval(open(args.state).read()))
        else:
            with open(args.state) as src:
                state.update(cPickle.load(src))
    for change in args.changes:
        state.update(eval("dict({})".format(change)))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    logger.debug("State:\n{}".format(pprint.pformat(state)))

    rng = numpy.random.RandomState(state['seed'])
    if args.proto == 'prototype_ntm_state' or args.proto == 'prototype_ntmencdec_state':
        print 'Neural Turing Machine'
        enc_dec = NTMEncoderDecoder(state, rng, args.skip_init)
    else:
        enc_dec = RNNEncoderDecoder(state, rng, args.skip_init)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()

    #s_enc_dec = RNNEncoderDecoder(state, rng, args.skip_init)
    #s_lm_model = s_enc_dec.create_lm_model()
    
    logger.debug("Load data")
    train_data = get_batch_iterator(state)
    train_data.start(-1)
    logger.debug("Compile trainer")
    #algo = eval(state['algo'])(lm_model, state, train_data)

    #algo()
    #train
    print '---test training---'
    
    for i in range(1):

        batch = train_data.next()
        #print batch
        x = batch['x']
        print x.shape
        print x
        xs = x[:,78:79]
        xsample = x[:,78]
        print xs.shape
        print xs
        print xsample
        y = batch['y']
        ys = y[:,78:79]

        print y.shape
        x_mask = batch['x_mask']
        xs_mask = x_mask[:,78:79]
        y_mask = batch['y_mask']
        ys_mask = y_mask[:,78:79]
        if not (args.proto == 'prototype_ntm_state' or args.proto == 'prototype_ntmencdec_state'):
            print '---search---'
            train_outputs = enc_dec.forward_training.rvalss+[enc_dec.predictions.out]
            test_train = theano.function(inputs=[enc_dec.x, enc_dec.x_mask, enc_dec.y, enc_dec.y_mask],
                                        outputs=train_outputs)
            result = test_train(x,x_mask,y,y_mask)
            for i in result:
                print i.shape
        else:
            print '---ntm---'
            train_outputs = enc_dec.forward_training.rvalss+[
                        enc_dec.training_c.out,
                        enc_dec.forward_training_c.out,
                        enc_dec.forward_training_m.out,
                        enc_dec.forward_training_rw.out,
                        enc_dec.backward_training_c.out,
                        enc_dec.backward_training_m.out,
                        enc_dec.backward_training_rw.out,
                        ]
            train_outputs = enc_dec.forward_training.rvalss+[\
                        enc_dec.predictions.out,
                        enc_dec.training_c.out
                        ]
            test_train = theano.function(inputs=[enc_dec.x, enc_dec.x_mask, enc_dec.y, enc_dec.y_mask],
                                        outputs=train_outputs)
            result = test_train(x,x_mask,y,y_mask)
            for i in result:
                print i.shape
            #small batch test
            print '---small---'
            results = test_train(xs,xs_mask,ys,ys_mask)
            for i in results:
                print i.shape
            print '---compare---'
            #print result[1][:,4,:,:]
            #print results[1][:,0,:,:]
            print results[-1].shape
            print result[-1][:,78,:]-results[-1][:,0,:]
            print numpy.sum(result[-1][:,78,:]-results[-1][:,0,:])
            #print numpy.sum(result[0][:,4,:]-results[0][:,0,:])
            tmp = copy.deepcopy(result[-1][:,78,:])
            tmpm = copy.deepcopy(result[1][:,78,:,:])
            


    #sample
    #batch = train_data.next()
    #print batch
    print '---test sampling---'
    x = [7,152,429,731,10239,1127,747,480,30000]
    n_samples=10
    n_steps=10
    T=1
    inps = [enc_dec.sampling_x,
            enc_dec.n_samples,
            enc_dec.n_steps,
            enc_dec.T]
    #test_sample = theano.function(inputs=[enc_dec.sampling_x],
    #                            outputs=[enc_dec.sample])
    test_outputs = [enc_dec.sampling_c,
                    enc_dec.forward_sampling_c,
                    enc_dec.forward_sampling_m,
                    enc_dec.forward_sampling_rw,
                    enc_dec.backward_sampling_c,
                    enc_dec.backward_sampling_m,
                    enc_dec.backward_sampling_rw
                    ]
    test_outputs = enc_dec.forward_sampling.rvalss#+[enc_dec.sample,enc_dec.sample_log_prob,enc_dec.sampling_updates]
    #test_outputs = [enc_dec.sample,enc_dec.sample_log_prob]
    #sample_fn = theano.function(inputs=inps,outputs=test_outputs)
    sampler = enc_dec.create_sampler(many_samples=True)
    result = sampler(n_samples, n_steps,T,xsample)
    #print result
    print '---single repr---'
    
    c,m = enc_dec.create_representation_computer()(x)
    states = map(lambda x : x[None, :], enc_dec.create_initializers()(c))

    #print states
    print states[0].shape
    print m[-1:].shape
    '''
    next = enc_dec.create_next_states_computer(c, 0, inputs, m[-1:],*states)
    #print next[0]
    #print next[1]
    print next[0].shape
    print next[1].shape

    print c
    print m
    '''
    print '---repr compare---'
    print c.shape
    print m.shape
    print c-tmp[0:c.shape[0],:]
    print numpy.sum(c-tmp[0:c.shape[0],:],axis=1)
    print m-tmpm[0:m.shape[0],:,:]
    print numpy.sum(m-tmp[0:m.shape[0],:,:],axis=1)

    return
コード例 #13
0
def main():
    args = parse_args()

    state = getattr(experiments.nmt, args.proto)()
    if args.state:
        if args.state.endswith(".py"):
            state.update(eval(open(args.state).read()))
        else:
            with open(args.state) as src:
                state.update(cPickle.load(src))
    for change in args.changes:
        state.update(eval("dict({})".format(change)))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    logger.debug("State:\n{}".format(pprint.pformat(state)))

    if 'rolling_vocab' not in state:
        state['rolling_vocab'] = 0
    if 'save_algo' not in state:
        state['save_algo'] = 0
    if 'save_gs' not in state:
        state['save_gs'] = 0
    if 'fixed_embeddings' not in state:
        state['fixed_embeddings'] = False
    if 'save_iter' not in state:
        state['save_iter'] = -1
    if 'var_src_len' not in state:
        state['var_src_len'] = False

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, args.skip_init)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()

    logger.debug("Load data")
    train_data = get_batch_iterator(state, rng)
    logger.debug("Compile trainer")
    algo = eval(state['algo'])(lm_model, state, train_data)

    if state['rolling_vocab']:
        logger.debug("Initializing extra parameters")
        init_extra_parameters(lm_model, state)
        if not state['fixed_embeddings']:
            init_adadelta_extra_parameters(algo, state)
        with open(state['rolling_vocab_dict'], 'rb') as f:
            lm_model.rolling_vocab_dict = cPickle.load(f)
        lm_model.total_num_batches = max(lm_model.rolling_vocab_dict)
        lm_model.Dx_shelve = shelve.open(state['Dx_file'])
        lm_model.Dy_shelve = shelve.open(state['Dy_file'])

    logger.debug("Run training")
    main = MainLoop(train_data,
                    None,
                    None,
                    lm_model,
                    algo,
                    state,
                    None,
                    reset=state['reset'],
                    hooks=[RandomSamplePrinter(state, lm_model, train_data)]
                    if state['hookFreq'] >= 0 else None)
    if state['reload']:
        main.load()
    if state['loopIters'] > 0:
        main.main()
    if state['rolling_vocab']:
        lm_model.Dx_shelve.close()
        lm_model.Dy_shelve.close()
コード例 #14
0
ファイル: rolling_dicts.py プロジェクト: lvapeab/LV_groundhog
def main():
    args = parse_args()

    state = getattr(experiments.nmt, args.proto)()
    if args.state:
        if args.state.endswith(".py"):
            state.update(eval(open(args.state).read()))
        else:
            with open(args.state) as src:
                state.update(cPickle.load(src))
    for change in args.changes:
        state.update(eval("dict({})".format(change)))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    logger.debug("State:\n{}".format(pprint.pformat(state)))

    if 'var_src_len' not in state:
        state['var_src_len'] = False
    if 'partial_Dxy' not in state:
        state['partial_Dxy'] = False
    state['rolling_vocab'] = True
    state['use_infinite_loop'] = False
    logger.debug("rolling_vocab set to True, 'use_infinite_loop' set to False")
    rng = numpy.random.RandomState(state['seed'])

    logger.debug("Load data")
    train_data = get_batch_iterator(state, rng)
    train_data.start(-1)

    if not state.get('var_src_len'):
        dx = {}
        Dx = {}
        Cx = {}
    else:
        Dx = {}
    dy = {}
    Dy = {}
    Cy = {}

    if not state.get('var_src_len'):
        for i in xrange(state['n_sym_source']):
            Dx[i] = i
            Cx[i] = i
    for i in xrange(state['n_sym_target']):
        Dy[i] = i
        Cy[i] = i

    def update_dicts(arr, d, D, C, full):
        i_range, j_range = numpy.shape(arr)
        for i in xrange(i_range):
            for j in xrange(j_range):
                word = arr[i,j]
                if word not in d:
                    if len(d) == full:
                        return True
                    if word not in D: # Also not in C
                        key, value = C.popitem()
                        del D[key]
                        d[word] = value
                        D[word] = value
                    else: # Also in C as (d UNION C) is D. (d INTERSECTION C) is the empty set.
                        d[word] = D[word]
                        del C[word]
        return False

    def unlimited_update_dicts(arr, D, size):
        i_range, j_range = numpy.shape(arr)
        for i in xrange(i_range):
            for j in xrange(j_range):
                word = arr[i,j]
                if word not in D:
                    D[word] = size
                    size += 1

    prev_step = 0
    step = 0
    rolling_vocab_dict = {}
    Dx_dict = {}
    Dy_dict = {}

    output = False
    stop = False

    while not stop: # Assumes the shuffling in get_homogeneous_batch_iter is always the same (Is this true?)
        try:
            batch = train_data.next()
            if step == 0:
                rolling_vocab_dict[step] = (batch['x'][:,0].tolist(), batch['y'][:,0].tolist())
        except:
            batch = None
            stop = True

        if batch:
            if not state.get('var_src_len'):
                output = update_dicts(batch['x'], dx, Dx, Cx, state['n_sym_source'])
                output += update_dicts(batch['y'], dy, Dy, Cy, state['n_sym_target'])
            else:
                unlimited_update_dicts(batch['x'], Dx, len(Dx))
                output = update_dicts(batch['y'], dy, Dy, Cy, state['n_sym_target'])

            if output:
                Dx_dict[prev_step] = Dx.copy() # Save dictionaries for the batches preceding this one
                Dy_dict[prev_step] = Dy.copy()
                rolling_vocab_dict[step] = (batch['x'][:,0].tolist(), batch['y'][:,0].tolist()) # When we get to this batch, we will need to use a new vocabulary
                if state.get('partial_Dxy') and (step/100000 - prev_step/100000):
                    print 'Updating Dx, Dy'
                    Dx_file = shelve.open(state['Dx_file'])
                    Dy_file = shelve.open(state['Dy_file'])
                    for key in Dx_dict:
                        Dx_file[str(key)] = Dx_dict[key]
                        Dy_file[str(key)] = Dy_dict[key]
                    Dx_file.close()
                    Dy_file.close()
                    Dx_dict = {}
                    Dy_dict = {}
                # tuple of first sentences of the batch # Uses large vocabulary indices
                prev_step = step
                if not state.get('var_src_len'):
                    print step
                    dx = {}
                    Cx = Dx.copy()
                else:
                    print step, len(Dx)
                    Dx = {}
                dy = {}
                Cy = Dy.copy()
                output = False

                if not state.get('var_src_len'):
                    update_dicts(batch['x'], dx, Dx, Cx, state['n_sym_source']) # Assumes you cannot fill dx or dy with only 1 batch
                    update_dicts(batch['y'], dy, Dy, Cy, state['n_sym_target'])
                else:
                    unlimited_update_dicts(batch['x'], Dx, len(Dx))
                    update_dicts(batch['y'], dy, Dy, Cy, state['n_sym_target'])
            
            step += 1

    Dx_dict[prev_step] = Dx.copy()
    Dy_dict[prev_step] = Dy.copy()
    rolling_vocab_dict[step]=0 # Total number of batches # Don't store first sentences here

    with open(state['rolling_vocab_dict'],'w') as f:
        cPickle.dump(rolling_vocab_dict, f)
    Dx_file = shelve.open(state['Dx_file'])
    Dy_file = shelve.open(state['Dy_file'])
    for key in Dx_dict:
        Dx_file[str(key)] = Dx_dict[key]
        Dy_file[str(key)] = Dy_dict[key]
    Dx_file.close()
    Dy_file.close()
コード例 #15
0
ファイル: train.py プロジェクト: krzwolk/GroundHog
    state['word_indx_trgt'] = prel('vocab.lang2.pkl')

    update_custom_keys(state, conf, ['bs', 'loopIters', 'timeStop', 'dim',
                                     'null_sym_source', 'null_sym_target'])
    if conf['method'] == 'RNNenc-50':
        state['prefix'] = 'encdec-50_'
        state['seqlen'] = 50
        state['sort_k_batches'] = 20

    log.debug("State:\n{}".format(pprint.pformat(state)))

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, False)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()

    log.debug("Load data")
    train_data = get_batch_iterator(state)
    log.debug("Compile trainer")
    algo = eval(state['algo'])(lm_model, state, train_data)
    log.debug("Run training")
    main = MainLoop(train_data, None, None, lm_model, algo, state, None,
                    reset=state['reset'],
                    hooks=[RandomSamplePrinter(state, lm_model, train_data)]
                    if state['hookFreq'] >= 0
                    else None)
    if state['reload']:
        main.load()
    if state['loopIters'] > 0:
        main.main()
コード例 #16
0
ファイル: train.py プロジェクト: Blues5/GroundHog
def main():
    args = parse_args()

    # this loads the state specified in the prototype
    state = getattr(experiments.nmt, args.proto)()
    # this is based on the suggestion in the README.md in this foloder
    if args.state:
        if args.state.endswith(".py"):
            state.update(eval(open(args.state).read()))
        else:
            with open(args.state) as src:
                state.update(cPickle.load(src))
    for change in args.changes:
        state.update(eval("dict({})".format(change)))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    logger.debug("State:\n{}".format(pprint.pformat(state)))

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, args.skip_init)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()

    # If we are going to use validation with the bleu script, we
    # will need early stopping
    bleu_validator = None
    if state['bleu_script'] is not None and state['validation_set'] is not None\
        and state['validation_set_grndtruth'] is not None:
        # make beam search
        beam_search = BeamSearch(enc_dec)
        beam_search.compile()
        bleu_validator = BleuValidator(state,
                                       lm_model,
                                       beam_search,
                                       verbose=state['output_validation_set'])

    logger.debug("Load data")
    train_data = get_batch_iterator(state)
    logger.debug("Compile trainer")

    algo = eval(state['algo'])(lm_model, state, train_data)
    logger.debug("Run training")

    main = MainLoop(train_data,
                    None,
                    None,
                    lm_model,
                    algo,
                    state,
                    None,
                    reset=state['reset'],
                    bleu_val_fn=bleu_validator,
                    hooks=[RandomSamplePrinter(state, lm_model, train_data)]
                    if state['hookFreq'] >= 0
                    and state['validation_set'] is not None else None)

    if state['reload']:
        main.load()
    if state['loopIters'] > 0:
        main.main()