Beispiel #1
0
def main():
    # Parse flags
    cfg = forge.config()
    cfg.num_workers = 0

    # Set manual seed
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    # Make CUDA operations deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Get data loaders
    train_loader, _, _ = fet.load(cfg.data_config, cfg)

    # Visualise
    for x in train_loader:
        fig, axes = plt.subplots(1, cfg.batch_size, figsize=(20, 10))

        img = x['input']
        for b_idx in range(cfg.batch_size):
            np_img = np.moveaxis(img.data.numpy()[b_idx], 0, -1)
            if img.shape[1] == 1:
                axes[b_idx].imshow(
                    np_img[:, :, 0], norm=NoNorm(), cmap='gray')
            elif img.shape[1] == 3:
                axes[b_idx].imshow(np_img)
            axes[b_idx].axis('off')

        manager = plt.get_current_fig_manager()
        manager.resize(*manager.window.maxsize())
        plt.show()
Beispiel #2
0
def main():
    # Parse flags
    config = forge.config()
    fet.EXPERIMENT_FOLDER = config.model_dir
    fet.FPRINT_FILE = 'fid_evaluation.txt'
    config.shuffle_test = True

    # Fix seeds. Always first thing to be done after parsing the config!
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    random.seed(config.seed)
    # Make CUDA operations deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Using GPU?
    if torch.cuda.is_available() and config.gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        config.gpu = False
        torch.set_default_tensor_type('torch.FloatTensor')
    fet.print_flags()

    # Load data
    _, _, test_loader = fet.load(config.data_config, config)

    #  Load model
    flag_path = osp.join(config.model_dir, 'flags.json')
    fprint(f"Restoring flags from {flag_path}")
    pretrained_flags = AttrDict(fet.json_load(flag_path))
    model = fet.load(config.model_config, pretrained_flags)
    model_path = osp.join(config.model_dir, config.model_file)
    fprint(f"Restoring model from {model_path}")
    checkpoint = torch.load(model_path, map_location='cpu')
    model_state_dict = checkpoint['model_state_dict']
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1',
                         None)
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2',
                         None)
    model.load_state_dict(model_state_dict)
    fprint(model)
    # Put model on GPU
    if config.gpu:
        model = model.cuda()

    # Compute FID
    fid_from_model(model, test_loader, config.batch_size,
                   config.num_fid_images, config.feat_dim, config.img_dir)
Beispiel #3
0
                    'Top directory for all experimental results.')
flags.DEFINE_string('run_name', 'mnist',
                    'Name of this job and name of results folder.')
flags.DEFINE_boolean('resume', False, 'Tries to resume a job if True.')

# Logging config
flags.DEFINE_integer('report_loss_every', 100,
                     'Number of iterations between reporting minibatch loss.')
flags.DEFINE_integer('train_epochs', 20, 'Maximum number of training epochs.')

# Experiment config
flags.DEFINE_integer('batch_size', 32, 'Mini-batch size.')
flags.DEFINE_float('learning_rate', 1e-5, 'SGD learning rate.')

# Parse flags
config = forge.config()

# Prepare enviornment
logdir = osp.join(config.results_dir, config.run_name)
logdir, resume_checkpoint = fet.init_checkpoint(logdir, config.data_config,
                                                config.model_config,
                                                config.resume)
checkpoint_name = osp.join(logdir, 'model.ckpt')

# Load data
train_loader = fet.load(config.data_config, config)
# Load model
model = fet.load(config.model_config, config)

# Print flags
fet.print_flags()
def main():
    # Parse flags
    config = forge.config()
    # Restore flags of pretrained model
    flag_path = osp.join(config.model_dir, 'flags.json')
    fprint(f"Restoring flags from {flag_path}")
    pretrained_flags = AttrDict(fet.json_load(flag_path))
    pretrained_flags.batch_size = 1
    pretrained_flags.gpu = False
    pretrained_flags.debug = True
    fet.print_flags()

    # Fix seeds. Always first thing to be done after parsing the config!
    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)
    # Make CUDA operations deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    #  Load model
    model = fet.load(config.model_config, pretrained_flags)
    model_path = osp.join(config.model_dir, config.model_file)
    fprint(f"Restoring model from {model_path}")
    checkpoint = torch.load(model_path, map_location='cpu')
    model_state_dict = checkpoint['model_state_dict']
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1',
                         None)
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2',
                         None)
    model.load_state_dict(model_state_dict)
    fprint(model)

    # Visualise
    model.eval()
    for _ in range(100):
        y, stats = model.sample(1, pretrained_flags.K_steps)
        fig, axes = plt.subplots(nrows=4, ncols=1 + pretrained_flags.K_steps)

        # Generated
        plot(axes, 0, 0, y, title='Generated scene', fontsize=12)
        # Empty plots
        plot(axes, 1, 0, fontsize=12)
        plot(axes, 2, 0, fontsize=12)
        plot(axes, 3, 0, fontsize=12)

        # Put K generation steps in separate subfigures
        for step in range(pretrained_flags.K_steps):
            x_step = stats['x_k'][step]
            m_step = stats['log_m_k'][step].exp()
            mx_step = stats['mx_k'][step]
            if 'log_s_k' in stats:
                s_step = stats['log_s_k'][step].exp()
            pre = 'Mask x RGB ' if step == 0 else ''
            plot(axes, 0, 1 + step, mx_step, pre + f'k={step+1}', fontsize=12)
            pre = 'RGB ' if step == 0 else ''
            plot(axes, 1, 1 + step, x_step, pre + f'k={step+1}', fontsize=12)
            pre = 'Mask ' if step == 0 else ''
            plot(axes,
                 2,
                 1 + step,
                 m_step,
                 pre + f'k={step+1}',
                 True,
                 fontsize=12)
            if 'log_s_k' in stats:
                pre = 'Scope ' if step == 0 else ''
                plot(axes,
                     3,
                     1 + step,
                     s_step,
                     pre + f'k={step+1}',
                     True,
                     axis=step == 0,
                     fontsize=12)

        # Beautify and show figure
        plt.subplots_adjust(wspace=0.05, hspace=0.05)
        manager = plt.get_current_fig_manager()
        manager.resize(*manager.window.maxsize())
        plt.show()
Beispiel #5
0
def main():
    # Parse flags
    cfg = forge.config()
    cfg.num_workers = 0

    # Set manual seed
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    # Make CUDA operations deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Get data loaders
    train_loader, _, _ = fet.load(cfg.data_config, cfg)

    # Optimally distinct RGB colour palette (15 colours)
    colours = json.load(open('utils/colour_palette15.json'))

    # Visualise
    for x in train_loader:
        fig, axes = plt.subplots(2, cfg.batch_size, figsize=(20, 10))

        for f_idx, field in enumerate(['input', 'instances']):
            for b_idx in range(cfg.batch_size):
                axes[f_idx, b_idx].axis('off')

            if field not in x:
                continue
            img = x[field]

            # Colour instance masks
            if field == 'instances':
                img_list = []
                for b_idx in range(img.shape[0]):
                    instances = img[b_idx, :, :, :]
                    img_r = torch.zeros_like(instances)
                    img_g = torch.zeros_like(instances)
                    img_b = torch.zeros_like(instances)
                    ins_idx = 0
                    for ins in range(instances.max().numpy()):
                        ins_map = instances == ins + 1
                        if ins_map.any():
                            img_r[ins_map] = colours['palette'][ins_idx][0]
                            img_g[ins_map] = colours['palette'][ins_idx][1]
                            img_b[ins_map] = colours['palette'][ins_idx][2]
                            ins_idx += 1
                    img_list.append(torch.cat([img_r, img_g, img_b], dim=0))
                img = torch.stack(img_list, dim=0)

            for b_idx in range(cfg.batch_size):
                np_img = np.moveaxis(img.data.numpy()[b_idx], 0, -1)
                if img.shape[1] == 1:
                    axes[f_idx, b_idx].imshow(np_img[:, :, 0],
                                              norm=NoNorm(),
                                              cmap='gray')
                elif img.shape[1] == 3:
                    axes[f_idx, b_idx].imshow(np_img)

        manager = plt.get_current_fig_manager()
        manager.resize(*manager.window.maxsize())
        plt.show()
