Beispiel #1
0
def main(hparams):

    if not isinstance(hparams, dict):
        hparams = vars(hparams)

    # print hparams to console
    _print_hparams(hparams)

    # start at random times (so test tube creates separate folders)
    np.random.seed(random.randint(0, 1000))
    time.sleep(np.random.uniform(1))

    # create test-tube experiment
    hparams, sess_ids, exp = create_tt_experiment(hparams)
    if hparams is None:
        print('Experiment exists! Aborting fit')
        return

    # build data generator
    data_generator = build_data_generator(hparams, sess_ids)

    # ####################
    # ### CREATE MODEL ###
    # ####################

    # get all latents in list
    n_datasets = len(data_generator)
    print('collecting observations from data generator...', end='')
    data_key = 'ae_latents'
    if hparams['model_class'].find('labels') > -1:
        data_key = 'labels'
    latents, trial_idxs = get_latent_arrays_by_dtype(
        data_generator, sess_idxs=list(range(n_datasets)), data_key=data_key)
    obs_dim = latents['train'][0].shape[1]

    hparams['total_train_length'] = np.sum([l.shape[0] for l in latents['train']])
    # get separated by dataset as well
    latents_sess = {d: None for d in range(n_datasets)}
    trial_idxs_sess = {d: None for d in range(n_datasets)}
    for d in range(n_datasets):
        latents_sess[d], trial_idxs_sess[d] = get_latent_arrays_by_dtype(
            data_generator, sess_idxs=d, data_key=data_key)
    print('done')

    if hparams['model_class'] == 'arhmm' or hparams['model_class'] == 'hmm':
        hparams['ae_model_path'] = os.path.join(
            os.path.dirname(data_generator.datasets[0].paths['ae_latents']))
        hparams['ae_model_latents_file'] = data_generator.datasets[0].paths['ae_latents']

    # collect model constructor inputs
    if hparams['noise_type'] == 'gaussian':
        if hparams['n_arhmm_lags'] > 0:
            if hparams['model_class'][:5] != 'arhmm':  # 'arhmm' or 'arhmm-labels'
                raise ValueError('Must specify model_class as arhmm when using AR lags')
            obs_type = 'ar'
        else:
            if hparams['model_class'][:3] != 'hmm':  # 'hmm' or 'hmm-labels'
                raise ValueError('Must specify model_class as hmm when using 0 AR lags')
            obs_type = 'gaussian'
    elif hparams['noise_type'] == 'studentst':
        if hparams['n_arhmm_lags'] > 0:
            if hparams['model_class'][:5] != 'arhmm':  # 'arhmm' or 'arhmm-labels'
                raise ValueError('Must specify model_class as arhmm when using AR lags')
            obs_type = 'robust_ar'
        else:
            if hparams['model_class'][:3] != 'hmm':  # 'hmm' or 'hmm-labels'
                raise ValueError('Must specify model_class as hmm when using 0 AR lags')
            obs_type = 'studentst'
    else:
        raise ValueError('%s is not a valid noise type' % hparams['noise_type'])

    if hparams['n_arhmm_lags'] > 0:
        obs_kwargs = {'lags': hparams['n_arhmm_lags']}
        obs_init_kwargs = {'localize': True}
    else:
        obs_kwargs = None
        obs_init_kwargs = {}
    if hparams['kappa'] == 0:
        transitions = 'stationary'
        transition_kwargs = None
    else:
        transitions = 'sticky'
        transition_kwargs = {'kappa': hparams['kappa']}

    print('constructing model...', end='')
    np.random.seed(hparams['rng_seed_model'])
    hmm = ssm.HMM(
        hparams['n_arhmm_states'], obs_dim,
        observations=obs_type, observation_kwargs=obs_kwargs,
        transitions=transitions, transition_kwargs=transition_kwargs)
    hmm.initialize(latents['train'])
    hmm.observations.initialize(latents['train'], **obs_init_kwargs)
    # save out hparams as csv and dict
    hparams['training_completed'] = False
    export_hparams(hparams, exp)
    print('done')

    # ####################
    # ### TRAIN MODEL ###
    # ####################

    # TODO: move fitting into own function
    # TODO: adopt early stopping strategy from ssm
    # precompute normalizers
    n_datapoints = {}
    n_datapoints_sess = {}
    for dtype in {'train', 'val', 'test'}:
        n_datapoints[dtype] = np.vstack(latents[dtype]).size
        n_datapoints_sess[dtype] = {}
        for d in range(n_datasets):
            n_datapoints_sess[dtype][d] = np.vstack(latents_sess[d][dtype]).size

    for epoch in range(hparams['n_iters'] + 1):
        # Note: the 0th epoch has no training (randomly initialized model is evaluated) so we cycle
        # through `n_iters` training epochs

        print('epoch %03i/%03i' % (epoch, hparams['n_iters']))
        if epoch > 0:
            hmm.fit(latents['train'], method='em', num_iters=1, initialize=False)

        # export aggregated metrics on train/val data
        tr_ll = hmm.log_likelihood(latents['train']) / n_datapoints['train']
        val_ll = hmm.log_likelihood(latents['val']) / n_datapoints['val']
        exp.log({
            'epoch': epoch, 'dataset': -1, 'tr_loss': tr_ll, 'val_loss': val_ll, 'trial': -1})

        # export individual session metrics on train/val data
        for d in range(data_generator.n_datasets):
            tr_ll = hmm.log_likelihood(latents_sess[d]['train']) / n_datapoints_sess['train'][d]
            val_ll = hmm.log_likelihood(latents_sess[d]['val']) / n_datapoints_sess['val'][d]
            exp.log({
                'epoch': epoch, 'dataset': d, 'tr_loss': tr_ll, 'val_loss': val_ll, 'trial': -1})

    # export individual session metrics on test data
    for d in range(n_datasets):
        for i, b in enumerate(trial_idxs_sess[d]['test']):
            n = latents_sess[d]['test'][i].size
            test_ll = hmm.log_likelihood(latents_sess[d]['test'][i]) / n
            exp.log({'epoch': epoch, 'dataset': d, 'test_loss': test_ll, 'trial': b})
    exp.save()

    # reconfigure model/states by usage
    zs = [hmm.most_likely_states(x) for x in latents['train']]
    usage = np.bincount(np.concatenate(zs), minlength=hmm.K)
    perm = np.argsort(usage)[::-1]
    hmm.permute(perm)

    # save model
    filepath = os.path.join(hparams['expt_dir'], 'version_%i' % exp.version, 'best_val_model.pt')
    with open(filepath, 'wb') as f:
        pickle.dump(hmm, f)   

    # ######################
    # ### EVALUATE ARHMM ###
    # ######################

    # export states
    if hparams['export_states']:
        export_states(hparams, data_generator, hmm)

    # export training plots
    if hparams['export_train_plots']:
        print('creating training plots...', end='')
        version_dir = os.path.join(hparams['expt_dir'], 'version_%i' % hparams['version'])
        save_file = os.path.join(version_dir, 'loss_training')
        export_train_plots(hparams, 'train', loss_type='ll', save_file=save_file)
        save_file = os.path.join(version_dir, 'loss_validation')
        export_train_plots(hparams, 'val', loss_type='ll', save_file=save_file)
        print('done')

    # update hparams upon successful training
    hparams['training_completed'] = True
    export_hparams(hparams, exp)

    # get rid of unneeded logging info
    _clean_tt_dir(hparams)
