def __init__(self, schema, lexicon, model_path, fact_check, decoding, timed_session=False, consecutive_entity=True, realizer=None): super(NeuralSystem, self).__init__() self.schema = schema self.lexicon = lexicon self.timed_session = timed_session self.consecutive_entity = consecutive_entity # Load arguments args_path = os.path.join(model_path, 'config.json') config = read_json(args_path) config['batch_size'] = 1 config['gpu'] = 0 # Don't need GPU for batch_size=1 config['decoding'] = decoding args = argparse.Namespace(**config) mappings_path = os.path.join(model_path, 'vocab.pkl') mappings = read_pickle(mappings_path) vocab = mappings['vocab'] # TODO: different models have the same key now args.dropout = 0 logstats.add_args('model_args', args) model = build_model(schema, mappings, args) # Tensorflow config if args.gpu == 0: print 'GPU is disabled' config = tf.ConfigProto(device_count = {'GPU': 0}) else: gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = 0.5, allow_growth=True) config = tf.ConfigProto(device_count = {'GPU': 1}, gpu_options=gpu_options) # NOTE: need to close the session when done tf_session = tf.Session(config=config) tf.initialize_all_variables().run(session=tf_session) # Load TF model parameters ckpt = tf.train.get_checkpoint_state(model_path+'-best') assert ckpt, 'No checkpoint found' assert ckpt.model_checkpoint_path, 'No model path found in checkpoint' saver = tf.train.Saver() saver.restore(tf_session, ckpt.model_checkpoint_path) self.model_name = args.model if self.model_name == 'attn-copy-encdec': args.entity_target_form = 'graph' copy = True else: copy = False preprocessor = Preprocessor(schema, lexicon, args.entity_encoding_form, args.entity_decoding_form, args.entity_target_form, args.prepend) textint_map = TextIntMap(vocab, mappings['entity'], preprocessor) Env = namedtuple('Env', ['model', 'tf_session', 'preprocessor', 'vocab', 'copy', 'textint_map', 'stop_symbol', 'remove_symbols', 'max_len', 'evaluator', 'prepend', 'consecutive_entity', 'realizer']) self.env = Env(model, tf_session, preprocessor, mappings['vocab'], copy, textint_map, stop_symbol=vocab.to_ind(markers.EOS), remove_symbols=map(vocab.to_ind, (markers.EOS, markers.PAD)), max_len=20, evaluator=FactEvaluator() if fact_check else None, prepend=args.prepend, consecutive_entity=self.consecutive_entity, realizer=realizer)
vocab_path = os.path.join(args.init_from, 'vocab.pkl') saved_config = read_json(config_path) saved_config['decoding'] = args.decoding saved_config['batch_size'] = args.batch_size model_args = argparse.Namespace(**saved_config) # Checkpoint if args.test and args.best: ckpt = tf.train.get_checkpoint_state(args.init_from + '-best') else: ckpt = tf.train.get_checkpoint_state(args.init_from) assert ckpt, 'No checkpoint found' assert ckpt.model_checkpoint_path, 'No model path found in checkpoint' # Load vocab mappings = read_pickle(vocab_path) print 'Done [%fs]' % (time.time() - start) else: # Save config if not os.path.isdir(args.checkpoint): os.makedirs(args.checkpoint) config_path = os.path.join(args.checkpoint, 'config.json') write_json(vars(args), config_path) model_args = args mappings = None ckpt = None schema = Schema(model_args.schema_path, model_args.domain) scenario_db = ScenarioDB.from_dict(schema, read_json(args.scenarios_path)) dataset = read_dataset(scenario_db, args) print 'Building lexicon...'
parser.add_argument('--stats', nargs='+', help='Path to files containing the stats of a transcript') parser.add_argument('--names', nargs='+', help='Names of systems corresponding to stats files') parser.add_argument('--output', default='.', help='Path to output figure') parser.add_argument('--attr', default=False, action='store_true', help='Plot mentioned attributes') parser.add_argument('--completion', default=False, action='store_true', help='Plot completion') parser.add_argument('--ngram-freqs', default=False, action='store_true', help='Plot ngram frequencies') parser.add_argument('--utterance-freqs', default=False, action='store_true', help='Plot utterance frequencies') parser.add_argument('--act-freqs', default=False, action='store_true', help='Plot speech act frequencies') args = parser.parse_args() if args.ngram_freqs: stats = {} #stats_files = ['%s_ngram_counts.pkl' % x for x in args.stats] stats_files = args.stats for name, stats_file in izip(args.names, stats_files): stats[name] = read_pickle(stats_file) k = 10 interval = 0.2 bar_width = 0.2 names = ['Human'] + [x for x in args.names if x != 'Human'] colors = ['b', 'g', 'r', 'y', 'c', 'm'][:len(names)] for n in xrange(1, 4): plt.cla() ngrams = set() for name in names: sorted_words = sorted(stats[name][n].iteritems(), key=lambda x: x[1], reverse=True) ngrams.update([x[0] for x in sorted_words[:k]]) ngrams = sorted(list(ngrams), key=lambda x: stats['Human'][n][x], reverse=True)[:15] label = [' '.join(x) if isinstance(x, tuple) else x for x in ngrams] pos = np.arange(len(ngrams))[::-1] for i, (name, color) in enumerate(izip(names, colors)):