コード例 #1
0
 def __init__(self, parent, devices):
     """
 :type parent: TaskThread
 """
     threading.Thread.__init__(self,
                               name="DeviceThread %s" %
                               " ".join([dev.name for dev in devices]))
     self.alloc_devices = devices
     self.parent = parent
     self.devices_batches_idx = None
     self.run_start_batch_idx = None
     self.eval_info = None
     " :type: dict[str] | None "
     self.allocated = False
     self.processing = False
     self.finished = True
     self.crashed = False
     self.num_frames = NumbersDict(0)
     self.run_frames = NumbersDict(0)
     self.daemon = True
     self.active = True
     self.result = {
         'batchess': [],
         'results': [],
         'result_format': None,
         'num_frames': 0
     }
     if self.alloc_devices:
         self.start()
コード例 #2
0
ファイル: __main__.py プロジェクト: e0397123/returnn
def analyze_data(config):  # pylint: disable=redefined-outer-name
    """
  :param Config config:
  """
    dss = config.value('analyze_dataset', 'train')
    ds = {"train": train_data, "dev": dev_data, "eval": eval_data}[dss]
    epoch = config.int('epoch', 1)
    print("Analyze dataset", dss, "epoch", epoch, file=log.v1)
    ds.init_seq_order(epoch=epoch)
    stat_prefix = config.value('statistics_save_prefix', 'statistics')
    dtype = config.value('statistics_dtype', 'float64')
    target = config.value('target', 'classes')
    data_key = config.value('data_key', 'data')
    assert ds.is_data_sparse(target), "need for prior calculation"
    assert not ds.is_data_sparse(data_key), "needed for mean/var estimation"
    from returnn.util.basic import inplace_increment, progress_bar_with_time, NumbersDict

    priors = numpy.zeros((ds.get_data_dim(target), ), dtype=dtype)
    mean = numpy.zeros((ds.get_data_dim(data_key), ), dtype=dtype)
    mean_sq = numpy.zeros((ds.get_data_dim(data_key), ), dtype=dtype)
    total_targets_len = 0
    total_data_len = 0

    # Note: This is not stable! See :class:`Util.Stats` for a better alternative.
    seq_idx = 0
    while ds.is_less_than_num_seqs(seq_idx):
        progress_bar_with_time(ds.get_complete_frac(seq_idx))
        ds.load_seqs(seq_idx, seq_idx + 1)
        targets = ds.get_data(seq_idx, target)
        inplace_increment(priors, targets, 1)
        total_targets_len += targets.shape[0]
        data = ds.get_data(seq_idx, data_key)
        new_total_data_len = total_data_len + data.shape[0]
        f = float(total_data_len) / new_total_data_len
        mean = mean * f + numpy.sum(data, axis=0) * (1.0 - f)
        mean_sq = mean_sq * f + numpy.sum(data * data, axis=0) * (1.0 - f)
        total_data_len = new_total_data_len
        seq_idx += 1
    log_priors = numpy.log(priors)
    log_priors -= numpy.log(NumbersDict(ds.get_num_timesteps())[target])
    std_dev = numpy.sqrt(mean_sq - mean * mean)
    print("Finished. %i total target frames, %i total data frames" %
          (total_targets_len, total_data_len),
          file=log.v1)
    priors_fn = stat_prefix + ".log_priors.txt"
    mean_fn = stat_prefix + ".mean.txt"
    std_dev_fn = stat_prefix + ".std_dev.txt"
    print("Dump priors to", priors_fn, file=log.v1)
    numpy.savetxt(priors_fn, log_priors)
    print("Dump mean to", mean_fn, file=log.v1)
    numpy.savetxt(mean_fn, mean)
    print("Dump std dev to", std_dev_fn, file=log.v1)
    numpy.savetxt(std_dev_fn, std_dev)
    print("Done.", file=log.v1)
