예제 #1
0
    def __call__(self, rgb_img, label_img, depth_img=None):
        w, h = rgb_img.size
        pad_along_w = max(0, int((1 + self.crop_size[0] - w) / 2))
        pad_along_h = max(0, int((1 + self.crop_size[1] - h) / 2))
        # padd the images
        rgb_img = Pad(padding=(pad_along_w, pad_along_h),
                      fill=0,
                      padding_mode='constant')(rgb_img)
        if label_img is not None:
            label_img = Pad(padding=(pad_along_w, pad_along_h),
                            fill=self.ignore_idx,
                            padding_mode='constant')(label_img)
        if depth_img is not None:
            depth_img = Pad(padding=(pad_along_w, pad_along_h),
                            fill=0,
                            padding_mode='constant')(depth_img)

        i, j, h, w = self.get_params(rgb_img, self.crop_size)
        rgb_img = F.crop(rgb_img, i, j, h, w)
        if label_img is not None:
            label_img = F.crop(label_img, i, j, h, w)
        if depth_img is not None:
            depth_img = F.crop(depth_img, i, j, h, w)

        if depth_img is not None:
            return rgb_img, label_img, depth_img
        else:
            return rgb_img, label_img
예제 #2
0
def adapt_size(cell):
    origin_h, origin_w = cell.shape
    if origin_h > origin_w:

        cell = gen_ToPILImage(cell)
        h2 = char_hw
        w2 = int((char_hw / origin_h) * origin_w)
        cell = Resize((h2, w2), interpolation=PIL.Image.NEAREST)(cell)
        pad = char_hw - w2
        if (pad == 0):
            pass
        elif (pad % 2 == 1):
            cell = Pad((pad // 2 + 1, 0, pad // 2, 0), 0)(cell)
        else:
            cell = Pad((pad // 2, 0, pad // 2, 0), 0)(cell)
    else:
        cell = gen_ToPILImage(cell)
        w2 = char_hw
        h2 = int((char_hw / origin_w) * origin_h)
        cell = Resize((h2, w2), interpolation=PIL.Image.NEAREST)(cell)
        pad = char_hw - h2
        if (pad == 0):
            pass
        elif (pad % 2 == 1):
            cell = Pad((0, pad // 2 + 1, 0, pad // 2), 0)(cell)
        else:
            cell = Pad((0, pad // 2, 0, pad // 2), 0)(cell)
    return gen_ToTensor(np.array(cell))
예제 #3
0
파일: TCAE_data.py 프로젝트: yumingh97/TCAE
 def __init__(self,
              num_views,
              random_seed,
              dataset,
              additional_face=True,
              jittering=False):
     if dataset == 1:
         self.ids = np.load('../Datasets/voxceleb1_ori/train.npy')
     if dataset == 2:
         self.ids = np.load('../Datasets/voxceleb1_ori/val.npy')
     if dataset == 3:
         self.ids = np.load('../Datasets/voxceleb1_ori/test.npy')
     self.rng = np.random.RandomState(random_seed)
     self.num_views = num_views
     #self.base_file = os.environ['VOX_CELEB_LOCATION'] + '/%s/'
     self.base_file = VOX_CELEB_LOCATION + '/%s/'
     crop = 200
     if jittering == True:
         precrop = crop + 24
         crop = self.rng.randint(crop, precrop)
         self.pose_transform = Compose([
             Scale((256, 256)),
             Pad((20, 80, 20, 30)),
             CenterCrop(precrop),
             RandomCrop(crop),
             Scale((256, 256)),
             ToTensor()
         ])
         self.transform = Compose([
             Scale((256, 256)),
             Pad((24, 24, 24, 24)),
             CenterCrop(precrop),
             RandomCrop(crop),
             Scale((256, 256)),
             ToTensor()
         ])
     else:
         precrop = crop + 24
         self.pose_transform = Compose([
             Scale((256, 256)),
             Pad((20, 80, 20, 30)),
             CenterCrop(crop),
             Scale((256, 256)),
             ToTensor()
         ])
         self.transform = Compose([
             Scale((256, 256)),
             Pad((24, 24, 24, 24)),
             CenterCrop(precrop),
             Scale((256, 256)),
             ToTensor()
         ])
예제 #4
0
def detect_image(img: Image, img_size: int, model: Darknet, conf_thresh: float,
                 nms_thresh: float):
    # scale and pad image
    ratio = min(img_size / img.size[0], img_size / img.size[1])
    imw = round(img.size[0] * ratio)
    imh = round(img.size[1] * ratio)
    img_transforms = Compose([
        Resize((imh, imw)),
        Pad(
            (
                max(int((imh - imw) / 2), 0),
                max(int((imw - imh) / 2), 0),
                max(int((imh - imw) / 2), 0),
                max(int((imw - imh) / 2), 0),
            ),
            (128, 128, 128),
        ),
        ToTensor(),
    ])
    # convert image to Tensor
    image_tensor = img_transforms(img).float()
    image_tensor = image_tensor.unsqueeze_(0)
    input_img = Variable(image_tensor.type(Tensor))
    # run inference on the model and get detections
    with torch.no_grad():
        detections = model(input_img)
        detections = non_max_suppression(detections, 80, conf_thresh,
                                         nms_thresh)
    return detections[0]
예제 #5
0
def get_train_test_loaders(dataset_name, path, batch_size, num_workers):

    assert dataset_name in datasets.__dict__, "Unknown dataset name {}".format(dataset_name)
    fn = datasets.__dict__[dataset_name]

    train_transform = Compose([
        Pad(2),
        RandomCrop(32),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25)),
    ])

    test_transform = Compose([
        ToTensor(),
        Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25)),
    ])

    train_ds = fn(root=path, train=True, transform=train_transform, download=True)
    test_ds = fn(root=path, train=False, transform=test_transform, download=False)

    train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size * 2, num_workers=num_workers, pin_memory=True)

    return train_loader, test_loader
예제 #6
0
 def setUpClass(cls):
     transform = Compose([Pad(2), ToTensor()])
     cls.train_mnist = MNIST(join(dirname(__file__), "tmp/mnist"), True, transform, download=True)
     cls.test_mnist = MNIST(join(dirname(__file__), "tmp/mnist"), False, transform, download=True)
     cls.input_size = cls.train_mnist[0][0].shape
     cls.output_size = 10
     cls.random_error = 1 - (1 / cls.output_size)
예제 #7
0
def get_test_transform(length=T):
    trans_list = [ToPILImage(),
                  Pad((length // 2, 0)),
                  TenCrop((1, length)),
                  Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])),
                  Centring(MAX_INT)]
    return transforms.Compose([ConvertToTuple(default_transforms) for default_transforms in trans_list])
예제 #8
0
def save_sample_sheet(cgn, u_fixed, sample_path, ep_str):
    cgn.eval()
    dev = u_fixed.to(cgn.get_device())
    ys = [15, 251, 330, 382, 385, 483, 559, 751, 938, 947, 999]

    to_save = []
    with torch.no_grad():
        for y in ys:
            # generate
            y_vec = cgn.get_class_vec(y, sz=1)
            inp = (u_fixed.to(dev), y_vec.to(dev), cgn.truncation)
            x_gt, mask, premask, foreground, background, bg_mask = cgn(inp)
            x_gen = mask * foreground + (1 - mask) * background

            # build class grid
            to_plot = [premask, foreground, background, x_gen, x_gt]
            grid = make_grid(torch.cat(to_plot).detach().cpu(),
                             nrow=len(to_plot),
                             padding=2,
                             normalize=True)

            # add unnormalized mask
            mask = Pad(2)(mask[0].repeat(3, 1, 1)).detach().cpu()
            grid = torch.cat([mask, grid], 2)

            # save to disk
            to_save.append(grid)
            del to_plot, mask, premask, foreground, background, x_gen, x_gt

    # save the image
    path = join(sample_path, f'cls_sheet_' + ep_str + '.png')
    torchvision.utils.save_image(torch.cat(to_save, 1), path)
    cgn.train()
예제 #9
0
def get_train_transform(length=T):
    trans_list = [ToPILImage(),
                  Pad((length // 2, 0)),
                  RandomCrop((1, length)),
                  ToTensor(),
                  Centring(MAX_INT)]
    return transforms.Compose([ConvertToTuple(default_transforms) for default_transforms in trans_list])
예제 #10
0
    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be scaled.

        Returns:
            PIL Image: Rescaled image.
        """
        width, height = img.size
        if width >= height:
            new_image = np.zeros((width, width))
            size_diff = width - height
            return Pad(padding=(0, int(size_diff / 2))).__call__(img)
        if height > width:
            new_image = np.zeros((height, height))
            size_diff = height - width
            return Pad(padding=(int(size_diff / 2), 0)).__call__(img)
예제 #11
0
def get_train_test_loaders(path,
                           batch_size,
                           num_workers,
                           distributed=False,
                           pin_memory=True):

    train_transform = Compose([
        Pad(4),
        RandomCrop(32, fill=128),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    test_transform = Compose([
        ToTensor(),
        Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    if not os.path.exists(path):
        os.makedirs(path)
        download = True
    else:
        download = True if len(os.listdir(path)) < 1 else False

    train_ds = datasets.CIFAR10(root=path,
                                train=True,
                                download=download,
                                transform=train_transform)
    test_ds = datasets.CIFAR10(root=path,
                               train=False,
                               download=False,
                               transform=test_transform)

    train_sampler = None
    test_sampler = None
    if distributed:
        train_sampler = DistributedSampler(train_ds)
        test_sampler = DistributedSampler(test_ds, shuffle=False)

    train_labelled_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    test_loader = DataLoader(
        test_ds,
        batch_size=batch_size * 2,
        sampler=test_sampler,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    return train_labelled_loader, test_loader
예제 #12
0
def get_train_transform(length=None):
    transforms = [
        ToPILImage(),
        Pad((length // 2, 0)),
        RandomCrop((1, length)),
        ToTensor(),
        Centring(MAX_INT)
    ]
    return torchvision.transforms.Compose(transforms)
예제 #13
0
def get_augmentation(augmentation, dataset, data_shape):
    c, h, w = data_shape
    if augmentation is None:
        pil_transforms = []
    elif augmentation == 'horizontal_flip':
        pil_transforms = [RandomHorizontalFlip(p=0.5)]
    elif augmentation == 'neta':
        assert h==w
        pil_transforms = [Pad(int(math.ceil(h * 0.04)), padding_mode='edge'),
                          RandomAffine(degrees=0, translate=(0.04, 0.04)),
                          CenterCrop(h)]
    elif augmentation == 'eta':
        assert h==w
        pil_transforms = [RandomHorizontalFlip(),
                          Pad(int(math.ceil(h * 0.04)), padding_mode='edge'),
                          RandomAffine(degrees=0, translate=(0.04, 0.04)),
                          CenterCrop(h)]
    return pil_transforms
예제 #14
0
def get_test_transform(length=None):
    transforms = [
        ToPILImage(),
        Pad((length // 2, 0)),
        TenCrop((1, length)),
        Lambda(
            lambda crops: torch.stack([ToTensor()(crop) for crop in crops])),
        Centring(MAX_INT)
    ]
    return torchvision.transforms.Compose(transforms)
    def __init__(self, input_size=(1920, 1080), side: int = 416):
        self.scale = np.max(input_size) / side

        if input_size[0] > input_size[1]:
            self.pad = (0, int(input_size[0] - input_size[1]) // 2)
        else:
            self.pad = (int(input_size[1] - input_size[0]) // 2, 0)
        super().__init__(
            [Pad(self.pad, fill=0),
             Resize((side, side)),
             ToTensor()])
예제 #16
0
def SliceStack(height: int = 40, width : int = 40):
    offsets = [0, 25, 55, 80]
    transes = Compose([
        Pad(12),
        CenterCrop(64),
        ToTensor(),
        Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    def stack(img):
        crops = [F.crop(img, 0, offset, height, width) for offset in offsets]
        return torch.stack([transes(crop) for crop in crops])
    return stack
예제 #17
0
 def __init__(self, root, train=True, download=False, process=False):
     transforms = Compose([
         Pad(padding=4),
         RandomCrop(size=32),
         RandomHorizontalFlip(p=0.5),
         ToTensor()
     ])
     super(WrappedQuickDraw, self).__init__(root,
                                            train,
                                            download,
                                            process,
                                            transforms=transforms)
예제 #18
0
def estimate(tenOrig):
    global netNetwork

    if netNetwork is None:
        netNetwork = Network('bsds500').cuda().eval()
    # end

    intPadWidth = 32
    tenInput = Pad(intPadWidth, padding_mode='edge')(tenOrig).float()
    intWidth = tenInput.shape[2]
    intHeight = tenInput.shape[1]

    tenOutput = torch.zeros(tenOrig.shape)
    arrShift = [100, 150, 200]
    for intShift in arrShift:
        tenInf = netNetwork(tenInput.cuda().view(1, 3, intHeight, intWidth),
                            intShift)[0, :, :, :].cpu()
        tenOutput = torch.maximum(
            tenOutput,
            CenterCrop(tenOrig.shape[1:3])(tenInf).float())

    return tenOutput
예제 #19
0
def get_train_transforms(jitter, dataset):
    if dataset == STL10:
        crop_size = 88
    else:
        crop_size = 32
    if not jitter:
        return Compose([
            ToTensor(),
            Pad(2),
            RandomCrop(crop_size),
            RandomHorizontalFlip(p=0.5)
        ])
    else:
        return Compose([
            RandomChoice([
                ColorJitter(brightness=(0.3, 1.0)),
                ColorJitter(contrast=(0.3, 1.0))
            ]),
            ToTensor(),
            Pad(2),
            RandomCrop(crop_size),
            RandomHorizontalFlip(p=0.5)
        ])
예제 #20
0
def CreateMnistDataloader(path, batch_size):
    transform = Compose([Pad(padding=2), ToTensor()])

    trainset = torchvision.datasets.MNIST(root=path, train=True,
                                          download=False, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)

    testset = torchvision.datasets.MNIST(root=path, train=False,
                                         download=False, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                             shuffle=False, num_workers=2)

    return trainloader, testloader
예제 #21
0
    def _build(self, ):
        self.pad = Pad(self.kernel_size // 2, padding_mode=self.padding_mode)
        self.feature_embedding = nn.Sequential(
            nn.Conv2d(self.in_channels, self.embedding_size, self.kernel_size),
            nn.Dropout(self.dropout),
            nn.Flatten(),
            nn.BatchNorm1d(self.embedding_size),
            nn.ReLU(),
        )

        self.positional_embedding = nn.Linear(2, self.embedding_size)

        self.sequence_module = nn.LSTM(
            self.embedding_size,
            self.hidden_size,
            num_layers=1,
            batch_first=True,
        )

        self.mdn_module = MixtureDensityNetwork(
            self.hidden_size,
            sample_size=self.embedding_size + 2,
            n_components=self.n_components,
            forward_mode=MixtureDensityNetwork.FORWARD_SAMPLE)
예제 #22
0
def validate_dir(root: str,
                 checkpoint_path: str,
                 position: int,
                 ignore_case: bool = True,
                 batch_size: int = 100):
    assert position >= 0 and position <= 3

    checkpoint = torch.load(checkpoint_path)
    classes = checkpoint['classes']
    model = resnet18(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, len(checkpoint['classes']))
    model.load_state_dict(checkpoint['model'])
    model = model.cuda()

    transes = Compose([
        SliceCrop(position),
        Pad(12),
        CenterCrop(64),
        ToTensor(),
        Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    dataset = ImageFolder(root,
                          position,
                          checkpoint['classes'],
                          transform=transes)
    dataloader = DataLoader(dataset,
                            shuffle=False,
                            num_workers=4,
                            batch_size=batch_size)

    def classes_map(ignore_case):
        low_classes = torch.tensor([classes.index(c.lower())
                                    for c in classes]).cuda()

        def tmap(preds: torch.LongTensor, y: torch.LongTensor):
            if not ignore_case:
                return preds, y
            return low_classes[preds], low_classes[y]

        return tmap

    with torch.no_grad():
        acc = validate(model, dataloader, classes_map(ignore_case))
    print(f'validate acc: {acc:.3f}')
    def __init__(self,
                 root,
                 train=True,
                 transform=None,
                 target_transform=None,
                 download=False,
                 pad=32,
                 translate=0.4):
        super().__init__(root,
                         train=train,
                         target_transform=target_transform,
                         download=download)

        self._transform = transform

        self.p = pad
        self.sz = (2 * self.p) + 32

        self._pad = Pad(pad)
        self._affine = RandomAffine(0, translate=(translate, translate))
예제 #24
0
 def __init__(self, root, train=True, download=False, *args, **kwargs):
     if train:
         transforms = Compose([
             Pad(padding=4),
             RandomCrop(size=32),
             RandomHorizontalFlip(p=0.5),
             ToTensor(),
             Normalize(mean=(0.485, 0.456, 0.406),
                       std=(0.229, 0.224, 0.225))
         ])
     else:
         transforms = Compose([
             RandomHorizontalFlip(p=0.5),
             ToTensor(),
             Normalize(mean=(0.485, 0.456, 0.406),
                       std=(0.229, 0.224, 0.225))
         ])
     super(WrappedCIFAR100, self).__init__(
         root=root, train=train, transform=transforms,
         download=download, *args, **kwargs)
예제 #25
0
파일: cli.py 프로젝트: chr5tphr/attbtor
def data(ctx, bsize, train, force_trainset, datapath, regex, dataset, download, shuffle, workers):
    if regex is not None:
        rvalid = re.compile(regex)
        def is_valid_file(fpath):
            return rvalid.fullmatch(fpath) is not None
    else:
        is_valid_file = None

    if dataset == 'CIFAR10':
        transf = Compose(([RandomCrop(32, padding=4), RandomHorizontalFlip()] if train else []) + [ToTensor(), Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
        dset = CIFAR10(root=datapath, train=train or force_trainset, transform=transf, download=download)
    elif dataset == 'MNIST':
        transf = Compose([Pad(2), ToTensor(), Normalize(0.4734, 0.2009)])
        dset = MNIST(root=datapath, train=train or force_trainset, transform=transf, download=download)
    elif dataset == 'Imagenet-12':
        transf = Compose(([RandomResizedCrop(224), RandomHorizontalFlip()] if train else [CenterCrop(224)]) + [ToTensor(), Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
        dset = ImageFolder(root=datapath, transform=transf, is_valid_file=is_valid_file)
    else:
        raise RuntimeError('No such dataset!')
    loader  = DataLoader(dset, bsize, shuffle=shuffle, num_workers=workers)
    ctx.obj.loader = loader
예제 #26
0
    def __init__(self, train=True):
        super(CustomCIFAR100, self).__init__()
        self.cifar_100 = CIFAR100(root='datasets', train=train, download=True)
        #  self._cifar10_train => index[0] = data, index[1] = label, format = PIL

        tensors = list()
        for i in range(len(self.cifar_100)):
            tensors.append(ToTensor()(
                self.cifar_100[i][0]).numpy())  # cifar100
            # ToTensor 는 들어온 이미지나 ndarray(H, W, C)를 (C, H, W)의 텐서로 바꿔서 돌려줌
        mean = np.mean(tensors,
                       axis=(0, 2,
                             3))  # 채널 별 평균 구하기. ToTensor를 썼기 때문에 index=1이 채널
        std = np.std(tensors,
                     axis=(0, 2,
                           3))  # 채널 별 표준편차 구하기. ToTensor를 썼기 때문에 index=1이 채널
        print("mean: {}, std: {}".format(mean, std))

        transform = [RandomHorizontalFlip()]
        transform += [Pad(4), RandomCrop(32)]
        transform += [ToTensor(), Normalize(mean=mean, std=std)]
        self.transform = Compose(transform)
예제 #27
0
    def test_dataset_generation(self):
        transform = Compose([Pad(2), ToTensor()])
        train_mnist = MNIST(join(dirname(__file__), "tmp/mnist"),
                            True,
                            transform,
                            download=True)
        input_size = train_mnist[0][0].shape
        number_of_classes = 10
        resnet = ResNet(20, input_size, number_of_classes)

        # Find a minimiser for the network
        optim_wrapper = ModelWrapper(
            DataModel(CompareModel(resnet, NLLLoss()), {"train": train_mnist}))
        optim_wrapper.to("cuda")
        optim_config = OptimConfig(100, SGD, {"lr": 0.1}, None, None,
                                   EvalConfig(128))
        minimum = find_minimum(optim_wrapper, optim_config)
        optim_wrapper.set_coords_no_grad(minimum["coords"])

        nim = NetworkInputModel(resnet, input_size, 0)
        nim.cuda()
        resnet.cuda()
        dataset = nim.generate_dataset(train_mnist, number_of_classes)
        self.assertEqual(len(dataset), 100)
예제 #28
0
import ignite.distributed as idist
from torchvision import datasets
from torchvision.transforms import (
    Compose,
    Normalize,
    Pad,
    RandomCrop,
    RandomHorizontalFlip,
    ToTensor,
)

train_transform = Compose(
    [
        Pad(4),
        RandomCrop(32, fill=128),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

eval_transform = Compose(
    [
        ToTensor(),
        Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)


def get_datasets(path):
    local_rank = idist.get_local_rank()
예제 #29
0
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Normalize, Compose, Pad, Lambda

from horch.config import cfg
from horch.datasets import train_test_split, CombineDataset
from horch.nas.darts.model_search_gdas import Network
from horch.nas.darts.trainer import DARTSTrainer
from horch.train import manual_seed
from horch.train.metrics import TrainLoss, Loss
from horch.train.metrics.classification import Accuracy

manual_seed(0)

train_transform = Compose([
    Pad(2),
    ToTensor(),
    Normalize((0.1307, ), (0.3081, )),
    Lambda(lambda x: x.expand(3, -1, -1))
])

root = '/Users/hrvvi/Code/study/pytorch/datasets'
ds_all = MNIST(root=root, train=True, download=True, transform=train_transform)

ds = train_test_split(ds_all, test_ratio=0.001, random=True)[1]
ds_train, ds_val = train_test_split(ds, test_ratio=0.5, random=True)
ds = CombineDataset(ds_train, ds_val)

train_loader = DataLoader(ds, batch_size=2, pin_memory=True, num_workers=2)
val_loader = DataLoader(ds_val, batch_size=2, pin_memory=True, num_workers=2)
batch_size = args.batch_size
epoch = args.num_epoch
save_path = 'model_save/'

#normalize for ImageNet
normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])

crop = 200
rng = np.random.RandomState(args.random_seed)
precrop = crop + 24
crop = rng.randint(crop, precrop)
transformations = Compose([
    Scale((256, 256)),
    Pad((24, 24, 24, 24)),
    CenterCrop(precrop),
    RandomCrop(crop),
    Scale((256, 256)),
    ToTensor(), normalize
])


#define a batch-wise l2 loss
def criterion_l2(input_f, target_f):
    # return a per batch l2 loss
    res = (input_f - target_f)
    res = res * res
    return res.sum(dim=2)