Exemple #1
0
    def _interpolate(im, x, y, out_size):
        num_batch, height, width, channels = im.size(
        )  # to be sure the input dims is NHWC
        x = torch._cast_Float(x).cuda()
        y = torch._cast_Float(y).cuda()
        height_f = torch._cast_Float(torch.Tensor([height]))[0].cuda()
        width_f = torch._cast_Float(torch.Tensor([width]))[0].cuda()
        out_height = out_size[0]
        out_width = out_size[1]
        zero = torch.zeros([], dtype=torch.int32).cuda()
        max_y = torch._cast_Long(torch.Tensor([height - 1]))[0].cuda()
        max_x = torch._cast_Long(torch.Tensor([width - 1]))[0].cuda()

        # scale indices from [-1, 1] to [0, width/height]
        x = (x + 1.0) * width_f / 2.0
        y = (y + 1.0) * height_f / 2.0

        # do sampling
        x0 = torch._cast_Long(torch.floor(x)).cuda()
        x1 = x0 + 1
        y0 = torch._cast_Long(torch.floor(y)).cuda()
        y1 = y0 + 1

        x0 = torch.clamp(x0, zero, max_x)
        x1 = torch.clamp(x1, zero, max_x)
        y0 = torch.clamp(y0, zero, max_y)
        y1 = torch.clamp(y1, zero, max_y)
        dim2 = width
        dim1 = width * height
        base = _repeat(torch.arange(num_batch) * dim1,
                       out_height * out_width).cuda()
        base_y0 = base + y0 * dim2
        base_y1 = base + y1 * dim2
        idx_a = base_y0 + x0
        idx_b = base_y1 + x0
        idx_c = base_y0 + x1
        idx_d = base_y1 + x1

        # use indices to look up pixels in the flate images
        # and restore channels dim
        im_flat = im.contiguous().view(-1, channels)
        im_flat = torch._cast_Float(im_flat)
        Ia = im_flat[idx_a]  # as in tf, the default dim is row first
        Ib = im_flat[idx_b]
        Ic = im_flat[idx_c]
        Id = im_flat[idx_d]

        # calculate interpolated values
        x0_f = torch._cast_Float(x0).cuda()
        x1_f = torch._cast_Float(x1).cuda()
        y0_f = torch._cast_Float(y0).cuda()
        y1_f = torch._cast_Float(y1).cuda()
        wa = ((x1_f - x) * (y1_f - y)).unsqueeze(1)
        wb = ((x1_f - x) * (y - y0_f)).unsqueeze(1)
        wc = ((x - x0_f) * (y1_f - y)).unsqueeze(1)
        wd = ((x - x0_f) * (y - y0_f)).unsqueeze(1)

        return wa * Ia + wb * Ib + wc * Ic + wd * Id
def _pairwise_distances(embeddings, squared=False):
    # Get the dot product between all embeddings
    # shape (batch_size, batch_size)
    dot_product = torch.matmul(embeddings, torch.transpose(embeddings, 0, 1))

    # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
    # This also provides more numerical stability (the diagonal of the result will be exactly 0).
    # shape (batch_size,)
    square_norm = torch.diag(dot_product)

    # Compute L2
    # shape (batch_size, batch_size)
    distances = torch.unsqueeze(square_norm, 1) - 2.0 * dot_product + torch.unsqueeze(square_norm, 0)

    # Because of computation errors, some distances might be negative so we put everything >= 0.0
    distances = torch.max(distances, torch.zeros_like(distances))

    if not squared:
        mask = torch._cast_Float(torch.eq(distances, 0.0))
        distances = distances + mask * 1e-16  # for sqrt(0)

        distances = torch.sqrt(distances)

        # Correct the epsilon added: set the distances on the mask to be exactly 0.0
        distances = distances * (1.0 - mask)

    return distances.cpu()
