def main(hparams, *args):

    if not isinstance(hparams, dict):
        hparams = vars(hparams)

    if hparams['model_type'] == 'conv':
        # blend outer hparams with architecture hparams
        hparams = {**hparams['architecture_params'], **hparams}

    # print hparams to console
    _print_hparams(hparams)

    if hparams['model_type'] == 'conv' and hparams['n_ae_latents'] > hparams[
            'max_latents']:
        raise ValueError(
            'Number of latents higher than max latents, architecture will not work'
        )

    # Start at random times (so test tube creates separate folders)
    np.random.seed(random.randint(0, 1000))
    time.sleep(np.random.uniform(3))

    # create test-tube experiment
    hparams, sess_ids, exp = create_tt_experiment(hparams)
    if hparams is None:
        print('Experiment exists! Aborting fit')
        return

    # build data generator
    data_generator = build_data_generator(hparams, sess_ids)

    # ####################
    # ### CREATE MODEL ###
    # ####################

    def set_n_labels(data_generator, hparams):
        data, _ = data_generator.next_batch('val')
        sh = data['labels'].shape
        hparams['n_labels'] = sh[2]  # [1, n_t, n_labels]

    print('constructing model...', end='')
    torch.manual_seed(hparams['rng_seed_model'])
    torch_rng_seed = torch.get_rng_state()
    hparams['model_build_rng_seed'] = torch_rng_seed
    hparams['n_datasets'] = len(sess_ids)
    if hparams['model_class'] == 'ae':
        from behavenet.models import AE as Model
    elif hparams['model_class'] == 'vae':
        from behavenet.models import VAE as Model
    elif hparams['model_class'] == 'beta-tcvae':
        from behavenet.models import BetaTCVAE as Model
    elif hparams['model_class'] == 'ps-vae':
        from behavenet.models import PSVAE as Model
        set_n_labels(data_generator, hparams)
    elif hparams['model_class'] == 'msps-vae':
        from behavenet.models import MSPSVAE as Model
        set_n_labels(data_generator, hparams)
    elif hparams['model_class'] == 'cond-vae':
        from behavenet.models import ConditionalVAE as Model
        set_n_labels(data_generator, hparams)
    elif hparams['model_class'] == 'cond-ae':
        from behavenet.models import ConditionalAE as Model
        set_n_labels(data_generator, hparams)
    elif hparams['model_class'] == 'cond-ae-msp':
        from behavenet.models import AEMSP as Model
        set_n_labels(data_generator, hparams)
    else:
        raise NotImplementedError(
            'The model class "%s" is not currently implemented' %
            hparams['model_class'])
    model = Model(hparams)
    model.to(hparams['device'])

    # load pretrained weights if specified
    model = load_pretrained_ae(model, hparams)

    # Parallelize over gpus if desired
    if hparams['n_parallel_gpus'] > 1:
        from behavenet.models import CustomDataParallel
        model = CustomDataParallel(model)

    model.version = exp.version
    torch_rng_seed = torch.get_rng_state()
    hparams['training_rng_seed'] = torch_rng_seed

    # save out hparams as csv and dict
    hparams['training_completed'] = False
    export_hparams(hparams, exp)
    print('done')

    # ###################
    # ### TRAIN MODEL ###
    # ###################

    print(model)

    fit(hparams, model, data_generator, exp, method='ae')

    # update hparams upon successful training
    hparams['training_completed'] = True
    export_hparams(hparams, exp)

    # get rid of unneeded logging info
    _clean_tt_dir(hparams)

    # export training plots
    if hparams['export_train_plots']:
        print('creating training plots...', end='')
        version_dir = os.path.join(hparams['expt_dir'],
                                   'version_%i' % hparams['version'])
        if hparams['model_class'] == 'msps-vae':
            from behavenet.plotting.cond_ae_utils import plot_mspsvae_training_curves
            save_file = os.path.join(version_dir, 'loss_training')
            plot_mspsvae_training_curves(
                hparams,
                alpha=hparams['ps_vae.alpha'],
                beta=hparams['ps_vae.beta'],
                delta=hparams['ps_vae.delta'],
                rng_seed_model=hparams['rng_seed_model'],
                n_latents=hparams['n_ae_latents'] - hparams['n_background'] -
                hparams['n_labels'],
                n_background=hparams['n_background'],
                n_labels=hparams['n_labels'],
                dtype='train',
                save_file=save_file,
                format='png',
                version_dir=version_dir)
            save_file = os.path.join(version_dir, 'loss_validation')
            plot_mspsvae_training_curves(
                hparams,
                alpha=hparams['ps_vae.alpha'],
                beta=hparams['ps_vae.beta'],
                delta=hparams['ps_vae.delta'],
                rng_seed_model=hparams['rng_seed_model'],
                n_latents=hparams['n_ae_latents'] - hparams['n_background'] -
                hparams['n_labels'],
                n_background=hparams['n_background'],
                n_labels=hparams['n_labels'],
                dtype='val',
                save_file=save_file,
                format='png',
                version_dir=version_dir)
        else:
            save_file = os.path.join(version_dir, 'loss_training')
            export_train_plots(hparams, 'train', save_file=save_file)
            save_file = os.path.join(version_dir, 'loss_validation')
            export_train_plots(hparams, 'val', save_file=save_file)
        print('done')
