Ejemplo n.º 1
0
def fit_pcnorm(modelspec,
               est: recording.Recording,
               metric=None,
               use_modelspec_init: bool = True,
               optimizer: str = 'adam',
               max_iter: int = 10000,
               early_stopping_steps: int = 5,
               tolerance: float = 5e-4,
               learning_rate: float = 1e-4,
               batch_size: typing.Union[None, int] = None,
               seed: int = 0,
               initializer: str = 'random_normal',
               freeze_layers: typing.Union[None, list] = None,
               epoch_name: str = "REFERENCE",
               n_pcs=2,
               **context):
    '''
    Required Arguments:
     est          A recording object
     modelspec     A modelspec object

    Optional Arguments:
     <copied from fit_tf for now

    Returns
     dictionary: {'modelspec': updated_modelspec}
    '''

    # Hard-coded
    cost_function = basic_cost
    fitter = scipy_minimize
    segmentor = nems.segmentors.use_all_data
    mapper = nems.fitters.mappers.simple_vector
    fit_kwargs = {'tolerance': tolerance, 'max_iter': max_iter}

    start_time = time.time()

    modelspec = copy.deepcopy(modelspec)

    # apply mask to remove invalid portions of signals and allow fit to
    # only evaluate the model on the valid portion of the signals
    if 'mask' in est.signals.keys():
        log.info("Data len pre-mask: %d", est['mask'].shape[1])
        est = est.apply_mask()
        log.info("Data len post-mask: %d", est['mask'].shape[1])

    conditions = [
        "_".join(k.split("_")[1:]) for k in est.signals.keys()
        if k.startswith("mask_")
    ]
    if (len(conditions) > 2) and any(
        [c.split("_")[-1] == 'lg' for c in conditions]):
        conditions.remove("small")
        conditions.remove("large")
    #conditions = conditions[0:2]
    #conditions = ['large','small']

    group_idx = [est['mask_' + c].as_continuous()[0, :] for c in conditions]
    cg_filtered = [(c, g) for c, g in zip(conditions, group_idx)
                   if g.sum() > 0]
    conditions, group_idx = zip(*cg_filtered)

    for c, g in zip(conditions, group_idx):
        log.info(f"Data subset for {c} len {g.sum()}")

    resp = est['resp'].as_continuous()
    pred0 = est['pred0'].as_continuous()
    residual = resp - pred0

    pca = PCA(n_components=n_pcs)
    pca.fit(residual.T)
    pc_axes = pca.components_

    pcproj = residual.T.dot(pc_axes.T).T

    group_pc = [pcproj[:, idx].std(axis=1) for idx in group_idx]
    resp_std = resp.std(axis=1)
    #import pdb; pdb.set_trace()

    if metric is None:
        metric = lambda d: pc_err(d,
                                  pred_name='pred',
                                  pred0_name='pred0',
                                  group_idx=group_idx,
                                  group_pc=group_pc,
                                  pc_axes=pc_axes,
                                  resp_std=resp_std)

    # turn on "fit mode". currently this serves one purpose, for normalization
    # parameters to be re-fit for the output of each module that uses
    # normalization. does nothing if normalization is not being used.
    ms.fit_mode_on(modelspec, est)

    # Create the mapper functions that translates to and from modelspecs.
    # It has three functions that, when defined as mathematical functions, are:
    #    .pack(modelspec) -> fitspace_point
    #    .unpack(fitspace_point) -> modelspec
    #    .bounds(modelspec) -> fitspace_bounds
    packer, unpacker, pack_bounds = mapper(modelspec)

    # A function to evaluate the modelspec on the data
    evaluator = nems.modelspec.evaluate

    my_cost_function = cost_function
    my_cost_function.counter = 0

    # Freeze everything but sigma, since that's all the fitter should be
    # updating.
    cost_fn = partial(my_cost_function,
                      unpacker=unpacker,
                      modelspec=modelspec,
                      data=est,
                      segmentor=segmentor,
                      evaluator=evaluator,
                      metric=metric,
                      display_N=1000)

    # get initial sigma value representing some point in the fit space,
    # and corresponding bounds for each value
    sigma = packer(modelspec)
    bounds = pack_bounds(modelspec)

    # Results should be a list of modelspecs
    # (might only be one in list, but still should be packaged as a list)
    improved_sigma = fitter(sigma, cost_fn, bounds=bounds, **fit_kwargs)
    improved_modelspec = unpacker(improved_sigma)
    elapsed_time = (time.time() - start_time)

    start_err = cost_fn(sigma)
    final_err = cost_fn(improved_sigma)
    log.info("Delta error: %.06f - %.06f = %e", start_err, final_err,
             final_err - start_err)

    # TODO: Should this maybe be moved to a higher level
    # so it applies to ALL the fittters?
    ms.fit_mode_off(improved_modelspec)
    ms.set_modelspec_metadata(improved_modelspec, 'fitter', 'ccnorm')
    ms.set_modelspec_metadata(improved_modelspec, 'fit_time', elapsed_time)
    ms.set_modelspec_metadata(improved_modelspec, 'n_parms',
                              len(improved_sigma))

    return {'modelspec': improved_modelspec.copy(), 'save_context': True}
