예제 #1
0
def init_history(ctx: Optional[dict]):
    history = {
        'train': {'epoch': [], 'err': []},
        'val': {'epoch': [], 'err': [], 'sdr': [], 'sir': [], 'sar': []}}
    from_epoch = 0
    if ctx:
        continue_training = get_ctx(ctx, 'continue_training')
        if continue_training == 'latest' or isinstance(continue_training, int):
            suffix_latest = 'latest.pth'
            from_epoch = torch.load('{}/epoch_{}'.format(get_ctx(ctx, 'path'), suffix_latest)) + 1
            history = torch.load('{}/history_{}'.format(get_ctx(ctx, 'path'), suffix_latest))

            if isinstance(continue_training, int):
                from_epoch = get_ctx(ctx, 'continue_training')
                for k in history:
                    for k1 in history[k]:
                        history[k][k1] = history[k][k1][:from_epoch]

            for step in get_ctx(ctx, 'lr_steps'):
                if step < from_epoch:
                    adjust_learning_rate(ctx)
        elif continue_training != "":
            raise ValueError(f'invalid value in continue training {continue_training}')

    return history, from_epoch
예제 #2
0
def detach_mask(ctx, mask, binary):
    N = get_ctx(ctx, 'num_mix')
    for n in range(N):
        mask[n] = mask[n].detach().cpu().numpy()
        if binary:
            mask[n] = (mask[n] > get_ctx(ctx, 'mask_thres')).astype(np.float32)

    return mask
예제 #3
0
def save_nets(ctx, suffix):
    path = get_ctx(ctx, 'path')
    nets = get_underlying_nets(get_ctx(ctx, 'net_wrapper'))
    (net_sound, net_frame, net_synthesizer) = nets

    torch.save(net_sound.state_dict(), os.path.join(path, f'sound_{suffix}'))
    torch.save(net_frame.state_dict(), os.path.join(path, f'frame_{suffix}'))
    torch.save(net_synthesizer.state_dict(),
               os.path.join(path, f'synthesizer_{suffix}'))
예제 #4
0
def test(ctx: dict):
    ctx['load_best_model'] = True
    ctx['net_wrapper'] = build_model(ctx)
    ctx['num_mix'] = 1

    dataset = MUSICMixDataset(get_ctx(ctx, 'list_test'), ctx, max_sample=get_ctx(ctx, 'num_test'), split='test')
    ctx['loader_test'] = torch.utils.data.DataLoader(dataset, batch_size=get_ctx(ctx, 'batch_size'),
                                                     shuffle=False, num_workers=2, drop_last=False)

    with torch.set_grad_enabled(False):
        _test(ctx)
예제 #5
0
def unwarp_log_scale(ctx, arr):
    N = get_ctx(ctx, 'num_mix')
    B = arr[0].size(0)
    linear = [None for _ in range(N)]

    for n in range(N):
        if get_ctx(ctx, 'log_freq'):
            w = warpgrid(B, get_ctx(ctx, 'stft_frame') // 2 + 1, arr[0].size(3), warp=False)
            grid_unwarp = torch.from_numpy(w).to(get_ctx(ctx, 'device'))
            linear[n] = F.grid_sample(arr[n], grid_unwarp, align_corners=True)
        else:
            linear[n] = arr[n]

    return linear
예제 #6
0
def predict_bboxes(img):
    t = get_transform(False)
    input = t(img)
    input = input[None].to(get_ctx(ctx, 'device'))
    with torch.no_grad():
        out = get_ctx(ctx, 'detector')(input)[0]

    boxes = torchvision.ops.nms(out['boxes'], out['scores'], 0.4)
    filtered_boxes = []
    for i in boxes:
        box = out['boxes'][i]
        label = out['labels'][i]
        score = out['scores'][i]
        if label == 1 and score > 0.7:
            filtered_boxes.append(box)

    return filtered_boxes
예제 #7
0
def evaluate(ctx: dict):
    ctx['load_best_model'] = True
    ctx['net_wrapper'] = build_model(ctx)

    dataset_val = MUSICMixDataset(get_ctx(ctx, 'list_val'), ctx, max_sample=get_ctx(ctx, 'num_val'), split='val')
    ctx['loader_val'] = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=get_ctx(ctx, 'batch_size'),
        shuffle=False,
        num_workers=2,
        drop_last=False)

    ctx['history'], _ = init_history(None)

    ctx['epoch'] = 0
    with torch.set_grad_enabled(False):
        _evaluate(ctx)