def main(hparams, *args):

    if not isinstance(hparams, dict):
        hparams = vars(hparams)

    if hparams['model_type'] == 'conv':
        # blend outer hparams with architecture hparams
        hparams = {**hparams, **hparams['architecture_params']}

    # print hparams to console
    _print_hparams(hparams)

    # Start at random times (so test tube creates separate folders)
    np.random.seed(random.randint(0, 1000))
    time.sleep(np.random.uniform(1))

    # create test-tube experiment
    hparams, sess_ids, exp = create_tt_experiment(hparams)
    if hparams is None:
        print('Experiment exists! Aborting fit')
        return

    # build data generator
    data_generator = build_data_generator(hparams, sess_ids)

    # ####################
    # ### CREATE MODEL ###
    # ####################

    print('constructing model...', end='')
    torch.manual_seed(hparams['rng_seed_model'])
    torch_rnd_seed = torch.get_rng_state()
    hparams['model_build_rnd_seed'] = torch_rnd_seed
    hparams['n_datasets'] = len(sess_ids)
    data, _ = data_generator.next_batch('train')
    sh = data['labels'].shape
    hparams['n_labels'] = sh[2]  # [1, n_t, n_labels]
    model = ConvDecoder(hparams)
    model.to(hparams['device'])

    # Load pretrained weights if specified
    # model = load_pretrained_ae(model, hparams)

    model.version = exp.version
    torch_rnd_seed = torch.get_rng_state()
    hparams['training_rnd_seed'] = torch_rnd_seed

    # save out hparams as csv and dict
    hparams['training_completed'] = False
    export_hparams(hparams, exp)
    print('done')

    # ###################
    # ### TRAIN MODEL ###
    # ###################

    fit(hparams, model, data_generator, exp, method='conv-decoder')

    # export training plots
    if hparams['export_train_plots']:
        print('creating training plots...', end='')
        version_dir = os.path.join(hparams['expt_dir'],
                                   'version_%i' % hparams['version'])
        save_file = os.path.join(version_dir, 'loss_training')
        export_train_plots(hparams, 'train', save_file=save_file)
        save_file = os.path.join(version_dir, 'loss_validation')
        export_train_plots(hparams, 'val', save_file=save_file)
        print('done')

    # update hparams upon successful training
    hparams['training_completed'] = True
    export_hparams(hparams, exp)

    # get rid of unneeded logging info
    _clean_tt_dir(hparams)
