示例#1
0
def main(args):
    builder = ModelBuilder(args)
    net_sound = builder.build_sound(arch=args.arch_sound, fc_dim=args.num_channels, weights=args.weights_sound)
    net_frame = builder.build_frame(arch=args.arch_frame, fc_dim=args.num_channels, pool_type=args.img_pool,
                                    weights=args.weights_frame)
    net_motion = builder.build_pretrained_3Dresnet_50(args.motion_path)

    nets = (net_sound, net_frame, net_motion)
    crit = builder.build_criterion(arch=args.loss)

    dataset_train = Music21_dataset(args, split='train')
    dataset_val = Music21_dataset(args, split='val')
    sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=False,
        sampler=sampler,
        num_workers=int(8),
        drop_last=True,
        pin_memory=True)

    loader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=int(8),
        drop_last=False,
        pin_memory=True)

    args.epoch_iters = len(dataset_train) // (args.batch_size * args.num_gpus)
    optimizer = _create_optimizer(nets, args)
    netwrapper = NetWrapper(nets, crit)
    netwrapper = nn.parallel.DistributedDataParallel(netwrapper.cuda(), device_ids=[args.local_rank],
                                                     output_device=args.local_rank, find_unused_parameters=True)
    total = sum([param.nelement() if param.requires_grad else 0 for param in netwrapper.parameters()])

    print("Number of parameter: %.2fM" % (total / 1e6))
    if args.reuse != 'None':
        map_location = {'cuda:%d' % 0: 'cuda:%d' % args.local_rank}
        netwrapper.load_state_dict(torch.load(args.reuse, map_location=map_location))

    history = {
        'train': {'epoch': [], 'err': []},
        'val': {'epoch': [], 'err': [], 'sdr': [], 'sir': [], 'sar': []}}

    if args.mode == 'eval':
        evaluate(netwrapper, loader_val, history, 0, args)
        print('Evaluation Done!')
        return 0

    for epoch in range(1, args.num_epoch + 1):
        train(netwrapper, loader_train, optimizer, history, epoch, args)
        evaluate(netwrapper, loader_val, history, epoch, args)
        _checkpoint2(netwrapper, history, epoch, args) if args.local_rank == 0 else None
        if epoch in args.lr_steps:
            _adjust_learning_rate(optimizer, args)

    print('Training Done!')
示例#2
0
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_sound_ground = builder.build_sound_ground(
        arch=args.arch_sound_ground, weights=args.weights_sound_ground)
    net_frame_ground = builder.build_frame_ground(
        arch=args.arch_frame_ground,
        pool_type=args.img_pool,
        weights=args.weights_frame_ground)
    net_sound = builder.build_sound(arch=args.arch_sound,
                                    fc_dim=args.num_channels,
                                    weights=args.weights_sound)
    net_frame = builder.build_frame(arch=args.arch_frame,
                                    fc_dim=args.num_channels,
                                    pool_type=args.img_pool,
                                    weights=args.weights_frame)
    net_synthesizer = builder.build_synthesizer(
        arch=args.arch_synthesizer,
        fc_dim=args.num_channels,
        weights=args.weights_synthesizer)
    net_grounding = builder.build_grounding(arch=args.arch_grounding,
                                            weights=args.weights_grounding)
    nets = (net_sound_ground, net_frame_ground, net_sound, net_frame,
            net_synthesizer, net_grounding)
    crit = builder.build_criterion(arch=args.loss)

    # Dataset and Loader
    dataset_train = MUSICMixDataset(args.list_train, args, split='train')
    dataset_val = MUSICMixDataset(args.list_val,
                                  args,
                                  max_sample=args.num_val,
                                  split=args.split)

    loader_train = torch.utils.data.DataLoader(dataset_train,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=int(args.workers),
                                               drop_last=True)
    loader_val = torch.utils.data.DataLoader(dataset_val,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=2,
                                             drop_last=False)
    args.epoch_iters = len(dataset_train) // args.batch_size
    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # Wrap networks
    netWrapper = NetWrapper(nets, crit)
    netWrapper = torch.nn.DataParallel(netWrapper,
                                       device_ids=range(args.num_gpus))
    netWrapper.to(args.device)

    # Set up optimizer
    optimizer = create_optimizer(nets, args)

    # History of peroformance
    history = {
        'train': {
            'epoch': [],
            'err': []
        },
        'val': {
            'epoch': [],
            'err': [],
            'sdr': [],
            'sir': [],
            'sar': []
        }
    }

    # Eval mode
    if args.mode == 'eval':
        args.testing = True
        evaluate(netWrapper, loader_val, history, 0, args)
        print('Evaluation Done!')
        return

    # Training loop
    for epoch in range(1, args.num_epoch + 1):
        train(netWrapper, loader_train, optimizer, history, epoch, args)

        # Evaluation and visualization
        if epoch % args.eval_epoch == 0:
            args.testing = True
            evaluate(netWrapper, loader_val, history, epoch, args)
            args.testing = False
            # checkpointing
            checkpoint(nets, history, epoch, args)

        # drop learning rate
        if epoch in args.lr_steps:
            adjust_learning_rate(optimizer, args)

    print('Training Done!')