예제 #8
0
def checkpoint(ctx: dict):
    epoch = get_ctx(ctx, 'epoch')
    history = get_ctx(ctx, 'history')
    path = get_ctx(ctx, 'path')

    print('Saving checkpoints at {} epochs.'.format(epoch))
    suffix_latest = 'latest.pth'
    suffix_best = 'best.pth'

    torch.save(epoch, os.path.join(path, f'epoch_{suffix_latest}'))
    torch.save(history, os.path.join(path, f'history_{suffix_latest}'))
    save_nets(ctx, suffix_latest)

    cur_metrics = (history['val']['sdr'][-1] + history['val']['sir'][-1] +
                   history['val']['sar'][-1]) / 3
    if cur_metrics > get_ctx(
            ctx, 'best_metrics') and epoch % get_ctx(ctx, 'eval_epoch') == 0:
        print(
            f'Best model, epoch = {epoch}, mean metrics = {cur_metrics}, prev best = {get_ctx(ctx, "best_metrics")}'
        )
        ctx['best_metrics'] = cur_metrics
        save_nets(ctx, suffix_best)

    if get_ctx(ctx, 'checkpoint_epoch') is not None and epoch % get_ctx(
            ctx, 'checkpoint_epoch') == 0:
        save_nets(ctx, f'{epoch}.pth')
예제 #9
0
    def _forward_pixelwise(self, batch_data, ctx):
        mag_mix = batch_data['mag_mix']
        frames = batch_data['frames']
        mag_mix = mag_mix + 1e-10

        bs = mag_mix.size(0)
        T = mag_mix.size(3)

        # 0.0 warp the spectrogram
        if get_ctx(ctx, 'log_freq'):
            grid_warp = torch.from_numpy(warpgrid(bs, 256, T, warp=True)).to(
                get_ctx(ctx, 'device'))
            mag_mix = F.grid_sample(mag_mix, grid_warp, align_corners=True)

        # LOG magnitude
        log_mag_mix = torch.log(mag_mix).detach()

        # 1. forward net_sound -> BxCxHxW
        feat_sound = self.net_sound(log_mag_mix)
        feat_sound = activate(feat_sound, get_ctx(ctx, 'sound_activation'))

        # 2. forward net_frame -> Bx1xC
        frames = frames[0]  # num_mix == 1
        feat_frames = self.net_frame.forward_multiframe(frames, pool=False)

        (B, C, T, H, W) = feat_frames.size()
        feat_frames = feat_frames.permute(0, 1, 3, 4, 2)
        feat_frames = feat_frames.reshape(B * C, H * W, T)
        feat_frames = F.adaptive_avg_pool1d(feat_frames, 1)
        feat_frames = feat_frames.view(B, C, H, W)

        feat_frames = activate(feat_frames, get_ctx(ctx, 'img_activation'))

        channels = feat_frames.detach().cpu().numpy()

        # 3. sound synthesizer
        pred_masks = self.net_synthesizer.forward_pixelwise(
            feat_frames, feat_sound)
        pred_masks = activate(pred_masks, get_ctx(ctx, 'output_activation'))

        return {
            'pred_masks': pred_masks,
            'processed_mag_mix': mag_mix,
            'feat_frames_channels': channels
        }
예제 #10
0
def regions(ctx: dict):
    ctx['load_best_model'] = True
    ctx['net_wrapper'] = build_model(ctx)
    ctx['num_mix'] = 1

    dataset = MUSICMixDataset(get_ctx(ctx, 'list_regions'), ctx, max_sample=get_ctx(ctx, 'num_regions'),
                              split='regions')

    loader = torch.utils.data.DataLoader(dataset, batch_size=get_ctx(ctx, 'batch_size'), shuffle=True,
                                         num_workers=1, drop_last=False)

    makedirs(get_ctx(ctx, 'vis_regions'), remove=True)
    cnt = 0
    with torch.no_grad():
        for data in tqdm(loader):
            output = ctx['net_wrapper'].forward(data, ctx, pixelwise=True)
            output_predictions(ctx, data, output)
            cnt += len(data['audios'][0])
