def sacred_experiment_with_config(config, name, main_fcn, db_name, base_dir, checkpoint_dir, sources=[], tune_config={}):
    """Launch a sacred experiment."""
    # creating a sacred experiment
    # https://github.com/IDSIA/sacred/issues/492
    from sacred import Experiment, SETTINGS
    SETTINGS.CONFIG.READ_ONLY_CONFIG = False

    ex = Experiment(name, base_dir=base_dir)
    ex.observers.append(MongoObserver(db_name=db_name))

    for f in sources:
        if isinstance(f, str):
            f_py = f + '.py'
            shutil.copy(f, f_py)
            ex.add_source_file(f_py)

    export_config = dict(config)
    export_config.update(tune_config)
    ex.add_config(config=tune_config, **tune_config)

    @ex.main
    def run_train():
        return main_fcn(config=config, checkpoint_dir=checkpoint_dir, ex=ex)

    return ex.run()
示例#2
0
class SacredTrainer(BaseTrainer):

    # TODO: this is old, I think it should't be used anymore, Neptune is better suited
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        args = self.args

        args_sacred = get_maybe_missing_args(args.loggers, 'sacred')
        if args_sacred is None:
            self.use_sacred = False
        else:
            self.use_sacred = args_sacred.use_sacred

        if self.use_sacred:
            self.sacred_exp = Experiment(args.exp_name)
            self.sacred_exp.captured_out_filter = apply_backspaces_and_linefeeds
            self.sacred_exp.add_config(vars(args))
            for source in self.get_sources():
                self.sacred_exp.add_source_file(source)

            if not args_sacred.sacred.mongodb_disable:
                url = "{0.mongodb_url}:{0.mongodb_port}".format(args)
                if (args_sacred.mongodb_name is not None
                        and args_sacred.mongodb_name != ''):
                    db_name = args_sacred.mongodb_name
                else:
                    db_name = args_sacred.mongodb_prefix + ''.join(
                        filter(str.isalnum, args.dataset_name.lower()))

                self.console_log.info('Connect to MongoDB@{}:{}'.format(
                    url, db_name))
                self.sacred_exp.observers.append(
                    MongoObserver.create(url=url, db_name=db_name))

    def log_sacred_scalar(self, name, val, step):
        if self.use_sacred and self.sacred_exp.current_run:
            self.sacred_exp.current_run.log_scalar(name, val, step)

    def get_sources(self):
        sources = []
        # The network file
        sources.append(inspect.getfile(self.model.__class__))
        # the main script
        sources.append(sys.argv[0])
        # and any user custom submodule
        for module in self.model.children():
            module_path = inspect.getfile(module.__class__)
            if 'site-packages' not in module_path:
                sources.append(module_path)
        return sources

    @main_ifsacred
    def fit(self):
        super().fit()

    def json_results(self, savedir, testscore):
        super().json_results(savedir, testscore)
        json_path = os.path.join(savedir, "results.json")
        if self.use_sacred and self.sacred_exp.current_run:
            self.sacred_exp.current_run.current_run.add_artifact(json_path)
示例#3
0
def add_package_sources(ex: Experiment):
    this_dir = os.path.dirname(__file__)
    package_dirs = [this_dir, os.path.join(this_dir, "..")]
    for package_dir in package_dirs:
        for name in os.listdir(package_dir):
            if name.endswith(".py"):
                ex.add_source_file(
                    os.path.abspath(os.path.join(package_dir, name)))
示例#4
0
文件: sacred.py 项目: yuan-yin/UNISST
def sacred_run(command, default_configs_root='configs/default'):
    ex = Experiment('default')
    files = glob.glob('./src/**/*.py', recursive=True)
    for f in files:
        ex.add_source_file(f)

    @ex.config_hook
    def default_config(config, command_name, logger):
        default_config = {}
        for comp, conf in config.items():
            default_file_path = os.path.join(default_configs_root, f'{comp}.yaml')
            default_config[comp] = get_component_configs(config, comp, default_file_path)

        return default_config

    ex.main(command)
    ex.run_commandline()
示例#5
0
from sacred import Experiment
# https://sacred.readthedocs.io/en/latest/configuration.html#prefi
#from sacred.observers import MongoObserver
import os

ex = Experiment('ISBI2012 U-Net')
ex.add_source_file("main.py")
ex.add_source_file("ISBI2012Data.py")
ex.add_source_file("model.py")
ex.add_source_file("dataaug.py")