コード例 #3
0
 def __init__(self,
              task,
              network,
              devices,
              data,
              batches,
              eval_batch_size=0,
              start_batch=0,
              share_batches=False,
              reduction_rate=1.0,
              report_prefix=None,
              exclude=None,
              epoch=None):
     """
   :type task: str
   :type network: Network.LayerNetwork
   :type devices: list[Device.Device]
   :type data: Dataset.Dataset
   :type batches: EngineBatch.BatchSetGenerator
   :type start_batch: int
   :param str report_prefix: such as epoch or so. only for reporting
   """
     threading.Thread.__init__(self, name="TaskThread %s" % task)
     assert len(devices) > 0
     if eval_batch_size == 0:
         eval_batch_size = sys.maxsize
     self.share_batches = share_batches
     self.eval_batch_size = eval_batch_size
     self.eval_batch_idx = 0
     self.start_batch = start_batch
     self.reduction_rate = reduction_rate
     self.devices = devices
     self.network = network
     self.batches = batches
     self.exclude = exclude
     self.task = task
     self.data = data
     self.daemon = True
     self.elapsed = 0
     self.finalized = False
     self.score = {}
     self.error = {}
     self.results = {}
     self.num_frames = NumbersDict(0)
     self.batch_idx = None
     " :type: int | None "
     self.device_crash_batch = None
     " :type: int | None "
     self.report_prefix = report_prefix or self.task
     self.epoch = epoch
     self.lock = threading.Lock()
     self.start()
コード例 #4
0
 def allocate(self):
     self.devices_batches_idx = self.parent.batches.get_current_batch_idx(
     )
     self.allocated_devices_batches = self.parent.allocate_devices(
         self.alloc_devices)
     self.run_frames = NumbersDict(0)
     for batches, device in zip(self.allocated_devices_batches,
                                self.alloc_devices):
         assert batches
         assert batches[0].seqs
         #assert batches[0].seqs[0].frame_length[1] > 0
         device.num_updates += 1 if not device.update_specs[
             'block_size'] else int(
                 ceil(
                     sum([len(batch.seqs) for batch in batches]) /
                     float(device.update_specs['block_size'])))
         self.run_frames += sum(
             [batch.get_total_num_frames() for batch in batches])
     if self.parent.share_batches:
         self.run_frames /= len(self.alloc_devices)
     assert self.run_frames.max_value() > 0
     self.allocated = True
コード例 #5
0
 def allocate(self):
     self.devices_batches_idx = self.parent.batches.get_current_batch_idx()
     assert len(self.alloc_devices) == 1
     self.devices_batches = [None] * len(self.alloc_devices)
     self.num_frames = NumbersDict(13)
     batch_dim = 1
     self.alloc_devices[0].alloc_data(
         shapes={
             "data": (self.num_frames["data"], batch_dim,
                      config.typed_value("num_inputs")),
             "classes": (self.num_frames["classes"], batch_dim)
         })
     self.parent.num_frames += self.num_frames
     self.allocated = True
