Exemple #1
0
def makeTrainer(*, task='h**o', device='cuda', lr=3e-3, bs=75, num_epochs=500,network=MolecLieResNet, 
                net_config={'k':1536,'nbhd':100,'act':'swish','group':lieGroups.T(3),
                'bn':True,'aug':True,'mean':True,'num_layers':6}, recenter=False,
                subsample=False, trainer_config={'log_dir':None,'log_suffix':''}):#,'log_args':{'timeFrac':1/4,'minPeriod':0}}):
    # Create Training set and model
    device = torch.device(device)
    with FixedNumpySeed(0):
        datasets, num_species, charge_scale = QM9datasets()
        if subsample: datasets.update(split_dataset(datasets['train'],{'train':subsample}))
    ds_stats = datasets['train'].stats[task]
    if recenter:
        m = datasets['train'].data['charges']>0
        pos = datasets['train'].data['positions'][m]
        mean,std = pos.mean(dim=0),1#pos.std()
        for ds in datasets.values():
            ds.data['positions'] = (ds.data['positions']-mean[None,None,:])/std
    model = network(num_species,charge_scale,**net_config).to(device)
    # Create train and Val(Test) dataloaders and move elems to gpu
    dataloaders = {key:LoaderTo(DataLoader(dataset,batch_size=bs,num_workers=0,
                    shuffle=(key=='train'),pin_memory=False,collate_fn=collate_fn,drop_last=True),
                    device) for key,dataset in datasets.items()}
    # subsampled training dataloader for faster logging of training performance
    dataloaders['Train'] = islice(dataloaders['train'],len(dataloaders['test']))#islice(dataloaders['train'],len(dataloaders['train'])//10)
    
    # Initialize optimizer and learning rate schedule
    opt_constr = functools.partial(Adam, lr=lr)
    cos = cosLr(num_epochs)
    lr_sched = lambda e: min(e / (.01 * num_epochs), 1) * cos(e)
    return MoleculeTrainer(model,dataloaders,opt_constr,lr_sched,
                            task=task,ds_stats=ds_stats,**trainer_config)
def makeTrainer(*,
                network=CHNN,
                net_cfg={},
                lr=3e-3,
                n_train=800,
                regen=False,
                dataset=RigidBodyDataset,
                body=ChainPendulum(3),
                C=5,
                dtype=torch.float32,
                device=torch.device("cuda"),
                bs=200,
                num_epochs=100,
                trainer_config={},
                opt_cfg={'weight_decay': 1e-5}):
    # Create Training set and model
    angular = not issubclass(network, (CH, CL))
    splits = {"train": n_train, "test": 200}
    with FixedNumpySeed(0):
        dataset = dataset(n_systems=n_train + 200,
                          regen=regen,
                          chunk_len=C,
                          body=body,
                          angular_coords=angular)
        datasets = split_dataset(dataset, splits)

    dof_ndim = dataset.body.D if angular else dataset.body.d
    model = network(dataset.body.body_graph,
                    dof_ndim=dof_ndim,
                    angular_dims=dataset.body.angular_dims,
                    **net_cfg)
    model = model.to(device=device, dtype=dtype)
    # Create train and Dev(Test) dataloaders and move elems to gpu
    dataloaders = {
        k: LoaderTo(DataLoader(v,
                               batch_size=min(bs, splits[k]),
                               num_workers=0,
                               shuffle=(k == "train")),
                    device=device,
                    dtype=dtype)
        for k, v in datasets.items()
    }
    dataloaders["Train"] = dataloaders["train"]
    # Initialize optimizer and learning rate schedule
    opt_constr = lambda params: AdamW(params, lr=lr, **opt_cfg)
    lr_sched = cosLr(num_epochs)
    return IntegratedDynamicsTrainer(model,
                                     dataloaders,
                                     opt_constr,
                                     lr_sched,
                                     log_args={
                                         "timeFrac": 1 / 4,
                                         "minPeriod": 0.0
                                     },
                                     **trainer_config)
