def run(self): if self.individual.cost is not None: return self.individual.cost start_time = time.time() hyper_param_mapping = self.individual.hyper_param_mapping print("Training %r using hyper params:" % self.individual.name, file=log.v2) for p in self.optim.hyper_params: print(" %s -> %s" % (p.description(), hyper_param_mapping[p]), file=log.v2) config = self.optim.create_config_instance(hyper_param_mapping, gpu_ids=self.gpu_ids) engine = Engine(config=config) train_data = StaticDataset.copy_from_dataset(self.optim.train_data) engine.init_train_from_config(config=config, train_data=train_data) # Not directly calling train() as we want to have full control. engine.epoch = 1 train_data.init_seq_order(epoch=engine.epoch) batches = train_data.generate_batches( recurrent_net=engine.network.recurrent, batch_size=engine.batch_size, max_seqs=engine.max_seqs, max_seq_length=int(engine.max_seq_length), seq_drop=engine.seq_drop, shuffle_batches=engine.shuffle_batches, used_data_keys=engine.network.used_data_keys) engine.updater.set_learning_rate(engine.learning_rate) trainer = Runner(engine=engine, dataset=train_data, batches=batches, train=True) self.runner = trainer if self.cancel_flag: raise CancelTrainingException("Trainer cancel flag is set") trainer.run(report_prefix="hyper param tune train %r" % self.individual.name) if not trainer.finalized: print("Trainer exception:", trainer.run_exception, file=log.v1) raise trainer.run_exception cost = trainer.score["cost:output"] print( "Individual %s:" % self.individual.name, "Train cost:", cost, "elapsed time:", hms_fraction(time.time() - start_time), file=self.optim.log) self.individual.cost = cost
def run(self): if self.individual.cost is not None: return self.individual.cost start_time = time.time() hyper_param_mapping = self.individual.hyper_param_mapping print("Training %r using hyper params:" % self.individual.name, file=log.v2) for p in self.optim.hyper_params: print(" %s -> %s" % (p.description(), hyper_param_mapping[p]), file=log.v2) config = self.optim.create_config_instance(hyper_param_mapping, gpu_ids=self.gpu_ids) engine = Engine(config=config) train_data = StaticDataset.copy_from_dataset(self.optim.train_data) engine.init_train_from_config(config=config, train_data=train_data) # Not directly calling train() as we want to have full control. engine.epoch = 1 train_data.init_seq_order(epoch=engine.epoch) batches = train_data.generate_batches( recurrent_net=engine.network.recurrent, batch_size=engine.batch_size, max_seqs=engine.max_seqs, max_seq_length=int(engine.max_seq_length), seq_drop=engine.seq_drop, shuffle_batches=engine.shuffle_batches, used_data_keys=engine.network.used_data_keys) engine.updater.set_learning_rate(engine.learning_rate, session=engine.tf_session) trainer = Runner(engine=engine, dataset=train_data, batches=batches, train=True) self.runner = trainer if self.cancel_flag: raise CancelTrainingException("Trainer cancel flag is set") trainer.run(report_prefix="hyper param tune train %r" % self.individual.name) if not trainer.finalized: print("Trainer exception:", trainer.run_exception, file=log.v1) raise trainer.run_exception cost = trainer.score["cost:output"] print( "Individual %s:" % self.individual.name, "Train cost:", cost, "elapsed time:", hms_fraction(time.time() - start_time), file=self.optim.log) self.individual.cost = cost
def main(argv): argparser = argparse.ArgumentParser(description=__doc__) argparser.add_argument("config_file", type=str) argparser.add_argument("--epoch", required=False, type=int) argparser.add_argument( '--data', default="train", help= "e.g. 'train', 'config:train', or sth like 'config:get_dataset('dev')'" ) argparser.add_argument('--do_search', default=False, action='store_true') argparser.add_argument('--beam_size', default=12, type=int) argparser.add_argument('--dump_dir', required=True) argparser.add_argument("--device", default="gpu") argparser.add_argument("--layers", default=["att_weights"], action="append", help="Layer of subnet to grab") argparser.add_argument("--rec_layer", default="output", help="Subnet layer to grab from; decoder") argparser.add_argument("--enc_layer", default="encoder") argparser.add_argument("--batch_size", type=int, default=5000) argparser.add_argument("--seq_list", default=[], action="append", help="predefined list of seqs") argparser.add_argument("--min_seq_len", default="0", help="can also be dict") argparser.add_argument("--output_format", default="npy", help="npy or png") argparser.add_argument("--dropout", default=None, type=float, help="if set, overwrites all dropout values") argparser.add_argument("--train_flag", action="store_true") args = argparser.parse_args(argv[1:]) layers = args.layers assert isinstance(layers, list) model_name = ".".join(args.config_file.split("/")[-1].split(".")[:-1]) init_returnn(config_fn=args.config_file, cmd_line_opts=["--device", args.device], args=args) if args.do_search: raise NotImplementedError min_seq_length = NumbersDict(eval(args.min_seq_len)) if not os.path.exists(args.dump_dir): os.makedirs(args.dump_dir) assert args.output_format in ["npy", "png"] if args.output_format == "png": import matplotlib.pyplot as plt # need to import early? https://stackoverflow.com/a/45582103/133374 import matplotlib.ticker as ticker dataset_str = args.data if dataset_str in ["train", "dev", "eval"]: dataset_str = "config:%s" % dataset_str dataset = init_dataset(dataset_str) init_net(args, layers) network = rnn.engine.network extra_fetches = {} for rec_ret_layer in ["rec_%s" % l for l in layers]: extra_fetches[rec_ret_layer] = rnn.engine.network.layers[ rec_ret_layer].output.get_placeholder_as_batch_major() extra_fetches.update({ "output": network.layers[args.rec_layer].output.get_placeholder_as_batch_major(), "output_len": network.layers[ args.rec_layer].output.get_sequence_lengths(), # decoder length "encoder_len": network.layers[ args.enc_layer].output.get_sequence_lengths(), # encoder length "seq_idx": network.get_extern_data("seq_idx"), "seq_tag": network.get_extern_data("seq_tag"), "target_data": network.get_extern_data(network.extern_data.default_input), "target_classes": network.get_extern_data(network.extern_data.default_target), }) dataset.init_seq_order(epoch=rnn.engine.epoch, seq_list=args.seq_list or None) dataset_batch = dataset.generate_batches( recurrent_net=network.recurrent, batch_size=args.batch_size, max_seqs=rnn.engine.max_seqs, max_seq_length=sys.maxsize, min_seq_length=min_seq_length, used_data_keys=network.used_data_keys) stats = {l: Stats() for l in layers} # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output, output_len, encoder_len, **kwargs): for i in range(len(seq_idx)): for l in layers: att_weights = kwargs["rec_%s" % l][i] stats[l].collect(att_weights.flatten()) if args.output_format == "npy": data = {} for i in range(len(seq_idx)): data[i] = { 'tag': seq_tag[i], 'data': target_data[i], 'classes': target_classes[i], 'output': output[i], 'output_len': output_len[i], 'encoder_len': encoder_len[i], } for l in [("rec_%s" % l) for l in layers]: assert l in kwargs out = kwargs[l][i] assert out.ndim >= 2 assert out.shape[0] >= output_len[i] and out.shape[ 1] >= encoder_len[i] data[i][l] = out[:output_len[i], :encoder_len[i]] fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % ( model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1]) np.save(fname, data) elif args.output_format == "png": for i in range(len(seq_idx)): for l in layers: extra_postfix = "" if args.dropout is not None: extra_postfix += "_dropout%.2f" % args.dropout elif args.train_flag: extra_postfix += "_train" fname = args.dump_dir + '/%s_ep%03d_plt_%05i_%s%s.png' % ( model_name, rnn.engine.epoch, seq_idx[i], l, extra_postfix) att_weights = kwargs["rec_%s" % l][i] att_weights = att_weights.squeeze(axis=2) # (out,enc) assert att_weights.shape[0] >= output_len[ i] and att_weights.shape[1] >= encoder_len[i] att_weights = att_weights[:output_len[i], :encoder_len[i]] print("Seq %i, %s: Dump att weights with shape %r to: %s" % (seq_idx[i], seq_tag[i], att_weights.shape, fname)) plt.matshow(att_weights) title = seq_tag[i] if dataset.can_serialize_data( network.extern_data.default_target): title += "\n" + dataset.serialize_data( network.extern_data.default_target, target_classes[i][:output_len[i]]) ax = plt.gca() tick_labels = [ dataset.serialize_data( network.extern_data.default_target, np.array([x], dtype=target_classes[i].dtype)) for x in target_classes[i][:output_len[i]] ] ax.set_yticklabels([''] + tick_labels, fontsize=8) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) plt.title(title) plt.savefig(fname) plt.close() else: raise NotImplementedError("output format %r" % args.output_format) runner = Runner(engine=rnn.engine, dataset=dataset, batches=dataset_batch, train=False, train_flag=args.dropout is not None or args.train_flag, extra_fetches=extra_fetches, extra_fetches_callback=fetch_callback) runner.run(report_prefix="att-weights epoch %i" % rnn.engine.epoch) for l in layers: stats[l].dump(stream_prefix="Layer %r " % l) if not runner.finalized: print("Some error occured, not finalized.") sys.exit(1) rnn.finalize()
def main(argv): argparser = argparse.ArgumentParser(description='Dump network as JSON.') argparser.add_argument("crnn_config_file", type=str) argparser.add_argument("--epoch", required=False, type=int) argparser.add_argument('--data', default="test") argparser.add_argument('--do_search', default=False, action='store_true') argparser.add_argument('--beam_size', default=12, type=int) argparser.add_argument('--dump_dir', required=True) argparser.add_argument("--device", default="gpu") argparser.add_argument("--layers", default=["att_weights"], action="append", help="Layer of subnet to grab") argparser.add_argument("--rec_layer", default="output", help="Subnet layer to grab from") argparser.add_argument("--batch_size", type=int, default=5000) args = argparser.parse_args(argv[1:]) if not os.path.exists(args.dump_dir): os.makedirs(args.dump_dir) model = ".".join(args.crnn_config_file.split("/")[-1].split(".")[:-1]) init(configFilename=args.crnn_config_file, commandLineOptions=["--device", args.device], args=args) if isinstance(args.layers, str): layers = [args.layers] else: layers = args.layers inject_retrieval_code(args, layers) network = rnn.engine.network assert rnn.eval_data is not None, "provide evaluation data" extra_fetches = {} for rec_ret_layer in ["rec_%s" % l for l in layers]: extra_fetches[rec_ret_layer] = rnn.engine.network.layers[ rec_ret_layer].output.placeholder extra_fetches.update({ "output": network.get_default_output_layer().output. get_placeholder_as_batch_major(), "output_len": network.get_default_output_layer().output.get_sequence_lengths( ), # decoder length "encoder_len": network.layers["encoder"].output.get_sequence_lengths( ), # encoder length "seq_idx": network.get_extern_data("seq_idx", mark_data_key_as_used=True), "seq_tag": network.get_extern_data("seq_tag", mark_data_key_as_used=True), "target_data": network.get_extern_data("data", mark_data_key_as_used=True), "target_classes": network.get_extern_data("classes", mark_data_key_as_used=True), }) dataset_batch = rnn.eval_data.generate_batches( recurrent_net=network.recurrent, batch_size=args.batch_size, max_seqs=rnn.engine.max_seqs, max_seq_length=sys.maxsize, used_data_keys=network.used_data_keys) # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output, output_len, encoder_len, **kwargs): data = {} for i in range(len(seq_idx)): data[i] = { 'tag': seq_tag[i], 'data': target_data[i], 'classes': target_classes[i], 'output': output[i], 'output_len': output_len[i], 'encoder_len': encoder_len[i], } for l in [("rec_%s" % l) for l in layers]: assert l in kwargs data[i][l] = kwargs[l] fname = os.path.join( args.dump_dir, '%s_ep%03d_data_%i_%i.npy' % (model, rnn.engine.epoch, seq_idx[0], seq_idx[-1])) np.save(fname, data) runner = Runner(engine=rnn.engine, dataset=rnn.eval_data, batches=dataset_batch, train=False, extra_fetches=extra_fetches, extra_fetches_callback=fetch_callback) runner.run(report_prefix="att-weights ") assert runner.finalized rnn.finalize()
def main(argv): argparser = argparse.ArgumentParser(description=__doc__) argparser.add_argument("config_file", type=str, help="RETURNN config, or model-dir") argparser.add_argument("--epoch", type=int) argparser.add_argument( '--data', default="train", help= "e.g. 'train', 'config:train', or sth like 'config:get_dataset('dev')'" ) argparser.add_argument('--do_search', default=False, action='store_true') argparser.add_argument('--beam_size', default=12, type=int) argparser.add_argument('--dump_dir', help="for npy or png") argparser.add_argument("--output_file", help="hdf") argparser.add_argument("--device", help="gpu or cpu (default: automatic)") argparser.add_argument("--layers", default=[], action="append", help="Layer of subnet to grab") argparser.add_argument("--rec_layer", default="output", help="Subnet layer to grab from; decoder") argparser.add_argument("--enc_layer", default="encoder") argparser.add_argument("--batch_size", type=int, default=5000) argparser.add_argument("--seq_list", default=[], action="append", help="predefined list of seqs") argparser.add_argument("--min_seq_len", default="0", help="can also be dict") argparser.add_argument("--num_seqs", default=-1, type=int, help="stop after this many seqs") argparser.add_argument("--output_format", default="npy", help="npy, png or hdf") argparser.add_argument("--dropout", default=None, type=float, help="if set, overwrites all dropout values") argparser.add_argument("--train_flag", action="store_true") argparser.add_argument("--reset_partition_epoch", type=int, default=None) argparser.add_argument("--reset_seq_ordering", default="default") argparser.add_argument("--reset_epoch_wise_filter", default=None) argparser.add_argument('--hmm_fac_fo', default=False, action='store_true') argparser.add_argument('--encoder_sa', default=False, action='store_true') argparser.add_argument('--tf_log_dir', help="for npy or png", default=None) argparser.add_argument("--instead_save_encoder_decoder", action="store_true") args = argparser.parse_args(argv[1:]) layers = args.layers assert isinstance(layers, list) config_fn = args.config_file explicit_model_dir = None if os.path.isdir(config_fn): # Assume we gave a model dir. explicit_model_dir = config_fn train_log_dir_config_pattern = "%s/train-*/*.config" % config_fn train_log_dir_configs = sorted(glob(train_log_dir_config_pattern)) assert train_log_dir_configs config_fn = train_log_dir_configs[-1] print("Using this config via model dir:", config_fn) else: assert os.path.isfile(config_fn) model_name = ".".join(config_fn.split("/")[-1].split(".")[:-1]) init_returnn(config_fn=config_fn, args=args) if explicit_model_dir: config.set( "model", "%s/%s" % (explicit_model_dir, os.path.basename(config.value('model', '')))) print("Model file prefix:", config.value('model', '')) min_seq_length = NumbersDict(eval(args.min_seq_len)) assert args.output_format in ["npy", "png", "hdf"] if args.output_format in ["npy", "png"]: assert args.dump_dir if not os.path.exists(args.dump_dir): os.makedirs(args.dump_dir) plt = ticker = None if args.output_format == "png": import matplotlib.pyplot as plt # need to import early? https://stackoverflow.com/a/45582103/133374 import matplotlib.ticker as ticker dataset_str = args.data if dataset_str in ["train", "dev", "eval"]: dataset_str = "config:%s" % dataset_str extra_dataset_kwargs = {} if args.reset_partition_epoch: print("NOTE: We are resetting partition epoch to %i." % (args.reset_partition_epoch, )) extra_dataset_kwargs["partition_epoch"] = args.reset_partition_epoch if args.reset_seq_ordering: print("NOTE: We will use %r seq ordering." % (args.reset_seq_ordering, )) extra_dataset_kwargs["seq_ordering"] = args.reset_seq_ordering if args.reset_epoch_wise_filter: extra_dataset_kwargs["epoch_wise_filter"] = eval( args.reset_epoch_wise_filter) dataset = init_dataset(dataset_str, extra_kwargs=extra_dataset_kwargs) if hasattr(dataset, "epoch_wise_filter") and args.reset_epoch_wise_filter is None: if dataset.epoch_wise_filter: print("NOTE: Resetting epoch_wise_filter to None.") dataset.epoch_wise_filter = None if args.reset_partition_epoch: assert dataset.partition_epoch == args.reset_partition_epoch if args.reset_seq_ordering: assert dataset.seq_ordering == args.reset_seq_ordering init_net(args, layers) network = rnn.engine.network hdf_writer = None if args.output_format == "hdf": assert args.output_file assert len(layers) == 1 sub_layer = network.get_layer("%s/%s" % (args.rec_layer, layers[0])) from HDFDataset import SimpleHDFWriter hdf_writer = SimpleHDFWriter(filename=args.output_file, dim=sub_layer.output.dim, ndim=sub_layer.output.ndim) extra_fetches = { "output": network.layers[args.rec_layer].output.get_placeholder_as_batch_major(), "output_len": network.layers[ args.rec_layer].output.get_sequence_lengths(), # decoder length "encoder_len": network.layers[ args.enc_layer].output.get_sequence_lengths(), # encoder length "seq_idx": network.get_extern_data("seq_idx"), "seq_tag": network.get_extern_data("seq_tag"), "target_data": network.get_extern_data(network.extern_data.default_input), "target_classes": network.get_extern_data(network.extern_data.default_target), } if args.instead_save_encoder_decoder: extra_fetches["encoder"] = network.layers["encoder"] extra_fetches["decoder"] = network.layers["output"].get_sub_layer( "decoder") else: for l in layers: sub_layer = rnn.engine.network.get_layer("%s/%s" % (args.rec_layer, l)) extra_fetches[ "rec_%s" % l] = sub_layer.output.get_placeholder_as_batch_major() if args.do_search: o_layer = rnn.engine.network.get_layer("output") extra_fetches["beam_scores_" + l] = o_layer.get_search_choices().beam_scores dataset.init_seq_order( epoch=1, seq_list=args.seq_list or None) # use always epoch 1, such that we have same seqs dataset_batch = dataset.generate_batches( recurrent_net=network.recurrent, batch_size=args.batch_size, max_seqs=rnn.engine.max_seqs, max_seq_length=sys.maxsize, min_seq_length=min_seq_length, max_total_num_seqs=args.num_seqs, used_data_keys=network.used_data_keys) stats = {l: Stats() for l in layers} # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output, output_len, encoder_len, **kwargs): """ :param list[int] seq_idx: len is n_batch :param list[str] seq_tag: len is n_batch :param numpy.ndarray target_data: extern data default input (e.g. "data"), shape e.g. (B,enc-T,...) :param numpy.ndarray target_classes: extern data default target (e.g. "classes"), shape e.g. (B,dec-T,...) :param numpy.ndarray output: rec layer output, shape e.g. (B,dec-T,...) :param numpy.ndarray output_len: rec layer seq len, i.e. decoder length, shape (B,) :param numpy.ndarray encoder_len: encoder seq len, shape (B,) :param kwargs: contains "rec_%s" % l for l in layers, the sub layers (e.g att weights) we are interested in """ n_batch = len(seq_idx) for i in range(n_batch): if not args.instead_save_encoder_decoder: for l in layers: att_weights = kwargs["rec_%s" % l][i] stats[l].collect(att_weights.flatten()) if args.output_format == "npy": data = {} if args.do_search: assert not args.instead_save_encoder_decoder, "Not implemented" # find axis with correct beam size axis_beam_size = n_batch * args.beam_size corr_axis = None num_axes = len(kwargs["rec_%s" % layers[0]].shape) for a in range(len(kwargs["rec_%s" % layers[0]].shape)): if kwargs["rec_%s" % layers[0]].shape[a] == axis_beam_size: corr_axis = a break assert corr_axis is not None, "Att Weights Decoding: correct axis not found! maybe check beam size." # set dimensions correctly for l, l_raw in zip([("rec_%s" % l) for l in layers], layers): swap = list(range(num_axes)) del swap[corr_axis] swap.insert(0, corr_axis) kwargs[l] = np.transpose(kwargs[l], swap) for i in range(n_batch): # The first beam contains the score with the highest beam i_beam = args.beam_size * i data[i] = { 'tag': seq_tag[i], 'data': target_data[i], 'classes': target_classes[i], 'output': output[i_beam], 'output_len': output_len[i_beam], 'encoder_len': encoder_len[i], } #if args.hmm_fac_fo is False: for l, l_raw in zip([("rec_%s" % l) for l in layers], layers): assert l in kwargs out = kwargs[l][i_beam] # Do search for multihead # out is [I, H, 1, J] is new version # out is [I, J, H, 1] for old version if len(out.shape) == 3 and min(out.shape) > 1: out = np.transpose( out, axes=(0, 3, 1, 2)) # [I, J, H, 1] new version out = np.squeeze(out, axis=-1) else: out = np.squeeze(out, axis=1) assert out.shape[0] >= output_len[ i_beam] and out.shape[1] >= encoder_len[i] data[i][l] = out[:output_len[i_beam], :encoder_len[i]] fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % ( model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1]) np.save(fname, data) else: for i in range(n_batch): data[i] = { 'tag': seq_tag[i], 'data': target_data[i], 'classes': target_classes[i], 'output': output[i], 'output_len': output_len[i], 'encoder_len': encoder_len[i], } #if args.hmm_fac_fo is False: if args.instead_save_encoder_decoder: out = kwargs["encoder"][i] data[i]["encoder"] = out[:encoder_len[i]] out_2 = kwargs["decoder"][i] data[i]["decoder"] = out_2[:output_len[i]] else: for l in [("rec_%s" % l) for l in layers]: assert l in kwargs out = kwargs[l][i] # [] # multi-head attention if len(out.shape) == 3 and min(out.shape) > 1: # Multihead attention out = np.transpose( out, axes=(1, 2, 0)) # (I, J, H) new version #out = np.transpose(out, axes=(2, 0, 1)) # [I, J, H] old version else: # RNN out = np.squeeze(out, axis=1) assert out.ndim >= 2 assert out.shape[0] >= output_len[i] and out.shape[ 1] >= encoder_len[i] data[i][l] = out[:output_len[i], :encoder_len[i]] fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % ( model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1]) np.save(fname, data) elif args.output_format == "png": for i in range(n_batch): for l in layers: extra_postfix = "" if args.dropout is not None: extra_postfix += "_dropout%.2f" % args.dropout elif args.train_flag: extra_postfix += "_train" fname = args.dump_dir + '/%s_ep%03d_plt_%05i_%s%s.png' % ( model_name, rnn.engine.epoch, seq_idx[i], l, extra_postfix) att_weights = kwargs["rec_%s" % l][i] att_weights = att_weights.squeeze(axis=2) # (out,enc) assert att_weights.shape[0] >= output_len[ i] and att_weights.shape[1] >= encoder_len[i] att_weights = att_weights[:output_len[i], :encoder_len[i]] print("Seq %i, %s: Dump att weights with shape %r to: %s" % (seq_idx[i], seq_tag[i], att_weights.shape, fname)) plt.matshow(att_weights) title = seq_tag[i] if dataset.can_serialize_data( network.extern_data.default_target): title += "\n" + dataset.serialize_data( network.extern_data.default_target, target_classes[i][:output_len[i]]) ax = plt.gca() tick_labels = [ dataset.serialize_data( network.extern_data.default_target, np.array([x], dtype=target_classes[i].dtype)) for x in target_classes[i][:output_len[i]] ] ax.set_yticklabels([''] + tick_labels, fontsize=8) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) plt.title(title) plt.savefig(fname) plt.close() elif args.output_format == "hdf": assert len(layers) == 1 att_weights = kwargs["rec_%s" % layers[0]] hdf_writer.insert_batch(inputs=att_weights, seq_len={ 0: output_len, 1: encoder_len }, seq_tag=seq_tag) else: raise Exception("output format %r" % args.output_format) runner = Runner(engine=rnn.engine, dataset=dataset, batches=dataset_batch, train=False, train_flag=bool(args.dropout) or args.train_flag, extra_fetches=extra_fetches, extra_fetches_callback=fetch_callback, eval=False) runner.run(report_prefix="att-weights epoch %i" % rnn.engine.epoch) if not args.instead_save_encoder_decoder: for l in layers: stats[l].dump(stream_prefix="Layer %r " % l) if not runner.finalized: print("Some error occured, not finalized.") sys.exit(1) if hdf_writer: hdf_writer.close() rnn.finalize()
def main(argv): argparser = argparse.ArgumentParser(description=__doc__) argparser.add_argument("config_file", type=str, help="RETURNN config, or model-dir") argparser.add_argument("--epoch", type=int) argparser.add_argument('--data', default="train", help="e.g. 'train', 'config:train', or sth like 'config:get_dataset('dev')'") argparser.add_argument('--do_search', default=False, action='store_true') argparser.add_argument('--beam_size', default=12, type=int) argparser.add_argument('--dump_dir', help="for npy or png") argparser.add_argument("--output_file", help="hdf") argparser.add_argument("--device", help="gpu or cpu (default: automatic)") argparser.add_argument("--layers", default=["att_weights"], action="append", help="Layer of subnet to grab") argparser.add_argument("--rec_layer", default="output", help="Subnet layer to grab from; decoder") argparser.add_argument("--enc_layer", default="encoder") argparser.add_argument("--batch_size", type=int, default=5000) argparser.add_argument("--seq_list", default=[], action="append", help="predefined list of seqs") argparser.add_argument("--min_seq_len", default="0", help="can also be dict") argparser.add_argument("--num_seqs", default=-1, type=int, help="stop after this many seqs") argparser.add_argument("--output_format", default="npy", help="npy, png or hdf") argparser.add_argument("--dropout", default=None, type=float, help="if set, overwrites all dropout values") argparser.add_argument("--train_flag", action="store_true") argparser.add_argument("--reset_partition_epoch", type=int, default=1) argparser.add_argument("--reset_seq_ordering", default="sorted_reverse") argparser.add_argument("--reset_epoch_wise_filter", default=None) args = argparser.parse_args(argv[1:]) layers = args.layers assert isinstance(layers, list) config_fn = args.config_file explicit_model_dir = None if os.path.isdir(config_fn): # Assume we gave a model dir. explicit_model_dir = config_fn train_log_dir_config_pattern = "%s/train-*/*.config" % config_fn train_log_dir_configs = sorted(glob(train_log_dir_config_pattern)) assert train_log_dir_configs config_fn = train_log_dir_configs[-1] print("Using this config via model dir:", config_fn) else: assert os.path.isfile(config_fn) model_name = ".".join(config_fn.split("/")[-1].split(".")[:-1]) init_returnn(config_fn=config_fn, args=args) if explicit_model_dir: config.set("model", "%s/%s" % (explicit_model_dir, os.path.basename(config.value('model', '')))) print("Model file prefix:", config.value('model', '')) if args.do_search: raise NotImplementedError min_seq_length = NumbersDict(eval(args.min_seq_len)) assert args.output_format in ["npy", "png", "hdf"] if args.output_format in ["npy", "png"]: assert args.dump_dir if not os.path.exists(args.dump_dir): os.makedirs(args.dump_dir) plt = ticker = None if args.output_format == "png": import matplotlib.pyplot as plt # need to import early? https://stackoverflow.com/a/45582103/133374 import matplotlib.ticker as ticker dataset_str = args.data if dataset_str in ["train", "dev", "eval"]: dataset_str = "config:%s" % dataset_str extra_dataset_kwargs = {} if args.reset_partition_epoch: print("NOTE: We are resetting partition epoch to %i." % (args.reset_partition_epoch,)) extra_dataset_kwargs["partition_epoch"] = args.reset_partition_epoch if args.reset_seq_ordering: print("NOTE: We will use %r seq ordering." % (args.reset_seq_ordering,)) extra_dataset_kwargs["seq_ordering"] = args.reset_seq_ordering if args.reset_epoch_wise_filter: extra_dataset_kwargs["epoch_wise_filter"] = eval(args.reset_epoch_wise_filter) dataset = init_dataset(dataset_str, extra_kwargs=extra_dataset_kwargs) if hasattr(dataset, "epoch_wise_filter") and args.reset_epoch_wise_filter is None: if dataset.epoch_wise_filter: print("NOTE: Resetting epoch_wise_filter to None.") dataset.epoch_wise_filter = None if args.reset_partition_epoch: assert dataset.partition_epoch == args.reset_partition_epoch if args.reset_seq_ordering: assert dataset.seq_ordering == args.reset_seq_ordering init_net(args, layers) network = rnn.engine.network hdf_writer = None if args.output_format == "hdf": assert args.output_file assert len(layers) == 1 sub_layer = network.get_layer("%s/%s" % (args.rec_layer, layers[0])) from HDFDataset import SimpleHDFWriter hdf_writer = SimpleHDFWriter(filename=args.output_file, dim=sub_layer.output.dim, ndim=sub_layer.output.ndim) extra_fetches = { "output": network.layers[args.rec_layer].output.get_placeholder_as_batch_major(), "output_len": network.layers[args.rec_layer].output.get_sequence_lengths(), # decoder length "encoder_len": network.layers[args.enc_layer].output.get_sequence_lengths(), # encoder length "seq_idx": network.get_extern_data("seq_idx"), "seq_tag": network.get_extern_data("seq_tag"), "target_data": network.get_extern_data(network.extern_data.default_input), "target_classes": network.get_extern_data(network.extern_data.default_target), } for l in layers: sub_layer = rnn.engine.network.get_layer("%s/%s" % (args.rec_layer, l)) extra_fetches["rec_%s" % l] = sub_layer.output.get_placeholder_as_batch_major() dataset.init_seq_order(epoch=1, seq_list=args.seq_list or None) # use always epoch 1, such that we have same seqs dataset_batch = dataset.generate_batches( recurrent_net=network.recurrent, batch_size=args.batch_size, max_seqs=rnn.engine.max_seqs, max_seq_length=sys.maxsize, min_seq_length=min_seq_length, max_total_num_seqs=args.num_seqs, used_data_keys=network.used_data_keys) stats = {l: Stats() for l in layers} # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output, output_len, encoder_len, **kwargs): """ :param list[int] seq_idx: len is n_batch :param list[str] seq_tag: len is n_batch :param numpy.ndarray target_data: extern data default input (e.g. "data"), shape e.g. (B,enc-T,...) :param numpy.ndarray target_classes: extern data default target (e.g. "classes"), shape e.g. (B,dec-T,...) :param numpy.ndarray output: rec layer output, shape e.g. (B,dec-T,...) :param numpy.ndarray output_len: rec layer seq len, i.e. decoder length, shape (B,) :param numpy.ndarray encoder_len: encoder seq len, shape (B,) :param kwargs: contains "rec_%s" % l for l in layers, the sub layers (e.g att weights) we are interested in """ n_batch = len(seq_idx) for i in range(n_batch): for l in layers: att_weights = kwargs["rec_%s" % l][i] stats[l].collect(att_weights.flatten()) if args.output_format == "npy": data = {} for i in range(n_batch): data[i] = { 'tag': seq_tag[i], 'data': target_data[i], 'classes': target_classes[i], 'output': output[i], 'output_len': output_len[i], 'encoder_len': encoder_len[i], } for l in [("rec_%s" % l) for l in layers]: assert l in kwargs out = kwargs[l][i] assert out.ndim >= 2 assert out.shape[0] >= output_len[i] and out.shape[1] >= encoder_len[i] data[i][l] = out[:output_len[i], :encoder_len[i]] fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % (model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1]) np.save(fname, data) elif args.output_format == "png": for i in range(n_batch): for l in layers: extra_postfix = "" if args.dropout is not None: extra_postfix += "_dropout%.2f" % args.dropout elif args.train_flag: extra_postfix += "_train" fname = args.dump_dir + '/%s_ep%03d_plt_%05i_%s%s.png' % ( model_name, rnn.engine.epoch, seq_idx[i], l, extra_postfix) att_weights = kwargs["rec_%s" % l][i] att_weights = att_weights.squeeze(axis=2) # (out,enc) assert att_weights.shape[0] >= output_len[i] and att_weights.shape[1] >= encoder_len[i] att_weights = att_weights[:output_len[i], :encoder_len[i]] print("Seq %i, %s: Dump att weights with shape %r to: %s" % ( seq_idx[i], seq_tag[i], att_weights.shape, fname)) plt.matshow(att_weights) title = seq_tag[i] if dataset.can_serialize_data(network.extern_data.default_target): title += "\n" + dataset.serialize_data( network.extern_data.default_target, target_classes[i][:output_len[i]]) ax = plt.gca() tick_labels = [ dataset.serialize_data(network.extern_data.default_target, np.array([x], dtype=target_classes[i].dtype)) for x in target_classes[i][:output_len[i]]] ax.set_yticklabels([''] + tick_labels, fontsize=8) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) plt.title(title) plt.savefig(fname) plt.close() elif args.output_format == "hdf": assert len(layers) == 1 att_weights = kwargs["rec_%s" % layers[0]] hdf_writer.insert_batch(inputs=att_weights, seq_len={0: output_len, 1: encoder_len}, seq_tag=seq_tag) else: raise Exception("output format %r" % args.output_format) runner = Runner(engine=rnn.engine, dataset=dataset, batches=dataset_batch, train=False, train_flag=bool(args.dropout) or args.train_flag, extra_fetches=extra_fetches, extra_fetches_callback=fetch_callback) runner.run(report_prefix="att-weights epoch %i" % rnn.engine.epoch) for l in layers: stats[l].dump(stream_prefix="Layer %r " % l) if not runner.finalized: print("Some error occured, not finalized.") sys.exit(1) if hdf_writer: hdf_writer.close() rnn.finalize()