Beispiel #6
0
def main():
    # Parse flags
    config = forge.config()
    config.batch_size = 1
    config.load_instances = True
    fet.print_flags()

    # Restore original model flags
    pretrained_flags = AttrDict(
        fet.json_load(os.path.join(config.model_dir, 'flags.json')))

    # Get validation loader
    train_loader, val_loader, test_loader = fet.load(config.data_config,
                                                     config)
    fprint(f"Split: {config.split}")
    if config.split == 'train':
        batch_loader = train_loader
    elif config.split == 'val':
        batch_loader = val_loader
    elif config.split == 'test':
        batch_loader = test_loader
    # Shuffle and prefetch to get same data for different models
    if 'gqn' not in config.data_config:
        batch_loader = torch.utils.data.DataLoader(batch_loader.dataset,
                                                   batch_size=1,
                                                   num_workers=0,
                                                   shuffle=True)
    # Prefetch batches
    prefetched_batches = []
    for i, x in enumerate(batch_loader):
        if i == config.num_images:
            break
        prefetched_batches.append(x)

    # Load model
    model = fet.load(config.model_config, pretrained_flags)
    fprint(model)
    model_path = os.path.join(config.model_dir, config.model_file)
    fprint(f"Restoring model from {model_path}")
    checkpoint = torch.load(model_path, map_location='cpu')
    model_state_dict = checkpoint['model_state_dict']
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1',
                         None)
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2',
                         None)
    model.load_state_dict(model_state_dict)

    # Set experiment folder and fprint file for logging
    fet.EXPERIMENT_FOLDER = config.model_dir
    fet.FPRINT_FILE = 'segmentation_metrics.txt'

    # Compute metrics
    model.eval()
    ari_fg_list, sc_fg_list, msc_fg_list = [], [], []
    with torch.no_grad():
        for i, x in enumerate(tqdm(prefetched_batches)):
            _, _, stats, _, _ = model(x['input'])
            # ARI
            ari_fg, _ = average_ari(stats.log_m_k,
                                    x['instances'],
                                    foreground_only=True)
            # Segmentation covering - foreground only
            gt_instances = x['instances'].clone()
            gt_instances[gt_instances == 0] = -100
            ins_preds = torch.argmax(torch.stack(stats.log_m_k, dim=1), dim=1)
            sc_fg = average_segcover(gt_instances, ins_preds)
            msc_fg = average_segcover(gt_instances, ins_preds, False)
            # Recording
            ari_fg_list.append(ari_fg)
            sc_fg_list.append(sc_fg)
            msc_fg_list.append(msc_fg)

    # Print average metrics
    fprint(f"Average FG ARI: {sum(ari_fg_list)/len(ari_fg_list)}")
    fprint(f"Average FG SegCover: {sum(sc_fg_list)/len(sc_fg_list)}")
    fprint(f"Average FG MeanSegCover: {sum(msc_fg_list)/len(msc_fg_list)}")