示例#3
0
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_sound = builder.build_sound(
        arch=args.arch_sound,
        fc_dim=args.num_channels,
    )
    net_frame = builder.build_frame(
        arch=args.arch_frame,
        fc_dim=args.num_channels,
        pool_type=args.img_pool,
    )

    nets = (net_sound, net_frame)
    crit = builder.build_criterion(arch=args.loss)

    # Dataset and Loader
    dataset_train = RAWDataset(args.list_train, args, split='train')
    dataset_val = STFTDataset(args.list_val,
                              args,
                              max_sample=args.num_val,
                              split='val')

    loader_train = torch.utils.data.DataLoader(dataset_train,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=int(args.workers),
                                               drop_last=True)
    loader_val = torch.utils.data.DataLoader(dataset_val,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=2,
                                             drop_last=False)
    args.epoch_iters = len(loader_train)
    args.disp_iter = len(loader_train) // args.disp_iter
    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # Wrap networks
    netWrapper = NetWrapper(nets, crit, checkpoint)
    netWrapper = torch.nn.DataParallel(netWrapper,
                                       device_ids=range(args.num_gpus))
    netWrapper.to(args.device)

    # Set up optimizer
    optimizer = create_optimizer(nets, args, checkpoint)

    # History of peroformance
    history = {
        'train': {
            'epoch': [],
            'err': []
        },
        'val': {
            'epoch': [],
            'err': [],
            'sdr': [],
            'sir': [],
            'sar': []
        }
    } if checkpoint is None else checkpoint['history']

    from epoch import train, evaluate
    # Eval mode
    # evaluate(netWrapper, loader_val, history, 0, args)
    # if args.mode == 'eval':
    #     print('Evaluation Done!')
    #     return

    # Training loop
    init_epoch = 1 if checkpoint is None else checkpoint['epoch']
    print('Training start at ', init_epoch)
    for epoch in range(1, args.num_epoch + 1):
        train(netWrapper, loader_train, optimizer, history, epoch, args)

        # Evaluation and visualization
        if epoch % args.eval_epoch == 0:
            evaluate(netWrapper, loader_val, history, epoch, args)

            # checkpointing
            from utils import save_checkpoint
            save_checkpoint(nets, history, optimizer, epoch, args)

        # drop learning rate
        if epoch in args.lr_steps:
            adjust_learning_rate(optimizer, args)

    print('Training Done!')
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_sound = builder.build_sound(
        arch=args.arch_sound,
        fc_dim=args.num_channels,
        weights=args.weights_sound)
    net_frame = builder.build_frame(
        arch=args.arch_frame,
        fc_dim=args.num_channels,
        pool_type=args.img_pool,
        weights=args.weights_frame)
    nets = (net_sound, net_frame)
    crit = builder.build_criterion(arch=args.loss)

    # Dataset and Loader
    dataset_train = MUSICMixDataset(
        args.list_train, args, split='train')
    dataset_val = MUSICMixDataset(
        args.list_val, args, max_sample=args.num_val, split='val')

    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=int(args.workers),
        drop_last=True)
    loader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=int(args.workers),
        drop_last=False)
    args.epoch_iters = len(dataset_train) // args.batch_size
    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # Set up optimizer
    optimizer = create_optimizer(nets, args)

    # History of peroformance
    history = {
        'train': {'epoch': [], 'err': []},
        'val': {'epoch': [], 'err': [], 'sdr': [], 'sir': [], 'sar': []}}


    # Training loop
    start_epoch = 1
    model_name = args.ckpt + '/checkpoint.pth'
    if os.path.exists(model_name):
        if args.mode == 'eval':
            nets = load_checkpoint_from_train(nets, model_name)
        elif args.mode == 'train':
            model_name = args.ckpt + '/checkpoint_latest.pth'
            nets, optimizer, start_epoch, history = load_checkpoint(nets, optimizer, history, model_name)
            print("Loading from previous checkpoint.")

    # Wrap networks
    netWrapper = NetWrapper(nets, crit)
    netWrapper = torch.nn.DataParallel(netWrapper, device_ids=range(args.num_gpus)).cuda()
    netWrapper.to(args.device)


    # Eval mode
    #evaluate(netWrapper, loader_val, history, 0, args)
    if args.mode == 'eval':
        evaluate(netWrapper, loader_val, history, 0, args)
        print('Evaluation Done!')
        return

        
    for epoch in range(start_epoch, args.num_epoch + 1):    
        train(netWrapper, loader_train, optimizer, history, epoch, args)

        # drop learning rate
        if epoch in args.lr_steps:
            adjust_learning_rate(optimizer, args)

        ## Evaluation and visualization
        if epoch % args.eval_epoch == 0:
            evaluate(netWrapper, loader_val, history, epoch, args)

            # checkpointing
            checkpoint(nets, optimizer, history, epoch, args)

    print('Training Done!')
