def __init__(self, input_dim, output_dim, bond_dim, feature_dim=2, nCh=3, kernel=2, virtual_dim=1, bn=False, dropout=0.0, seg=False, adaptive_mode=False, periodic_bc=False, parallel_eval=False, label_site=None, path=None, init_std=1e-9, use_bias=True, fixed_bias=True, cutoff=1e-10, merge_threshold=2000): super().__init__() self.input_dim = input_dim nDim = len(self.input_dim) self.nDim = nDim self.bn = bn self.dropout = dropout self.vdim = 1 ## Work out the depth of the tensornet self.dropout = dropout self.ker = [16, 4] nL = np.log(self.input_dim[0].item()) / np.log(self.ker[0]) self.L = int(np.floor(nL)) - 1 nCh = self.ker[0]**nDim * nCh self.L = len(self.ker) print("Using depth of %d" % (self.L + 1)) self.nCh = nCh self.lFeat = 1 feature_dim = self.lFeat * nCh ### First level MPS blocks self.module = nn.ModuleList([ MPS(input_dim=((self.vdim - 1) * int(i > 0) + 1) * torch.prod(self.input_dim // (np.prod(self.ker[:i + 1]))), output_dim=self.vdim * torch.prod(self.input_dim // np.prod(self.ker[:i + 1])), bond_dim=bond_dim, lFeat=self.lFeat, feature_dim=self.lFeat * (self.ker[i])**nDim, parallel_eval=parallel_eval, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) for i in range(self.L) ]) if self.bn: self.BN = nn.ModuleList([nn.BatchNorm1d(self.vdim*torch.prod(self.input_dim//(np.prod(self.ker[:i+1]))).numpy(),\ affine=True) for i in range(self.L)]) ### Third level MPS blocks ### Final MPS block self.mpsFinal = MPS(input_dim=self.vdim * torch.prod(self.input_dim // np.prod(self.ker)), output_dim=output_dim, lFeat=self.lFeat, bond_dim=bond_dim, feature_dim=self.lFeat, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, parallel_eval=parallel_eval)
def __init__(self, input_dim, output_dim, bond_dim, feature_dim=2, nCh=3, kernel=2, virtual_dim=1, adaptive_mode=False, periodic_bc=False, parallel_eval=False, label_site=None, path=None, init_std=1e-9, use_bias=True, fixed_bias=True, cutoff=1e-10, merge_threshold=2000): super().__init__() self.input_dim = input_dim self.virtual_dim = bond_dim ### Squeezing of spatial dimension in first step self.kScale = 4 nCh = self.kScale**2 * nCh self.input_dim = self.input_dim // self.kScale self.nCh = nCh self.ker = kernel iDim = (self.input_dim // (self.ker)) feature_dim = 2 * nCh ### First level MPS blocks self.module1 = nn.ModuleList([ MPS(input_dim=(self.ker)**2, output_dim=self.virtual_dim, nCh=nCh, bond_dim=bond_dim, feature_dim=feature_dim, parallel_eval=parallel_eval, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) for i in range(torch.prod(iDim)) ]) self.BN1 = nn.BatchNorm1d(torch.prod(iDim).numpy(), affine=True) iDim = iDim // self.ker feature_dim = 2 * self.virtual_dim ### Second level MPS blocks self.module2 = nn.ModuleList([ MPS(input_dim=self.ker**2, output_dim=self.virtual_dim, nCh=self.virtual_dim, bond_dim=bond_dim, feature_dim=feature_dim, parallel_eval=parallel_eval, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) for i in range(torch.prod(iDim)) ]) self.BN2 = nn.BatchNorm1d(torch.prod(iDim).numpy(), affine=True) iDim = iDim // self.ker ### Third level MPS blocks self.module3 = nn.ModuleList([ MPS(input_dim=self.ker**2, output_dim=self.virtual_dim, nCh=self.virtual_dim, bond_dim=bond_dim, feature_dim=feature_dim, parallel_eval=parallel_eval, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) for i in range(torch.prod(iDim)) ]) self.BN3 = nn.BatchNorm1d(torch.prod(iDim).numpy(), affine=True) ### Final MPS block self.mpsFinal = MPS(input_dim=len(self.module3), output_dim=output_dim, nCh=1, bond_dim=bond_dim, feature_dim=feature_dim, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, parallel_eval=parallel_eval)
def __init__(self, input_dim, output_dim, bond_dim, feature_dim=2, nCh=3, kernel1=[2, 2, 2], kernel2=[2, 2, 2], virtual_dim=1, adaptive_mode=False, periodic_bc=False, parallel_eval=False, label_site=None, path=None, init_std=1e-9, use_bias=True, fixed_bias=True, cutoff=1e-10, merge_threshold=2000): super().__init__() self.input_dim = input_dim self.virtual_dim = bond_dim ### Squeezing of spatial dimension in first step self.LoTe_kScale = 4 # what is this? LoTe_nCh = self.LoTe_kScale**2 * nCh self.LoTe_input_dim = self.input_dim / self.LoTe_kScale # print(nCh) self.LoTe_nCh = LoTe_nCh if isinstance(kernel1, int): LoTe_kernel = 3 * [kernel1] else: LoTe_kernel = kernel1 self.LoTe_ker = LoTe_kernel LoTe_iDim = (self.LoTe_input_dim / (self.LoTe_ker[0])) LoTe_feature_dim = 2 * LoTe_nCh ## Parameters for Conv self.Conv_input_dim = self.input_dim self.Conv_nCh = nCh if isinstance(kernel2, int): Conv_kernel = 3 * [kernel2] else: Conv_kernel = kernel2 self.Conv_ker = Conv_kernel Conv_iDim = (self.Conv_input_dim / (self.Conv_ker[0])) Conv_feature_dim = 2 * self.Conv_ker[0]**2 ### First level MPS blocks self.LoTe_module1 = nn.ModuleList([ MPS(input_dim=(self.LoTe_ker[0])**2, output_dim=self.virtual_dim, nCh=nCh, bond_dim=bond_dim, feature_dim=LoTe_feature_dim, parallel_eval=parallel_eval, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) for i in range(torch.prod(LoTe_iDim)) ]) self.LoTe_BN1 = nn.BatchNorm1d(torch.prod(LoTe_iDim).numpy(), affine=True) LoTe_iDim = LoTe_iDim / self.LoTe_ker[1] LoTe_feature_dim = 2 * self.virtual_dim ### Second level MPS blocks self.LoTe_module2 = nn.ModuleList([ MPS(input_dim=self.LoTe_ker[1]**2 + self.Conv_ker[1]**2, output_dim=self.virtual_dim, nCh=self.virtual_dim, bond_dim=bond_dim, feature_dim=LoTe_feature_dim, parallel_eval=parallel_eval, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) for i in range(torch.prod(LoTe_iDim)) ]) self.LoTe_BN2 = nn.BatchNorm1d(torch.prod(LoTe_iDim).numpy(), affine=True) LoTe_iDim = LoTe_iDim / self.LoTe_ker[2] ### Third level MPS blocks self.LoTe_module3 = nn.ModuleList([ MPS(input_dim=self.LoTe_ker[2]**2, output_dim=self.virtual_dim, nCh=self.virtual_dim, bond_dim=bond_dim, feature_dim=2 * LoTe_feature_dim, parallel_eval=parallel_eval, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) for i in range(torch.prod(LoTe_iDim)) ]) self.LoTe_BN3 = nn.BatchNorm1d(torch.prod(LoTe_iDim).numpy(), affine=True) ### Final MPS block # self.LoTe_mpsFinal = MPS(input_dim=len(self.LoTe_module3), # output_dim=output_dim, nCh=1, # bond_dim=bond_dim, feature_dim=LoTe_feature_dim, # adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, # parallel_eval=parallel_eval) ############################################################################## ############################################################################## ############################ # ####### ##### ###### ####################### ########################### ########## ### ##### #### ######################## ########################### ########## ### ###### ## ######################### ############################# # ####### ######## ########################### ############################################################################# ## Parameters for Conv TN ### First level MPS blocks self.Conv_module1 = nn.ModuleList([ MPS(input_dim=self.Conv_nCh, output_dim=self.virtual_dim, nCh=nCh, bond_dim=bond_dim, feature_dim=Conv_feature_dim, parallel_eval=parallel_eval, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) for i in range(torch.prod(Conv_iDim)) ]) self.Conv_BN1 = nn.BatchNorm1d(torch.prod(Conv_iDim).numpy(), affine=True) Conv_iDim = Conv_iDim / self.Conv_ker[1] Conv_feature_dim = 2 * self.Conv_ker[1]**2 ### Second level MPS blocks self.Conv_module2 = nn.ModuleList([ MPS(input_dim=self.virtual_dim, output_dim=self.virtual_dim, nCh=self.virtual_dim, bond_dim=bond_dim, feature_dim=2 * self.LoTe_kScale + Conv_feature_dim, parallel_eval=parallel_eval, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) for i in range(torch.prod(Conv_iDim)) ]) self.Conv_BN2 = nn.BatchNorm1d(torch.prod(Conv_iDim).numpy(), affine=True) self.BN2 = nn.BatchNorm1d(torch.prod(Conv_iDim).numpy(), affine=True) Conv_iDim = Conv_iDim / self.Conv_ker[2] Conv_feature_dim = 2 * self.Conv_ker[2]**2 ### Third level MPS blocks self.Conv_module3 = nn.ModuleList([ MPS(input_dim=2 * self.virtual_dim, output_dim=self.virtual_dim, nCh=self.virtual_dim, bond_dim=bond_dim, feature_dim=Conv_feature_dim, parallel_eval=parallel_eval, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) for i in range(torch.prod(Conv_iDim)) ]) self.Conv_BN3 = nn.BatchNorm1d(torch.prod(Conv_iDim).numpy(), affine=True) self.BN3 = nn.BatchNorm1d(torch.prod(Conv_iDim).numpy(), affine=True) Conv_feature_dim = 2 * self.virtual_dim ### Final MPS block # self.Conv_mpsFinal = MPS(input_dim=len(self.Conv_module3), # output_dim=output_dim, nCh=1, # bond_dim=bond_dim, feature_dim=Conv_feature_dim, # adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, # parallel_eval=parallel_eval) self.mpsFinal = MPS(input_dim=len(self.Conv_module3), output_dim=output_dim, nCh=1, bond_dim=bond_dim, feature_dim=2 * Conv_feature_dim, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, parallel_eval=parallel_eval)
def __init__(self, input_dim, output_dim, bond_dim, feature_dim=2, nCh=3, kernel=[2, 2, 2], virtual_dim=1, adaptive_mode=False, periodic_bc=False, parallel_eval=False, label_site=None, path=None, init_std=1e-9, use_bias=True, fixed_bias=True, cutoff=1e-10, merge_threshold=2000): # bond_dim parameter is non-sense super().__init__() self.input_dim = input_dim self.virtual_dim = bond_dim ### Squeezing of spatial dimension in first step # self.kScale = 4 # what is this? # nCh = self.kScale ** 2 * nCh # self.input_dim = self.input_dim / self.kScale # print(nCh) self.nCh = nCh if isinstance(kernel, int): kernel = 3 * [kernel] self.ker = kernel num_layers = np.int(np.log2(input_dim[0]) - 2) self.num_layers = num_layers self.disentangler_list = [] self.isometry_list = [] iDim = (self.input_dim - 2) / 2 for ii in range(num_layers): # feature_dim = 2 * nCh feature_dim = 2 * nCh # print(feature_dim) # First level disentanglers # First level isometries ### First level MERA blocks self.disentangler_list.append( nn.ModuleList([ MPS(input_dim=4, output_dim=4, nCh=nCh, bond_dim=4, feature_dim=feature_dim, parallel_eval=parallel_eval, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) for i in range(torch.prod(iDim)) ])) iDim = iDim + 1 self.isometry_list.append( nn.ModuleList([ MPS(input_dim=4, output_dim=1, nCh=nCh, bond_dim=bond_dim, feature_dim=feature_dim, parallel_eval=parallel_eval, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc) for i in range(torch.prod(iDim)) ])) iDim = (iDim - 2) / 2 ### Final MPS block self.mpsFinal = MPS(input_dim=49, output_dim=output_dim, nCh=1, bond_dim=bond_dim, feature_dim=feature_dim, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, parallel_eval=parallel_eval)
nCh=nCh, kernel=kernel, bn=args.bn, dropout=args.dropout, bond_dim=args.bond_dim, feature_dim=feature_dim, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, virtual_dim=1) elif args.tenetx: print("Tensornet-X baseline") # pdb.set_trace() model = MPS(input_dim=torch.prod(dim), output_dim=output_dim, bond_dim=args.bond_dim, feature_dim=feature_dim, adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, tenetx=args.tenetx) elif args.densenet: print("Densenet Baseline!") model = DenseNet(depth=40, growthRate=12, reduction=0.5, bottleneck=True, nClasses=output_dim) elif args.mlp: print("MLP Baseline!") model = BaselineMLP(inCh=torch.prod(dim), nhid=1, nClasses=output_dim) else: print("Choose a model!")
batch_size=batch_size, shuffle=True,pin_memory=True) loader_valid = DataLoader(dataset = dataset_valid, drop_last=True,num_workers=1, batch_size=batch_size, shuffle=False,pin_memory=True) loader_test = DataLoader(dataset = dataset_test, drop_last=True,num_workers=1, batch_size=batch_size, shuffle=False,pin_memory=True) nValid = len(loader_valid) nTrain = len(loader_train) nTest = len(loader_test) # Initialize the models dim = dim//args.kernel print("Using Strided Tenet with patches of size",dim) output_dim = torch.prod(dim) model = MPS(input_dim=torch.prod(dim), output_dim=output_dim, bond_dim=args.bond_dim, feature_dim=feature_dim*nCh, lFeat=feature_dim) model = model.to(device) # Initialize loss and metrics accuracy = dice loss_fun = dice_loss() # Initialize optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2) nParam = sum(p.numel() for p in model.parameters() if p.requires_grad) print("Number of parameters:%d"%(nParam)) print(f"Maximum MPS bond dimension = {args.bond_dim}")