Beispiel #7
0
def main():
    # ------------------------
    # SETUP
    # ------------------------

    # Parse flags
    config = forge.config()
    if config.debug:
        config.num_workers = 0
        config.batch_size = 2

    # Fix seeds. Always first thing to be done after parsing the config!
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    # Make CUDA operations deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Setup checkpoint or resume
    logdir = osp.join(config.results_dir, config.run_name)
    logdir, resume_checkpoint = fet.init_checkpoint(logdir, config.data_config,
                                                    config.model_config,
                                                    config.resume)
    checkpoint_name = osp.join(logdir, 'model.ckpt')

    # Using GPU(S)?
    if torch.cuda.is_available() and config.gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        config.gpu = False
        torch.set_default_tensor_type('torch.FloatTensor')
    fprint(f"Use GPU: {config.gpu}")
    if config.gpu and config.multi_gpu and torch.cuda.device_count() > 1:
        fprint(f"Using {torch.cuda.device_count()} GPUs!")
        config.num_workers = torch.cuda.device_count() * config.num_workers
    else:
        config.multi_gpu = False

    # Print flags
    # fet.print_flags()
    # TODO(martin) make this cleaner
    fprint(json.dumps(fet._flags.FLAGS.__flags, indent=4, sort_keys=True))

    # Setup TensorboardX SummaryWriter
    writer = SummaryWriter(logdir)

    # Load data
    train_loader, val_loader, test_loader = fet.load(config.data_config,
                                                     config)
    num_elements = 3 * config.img_size**2  # Assume three input channels

    # Load model
    model = fet.load(config.model_config, config)
    fprint(model)
    if config.geco:
        # Goal is specified per pixel & channel so it doesn't need to
        # be changed for different resolutions etc.
        geco_goal = config.g_goal * num_elements
        # Scale step size to get similar update at different resolutions
        geco_lr = config.g_lr * (64**2 / config.img_size**2)
        geco = GECO(geco_goal, geco_lr, config.g_alpha, config.g_init,
                    config.g_min, config.g_speedup)
        beta = geco.beta
    else:
        beta = torch.tensor(config.beta)

    # Setup optimiser
    if config.optimiser == 'rmsprop':
        optimiser = optim.RMSprop(model.parameters(), config.learning_rate)
    elif config.optimiser == 'adam':
        optimiser = optim.Adam(model.parameters(), config.learning_rate)
    elif config.optimiser == 'sgd':
        optimiser = optim.SGD(model.parameters(), config.learning_rate, 0.9)

    # Try to restore model and optimiser from checkpoint
    iter_idx = 0
    if resume_checkpoint is not None:
        fprint(f"Restoring checkpoint from {resume_checkpoint}")
        checkpoint = torch.load(resume_checkpoint, map_location='cpu')
        # Restore model & optimiser
        model_state_dict = checkpoint['model_state_dict']
        model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1',
                             None)
        model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2',
                             None)
        model.load_state_dict(model_state_dict)
        optimiser.load_state_dict(checkpoint['optimiser_state_dict'])
        # Restore GECO
        if config.geco and 'beta' in checkpoint:
            geco.beta = checkpoint['beta']
        if config.geco and 'err_ema' in checkpoint:
            geco.err_ema = checkpoint['err_ema']
        # Update starting iter
        iter_idx = checkpoint['iter_idx'] + 1
    fprint(f"Starting training at iter = {iter_idx}")

    # Push model to GPU(s)?
    if config.multi_gpu:
        fprint("Wrapping model in DataParallel.")
        model = nn.DataParallel(model)
    if config.gpu:
        fprint("Pushing model to GPU.")
        model = model.cuda()
        if config.geco:
            geco.to_cuda()

    # ------------------------
    # TRAINING
    # ------------------------

    model.train()
    timer = time.time()
    while iter_idx <= config.train_iter:
        for train_batch in train_loader:
            # Parse data
            train_input = train_batch['input']
            if config.gpu:
                train_input = train_input.cuda()

            # Forward propagation
            optimiser.zero_grad()
            output, losses, stats, att_stats, comp_stats = model(train_input)

            # Reconstruction error
            err = losses.err.mean(0)
            # KL divergences
            kl_m, kl_l = torch.tensor(0), torch.tensor(0)
            # -- KL stage 1
            if 'kl_m' in losses:
                kl_m = losses.kl_m.mean(0)
            elif 'kl_m_k' in losses:
                kl_m = torch.stack(losses.kl_m_k, dim=1).mean(dim=0).sum()
            # -- KL stage 2
            if 'kl_l' in losses:
                kl_l = losses.kl_l.mean(0)
            elif 'kl_l_k' in losses:
                kl_l = torch.stack(losses.kl_l_k, dim=1).mean(dim=0).sum()

            # Compute ELBO
            elbo = (err + kl_l + kl_m).detach()
            err_new = err.detach()
            kl_new = (kl_m + kl_l).detach()
            # Compute MSE / RMSE
            mse_batched = ((train_input - output)**2).mean((1, 2, 3)).detach()
            rmse_batched = mse_batched.sqrt()
            mse, rmse = mse_batched.mean(0), rmse_batched.mean(0)

            # Main objective
            if config.geco:
                loss = geco.loss(err, kl_l + kl_m)
                beta = geco.beta
            else:
                if config.beta_warmup:
                    # Increase beta linearly over 20% of training
                    beta = config.beta * iter_idx / (0.2 * config.train_iter)
                    beta = torch.tensor(beta).clamp(0, config.beta)
                else:
                    beta = config.beta
                loss = err + beta * (kl_l + kl_m)

            # Backprop and optimise
            loss.backward()
            optimiser.step()

            # Heartbeat log
            if (iter_idx % config.report_loss_every == 0
                    or float(elbo) > ELBO_DIV or config.debug):
                # Print output and write to file
                ps = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                ps += f' {config.run_name} | '
                ps += f'[{iter_idx}/{config.train_iter:.0e}]'
                ps += f' elb: {float(elbo):.0f} err: {float(err):.0f} '
                if 'kl_m' in losses or 'kl_m_k' in losses:
                    ps += f' klm: {float(kl_m):.1f}'
                if 'kl_l' in losses or 'kl_l_k' in losses:
                    ps += f' kll: {float(kl_l):.1f}'
                ps += f' bet: {float(beta):.1e}'
                s_per_b = (time.time() - timer)
                if not config.debug:
                    s_per_b /= config.report_loss_every
                timer = time.time()  # Reset timer
                ps += f' - {s_per_b:.2f} s/b'
                fprint(ps)

                # TensorBoard logging
                # -- Optimisation stats
                writer.add_scalar('optim/beta', beta, iter_idx)
                writer.add_scalar('optim/s_per_batch', s_per_b, iter_idx)
                if config.geco:
                    writer.add_scalar('optim/geco_err_ema', geco.err_ema,
                                      iter_idx)
                    writer.add_scalar('optim/geco_err_ema_element',
                                      geco.err_ema / num_elements, iter_idx)
                # -- Main loss terms
                writer.add_scalar('train/err', err, iter_idx)
                writer.add_scalar('train/err_element', err / num_elements,
                                  iter_idx)
                writer.add_scalar('train/kl_m', kl_m, iter_idx)
                writer.add_scalar('train/kl_l', kl_l, iter_idx)
                writer.add_scalar('train/elbo', elbo, iter_idx)
                writer.add_scalar('train/loss', loss, iter_idx)
                writer.add_scalar('train/mse', mse, iter_idx)
                writer.add_scalar('train/rmse', rmse, iter_idx)
                # -- Per step loss terms
                for key in ['kl_l_k', 'kl_m_k']:
                    if key not in losses: continue
                    for step, val in enumerate(losses[key]):
                        writer.add_scalar(f'train_steps/{key}{step}',
                                          val.mean(0), iter_idx)
                # -- Attention stats
                if config.log_distributions and att_stats is not None:
                    for key in ['mu_k', 'sigma_k', 'pmu_k', 'psigma_k']:
                        if key not in att_stats: continue
                        for step, val in enumerate(att_stats[key]):
                            writer.add_histogram(f'att_{key}_{step}', val,
                                                 iter_idx)
                # -- Component stats
                if config.log_distributions and comp_stats is not None:
                    for key in ['mu_k', 'sigma_k', 'pmu_k', 'psigma_k']:
                        if key not in comp_stats: continue
                        for step, val in enumerate(comp_stats[key]):
                            writer.add_histogram(f'comp_{key}_{step}', val,
                                                 iter_idx)

            # Save checkpoints
            ckpt_freq = config.train_iter / config.num_checkpoints
            if iter_idx % ckpt_freq == 0:
                ckpt_file = '{}-{}'.format(checkpoint_name, iter_idx)
                fprint(f"Saving model training checkpoint to: {ckpt_file}")
                if config.multi_gpu:
                    model_state_dict = model.module.state_dict()
                else:
                    model_state_dict = model.state_dict()
                ckpt_dict = {
                    'iter_idx': iter_idx,
                    'model_state_dict': model_state_dict,
                    'optimiser_state_dict': optimiser.state_dict(),
                    'elbo': elbo
                }
                if config.geco:
                    ckpt_dict['beta'] = geco.beta
                    ckpt_dict['err_ema'] = geco.err_ema
                torch.save(ckpt_dict, ckpt_file)

            # Run validation and log images
            if (iter_idx % config.run_validation_every == 0
                    or float(elbo) > ELBO_DIV):
                # Weight and gradient histograms
                if config.log_grads_and_weights:
                    for name, param in model.named_parameters():
                        writer.add_histogram(f'weights/{name}', param.data,
                                             iter_idx)
                        writer.add_histogram(f'grads/{name}', param.grad,
                                             iter_idx)
                # TensorboardX logging - images
                visualise_inference(model, train_batch, writer, 'train',
                                    iter_idx)
                # Validation
                fprint("Running validation...")
                eval_model = model.module if config.multi_gpu else model
                evaluation(eval_model,
                           val_loader,
                           writer,
                           config,
                           iter_idx,
                           N_eval=config.N_eval)

            # Increment counter
            iter_idx += 1
            if iter_idx > config.train_iter:
                break

            # Exit if training has diverged
            if elbo.item() > ELBO_DIV:
                fprint(f"ELBO: {elbo.item()}")
                fprint(
                    f"ELBO has exceeded {ELBO_DIV} - training has diverged.")
                sys.exit()

    # ------------------------
    # TESTING
    # ------------------------

    # Save final checkpoint
    ckpt_file = '{}-{}'.format(checkpoint_name, 'FINAL')
    fprint(f"Saving model training checkpoint to: {ckpt_file}")
    if config.multi_gpu:
        model_state_dict = model.module.state_dict()
    else:
        model_state_dict = model.state_dict()
    ckpt_dict = {
        'iter_idx': iter_idx,
        'model_state_dict': model_state_dict,
        'optimiser_state_dict': optimiser.state_dict()
    }
    if config.geco:
        ckpt_dict['beta'] = geco.beta
        ckpt_dict['err_ema'] = geco.err_ema
    torch.save(ckpt_dict, ckpt_file)

    # Test evaluation
    fprint("STARTING TESTING...")
    eval_model = model.module if config.gpu and config.multi_gpu else model
    final_elbo = evaluation(eval_model,
                            test_loader,
                            None,
                            config,
                            iter_idx,
                            N_eval=config.N_eval)
    fprint(f"TEST ELBO = {float(final_elbo)}")

    # FID computation
    try:
        fid_from_model(model, test_loader, img_dir=osp.join('/tmp', logdir))
    except NotImplementedError:
        fprint("Sampling not implemented for this model.")

    # Close writer
    writer.close()
def main():
    config = forge.config()

    # print(config.__dict__["__flags"])

    with open(osp.join(config.load_path, "flags.json"), "r") as f:
        run_config = json.load(f)
    # print(run_config)

    config = types.SimpleNamespace(**{**config.__dict__["__flags"], **run_config})

    print(config)

    # Load data
    dataloaders, num_species, charge_scale, ds_stats, data_name = fet.load(
        config.data_config, config=config
    )

    test_dataloader = DataLoader(
        dataloaders["test"].dataset,
        batch_size=300,
        num_workers=0,
        shuffle=False,  # False,
        pin_memory=False,
        collate_fn=collate_fn,
        drop_last=False,
    )

    actuals = torch.Tensor().to(device)
    meadian, mad = ds_stats

    with torch.no_grad():
        for data in test_dataloader:
            data = {k: v.to(device) for k, v in data.items()}
            actuals = torch.cat([actuals, data[config.task]])

    config.num_species = num_species
    config.charge_scale = charge_scale
    config.ds_stats = ds_stats

    # Load model
    model, model_name = fet.load(config.model_config, config)
    model.to(device)

    config.charge_scale = float(config.charge_scale.numpy())
    config.ds_stats = [float(stat.numpy()) for stat in config.ds_stats]

    load_checkpoint(osp.join(config.load_path, config.checkpoint), model)

    multisample = multisample_model(
        model, test_dataloader, actuals, config.samples, fix_seed=True
    )

    multisample_maes = (
        (multisample - actuals.unsqueeze(1)).abs().mean(dim=0).cpu().numpy()
    )

    multisample_average = multisample.mean(dim=1)
    multisample_average_mae = (multisample_average - actuals).abs().mean().cpu().numpy()

    other_samples = []
    for i in range(config.samples):
        with torch.no_grad():
            model.eval()
            test_mae = 0.0
            for data in test_dataloader:
                data = {k: v.to(device) for k, v in data.items()}
                outputs = model(data, compute_loss=True)
                test_mae = test_mae + outputs.mae

            other_samples.append(test_mae / len(test_dataloader))

    with open(
        osp.join(config.load_path, "sample_results_" + config.checkpoint + ".txt"), "w"
    ) as f:
        f.write("Sampled MAEs: " + str(multisample_maes) + "\r\n")
        f.write(
            "Spike:: "
            + "Min: "
            + str(multisample_maes.min())
            + ", Mean: "
            + str(multisample_maes.mean())
            + ", Max: "
            + str(multisample_maes.max())
            + "\r\n"
        )
        f.write("Averaged outputs MAE: " + str(multisample_average_mae) "\r\n")
        f.write("Other samples: " + str(other_samples))
