Пример #1
0
 def animate(i):
     text.set_text(str(time[i]))
     im1.set_data(to_np(frames1[i]))
     im2.set_data(to_np(frames2[i]))
     im3.set_data(to_np(frames3[i]))
     im4.set_data(to_np(frames4[i]))
     return im1, im2, im3, text
Пример #2
0
 def init():
     text.set_text("")
     im1.set_data(to_np(frames1[0]))
     im2.set_data(to_np(frames2[0]))
     im3.set_data(to_np(frames3[0]))
     im4.set_data(to_np(frames4[0]))
     return im1, im2, im3, im4, text
Пример #3
0
    def predict(self, initial_input, initial_state, nr_predictions):
        self.timer.begin("predict")
        
        (T,H) = (nr_predictions,initial_state.shape[0])
        (M,N) = initial_input.shape
        inp   = to_bh(initial_input)
        state = to_bh(initial_state)

        states  = bh.empty((T,H),dtype=state.dtype)
        outputs = bh.empty((T,M,N),dtype=inp.dtype)

        for i in range(T):
            state = self.esn_cell.forward(inp, state)
            S = to_np(state)
            I = to_np(inp).reshape(-1)
            O = to_np(self.ones)
            #            ext_state = bh.concatenate([self.ones, inp.reshape(-1), state], axis=0)
            ext_state = to_bh(np.concatenate([O,I,S], axis=0))
            output = bh_dot(self.wout, ext_state).reshape((M,N))
            
            inp        = output
            outputs[i] = to_bh(output)
            states[i]  = to_bh(state)

        self.timer.end()
        return outputs, states
Пример #4
0
def train_predict_esn(model, dataset, outdir=None, shuffle=False, steps=1,
                      step_length=1, step_start=0):
    if outdir is not None and not isinstance(outdir, pathlib.Path):
        outdir = pathlib.Path(outdir)

    tlen = model.params.transient_length
    hidden_size = model.esn_cell.hidden_size
    backend = model.params.backend
    dtype = model.esn_cell.dtype

    if shuffle:
        logger.warning("If shuffle is True `step_length` has no effect!")

    for ii in range(steps):
        model.timer.reset()
        
        logger.info(f"--- Train/Predict Step Nr. {ii+1} ---")
        if shuffle:
            idx = np.random.randint(low=0, high=len(dataset))
        else:
            idx = ii * step_length + step_start
        inputs, labels, pred_labels = dataset[idx]

        logger.info(f"Creating {inputs.shape[0]} training states")
        zero_state = initial_state(hidden_size, dtype, backend)
        _, states = model.forward(inputs, zero_state, states_only=True)

        if outdir is not None:
            outfile = outdir / f"train_data_idx{idx}.nc"
            logger.info(f"Saving training to {outfile}")
            dump_training(outfile, dataset, idx, states=states)

        logger.info("Optimizing output weights")
        model.optimize(inputs=inputs[tlen:], states=states[tlen:], labels=labels[tlen:])

        if outdir is not None:
            save_model(outdir, model, prefix=f"idx{idx}")

        logger.info(f"Predicting the next {model.params.pred_length} frames")
        init_inputs = labels[-1]
        outputs, out_states = model.predict(
            init_inputs, states[-1], nr_predictions=model.params.pred_length)

        logger.info(model.timer.pretty_print())        
        
        if outdir is not None:
            outfile = outdir / f"pred_data_idx{idx}.nc"
            logger.info(f"Saving prediction to {outfile}")
            dump_prediction(
                outfile, outputs=to_np(outputs), labels=to_np(pred_labels), states=to_np(out_states))

    logger.info(f"Done")
    return model, outputs, pred_labels