예제 #11
0
def create_optimizer(nets, ctx):
    (net_sound, net_frame, net_synthesizer) = nets
    param_groups = [{'params': net_sound.parameters(), 'lr': get_ctx(ctx, 'lr_sound')},
                    {'params': net_synthesizer.parameters(), 'lr': get_ctx(ctx, 'lr_synthesizer')},
                    {'params': net_frame.features.parameters(), 'lr': get_ctx(ctx, 'lr_frame')},
                    {'params': net_frame.fc.parameters(), 'lr': get_ctx(ctx, 'lr_frame')}]
    return torch.optim.SGD(param_groups, momentum=get_ctx(ctx, 'beta1'),
                           weight_decay=get_ctx(ctx, 'weight_decay'))
예제 #12
0
def lr_range_test_for_part(ctx: dict, groups: List[int]):
    ctx['net_wrapper'] = build_model(ctx)
    ctx['optimizer'] = create_optimizer(get_underlying_nets(get_ctx(ctx, 'net_wrapper')), ctx)
    dataset_train = MUSICMixDataset(get_ctx(ctx, 'list_train'), ctx, split='train')
    ctx['loader_train'] = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=get_ctx(ctx, 'batch_size'),
        shuffle=True,
        num_workers=int(get_ctx(ctx, 'workers')),
        drop_last=True)

    rates = []
    losses = []

    net_wrapper = get_ctx(ctx, 'net_wrapper')
    optimizer = get_ctx(ctx, 'optimizer')
    loader = get_ctx(ctx, 'loader_train')

    total = 1000
    min_lr = 1e-10
    max_lr = 1e1
    smooth_f = 0.05
    for step, batch_data in tqdm(enumerate(loader), total=total):
        if step == total:
            break

        it = step / total
        lr = np.exp((1 - it) * np.log(min_lr) + it * np.log(max_lr))

        set_lr(optimizer, lr, groups)

        net_wrapper.zero_grad()
        loss, _ = net_wrapper.forward(batch_data, ctx)
        loss = loss.mean()

        # backward
        loss.backward()
        optimizer.step()

        if step > 0:
            loss = smooth_f * loss + (1 - smooth_f) * losses[-1]
        rates.append(lr)
        losses.append(loss.item())

    return np.array(rates), np.array(losses)
예제 #13
0
def train_epoch(ctx: dict):
    net_wrapper = get_ctx(ctx, 'net_wrapper')
    optimizer = get_ctx(ctx, 'optimizer')
    loader = get_ctx(ctx, 'loader_train')
    history = get_ctx(ctx, 'history')
    epoch = get_ctx(ctx, 'epoch')

    batch_time = AverageMeter()
    data_time = AverageMeter()
    # switch to train mode
    net_wrapper.train()

    # main loop
    synchronize(ctx)
    tic = time.perf_counter()
    for i, batch_data in enumerate(loader):
        # measure data time
        synchronize(ctx)
        data_time.update(time.perf_counter() - tic)

        # forward pass
        net_wrapper.zero_grad()
        err, _ = net_wrapper.forward(batch_data, ctx)
        err = err.mean()

        # backward
        err.backward()
        optimizer.step()

        # measure total time
        synchronize(ctx)
        batch_time.update(time.perf_counter() - tic)
        tic = time.perf_counter()

        # display
        if i % get_ctx(ctx, 'disp_iter') == 0:
            print(
                f'{get_timestr()} Epoch: [{epoch}][{i}/{get_ctx(ctx, "epoch_iters")}],'
                f' Time: {batch_time.average():.2f}, Data: {data_time.average():.2f}, '
                f'lr_sound: {get_ctx(ctx, "lr_sound")}, lr_frame: {get_ctx(ctx, "lr_frame")}, '
                f'lr_synthesizer: {get_ctx(ctx, "lr_synthesizer")}, '
                f'loss: {err.item():.4f}')
            fractional_epoch = epoch - 1 + 1. * i / get_ctx(ctx, 'epoch_iters')
            history['train']['epoch'].append(fractional_epoch)
            history['train']['err'].append(err.item())
예제 #14
0
def _test(ctx: dict):
    makedirs(get_ctx(ctx, 'vis_test'), remove=True)

    net_wrapper = get_ctx(ctx, 'net_wrapper')
    net_wrapper.eval()
    loader = get_ctx(ctx, 'loader_test')

    # initialize HTML header
    visualizer = HTMLVisualizer(os.path.join(get_ctx(ctx, 'vis_test'), 'index.html'))
    header = ['Filename', 'Input Audio']
    for n in range(1, get_ctx(ctx, 'num_mix') + 1):
        header += [f'Predicted Audio {n:d}', f'Predicted Mask {n}']
    visualizer.add_header(header)
    vis_rows = []

    for i, batch_data in tqdm(enumerate(loader)):
        _, outputs = net_wrapper.forward(batch_data, ctx)

        if len(vis_rows) < get_ctx(ctx, 'num_vis'):
            output_visuals(vis_rows, batch_data, outputs, ctx)

    visualizer.add_rows(vis_rows)
    visualizer.write_html()
