def __call__(self, sample):
        train, truth = sample['train'], sample['truth']
        if np.random.rand() >= self.probability:
            train = torch.fliplr(train)
            truth = torch.fliplr(truth)

        return {'train': train, 'truth': truth}
Ejemplo n.º 2
0
def augment_data(x, y, n=None):
    """
    Generate an augmented training dataset with random reflections
    and 90 degree rotations
    x, y : Image sets of shape (Samples, Width, Height, Channels)
        training images and next images
    n : number of training examples
    """
    n_data = x.shape[0]

    if not n:
        n = n_data
    x_out, y_out = list(), list()

    for i in range(n):
        r = random.randint(0, n_data)
        x_r, y_r = x[r], y[r]

        if random.random() < 0.5:
            x_r = torch.fliplr(x_r)
            y_r = torch.fliplr(y_r)
        if random.random() < 0.5:
            x_r = torch.flipud(x_r)
            y_r = torch.flipud(y_r)

        num_rots = random.randint(0, 4)
        x_r = torch.rot90(x_r, k=num_rots)
        y_r = torch.rot90(y_r, k=num_rots)

        x_out.append(x_r), y_out.append(y_r)
    return torch.stack(x_out), torch.stack(y_out)
Ejemplo n.º 3
0
 def test_step(self, batch, batch_idx):
     # OPTIONAL
     data = batch
     t1 = time.time()
     if (self.hparams.self_ensemble == True):
         image = data['input']
         B, C, H, W = image.shape
         new_image = torch.zeros((B * 4, C, H, W),
                                 dtype=image.dtype,
                                 device=image.device)
         new_image[0:B] = image
         new_image[B:B * 2] = torch.fliplr(image)
         image = torch.rot90(image, 2, [2, 3])
         new_image[B * 2:B * 3] = image
         new_image[B * 3:B * 4] = torch.fliplr(image)
         new_image2 = torch.rot90(new_image, 1, [2, 3])
         data['input'] = new_image
         out1 = self.forward(data)['image']
         data['input'] = new_image2
         out2 = self.forward(data)['image']
         out2 = torch.rot90(out2, 3, [2, 3])
         tempout = (out1 + out2) / 2
         tempout[B:B * 2] = torch.fliplr(tempout[B:B * 2])
         tempout[B * 2:B * 3] = torch.rot90(tempout[B * 2:B * 3], 2, [2, 3])
         tempout[B * 3:B * 4] = torch.rot90(
             torch.fliplr(tempout[B * 3:B * 4]), 2, [2, 3])
         for i in range(B):
             image[i] = torch.mean(tempout[i::B], 0)
         out = image
     else:
         out = self.forward(data)['image']
     t2 = time.time()
     return {'image': out, 'name': data['name'], 't': t2 - t1}
Ejemplo n.º 4
0
    def data_augmentation(image, mask):
        image = torch.Tensor(image)
        mask = torch.Tensor(mask)
        mask = mask.unsqueeze(0)

        if random.random() < 0.5:
            # flip left right
            image = torch.fliplr(image)
            mask = torch.fliplr(mask)

        rot = np.random.choice([0, 1, 2, 3])
        image = torch.rot90(image, rot, [1, 2])
        mask = torch.rot90(mask, rot, [1, 2])

        if random.random() < 0.5:
            # flip up-down
            image = torch.flipud(image)
            mask = torch.flipud(mask)

        if intensity >= 1:

            # random crop
            cropsize = image.shape[2] // 2
            image, mask = random_crop(image, mask, cropsize=cropsize)

            std_noise = 1 * image.std()
            if random.random() < 0.5:
                # add noise per pixel and per channel
                pixel_noise = torch.rand(image.shape[1], image.shape[2])
                pixel_noise = torch.repeat_interleave(pixel_noise.unsqueeze(0),
                                                      image.size(0),
                                                      dim=0)
                image = image + pixel_noise * std_noise

            if random.random() < 0.5:
                channel_noise = torch.rand(
                    image.shape[0]).unsqueeze(1).unsqueeze(2)
                channel_noise = torch.repeat_interleave(
                    torch.repeat_interleave(channel_noise, image.shape[1], 1),
                    image.shape[2], 2)
                image = image + channel_noise * std_noise

            if random.random() < 0.5:
                # add noise
                noise = torch.rand(image.shape[0], image.shape[1],
                                   image.shape[2]) * std_noise
                image = image + noise

        if intensity >= 2:
            # channel shuffle
            if random.random() < 0.5:
                idxs = np.arange(image.shape[0])
                np.random.shuffle(idxs)  # random band indixes
                image = image[idxs]

        mask = mask.squeeze(0)
        return image, mask
Ejemplo n.º 5
0
 def __call__(self, img, mask):
     # if random.random() < self.p:
     #     return (
     #         np.fliplr(img), np.fliplr(mask)
     #     )
     # return img, mask
     if random.random() < self.p:
         return (torch.fliplr(img), torch.fliplr(mask))
     return img, mask
