def main():
    # =========================================================================
    # PARSE EXPERIMENT SETTINGS, SETUP SNAPSHOTS DIRECTORY, LOGGING
    # =========================================================================
    args, kwargs = parse_args()

    # =========================================================================
    # LOAD DATA
    # =========================================================================
    logger.info('LOADING DATA:')
    train_loader, val_loader, test_loader, args = load_image_dataset(
        args, **kwargs)
    args.z_size = args.input_size

    # =========================================================================
    # SAVE EXPERIMENT SETTINGS
    # =========================================================================
    logger.info(f'EXPERIMENT SETTINGS:\n{args}\n')
    torch.save(args, os.path.join(args.snap_dir, 'config.pt'))

    # =========================================================================
    # INITIALIZE MODEL AND OPTIMIZATION
    # =========================================================================
    model = init_model(args)
    optimizer, scheduler = init_optimizer(model, args)
    num_params = sum([param.nelement() for param in model.parameters()])
    logger.info(f"MODEL:\nNumber of model parameters={num_params}\n{model}\n")

    # =========================================================================
    # TRAINING
    # =========================================================================
    logger.info('TRAINING:')
    train(model, train_loader, val_loader, optimizer, scheduler, args)

    # =========================================================================
    # VALIDATION
    # =========================================================================
    logger.info('VALIDATION:')
    val_loss = evaluate(model, val_loader, args)

    # =========================================================================
    # TESTING
    # =========================================================================
    if args.testing:
        logger.info("TESTING:")
        val_loss = evaluate(model, test_loader, args)