Beispiel #9
0
def main():
    # Parse flags
    config = forge.config()
    fet.print_flags()
    # Restore flags of pretrained model
    flag_path = osp.join(config.model_dir, 'flags.json')
    fprint(f"Restoring flags from {flag_path}")
    pretrained_flags = AttrDict(fet.json_load(flag_path))
    pretrained_flags.debug = True

    # Fix seeds. Always first thing to be done after parsing the config!
    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)
    # Make CUDA operations deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Load data
    config.batch_size = 1
    _, _, test_loader = fet.load(config.data_config, config)

    # Load model
    model = fet.load(config.model_config, pretrained_flags)
    model_path = osp.join(config.model_dir, config.model_file)
    fprint(f"Restoring model from {model_path}")
    checkpoint = torch.load(model_path, map_location='cpu')
    model_state_dict = checkpoint['model_state_dict']
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_1',
                         None)
    model_state_dict.pop('comp_vae.decoder_module.seq.0.pixel_coords.g_2',
                         None)
    model.load_state_dict(model_state_dict)
    fprint(model)

    # Visualise
    model.eval()
    for count, batch in enumerate(test_loader):
        if count >= config.num_images:
            break

        # Forward pass
        output, _, stats, _, _ = model(batch['input'])
        # Set up figure
        fig, axes = plt.subplots(nrows=4, ncols=1 + pretrained_flags.K_steps)

        # Input and reconstruction
        plot(axes, 0, 0, batch['input'], title='Input image', fontsize=12)
        plot(axes, 1, 0, output, title='Reconstruction', fontsize=12)
        # Empty plots
        plot(axes, 2, 0, fontsize=12)
        plot(axes, 3, 0, fontsize=12)

        # Put K reconstruction steps into separate subfigures
        x_k = stats['x_r_k']
        log_m_k = stats['log_m_k']
        mx_k = [x * m.exp() for x, m in zip(x_k, log_m_k)]
        log_s_k = stats['log_s_k'] if 'log_s_k' in stats else None
        for step in range(pretrained_flags.K_steps):
            mx_step = mx_k[step]
            x_step = x_k[step]
            m_step = log_m_k[step].exp()
            if log_s_k:
                s_step = log_s_k[step].exp()

            pre = 'Mask x RGB ' if step == 0 else ''
            plot(axes, 0, 1 + step, mx_step, pre + f'k={step+1}', fontsize=12)
            pre = 'RGB ' if step == 0 else ''
            plot(axes, 1, 1 + step, x_step, pre + f'k={step+1}', fontsize=12)
            pre = 'Mask ' if step == 0 else ''
            plot(axes,
                 2,
                 1 + step,
                 m_step,
                 pre + f'k={step+1}',
                 True,
                 fontsize=12)
            if log_s_k:
                pre = 'Scope ' if step == 0 else ''
                plot(axes,
                     3,
                     1 + step,
                     s_step,
                     pre + f'k={step+1}',
                     True,
                     axis=step == 0,
                     fontsize=12)

        # Beautify and show figure
        plt.subplots_adjust(wspace=0.05, hspace=0.15)
        manager = plt.get_current_fig_manager()
        manager.resize(*manager.window.maxsize())
        plt.show()
