Exemple #1
0
def main():
    # Parse flags
    config = forge.config()
    config.batch_size = 1
    config.load_instances = True
    fet.print_flags()

    # Restore original model flags
    pretrained_flags = AttrDict(
        fet.json_load(os.path.join(config.model_dir, 'flags.json')))

    # Get validation loader
    train_loader, val_loader, test_loader = fet.load(config.data_config,
                                                     config)
    fprint(f"Split: {config.split}")
    if config.split == 'train':
        batch_loader = train_loader
    elif config.split == 'val':
        batch_loader = val_loader
    elif config.split == 'test':
        batch_loader = test_loader
    # Shuffle and prefetch to get same data for different models
    if 'gqn' not in config.data_config:
        batch_loader = torch.utils.data.DataLoader(batch_loader.dataset,
                                                   batch_size=1,
                                                   num_workers=0,
                                                   shuffle=True)
    # Prefetch batches
    prefetched_batches = []
    for i, x in enumerate(batch_loader):
        if i == config.num_images:
            break
        prefetched_batches.append(x)

    # Load model
    model = fet.load(config.model_config, pretrained_flags)
    fprint(model)
    model_path = os.path.join(config.model_dir, config.model_file)
    fprint(f"Restoring model from {model_path}")
    checkpoint = torch.load(model_path, map_location='cpu')
    model_state_dict = checkpoint['model_state_dict']
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1',
                         None)
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2',
                         None)
    model.load_state_dict(model_state_dict)

    # Set experiment folder and fprint file for logging
    fet.EXPERIMENT_FOLDER = config.model_dir
    fet.FPRINT_FILE = 'segmentation_metrics.txt'

    # Compute metrics
    model.eval()
    ari_fg_list, sc_fg_list, msc_fg_list = [], [], []
    with torch.no_grad():
        for i, x in enumerate(tqdm(prefetched_batches)):
            _, _, stats, _, _ = model(x['input'])
            # ARI
            ari_fg, _ = average_ari(stats.log_m_k,
                                    x['instances'],
                                    foreground_only=True)
            # Segmentation covering - foreground only
            iseg = torch.argmax(torch.cat(stats.log_m_k, 1), 1, True)
            msc_fg, _ = average_segcover(x['instances'], iseg, True)
            # Recording
            ari_fg_list.append(ari_fg)
            msc_fg_list.append(msc_fg)

    # Print average metrics
    fprint(f"Average FG ARI: {sum(ari_fg_list)/len(ari_fg_list)}")
    fprint(f"Average FG MSC: {sum(msc_fg_list)/len(msc_fg_list)}")
def main():
    # Parse flags
    config = forge.config()
    # Restore flags of pretrained model
    flag_path = osp.join(config.model_dir, 'flags.json')
    fprint(f"Restoring flags from {flag_path}")
    pretrained_flags = AttrDict(fet.json_load(flag_path))
    pretrained_flags.batch_size = 1
    pretrained_flags.gpu = False
    pretrained_flags.debug = True
    fet.print_flags()

    # Fix seeds. Always first thing to be done after parsing the config!
    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)
    # Make CUDA operations deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    #  Load model
    model = fet.load(config.model_config, pretrained_flags)
    model_path = osp.join(config.model_dir, config.model_file)
    fprint(f"Restoring model from {model_path}")
    checkpoint = torch.load(model_path, map_location='cpu')
    model_state_dict = checkpoint['model_state_dict']
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1',
                         None)
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2',
                         None)
    model.load_state_dict(model_state_dict)
    fprint(model)

    # Visualise
    model.eval()
    for _ in range(100):
        y, stats = model.sample(1, pretrained_flags.K_steps)
        fig, axes = plt.subplots(nrows=4, ncols=1 + pretrained_flags.K_steps)

        # Generated
        plot(axes, 0, 0, y, title='Generated scene', fontsize=12)
        # Empty plots
        plot(axes, 1, 0, fontsize=12)
        plot(axes, 2, 0, fontsize=12)
        plot(axes, 3, 0, fontsize=12)

        # Put K generation steps in separate subfigures
        for step in range(pretrained_flags.K_steps):
            x_step = stats['x_k'][step]
            m_step = stats['log_m_k'][step].exp()
            mx_step = stats['mx_k'][step]
            if 'log_s_k' in stats:
                s_step = stats['log_s_k'][step].exp()
            pre = 'Mask x RGB ' if step == 0 else ''
            plot(axes, 0, 1 + step, mx_step, pre + f'k={step+1}', fontsize=12)
            pre = 'RGB ' if step == 0 else ''
            plot(axes, 1, 1 + step, x_step, pre + f'k={step+1}', fontsize=12)
            pre = 'Mask ' if step == 0 else ''
            plot(axes,
                 2,
                 1 + step,
                 m_step,
                 pre + f'k={step+1}',
                 True,
                 fontsize=12)
            if 'log_s_k' in stats:
                pre = 'Scope ' if step == 0 else ''
                plot(axes,
                     3,
                     1 + step,
                     s_step,
                     pre + f'k={step+1}',
                     True,
                     axis=step == 0,
                     fontsize=12)

        # Beautify and show figure
        plt.subplots_adjust(wspace=0.05, hspace=0.05)
        manager = plt.get_current_fig_manager()
        manager.resize(*manager.window.maxsize())
        plt.show()