def train(model, data_loaders, optimizer, scheduler, args):
    writer = SummaryWriter(args.snap_dir) if args.tensorboard else None

    header_msg = f'| Epoch | {"TRAIN": <14}{"Loss": >4} | {"VALIDATION": <14}{"Loss": >4} | {"TIMING":<8}{"(sec)":>4} | {"Improved": >8} |'
    header_msg += f' {"Component": >9} | {"All Trained": >11} | {"Rho": >{min(8, args.num_components) * 6}} |' if args.boosted else ''
    logger.info('|' + "-" * (len(header_msg) - 2) + '|')
    logger.info(header_msg)
    logger.info('|' + "-" * (len(header_msg) - 2) + '|')

    best_loss = np.array([np.inf] * args.num_components)
    early_stop_count = 0
    converged_epoch = 0  # for boosting, helps keep track how long the current component has been training

    if args.boosted:
        #model.component = 0
        prev_lr = init_boosted_lr(model, optimizer, args)
    else:
        prev_lr = []

    grad_norm = None
    epoch_times = []
    epoch_train = []
    epoch_valid = []

    pval_loss = 0.0
    val_losses = {'g_nll': 9999999.9}
    step = 0
    for epoch in range(args.init_epoch, args.epochs + 1):

        model.train()
        train_loss = []
        t_start = time.time()

        for batch_id, (x, _) in enumerate(data_loaders['train']):

            # initialize data and optimizer
            x = x.to(args.device)
            optimizer.zero_grad()

            # initialize ActNorm on first steps
            if (args.flow == 'glow' or args.component_type
                    == 'glow') and step < args.num_init_batches:
                with torch.no_grad():
                    if args.boosted:
                        for i in range(args.num_components):
                            model(x=x, components=i)
                    else:
                        model(x=x)

                    step += 1
                    continue

            # compute loss and gradients
            losses = compute_kl_pq_loss(model, x, args)
            train_loss.append(losses['nll'])
            losses['nll'].backward()

            if args.max_grad_norm > 0:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.max_grad_norm)

            # Adjust learning rates for boosted model, keep fixed components frozen
            if args.boosted:
                update_learning_rates(prev_lr, model, optimizer, step, args)

            # batch level reporting
            batch_reporting(writer, optimizer, losses, grad_norm, step, args)

            # Perform gradient update, modify learning rate according to learning rate schedule
            optimizer.step()
            if not args.no_lr_schedule:
                prev_lr = update_scheduler(prev_lr, model, optimizer,
                                           scheduler, val_losses['g_nll'],
                                           step, args)

                if args.lr_schedule == "test":
                    if step % 50 == 0:
                        pval_loss = evaluate(model, data_loaders['val'],
                                             args)['nll']

                    writer.add_scalar('step/val_nll', pval_loss, step)

            step += 1

        # Validation, collect results
        val_losses = evaluate(model, data_loaders['val'], args)
        train_loss = torch.stack(train_loss).mean().item()
        epoch_times.append(time.time() - t_start)
        epoch_train.append(train_loss)
        epoch_valid.append(val_losses['nll'])

        # Assess convergence
        component = (model.component, model.all_trained) if args.boosted else 0
        converged, model_improved, early_stop_count, best_loss = check_convergence(
            early_stop_count, val_losses, best_loss, epoch - converged_epoch,
            component, args)
        if model_improved:
            fname = f'model_c{model.component}.pt' if args.boosted and args.save_intermediate_checkpoints else 'model.pt'
            save(model, optimizer, args.snap_dir + fname, scheduler)

        # epoch level reporting
        epoch_msg = epoch_reporting(writer, model, train_loss, val_losses,
                                    epoch_times, model_improved, epoch, args)

        if converged:
            logger.info(epoch_msg + ' |')
            logger.info("-" * (len(header_msg)))

            if args.boosted:
                converged_epoch = epoch

                # revert back to the last best version of the model and update rho
                fname = f'model_c{model.component}.pt' if args.save_intermediate_checkpoints else 'model.pt'
                load(model=model,
                     optimizer=optimizer,
                     path=args.snap_dir + fname,
                     args=args,
                     scheduler=scheduler,
                     verbose=False)
                model.update_rho(data_loaders['train'])

                last_component = model.component == (args.num_components - 1)
                no_fine_tuning = args.epochs <= args.epochs_per_component * args.num_components
                fine_tuning_done = model.all_trained and last_component  # no early stopping if burnin employed
                if (fine_tuning_done or no_fine_tuning) and last_component:
                    # stop the full model after all components have been trained
                    logger.info(
                        f"Model converged, training complete, saving: {args.snap_dir + 'model.pt'}"
                    )
                    model.all_trained = True
                    save(model, optimizer, args.snap_dir + f'model.pt',
                         scheduler)
                    break

                # else if not done training: save model with updated rho
                save(model, optimizer, args.snap_dir + fname, scheduler)

                # tempory: look at results after each component
                test_loss = evaluate(model, data_loaders['test'], args)
                logger.info(
                    f"Loss after training {model.component + 1} components: {test_loss['nll']:8.3f}"
                )
                logger.info("-" * (len(header_msg)))

                # reset optimizer, scheduler, and early_stop_count and train the next component
                model.increment_component()
                early_stop_count = 0
                val_losses = {'g_nll': 9999999.9}
                optimizer, scheduler = init_optimizer(model,
                                                      args,
                                                      verbose=False)
                prev_lr = init_boosted_lr(model, optimizer, args)
            else:
                # if a standard model converges once, break
                logger.info(f"Model converged, stopping training.")
                break

        else:
            logger.info(epoch_msg + ' |')
            if epoch == args.epochs:
                if args.boosted and args.save_intermediate_checkpoints:
                    # Save the best version of the model trained up to the current component with filename model.pt
                    # This is to protect against times when the model is trained/re-trained but doesn't run long enough
                    # for all components to converge / train completely
                    copyfile(args.snap_dir + f'model_c{model.component}.pt',
                             args.snap_dir + 'model.pt')
                    logger.info(
                        f"Resaving last improved version of {f'model_c{model.component}.pt'} as 'model.pt' for future testing"
                    )
                else:
                    logger.info(
                        f"Stopping training after {epoch} epochs of training.")

    logger.info('|' + "-" * (len(header_msg) - 2) + '|\n')
    if args.tensorboard:
        writer.close()

    epoch_times, epoch_train, epoch_valid = np.array(epoch_times), np.array(
        epoch_train), np.array(epoch_valid)
    timing_msg = f"Stopped after {epoch_times.shape[0]} epochs. "
    timing_msg += f"Average train time per epoch: {np.mean(epoch_times):.2f} +/- {np.std(epoch_times, ddof=1):.2f}"
    logger.info(timing_msg + '\n')
    if args.save_results:
        np.savetxt(args.snap_dir + '/train_loss.csv',
                   epoch_train,
                   fmt='%f',
                   delimiter=',')
        np.savetxt(args.snap_dir + '/valid_loss.csv',
                   epoch_valid,
                   fmt='%f',
                   delimiter=',')
        np.savetxt(args.snap_dir + '/epoch_times.csv',
                   epoch_times,
                   fmt='%f',
                   delimiter=',')
        with open(args.exp_log, 'a') as ff:
            timestamp = str(datetime.datetime.now())[0:19].replace(' ', '_')
            setup_msg = '\n'.join([timestamp, args.snap_dir
                                   ]) + '\n' + repr(args)
            print('\n' + setup_msg + '\n' + timing_msg, file=ff)