def main(hparams, *args):

    if not isinstance(hparams, dict):
        hparams = vars(hparams)

    if hparams['model_type'] == 'conv':
        # blend outer hparams with architecture hparams
        hparams = {**hparams['architecture_params'], **hparams}

    # print hparams to console
    _print_hparams(hparams)

    if hparams['model_type'] == 'conv' and hparams['n_ae_latents'] > hparams[
            'max_latents']:
        raise ValueError(
            'Number of latents higher than max latents, architecture will not work'
        )

    # Start at random times (so test tube creates separate folders)
    np.random.seed(random.randint(0, 1000))
    time.sleep(np.random.uniform(3))

    # create test-tube experiment
    hparams, sess_ids, exp = create_tt_experiment(hparams)
    if hparams is None:
        print('Experiment exists! Aborting fit')
        return

    # build data generator
    data_generator = build_data_generator(hparams, sess_ids)

    # ####################
    # ### CREATE MODEL ###
    # ####################

    def set_n_labels(data_generator, hparams):
        data, _ = data_generator.next_batch('val')
        sh = data['labels'].shape
        hparams['n_labels'] = sh[2]  # [1, n_t, n_labels]

    print('constructing model...', end='')
    torch.manual_seed(hparams['rng_seed_model'])
    torch_rng_seed = torch.get_rng_state()
    hparams['model_build_rng_seed'] = torch_rng_seed
    hparams['n_datasets'] = len(sess_ids)
    if hparams['model_class'] == 'ae':
        from behavenet.models import AE as Model
    elif hparams['model_class'] == 'vae':
        from behavenet.models import VAE as Model
    elif hparams['model_class'] == 'beta-tcvae':
        from behavenet.models import BetaTCVAE as Model
    elif hparams['model_class'] == 'ps-vae':
        from behavenet.models import PSVAE as Model
        set_n_labels(data_generator, hparams)
    elif hparams['model_class'] == 'msps-vae':
        from behavenet.models import MSPSVAE as Model
        set_n_labels(data_generator, hparams)
    elif hparams['model_class'] == 'cond-vae':
        from behavenet.models import ConditionalVAE as Model
        set_n_labels(data_generator, hparams)
    elif hparams['model_class'] == 'cond-ae':
        from behavenet.models import ConditionalAE as Model
        set_n_labels(data_generator, hparams)
    elif hparams['model_class'] == 'cond-ae-msp':
        from behavenet.models import AEMSP as Model
        set_n_labels(data_generator, hparams)
    else:
        raise NotImplementedError(
            'The model class "%s" is not currently implemented' %
            hparams['model_class'])
    model = Model(hparams)
    model.to(hparams['device'])

    # load pretrained weights if specified
    model = load_pretrained_ae(model, hparams)

    # Parallelize over gpus if desired
    if hparams['n_parallel_gpus'] > 1:
        from behavenet.models import CustomDataParallel
        model = CustomDataParallel(model)

    model.version = exp.version
    torch_rng_seed = torch.get_rng_state()
    hparams['training_rng_seed'] = torch_rng_seed

    # save out hparams as csv and dict
    hparams['training_completed'] = False
    export_hparams(hparams, exp)
    print('done')

    # ###################
    # ### TRAIN MODEL ###
    # ###################

    print(model)

    fit(hparams, model, data_generator, exp, method='ae')

    # update hparams upon successful training
    hparams['training_completed'] = True
    export_hparams(hparams, exp)

    # get rid of unneeded logging info
    _clean_tt_dir(hparams)

    # export training plots
    if hparams['export_train_plots']:
        print('creating training plots...', end='')
        version_dir = os.path.join(hparams['expt_dir'],
                                   'version_%i' % hparams['version'])
        if hparams['model_class'] == 'msps-vae':
            from behavenet.plotting.cond_ae_utils import plot_mspsvae_training_curves
            save_file = os.path.join(version_dir, 'loss_training')
            plot_mspsvae_training_curves(
                hparams,
                alpha=hparams['ps_vae.alpha'],
                beta=hparams['ps_vae.beta'],
                delta=hparams['ps_vae.delta'],
                rng_seed_model=hparams['rng_seed_model'],
                n_latents=hparams['n_ae_latents'] - hparams['n_background'] -
                hparams['n_labels'],
                n_background=hparams['n_background'],
                n_labels=hparams['n_labels'],
                dtype='train',
                save_file=save_file,
                format='png',
                version_dir=version_dir)
            save_file = os.path.join(version_dir, 'loss_validation')
            plot_mspsvae_training_curves(
                hparams,
                alpha=hparams['ps_vae.alpha'],
                beta=hparams['ps_vae.beta'],
                delta=hparams['ps_vae.delta'],
                rng_seed_model=hparams['rng_seed_model'],
                n_latents=hparams['n_ae_latents'] - hparams['n_background'] -
                hparams['n_labels'],
                n_background=hparams['n_background'],
                n_labels=hparams['n_labels'],
                dtype='val',
                save_file=save_file,
                format='png',
                version_dir=version_dir)
        else:
            save_file = os.path.join(version_dir, 'loss_training')
            export_train_plots(hparams, 'train', save_file=save_file)
            save_file = os.path.join(version_dir, 'loss_validation')
            export_train_plots(hparams, 'val', save_file=save_file)
        print('done')
