def notes_to_midi_file(notes, midi_file, midi_mapping): SP.header('MIDI MAPPING', '%d samples', len(midi_mapping)) SP.print('sample midi base dur vol') fmt = '%6d %4d %4d %3d %3.2f' for sample_idx, midi_def in midi_mapping.items(): SP.print(fmt, (sample_idx,) + tuple(midi_def)) SP.leave() notes_per_channel = sort_groupby(notes, lambda n: n.col_idx) notes_per_channel = [list(grp) for (_, grp) in notes_per_channel] notes_per_channel = [ list(mod_notes_to_midi_notes(notes, midi_mapping)) for notes in notes_per_channel] notes = sorted(flatten(notes_per_channel)) SP.print('Produced %d midi notes (on/offs).' % len(notes)) # Group by column (9 for drums) note_groups = groupby(notes, lambda el: el[0]) tracks = [MidiTrack(list(midi_notes_to_track(channel, note_group))) for (channel, note_group) in note_groups] midi = MidiFile(type = 1) midi.tracks = tracks midi.save(midi_file)
def run_python_file(conn, root_path, file_name, args): cmds = [ f'cd "{root_path}"', 'export PYTHONPATH="."', 'python3 %s %s' % (file_name, ' '.join(args)) ] script = ' && '.join(cmds) SP.print('Running %s...' % file_name) conn.run(script, pty=True)
def load_index(corpus_path): index_file = corpus_path / 'index' if not index_file.exists(): SP.print('Empty index.') return {} with open(index_file, 'rt') as f: mods = [IndexedModule.from_line(line) for line in f] return {mod.id: mod for mod in mods}
def save_index(corpus_path, index): index_file = corpus_path / 'index' mods = sorted(index.values(), key=lambda x: x.fname) SP.print('Saving index with %d modules.' % len(mods)) with open(index_file, 'wt') as f: for mod in mods: args = (mod.fname, mod.id, mod.format, mod.kb_size, mod.n_channels, mod.genre, mod.year, mod.n_downloads, mod.member_rating, mod.reviewer_rating) f.write(long_line(mod))
def column_to_mod_notes(rows, col_idx, volumes): tempo = DEFAULT_TEMPO speed = DEFAULT_SPEED col_period = None col_sample_idx = None notes = [] for row_idx, row in enumerate(rows): tempo, speed = update_timings(row, tempo, speed) time_ms = int(calc_row_time(tempo, speed) * 1000) cell = row[col_idx] sample_idx = cell.sample_idx period = cell.period # Neither sample nor note, skipping if not sample_idx and not period: continue if sample_idx: col_sample_idx = sample_idx if period: col_period_idx = period # Sample but no note, we skip those. if sample_idx and not period: continue # Period but no sample if period and not sample_idx: sample_idx = col_sample_idx if sample_idx is None: fmt = 'Missing sample at cell %4d:%d and ' \ + 'no channel sample. MOD bug?' SP.print(fmt % (row_idx, col_idx)) continue # fmt = 'Using last sample at cell %4d:%d' # SP.print(fmt % (row_idx, col_idx)) vol_idx = sample_idx - 1 if not 0 <= vol_idx < len(volumes): fmt = 'Sample %d out of bounds at cell %4d:%d. MOD bug?' SP.print(fmt % (sample_idx, row_idx, col_idx)) continue vol = mod_note_volume(volumes[vol_idx], cell) pitch_idx = period_to_idx(period) assert 0 <= pitch_idx < 60 note = ModNote(row_idx, col_idx, sample_idx, pitch_idx, vol, time_ms) notes.append(note) # Add durations for n1, n2 in zip(notes, notes[1:]): n1.duration = n2.row_idx - n1.row_idx if notes: notes[-1].duration = len(rows) - notes[-1].row_idx return notes
def mod_file_to_patterns(mod_file): SP.print(str(mod_file)) try: mod = load_file(mod_file) except PowerPackerModule: return [] rows = linearize_rows(mod) volumes = [header.volume for header in mod.sample_headers] notes = rows_to_mod_notes(rows, volumes) percussive = {s for (s, p) in sample_props(mod, notes) if p.is_percussive} return [pattern_to_matrix(pat, percussive) for pat in mod.patterns]
def to_notes(pcode, rel_pitches, row_time): notes = to_notes_without_tempo(pcode, rel_pitches) for n in notes: n.time_ms = row_time fmt = 'Rel pitches: %s, row time: %s.' SP.print(fmt % (rel_pitches, row_time)) # Fix durations cols = sort_groupby(notes, lambda n: n.col_idx) for _, col in cols: fix_durations(list(col)) return notes
def load_fragment(root_path, code_type, n_prompt, n_generate, ofs): # Works by loading pcode_abs and coverting to the desired format. _, _, td = load_training_data('pcode_abs', root_path) n_frag = n_prompt + n_generate # Pick a song fragment if ofs == 'random': ofs = random_rel_ofs(td, n_frag) else: ofs = tuple([int(s) for s in ofs.split('-')]) s_i, ss_i, t_i, o = ofs name = td.songs[s_i][0] song = td.songs[s_i][1][ss_i][t_i] assert o + n_frag <= len(song) frag = song[o:o + n_frag] SP.print('Selected %s:%d of song %s.' % (ofs, len(frag), name)) code = td.encoder.decode_chars(frag) code = normalize_pitches(code) # Split it into prompt and remainder. prompt, orig = code[:n_prompt], code[n_prompt:] if code_type == 'rcode2': # This is tricky... both the rcoded length of prompt and orig # needs to be divisble by 2. prompt = list(rcode.from_pcode(prompt)) orig = list(rcode.from_pcode(orig)) if len(prompt) % 2 == 1: # Steal one token from orig prompt.append(orig[0]) orig = orig[1:] if len(orig) % 2 == 1: # Pad orig.append((INSN_SILENCE, 1)) # Convert it back to the native format _, _, td = load_training_data(code_type, root_path) code_mod = CODE_MODULES[code_type] prompt = list(code_mod.from_pcode(prompt)) orig = list(code_mod.from_pcode(orig)) prompt = td.encoder.encode_chars(prompt, False) orig = td.encoder.encode_chars(orig, False) SP.print('%d prompt and %d orig tokens.' % (len(prompt), len(orig))) return td, ofs, prompt, orig
def notes_to_audio_file(notes, audio_file, midi_mapping, stereo): type = 'stereo' if stereo else 'mono' SP.header('%d NOTES TO %s (%s)' % (len(notes), audio_file, type)) if audio_file.suffix == '.mid': notes_to_midi_file(notes, audio_file, midi_mapping) SP.leave() return temp_dir = mkdtemp() temp_dir = Path(temp_dir) if stereo: left_notes = [n for n in notes if n.col_idx in {0, 3}] right_notes = [n for n in notes if n.col_idx in {1, 2}] for notes, side in [(left_notes, 'L'), (right_notes, 'R')]: mid = temp_dir / (side + '.mid') notes_to_midi_file(notes, mid, midi_mapping) system('timidity %s -OwM --preserve-silence' % mid) SP.print('Generating stereo output using sox.') fmt = 'sox -M -c 1 %s -c 1 %s -C 64.0 %s' system(fmt % (temp_dir / 'L.wav', temp_dir / 'R.wav', audio_file)) else: mid = temp_dir / 't.mid' notes_to_midi_file(notes, mid, midi_mapping) system('timidity %s -OwM --preserve-silence' % mid) system('sox %s %s' % (temp_dir / 't.wav', audio_file)) SP.leave() rmtree(temp_dir)
def main(): args = docopt(__doc__, version='Colab Tool 1.0') SP.enabled = args['--verbose'] root_path = args.get('--root-path') if not root_path: root_path = environ['MUSICGEN_ROOT_PATH'] root_path = Path(root_path) auth = args.get('--authority') if not auth: auth = environ['MUSICGEN_AUTHORITY'] userinfo, netloc = auth.split('@') _, password = userinfo.split(':') host, port = netloc.split(':') port = int(port) connect_kwargs = {'password': password} SP.print('Connecting to %s' % host) conn = Connection(host, 'root', port, connect_kwargs=connect_kwargs) sftp = conn.sftp() SP.print('Changing to dir "%s".' % root_path) remote_mkdir_safe(sftp, root_path) sftp.chdir(str(root_path)) if args['get-data']: get_data(conn, sftp) elif args['upload-code']: upload_code(conn, sftp) elif args['upload-caches']: corpus_path = Path(args['<corpus-path>']) upload_caches(conn, corpus_path) elif args['upload-file']: local_path = Path(args['<local-file>']) upload_file(conn, local_path) elif args['upload-and-run-file']: src = Path(args['<file>']) dst = src.parent if args['--drop-path']: dst = Path('.') upload_files(conn, [(src, dst)]) if args['--drop-path']: src = src.name run_python_file(conn, root_path, str(src), args['<args>']) elif args['run-file']: run_python_file(conn, root_path, args['<file>'], args['<args>']) else: assert False
def notes_to_matrix(notes, sample_props, n_rows): # Assign columns for percussive instruments. percussion = {s : i % 3 for i, s in enumerate(sample_props) if sample_props[s].is_percussive} pitches = {n.pitch_idx for n in notes if n.sample_idx not in percussion} if not pitches: SP.print('No melody.') return None min_pitch = min(pitch for pitch in pitches) max_pitch = max(pitch for pitch in pitches) pitch_range = max_pitch - min_pitch if pitch_range >= 36: SP.print('Pitch range %d too large' % pitch_range) return None def note_to_triplet(n): si = n.sample_idx if si in percussion: col_idx = percussion[si] note_dur = 4 else: col_idx = 3 + n.pitch_idx - min_pitch assert col_idx >= 3 sample_dur = sample_props[si].note_duration # Should be correct since it is set in rows_to_mod_notes note_dur = min(n.duration, sample_dur) return n.row_idx, col_idx, note_dur notes = sorted([note_to_triplet(n) for n in notes]) M = np.zeros((n_rows + 4, 3 + 36)) # Fill matrix with notes for (row, col, dur) in notes: M[row][col] = 1.0 assert dur > 0 for fol in range(dur - 1): M[row + fol + 1][col] = 0.5 # Clip silence. last_nonzero_row = np.nonzero(M)[0][-1] return M[:last_nonzero_row + 1]
def main(): global SCALE args = docopt(__doc__, version='GAN Model 1.0') SP.enabled = args['--verbose'] kb_limit = int(args['--kb-limit']) corpus_path = Path(args['<corpus>']) dataset = load_data(corpus_path, kb_limit) dataset = dataset.reshape(len(dataset), 64, 4, 1) # Scale to [0,1] SCALE = np.max(dataset) dataset = dataset / SCALE n_patterns = len(dataset) n_train = int(n_patterns * 0.8) train, test = dataset[:n_train], dataset[n_train:] SP.print('%d train and %d test patterns.', (len(train), len(test))) latent_dim = 100 d_model = define_discriminator((64, 4, 1)) g_model = define_generator(latent_dim) gan_model = define_gan(g_model, d_model) n_batch = 256 n_epochs = 500 batches_per_epoch = n_patterns // n_batch for i in range(n_epochs): if i % 50 == 0: summarize_performance(i, g_model, d_model, test, latent_dim) for j in range(batches_per_epoch): X, y = generate_real_and_fake_samples(train, g_model, latent_dim, n_batch) d_loss, _ = d_model.train_on_batch(X, y) X_gan = generate_latent_points(latent_dim, n_batch) y_gan = np.ones((n_batch, 1)) g_loss = gan_model.train_on_batch(X_gan, y_gan) if j == 0: fmt = '>%d, %d/%d, d=%.3f, g=%.3f' print(fmt % (i + 1, j + 1, batches_per_epoch, d_loss, g_loss)) summarize_performance('final', g_model, d_model, test, latent_dim)
def training_data_to_dataset(td, sl, bs): SP.print('Creating samples from %d songs.' % len(td.songs)) # For some reason NumPy thinks the arrays loaded from the pickle # cache are views. windows = [] for _, s in td.songs: for ss in s: for t in ss: assert t.dtype == np.uint16 t = t.copy() wins = slide_window(t, sl + 1, sl, None) for win in wins: win = win.copy() windows.append(win) shuffle(windows) SP.print('Created %d sliding windows.' % len(windows)) # Length must be a multiple of bs n_samples = (len(windows) // bs) * bs SP.print('Truncating to %d samples.' % n_samples) windows = windows[:n_samples] xs = np.array([e[:-1] for e in windows]) ys = np.array([e[1:] for e in windows]) return xs, ys
def print_histogram(td): counts = tally_tokens(td.encoder, td.songs) total = sum(v for (_, v) in counts) SP.header('%d TOKENS %d TYPES' % (total, len(counts))) for (cmd, arg), cnt in counts: SP.print('%3s %10s %10d' % (cmd, arg, cnt)) SP.leave()
def mod_notes_to_midi_notes(notes, midi_mapping): offset_ms = 0 last_row_idx = 0 for n in notes: row_delta = n.row_idx - last_row_idx # Update time offset_ms += row_delta * n.time_ms last_row_idx = n.row_idx program, midi_idx_base, note_dur, vol_adj \ = midi_mapping[n.sample_idx] # Note duration is the minimum... note_dur = min(note_dur, n.duration) # -2 indicates filtered notes. if program == -2: continue # Note velocity velocity = int(min((n.vol / 64) * 127 * vol_adj, 127)) # On and off offsets note_on = offset_ms note_off = offset_ms + note_dur * n.time_ms # Clamp the pitch in case the network generates garbage. midi_idx = midi_idx_base + n.pitch_idx if not 0 <= midi_idx < 120: SP.print('Fixing midi note %d.', midi_idx) midi_idx = min(max(midi_idx, 0), 120) # Drum track/melodic if program == -1: yield 9, note_on, 1, None, midi_idx_base, velocity else: yield n.col_idx, note_on, 1, program, midi_idx, velocity yield n.col_idx, note_off, 0, program, midi_idx, 0
def get_data(connection, sftp): paths = [Path(p) for p in sftp.listdir()] paths = [p for p in paths if p.suffix in ('.mid', '.png')] SP.header('DOWNLOADING %d FILES' % len(paths)) for path in paths: SP.print(path) connection.get(path) SP.leave()
def main(): # Prologue args = docopt(__doc__, version='Pickle to audio 1.0') SP.enabled = args['--verbose'] files = args['<files>'] file_paths = [Path(f) for f in files] # Prompt is used to estimate tempo n_prompt = int(args['--n-prompt']) format = args['--format'] for file_path in file_paths: code = load_pickle(file_path) code_type = file_path.name.split('-')[4] code_mod = CODE_MODULES[code_type] as_pcode = list(code_mod.to_pcode(code)) row_time = pcode.estimate_row_time(as_pcode[:n_prompt], False) notes = code_mod.to_notes(code, row_time) # Creates a simple fadeout. Not sure if it is a good feature # or not. max_row = max(n.row_idx for n in notes) for n in notes: delim = 0.9 if n.row_idx / max_row > delim: over = 1 - n.row_idx / max_row frac = over / (1 - delim) n.vol = 32 + int(16 * frac) prefix = '.'.join(str(file_path).split('.')[:-2]) output_name = '%s.%s' % (prefix, format) output_path = file_path.parent / output_name SP.print('Creating %s.' % output_path) stereo = (format == 'mp3') notes_to_audio_file(notes, output_path, CODE_MIDI_MAPPING, stereo)
def is_percussive(n_pitches, n_unique, n_pitch_classes, max_ringout, repeat_pct, longest_rep, most_common_freq): if n_unique <= 2 and max_ringout <= 0.15: return True # Sample is not repeating if repeat_pct == 1.0: # Always percussive if only one note is played. if n_unique == 1: return True if most_common_freq > 0.9 and n_unique <= 2 and max_ringout < 0.6: return True # If the same note is repeated more than 40 times, it must be # percussive. This is ofc completely arbitrary. if longest_rep >= 40: # and max_ringout < 1.0: return True # Another arbitrary one. if n_unique == 3 and max_ringout <= 0.11 and longest_rep >= 23: return True # This heuristic is "unsafe" but removes a lot of noise. if n_unique == 2 and n_pitch_classes <= 1: SP.print('Only one pitch class (%d pitches)' % n_pitches) return True # Sample is repeating, might still be percussive # (alfrdchi_endofgame1.mod) if repeat_pct > 0.0 and n_unique == 1: if longest_rep >= 100: return True if longest_rep == n_pitches: return True return False
def convert_to_midi(code_type, mod_file): code_mod = CODE_MODULES[code_type] mod = load_file(mod_file) subsongs = linearize_subsongs(mod, 1) volumes = [header.volume for header in mod.sample_headers] for idx, (_, rows) in enumerate(subsongs): notes = rows_to_mod_notes(rows, volumes) percussion = guess_percussive_instruments(mod, notes) pitches = {n.pitch_idx for n in notes if n.sample_idx not in percussion} min_pitch = min(pitches, default = 0) for n in notes: n.pitch_idx -= min_pitch code = list(code_mod.to_code(notes, percussion)) fmt = '%d notes, %d rows, %d tokens, %d ms/row, percussion %s' args = (len(notes), len(rows), len(code), notes[0].time_ms if notes else - 1, set(percussion)) SP.print(fmt % args) row_time = code_mod.estimate_row_time(code) notes = code_mod.to_notes(code, row_time) fname = Path('test-%02d.mid' % idx) notes_to_audio_file(notes, fname, CODE_MIDI_MAPPING, False)
def mod_file_to_piano_roll(file_path): SP.header('PARSING %s' % str(file_path)) try: mod = load_file(file_path) except PowerPackerModule: SP.print('PowerPacker module.') return None rows = linearize_rows(mod) volumes = [header.volume for header in mod.sample_headers] notes = rows_to_mod_notes(rows, volumes) props = sample_props(mod, notes) mat = notes_to_matrix(notes, props, len(rows)) SP.leave() return mat
def print_encoding_errors(errors): errors_per_type = sort_groupby(errors, lambda x: x[2][0]) for error_type, subsongs in errors_per_type: subsongs = list(subsongs) n_subsongs = len(subsongs) if error_type == ERR_DISSONANCE: header_part = 'WITH DISSONANCE' elif error_type == ERR_FEW_MEL_NOTES: header_part = 'WITH TO FEW MELODIC NOTES' elif error_type == ERR_PARSE_ERROR: header_part = 'WITH PARSE ERRORS' elif error_type == ERR_PITCH_RANGE: header_part = 'WITH TOO WIDE PITCH RANGES' elif error_type == ERR_FEW_NOTES: header_part = 'WITH TO FEW NOTES' elif error_type == ERR_FEW_UNIQUE_PITCHES: header_part = 'WITH TO FEW UNIQUE PITCHES' elif error_type == ERR_EXCESSIVE_PERCUSSION: header_part = 'WITH EXCESSIVE PERCUSSION' else: assert False SP.header('%d SUBSONGS %s' % (n_subsongs, header_part)) for name, idx, err in subsongs: if error_type == ERR_DISSONANCE: args = name, idx, err[1], err[2] fmt = '%-40s %3d %.2f %4d' elif error_type == ERR_FEW_MEL_NOTES: args = name, idx, err[1] fmt = '%-40s %3d %4d' elif error_type == ERR_PARSE_ERROR: args = name, idx, err[1] fmt = '%-40s %3d %s' elif error_type == ERR_PITCH_RANGE: args = name, idx, err[1] fmt = '%-40s %3d %2d' elif error_type == ERR_FEW_NOTES: args = name, idx, err[1] fmt = '%-40s %3d %4d' elif error_type == ERR_FEW_UNIQUE_PITCHES: args = name, idx, err[1] fmt = '%-40s %3d %4d' elif error_type == ERR_EXCESSIVE_PERCUSSION: args = name, idx, err[1], err[2] fmt = '%-40s %3d %4d %4d' else: assert False SP.print(fmt % args) SP.leave()
def upload_files(connection, files): SP.header('UPLOADING %d FILE(S)' % len(files)) for src, dst in sorted(files): SP.print('%-30s => %s' % (src, dst)) connection.put(str(src), str(dst)) SP.leave()
def main(): args = docopt(__doc__, version='The Mod Archive download tool 1.0') SP.enabled = args['--verbose'] corpus_path = Path(args['<corpus-path>']) corpus_path.mkdir(parents=True, exist_ok=True) if args['download']: format = args['--format'] kb_limit = int(args['--kb-limit']) download_mods(corpus_path, format, kb_limit) elif args['update-index']: genre_id = args['--genre-id'] n_random = args['--random'] module_id = args['--module-id'] index = load_index(corpus_path) if genre_id is not None: SP.header('GENRE', '%d', genre_id) mods = modules_for_genre(genre_id) elif module_id is not None: module_id = int(module_id) SP.header('MODULE', '%d', module_id) mods = [IndexedModule.from_modarchive(module_id)] elif n_random is not None: n_random = int(n_random) SP.header('RANDOM', '%d', n_random) mods = [ IndexedModule.from_modarchive_random() for _ in range(n_random) ] mods = [m for m in mods if m.id not in index] SP.leave() SP.header('%d ENTRIES' % len(mods)) for mod in mods: SP.print(short_line(mod)) SP.leave() for mod in mods: index[mod.id] = mod save_index(corpus_path, index) elif args['print-stats']: format = args['--format'] kb_limit = int(args['--kb-limit']) index = load_index(corpus_path) print_stats(index.values(), format, kb_limit)
def draw_plagiarism(lo_ngram, hi_ngram, code_gen, measurements, training_data): SP.print('Loading training data...') td = load_pickle(training_data) songs = td[1] # Strip names songs = [c for (n, c) in songs] SP.print('Flattening %d songs...' % len(songs)) tokens = flatten(flatten(flatten(songs))) tokens = np.array(tokens, dtype=np.uint16) SP.print('Loading samples...') data = load_pickle(measurements) stats = data[code_gen] gen = stats[False] seqs = list(gen.values()) seqs = [s[0] for s in seqs] n_samples = 1000 plag_ratios = {} for ngram in range(lo_ngram, hi_ngram): SP.header('FINDING MATCHES FOR NGRAMS OF LENGTH %d' % ngram) samples = [sample_seq(seqs, ngram) for _ in range(n_samples)] n_matches = find_samples(tokens, samples) frac = n_matches / n_samples SP.print('%d samples matches, %.2f%%.' % (n_matches, 100 * frac)) SP.leave() plag_ratios[ngram] = frac print(plag_ratios)
def rebuild_fn(): n = len(mods) SP.print('Shuffling %d mods.' % n) indices = list(range(n)) shuffle(indices) return indices
def mod_file_to_codes_w_progress(i, n, file_path, code_type): SP.header('[ %4d / %4d ] PARSING %s' % (i, n, file_path)) try: mod = load_file(file_path) except UnsupportedModule as e: SP.print('Unsupported module format.') SP.leave() err_arg = e.args[0] if e.args else e.__class__.__name__ yield False, 0, (ERR_PARSE_ERROR, err_arg) return code_mod = CODE_MODULES[code_type] subsongs = list(linearize_subsongs(mod, 1)) volumes = [header.volume for header in mod.sample_headers] parsed_subsongs = [] for idx, (order, rows) in enumerate(subsongs): SP.header('SUBSONG %d' % idx) notes = rows_to_mod_notes(rows, volumes) percussion = guess_percussive_instruments(mod, notes) if notes: fmt = '%d rows, %d ms/row, percussion %s, %d notes' args = (len(rows), notes[0].time_ms, set(percussion), len(notes)) SP.print(fmt % args) err = training_error(notes, percussion) if err: yield False, idx, err else: pitches = {n.pitch_idx for n in notes if n.sample_idx not in percussion} min_pitch = min(pitches, default = 0) # Subtract min pitch for n in notes: n.pitch_idx -= min_pitch code = list(code_mod.to_code(notes, percussion)) if code_mod.is_transposable(): codes = code_mod.code_transpositions(code) else: codes = [code] fmt = '%d transpositions of length %d' SP.print(fmt % (len(codes), len(code))) yield True, idx, codes SP.leave() SP.leave()
def guess_time_ms(mat): mat2 = mat[np.count_nonzero(mat == 1.0, axis = 1) > 0] zero_ratio = len(mat2) / len(mat) row_time = int(160 * zero_ratio) SP.print('Guessed row time %d ms.' % row_time) return row_time
def main(): # Prologue args = docopt(__doc__, version='Bulk generator 1.0') SP.enabled = args['--verbose'] root_path = Path(args['<root-path>']) # Kind of code gen_name = args['<generator>'] g = get_code_generator(gen_name) # Parse generation schedule n_prompt = int(args['--n-prompt']) n_generate = int(args['--n-generate']) n_clips = int(args['--n-clips']) if n_generate % 2 == 1 or n_prompt % 2 == 1: raise ValueError('The number of tokens in the prompt and ' 'the number of tokens to generate must ' 'be divisible by two.') # Load the generation schedule perp_path = root_path / 'perplexity' perp_path.mkdir(exist_ok=True) # Load the cached code _, td, _ = load_training_data(g['code-type'], root_path) if td.code_type == 'dcode': SP.print('Code type is dcode so halving generation sizes.') n_generate //= 2 n_prompt //= 2 n_frag = n_prompt + n_generate # We save the random indexes in a file so that the same bulk job # can be repeated using other code generators. schedule_name = 'schedule-%04d.pickle.gz' % n_clips schedule_path = perp_path / schedule_name def pickle_cache_fun(): return [random_rel_ofs(td, n_frag) for _ in range(n_clips)] offsets = load_pickle_cache(schedule_path, pickle_cache_fun) # Save all measurements in a datafile measurements_name = 'measurements-%04d.pickle.gz' % n_clips measurements_path = perp_path / measurements_name measurements = load_pickle_cache(measurements_path, lambda: {}) if not gen_name in measurements: measurements[gen_name] = {} use_original = args['--use-original'] if not use_original in measurements[gen_name]: measurements[gen_name][use_original] = {} offsets = [ o for o in offsets if not o in measurements[gen_name][use_original] ] SP.print('Measuring %d offsets...' % len(offsets)) # Splitting the load into chunks of 16. To many sequences at once # either exhausts the memory or times out Google Colab. job_size = 16 for i in range(0, len(offsets), job_size): job = offsets[i:i + job_size] log_probs = bulk_measure(g, root_path, job, td, n_prompt, n_generate, use_original) for offset, log_prob in zip(job, log_probs): logppl = -log_prob / n_generate measurements[gen_name][use_original][offset] = logppl save_pickle(measurements_path, measurements) for gen_name, cats in measurements.items(): SP.header(gen_name) for use_original, offsets in cats.items(): SP.header('%s' % use_original) for ofs, ppl in offsets.items(): SP.print('%s %.3f' % (ofs, ppl)) SP.leave() SP.leave()
def main(): # Prologue args = docopt(__doc__, version='Train MOD model 1.0') SP.enabled = args['--verbose'] root_path = Path(args['<root-path>']) # Kind of code g = get_code_generator(args['<model>']) # Load training data train, valid, test = load_training_data(g['code-type'], root_path) vocab_size = len(train.encoder.ix2ch) args = len(train.songs), len(valid.songs), len(test.songs) train_x, train_y = training_data_to_dataset(train, g['sequence-length'], g['batch-size']) valid_x, valid_y = training_data_to_dataset(valid, g['sequence-length'], g['batch-size']) # Load the training model model = load_training_model(g, vocab_size) weights_dir = root_path / 'weights' weights_dir.mkdir(exist_ok=True) weights_path = weights_dir / weights_file(g) if weights_path.exists(): SP.print('Loading weights from %s.' % weights_path) model.load_weights(str(weights_path)) else: SP.print('Weights file not found.') model.reset_states() model.summary() # Logging log_path = weights_dir / log_file(g) def log_losses(epoch, logs): with open(log_path, 'at') as outf: outf.write('%d %.5f %.5f\n' % (epoch, logs['loss'], logs['val_loss'])) cb_epoch_end = LambdaCallback(on_epoch_end=log_losses) cb_best = ModelCheckpoint(str(weights_path), monitor='val_loss', verbose=1, save_weights_only=True, save_best_only=True, mode='min') stopping = EarlyStopping(patience=30, verbose=1) reduce_lr = ReduceLROnPlateau(factor=0.2, patience=8, min_lr=g['learning-rate'] / 100, verbose=1) callbacks = [reduce_lr, cb_best, stopping, cb_epoch_end] SP.print('Batching samples...') model.fit(x=train_x, y=train_y, batch_size=g['batch-size'], epochs=200, callbacks=callbacks, verbose=1, validation_data=(valid_x, valid_y))
def print_hyper_params(): SP.header('INITIALIZING GPT2 MODEL') SP.print('Vocab size : %5d' % VOCAB_SIZE) SP.print('Hidden size: %5d' % HIDDEN_SIZE) SP.print('N layers : %5d' % N_LAYER) SP.print('N heads : %5d' % N_HEAD) SP.leave()