Ejemplo n.º 6
0
 def __call__(self, img, mask):
     if random.random() < self.p:
         # img = cv2.flip(img, 0)
         # mask = cv2.flip(mask, 0)
         img = torch.fliplr(img)
         mask = torch.fliplr(mask)
         return (
             # img.transpose(Image.FLIP_LEFT_RIGHT),
             # mask.transpose(Image.FLIP_LEFT_RIGHT),
             img,
             mask,
         )
     # print('hflip: ', img.shape)
     # print('hflip: ', mask.shape)
     return img, mask
Ejemplo n.º 7
0
    def to_weights(self, coeffs, ind, zero_weights, linear1, linear2):

        zero_weights_ = zero_weights.clone()
        weights = torch.fliplr(zero_weights_.index_put_(tuple(ind), coeffs))
        weights = linear1(weights)
        weights = linear2(weights.transpose(-1, -2))
        return weights.transpose(-1, -2)
Ejemplo n.º 8
0
def prepare_first_frame(curr_video,
                        save_prediction,
                        annotation,
                        sigma1=8,
                        sigma2=21,
                        inference_strategy='single',
                        probability_propagation=False,
                        scale=None):
    first_annotation = Image.open(annotation)
    (H, W) = np.asarray(first_annotation).shape
    H_d = int(np.ceil(H * Config.SCALE))
    W_d = int(np.ceil(W * Config.SCALE))
    palette = first_annotation.getpalette()
    label = np.asarray(first_annotation)
    d = np.max(label) + 1
    label = torch.Tensor(label).long().to(Config.DEVICE)  # (1, H, W)
    label_1hot = get_labels(label, d, H, W, H_d, W_d)

    weight_dense = get_spatial_weight((H_d, W_d), sigma1) if not probability_propagation else None
    weight_sparse = get_spatial_weight((H_d, W_d), sigma2) if not probability_propagation else None

    if save_prediction is not None:
        if not os.path.exists(save_prediction):
            os.makedirs(save_prediction)
        save_path = os.path.join(save_prediction, curr_video)
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        first_annotation.save(os.path.join(save_path, '00000.png'))

    if inference_strategy == 'single':
        return label_1hot, d, palette, weight_dense, weight_sparse
    elif inference_strategy == 'hor-flip':
        label_1hot_flipped = get_labels(torch.fliplr(label), d, H, W, H_d, W_d)
        return label_1hot, label_1hot_flipped, d, palette, weight_dense, weight_sparse
    elif inference_strategy == 'ver-flip':
        label_1hot_flipped = get_labels(torch.flipud(label), d, H, W, H_d, W_d)
        return label_1hot, label_1hot_flipped, d, palette, weight_dense, weight_sparse
    elif inference_strategy == '2-scale' or inference_strategy == 'hor-2-scale':
        H_d_2 = int(np.ceil(H * Config.SCALE * scale))
        W_d_2 = int(np.ceil(W * Config.SCALE * scale))
        weight_dense_2 = get_spatial_weight((H_d_2, W_d_2), sigma1) if not probability_propagation else None
        weight_sparse_2 = get_spatial_weight((H_d_2, W_d_2), sigma2) if not probability_propagation else None
        label_1hot_2 = get_labels(label, d, H, W, H_d_2, W_d_2)
        return (label_1hot, label_1hot_2), d, palette, (weight_dense, weight_dense_2), (weight_sparse, weight_sparse_2)
    elif inference_strategy == 'multimodel':
        # that's right, do nothing
        pass
    elif inference_strategy == '3-scale':
        del weight_dense, weight_sparse
        H_d = int(np.ceil(H * Config.SCALE * scale))
        W_d = int(np.ceil(W * Config.SCALE * scale))
        weight_dense = get_spatial_weight((H_d, W_d), sigma1) if not probability_propagation else None
        weight_sparse = get_spatial_weight((H_d, W_d), sigma2) if not probability_propagation else None
        label_1hot = get_labels(label, d, H, W, H_d, W_d)
        return label_1hot, d, palette, weight_dense, weight_sparse

    return label_1hot, d, palette, weight_dense, weight_sparse
Ejemplo n.º 9
0
    def get_dct_init(self, len_coeffs, dim_out, dim_in, diag_shift):

        factor = 1.
        init = torch.rand([dim_out, dim_in])
        if self.cuda:  # TODO update to device.
            init = init.cuda()

        initrange = 1.0 / math.sqrt(dim_out)
        nn.init.uniform_(init, -initrange, initrange)
        init_f = torch.fliplr(dct.dct_2d(init, norm='ortho'))
        ind = torch.triu_indices(dim_out, dim_in, diag_shift)
        coeffs_init = init_f[tuple(ind)] * factor

        return coeffs_init