Ejemplo n.º 2
0
def fit_tf(modelspec,
           est: recording.Recording,
           use_modelspec_init: bool = True,
           optimizer: str = 'adam',
           max_iter: int = 10000,
           cost_function: str = 'squared_error',
           early_stopping_steps: int = 5,
           early_stopping_tolerance: float = 5e-4,
           learning_rate: float = 1e-4,
           batch_size: typing.Union[None, int] = None,
           seed: int = 0,
           initializer: str = 'random_normal',
           filepath: typing.Union[str, Path] = None,
           freeze_layers: typing.Union[None, list] = None,
           IsReload: bool = False,
           epoch_name: str = "REFERENCE",
           **context) -> dict:
    """TODO

    :param est:
    :param modelspec:
    :param use_modelspec_init:
    :param optimizer:
    :param max_iter:
    :param cost_function:
    :param early_stopping_steps:
    :param early_stopping_tolerance:
    :param learning_rate:
    :param batch_size:
    :param seed:
    :param filepath:
    :param freeze_layers: Indexes of layers to freeze prior to training. Indexes are modelspec indexes, so are offset
      from model layer indexes.
    :param IsReload:
    :param epoch_name
    :param context:

    :return: dict {'modelspec': modelspec}
    """

    if IsReload:
        return {}

    tf.random.set_seed(seed)
    np.random.seed(seed)
    #os.environ['TF_DETERMINISTIC_OPS'] = '1'   # makes output deterministic, but reduces prediction accuracy

    log.info('Building tensorflow keras model from modelspec.')
    nems.utils.progress_fun()

    # figure out where to save model checkpoints
    if filepath is None:
        filepath = modelspec.meta['modelpath']

    # if job is running on slurm, need to change model checkpoint dir
    job_id = os.environ.get('SLURM_JOBID', None)
    if job_id is not None:
        # keep a record of the job id
        modelspec.meta['slurm_jobid'] = job_id

        log_dir_root = Path('/mnt/scratch')
        assert log_dir_root.exists()
        log_dir_sub = Path('SLURM_JOBID' + job_id) / str(modelspec.meta['batch'])\
                      / modelspec.meta.get('cellid', "NOCELL")\
                      / modelspec.meta['modelname']
        filepath = log_dir_root / log_dir_sub

    filepath = Path(filepath)
    if not filepath.exists():
        filepath.mkdir(exist_ok=True, parents=True)

    checkpoint_filepath = filepath / 'weights.hdf5'
    tensorboard_filepath = filepath / 'logs'
    gradient_filepath = filepath / 'gradients'

    # update seed based on fit index
    seed += modelspec.fit_index

    # need to get duration of stims in order to reshape data
    #epoch_name = 'REFERENCE'  # TODO: this should not be hardcoded
    # moved to input parameter

    input_name = modelspec.meta.get('input_name', 'stim')
    output_name = modelspec.meta.get('output_name', 'resp')

    # also grab the fs
    fs = est[input_name].fs

    if (epoch_name is not None) and (epoch_name != ""):
        # extract out the raw data, and reshape to (batch, time, channel)
        stim_train = np.transpose(
            est[input_name].extract_epoch(epoch=epoch_name, mask=est['mask']),
            [0, 2, 1])
        resp_train = np.transpose(
            est[output_name].extract_epoch(epoch=epoch_name, mask=est['mask']),
            [0, 2, 1])
    else:
        # extract data as a single batch size (1, time, channel)
        stim_train = np.transpose(
            est.apply_mask()[input_name].as_continuous()[np.newaxis, ...],
            [0, 2, 1])
        resp_train = np.transpose(
            est.apply_mask()[output_name].as_continuous()[np.newaxis, ...],
            [0, 2, 1])

    log.info(
        f'Feature dimensions: {stim_train.shape}; Data dimensions: {resp_train.shape}.'
    )

    # get state if present, and setup training data
    if 'state' in est.signals:
        if (epoch_name is not None) and (epoch_name != ""):
            state_train = np.transpose(
                est['state'].extract_epoch(epoch=epoch_name, mask=est['mask']),
                [0, 2, 1])
        else:
            state_train = np.transpose(
                est.apply_mask()['state'].as_continuous()[np.newaxis, ...],
                [0, 2, 1])
        state_shape = state_train.shape
        log.info(f'State dimensions: {state_shape}')
        train_data = [stim_train, state_train]
    else:
        state_train, state_shape = None, None
        train_data = stim_train

    # correlation for monitoring
    # TODO: tf.utils?
    def pearson(y_true, y_pred):
        return tfp.stats.correlation(y_true,
                                     y_pred,
                                     event_axis=None,
                                     sample_axis=None)

    # get the layers and build the model
    cost_fn = loss_functions.get_loss_fn(cost_function)
    model_layers = modelspec.modelspec2tf2(
        use_modelspec_init=use_modelspec_init,
        seed=seed,
        fs=fs,
        initializer=initializer)
    if np.any([isinstance(layer, Conv2D_NEMS) for layer in model_layers]):
        # need a "channel" dimension for Conv2D (like rgb channels, not frequency). Only 1 channel for our data.
        stim_train = stim_train[..., np.newaxis]
        train_data = train_data[..., np.newaxis]

    # do some batch sizing logic
    batch_size = stim_train.shape[0] if batch_size == 0 else batch_size

    model = modelbuilder.ModelBuilder(
        name='Test-model',
        layers=model_layers,
        learning_rate=learning_rate,
        loss_fn=cost_fn,
        optimizer=optimizer,
        metrics=[pearson],
    ).build_model(input_shape=stim_train.shape,
                  state_shape=state_shape,
                  batch_size=batch_size)

    # tracking early termination
    model.early_terminated = False

    # create the callbacks
    early_stopping = callbacks.DelayedStopper(
        monitor='loss',
        patience=30 * early_stopping_steps,
        min_delta=early_stopping_tolerance,
        verbose=1,
        restore_best_weights=False)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        filepath=str(checkpoint_filepath),
        save_best_only=False,
        save_weights_only=True,
        save_freq=100 * stim_train.shape[0],
        monitor='loss',
        verbose=0)
    sparse_logger = callbacks.SparseProgbarLogger(n_iters=10)
    nan_terminate = tf.keras.callbacks.TerminateOnNaN()
    nan_weight_terminate = callbacks.TerminateOnNaNWeights()
    tensorboard = tf.keras.callbacks.TensorBoard(
        log_dir=str(tensorboard_filepath),  # TODO: generic tensorboard dir?
        histogram_freq=0,  # record the distribution of the weights
        write_graph=False,
        update_freq='epoch',
        profile_batch=0)
    # gradient_logger = callbacks.GradientLogger(filepath=str(gradient_filepath),
    #                                            train_input=stim_train,
    #                                            model=model)

    # freeze layers
    if freeze_layers is not None:
        for freeze_index in freeze_layers:
            log.info(
                f'Freezing layer #{freeze_index}: "{model.layers[freeze_index + 1].name}".'
            )
            model.layers[freeze_index + 1].trainable = False

    # save an initial set of weights before freezing, in case of termination before any checkpoints
    #log.info('saving weights to : %s', str(checkpoint_filepath) )
    model.save_weights(str(checkpoint_filepath), overwrite=True)

    if version.parse(tf.__version__) >= version.parse("2.2.0"):
        callback0 = [sparse_logger]
        verbose = 0
    else:
        callback0 = []
        verbose = 2

    log.info(f'Fitting model (batch_size={batch_size})...')
    history = model.fit(
        train_data,
        resp_train,
        # validation_split=0.2,
        verbose=verbose,
        epochs=max_iter,
        batch_size=batch_size,
        callbacks=callback0 + [
            nan_terminate,
            nan_weight_terminate,
            early_stopping,
            checkpoint,
            # enable the below to log tracked parameters to tensorboard
            # tensorboard,
            # enable the below to record gradients to visualize in tensorboard; this is very slow,
            # and loading all this into tensorboard can use A LOT of memory
            # gradient_logger,
        ])

    # did we terminate on a nan loss or weights? Load checkpoint if so
    if np.all(np.isnan(model.predict(train_data))
              ) or model.early_terminated:  # TODO: should this be np.any()?
        log.warning(
            'Model terminated on nan loss or weights, restoring saved weights.'
        )
        try:
            # this can fail if it nans out before a single checkpoint gets saved, either because no saved weights
            # exist, or it tries to load a in different model from the init
            model.load_weights(str(checkpoint_filepath))
            log.warning('Reloaded previous saved weights after nan loss.')
        except (tf.errors.NotFoundError, ValueError):
            pass

    modelspec = tf2modelspec(model, modelspec)

    contains_tf_only_layers = np.any(
        ['tf_only' in m['fn'] for m in modelspec.modules])
    if not contains_tf_only_layers:
        # compare the predictions from the model and modelspec
        error = compare_ms_tf(modelspec, model, est, train_data)
        if error > 1e-5:
            log.warning(
                f'Mean difference between NEMS and TF model prediction: {error}'
            )
        else:
            log.info(
                f'Mean difference between NEMS and TF model prediction: {error}'
            )
    else:
        # nothing to compare, ms evaluation is not implemented for this type of model
        pass

    # add in some relevant meta information
    modelspec.meta['n_parms'] = len(modelspec.phi_vector)
    try:
        n_epochs = len(history.history['loss'])
    except KeyError:
        n_epochs = 0
    try:
        max_iter = modelspec.meta['extra_results']
        modelspec.meta['extra_results'] = max(max_iter, n_epochs)
    except KeyError:
        modelspec.meta['extra_results'] = n_epochs

    nems.utils.progress_fun()

    return {'modelspec': modelspec}