Exemple #3
0
def main(hparams, *args):

    if not isinstance(hparams, dict):
        hparams = vars(hparams)

    # print hparams to console
    _print_hparams(hparams)

    # Start at random times (so test tube creates separate folders)
    np.random.seed(random.randint(0, 1000))
    time.sleep(np.random.uniform(1))

    # create test-tube experiment
    hparams, sess_ids, exp = create_tt_experiment(hparams)
    if hparams is None:
        print('Experiment exists! Aborting fit')
        return

    # build data generator
    data_generator = build_data_generator(hparams, sess_ids)

    ex_trial = data_generator.datasets[0].batch_idxs['train'][0]
    i_sig = hparams['input_signal']
    o_sig = hparams['output_signal']

    if hparams['model_class'] == 'neural-arhmm':
        hparams['input_size'] = data_generator.datasets[0][ex_trial][
            i_sig].shape[1]
        hparams['output_size'] = hparams['n_arhmm_states']
    elif hparams['model_class'] == 'arhmm-neural':
        hparams['input_size'] = hparams['n_arhmm_states']
        hparams['output_size'] = data_generator.datasets[0][ex_trial][
            o_sig].shape[1]
    elif hparams['model_class'] == 'neural-ae':
        hparams['input_size'] = data_generator.datasets[0][ex_trial][
            i_sig].shape[1]
        hparams['output_size'] = hparams['n_ae_latents']
    elif hparams['model_class'] == 'neural-ae-me':
        hparams['input_size'] = data_generator.datasets[0][ex_trial][
            i_sig].shape[1]
        hparams['output_size'] = hparams['n_ae_latents']
    elif hparams['model_class'] == 'ae-neural':
        hparams['input_size'] = hparams['n_ae_latents']
        hparams['output_size'] = data_generator.datasets[0][ex_trial][
            o_sig].shape[1]
    elif hparams['model_class'] == 'neural-labels':
        hparams['input_size'] = data_generator.datasets[0][ex_trial][
            i_sig].shape[1]
        hparams['output_size'] = hparams['n_labels']
    elif hparams['model_class'] == 'labels-neural':
        hparams['input_size'] = hparams['n_labels']
        hparams['output_size'] = data_generator.datasets[0][ex_trial][
            o_sig].shape[1]
    else:
        raise ValueError('%s is an invalid model class' %
                         hparams['model_class'])

    if hparams['model_class'] == 'neural-ae' or hparams['model_class'] == 'neural-ae' \
            or hparams['model_class'] == 'ae-neural':
        hparams['ae_model_path'] = os.path.join(
            os.path.dirname(data_generator.datasets[0].paths['ae_latents']))
        hparams['ae_model_latents_file'] = data_generator.datasets[0].paths[
            'ae_latents']
    elif hparams['model_class'] == 'neural-arhmm' or hparams[
            'model_class'] == 'arhmm-neural':
        hparams['arhmm_model_path'] = os.path.dirname(
            data_generator.datasets[0].paths['arhmm_states'])
        hparams['arhmm_model_states_file'] = data_generator.datasets[0].paths[
            'arhmm_states']

        # Store which AE was used for the ARHMM
        tags = pickle.load(
            open(os.path.join(hparams['arhmm_model_path'], 'meta_tags.pkl'),
                 'rb'))
        hparams['ae_model_latents_file'] = tags['ae_model_latents_file']

    # ####################
    # ### CREATE MODEL ###
    # ####################
    print('constructing model...', end='')
    torch.manual_seed(hparams['rng_seed_model'])
    torch_rng_seed = torch.get_rng_state()
    hparams['model_build_rng_seed'] = torch_rng_seed
    model = Decoder(hparams)
    model.to(hparams['device'])
    model.version = exp.version
    torch_rng_seed = torch.get_rng_state()
    hparams['training_rng_seed'] = torch_rng_seed

    # save out hparams as csv and dict for easy reloading
    hparams['training_completed'] = False
    export_hparams(hparams, exp)
    print('done')

    # ####################
    # ### TRAIN MODEL ###
    # ####################

    fit(hparams, model, data_generator, exp, method='nll')

    # update hparams upon successful training
    hparams['training_completed'] = True
    export_hparams(hparams, exp)

    # get rid of unneeded logging info
    _clean_tt_dir(hparams)