def main(args):
    # Network Builders
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    np.random.seed(0)
    random.seed(0)
    builder = ModelBuilder()
    net_sound = builder.build_sound(
        arch=args.arch_sound,
        input_channel=1,
        output_channel=args.num_channels,
        fc_dim=args.num_channels,
        weights=args.weights_sound)
    net_frame = builder.build_frame(
        arch=args.arch_frame,
        fc_dim=args.num_channels,
        pool_type=args.img_pool,
        weights=args.weights_frame)
    net_avol = builder.build_avol(
        arch=args.arch_avol,
        fc_dim=args.num_channels,
        weights=args.weights_frame)

    crit_loc = nn.BCELoss()
    crit_sep = builder.build_criterion(arch=args.loss)

    # Dataset and Loader
    dataset_train = MUSICMixDataset(
        args.list_train, args, split='train')
    dataset_val = MUSICMixDataset(
        args.list_val, args, max_sample=args.num_val, split='val')

    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=int(args.workers),
        drop_last=True)
    loader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=int(args.workers),
        drop_last=False)
    args.epoch_iters = len(dataset_train) // args.batch_size
    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # Set up optimizer
    optimizer = create_optimizer(net_sound, net_frame, net_avol, args)

    # History of peroformance
    history = {
        'train': {'epoch': [], 'err': [], 'err_loc': [], 'err_sep': [], 'acc': []},
        'val': {'epoch': [], 'err': [],  'err_loc': [], 'err_sep': [], 'acc': [], 'sdr': [], 'sir': [], 'sar': []}}


    # Training loop
    # Load from pretrained models
    start_epoch = 1
    model_name = args.ckpt + '/checkpoint.pth'
    if os.path.exists(model_name):
        if args.mode == 'eval':
            net_sound, net_frame, net_avol = load_checkpoint_from_train(net_sound, net_frame, net_avol, model_name)
        elif args.mode == 'train':
            model_name = args.ckpt + '/checkpoint_latest.pth'
            net_sound, net_frame, net_avol, optimizer, start_epoch, history = load_checkpoint(net_sound, net_frame, net_avol, optimizer, history, model_name)
            print("Loading from previous checkpoint.")
    
    else:
        if args.mode == 'train' and start_epoch==1 and os.path.exists(args.weights_model):
            net_sound, net_frame = load_sep(net_sound, net_frame, args.weights_model)
            print("Loading from appearance + sound checkpoint.")
    
    # Wrap networks
    netWrapper1 = NetWrapper1(net_sound)
    netWrapper1 = torch.nn.DataParallel(netWrapper1, device_ids=range(args.num_gpus)).cuda()
    netWrapper1.to(args.device)

    netWrapper2 = NetWrapper2(net_frame)
    netWrapper2 = torch.nn.DataParallel(netWrapper2, device_ids=range(args.num_gpus)).cuda()
    netWrapper2.to(args.device)

    netWrapper3 = NetWrapper3(net_avol)
    netWrapper3 = torch.nn.DataParallel(netWrapper3, device_ids=range(args.num_gpus)).cuda()
    netWrapper3.to(args.device)


    # Eval mode
    #evaluate(crit_loc, crit_sep, netWrapper1, netWrapper2, netWrapper3, loader_val, history, 0, args)
    if args.mode == 'eval':
        evaluate(crit_loc, crit_sep, netWrapper1, netWrapper2, netWrapper3, loader_val, history, 0, args)
        print('Evaluation Done!')
        return

        
    for epoch in range(start_epoch, args.num_epoch + 1):    
        train(crit_loc, crit_sep, netWrapper1, netWrapper2, netWrapper3, loader_train, optimizer, history, epoch, args)

        # drop learning rate
        if epoch in args.lr_steps:
            adjust_learning_rate(optimizer, args)

        ## Evaluation and visualization
        if epoch % args.eval_epoch == 0:
            evaluate(crit_loc, crit_sep, netWrapper1, netWrapper2, netWrapper3, loader_val, history, epoch, args)
            # checkpointing
            checkpoint(net_sound, net_frame, net_avol, optimizer, history, epoch, args)

    print('Training Done!')