Beispiel #3
0
def main(hparams, *args):

    if not isinstance(hparams, dict):
        hparams = vars(hparams)

    # print hparams to console
    _print_hparams(hparams)

    if hparams['model_type'] == 'conv':
        # blend outer hparams with architecture hparams
        hparams = {**hparams, **hparams['architecture_params']}

    if hparams['model_type'] == 'conv' and hparams['n_ae_latents'] > hparams[
            'max_latents']:
        raise ValueError(
            'Number of latents higher than max latents, architecture will not work'
        )

    # Start at random times (so test tube creates separate folders)
    np.random.seed(random.randint(0, 1000))
    time.sleep(np.random.uniform(1))

    # create test-tube experiment
    hparams, sess_ids, exp = create_tt_experiment(hparams)
    if hparams is None:
        print('Experiment exists! Aborting fit')
        return

    # build data generator
    data_generator = build_data_generator(hparams, sess_ids)

    # ####################
    # ### CREATE MODEL ###
    # ####################

    print('constructing model...', end='')
    torch.manual_seed(hparams['rng_seed_model'])
    torch_rnd_seed = torch.get_rng_state()
    hparams['model_build_rnd_seed'] = torch_rnd_seed
    hparams['n_datasets'] = len(sess_ids)
    model = AE(hparams)
    model.to(hparams['device'])
    model.version = exp.version
    torch_rnd_seed = torch.get_rng_state()
    hparams['training_rnd_seed'] = torch_rnd_seed

    # save out hparams as csv and dict
    hparams['training_completed'] = False
    export_hparams(hparams, exp)
    print('done')

    # ####################
    # ### TRAIN MODEL ###
    # ####################

    fit(hparams, model, data_generator, exp, method='ae')

    # export training plots
    if hparams['export_train_plots']:
        print('creating training plots...', end='')
        version_dir = os.path.join(hparams['expt_dir'],
                                   'version_%i' % hparams['version'])
        save_file = os.path.join(version_dir, 'loss_training')
        export_train_plots(hparams, 'train', save_file=save_file)
        save_file = os.path.join(version_dir, 'loss_validation')
        export_train_plots(hparams, 'val', save_file=save_file)
        print('done')

    # update hparams upon successful training
    hparams['training_completed'] = True
    export_hparams(hparams, exp)

    # get rid of unneeded logging info
    _clean_tt_dir(hparams)