def makeTrainer(*,
                network=EMLP,
                num_epochs=500,
                seed=2020,
                aug=False,
                bs=50,
                lr=1e-3,
                device='cuda',
                net_config={
                    'num_layers': 3,
                    'ch': rep,
                    'group': Cube()
                },
                log_level='info',
                trainer_config={
                    'log_dir': None,
                    'log_args': {
                        'minPeriod': .02,
                        'timeFrac': 50
                    }
                },
                save=False):
    levels = {
        'critical': logging.CRITICAL,
        'error': logging.ERROR,
        'warn': logging.WARNING,
        'warning': logging.WARNING,
        'info': logging.INFO,
        'debug': logging.DEBUG
    }
    logging.getLogger().setLevel(levels[log_level])
    # Prep the datasets splits, model, and dataloaders
    with FixedNumpySeed(seed), FixedPytorchSeed(seed):
        datasets = {
            'train': InvertedCube(train=True),
            'test': InvertedCube(train=False)
        }
    model = Standardize(
        network(datasets['train'].rep_in, datasets['train'].rep_out,
                **net_config), datasets['train'].stats)
    dataloaders = {
        k: LoaderTo(
            DataLoader(v,
                       batch_size=min(bs, len(v)),
                       shuffle=(k == 'train'),
                       num_workers=0,
                       pin_memory=False))
        for k, v in datasets.items()
    }
    dataloaders['Train'] = dataloaders['train']
    opt_constr = objax.optimizer.Adam
    lr_sched = lambda e: lr * cosLr(num_epochs)(e)
    return ClassifierPlus(model, dataloaders, opt_constr, lr_sched,
                          **trainer_config)
Exemple #4
0
def makeTrainer():
    device = torch.device('cuda')
    CNN = smallCNN(**net_config).to(device)
    fullCNN = nn.Sequential(C10augLayers(),CNN)
    trainset, testset = CIFAR10(False, '~/datasets/cifar10/')

    dataloaders = {}
    dataloaders['train'], dataloaders['dev'] = getLabLoader(trainset,**loader_config)
    dataloaders = {k: loader_to(device)(v) for k,v in dataloaders.items()}

    opt_constr = lambda params: optim.SGD(params, **opt_config)
    lr_sched = cosLr(**sched_config)
    return Classifier(fullCNN,dataloaders,opt_constr,lr_sched,**trainer_config,tracked_hypers=all_hypers)