ex.add_config('config.json')


# if the directory already exists, add a number to it.
# therefore dont overwrite old stuff
@ex.config
def my_config(params):
    if not os.path.exists(params["savedir"]):
        os.makedirs(str(params["savedir"]))
    elif not params["resume"]:
        dirindex = 1
        while os.path.exists(params["savedir"][:-1] + str(dirindex) + "/"):
            dirindex += 1
        params["savedir"] = params["savedir"][:-1] + str(dirindex) + "/"
        os.makedirs(str(params["savedir"]))
    else:
        params["savedir"] = params["resume"][:params["resume"].rfind("/")+1]

    # if not params["evaluate"]:
    #     mongourl = (("mongodb://mongodbconnection")
示例#6
0
from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.utils import apply_backspaces_and_linefeeds

sacred.SETTINGS['CONFIG']['READ_ONLY_CONFIG'] = False
sacred.SETTINGS.CAPTURE_MODE = 'no'

ex = Experiment('PANetExt')
ex.captured_out_filter = apply_backspaces_and_linefeeds

source_folders = ['.', './dataloaders', './models', './util']
sources_to_save = list(
    itertools.chain.from_iterable(
        [glob.glob(f'{folder}/*.py') for folder in source_folders]))
for source_file in sources_to_save:
    ex.add_source_file(source_file)


@ex.config
def cfg():
    """Default configurations"""
    input_size = (417, 417)
    seed = 1234
    cuda_visable = '0, 1, 2, 3, 4, 5, 6, 7'
    gpu_id = 0
    mode = 'visualize'  # 'train' or 'test' or 'visualize
    label_sets = None
    details = 'Base PANet configuration'

    if mode == 'train':
        dataset = 'VOC'  # 'VOC' or 'COCO'
示例#7
0
from derive_conceptualspace.settings import get_setting
from fb_classifier.preprocess_data import preprocess_data, create_traintest
from fb_classifier.dataset import load_data
from fb_classifier.train import TrainPipeline
# from src.fb_classifier.util.misc import get_all_debug_confs, clear_checkpoints_summary
from fb_classifier.util.misc import get_all_configs
from fb_classifier.settings import CLASSIFIER_CHECKPOINT_PATH, SUMMARY_PATH, MONGO_URI, DATA_BASE, DEBUG
import fb_classifier

ex = Experiment("Fachbereich_Classifier")
ex.observers.append(MongoObserver(url=MONGO_URI, db_name=os.environ["MONGO_DATABASE"]))
ex.captured_out_filter = apply_backspaces_and_linefeeds
ex.add_config(get_all_configs(as_string=False))
ex.add_config(DEBUG=DEBUG)
for pyfile in [join(path, name) for path, subdirs, files in os.walk(dirname(fb_classifier.__file__)) for name in files if splitext(name)[1] == ".py"]:
    ex.add_source_file(pyfile)

@ex.main
def run_experiment(_run):
    args = parse_command_line_args()
    setup_logging(args.loglevel, args.logfile)

    # if args.restart: #delete checkpoint- and summary-dir
    #     clear_checkpoints_summary()

    if args.no_continue or not os.listdir(CLASSIFIER_CHECKPOINT_PATH):
        classifier_checkpoint_path = join(CLASSIFIER_CHECKPOINT_PATH, str(_run._id))
        summary_path = join(SUMMARY_PATH, str(_run._id))
    else:
        latest_exp = max(int(i) for i in os.listdir(CLASSIFIER_CHECKPOINT_PATH))
        classifier_checkpoint_path = join(CLASSIFIER_CHECKPOINT_PATH, str(latest_exp))
import numpy as np
import torch
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import mean_squared_error
from sacred import Experiment

from dataset import read_dataset
from bots import LSTNetBot
from io_utils import export_validation, export_test

logging.basicConfig(level=logging.WARNING)

ex = Experiment('LSTNet')
ex.add_source_file("preprocess.py")
ex.add_source_file("prepare_seq_data.py")