Ejemplo n.º 10
0
    def compute_gradients(self, samples, x=None):
        """
        Compute the gradients of log q(x), given samples~q(x);
        If x=None, compute the gradients of log q(samples).
        """
        if x is None:
            kernel_width = self.heuristic_kernel_width(samples, samples)
            x = samples
        else:
            # _samples: [..., N + M, x_dim]
            _samples = torch.cat((samples, x), dim=-2)
            kernel_width = self.heuristic_kernel_width(_samples, _samples)

        M = samples.shape[0]
        # Kq: [..., M, M]
        # grad_K1: [..., M, M, x_dim]
        # grad_K2: [..., M, M, x_dim]
        Kq, grad_K1, grad_K2 = self.grad_gram(samples, samples, kernel_width)
        if self._eta is not None:
            Kq += self._eta * torch.eye(M).to(samples.device)
        # eigen_vectors: [..., M, M]
        # eigen_values: [..., M]
        eigen_values, eigen_vectors = torch.symeig(Kq, eigenvectors=True)

        if (self._n_eigen is None) and (self._n_eigen_threshold is not None):
            eigen_arr = torch.mean(torch.fliplr(eigen_values.view(1, -1)), axis=0)
            eigen_arr /= torch.sum(eigen_arr)
            eigen_cum = torch.cumsum(eigen_arr, dim=0)
            self._n_eigen = torch.sum(torch.lt(eigen_cum, self._n_eigen_threshold))

        if self._n_eigen is not None:

            eigen_values = eigen_values[..., -self._n_eigen:]
            eigen_vectors = eigen_vectors[..., -self._n_eigen:]

        eigen_ext = self.nystrom_ext(samples, x, eigen_vectors, eigen_values, kernel_width)

        grad_K1_avg = torch.mean(grad_K1, dim=-3)
        beta = -torch.sqrt(torch.tensor(M).float()) * torch.matmul(torch.transpose(eigen_vectors, 0, 1),
                                                                   grad_K1_avg) / eigen_values.view(self._n_eigen, 1)
        grads = torch.matmul(eigen_ext, beta)

        return grads
Ejemplo n.º 11
0
    def append(self, self_play_data: SelfPlayData):
        """ 向数据集中插入数据 """
        n = self.board_len
        z_list = Tensor(self_play_data.z_list)
        pi_list = self_play_data.pi_list
        feature_planes_list = self_play_data.feature_planes_list
        # 使用翻转和镜像扩充已有数据集
        for z, pi, feature_planes in zip(z_list, pi_list, feature_planes_list):
            for i in range(4):
                # 逆时针旋转 i*90°
                rot_features = torch.rot90(Tensor(feature_planes), i, (1, 2))
                rot_pi = torch.rot90(Tensor(pi.reshape(n, n)), i)
                self.__data_deque.append(
                    (rot_features, rot_pi.flatten(), z))

                # 对逆时针旋转后的数组进行水平翻转
                flip_features = torch.flip(rot_features, [2])
                flip_pi = torch.fliplr(rot_pi)
                self.__data_deque.append(
                    (flip_features, flip_pi.flatten(), z))
Ejemplo n.º 12
0
        def collate_fn(items):
            '''
            Pads the context from the left and the response from the right.
            Results in batches with (contexts, responses), with batch-first.
            '''
            assert padding_value != None, "Please provide pad_token_id to use for padding"
            inputs = torch.fliplr(
                pad_sequence(
                    [
                        torch.tensor(list(reversed(sample[0])))
                        for sample in items
                    ],
                    batch_first=True,
                    padding_value=padding_value  # Padding code
                ))
            targets = pad_sequence(
                [torch.tensor(sample[1]) for sample in items],
                batch_first=True,
                padding_value=padding_value  # Padding code
            )

            return inputs, targets
Ejemplo n.º 13
0
def polyval(p, x):
    y = torch.zeros_like(x)
    p = torch.fliplr(p.view(1, p.shape[0])).flatten()  # (1, 12)
    for i in range(len(p)):
        y = y * x + p[i]
    return y
Ejemplo n.º 14
0
    def predict(self, path, predpath):
        with rasterio.open(path, "r") as src:
            meta = src.meta
        self.model.eval()

        # for prediction
        predimage = os.path.join(predpath, os.path.basename(path))
        os.makedirs(predpath, exist_ok=True)
        meta["count"] = 1
        meta["dtype"] = "uint8" # storing as uint8 saves a lot of storage space

        #Window(col_off, row_off, width, height)
        H, W = self.image_size

        rows = np.arange(0, meta["height"], H)
        cols = np.arange(0, meta["width"], W)

        image_window = Window(0, 0, meta["width"], meta["height"])

        with rasterio.open(predimage, "w+", **meta) as dst:

            for r, c in tqdm(product(rows, cols), total=len(rows) * len(cols), leave=False):

                window = image_window.intersection(
                    Window(c-self.offset, r-self.offset, W+self.offset, H+self.offset))

                with rasterio.open(path) as src:
                    image = src.read(window=window)

                # if L1C image (13 bands). read only the 12 bands compatible with L2A data
                if (image.shape[0] == 13):
                    image = image[[l1cbands.index(b) for b in l2abands]]

                # to torch + normalize
                image = self.transform(torch.from_numpy(image.astype(np.float32)), [])[0].to(self.device)

                # predict
                with torch.no_grad():
                    x = image.unsqueeze(0)
                    #import pdb; pdb.set_trace()
                    y_logits = torch.sigmoid(self.model(x).squeeze(0))
                    if self.use_test_aug > 0:
                        y_logits += torch.sigmoid(torch.fliplr(self.model(torch.fliplr(x)))).squeeze(0) # fliplr)
                        y_logits += torch.sigmoid(torch.flipud(self.model(torch.flipud(x)))).squeeze(0) # flipud
                        if self.use_test_aug > 1:
                            for rot in [1, 2, 3]: # 90, 180, 270 degrees
                                y_logits += torch.sigmoid(torch.rot90(self.model(torch.rot90(x, rot, [2, 3])),-rot,[2,3]).squeeze(0))
                            y_logits /= 6
                        else:
                            y_logits /= 3

                    y_score = y_logits.cpu().detach().numpy()[0]
                    #y_score = y_score[:,self.offset:-self.offset, self.offset:-self.offset]

                data = dst.read(window=window)[0] / 255
                overlap = data > 0

                if overlap.any():
                    # smooth transition in overlapping regions
                    dx, dy = np.gradient(overlap.astype(float)) # get border
                    g = np.abs(dx) + np.abs(dy)
                    transition = gaussian_filter(g, sigma=self.offset / 2)
                    transition /= transition.max()
                    transition[~overlap] = 1.# normalize to 1

                    y_score = transition * y_score + (1-transition) * data

                # write
                writedata = (np.expand_dims(y_score, 0).astype(np.float32) * 255).astype(np.uint8)
                dst.write(writedata, window=window)