Exemple #5
0
 def makeTrainer(config):
     cfg = {
         'dataset': CIFAR10,
         'network': iCNN,
         'net_config': {},
         'loader_config': {
             'amnt_dev': 5000,
             'lab_BS': 32,
             'pin_memory': True,
             'num_workers': 3
         },
         'opt_config': {
             'lr': .0003,
         },  # 'momentum':.9, 'weight_decay':1e-4,'nesterov':True},
         'num_epochs': 100,
         'trainer_config': {},
         'parallel': False,
     }
     recursively_update(cfg, config)
     train_transforms = transforms.Compose(
         [transforms.ToTensor(),
          transforms.RandomHorizontalFlip()])
     trainset = cfg['dataset'](
         '~/datasets/{}/'.format(cfg['dataset']),
         flow=True,
     )
     device = torch.device('cuda')
     fullCNN = cfg['network'](num_classes=trainset.num_classes,
                              **cfg['net_config']).to(device)
     if cfg['parallel']: fullCNN = multigpu_parallelize(fullCNN, cfg)
     dataloaders = {}
     dataloaders['train'], dataloaders['dev'] = getLabLoader(
         trainset, **cfg['loader_config'])
     dataloaders['Train'] = islice(dataloaders['train'],
                                   10000 // cfg['loader_config']['lab_BS'])
     if len(dataloaders['dev']) == 0:
         testset = cfg['dataset']('~/datasets/{}/'.format(cfg['dataset']),
                                  train=False,
                                  flow=True)
         dataloaders['test'] = DataLoader(
             testset,
             batch_size=cfg['loader_config']['lab_BS'],
             shuffle=False)
     dataloaders = {k: LoaderTo(v, device)
                    for k, v in dataloaders.items()}  #LoaderTo(v,device)
     opt_constr = lambda params: torch.optim.Adam(params, **cfg['opt_config'
                                                                ])
     lr_sched = cosLr(cfg['num_epochs'])
     return Flow(fullCNN, dataloaders, opt_constr, lr_sched,
                 **cfg['trainer_config'])
Exemple #6
0
def makeTrainer(*,
                dataset=YAHOO,
                network=SmallNN,
                num_epochs=15,
                bs=5000,
                lr=1e-3,
                optim=AdamW,
                device='cuda',
                trainer=Classifier,
                split={
                    'train': 20,
                    'val': 5000
                },
                net_config={},
                opt_config={'weight_decay': 1e-5},
                trainer_config={
                    'log_dir': os.path.expanduser('~/tb-experiments/UCI/'),
                    'log_args': {
                        'minPeriod': .1,
                        'timeFrac': 3 / 10
                    }
                },
                save=False):

    # Prep the datasets splits, model, and dataloaders
    with FixedNumpySeed(0):
        datasets = split_dataset(dataset(), splits=split)
        datasets['_unlab'] = dmap(lambda mb: mb[0], dataset())
        datasets['test'] = dataset(train=False)
        #print(datasets['test'][0])
    device = torch.device(device)
    model = network(num_classes=datasets['train'].num_classes,
                    dim_in=datasets['train'].dim,
                    **net_config).to(device)

    dataloaders = {
        k: LoaderTo(
            DataLoader(v,
                       batch_size=min(bs, len(datasets[k])),
                       shuffle=(k == 'train'),
                       num_workers=0,
                       pin_memory=False), device)
        for k, v in datasets.items()
    }
    dataloaders['Train'] = dataloaders['train']
    opt_constr = partial(optim, lr=lr, **opt_config)
    lr_sched = cosLr(num_epochs)  #lambda e:1#
    return trainer(model, dataloaders, opt_constr, lr_sched, **trainer_config)
def makeTrainer(*,
                network=ResNet,
                num_epochs=5,
                seed=2020,
                aug=False,
                bs=30,
                lr=1e-3,
                device='cuda',
                split={
                    'train': -1,
                    'val': 10000
                },
                net_config={
                    'k': 512,
                    'num_layers': 4
                },
                log_level='info',
                trainer_config={
                    'log_dir': None,
                    'log_args': {
                        'minPeriod': .02,
                        'timeFrac': .2
                    }
                },
                save=False):
    # Prep the datasets splits, model, and dataloaders
    datasets = {split: TopTagging(split=split) for split in ['train', 'val']}
    model = network(4, 2, **net_config)
    dataloaders = {
        k: LoaderTo(
            DataLoader(v,
                       batch_size=bs,
                       shuffle=(k == 'train'),
                       num_workers=0,
                       pin_memory=False,
                       collate_fn=collate_fn,
                       drop_last=True))
        for k, v in datasets.items()
    }
    dataloaders['Train'] = islice(dataloaders['train'], 0, None,
                                  10)  #for logging subsample dataset by 5x
    #equivariance_test(model,dataloaders['train'],net_config['group'])
    opt_constr = objax.optimizer.Adam
    lr_sched = lambda e: lr * cosLr(num_epochs)(e)
    return ClassifierPlus(model, dataloaders, opt_constr, lr_sched,
                          **trainer_config)
Exemple #8
0
def makeTrainer(*,
                dataset=MnistRotDataset,
                network=ImgLieResnet,
                num_epochs=100,
                bs=50,
                lr=3e-3,
                aug=True,
                optim=Adam,
                device='cuda',
                trainer=Classifier,
                split={'train': 12000},
                small_test=False,
                net_config={},
                opt_config={},
                trainer_config={'log_dir': None}):

    # Prep the datasets splits, model, and dataloaders
    datasets = split_dataset(dataset(f'~/datasets/{dataset}/'), splits=split)
    datasets['test'] = dataset(f'~/datasets/{dataset}/', train=False)
    device = torch.device(device)
    model = network(num_targets=datasets['train'].num_targets,
                    **net_config).to(device)
    if aug:
        model = torch.nn.Sequential(datasets['train'].default_aug_layers(),
                                    model)
    model, bs = try_multigpu_parallelize(model, bs)

    dataloaders = {
        k: LoaderTo(
            DataLoader(v,
                       batch_size=bs,
                       shuffle=(k == 'train'),
                       num_workers=0,
                       pin_memory=False), device)
        for k, v in datasets.items()
    }
    dataloaders['Train'] = islice(dataloaders['train'],
                                  1 + len(dataloaders['train']) // 10)
    if small_test:
        dataloaders['test'] = islice(dataloaders['test'],
                                     1 + len(dataloaders['train']) // 10)
    # Add some extra defaults if SGD is chosen
    opt_constr = partial(optim, lr=lr, **opt_config)
    lr_sched = cosLr(num_epochs)
    return trainer(model, dataloaders, opt_constr, lr_sched, **trainer_config)
def makeTrainer(*,network,net_cfg,lr=1e-2,n_train=5000,regen=False,
                dtype=torch.float32,device=torch.device('cuda'),bs=200,num_epochs=2,
                trainer_config={'log_dir':'data_scaling_study_final'}):
    # Create Training set and model
    splits = {'train':n_train,'val':min(n_train,2000),'test':2000}
    dataset = SpringDynamics(n_systems=100000, regen=regen)
    with FixedNumpySeed(0):
        datasets = split_dataset(dataset,splits)
    model = network(**net_cfg).to(device=device,dtype=dtype)
    # Create train and Dev(Test) dataloaders and move elems to gpu
    dataloaders = {k:LoaderTo(DataLoader(v,batch_size=min(bs,n_train),num_workers=0,shuffle=(k=='train')),
                                device=device,dtype=dtype) for k,v in datasets.items()}
    dataloaders['Train'] = islice(dataloaders['train'],len(dataloaders['val']))
    # Initialize optimizer and learning rate schedule
    opt_constr = lambda params: Adam(params, lr=lr)
    lr_sched = cosLr(num_epochs)
    return IntegratedDynamicsTrainer2(model,dataloaders,opt_constr,lr_sched,
                                    log_args={'timeFrac':1/4,'minPeriod':0.0},**trainer_config)
Exemple #10
0
 def makeTrainer(config):
     cfg = {
         'dataset': CIFAR10,
         'network': layer13s,
         'net_config': {},
         'loader_config': {
             'amnt_dev': 5000,
             'lab_BS': 20,
             'pin_memory': True,
             'num_workers': 2
         },
         'opt_config': {
             'lr': .3e-4
         },  #, 'momentum':.9, 'weight_decay':1e-4,'nesterov':True},
         'num_epochs': 100,
         'trainer_config': {},
     }
     recursively_update(cfg, config)
     trainset = cfg['dataset']('~/datasets/{}/'.format(cfg['dataset']),
                               flow=True)
     device = torch.device('cuda')
     fullCNN = torch.nn.Sequential(
         trainset.default_aug_layers(),
         cfg['network'](num_classes=trainset.num_classes,
                        **cfg['net_config']).to(device))
     dataloaders = {}
     dataloaders['train'], dataloaders['dev'] = getLabLoader(
         trainset, **cfg['loader_config'])
     dataloaders['Train'] = islice(dataloaders['train'],
                                   10000 // cfg['loader_config']['lab_BS'])
     if len(dataloaders['dev']) == 0:
         testset = cfg['dataset']('~/datasets/{}/'.format(cfg['dataset']),
                                  train=False)
         dataloaders['test'] = DataLoader(
             testset,
             batch_size=cfg['loader_config']['lab_BS'],
             shuffle=False)
     dataloaders = {k: LoaderTo(v, device) for k, v in dataloaders.items()}
     opt_constr = lambda params: torch.optim.Adam(params, **cfg[
         'opt_config'])  #torch.optim.SGD(params, **cfg['opt_config'])
     lr_sched = cosLr(cfg['num_epochs'])
     return iClassifier(fullCNN, dataloaders, opt_constr, lr_sched,
                        **cfg['trainer_config'])
Exemple #11
0
def make_trainer(
        chunk_len: int,
        angular: Union[Tuple, bool],
        body,
        bs: int,
        dataset,
        dt: float,
        lr: float,
        n_train: int,
        n_val: int,
        n_test: int,
        net_cfg: dict,
        network,
        num_epochs: int,
        regen: bool,
        seed: int = 0,
        device=torch.device("cuda"),
        dtype=torch.float32,
        trainer_config={},
):
    # Create Training set and model
    splits = {"train": n_train, "val": n_val, "test": n_test}
    dataset = dataset(
        n_systems=n_train + n_val + n_test,
        regen=regen,
        chunk_len=chunk_len,
        body=body,
        dt=dt,
        integration_time=10,
        angular_coords=angular,
    )
    # dataset=CartpoleDataset(batch_size=500,regen=regen)
    with FixedNumpySeed(seed):
        datasets = split_dataset(dataset, splits)
    model = network(G=dataset.body.body_graph, **net_cfg).to(device=device,
                                                             dtype=dtype)

    # Create train and Dev(Test) dataloaders and move elems to gpu
    dataloaders = {
        k: LoaderTo(
            DataLoader(v,
                       batch_size=min(bs, splits[k]),
                       num_workers=0,
                       shuffle=(k == "train")),
            device=device,
            dtype=dtype,
        )
        for k, v in datasets.items()
    }
    dataloaders["Train"] = dataloaders["train"]
    # Initialize optimizer and learning rate schedule
    opt_constr = lambda params: Adam(params, lr=lr)
    lr_sched = cosLr(num_epochs)
    return IntegratedDynamicsTrainer(model,
                                     dataloaders,
                                     opt_constr,
                                     lr_sched,
                                     log_args={
                                         "timeFrac": 1 / 4,
                                         "minPeriod": 0.0
                                     },
                                     **trainer_config)
Exemple #12
0
def main():
    # Parse flags
    config = forge.config()

    # Load data
    dataloaders, num_species, charge_scale, ds_stats, data_name = fet.load(
        config.data_config, config=config)

    config.num_species = num_species
    config.charge_scale = charge_scale
    config.ds_stats = ds_stats

    # Load model
    model, model_name = fet.load(config.model_config, config)
    model.to(device)

    config.charge_scale = float(config.charge_scale.numpy())
    config.ds_stats = [float(stat.numpy()) for stat in config.ds_stats]

    # Prepare environment
    run_name = (config.run_name + "_bs" + str(config.batch_size) + "_lr" +
                str(config.learning_rate))

    if config.batch_fit != 0:
        run_name += "_bf" + str(config.batch_fit)

    if config.lr_schedule != "none":
        run_name += "_" + config.lr_schedule

    # Print flags
    fet.print_flags()

    # Setup optimizer
    model_params = model.predictor.parameters()

    opt_learning_rate = config.learning_rate
    model_opt = torch.optim.Adam(
        model_params,
        lr=opt_learning_rate,
        betas=(config.beta1, config.beta2),
        eps=1e-8,
    )
    # model_opt = torch.optim.SGD(model_params, lr=opt_learning_rate)

    # Cosine annealing learning rate
    if config.lr_schedule == "cosine":
        cos = cosLr(config.train_epochs)
        lr_sched = lambda e: max(cos(e), config.lr_floor * config.learning_rate
                                 )
        lr_schedule = optim.lr_scheduler.LambdaLR(model_opt, lr_sched)
    elif config.lr_schedule == "cosine_warmup":
        cos = cosLr(config.train_epochs)
        lr_sched = lambda e: max(
            min(e / (config.warmup_length * config.train_epochs), 1) * cos(e),
            config.lr_floor * config.learning_rate,
        )
        lr_schedule = optim.lr_scheduler.LambdaLR(model_opt, lr_sched)
    elif config.lr_schedule == "quadratic_warmup":
        lr_sched = lambda e: min(e / (0.01 * config.train_epochs), 1) * (
            1.0 / sqrt(1.0 + 10000.0 * (e / config.train_epochs)
                       )  # finish at 1/100 of initial lr
        )
        lr_schedule = optim.lr_scheduler.LambdaLR(model_opt, lr_sched)
    elif config.lr_schedule == "none":
        lr_sched = lambda e: 1.0
        lr_schedule = optim.lr_scheduler.LambdaLR(model_opt, lr_sched)
    else:
        raise ValueError(
            f"{config.lr_schedule} is not a recognised learning rate schedule")

    num_params = param_count(model)
    if config.parameter_count:
        for (name, parameter) in model.predictor.named_parameters():
            print(name, parameter.dtype)

        print(model)
        print("============================================================")
        print(f"{model_name} parameters: {num_params:.5e}")
        print("============================================================")
        # from torchsummary import summary

        # data = next(iter(dataloaders["train"]))

        # data = {k: v.to(device) for k, v in data.items()}
        # print(
        #     summary(
        #         model.predictor,
        #         data,
        #         batch_size=config.batch_size,
        #     )
        # )

        parameters = sum(parameter.numel()
                         for parameter in model.predictor.parameters())
        parameters_grad = sum(
            parameter.numel() if parameter.requires_grad else 0
            for parameter in model.predictor.parameters())
        print(f"Parameters: {parameters:,}")
        print(f"Parameters grad: {parameters_grad:,}")

        memory_allocations = []

        for batch_idx, data in enumerate(dataloaders["train"]):
            print(batch_idx)
            data = {k: v.to(device) for k, v in data.items()}

            model_opt.zero_grad()
            outputs = model(data, compute_loss=True)
            # torch.cuda.empty_cache()
            # memory_allocations.append(torch.cuda.memory_reserved() / 1024 / 1024 / 1024)
            # outputs.loss.backward()

        print(
            f"max memory reserved in one pass: {max(memory_allocations):0.4}GB"
        )
        sys.exit(0)

    else:
        print(f"{model_name} parameters: {num_params:.5e}")

    # set up results folders
    results_folder_name = osp.join(
        data_name,
        model_name,
        run_name,
    )

    logdir = osp.join(config.results_dir,
                      results_folder_name.replace(".", "_"))
    logdir, resume_checkpoint = fet.init_checkpoint(logdir, config.data_config,
                                                    config.model_config,
                                                    config.resume)

    checkpoint_name = osp.join(logdir, "model.ckpt")

    # Try to restore model and optimizer from checkpoint
    if resume_checkpoint is not None:
        start_epoch, best_valid_mae = load_checkpoint(resume_checkpoint, model,
                                                      model_opt, lr_schedule)
    else:
        start_epoch = 1
        best_valid_mae = 1e12

    train_iter = (start_epoch - 1) * (len(dataloaders["train"].dataset) //
                                      config.batch_size) + 1

    print("Starting training at epoch = {}, iter = {}".format(
        start_epoch, train_iter))

    # Setup tensorboard writing
    summary_writer = SummaryWriter(logdir)

    report_all = defaultdict(list)
    # Saving model at epoch 0 before training
    print("saving model at epoch 0 before training ... ")
    save_checkpoint(checkpoint_name, 0, model, model_opt, lr_schedule, 0.0)
    print("finished saving model at epoch 0 before training")

    if (config.debug and config.model_config
            == "configs/dynamics/eqv_transformer_model.py"):
        model_components = ([(0, [], "embedding_layer")] + list(
            chain.from_iterable((
                (k, [], f"ema_{k}"),
                (
                    k,
                    ["ema", "kernel", "location_kernel"],
                    f"ema_{k}_location_kernel",
                ),
                (
                    k,
                    ["ema", "kernel", "feature_kernel"],
                    f"ema_{k}_feature_kernel",
                ),
            ) for k in range(1, config.num_layers + 1))) +
                            [(config.num_layers + 2, [], "output_mlp")]
                            )  # components to track for debugging
        grad_flows = []

    if config.init_activations:
        activation_tracked = [(name, module)
                              for name, module in model.named_modules()
                              if isinstance(module, Expression)
                              | isinstance(module, nn.Linear)
                              | isinstance(module, MultiheadLinear)
                              | isinstance(module, MaskBatchNormNd)]
        activations = {}

        def save_activation(name, mod, inpt, otpt):
            if isinstance(inpt, tuple):
                if isinstance(inpt[0], list) | isinstance(inpt[0], tuple):
                    activations[name + "_inpt"] = inpt[0][1].detach().cpu()
                else:
                    if len(inpt) == 1:
                        activations[name + "_inpt"] = inpt[0].detach().cpu()
                    else:
                        activations[name + "_inpt"] = inpt[1].detach().cpu()
            else:
                activations[name + "_inpt"] = inpt.detach().cpu()

            if isinstance(otpt, tuple):
                if isinstance(otpt[0], list):
                    activations[name + "_otpt"] = otpt[0][1].detach().cpu()
                else:
                    if len(otpt) == 1:
                        activations[name + "_otpt"] = otpt[0].detach().cpu()
                    else:
                        activations[name + "_otpt"] = otpt[1].detach().cpu()
            else:
                activations[name + "_otpt"] = otpt.detach().cpu()

        for name, tracked_module in activation_tracked:
            tracked_module.register_forward_hook(partial(
                save_activation, name))

    # Training
    start_t = time.perf_counter()

    iters_per_epoch = len(dataloaders["train"])
    last_valid_loss = 1000.0
    for epoch in tqdm(range(start_epoch, config.train_epochs + 1)):
        model.train()

        for batch_idx, data in enumerate(dataloaders["train"]):
            data = {k: v.to(device) for k, v in data.items()}

            model_opt.zero_grad()
            outputs = model(data, compute_loss=True)

            outputs.loss.backward()
            if config.clip_grad:
                # Clip gradient L2-norm at 1
                torch.nn.utils.clip_grad.clip_grad_norm_(
                    model.predictor.parameters(), 1.0)
            model_opt.step()

            if config.init_activations:
                model_opt.zero_grad()
                outputs = model(data, compute_loss=True)
                outputs.loss.backward()
                for name, activation in activations.items():
                    print(name)
                    summary_writer.add_histogram(f"activations/{name}",
                                                 activation.numpy(), 0)

                sys.exit(0)

            if config.log_train_values:
                reports = parse_reports(outputs.reports)
                if batch_idx % config.report_loss_every == 0:
                    log_tensorboard(summary_writer, train_iter, reports,
                                    "train/")
                    report_all = log_reports(report_all, train_iter, reports,
                                             "train")
                    print_reports(
                        reports,
                        start_t,
                        epoch,
                        batch_idx,
                        len(dataloaders["train"].dataset) // config.batch_size,
                        prefix="train",
                    )

            # Logging
            if batch_idx % config.evaluate_every == 0:
                model.eval()
                with torch.no_grad():
                    valid_mae = 0.0
                    for data in dataloaders["valid"]:
                        data = {k: v.to(device) for k, v in data.items()}
                        outputs = model(data, compute_loss=True)
                        valid_mae = valid_mae + outputs.mae
                model.train()

                outputs["reports"].valid_mae = valid_mae / len(
                    dataloaders["valid"])

                reports = parse_reports(outputs.reports)

                log_tensorboard(summary_writer, train_iter, reports, "valid")
                report_all = log_reports(report_all, train_iter, reports,
                                         "valid")
                print_reports(
                    reports,
                    start_t,
                    epoch,
                    batch_idx,
                    len(dataloaders["train"].dataset) // config.batch_size,
                    prefix="valid",
                )

                loss_diff = (last_valid_loss -
                             (valid_mae / len(dataloaders["valid"])).item())
                if loss_diff and config.find_spikes < -0.1:
                    save_checkpoint(
                        checkpoint_name + "_spike",
                        epoch,
                        model,
                        model_opt,
                        lr_schedule,
                        outputs.loss,
                    )

                last_valid_loss = (valid_mae /
                                   len(dataloaders["valid"])).item()

                if outputs["reports"].valid_mae < best_valid_mae:
                    save_checkpoint(
                        checkpoint_name,
                        "best_valid_mae",
                        model,
                        model_opt,
                        lr_schedule,
                        best_valid_mae,
                    )
                    best_valid_mae = outputs["reports"].valid_mae

            train_iter += 1

            # Step the LR schedule
            lr_schedule.step(train_iter / iters_per_epoch)

            # Track stuff for debugging
            if config.debug:
                model.eval()
                with torch.no_grad():
                    curr_params = {
                        k: v.detach().clone()
                        for k, v in model.state_dict().items()
                    }

                    # updates norm
                    if train_iter == 1:
                        update_norm = 0
                    else:
                        update_norms = []
                        for (k_prev,
                             v_prev), (k_curr,
                                       v_curr) in zip(prev_params.items(),
                                                      curr_params.items()):
                            assert k_prev == k_curr
                            if (
                                    "tracked" not in k_prev
                            ):  # ignore batch norm tracking. TODO: should this be ignored? if not, fix!
                                update_norms.append(
                                    (v_curr - v_prev).norm(1).item())
                        update_norm = sum(update_norms)

                    # gradient norm
                    grad_norm = 0
                    for p in model.parameters():
                        try:
                            grad_norm += p.grad.norm(1)
                        except AttributeError:
                            pass

                    # weights norm
                    if (config.model_config ==
                            "configs/dynamics/eqv_transformer_model.py"):
                        model_norms = {}
                        for comp_name in model_components:
                            comp = get_component(model.predictor.net,
                                                 comp_name)
                            norm = get_average_norm(comp)
                            model_norms[comp_name[2]] = norm

                        log_tensorboard(
                            summary_writer,
                            train_iter,
                            model_norms,
                            "debug/avg_model_norms/",
                        )

                    log_tensorboard(
                        summary_writer,
                        train_iter,
                        {
                            "avg_update_norm1": update_norm / num_params,
                            "avg_grad_norm1": grad_norm / num_params,
                        },
                        "debug/",
                    )
                    prev_params = curr_params

                    # gradient flow
                    ave_grads = []
                    max_grads = []
                    layers = []
                    for n, p in model.named_parameters():
                        if (p.requires_grad) and ("bias" not in n):
                            layers.append(n)
                            ave_grads.append(p.grad.abs().mean().item())
                            max_grads.append(p.grad.abs().max().item())

                    grad_flow = {
                        "layers": layers,
                        "ave_grads": ave_grads,
                        "max_grads": max_grads,
                    }
                    grad_flows.append(grad_flow)

                model.train()

        # Test model at end of batch
        with torch.no_grad():
            model.eval()
            test_mae = 0.0
            for data in dataloaders["test"]:
                data = {k: v.to(device) for k, v in data.items()}
                outputs = model(data, compute_loss=True)
                test_mae = test_mae + outputs.mae

        outputs["reports"].test_mae = test_mae / len(dataloaders["test"])

        reports = parse_reports(outputs.reports)

        log_tensorboard(summary_writer, train_iter, reports, "test")
        report_all = log_reports(report_all, train_iter, reports, "test")

        print_reports(
            reports,
            start_t,
            epoch,
            batch_idx,
            len(dataloaders["train"].dataset) // config.batch_size,
            prefix="test",
        )

        reports = {
            "lr": lr_schedule.get_lr()[0],
            "time": time.perf_counter() - start_t,
            "epoch": epoch,
        }

        log_tensorboard(summary_writer, train_iter, reports, "stats")
        report_all = log_reports(report_all, train_iter, reports, "stats")

        # Save the reports
        dd.io.save(logdir + "/results_dict.h5", report_all)

        # Save a checkpoint
        if epoch % config.save_check_points == 0:
            save_checkpoint(
                checkpoint_name,
                epoch,
                model,
                model_opt,
                lr_schedule,
                best_valid_mae,
            )
            if config.only_store_last_checkpoint:
                delete_checkpoint(checkpoint_name,
                                  epoch - config.save_check_points)

    save_checkpoint(
        checkpoint_name,
        "final",
        model,
        model_opt,
        lr_schedule,
        outputs.loss,
    )
def make_trainer(
        train_data,
        test_data,
        bs=5000,
        split={
            'train': 200,
            'val': 5000
        },
        network=RealNVPTabularWPrior,
        net_config={},
        num_epochs=15,
        optim=AdamW,
        lr=1e-3,
        opt_config={'weight_decay': 1e-5},
        swag=False,
        swa_config={
            'swa_dec_pct': .5,
            'swa_start_pct': .75,
            'swa_freq_pct': .05,
            'swa_lr_factor': .1
        },
        swag_config={
            'subspace': 'covariance',
            'max_num_models': 20
        },
        #                 subspace='covariance', max_num_models=20,
        trainer=SemiFlow,
        trainer_config={
            'log_dir': os.path.expanduser('~/tb-experiments/UCI/'),
            'log_args': {
                'minPeriod': .1,
                'timeFrac': 3 / 10
            }
        },
        dev='cuda',
        save=False):
    with FixedNumpySeed(0):
        datasets = split_dataset(train_data, splits=split)
        datasets['_unlab'] = dmap(lambda mb: mb[0], train_data)
        datasets['test'] = test_data

    device = torch.device(dev)

    dataloaders = {
        k: LoaderTo(
            DataLoader(v,
                       batch_size=min(bs, len(datasets[k])),
                       shuffle=(k == 'train'),
                       num_workers=0,
                       pin_memory=False), device)
        for k, v in datasets.items()
    }
    dataloaders['Train'] = dataloaders['train']

    #     model = network(num_classes=train_data.num_classes, dim_in=train_data.dim, **net_config).to(device)
    #     swag_model = SWAG(model_cfg.base,
    #                     subspace_type=args.subspace, subspace_kwargs={'max_rank': args.max_num_models},
    #                     *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)

    #     swag_model.to(args.device)
    opt_constr = partial(optim, lr=lr, **opt_config)
    model = network(num_classes=train_data.num_classes,
                    dim_in=train_data.dim,
                    **net_config).to(device)
    if swag:
        swag_model = RealNVPTabularSWAG(dim_in=train_data.dim,
                                        **net_config,
                                        **swag_config)
        #         swag_model = SWAG(RealNVPTabular,
        #                           subspace_type=subspace, subspace_kwargs={'max_rank': max_num_models},
        #                           num_classes=train_data.num_classes, dim_in=train_data.dim,
        #                           num_coupling_layers=coupling_layers,in_dim=dim_in,**net_config)
        #         swag_model.to(device)
        #         swag_model = SWAG(RealNVPTabular, num_classes=train_data.num_classes, dim_in=train_data.dim,
        #                         swag=True, **swag_config, **net_config)
        #         model.to(device)
        swag_model.to(device)
        swa_config['steps_per_epoch'] = len(dataloaders['_unlab'])
        swa_config['num_epochs'] = num_epochs
        lr_sched = swa_learning_rate(**swa_config)
        #         lr_sched = cosLr(num_epochs)
        return trainer(model,
                       dataloaders,
                       opt_constr,
                       lr_sched,
                       swag_model=swag_model,
                       **swa_config,
                       **trainer_config)
    else:
        #         model = network(num_classes=train_data.num_classes, dim_in=train_data.dim, **net_config).to(device)
        lr_sched = cosLr(num_epochs)
        #     lr_sched = lambda e:1
        return trainer(model, dataloaders, opt_constr, lr_sched,
                       **trainer_config)