예제 #15
0
def to_device(ctx, data):
    if get_ctx(ctx, 'device').type != 'cpu' and len(get_ctx(ctx, 'gpu')) == 1:
        for k in data:
            if isinstance(data[k], torch.Tensor):
                data[k] = data[k].to(get_ctx(ctx, 'device'))
예제 #16
0
def build_model(ctx: dict):
    if get_ctx(ctx, 'load_best_model'):
        weights_sound = get_ctx(ctx, 'weights_sound_best')
        weights_frame = get_ctx(ctx, 'weights_frame_best')
        weights_synthesizer = get_ctx(ctx, 'weights_synthesizer_best')
    elif get_ctx(ctx, 'continue_training') == 'latest':
        weights_sound = get_ctx(ctx, 'weights_sound_latest')
        weights_frame = get_ctx(ctx, 'weights_frame_latest')
        weights_synthesizer = get_ctx(ctx, 'weights_synthesizer_latest')
    elif isinstance(get_ctx(ctx, 'continue_training'), int):
        ep = get_ctx(ctx, 'continue_training')
        weights_sound = get_ctx(ctx, f'weights_sound_{ep}')
        weights_frame = get_ctx(ctx, f'weights_frame_{ep}')
        weights_synthesizer = get_ctx(ctx, f'weights_synthesizer_{ep}')
    elif get_ctx(ctx, 'finetune'):
        weights_sound = get_ctx(ctx, 'weights_sound_finetune')
        weights_frame = get_ctx(ctx, 'weights_frame_finetune')
        weights_synthesizer = get_ctx(ctx, 'weights_synthesizer_finetune')
    else:
        weights_sound, weights_frame, weights_synthesizer = '', '', ''

    builder = ModelBuilder()
    net_sound = builder.build_sound(
        arch=get_ctx(ctx, 'arch_sound'),
        fc_dim=get_ctx(ctx, 'num_channels'),
        weights=weights_sound)
    net_frame = builder.build_frame(
        arch=get_ctx(ctx, 'arch_frame'),
        fc_dim=get_ctx(ctx, 'num_channels'),
        pool_type=get_ctx(ctx, 'img_pool'),
        weights=weights_frame)
    net_synthesizer = builder.build_synthesizer(
        arch=get_ctx(ctx, 'arch_synthesizer'),
        fc_dim=get_ctx(ctx, 'num_channels'),
        weights=weights_synthesizer)

    if get_ctx(ctx, 'finetune'):
        for param in net_sound.parameters():
            param.requires_grad = False
        for param in net_frame.parameters():
            param.requires_grad = False
    nets = (net_sound, net_frame, net_synthesizer)
    crit = builder.build_criterion(arch=get_ctx(ctx, 'loss'))
    net_wrapper = NetWrapper(nets, crit)
    if get_ctx(ctx, 'device').type != 'cpu':
        net_wrapper = nn.DataParallel(net_wrapper, device_ids=get_ctx(ctx, 'gpu'))
        # net_wrapper.to(get_ctx(ctx, 'device'))

    return net_wrapper