def main(hparams):

    if not isinstance(hparams, dict):
        hparams = vars(hparams)

    if hparams['transitions'] == 'sticky' and hparams['kappa'] == 0:
        print('Cannot fit sticky transitions with kappa=0! Aborting fit')
        return
    if hparams['transitions'] != 'sticky' and hparams['kappa'] > 0:
        print('Cannot fit %s transitions with kappa>0! Aborting fit' %
              hparams['transitions'])
        return

    # print hparams to console
    _print_hparams(hparams)

    # start at random times (so test tube creates separate folders)
    np.random.seed(random.randint(0, 1000))
    time.sleep(np.random.uniform(1))

    # create test-tube experiment
    hparams, sess_ids, exp = create_tt_experiment(hparams)
    if hparams is None:
        print('Experiment exists! Aborting fit')
        return

    # build data generator
    data_generator = build_data_generator(hparams, sess_ids)

    # ####################
    # ### CREATE MODEL ###
    # ####################

    # get all latents in list
    n_datasets = len(data_generator)
    print('collecting observations from data generator...', end='')
    data_key = 'ae_latents'
    if hparams['model_class'].find('labels') > -1:
        data_key = 'labels'
    latents, trial_idxs = get_latent_arrays_by_dtype(data_generator,
                                                     sess_idxs=list(
                                                         range(n_datasets)),
                                                     data_key=data_key)
    obs_dim = latents['train'][0].shape[1]

    hparams['total_train_length'] = np.sum(
        [z.shape[0] for z in latents['train']])
    # get separated by dataset as well
    latents_sess = {d: None for d in range(n_datasets)}
    trial_idxs_sess = {d: None for d in range(n_datasets)}
    for d in range(n_datasets):
        latents_sess[d], trial_idxs_sess[d] = get_latent_arrays_by_dtype(
            data_generator, sess_idxs=d, data_key=data_key)
    print('done')

    if hparams['model_class'] == 'arhmm' or hparams['model_class'] == 'hmm':
        hparams['ae_model_path'] = os.path.join(
            os.path.dirname(data_generator.datasets[0].paths['ae_latents']))
        hparams['ae_model_latents_file'] = data_generator.datasets[0].paths[
            'ae_latents']

    if hparams['n_arhmm_lags'] > 0:
        if hparams['model_class'][:5] != 'arhmm':  # 'arhmm' or 'arhmm-labels'
            raise ValueError(
                'Must specify model_class as arhmm when using AR lags')
    else:
        if hparams['model_class'][:3] != 'hmm':  # 'hmm' or 'hmm-labels'
            raise ValueError(
                'Must specify model_class as hmm when using 0 AR lags')

    # determine observation model
    if hparams['noise_type'] == 'gaussian':
        if hparams['n_arhmm_lags'] > 0:
            obs_type = 'ar'
        else:
            obs_type = 'gaussian'
    elif hparams['noise_type'] == 'studentst':
        if hparams['n_arhmm_lags'] > 0:
            obs_type = 'robust_ar'
        else:
            obs_type = 'studentst'
    elif hparams['noise_type'] == 'diagonal_gaussian':
        if hparams['n_arhmm_lags'] > 0:
            obs_type = 'diagonal_ar'
        else:
            obs_type = 'diagonal_gaussian'
    elif hparams['noise_type'] == 'diagonal_studentst':
        if hparams['n_arhmm_lags'] > 0:
            obs_type = 'diagonal_robust_ar'
        else:
            obs_type = 'diagonal_studentst'
    else:
        raise ValueError('%s is not a valid noise type' %
                         hparams['noise_type'])

    if hparams['n_arhmm_lags'] > 0:
        obs_kwargs = {'lags': hparams['n_arhmm_lags']}
        obs_init_kwargs = {'localize': True}
    else:
        obs_kwargs = None
        obs_init_kwargs = {}

    # determine transition model
    if hparams['transitions'] == 'stationary' or hparams[
            'transitions'] == 'standard':
        transitions = 'stationary'
        transition_kwargs = None
    elif hparams['transitions'] == 'sticky':
        transitions = 'sticky'
        transition_kwargs = {'kappa': hparams['kappa']}
    elif hparams['transitions'] == 'recurrent':
        transitions = 'recurrent'
        transition_kwargs = None
    elif hparams['transitions'] == 'recurrent_only':
        transitions = 'recurrent_only'
        transition_kwargs = None
    else:
        raise ValueError('%s is not a valid transition type' %
                         hparams['transitions'])

    print('constructing model...', end='')
    np.random.seed(hparams['rng_seed_model'])
    hmm = ssm.HMM(hparams['n_arhmm_states'],
                  obs_dim,
                  observations=obs_type,
                  observation_kwargs=obs_kwargs,
                  transitions=transitions,
                  transition_kwargs=transition_kwargs)
    hmm.initialize(latents['train'])
    hmm.observations.initialize(latents['train'], **obs_init_kwargs)
    # save out hparams as csv and dict
    hparams['training_completed'] = False
    export_hparams(hparams, exp)
    hmm.hparams = hparams
    print('done')

    # ####################
    # ### TRAIN MODEL ###
    # ####################

    # TODO: move fitting into own function
    # precompute normalizers
    n_datapoints = {}
    n_datapoints_sess = {}
    for dtype in {'train', 'val', 'test'}:
        n_datapoints[dtype] = np.vstack(latents[dtype]).size
        n_datapoints_sess[dtype] = {}
        for d in range(n_datasets):
            n_datapoints_sess[dtype][d] = np.vstack(
                latents_sess[d][dtype]).size

    val_ll_prev = np.inf
    tolerance = hparams.get('arhmm_es_tol', 0)
    # hmm.fit(
    #     latents['train'], method='em', num_iters=hparams['n_iters'], initialize=False,
    #     tolerance=tolerance)
    # epoch = hparams['n_iters']
    for epoch in range(hparams['n_iters'] + 1):
        # Note: the 0th epoch has no training (randomly initialized model is evaluated) so we cycle
        # through `n_iters` training epochs

        print('epoch %03i/%03i' % (epoch, hparams['n_iters']))
        if epoch > 0:
            hmm.fit(latents['train'],
                    method='em',
                    num_iters=1,
                    initialize=False)

        # export aggregated metrics on train/val data
        tr_ll = -hmm.log_likelihood(latents['train']) / n_datapoints['train']
        val_ll = -hmm.log_likelihood(latents['val']) / n_datapoints['val']
        exp.log({
            'epoch': epoch,
            'dataset': -1,
            'tr_loss': tr_ll,
            'val_loss': val_ll,
            'trial': -1
        })

        # export individual session metrics on train/val data
        for d in range(data_generator.n_datasets):
            tr_ll = -hmm.log_likelihood(
                latents_sess[d]['train']) / n_datapoints_sess['train'][d]
            val_ll = -hmm.log_likelihood(
                latents_sess[d]['val']) / n_datapoints_sess['val'][d]
            exp.log({
                'epoch': epoch,
                'dataset': d,
                'tr_loss': tr_ll,
                'val_loss': val_ll,
                'trial': -1
            })

        # check for convergence
        if epoch > 10 and np.abs((val_ll - val_ll_prev) / val_ll) < tolerance:
            print(
                'relative change less than tolerance=%1.2f; training terminating!'
                % tolerance)
            break

        val_ll_prev = val_ll

    # export individual session metrics on test data
    for d in range(n_datasets):
        for i, b in enumerate(trial_idxs_sess[d]['test']):
            n = latents_sess[d]['test'][i].size
            test_ll = -hmm.log_likelihood(latents_sess[d]['test'][i]) / n
            exp.log({
                'epoch': epoch,
                'dataset': d,
                'test_loss': test_ll,
                'trial': b
            })
    exp.save()

    # reconfigure model/states by usage
    zs = [hmm.most_likely_states(x) for x in latents['train']]
    usage = np.bincount(np.concatenate(zs), minlength=hmm.K)
    perm = np.argsort(usage)[::-1]
    hmm.permute(perm)

    # save model
    filepath = os.path.join(hparams['expt_dir'], 'version_%i' % exp.version,
                            'best_val_model.pt')
    with open(filepath, 'wb') as f:
        pickle.dump(hmm, f)

    # ######################
    # ### EVALUATE ARHMM ###
    # ######################

    # export states
    if hparams['export_states']:
        export_states(hparams, data_generator, hmm)

    # export training plots
    if hparams['export_train_plots']:
        print('creating training plots...', end='')
        version_dir = os.path.join(hparams['expt_dir'],
                                   'version_%i' % hparams['version'])
        save_file = os.path.join(version_dir, 'loss_training')
        export_train_plots(hparams,
                           'train',
                           loss_type='ll',
                           save_file=save_file)
        save_file = os.path.join(version_dir, 'loss_validation')
        export_train_plots(hparams, 'val', loss_type='ll', save_file=save_file)
        print('done')

    # update hparams upon successful training
    hparams['training_completed'] = True
    export_hparams(hparams, exp)

    # get rid of unneeded logging info
    _clean_tt_dir(hparams)