Beispiel #10
0
def main():
    # Parse flags
    config = forge.config()

    # Load data
    dataloaders, num_species, charge_scale, ds_stats, data_name = fet.load(
        config.data_config, config=config)

    config.num_species = num_species
    config.charge_scale = charge_scale
    config.ds_stats = ds_stats

    # Load model
    model, model_name = fet.load(config.model_config, config)
    model.to(device)

    config.charge_scale = float(config.charge_scale.numpy())
    config.ds_stats = [float(stat.numpy()) for stat in config.ds_stats]

    # Prepare environment
    run_name = (config.run_name + "_bs" + str(config.batch_size) + "_lr" +
                str(config.learning_rate))

    if config.batch_fit != 0:
        run_name += "_bf" + str(config.batch_fit)

    if config.lr_schedule != "none":
        run_name += "_" + config.lr_schedule

    # Print flags
    fet.print_flags()

    # Setup optimizer
    model_params = model.predictor.parameters()

    opt_learning_rate = config.learning_rate
    model_opt = torch.optim.Adam(
        model_params,
        lr=opt_learning_rate,
        betas=(config.beta1, config.beta2),
        eps=1e-8,
    )
    # model_opt = torch.optim.SGD(model_params, lr=opt_learning_rate)

    # Cosine annealing learning rate
    if config.lr_schedule == "cosine":
        cos = cosLr(config.train_epochs)
        lr_sched = lambda e: max(cos(e), config.lr_floor * config.learning_rate
                                 )
        lr_schedule = optim.lr_scheduler.LambdaLR(model_opt, lr_sched)
    elif config.lr_schedule == "cosine_warmup":
        cos = cosLr(config.train_epochs)
        lr_sched = lambda e: max(
            min(e / (config.warmup_length * config.train_epochs), 1) * cos(e),
            config.lr_floor * config.learning_rate,
        )
        lr_schedule = optim.lr_scheduler.LambdaLR(model_opt, lr_sched)
    elif config.lr_schedule == "quadratic_warmup":
        lr_sched = lambda e: min(e / (0.01 * config.train_epochs), 1) * (
            1.0 / sqrt(1.0 + 10000.0 * (e / config.train_epochs)
                       )  # finish at 1/100 of initial lr
        )
        lr_schedule = optim.lr_scheduler.LambdaLR(model_opt, lr_sched)
    elif config.lr_schedule == "none":
        lr_sched = lambda e: 1.0
        lr_schedule = optim.lr_scheduler.LambdaLR(model_opt, lr_sched)
    else:
        raise ValueError(
            f"{config.lr_schedule} is not a recognised learning rate schedule")

    num_params = param_count(model)
    if config.parameter_count:
        for (name, parameter) in model.predictor.named_parameters():
            print(name, parameter.dtype)

        print(model)
        print("============================================================")
        print(f"{model_name} parameters: {num_params:.5e}")
        print("============================================================")
        # from torchsummary import summary

        # data = next(iter(dataloaders["train"]))

        # data = {k: v.to(device) for k, v in data.items()}
        # print(
        #     summary(
        #         model.predictor,
        #         data,
        #         batch_size=config.batch_size,
        #     )
        # )

        parameters = sum(parameter.numel()
                         for parameter in model.predictor.parameters())
        parameters_grad = sum(
            parameter.numel() if parameter.requires_grad else 0
            for parameter in model.predictor.parameters())
        print(f"Parameters: {parameters:,}")
        print(f"Parameters grad: {parameters_grad:,}")

        memory_allocations = []

        for batch_idx, data in enumerate(dataloaders["train"]):
            print(batch_idx)
            data = {k: v.to(device) for k, v in data.items()}

            model_opt.zero_grad()
            outputs = model(data, compute_loss=True)
            # torch.cuda.empty_cache()
            # memory_allocations.append(torch.cuda.memory_reserved() / 1024 / 1024 / 1024)
            # outputs.loss.backward()

        print(
            f"max memory reserved in one pass: {max(memory_allocations):0.4}GB"
        )
        sys.exit(0)

    else:
        print(f"{model_name} parameters: {num_params:.5e}")

    # set up results folders
    results_folder_name = osp.join(
        data_name,
        model_name,
        run_name,
    )

    logdir = osp.join(config.results_dir,
                      results_folder_name.replace(".", "_"))
    logdir, resume_checkpoint = fet.init_checkpoint(logdir, config.data_config,
                                                    config.model_config,
                                                    config.resume)

    checkpoint_name = osp.join(logdir, "model.ckpt")

    # Try to restore model and optimizer from checkpoint
    if resume_checkpoint is not None:
        start_epoch, best_valid_mae = load_checkpoint(resume_checkpoint, model,
                                                      model_opt, lr_schedule)
    else:
        start_epoch = 1
        best_valid_mae = 1e12

    train_iter = (start_epoch - 1) * (len(dataloaders["train"].dataset) //
                                      config.batch_size) + 1

    print("Starting training at epoch = {}, iter = {}".format(
        start_epoch, train_iter))

    # Setup tensorboard writing
    summary_writer = SummaryWriter(logdir)

    report_all = defaultdict(list)
    # Saving model at epoch 0 before training
    print("saving model at epoch 0 before training ... ")
    save_checkpoint(checkpoint_name, 0, model, model_opt, lr_schedule, 0.0)
    print("finished saving model at epoch 0 before training")

    if (config.debug and config.model_config
            == "configs/dynamics/eqv_transformer_model.py"):
        model_components = ([(0, [], "embedding_layer")] + list(
            chain.from_iterable((
                (k, [], f"ema_{k}"),
                (
                    k,
                    ["ema", "kernel", "location_kernel"],
                    f"ema_{k}_location_kernel",
                ),
                (
                    k,
                    ["ema", "kernel", "feature_kernel"],
                    f"ema_{k}_feature_kernel",
                ),
            ) for k in range(1, config.num_layers + 1))) +
                            [(config.num_layers + 2, [], "output_mlp")]
                            )  # components to track for debugging
        grad_flows = []

    if config.init_activations:
        activation_tracked = [(name, module)
                              for name, module in model.named_modules()
                              if isinstance(module, Expression)
                              | isinstance(module, nn.Linear)
                              | isinstance(module, MultiheadLinear)
                              | isinstance(module, MaskBatchNormNd)]
        activations = {}

        def save_activation(name, mod, inpt, otpt):
            if isinstance(inpt, tuple):
                if isinstance(inpt[0], list) | isinstance(inpt[0], tuple):
                    activations[name + "_inpt"] = inpt[0][1].detach().cpu()
                else:
                    if len(inpt) == 1:
                        activations[name + "_inpt"] = inpt[0].detach().cpu()
                    else:
                        activations[name + "_inpt"] = inpt[1].detach().cpu()
            else:
                activations[name + "_inpt"] = inpt.detach().cpu()

            if isinstance(otpt, tuple):
                if isinstance(otpt[0], list):
                    activations[name + "_otpt"] = otpt[0][1].detach().cpu()
                else:
                    if len(otpt) == 1:
                        activations[name + "_otpt"] = otpt[0].detach().cpu()
                    else:
                        activations[name + "_otpt"] = otpt[1].detach().cpu()
            else:
                activations[name + "_otpt"] = otpt.detach().cpu()

        for name, tracked_module in activation_tracked:
            tracked_module.register_forward_hook(partial(
                save_activation, name))

    # Training
    start_t = time.perf_counter()

    iters_per_epoch = len(dataloaders["train"])
    last_valid_loss = 1000.0
    for epoch in tqdm(range(start_epoch, config.train_epochs + 1)):
        model.train()

        for batch_idx, data in enumerate(dataloaders["train"]):
            data = {k: v.to(device) for k, v in data.items()}

            model_opt.zero_grad()
            outputs = model(data, compute_loss=True)

            outputs.loss.backward()
            if config.clip_grad:
                # Clip gradient L2-norm at 1
                torch.nn.utils.clip_grad.clip_grad_norm_(
                    model.predictor.parameters(), 1.0)
            model_opt.step()

            if config.init_activations:
                model_opt.zero_grad()
                outputs = model(data, compute_loss=True)
                outputs.loss.backward()
                for name, activation in activations.items():
                    print(name)
                    summary_writer.add_histogram(f"activations/{name}",
                                                 activation.numpy(), 0)

                sys.exit(0)

            if config.log_train_values:
                reports = parse_reports(outputs.reports)
                if batch_idx % config.report_loss_every == 0:
                    log_tensorboard(summary_writer, train_iter, reports,
                                    "train/")
                    report_all = log_reports(report_all, train_iter, reports,
                                             "train")
                    print_reports(
                        reports,
                        start_t,
                        epoch,
                        batch_idx,
                        len(dataloaders["train"].dataset) // config.batch_size,
                        prefix="train",
                    )

            # Logging
            if batch_idx % config.evaluate_every == 0:
                model.eval()
                with torch.no_grad():
                    valid_mae = 0.0
                    for data in dataloaders["valid"]:
                        data = {k: v.to(device) for k, v in data.items()}
                        outputs = model(data, compute_loss=True)
                        valid_mae = valid_mae + outputs.mae
                model.train()

                outputs["reports"].valid_mae = valid_mae / len(
                    dataloaders["valid"])

                reports = parse_reports(outputs.reports)

                log_tensorboard(summary_writer, train_iter, reports, "valid")
                report_all = log_reports(report_all, train_iter, reports,
                                         "valid")
                print_reports(
                    reports,
                    start_t,
                    epoch,
                    batch_idx,
                    len(dataloaders["train"].dataset) // config.batch_size,
                    prefix="valid",
                )

                loss_diff = (last_valid_loss -
                             (valid_mae / len(dataloaders["valid"])).item())
                if loss_diff and config.find_spikes < -0.1:
                    save_checkpoint(
                        checkpoint_name + "_spike",
                        epoch,
                        model,
                        model_opt,
                        lr_schedule,
                        outputs.loss,
                    )

                last_valid_loss = (valid_mae /
                                   len(dataloaders["valid"])).item()

                if outputs["reports"].valid_mae < best_valid_mae:
                    save_checkpoint(
                        checkpoint_name,
                        "best_valid_mae",
                        model,
                        model_opt,
                        lr_schedule,
                        best_valid_mae,
                    )
                    best_valid_mae = outputs["reports"].valid_mae

            train_iter += 1

            # Step the LR schedule
            lr_schedule.step(train_iter / iters_per_epoch)

            # Track stuff for debugging
            if config.debug:
                model.eval()
                with torch.no_grad():
                    curr_params = {
                        k: v.detach().clone()
                        for k, v in model.state_dict().items()
                    }

                    # updates norm
                    if train_iter == 1:
                        update_norm = 0
                    else:
                        update_norms = []
                        for (k_prev,
                             v_prev), (k_curr,
                                       v_curr) in zip(prev_params.items(),
                                                      curr_params.items()):
                            assert k_prev == k_curr
                            if (
                                    "tracked" not in k_prev
                            ):  # ignore batch norm tracking. TODO: should this be ignored? if not, fix!
                                update_norms.append(
                                    (v_curr - v_prev).norm(1).item())
                        update_norm = sum(update_norms)

                    # gradient norm
                    grad_norm = 0
                    for p in model.parameters():
                        try:
                            grad_norm += p.grad.norm(1)
                        except AttributeError:
                            pass

                    # weights norm
                    if (config.model_config ==
                            "configs/dynamics/eqv_transformer_model.py"):
                        model_norms = {}
                        for comp_name in model_components:
                            comp = get_component(model.predictor.net,
                                                 comp_name)
                            norm = get_average_norm(comp)
                            model_norms[comp_name[2]] = norm

                        log_tensorboard(
                            summary_writer,
                            train_iter,
                            model_norms,
                            "debug/avg_model_norms/",
                        )

                    log_tensorboard(
                        summary_writer,
                        train_iter,
                        {
                            "avg_update_norm1": update_norm / num_params,
                            "avg_grad_norm1": grad_norm / num_params,
                        },
                        "debug/",
                    )
                    prev_params = curr_params

                    # gradient flow
                    ave_grads = []
                    max_grads = []
                    layers = []
                    for n, p in model.named_parameters():
                        if (p.requires_grad) and ("bias" not in n):
                            layers.append(n)
                            ave_grads.append(p.grad.abs().mean().item())
                            max_grads.append(p.grad.abs().max().item())

                    grad_flow = {
                        "layers": layers,
                        "ave_grads": ave_grads,
                        "max_grads": max_grads,
                    }
                    grad_flows.append(grad_flow)

                model.train()

        # Test model at end of batch
        with torch.no_grad():
            model.eval()
            test_mae = 0.0
            for data in dataloaders["test"]:
                data = {k: v.to(device) for k, v in data.items()}
                outputs = model(data, compute_loss=True)
                test_mae = test_mae + outputs.mae

        outputs["reports"].test_mae = test_mae / len(dataloaders["test"])

        reports = parse_reports(outputs.reports)

        log_tensorboard(summary_writer, train_iter, reports, "test")
        report_all = log_reports(report_all, train_iter, reports, "test")

        print_reports(
            reports,
            start_t,
            epoch,
            batch_idx,
            len(dataloaders["train"].dataset) // config.batch_size,
            prefix="test",
        )

        reports = {
            "lr": lr_schedule.get_lr()[0],
            "time": time.perf_counter() - start_t,
            "epoch": epoch,
        }

        log_tensorboard(summary_writer, train_iter, reports, "stats")
        report_all = log_reports(report_all, train_iter, reports, "stats")

        # Save the reports
        dd.io.save(logdir + "/results_dict.h5", report_all)

        # Save a checkpoint
        if epoch % config.save_check_points == 0:
            save_checkpoint(
                checkpoint_name,
                epoch,
                model,
                model_opt,
                lr_schedule,
                best_valid_mae,
            )
            if config.only_store_last_checkpoint:
                delete_checkpoint(checkpoint_name,
                                  epoch - config.save_check_points)

    save_checkpoint(
        checkpoint_name,
        "final",
        model,
        model_opt,
        lr_schedule,
        outputs.loss,
    )