コード例 #6
0
def main(argv):
    """
  Main entry.
  """
    argparser = argparse.ArgumentParser(description=__doc__)
    argparser.add_argument("config_file",
                           type=str,
                           help="RETURNN config, or model-dir")
    argparser.add_argument("--epoch", type=int)
    argparser.add_argument(
        '--data',
        default="train",
        help=
        "e.g. 'train', 'config:train', or sth like 'config:get_dataset('dev')'"
    )
    argparser.add_argument('--do_search', default=False, action='store_true')
    argparser.add_argument('--beam_size', default=12, type=int)
    argparser.add_argument('--dump_dir', help="for npy or png")
    argparser.add_argument("--output_file", help="hdf")
    argparser.add_argument("--device", help="gpu or cpu (default: automatic)")
    argparser.add_argument("--layers",
                           default=["att_weights"],
                           action="append",
                           help="Layer of subnet to grab")
    argparser.add_argument("--rec_layer",
                           default="output",
                           help="Subnet layer to grab from; decoder")
    argparser.add_argument("--enc_layer", default="encoder")
    argparser.add_argument("--batch_size", type=int, default=5000)
    argparser.add_argument("--seq_list",
                           default=[],
                           action="append",
                           help="predefined list of seqs")
    argparser.add_argument("--min_seq_len",
                           default="0",
                           help="can also be dict")
    argparser.add_argument("--num_seqs",
                           default=-1,
                           type=int,
                           help="stop after this many seqs")
    argparser.add_argument("--output_format",
                           default="npy",
                           help="npy, png or hdf")
    argparser.add_argument("--dropout",
                           default=None,
                           type=float,
                           help="if set, overwrites all dropout values")
    argparser.add_argument("--train_flag", action="store_true")
    argparser.add_argument("--reset_partition_epoch", type=int, default=1)
    argparser.add_argument("--reset_seq_ordering", default="sorted_reverse")
    argparser.add_argument("--reset_epoch_wise_filter", default=None)
    args = argparser.parse_args(argv[1:])

    layers = args.layers
    assert isinstance(layers, list)
    config_fn = args.config_file
    explicit_model_dir = None
    if os.path.isdir(config_fn):
        # Assume we gave a model dir.
        explicit_model_dir = config_fn
        train_log_dir_config_pattern = "%s/train-*/*.config" % config_fn
        train_log_dir_configs = sorted(glob(train_log_dir_config_pattern))
        assert train_log_dir_configs
        config_fn = train_log_dir_configs[-1]
        print("Using this config via model dir:", config_fn)
    else:
        assert os.path.isfile(config_fn)
    model_name = ".".join(config_fn.split("/")[-1].split(".")[:-1])

    init_returnn(config_fn=config_fn, args=args)
    if explicit_model_dir:
        config.set(
            "model", "%s/%s" %
            (explicit_model_dir, os.path.basename(config.value('model', ''))))
    print("Model file prefix:", config.value('model', ''))

    if args.do_search:
        raise NotImplementedError
    min_seq_length = NumbersDict(eval(args.min_seq_len))

    assert args.output_format in ["npy", "png", "hdf"]
    if args.output_format in ["npy", "png"]:
        assert args.dump_dir
        if not os.path.exists(args.dump_dir):
            os.makedirs(args.dump_dir)
    plt = ticker = None
    if args.output_format == "png":
        import matplotlib.pyplot as plt  # need to import early? https://stackoverflow.com/a/45582103/133374
        import matplotlib.ticker as ticker

    dataset_str = args.data
    if dataset_str in ["train", "dev", "eval"]:
        dataset_str = "config:%s" % dataset_str
    extra_dataset_kwargs = {}
    if args.reset_partition_epoch:
        print("NOTE: We are resetting partition epoch to %i." %
              (args.reset_partition_epoch, ))
        extra_dataset_kwargs["partition_epoch"] = args.reset_partition_epoch
    if args.reset_seq_ordering:
        print("NOTE: We will use %r seq ordering." %
              (args.reset_seq_ordering, ))
        extra_dataset_kwargs["seq_ordering"] = args.reset_seq_ordering
    if args.reset_epoch_wise_filter:
        extra_dataset_kwargs["epoch_wise_filter"] = eval(
            args.reset_epoch_wise_filter)
    dataset = init_dataset(dataset_str, extra_kwargs=extra_dataset_kwargs)
    if hasattr(dataset,
               "epoch_wise_filter") and args.reset_epoch_wise_filter is None:
        if dataset.epoch_wise_filter:
            print("NOTE: Resetting epoch_wise_filter to None.")
            dataset.epoch_wise_filter = None
    if args.reset_partition_epoch:
        assert dataset.partition_epoch == args.reset_partition_epoch
    if args.reset_seq_ordering:
        assert dataset.seq_ordering == args.reset_seq_ordering

    init_net(args, layers)
    network = rnn.engine.network

    hdf_writer = None
    if args.output_format == "hdf":
        assert args.output_file
        assert len(layers) == 1
        sub_layer = network.get_layer("%s/%s" % (args.rec_layer, layers[0]))
        from returnn.datasets.hdf import SimpleHDFWriter
        hdf_writer = SimpleHDFWriter(filename=args.output_file,
                                     dim=sub_layer.output.dim,
                                     ndim=sub_layer.output.ndim)

    extra_fetches = {
        "output":
        network.layers[args.rec_layer].output.get_placeholder_as_batch_major(),
        "output_len":
        network.layers[
            args.rec_layer].output.get_sequence_lengths(),  # decoder length
        "encoder_len":
        network.layers[
            args.enc_layer].output.get_sequence_lengths(),  # encoder length
        "seq_idx":
        network.get_extern_data("seq_idx"),
        "seq_tag":
        network.get_extern_data("seq_tag"),
        "target_data":
        network.get_extern_data(network.extern_data.default_input),
        "target_classes":
        network.get_extern_data(network.extern_data.default_target),
    }
    for layer in layers:
        sub_layer = rnn.engine.network.get_layer("%s/%s" %
                                                 (args.rec_layer, layer))
        extra_fetches[
            "rec_%s" %
            layer] = sub_layer.output.get_placeholder_as_batch_major()
    dataset.init_seq_order(
        epoch=1, seq_list=args.seq_list
        or None)  # use always epoch 1, such that we have same seqs
    dataset_batch = dataset.generate_batches(
        recurrent_net=network.recurrent,
        batch_size=args.batch_size,
        max_seqs=rnn.engine.max_seqs,
        max_seq_length=sys.maxsize,
        min_seq_length=min_seq_length,
        max_total_num_seqs=args.num_seqs,
        used_data_keys=network.used_data_keys)

    stats = {layer: Stats() for layer in layers}

    # (**dict[str,numpy.ndarray|str|list[numpy.ndarray|str])->None
    def fetch_callback(seq_idx, seq_tag, target_data, target_classes, output,
                       output_len, encoder_len, **kwargs):
        """
    :param list[int] seq_idx: len is n_batch
    :param list[str] seq_tag: len is n_batch
    :param numpy.ndarray target_data: extern data default input (e.g. "data"), shape e.g. (B,enc-T,...)
    :param numpy.ndarray target_classes: extern data default target (e.g. "classes"), shape e.g. (B,dec-T,...)
    :param numpy.ndarray output: rec layer output, shape e.g. (B,dec-T,...)
    :param numpy.ndarray output_len: rec layer seq len, i.e. decoder length, shape (B,)
    :param numpy.ndarray encoder_len: encoder seq len, shape (B,)
    :param kwargs: contains "rec_%s" % l for l in layers, the sub layers (e.g att weights) we are interested in
    """
        n_batch = len(seq_idx)
        for i in range(n_batch):
            # noinspection PyShadowingNames
            for layer in layers:
                att_weights = kwargs["rec_%s" % layer][i]
                stats[layer].collect(att_weights.flatten())
        if args.output_format == "npy":
            data = {}
            for i in range(n_batch):
                data[i] = {
                    'tag': seq_tag[i],
                    'data': target_data[i],
                    'classes': target_classes[i],
                    'output': output[i],
                    'output_len': output_len[i],
                    'encoder_len': encoder_len[i],
                }
                # noinspection PyShadowingNames
                for layer in [("rec_%s" % layer) for layer in layers]:
                    assert layer in kwargs
                    out = kwargs[layer][i]
                    assert out.ndim >= 2
                    assert out.shape[0] >= output_len[i] and out.shape[
                        1] >= encoder_len[i]
                    data[i][layer] = out[:output_len[i], :encoder_len[i]]
                fname = args.dump_dir + '/%s_ep%03d_data_%i_%i.npy' % (
                    model_name, rnn.engine.epoch, seq_idx[0], seq_idx[-1])
                np.save(fname, data)
        elif args.output_format == "png":
            for i in range(n_batch):
                # noinspection PyShadowingNames
                for layer in layers:
                    extra_postfix = ""
                    if args.dropout is not None:
                        extra_postfix += "_dropout%.2f" % args.dropout
                    elif args.train_flag:
                        extra_postfix += "_train"
                    fname = args.dump_dir + '/%s_ep%03d_plt_%05i_%s%s.png' % (
                        model_name, rnn.engine.epoch, seq_idx[i], layer,
                        extra_postfix)
                    att_weights = kwargs["rec_%s" % layer][i]
                    att_weights = att_weights.squeeze(axis=2)  # (out,enc)
                    assert att_weights.shape[0] >= output_len[
                        i] and att_weights.shape[1] >= encoder_len[i]
                    att_weights = att_weights[:output_len[i], :encoder_len[i]]
                    print("Seq %i, %s: Dump att weights with shape %r to: %s" %
                          (seq_idx[i], seq_tag[i], att_weights.shape, fname))
                    plt.matshow(att_weights)
                    title = seq_tag[i]
                    if dataset.can_serialize_data(
                            network.extern_data.default_target):
                        title += "\n" + dataset.serialize_data(
                            network.extern_data.default_target,
                            target_classes[i][:output_len[i]])
                        ax = plt.gca()
                        tick_labels = [
                            dataset.serialize_data(
                                network.extern_data.default_target,
                                np.array([x], dtype=target_classes[i].dtype))
                            for x in target_classes[i][:output_len[i]]
                        ]
                        ax.set_yticklabels([''] + tick_labels, fontsize=8)
                        ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
                    plt.title(title)
                    plt.savefig(fname)
                    plt.close()
        elif args.output_format == "hdf":
            assert len(layers) == 1
            att_weights = kwargs["rec_%s" % layers[0]]
            hdf_writer.insert_batch(inputs=att_weights,
                                    seq_len={
                                        0: output_len,
                                        1: encoder_len
                                    },
                                    seq_tag=seq_tag)
        else:
            raise Exception("output format %r" % args.output_format)

    runner = Runner(engine=rnn.engine,
                    dataset=dataset,
                    batches=dataset_batch,
                    train=False,
                    train_flag=bool(args.dropout) or args.train_flag,
                    extra_fetches=extra_fetches,
                    extra_fetches_callback=fetch_callback)
    runner.run(report_prefix="att-weights epoch %i" % rnn.engine.epoch)
    for layer in layers:
        stats[layer].dump(stream_prefix="Layer %r " % layer)
    if not runner.finalized:
        print("Some error occured, not finalized.")
        sys.exit(1)

    if hdf_writer:
        hdf_writer.close()
    rnn.finalize()