Beispiel #4
0
def main(hparams, *args):

    if not isinstance(hparams, dict):
        hparams = vars(hparams)

    # print hparams to console
    _print_hparams(hparams)

    # Start at random times (so test tube creates separate folders)
    np.random.seed(random.randint(0, 1000))
    time.sleep(np.random.uniform(1))

    # create test-tube experiment
    hparams, sess_ids, exp = create_tt_experiment(hparams)
    if hparams is None:
        print('Experiment exists! Aborting fit')
        return

    # build data generator
    data_generator = build_data_generator(hparams, sess_ids)

    ex_trial = data_generator.datasets[0].batch_idxs['train'][0]
    i_sig = hparams['input_signal']
    o_sig = hparams['output_signal']

    if hparams['model_class'] == 'neural-arhmm':
        hparams['input_size'] = data_generator.datasets[0][ex_trial][
            i_sig].shape[1]
        hparams['output_size'] = hparams['n_arhmm_states']
    elif hparams['model_class'] == 'arhmm-neural':
        hparams['input_size'] = hparams['n_arhmm_states']
        hparams['output_size'] = data_generator.datasets[0][ex_trial][
            o_sig].shape[1]
    elif hparams['model_class'] == 'neural-ae':
        hparams['input_size'] = data_generator.datasets[0][ex_trial][
            i_sig].shape[1]
        hparams['output_size'] = hparams['n_ae_latents']
    elif hparams['model_class'] == 'neural-ae-me':
        hparams['input_size'] = data_generator.datasets[0][ex_trial][
            i_sig].shape[1]
        hparams['output_size'] = hparams['n_ae_latents']
    elif hparams['model_class'] == 'ae-neural':
        hparams['input_size'] = hparams['n_ae_latents']
        hparams['output_size'] = data_generator.datasets[0][ex_trial][
            o_sig].shape[1]
    elif hparams['model_class'] == 'neural-labels':
        hparams['input_size'] = data_generator.datasets[0][ex_trial][
            i_sig].shape[1]
        hparams['output_size'] = hparams['n_labels']
    elif hparams['model_class'] == 'labels-neural':
        hparams['input_size'] = hparams['n_labels']
        hparams['output_size'] = data_generator.datasets[0][ex_trial][
            o_sig].shape[1]
    else:
        raise ValueError('%s is an invalid model class' %
                         hparams['model_class'])

    if hparams['model_class'] == 'neural-ae' or hparams['model_class'] == 'neural-ae' \
            or hparams['model_class'] == 'ae-neural':
        hparams['ae_model_path'] = os.path.join(
            os.path.dirname(data_generator.datasets[0].paths['ae_latents']))
        hparams['ae_model_latents_file'] = data_generator.datasets[0].paths[
            'ae_latents']
    elif hparams['model_class'] == 'neural-arhmm' or hparams[
            'model_class'] == 'arhmm-neural':
        hparams['arhmm_model_path'] = os.path.dirname(
            data_generator.datasets[0].paths['arhmm_states'])
        hparams['arhmm_model_states_file'] = data_generator.datasets[0].paths[
            'arhmm_states']

        # Store which AE was used for the ARHMM
        tags = pickle.load(
            open(os.path.join(hparams['arhmm_model_path'], 'meta_tags.pkl'),
                 'rb'))
        hparams['ae_model_latents_file'] = tags['ae_model_latents_file']

    # ####################
    # ### CREATE MODEL ###
    # ####################
    print('constructing model...', end='')
    torch.manual_seed(hparams['rng_seed_model'])
    torch_rng_seed = torch.get_rng_state()
    hparams['model_build_rng_seed'] = torch_rng_seed
    model = Decoder(hparams)
    model.to(hparams['device'])
    model.version = exp.version
    torch_rng_seed = torch.get_rng_state()
    hparams['training_rng_seed'] = torch_rng_seed

    # save out hparams as csv and dict for easy reloading
    hparams['training_completed'] = False
    export_hparams(hparams, exp)
    print('done')

    # ####################
    # ### TRAIN MODEL ###
    # ####################

    fit(hparams, model, data_generator, exp, method='nll')

    # update hparams upon successful training
    hparams['training_completed'] = True
    export_hparams(hparams, exp)

    # get rid of unneeded logging info
    _clean_tt_dir(hparams)