Exemple #3
0
def main():
    # Parse flags
    config = forge.config()
    fet.print_flags()
    # Restore flags of pretrained model
    flag_path = osp.join(config.model_dir, 'flags.json')
    fprint(f"Restoring flags from {flag_path}")
    pretrained_flags = AttrDict(fet.json_load(flag_path))
    pretrained_flags.debug = True

    # Fix seeds. Always first thing to be done after parsing the config!
    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)
    # Make CUDA operations deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Load data
    config.batch_size = 1
    _, _, test_loader = fet.load(config.data_config, config)

    # Load model
    model = fet.load(config.model_config, pretrained_flags)
    model_path = osp.join(config.model_dir, config.model_file)
    fprint(f"Restoring model from {model_path}")
    checkpoint = torch.load(model_path, map_location='cpu')
    model_state_dict = checkpoint['model_state_dict']
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1',
                         None)
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2',
                         None)
    model.load_state_dict(model_state_dict)
    fprint(model)

    # Visualise
    model.eval()
    for count, batch in enumerate(test_loader):
        if count >= config.num_images:
            break

        # Forward pass
        output, _, stats, _, _ = model(batch['input'])
        # Set up figure
        fig, axes = plt.subplots(nrows=4, ncols=1 + pretrained_flags.K_steps)

        # Input and reconstruction
        plot(axes, 0, 0, batch['input'], title='Input image', fontsize=12)
        plot(axes, 1, 0, output, title='Reconstruction', fontsize=12)
        # Empty plots
        plot(axes, 2, 0, fontsize=12)
        plot(axes, 3, 0, fontsize=12)

        # Put K reconstruction steps into separate subfigures
        x_k = stats['x_r_k']
        log_m_k = stats['log_m_k']
        mx_k = [x * m.exp() for x, m in zip(x_k, log_m_k)]
        log_s_k = stats['log_s_k'] if 'log_s_k' in stats else None
        for step in range(pretrained_flags.K_steps):
            mx_step = mx_k[step]
            x_step = x_k[step]
            m_step = log_m_k[step].exp()
            if log_s_k:
                s_step = log_s_k[step].exp()

            pre = 'Mask x RGB ' if step == 0 else ''
            plot(axes, 0, 1 + step, mx_step, pre + f'k={step+1}', fontsize=12)
            pre = 'RGB ' if step == 0 else ''
            plot(axes, 1, 1 + step, x_step, pre + f'k={step+1}', fontsize=12)
            pre = 'Mask ' if step == 0 else ''
            plot(axes,
                 2,
                 1 + step,
                 m_step,
                 pre + f'k={step+1}',
                 True,
                 fontsize=12)
            if log_s_k:
                pre = 'Scope ' if step == 0 else ''
                plot(axes,
                     3,
                     1 + step,
                     s_step,
                     pre + f'k={step+1}',
                     True,
                     axis=step == 0,
                     fontsize=12)

        # Beautify and show figure
        plt.subplots_adjust(wspace=0.05, hspace=0.15)
        manager = plt.get_current_fig_manager()
        manager.resize(*manager.window.maxsize())
        plt.show()