コード例 #7
0
 def __init__(self):
     self.num_frames = NumbersDict(0)
     self.y = {}
     self.j = {}
コード例 #8
0
    def run_inner(self):
        self.start_time = time.time()
        for device in self.devices:
            device.prepare(epoch=self.epoch, **self.get_device_prepare_args())
        self.initialize()
        terminal_width, _ = terminal_size()
        self.interactive = (log.v[3] and terminal_width >= 0)
        print("starting task", self.task, file=log.v5)

        for device in self.devices:
            device.eval_batch_idx = -1
            device.start_epoch_stats()
            device.num_frames = 0
            device.num_updates = 0
            device.tot = 0

        num_device_runs = 1 if self.share_batches else len(self.devices)
        deviceRuns = [
            self.DeviceBatchRun(
                self,
                [self.devices[i]] if not self.share_batches else self.devices)
            for i in range(num_device_runs)
        ]

        results = {'batchess': [], 'results': [], 'num_frames': NumbersDict(0)}
        run_frames = NumbersDict(0)
        cost_result_format = -1

        crashed = False
        assert num_device_runs > 0

        while True:
            if getattr(sys, "exited", False):
                # This happens when we exit Python.
                # Without this check, this thread would keep running until all exit handlers of Python are done.
                print("%s stopped" % self, file=log.v5)
                crashed = True
                break

            for i in range(num_device_runs):
                if deviceRuns[i].crashed or not deviceRuns[i].is_alive():
                    crashed = True
                    break
                if deviceRuns[i].finished:
                    results['batchess'] += deviceRuns[i].result['batchess'][:]
                    results['results'] += deviceRuns[i].result['results'][:]
                    results['result_format'] = deviceRuns[i].result[
                        'result_format']
                    deviceRuns[i].finished = False
            if crashed:
                break

            if cost_result_format < 0 and deviceRuns[i].result['result_format']:
                for idx, fmt in enumerate(
                        deviceRuns[i].result['result_format']):
                    if fmt and fmt.startswith('cost:'):
                        cost_result_format = idx
            total_cost = 0
            if results['results'] and cost_result_format >= 0:
                total_cost = numpy.asarray(
                    results['results'])[:, cost_result_format].sum()
            if total_cost >= self.eval_batch_size or not self.batches.has_more(
            ):
                if all(not (dev.finished or dev.allocated or dev.processing)
                       for dev in deviceRuns):
                    results['num_frames'] = run_frames
                    self.num_frames += run_frames
                    if self.share_batches: run_frames *= len(self.devices)
                    self.reduce(run_frames)
                    self.eval_batch_idx += 1
                    run_frames = NumbersDict(0)
                    results['batchess'] = []
                    results['results'] = []
                    for device in self.devices:
                        device.num_frames = 0
                        device.num_updates = 0
                    if not self.batches.has_more():
                        break
                else:
                    time.sleep(0.01)

            match = True
            while self.batches.has_more(
            ) and total_cost < self.eval_batch_size and match:
                self.batch_idx = self.batches.get_current_batch_idx()
                if self.batch_idx < self.start_batch:
                    self.batches.advance(1)
                    break
                match = False
                for i in range(num_device_runs):
                    if not deviceRuns[i].allocated:
                        deviceRuns[i].allocate()
                        run_frames += deviceRuns[i].run_frames
                        match = True
                        break
            if not match:
                time.sleep(0.01)

        for run in deviceRuns:
            run.stop()
        if crashed: return
        for device in self.devices:
            device.finish_epoch_stats()
        self.finalize()
        if self.interactive: progress_bar()
        self.elapsed = (time.time() - self.start_time)