예제 #17
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('mode', help='train/eval/regions/lr_range_test')
    parser.add_argument('--config',
                        default='',
                        help='Configuration in JSON file')
    args = parser.parse_args()

    config = read_config('configs/default.json')
    if args.config != '':
        config_override = read_config(args.config)
        config = {**config, **config_override}

    ctx = create_context(config)

    random.seed(get_ctx(ctx, 'seed'))
    np.random.seed(get_ctx(ctx, 'seed'))
    torch.manual_seed(get_ctx(ctx, 'seed'))

    if args.mode == 'train':
        train(ctx)
    elif args.mode == 'eval':
        evaluate(ctx)
    # elif args.mode == 'test':
    #     test(ctx)
    elif args.mode == 'regions':
        regions(ctx)
    elif args.mode == 'lr_range_test':
        lr_range_test(ctx)
    else:
        raise RuntimeError(
예제 #18
0
def synchronize(ctx: dict):
    if get_ctx(ctx, 'device').type != 'cpu':
        torch.cuda.synchronize()
예제 #19
0
def main():
    ctx['temp_dir'] = tempfile.mkdtemp()
    print(get_ctx(ctx, 'temp_dir'))

    sound_len = int(
        math.ceil(get_ctx(ctx, 'aud_len') / get_ctx(ctx, 'aud_rate')))
    print(sound_len)

    audio_fname = os.path.join(get_ctx(ctx, 'temp_dir'), 'audio.mp3')
    frame_path = os.path.join(get_ctx(ctx, 'temp_dir'), 'frame')
    os.mkdir(frame_path)
    frame_low_path = os.path.join(get_ctx(ctx, 'temp_dir'), 'frame_low')
    os.mkdir(frame_low_path)

    extract_frames_ffmpeg(get_ctx(ctx, 'video_path'), frame_path, BASE_FPS,
                          get_ctx(ctx, 'time'), sound_len)
    extract_frames_ffmpeg(get_ctx(ctx, 'video_path'), frame_low_path,
                          get_ctx(ctx, 'frame_rate') + 1, get_ctx(ctx, 'time'),
                          sound_len)
    extract_audio_ffmpeg(get_ctx(ctx, 'video_path'), audio_fname,
                         get_ctx(ctx, 'time'), sound_len)

    img = get_img_for_detection(frame_low_path)
    img_width, img_height = img.size
    bboxes = predict_bboxes(img)

    data, out = predict_masks(audio_fname, frame_low_path)
    pred_masks = out['pred_masks']
    _, grid_width, grid_height, mask_width, mask_height = out[
        'pred_masks'].size()

    # unwarp
    pred_masks_linear = torch.zeros(
        (grid_width, 1, grid_height, 512, 256)).to(get_ctx(ctx, 'device'))
    for h in range(grid_width):
        pred_masks_linear_h = unwarp_log_scale(ctx,
                                               [pred_masks[:, h, :, :, :]])
        pred_masks_linear[h] = pred_masks_linear_h[0]
    pred_masks_linear = pred_masks_linear.permute(1, 0, 2, 3, 4)
    pred_masks_linear = detach_mask(ctx, [pred_masks_linear],
                                    get_ctx(ctx, 'binary_mask'))[0]

    frame_out_path = os.path.join(get_ctx(ctx, 'temp_dir'), 'frame_out')
    shutil.copytree(frame_path, frame_out_path)

    mag_mix = data['mag_mix'].numpy()
    phase_mix = data['phase_mix'].numpy()
    mix_wav = istft_reconstruction(mag_mix[0, 0],
                                   phase_mix[0, 0],
                                   hop_length=get_ctx(ctx, 'stft_hop'))
    mix_wav = resize_to_aud_len(mix_wav)
    wavfile.write(os.path.join(get_ctx(ctx, 'temp_dir'), 'mix.mp3'),
                  get_ctx(ctx, 'aud_rate'), mix_wav)

    mask_boxes = get_mask_boxes(img, grid_width, grid_height)
    cell_area = area(0, 0, img_width // grid_width, img_height // grid_height)
    for box in bboxes:
        number = last_frame_number(frame_out_path)
        box_np = box.cpu().numpy()
        avg_mask = get_average_mask(box_np, mask_boxes, pred_masks_linear[0],
                                    cell_area, 512, 256, grid_width,
                                    grid_height)
        if get_ctx(ctx, 'binary_mask'):
            avg_mask = (avg_mask > get_ctx(ctx, 'mask_thres')).astype(
                np.float32)

        pred_mag = mag_mix[0, 0] * avg_mask
        preds_wav = istft_reconstruction(pred_mag,
                                         phase_mix[0, 0],
                                         hop_length=get_ctx(ctx, 'stft_hop'))
        preds_wav = resize_to_aud_len(preds_wav)
        wavfile.write(
            os.path.join(get_ctx(ctx, 'temp_dir'), f'{number:06d}.mp3'),
            get_ctx(ctx, 'aud_rate'), preds_wav)
        mix_wav = np.concatenate([mix_wav, preds_wav])

        for frame_name in sorted(os.listdir(frame_path)):
            with Image.open(os.path.join(frame_out_path, frame_name)) as im:
                draw = ImageDraw.Draw(im)
                draw.rectangle(box.tolist(), outline='red')

                number += 1
                im.save(os.path.join(frame_out_path, f'{number:06d}.jpg'))

    all_frames = np.array([
        np.array(Image.open(os.path.join(frame_out_path, frame_name)))
        for frame_name in sorted(os.listdir(frame_out_path))
    ])
    wavfile.write(os.path.join(get_ctx(ctx, 'temp_dir'), 'full.mp3'),
                  get_ctx(ctx, 'aud_rate'), mix_wav)
    save_video(os.path.join(get_ctx(ctx, 'temp_dir'), 'full.mp4'), all_frames,
               BASE_FPS)
    combine_video_audio(os.path.join(get_ctx(ctx, 'temp_dir'), 'full.mp4'),
                        os.path.join(get_ctx(ctx, 'temp_dir'), 'full.mp3'),
                        os.path.join(get_ctx(ctx, 'temp_dir'), 'result.mp4'))

    shutil.copy(os.path.join(get_ctx(ctx, 'temp_dir'), 'result.mp4'),
                get_ctx(ctx, 'output'))

    shutil.rmtree(ctx['temp_dir'])
예제 #20
0
def resize_to_aud_len(wav):
    if get_ctx(ctx, 'aud_len') > wav.shape[0]:
        return np.pad(wav, (0, get_ctx(ctx, 'aud_len') - wav.shape[0]))
    elif get_ctx(ctx, 'aud_len') < wav.shape[0]:
        return wav[:get_ctx(ctx, 'aud_len')]
    return wav
예제 #21
0
    shutil.rmtree(ctx['temp_dir'])


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('video', type=str)
    parser.add_argument('--time', type=int, required=True)
    parser.add_argument('--config', type=str, required=True)
    parser.add_argument('--output', type=str, default='result.mp4')
    args = parser.parse_args()

    config = read_config('configs/default.json')
    if args.config != '':
        config_override = read_config(args.config)
        config = {**config, **config_override}

    ctx = create_context(config)
    ctx['video_path'] = args.video
    ctx['time'] = args.time
    ctx['output'] = args.output

    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
        pretrained=True)
    ctx['detector'] = model.to(get_ctx(ctx, 'device')).eval()

    random.seed(get_ctx(ctx, 'seed'))
    np.random.seed(get_ctx(ctx, 'seed'))
    torch.manual_seed(get_ctx(ctx, 'seed'))

    main()
예제 #22
0
    def _forward(self, batch_data, ctx):
        mag_mix = batch_data['mag_mix']
        mags = batch_data['mags']
        frames = batch_data['frames']
        mag_mix = mag_mix + 1e-10

        N = get_ctx(ctx, 'num_mix')
        B = mag_mix.size(0)
        T = mag_mix.size(3)

        # 0.0 warp the spectrogram
        if get_ctx(ctx, 'log_freq'):
            grid_warp = torch.from_numpy(warpgrid(B, 256, T, warp=True)).to(
                get_ctx(ctx, 'device'))
            mag_mix = F.grid_sample(mag_mix, grid_warp, align_corners=True)
            for n in range(N):
                mags[n] = F.grid_sample(mags[n], grid_warp, align_corners=True)

        # 0.1 calculate loss weighting coefficient: magnitude of input mixture
        if get_ctx(ctx, 'weighted_loss'):
            weight = torch.log1p(mag_mix)
            weight = torch.clamp(weight, 1e-3, 10)
        else:
            weight = torch.ones_like(mag_mix)

        # 0.2 ground truth masks are computed after warping!
        gt_masks = [None for n in range(N)]
        for n in range(N):
            if get_ctx(ctx, 'binary_mask'):
                # for simplicity, mag_N > 0.5 * mag_mix
                gt_masks[n] = (mags[n] > 0.5 * mag_mix).float()
            else:
                gt_masks[n] = mags[n] / mag_mix
                # clamp to avoid large numbers in ratio masks
                gt_masks[n].clamp_(0., 5.)

        # LOG magnitude
        log_mag_mix = torch.log(mag_mix).detach()

        # 1. forward net_sound -> BxCxHxW
        feat_sound = self.net_sound(log_mag_mix)
        feat_sound = activate(feat_sound, get_ctx(ctx, 'sound_activation'))

        # 2. forward net_frame -> Bx1xC
        feat_frames = [None for n in range(N)]
        for n in range(N):
            feat_frames[n] = self.net_frame.forward_multiframe(frames[n])
            feat_frames[n] = activate(feat_frames[n],
                                      get_ctx(ctx, 'img_activation'))

        # 3. sound synthesizer
        pred_masks = [None for n in range(N)]
        for n in range(N):
            pred_masks[n] = self.net_synthesizer(feat_frames[n], feat_sound)
            pred_masks[n] = activate(pred_masks[n],
                                     get_ctx(ctx, 'output_activation'))

        # 4. loss
        err = self.crit(pred_masks, gt_masks, weight).reshape(1)

        return err, \
            {'pred_masks': pred_masks, 'gt_masks': gt_masks,
             'mag_mix': mag_mix, 'mags': mags, 'weight': weight}
예제 #23
0
 def __init__(self, list_sample, ctx, **kwargs):
     super(MUSICMixDataset, self).__init__(
         list_sample, ctx, **kwargs)
     self.fps = get_ctx(ctx, 'frame_rate')
     self.num_mix = get_ctx(ctx, 'num_mix')
예제 #24
0
    def __init__(self, list_sample, ctx, max_sample=-1, split='train'):
        # params
        self.num_frames = get_ctx(ctx, 'num_frames')
        self.stride_frames = get_ctx(ctx, 'stride_frames')
        self.frame_rate = get_ctx(ctx, 'frame_rate')
        self.img_size = get_ctx(ctx, 'img_size')
        self.aud_rate = get_ctx(ctx, 'aud_rate')
        self.aud_len = get_ctx(ctx, 'aud_len')
        self.aud_sec = 1. * self.aud_len / self.aud_rate
        self.binary_mask = get_ctx(ctx, 'binary_mask')

        # STFT params
        self.log_freq = get_ctx(ctx, 'log_freq')
        self.stft_frame = get_ctx(ctx, 'stft_frame')
        self.stft_hop = get_ctx(ctx, 'stft_hop')
        self.HS = get_ctx(ctx, 'stft_frame') // 2 + 1
        self.WS = (self.aud_len + 1) // self.stft_hop

        self.split = split

        # initialize video transform
        self._init_vtransform()

        # list_sample can be a python list or a csv file of list
        if isinstance(list_sample, str):
            # self.list_sample = [x.rstrip() for x in open(list_sample, 'r')]
            self.list_sample = []
            with open(list_sample, 'r') as f:
                for row in csv.reader(f, delimiter=','):
                    if len(row) < 2:
                        continue
                    self.list_sample.append(row)
        elif isinstance(list_sample, list):
            self.list_sample = list_sample
        else:
            raise RuntimeError('Error list_sample!')

        if self.split == 'train':
            self.list_sample *= get_ctx(ctx, 'dup_trainset')
            random.shuffle(self.list_sample)

        if max_sample > 0:
            self.list_sample = self.list_sample[0:max_sample]

        num_sample = len(self.list_sample)
        assert num_sample > 0
        print('# samples: {}'.format(num_sample))
예제 #25
0
def adjust_learning_rate(ctx):
    ctx['lr_sound'] *= 0.1
    ctx['lr_frame'] *= 0.1
    ctx['lr_synthesizer'] *= 0.1
    for param_group in get_ctx(ctx, 'optimizer').param_groups:
        param_group['lr'] *= 0.1
예제 #26
0
def output_predictions(ctx, data, outputs):
    mag_mix = data['mag_mix']
    phase_mix = data['phase_mix']
    frames = data['frames']
    infos = data['infos']
    pred_masks_ = outputs['pred_masks']

    bs, im_h, im_w, _, _ = pred_masks_.shape

    # unwarp
    pred_masks_linear = torch.zeros((im_h, bs, im_w, 512, 256)).to(get_ctx(ctx, 'device'))
    for h in range(im_h):
        pred_masks_linear_h = unwarp_log_scale(ctx, [pred_masks_[:, h, :, :, :]])
        pred_masks_linear[h] = pred_masks_linear_h[0]
    pred_masks_linear = pred_masks_linear.permute(1, 0, 2, 3, 4)

    # to cpu
    pred_masks_linear = detach_mask(ctx, [pred_masks_linear], get_ctx(ctx, 'binary_mask'))[0]
    pred_masks_ = detach_mask(ctx, [pred_masks_], get_ctx(ctx, 'binary_mask'))[0]
    mag_mix = mag_mix.numpy()
    phase_mix = phase_mix.numpy()
    frames = frames[0]

    for i in range(bs):
        frames_tensor = np.asarray([recover_rgb(frames[i, :, t].cpu()) for t in range(get_ctx(ctx, 'num_frames'))])

        pth, id_ = os.path.split(infos[0][1][i])
        _, group = os.path.split(pth)
        prefix = group + '-' + id_
        folder = os.path.join(get_ctx(ctx, 'vis_regions'), prefix)
        sbr_folder = os.path.join(folder, 'sbr')
        grid_folder = os.path.join(sbr_folder, 'grid')

        makedirs(folder)
        makedirs(sbr_folder)
        makedirs(grid_folder)

        grid_pred_mask = np.zeros((14 * 256, 14 * 256))

        for j in range(get_ctx(ctx, 'num_frames')):
            imwrite(os.path.join(folder, f'frame{j}.jpg'), frames_tensor[j])

        mix_wav = istft_reconstruction(mag_mix[i, 0], phase_mix[i, 0], hop_length=get_ctx(ctx, 'stft_hop'))
        wavfile.write(os.path.join(folder, 'mix.wav'), get_ctx(ctx, 'aud_rate'), mix_wav)

        # SBR
        for h in range(im_h):
            for w in range(im_w):
                name = f'{h}x{w}'

                # output audio
                pred_mag = mag_mix[i, 0] * pred_masks_linear[i, h, w]
                preds_wav = istft_reconstruction(pred_mag, phase_mix[i, 0], hop_length=get_ctx(ctx, 'stft_hop'))
                wavfile.write(os.path.join(grid_folder, f'{name}-pred.wav'), get_ctx(ctx, 'aud_rate'), preds_wav)

                # output masks
                pred_mask = (np.clip(pred_masks_[i, h, w], 0, 1) * 255).astype(np.uint8)
                imwrite(os.path.join(grid_folder, f'{name}-predmask.jpg'), pred_mask[::-1, :])
                grid_pred_mask[h * 256:(h + 1) * 256, w * 256:(w + 1) * 256] = pred_mask[::-1, :]

                # ouput spectrogram (log of magnitude, show colormap)
                pred_mag = magnitude2heatmap(pred_mag)
                imwrite(os.path.join(grid_folder, f'{name}-predamp.jpg'), pred_mag[::-1, :, :])

        imwrite(os.path.join(sbr_folder, f'masks-grid.jpg'), grid_pred_mask)

        grid_frame = frames_tensor[0]
        grid_frame[:, np.arange(16, 224, 16)] = 255
        grid_frame[np.arange(16, 224, 16), :] = 255
        imwrite(os.path.join(sbr_folder, f'frame.jpg'), grid_frame)

        with open(os.path.join(sbr_folder, 'sbr.html'), 'w') as text_file:
            text_file.write(sbr_html)
예제 #27
0
def train(ctx: dict):
    ctx['net_wrapper'] = build_model(ctx)
    ctx['optimizer'] = create_optimizer(
        get_underlying_nets(get_ctx(ctx, 'net_wrapper')), ctx)

    dataset_train = MUSICMixDataset(get_ctx(ctx, 'list_train'),
                                    ctx,
                                    split='train')
    ctx['loader_train'] = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=get_ctx(ctx, 'batch_size'),
        shuffle=True,
        num_workers=int(get_ctx(ctx, 'workers')),
        drop_last=True)

    ctx['epoch_iters'] = len(dataset_train) // get_ctx(ctx, 'batch_size')
    print(f'1 Epoch = {get_ctx(ctx, "epoch_iters")} iters')

    dataset_val = MUSICMixDataset(get_ctx(ctx, 'list_val'),
                                  ctx,
                                  max_sample=get_ctx(ctx, 'num_val'),
                                  split='val')
    ctx['loader_val'] = torch.utils.data.DataLoader(dataset_val,
                                                    batch_size=get_ctx(
                                                        ctx, 'batch_size'),
                                                    shuffle=False,
                                                    num_workers=2,
                                                    drop_last=False)

    ctx['history'], from_epoch = init_history(ctx)
    if get_ctx(ctx, 'continue_training') == '':
        makedirs(get_ctx(ctx, 'path'), remove=True)

    for epoch in range(from_epoch, get_ctx(ctx, 'num_epoch') + 1):
        ctx['epoch'] = epoch

        with torch.set_grad_enabled(True):
            train_epoch(ctx)

        with torch.set_grad_enabled(False):
            if epoch % get_ctx(ctx, 'eval_epoch') == 0:
                _evaluate(ctx)
            checkpoint(ctx)

        # drop learning rate
        if epoch in get_ctx(ctx, 'lr_steps'):
            adjust_learning_rate(ctx)

    print('Training Done!')