Exemple #4
0
def main():
    # Parse flags
    config = forge.config()

    # Load data
    dataloaders, num_species, charge_scale, ds_stats, data_name = fet.load(
        config.data_config, config=config)

    config.num_species = num_species
    config.charge_scale = charge_scale
    config.ds_stats = ds_stats

    # Load model
    model, model_name = fet.load(config.model_config, config)
    model.to(device)

    config.charge_scale = float(config.charge_scale.numpy())
    config.ds_stats = [float(stat.numpy()) for stat in config.ds_stats]

    # Prepare environment
    run_name = (config.run_name + "_bs" + str(config.batch_size) + "_lr" +
                str(config.learning_rate))

    if config.batch_fit != 0:
        run_name += "_bf" + str(config.batch_fit)

    if config.lr_schedule != "none":
        run_name += "_" + config.lr_schedule

    # Print flags
    fet.print_flags()

    # Setup optimizer
    model_params = model.predictor.parameters()

    opt_learning_rate = config.learning_rate
    model_opt = torch.optim.Adam(
        model_params,
        lr=opt_learning_rate,
        betas=(config.beta1, config.beta2),
        eps=1e-8,
    )
    # model_opt = torch.optim.SGD(model_params, lr=opt_learning_rate)

    # Cosine annealing learning rate
    if config.lr_schedule == "cosine":
        cos = cosLr(config.train_epochs)
        lr_sched = lambda e: max(cos(e), config.lr_floor * config.learning_rate
                                 )
        lr_schedule = optim.lr_scheduler.LambdaLR(model_opt, lr_sched)
    elif config.lr_schedule == "cosine_warmup":
        cos = cosLr(config.train_epochs)
        lr_sched = lambda e: max(
            min(e / (config.warmup_length * config.train_epochs), 1) * cos(e),
            config.lr_floor * config.learning_rate,
        )
        lr_schedule = optim.lr_scheduler.LambdaLR(model_opt, lr_sched)
    elif config.lr_schedule == "quadratic_warmup":
        lr_sched = lambda e: min(e / (0.01 * config.train_epochs), 1) * (
            1.0 / sqrt(1.0 + 10000.0 * (e / config.train_epochs)
                       )  # finish at 1/100 of initial lr
        )
        lr_schedule = optim.lr_scheduler.LambdaLR(model_opt, lr_sched)
    elif config.lr_schedule == "none":
        lr_sched = lambda e: 1.0
        lr_schedule = optim.lr_scheduler.LambdaLR(model_opt, lr_sched)
    else:
        raise ValueError(
            f"{config.lr_schedule} is not a recognised learning rate schedule")

    num_params = param_count(model)
    if config.parameter_count:
        for (name, parameter) in model.predictor.named_parameters():
            print(name, parameter.dtype)

        print(model)
        print("============================================================")
        print(f"{model_name} parameters: {num_params:.5e}")
        print("============================================================")
        # from torchsummary import summary

        # data = next(iter(dataloaders["train"]))

        # data = {k: v.to(device) for k, v in data.items()}
        # print(
        #     summary(
        #         model.predictor,
        #         data,
        #         batch_size=config.batch_size,
        #     )
        # )

        parameters = sum(parameter.numel()
                         for parameter in model.predictor.parameters())
        parameters_grad = sum(
            parameter.numel() if parameter.requires_grad else 0
            for parameter in model.predictor.parameters())
        print(f"Parameters: {parameters:,}")
        print(f"Parameters grad: {parameters_grad:,}")

        memory_allocations = []

        for batch_idx, data in enumerate(dataloaders["train"]):
            print(batch_idx)
            data = {k: v.to(device) for k, v in data.items()}

            model_opt.zero_grad()
            outputs = model(data, compute_loss=True)
            # torch.cuda.empty_cache()
            # memory_allocations.append(torch.cuda.memory_reserved() / 1024 / 1024 / 1024)
            # outputs.loss.backward()

        print(
            f"max memory reserved in one pass: {max(memory_allocations):0.4}GB"
        )
        sys.exit(0)

    else:
        print(f"{model_name} parameters: {num_params:.5e}")

    # set up results folders
    results_folder_name = osp.join(
        data_name,
        model_name,
        run_name,
    )

    logdir = osp.join(config.results_dir,
                      results_folder_name.replace(".", "_"))
    logdir, resume_checkpoint = fet.init_checkpoint(logdir, config.data_config,
                                                    config.model_config,
                                                    config.resume)

    checkpoint_name = osp.join(logdir, "model.ckpt")

    # Try to restore model and optimizer from checkpoint
    if resume_checkpoint is not None:
        start_epoch, best_valid_mae = load_checkpoint(resume_checkpoint, model,
                                                      model_opt, lr_schedule)
    else:
        start_epoch = 1
        best_valid_mae = 1e12

    train_iter = (start_epoch - 1) * (len(dataloaders["train"].dataset) //
                                      config.batch_size) + 1

    print("Starting training at epoch = {}, iter = {}".format(
        start_epoch, train_iter))

    # Setup tensorboard writing
    summary_writer = SummaryWriter(logdir)

    report_all = defaultdict(list)
    # Saving model at epoch 0 before training
    print("saving model at epoch 0 before training ... ")
    save_checkpoint(checkpoint_name, 0, model, model_opt, lr_schedule, 0.0)
    print("finished saving model at epoch 0 before training")

    if (config.debug and config.model_config
            == "configs/dynamics/eqv_transformer_model.py"):
        model_components = ([(0, [], "embedding_layer")] + list(
            chain.from_iterable((
                (k, [], f"ema_{k}"),
                (
                    k,
                    ["ema", "kernel", "location_kernel"],
                    f"ema_{k}_location_kernel",
                ),
                (
                    k,
                    ["ema", "kernel", "feature_kernel"],
                    f"ema_{k}_feature_kernel",
                ),
            ) for k in range(1, config.num_layers + 1))) +
                            [(config.num_layers + 2, [], "output_mlp")]
                            )  # components to track for debugging
        grad_flows = []

    if config.init_activations:
        activation_tracked = [(name, module)
                              for name, module in model.named_modules()
                              if isinstance(module, Expression)
                              | isinstance(module, nn.Linear)
                              | isinstance(module, MultiheadLinear)
                              | isinstance(module, MaskBatchNormNd)]
        activations = {}

        def save_activation(name, mod, inpt, otpt):
            if isinstance(inpt, tuple):
                if isinstance(inpt[0], list) | isinstance(inpt[0], tuple):
                    activations[name + "_inpt"] = inpt[0][1].detach().cpu()
                else:
                    if len(inpt) == 1:
                        activations[name + "_inpt"] = inpt[0].detach().cpu()
                    else:
                        activations[name + "_inpt"] = inpt[1].detach().cpu()
            else:
                activations[name + "_inpt"] = inpt.detach().cpu()

            if isinstance(otpt, tuple):
                if isinstance(otpt[0], list):
                    activations[name + "_otpt"] = otpt[0][1].detach().cpu()
                else:
                    if len(otpt) == 1:
                        activations[name + "_otpt"] = otpt[0].detach().cpu()
                    else:
                        activations[name + "_otpt"] = otpt[1].detach().cpu()
            else:
                activations[name + "_otpt"] = otpt.detach().cpu()

        for name, tracked_module in activation_tracked:
            tracked_module.register_forward_hook(partial(
                save_activation, name))

    # Training
    start_t = time.perf_counter()

    iters_per_epoch = len(dataloaders["train"])
    last_valid_loss = 1000.0
    for epoch in tqdm(range(start_epoch, config.train_epochs + 1)):
        model.train()

        for batch_idx, data in enumerate(dataloaders["train"]):
            data = {k: v.to(device) for k, v in data.items()}

            model_opt.zero_grad()
            outputs = model(data, compute_loss=True)

            outputs.loss.backward()
            if config.clip_grad:
                # Clip gradient L2-norm at 1
                torch.nn.utils.clip_grad.clip_grad_norm_(
                    model.predictor.parameters(), 1.0)
            model_opt.step()

            if config.init_activations:
                model_opt.zero_grad()
                outputs = model(data, compute_loss=True)
                outputs.loss.backward()
                for name, activation in activations.items():
                    print(name)
                    summary_writer.add_histogram(f"activations/{name}",
                                                 activation.numpy(), 0)

                sys.exit(0)

            if config.log_train_values:
                reports = parse_reports(outputs.reports)
                if batch_idx % config.report_loss_every == 0:
                    log_tensorboard(summary_writer, train_iter, reports,
                                    "train/")
                    report_all = log_reports(report_all, train_iter, reports,
                                             "train")
                    print_reports(
                        reports,
                        start_t,
                        epoch,
                        batch_idx,
                        len(dataloaders["train"].dataset) // config.batch_size,
                        prefix="train",
                    )

            # Logging
            if batch_idx % config.evaluate_every == 0:
                model.eval()
                with torch.no_grad():
                    valid_mae = 0.0
                    for data in dataloaders["valid"]:
                        data = {k: v.to(device) for k, v in data.items()}
                        outputs = model(data, compute_loss=True)
                        valid_mae = valid_mae + outputs.mae
                model.train()

                outputs["reports"].valid_mae = valid_mae / len(
                    dataloaders["valid"])

                reports = parse_reports(outputs.reports)

                log_tensorboard(summary_writer, train_iter, reports, "valid")
                report_all = log_reports(report_all, train_iter, reports,
                                         "valid")
                print_reports(
                    reports,
                    start_t,
                    epoch,
                    batch_idx,
                    len(dataloaders["train"].dataset) // config.batch_size,
                    prefix="valid",
                )

                loss_diff = (last_valid_loss -
                             (valid_mae / len(dataloaders["valid"])).item())
                if loss_diff and config.find_spikes < -0.1:
                    save_checkpoint(
                        checkpoint_name + "_spike",
                        epoch,
                        model,
                        model_opt,
                        lr_schedule,
                        outputs.loss,
                    )

                last_valid_loss = (valid_mae /
                                   len(dataloaders["valid"])).item()

                if outputs["reports"].valid_mae < best_valid_mae:
                    save_checkpoint(
                        checkpoint_name,
                        "best_valid_mae",
                        model,
                        model_opt,
                        lr_schedule,
                        best_valid_mae,
                    )
                    best_valid_mae = outputs["reports"].valid_mae

            train_iter += 1

            # Step the LR schedule
            lr_schedule.step(train_iter / iters_per_epoch)

            # Track stuff for debugging
            if config.debug:
                model.eval()
                with torch.no_grad():
                    curr_params = {
                        k: v.detach().clone()
                        for k, v in model.state_dict().items()
                    }

                    # updates norm
                    if train_iter == 1:
                        update_norm = 0
                    else:
                        update_norms = []
                        for (k_prev,
                             v_prev), (k_curr,
                                       v_curr) in zip(prev_params.items(),
                                                      curr_params.items()):
                            assert k_prev == k_curr
                            if (
                                    "tracked" not in k_prev
                            ):  # ignore batch norm tracking. TODO: should this be ignored? if not, fix!
                                update_norms.append(
                                    (v_curr - v_prev).norm(1).item())
                        update_norm = sum(update_norms)

                    # gradient norm
                    grad_norm = 0
                    for p in model.parameters():
                        try:
                            grad_norm += p.grad.norm(1)
                        except AttributeError:
                            pass

                    # weights norm
                    if (config.model_config ==
                            "configs/dynamics/eqv_transformer_model.py"):
                        model_norms = {}
                        for comp_name in model_components:
                            comp = get_component(model.predictor.net,
                                                 comp_name)
                            norm = get_average_norm(comp)
                            model_norms[comp_name[2]] = norm

                        log_tensorboard(
                            summary_writer,
                            train_iter,
                            model_norms,
                            "debug/avg_model_norms/",
                        )

                    log_tensorboard(
                        summary_writer,
                        train_iter,
                        {
                            "avg_update_norm1": update_norm / num_params,
                            "avg_grad_norm1": grad_norm / num_params,
                        },
                        "debug/",
                    )
                    prev_params = curr_params

                    # gradient flow
                    ave_grads = []
                    max_grads = []
                    layers = []
                    for n, p in model.named_parameters():
                        if (p.requires_grad) and ("bias" not in n):
                            layers.append(n)
                            ave_grads.append(p.grad.abs().mean().item())
                            max_grads.append(p.grad.abs().max().item())

                    grad_flow = {
                        "layers": layers,
                        "ave_grads": ave_grads,
                        "max_grads": max_grads,
                    }
                    grad_flows.append(grad_flow)

                model.train()

        # Test model at end of batch
        with torch.no_grad():
            model.eval()
            test_mae = 0.0
            for data in dataloaders["test"]:
                data = {k: v.to(device) for k, v in data.items()}
                outputs = model(data, compute_loss=True)
                test_mae = test_mae + outputs.mae

        outputs["reports"].test_mae = test_mae / len(dataloaders["test"])

        reports = parse_reports(outputs.reports)

        log_tensorboard(summary_writer, train_iter, reports, "test")
        report_all = log_reports(report_all, train_iter, reports, "test")

        print_reports(
            reports,
            start_t,
            epoch,
            batch_idx,
            len(dataloaders["train"].dataset) // config.batch_size,
            prefix="test",
        )

        reports = {
            "lr": lr_schedule.get_lr()[0],
            "time": time.perf_counter() - start_t,
            "epoch": epoch,
        }

        log_tensorboard(summary_writer, train_iter, reports, "stats")
        report_all = log_reports(report_all, train_iter, reports, "stats")

        # Save the reports
        dd.io.save(logdir + "/results_dict.h5", report_all)

        # Save a checkpoint
        if epoch % config.save_check_points == 0:
            save_checkpoint(
                checkpoint_name,
                epoch,
                model,
                model_opt,
                lr_schedule,
                best_valid_mae,
            )
            if config.only_store_last_checkpoint:
                delete_checkpoint(checkpoint_name,
                                  epoch - config.save_check_points)

    save_checkpoint(
        checkpoint_name,
        "final",
        model,
        model_opt,
        lr_schedule,
        outputs.loss,
    )
