def test_routine_resnet(config): """ Finetune the ResNet-DCov Returns ------- """ nb_epoch_finetune = 2 nb_epoch_after = 50 title = 'test_minc_orig_cov' image_gen = ImageDataGeneratorAdvanced( TARGET_SIZE, RESCALE_SMALL, True, horizontal_flip=True, # preprocessing_function=preprocess_image_for_imagenet # channelwise_std_normalization=True ) run_finetune(ResNet50_o2, mincorig_finetune, input_shape, config, image_gen, nb_classes=nb_classes, nb_epoch_finetune=nb_epoch_finetune, nb_epoch_after=nb_epoch_after, title=title, verbose=(2, 1))
def run_residual_cov_resnet(exp): """ Finetune the ResNet-DCov Returns ------- m """ nb_epoch_finetune = 50 nb_epoch_after = 50 config = get_residual_cov_experiment(exp) title = 'sun_residual_cov' image_gen = ImageDataGeneratorAdvanced( TARGET_SIZE, RESCALE_SMALL, True, horizontal_flip=True, # preprocessing_function=preprocess_image_for_imagenet # channelwise_std_normalization=True ) run_finetune(ResNet50_o2, sun_finetune, input_shape, config, image_gen, nb_classes=nb_classes, nb_epoch_finetune=nb_epoch_finetune, nb_epoch_after=nb_epoch_after, title=title, verbose=(2, 1))
def test_routine_resnet(config, verbose=(1, 2), nb_epoch_finetune=1, nb_epoch_after=50): """ Finetune the ResNet-DCov Returns ------- """ image_gen = ImageDataGeneratorAdvanced( TARGET_SIZE, RESCALE_SMALL, True, horizontal_flip=True, # channelwise_std_normalization=True ) run_finetune(ResNet50_o2, dtd_finetune, nb_classes=nb_classes, input_shape=input_shape, config=config, nb_epoch_finetune=nb_epoch_finetune, nb_epoch_after=nb_epoch_after, image_gen=image_gen, title='test_dtd_resnet', verbose=verbose)
def run_routine_resnet(config, verbose=(2,2), nb_epoch_finetune=15, nb_epoch_after=50, stiefel_observed=None, stiefel_lr=0.01): """ Finetune the ResNet-DCov Returns ------- """ image_gen = ImageDataGeneratorAdvanced(TARGET_SIZE, RESCALE_SMALL, True, horizontal_flip=True, ) monitor_class = (O2Transform, SecondaryStatistic) # monitor_metrics = ['weight_norm',] # monitor_metrics = ['output_norm',] monitor_metrics = ['matrix_image',] if stiefel_observed is None: run_finetune(ResNet50_o2, mincorig_finetune, nb_classes=nb_classes, input_shape=input_shape, config=config, nb_epoch_finetune=nb_epoch_finetune, nb_epoch_after=nb_epoch_after, image_gen=image_gen, title='minc_orig_resnet50', verbose=verbose, monitor_classes=monitor_class, monitor_measures=monitor_metrics) else: run_finetune_with_Stiefel_layer(ResNet50_o2, mincorig_finetune, nb_classes=nb_classes, input_shape=input_shape, config=config, nb_epoch_finetune=nb_epoch_finetune, nb_epoch_after=nb_epoch_after, image_gen=image_gen, title='minc_orig_resnet50_stiefel', verbose=verbose, monitor_classes=monitor_class, monitor_measures=monitor_metrics, observed_keywords=stiefel_observed, lr=stiefel_lr)
def run_model_with_config(model, config, title='cifar10', image_gen=None, verbose=(2, 2), nb_epoch_finetune=15, nb_epoch_after=50, stiefel_observed=None, stiefel_lr=0.01, weight_norm=False, lr_decay=False): """ Finetune the ResNet-DCov Returns ------- """ monitor_class = (O2Transform, SecondaryStatistic) # monitor_metrics = ['weight_norm',] # monitor_metrics = ['output_norm',] monitor_metrics = [ 'matrix_image', ] if stiefel_observed is None: if weight_norm: run_finetune_with_weight_norm(model, cifar_train, nb_classes=nb_classes, input_shape=input_shape, config=config, nb_epoch_finetune=nb_epoch_finetune, nb_epoch_after=nb_epoch_after, image_gen=image_gen, title=title + '-weight_norm', verbose=verbose, monitor_classes=monitor_class, monitor_measures=monitor_metrics, lr_decay=lr_decay) run_finetune(model, cifar_train, nb_classes=nb_classes, input_shape=input_shape, config=config, nb_epoch_finetune=nb_epoch_finetune, nb_epoch_after=nb_epoch_after, image_gen=image_gen, title=title, verbose=verbose, monitor_classes=monitor_class, monitor_measures=monitor_metrics, lr_decay=lr_decay) else: run_finetune_with_Stiefel_layer(model, cifar_train, nb_classes=nb_classes, input_shape=input_shape, config=config, nb_epoch_finetune=nb_epoch_finetune, nb_epoch_after=nb_epoch_after, image_gen=image_gen, title=title + "-stiefel", verbose=verbose, monitor_classes=monitor_class, monitor_measures=monitor_metrics, observed_keywords=stiefel_observed, lr=stiefel_lr, lr_decay=lr_decay)