Exemple #3
0
    def get_4_pts(self, theta, batch_size):
        pts1_ = []
        pts2_ = []
        pts = []
        h = 2.0 / grid_h
        w = 2.0 / grid_w
        tot = 0
        for i in range(grid_h + 1):
            pts.append([])
            for j in range(grid_w + 1):
                hh = i * h - 1
                ww = j * w - 1
                p = torch._cast_Float(torch.Tensor([ww, hh]).view(2)).cuda()
                temp = theta[:, tot * 2:tot * 2 + 2]
                tot += 1
                p = (p + temp).view([batch_size, 1, 2])
                p = torch.clamp(p, -1. / do_crop_rate, 1. / do_crop_rate)
                pts[i].append(p.view([batch_size, 2, 1]))
                pts2_.append(p)

        for i in range(grid_h):
            for j in range(grid_w):
                g = torch.cat([
                    pts[i][j], pts[i][j + 1], pts[i + 1][j], pts[i + 1][j + 1]
                ],
                              dim=2)
                pts1_.append(g.view([batch_size, 1, 8]))

        pts1 = torch.cat(pts1_, 1).view([batch_size, grid_h, grid_w, 8])
        pts2 = torch.cat(pts2_,
                         1).view([batch_size, grid_h + 1, grid_w + 1, 2])

        return pts1, pts2
def batch_hard_triplet_loss(labels, embeddings, margin=0.3, squared=False):
    # Get the pairwise distance matrix
    pairwise_dist = _pairwise_distances(embeddings, squared=squared)

    # For each anchor, get the hardest positive
    # First, we need to get a mask for every valid positive (they should have same label)
    mask_anchor_positive = _get_anchor_positive_triplet_mask(labels)
    mask_anchor_positive = torch._cast_Float(mask_anchor_positive)

    # We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
    anchor_positive_dist = torch.mul(mask_anchor_positive, pairwise_dist)

    # shape (batch_size, 1)
    hardest_positive_dist, _ = torch.max(anchor_positive_dist,
                                         dim=1,
                                         keepdim=True)
    # torch.summary.scalar("hardest_positive_dist", torch.mean(hardest_positive_dist))

    # For each anchor, get the hardest negative
    # First, we need to get a mask for every valid negative (they should have different labels)
    mask_anchor_negative = _get_anchor_negative_triplet_mask(labels)
    mask_anchor_negative = torch._cast_Float(mask_anchor_negative)

    # We add the maximum value in each row to the invalid negatives (label(a) == label(n))
    max_anchor_negative_dist, _ = torch.max(pairwise_dist, dim=1, keepdim=True)
    anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (
        torch.ones_like(mask_anchor_negative) - mask_anchor_negative)

    # shape (batch_size,)
    hardest_negative_dist, _ = torch.min(anchor_negative_dist,
                                         dim=1,
                                         keepdim=True)

    # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
    triplet_loss = torch.max(
        hardest_positive_dist - hardest_negative_dist + margin,
        torch.zeros_like(hardest_positive_dist))

    # Get final mean triplet loss
    triplet_loss = torch.mean(triplet_loss)

    return triplet_loss
Exemple #5
0
def train():
    weight_decay = 1e-4
    densenet = DenseNet()
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    densenet.to(device)
    optimizer = torch.optim.SGD(densenet.parameters(), lr=0.1, momentum=0.9)
    lr_scheduler_opt = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[50000, 75000], gamma=0.1)
    criterion = nn.CrossEntropyLoss()
    L2_reg = nn.MSELoss()
    data1 = sio.loadmat("./cifar-10-batches-mat/data_batch_1.mat")
    data2 = sio.loadmat("./cifar-10-batches-mat/data_batch_2.mat")
    data3 = sio.loadmat("./cifar-10-batches-mat/data_batch_3.mat")
    data4 = sio.loadmat("./cifar-10-batches-mat/data_batch_4.mat")
    data5 = sio.loadmat("./cifar-10-batches-mat/data_batch_5.mat")
    data = np.concatenate((data1["data"], data2["data"], data3["data"],
                           data4["data"], data5["data"]),
                          axis=0)
    labels = np.concatenate((data1["labels"], data2["labels"], data3["labels"],
                             data4["labels"], data5["labels"]),
                            axis=0)
    data = np.reshape(data, [-1, 3, 32, 32])
    labels = labels[:, 0]
    nums = data.shape[0]
    train_nums = 49500
    train_data = data[:train_nums]
    train_labels = labels[:train_nums]
    val_data = data[train_nums:]
    val_labels = labels[train_nums:]
    train_data = torch.tensor(train_data, dtype=torch.float32).to(device)
    train_labels = torch.tensor(train_labels, dtype=torch.long).to(device)
    val_data = torch.tensor(val_data, dtype=torch.float32).to(device)
    val_labels = torch.tensor(val_labels, dtype=torch.long).to(device)
    paras = list(densenet.parameters())
    for i in range(100000):
        rand_idx = np.random.randint(0, train_nums, [64])
        batch = train_data[rand_idx]
        batch_label = train_labels[rand_idx]
        logits = densenet(batch)
        reg = torch.mean(
            torch.tensor([L2_reg(p, torch.zeros_like(p)) for p in paras]))
        loss = criterion(logits, batch_label) + reg * weight_decay
        optimizer.zero_grad()
        loss.backward()
        lr_scheduler_opt.step()
        if i % 100 == 0:
            logits = densenet(val_data)
            val_loss = criterion(logits, val_labels)
            val_acc = torch._cast_Float(
                torch.argmax(logits, dim=1) == val_labels).mean()
            print("Iteration: %d, Val_loss: %f, Val_acc: %f" %
                  (i, val_loss, val_acc))
    pass