コード例 #9
0
    class DeviceBatchRun(threading.Thread):
        def __init__(self, parent, devices):
            """
        :type parent: TaskThread
        """
            threading.Thread.__init__(self,
                                      name="DeviceThread %s" %
                                      " ".join([dev.name for dev in devices]))
            self.alloc_devices = devices
            self.parent = parent
            self.devices_batches_idx = None
            self.run_start_batch_idx = None
            self.eval_info = None
            " :type: dict[str] | None "
            self.allocated = False
            self.processing = False
            self.finished = True
            self.crashed = False
            self.num_frames = NumbersDict(0)
            self.run_frames = NumbersDict(0)
            self.daemon = True
            self.active = True
            self.result = {
                'batchess': [],
                'results': [],
                'result_format': None,
                'num_frames': 0
            }
            if self.alloc_devices:
                self.start()

        def allocate(self):
            self.devices_batches_idx = self.parent.batches.get_current_batch_idx(
            )
            self.allocated_devices_batches = self.parent.allocate_devices(
                self.alloc_devices)
            self.run_frames = NumbersDict(0)
            for batches, device in zip(self.allocated_devices_batches,
                                       self.alloc_devices):
                assert batches
                assert batches[0].seqs
                #assert batches[0].seqs[0].frame_length[1] > 0
                device.num_updates += 1 if not device.update_specs[
                    'block_size'] else int(
                        ceil(
                            sum([len(batch.seqs) for batch in batches]) /
                            float(device.update_specs['block_size'])))
                self.run_frames += sum(
                    [batch.get_total_num_frames() for batch in batches])
            if self.parent.share_batches:
                self.run_frames /= len(self.alloc_devices)
            assert self.run_frames.max_value() > 0
            self.allocated = True

        def finish(self):
            """
        :returns whether everything is fine.
        """
            device_results, outputs_format = self.device_collect_results()
            if device_results is None:
                if not getattr(sys, "exited", False):
                    print("device crashed on batch",
                          self.run_start_batch_idx,
                          file=log.v3)
                self.parent.device_crash_batch = self.run_start_batch_idx
                self.crashed = True
                return False
            assert len(device_results) == len(self.alloc_devices) == len(
                self.running_devices_batches)

            if outputs_format and any(
                [k.startswith("gparam:") for k in outputs_format]):
                # WARNING: this code is untested and likely broken!
                for i in range(len(self.alloc_devices)):
                    res = Device.make_result_dict(device_results[i],
                                                  outputs_format)
                    self.alloc_devices[i].sync_net_train_params()
                    devnet = self.alloc_devices[i].get_net_train_params(
                        self.parent.network)
                    vars = self.parent.network.get_all_params_vars()
                    for p, q in zip(vars, devnet):
                        p.set_value(q)
                    gparams = {}
                    for p in vars:
                        gparams[p] = numpy.zeros(p.get_value(
                            borrow=True, return_internal_type=True).shape,
                                                 dtype=theano.config.floatX)
                    for p in vars:
                        q = res["gparam:%s" % p.name]
                        if q.shape == p.get_value().shape:
                            gparams[p] = q
                        elif q.shape:
                            print(
                                "warning: shape for gradient does not match:",
                                p.get_value().shape,
                                q.shape,
                                file=log.v2)
                    self.parent.updater.setNetParamDeltas(gparams)
                    self.parent.updater.update()
                    self.alloc_devices[i].set_net_params(self.parent.network)

            self.result = {
                'batchess': self.running_devices_batches,
                'results': device_results,
                'result_format': outputs_format,
                'num_frames': self.num_frames
            }
            self.eval_info = self.parent.evaluate(**self.result)
            self.parent.lock.acquire()
            self.print_process()
            self.parent.lock.release()
            return True

        def run(self):
            try:
                while self.active and not getattr(sys, "exited", False):
                    if self.allocated and not self.finished:
                        self.device_run()
                        self.num_frames = self.run_frames
                        self.processing = True
                        self.allocated = False
                        self.finish()
                        self.finished = True
                        self.processing = False
                    else:
                        time.sleep(0.01)
            except BaseException:
                self.crashed = True
                sys.excepthook(*sys.exc_info())
            finally:
                self.finished = True

        def stop(self):
            self.active = False

        def device_run(self):
            batch_idx = self.run_start_batch_idx = self.devices_batches_idx
            assert len(self.alloc_devices) == len(
                self.allocated_devices_batches)
            self.running_devices_batches = self.allocated_devices_batches
            for device, batches in zip(self.alloc_devices,
                                       self.running_devices_batches):
                if self.parent.network.recurrent:
                    print("running", device.targets["data"].shape[1], \
                                     "sequence slices (%i nts)" % (device.targets["data"].shape[0] * device.targets["data"].shape[1]), end=' ', file=log.v5)
                else:
                    print("running",
                          device.targets["data"].shape[0] *
                          device.targets["data"].shape[1],
                          "frames",
                          end=' ',
                          file=log.v5)
                if device.num_batches == 1:
                    print("of batch %i" % batch_idx, end=' ', file=log.v5)
                else:
                    print("of batches %i-%i" %
                          (batch_idx, batch_idx + device.num_batches - 1),
                          end=' ',
                          file=log.v5)
                print("on device", device.name, file=log.v5)
                device.run(self.parent.task)

        #if not self.share batch_idx += device.num_batches

        def device_collect_results(self):
            device_results = []
            outputs_format = None
            for i, device in enumerate(self.alloc_devices):
                try:
                    result, outputs_format_new = device.result()
                except RuntimeError:
                    return None, None
                if result is None:
                    return None, None
                assert isinstance(result, list)
                assert len(result) > 0  # we always expect to get some result
                if i >= 1:
                    assert outputs_format == outputs_format_new, "We expect to always get the same output format."
                outputs_format = outputs_format_new
                device_results.append(result)
            return device_results, outputs_format

        def device_mem_usage_str(self, devices):
            """
        :type devices: list[Device.Device]
        :rtype: str | None
        """
            if not devices:
                return None
            mem_info = [device.get_memory_info() for device in devices]
            if len(mem_info) == 1 and mem_info[0] is None:
                return None
            mem_usage = [info.used if info else None for info in mem_info]
            s = [
                "%s MB" % (mem /
                           (1024 * 1024)) if mem is not None else "unknown"
                for mem in mem_usage
            ]
            return "/".join(s)

        def print_process(self):
            if not self.parent.interactive and not log.v[5]:
                return
            start_elapsed = time.time() - self.parent.start_time
            complete = self.parent.batches.completed_frac()
            assert complete > 0
            total_time_estimated = start_elapsed / complete
            remaining_estimated = total_time_estimated - start_elapsed
            if log.verbose[5]:
                mem_usage = self.device_mem_usage_str(self.alloc_devices)
                info = [
                    self.parent.report_prefix,
                    "batch %i" % self.run_start_batch_idx
                ]
                if self.eval_info:  # Such as score.
                    info += [
                        "%s %s" % item
                        for item in sorted(self.eval_info.items())
                    ]
                info += [
                    "elapsed %s" % hms(start_elapsed),
                    "exp. remaining %s" % hms(remaining_estimated),
                    "complete %.02f%%" % (complete * 100)
                ]
                if mem_usage:
                    info += ["memory %s" % mem_usage]
                print(", ".join(filter(None, info)), file=log.v5)
            if self.parent.interactive:
                progress_bar(complete, hms(remaining_estimated))