Ejemplo n.º 3
0
def fit_ccnorm(modelspec,
               est: recording.Recording,
               metric=None,
               use_modelspec_init: bool = True,
               optimizer: str = 'adam',
               max_iter: int = 10000,
               early_stopping_steps: int = 5,
               tolerance: float = 5e-4,
               learning_rate: float = 1e-4,
               batch_size: typing.Union[None, int] = None,
               seed: int = 0,
               initializer: str = 'random_normal',
               freeze_layers: typing.Union[None, list] = None,
               epoch_name: str = "REFERENCE",
               shrink_cc: float = 0,
               noise_pcs: int = 0,
               shared_pcs: int = 0,
               also_fit_resp: bool = False,
               force_psth: bool = False,
               use_metric: typing.Union[None, str] = None,
               alpha: float = 0.1,
               beta: float = 1,
               exclude_idx=None,
               exclude_after=None,
               freeze_idx=None,
               freeze_after=None,
               **context):
    '''
    Required Arguments:
     est          A recording object
     modelspec     A modelspec object

    Optional Arguments:
     <copied from fit_tf for now

    Returns
     dictionary: {'modelspec': updated_modelspec}
    '''

    # Hard-coded
    cost_function = basic_cost
    fitter = scipy_minimize
    segmentor = nems.segmentors.use_all_data
    mapper = nems.fitters.mappers.simple_vector
    fit_kwargs = {'tolerance': tolerance, 'max_iter': max_iter}

    start_time = time.time()

    fit_index = modelspec.fit_index
    if (exclude_idx is not None) | (freeze_idx is not None) | \
            (exclude_after is not None) | (freeze_after is not None):
        modelspec0 = modelspec.copy()
        modelspec, include_set = modelspec_freeze_layers(
            modelspec,
            include_idx=None,
            exclude_idx=exclude_idx,
            exclude_after=exclude_after,
            freeze_idx=freeze_idx,
            freeze_after=freeze_after)
        modelspec0.set_fit(fit_index)
        modelspec.set_fit(fit_index)
    else:
        include_set = None

    # Computing PCs before masking out unwanted stimuli in order to
    # preserve match with epochs
    epoch_regex = "^STIM_"
    stims = (est.epochs['name'].value_counts() >= 8)
    stims = [
        stims.index[i] for i, s in enumerate(stims)
        if bool(re.search(epoch_regex, stims.index[i])) and s == True
    ]

    Rall_u = est.apply_mask()['psth'].as_continuous().T
    # can't simply extract evoked for refs because can be longer/shorted if it came after target
    # and / or if it was the last stim. So, masking prestim / postim doesn't work. Do it manually
    #d = est['resp'].extract_epochs(stims, mask=est['mask'])

    #R = [v.mean(axis=0) for (k, v) in d.items()]
    #R = [np.reshape(np.transpose(v,[1,0,2]),[v.shape[1],-1]) for (k, v) in d.items()]
    #Rall_u = np.hstack(R).T

    pca = PCA(n_components=2)
    pca.fit(Rall_u)
    pc_axes = pca.components_

    # apply mask to remove invalid portions of signals and allow fit to
    # only evaluate the model on the valid portion of the signals
    if 'mask_small' in est.signals.keys():
        log.info('reseting mask with mask_small+mask_large subset')
        est['mask'] = est['mask']._modified_copy(data=est['mask_small']._data +
                                                 est['mask_large']._data)

    if 'mask' in est.signals.keys():
        log.info("Data len pre-mask: %d", est['mask'].shape[1])
        est = est.apply_mask()
        log.info("Data len post-mask: %d", est['mask'].shape[1])

    # if we want to fit to first-order cc error.
    #uncomment this and make sure sdexp is generating a pred0 signal
    est = modelspec.evaluate(est, stop=2)
    if ('pred0' in est.signals.keys()) & (not force_psth):
        input_name = 'pred0'
        log.info('Found pred0 for fitting CC')
    else:
        input_name = 'psth'
        log.info('No pred0, using psth for fitting CC')

    conditions = [
        "_".join(k.split("_")[1:]) for k in est.signals.keys()
        if k.startswith("mask_")
    ]
    if (len(conditions) > 2) and any(
        [c.split("_")[-1] == 'lg' for c in conditions]):
        conditions.remove("small")
        conditions.remove("large")
    #conditions = conditions[0:2]
    #conditions = ['large','small']

    group_idx = [est['mask_' + c].as_continuous()[0, :] for c in conditions]
    cg_filtered = [(c, g) for c, g in zip(conditions, group_idx)
                   if g.sum() > 0]
    conditions, group_idx = zip(*cg_filtered)

    for c, g in zip(conditions, group_idx):
        log.info(f"cc data for {c} len {g.sum()}")

    resp = est['resp'].as_continuous()
    pred0 = est[input_name].as_continuous()
    #import pdb; pdb.set_trace()
    if shrink_cc > 0:
        log.info(f'cc approx: shrink_cc={shrink_cc}')
        group_cc = [
            cc_shrink(resp[:, idx] - pred0[:, idx], sigrat=shrink_cc)
            for idx in group_idx
        ]
    elif shared_pcs > 0:
        log.info(f'cc approx: shared_pcs={shared_pcs}')
        cc = np.cov(resp - pred0)
        u, s, vh = np.linalg.svd(cc)
        U = u[:, :shared_pcs] @ u[:, :shared_pcs].T

        group_cc = [
            cc_shared_space(resp[:, idx] - pred0[:, idx], U)
            for idx in group_idx
        ]
    elif noise_pcs > 0:
        log.info(f'cc approx: noise_pcs={noise_pcs}')
        group_cc = [
            cc_lowrank(resp[:, idx] - pred0[:, idx], n_pcs=noise_pcs)
            for idx in group_idx
        ]
    else:
        group_cc = [np.cov(resp[:, idx] - pred0[:, idx]) for idx in group_idx]
    group_cc_raw = [np.cov(resp[:, idx] - pred0[:, idx]) for idx in group_idx]

    # variance of projection onto PCs (PCs computed above before masking)
    pcproj0 = (resp - pred0).T.dot(pc_axes.T).T
    pcproj_std = pcproj0.std(axis=1)

    if (use_metric == 'cc_err_w'):

        def metric(d, verbose=False):
            return metrics.cc_err_w(d,
                                    pred_name='pred',
                                    pred0_name=input_name,
                                    group_idx=group_idx,
                                    group_cc=group_cc,
                                    alpha=alpha,
                                    pcproj_std=None,
                                    pc_axes=None,
                                    verbose=verbose)

        log.info(f"fit_ccnorm metric: cc_err_w (alpha={alpha})")

    elif (metric is None) and also_fit_resp:
        log.info(f"resp_cc_err: pred0_name: {input_name} beta: {beta}")
        metric = lambda d: metrics.resp_cc_err(d,
                                               pred_name='pred',
                                               pred0_name=input_name,
                                               group_idx=group_idx,
                                               group_cc=group_cc,
                                               beta=beta,
                                               pcproj_std=None,
                                               pc_axes=None)

    elif (use_metric == 'cc_err_md'):

        def metric(d, verbose=False):
            return metrics.cc_err_md(d,
                                     pred_name='pred',
                                     pred0_name=input_name,
                                     group_idx=group_idx,
                                     group_cc=group_cc,
                                     pcproj_std=None,
                                     pc_axes=None)

        log.info(f"fit_ccnorm metric: cc_err_md")

    elif (metric is None):
        #def cc_err(result, pred_name='pred_lv', resp_name='resp', pred0_name='pred',
        #   group_idx=None, group_cc=None, pcproj_std=None, pc_axes=None):
        # current implementation of cc_err
        metric = lambda d: metrics.cc_err(d,
                                          pred_name='pred',
                                          pred0_name=input_name,
                                          group_idx=group_idx,
                                          group_cc=group_cc,
                                          pcproj_std=None,
                                          pc_axes=None)
        log.info(f"fit_ccnorm metric: cc_err")

    # turn on "fit mode". currently this serves one purpose, for normalization
    # parameters to be re-fit for the output of each module that uses
    # normalization. does nothing if normalization is not being used.
    ms.fit_mode_on(modelspec, est)

    # Create the mapper functions that translates to and from modelspecs.
    # It has three functions that, when defined as mathematical functions, are:
    #    .pack(modelspec) -> fitspace_point
    #    .unpack(fitspace_point) -> modelspec
    #    .bounds(modelspec) -> fitspace_bounds
    packer, unpacker, pack_bounds = mapper(modelspec)

    # A function to evaluate the modelspec on the data
    evaluator = nems.modelspec.evaluate

    my_cost_function = cost_function
    my_cost_function.counter = 0

    # Freeze everything but sigma, since that's all the fitter should be
    # updating.
    cost_fn = partial(my_cost_function,
                      unpacker=unpacker,
                      modelspec=modelspec,
                      data=est,
                      segmentor=segmentor,
                      evaluator=evaluator,
                      metric=metric,
                      display_N=1000)

    # get initial sigma value representing some point in the fit space,
    # and corresponding bounds for each value
    sigma = packer(modelspec)
    bounds = pack_bounds(modelspec)

    # Results should be a list of modelspecs
    # (might only be one in list, but still should be packaged as a list)
    improved_sigma = fitter(sigma, cost_fn, bounds=bounds, **fit_kwargs)
    improved_modelspec = unpacker(improved_sigma)
    elapsed_time = (time.time() - start_time)

    start_err = cost_fn(sigma)
    final_err = cost_fn(improved_sigma)
    log.info("Delta error: %.06f - %.06f = %e", start_err, final_err,
             final_err - start_err)

    # TODO: Should this maybe be moved to a higher level
    # so it applies to ALL the fittters?
    ms.fit_mode_off(improved_modelspec)

    if include_set is not None:
        # pull out updated phi values from improved_modelspec, include_set only
        improved_modelspec = \
            modelspec_unfreeze_layers(improved_modelspec, modelspec0, include_set)
        improved_modelspec.set_fit(fit_index)

    log.info(
        f"Updating improved modelspec with fit_idx={improved_modelspec.fit_index}"
    )
    improved_modelspec.meta['fitter'] = 'ccnorm'
    improved_modelspec.meta['n_parms'] = len(improved_sigma)
    if modelspec.fit_count == 1:
        improved_modelspec.meta['fit_time'] = elapsed_time
        improved_modelspec.meta['loss'] = final_err
    else:
        if modelspec.fit_index == 0:
            improved_modelspec.meta['fit_time'] = np.zeros(
                improved_modelspec.fit_count)
            improved_modelspec.meta['loss'] = np.zeros(
                improved_modelspec.fit_count)
        improved_modelspec.meta['fit_time'][fit_index] = elapsed_time
        improved_modelspec.meta['loss'][fit_index] = final_err

    return {'modelspec': improved_modelspec}