Exemple #5
0
checkpoint_name = osp.join(logdir, 'model.ckpt')

# Build the graph
tf.reset_default_graph()
# load data
data_dict = fet.load(config.data_config, config)
# load the model
loss, stats, _ = fet.load(config.model_config, config, **data_dict)

# Add summaries for reported stats
# summaries can be set up in the model config file
for (k, v) in stats.items():
    tf.summary.scalar(k, v)

# Print model stats
fet.print_flags()
fet.print_variables_by_scope()
fet.print_num_params()

# Setup the optimizer
global_step = tf.train.get_or_create_global_step()
opt = tf.train.RMSPropOptimizer(config.learning_rate, momentum=.9)

# Create the train step
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_step = opt.minimize(loss, global_step=global_step)

# create session and initializer variables
sess = fet.get_session()
sess.run(tf.global_variables_initializer())
Exemple #6
0
def main():
    # Parse flags
    config = forge.config()

    # Set device
    if torch.cuda.is_available():
        device = f"cuda:{config.device}"
        torch.cuda.set_device(device)
    else:
        device = "cpu"

    # Load data
    dataloaders, data_name = fet.load(config.data_config, config=config)

    train_loader = dataloaders["train"]
    test_loader = dataloaders["test"]
    val_loader = dataloaders["val"]

    # Load model
    model, model_name = fet.load(config.model_config, config)
    model = model.to(device)
    print(model)

    # Prepare environment
    params_in_run_name = [
        ("batch_size", "bs"),
        ("learning_rate", "lr"),
        ("num_heads", "nheads"),
        ("num_layers", "nlayers"),
        ("dim_hidden", "hdim"),
        ("kernel_dim", "kdim"),
        ("location_attention", "locatt"),
        ("model_seed", "mseed"),
        ("lr_schedule", "lrsched"),
        ("layer_norm", "ln"),
        ("batch_norm", "bn"),
        ("channel_width", "width"),
        ("attention_fn", "attfn"),
        ("output_mlp_scale", "mlpscale"),
        ("train_epochs", "epochs"),
        ("block_norm", "block"),
        ("kernel_type", "ktype"),
        ("architecture", "arch"),
        ("activation_function", "act"),
        ("space_dim", "spacedim"),
        ("num_particles", "prtcls"),
        ("n_train", "ntrain"),
        ("group", "group"),
        ("lift_samples", "ls"),
    ]

    run_name = ""
    for config_param in params_in_run_name:
        attr = config_param[0]
        abbrev = config_param[1]

        if hasattr(config, attr):
            run_name += abbrev
            run_name += str(getattr(config, attr))
            run_name += "_"

    if config.clip_grad_norm:
        run_name += "clipnorm" + str(config.max_grad_norm) + "_"

    results_folder_name = osp.join(
        data_name,
        model_name,
        config.run_name,
        run_name,
    )

    logdir = osp.join(config.results_dir,
                      results_folder_name.replace(".", "_"))
    logdir, resume_checkpoint = fet.init_checkpoint(logdir, config.data_config,
                                                    config.model_config,
                                                    config.resume)

    checkpoint_name = osp.join(logdir, "model.ckpt")

    # Print flags
    fet.print_flags()

    # Setup optimizer
    model_params = model.predictor.parameters()

    opt_learning_rate = config.learning_rate
    model_opt = torch.optim.Adam(model_params,
                                 lr=opt_learning_rate,
                                 betas=(config.beta1, config.beta2))

    if config.lr_schedule == "cosine_annealing":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            model_opt, config.train_epochs)
    elif config.lr_schedule == "cosine_annealing_warmup":
        num_warmup_epochs = int(0.05 * config.train_epochs)
        num_warmup_steps = len(train_loader) * num_warmup_epochs
        num_training_steps = len(train_loader) * config.train_epochs
        scheduler = transformers.get_cosine_schedule_with_warmup(
            model_opt,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
        )

    # Try to restore model and optimizer from checkpoint
    if resume_checkpoint is not None:
        start_epoch = load_checkpoint(resume_checkpoint, model, model_opt)
    else:
        start_epoch = 1

    train_iter = (start_epoch - 1) * (len(train_loader.dataset) //
                                      config.batch_size) + 1

    print("Starting training at epoch = {}, iter = {}".format(
        start_epoch, train_iter))

    # Setup tensorboard writing
    summary_writer = SummaryWriter(logdir)

    train_reports = []
    report_all = {}
    report_all_val = {}
    # Saving model at epoch 0 before training
    print("saving model at epoch 0 before training ... ")
    save_checkpoint(checkpoint_name, 0, model, model_opt, loss=0.0)
    print("finished saving model at epoch 0 before training")

    # if (
    #     config.debug
    #     and config.model_config == "configs/dynamics/eqv_transformer_model.py"
    # ):
    #     model_components = (
    #         [(0, [], "embedding_layer")]
    #         + list(
    #             chain.from_iterable(
    #                 (
    #                     (k, [], f"ema_{k}"),
    #                     (
    #                         k,
    #                         ["ema", "kernel", "location_kernel"],
    #                         f"ema_{k}_location_kernel",
    #                     ),
    #                     (
    #                         k,
    #                         ["ema", "kernel", "feature_kernel"],
    #                         f"ema_{k}_feature_kernel",
    #                     ),
    #                 )
    #                 for k in range(1, config.num_layers + 1)
    #             )
    #         )
    #         + [(config.num_layers + 2, [], "output_mlp")]
    #     )  # components to track for debugging

    num_params = param_count(model)
    print(f"Number of model parameters: {num_params}")

    # Training
    start_t = time.time()

    total_train_iters = len(train_loader) * config.train_epochs
    iters_per_eval = int(total_train_iters /
                         100)  # evaluate 100 times over the course of training

    assert (
        config.n_train % min(config.batch_size, config.n_train) == 0
    ), "Batch size doesn't divide dataset size. Can be problematic for loss computation (see below)."

    training_failed = False
    best_val_loss_so_far = 1e+7

    for epoch in tqdm(range(start_epoch, config.train_epochs + 1)):
        model.train()

        for batch_idx, data in enumerate(train_loader):
            data = nested_to(
                data, device, torch.float32
            )  # the format is ((z0, sys_params, ts), true_zs) for data
            outputs = model(data)

            if torch.isnan(outputs.loss):
                if not training_failed:
                    epoch_of_nan = epoch
                if (epoch > epoch_of_nan + 1) and training_failed:
                    raise ValueError("Loss Nan-ed.")
                training_failed = True

            model_opt.zero_grad()
            outputs.loss.backward(retain_graph=False)

            if config.clip_grad_norm:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               config.max_grad_norm,
                                               norm_type=1)

            model_opt.step()

            train_reports.append(parse_reports_cpu(outputs.reports))

            if config.log_train_values:
                reports = parse_reports(outputs.reports)
                if batch_idx % config.report_loss_every == 0:
                    log_tensorboard(summary_writer, train_iter, reports,
                                    "train/")
                    print_reports(
                        reports,
                        start_t,
                        epoch,
                        batch_idx,
                        len(train_loader.dataset) // config.batch_size,
                        prefix="train",
                    )
                    log_tensorboard(
                        summary_writer,
                        train_iter,
                        {"lr": model_opt.param_groups[0]["lr"]},
                        "hyperparams/",
                    )

            # Do learning rate schedule steps per STEP for cosine_annealing_warmup
            if config.lr_schedule == "cosine_annealing_warmup":
                scheduler.step()

            # Track stuff for debugging
            if config.debug:
                model.eval()
                with torch.no_grad():
                    curr_params = {
                        k: v.detach().clone()
                        for k, v in model.state_dict().items()
                    }

                    # updates norm
                    if train_iter == 1:
                        update_norm = 0
                    else:
                        update_norms = []
                        for (k_prev,
                             v_prev), (k_curr,
                                       v_curr) in zip(prev_params.items(),
                                                      curr_params.items()):
                            assert k_prev == k_curr
                            if ("tracked" not in k_prev):
                                update_norms.append(
                                    (v_curr - v_prev).norm(1).item())
                        update_norm = sum(update_norms)

                    # gradient norm
                    grad_norm = 0
                    for p in model.parameters():
                        try:
                            grad_norm += p.grad.norm(1)
                        except AttributeError:
                            pass

                    # Average V size:
                    (z0, sys_params, ts), true_zs = data

                    z = z0
                    m = sys_params[
                        ..., 0]  # assume the first component encodes masses
                    D = z.shape[-1]  # of ODE dims, 2*num_particles*space_dim
                    q = z[:, :D // 2].reshape(*m.shape, -1)
                    k = sys_params[..., 1]

                    V = model.predictor.compute_V((q, sys_params))
                    V_true = SpringV(q, k)

                    log_tensorboard(
                        summary_writer,
                        train_iter,
                        {
                            "avg_update_norm1":
                            update_norm / num_params,
                            "avg_grad_norm1":
                            grad_norm / num_params,
                            "avg_predicted_potential_norm":
                            V.norm(1) / V.numel(),
                            "avg_true_potential_norm":
                            V_true.norm(1) / V_true.numel(),
                        },
                        "debug/",
                    )
                    prev_params = curr_params

                model.train()

            # Logging
            if train_iter % iters_per_eval == 0 or (
                    train_iter == total_train_iters
            ):  # batch_idx % config.evaluate_every == 0:
                model.eval()
                with torch.no_grad():

                    reports = None
                    for data in test_loader:
                        data = nested_to(data, device, torch.float32)
                        outputs = model(data)

                        if reports is None:
                            reports = {
                                k: v.detach().clone().cpu()
                                for k, v in outputs.reports.items()
                            }
                        else:
                            for k, v in outputs.reports.items():
                                reports[k] += v.detach().clone().cpu()

                    for k, v in reports.items():
                        reports[k] = v / len(test_loader)

                    reports = parse_reports(reports)
                    reports["time"] = time.time() - start_t
                    if report_all == {}:
                        report_all = deepcopy(reports)

                        for d in reports.keys():
                            report_all[d] = [report_all[d]]
                    else:
                        for d in reports.keys():
                            report_all[d].append(reports[d])

                    log_tensorboard(summary_writer, train_iter, reports,
                                    "test/")
                    print_reports(
                        reports,
                        start_t,
                        epoch,
                        batch_idx,
                        len(train_loader.dataset) // config.batch_size,
                        prefix="test",
                    )

                    if config.kill_if_poor:
                        if epoch > config.train_epochs * 0.2:
                            if reports["mse"] > 0.01:
                                raise RuntimeError(
                                    f"Killed run due to poor performance.")

                    # repeat for validation data
                    reports = None
                    for data in val_loader:
                        data = nested_to(data, device, torch.float32)
                        outputs = model(data)

                        if reports is None:
                            reports = {
                                k: v.detach().clone().cpu()
                                for k, v in outputs.reports.items()
                            }
                        else:
                            for k, v in outputs.reports.items():
                                reports[k] += v.detach().clone().cpu()

                    for k, v in reports.items():
                        reports[k] = v / len(val_loader)

                    reports = parse_reports(reports)
                    reports["time"] = time.time() - start_t
                    if report_all_val == {}:
                        report_all_val = deepcopy(reports)

                        for d in reports.keys():
                            report_all_val[d] = [report_all_val[d]]
                    else:
                        for d in reports.keys():
                            report_all_val[d].append(reports[d])

                    log_tensorboard(summary_writer, train_iter, reports,
                                    "val/")
                    print_reports(
                        reports,
                        start_t,
                        epoch,
                        batch_idx,
                        len(train_loader.dataset) // config.batch_size,
                        prefix="val",
                    )

                    if report_all_val['mse'][-1] < best_val_loss_so_far:
                        save_checkpoint(checkpoint_name,
                                        f"early_stop",
                                        model,
                                        model_opt,
                                        loss=outputs.loss)
                        best_val_loss_so_far = report_all_val['mse'][-1]

                model.train()

            train_iter += 1

        if config.lr_schedule == "cosine_annealing":
            scheduler.step()

        if epoch % config.save_check_points == 0:
            save_checkpoint(checkpoint_name,
                            train_iter,
                            model,
                            model_opt,
                            loss=outputs.loss)

        dd.io.save(logdir + "/results_dict_train.h5", train_reports)
        dd.io.save(logdir + "/results_dict.h5", report_all)
        dd.io.save(logdir + "/results_dict_val.h5", report_all_val)

    # always save final model
    save_checkpoint(checkpoint_name,
                    train_iter,
                    model,
                    model_opt,
                    loss=outputs.loss)

    if config.save_test_predictions:
        print(
            "Starting to make model predictions on test sets for *final model*."
        )
        for chunk_len in [5, 100]:
            start_t_preds = time.time()
            data_config = SimpleNamespace(
                **{
                    **config.__dict__["__flags"],
                    **{
                        "chunk_len": chunk_len,
                        "batch_size": 500
                    }
                })
            dataloaders, data_name = fet.load(config.data_config,
                                              config=data_config)
            test_loader_preds = dataloaders["test"]

            torch.cuda.empty_cache()
            with torch.no_grad():
                preds = []
                true = []
                num_datapoints = 0
                for idx, d in enumerate(test_loader_preds):
                    true.append(d[-1])
                    d = nested_to(d, device, torch.float32)
                    outputs = model(d)

                    pred_zs = outputs.prediction
                    preds.append(pred_zs)

                    num_datapoints += len(pred_zs)

                    if num_datapoints >= 2000:
                        break

                preds = torch.cat(preds, dim=0).cpu()
                true = torch.cat(true, dim=0).cpu()

                save_dir = osp.join(
                    logdir, f"traj_preds_{chunk_len}_steps_2k_test.pt")
                torch.save(preds, save_dir)

                save_dir = osp.join(logdir,
                                    f"traj_true_{chunk_len}_steps_2k_test.pt")
                torch.save(true, save_dir)

                print(
                    f"Completed making test predictions for chunk_len = {chunk_len} in {time.time() - start_t_preds:.2f} seconds."
                )
def main():
    # Parse flags
    config = forge.config()

    # Set device
    if torch.cuda.is_available():
        device = f"cuda:{config.device}"
        torch.cuda.set_device(device)
    else:
        device = "cpu"

    # Load data
    train_loader, test_loader, data_name = fet.load(config.data_config,
                                                    config=config)

    # Load model
    model, model_name = fet.load(config.model_config, config)
    model = model.to(device)
    print(model)
    torch.manual_seed(0)

    # Prepare environment

    if "set_transformer" in config.model_config:
        params_in_run_name = [
            ("batch_size", "bs"),
            ("learning_rate", "lr"),
            ("num_heads", "nheads"),
            ("patterns_reps", "reps"),
            ("model_seed", "mseed"),
            ("train_epochs", "epochs"),
            ("naug", "naug"),
        ]

    else:
        params_in_run_name = [
            ("batch_size", "bs"),
            ("learning_rate", "lr"),
            ("num_heads", "nheads"),
            ("num_layers", "nlayers"),
            ("dim_hidden", "hdim"),
            ("kernel_dim", "kdim"),
            ("location_attention", "locatt"),
            ("model_seed", "mseed"),
            ("lr_schedule", "lrsched"),
            ("layer_norm", "ln"),
            ("batch_norm_att", "bnatt"),
            ("batch_norm", "bn"),
            ("batch_norm_final_mlp", "bnfinalmlp"),
            ("k", "k"),
            ("attention_fn", "attfn"),
            ("output_mlp_scale", "mlpscale"),
            ("train_epochs", "epochs"),
            ("block_norm", "block"),
            ("kernel_type", "ktype"),
            ("architecture", "arch"),
            ("kernel_act", "actv"),
            ("patterns_reps", "reps"),
            ("lift_samples", "nsamples"),
            ("content_type", "content"),
            ("naug", "naug"),
        ]

    run_name = ""  # config.run_name
    for config_param in params_in_run_name:
        attr = config_param[0]
        abbrev = config_param[1]

        if hasattr(config, attr):
            run_name += abbrev
            run_name += str(getattr(config, attr))
            run_name += "_"

    results_folder_name = osp.join(
        data_name,
        model_name,
        config.run_name,
        run_name,
    )

    # results_folder_name = osp.join(data_name, model_name, run_name,)

    logdir = osp.join(config.results_dir,
                      results_folder_name.replace(".", "_"))
    logdir, resume_checkpoint = fet.init_checkpoint(logdir, config.data_config,
                                                    config.model_config,
                                                    config.resume)

    checkpoint_name = osp.join(logdir, "model.ckpt")

    # Print flags
    fet.print_flags()

    # Setup optimizer
    model_params = model.encoder.parameters()

    opt_learning_rate = config.learning_rate
    model_opt = torch.optim.Adam(model_params,
                                 lr=opt_learning_rate,
                                 betas=(config.beta1, config.beta2))

    # Try to restore model and optimizer from checkpoint
    if resume_checkpoint is not None:
        start_epoch = load_checkpoint(resume_checkpoint, model, model_opt)
    else:
        start_epoch = 1

    train_iter = (start_epoch - 1) * (len(train_loader.dataset) //
                                      config.batch_size) + 1

    print("Starting training at epoch = {}, iter = {}".format(
        start_epoch, train_iter))

    # Setup tensorboard writing
    summary_writer = SummaryWriter(logdir)

    train_reports = []
    report_all = {}
    # Saving model at epoch 0 before training
    print("saving model at epoch 0 before training ... ")
    save_checkpoint(checkpoint_name, 0, model, model_opt, loss=0.0)
    print("finished saving model at epoch 0 before training")

    if (config.debug and config.model_config
            == "configs/constellation/eqv_transformer_model.py"):
        model_components = ([(0, [], "embedding_layer")] + list(
            chain.from_iterable((
                (k, [], f"ema_{k}"),
                (
                    k,
                    ["ema", "kernel", "location_kernel"],
                    f"ema_{k}_location_kernel",
                ),
                (
                    k,
                    ["ema", "kernel", "feature_kernel"],
                    f"ema_{k}_feature_kernel",
                ),
            ) for k in range(1, config.num_layers + 1))) +
                            [(config.num_layers + 2, [], "output_mlp")]
                            )  # components to track for debugging

    num_params = param_count(model)
    print(f"Number of model parameters: {num_params}")

    # Training
    start_t = time.time()

    grad_flows = []
    training_failed = False
    for epoch in tqdm(range(start_epoch, config.train_epochs + 1)):
        model.train()

        for batch_idx, data in enumerate(train_loader):
            data, presence, target = [d.to(device).float() for d in data]
            if config.train_aug_t2:
                unifs = np.random.normal(
                    1
                )  # torch.rand(1) * 2 * math.pi # torch.randn(1).to(device) * 10
                data += unifs
            if config.train_aug_se2:
                # angle = torch.rand(1) * 2 * math.pi
                angle = np.random.random(size=1) * 2 * math.pi
                unifs = np.random.normal(1)  # torch.randn(1).to(device) * 10
                data = rotate(data, angle.item())
                data += unifs
            outputs = model([data, presence], target)

            if torch.isnan(outputs.loss):
                if not training_failed:
                    epoch_of_nan = epoch
                if (epoch > epoch_of_nan + 1) and training_failed:
                    raise ValueError("Loss Nan-ed.")
                training_failed = True

            model_opt.zero_grad()
            outputs.loss.backward(retain_graph=False)
            model_opt.step()

            train_reports.append(parse_reports_cpu(outputs.reports))

            if config.log_train_values:
                reports = parse_reports(outputs.reports)
                if batch_idx % config.report_loss_every == 0:
                    log_tensorboard(summary_writer, train_iter, reports,
                                    "train/")
                    print_reports(
                        reports,
                        start_t,
                        epoch,
                        batch_idx,
                        len(train_loader.dataset) // config.batch_size,
                        prefix="train",
                    )

            # Track stuff for debugging
            if config.debug:
                model.eval()
                with torch.no_grad():
                    curr_params = {
                        k: v.detach().clone()
                        for k, v in model.state_dict().items()
                    }

                    # updates norm
                    if train_iter == 1:
                        update_norm = 0
                    else:
                        update_norms = []
                        for (k_prev,
                             v_prev), (k_curr,
                                       v_curr) in zip(prev_params.items(),
                                                      curr_params.items()):
                            assert k_prev == k_curr
                            if (
                                    "tracked" not in k_prev
                            ):  # ignore batch norm tracking. TODO: should this be ignored? if not, fix!
                                update_norms.append(
                                    (v_curr - v_prev).norm(1).item())
                        update_norm = sum(update_norms)

                    # gradient norm
                    grad_norm = 0
                    for p in model.parameters():
                        try:
                            grad_norm += p.grad.norm(1)
                        except AttributeError:
                            pass

                    # # weights norm
                    # if config.model_config == 'configs/constellation/eqv_transformer_model.py':
                    #     model_norms = {}
                    #     for comp_name in model_components:
                    #         comp = get_component(model.encoder.net, comp_name)
                    #         norm = get_average_norm(comp)
                    #         model_norms[comp_name[2]] = norm

                    #     log_tensorboard(
                    #         summary_writer,
                    #         train_iter,
                    #         model_norms,
                    #         "debug/avg_model_norms/",
                    #     )

                    log_tensorboard(
                        summary_writer,
                        train_iter,
                        {
                            "avg_update_norm1": update_norm / num_params,
                            "avg_grad_norm1": grad_norm / num_params,
                        },
                        "debug/",
                    )
                    prev_params = curr_params

                    # # gradient flow
                    # ave_grads = []
                    # max_grads= []
                    # layers = []
                    # for n, p in model.named_parameters():
                    #     if (p.requires_grad) and (p.grad is not None): # and ("bias" not in n):
                    #         layers.append(n)
                    #         ave_grads.append(p.grad.abs().mean().item())
                    #         max_grads.append(p.grad.abs().max().item())

                    # grad_flow = {"layers": layers, "ave_grads": ave_grads, "max_grads": max_grads}
                    # grad_flows.append(grad_flow)

                model.train()

            # Logging
            if batch_idx % config.evaluate_every == 0:
                model.eval()
                with torch.no_grad():
                    reports = None
                    for data in test_loader:
                        data, presence, target = [
                            d.to(device).float() for d in data
                        ]
                        # if config.data_config == "configs/constellation/constellation.py":
                        # if config.global_rotation_angle != 0.0:
                        # data = rotate(data, config.global_rotation_angle)
                        outputs = model([data, presence], target)

                        if reports is None:
                            reports = {
                                k: v.detach().clone().cpu()
                                for k, v in outputs.reports.items()
                            }
                        else:
                            for k, v in outputs.reports.items():
                                reports[k] += v.detach().clone().cpu()

                    for k, v in reports.items():
                        reports[k] = v / len(
                            test_loader
                        )  # XXX: note this is slightly incorrect since mini-batch sizes can vary (if batch_size doesn't divide train_size), but approximately correct.

                    reports = parse_reports(reports)
                    reports["time"] = time.time() - start_t
                    if report_all == {}:
                        report_all = deepcopy(reports)

                        for d in reports.keys():
                            report_all[d] = [report_all[d]]
                    else:
                        for d in reports.keys():
                            report_all[d].append(reports[d])

                    log_tensorboard(summary_writer, train_iter, reports,
                                    "test/")
                    print_reports(
                        reports,
                        start_t,
                        epoch,
                        batch_idx,
                        len(train_loader.dataset) // config.batch_size,
                        prefix="test",
                    )

                model.train()

            train_iter += 1

        if epoch % config.save_check_points == 0:
            save_checkpoint(checkpoint_name,
                            train_iter,
                            model,
                            model_opt,
                            loss=outputs.loss)

        dd.io.save(logdir + "/results_dict_train.h5", train_reports)
        dd.io.save(logdir + "/results_dict.h5", report_all)

        # if config.debug:
        #     # dd.io.save(logdir + "/grad_flows.h5", grad_flows)

    save_checkpoint(checkpoint_name,
                    train_iter,
                    model,
                    model_opt,
                    loss=outputs.loss)