Пример #5
0
def resample2d_numpy(image,size,timer=None):
    start_timer(timer,"resample2d bilinear numpy")

    (N,M) = image.shape[-2:];
    (n,m) = size;

    xs = np.linspace(1,M-1,m)[None,:];
    ys = np.linspace(1,N-1,n)[:,None];

    yminus,iminus= np.modf(ys-0.5);
    yplus, iplus = np.modf(ys+0.5);
    xminus,jminus= np.modf(xs-0.5);
    xplus, jplus = np.modf(xs+0.5);     

    I = to_np(image.reshape((-1,M*N)))
    
    LD = (iminus*M + jminus).astype(np.uint64); # x-,y-
    LU = (iplus *M + jminus).astype(np.uint64); # x-,y+
    RD = (iminus*M + jplus).astype(np.uint64);  # x+,y-
    RU = (iplus *M + jplus).astype(np.uint64);  # x+,y+

    I_bilin = (1-xminus)*(1-yminus)*I[:,LD] \
             +(1-xminus)*yplus     *I[:,LU] \
             +xplus     *yplus     *I[:,RD] \
             +xplus*(1-yminus)     *I[:,RU];    

    new_shape = image.shape[:-2]+(n,m)
    end_timer(timer)
    return I_bilin.reshape(new_shape)
Пример #6
0
def tikhonov(inputs, states, labels, beta):
    X = to_np(_extended_states(inputs, states))

    Id = np.eye(X.shape[0])
    A = np.dot(X, X.T) + beta + Id
    B = np.dot(X, labels)

    # Solve linear system instead of calculating inverse
    wout = np.linalg.solve(A, B)
    return to_bh(wout.T)
Пример #7
0
def dump_training(fname, dataset, idx, states, attrs=None):
    inputs, labels, pred_labels = dataset[idx]

    if not isinstance(inputs, np.ndarray):
        raise ValueError("Check that this acutally works...")
        msg = "Inputs are not numpy arrays. " \
              "Assuming Tensors of shape [time, batch, features]"
        logger.debug(msg)
        inputs = inputs.numpy().reshape([-1, inputs.size(2)])
        labels = labels.numpy().reshape([-1, labels.size(2)])
        states = states.numpy().reshape([-1, states.size(2)])
        pred_labels = pred_labels.numpy().reshape([-1, pred_labels.size(2)])

    if not isinstance(fname, pathlib.Path):
        fname = pathlib.Path(fname)
    if not fname.parent.exists():
        fname.parent.mkdir(parents=True)

    with nc.Dataset(fname, "w") as dst:

        dst.createDimension("train_length", inputs.shape[0])
        dst.createDimension("pred_length", pred_labels.shape[0])
        dst.createDimension("image_height", inputs.shape[1])
        dst.createDimension("image_width", inputs.shape[2])
        dst.createDimension("hidden_size", states.shape[1])

        dst.createVariable("inputs", float, ["train_length", "image_height", "image_width"])
        dst.createVariable("labels", float, ["train_length", "image_height", "image_width"])
        dst.createVariable("states", float, ["train_length", "hidden_size"])
        dst.createVariable("pred_labels", float, ["pred_length", "image_height", "image_width"])

        if "cycle_length" in dataset.params.dict:
            dump_cycles(dst, dataset)

        if attrs is not None:
            dst.setncatts(attrs)

        dst["inputs"][:] = to_np(inputs)
        dst["labels"][:] = to_np(labels)
        dst["states"][:] = to_np(states)
        dst["pred_labels"][:] = to_np(pred_labels)
Пример #8
0
    def forward(self, image, state):
        start_timer(self.timer,"forward")
        
        self.check_dtypes(image, state)

        input_stack = self.input_map(to_np(image))       # np
        x_input     = to_bh(bh.concatenate(input_stack)) # np -> bh
        x_state     = self.state_map(to_bh(state))       # bh
        start_timer(self.timer,"tanh")
        new_state   = bh.tanh(x_input+x_state)      # bh
        end_timer(self.timer) # /tanh
        end_timer(self.timer) # /forward
        return new_state