def train_boosted(train_loader, val_loader, model, optimizer, scheduler, args):
    train_times = []
    train_loss = []
    train_rec = []
    train_G = []
    train_p = []
    train_entropy = []

    val_loss = []
    val_rec = []
    val_kl = []

    # for early stopping
    best_loss = np.array([np.inf] * args.num_components)
    best_tr_ratio = np.array([-np.inf] * args.num_components)
    early_stop_count = 0
    converged_epoch = 0  # corrects the annealing schedule when a component converges early
    v_loss = 9999999.9

    # initialize learning rates for boosted components
    prev_lr = init_boosted_lr(model, optimizer, args)

    args.step = 0
    for epoch in range(args.init_epoch, args.epochs + 1):

        # compute annealing rate for KL loss term
        beta = kl_annealing_rate(epoch - converged_epoch, model.component,
                                 model.all_trained, args)

        # occasionally sample from all components to keep decoder from focusing solely on new component
        prob_all = sample_from_all_prob(epoch - converged_epoch,
                                        model.component, model.all_trained,
                                        args)

        # Train model
        t_start = time.time()
        tr_loss, tr_rec, tr_G, tr_p, tr_entropy, tr_ratio, prev_lr = train_epoch_boosted(
            epoch, train_loader, model, optimizer, scheduler, beta, prob_all,
            prev_lr, v_loss, args)
        train_times.append(time.time() - t_start)
        train_loss.append(tr_loss)
        train_rec.append(tr_rec)
        train_G.append(tr_G)
        train_p.append(tr_p)
        train_entropy.append(tr_entropy)

        # Evaluate model
        v_loss, v_rec, v_kl = evaluate(val_loader, model, args, epoch=epoch)
        val_loss.append(v_loss)
        val_rec.append(v_rec)
        val_kl.append(v_kl)

        # Assess convergence
        component_converged, model_improved, early_stop_count, best_loss, best_tr_ratio = check_convergence(
            early_stop_count, v_loss, best_loss, tr_ratio, best_tr_ratio,
            epoch - converged_epoch, model, args)

        # epoch level reporting
        epoch_msg = epoch_reporting(model, tr_loss, tr_rec, tr_G, tr_p,
                                    tr_entropy, tr_ratio, v_loss, v_rec, v_kl,
                                    beta, prob_all, train_times, epoch,
                                    model_improved, args)

        if model_improved:
            fname = f'model_c{model.component}.pt' if args.boosted and args.save_intermediate_checkpoints else 'model.pt'
            save(model, optimizer, args.snap_dir + fname, scheduler)

        if component_converged:
            logger.info(epoch_msg + f'{"| ": >4}')
            logger.info("-" * 206)
            converged_epoch = epoch

            # revert back to the last best version of the model and update rho
            fname = f'model_c{model.component}.pt' if args.save_intermediate_checkpoints else 'model.pt'
            load(model=model,
                 optimizer=optimizer,
                 path=args.snap_dir + fname,
                 args=args,
                 scheduler=scheduler,
                 verbose=False)
            model.update_rho(train_loader)

            last_component = model.component == (args.num_components - 1)
            no_fine_tuning = args.epochs <= args.epochs_per_component * args.num_components
            fine_tuning_done = model.all_trained and last_component
            if (fine_tuning_done or no_fine_tuning) and last_component:
                # stop the full model after all components have been trained
                logger.info(
                    f"Model converged, training complete, saving: {args.snap_dir + 'model.pt'}"
                )
                model.all_trained = True
                save(model, optimizer, args.snap_dir + f'model.pt', scheduler)
                break

            save(model, optimizer,
                 args.snap_dir + f'model_c{model.component}.pt', scheduler)

            # reset early_stop_count and train the next component
            model.increment_component()
            early_stop_count = 0
            v_loss = 9999999.9
            optimizer, scheduler = init_optimizer(model, args, verbose=False)
            prev_lr = init_boosted_lr(model, optimizer, args)
        else:
            logger.info(epoch_msg + f'{"| ": >4}')
            if epoch == args.epochs:
                if args.boosted and args.save_intermediate_checkpoints:
                    # Save the best version of the model trained up to the current component with filename model.pt
                    # This is to protect against times when the model is trained/re-trained but doesn't run long enough
                    #   for all components to converge / train completely
                    copyfile(args.snap_dir + f'model_c{model.component}.pt',
                             args.snap_dir + 'model.pt')
                    logger.info(
                        f"Resaving last improved version of {f'model_c{model.component}.pt'} as 'model.pt' for future testing"
                    )
                else:
                    logger.info(
                        f"Stopping training after {epoch} epochs of training.")

    train_loss = np.hstack(train_loss)
    train_rec = np.hstack(train_rec)
    train_G = np.hstack(train_G)
    train_p = np.hstack(train_p)
    train_entropy = np.hstack(train_entropy)

    val_loss = np.array(val_loss)
    val_rec = np.array(val_rec)
    val_kl = np.array(val_kl)
    train_times = np.array(train_times)
    return train_loss, train_rec, train_G, train_p, train_entropy, val_loss, val_rec, val_kl, train_times
