示例#1
0
    def forward(self, xs, n_speakers, activation=None):
        ilens = [x.shape[0] for x in xs]
        # xs: (B, T, F)
        xs = F.pad_sequence(xs, padding=-1)
        pad_shape = xs.shape
        # emb: (B*T, E)
        emb = self.enc(xs)
        # emb: (B, T, F)
        emb = F.separate(emb.reshape(pad_shape[0], pad_shape[1], -1), axis=0)
        emb = [F.get_item(e, slice(0, ilen)) for e, ilen in zip(emb, ilens)]
        emb2 = [cp.random.permutation(e) for e in emb]

        # get name: main-                 num_speakers=n_speakers, to_train=1
        #           validation/main-      num_speakers=n_speaker,  to_train=0
        #           validation_1/main-    num_speakers=None,       to_train=0
        name = reporter.get_current_reporter()._observer_names[id(self)]
        num_speakers = None if name == "validation_1/main" else n_speakers
        to_train = 1 if name == 'main' else 0
        # h_0: (1, B, F)
        # c_0: (1, B, F)
        h_0, c_0 = self.encoder(emb2)
        # A: (B, n_spk, F)
        # P: (B, n_spk, 1)
        A, P = self.decoder(h_0,
                            c_0,
                            n_speakers=num_speakers,
                            to_train=to_train)
        # yhat: (B, T, n_spk)
        ys = [F.matmul(e, a.T) for a, e in zip(A, emb)]

        return ys, P
示例#2
0
    def __call__(self, trainer):
        """Executes the evaluator extension.

        Unlike usual extensions, this extension can be executed without passing
        a trainer object. This extension reports the performance on validation
        dataset using the :func:`~chainer.report` function. Thus, users can use
        this extension independently from any trainer by manually configuring
        a :class:`~chainer.Reporter` object.

        Args:
            trainer (~chainer.training.Trainer): Trainer object that invokes
                this extension. It can be omitted in case of calling this
                extension manually.

        Returns:
            dict: Result dictionary that contains mean statistics of values
            reported by the evaluation function.

        """

        # set up a reporter
        reporter = reporter_module.Reporter()
        if self.name is not None:
            prefix = self.name + '/'
        else:
            prefix = ''
        for name, target in six.iteritems(self._targets):
            reporter.add_observer(prefix + name, target)
            reporter.add_observers(prefix + name,
                                   target.namedlinks(skipself=True))

        with reporter:
            with configuration.using_config('train', False):
                result = self.evaluate()

        reporter_module.get_current_reporter().report(result)
        return result
示例#3
0
def image(images, name=None, ch_axis=1, row=0, mode=None):
    """summary images to visualize.

    A batch of image is registered on global observation and these images
    are collected by :class:`chainerui.extensions.ImageReporter`. This function
    expects to be used with :class:`chainer.training.Trainer`. If using this
    function without :class:`chainer.training.Trainer`, need to setup
    :class:`chainer.Reporter` before using it.

    Args:
        images (:class:`numpy.ndarray` or :class:`cupy.ndarray` or
            `chainer.Variable`): batch of images or an image.
        name (str): name of image. when not setting, assigned number
            automatically.
        ch_axis (int): index number of channel dimension. set 1 by default.
        row (int): row size to visualize batched images. when set 0,
            show on unstuck. if images set only one image, the row size
            will be ignored.
        mode (str): if the images are not RGB or RGBA space, set their
            color space code. ChainerUI supports 'HSV'.
    """

    current_reporter = reporter.get_current_reporter()
    observer = chainerui_image_observer
    with reporter.report_scope(observer.observation):
        if name is None:
            # TODO(disktnk): support tupled image and increment automatically
            name = '0'
        if isinstance(images, chainer.Variable):
            images = images.data
        images = cuda.to_cpu(images)
        if ch_axis != -1:
            roll_ax = np.append(np.delete(np.arange(images.ndim), ch_axis),
                                ch_axis)
            images = images.transpose(roll_ax)
        value = dict(array=images)
        if row > 0:
            value['row'] = row
        if mode is not None:
            value['mode'] = mode.lower()
        current_reporter.report({name: value}, observer)
 def __call__(self, trainer=None):
     reporter = reporter_.get_current_reporter()
     result = self._evaluate()
     reporter.report(result)
     return result
示例#5
0
文件: vgg.py 项目: hisakaz0/py-utils
    def evaluate(self):
        iterator = self._iterators['main']
        target = self._targets['main']
        eval_func = self.eval_func


        if self.eval_hook:
            self.eval_hook(self)

        if hasattr(iterator, 'reset'):
            iterator.reset()
            it = iterator
        else:
            it = copy.copy(iterator)

        if it.shuffle == True:
            print("This evaluator accepts only sequential iterator")
            raise ValueError

        observers = [object() for _ in range(0, self.num_labels)]
        prefix = 'validation' + '/'
        reporter = reporter_module.get_current_reporter()
        for index, observer in enumerate(observers):
            o = (str(index), observer)
            reporter.add_observer(*o)
            reporter.add_observers(prefix, [o])

        summary = reporter_module.DictSummary()
        len_it  = len(it.dataset) // it.batch_size

        interval = 0
        for index, batch in enumerate(it):
            observation = {}
            with reporter_module.report_scope(observation):
                in_arrays = self.converter(batch, self.device)
                with function.no_backprop_mode():
                    if not isinstance(in_arrays, tuple):
                        raise TypeError
                    images, labels = in_arrays
                    prob, = target(inputs={'data': images}, outputs=['prob'])
                    res = Queue(prob)

                    """ VGGのテストのフロ(ー)チャ(ート)
                    1. activationsが空になるまで繰り返す
                    2. pool_indexより、pool_activationsの空き領域にactivations
                        の先頭から入れる。pool_indexが0の場合、pool_labelを
                        同じlabelsで設定する。
                    3. pool_activationsの状態をチェック
                    4. 満タンなら5、そうでなければ1に移動する。
                    5. eval_funcを実行、pool_activationsを空にし、1へ移動する。
                    """

                    # activations_index = 0
                    while not res.is_full:
                        if self.pool.is_empty:
                            label = labels[res.index]

                        set_remain(self.pool, res)

                        if self.pool.is_full:
                            acc = eval_func(label)
                            report(acc, observers[int(label)])
                            self.pool.index = 0 # reset

            summary.add(observation)
            # print_reportだと、trainingに合わせないとだめなので、自作
            if interval >= self.interval:
                sys.stdout.write("\riteration: {}/{}". format(index, len_it))
                interval = 0
            else:
                interval += 1

        return summary.compute_mean()