Ejemplo n.º 15
0
    print(torch.maximum(a, b))
    print(torch.minimum(a, b))
    print(torch.fmod(a, 2))
    print(torch.dist(c, d, 1))  # p-norm
    print(torch.norm(c))
    print(torch.div(c, d))
    print(torch.true_divide(c, d))  # rounding_mode=None
    print(torch.sub(c, d, alpha=2.))
    print(c.add(d))
    print(torch.dot(c, d))
    print(torch.sigmoid(c))
    # print(torch.inner(c, d))
    """flip"""
    x = torch.arange(4).view(2, 2)
    print(torch.flipud(x))
    print(torch.fliplr(x))

    # logical
    print("logical function:")
    print(torch.eq(c, d))
    print(torch.ne(c, d))
    print(torch.gt(c, d))
    print(torch.logical_and(c, d))
    print(torch.logical_or(c, d))
    print(torch.logical_xor(c, d))
    print(torch.logical_not(c))
    print(torch.equal(c, d))  # if all equal
    a = torch.rand(2, 2).bool()
    print(a)
    print(torch.all(a))
    print(torch.all(a, dim=0))  # 按列
Ejemplo n.º 16
0
 def __call__(self, latent):
     ret = latent
     if torch.rand(1) < self.p:
         ret = torch.fliplr(latent)
     return ret
