def init_history(ctx: Optional[dict]): history = { 'train': {'epoch': [], 'err': []}, 'val': {'epoch': [], 'err': [], 'sdr': [], 'sir': [], 'sar': []}} from_epoch = 0 if ctx: continue_training = get_ctx(ctx, 'continue_training') if continue_training == 'latest' or isinstance(continue_training, int): suffix_latest = 'latest.pth' from_epoch = torch.load('{}/epoch_{}'.format(get_ctx(ctx, 'path'), suffix_latest)) + 1 history = torch.load('{}/history_{}'.format(get_ctx(ctx, 'path'), suffix_latest)) if isinstance(continue_training, int): from_epoch = get_ctx(ctx, 'continue_training') for k in history: for k1 in history[k]: history[k][k1] = history[k][k1][:from_epoch] for step in get_ctx(ctx, 'lr_steps'): if step < from_epoch: adjust_learning_rate(ctx) elif continue_training != "": raise ValueError(f'invalid value in continue training {continue_training}') return history, from_epoch
def detach_mask(ctx, mask, binary): N = get_ctx(ctx, 'num_mix') for n in range(N): mask[n] = mask[n].detach().cpu().numpy() if binary: mask[n] = (mask[n] > get_ctx(ctx, 'mask_thres')).astype(np.float32) return mask
def save_nets(ctx, suffix): path = get_ctx(ctx, 'path') nets = get_underlying_nets(get_ctx(ctx, 'net_wrapper')) (net_sound, net_frame, net_synthesizer) = nets torch.save(net_sound.state_dict(), os.path.join(path, f'sound_{suffix}')) torch.save(net_frame.state_dict(), os.path.join(path, f'frame_{suffix}')) torch.save(net_synthesizer.state_dict(), os.path.join(path, f'synthesizer_{suffix}'))
def test(ctx: dict): ctx['load_best_model'] = True ctx['net_wrapper'] = build_model(ctx) ctx['num_mix'] = 1 dataset = MUSICMixDataset(get_ctx(ctx, 'list_test'), ctx, max_sample=get_ctx(ctx, 'num_test'), split='test') ctx['loader_test'] = torch.utils.data.DataLoader(dataset, batch_size=get_ctx(ctx, 'batch_size'), shuffle=False, num_workers=2, drop_last=False) with torch.set_grad_enabled(False): _test(ctx)
def unwarp_log_scale(ctx, arr): N = get_ctx(ctx, 'num_mix') B = arr[0].size(0) linear = [None for _ in range(N)] for n in range(N): if get_ctx(ctx, 'log_freq'): w = warpgrid(B, get_ctx(ctx, 'stft_frame') // 2 + 1, arr[0].size(3), warp=False) grid_unwarp = torch.from_numpy(w).to(get_ctx(ctx, 'device')) linear[n] = F.grid_sample(arr[n], grid_unwarp, align_corners=True) else: linear[n] = arr[n] return linear
def predict_bboxes(img): t = get_transform(False) input = t(img) input = input[None].to(get_ctx(ctx, 'device')) with torch.no_grad(): out = get_ctx(ctx, 'detector')(input)[0] boxes = torchvision.ops.nms(out['boxes'], out['scores'], 0.4) filtered_boxes = [] for i in boxes: box = out['boxes'][i] label = out['labels'][i] score = out['scores'][i] if label == 1 and score > 0.7: filtered_boxes.append(box) return filtered_boxes
def evaluate(ctx: dict): ctx['load_best_model'] = True ctx['net_wrapper'] = build_model(ctx) dataset_val = MUSICMixDataset(get_ctx(ctx, 'list_val'), ctx, max_sample=get_ctx(ctx, 'num_val'), split='val') ctx['loader_val'] = torch.utils.data.DataLoader( dataset_val, batch_size=get_ctx(ctx, 'batch_size'), shuffle=False, num_workers=2, drop_last=False) ctx['history'], _ = init_history(None) ctx['epoch'] = 0 with torch.set_grad_enabled(False): _evaluate(ctx)
def checkpoint(ctx: dict): epoch = get_ctx(ctx, 'epoch') history = get_ctx(ctx, 'history') path = get_ctx(ctx, 'path') print('Saving checkpoints at {} epochs.'.format(epoch)) suffix_latest = 'latest.pth' suffix_best = 'best.pth' torch.save(epoch, os.path.join(path, f'epoch_{suffix_latest}')) torch.save(history, os.path.join(path, f'history_{suffix_latest}')) save_nets(ctx, suffix_latest) cur_metrics = (history['val']['sdr'][-1] + history['val']['sir'][-1] + history['val']['sar'][-1]) / 3 if cur_metrics > get_ctx( ctx, 'best_metrics') and epoch % get_ctx(ctx, 'eval_epoch') == 0: print( f'Best model, epoch = {epoch}, mean metrics = {cur_metrics}, prev best = {get_ctx(ctx, "best_metrics")}' ) ctx['best_metrics'] = cur_metrics save_nets(ctx, suffix_best) if get_ctx(ctx, 'checkpoint_epoch') is not None and epoch % get_ctx( ctx, 'checkpoint_epoch') == 0: save_nets(ctx, f'{epoch}.pth')
def _forward_pixelwise(self, batch_data, ctx): mag_mix = batch_data['mag_mix'] frames = batch_data['frames'] mag_mix = mag_mix + 1e-10 bs = mag_mix.size(0) T = mag_mix.size(3) # 0.0 warp the spectrogram if get_ctx(ctx, 'log_freq'): grid_warp = torch.from_numpy(warpgrid(bs, 256, T, warp=True)).to( get_ctx(ctx, 'device')) mag_mix = F.grid_sample(mag_mix, grid_warp, align_corners=True) # LOG magnitude log_mag_mix = torch.log(mag_mix).detach() # 1. forward net_sound -> BxCxHxW feat_sound = self.net_sound(log_mag_mix) feat_sound = activate(feat_sound, get_ctx(ctx, 'sound_activation')) # 2. forward net_frame -> Bx1xC frames = frames[0] # num_mix == 1 feat_frames = self.net_frame.forward_multiframe(frames, pool=False) (B, C, T, H, W) = feat_frames.size() feat_frames = feat_frames.permute(0, 1, 3, 4, 2) feat_frames = feat_frames.reshape(B * C, H * W, T) feat_frames = F.adaptive_avg_pool1d(feat_frames, 1) feat_frames = feat_frames.view(B, C, H, W) feat_frames = activate(feat_frames, get_ctx(ctx, 'img_activation')) channels = feat_frames.detach().cpu().numpy() # 3. sound synthesizer pred_masks = self.net_synthesizer.forward_pixelwise( feat_frames, feat_sound) pred_masks = activate(pred_masks, get_ctx(ctx, 'output_activation')) return { 'pred_masks': pred_masks, 'processed_mag_mix': mag_mix, 'feat_frames_channels': channels }
def regions(ctx: dict): ctx['load_best_model'] = True ctx['net_wrapper'] = build_model(ctx) ctx['num_mix'] = 1 dataset = MUSICMixDataset(get_ctx(ctx, 'list_regions'), ctx, max_sample=get_ctx(ctx, 'num_regions'), split='regions') loader = torch.utils.data.DataLoader(dataset, batch_size=get_ctx(ctx, 'batch_size'), shuffle=True, num_workers=1, drop_last=False) makedirs(get_ctx(ctx, 'vis_regions'), remove=True) cnt = 0 with torch.no_grad(): for data in tqdm(loader): output = ctx['net_wrapper'].forward(data, ctx, pixelwise=True) output_predictions(ctx, data, output) cnt += len(data['audios'][0])
def create_optimizer(nets, ctx): (net_sound, net_frame, net_synthesizer) = nets param_groups = [{'params': net_sound.parameters(), 'lr': get_ctx(ctx, 'lr_sound')}, {'params': net_synthesizer.parameters(), 'lr': get_ctx(ctx, 'lr_synthesizer')}, {'params': net_frame.features.parameters(), 'lr': get_ctx(ctx, 'lr_frame')}, {'params': net_frame.fc.parameters(), 'lr': get_ctx(ctx, 'lr_frame')}] return torch.optim.SGD(param_groups, momentum=get_ctx(ctx, 'beta1'), weight_decay=get_ctx(ctx, 'weight_decay'))
def lr_range_test_for_part(ctx: dict, groups: List[int]): ctx['net_wrapper'] = build_model(ctx) ctx['optimizer'] = create_optimizer(get_underlying_nets(get_ctx(ctx, 'net_wrapper')), ctx) dataset_train = MUSICMixDataset(get_ctx(ctx, 'list_train'), ctx, split='train') ctx['loader_train'] = torch.utils.data.DataLoader( dataset_train, batch_size=get_ctx(ctx, 'batch_size'), shuffle=True, num_workers=int(get_ctx(ctx, 'workers')), drop_last=True) rates = [] losses = [] net_wrapper = get_ctx(ctx, 'net_wrapper') optimizer = get_ctx(ctx, 'optimizer') loader = get_ctx(ctx, 'loader_train') total = 1000 min_lr = 1e-10 max_lr = 1e1 smooth_f = 0.05 for step, batch_data in tqdm(enumerate(loader), total=total): if step == total: break it = step / total lr = np.exp((1 - it) * np.log(min_lr) + it * np.log(max_lr)) set_lr(optimizer, lr, groups) net_wrapper.zero_grad() loss, _ = net_wrapper.forward(batch_data, ctx) loss = loss.mean() # backward loss.backward() optimizer.step() if step > 0: loss = smooth_f * loss + (1 - smooth_f) * losses[-1] rates.append(lr) losses.append(loss.item()) return np.array(rates), np.array(losses)
def train_epoch(ctx: dict): net_wrapper = get_ctx(ctx, 'net_wrapper') optimizer = get_ctx(ctx, 'optimizer') loader = get_ctx(ctx, 'loader_train') history = get_ctx(ctx, 'history') epoch = get_ctx(ctx, 'epoch') batch_time = AverageMeter() data_time = AverageMeter() # switch to train mode net_wrapper.train() # main loop synchronize(ctx) tic = time.perf_counter() for i, batch_data in enumerate(loader): # measure data time synchronize(ctx) data_time.update(time.perf_counter() - tic) # forward pass net_wrapper.zero_grad() err, _ = net_wrapper.forward(batch_data, ctx) err = err.mean() # backward err.backward() optimizer.step() # measure total time synchronize(ctx) batch_time.update(time.perf_counter() - tic) tic = time.perf_counter() # display if i % get_ctx(ctx, 'disp_iter') == 0: print( f'{get_timestr()} Epoch: [{epoch}][{i}/{get_ctx(ctx, "epoch_iters")}],' f' Time: {batch_time.average():.2f}, Data: {data_time.average():.2f}, ' f'lr_sound: {get_ctx(ctx, "lr_sound")}, lr_frame: {get_ctx(ctx, "lr_frame")}, ' f'lr_synthesizer: {get_ctx(ctx, "lr_synthesizer")}, ' f'loss: {err.item():.4f}') fractional_epoch = epoch - 1 + 1. * i / get_ctx(ctx, 'epoch_iters') history['train']['epoch'].append(fractional_epoch) history['train']['err'].append(err.item())
def _test(ctx: dict): makedirs(get_ctx(ctx, 'vis_test'), remove=True) net_wrapper = get_ctx(ctx, 'net_wrapper') net_wrapper.eval() loader = get_ctx(ctx, 'loader_test') # initialize HTML header visualizer = HTMLVisualizer(os.path.join(get_ctx(ctx, 'vis_test'), 'index.html')) header = ['Filename', 'Input Audio'] for n in range(1, get_ctx(ctx, 'num_mix') + 1): header += [f'Predicted Audio {n:d}', f'Predicted Mask {n}'] visualizer.add_header(header) vis_rows = [] for i, batch_data in tqdm(enumerate(loader)): _, outputs = net_wrapper.forward(batch_data, ctx) if len(vis_rows) < get_ctx(ctx, 'num_vis'): output_visuals(vis_rows, batch_data, outputs, ctx) visualizer.add_rows(vis_rows) visualizer.write_html()
def to_device(ctx, data): if get_ctx(ctx, 'device').type != 'cpu' and len(get_ctx(ctx, 'gpu')) == 1: for k in data: if isinstance(data[k], torch.Tensor): data[k] = data[k].to(get_ctx(ctx, 'device'))
def build_model(ctx: dict): if get_ctx(ctx, 'load_best_model'): weights_sound = get_ctx(ctx, 'weights_sound_best') weights_frame = get_ctx(ctx, 'weights_frame_best') weights_synthesizer = get_ctx(ctx, 'weights_synthesizer_best') elif get_ctx(ctx, 'continue_training') == 'latest': weights_sound = get_ctx(ctx, 'weights_sound_latest') weights_frame = get_ctx(ctx, 'weights_frame_latest') weights_synthesizer = get_ctx(ctx, 'weights_synthesizer_latest') elif isinstance(get_ctx(ctx, 'continue_training'), int): ep = get_ctx(ctx, 'continue_training') weights_sound = get_ctx(ctx, f'weights_sound_{ep}') weights_frame = get_ctx(ctx, f'weights_frame_{ep}') weights_synthesizer = get_ctx(ctx, f'weights_synthesizer_{ep}') elif get_ctx(ctx, 'finetune'): weights_sound = get_ctx(ctx, 'weights_sound_finetune') weights_frame = get_ctx(ctx, 'weights_frame_finetune') weights_synthesizer = get_ctx(ctx, 'weights_synthesizer_finetune') else: weights_sound, weights_frame, weights_synthesizer = '', '', '' builder = ModelBuilder() net_sound = builder.build_sound( arch=get_ctx(ctx, 'arch_sound'), fc_dim=get_ctx(ctx, 'num_channels'), weights=weights_sound) net_frame = builder.build_frame( arch=get_ctx(ctx, 'arch_frame'), fc_dim=get_ctx(ctx, 'num_channels'), pool_type=get_ctx(ctx, 'img_pool'), weights=weights_frame) net_synthesizer = builder.build_synthesizer( arch=get_ctx(ctx, 'arch_synthesizer'), fc_dim=get_ctx(ctx, 'num_channels'), weights=weights_synthesizer) if get_ctx(ctx, 'finetune'): for param in net_sound.parameters(): param.requires_grad = False for param in net_frame.parameters(): param.requires_grad = False nets = (net_sound, net_frame, net_synthesizer) crit = builder.build_criterion(arch=get_ctx(ctx, 'loss')) net_wrapper = NetWrapper(nets, crit) if get_ctx(ctx, 'device').type != 'cpu': net_wrapper = nn.DataParallel(net_wrapper, device_ids=get_ctx(ctx, 'gpu')) # net_wrapper.to(get_ctx(ctx, 'device')) return net_wrapper
if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('mode', help='train/eval/regions/lr_range_test') parser.add_argument('--config', default='', help='Configuration in JSON file') args = parser.parse_args() config = read_config('configs/default.json') if args.config != '': config_override = read_config(args.config) config = {**config, **config_override} ctx = create_context(config) random.seed(get_ctx(ctx, 'seed')) np.random.seed(get_ctx(ctx, 'seed')) torch.manual_seed(get_ctx(ctx, 'seed')) if args.mode == 'train': train(ctx) elif args.mode == 'eval': evaluate(ctx) # elif args.mode == 'test': # test(ctx) elif args.mode == 'regions': regions(ctx) elif args.mode == 'lr_range_test': lr_range_test(ctx) else: raise RuntimeError(
def synchronize(ctx: dict): if get_ctx(ctx, 'device').type != 'cpu': torch.cuda.synchronize()
def main(): ctx['temp_dir'] = tempfile.mkdtemp() print(get_ctx(ctx, 'temp_dir')) sound_len = int( math.ceil(get_ctx(ctx, 'aud_len') / get_ctx(ctx, 'aud_rate'))) print(sound_len) audio_fname = os.path.join(get_ctx(ctx, 'temp_dir'), 'audio.mp3') frame_path = os.path.join(get_ctx(ctx, 'temp_dir'), 'frame') os.mkdir(frame_path) frame_low_path = os.path.join(get_ctx(ctx, 'temp_dir'), 'frame_low') os.mkdir(frame_low_path) extract_frames_ffmpeg(get_ctx(ctx, 'video_path'), frame_path, BASE_FPS, get_ctx(ctx, 'time'), sound_len) extract_frames_ffmpeg(get_ctx(ctx, 'video_path'), frame_low_path, get_ctx(ctx, 'frame_rate') + 1, get_ctx(ctx, 'time'), sound_len) extract_audio_ffmpeg(get_ctx(ctx, 'video_path'), audio_fname, get_ctx(ctx, 'time'), sound_len) img = get_img_for_detection(frame_low_path) img_width, img_height = img.size bboxes = predict_bboxes(img) data, out = predict_masks(audio_fname, frame_low_path) pred_masks = out['pred_masks'] _, grid_width, grid_height, mask_width, mask_height = out[ 'pred_masks'].size() # unwarp pred_masks_linear = torch.zeros( (grid_width, 1, grid_height, 512, 256)).to(get_ctx(ctx, 'device')) for h in range(grid_width): pred_masks_linear_h = unwarp_log_scale(ctx, [pred_masks[:, h, :, :, :]]) pred_masks_linear[h] = pred_masks_linear_h[0] pred_masks_linear = pred_masks_linear.permute(1, 0, 2, 3, 4) pred_masks_linear = detach_mask(ctx, [pred_masks_linear], get_ctx(ctx, 'binary_mask'))[0] frame_out_path = os.path.join(get_ctx(ctx, 'temp_dir'), 'frame_out') shutil.copytree(frame_path, frame_out_path) mag_mix = data['mag_mix'].numpy() phase_mix = data['phase_mix'].numpy() mix_wav = istft_reconstruction(mag_mix[0, 0], phase_mix[0, 0], hop_length=get_ctx(ctx, 'stft_hop')) mix_wav = resize_to_aud_len(mix_wav) wavfile.write(os.path.join(get_ctx(ctx, 'temp_dir'), 'mix.mp3'), get_ctx(ctx, 'aud_rate'), mix_wav) mask_boxes = get_mask_boxes(img, grid_width, grid_height) cell_area = area(0, 0, img_width // grid_width, img_height // grid_height) for box in bboxes: number = last_frame_number(frame_out_path) box_np = box.cpu().numpy() avg_mask = get_average_mask(box_np, mask_boxes, pred_masks_linear[0], cell_area, 512, 256, grid_width, grid_height) if get_ctx(ctx, 'binary_mask'): avg_mask = (avg_mask > get_ctx(ctx, 'mask_thres')).astype( np.float32) pred_mag = mag_mix[0, 0] * avg_mask preds_wav = istft_reconstruction(pred_mag, phase_mix[0, 0], hop_length=get_ctx(ctx, 'stft_hop')) preds_wav = resize_to_aud_len(preds_wav) wavfile.write( os.path.join(get_ctx(ctx, 'temp_dir'), f'{number:06d}.mp3'), get_ctx(ctx, 'aud_rate'), preds_wav) mix_wav = np.concatenate([mix_wav, preds_wav]) for frame_name in sorted(os.listdir(frame_path)): with Image.open(os.path.join(frame_out_path, frame_name)) as im: draw = ImageDraw.Draw(im) draw.rectangle(box.tolist(), outline='red') number += 1 im.save(os.path.join(frame_out_path, f'{number:06d}.jpg')) all_frames = np.array([ np.array(Image.open(os.path.join(frame_out_path, frame_name))) for frame_name in sorted(os.listdir(frame_out_path)) ]) wavfile.write(os.path.join(get_ctx(ctx, 'temp_dir'), 'full.mp3'), get_ctx(ctx, 'aud_rate'), mix_wav) save_video(os.path.join(get_ctx(ctx, 'temp_dir'), 'full.mp4'), all_frames, BASE_FPS) combine_video_audio(os.path.join(get_ctx(ctx, 'temp_dir'), 'full.mp4'), os.path.join(get_ctx(ctx, 'temp_dir'), 'full.mp3'), os.path.join(get_ctx(ctx, 'temp_dir'), 'result.mp4')) shutil.copy(os.path.join(get_ctx(ctx, 'temp_dir'), 'result.mp4'), get_ctx(ctx, 'output')) shutil.rmtree(ctx['temp_dir'])
def resize_to_aud_len(wav): if get_ctx(ctx, 'aud_len') > wav.shape[0]: return np.pad(wav, (0, get_ctx(ctx, 'aud_len') - wav.shape[0])) elif get_ctx(ctx, 'aud_len') < wav.shape[0]: return wav[:get_ctx(ctx, 'aud_len')] return wav
shutil.rmtree(ctx['temp_dir']) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('video', type=str) parser.add_argument('--time', type=int, required=True) parser.add_argument('--config', type=str, required=True) parser.add_argument('--output', type=str, default='result.mp4') args = parser.parse_args() config = read_config('configs/default.json') if args.config != '': config_override = read_config(args.config) config = {**config, **config_override} ctx = create_context(config) ctx['video_path'] = args.video ctx['time'] = args.time ctx['output'] = args.output model = torchvision.models.detection.fasterrcnn_resnet50_fpn( pretrained=True) ctx['detector'] = model.to(get_ctx(ctx, 'device')).eval() random.seed(get_ctx(ctx, 'seed')) np.random.seed(get_ctx(ctx, 'seed')) torch.manual_seed(get_ctx(ctx, 'seed')) main()
def _forward(self, batch_data, ctx): mag_mix = batch_data['mag_mix'] mags = batch_data['mags'] frames = batch_data['frames'] mag_mix = mag_mix + 1e-10 N = get_ctx(ctx, 'num_mix') B = mag_mix.size(0) T = mag_mix.size(3) # 0.0 warp the spectrogram if get_ctx(ctx, 'log_freq'): grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to( get_ctx(ctx, 'device')) mag_mix = F.grid_sample(mag_mix, grid_warp, align_corners=True) for n in range(N): mags[n] = F.grid_sample(mags[n], grid_warp, align_corners=True) # 0.1 calculate loss weighting coefficient: magnitude of input mixture if get_ctx(ctx, 'weighted_loss'): weight = torch.log1p(mag_mix) weight = torch.clamp(weight, 1e-3, 10) else: weight = torch.ones_like(mag_mix) # 0.2 ground truth masks are computed after warping! gt_masks = [None for n in range(N)] for n in range(N): if get_ctx(ctx, 'binary_mask'): # for simplicity, mag_N > 0.5 * mag_mix gt_masks[n] = (mags[n] > 0.5 * mag_mix).float() else: gt_masks[n] = mags[n] / mag_mix # clamp to avoid large numbers in ratio masks gt_masks[n].clamp_(0., 5.) # LOG magnitude log_mag_mix = torch.log(mag_mix).detach() # 1. forward net_sound -> BxCxHxW feat_sound = self.net_sound(log_mag_mix) feat_sound = activate(feat_sound, get_ctx(ctx, 'sound_activation')) # 2. forward net_frame -> Bx1xC feat_frames = [None for n in range(N)] for n in range(N): feat_frames[n] = self.net_frame.forward_multiframe(frames[n]) feat_frames[n] = activate(feat_frames[n], get_ctx(ctx, 'img_activation')) # 3. sound synthesizer pred_masks = [None for n in range(N)] for n in range(N): pred_masks[n] = self.net_synthesizer(feat_frames[n], feat_sound) pred_masks[n] = activate(pred_masks[n], get_ctx(ctx, 'output_activation')) # 4. loss err = self.crit(pred_masks, gt_masks, weight).reshape(1) return err, \ {'pred_masks': pred_masks, 'gt_masks': gt_masks, 'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
def __init__(self, list_sample, ctx, **kwargs): super(MUSICMixDataset, self).__init__( list_sample, ctx, **kwargs) self.fps = get_ctx(ctx, 'frame_rate') self.num_mix = get_ctx(ctx, 'num_mix')
def __init__(self, list_sample, ctx, max_sample=-1, split='train'): # params self.num_frames = get_ctx(ctx, 'num_frames') self.stride_frames = get_ctx(ctx, 'stride_frames') self.frame_rate = get_ctx(ctx, 'frame_rate') self.img_size = get_ctx(ctx, 'img_size') self.aud_rate = get_ctx(ctx, 'aud_rate') self.aud_len = get_ctx(ctx, 'aud_len') self.aud_sec = 1. * self.aud_len / self.aud_rate self.binary_mask = get_ctx(ctx, 'binary_mask') # STFT params self.log_freq = get_ctx(ctx, 'log_freq') self.stft_frame = get_ctx(ctx, 'stft_frame') self.stft_hop = get_ctx(ctx, 'stft_hop') self.HS = get_ctx(ctx, 'stft_frame') // 2 + 1 self.WS = (self.aud_len + 1) // self.stft_hop self.split = split # initialize video transform self._init_vtransform() # list_sample can be a python list or a csv file of list if isinstance(list_sample, str): # self.list_sample = [x.rstrip() for x in open(list_sample, 'r')] self.list_sample = [] with open(list_sample, 'r') as f: for row in csv.reader(f, delimiter=','): if len(row) < 2: continue self.list_sample.append(row) elif isinstance(list_sample, list): self.list_sample = list_sample else: raise RuntimeError('Error list_sample!') if self.split == 'train': self.list_sample *= get_ctx(ctx, 'dup_trainset') random.shuffle(self.list_sample) if max_sample > 0: self.list_sample = self.list_sample[0:max_sample] num_sample = len(self.list_sample) assert num_sample > 0 print('# samples: {}'.format(num_sample))
def adjust_learning_rate(ctx): ctx['lr_sound'] *= 0.1 ctx['lr_frame'] *= 0.1 ctx['lr_synthesizer'] *= 0.1 for param_group in get_ctx(ctx, 'optimizer').param_groups: param_group['lr'] *= 0.1
def output_predictions(ctx, data, outputs): mag_mix = data['mag_mix'] phase_mix = data['phase_mix'] frames = data['frames'] infos = data['infos'] pred_masks_ = outputs['pred_masks'] bs, im_h, im_w, _, _ = pred_masks_.shape # unwarp pred_masks_linear = torch.zeros((im_h, bs, im_w, 512, 256)).to(get_ctx(ctx, 'device')) for h in range(im_h): pred_masks_linear_h = unwarp_log_scale(ctx, [pred_masks_[:, h, :, :, :]]) pred_masks_linear[h] = pred_masks_linear_h[0] pred_masks_linear = pred_masks_linear.permute(1, 0, 2, 3, 4) # to cpu pred_masks_linear = detach_mask(ctx, [pred_masks_linear], get_ctx(ctx, 'binary_mask'))[0] pred_masks_ = detach_mask(ctx, [pred_masks_], get_ctx(ctx, 'binary_mask'))[0] mag_mix = mag_mix.numpy() phase_mix = phase_mix.numpy() frames = frames[0] for i in range(bs): frames_tensor = np.asarray([recover_rgb(frames[i, :, t].cpu()) for t in range(get_ctx(ctx, 'num_frames'))]) pth, id_ = os.path.split(infos[0][1][i]) _, group = os.path.split(pth) prefix = group + '-' + id_ folder = os.path.join(get_ctx(ctx, 'vis_regions'), prefix) sbr_folder = os.path.join(folder, 'sbr') grid_folder = os.path.join(sbr_folder, 'grid') makedirs(folder) makedirs(sbr_folder) makedirs(grid_folder) grid_pred_mask = np.zeros((14 * 256, 14 * 256)) for j in range(get_ctx(ctx, 'num_frames')): imwrite(os.path.join(folder, f'frame{j}.jpg'), frames_tensor[j]) mix_wav = istft_reconstruction(mag_mix[i, 0], phase_mix[i, 0], hop_length=get_ctx(ctx, 'stft_hop')) wavfile.write(os.path.join(folder, 'mix.wav'), get_ctx(ctx, 'aud_rate'), mix_wav) # SBR for h in range(im_h): for w in range(im_w): name = f'{h}x{w}' # output audio pred_mag = mag_mix[i, 0] * pred_masks_linear[i, h, w] preds_wav = istft_reconstruction(pred_mag, phase_mix[i, 0], hop_length=get_ctx(ctx, 'stft_hop')) wavfile.write(os.path.join(grid_folder, f'{name}-pred.wav'), get_ctx(ctx, 'aud_rate'), preds_wav) # output masks pred_mask = (np.clip(pred_masks_[i, h, w], 0, 1) * 255).astype(np.uint8) imwrite(os.path.join(grid_folder, f'{name}-predmask.jpg'), pred_mask[::-1, :]) grid_pred_mask[h * 256:(h + 1) * 256, w * 256:(w + 1) * 256] = pred_mask[::-1, :] # ouput spectrogram (log of magnitude, show colormap) pred_mag = magnitude2heatmap(pred_mag) imwrite(os.path.join(grid_folder, f'{name}-predamp.jpg'), pred_mag[::-1, :, :]) imwrite(os.path.join(sbr_folder, f'masks-grid.jpg'), grid_pred_mask) grid_frame = frames_tensor[0] grid_frame[:, np.arange(16, 224, 16)] = 255 grid_frame[np.arange(16, 224, 16), :] = 255 imwrite(os.path.join(sbr_folder, f'frame.jpg'), grid_frame) with open(os.path.join(sbr_folder, 'sbr.html'), 'w') as text_file: text_file.write(sbr_html)
def train(ctx: dict): ctx['net_wrapper'] = build_model(ctx) ctx['optimizer'] = create_optimizer( get_underlying_nets(get_ctx(ctx, 'net_wrapper')), ctx) dataset_train = MUSICMixDataset(get_ctx(ctx, 'list_train'), ctx, split='train') ctx['loader_train'] = torch.utils.data.DataLoader( dataset_train, batch_size=get_ctx(ctx, 'batch_size'), shuffle=True, num_workers=int(get_ctx(ctx, 'workers')), drop_last=True) ctx['epoch_iters'] = len(dataset_train) // get_ctx(ctx, 'batch_size') print(f'1 Epoch = {get_ctx(ctx, "epoch_iters")} iters') dataset_val = MUSICMixDataset(get_ctx(ctx, 'list_val'), ctx, max_sample=get_ctx(ctx, 'num_val'), split='val') ctx['loader_val'] = torch.utils.data.DataLoader(dataset_val, batch_size=get_ctx( ctx, 'batch_size'), shuffle=False, num_workers=2, drop_last=False) ctx['history'], from_epoch = init_history(ctx) if get_ctx(ctx, 'continue_training') == '': makedirs(get_ctx(ctx, 'path'), remove=True) for epoch in range(from_epoch, get_ctx(ctx, 'num_epoch') + 1): ctx['epoch'] = epoch with torch.set_grad_enabled(True): train_epoch(ctx) with torch.set_grad_enabled(False): if epoch % get_ctx(ctx, 'eval_epoch') == 0: _evaluate(ctx) checkpoint(ctx) # drop learning rate if epoch in get_ctx(ctx, 'lr_steps'): adjust_learning_rate(ctx) print('Training Done!')