コード例 #10
0
def test_iterate_seqs_chunking_varying_sequence_length():
    dataset = DummyDatasetMultipleSequenceLength(input_dim=2,
                                                 output_dim=3,
                                                 num_seqs=2,
                                                 seq_len={
                                                     'data': 24,
                                                     'classes': 12
                                                 })
    dataset.init_seq_order(1)
    seqs = list(
        dataset.iterate_seqs(chunk_size={
            'data': 12,
            'classes': 6
        },
                             chunk_step={
                                 'data': 6,
                                 'classes': 3
                             },
                             used_data_keys=None))
    for s in seqs:
        print(s)
    assert_equal(len(seqs), 8)
    assert_equal(seqs[0], (0, NumbersDict({
        'data': 0,
        'classes': 0
    }), NumbersDict({
        'data': 12,
        'classes': 6
    })))
    assert_equal(seqs[1], (0, NumbersDict({
        'data': 6,
        'classes': 3
    }), NumbersDict({
        'data': 18,
        'classes': 9
    })))
    assert_equal(seqs[2], (0, NumbersDict({
        'data': 12,
        'classes': 6
    }), NumbersDict({
        'data': 24,
        'classes': 12
    })))
    assert_equal(seqs[3], (0, NumbersDict({
        'data': 18,
        'classes': 9
    }), NumbersDict({
        'data': 24,
        'classes': 12
    })))
    assert_equal(seqs[4], (1, NumbersDict({
        'data': 0,
        'classes': 0
    }), NumbersDict({
        'data': 12,
        'classes': 6
    })))
    assert_equal(seqs[5], (1, NumbersDict({
        'data': 6,
        'classes': 3
    }), NumbersDict({
        'data': 18,
        'classes': 9
    })))
    assert_equal(seqs[6], (1, NumbersDict({
        'data': 12,
        'classes': 6
    }), NumbersDict({
        'data': 24,
        'classes': 12
    })))
    assert_equal(seqs[7], (1, NumbersDict({
        'data': 18,
        'classes': 9
    }), NumbersDict({
        'data': 24,
        'classes': 12
    })))
