from framework.module import NamedModule
from framework.nn.ops.pairwise import L2Norm2
from loss.losses import Samples_Loss
from modules.image2measure import ResImageToMeasure
from modules.linear_ot import LinearTransformOT, LinearTransformOT_bk
from modules.measure2image import MeasureToImage, ResMeasureToImage, Measure2imgTmp
import os
import cv2

from parameters.dataset import DatasetParameters
from parameters.deformation import DeformationParameters
from parameters.gan import GanParameters

parser = argparse.ArgumentParser(
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    parents=[DatasetParameters(),
             GanParameters(),
             DeformationParameters()])
args = parser.parse_args()
for k in vars(args):
    print(f"{k}: {vars(args)[k]}")

device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)

dataset_test = SegmentationDataset(
    "/home/nazar/PycharmProjects/mrt",
    transform=albumentations.Compose([
        albumentations.Resize(args.image_size, args.image_size),
        albumentations.CenterCrop(args.image_size, args.image_size),
        ToTensorV2()
Пример #2
0
    parser.add_argument('--r1', type=float, default=10)
    parser.add_argument('--path_regularize', type=float, default=2)
    parser.add_argument('--path_batch_shrink', type=int, default=2)
    parser.add_argument('--d_reg_every', type=int, default=16)
    parser.add_argument('--g_reg_every', type=int, default=4)
    parser.add_argument('--mixing', type=float, default=0.9)
    parser.add_argument('--ckpt', type=str, default=None)
    parser.add_argument('--lr', type=float, default=0.002)
    parser.add_argument('--channel_multiplier', type=int, default=1)
    parser.add_argument('--wandb', action='store_true')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()

    parser = argparse.ArgumentParser(
        parents=[
            DatasetParameters(),
            GanParameters(),
            DeformationParameters(),
            MunitParameters()
        ],
        # formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    munit_args = parser.parse_args()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    torch.cuda.set_device(device)

    cont_style_encoder: MunitEncoder = cont_style_munit_enc(
        munit_args,
        None,  # "/home/ibespalov/pomoika/munit_content_encoder15.pt",
Пример #3
0
from dataset.toheatmap import ToGaussHeatMap
from dataset.probmeasure import UniformMeasure2D01
import pandas as pd
import networkx as nx
import ot
from barycenters.sampler import Uniform2DBarycenterSampler, Uniform2DAverageSampler
from parameters.path import Paths
from joblib import Parallel, delayed

N = 100
D = np.load(f"{Paths.default.models()}/hum36_graph{N}.npy")
padding = 32
prob = np.ones(padding) / padding
NS = 13410

parser = DatasetParameters()
args = parser.parse_args()
for k in vars(args):
    print(f"{k}: {vars(args)[k]}")
    "0"

data = SimpleHuman36mDataset()
data.initialize(args.data_path)


def LS(k):
    return data[k]["paired_B"].numpy()


ls = np.asarray([LS(k) for k in range(0, N, 1)])
# ls2 = np.asarray([LS(k) for k in range(NN, 2 * NN)])