def main(args): # Network Builders builder = ModelBuilder() net_sound = builder.build_sound(arch=args.arch_sound, fc_dim=args.num_channels, weights=args.weights_sound) net_vision = builder.build_vision(arch=args.arch_vision, fc_dim=args.num_channels, weights=args.weights_vision) net_synthesizer = builder.build_synthesizer( arch=args.arch_synthesizer, fc_dim=args.num_channels, weights=args.weights_synthesizer) nets = (net_sound, net_vision, net_synthesizer) crit = builder.build_criterion(arch=args.loss) # Dataset and Loader dataset_train = MUSICMixDataset(args.list_train, args, split='train') dataset_val = MUSICMixDataset(args.list_val, args, max_sample=args.num_val, split='val') loader_train = make_data_loader(dataset_train, args) loader_val = make_data_loader(dataset_val, args) args.epoch_iters = len(dataset_train) // args.batch_size if args.mode == 'train': print('1 Epoch = {} iters'.format(args.epoch_iters)) # Wrap networks netWrapper = NetWrapper(nets, crit) netWrapper = torch.nn.DataParallel(netWrapper, device_ids=range(args.num_gpus)) netWrapper.to(args.device) # Set up optimizer optimizer = create_optimizer(nets, args) # History of performance history = { 'train': { 'epoch': [], 'err': [] }, 'val': { 'epoch': [], 'err': [], 'sdr': [], 'sir': [], 'sar': [] } } # Eval mode evaluate(netWrapper, loader_val, history, 0, args) if args.mode == 'eval': print('Evaluation Done!') return # Training loop for epoch in range(1, args.num_epoch + 1): train(netWrapper, loader_train, optimizer, history, epoch, args) # Evaluation and visualization if epoch % args.eval_epoch == 0: evaluate(netWrapper, loader_val, history, epoch, args) # checkpointing checkpoint(nets, history, epoch, args) print('Training Done!')
def main(args): # Network Builders builder = ModelBuilder() net_sound_ground = builder.build_sound_ground( arch=args.arch_sound_ground, weights=args.weights_sound_ground) net_frame_ground = builder.build_frame_ground( arch=args.arch_frame_ground, pool_type=args.img_pool, weights=args.weights_frame_ground) net_sound = builder.build_sound(arch=args.arch_sound, fc_dim=args.num_channels, weights=args.weights_sound) net_frame = builder.build_frame(arch=args.arch_frame, fc_dim=args.num_channels, pool_type=args.img_pool, weights=args.weights_frame) net_synthesizer = builder.build_synthesizer( arch=args.arch_synthesizer, fc_dim=args.num_channels, weights=args.weights_synthesizer) net_grounding = builder.build_grounding(arch=args.arch_grounding, weights=args.weights_grounding) nets = (net_sound_ground, net_frame_ground, net_sound, net_frame, net_synthesizer, net_grounding) crit = builder.build_criterion(arch=args.loss) # Dataset and Loader dataset_train = MUSICMixDataset(args.list_train, args, split='train') dataset_val = MUSICMixDataset(args.list_val, args, max_sample=args.num_val, split=args.split) loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers), drop_last=True) loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, num_workers=2, drop_last=False) args.epoch_iters = len(dataset_train) // args.batch_size print('1 Epoch = {} iters'.format(args.epoch_iters)) # Wrap networks netWrapper = NetWrapper(nets, crit) netWrapper = torch.nn.DataParallel(netWrapper, device_ids=range(args.num_gpus)) netWrapper.to(args.device) # Set up optimizer optimizer = create_optimizer(nets, args) # History of peroformance history = { 'train': { 'epoch': [], 'err': [] }, 'val': { 'epoch': [], 'err': [], 'sdr': [], 'sir': [], 'sar': [] } } # Eval mode if args.mode == 'eval': args.testing = True evaluate(netWrapper, loader_val, history, 0, args) print('Evaluation Done!') return # Training loop for epoch in range(1, args.num_epoch + 1): train(netWrapper, loader_train, optimizer, history, epoch, args) # Evaluation and visualization if epoch % args.eval_epoch == 0: args.testing = True evaluate(netWrapper, loader_val, history, epoch, args) args.testing = False # checkpointing checkpoint(nets, history, epoch, args) # drop learning rate if epoch in args.lr_steps: adjust_learning_rate(optimizer, args) print('Training Done!')
def main(args): # Network Builders builder = ModelBuilder() net_sound = builder.build_sound(arch=args.arch_sound, fc_dim=args.num_channels, weights=args.weights_sound) net_frame = builder.build_frame(arch=args.arch_frame, fc_dim=args.num_channels, pool_type=args.img_pool, weights=args.weights_frame) net_synthesizer = builder.build_synthesizer( arch=args.arch_synthesizer, fc_dim=args.num_channels, weights=args.weights_synthesizer) nets = (net_sound, net_frame, net_synthesizer) crit = builder.build_criterion(arch=args.loss) # Dataset and Loader # dataset_train = MUSICMixDataset( # args.list_train, args, split='train') dataset_val = MUSICMixDataset(args.list_val, args, max_sample=args.num_val, split='val') # loader_train = torch.utils.data.DataLoader( # dataset_train, # batch_size=args.batch_size, # shuffle=True, # num_workers=int(args.workers), # drop_last=True) loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, num_workers=2, drop_last=False) # args.epoch_iters = len(dataset_train) // args.batch_size # print('1 Epoch = {} iters'.format(args.epoch_iters)) # Wrap networks netWrapper = NetWrapper(nets, crit) netWrapper = torch.nn.DataParallel(netWrapper, device_ids=range(args.num_gpus)) netWrapper.to(args.device) # Set up optimizer optimizer = create_optimizer(nets, args) # History of peroformance history = { 'train': { 'epoch': [], 'err': [] }, 'val': { 'epoch': [], 'err': [], 'sdr': [], 'sir': [], 'sar': [] } } # Eval mode evaluate(netWrapper, loader_val, history, 0, args) if args.mode == 'eval': print('Evaluation Done!') return