Пример #9
0
def animate_imshow(frames, time=None, vmin=None, vmax=None,
                   cmap_name="inferno", figsize=(8, 5)):
    def _blit_draw(self, artists, bg_cache):
        # Handles blitted drawing, which renders only the artists given instead
        # of the entire figure.
        updated_ax = []
        for a in artists:
            # If we haven't cached the background for this axes object, do
            # so now. This might not always be reliable, but it's an attempt
            # to automate the process.
            if a.axes not in bg_cache:
                # bg_cache[a.axes] = a.figure.canvas.copy_from_bbox(a.axes.bbox)
                # change here
                bg_cache[a.axes] = a.figure.canvas.copy_from_bbox(a.axes.figure.bbox)
            a.axes.draw_artist(a)
            updated_ax.append(a.axes)

        # After rendering all the needed artists, blit each axes individually.
        for ax in set(updated_ax):
            # and here
            # ax.figure.canvas.blit(ax.bbox)
            ax.figure.canvas.blit(ax.figure.bbox)
    matplotlib.animation.Animation._blit_draw = _blit_draw
    fig = plt.figure(figsize=figsize)
    ax = plt.gca()
    im = ax.imshow(to_np(frames[0]), animated=True, vmin=vmin, vmax=vmax,
                   cmap=plt.get_cmap(cmap_name))
    plt.colorbar(im)
    text = ax.text(.5, 1.05, '', transform=ax.transAxes, va='center')

    if time is None:
        time = np.arange(frames.shape[0])

    def init():
        text.set_text("")
        im.set_data(to_np(frames[0]))
        return im, text

    def animate(i):
        text.set_text(str(time[i]))
        im.set_data(to_np(frames[i]))
        return im, text

    anim = animation.FuncAnimation(fig, animate, init_func=init,
                                   frames=len(frames), interval=20, blit=True)
    return anim
Пример #10
0
def _extended_states(inputs, states):
    ones = np.ones([inputs.shape[0], 1], dtype=inputs.dtype)
    X    = np.concatenate([ones, to_np(inputs), to_np(states)], axis=1).T
    return X
Пример #11
0
 def animate(i):
     text.set_text(str(time[i]))
     im.set_data(to_np(frames[i]))
     return im, text
Пример #12
0
def plot_iteration(model, idx, inp, state):
    new_state   = to_np(model.esn_cell.forward(inp, state))
    input_stack = to_np(model.esn_cell.input_map(inp))
    x_input     = to_np(model.esn_cell.cat_input_map(input_stack))
    x_state     = to_np(model.esn_cell.state_map(state))

    def vec_to_rect(vec):
        size = int(np.ceil(vec.shape[0]**.5))
        shape = (size, size)
        pad = np.zeros(size * size - vec.shape[0])
        rect = np.concatenate([vec, pad], axis=0).reshape(shape)
        return rect

    nr_plots_to_dims = {
        6: (2, 3),
        7: (2, 4),
        8: (2, 4),
        9: (3, 3),
        10: (2, 5),
        11: (3, 4),
        12: (3, 4),
        13: (3, 5),
        14: (3, 5),
        15: (3, 5),
        16: (4, 4),
        17: (3, 6),
        18: (3, 6)}
    nr_plots = len(input_stack) + 5
    if nr_plots not in nr_plots_to_dims:
        raise ValueError("Too many input_map_specs to plot")

    height, width = nr_plots_to_dims[nr_plots]
    fig, ax = plt.subplots(height, width, figsize=(10, 10))
    ax = ax.flatten() if isinstance(ax, np.ndarray) else [ax]

    im = ax[0].imshow(inp)
    ax[0].set_title("image")
    plt.colorbar(im, ax=ax[0])

    im = ax[1].imshow(vec_to_rect(state))
    ax[1].set_title("state")
    plt.colorbar(im, ax=ax[1])

    for i in range(nr_plots - 5):
        x = input_stack[i]
        spec = model.esn_cell.input_map_specs[i]
        axi = ax[i+2]

        if spec["type"] == "random_weights":
            arr = vec_to_rect(x)
        else:
            arr = x.reshape(spec["dbg_size"])
        im = axi.imshow(arr)
        axi.set_title(f"Win(image)_{spec['type']} spec: {i}")
        plt.colorbar(im, ax=axi)

    im = ax[-3].imshow(vec_to_rect(x_state))
    ax[-3].set_title("W(state)")
    plt.colorbar(im, ax=ax[-3])

    im = ax[-2].imshow(vec_to_rect(x_state + x_input))
    ax[-2].set_title("W(state) + Win(image)")
    plt.colorbar(im, ax=ax[-2])

    im = ax[-1].imshow(vec_to_rect(new_state))
    ax[-1].set_title("tanh(W(state) + Win(image))")
    plt.colorbar(im, ax=ax[-1])

    fig.suptitle(f"Iteration {idx}")
    plt.show()
