def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_sound = builder.build_sound(arch=args.arch_sound,
                                    fc_dim=args.num_channels,
                                    weights=args.weights_sound)
    net_vision = builder.build_vision(arch=args.arch_vision,
                                      fc_dim=args.num_channels,
                                      weights=args.weights_vision)
    net_synthesizer = builder.build_synthesizer(
        arch=args.arch_synthesizer,
        fc_dim=args.num_channels,
        weights=args.weights_synthesizer)
    nets = (net_sound, net_vision, net_synthesizer)
    crit = builder.build_criterion(arch=args.loss)

    # Dataset and Loader
    dataset_train = MUSICMixDataset(args.list_train, args, split='train')
    dataset_val = MUSICMixDataset(args.list_val,
                                  args,
                                  max_sample=args.num_val,
                                  split='val')

    loader_train = make_data_loader(dataset_train, args)
    loader_val = make_data_loader(dataset_val, args)

    args.epoch_iters = len(dataset_train) // args.batch_size
    if args.mode == 'train':
        print('1 Epoch = {} iters'.format(args.epoch_iters))

    # Wrap networks
    netWrapper = NetWrapper(nets, crit)
    netWrapper = torch.nn.DataParallel(netWrapper,
                                       device_ids=range(args.num_gpus))
    netWrapper.to(args.device)

    # Set up optimizer
    optimizer = create_optimizer(nets, args)

    # History of performance
    history = {
        'train': {
            'epoch': [],
            'err': []
        },
        'val': {
            'epoch': [],
            'err': [],
            'sdr': [],
            'sir': [],
            'sar': []
        }
    }

    # Eval mode
    evaluate(netWrapper, loader_val, history, 0, args)
    if args.mode == 'eval':
        print('Evaluation Done!')
        return

    # Training loop
    for epoch in range(1, args.num_epoch + 1):
        train(netWrapper, loader_train, optimizer, history, epoch, args)

        # Evaluation and visualization
        if epoch % args.eval_epoch == 0:
            evaluate(netWrapper, loader_val, history, epoch, args)

            # checkpointing
            checkpoint(nets, history, epoch, args)

    print('Training Done!')
Esempio n. 2
0
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_sound_ground = builder.build_sound_ground(
        arch=args.arch_sound_ground, weights=args.weights_sound_ground)
    net_frame_ground = builder.build_frame_ground(
        arch=args.arch_frame_ground,
        pool_type=args.img_pool,
        weights=args.weights_frame_ground)
    net_sound = builder.build_sound(arch=args.arch_sound,
                                    fc_dim=args.num_channels,
                                    weights=args.weights_sound)
    net_frame = builder.build_frame(arch=args.arch_frame,
                                    fc_dim=args.num_channels,
                                    pool_type=args.img_pool,
                                    weights=args.weights_frame)
    net_synthesizer = builder.build_synthesizer(
        arch=args.arch_synthesizer,
        fc_dim=args.num_channels,
        weights=args.weights_synthesizer)
    net_grounding = builder.build_grounding(arch=args.arch_grounding,
                                            weights=args.weights_grounding)
    nets = (net_sound_ground, net_frame_ground, net_sound, net_frame,
            net_synthesizer, net_grounding)
    crit = builder.build_criterion(arch=args.loss)

    # Dataset and Loader
    dataset_train = MUSICMixDataset(args.list_train, args, split='train')
    dataset_val = MUSICMixDataset(args.list_val,
                                  args,
                                  max_sample=args.num_val,
                                  split=args.split)

    loader_train = torch.utils.data.DataLoader(dataset_train,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=int(args.workers),
                                               drop_last=True)
    loader_val = torch.utils.data.DataLoader(dataset_val,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=2,
                                             drop_last=False)
    args.epoch_iters = len(dataset_train) // args.batch_size
    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # Wrap networks
    netWrapper = NetWrapper(nets, crit)
    netWrapper = torch.nn.DataParallel(netWrapper,
                                       device_ids=range(args.num_gpus))
    netWrapper.to(args.device)

    # Set up optimizer
    optimizer = create_optimizer(nets, args)

    # History of peroformance
    history = {
        'train': {
            'epoch': [],
            'err': []
        },
        'val': {
            'epoch': [],
            'err': [],
            'sdr': [],
            'sir': [],
            'sar': []
        }
    }

    # Eval mode
    if args.mode == 'eval':
        args.testing = True
        evaluate(netWrapper, loader_val, history, 0, args)
        print('Evaluation Done!')
        return

    # Training loop
    for epoch in range(1, args.num_epoch + 1):
        train(netWrapper, loader_train, optimizer, history, epoch, args)

        # Evaluation and visualization
        if epoch % args.eval_epoch == 0:
            args.testing = True
            evaluate(netWrapper, loader_val, history, epoch, args)
            args.testing = False
            # checkpointing
            checkpoint(nets, history, epoch, args)

        # drop learning rate
        if epoch in args.lr_steps:
            adjust_learning_rate(optimizer, args)

    print('Training Done!')
Esempio n. 3
0
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_sound = builder.build_sound(arch=args.arch_sound,
                                    fc_dim=args.num_channels,
                                    weights=args.weights_sound)
    net_frame = builder.build_frame(arch=args.arch_frame,
                                    fc_dim=args.num_channels,
                                    pool_type=args.img_pool,
                                    weights=args.weights_frame)
    net_synthesizer = builder.build_synthesizer(
        arch=args.arch_synthesizer,
        fc_dim=args.num_channels,
        weights=args.weights_synthesizer)
    nets = (net_sound, net_frame, net_synthesizer)
    crit = builder.build_criterion(arch=args.loss)

    # Dataset and Loader
    # dataset_train = MUSICMixDataset(
    #     args.list_train, args, split='train')
    dataset_val = MUSICMixDataset(args.list_val,
                                  args,
                                  max_sample=args.num_val,
                                  split='val')

    # loader_train = torch.utils.data.DataLoader(
    #     dataset_train,
    #     batch_size=args.batch_size,
    #     shuffle=True,
    #     num_workers=int(args.workers),
    #     drop_last=True)
    loader_val = torch.utils.data.DataLoader(dataset_val,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=2,
                                             drop_last=False)
    # args.epoch_iters = len(dataset_train) // args.batch_size
    # print('1 Epoch = {} iters'.format(args.epoch_iters))

    # Wrap networks
    netWrapper = NetWrapper(nets, crit)
    netWrapper = torch.nn.DataParallel(netWrapper,
                                       device_ids=range(args.num_gpus))
    netWrapper.to(args.device)

    # Set up optimizer
    optimizer = create_optimizer(nets, args)

    # History of peroformance
    history = {
        'train': {
            'epoch': [],
            'err': []
        },
        'val': {
            'epoch': [],
            'err': [],
            'sdr': [],
            'sir': [],
            'sar': []
        }
    }

    # Eval mode
    evaluate(netWrapper, loader_val, history, 0, args)
    if args.mode == 'eval':
        print('Evaluation Done!')
        return