def fit_pcnorm(modelspec, est: recording.Recording, metric=None, use_modelspec_init: bool = True, optimizer: str = 'adam', max_iter: int = 10000, early_stopping_steps: int = 5, tolerance: float = 5e-4, learning_rate: float = 1e-4, batch_size: typing.Union[None, int] = None, seed: int = 0, initializer: str = 'random_normal', freeze_layers: typing.Union[None, list] = None, epoch_name: str = "REFERENCE", n_pcs=2, **context): ''' Required Arguments: est A recording object modelspec A modelspec object Optional Arguments: <copied from fit_tf for now Returns dictionary: {'modelspec': updated_modelspec} ''' # Hard-coded cost_function = basic_cost fitter = scipy_minimize segmentor = nems.segmentors.use_all_data mapper = nems.fitters.mappers.simple_vector fit_kwargs = {'tolerance': tolerance, 'max_iter': max_iter} start_time = time.time() modelspec = copy.deepcopy(modelspec) # apply mask to remove invalid portions of signals and allow fit to # only evaluate the model on the valid portion of the signals if 'mask' in est.signals.keys(): log.info("Data len pre-mask: %d", est['mask'].shape[1]) est = est.apply_mask() log.info("Data len post-mask: %d", est['mask'].shape[1]) conditions = [ "_".join(k.split("_")[1:]) for k in est.signals.keys() if k.startswith("mask_") ] if (len(conditions) > 2) and any( [c.split("_")[-1] == 'lg' for c in conditions]): conditions.remove("small") conditions.remove("large") #conditions = conditions[0:2] #conditions = ['large','small'] group_idx = [est['mask_' + c].as_continuous()[0, :] for c in conditions] cg_filtered = [(c, g) for c, g in zip(conditions, group_idx) if g.sum() > 0] conditions, group_idx = zip(*cg_filtered) for c, g in zip(conditions, group_idx): log.info(f"Data subset for {c} len {g.sum()}") resp = est['resp'].as_continuous() pred0 = est['pred0'].as_continuous() residual = resp - pred0 pca = PCA(n_components=n_pcs) pca.fit(residual.T) pc_axes = pca.components_ pcproj = residual.T.dot(pc_axes.T).T group_pc = [pcproj[:, idx].std(axis=1) for idx in group_idx] resp_std = resp.std(axis=1) #import pdb; pdb.set_trace() if metric is None: metric = lambda d: pc_err(d, pred_name='pred', pred0_name='pred0', group_idx=group_idx, group_pc=group_pc, pc_axes=pc_axes, resp_std=resp_std) # turn on "fit mode". currently this serves one purpose, for normalization # parameters to be re-fit for the output of each module that uses # normalization. does nothing if normalization is not being used. ms.fit_mode_on(modelspec, est) # Create the mapper functions that translates to and from modelspecs. # It has three functions that, when defined as mathematical functions, are: # .pack(modelspec) -> fitspace_point # .unpack(fitspace_point) -> modelspec # .bounds(modelspec) -> fitspace_bounds packer, unpacker, pack_bounds = mapper(modelspec) # A function to evaluate the modelspec on the data evaluator = nems.modelspec.evaluate my_cost_function = cost_function my_cost_function.counter = 0 # Freeze everything but sigma, since that's all the fitter should be # updating. cost_fn = partial(my_cost_function, unpacker=unpacker, modelspec=modelspec, data=est, segmentor=segmentor, evaluator=evaluator, metric=metric, display_N=1000) # get initial sigma value representing some point in the fit space, # and corresponding bounds for each value sigma = packer(modelspec) bounds = pack_bounds(modelspec) # Results should be a list of modelspecs # (might only be one in list, but still should be packaged as a list) improved_sigma = fitter(sigma, cost_fn, bounds=bounds, **fit_kwargs) improved_modelspec = unpacker(improved_sigma) elapsed_time = (time.time() - start_time) start_err = cost_fn(sigma) final_err = cost_fn(improved_sigma) log.info("Delta error: %.06f - %.06f = %e", start_err, final_err, final_err - start_err) # TODO: Should this maybe be moved to a higher level # so it applies to ALL the fittters? ms.fit_mode_off(improved_modelspec) ms.set_modelspec_metadata(improved_modelspec, 'fitter', 'ccnorm') ms.set_modelspec_metadata(improved_modelspec, 'fit_time', elapsed_time) ms.set_modelspec_metadata(improved_modelspec, 'n_parms', len(improved_sigma)) return {'modelspec': improved_modelspec.copy(), 'save_context': True}
def fit_tf(modelspec, est: recording.Recording, use_modelspec_init: bool = True, optimizer: str = 'adam', max_iter: int = 10000, cost_function: str = 'squared_error', early_stopping_steps: int = 5, early_stopping_tolerance: float = 5e-4, learning_rate: float = 1e-4, batch_size: typing.Union[None, int] = None, seed: int = 0, initializer: str = 'random_normal', filepath: typing.Union[str, Path] = None, freeze_layers: typing.Union[None, list] = None, IsReload: bool = False, epoch_name: str = "REFERENCE", **context) -> dict: """TODO :param est: :param modelspec: :param use_modelspec_init: :param optimizer: :param max_iter: :param cost_function: :param early_stopping_steps: :param early_stopping_tolerance: :param learning_rate: :param batch_size: :param seed: :param filepath: :param freeze_layers: Indexes of layers to freeze prior to training. Indexes are modelspec indexes, so are offset from model layer indexes. :param IsReload: :param epoch_name :param context: :return: dict {'modelspec': modelspec} """ if IsReload: return {} tf.random.set_seed(seed) np.random.seed(seed) #os.environ['TF_DETERMINISTIC_OPS'] = '1' # makes output deterministic, but reduces prediction accuracy log.info('Building tensorflow keras model from modelspec.') nems.utils.progress_fun() # figure out where to save model checkpoints if filepath is None: filepath = modelspec.meta['modelpath'] # if job is running on slurm, need to change model checkpoint dir job_id = os.environ.get('SLURM_JOBID', None) if job_id is not None: # keep a record of the job id modelspec.meta['slurm_jobid'] = job_id log_dir_root = Path('/mnt/scratch') assert log_dir_root.exists() log_dir_sub = Path('SLURM_JOBID' + job_id) / str(modelspec.meta['batch'])\ / modelspec.meta.get('cellid', "NOCELL")\ / modelspec.meta['modelname'] filepath = log_dir_root / log_dir_sub filepath = Path(filepath) if not filepath.exists(): filepath.mkdir(exist_ok=True, parents=True) checkpoint_filepath = filepath / 'weights.hdf5' tensorboard_filepath = filepath / 'logs' gradient_filepath = filepath / 'gradients' # update seed based on fit index seed += modelspec.fit_index # need to get duration of stims in order to reshape data #epoch_name = 'REFERENCE' # TODO: this should not be hardcoded # moved to input parameter input_name = modelspec.meta.get('input_name', 'stim') output_name = modelspec.meta.get('output_name', 'resp') # also grab the fs fs = est[input_name].fs if (epoch_name is not None) and (epoch_name != ""): # extract out the raw data, and reshape to (batch, time, channel) stim_train = np.transpose( est[input_name].extract_epoch(epoch=epoch_name, mask=est['mask']), [0, 2, 1]) resp_train = np.transpose( est[output_name].extract_epoch(epoch=epoch_name, mask=est['mask']), [0, 2, 1]) else: # extract data as a single batch size (1, time, channel) stim_train = np.transpose( est.apply_mask()[input_name].as_continuous()[np.newaxis, ...], [0, 2, 1]) resp_train = np.transpose( est.apply_mask()[output_name].as_continuous()[np.newaxis, ...], [0, 2, 1]) log.info( f'Feature dimensions: {stim_train.shape}; Data dimensions: {resp_train.shape}.' ) # get state if present, and setup training data if 'state' in est.signals: if (epoch_name is not None) and (epoch_name != ""): state_train = np.transpose( est['state'].extract_epoch(epoch=epoch_name, mask=est['mask']), [0, 2, 1]) else: state_train = np.transpose( est.apply_mask()['state'].as_continuous()[np.newaxis, ...], [0, 2, 1]) state_shape = state_train.shape log.info(f'State dimensions: {state_shape}') train_data = [stim_train, state_train] else: state_train, state_shape = None, None train_data = stim_train # correlation for monitoring # TODO: tf.utils? def pearson(y_true, y_pred): return tfp.stats.correlation(y_true, y_pred, event_axis=None, sample_axis=None) # get the layers and build the model cost_fn = loss_functions.get_loss_fn(cost_function) model_layers = modelspec.modelspec2tf2( use_modelspec_init=use_modelspec_init, seed=seed, fs=fs, initializer=initializer) if np.any([isinstance(layer, Conv2D_NEMS) for layer in model_layers]): # need a "channel" dimension for Conv2D (like rgb channels, not frequency). Only 1 channel for our data. stim_train = stim_train[..., np.newaxis] train_data = train_data[..., np.newaxis] # do some batch sizing logic batch_size = stim_train.shape[0] if batch_size == 0 else batch_size model = modelbuilder.ModelBuilder( name='Test-model', layers=model_layers, learning_rate=learning_rate, loss_fn=cost_fn, optimizer=optimizer, metrics=[pearson], ).build_model(input_shape=stim_train.shape, state_shape=state_shape, batch_size=batch_size) # tracking early termination model.early_terminated = False # create the callbacks early_stopping = callbacks.DelayedStopper( monitor='loss', patience=30 * early_stopping_steps, min_delta=early_stopping_tolerance, verbose=1, restore_best_weights=False) checkpoint = tf.keras.callbacks.ModelCheckpoint( filepath=str(checkpoint_filepath), save_best_only=False, save_weights_only=True, save_freq=100 * stim_train.shape[0], monitor='loss', verbose=0) sparse_logger = callbacks.SparseProgbarLogger(n_iters=10) nan_terminate = tf.keras.callbacks.TerminateOnNaN() nan_weight_terminate = callbacks.TerminateOnNaNWeights() tensorboard = tf.keras.callbacks.TensorBoard( log_dir=str(tensorboard_filepath), # TODO: generic tensorboard dir? histogram_freq=0, # record the distribution of the weights write_graph=False, update_freq='epoch', profile_batch=0) # gradient_logger = callbacks.GradientLogger(filepath=str(gradient_filepath), # train_input=stim_train, # model=model) # freeze layers if freeze_layers is not None: for freeze_index in freeze_layers: log.info( f'Freezing layer #{freeze_index}: "{model.layers[freeze_index + 1].name}".' ) model.layers[freeze_index + 1].trainable = False # save an initial set of weights before freezing, in case of termination before any checkpoints #log.info('saving weights to : %s', str(checkpoint_filepath) ) model.save_weights(str(checkpoint_filepath), overwrite=True) if version.parse(tf.__version__) >= version.parse("2.2.0"): callback0 = [sparse_logger] verbose = 0 else: callback0 = [] verbose = 2 log.info(f'Fitting model (batch_size={batch_size})...') history = model.fit( train_data, resp_train, # validation_split=0.2, verbose=verbose, epochs=max_iter, batch_size=batch_size, callbacks=callback0 + [ nan_terminate, nan_weight_terminate, early_stopping, checkpoint, # enable the below to log tracked parameters to tensorboard # tensorboard, # enable the below to record gradients to visualize in tensorboard; this is very slow, # and loading all this into tensorboard can use A LOT of memory # gradient_logger, ]) # did we terminate on a nan loss or weights? Load checkpoint if so if np.all(np.isnan(model.predict(train_data)) ) or model.early_terminated: # TODO: should this be np.any()? log.warning( 'Model terminated on nan loss or weights, restoring saved weights.' ) try: # this can fail if it nans out before a single checkpoint gets saved, either because no saved weights # exist, or it tries to load a in different model from the init model.load_weights(str(checkpoint_filepath)) log.warning('Reloaded previous saved weights after nan loss.') except (tf.errors.NotFoundError, ValueError): pass modelspec = tf2modelspec(model, modelspec) contains_tf_only_layers = np.any( ['tf_only' in m['fn'] for m in modelspec.modules]) if not contains_tf_only_layers: # compare the predictions from the model and modelspec error = compare_ms_tf(modelspec, model, est, train_data) if error > 1e-5: log.warning( f'Mean difference between NEMS and TF model prediction: {error}' ) else: log.info( f'Mean difference between NEMS and TF model prediction: {error}' ) else: # nothing to compare, ms evaluation is not implemented for this type of model pass # add in some relevant meta information modelspec.meta['n_parms'] = len(modelspec.phi_vector) try: n_epochs = len(history.history['loss']) except KeyError: n_epochs = 0 try: max_iter = modelspec.meta['extra_results'] modelspec.meta['extra_results'] = max(max_iter, n_epochs) except KeyError: modelspec.meta['extra_results'] = n_epochs nems.utils.progress_fun() return {'modelspec': modelspec}
def fit_ccnorm(modelspec, est: recording.Recording, metric=None, use_modelspec_init: bool = True, optimizer: str = 'adam', max_iter: int = 10000, early_stopping_steps: int = 5, tolerance: float = 5e-4, learning_rate: float = 1e-4, batch_size: typing.Union[None, int] = None, seed: int = 0, initializer: str = 'random_normal', freeze_layers: typing.Union[None, list] = None, epoch_name: str = "REFERENCE", shrink_cc: float = 0, noise_pcs: int = 0, shared_pcs: int = 0, also_fit_resp: bool = False, force_psth: bool = False, use_metric: typing.Union[None, str] = None, alpha: float = 0.1, beta: float = 1, exclude_idx=None, exclude_after=None, freeze_idx=None, freeze_after=None, **context): ''' Required Arguments: est A recording object modelspec A modelspec object Optional Arguments: <copied from fit_tf for now Returns dictionary: {'modelspec': updated_modelspec} ''' # Hard-coded cost_function = basic_cost fitter = scipy_minimize segmentor = nems.segmentors.use_all_data mapper = nems.fitters.mappers.simple_vector fit_kwargs = {'tolerance': tolerance, 'max_iter': max_iter} start_time = time.time() fit_index = modelspec.fit_index if (exclude_idx is not None) | (freeze_idx is not None) | \ (exclude_after is not None) | (freeze_after is not None): modelspec0 = modelspec.copy() modelspec, include_set = modelspec_freeze_layers( modelspec, include_idx=None, exclude_idx=exclude_idx, exclude_after=exclude_after, freeze_idx=freeze_idx, freeze_after=freeze_after) modelspec0.set_fit(fit_index) modelspec.set_fit(fit_index) else: include_set = None # Computing PCs before masking out unwanted stimuli in order to # preserve match with epochs epoch_regex = "^STIM_" stims = (est.epochs['name'].value_counts() >= 8) stims = [ stims.index[i] for i, s in enumerate(stims) if bool(re.search(epoch_regex, stims.index[i])) and s == True ] Rall_u = est.apply_mask()['psth'].as_continuous().T # can't simply extract evoked for refs because can be longer/shorted if it came after target # and / or if it was the last stim. So, masking prestim / postim doesn't work. Do it manually #d = est['resp'].extract_epochs(stims, mask=est['mask']) #R = [v.mean(axis=0) for (k, v) in d.items()] #R = [np.reshape(np.transpose(v,[1,0,2]),[v.shape[1],-1]) for (k, v) in d.items()] #Rall_u = np.hstack(R).T pca = PCA(n_components=2) pca.fit(Rall_u) pc_axes = pca.components_ # apply mask to remove invalid portions of signals and allow fit to # only evaluate the model on the valid portion of the signals if 'mask_small' in est.signals.keys(): log.info('reseting mask with mask_small+mask_large subset') est['mask'] = est['mask']._modified_copy(data=est['mask_small']._data + est['mask_large']._data) if 'mask' in est.signals.keys(): log.info("Data len pre-mask: %d", est['mask'].shape[1]) est = est.apply_mask() log.info("Data len post-mask: %d", est['mask'].shape[1]) # if we want to fit to first-order cc error. #uncomment this and make sure sdexp is generating a pred0 signal est = modelspec.evaluate(est, stop=2) if ('pred0' in est.signals.keys()) & (not force_psth): input_name = 'pred0' log.info('Found pred0 for fitting CC') else: input_name = 'psth' log.info('No pred0, using psth for fitting CC') conditions = [ "_".join(k.split("_")[1:]) for k in est.signals.keys() if k.startswith("mask_") ] if (len(conditions) > 2) and any( [c.split("_")[-1] == 'lg' for c in conditions]): conditions.remove("small") conditions.remove("large") #conditions = conditions[0:2] #conditions = ['large','small'] group_idx = [est['mask_' + c].as_continuous()[0, :] for c in conditions] cg_filtered = [(c, g) for c, g in zip(conditions, group_idx) if g.sum() > 0] conditions, group_idx = zip(*cg_filtered) for c, g in zip(conditions, group_idx): log.info(f"cc data for {c} len {g.sum()}") resp = est['resp'].as_continuous() pred0 = est[input_name].as_continuous() #import pdb; pdb.set_trace() if shrink_cc > 0: log.info(f'cc approx: shrink_cc={shrink_cc}') group_cc = [ cc_shrink(resp[:, idx] - pred0[:, idx], sigrat=shrink_cc) for idx in group_idx ] elif shared_pcs > 0: log.info(f'cc approx: shared_pcs={shared_pcs}') cc = np.cov(resp - pred0) u, s, vh = np.linalg.svd(cc) U = u[:, :shared_pcs] @ u[:, :shared_pcs].T group_cc = [ cc_shared_space(resp[:, idx] - pred0[:, idx], U) for idx in group_idx ] elif noise_pcs > 0: log.info(f'cc approx: noise_pcs={noise_pcs}') group_cc = [ cc_lowrank(resp[:, idx] - pred0[:, idx], n_pcs=noise_pcs) for idx in group_idx ] else: group_cc = [np.cov(resp[:, idx] - pred0[:, idx]) for idx in group_idx] group_cc_raw = [np.cov(resp[:, idx] - pred0[:, idx]) for idx in group_idx] # variance of projection onto PCs (PCs computed above before masking) pcproj0 = (resp - pred0).T.dot(pc_axes.T).T pcproj_std = pcproj0.std(axis=1) if (use_metric == 'cc_err_w'): def metric(d, verbose=False): return metrics.cc_err_w(d, pred_name='pred', pred0_name=input_name, group_idx=group_idx, group_cc=group_cc, alpha=alpha, pcproj_std=None, pc_axes=None, verbose=verbose) log.info(f"fit_ccnorm metric: cc_err_w (alpha={alpha})") elif (metric is None) and also_fit_resp: log.info(f"resp_cc_err: pred0_name: {input_name} beta: {beta}") metric = lambda d: metrics.resp_cc_err(d, pred_name='pred', pred0_name=input_name, group_idx=group_idx, group_cc=group_cc, beta=beta, pcproj_std=None, pc_axes=None) elif (use_metric == 'cc_err_md'): def metric(d, verbose=False): return metrics.cc_err_md(d, pred_name='pred', pred0_name=input_name, group_idx=group_idx, group_cc=group_cc, pcproj_std=None, pc_axes=None) log.info(f"fit_ccnorm metric: cc_err_md") elif (metric is None): #def cc_err(result, pred_name='pred_lv', resp_name='resp', pred0_name='pred', # group_idx=None, group_cc=None, pcproj_std=None, pc_axes=None): # current implementation of cc_err metric = lambda d: metrics.cc_err(d, pred_name='pred', pred0_name=input_name, group_idx=group_idx, group_cc=group_cc, pcproj_std=None, pc_axes=None) log.info(f"fit_ccnorm metric: cc_err") # turn on "fit mode". currently this serves one purpose, for normalization # parameters to be re-fit for the output of each module that uses # normalization. does nothing if normalization is not being used. ms.fit_mode_on(modelspec, est) # Create the mapper functions that translates to and from modelspecs. # It has three functions that, when defined as mathematical functions, are: # .pack(modelspec) -> fitspace_point # .unpack(fitspace_point) -> modelspec # .bounds(modelspec) -> fitspace_bounds packer, unpacker, pack_bounds = mapper(modelspec) # A function to evaluate the modelspec on the data evaluator = nems.modelspec.evaluate my_cost_function = cost_function my_cost_function.counter = 0 # Freeze everything but sigma, since that's all the fitter should be # updating. cost_fn = partial(my_cost_function, unpacker=unpacker, modelspec=modelspec, data=est, segmentor=segmentor, evaluator=evaluator, metric=metric, display_N=1000) # get initial sigma value representing some point in the fit space, # and corresponding bounds for each value sigma = packer(modelspec) bounds = pack_bounds(modelspec) # Results should be a list of modelspecs # (might only be one in list, but still should be packaged as a list) improved_sigma = fitter(sigma, cost_fn, bounds=bounds, **fit_kwargs) improved_modelspec = unpacker(improved_sigma) elapsed_time = (time.time() - start_time) start_err = cost_fn(sigma) final_err = cost_fn(improved_sigma) log.info("Delta error: %.06f - %.06f = %e", start_err, final_err, final_err - start_err) # TODO: Should this maybe be moved to a higher level # so it applies to ALL the fittters? ms.fit_mode_off(improved_modelspec) if include_set is not None: # pull out updated phi values from improved_modelspec, include_set only improved_modelspec = \ modelspec_unfreeze_layers(improved_modelspec, modelspec0, include_set) improved_modelspec.set_fit(fit_index) log.info( f"Updating improved modelspec with fit_idx={improved_modelspec.fit_index}" ) improved_modelspec.meta['fitter'] = 'ccnorm' improved_modelspec.meta['n_parms'] = len(improved_sigma) if modelspec.fit_count == 1: improved_modelspec.meta['fit_time'] = elapsed_time improved_modelspec.meta['loss'] = final_err else: if modelspec.fit_index == 0: improved_modelspec.meta['fit_time'] = np.zeros( improved_modelspec.fit_count) improved_modelspec.meta['loss'] = np.zeros( improved_modelspec.fit_count) improved_modelspec.meta['fit_time'][fit_index] = elapsed_time improved_modelspec.meta['loss'][fit_index] = final_err return {'modelspec': improved_modelspec}
def fit_tf(modelspec, est: recording.Recording, use_modelspec_init: bool = True, optimizer: str = 'adam', max_iter: int = 10000, cost_function: str = 'squared_error', early_stopping_steps: int = 5, early_stopping_tolerance: float = 5e-4, early_stopping_val_split: float = 0, learning_rate: float = 1e-4, variable_learning_rate: bool = False, batch_size: typing.Union[None, int] = None, seed: int = 0, initializer: str = 'random_normal', filepath: typing.Union[str, Path] = None, freeze_layers: typing.Union[None, list] = None, IsReload: bool = False, epoch_name: str = "REFERENCE", use_tensorboard: bool = False, kernel_regularizer: str = None, **context) -> dict: """TODO :param est: :param modelspec: :param use_modelspec_init: :param optimizer: :param max_iter: :param cost_function: :param early_stopping_steps: :param early_stopping_tolerance: :param learning_rate: :param batch_size: :param seed: :param filepath: :param freeze_layers: Indexes of layers to freeze prior to training. Indexes are modelspec indexes, so are offset from model layer indexes. :param IsReload: :param epoch_name :param context: :return: dict {'modelspec': modelspec} """ if IsReload: return {} tf.random.set_seed(seed) np.random.seed(seed) #os.environ['TF_DETERMINISTIC_OPS'] = '1' # makes output deterministic, but reduces prediction accuracy log.info('Building tensorflow keras model from modelspec.') nems.utils.progress_fun() # figure out where to save model checkpoints job_id = os.environ.get('SLURM_JOBID', None) if job_id is not None: # if job is running on slurm, need to change model checkpoint dir # keep a record of the job id modelspec.meta['slurm_jobid'] = job_id log_dir_root = Path('/mnt/scratch') assert log_dir_root.exists() log_dir_base = log_dir_root / Path('SLURM_JOBID' + job_id) log_dir_sub = Path(str(modelspec.meta['batch'])) \ / modelspec.meta.get('cellid', "NOCELL") \ / modelspec.get_longname() filepath = log_dir_base / log_dir_sub tbroot = filepath / 'logs' elif filepath is None: filepath = modelspec.meta['modelpath'] tbroot = Path(f'/auto/data/tmp/tensorboard/') else: tbroot = Path(f'/auto/data/tmp/tensorboard/') filepath = Path(filepath) if not filepath.exists(): filepath.mkdir(exist_ok=True, parents=True) cellid = modelspec.meta.get('cellid', 'CELL') tbpath = tbroot / (str(modelspec.meta['batch']) + '_' + cellid + '_' + modelspec.meta['modelname']) # TODO: should this code just be deleted then? if 0 & use_tensorboard: # disabled, this is dumb. it deletes the previous round of fitting (eg, tfinit) fileList = glob.glob(str(tbpath / '*' / '*')) for filePath in fileList: try: os.remove(filePath) except: print("Error while deleting file : ", filePath) checkpoint_filepath = filepath / 'weights.hdf5' tensorboard_filepath = tbpath gradient_filepath = filepath / 'gradients' # update seed based on fit index seed += modelspec.fit_index if (freeze_layers is not None) and len(freeze_layers) and (len(freeze_layers) == freeze_layers[-1] + 1): truncate_model = True modelspec_trunc, est_trunc = \ initializers.modelspec_remove_input_layers(modelspec, est, remove_count=len(freeze_layers)) modelspec_original = modelspec est_original = est modelspec = modelspec_trunc est = est_trunc freeze_layers = None log.info( f"Special case of freezing: truncating model. fit_index={modelspec.fit_index} cell_index={modelspec.cell_index}" ) else: truncate_model = False input_name = modelspec.meta.get('input_name', 'stim') output_name = modelspec.meta.get('output_name', 'resp') # also grab the fs fs = est[input_name].fs if (epoch_name is not None) and (epoch_name != ""): # extract out the raw data, and reshape to (batch, time, channel) stim_train = np.transpose( est[input_name].extract_epoch(epoch=epoch_name, mask=est['mask']), [0, 2, 1]) resp_train = np.transpose( est[output_name].extract_epoch(epoch=epoch_name, mask=est['mask']), [0, 2, 1]) else: # extract data as a single batch size (1, time, channel) stim_train = np.transpose( est.apply_mask()[input_name].as_continuous()[np.newaxis, ...], [0, 2, 1]) resp_train = np.transpose( est.apply_mask()[output_name].as_continuous()[np.newaxis, ...], [0, 2, 1]) log.info( f'Feature dimensions: {stim_train.shape}; Data dimensions: {resp_train.shape}.' ) if True: log.info("adding a tiny bit of noise to resp_train") resp_train = resp_train + np.random.randn(*resp_train.shape) / 10000 # get state if present, and setup training data if 'state' in est.signals: if (epoch_name is not None) and (epoch_name != ""): state_train = np.transpose( est['state'].extract_epoch(epoch=epoch_name, mask=est['mask']), [0, 2, 1]) else: state_train = np.transpose( est.apply_mask()['state'].as_continuous()[np.newaxis, ...], [0, 2, 1]) state_shape = state_train.shape log.info(f'State dimensions: {state_shape}') train_data = [stim_train, state_train] else: state_train, state_shape = None, None train_data = stim_train # get the layers and build the model cost_fn = loss_functions.get_loss_fn(cost_function) #model_layers = modelspec.modelspec2tf2( # use_modelspec_init=use_modelspec_init, seed=seed, fs=fs, # initializer=initializer, freeze_layers=freeze_layers, # kernel_regularizer=kernel_regularizer) model_layers = modelbuilder.modelspec2tf( modelspec, use_modelspec_init=use_modelspec_init, seed=seed, fs=fs, initializer=initializer, freeze_layers=freeze_layers, kernel_regularizer=kernel_regularizer) if np.any([isinstance(layer, Conv2D_NEMS) for layer in model_layers]): # need a "channel" dimension for Conv2D (like rgb channels, not frequency). Only 1 channel for our data. stim_train = stim_train[..., np.newaxis] train_data = train_data[..., np.newaxis] # do some batch sizing logic batch_size = stim_train.shape[0] if batch_size == 0 else batch_size if variable_learning_rate: # TODO: allow other schedule options instead of hard-coding exp decay? # TODO: expose exp decay kwargs as kw options? not clear how to choose these parameters learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=learning_rate, decay_steps=10000, decay_rate=0.9) from nems.tf.loss_functions import pearson model = modelbuilder.ModelBuilder( name='Test-model', layers=model_layers, learning_rate=learning_rate, loss_fn=cost_fn, optimizer=optimizer, metrics=[pearson], ).build_model(input_shape=stim_train.shape, state_shape=state_shape, batch_size=batch_size) if freeze_layers is not None: for freeze_index in freeze_layers: log.info( f'TF layer #{freeze_index}: "{model.layers[freeze_index + 1].name}" is not trainable.' ) # tracking early termination model.early_terminated = False # create the callbacks early_stopping = callbacks.DelayedStopper( monitor='val_loss', patience=30 * early_stopping_steps, min_delta=early_stopping_tolerance, verbose=1, restore_best_weights=True) regular_stopping = callbacks.DelayedStopper( monitor='loss', patience=30 * early_stopping_steps, min_delta=early_stopping_tolerance, verbose=1, restore_best_weights=True) checkpoint = tf.keras.callbacks.ModelCheckpoint( filepath=str(checkpoint_filepath), save_best_only=False, save_weights_only=True, save_freq=100 * stim_train.shape[0], monitor='loss', verbose=0) sparse_logger = callbacks.SparseProgbarLogger(n_iters=50) nan_terminate = tf.keras.callbacks.TerminateOnNaN() nan_weight_terminate = callbacks.TerminateOnNaNWeights() tensorboard = tf.keras.callbacks.TensorBoard( log_dir=str(tensorboard_filepath), # TODO: generic tensorboard dir? histogram_freq=0, # record the distribution of the weights write_graph=False, update_freq='epoch', profile_batch=0) # gradient_logger = callbacks.GradientLogger(filepath=str(gradient_filepath), # train_input=stim_train, # model=model) # save an initial set of weights before freezing, in case of termination before any checkpoints #log.info('saving weights to : %s', str(checkpoint_filepath) ) model.save_weights(str(checkpoint_filepath), overwrite=True) if version.parse(tf.__version__) >= version.parse("2.2.0"): callback0 = [sparse_logger] verbose = 0 else: callback0 = [] verbose = 2 # enable the below to log tracked parameters to tensorboard if use_tensorboard: callback0.append(tensorboard) log.info(f'Enabling tensorboard, log: {str(tensorboard_filepath)}') # enable the below to record gradients to visualize in tensorboard; this is very slow, # and loading all this into tensorboard can use A LOT of memory # callback0.append(gradient_logger) if early_stopping_val_split > 0: callback0.append(early_stopping) log.info( f'Enabling early stopping, val split: {str(early_stopping_val_split)}' ) else: callback0.append(regular_stopping) log.info(f'Stop tolerance: min_delta={early_stopping_tolerance}') log.info(f'Fitting model (batch_size={batch_size})...') history = model.fit(train_data, resp_train, validation_split=early_stopping_val_split, verbose=verbose, epochs=max_iter, callbacks=callback0 + [ nan_terminate, nan_weight_terminate, checkpoint, ], batch_size=batch_size) # did we terminate on a nan loss or weights? Load checkpoint if so if np.all(np.isnan(model.predict(train_data)) ) or model.early_terminated: # TODO: should this be np.any()? log.warning( 'Model terminated on nan loss or weights, restoring saved weights.' ) try: # this can fail if it nans out before a single checkpoint gets saved, either because no saved weights # exist, or it tries to load a in different model from the init model.load_weights(str(checkpoint_filepath)) log.warning('Reloaded previous saved weights after nan loss.') except (tf.errors.NotFoundError, ValueError): pass modelspec = tf2modelspec(model, modelspec) if truncate_model: log.info("Special case of freezing: restoring truncated model!!!") #modelspec_restored, rec_restored = modelspec_restore_input_layers(modelspec_trunc, rec_trunc, modelspec_original) modelspec_restored, est_restored = initializers.modelspec_restore_input_layers( modelspec, est, modelspec_original) est = est_original modelspec = modelspec_restored # debug: dump modelspec parameters #for i in range(len(modelspec)): # log.info(modelspec.phi[i]) contains_tf_only_layers = np.any( ['tf_only' in m['fn'] for m in modelspec.modules]) if not contains_tf_only_layers: # compare the predictions from the model and modelspec error = compare_ms_tf(modelspec, model, est, train_data) if error > 1e-5: log.warning( f'Mean difference between NEMS and TF model prediction: {error}' ) else: log.info( f'Mean difference between NEMS and TF model prediction: {error}' ) else: # nothing to compare, ms evaluation is not implemented for this type of model pass # add in some relevant meta information modelspec.meta['n_parms'] = len(modelspec.phi_vector) try: n_epochs = len(history.history['loss']) if 'val_loss' in history.history.keys(): #val_stop = np.argmin(history.history['val_loss']) #loss = history.history['loss'][val_stop] loss = np.nanmin(history.history['val_loss']) else: loss = np.nanmin(history.history['loss']) except KeyError: n_epochs = 0 loss = 0 if modelspec.fit_count == 1: modelspec.meta['n_epochs'] = n_epochs modelspec.meta['loss'] = loss else: if modelspec.fit_index == 0: modelspec.meta['n_epochs'] = np.zeros(modelspec.fit_count) modelspec.meta['loss'] = np.zeros(modelspec.fit_count) modelspec.meta['n_epochs'][modelspec.fit_index] = n_epochs modelspec.meta['loss'][modelspec.fit_index] = loss try: max_iter = modelspec.meta['extra_results'] modelspec.meta['extra_results'] = max(max_iter, n_epochs) except KeyError: modelspec.meta['extra_results'] = n_epochs nems.utils.progress_fun() # clean up temp files if job_id is not None: log.info('removing temporary weights file(s)') shutil.rmtree(log_dir_base) return {'modelspec': modelspec}