Ejemplo n.º 4
0
def fit_tf(modelspec,
           est: recording.Recording,
           use_modelspec_init: bool = True,
           optimizer: str = 'adam',
           max_iter: int = 10000,
           cost_function: str = 'squared_error',
           early_stopping_steps: int = 5,
           early_stopping_tolerance: float = 5e-4,
           early_stopping_val_split: float = 0,
           learning_rate: float = 1e-4,
           variable_learning_rate: bool = False,
           batch_size: typing.Union[None, int] = None,
           seed: int = 0,
           initializer: str = 'random_normal',
           filepath: typing.Union[str, Path] = None,
           freeze_layers: typing.Union[None, list] = None,
           IsReload: bool = False,
           epoch_name: str = "REFERENCE",
           use_tensorboard: bool = False,
           kernel_regularizer: str = None,
           **context) -> dict:
    """TODO

    :param est:
    :param modelspec:
    :param use_modelspec_init:
    :param optimizer:
    :param max_iter:
    :param cost_function:
    :param early_stopping_steps:
    :param early_stopping_tolerance:
    :param learning_rate:
    :param batch_size:
    :param seed:
    :param filepath:
    :param freeze_layers: Indexes of layers to freeze prior to training. Indexes are modelspec indexes, so are offset
      from model layer indexes.
    :param IsReload:
    :param epoch_name
    :param context:

    :return: dict {'modelspec': modelspec}
    """

    if IsReload:
        return {}

    tf.random.set_seed(seed)
    np.random.seed(seed)
    #os.environ['TF_DETERMINISTIC_OPS'] = '1'   # makes output deterministic, but reduces prediction accuracy

    log.info('Building tensorflow keras model from modelspec.')
    nems.utils.progress_fun()

    # figure out where to save model checkpoints
    job_id = os.environ.get('SLURM_JOBID', None)
    if job_id is not None:
        # if job is running on slurm, need to change model checkpoint dir
        # keep a record of the job id
        modelspec.meta['slurm_jobid'] = job_id

        log_dir_root = Path('/mnt/scratch')
        assert log_dir_root.exists()
        log_dir_base = log_dir_root / Path('SLURM_JOBID' + job_id)
        log_dir_sub = Path(str(modelspec.meta['batch'])) \
                / modelspec.meta.get('cellid', "NOCELL") \
                / modelspec.get_longname()
        filepath = log_dir_base / log_dir_sub
        tbroot = filepath / 'logs'
    elif filepath is None:
        filepath = modelspec.meta['modelpath']
        tbroot = Path(f'/auto/data/tmp/tensorboard/')
    else:
        tbroot = Path(f'/auto/data/tmp/tensorboard/')

    filepath = Path(filepath)
    if not filepath.exists():
        filepath.mkdir(exist_ok=True, parents=True)
    cellid = modelspec.meta.get('cellid', 'CELL')
    tbpath = tbroot / (str(modelspec.meta['batch']) + '_' + cellid + '_' +
                       modelspec.meta['modelname'])
    # TODO: should this code just be deleted then?
    if 0 & use_tensorboard:
        # disabled, this is dumb. it deletes the previous round of fitting (eg, tfinit)
        fileList = glob.glob(str(tbpath / '*' / '*'))
        for filePath in fileList:
            try:
                os.remove(filePath)
            except:
                print("Error while deleting file : ", filePath)

    checkpoint_filepath = filepath / 'weights.hdf5'
    tensorboard_filepath = tbpath
    gradient_filepath = filepath / 'gradients'

    # update seed based on fit index
    seed += modelspec.fit_index

    if (freeze_layers
            is not None) and len(freeze_layers) and (len(freeze_layers)
                                                     == freeze_layers[-1] + 1):
        truncate_model = True
        modelspec_trunc, est_trunc = \
            initializers.modelspec_remove_input_layers(modelspec, est, remove_count=len(freeze_layers))
        modelspec_original = modelspec
        est_original = est
        modelspec = modelspec_trunc
        est = est_trunc
        freeze_layers = None
        log.info(
            f"Special case of freezing: truncating model. fit_index={modelspec.fit_index} cell_index={modelspec.cell_index}"
        )
    else:
        truncate_model = False

    input_name = modelspec.meta.get('input_name', 'stim')
    output_name = modelspec.meta.get('output_name', 'resp')

    # also grab the fs
    fs = est[input_name].fs

    if (epoch_name is not None) and (epoch_name != ""):
        # extract out the raw data, and reshape to (batch, time, channel)
        stim_train = np.transpose(
            est[input_name].extract_epoch(epoch=epoch_name, mask=est['mask']),
            [0, 2, 1])
        resp_train = np.transpose(
            est[output_name].extract_epoch(epoch=epoch_name, mask=est['mask']),
            [0, 2, 1])
    else:
        # extract data as a single batch size (1, time, channel)
        stim_train = np.transpose(
            est.apply_mask()[input_name].as_continuous()[np.newaxis, ...],
            [0, 2, 1])
        resp_train = np.transpose(
            est.apply_mask()[output_name].as_continuous()[np.newaxis, ...],
            [0, 2, 1])

    log.info(
        f'Feature dimensions: {stim_train.shape}; Data dimensions: {resp_train.shape}.'
    )

    if True:
        log.info("adding a tiny bit of noise to resp_train")
        resp_train = resp_train + np.random.randn(*resp_train.shape) / 10000
    # get state if present, and setup training data
    if 'state' in est.signals:
        if (epoch_name is not None) and (epoch_name != ""):
            state_train = np.transpose(
                est['state'].extract_epoch(epoch=epoch_name, mask=est['mask']),
                [0, 2, 1])
        else:
            state_train = np.transpose(
                est.apply_mask()['state'].as_continuous()[np.newaxis, ...],
                [0, 2, 1])
        state_shape = state_train.shape
        log.info(f'State dimensions: {state_shape}')
        train_data = [stim_train, state_train]
    else:
        state_train, state_shape = None, None
        train_data = stim_train

    # get the layers and build the model
    cost_fn = loss_functions.get_loss_fn(cost_function)
    #model_layers = modelspec.modelspec2tf2(
    #    use_modelspec_init=use_modelspec_init, seed=seed, fs=fs,
    #    initializer=initializer, freeze_layers=freeze_layers,
    #    kernel_regularizer=kernel_regularizer)
    model_layers = modelbuilder.modelspec2tf(
        modelspec,
        use_modelspec_init=use_modelspec_init,
        seed=seed,
        fs=fs,
        initializer=initializer,
        freeze_layers=freeze_layers,
        kernel_regularizer=kernel_regularizer)

    if np.any([isinstance(layer, Conv2D_NEMS) for layer in model_layers]):
        # need a "channel" dimension for Conv2D (like rgb channels, not frequency). Only 1 channel for our data.
        stim_train = stim_train[..., np.newaxis]
        train_data = train_data[..., np.newaxis]

    # do some batch sizing logic
    batch_size = stim_train.shape[0] if batch_size == 0 else batch_size

    if variable_learning_rate:
        # TODO: allow other schedule options instead of hard-coding exp decay?
        # TODO: expose exp decay kwargs as kw options? not clear how to choose these parameters
        learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=learning_rate,
            decay_steps=10000,
            decay_rate=0.9)

    from nems.tf.loss_functions import pearson
    model = modelbuilder.ModelBuilder(
        name='Test-model',
        layers=model_layers,
        learning_rate=learning_rate,
        loss_fn=cost_fn,
        optimizer=optimizer,
        metrics=[pearson],
    ).build_model(input_shape=stim_train.shape,
                  state_shape=state_shape,
                  batch_size=batch_size)

    if freeze_layers is not None:
        for freeze_index in freeze_layers:
            log.info(
                f'TF layer #{freeze_index}: "{model.layers[freeze_index + 1].name}" is not trainable.'
            )

    # tracking early termination
    model.early_terminated = False

    # create the callbacks
    early_stopping = callbacks.DelayedStopper(
        monitor='val_loss',
        patience=30 * early_stopping_steps,
        min_delta=early_stopping_tolerance,
        verbose=1,
        restore_best_weights=True)
    regular_stopping = callbacks.DelayedStopper(
        monitor='loss',
        patience=30 * early_stopping_steps,
        min_delta=early_stopping_tolerance,
        verbose=1,
        restore_best_weights=True)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        filepath=str(checkpoint_filepath),
        save_best_only=False,
        save_weights_only=True,
        save_freq=100 * stim_train.shape[0],
        monitor='loss',
        verbose=0)
    sparse_logger = callbacks.SparseProgbarLogger(n_iters=50)
    nan_terminate = tf.keras.callbacks.TerminateOnNaN()
    nan_weight_terminate = callbacks.TerminateOnNaNWeights()
    tensorboard = tf.keras.callbacks.TensorBoard(
        log_dir=str(tensorboard_filepath),  # TODO: generic tensorboard dir?
        histogram_freq=0,  # record the distribution of the weights
        write_graph=False,
        update_freq='epoch',
        profile_batch=0)
    # gradient_logger = callbacks.GradientLogger(filepath=str(gradient_filepath),
    #                                            train_input=stim_train,
    #                                            model=model)

    # save an initial set of weights before freezing, in case of termination before any checkpoints
    #log.info('saving weights to : %s', str(checkpoint_filepath) )
    model.save_weights(str(checkpoint_filepath), overwrite=True)

    if version.parse(tf.__version__) >= version.parse("2.2.0"):
        callback0 = [sparse_logger]
        verbose = 0
    else:
        callback0 = []
        verbose = 2
    # enable the below to log tracked parameters to tensorboard
    if use_tensorboard:
        callback0.append(tensorboard)
        log.info(f'Enabling tensorboard, log: {str(tensorboard_filepath)}')
        # enable the below to record gradients to visualize in tensorboard; this is very slow,
        # and loading all this into tensorboard can use A LOT of memory
        # callback0.append(gradient_logger)

    if early_stopping_val_split > 0:
        callback0.append(early_stopping)
        log.info(
            f'Enabling early stopping, val split: {str(early_stopping_val_split)}'
        )
    else:
        callback0.append(regular_stopping)
        log.info(f'Stop tolerance: min_delta={early_stopping_tolerance}')

    log.info(f'Fitting model (batch_size={batch_size})...')
    history = model.fit(train_data,
                        resp_train,
                        validation_split=early_stopping_val_split,
                        verbose=verbose,
                        epochs=max_iter,
                        callbacks=callback0 + [
                            nan_terminate,
                            nan_weight_terminate,
                            checkpoint,
                        ],
                        batch_size=batch_size)

    # did we terminate on a nan loss or weights? Load checkpoint if so
    if np.all(np.isnan(model.predict(train_data))
              ) or model.early_terminated:  # TODO: should this be np.any()?
        log.warning(
            'Model terminated on nan loss or weights, restoring saved weights.'
        )
        try:
            # this can fail if it nans out before a single checkpoint gets saved, either because no saved weights
            # exist, or it tries to load a in different model from the init
            model.load_weights(str(checkpoint_filepath))
            log.warning('Reloaded previous saved weights after nan loss.')
        except (tf.errors.NotFoundError, ValueError):
            pass

    modelspec = tf2modelspec(model, modelspec)

    if truncate_model:
        log.info("Special case of freezing: restoring truncated model!!!")
        #modelspec_restored, rec_restored = modelspec_restore_input_layers(modelspec_trunc, rec_trunc, modelspec_original)
        modelspec_restored, est_restored = initializers.modelspec_restore_input_layers(
            modelspec, est, modelspec_original)
        est = est_original
        modelspec = modelspec_restored

    # debug: dump modelspec parameters
    #for i in range(len(modelspec)):
    #    log.info(modelspec.phi[i])

    contains_tf_only_layers = np.any(
        ['tf_only' in m['fn'] for m in modelspec.modules])
    if not contains_tf_only_layers:
        # compare the predictions from the model and modelspec
        error = compare_ms_tf(modelspec, model, est, train_data)
        if error > 1e-5:
            log.warning(
                f'Mean difference between NEMS and TF model prediction: {error}'
            )
        else:
            log.info(
                f'Mean difference between NEMS and TF model prediction: {error}'
            )
    else:
        # nothing to compare, ms evaluation is not implemented for this type of model
        pass

    # add in some relevant meta information
    modelspec.meta['n_parms'] = len(modelspec.phi_vector)
    try:
        n_epochs = len(history.history['loss'])
        if 'val_loss' in history.history.keys():
            #val_stop = np.argmin(history.history['val_loss'])
            #loss = history.history['loss'][val_stop]
            loss = np.nanmin(history.history['val_loss'])
        else:
            loss = np.nanmin(history.history['loss'])

    except KeyError:
        n_epochs = 0
        loss = 0
    if modelspec.fit_count == 1:
        modelspec.meta['n_epochs'] = n_epochs
        modelspec.meta['loss'] = loss
    else:
        if modelspec.fit_index == 0:
            modelspec.meta['n_epochs'] = np.zeros(modelspec.fit_count)
            modelspec.meta['loss'] = np.zeros(modelspec.fit_count)
        modelspec.meta['n_epochs'][modelspec.fit_index] = n_epochs
        modelspec.meta['loss'][modelspec.fit_index] = loss

    try:
        max_iter = modelspec.meta['extra_results']
        modelspec.meta['extra_results'] = max(max_iter, n_epochs)
    except KeyError:
        modelspec.meta['extra_results'] = n_epochs

    nems.utils.progress_fun()

    # clean up temp files
    if job_id is not None:
        log.info('removing temporary weights file(s)')
        shutil.rmtree(log_dir_base)

    return {'modelspec': modelspec}