Ejemplo n.º 17
0
def eprop(cfg, X, T, betas, W, grad=True):

    start = time.time()

    X, T = interpolate_inputs(cfg=cfg, inp=X, tar=T, stretch=cfg["Repeats"])

    X, T = trim_samples(X=X, T=T)

    n_steps = count_lengths(X=X)

    # Trim input down to layer size.
    X = X[:, :, :cfg["N_R"]]
    inp_size = X.shape[-1]

    X = np.pad(X, ((0, 0), (0, 0), (0, cfg["N_R"] - inp_size)))

    X = torch.tensor(X)
    T = torch.tensor(T)
    n_steps = torch.tensor(n_steps)

    M = initialize_model(cfg=cfg,
                         inp_size=inp_size,
                         tar_size=T.shape[-1],
                         batch_size=X.shape[0],
                         n_steps=max(n_steps))

    M['x'] = X[None, :]  # Insert subnetwork dimension
    M['t'] = T

    if cfg["n_directions"] == 2:
        M['x'] = torch.cat((M['x'], M['x']))
        for b in range(M['x'].shape[1]):
            M['x'][1, b, :n_steps[b]] = torch.fliplr(M['x'][0, b, :n_steps[b]])

    for t in np.arange(0, max(n_steps.cpu().numpy())):
        prev_syn_t, curr_syn_t = conn_t_idxs(
            track_synapse=cfg['Track_synapse'], t=t)
        prev_nrn_t, curr_nrn_t = conn_t_idxs(track_synapse=cfg['Track_neuron'],
                                             t=t)
        is_valid = torch.logical_not(torch.any(M['x'][0, :, t] == -1, axis=1))

        for r in range(
                cfg["N_Rec"]
        ):  # TODO: Can overwrite r instead of appending, except Z
            # TODO: See which is_valid masks can be removed

            if grad and cfg["neuron"] in ["ALIF", "STDP-ALIF"]:
                M['va'][:, :, curr_syn_t,
                        r] = (M['h'][:, :, prev_nrn_t, r, :, None] *
                              M['vv'][:, :, prev_syn_t, r] +
                              (cfg["rho"] -
                               (M['h'][:, :, prev_nrn_t, r, :, None] *
                                betas[:, None, r, :, None])) *
                              M['va'][:, :, prev_syn_t, r])
            elif grad and cfg["neuron"] == "Izhikevich":
                oldva = M['va'][:, :, prev_syn_t, r]
                M['va'][:, :, curr_syn_t, r] = (
                    cfg["IzhA1"] * (1 - M['z'][:, :, prev_nrn_t, r, :, None]) *
                    M['vv'][:, :, prev_syn_t, r] +
                    (1 + cfg["IzhA2"]) * M['va'][:, :, prev_syn_t, r])

            Z_prev_layer = (M['z'][:, :, curr_nrn_t,
                                   r - 1] if r else M['x'][:, :, t])
            Z_prev_time = M['z'][:, :, prev_nrn_t, r]

            # Z_in is incoming at current t, so from last t.
            M['z_in'][:, :, curr_nrn_t, r] = torch.cat(
                (Z_prev_layer, Z_prev_time), axis=2)

            if grad and cfg["neuron"] == "ALIF":
                M['vv'][:, :, curr_syn_t,
                        r] = (cfg["alpha"] * M['vv'][:, :, prev_syn_t, r] +
                              M['z_in'][:, :, curr_nrn_t, r, None, :])
            elif grad and cfg["neuron"] == "STDP-ALIF":
                M['vv'][:, :, curr_syn_t, r] = (
                    cfg["alpha"] * M['vv'][:, :, prev_syn_t, r] *
                    (1 - M['z'][:, :, prev_nrn_t, r] -
                     torch.where(M['tz'][:, :, r] == t - 1 - cfg["dt_refr"], 1,
                                 0))[..., None] +
                    M['z_in'][:, :, curr_nrn_t, r, None, :])
            elif grad and cfg["neuron"] == "Izhikevich":
                M['vv'][:, :, curr_syn_t, r] = (
                    (1 - M['z'][:, :, prev_nrn_t, r, :, None]) *
                    0.9  # TODO: CORRECTION FACTOR?
                    *
                    (1 +
                     (2 * cfg["IzhV1"] * M['v'][:, :, prev_nrn_t, r, :, None] +
                      cfg["IzhV2"])) * M['vv'][:, :, prev_syn_t, r] - oldva +
                    M['z_in'][:, :, curr_nrn_t, r, None, :])
                M['va'][:, :, curr_syn_t,
                        r] = torch.clip(M['vv'][:, :, curr_syn_t, r], -0.005,
                                        0.005)
                M['vv'][:, :, curr_syn_t,
                        r] = torch.clip(M['va'][:, :, curr_syn_t, r], -3., 3.)

            if r == 0:
                M['I_in'][:, :, curr_nrn_t,
                          r] = torch.sum(W['W_in'][:, None, r, :, :inp_size] *
                                         Z_prev_layer[:, :, None, :inp_size],
                                         axis=-1)

            else:
                M['I_in'][:, :, curr_nrn_t, r] = torch.sum(
                    W['W_in'][:, None, r] * Z_prev_layer[:, :, None, :],
                    axis=-1)

            M['I_rec'][:, :, curr_nrn_t, r] = torch.sum(
                W['W_rec'][:, None, r] * Z_prev_time[:, :, None, :], axis=-1)

            M['I'][:, :, curr_nrn_t,
                   r] = M['I_in'][:, :, curr_nrn_t,
                                  r] + M['I_rec'][:, :, curr_nrn_t, r]

            if cfg["neuron"] in ["ALIF", "STDP-ALIF"]:
                M['a'][:, :, curr_nrn_t,
                       r] = (cfg["rho"] * M['a'][:, :, prev_nrn_t, r] +
                             M['z'][:, :, prev_nrn_t, r])

            elif cfg["neuron"] == "Izhikevich":
                at = M['a'][:, :, prev_nrn_t,
                            r] + 2 * M['z'][:, :, prev_nrn_t, r]
                M['a'][:, :, curr_nrn_t,
                       r] = (at + cfg["IzhA1"] *
                             (M['v'][:, :, prev_nrn_t, r] -
                              (M['v'][:, :, prev_nrn_t, r] + cfg["IzhThr"]) *
                              M['z'][:, :, prev_nrn_t, r]) + cfg["IzhA2"] * at)

            A = cfg["thr"] + betas[:, None, r] * M['a'][:, :, curr_nrn_t, r]

            if cfg["neuron"] == "ALIF":
                M['v'][:, :, curr_nrn_t,
                       r] = ((cfg["alpha"] * M['v'][:, :, prev_nrn_t, r] +
                              M['I'][:, :, curr_nrn_t, r] -
                              M['z'][:, :, prev_nrn_t, r] *
                              (A if cfg["v_fix"] else cfg["thr"])))
            elif cfg["neuron"] == "STDP-ALIF":
                M['v'][:, :, curr_nrn_t, r] = ((
                    cfg["alpha"] * M['v'][:, :, prev_nrn_t, r] +
                    M['I'][:, :, curr_nrn_t, r] - M['z'][:, :, prev_nrn_t, r] *
                    cfg["alpha"] * M['v'][:, :, prev_nrn_t, r] -
                    cfg["alpha"] * M['v'][:, :, prev_nrn_t, r] *
                    torch.where(M['tz'][:, :, r] == t - cfg["dt_refr"], 1, 0)))
            elif cfg["neuron"] == "Izhikevich":
                vt = M['v'][:, :, prev_nrn_t,
                            r] - (M['v'][:, :, prev_nrn_t, r] - cfg["IzhReset"]
                                  ) * M['z'][:, :, prev_nrn_t, r]
                M['v'][:, :, curr_nrn_t,
                       r] = (vt + cfg["IzhV1"] * vt**2 + cfg["IzhV2"] * vt +
                             cfg["IzhV3"] - (M['a'][:, :, prev_nrn_t, r] +
                                             2 * M['z'][:, :, prev_nrn_t, r]) +
                             M['I'][:, :, curr_nrn_t, r])

            if cfg["neuron"] in ["ALIF", "STDP-ALIF"]:
                M['z'][:, :, curr_nrn_t, r] = torch.where(
                    torch.logical_and(t - M['tz'][:, :, r] > cfg["dt_refr"],
                                      M['v'][:, :, curr_nrn_t, r] >= A), 1, 0)
            elif cfg["neuron"] == "Izhikevich":
                M['z'][:, :, curr_nrn_t, r] = torch.where(
                    M['v'][:, :, curr_nrn_t, r] >= cfg["IzhThr"], 1, 0)

            M['tz'][:, :,
                    r] = torch.where(M['z'][:, :, curr_nrn_t, r] != 0,
                                     torch.ones_like(M['tz'][:, :, r]) * t,
                                     M['tz'][:, :, r])
            M['zs'][:, :,
                    r] += M['z'][:, :, curr_nrn_t, r] * is_valid[None, :, None]
            if not grad:
                continue

            if cfg["neuron"] in ['ALIF', "STDP-ALIF"]:
                M['h'][:, :, curr_nrn_t, r] = (
                    ((1 / (A if cfg["v_fix"] else cfg["thr"]))
                     if not cfg["v_fix_psi"] else 1)
                    # (1 / cfg["thr"])
                    * cfg["gamma"] * torch.clip(
                        1 - (abs(
                            (M['v'][:, :, curr_nrn_t, r] - A) / cfg["thr"])),
                        0, None))
            elif cfg["neuron"] == 'Izhikevich':
                cfg["gamma"] * torch.exp(
                    (torch.clip(M['v'][:, :, curr_nrn_t, r], None,
                                cfg["IzhThr"]) - cfg["IzhThr"]) /
                    cfg["IzhThr"])

            if cfg["neuron"] == "ALIF":
                M['h'][:, :, curr_nrn_t, r] = torch.where(
                    t - M['tz'][:, :, r] >= cfg["dt_refr"],
                    M['h'][:, :, curr_nrn_t, r],
                    torch.zeros_like(M['h'][:, :, curr_nrn_t, r]))
            elif cfg["neuron"] == "STDP-ALIF":
                M['h'][:, :, curr_nrn_t, r] = torch.where(
                    t - M['tz'][:, :, r] >= cfg["dt_refr"],
                    M['h'][:, :, curr_nrn_t, r],
                    torch.ones_like(M['h'][:, :, curr_nrn_t, r]) *
                    cfg["gamma"])

            if cfg["neuron"] in ['ALIF', "STDP-ALIF"]:
                M['et'][:, :, curr_syn_t, r] = (
                    M['h'][:, :, curr_nrn_t, r, :, None] *
                    (M['vv'][:, :, curr_syn_t, r] -
                     betas[:, None, r, :, None] * M['va'][:, :, curr_syn_t, r])
                )
            elif cfg["neuron"] == 'Izhikevich':
                M['et'][:, :, curr_syn_t,
                        r] = (M['h'][:, :, curr_nrn_t, r, :, None] *
                              M['vv'][:, :, curr_syn_t, r])

            M['etbar'][:, :, curr_syn_t,
                       r] = (cfg["kappa"] * M['etbar'][:, :, prev_syn_t, r] +
                             M['et'][:, :, curr_syn_t, r])
            M['zbar'][:, :, curr_nrn_t,
                      r] = (cfg["kappa"] * M['zbar'][:, :, prev_nrn_t, r] +
                            M['z'][:, :, curr_nrn_t, r])

        M['ysub'][:, :, curr_nrn_t] = (
            torch.sum(W['out'][:, None] * M['z'][:, :, curr_nrn_t, :, :, None],
                      axis=(-2, -3)) +
            cfg["kappa"] * M['ysub'][:, :, prev_nrn_t])

        M['y'][:, curr_nrn_t] = torch.sum(M['ysub'][:, :, curr_nrn_t],
                                          axis=0) + W['bias']

        M['p'][:,
               t] = torch.exp(M['y'][:, curr_nrn_t] -
                              torch.amax(M['y'][:, curr_nrn_t], axis=1)[:,
                                                                        None])

        M['p'][:, t] = M['p'][:, t] / torch.sum(M['p'][:, t], axis=1)[:, None]

        if not grad:  # Next steps all have to do with gradient calculation
            continue

        M['d'][:, curr_nrn_t] = (M['p'][:, t] - T[:, t])

        M['loss_pred'][:, :, curr_nrn_t] = torch.sum(
            W['B'][:, None, :, :] *
            M['d'][None, :, curr_nrn_t, None,
                   None, :],  # Checked correct (for batch size 1)
            axis=-1)

        M['GW_in'][:, curr_syn_t] += torch.mean(
            M['loss_pred'][:, :, curr_nrn_t, :, :, None] *
            M['etbar'][:, :, curr_syn_t, :, :, :cfg["N_R"]],
            axis=1)
        M['GW_rec'][:, curr_syn_t] += torch.mean(
            M['loss_pred'][:, :, curr_nrn_t, :, :, None] *
            M['etbar'][:, :, curr_syn_t, :, :, cfg["N_R"]:],
            axis=1)

        M['loss_reg'][:, curr_nrn_t] = torch.mean(
            cfg["FR_reg"]
            # * 2 * t  # Initial means are less informative
            # / n_steps[None, :, None, None] ** 2  # Square corrects previous term
            * (M['zs'] / (t + 1) - cfg["FR_target"]) *
            is_valid[None, :, None, None],
            axis=1)

        M['GW_in'][:, curr_syn_t] += M['loss_reg'][:, curr_nrn_t, :, :, None]
        M['GW_rec'][:, curr_syn_t] += M['loss_reg'][:, curr_nrn_t, :, :, None]

        M['Gout'][:, curr_syn_t] += torch.mean(
            M['d'][None, :, curr_nrn_t, None] *
            M['zbar'][:, :, curr_nrn_t, -1, :, None] *
            is_valid[None, :, None, None],
            axis=1)
        M['Gbias'][curr_syn_t] += torch.mean(torch.sum(M['d'][:, curr_nrn_t] *
                                                       is_valid[None, :, None],
                                                       axis=0),
                                             axis=0)

    a = torch.arange(M['p'].shape[1])
    for b in range(M['p'].shape[0]):
        M['pm'][b, a, M['p'][b].argmax(axis=1)] = 1

    M['correct'] = (M['pm'] == M['t']).all(axis=2)

    M['ce'] = -torch.sum(M['t'] * torch.log(1e-30 + M['p']), axis=2)
    M['reg_error'] = 0.5 * torch.sum(
        torch.mean((M['zs'] / n_steps[None, :, None, None] - cfg["FR_target"]),
                   axis=1))**2

    if grad:
        G = {}
        # Sum over time
        G['W_in'] = torch.sum(M['GW_in'], axis=1)
        G['W_rec'] = torch.sum(M['GW_rec'], axis=1)
        G['out'] = torch.sum(M['Gout'], axis=1)
        G['bias'] = torch.sum(M['Gbias'], axis=0)

        if cfg["L2_reg"]:
            G['W_in'] += cfg["L2_reg"] * W['W_in']
            G['W_rec'] += cfg["L2_reg"] * W['W_rec']

        # Don't update dead weights
        G['W_in'][W['W_in'] == 0] = 0
        G['W_rec'][W['W_rec'] == 0] = 0

        return G, M, n_steps.cpu().numpy()

    return None, M, n_steps.cpu().numpy()
