Ejemplo n.º 1
0
def main(args):
    config_file = args.config or os.path.join(args.logdir, 'model.yaml')
    with open(config_file, 'rb') as f:
        config = Configuration.from_yaml(f)
    logger.debug(config)

    model, trainer, encoding = config.configure(
        _init,
        logdir=args.logdir,
        train_mode=(args.action == 'train'),
        sampling_seed=getattr(args, 'seed', None))

    if args.action == 'train':
        trainer.train()
    elif args.action == 'run':
        trainer.load_variables(checkpoint_file=args.checkpoint)
        data = pickle.load(args.input_file)
        dataset = make_simple_dataset(_make_data_generator(encoding, data),
                                      output_types=(tf.int32, tf.int32,
                                                    tf.int32),
                                      output_shapes=([None], [None], [None]),
                                      batch_size=args.batch_size)

        output_ids = model.run(trainer.session, dataset, args.sample,
                               args.softmax_temperature)
        output = [encoding.decode(seq) for seq in output_ids]
        pickle.dump(output, args.output_file)
    def run(self,
            pipeline,
            batch_size=None,
            filters='program',
            sample=False,
            softmax_temperature=1.,
            normalize_velocity=False,
            options=None):
        metadata_list = []  # gather metadata about each item of the dataset
        apply_filters = '__program__' if filters == 'program' else True
        dataset = make_simple_dataset(
            self._load_data(tqdm.tqdm(pipeline),
                            apply_filters=apply_filters,
                            normalize_velocity=normalize_velocity,
                            metadata_list=metadata_list),
            output_types=self.input_types,
            output_shapes=self.input_shapes,
            batch_size=batch_size
            or self._cfg['data_prep'].get('val_batch_size'))
        output_ids = self.model.run(self.trainer.session,
                                    dataset,
                                    sample,
                                    softmax_temperature,
                                    options=options) or []
        sequences = [self.output_encoding.decode(ids) for ids in output_ids]
        merged_sequences = []
        instrument_id = 0
        for seq, meta in zip(sequences, metadata_list):
            instrument_id += 1
            while meta['input_index'] > len(merged_sequences) - 1:
                merged_sequences.append(music_pb2.NoteSequence())
                instrument_id = 0

            # Apply features (instrument, velocity)
            if meta['note_features'] is not None:
                if self._cfg['output_encoding'].get('use_velocity', False):
                    # If the output has velocity information, do not override it
                    del meta['note_features']['velocity']

                set_note_fields(seq,
                                **meta['note_features'],
                                instrument=instrument_id)
            else:
                # If the style input had no notes, force the output to be empty
                seq.Clear()

            # Merge
            merged_sequences[-1].notes.extend(seq.notes)
            merged_sequences[-1].total_time = max(
                merged_sequences[-1].total_time, seq.total_time)
            instrument_info = merged_sequences[-1].instrument_infos.add()
            instrument_info.instrument = instrument_id
            instrument_info.name = meta['filter_name']

        return merged_sequences
    def run(self, args):
        self.trainer.load_variables(checkpoint_file=args.checkpoint)
        data = pickle.load(args.input_file)

        def generator():
            style_id = self.style_vocabulary.to_id(args.target_style)
            for example in data:
                segment_id, notes = example
                yield self.input_encoding.encode(notes), style_id, [], []

        dataset = make_simple_dataset(generator,
                                      output_types=self.input_types,
                                      output_shapes=self.input_shapes,
                                      batch_size=args.batch_size)

        output_ids = self.model.run(self.trainer.session, dataset, args.sample,
                                    args.softmax_temperature)
        outputs = [(segment_id, self.output_encoding.decode(seq))
                   for seq, (segment_id, _) in zip(output_ids, data)]

        pickle.dump(outputs, args.output_file)
def _run(model, trainer, encoding, style_vocabulary, config, args):
    trainer.load_variables(checkpoint_file=args.checkpoint)
    data = pickle.load(args.input_file)

    def generator():
        style_id = style_vocabulary.to_id(args.target_style)
        for example in data:
            segment_id, notes = example
            yield encoding.encode(notes, add_start=False,
                                  add_end=False), style_id, [], []

    dataset = make_simple_dataset(generator,
                                  output_types=(tf.int32, tf.int32, tf.int32,
                                                tf.int32),
                                  output_shapes=([None], [], [None], [None]),
                                  batch_size=args.batch_size)

    output_ids = model.run(trainer.session, dataset, args.sample,
                           args.softmax_temperature)
    output = [(segment_id, encoding.decode(seq))
              for seq, (segment_id, _) in zip(output_ids, data)]

    pickle.dump(output, args.output_file)