Exemple #6
0
 def einstein_midpoint(self, x, c):
     x = self.to_klein(x, c)
     x_lorentz = self.lorentz_factors(x)
     x_norm = torch.norm(x, dim=-1)
     # deal with pad value
     x_lorentz = (1.0 - torch._cast_Float(x_norm == 0.0)) * x_lorentz
     x_lorentz_sum = torch.sum(x_lorentz, dim=-1, keepdim=True)
     x_lorentz_expand = torch.unsqueeze(x_lorentz, dim=-1)
     x_midpoint = torch.sum(x_lorentz_expand * x, dim=1) / x_lorentz_sum
     x_midpoint = self.klein_constraint(x_midpoint)
     x_p = self.klein_to_poincare(x_midpoint, c)
     return x_p
Exemple #7
0
def train():
    resnet50 = ResNet(10)
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    resnet50.to(device)
    optimizer = torch.optim.Adam(resnet50.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    data1 = sio.loadmat("./cifar-10-batches-mat/data_batch_1.mat")
    data2 = sio.loadmat("./cifar-10-batches-mat/data_batch_2.mat")
    data3 = sio.loadmat("./cifar-10-batches-mat/data_batch_3.mat")
    data4 = sio.loadmat("./cifar-10-batches-mat/data_batch_4.mat")
    data5 = sio.loadmat("./cifar-10-batches-mat/data_batch_5.mat")
    data = np.concatenate((data1["data"], data2["data"], data3["data"], data4["data"], data5["data"]), axis=0)
    labels = np.concatenate((data1["labels"], data2["labels"], data3["labels"], data4["labels"], data5["labels"]), axis=0)
    data = np.reshape(data, [-1, 3, 32, 32])
    labels = labels[:, 0]
    nums = data.shape[0]
    train_nums = 49500
    train_data = data[:train_nums]
    train_labels = labels[:train_nums]
    val_data = data[train_nums:]
    val_labels = labels[train_nums:]
    train_data = torch.tensor(train_data, dtype=torch.float32).to(device)
    train_labels = torch.tensor(train_labels, dtype=torch.long).to(device)
    val_data = torch.tensor(val_data, dtype=torch.float32).to(device)
    val_labels = torch.tensor(val_labels, dtype=torch.long).to(device)
    for i in range(100000):
        rand_idx = np.random.randint(0, train_nums, [256])
        batch = train_data[rand_idx]
        batch_labels = train_labels[rand_idx]
        logits = resnet50(batch)
        train_loss = criterion(logits, batch_labels)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        train_acc = torch._cast_Float(torch.argmax(logits, dim=1) == batch_labels).mean()
        if i % 100 == 0:
            logits = resnet50(val_data)
            val_acc = torch._cast_Float(torch.argmax(logits, dim=1) == val_labels).mean()
            print("Iteration: %d, Loss: %f, Val Accuracy: %f"%(i, train_loss, val_acc))
Exemple #8
0
    def get_Hs(theta):
        num_batch = theta.size()[0]
        h = 2.0 / grid_h
        w = 2.0 / grid_w
        Hs = []
        for i in range(grid_h):
            for j in range(grid_w):
                hh = i * h - 1
                ww = j * w - 1
                ori = torch._cast_Float(torch.Tensor([ww, hh, ww + w, hh, ww, hh + h, ww + w, hh + h])).\
                    view([1, 8]).repeat([num_batch, 1]).cuda()
                id = i * (grid_w + 1) + grid_w
                tar = torch.cat([
                    theta[:, i:i + 1, j:j + 1, :], theta[:, i:i + 1,
                                                         j + 1:j + 2, :],
                    theta[:, i + 1:i + 2, j:j + 1, :], theta[:, i + 1:i + 2,
                                                             j + 1:j + 2, :]
                ],
                                dim=1)
                tar = tar.view([num_batch, 8])
                Hs.append(get_H(ori, tar).view([num_batch, 1, 9]))

        Hs = torch.cat(Hs, dim=1).view([num_batch, grid_h, grid_w, 9])
        return Hs
Exemple #9
0
    def sample_walks(self, data, steps=None, start_p=1.0):
        """
        :param data: Preprocessed PyTorch Geometric data object.
        :param x_edge: Edge features
        :param steps: Number of walk steps (if None, default_old from config is used)
        :param start_p: Probability of starting a walk at each node
        :return: The data object with the walk added as an attribute
        """

        device = data.x.device

        # get adjacency data
        adj_nodes = data.edge_index[1]
        adj_offset = data.adj_offset
        degrees = data.degrees
        node_id = data.node_id
        adj_bits = data.adj_bits
        graph_idx = data.batch
        graph_offset = data.graph_offset
        order = data.order

        # use default_old number of steps if not specified
        if steps is None:
            steps = self.steps

        # set dimensions
        s = self.win_size
        n = degrees.shape[0]
        l = steps + 1

        # sample starting nodes
        if self.training and start_p < 1.0:
            start, walk_graph_idx = Walker.sample_start(
                start_p, graph_idx, graph_offset, order, device)
        else:
            start = torch.arange(0, n, dtype=torch.int64).view(-1)
        start = start[degrees[start] > 0]

        # init tensor to hold walk indices
        w = start.shape[0]
        walks = torch.zeros((l, w), dtype=torch.int64, device=device)
        walks[0] = start

        walk_edges = torch.zeros((l - 1, w), dtype=torch.int64, device=device)

        # get all random decisions at once (faster then individual calls)
        choices = torch.randint(0, MAXINT, (steps, w), device=device)

        if self.compute_id:
            id_enc = torch.zeros((l, s, w), dtype=torch.bool, device=device)

        if self.compute_adj:
            edges = torch.zeros((l, s, w), dtype=torch.bool, device=device)

        # remove one choice of each node with deg > 1 for no_backtrack walks
        nb_degree_mask = (degrees == 1)
        nb_degrees = nb_degree_mask * degrees + (~nb_degree_mask) * (degrees -
                                                                     1)

        for i in range(steps):
            chosen_edges = self.unweighted_choice(i, walks, adj_nodes,
                                                  adj_offset, degrees,
                                                  nb_degrees, choices)

            # update nodes
            walks[i + 1] = adj_nodes[chosen_edges]

            # update edge features
            walk_edges[i] = chosen_edges

            o = min(s, i + 1)
            prev = walks[i + 1 - o:i + 1]

            if self.compute_id:
                # get local identity relation
                id_enc[i + 1, s - o:] = torch.eq(walks[i + 1].view(1, w), prev)

            if self.compute_adj:
                # look up edges in the bit-wise adjacency encoding
                cur_id = node_id[walks[i + 1]]
                cur_int = (cur_id // 63).view(1, -1, 1).repeat(o, 1, 1)
                edges[i + 1, s - o:] = (
                    torch.gather(adj_bits[prev], 2, cur_int).view(o, -1) >>
                    (cur_id % 63).view(1, -1)) % 2 == 1

        # permute walks into the correct shapes
        data.walk_nodes = walks.permute(1, 0)
        data.walk_edges = walk_edges.permute(1, 0)

        # combine id, adj and edge features
        feat = []
        if self.compute_id:
            feat.append(torch._cast_Float(id_enc.permute(2, 1, 0)))
        if self.compute_adj:
            feat.append(torch._cast_Float(edges.permute(2, 1, 0))[:, :-1, :])
        data.walk_x = torch.cat(feat, dim=1) if len(feat) > 0 else None

        return data
Exemple #10
0
def train(model, train_iter, val_iter):
    writer = SummaryWriter(model.model_dir)
    res_path = os.path.join(model.model_dir, 'results.csv')

    with open(res_path, 'w', newline='') as f:
        csv_writer = DictWriter(
            f, fieldnames=['train_acc', 'test_acc', 'test_std'])
        csv_writer.writeheader()

        model.to(device)
        opt = torch.optim.Adam(model.parameters(),
                               lr=model.config['lr'],
                               weight_decay=model.config['weight_decay'])
        sch = torch.optim.lr_scheduler.StepLR(
            opt,
            gamma=model.config['decay_factor'],
            step_size=model.config['patience'])

        max_epochs = model.config['epochs']

        binary = model.out_dim == 1

        for e in range(1, max_epochs + 1):
            train_acc = []
            train_loss = []

            model.train()
            for data in train_iter:
                data = data.to(device)
                y = data.y
                opt.zero_grad()
                data = model(data)
                logits = data.y_pred

                if binary:
                    y_pred = torch.sigmoid(logits)
                    loss = F.binary_cross_entropy(y_pred,
                                                  torch._cast_Float(y),
                                                  reduction='mean')
                    acc = torch._cast_Int(y_pred > 0.5).eq(y).sum() / float(
                        y.shape[0])
                else:
                    loss = F.cross_entropy(logits,
                                           y.reshape(-1),
                                           reduction='mean')
                    acc = logits.argmax(dim=1).eq(y.reshape(-1)).sum() / float(
                        y.shape[0])

                loss.backward()
                opt.step()

                train_acc.append(acc.cpu().detach().numpy())
                train_loss.append(loss.cpu().detach().numpy())

            torch.cuda.empty_cache()

            train_acc = np.mean(train_acc)
            train_loss = np.mean(train_loss)
            val_acc, val_std = eval(model,
                                    val_iter,
                                    repeats=model.config['eval_rep'])

            writer.add_scalar('Loss/train', train_loss, e)
            writer.add_scalar('Acc/train', train_acc, e)
            writer.add_scalar('Acc/val', val_acc, e)

            csv_writer.writerow({
                'train_acc': train_acc,
                'test_acc': val_acc,
                'test_std': val_std
            })

            print(
                f'Epoch {e + 1} Loss: {train_loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f} (+-{val_std:.4f})'
            )
            sch.step()
Exemple #11
0
def feat_transform(graph):
    graph.x = torch._cast_Float(F.one_hot(graph.x.view(-1), num_node_feat))
    graph.edge_attr = torch._cast_Float(
        F.one_hot(graph.edge_attr.view(-1), num_edge_feat))
    return graph
Exemple #12
0
def train(model, train_iter, val_iter):
    writer = SummaryWriter(model.model_dir)
    res_path = os.path.join(model.model_dir, 'results.json')

    model.to(device)

    decay_step = model.config['decay_step']
    opt = torch.optim.Adam(model.parameters(),
                           lr=model.config['lr'],
                           weight_decay=model.config['weight_decay'])
    sch = torch.optim.lr_scheduler.MultiStepLR(
        opt, [decay_step], gamma=model.config['decay_factor'], verbose=True)

    best_val_score = 0.0

    walk_start_p = model.config['train_start_ratio']
    max_epochs = model.config['epochs']

    for e in range(1, max_epochs + 1):
        train_loss = []
        all_y_pred = []
        all_y_true = []

        model.train()
        for data in tqdm(train_iter):
            opt.zero_grad()
            data = data.to(device)

            data = model(data, walk_start_p=walk_start_p)
            y_true = data.y

            no_nan = ~torch.isnan(y_true)
            loss = (F.binary_cross_entropy_with_logits(
                data.y_pred[no_nan], torch._cast_Float(y_true[no_nan])))

            loss.backward()
            opt.step()

            y_pred = torch.sigmoid(data.y_pred)
            y_pred = y_pred.cpu().detach().numpy()
            y_true = y_true.cpu().detach().numpy()
            all_y_pred.append(y_pred)
            all_y_true.append(y_true)
            train_loss.append(loss.cpu().detach().numpy())

        score = evaluator.eval({
            "y_true": np.vstack(all_y_true),
            "y_pred": np.vstack(all_y_pred)
        })
        train_score = score[score_key]
        train_loss = np.mean(train_loss)

        model.save(f'model_epoch_{e}')

        val_score, val_std = eval(model,
                                  val_iter,
                                  repeats=model.config['eval_rep'])
        print(
            f'Epoch {e} Loss: {train_loss:.4f}, Train: {train_score:.4f}, Val: {val_score:.4f} (+-{val_std:.4f})'
        )

        if val_score >= best_val_score:
            best_val_score = val_score
            model.save()

        writer.add_scalar('Score/val', val_score, e)

        writer.add_scalar('Loss/train', train_loss, e)
        writer.add_scalar('Score/train', train_score, e)

        sch.step()
Exemple #13
0
    def _transform3(theta, input_dim):
        input_dim = input_dim.permute([0, 2, 3, 1])
        num_batch = input_dim.size()[0]
        num_channels = input_dim.size()[3]
        theta = torch._cast_Float(theta)
        Hs = get_Hs(theta)
        gh = int(math.floor(height / grid_h))
        gw = int(math.floor(width / grid_w))
        x_ = []
        y_ = []

        for i in range(grid_h):
            row_x_ = []
            row_y_ = []
            for j in range(grid_w):
                H = Hs[:, i:i + 1, j:j + 1, :].view(num_batch, 3, 3)
                sh = i * gh
                eh = (i + 1) * gh - 1
                sw = j * gw
                ew = (j + 1) * gw - 1
                if (i == grid_h - 1):
                    eh = height - 1
                if (j == grid_w - 1):
                    ew = width - 1
                grid = _meshgrid2(height, width, sh, eh, sw, ew)
                grid = grid.unsqueeze(0)
                grid = grid.repeat([num_batch, 1, 1])

                T_g = torch.matmul(H, grid)
                x_s = T_g[:, 0:1, :]
                y_s = T_g[:, 1:2, :]
                z_s = T_g[:, 2:3, :]

                z_s_flat = z_s.contiguous().view(-1)
                t_1 = torch.ones(z_s_flat.size()).cuda()
                t_0 = torch.zeros(z_s_flat.size()).cuda()

                sign_z_flat = torch.where(z_s_flat >= 0, t_1, t_0) * 2 - 1
                z_s_flat = z_s.contiguous().view(-1) + sign_z_flat * 1e-8
                x_s_flat = x_s.contiguous().view(-1) / z_s_flat
                y_s_flat = y_s.contiguous().view(-1) / z_s_flat

                x_s = x_s_flat.view([num_batch, eh - sh + 1, ew - sw + 1])
                y_s = y_s_flat.view([num_batch, eh - sh + 1, ew - sw + 1])
                row_x_.append(x_s)
                row_y_.append(y_s)
            row_x = torch.cat(row_x_, dim=2)
            row_y = torch.cat(row_y_, dim=2)
            x_.append(row_x)
            y_.append(row_y)

        x = torch.cat(x_, dim=1).view([num_batch, height, width, 1])
        y = torch.cat(y_, dim=1).view([num_batch, height, width, 1])
        x_map_ = x.clone()
        y_map_ = y.clone()
        img = torch.cat([x, y], dim=3)
        x_s_flat = x.view(-1)
        y_s_flat = y.view(-1)

        t_1 = torch.ones(x_s_flat.size()).cuda()
        t_0 = torch.zeros(x_s_flat.size()).cuda()

        cond = (torch.gt(t_1 * -1, x_s_flat) | torch.gt(x_s_flat, t_1)) | \
            (torch.gt(t_1 * -1, y_s_flat) | torch.gt(y_s_flat, t_1))

        black_pix = torch.where(cond, t_1,
                                t_0).view([num_batch, height, width])

        out_size = (height, width)
        input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat,
                                         out_size)

        output = input_transformed.view(
            [num_batch, height, width, num_channels])
        output = output.permute([0, 3, 1, 2])

        return output, black_pix, img, x_map_, y_map_