Ejemplo n.º 18
0
 def other_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     c = torch.randint(0, 8, (5, ), dtype=torch.int64)
     e = torch.randn(4, 3)
     f = torch.randn(4, 4, 4)
     size = [0, 1]
     dims = [0, 1]
     return (
         torch.atleast_1d(a),
         torch.atleast_2d(a),
         torch.atleast_3d(a),
         torch.bincount(c),
         torch.block_diag(a),
         torch.broadcast_tensors(a),
         torch.broadcast_to(a, (4)),
         # torch.broadcast_shapes(a),
         torch.bucketize(a, b),
         torch.cartesian_prod(a),
         torch.cdist(e, e),
         torch.clone(a),
         torch.combinations(a),
         torch.corrcoef(a),
         # torch.cov(a),
         torch.cross(e, e),
         torch.cummax(a, 0),
         torch.cummin(a, 0),
         torch.cumprod(a, 0),
         torch.cumsum(a, 0),
         torch.diag(a),
         torch.diag_embed(a),
         torch.diagflat(a),
         torch.diagonal(e),
         torch.diff(a),
         torch.einsum("iii", f),
         torch.flatten(a),
         torch.flip(e, dims),
         torch.fliplr(e),
         torch.flipud(e),
         torch.kron(a, b),
         torch.rot90(e),
         torch.gcd(c, c),
         torch.histc(a),
         torch.histogram(a),
         torch.meshgrid(a),
         torch.lcm(c, c),
         torch.logcumsumexp(a, 0),
         torch.ravel(a),
         torch.renorm(e, 1, 0, 5),
         torch.repeat_interleave(c),
         torch.roll(a, 1, 0),
         torch.searchsorted(a, b),
         torch.tensordot(e, e),
         torch.trace(e),
         torch.tril(e),
         torch.tril_indices(3, 3),
         torch.triu(e),
         torch.triu_indices(3, 3),
         torch.vander(a),
         torch.view_as_real(torch.randn(4, dtype=torch.cfloat)),
         torch.view_as_complex(torch.randn(4, 2)),
         torch.resolve_conj(a),
         torch.resolve_neg(a),
     )
