from utils import masked_l1 import wandb from tqdm import tqdm torch.autograd.set_detect_anomaly(True) device = "cuda" if torch.cuda.is_available() else "cpu" dataset = Dataset(num_parts=24) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) model = Pipeline(H=512, W=512, num_features=16, num_parts=24) model.to(device) # model.load_state_dict(torch.load("tmp.pth")) model.train() learning_rate = 1e-3 optimizer = torch.optim.Adam([ # Apply increasing amount of regularization to finer layers { "params": model.atlas.layer1, 'weight_decay': 1e-2, 'lr': learning_rate }, { "params": model.atlas.layer2, 'weight_decay': 1e-3, 'lr': learning_rate }, {