def main(): args = parse_args() state = getattr(experiments.nmt, args.proto)() if args.state: if args.state.endswith(".py"): state.update(eval(open(args.state).read())) else: with open(args.state) as src: state.update(cPickle.load(src)) for change in args.changes: state.update(eval("dict({})".format(change))) logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") logger.debug("State:\n{}".format(pprint.pformat(state))) rng = numpy.random.RandomState(state['seed']) enc_dec = RNNEncoderDecoder(state, rng, skip_init=args.skip_init, compute_alignment=True) enc_dec.build() lm_model = enc_dec.create_lm_model() logger.debug("Load data") train_data = get_batch_iterator(state) logger.debug("Compile trainer") algo = eval(state['algo'])(lm_model, state, train_data) logger.debug("Run training") main = MainLoop(train_data, None, None, lm_model, algo, state, None, reset=state['reset'], hooks=[RandomSamplePrinter(state, lm_model, train_data)] if state['hookFreq'] >= 0 else None) if state['reload']: main.load() if state['loopIters'] > 0: main.main()
def create_loop(state, skip_init=False): """TODO: Docstring for create_loop. :state: TODO :skip_init: TODO :returns: TODO """ log.debug("State:\n{}".format(pprint.pformat(state))) rng = numpy.random.RandomState(state["seed"]) enc_dec = RNNEncoderDecoder(state, rng, skip_init) enc_dec.build() lm_model = enc_dec.create_lm_model() log.debug("Load data") train_data = get_batch_iterator(state) log.debug("Compile trainer") algo = eval(state["algo"])(lm_model, state, train_data) log.debug("Run training") return MainLoop( train_data, None, None, lm_model, algo, state, None, reset=state["reset"], hooks=[RandomSamplePrinter(state, lm_model, train_data)] if state["hookFreq"] >= 0 else None, )
def main(): args = parse_args() state = getattr(experiments.nmt, args.proto)() if args.state: if args.state.endswith(".py"): state.update(eval(open(args.state).read())) else: with open(args.state) as src: state.update(cPickle.load(src)) for change in args.changes: state.update(eval("dict({})".format(change))) logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") logger.debug("State:\n{}".format(pprint.pformat(state))) rng = numpy.random.RandomState(state['seed']) enc_dec = RNNEncoderDecoder(state, rng, args.skip_init) enc_dec.build() lm_model = enc_dec.create_lm_model() logger.debug("Load data") train_data = get_batch_iterator(state) logger.debug("Compile trainer") algo = eval(state['algo'])(lm_model, state, train_data) logger.debug("Run training") main = MainLoop(train_data, None, None, lm_model, algo, state, None, reset=state['reset'], hooks=[RandomSamplePrinter(state, lm_model, train_data)] if state['hookFreq'] >= 0 else None) if state['reload']: main.load() if state['loopIters'] > 0: main.main()
def main(): args = parse_args() print 'syscomb' state = getattr(experiments.nmt, args.proto)() if args.state: if args.state.endswith(".py"): state.update(eval(open(args.state).read())) else: with open(args.state) as src: state.update(cPickle.load(src)) for change in args.changes: state.update(eval("dict({})".format(change))) logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") logger.debug("State:\n{}".format(pprint.pformat(state))) ''' rng = numpy.random.RandomState(state['seed']) if state['syscomb']: enc_dec = SystemCombination(state, rng, args.skip_init) else: enc_dec = RNNEncoderDecoder(state, rng, args.skip_init) enc_dec.build() lm_model = enc_dec.create_lm_model() ''' logger.debug("Load data") train_data = get_batch_iterator(state) train_data.start(-1) '''
def main(): args = parse_args() print 'syscomb' state = getattr(experiments.nmt, args.proto)() if args.state: if args.state.endswith(".py"): state.update(eval(open(args.state).read())) else: with open(args.state) as src: state.update(cPickle.load(src)) for change in args.changes: state.update(eval("dict({})".format(change))) logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") logger.debug("State:\n{}".format(pprint.pformat(state))) rng = numpy.random.RandomState(state['seed']) if state['syscomb']: enc_dec = SystemCombination(state, rng, args.skip_init) else: enc_dec = RNNEncoderDecoder(state, rng, args.skip_init) enc_dec.build() if state['algo'] == 'SGD_mrt': train_sampler = enc_dec.create_sampler(many_samples=True) lm_model = enc_dec.create_lm_model() print 'lm model inputs:', lm_model.inputs logger.debug("Load data") if state['syscomb']: train_data = get_batch_iterator_multi(state) sampler = RandomSamplePrinter_multi(state, lm_model, train_data, enc_dec) else: train_data = get_batch_iterator(state) sampler = RandomSamplePrinter(state, lm_model, train_data) logger.debug("Compile trainer") if state['algo'] == 'SGD_mrt': algo = eval(state['algo'])(lm_model, state, train_data, train_sampler) else: algo = eval(state['algo'])(lm_model, state, train_data) logger.debug("Run training") main = MainLoop(train_data, None, None, lm_model, algo, state, None, reset=state['reset'], hooks=[sampler] if state['hookFreq'] >= 0 else None) if state['reload']: main.load() if state['loopIters'] > 0: main.main()
def main(): args = parse_args() # this loads the state specified in the prototype state = getattr(experiments.nmt, args.proto)() # this is based on the suggestion in the README.md in this foloder if args.state: if args.state.endswith(".py"): state.update(eval(open(args.state).read())) else: with open(args.state) as src: state.update(cPickle.load(src)) for change in args.changes: state.update(eval("dict({})".format(change))) logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") logger.debug("State:\n{}".format(pprint.pformat(state))) rng = numpy.random.RandomState(state['seed']) enc_dec = RNNEncoderDecoder(state, rng, args.skip_init) enc_dec.build() lm_model = enc_dec.create_lm_model() # If we are going to use validation with the bleu script, we # will need early stopping bleu_validator = None if state['bleu_script'] is not None and state['validation_set'] is not None\ and state['validation_set_grndtruth'] is not None: # make beam search beam_search = BeamSearch(enc_dec) beam_search.compile() bleu_validator = BleuValidator(state, lm_model, beam_search, verbose=state['output_validation_set']) logger.debug("Load data") train_data = get_batch_iterator(state) logger.debug("Compile trainer") algo = eval(state['algo'])(lm_model, state, train_data) logger.debug("Run training") main = MainLoop(train_data, None, None, lm_model, algo, state, None, reset=state['reset'], bleu_val_fn = bleu_validator, hooks=[RandomSamplePrinter(state, lm_model, train_data)] if state['hookFreq'] >= 0 and state['validation_set'] is not None else None) if state['reload']: main.load() if state['loopIters'] > 0: main.main()
def main(): args = parse_args() state = getattr(experiments.nmt, args.proto)() if args.state: if args.state.endswith(".py"): state.update(eval(open(args.state).read())) else: with open(args.state) as src: state.update(cPickle.load(src)) for change in args.changes: state.update(eval("dict({})".format(change))) logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") logger.debug("State:\n{}".format(pprint.pformat(state))) if 'rolling_vocab' not in state: state['rolling_vocab'] = 0 if 'save_algo' not in state: state['save_algo'] = 0 if 'save_gs' not in state: state['save_gs'] = 0 if 'fixed_embeddings' not in state: state['fixed_embeddings'] = False if 'save_iter' not in state: state['save_iter'] = -1 if 'var_src_len' not in state: state['var_src_len'] = False if 'reprocess_each_iteration' not in state: state['reprocess_each_iteration'] = False rng = numpy.random.RandomState(state['seed']) enc_dec = RNNEncoderDecoder(state, rng, args.skip_init) enc_dec.build() lm_model = enc_dec.create_lm_model() logger.debug("Load data") train_data = get_batch_iterator(state, rng) logger.debug("Compile trainer") algo = eval(state['algo'])(lm_model, state, train_data) if state['rolling_vocab']: logger.debug("Initializing extra parameters") init_extra_parameters(lm_model, state) if not state['fixed_embeddings']: init_adadelta_extra_parameters(algo, state) with open(state['rolling_vocab_dict'], 'rb') as f: lm_model.rolling_vocab_dict = cPickle.load(f) lm_model.total_num_batches = max(lm_model.rolling_vocab_dict) lm_model.Dx_shelve = shelve.open(state['Dx_file']) lm_model.Dy_shelve = shelve.open(state['Dy_file']) hooks = [] if state['hookFreq'] >= 0: hooks.append(RandomSamplePrinter(state, lm_model, train_data)) if 'external_validation_script' in state and state['external_validation_script']: hooks.append(ExternalValidator(state, lm_model)) logger.debug("Run training") main = MainLoop(train_data, None, None, lm_model, algo, state, None, reset=state['reset'], hooks= hooks) if state['reload']: main.load() if state['loopIters'] > 0: main.main() if state['rolling_vocab']: lm_model.Dx_shelve.close() lm_model.Dy_shelve.close()
def main(self): assert self.reset == -1 print_mem("start") self.state["gotNaN"] = 0 start_time = time.time() self.start_time = start_time self.batch_start_time = time.time() self.step = int(self.timings["step"]) self.algo.step = self.step if self.state["save_iter"] < 0: self.save_iter = 0 self.state["save_iter"] = 0 self.save() if self.channel is not None: self.channel.save() else: # Fake saving self.save_iter += 1 self.state["save_iter"] = self.save_iter self.save_time = time.time() last_cost = 1.0 self.state["clr"] = self.state["lr"] self.train_data.start(self.timings["next_offset"] if "next_offset" in self.timings else -1) if self.state["rolling_vocab"]: for i in xrange(self.timings["step"] - self.timings["super_step"]): self.train_data.next() if self.state["rolling_vocab"]: # Make sure dictionary is current. # If training is interrupted when the vocabularies are exchanged, # things may get broken. step_modulo = self.step % self.model.total_num_batches if step_modulo in self.model.rolling_vocab_dict: # 0 always in. cur_key = step_modulo else: cur_key = 0 for key in self.model.rolling_vocab_dict: if (key < step_modulo) and (key > cur_key): # Find largest key smaller than step_modulo cur_key = key new_large2small_src = self.model.Dx_shelve[str(cur_key)] new_large2small_trgt = self.model.Dy_shelve[str(cur_key)] self.roll_vocab_update_dicts(new_large2small_src, new_large2small_trgt) self.zero_or_reload = True while ( self.step < self.state["loopIters"] and last_cost > 0.1 * self.state["minerr"] and (time.time() - start_time) / 60.0 < self.state["timeStop"] and self.state["lr"] > self.state["minlr"] ): if self.step > 0 and (time.time() - self.save_time) / 60.0 >= self.state["saveFreq"]: self.save() if self.channel is not None: self.channel.save() self.save_time = time.time() st = time.time() try: if self.state["rolling_vocab"]: step_modulo = self.step % self.model.total_num_batches if step_modulo in self.model.rolling_vocab_dict: if not self.zero_or_reload: self.roll_vocab_small2large() # Not necessary for 0 or when reloading a properly saved model new_large2small_src = self.model.Dx_shelve[str(step_modulo)] new_large2small_trgt = self.model.Dy_shelve[str(step_modulo)] self.roll_vocab_update_dicts( new_large2small_src, new_large2small_trgt ) # Done above for 0 or reloaded model self.roll_vocab_large2small() try: tmp_batch = self.train_data.next(peek=True) except StopIteration: if self.state["reprocess_each_iteration"]: logger.info("Reached end of file; re-preprocessing") subprocess.check_call(self.state["reprocess_each_iteration"], shell=True) if self.state["rolling_vocab"]: os.remove(self.state["Dx_file"]) os.remove(self.state["Dy_file"]) tmp_state = copy.deepcopy(self.state) rolling_dicts.main(tmp_state) with open(self.state["rolling_vocab_dict"], "rb") as f: self.model.rolling_vocab_dict = cPickle.load(f) self.model.total_num_batches = max(self.model.rolling_vocab_dict) self.model.Dx_shelve = shelve.open(self.state["Dx_file"]) self.model.Dy_shelve = shelve.open(self.state["Dy_file"]) # round up/down number of steps so modulo is 0 (hack because total_num_batches can change) logger.debug("step before restart: {0}".format(self.step)) if self.step % self.model.total_num_batches < self.model.total_num_batches / 2: self.step -= self.step % self.model.total_num_batches else: self.step += self.model.total_num_batches - ( self.step % self.model.total_num_batches ) logger.debug("step after restart: {0}".format(self.step)) logger.debug("Load data") self.train_data = get_batch_iterator( self.state, numpy.random.RandomState(self.state["seed"]) ) self.train_data.start(-1) self.timings["next_offset"] = -1 step_modulo = self.step % self.model.total_num_batches if step_modulo in self.model.rolling_vocab_dict: if not self.zero_or_reload: self.roll_vocab_small2large() # Not necessary for 0 or when reloading a properly saved model new_large2small_src = self.model.Dx_shelve[str(step_modulo)] new_large2small_trgt = self.model.Dy_shelve[str(step_modulo)] self.roll_vocab_update_dicts( new_large2small_src, new_large2small_trgt ) # Done above for 0 or reloaded model self.roll_vocab_large2small() self.algo.data = self.train_data self.algo.step = self.step tmp_batch = self.train_data.next(peek=True) if self.hooks: self.hooks[0].train_iter = self.train_data else: self.save() raise if ( tmp_batch["x"][:, 0].tolist(), tmp_batch["y"][:, 0].tolist(), ) == self.model.rolling_vocab_dict[step_modulo]: logger.debug("Identical first sentences. OK") else: logger.error("Batches do not correspond.") elif self.state["hookFreq"] > 0 and self.step % self.state["hookFreq"] == 0 and self.hooks: [fn() for fn in self.hooks] # Hook first so that the peeked batch is the same as the one used in algo # Use elif not to peek twice try: rvals = self.algo() except StopIteration: if self.state["reprocess_each_iteration"]: logger.info("Reached end of file; re-preprocessing") subprocess.check_call(self.state["reprocess_each_iteration"], shell=True) logger.debug("Load data") self.train_data = get_batch_iterator(self.state, numpy.random.RandomState(self.state["seed"])) self.train_data.start(-1) self.timings["next_offset"] = -1 self.algo.data = self.train_data self.algo.step = self.step rvals = self.algo() if self.hooks: self.hooks[0].train_iter = self.train_data else: self.save() raise self.state["traincost"] = float(rvals["cost"]) self.state["step"] = self.step last_cost = rvals["cost"] for name in rvals.keys(): self.timings[name][self.step] = float(numpy.array(rvals[name])) if self.l2_params: for param in self.model.params: self.timings["l2_" + param.name][self.step] = numpy.mean(param.get_value() ** 2) ** 0.5 if (numpy.isinf(rvals["cost"]) or numpy.isnan(rvals["cost"])) and self.state["on_nan"] == "raise": self.state["gotNaN"] = 1 self.save() if self.channel: self.channel.save() print "Got NaN while training" last_cost = 0 if self.valid_data is not None and self.step % self.state["validFreq"] == 0 and self.step > 1: valcost = self.validate() if valcost > self.old_cost * self.state["cost_threshold"]: self.patience -= 1 if "lr_start" in self.state and self.state["lr_start"] == "on_error": self.state["lr_start"] = self.step elif valcost < self.old_cost: self.patience = self.state["patience"] self.old_cost = valcost if self.state["divide_lr"] and self.patience < 1: # Divide lr by 2 self.algo.lr = self.algo.lr / self.state["divide_lr"] bparams = dict(self.model.best_params) self.patience = self.state["patience"] for p in self.model.params: p.set_value(bparams[p.name]) if not self.state["rolling_vocab"]: # Standard use of hooks if self.state["hookFreq"] > 0 and self.step % self.state["hookFreq"] == 0 and self.hooks: [fn() for fn in self.hooks] if self.reset > 0 and self.step > 1 and self.step % self.reset == 0: print "Resetting the data iterator" self.train_data.reset() self.step += 1 if self.state["rolling_vocab"]: self.zero_or_reload = False self.timings["step"] = self.step # Step now if (self.step % self.model.total_num_batches) % self.state[ "sort_k_batches" ] == 0: # Start of a super_batch. logger.debug("Set super_step and next_offset") # This log shoud appear just before 'logger.debug("Start of a super batch")' in 'get_homogeneous_batch_iter()' self.timings["super_step"] = self.step # Step at start of superbatch. super_step < step self.timings["next_offset"] = self.train_data.next_offset # Where to start after reload. Will need to call next() a few times else: self.timings["step"] = self.step self.timings["next_offset"] = self.train_data.next_offset except KeyboardInterrupt: break if self.state["rolling_vocab"]: self.roll_vocab_small2large() self.state["wholetime"] = float(time.time() - start_time) if self.valid_data is not None: self.validate() self.save() if self.channel: self.channel.save() print "Took", (time.time() - start_time) / 60.0, "min" avg_step = self.timings["time_step"][: self.step].mean() avg_cost2expl = self.timings["log2_p_expl"][: self.step].mean() print "Average step took {}".format(avg_step) print "That amounts to {} sentences in a day".format(1 / avg_step * 86400 * self.state["bs"]) print "Average log2 per example is {}".format(avg_cost2expl)
def main(): args = parse_args() state = prototype_state() with open(args.state) as src: state.update(cPickle.load(src)) state.update(eval("dict({})".format(args.changes))) state['sort_k_batches'] = 1 state['shuffle'] = False state['use_infinite_loop'] = False state['force_enc_repr_cpu'] = False logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") rng = numpy.random.RandomState(state['seed']) enc_dec = RNNEncoderDecoder(state, rng, skip_init=True, compute_alignment=True) enc_dec.build() lm_model = enc_dec.create_lm_model() lm_model.load(args.model_path) indx_word_src = cPickle.load(open(state['word_indx'],'rb')) indx_word_trgt = cPickle.load(open(state['word_indx_trgt'], 'rb')) if args.mode == "batch": data_given = args.src or args.trg txt = data_given and not (args.src.endswith(".h5") and args.trg.endswith(".h5")) if data_given and not txt: state['source'] = [args.src] state['target'] = [args.trg] if not data_given and not txt: logger.info("Using the training data") if txt: data_iter = BatchBiTxtIterator(state, args.src, indx_word_src, args.trg, indx_word_trgt, state['bs'], raise_unk=not args.allow_unk) data_iter.start() else: data_iter = get_batch_iterator(state) data_iter.start(0) score_file = open(args.scores, "w") if args.scores else sys.stdout scorer = enc_dec.create_scorer(batch=True) count = 0 n_samples = 0 logger.info('Scoring phrases') for i, batch in enumerate(data_iter): if batch == None: continue if args.n_batches >= 0 and i == args.n_batches: break if args.y_noise: y = batch['y'] random_words = numpy.random.randint(0, 100, y.shape).astype("int64") change_mask = numpy.random.binomial(1, args.y_noise, y.shape).astype("int64") y = change_mask * random_words + (1 - change_mask) * y batch['y'] = y st = time.time() [scores] = scorer(batch['x'], batch['y'], batch['x_mask'], batch['y_mask']) if args.print_probs: scores = numpy.exp(scores) up_time = time.time() - st for s in scores: print >>score_file, "{:.5e}".format(float(s)) n_samples += batch['x'].shape[1] count += 1 if count % 100 == 0: score_file.flush() logger.debug("Scores flushed") logger.debug("{} batches, {} samples, {} per sample; example scores: {}".format( count, n_samples, up_time/scores.shape[0], scores[:5])) logger.info("Done") score_file.flush() elif args.mode == "interact": scorer = enc_dec.create_scorer() while True: try: compute_probs = enc_dec.create_probs_computer() src_line = raw_input('Source sequence: ') trgt_line = raw_input('Target sequence: ') src_seq = parse_input(state, indx_word_src, src_line, raise_unk=not args.allow_unk, unk_sym=state['unk_sym_source'], null_sym=state['null_sym_source']) trgt_seq = parse_input(state, indx_word_trgt, trgt_line, raise_unk=not args.allow_unk, unk_sym=state['unk_sym_target'], null_sym=state['null_sym_target']) print "Binarized source: ", src_seq print "Binarized target: ", trgt_seq probs = compute_probs(src_seq, trgt_seq) print "Probs: {}, cost: {}".format(probs, -numpy.sum(numpy.log(probs))) except Exception: traceback.print_exc() elif args.mode == "txt": assert args.src and args.trg scorer = enc_dec.create_scorer() src_file = open(args.src, "r") trg_file = open(args.trg, "r") compute_probs = enc_dec.create_probs_computer(return_alignment=True) try: numpy.set_printoptions(precision=3, linewidth=150, suppress=True) i = 0 while True: src_line = next(src_file).strip() trgt_line = next(trg_file).strip() src_seq, src_words = parse_input(state, indx_word_src, src_line, raise_unk=not args.allow_unk, unk_sym=state['unk_sym_source'], null_sym=state['null_sym_source']) trgt_seq, trgt_words = parse_input(state, indx_word_trgt, trgt_line, raise_unk=not args.allow_unk, unk_sym=state['unk_sym_target'], null_sym=state['null_sym_target']) probs, alignment = compute_probs(src_seq, trgt_seq) if args.verbose: print "Probs: ", probs.flatten() if alignment.ndim == 3: print "Alignment:".ljust(20), src_line, "<eos>" for i, word in enumerate(trgt_words): print "{}{}".format(word.ljust(20), alignment[i, :, 0]) print "Generated by:" for i, word in enumerate(trgt_words): j = numpy.argmax(alignment[i, :, 0]) print "{} <--- {}".format(word, src_words[j] if j < len(src_words) else "<eos>") i += 1 if i % 100 == 0: sys.stdout.flush() logger.debug(i) print -numpy.sum(numpy.log(probs)) except StopIteration: pass else: raise Exception("Unknown mode {}".format(args.mode))
def main(): args = parse_args() state = prototype_state() with open(args.state) as src: state.update(cPickle.load(src)) state.update(eval("dict({})".format(args.changes))) state['sort_k_batches'] = 1 # which means don't sort state['shuffle'] = False state['use_infinite_loop'] = False state['force_enc_repr_cpu'] = False logging.basicConfig( level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") rng = numpy.random.RandomState(state['seed']) enc_dec = RNNEncoderDecoder(state, rng, skip_init=True, compute_alignment=True) enc_dec.build() lm_model = enc_dec.create_lm_model() lm_model.load(args.model_path) indx_word_src = cPickle.load(open(state['word_indx'], 'rb')) indx_word_trgt = cPickle.load(open(state['word_indx_trgt'], 'rb')) if args.mode == "batch": data_given = args.src or args.trg txt = data_given and not (args.src.endswith(".h5") and args.trg.endswith(".h5")) if data_given and not txt: state['source'] = [args.src] state['target'] = [args.trg] if not data_given and not txt: logger.info("Using the training data") if txt: data_iter = BatchBiTxtIterator(state, args.src, indx_word_src, args.trg, indx_word_trgt, state['bs'], raise_unk=not args.allow_unk) data_iter.start() else: data_iter = get_batch_iterator(state) data_iter.start(0) score_file = open(args.scores, "w") if args.scores else sys.stdout scorer = enc_dec.create_scorer(batch=True) count = 0 n_samples = 0 logger.info('Scoring phrases') for i, batch in enumerate(data_iter): if batch == None: continue if args.n_batches >= 0 and i == args.n_batches: break if args.y_noise: y = batch['y'] random_words = numpy.random.randint(0, 100, y.shape).astype("int64") change_mask = numpy.random.binomial(1, args.y_noise, y.shape).astype("int64") y = change_mask * random_words + (1 - change_mask) * y batch['y'] = y st = time.time() [scores] = scorer(batch['x'], batch['y'], batch['x_mask'], batch['y_mask']) if args.print_probs: scores = numpy.exp(scores) up_time = time.time() - st for s in scores: print >> score_file, "{:.5e}".format(float(s)) n_samples += batch['x'].shape[1] count += 1 if count % 100 == 0: score_file.flush() logger.debug("Scores flushed") logger.debug( "{} batches, {} samples, {} per sample; example scores: {}". format(count, n_samples, up_time / scores.shape[0], scores[:5])) logger.info("Done") score_file.flush() elif args.mode == "interact": scorer = enc_dec.create_scorer() while True: try: compute_probs = enc_dec.create_probs_computer() src_line = raw_input('Source sequence: ') trgt_line = raw_input('Target sequence: ') src_seq = parse_input(state, indx_word_src, src_line, raise_unk=not args.allow_unk, unk_sym=state['unk_sym_source'], null_sym=state['null_sym_source']) trgt_seq = parse_input(state, indx_word_trgt, trgt_line, raise_unk=not args.allow_unk, unk_sym=state['unk_sym_target'], null_sym=state['null_sym_target']) print "Binarized source: ", src_seq print "Binarized target: ", trgt_seq probs = compute_probs(src_seq, trgt_seq) print "Probs: {}, cost: {}".format( probs, -numpy.sum(numpy.log(probs))) except Exception: traceback.print_exc() elif args.mode == "txt": assert args.src and args.trg scorer = enc_dec.create_scorer() src_file = open(args.src, "r") trg_file = open(args.trg, "r") compute_probs = enc_dec.create_probs_computer(return_alignment=True) try: numpy.set_printoptions(precision=3, linewidth=150, suppress=True) i = 0 while True: src_line = next(src_file).strip() trgt_line = next(trg_file).strip() src_seq, src_words = parse_input( state, indx_word_src, src_line, raise_unk=not args.allow_unk, unk_sym=state['unk_sym_source'], null_sym=state['null_sym_source']) trgt_seq, trgt_words = parse_input( state, indx_word_trgt, trgt_line, raise_unk=not args.allow_unk, unk_sym=state['unk_sym_target'], null_sym=state['null_sym_target']) probs, alignment = compute_probs(src_seq, trgt_seq) if args.verbose: print "Probs: ", probs.flatten() if alignment.ndim == 3: print "Alignment:".ljust(20), src_line, "<eos>" for i, word in enumerate(trgt_words): print "{}{}".format(word.ljust(20), alignment[i, :, 0]) print "Generated by:" for i, word in enumerate(trgt_words): j = numpy.argmax(alignment[i, :, 0]) print "{} <--- {}".format( word, src_words[j] if j < len(src_words) else "<eos>") i += 1 if i % 100 == 0: sys.stdout.flush() logger.debug(i) print -numpy.sum(numpy.log(probs)) except StopIteration: pass else: raise Exception("Unknown mode {}".format(args.mode))
def main(): args = parse_args() state = getattr(experiments.nmt, args.proto)() if args.state: if args.state.endswith(".py"): state.update(eval(open(args.state).read())) else: with open(args.state) as src: state.update(cPickle.load(src)) for change in args.changes: state.update(eval("dict({})".format(change))) logging.basicConfig( level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") logger.debug("State:\n{}".format(pprint.pformat(state))) if 'var_src_len' not in state: state['var_src_len'] = False if 'partial_Dxy' not in state: state['partial_Dxy'] = False state['rolling_vocab'] = True state['use_infinite_loop'] = False logger.debug("rolling_vocab set to True, 'use_infinite_loop' set to False") rng = numpy.random.RandomState(state['seed']) logger.debug("Load data") train_data = get_batch_iterator(state, rng) train_data.start(-1) if not state.get('var_src_len'): dx = {} Dx = {} Cx = {} else: Dx = {} dy = {} Dy = {} Cy = {} if not state.get('var_src_len'): for i in xrange(state['n_sym_source']): Dx[i] = i Cx[i] = i for i in xrange(state['n_sym_target']): Dy[i] = i Cy[i] = i def update_dicts(arr, d, D, C, full): i_range, j_range = numpy.shape(arr) for i in xrange(i_range): for j in xrange(j_range): word = arr[i, j] if word not in d: if len(d) == full: return True if word not in D: # Also not in C key, value = C.popitem() del D[key] d[word] = value D[word] = value else: # Also in C as (d UNION C) is D. (d INTERSECTION C) is the empty set. d[word] = D[word] del C[word] return False def unlimited_update_dicts(arr, D, size): i_range, j_range = numpy.shape(arr) for i in xrange(i_range): for j in xrange(j_range): word = arr[i, j] if word not in D: D[word] = size size += 1 prev_step = 0 step = 0 rolling_vocab_dict = {} Dx_dict = {} Dy_dict = {} output = False stop = False while not stop: # Assumes the shuffling in get_homogeneous_batch_iter is always the same (Is this true?) try: batch = train_data.next() if step == 0: rolling_vocab_dict[step] = (batch['x'][:, 0].tolist(), batch['y'][:, 0].tolist()) except: batch = None stop = True if batch: if not state.get('var_src_len'): output = update_dicts(batch['x'], dx, Dx, Cx, state['n_sym_source']) output += update_dicts(batch['y'], dy, Dy, Cy, state['n_sym_target']) else: unlimited_update_dicts(batch['x'], Dx, len(Dx)) output = update_dicts(batch['y'], dy, Dy, Cy, state['n_sym_target']) if output: Dx_dict[prev_step] = Dx.copy( ) # Save dictionaries for the batches preceding this one Dy_dict[prev_step] = Dy.copy() rolling_vocab_dict[step] = ( batch['x'][:, 0].tolist(), batch['y'][:, 0].tolist() ) # When we get to this batch, we will need to use a new vocabulary if state.get('partial_Dxy') and (step / 100000 - prev_step / 100000): print 'Updating Dx, Dy' Dx_file = shelve.open(state['Dx_file']) Dy_file = shelve.open(state['Dy_file']) for key in Dx_dict: Dx_file[str(key)] = Dx_dict[key] Dy_file[str(key)] = Dy_dict[key] Dx_file.close() Dy_file.close() Dx_dict = {} Dy_dict = {} # tuple of first sentences of the batch # Uses large vocabulary indices prev_step = step if not state.get('var_src_len'): print step dx = {} Cx = Dx.copy() else: print step, len(Dx) Dx = {} dy = {} Cy = Dy.copy() output = False if not state.get('var_src_len'): update_dicts( batch['x'], dx, Dx, Cx, state['n_sym_source'] ) # Assumes you cannot fill dx or dy with only 1 batch update_dicts(batch['y'], dy, Dy, Cy, state['n_sym_target']) else: unlimited_update_dicts(batch['x'], Dx, len(Dx)) update_dicts(batch['y'], dy, Dy, Cy, state['n_sym_target']) step += 1 Dx_dict[prev_step] = Dx.copy() Dy_dict[prev_step] = Dy.copy() rolling_vocab_dict[ step] = 0 # Total number of batches # Don't store first sentences here with open(state['rolling_vocab_dict'], 'w') as f: cPickle.dump(rolling_vocab_dict, f) Dx_file = shelve.open(state['Dx_file']) Dy_file = shelve.open(state['Dy_file']) for key in Dx_dict: Dx_file[str(key)] = Dx_dict[key] Dy_file[str(key)] = Dy_dict[key] Dx_file.close() Dy_file.close()
def main(): args = parse_args() state = getattr(experiments.nmt, args.proto)() if args.state: if args.state.endswith(".py"): state.update(eval(open(args.state).read())) else: with open(args.state) as src: state.update(cPickle.load(src)) for change in args.changes: state.update(eval("dict({})".format(change))) logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") logger.debug("State:\n{}".format(pprint.pformat(state))) rng = numpy.random.RandomState(state['seed']) if args.proto == 'prototype_ntm_state' or args.proto == 'prototype_ntmencdec_state': print 'Neural Turing Machine' enc_dec = NTMEncoderDecoder(state, rng, args.skip_init) else: enc_dec = RNNEncoderDecoder(state, rng, args.skip_init) enc_dec.build() lm_model = enc_dec.create_lm_model() #s_enc_dec = RNNEncoderDecoder(state, rng, args.skip_init) #s_lm_model = s_enc_dec.create_lm_model() logger.debug("Load data") train_data = get_batch_iterator(state) train_data.start(-1) logger.debug("Compile trainer") #algo = eval(state['algo'])(lm_model, state, train_data) #algo() #train print '---test training---' for i in range(1): batch = train_data.next() #print batch x = batch['x'] print x.shape print x xs = x[:,78:79] xsample = x[:,78] print xs.shape print xs print xsample y = batch['y'] ys = y[:,78:79] print y.shape x_mask = batch['x_mask'] xs_mask = x_mask[:,78:79] y_mask = batch['y_mask'] ys_mask = y_mask[:,78:79] if not (args.proto == 'prototype_ntm_state' or args.proto == 'prototype_ntmencdec_state'): print '---search---' train_outputs = enc_dec.forward_training.rvalss+[enc_dec.predictions.out] test_train = theano.function(inputs=[enc_dec.x, enc_dec.x_mask, enc_dec.y, enc_dec.y_mask], outputs=train_outputs) result = test_train(x,x_mask,y,y_mask) for i in result: print i.shape else: print '---ntm---' train_outputs = enc_dec.forward_training.rvalss+[ enc_dec.training_c.out, enc_dec.forward_training_c.out, enc_dec.forward_training_m.out, enc_dec.forward_training_rw.out, enc_dec.backward_training_c.out, enc_dec.backward_training_m.out, enc_dec.backward_training_rw.out, ] train_outputs = enc_dec.forward_training.rvalss+[\ enc_dec.predictions.out, enc_dec.training_c.out ] test_train = theano.function(inputs=[enc_dec.x, enc_dec.x_mask, enc_dec.y, enc_dec.y_mask], outputs=train_outputs) result = test_train(x,x_mask,y,y_mask) for i in result: print i.shape #small batch test print '---small---' results = test_train(xs,xs_mask,ys,ys_mask) for i in results: print i.shape print '---compare---' #print result[1][:,4,:,:] #print results[1][:,0,:,:] print results[-1].shape print result[-1][:,78,:]-results[-1][:,0,:] print numpy.sum(result[-1][:,78,:]-results[-1][:,0,:]) #print numpy.sum(result[0][:,4,:]-results[0][:,0,:]) tmp = copy.deepcopy(result[-1][:,78,:]) tmpm = copy.deepcopy(result[1][:,78,:,:]) #sample #batch = train_data.next() #print batch print '---test sampling---' x = [7,152,429,731,10239,1127,747,480,30000] n_samples=10 n_steps=10 T=1 inps = [enc_dec.sampling_x, enc_dec.n_samples, enc_dec.n_steps, enc_dec.T] #test_sample = theano.function(inputs=[enc_dec.sampling_x], # outputs=[enc_dec.sample]) test_outputs = [enc_dec.sampling_c, enc_dec.forward_sampling_c, enc_dec.forward_sampling_m, enc_dec.forward_sampling_rw, enc_dec.backward_sampling_c, enc_dec.backward_sampling_m, enc_dec.backward_sampling_rw ] test_outputs = enc_dec.forward_sampling.rvalss#+[enc_dec.sample,enc_dec.sample_log_prob,enc_dec.sampling_updates] #test_outputs = [enc_dec.sample,enc_dec.sample_log_prob] #sample_fn = theano.function(inputs=inps,outputs=test_outputs) sampler = enc_dec.create_sampler(many_samples=True) result = sampler(n_samples, n_steps,T,xsample) #print result print '---single repr---' c,m = enc_dec.create_representation_computer()(x) states = map(lambda x : x[None, :], enc_dec.create_initializers()(c)) #print states print states[0].shape print m[-1:].shape ''' next = enc_dec.create_next_states_computer(c, 0, inputs, m[-1:],*states) #print next[0] #print next[1] print next[0].shape print next[1].shape print c print m ''' print '---repr compare---' print c.shape print m.shape print c-tmp[0:c.shape[0],:] print numpy.sum(c-tmp[0:c.shape[0],:],axis=1) print m-tmpm[0:m.shape[0],:,:] print numpy.sum(m-tmp[0:m.shape[0],:,:],axis=1) return
def main(): args = parse_args() state = getattr(experiments.nmt, args.proto)() if args.state: if args.state.endswith(".py"): state.update(eval(open(args.state).read())) else: with open(args.state) as src: state.update(cPickle.load(src)) for change in args.changes: state.update(eval("dict({})".format(change))) logging.basicConfig( level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") logger.debug("State:\n{}".format(pprint.pformat(state))) if 'rolling_vocab' not in state: state['rolling_vocab'] = 0 if 'save_algo' not in state: state['save_algo'] = 0 if 'save_gs' not in state: state['save_gs'] = 0 if 'fixed_embeddings' not in state: state['fixed_embeddings'] = False if 'save_iter' not in state: state['save_iter'] = -1 if 'var_src_len' not in state: state['var_src_len'] = False rng = numpy.random.RandomState(state['seed']) enc_dec = RNNEncoderDecoder(state, rng, args.skip_init) enc_dec.build() lm_model = enc_dec.create_lm_model() logger.debug("Load data") train_data = get_batch_iterator(state, rng) logger.debug("Compile trainer") algo = eval(state['algo'])(lm_model, state, train_data) if state['rolling_vocab']: logger.debug("Initializing extra parameters") init_extra_parameters(lm_model, state) if not state['fixed_embeddings']: init_adadelta_extra_parameters(algo, state) with open(state['rolling_vocab_dict'], 'rb') as f: lm_model.rolling_vocab_dict = cPickle.load(f) lm_model.total_num_batches = max(lm_model.rolling_vocab_dict) lm_model.Dx_shelve = shelve.open(state['Dx_file']) lm_model.Dy_shelve = shelve.open(state['Dy_file']) logger.debug("Run training") main = MainLoop(train_data, None, None, lm_model, algo, state, None, reset=state['reset'], hooks=[RandomSamplePrinter(state, lm_model, train_data)] if state['hookFreq'] >= 0 else None) if state['reload']: main.load() if state['loopIters'] > 0: main.main() if state['rolling_vocab']: lm_model.Dx_shelve.close() lm_model.Dy_shelve.close()
def main(): args = parse_args() state = getattr(experiments.nmt, args.proto)() if args.state: if args.state.endswith(".py"): state.update(eval(open(args.state).read())) else: with open(args.state) as src: state.update(cPickle.load(src)) for change in args.changes: state.update(eval("dict({})".format(change))) logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") logger.debug("State:\n{}".format(pprint.pformat(state))) if 'var_src_len' not in state: state['var_src_len'] = False if 'partial_Dxy' not in state: state['partial_Dxy'] = False state['rolling_vocab'] = True state['use_infinite_loop'] = False logger.debug("rolling_vocab set to True, 'use_infinite_loop' set to False") rng = numpy.random.RandomState(state['seed']) logger.debug("Load data") train_data = get_batch_iterator(state, rng) train_data.start(-1) if not state.get('var_src_len'): dx = {} Dx = {} Cx = {} else: Dx = {} dy = {} Dy = {} Cy = {} if not state.get('var_src_len'): for i in xrange(state['n_sym_source']): Dx[i] = i Cx[i] = i for i in xrange(state['n_sym_target']): Dy[i] = i Cy[i] = i def update_dicts(arr, d, D, C, full): i_range, j_range = numpy.shape(arr) for i in xrange(i_range): for j in xrange(j_range): word = arr[i,j] if word not in d: if len(d) == full: return True if word not in D: # Also not in C key, value = C.popitem() del D[key] d[word] = value D[word] = value else: # Also in C as (d UNION C) is D. (d INTERSECTION C) is the empty set. d[word] = D[word] del C[word] return False def unlimited_update_dicts(arr, D, size): i_range, j_range = numpy.shape(arr) for i in xrange(i_range): for j in xrange(j_range): word = arr[i,j] if word not in D: D[word] = size size += 1 prev_step = 0 step = 0 rolling_vocab_dict = {} Dx_dict = {} Dy_dict = {} output = False stop = False while not stop: # Assumes the shuffling in get_homogeneous_batch_iter is always the same (Is this true?) try: batch = train_data.next() if step == 0: rolling_vocab_dict[step] = (batch['x'][:,0].tolist(), batch['y'][:,0].tolist()) except: batch = None stop = True if batch: if not state.get('var_src_len'): output = update_dicts(batch['x'], dx, Dx, Cx, state['n_sym_source']) output += update_dicts(batch['y'], dy, Dy, Cy, state['n_sym_target']) else: unlimited_update_dicts(batch['x'], Dx, len(Dx)) output = update_dicts(batch['y'], dy, Dy, Cy, state['n_sym_target']) if output: Dx_dict[prev_step] = Dx.copy() # Save dictionaries for the batches preceding this one Dy_dict[prev_step] = Dy.copy() rolling_vocab_dict[step] = (batch['x'][:,0].tolist(), batch['y'][:,0].tolist()) # When we get to this batch, we will need to use a new vocabulary if state.get('partial_Dxy') and (step/100000 - prev_step/100000): print 'Updating Dx, Dy' Dx_file = shelve.open(state['Dx_file']) Dy_file = shelve.open(state['Dy_file']) for key in Dx_dict: Dx_file[str(key)] = Dx_dict[key] Dy_file[str(key)] = Dy_dict[key] Dx_file.close() Dy_file.close() Dx_dict = {} Dy_dict = {} # tuple of first sentences of the batch # Uses large vocabulary indices prev_step = step if not state.get('var_src_len'): print step dx = {} Cx = Dx.copy() else: print step, len(Dx) Dx = {} dy = {} Cy = Dy.copy() output = False if not state.get('var_src_len'): update_dicts(batch['x'], dx, Dx, Cx, state['n_sym_source']) # Assumes you cannot fill dx or dy with only 1 batch update_dicts(batch['y'], dy, Dy, Cy, state['n_sym_target']) else: unlimited_update_dicts(batch['x'], Dx, len(Dx)) update_dicts(batch['y'], dy, Dy, Cy, state['n_sym_target']) step += 1 Dx_dict[prev_step] = Dx.copy() Dy_dict[prev_step] = Dy.copy() rolling_vocab_dict[step]=0 # Total number of batches # Don't store first sentences here with open(state['rolling_vocab_dict'],'w') as f: cPickle.dump(rolling_vocab_dict, f) Dx_file = shelve.open(state['Dx_file']) Dy_file = shelve.open(state['Dy_file']) for key in Dx_dict: Dx_file[str(key)] = Dx_dict[key] Dy_file[str(key)] = Dy_dict[key] Dx_file.close() Dy_file.close()
state['word_indx_trgt'] = prel('vocab.lang2.pkl') update_custom_keys(state, conf, ['bs', 'loopIters', 'timeStop', 'dim', 'null_sym_source', 'null_sym_target']) if conf['method'] == 'RNNenc-50': state['prefix'] = 'encdec-50_' state['seqlen'] = 50 state['sort_k_batches'] = 20 log.debug("State:\n{}".format(pprint.pformat(state))) rng = numpy.random.RandomState(state['seed']) enc_dec = RNNEncoderDecoder(state, rng, False) enc_dec.build() lm_model = enc_dec.create_lm_model() log.debug("Load data") train_data = get_batch_iterator(state) log.debug("Compile trainer") algo = eval(state['algo'])(lm_model, state, train_data) log.debug("Run training") main = MainLoop(train_data, None, None, lm_model, algo, state, None, reset=state['reset'], hooks=[RandomSamplePrinter(state, lm_model, train_data)] if state['hookFreq'] >= 0 else None) if state['reload']: main.load() if state['loopIters'] > 0: main.main()
def main(): args = parse_args() # this loads the state specified in the prototype state = getattr(experiments.nmt, args.proto)() # this is based on the suggestion in the README.md in this foloder if args.state: if args.state.endswith(".py"): state.update(eval(open(args.state).read())) else: with open(args.state) as src: state.update(cPickle.load(src)) for change in args.changes: state.update(eval("dict({})".format(change))) logging.basicConfig( level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") logger.debug("State:\n{}".format(pprint.pformat(state))) rng = numpy.random.RandomState(state['seed']) enc_dec = RNNEncoderDecoder(state, rng, args.skip_init) enc_dec.build() lm_model = enc_dec.create_lm_model() # If we are going to use validation with the bleu script, we # will need early stopping bleu_validator = None if state['bleu_script'] is not None and state['validation_set'] is not None\ and state['validation_set_grndtruth'] is not None: # make beam search beam_search = BeamSearch(enc_dec) beam_search.compile() bleu_validator = BleuValidator(state, lm_model, beam_search, verbose=state['output_validation_set']) logger.debug("Load data") train_data = get_batch_iterator(state) logger.debug("Compile trainer") algo = eval(state['algo'])(lm_model, state, train_data) logger.debug("Run training") main = MainLoop(train_data, None, None, lm_model, algo, state, None, reset=state['reset'], bleu_val_fn=bleu_validator, hooks=[RandomSamplePrinter(state, lm_model, train_data)] if state['hookFreq'] >= 0 and state['validation_set'] is not None else None) if state['reload']: main.load() if state['loopIters'] > 0: main.main()