Ejemplo n.º 19
0
    def compute(self, s: torch.Tensor) -> torch.Tensor:
        if s.shape != self.pre.shape:
            raise Exception('Spikes shape is diffrent from pre shape!')

        return convolve(s, torch.flipud(torch.fliplr(self.kernel)))
Ejemplo n.º 20
0
def inference_hor_flip(model, inference_loader, total_len, annotation_dir, last_video, save, sigma_1, sigma_2,
                       frame_range, ref_num, temperature, probability_propagation, reduction_str, disable):
    global pred_visualize, palette, feats_history_l, label_history_l, weight_dense, weight_sparse, feats_history_r, label_history_r, d
    frame_idx = 0
    for input, (current_video,) in tqdm(inference_loader, total=total_len, disable=disable):
        if current_video != last_video:
            # save prediction
            pred_visualize = pred_visualize.cpu().numpy()
            save_predictions(pred_visualize, palette, save, last_video)
            frame_idx = 0
        if frame_idx == 0:
            input_l = input[0].to(Config.DEVICE)
            input_r = input[1].to(Config.DEVICE)
            with torch.cuda.amp.autocast():
                feats_history_l = model(input_l)
                feats_history_r = model(input_r)
            first_annotation = annotation_dir / current_video / '00000.png'
            label_history_l, label_history_r, d, palette, weight_dense, weight_sparse = prepare_first_frame(
                current_video,
                save,
                first_annotation,
                sigma_1,
                sigma_2,
                inference_strategy='hor-flip',
                probability_propagation=probability_propagation)
            frame_idx += 1
            last_video = current_video
            continue
        (batch_size, num_channels, H, W) = input[0].shape

        input_l = input[0].to(Config.DEVICE)
        input_r = input[1].to(Config.DEVICE)
        with torch.cuda.amp.autocast():
            features_l = model(input_l)
            features_r = model(input_r)

        (_, feature_dim, H_d, W_d) = features_l.shape
        prediction_l = predict(feats_history_l,
                               features_l[0],
                               label_history_l,
                               weight_dense,
                               weight_sparse,
                               frame_idx,
                               frame_range,
                               ref_num,
                               temperature,
                               probability_propagation)
        # Store all frames' features
        if probability_propagation:
            new_label_l = prediction_l.unsqueeze(1)
        else:
            new_label_l = index_to_onehot(torch.argmax(prediction_l, 0), d).unsqueeze(1)
        label_history_l = torch.cat((label_history_l, new_label_l), 1)
        feats_history_l = torch.cat((feats_history_l, features_l), 0)

        prediction_l = torch.nn.functional.interpolate(prediction_l.view(1, d, H_d, W_d),
                                                       size=(H, W),
                                                       mode='nearest')
        if not probability_propagation:
            prediction_l = torch.argmax(prediction_l, 1).squeeze()  # (1, H, W)

        prediction_r = predict(feats_history_r,
                               features_r[0],
                               label_history_r,
                               weight_dense,
                               weight_sparse,
                               frame_idx,
                               frame_range,
                               ref_num,
                               temperature,
                               probability_propagation)
        # Store all frames' features
        if probability_propagation:
            new_label_r = prediction_r.unsqueeze(1)
        else:
            new_label_r = index_to_onehot(torch.argmax(prediction_r, 0), d).unsqueeze(1)
        label_history_r = torch.cat((label_history_r, new_label_r), 1)
        feats_history_r = torch.cat((feats_history_r, features_r), 0)

        # 1. upsample, 2. argmax
        prediction_r = F.interpolate(prediction_r.view(1, d, H_d, W_d), size=(H, W), mode='nearest')
        if not probability_propagation:
            prediction_r = torch.argmax(prediction_r, 1).squeeze()  # (1, H, W)
        prediction_r = torch.fliplr(prediction_r).cpu()
        prediction_l = prediction_l.cpu()

        last_video = current_video
        frame_idx += 1

        if probability_propagation:
            reduction = REDUCTIONS.get(reduction_str)
            prediction = reduction(prediction_l, prediction_r).cpu().half()
            prediction = torch.argmax(prediction, 1).cpu()  # (1, H, W)
        else:
            prediction = torch.maximum(prediction_l, prediction_r).unsqueeze(0).cpu().half()

        if frame_idx == 2:
            pred_visualize = prediction
        else:
            pred_visualize = torch.cat((pred_visualize, prediction), 0)

    # save last video's prediction
    pred_visualize = pred_visualize.cpu().numpy()
    save_predictions(pred_visualize, palette, save, last_video)
Ejemplo n.º 21
0
 def __call__(self, image):
     rand_ = random.uniform(0, 1)
     if rand_ < self.prob:
         image = torch.fliplr(image)
     return image
Ejemplo n.º 22
0
 def test_fliplr_invalid(self, device, dtype):
     x = torch.randn(42).to(dtype)
     with self.assertRaisesRegex(RuntimeError, "Input must be >= 2-d."):
         torch.fliplr(x)
     with self.assertRaisesRegex(RuntimeError, "Input must be >= 2-d."):
         torch.fliplr(torch.tensor(42, device=device, dtype=dtype))
Ejemplo n.º 23
0
    def flip_code(code):

        return torch.fliplr(code)