Beispiel #11
0
def main():
    # Parse flags
    config = forge.config()

    # Set device
    if torch.cuda.is_available():
        device = f"cuda:{config.device}"
        torch.cuda.set_device(device)
    else:
        device = "cpu"

    # Load data
    dataloaders, data_name = fet.load(config.data_config, config=config)

    train_loader = dataloaders["train"]
    test_loader = dataloaders["test"]
    val_loader = dataloaders["val"]

    # Load model
    model, model_name = fet.load(config.model_config, config)
    model = model.to(device)
    print(model)

    # Prepare environment
    params_in_run_name = [
        ("batch_size", "bs"),
        ("learning_rate", "lr"),
        ("num_heads", "nheads"),
        ("num_layers", "nlayers"),
        ("dim_hidden", "hdim"),
        ("kernel_dim", "kdim"),
        ("location_attention", "locatt"),
        ("model_seed", "mseed"),
        ("lr_schedule", "lrsched"),
        ("layer_norm", "ln"),
        ("batch_norm", "bn"),
        ("channel_width", "width"),
        ("attention_fn", "attfn"),
        ("output_mlp_scale", "mlpscale"),
        ("train_epochs", "epochs"),
        ("block_norm", "block"),
        ("kernel_type", "ktype"),
        ("architecture", "arch"),
        ("activation_function", "act"),
        ("space_dim", "spacedim"),
        ("num_particles", "prtcls"),
        ("n_train", "ntrain"),
        ("group", "group"),
        ("lift_samples", "ls"),
    ]

    run_name = ""
    for config_param in params_in_run_name:
        attr = config_param[0]
        abbrev = config_param[1]

        if hasattr(config, attr):
            run_name += abbrev
            run_name += str(getattr(config, attr))
            run_name += "_"

    if config.clip_grad_norm:
        run_name += "clipnorm" + str(config.max_grad_norm) + "_"

    results_folder_name = osp.join(
        data_name,
        model_name,
        config.run_name,
        run_name,
    )

    logdir = osp.join(config.results_dir,
                      results_folder_name.replace(".", "_"))
    logdir, resume_checkpoint = fet.init_checkpoint(logdir, config.data_config,
                                                    config.model_config,
                                                    config.resume)

    checkpoint_name = osp.join(logdir, "model.ckpt")

    # Print flags
    fet.print_flags()

    # Setup optimizer
    model_params = model.predictor.parameters()

    opt_learning_rate = config.learning_rate
    model_opt = torch.optim.Adam(model_params,
                                 lr=opt_learning_rate,
                                 betas=(config.beta1, config.beta2))

    if config.lr_schedule == "cosine_annealing":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            model_opt, config.train_epochs)
    elif config.lr_schedule == "cosine_annealing_warmup":
        num_warmup_epochs = int(0.05 * config.train_epochs)
        num_warmup_steps = len(train_loader) * num_warmup_epochs
        num_training_steps = len(train_loader) * config.train_epochs
        scheduler = transformers.get_cosine_schedule_with_warmup(
            model_opt,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
        )

    # Try to restore model and optimizer from checkpoint
    if resume_checkpoint is not None:
        start_epoch = load_checkpoint(resume_checkpoint, model, model_opt)
    else:
        start_epoch = 1

    train_iter = (start_epoch - 1) * (len(train_loader.dataset) //
                                      config.batch_size) + 1

    print("Starting training at epoch = {}, iter = {}".format(
        start_epoch, train_iter))

    # Setup tensorboard writing
    summary_writer = SummaryWriter(logdir)

    train_reports = []
    report_all = {}
    report_all_val = {}
    # Saving model at epoch 0 before training
    print("saving model at epoch 0 before training ... ")
    save_checkpoint(checkpoint_name, 0, model, model_opt, loss=0.0)
    print("finished saving model at epoch 0 before training")

    # if (
    #     config.debug
    #     and config.model_config == "configs/dynamics/eqv_transformer_model.py"
    # ):
    #     model_components = (
    #         [(0, [], "embedding_layer")]
    #         + list(
    #             chain.from_iterable(
    #                 (
    #                     (k, [], f"ema_{k}"),
    #                     (
    #                         k,
    #                         ["ema", "kernel", "location_kernel"],
    #                         f"ema_{k}_location_kernel",
    #                     ),
    #                     (
    #                         k,
    #                         ["ema", "kernel", "feature_kernel"],
    #                         f"ema_{k}_feature_kernel",
    #                     ),
    #                 )
    #                 for k in range(1, config.num_layers + 1)
    #             )
    #         )
    #         + [(config.num_layers + 2, [], "output_mlp")]
    #     )  # components to track for debugging

    num_params = param_count(model)
    print(f"Number of model parameters: {num_params}")

    # Training
    start_t = time.time()

    total_train_iters = len(train_loader) * config.train_epochs
    iters_per_eval = int(total_train_iters /
                         100)  # evaluate 100 times over the course of training

    assert (
        config.n_train % min(config.batch_size, config.n_train) == 0
    ), "Batch size doesn't divide dataset size. Can be problematic for loss computation (see below)."

    training_failed = False
    best_val_loss_so_far = 1e+7

    for epoch in tqdm(range(start_epoch, config.train_epochs + 1)):
        model.train()

        for batch_idx, data in enumerate(train_loader):
            data = nested_to(
                data, device, torch.float32
            )  # the format is ((z0, sys_params, ts), true_zs) for data
            outputs = model(data)

            if torch.isnan(outputs.loss):
                if not training_failed:
                    epoch_of_nan = epoch
                if (epoch > epoch_of_nan + 1) and training_failed:
                    raise ValueError("Loss Nan-ed.")
                training_failed = True

            model_opt.zero_grad()
            outputs.loss.backward(retain_graph=False)

            if config.clip_grad_norm:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               config.max_grad_norm,
                                               norm_type=1)

            model_opt.step()

            train_reports.append(parse_reports_cpu(outputs.reports))

            if config.log_train_values:
                reports = parse_reports(outputs.reports)
                if batch_idx % config.report_loss_every == 0:
                    log_tensorboard(summary_writer, train_iter, reports,
                                    "train/")
                    print_reports(
                        reports,
                        start_t,
                        epoch,
                        batch_idx,
                        len(train_loader.dataset) // config.batch_size,
                        prefix="train",
                    )
                    log_tensorboard(
                        summary_writer,
                        train_iter,
                        {"lr": model_opt.param_groups[0]["lr"]},
                        "hyperparams/",
                    )

            # Do learning rate schedule steps per STEP for cosine_annealing_warmup
            if config.lr_schedule == "cosine_annealing_warmup":
                scheduler.step()

            # Track stuff for debugging
            if config.debug:
                model.eval()
                with torch.no_grad():
                    curr_params = {
                        k: v.detach().clone()
                        for k, v in model.state_dict().items()
                    }

                    # updates norm
                    if train_iter == 1:
                        update_norm = 0
                    else:
                        update_norms = []
                        for (k_prev,
                             v_prev), (k_curr,
                                       v_curr) in zip(prev_params.items(),
                                                      curr_params.items()):
                            assert k_prev == k_curr
                            if ("tracked" not in k_prev):
                                update_norms.append(
                                    (v_curr - v_prev).norm(1).item())
                        update_norm = sum(update_norms)

                    # gradient norm
                    grad_norm = 0
                    for p in model.parameters():
                        try:
                            grad_norm += p.grad.norm(1)
                        except AttributeError:
                            pass

                    # Average V size:
                    (z0, sys_params, ts), true_zs = data

                    z = z0
                    m = sys_params[
                        ..., 0]  # assume the first component encodes masses
                    D = z.shape[-1]  # of ODE dims, 2*num_particles*space_dim
                    q = z[:, :D // 2].reshape(*m.shape, -1)
                    k = sys_params[..., 1]

                    V = model.predictor.compute_V((q, sys_params))
                    V_true = SpringV(q, k)

                    log_tensorboard(
                        summary_writer,
                        train_iter,
                        {
                            "avg_update_norm1":
                            update_norm / num_params,
                            "avg_grad_norm1":
                            grad_norm / num_params,
                            "avg_predicted_potential_norm":
                            V.norm(1) / V.numel(),
                            "avg_true_potential_norm":
                            V_true.norm(1) / V_true.numel(),
                        },
                        "debug/",
                    )
                    prev_params = curr_params

                model.train()

            # Logging
            if train_iter % iters_per_eval == 0 or (
                    train_iter == total_train_iters
            ):  # batch_idx % config.evaluate_every == 0:
                model.eval()
                with torch.no_grad():

                    reports = None
                    for data in test_loader:
                        data = nested_to(data, device, torch.float32)
                        outputs = model(data)

                        if reports is None:
                            reports = {
                                k: v.detach().clone().cpu()
                                for k, v in outputs.reports.items()
                            }
                        else:
                            for k, v in outputs.reports.items():
                                reports[k] += v.detach().clone().cpu()

                    for k, v in reports.items():
                        reports[k] = v / len(test_loader)

                    reports = parse_reports(reports)
                    reports["time"] = time.time() - start_t
                    if report_all == {}:
                        report_all = deepcopy(reports)

                        for d in reports.keys():
                            report_all[d] = [report_all[d]]
                    else:
                        for d in reports.keys():
                            report_all[d].append(reports[d])

                    log_tensorboard(summary_writer, train_iter, reports,
                                    "test/")
                    print_reports(
                        reports,
                        start_t,
                        epoch,
                        batch_idx,
                        len(train_loader.dataset) // config.batch_size,
                        prefix="test",
                    )

                    if config.kill_if_poor:
                        if epoch > config.train_epochs * 0.2:
                            if reports["mse"] > 0.01:
                                raise RuntimeError(
                                    f"Killed run due to poor performance.")

                    # repeat for validation data
                    reports = None
                    for data in val_loader:
                        data = nested_to(data, device, torch.float32)
                        outputs = model(data)

                        if reports is None:
                            reports = {
                                k: v.detach().clone().cpu()
                                for k, v in outputs.reports.items()
                            }
                        else:
                            for k, v in outputs.reports.items():
                                reports[k] += v.detach().clone().cpu()

                    for k, v in reports.items():
                        reports[k] = v / len(val_loader)

                    reports = parse_reports(reports)
                    reports["time"] = time.time() - start_t
                    if report_all_val == {}:
                        report_all_val = deepcopy(reports)

                        for d in reports.keys():
                            report_all_val[d] = [report_all_val[d]]
                    else:
                        for d in reports.keys():
                            report_all_val[d].append(reports[d])

                    log_tensorboard(summary_writer, train_iter, reports,
                                    "val/")
                    print_reports(
                        reports,
                        start_t,
                        epoch,
                        batch_idx,
                        len(train_loader.dataset) // config.batch_size,
                        prefix="val",
                    )

                    if report_all_val['mse'][-1] < best_val_loss_so_far:
                        save_checkpoint(checkpoint_name,
                                        f"early_stop",
                                        model,
                                        model_opt,
                                        loss=outputs.loss)
                        best_val_loss_so_far = report_all_val['mse'][-1]

                model.train()

            train_iter += 1

        if config.lr_schedule == "cosine_annealing":
            scheduler.step()

        if epoch % config.save_check_points == 0:
            save_checkpoint(checkpoint_name,
                            train_iter,
                            model,
                            model_opt,
                            loss=outputs.loss)

        dd.io.save(logdir + "/results_dict_train.h5", train_reports)
        dd.io.save(logdir + "/results_dict.h5", report_all)
        dd.io.save(logdir + "/results_dict_val.h5", report_all_val)

    # always save final model
    save_checkpoint(checkpoint_name,
                    train_iter,
                    model,
                    model_opt,
                    loss=outputs.loss)

    if config.save_test_predictions:
        print(
            "Starting to make model predictions on test sets for *final model*."
        )
        for chunk_len in [5, 100]:
            start_t_preds = time.time()
            data_config = SimpleNamespace(
                **{
                    **config.__dict__["__flags"],
                    **{
                        "chunk_len": chunk_len,
                        "batch_size": 500
                    }
                })
            dataloaders, data_name = fet.load(config.data_config,
                                              config=data_config)
            test_loader_preds = dataloaders["test"]

            torch.cuda.empty_cache()
            with torch.no_grad():
                preds = []
                true = []
                num_datapoints = 0
                for idx, d in enumerate(test_loader_preds):
                    true.append(d[-1])
                    d = nested_to(d, device, torch.float32)
                    outputs = model(d)

                    pred_zs = outputs.prediction
                    preds.append(pred_zs)

                    num_datapoints += len(pred_zs)

                    if num_datapoints >= 2000:
                        break

                preds = torch.cat(preds, dim=0).cpu()
                true = torch.cat(true, dim=0).cpu()

                save_dir = osp.join(
                    logdir, f"traj_preds_{chunk_len}_steps_2k_test.pt")
                torch.save(preds, save_dir)

                save_dir = osp.join(logdir,
                                    f"traj_true_{chunk_len}_steps_2k_test.pt")
                torch.save(true, save_dir)

                print(
                    f"Completed making test predictions for chunk_len = {chunk_len} in {time.time() - start_t_preds:.2f} seconds."
                )
def main():
    # Parse flags
    config = forge.config()

    # Set device
    if torch.cuda.is_available():
        device = f"cuda:{config.device}"
        torch.cuda.set_device(device)
    else:
        device = "cpu"

    # Load data
    train_loader, test_loader, data_name = fet.load(config.data_config,
                                                    config=config)

    # Load model
    model, model_name = fet.load(config.model_config, config)
    model = model.to(device)
    print(model)
    torch.manual_seed(0)

    # Prepare environment

    if "set_transformer" in config.model_config:
        params_in_run_name = [
            ("batch_size", "bs"),
            ("learning_rate", "lr"),
            ("num_heads", "nheads"),
            ("patterns_reps", "reps"),
            ("model_seed", "mseed"),
            ("train_epochs", "epochs"),
            ("naug", "naug"),
        ]

    else:
        params_in_run_name = [
            ("batch_size", "bs"),
            ("learning_rate", "lr"),
            ("num_heads", "nheads"),
            ("num_layers", "nlayers"),
            ("dim_hidden", "hdim"),
            ("kernel_dim", "kdim"),
            ("location_attention", "locatt"),
            ("model_seed", "mseed"),
            ("lr_schedule", "lrsched"),
            ("layer_norm", "ln"),
            ("batch_norm_att", "bnatt"),
            ("batch_norm", "bn"),
            ("batch_norm_final_mlp", "bnfinalmlp"),
            ("k", "k"),
            ("attention_fn", "attfn"),
            ("output_mlp_scale", "mlpscale"),
            ("train_epochs", "epochs"),
            ("block_norm", "block"),
            ("kernel_type", "ktype"),
            ("architecture", "arch"),
            ("kernel_act", "actv"),
            ("patterns_reps", "reps"),
            ("lift_samples", "nsamples"),
            ("content_type", "content"),
            ("naug", "naug"),
        ]

    run_name = ""  # config.run_name
    for config_param in params_in_run_name:
        attr = config_param[0]
        abbrev = config_param[1]

        if hasattr(config, attr):
            run_name += abbrev
            run_name += str(getattr(config, attr))
            run_name += "_"

    results_folder_name = osp.join(
        data_name,
        model_name,
        config.run_name,
        run_name,
    )

    # results_folder_name = osp.join(data_name, model_name, run_name,)

    logdir = osp.join(config.results_dir,
                      results_folder_name.replace(".", "_"))
    logdir, resume_checkpoint = fet.init_checkpoint(logdir, config.data_config,
                                                    config.model_config,
                                                    config.resume)

    checkpoint_name = osp.join(logdir, "model.ckpt")

    # Print flags
    fet.print_flags()

    # Setup optimizer
    model_params = model.encoder.parameters()

    opt_learning_rate = config.learning_rate
    model_opt = torch.optim.Adam(model_params,
                                 lr=opt_learning_rate,
                                 betas=(config.beta1, config.beta2))

    # Try to restore model and optimizer from checkpoint
    if resume_checkpoint is not None:
        start_epoch = load_checkpoint(resume_checkpoint, model, model_opt)
    else:
        start_epoch = 1

    train_iter = (start_epoch - 1) * (len(train_loader.dataset) //
                                      config.batch_size) + 1

    print("Starting training at epoch = {}, iter = {}".format(
        start_epoch, train_iter))

    # Setup tensorboard writing
    summary_writer = SummaryWriter(logdir)

    train_reports = []
    report_all = {}
    # Saving model at epoch 0 before training
    print("saving model at epoch 0 before training ... ")
    save_checkpoint(checkpoint_name, 0, model, model_opt, loss=0.0)
    print("finished saving model at epoch 0 before training")

    if (config.debug and config.model_config
            == "configs/constellation/eqv_transformer_model.py"):
        model_components = ([(0, [], "embedding_layer")] + list(
            chain.from_iterable((
                (k, [], f"ema_{k}"),
                (
                    k,
                    ["ema", "kernel", "location_kernel"],
                    f"ema_{k}_location_kernel",
                ),
                (
                    k,
                    ["ema", "kernel", "feature_kernel"],
                    f"ema_{k}_feature_kernel",
                ),
            ) for k in range(1, config.num_layers + 1))) +
                            [(config.num_layers + 2, [], "output_mlp")]
                            )  # components to track for debugging

    num_params = param_count(model)
    print(f"Number of model parameters: {num_params}")

    # Training
    start_t = time.time()

    grad_flows = []
    training_failed = False
    for epoch in tqdm(range(start_epoch, config.train_epochs + 1)):
        model.train()

        for batch_idx, data in enumerate(train_loader):
            data, presence, target = [d.to(device).float() for d in data]
            if config.train_aug_t2:
                unifs = np.random.normal(
                    1
                )  # torch.rand(1) * 2 * math.pi # torch.randn(1).to(device) * 10
                data += unifs
            if config.train_aug_se2:
                # angle = torch.rand(1) * 2 * math.pi
                angle = np.random.random(size=1) * 2 * math.pi
                unifs = np.random.normal(1)  # torch.randn(1).to(device) * 10
                data = rotate(data, angle.item())
                data += unifs
            outputs = model([data, presence], target)

            if torch.isnan(outputs.loss):
                if not training_failed:
                    epoch_of_nan = epoch
                if (epoch > epoch_of_nan + 1) and training_failed:
                    raise ValueError("Loss Nan-ed.")
                training_failed = True

            model_opt.zero_grad()
            outputs.loss.backward(retain_graph=False)
            model_opt.step()

            train_reports.append(parse_reports_cpu(outputs.reports))

            if config.log_train_values:
                reports = parse_reports(outputs.reports)
                if batch_idx % config.report_loss_every == 0:
                    log_tensorboard(summary_writer, train_iter, reports,
                                    "train/")
                    print_reports(
                        reports,
                        start_t,
                        epoch,
                        batch_idx,
                        len(train_loader.dataset) // config.batch_size,
                        prefix="train",
                    )

            # Track stuff for debugging
            if config.debug:
                model.eval()
                with torch.no_grad():
                    curr_params = {
                        k: v.detach().clone()
                        for k, v in model.state_dict().items()
                    }

                    # updates norm
                    if train_iter == 1:
                        update_norm = 0
                    else:
                        update_norms = []
                        for (k_prev,
                             v_prev), (k_curr,
                                       v_curr) in zip(prev_params.items(),
                                                      curr_params.items()):
                            assert k_prev == k_curr
                            if (
                                    "tracked" not in k_prev
                            ):  # ignore batch norm tracking. TODO: should this be ignored? if not, fix!
                                update_norms.append(
                                    (v_curr - v_prev).norm(1).item())
                        update_norm = sum(update_norms)

                    # gradient norm
                    grad_norm = 0
                    for p in model.parameters():
                        try:
                            grad_norm += p.grad.norm(1)
                        except AttributeError:
                            pass

                    # # weights norm
                    # if config.model_config == 'configs/constellation/eqv_transformer_model.py':
                    #     model_norms = {}
                    #     for comp_name in model_components:
                    #         comp = get_component(model.encoder.net, comp_name)
                    #         norm = get_average_norm(comp)
                    #         model_norms[comp_name[2]] = norm

                    #     log_tensorboard(
                    #         summary_writer,
                    #         train_iter,
                    #         model_norms,
                    #         "debug/avg_model_norms/",
                    #     )

                    log_tensorboard(
                        summary_writer,
                        train_iter,
                        {
                            "avg_update_norm1": update_norm / num_params,
                            "avg_grad_norm1": grad_norm / num_params,
                        },
                        "debug/",
                    )
                    prev_params = curr_params

                    # # gradient flow
                    # ave_grads = []
                    # max_grads= []
                    # layers = []
                    # for n, p in model.named_parameters():
                    #     if (p.requires_grad) and (p.grad is not None): # and ("bias" not in n):
                    #         layers.append(n)
                    #         ave_grads.append(p.grad.abs().mean().item())
                    #         max_grads.append(p.grad.abs().max().item())

                    # grad_flow = {"layers": layers, "ave_grads": ave_grads, "max_grads": max_grads}
                    # grad_flows.append(grad_flow)

                model.train()

            # Logging
            if batch_idx % config.evaluate_every == 0:
                model.eval()
                with torch.no_grad():
                    reports = None
                    for data in test_loader:
                        data, presence, target = [
                            d.to(device).float() for d in data
                        ]
                        # if config.data_config == "configs/constellation/constellation.py":
                        # if config.global_rotation_angle != 0.0:
                        # data = rotate(data, config.global_rotation_angle)
                        outputs = model([data, presence], target)

                        if reports is None:
                            reports = {
                                k: v.detach().clone().cpu()
                                for k, v in outputs.reports.items()
                            }
                        else:
                            for k, v in outputs.reports.items():
                                reports[k] += v.detach().clone().cpu()

                    for k, v in reports.items():
                        reports[k] = v / len(
                            test_loader
                        )  # XXX: note this is slightly incorrect since mini-batch sizes can vary (if batch_size doesn't divide train_size), but approximately correct.

                    reports = parse_reports(reports)
                    reports["time"] = time.time() - start_t
                    if report_all == {}:
                        report_all = deepcopy(reports)

                        for d in reports.keys():
                            report_all[d] = [report_all[d]]
                    else:
                        for d in reports.keys():
                            report_all[d].append(reports[d])

                    log_tensorboard(summary_writer, train_iter, reports,
                                    "test/")
                    print_reports(
                        reports,
                        start_t,
                        epoch,
                        batch_idx,
                        len(train_loader.dataset) // config.batch_size,
                        prefix="test",
                    )

                model.train()

            train_iter += 1

        if epoch % config.save_check_points == 0:
            save_checkpoint(checkpoint_name,
                            train_iter,
                            model,
                            model_opt,
                            loss=outputs.loss)

        dd.io.save(logdir + "/results_dict_train.h5", train_reports)
        dd.io.save(logdir + "/results_dict.h5", report_all)

        # if config.debug:
        #     # dd.io.save(logdir + "/grad_flows.h5", grad_flows)

    save_checkpoint(checkpoint_name,
                    train_iter,
                    model,
                    model_opt,
                    loss=outputs.loss)