Пример #13
0
 def animate(i):
     t1.set_text(str(time[i]))
     t2.set_text(str(time[i]))
     im1.set_data(to_np(frames1[i]))
     im2.set_data(to_np(frames2[i]))
     return im1, im2, t1, t2
Пример #14
0
 def init():
     t1.set_text("")
     t2.set_text("")
     im1.set_data(to_np(frames1[0]))
     im2.set_data(to_np(frames2[0]))
     return im1, im2, t1, t2
Пример #15
0
def animate_double_imshow(frames1, frames2,
                          time=None, vmin=None, vmax=None,
                          cmap_name="inferno", figsize=(12, 4), title=None, labels=None):
    def _blit_draw(self, artists, bg_cache):
        # Handles blitted drawing, which renders only the artists given instead
        # of the entire figure.
        updated_ax = []
        for a in artists:
            # If we haven't cached the background for this axes object, do
            # so now. This might not always be reliable, but it's an attempt
            # to automate the process.
            if a.axes not in bg_cache:
                # bg_cache[a.axes] = a.figure.canvas.copy_from_bbox(a.axes.bbox)
                # change here
                bg_cache[a.axes] = a.figure.canvas.copy_from_bbox(a.axes.figure.bbox)
            a.axes.draw_artist(a)
            updated_ax.append(a.axes)
        # After rendering all the needed artists, blit each axes individually.
        for ax in set(updated_ax):
            # and here
            # ax.figure.canvas.blit(ax.bbox)
            ax.figure.canvas.blit(ax.figure.bbox)
    matplotlib.animation.Animation._blit_draw = _blit_draw
    fig, ax = plt.subplots(1, 2, figsize=figsize)
    if title is not None:
        fig.suptitle(title)
    if labels is not None:
        ax[0].set_title(labels[0])
        ax[1].set_title(labels[1])

    im1 = ax[0].imshow(
        to_np(frames1[0]), animated=True, vmin=vmin, vmax=vmax,
        cmap=plt.get_cmap(cmap_name))
    im2 = ax[1].imshow(
        to_np(frames2[0]), animated=True, vmin=vmin, vmax=vmax,
        cmap=plt.get_cmap(cmap_name))
    # trivial prediciton
#    im3 = ax[2].imshow(
#        frames1[0], animated=True, vmin=vmin, vmax=vmax,
#        cmap=plt.get_cmap(cmap_name))

    plt.colorbar(im1, ax=ax[0], fraction=0.046, pad=0.04)
    plt.colorbar(im2, ax=ax[1], fraction=0.046, pad=0.04)
#    plt.colorbar(im3, ax=ax[2], fraction=0.046, pad=0.04)
    t1 = ax[0].text(0.9, 0.05, '', transform=ax[0].transAxes, va='center')
    t2 = ax[1].text(0.9, 0.05, '', transform=ax[1].transAxes, va='center')
    if time is None:
        time = np.arange(len(frames1))

    def init():
        t1.set_text("")
        t2.set_text("")
        im1.set_data(to_np(frames1[0]))
        im2.set_data(to_np(frames2[0]))
        return im1, im2, t1, t2

    def animate(i):
        t1.set_text(str(time[i]))
        t2.set_text(str(time[i]))
        im1.set_data(to_np(frames1[i]))
        im2.set_data(to_np(frames2[i]))
        return im1, im2, t1, t2

    anim = animation.FuncAnimation(fig, animate, init_func=init,
                                   frames=len(frames1), interval=20, blit=True)
    return anim