def main(main_args=None):
    """
    use main_args to run this script as function in another script
    """

    # =========================================================================
    # PARSE EXPERIMENT SETTINGS, SETUP SNAPSHOTS DIRECTORY, LOGGING
    # =========================================================================
    args, kwargs = parse_args(main_args)

    # =========================================================================
    # LOAD DATA
    # =========================================================================
    logger.info('LOADING DATA:')
    data_loaders, args = load_density_dataset(args)

    # =========================================================================
    # SAVE EXPERIMENT SETTINGS
    # =========================================================================
    logger.info(f'EXPERIMENT SETTINGS:\n{args}\n')
    torch.save(args, os.path.join(args.snap_dir, 'config.pt'))

    # =========================================================================
    # INITIALIZE MODEL AND OPTIMIZATION
    # =========================================================================
    model = init_model(args)
    optimizer, scheduler = init_optimizer(model, args)
    num_params = sum([param.nelement() for param in model.parameters()])
    logger.info(f"MODEL:\nNumber of model parameters={num_params}\n{model}\n")

    if args.load:
        logger.info(f'LOADING CHECKPOINT FROM PRE-TRAINED MODEL: {args.load}')
        init_with_args = args.flow == "boosted" and args.loaded_init_component is not None and args.loaded_all_trained is not None
        load(model=model,
             optimizer=optimizer,
             path=args.load,
             args=args,
             init_with_args=init_with_args,
             scheduler=scheduler)
        logger.info(
            f'Warning: boosted models may only be loaded to train a new component (until pytorch bug is fixed), optimizer and scheduler will be reset. Non-boosted models may not be loaded at all (will fail).'
        )
        optimizer, scheduler = init_optimizer(model, args, verbose=False)

    # =========================================================================
    # TRAINING
    # =========================================================================
    if args.epochs > 0:
        logger.info('TRAINING:')
        if args.tensorboard:
            logger.info(f'Follow progress on tensorboard: tb {args.snap_dir}')

        train(model, data_loaders, optimizer, scheduler, args)

    # =========================================================================
    # VALIDATION
    # =========================================================================
    logger.info('VALIDATION:')
    load(model=model,
         optimizer=optimizer,
         path=args.snap_dir + 'model.pt',
         args=args)
    val_loss = evaluate(model,
                        data_loaders['val'],
                        args,
                        results_type='Validation')

    # =========================================================================
    # TESTING
    # =========================================================================
    if args.testing:
        logger.info("TESTING:")
        test_loss = evaluate(model,
                             data_loaders['test'],
                             args,
                             results_type='Test')
def main(main_args=None):
    """
    use main_args to run this script as function in another script
    """

    # =========================================================================
    # PARSE EXPERIMENT SETTINGS, SETUP SNAPSHOTS DIRECTORY, LOGGING
    # =========================================================================
    args, kwargs = parse_args(main_args)

    # =========================================================================
    # LOAD DATA
    # =========================================================================
    logger.info('LOADING DATA:')
    train_loader, val_loader, test_loader, args = load_image_dataset(args, **kwargs)

    # =========================================================================
    # SAVE EXPERIMENT SETTINGS
    # =========================================================================
    logger.info(f'EXPERIMENT SETTINGS:\n{args}\n')
    torch.save(args, os.path.join(args.snap_dir, 'config.pt'))

    # =========================================================================
    # INITIALIZE MODEL AND OPTIMIZATION
    # =========================================================================
    model = init_model(args)
    optimizer, scheduler = init_optimizer(model, args)
    num_params = sum([param.nelement() for param in model.parameters()])
    logger.info(f"MODEL:\nNumber of model parameters={num_params}\n{model}\n")

    if args.load:
        logger.info(f'LOADING CHECKPOINT FROM PRE-TRAINED MODEL: {args.load}')
        init_with_args = args.flow == "boosted" and args.loaded_init_component is not None and args.loaded_all_trained is not None
        load(model, optimizer, args.load, args, init_with_args)

    # =========================================================================
    # TRAINING
    # =========================================================================
    training_required = args.epochs > 0 or args.load is None
    if training_required:
        logger.info('TRAINING:')
        if args.tensorboard:
            logger.info(f'Follow progress on tensorboard: tb {args.snap_dir}')

        train_loss, val_loss = train(train_loader, val_loader, model, optimizer, scheduler, args)

    # =========================================================================
    # VALIDATION
    # =========================================================================
    logger.info('VALIDATION:')
    if training_required:
        load(model, optimizer, args.snap_dir + 'model.pt', args)
    val_loss, val_rec, val_kl = evaluate(val_loader, model, args, results_type='Validation')

    # =========================================================================
    # TESTING
    # =========================================================================
    if args.testing:
        logger.info("TESTING:")
        test_loss, test_rec, test_kl = evaluate(test_loader, model, args, results_type='Test')
        test_nll = evaluate_likelihood(test_loader, model, args, S=args.nll_samples, MB=args.nll_mb, results_type='Test')