コード例 #11
0
def analyze_dataset(options):
    """
  :param options: argparse.Namespace
  """
    print("Epoch: %i" % options.epoch, file=log.v3)
    print("Dataset keys:", dataset.get_data_keys(), file=log.v3)
    print("Dataset target keys:", dataset.get_target_list(), file=log.v3)
    assert options.key in dataset.get_data_keys()

    terminal_width, _ = util.terminal_size()
    show_interactive_process_bar = (log.verbose[3] and (not log.verbose[5])
                                    and terminal_width >= 0)

    start_time = time.time()
    num_seqs_stats = Stats()
    if options.endseq < 0:
        options.endseq = float("inf")

    recurrent = True
    used_data_keys = dataset.get_data_keys()
    batch_size = config.typed_value('batch_size', 1)
    max_seqs = config.int('max_seqs', -1)
    seq_drop = config.float('seq_drop', 0.0)
    max_seq_length = config.typed_value(
        'max_seq_length', None) or config.float('max_seq_length', 0)
    max_pad_size = config.typed_value("max_pad_size", None)

    batches = dataset.generate_batches(recurrent_net=recurrent,
                                       batch_size=batch_size,
                                       max_seqs=max_seqs,
                                       max_seq_length=max_seq_length,
                                       max_pad_size=max_pad_size,
                                       seq_drop=seq_drop,
                                       used_data_keys=used_data_keys)

    step = 0
    total_num_seqs = 0
    total_num_frames = NumbersDict()
    total_num_used_frames = NumbersDict()

    try:
        while batches.has_more():
            # See FeedDictDataProvider.
            batch, = batches.peek_next_n(1)
            assert isinstance(batch, Batch)
            if batch.start_seq > options.endseq:
                break
            dataset.load_seqs(batch.start_seq, batch.end_seq)
            complete_frac = batches.completed_frac()
            start_elapsed = time.time() - start_time
            try:
                num_seqs_s = str(dataset.num_seqs)
            except NotImplementedError:
                try:
                    num_seqs_s = "~%i" % dataset.estimated_num_seqs
                except TypeError:  # a number is required, not NoneType
                    num_seqs_s = "?"
            progress_prefix = "%i/%s" % (batch.start_seq, num_seqs_s)
            progress = "%s (%.02f%%)" % (progress_prefix, complete_frac * 100)
            if complete_frac > 0:
                total_time_estimated = start_elapsed / complete_frac
                remaining_estimated = total_time_estimated - start_elapsed
                progress += " (%s)" % hms(remaining_estimated)

            batch_max_time = NumbersDict.max(
                [seq.frame_length for seq in batch.seqs]) * len(batch.seqs)
            batch_num_used_frames = sum(
                [seq.frame_length for seq in batch.seqs], NumbersDict())
            total_num_seqs += len(batch.seqs)
            num_seqs_stats.collect(numpy.array([len(batch.seqs)]))
            total_num_frames += batch_max_time
            total_num_used_frames += batch_num_used_frames

            print("%s, batch %i, num seqs %i, frames %s, used %s (%s)" %
                  (progress, step, len(
                      batch.seqs), batch_max_time, batch_num_used_frames,
                   batch_num_used_frames / batch_max_time),
                  file=log.v5)
            if show_interactive_process_bar:
                util.progress_bar_with_time(complete_frac,
                                            prefix=progress_prefix)

            step += 1
            batches.advance(1)

    finally:
        print("Done. Total time %s. More seqs which we did not dumped: %s" %
              (hms(time.time() - start_time), batches.has_more()),
              file=log.v2)
        print("Dataset epoch %i, order %r." %
              (dataset.epoch, dataset.seq_ordering))
        print("Num batches (steps): %i" % step, file=log.v1)
        print("Num seqs: %i" % total_num_seqs, file=log.v1)
        num_seqs_stats.dump(stream=log.v1, stream_prefix="Batch num seqs ")
        for key in used_data_keys:
            print("Data key %r:" % key, file=log.v1)
            print("  Num frames: %s" % total_num_frames[key], file=log.v1)
            print("  Num used frames: %s" % total_num_used_frames[key],
                  file=log.v1)
            print("  Fraction used frames: %s" %
                  (total_num_used_frames / total_num_frames)[key],
                  file=log.v1)
        dataset.finish_epoch()