@ex.named_config
def cnn_7():
    batch_size = 128
    ar_window_size = 28
    model_details = {
        "cnn_hidden_size": 256,
        "rnn_hidden_size": 256,
        "skip_hidden_size": 256,
        "skip": 7,
        "cnn_kernel": 7,
        "odrop": 0.25,
        "edrop": 0.25,
示例#9
0
def create_experiment(task,
                      name,
                      dataset_configs,
                      training_configs,
                      model_configs,
                      observers,
                      experiment_configs=None):

    dataset, load_dataset = get_dataset_ingredient(task)

    # Create experiment
    ex = Experiment(name=name, ingredients=[dataset, model, training])

    update_configs_(dataset, dataset_configs)
    update_configs_(training, training_configs)
    update_configs_(model, model_configs)

    if experiment_configs is not None:
        update_configs_(ex, experiment_configs)

    # Runtime options
    save_folder = '../../data/sims/deladd/temp/'
    ex.add_config({
        'no_cuda': False,
    })

    # Add dependencies
    ex.add_source_file('../../src/model/subLSTM/nn.py')
    ex.add_source_file('../../src/model/subLSTM/functional.py')
    ex.add_package_dependency('torch', torch.__version__)
    ex.observers.extend(observers)

    def _log_training(tracer):
        ex.log_scalar('training_loss', tracer.trace[-1])
        tracer.trace.clear()

    def _log_validation(engine):
        for metric, value in engine.state.metrics.items():
            ex.log_scalar('val_{}'.format(metric), value)

    def _run_experiment(_config, seed):
        no_cuda = _config['no_cuda']
        batch_size = _config['training']['batch_size']

        device = set_seed_and_device(seed, no_cuda)
        training_set, test_set, validation_set = load_dataset(
            batch_size=batch_size)
        model = init_model(device=device)

        trainer, validator, checkpoint, metrics = setup_training(
            model,
            validation_set,
            save=save_folder,
            device=device,
            trace=False,
            time=False)[:4]

        tracer = Tracer().attach(trainer)
        trainer.add_event_handler(Events.EPOCH_COMPLETED,
                                  lambda e: _log_training(tracer))
        validator.add_event_handler(Events.EPOCH_COMPLETED, _log_validation)

        test_metrics = run_training(model=model,
                                    train_data=training_set,
                                    trainer=trainer,
                                    test_data=test_set,
                                    metrics=metrics,
                                    model_checkpoint=checkpoint,
                                    device=device)

        # save best model performance and state
        for metric, value in test_metrics.items():
            ex.log_scalar('test_{}'.format(metric), value)

        ex.add_artifact(str(checkpoint._saved[-1][1][0]), 'trained-model')

    return ex, _run_experiment
示例#10
0
import os
from sacred import Experiment

ex = Experiment()

source_filepaths = ["train.py"]

for fpath in source_filepaths:
    if os.path.exists(fpath):
        ex.add_source_file(fpath)


@ex.config
def ex_config():
    input_shape = (150, 150, 3)
示例#11
0
from PIL import Image
from sacred import Experiment
from sacred.utils import apply_backspaces_and_linefeeds
from skimage.metrics._structural_similarity import structural_similarity as compare_ssim
from skimage.metrics.simple_metrics import peak_signal_noise_ratio as compare_psnr
from torch.utils.data import DataLoader

import lib.pytorch_ssim as pytorch_ssim
from lib.data import get_training_set, is_image_file, get_Low_light_training_set
from lib.utils import TVLoss, print_network
from model import DLN

Name_Exp = 'DLN'
exp = Experiment(Name_Exp)
# exp.observers.append(MongoObserver(url='Host:27017', db_name='low_light'))
exp.add_source_file("train.py")
exp.add_source_file("model.py")
exp.add_source_file("lib/dataset.py")
exp.captured_out_filter = apply_backspaces_and_linefeeds


@exp.config
def cfg():
    parser = argparse.ArgumentParser(description='PyTorch Low-Light Enhancement')
    parser.add_argument('--batchSize', type=int, default=32, help='training batch size')
    parser.add_argument('--nEpochs', type=int, default=500, help='number of epochs to train for')
    parser.add_argument('--snapshots', type=int, default=10, help='Snapshots')
    parser.add_argument('--start_iter', type=int, default=0, help='Starting Epoch')
    parser.add_argument('--lr', type=float, default=1e-5, help='Learning Rate. Default=0.0001')
    parser.add_argument('--gpu_mode', type=bool, default=True)
    parser.add_argument('--threads', type=int, default=16, help='number of threads for data loader to use')
示例#12
0
class Trainer:
    def __init__(self,
                 network,
                 optim_class,
                 train_loader,
                 val_loader,
                 test_loader,
                 params,
                 lr_scheduler=None,
                 params_scheduler=None,
                 verbose=True):

        super().__init__()
        self.network = network
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.num_classes = train_loader.dataset.num_classes
        self.class_names = train_loader.dataset.classes
        self.num_weights = self.get_num_weights()
        self.num_trainable_weights = self.get_num_weights(trainable=True)

        self.verbose = verbose
        self.params = params
        self.seed = params.seed
        self.restart_path = params.restart_path
        self.max_epochs = params.max_epochs
        self.max_beaten_epochs = params.max_beaten_epochs
        self.learning_rate = params.learning_rate
        self.lr_param_multipliers = params.lr_param_multipliers
        self.optimize_every = params.optimize_every
        self.clip_grad_norm = params.clip_grad_norm
        self.keep_only_best_checkpoint = params.keep_only_best_checkpoint

        timestring = strftime("%Y-%m-%d_%H-%M-%S",
                              gmtime()) + "_%s" % params.exp_name
        # self.log_dir = self.restart_path if self.restart_path else os.path.join(params.log_dir, timestring)
        self.log_dir = os.path.join(params.log_dir, timestring)

        self.use_sacred = params.sacred
        if params.sacred:
            self.sacred_exp = Experiment(params.exp_name)
            self.sacred_exp.captured_out_filter = apply_backspaces_and_linefeeds
            configs = vars(params)
            configs.update({'num_weights': self.num_weights})
            configs.update(
                {'num_trainable_weights': self.num_trainable_weights})
            configs.update({'log_dir': self.log_dir})
            self.sacred_exp.add_config(self.mongo_compatible(configs))
            for source in self.get_sources():
                self.sacred_exp.add_source_file(source)

            if not params.mongodb_disable:
                url = "{0.mongodb_url}:{0.mongodb_port}".format(params)
                db_name = [
                    d for d in params.data_dir.split('/') if len(d) > 0
                ][-1]
                if hasattr(params, 'mongodb_name') and params.mongodb_name:
                    db_name = params.mongodb_name

                print(
                    colored('Connect to MongoDB@{}:{}'.format(url, db_name),
                            "green"))
                self.sacred_exp.observers.append(
                    MongoObserver.create(url=url, db_name=db_name))

        self.seen = 0
        self.epoch = 0
        self.steps = 0
        self.best_epoch = 0
        self.best_epoch_score = 0
        self.beaten_epochs = 0
        self.optim_class = optim_class
        self.lr_scheduler = lr_scheduler
        self.params_scheduler = params_scheduler
        self.optimizer = None
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

    @classmethod
    def mongo_compatible(cls, obj):
        if isinstance(obj, dict):
            res = dict()
            for key, value in obj.items():
                # '.' and '$' are mongoDB reserved characters, replace them
                # as ',' and '£' respectively
                key = key.replace(".", ',').replace("$", "£")
                res[key] = cls.mongo_compatible(value)
            return res
        elif isinstance(obj, (list, tuple)):
            return list([cls.mongo_compatible(value) for value in obj])
        return obj

    def set_seed(self, seed):
        torch.manual_seed(seed)
        np.random.seed(seed)
        print(colored("Using seed %d" % seed, "green"))

    def get_sources(self):
        sources = []
        # The network file
        sources.append(inspect.getfile(self.network.__class__))
        # the main script
        sources.append(sys.argv[0])
        # and any user custom submodule
        for module in self.network.children():
            module_path = inspect.getfile(module.__class__)
            if 'site-packages' not in module_path:
                sources.append(module_path)

        # The configuration file
        # if hasattr(self.params, "config") and self.params.config:
        #     if os.path.exists(self.params.config):
        #         sources.append(self.params.config)

        return sources

    def get_num_weights(self, trainable=False):
        return sum([
            p.numel() for p in self.network.parameters()
            if not trainable or p.requires_grad
        ])

    def log_sacred_scalar(self, name, val, step):
        if self.use_sacred and hasattr(
                self, 'sacred_exp') and self.sacred_exp.current_run:
            self.sacred_exp.current_run.log_scalar(name, val, step)

    def log_params(self, logger):
        name_str = os.path.basename(sys.argv[0])
        args_str = "".join([("  %s: %s \n" % (arg, val))
                            for arg, val in sorted(vars(self.params).items())
                            ])[:-2]
        logger.add_text("Script arguments", name_str + "\n" + args_str)

    def json_params(self, savedir):
        try:
            dict_params = vars(self.params)
            json_path = os.path.join(savedir, "params.json")

            with open(json_path, 'w') as fp:
                json.dump(dict_params, fp)
        except Exception as e:
            print(
                colored("An error occurred while saving parameters into JSON:",
                        "red"))
            print(e)

    def yaml_params(self, savedir):
        try:
            yaml_path = os.path.join(savedir, "params.yaml")
            with open(yaml_path, 'w') as fp:
                for k, v in vars(self.params).items():
                    if isinstance(v, list):
                        v = "[" + "".join(["{}, ".format(z)
                                           for z in v])[:-len(", ")] + "]"
                    if isinstance(v, str) and len(v) == 0:
                        continue
                    if v is None:
                        continue
                    fp.write("{}: {}\n".format(k, v))
        except Exception as e:
            print(
                colored("An error occurred while saving parameters into YAML:",
                        "red"))
            print(e)

    def json_results(self, savedir, test_score):
        try:
            json_path = os.path.join(savedir, "results.json")
            results = {
                'seen': self.seen,
                'epoch': self.epoch,
                'best_epoch': self.best_epoch,
                'beaten_epochs': self.beaten_epochs,
                'best_epoch_score': self.best_epoch_score,
                'test_score': test_score
            }

            with open(json_path, 'w') as fp:
                json.dump(results, fp)

            if self.use_sacred and hasattr(
                    self, 'sacred_exp') and self.sacred_exp.current_run:
                self.sacred_exp.current_run.add_artifact(json_path)
        except Exception as e:
            print("An error occurred while saving results into JSON:")
            print(e)

    def log_gradients(self, logger, global_step):

        for name, param in self.network.named_parameters():
            if param.requires_grad and param.grad is not None:
                logger.add_scalar("gradients/" + name,
                                  param.grad.norm(2).item(),
                                  global_step=global_step)

    def restart_exp(self):
        regex = re.compile(r'.*epoch(\d+)\.ckpt')
        checkpoints = glob.glob(os.path.join(self.restart_path, "*.ckpt"))
        # Sort checkpoints
        checkpoints = sorted(checkpoints,
                             key=lambda f: int(regex.findall(f)[0]))
        last_checkpoint = checkpoints[-1]
        self.load_trainer_checkpoint("", last_checkpoint)

    def save_checkpoint(self, path, filename):
        os.makedirs(path, exist_ok=True)

        try:
            torch.save(
                {
                    'seen': self.seen,
                    'epoch': self.epoch,
                    'best_epoch': self.best_epoch,
                    'beaten_epochs': self.beaten_epochs,
                    'best_epoch_score': self.best_epoch_score,
                    'model_state_dict': self.network.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict()
                }, os.path.join(path, filename))
        except Exception as e:
            print("An error occurred while saving the checkpoint:")
            print(e)

    def remove_checkpoint(self, path, filename):
        try:
            # The checkpoint may already be removed
            os.remove(os.path.join(path, filename))
        except:
            pass

    def load_trainer_checkpoint(self, path, filename):
        ckpt_path = os.path.join(path, filename)

        checkpoint = torch.load(ckpt_path)
        model_state_dict = checkpoint['model_state_dict']
        if model_state_dict['resnet.fc.weight'].shape[
                0] != self.network.resnet.fc.out_features:
            model_state_dict.pop('resnet.fc.weight')
            prev_class_num = model_state_dict.pop('resnet.fc.bias').shape[0]
            print('popping checkpoint last from {}'.format(prev_class_num))
            self.network.load_state_dict(model_state_dict, strict=False)
        else:
            self.seen = checkpoint['seen']
            self.epoch = checkpoint['epoch']
            self.best_epoch = checkpoint['best_epoch']
            self.beaten_epochs = checkpoint['beaten_epochs']
            self.best_epoch_score = checkpoint['best_epoch_score']
            self.network.load_state_dict(model_state_dict)
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    def load_network_checkpoint(self, path, filename):
        ckpt_path = os.path.join(path, filename)

        checkpoint = torch.load(ckpt_path)
        self.network.load_state_dict(checkpoint['model_state_dict'])

    def correct(self, batch_probas, batch_labels):
        """
        Computes the number of correct predictions given
        :param batch_probas: A tensor of shape [batch_size, num_classes] containing the predicted probabilities
        :param batch_labels: A tensor of shape [num_classes] containing the target labels
        :return: a scalar representing the number of correct predictions
        """

        predicted_classes = torch.argmax(batch_probas, dim=-1)
        return torch.sum(predicted_classes == batch_labels)

    def confusion_matrix(self, batch_probas, batch_labels):
        """

        :param batch_probas:
        :param batch_labels:
        :return:
        """

        _, batch_predictions = torch.max(batch_probas, dim=-1)
        if isinstance(batch_predictions, torch.Tensor):
            batch_predictions = batch_predictions.clone().cpu().data.numpy()
        if isinstance(batch_labels, torch.Tensor):
            batch_labels = batch_labels.clone().cpu().data.numpy()
        conf_matrix = metrics.confusion_matrix(batch_labels,
                                               batch_predictions,
                                               labels=np.arange(
                                                   self.num_classes))
        return conf_matrix

    def optimizer_parameters(self, base_lr, params_mult):
        """
        Associates network parameters with learning rates
        :param float base_lr: the basic learning rate
        :param OrderedDict params_mult: an OrderedDict containing 'param_name':lr_multiplier pairs. All parameters containing
        'param_name' in their name are be grouped together and assigned to a lr_multiplier*base_lr learning rate.
        Parameters not matching any 'param_name' are assigned to the base_lr learning_rate
        :return: A list of dictionaries [{'params': <list_params>, 'lr': lr}, ...]
        """

        selected = []
        grouped_params = []
        if params_mult is not None:
            for groupname, multiplier in params_mult.items():
                group = []
                for paramname, param in self.network.named_parameters():
                    if groupname in paramname:
                        if paramname not in selected:
                            group.append(param)
                            selected.append(paramname)
                        else:
                            raise RuntimeError(
                                "%s matches with multiple parameters groups!" %
                                paramname)
                if group:
                    grouped_params.append({
                        'params': group,
                        'lr': multiplier * base_lr
                    })

        other_params = [
            param for paramname, param in self.network.named_parameters()
            if paramname not in selected
        ]
        grouped_params.append({'params': other_params, 'lr': base_lr})
        assert len(selected) + len(other_params) == len(
            list(self.network.parameters()))

        return grouped_params

    @main_ifsacred
    def train_network(self):
        """
        Performs a complete training procedure by performing early stopping using the provided validation set
        """

        self.seen = 0
        self.steps = 0
        self.epoch = 0
        self.best_epoch = 0
        self.beaten_epochs = 0
        self.best_epoch_score = 0
        # Initializes the network and optimizer states
        self.set_seed(self.seed)
        self.network.init_params()
        # Moves the network to 'device' (GPU if available)
        self.network = self.network.to(self.device)

        self.optimizer = self.optim_class(self.optimizer_parameters(
            base_lr=self.learning_rate, params_mult=self.lr_param_multipliers),
                                          lr=self.learning_rate)
        if self.lr_scheduler is not None:
            self.lr_scheduler = self.lr_scheduler(self.optimizer)
        # Loads the initial checkpoint if provided
        if self.restart_path is not None:
            self.restart_exp()

        # Use the same directory as the restart experiment
        os.makedirs(self.log_dir, exist_ok=True)
        logger = SummaryWriter(log_dir=self.log_dir)
        self.log_params(logger)
        self.json_params(self.log_dir)
        self.yaml_params(self.log_dir)

        while self.beaten_epochs < self.max_beaten_epochs and self.epoch < self.max_epochs:

            if self.lr_scheduler is not None:
                for i, lr in enumerate(self.lr_scheduler.get_lr()):
                    logger.add_scalar("train/lr%d" % i, lr, self.epoch)

            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            if self.params_scheduler is not None:
                self.params_scheduler.step()

            # Performs one epoch of training
            self.train_epoch(self.train_loader, logger=logger)
            self.save_checkpoint(self.log_dir, "epoch%d.ckpt" % self.epoch)
            # Validate the network against the validation set
            valid_score = self.validate_network(self.val_loader,
                                                log_descr="validation",
                                                logger=logger)

            if valid_score > self.best_epoch_score or self.epoch == 0:
                if self.keep_only_best_checkpoint and self.epoch != 0:
                    self.remove_checkpoint(self.log_dir,
                                           "epoch%d.ckpt" % self.best_epoch)
                    self.remove_checkpoint(self.log_dir,
                                           "epoch%d.ckpt" % (self.epoch - 1))
                self.best_epoch = self.epoch
                self.best_epoch_score = valid_score
                self.beaten_epochs = 0
            else:
                if self.keep_only_best_checkpoint and self.beaten_epochs > 0:
                    self.remove_checkpoint(self.log_dir,
                                           "epoch%d.ckpt" % (self.epoch - 1))
                self.beaten_epochs += 1

            self.epoch += 1

        self.load_network_checkpoint(self.log_dir,
                                     "epoch%d.ckpt" % self.best_epoch)
        test_score = self.validate_network(self.test_loader,
                                           log_descr="test",
                                           logger=logger)
        self.json_results(self.log_dir, test_score)
        self.log_sacred_scalar("test/accuracy", test_score, self.epoch)

        return test_score

    def train_epoch(self, dataloader, logger=None):
        """
        Performs one entire epoch of training
        :param dataloader: A DataLoader object producing training samples
        :return: a tuple (epoch_loss, epoch_accuracy)
        """

        running_loss = 0
        running_correct = 0
        running_samples = 0
        running_batches = 0
        running_iter_loss = 0
        running_iter_correct = 0
        running_iter_samples = 0

        tot_num_batches = len(dataloader)
        running_optimize_every = min(self.optimize_every,
                                     tot_num_batches - running_batches)

        # Enters train mode
        self.network.train()
        # Zero the parameter gradients
        self.optimizer.zero_grad()

        if not self.verbose:
            print("Starting training phase ...")

        pbar_descr_prefix = "Epoch %d (best: %d, beaten: %d)" % (
            self.epoch, self.best_epoch, self.beaten_epochs)
        with tqdm(total=tot_num_batches,
                  disable=not self.verbose,
                  desc=pbar_descr_prefix +
                  " - Mini-batch progress") as iterator:
            for batch in dataloader:
                # Get the inputs
                batch_lengths, batch_events, batch_labels = batch
                # Moves batch to the proper device based on GPU availability
                batch_lengths = batch_lengths.to(self.device)
                batch_events = batch_events.to(self.device).type(torch.float32)
                batch_labels = batch_labels.to(self.device)

                # forward + backward + optimize
                batch_outputs = self.network.forward(batch_events,
                                                     batch_lengths)
                loss = self.network.loss(batch_outputs, batch_labels)
                norm_loss = loss / running_optimize_every
                norm_loss.backward()

                loss_b = loss.item()
                running_loss += loss_b
                running_iter_loss += loss_b
                correct_b = self.correct(batch_outputs, batch_labels).item()
                running_correct += correct_b
                running_iter_correct += correct_b

                running_batches += 1
                samples_b = batch_labels.shape[0]
                self.seen += samples_b
                running_samples += samples_b
                running_iter_samples += samples_b

                if running_batches % running_optimize_every == 0:
                    if self.clip_grad_norm is not None:
                        torch.nn.utils.clip_grad_norm_(
                            self.network.parameters(), self.clip_grad_norm)
                    self.optimizer.step()

                    self.log_gradients(logger, global_step=self.steps)
                    self.network.log_parameters(logger, global_step=self.steps)
                    logger.add_scalar("train/loss",
                                      running_iter_loss /
                                      running_optimize_every,
                                      global_step=self.seen)
                    logger.add_scalar("train/accuracy",
                                      running_iter_correct /
                                      running_iter_samples,
                                      global_step=self.seen)

                    # Zero the parameter gradients
                    self.optimizer.zero_grad()
                    # Either self.optimize_every or the number of remaining samples
                    running_optimize_every = min(
                        self.optimize_every, tot_num_batches - running_batches)
                    running_iter_samples = 0
                    running_iter_correct = 0
                    running_iter_loss = 0
                    self.steps += 1

                # Only print infos at 25%, 50%, 75%
                if not self.verbose and self.steps % (tot_num_batches //
                                                      4) == 0:
                    print(("Training phase {}% - " + pbar_descr_prefix +
                           " - metric: {},  loss: {}").format(
                               int((running_batches / tot_num_batches) * 100),
                               running_correct / running_samples,
                               running_loss / running_batches))

                iterator.update()

            tot_metric = running_correct / running_samples
            tot_loss = running_loss / running_batches
            iterator.set_description(pbar_descr_prefix +
                                     " - metric: %f,  loss: %f" %
                                     (tot_metric, tot_loss))
            self.log_sacred_scalar("train/accuracy", tot_metric, self.epoch)

        if not self.verbose:
            print("Finished " + pbar_descr_prefix +
                  " - metric: %f,  loss: %f" % (tot_metric, tot_loss))

        return tot_loss, tot_metric

    def validate_network(self,
                         dataloader,
                         log_descr="validation",
                         logger=None):
        """
        Computes the accuracy of the network against a validation set
        :param dataloader: A DataLoader object producing validation/test samples
        :return: the accuracy over the validation dataset
        """

        running_samples = 0
        running_correct = 0
        running_confmatrix = np.zeros([self.num_classes, self.num_classes],
                                      dtype=np.int)

        if not self.verbose:
            print("Starting " + log_descr + " phase ...")

        # Enters eval mode
        self.network.eval()
        # Disable autograd while evaluating the model
        with no_grad_ifnotscript(self.network):
            with tqdm(total=len(dataloader),
                      disable=not self.verbose,
                      desc="    %s - Mini-batch progress" %
                      log_descr.title()) as iterator:
                for batch in dataloader:
                    # Get the inputs
                    batch_lengths, batch_events, batch_labels = batch
                    # Moves batch to the proper device based on GPU availability
                    batch_lengths = batch_lengths.to(self.device)
                    batch_events = batch_events.to(self.device).type(
                        torch.float32)
                    batch_labels = batch_labels.to(self.device)

                    batch_outputs = self.network.forward(
                        batch_events, batch_lengths)
                    running_correct += self.correct(batch_outputs,
                                                    batch_labels).item()
                    running_confmatrix += self.confusion_matrix(
                        batch_outputs, batch_labels)
                    running_samples += batch_labels.shape[0]

                    iterator.update()

                tot_metric = running_correct / running_samples
                logger.add_scalar(log_descr + "/accuracy",
                                  tot_metric,
                                  global_step=self.epoch)
                self.log_sacred_scalar((log_descr + "/accuracy"), tot_metric,
                                       self.epoch)
                confmatrix_fig = confusion_matrix_fig(running_confmatrix,
                                                      self.class_names)
                logger.add_figure(log_descr + "/confusion_matrix",
                                  confmatrix_fig,
                                  global_step=self.epoch,
                                  close=True)
                self.network.log_validation(logger, global_step=self.epoch)

                iterator.set_description("    %s - metric: %f" %
                                         (log_descr.title(), tot_metric))

        if not self.verbose:
            print("Finished %s - metric: %f" % (log_descr, tot_metric))

        return tot_metric
示例#13
0
with open('mongodb.json') as f:
    mongodb_settings = json.load(f)

ex.observers.append(
    MongoObserver.create(
        url='mongodb://{user}:{pwd}@{ip}:{port}'.format(**mongodb_settings),
        db_name='{db}'.format(**mongodb_settings),
    ))

if (c.ROOT / 'telegram.json').exists():
    ex.observers.append(TelegramObserver.from_config('telegram.json'))

for filename in c.ROOT.glob('**/*.py'):
    print("Saving File: {}".format(filename.absolute()))
    ex.add_source_file(filename.absolute())


@ex.config
def default_config():
    project = 'dcase20191b'


@ex.config
def custom_config(project):
    ex.add_config(str(c.ROOT / 'configs' / '{}.json'.format(project)))


@ex.automain
def run(_config, _run, _rnd):
    # experiment, data_set, model, training, resources, _run, _rnd
import numpy as np
import torch
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from transformer.Optim import ScheduledOptim
from sacred import Experiment

from dataset import read_dataset
from bots import TransformerBot
from io_utils import export_validation, export_test


logging.basicConfig(level=logging.WARNING)

ex = Experiment('Transformer')
ex.add_source_file("preprocess.py")


@ex.named_config
def no_tf_2l():
    batch_size = 32
    model_details = {
        "odrop": 0.25,
        "edrop": 0.25,
        "hdrop": 0.1,
        "d_model": 128,
        "d_inner_hid": 256,
        "n_layers": 2,
        "n_head": 4,
        "propagate": False
    }