示例#6
0
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_sound_M = builder.build_sound(arch=args.arch_sound,
                                      fc_dim=args.num_channels,
                                      weights=args.weights_sound_M)
    net_frame_M = builder.build_frame(arch=args.arch_frame,
                                      fc_dim=args.num_channels,
                                      pool_type=args.img_pool,
                                      weights=args.weights_frame_M)

    net_sound_P = builder.build_sound(
        input_nc=2,
        arch=args.arch_sound,
        # fc_dim=args.num_channels,
        fc_dim=1,
        weights=args.weights_sound_P)

    nets = (net_sound_M, net_frame_M, net_sound_P)
    crit = builder.build_criterion(arch=args.loss)

    # Wrap networks
    # set netwrapper forward mode
    # there are there modes for different training stages
    # ['Minus', 'Plus', 'Minus_Plus']
    netwrapper = NetWrapper(nets, crit, mode=args.forward_mode)
    netwrapper = torch.nn.DataParallel(netwrapper,
                                       device_ids=range(args.num_gpus))
    netwrapper.to(args.device)

    # Dataset and Loader
    dataset_train = MUSICMixDataset(args.list_train, args, split='train')
    dataset_val = MUSICMixDataset(args.list_val,
                                  args,
                                  max_sample=args.num_val,
                                  split='val')

    loader_train = torch.utils.data.DataLoader(dataset_train,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=int(args.workers),
                                               drop_last=True)
    loader_val = torch.utils.data.DataLoader(dataset_val,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=2,
                                             drop_last=False)
    args.epoch_iters = len(dataset_train) // args.batch_size
    print('1 Epoch = {} iters'.format(args.epoch_iters))

    # Set up optimizer
    optimizer = MP_Trainer.create_optimizer(nets, args)

    mp_trainer = MP_Trainer(netwrapper, optimizer, args)

    # Eval firstly
    mp_trainer.evaluate(loader_val)
    if mp_trainer.mode == 'eval':
        print('Evaluation Done!')
    else:
        # start training
        for epoch in range(1, args.num_epoch + 1):
            mp_trainer.epoch = epoch
            mp_trainer.train(loader_train)

            # Evaluation and visualization
            if epoch % args.eval_epoch == 0:
                mp_trainer.evaluate(loader_val)

                # checkpointing
                mp_trainer.checkpoint()

            # adjust learning rate
            if epoch in args.lr_steps:
                mp_trainer.adjust_learning_rate()

        print('Training Done!')
        mp_trainer.writer.close()
示例#7
0
def main(args):
    # Network Builders
    builder = ModelBuilder()
    net_sound = builder.build_sound(arch=args.arch_sound,
                                    fc_dim=args.num_channels,
                                    weights=args.weights_sound)
    net_frame = builder.build_frame(arch=args.arch_frame,
                                    fc_dim=args.num_channels,
                                    pool_type=args.img_pool,
                                    weights=args.weights_frame)
    net_synthesizer = builder.build_synthesizer(
        arch=args.arch_synthesizer,
        fc_dim=args.num_channels,
        weights=args.weights_synthesizer)
    nets = (net_sound, net_frame, net_synthesizer)
    crit = builder.build_criterion(arch=args.loss)

    # Dataset and Loader
    # dataset_train = MUSICMixDataset(
    #     args.list_train, args, split='train')
    dataset_val = MUSICMixDataset(args.list_val,
                                  args,
                                  max_sample=args.num_val,
                                  split='val')

    # loader_train = torch.utils.data.DataLoader(
    #     dataset_train,
    #     batch_size=args.batch_size,
    #     shuffle=True,
    #     num_workers=int(args.workers),
    #     drop_last=True)
    loader_val = torch.utils.data.DataLoader(dataset_val,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=2,
                                             drop_last=False)
    # args.epoch_iters = len(dataset_train) // args.batch_size
    # print('1 Epoch = {} iters'.format(args.epoch_iters))

    # Wrap networks
    netWrapper = NetWrapper(nets, crit)
    netWrapper = torch.nn.DataParallel(netWrapper,
                                       device_ids=range(args.num_gpus))
    netWrapper.to(args.device)

    # Set up optimizer
    optimizer = create_optimizer(nets, args)

    # History of peroformance
    history = {
        'train': {
            'epoch': [],
            'err': []
        },
        'val': {
            'epoch': [],
            'err': [],
            'sdr': [],
            'sir': [],
            'sar': []
        }
    }

    # Eval mode
    evaluate(netWrapper, loader_val, history, 0, args)
    if args.mode == 'eval':
        print('Evaluation Done!')
        return