Esempio n. 1
0
    def __init__(self,
                 input_dim,
                 output_dim,
                 bond_dim,
                 feature_dim=2,
                 nCh=3,
                 kernel=2,
                 virtual_dim=1,
                 bn=False,
                 dropout=0.0,
                 seg=False,
                 adaptive_mode=False,
                 periodic_bc=False,
                 parallel_eval=False,
                 label_site=None,
                 path=None,
                 init_std=1e-9,
                 use_bias=True,
                 fixed_bias=True,
                 cutoff=1e-10,
                 merge_threshold=2000):
        super().__init__()
        self.input_dim = input_dim
        nDim = len(self.input_dim)
        self.nDim = nDim
        self.bn = bn
        self.dropout = dropout
        self.vdim = 1
        ## Work out the depth of the tensornet
        self.dropout = dropout
        self.ker = [16, 4]

        nL = np.log(self.input_dim[0].item()) / np.log(self.ker[0])
        self.L = int(np.floor(nL)) - 1

        nCh = self.ker[0]**nDim * nCh
        self.L = len(self.ker)
        print("Using depth of %d" % (self.L + 1))
        self.nCh = nCh
        self.lFeat = 1
        feature_dim = self.lFeat * nCh

        ### First level MPS blocks
        self.module = nn.ModuleList([
            MPS(input_dim=((self.vdim - 1) * int(i > 0) + 1) *
                torch.prod(self.input_dim // (np.prod(self.ker[:i + 1]))),
                output_dim=self.vdim *
                torch.prod(self.input_dim // np.prod(self.ker[:i + 1])),
                bond_dim=bond_dim,
                lFeat=self.lFeat,
                feature_dim=self.lFeat * (self.ker[i])**nDim,
                parallel_eval=parallel_eval,
                adaptive_mode=adaptive_mode,
                periodic_bc=periodic_bc) for i in range(self.L)
        ])

        if self.bn:
            self.BN = nn.ModuleList([nn.BatchNorm1d(self.vdim*torch.prod(self.input_dim//(np.prod(self.ker[:i+1]))).numpy(),\
                                        affine=True) for i in range(self.L)])

        ### Third level MPS blocks
        ### Final MPS block
        self.mpsFinal = MPS(input_dim=self.vdim *
                            torch.prod(self.input_dim // np.prod(self.ker)),
                            output_dim=output_dim,
                            lFeat=self.lFeat,
                            bond_dim=bond_dim,
                            feature_dim=self.lFeat,
                            adaptive_mode=adaptive_mode,
                            periodic_bc=periodic_bc,
                            parallel_eval=parallel_eval)
Esempio n. 2
0
    def __init__(self,
                 input_dim,
                 output_dim,
                 bond_dim,
                 feature_dim=2,
                 nCh=3,
                 kernel=2,
                 virtual_dim=1,
                 adaptive_mode=False,
                 periodic_bc=False,
                 parallel_eval=False,
                 label_site=None,
                 path=None,
                 init_std=1e-9,
                 use_bias=True,
                 fixed_bias=True,
                 cutoff=1e-10,
                 merge_threshold=2000):
        super().__init__()
        self.input_dim = input_dim
        self.virtual_dim = bond_dim

        ### Squeezing of spatial dimension in first step
        self.kScale = 4
        nCh = self.kScale**2 * nCh
        self.input_dim = self.input_dim // self.kScale

        self.nCh = nCh
        self.ker = kernel
        iDim = (self.input_dim // (self.ker))

        feature_dim = 2 * nCh

        ### First level MPS blocks
        self.module1 = nn.ModuleList([
            MPS(input_dim=(self.ker)**2,
                output_dim=self.virtual_dim,
                nCh=nCh,
                bond_dim=bond_dim,
                feature_dim=feature_dim,
                parallel_eval=parallel_eval,
                adaptive_mode=adaptive_mode,
                periodic_bc=periodic_bc) for i in range(torch.prod(iDim))
        ])

        self.BN1 = nn.BatchNorm1d(torch.prod(iDim).numpy(), affine=True)

        iDim = iDim // self.ker
        feature_dim = 2 * self.virtual_dim

        ### Second level MPS blocks
        self.module2 = nn.ModuleList([
            MPS(input_dim=self.ker**2,
                output_dim=self.virtual_dim,
                nCh=self.virtual_dim,
                bond_dim=bond_dim,
                feature_dim=feature_dim,
                parallel_eval=parallel_eval,
                adaptive_mode=adaptive_mode,
                periodic_bc=periodic_bc) for i in range(torch.prod(iDim))
        ])

        self.BN2 = nn.BatchNorm1d(torch.prod(iDim).numpy(), affine=True)

        iDim = iDim // self.ker

        ### Third level MPS blocks
        self.module3 = nn.ModuleList([
            MPS(input_dim=self.ker**2,
                output_dim=self.virtual_dim,
                nCh=self.virtual_dim,
                bond_dim=bond_dim,
                feature_dim=feature_dim,
                parallel_eval=parallel_eval,
                adaptive_mode=adaptive_mode,
                periodic_bc=periodic_bc) for i in range(torch.prod(iDim))
        ])

        self.BN3 = nn.BatchNorm1d(torch.prod(iDim).numpy(), affine=True)

        ### Final MPS block
        self.mpsFinal = MPS(input_dim=len(self.module3),
                            output_dim=output_dim,
                            nCh=1,
                            bond_dim=bond_dim,
                            feature_dim=feature_dim,
                            adaptive_mode=adaptive_mode,
                            periodic_bc=periodic_bc,
                            parallel_eval=parallel_eval)
    def __init__(self,
                 input_dim,
                 output_dim,
                 bond_dim,
                 feature_dim=2,
                 nCh=3,
                 kernel1=[2, 2, 2],
                 kernel2=[2, 2, 2],
                 virtual_dim=1,
                 adaptive_mode=False,
                 periodic_bc=False,
                 parallel_eval=False,
                 label_site=None,
                 path=None,
                 init_std=1e-9,
                 use_bias=True,
                 fixed_bias=True,
                 cutoff=1e-10,
                 merge_threshold=2000):
        super().__init__()
        self.input_dim = input_dim
        self.virtual_dim = bond_dim

        ### Squeezing of spatial dimension in first step
        self.LoTe_kScale = 4  # what is this?
        LoTe_nCh = self.LoTe_kScale**2 * nCh
        self.LoTe_input_dim = self.input_dim / self.LoTe_kScale

        # print(nCh)
        self.LoTe_nCh = LoTe_nCh
        if isinstance(kernel1, int):
            LoTe_kernel = 3 * [kernel1]
        else:
            LoTe_kernel = kernel1
        self.LoTe_ker = LoTe_kernel
        LoTe_iDim = (self.LoTe_input_dim / (self.LoTe_ker[0]))

        LoTe_feature_dim = 2 * LoTe_nCh

        ## Parameters for Conv
        self.Conv_input_dim = self.input_dim

        self.Conv_nCh = nCh
        if isinstance(kernel2, int):
            Conv_kernel = 3 * [kernel2]
        else:
            Conv_kernel = kernel2
        self.Conv_ker = Conv_kernel
        Conv_iDim = (self.Conv_input_dim / (self.Conv_ker[0]))
        Conv_feature_dim = 2 * self.Conv_ker[0]**2

        ### First level MPS blocks
        self.LoTe_module1 = nn.ModuleList([
            MPS(input_dim=(self.LoTe_ker[0])**2,
                output_dim=self.virtual_dim,
                nCh=nCh,
                bond_dim=bond_dim,
                feature_dim=LoTe_feature_dim,
                parallel_eval=parallel_eval,
                adaptive_mode=adaptive_mode,
                periodic_bc=periodic_bc) for i in range(torch.prod(LoTe_iDim))
        ])

        self.LoTe_BN1 = nn.BatchNorm1d(torch.prod(LoTe_iDim).numpy(),
                                       affine=True)

        LoTe_iDim = LoTe_iDim / self.LoTe_ker[1]
        LoTe_feature_dim = 2 * self.virtual_dim

        ### Second level MPS blocks
        self.LoTe_module2 = nn.ModuleList([
            MPS(input_dim=self.LoTe_ker[1]**2 + self.Conv_ker[1]**2,
                output_dim=self.virtual_dim,
                nCh=self.virtual_dim,
                bond_dim=bond_dim,
                feature_dim=LoTe_feature_dim,
                parallel_eval=parallel_eval,
                adaptive_mode=adaptive_mode,
                periodic_bc=periodic_bc) for i in range(torch.prod(LoTe_iDim))
        ])

        self.LoTe_BN2 = nn.BatchNorm1d(torch.prod(LoTe_iDim).numpy(),
                                       affine=True)
        LoTe_iDim = LoTe_iDim / self.LoTe_ker[2]

        ### Third level MPS blocks
        self.LoTe_module3 = nn.ModuleList([
            MPS(input_dim=self.LoTe_ker[2]**2,
                output_dim=self.virtual_dim,
                nCh=self.virtual_dim,
                bond_dim=bond_dim,
                feature_dim=2 * LoTe_feature_dim,
                parallel_eval=parallel_eval,
                adaptive_mode=adaptive_mode,
                periodic_bc=periodic_bc) for i in range(torch.prod(LoTe_iDim))
        ])

        self.LoTe_BN3 = nn.BatchNorm1d(torch.prod(LoTe_iDim).numpy(),
                                       affine=True)

        ### Final MPS block
        # self.LoTe_mpsFinal = MPS(input_dim=len(self.LoTe_module3),
        # 					output_dim=output_dim, nCh=1,
        # 					bond_dim=bond_dim, feature_dim=LoTe_feature_dim,
        # 					adaptive_mode=adaptive_mode, periodic_bc=periodic_bc,
        # 					parallel_eval=parallel_eval)

        ##############################################################################
        ##############################################################################
        ############################ #  #######   ##### ###### #######################
        ########################### ########## ### ##### #### ########################
        ########################### ########## ### ###### ## #########################
        ############################# # #######   ######## ###########################
        #############################################################################

        ## Parameters for Conv TN
        ### First level MPS blocks
        self.Conv_module1 = nn.ModuleList([
            MPS(input_dim=self.Conv_nCh,
                output_dim=self.virtual_dim,
                nCh=nCh,
                bond_dim=bond_dim,
                feature_dim=Conv_feature_dim,
                parallel_eval=parallel_eval,
                adaptive_mode=adaptive_mode,
                periodic_bc=periodic_bc) for i in range(torch.prod(Conv_iDim))
        ])

        self.Conv_BN1 = nn.BatchNorm1d(torch.prod(Conv_iDim).numpy(),
                                       affine=True)

        Conv_iDim = Conv_iDim / self.Conv_ker[1]
        Conv_feature_dim = 2 * self.Conv_ker[1]**2

        ### Second level MPS blocks
        self.Conv_module2 = nn.ModuleList([
            MPS(input_dim=self.virtual_dim,
                output_dim=self.virtual_dim,
                nCh=self.virtual_dim,
                bond_dim=bond_dim,
                feature_dim=2 * self.LoTe_kScale + Conv_feature_dim,
                parallel_eval=parallel_eval,
                adaptive_mode=adaptive_mode,
                periodic_bc=periodic_bc) for i in range(torch.prod(Conv_iDim))
        ])

        self.Conv_BN2 = nn.BatchNorm1d(torch.prod(Conv_iDim).numpy(),
                                       affine=True)
        self.BN2 = nn.BatchNorm1d(torch.prod(Conv_iDim).numpy(), affine=True)

        Conv_iDim = Conv_iDim / self.Conv_ker[2]
        Conv_feature_dim = 2 * self.Conv_ker[2]**2
        ### Third level MPS blocks
        self.Conv_module3 = nn.ModuleList([
            MPS(input_dim=2 * self.virtual_dim,
                output_dim=self.virtual_dim,
                nCh=self.virtual_dim,
                bond_dim=bond_dim,
                feature_dim=Conv_feature_dim,
                parallel_eval=parallel_eval,
                adaptive_mode=adaptive_mode,
                periodic_bc=periodic_bc) for i in range(torch.prod(Conv_iDim))
        ])

        self.Conv_BN3 = nn.BatchNorm1d(torch.prod(Conv_iDim).numpy(),
                                       affine=True)
        self.BN3 = nn.BatchNorm1d(torch.prod(Conv_iDim).numpy(), affine=True)

        Conv_feature_dim = 2 * self.virtual_dim
        ### Final MPS block
        # self.Conv_mpsFinal = MPS(input_dim=len(self.Conv_module3),
        # 					output_dim=output_dim, nCh=1,
        # 					bond_dim=bond_dim, feature_dim=Conv_feature_dim,
        # 					adaptive_mode=adaptive_mode, periodic_bc=periodic_bc,
        # 					parallel_eval=parallel_eval)

        self.mpsFinal = MPS(input_dim=len(self.Conv_module3),
                            output_dim=output_dim,
                            nCh=1,
                            bond_dim=bond_dim,
                            feature_dim=2 * Conv_feature_dim,
                            adaptive_mode=adaptive_mode,
                            periodic_bc=periodic_bc,
                            parallel_eval=parallel_eval)
Esempio n. 4
0
    def __init__(self,
                 input_dim,
                 output_dim,
                 bond_dim,
                 feature_dim=2,
                 nCh=3,
                 kernel=[2, 2, 2],
                 virtual_dim=1,
                 adaptive_mode=False,
                 periodic_bc=False,
                 parallel_eval=False,
                 label_site=None,
                 path=None,
                 init_std=1e-9,
                 use_bias=True,
                 fixed_bias=True,
                 cutoff=1e-10,
                 merge_threshold=2000):
        # bond_dim parameter is non-sense
        super().__init__()
        self.input_dim = input_dim
        self.virtual_dim = bond_dim

        ### Squeezing of spatial dimension in first step
        # self.kScale = 4  # what is this?
        # nCh = self.kScale ** 2 * nCh
        # self.input_dim = self.input_dim / self.kScale

        # print(nCh)
        self.nCh = nCh
        if isinstance(kernel, int):
            kernel = 3 * [kernel]
        self.ker = kernel

        num_layers = np.int(np.log2(input_dim[0]) - 2)
        self.num_layers = num_layers
        self.disentangler_list = []
        self.isometry_list = []
        iDim = (self.input_dim - 2) / 2

        for ii in range(num_layers):
            # feature_dim = 2 * nCh
            feature_dim = 2 * nCh
            # print(feature_dim)
            # First level disentanglers
            # First level isometries

            ### First level MERA blocks
            self.disentangler_list.append(
                nn.ModuleList([
                    MPS(input_dim=4,
                        output_dim=4,
                        nCh=nCh,
                        bond_dim=4,
                        feature_dim=feature_dim,
                        parallel_eval=parallel_eval,
                        adaptive_mode=adaptive_mode,
                        periodic_bc=periodic_bc)
                    for i in range(torch.prod(iDim))
                ]))

            iDim = iDim + 1

            self.isometry_list.append(
                nn.ModuleList([
                    MPS(input_dim=4,
                        output_dim=1,
                        nCh=nCh,
                        bond_dim=bond_dim,
                        feature_dim=feature_dim,
                        parallel_eval=parallel_eval,
                        adaptive_mode=adaptive_mode,
                        periodic_bc=periodic_bc)
                    for i in range(torch.prod(iDim))
                ]))
            iDim = (iDim - 2) / 2

        ### Final MPS block
        self.mpsFinal = MPS(input_dim=49,
                            output_dim=output_dim,
                            nCh=1,
                            bond_dim=bond_dim,
                            feature_dim=feature_dim,
                            adaptive_mode=adaptive_mode,
                            periodic_bc=periodic_bc,
                            parallel_eval=parallel_eval)
Esempio n. 5
0
                 nCh=nCh,
                 kernel=kernel,
                 bn=args.bn,
                 dropout=args.dropout,
                 bond_dim=args.bond_dim,
                 feature_dim=feature_dim,
                 adaptive_mode=adaptive_mode,
                 periodic_bc=periodic_bc,
                 virtual_dim=1)
elif args.tenetx:
    print("Tensornet-X baseline")
    #       pdb.set_trace()
    model = MPS(input_dim=torch.prod(dim),
                output_dim=output_dim,
                bond_dim=args.bond_dim,
                feature_dim=feature_dim,
                adaptive_mode=adaptive_mode,
                periodic_bc=periodic_bc,
                tenetx=args.tenetx)
elif args.densenet:
    print("Densenet Baseline!")
    model = DenseNet(depth=40,
                     growthRate=12,
                     reduction=0.5,
                     bottleneck=True,
                     nClasses=output_dim)
elif args.mlp:
    print("MLP Baseline!")
    model = BaselineMLP(inCh=torch.prod(dim), nhid=1, nClasses=output_dim)
else:
    print("Choose a model!")
Esempio n. 6
0
                          batch_size=batch_size, shuffle=True,pin_memory=True)
loader_valid = DataLoader(dataset = dataset_valid, drop_last=True,num_workers=1,
                          batch_size=batch_size, shuffle=False,pin_memory=True)
loader_test = DataLoader(dataset = dataset_test, drop_last=True,num_workers=1,
                         batch_size=batch_size, shuffle=False,pin_memory=True)
nValid = len(loader_valid)
nTrain = len(loader_train)
nTest = len(loader_test)

# Initialize the models
dim = dim//args.kernel 
print("Using Strided Tenet with patches of size",dim)
output_dim = torch.prod(dim)
model = MPS(input_dim=torch.prod(dim), 
        output_dim=output_dim, 
        bond_dim=args.bond_dim,
        feature_dim=feature_dim*nCh,
        lFeat=feature_dim)
model = model.to(device)

# Initialize loss and metrics
accuracy = dice
loss_fun = dice_loss()

# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, 
                             weight_decay=args.l2)

nParam = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters:%d"%(nParam))
print(f"Maximum MPS bond dimension = {args.bond_dim}")