Пример #16
0
 def init():
     text.set_text("")
     im.set_data(to_np(frames[0]))
     return im, text
Пример #17
0
def animate_quad_imshow(frames1, frames2, frames3, frames4,
                          time=None, vmin=None, vmax=None,
                          cmap_name="inferno", figsize=(6, 6), title=None,
                          axes_labels=None):
    def _blit_draw(self, artists, bg_cache):
        # Handles blitted drawing, which renders only the artists given instead
        # of the entire figure.
        updated_ax = []
        for a in artists:
            # If we haven't cached the background for this axes object, do
            # so now. This might not always be reliable, but it's an attempt
            # to automate the process.
            if a.axes not in bg_cache:
                # bg_cache[a.axes] = a.figure.canvas.copy_from_bbox(a.axes.bbox)
                # change here
                bg_cache[a.axes] = a.figure.canvas.copy_from_bbox(a.axes.figure.bbox)
            a.axes.draw_artist(a)
            updated_ax.append(a.axes)
        # After rendering all the needed artists, blit each axes individually.
        for ax in set(updated_ax):
            # and here
            # ax.figure.canvas.blit(ax.bbox)
            ax.figure.canvas.blit(ax.figure.bbox)
    matplotlib.animation.Animation._blit_draw = _blit_draw
    fig, ax = plt.subplots(2, 2, figsize=figsize)
    ax = ax.flatten()
    if title is not None:
        fig.suptitle(title)

    im1 = ax[0].imshow(
        to_np(frames1[0]), animated=True, vmin=vmin, vmax=vmax,
        cmap=plt.get_cmap(cmap_name))
    im2 = ax[1].imshow(
        to_np(frames2[0]), animated=True, vmin=vmin, vmax=vmax,
        cmap=plt.get_cmap(cmap_name))
    im3 = ax[2].imshow(
        to_np(frames3[0]), animated=True, vmin=vmin, vmax=vmax,
        cmap=plt.get_cmap(cmap_name))
    im4 = ax[3].imshow(
        to_np(frames4[0]), animated=True, vmin=vmin, vmax=vmax,
        cmap=plt.get_cmap(cmap_name))

    plt.colorbar(im1, ax=ax[0], fraction=0.046, pad=0.04)
    plt.colorbar(im2, ax=ax[1], fraction=0.046, pad=0.04)
    plt.colorbar(im3, ax=ax[2], fraction=0.046, pad=0.04)
    plt.colorbar(im4, ax=ax[3], fraction=0.046, pad=0.04)
    text = ax[0].text(.5, 1.05, '', transform=ax[0].transAxes, va='center')
    if axes_labels is not None:
        for lbl, a in zip(axes_labels, ax):
            a.set_xlabel(lbl)
    plt.tight_layout()

    if time is None:
        time = np.arange(frames1.shape[0])

    def init():
        text.set_text("")
        im1.set_data(to_np(frames1[0]))
        im2.set_data(to_np(frames2[0]))
        im3.set_data(to_np(frames3[0]))
        im4.set_data(to_np(frames4[0]))
        return im1, im2, im3, im4, text

    def animate(i):
        text.set_text(str(time[i]))
        im1.set_data(to_np(frames1[i]))
        im2.set_data(to_np(frames2[i]))
        im3.set_data(to_np(frames3[i]))
        im4.set_data(to_np(frames4[i]))
        return im1, im2, im3, text

    anim = animation.FuncAnimation(
        fig, animate, init_func=init, frames=len(frames1),
        interval=20, blit=True)
    return anim