Ejemplo n.º 1
0
    def setup_model(self, load_checkpoint=None, print_model_summary=False):

        self.checkpoints_path = os.path.join(self.config['training']['path'],
                                             'checkpoints')
        self.samples_path = os.path.join(self.config['training']['path'],
                                         'samples')
        self.history_filename = 'history_' + self.config['training']['path'][
            self.config['training']['path'].rindex('/') + 1:] + '.csv'

        model = self.build_model()

        if os.path.exists(self.checkpoints_path) and util.dir_contains_files(
                self.checkpoints_path):

            if load_checkpoint is not None:
                last_checkpoint_path = load_checkpoint
                self.epoch_num = 0
            else:
                checkpoints = os.listdir(self.checkpoints_path)
                checkpoints.sort(key=lambda x: os.stat(
                    os.path.join(self.checkpoints_path, x)).st_mtime)
                last_checkpoint = checkpoints[-1]
                last_checkpoint_path = os.path.join(self.checkpoints_path,
                                                    last_checkpoint)
                self.epoch_num = int(last_checkpoint[11:16])
            print('Loading model from epoch: %d' % self.epoch_num)
            model.load_weights(last_checkpoint_path)

        else:
            print('Building new model...')

            if not os.path.exists(self.config['training']['path']):
                os.makedirs(self.config['training']['path'])

            if not os.path.exists(self.checkpoints_path):
                os.makedirs(self.checkpoints_path)

            self.epoch_num = 0

        if not os.path.exists(self.samples_path):
            os.makedirs(self.samples_path)

        if print_model_summary:
            model.summary()

        self.compile_model(model)
        # model.compile(optimizer=self.optimizer,
        #   loss={'data_output_1': self.out_1_loss, 'data_output_2': self.out_2_loss}, metrics=self.metrics)
        self.config['model']['num_params'] = model.count_params()

        config_path = os.path.join(self.config['training']['path'],
                                   'config.json')
        if not os.path.exists(config_path):
            util.pretty_json_dump(self.config, config_path)

        if print_model_summary:
            util.pretty_json_dump(self.config)
        return model
Ejemplo n.º 2
0
                ## or last epoch
                if curr_sdr > valid_best_sdr or epoch == config['training']['num_epochs']:
                    util.myprint(history_file, 'Save Model')
                    valid_wait = 0
                    valid_best_sdr = curr_sdr
                    G.save(G_save_path, g_curr_step)

                else:
                    valid_wait += 1
                    if valid_wait == config['training']['half_lr_patience']:
                        glr /= 2; valid_wait = 0
                break

        util.write(os.path.join(config['training']['path'], 'tr_perm.csv'), tr_dataset.file_base, tr_audio_perm, epoch, config['training']['n_speaker'])

if __name__ == "__main__":

    cla = get_command_line_arguments()
    config = load_config(cla.config)
    print('Save model path : {}'.format(config['training']['path']))

    if not os.path.exists(config['training']['path']):
        os.mkdir(config['training']['path'])

    if cla.mode == 'train':
        util.pretty_json_dump(config, os.path.join(config['training']['path'], os.path.basename(cla.config)))
        training(config, cla)

    elif cla.mode == 'test':
        test.test(config, cla)