Ejemplo n.º 1
0
    def _save_params(self):
        """ Save the params of the model, like the current glob_step value
        Warning: if you modify this function, make sure the changes mirror load_params
        """
        config = configparser.ConfigParser()
        config['General'] = {}
        config['General']['version'] = self.CONFIG_VERSION
        config['General']['glob_step'] = str(self.glob_step)
        config['General']['keep_all'] = str(self.args.keep_all)
        config['General']['dataset_tag'] = self.args.dataset_tag
        config['General']['sample_length'] = str(self.args.sample_length)

        config['Network'] = {}
        config['Network']['hidden_size'] = str(self.args.hidden_size)
        config['Network']['num_layers'] = str(self.args.num_layers)
        config['Network'][
            'target_weights'] = self.args.target_weights  # Could be modified manually
        config['Network']['scheduled_sampling'] = ' '.join(
            self.args.scheduled_sampling)

        # Keep track of the learning params (are not model dependent so can be manually edited)
        config['Training'] = {}
        config['Training']['batch_size'] = str(self.args.batch_size)
        config['Training']['save_every'] = str(self.args.save_every)
        config['Training']['ratio_dataset'] = str(self.args.ratio_dataset)
        config['Training']['testing_curve'] = str(self.args.testing_curve)

        # Save the chosen modules and their configuration
        ModuleLoader.save_all(config)

        with open(os.path.join(self.model_dir, self.CONFIG_FILENAME),
                  'w') as config_file:
            config.write(config_file)
Ejemplo n.º 2
0
    def _restore_params(self):
        """ Load the some values associated with the current model, like the current glob_step value.
        Needs to be called before any other function because it initialize some variables used on the rest of the
        program

        Warning: if you modify this function, make sure the changes mirror _save_params, also check if the parameters
        should be reset in manage_previous_model
        """
        # Compute the current model path
        self.model_dir = os.path.join(self.args.root_dir, self.MODEL_DIR_BASE)
        if self.args.model_tag:
            self.model_dir += '-' + self.args.model_tag

        # If there is a previous model, restore some parameters
        config_name = os.path.join(self.model_dir, self.CONFIG_FILENAME)
        if not self.args.reset and not self.args.create_dataset and os.path.exists(
                config_name):
            # Loading
            config = configparser.ConfigParser()
            config.read(config_name)

            # Check the version
            current_version = config['General'].get('version')
            if current_version != self.CONFIG_VERSION:
                raise UserWarning(
                    'Present configuration version {0} does not match {1}. You can try manual changes on \'{2}\''
                    .format(current_version, self.CONFIG_VERSION, config_name))

            # Restoring the the parameters
            self.glob_step = config['General'].getint('glob_step')
            self.args.keep_all = config['General'].getboolean('keep_all')
            self.args.dataset_tag = config['General'].get('dataset_tag')
            if not self.args.test:  # When testing, we don't use the training length
                self.args.sample_length = config['General'].getint(
                    'sample_length')

            self.args.hidden_size = config['Network'].getint('hidden_size')
            self.args.num_layers = config['Network'].getint('num_layers')
            self.args.target_weights = config['Network'].get('target_weights')
            self.args.scheduled_sampling = config['Network'].get(
                'scheduled_sampling').split(' ')

            self.args.batch_size = config['Training'].getint('batch_size')
            self.args.save_every = config['Training'].getint('save_every')
            self.args.ratio_dataset = config['Training'].getfloat(
                'ratio_dataset')
            self.args.testing_curve = config['Training'].getint(
                'testing_curve')

            ModuleLoader.load_all(self.args, config)

            # Show the restored params
            print(
                'Warning: Restoring parameters from previous configuration (you should manually edit the file if you want to change one of those)'
            )

        # When testing, only predict one song at the time
        if self.args.test:
            self.args.batch_size = 1
            self.args.scheduled_sampling = [Model.ScheduledSamplingPolicy.NONE]
Ejemplo n.º 3
0
    def _print_params(self):
        """ Print the current params
        """
        print()
        print('Current parameters:')
        print('glob_step: {}'.format(self.glob_step))
        print('keep_all: {}'.format(self.args.keep_all))
        print('dataset_tag: {}'.format(self.args.dataset_tag))
        print('sample_length: {}'.format(self.args.sample_length))

        print('hidden_size: {}'.format(self.args.hidden_size))
        print('num_layers: {}'.format(self.args.num_layers))
        print('target_weights: {}'.format(self.args.target_weights))
        print('scheduled_sampling: {}'.format(' '.join(self.args.scheduled_sampling)))

        print('batch_size: {}'.format(self.args.batch_size))
        print('save_every: {}'.format(self.args.save_every))
        print('ratio_dataset: {}'.format(self.args.ratio_dataset))
        print('testing_curve: {}'.format(self.args.testing_curve))

        ModuleLoader.print_all(self.args)
Ejemplo n.º 4
0
    def main(self, args=None):
        """
        Launch the training and/or the interactive mode
        """
        print('Welcome to DeepMusic v0.1 !')
        print()
        print('TensorFlow detected: v{}'.format(tf.__version__))

        # General initialisations

        tf.logging.set_verbosity(
            tf.logging.INFO)  # DEBUG, INFO, WARN (default), ERROR, or FATAL

        ModuleLoader.register_all()  # Load available modules
        self.args = self._parse_args(args)
        if not self.args.root_dir:
            self.args.root_dir = os.getcwd(
            )  # Use the current working directory

        self._restore_params(
        )  # Update the self.model_dir and self.glob_step, for now, not used when loading Model
        self._print_params()

        self.music_data = MusicData(self.args)
        if self.args.create_dataset:
            print('Dataset created! You can start training some models.')
            return  # No need to go further

        with tf.device(self._get_device()):
            self.model = Model(self.args)
        print("before:")
        print(self.model_dir)
        # Saver/summaries
        self.writer = tf.summary.FileWriter(
            os.path.join(self.model_dir, 'train'))
        self.writer_test = tf.summary.FileWriter(
            os.path.join(self.model_dir, 'test'))
        self.saver = tf.train.Saver(
            max_to_keep=200)  # Set the arbitrary limit ?

        # TODO: Fixed seed (WARNING: If dataset shuffling, make sure to do that after saving the
        # dataset, otherwise, all what comes after the shuffling won't be replicable when
        # reloading the dataset). How to restore the seed after loading ?? (with get_state/set_state)
        # Also fix seed for np.random (does it works globally for all files ?)

        # Running session
        config = tf.ConfigProto(allow_soft_placement=True)
        self.sess = tf.Session(config=config)

        print('Initialize variables...')
        self.sess.run(tf.initialize_all_variables())

        # Reload the model eventually (if it exist), on testing mode, the models are not loaded here (but in main_test())
        self._restore_previous_model(self.sess)

        if self.args.test:
            if self.args.test == Composer.TestMode.ALL:
                self._main_test()
            elif self.args.test == Composer.TestMode.DAEMON:
                print('Daemon mode, running in background...')
                raise NotImplementedError('No daemon mode')  # Come back later
            else:
                raise RuntimeError('Unknown test mode: {}'.format(
                    self.args.test))  # Should never happen
        else:
            self._main_train()

        if self.args.test != Composer.TestMode.DAEMON:
            self.sess.close()
            print('The End! Thanks for using this program')