def main(hparams, *args):

    if not isinstance(hparams, dict):
        hparams = vars(hparams)

    if hparams['model_type'] == 'conv':
        # blend outer hparams with architecture hparams
        hparams = {**hparams, **hparams['architecture_params']}

    # print hparams to console
    _print_hparams(hparams)

    # Start at random times (so test tube creates separate folders)
    np.random.seed(random.randint(0, 1000))
    time.sleep(np.random.uniform(1))

    # create test-tube experiment
    hparams, sess_ids, exp = create_tt_experiment(hparams)
    if hparams is None:
        print('Experiment exists! Aborting fit')
        return

    # build data generator
    data_generator = build_data_generator(hparams, sess_ids)

    # ####################
    # ### CREATE MODEL ###
    # ####################

    print('constructing model...', end='')
    torch.manual_seed(hparams['rng_seed_model'])
    torch_rnd_seed = torch.get_rng_state()
    hparams['model_build_rnd_seed'] = torch_rnd_seed
    hparams['n_datasets'] = len(sess_ids)
    data, _ = data_generator.next_batch('train')
    sh = data['labels'].shape
    hparams['n_labels'] = sh[2]  # [1, n_t, n_labels]
    model = ConvDecoder(hparams)
    model.to(hparams['device'])

    # Load pretrained weights if specified
    # model = load_pretrained_ae(model, hparams)

    model.version = exp.version
    torch_rnd_seed = torch.get_rng_state()
    hparams['training_rnd_seed'] = torch_rnd_seed

    # save out hparams as csv and dict
    hparams['training_completed'] = False
    export_hparams(hparams, exp)
    print('done')

    # ###################
    # ### TRAIN MODEL ###
    # ###################

    fit(hparams, model, data_generator, exp, method='conv-decoder')

    # export training plots
    if hparams['export_train_plots']:
        print('creating training plots...', end='')
        version_dir = os.path.join(hparams['expt_dir'],
                                   'version_%i' % hparams['version'])
        save_file = os.path.join(version_dir, 'loss_training')
        export_train_plots(hparams, 'train', save_file=save_file)
        save_file = os.path.join(version_dir, 'loss_validation')
        export_train_plots(hparams, 'val', save_file=save_file)
        print('done')

    # update hparams upon successful training
    hparams['training_completed'] = True
    export_hparams(hparams, exp)

    # get rid of unneeded logging info
    _clean_tt_dir(hparams)