def __init__(self,
                 style_image_path,
                 path,
                 style_weight=1.,
                 content_weight=1.,
                 input_size=(256, 256),
                 optimizer=None,
                 sub_path=None,
                 test_photo_path=None,
                 load_model=None,
                 load_only=False):
        """Style transfer class
        Args:
            style_image_path: path to style image
            path: base path
            style_weight: weight for the style loss
            content_weight: weight for the content loss
            input_size: size used for training
            optimizer: pass a keras optimizer, Adam used standard
            sub_path: sub path used for every model, automatically generated if not given
            test_photo_path: a path to a folder of test photos to run on after training
            load_model: path to a pretrained model
            load_only: do not set up with a new directory for training, only prediction"""

        if load_model is None:
            self.build_generator_net(input_size=input_size)
        else:
            self.build_from_path(load_model)

        if not load_only:
            self.style_image = load_and_process_img(style_image_path,
                                                    resize=False)

            self.style_image_path = style_image_path
            self.style_weight = style_weight
            self.content_weight = content_weight

            self.loss, self.cl, self.sl = get_loss_func(
                self.style_image,
                style_weight=style_weight,
                content_weight=content_weight)

            if optimizer is None:
                self.optimizer = optimizers.Adam
            else:
                self.optimizer = optimizer

            self.path = path

            if sub_path is None:
                i = 0
                while os.path.exists(self.path + '{:05d}/'.format(i)):
                    i += 1
                self.sub_path = '{:05d}/'.format(i)
                os.mkdir(self.path + self.sub_path)
            self.test_photo_path = test_photo_path
Ejemplo n.º 2
0
def train(args):

    # Arugments & parameters
    dataset_dir = args.dataset_dir
    workspace = args.workspace
    holdout_fold = args.holdout_fold
    model_type = args.model_type
    pretrained_checkpoint_path = args.pretrained_checkpoint_path
    freeze_base = args.freeze_base
    loss_type = args.loss_type
    augmentation = args.augmentation
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    few_shots = args.few_shots
    random_seed = args.random_seed
    resume_iteration = args.resume_iteration
    stop_iteration = args.stop_iteration
    device = 'cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu'
    mini_data = args.mini_data
    filename = args.filename

    loss_func = get_loss_func(loss_type)
    pretrain = True if pretrained_checkpoint_path else False
    num_workers = 16
    
    # Paths
    if mini_data:
        prefix = 'minidata_'
    else:
        prefix = ''

    train_hdf5_path = os.path.join(workspace, 'features', 
        '{}training.waveform.h5'.format(prefix))

    test_hdf5_path = os.path.join(workspace, 'features', 
        'testing.waveform.h5'.format(prefix))

    evaluate_hdf5_path = os.path.join(workspace, 'features', 
        'evaluation.waveform.h5'.format(prefix))

    test_reference_csv_path = os.path.join(dataset_dir, 'metadata', 
        'groundtruth_strong_label_testing_set.csv')
        
    evaluate_reference_csv_path = os.path.join(dataset_dir, 'metadata', 
        'groundtruth_strong_label_evaluation_set.csv')

    checkpoints_dir = os.path.join(workspace, 'checkpoints', filename, 
        'holdout_fold={}'.format(holdout_fold), model_type, 
        'pretrain={}'.format(pretrain), 'loss_type={}'.format(loss_type), 
        'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size), 
        'few_shots={}'.format(few_shots), 'random_seed={}'.format(random_seed), 
        'freeze_base={}'.format(freeze_base))
    create_folder(checkpoints_dir)

    tmp_submission_path = os.path.join(workspace, '_tmp_submission', 
        '{}{}'.format(prefix, filename), 'holdout_fold={}'.format(holdout_fold), 
        model_type, 'pretrain={}'.format(pretrain), 'loss_type={}'.format(loss_type), 
        'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size), 
        'few_shots={}'.format(few_shots), 'random_seed={}'.format(random_seed), 
        'freeze_base={}'.format(freeze_base), '_submission.csv')
    create_folder(os.path.dirname(tmp_submission_path))

    statistics_path = os.path.join(workspace, 'statistics', 
        '{}{}'.format(prefix, filename), 'holdout_fold={}'.format(holdout_fold), 
        model_type, 'pretrain={}'.format(pretrain), 'loss_type={}'.format(loss_type), 
        'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size), 
        'few_shots={}'.format(few_shots), 'random_seed={}'.format(random_seed), 
        'freeze_base={}'.format(freeze_base), 'statistics.pickle')
    create_folder(os.path.dirname(statistics_path))

    predictions_dir = os.path.join(workspace, 'predictions', 
        '{}{}'.format(prefix, filename), 'holdout_fold={}'.format(holdout_fold), 
        model_type, 'pretrain={}'.format(pretrain), 
        'loss_type={}'.format(loss_type), 'augmentation={}'.format(augmentation), 
        'few_shots={}'.format(few_shots), 'random_seed={}'.format(random_seed), 
        'freeze_base={}'.format(freeze_base), 'batch_size={}'.format(batch_size))
    create_folder(predictions_dir)

    logs_dir = os.path.join(workspace, 'logs', '{}{}'.format(prefix, filename), 
        'holdout_fold={}'.format(holdout_fold), model_type, 
        'pretrain={}'.format(pretrain), 'loss_type={}'.format(loss_type), 
        'augmentation={}'.format(augmentation), 'few_shots={}'.format(few_shots), 
        'random_seed={}'.format(random_seed), 'freeze_base={}'.format(freeze_base), 
        'batch_size={}'.format(batch_size))
    create_logging(logs_dir, 'w')
    logging.info(args)

    if 'cuda' in device:
        logging.info('Using GPU.')
    else:
        logging.info('Using CPU. Set --cuda flag to use GPU.')
    
    # Model
    Model = eval(model_type)
    model = Model(sample_rate, window_size, hop_size, mel_bins, fmin, fmax, 
        classes_num)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    if pretrain:
        logging.info('Load pretrained model from {}'.format(pretrained_checkpoint_path))
        model.load_from_pretrain(pretrained_checkpoint_path)

    if resume_iteration:
        resume_checkpoint_path = os.path.join(checkpoints_dir, '{}_iterations.pth'.format(resume_iteration))
        logging.info('Load resume model from {}'.format(resume_checkpoint_path))
        resume_checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(resume_checkpoint['model'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = resume_checkpoint['iteration']
    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in device:
        model.to(device)

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, 
        betas=(0.9, 0.999), eps=1e-08, weight_decay=0., amsgrad=True)

    train_dataset = DCASE2017Task4Dataset(hdf5_path=train_hdf5_path)
    test_dataset = DCASE2017Task4Dataset(hdf5_path=test_hdf5_path)
    evaluate_dataset = DCASE2017Task4Dataset(hdf5_path=evaluate_hdf5_path)

    train_sampler = TrainSampler(
        hdf5_path=train_hdf5_path, 
        batch_size=batch_size * 2 if 'mixup' in augmentation else batch_size, 
        few_shots=few_shots, 
        random_seed=random_seed)

    test_sampler = EvaluateSampler(dataset_size=len(test_dataset), batch_size=batch_size)
    evaluate_sampler = EvaluateSampler(dataset_size=len(evaluate_dataset), batch_size=batch_size)

    collector = Collator()

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
        batch_sampler=train_sampler, collate_fn=collector, 
        num_workers=num_workers, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
        batch_sampler=test_sampler, collate_fn=collector, 
        num_workers=num_workers, pin_memory=True)

    evaluate_loader = torch.utils.data.DataLoader(dataset=evaluate_dataset, 
        batch_sampler=evaluate_sampler, collate_fn=collector, 
        num_workers=num_workers, pin_memory=True)

    if 'mixup' in augmentation:
        mixup_augmenter = Mixup(mixup_alpha=1.)
        
    # Evaluator
    test_evaluator = Evaluator(
        model=model, 
        generator=test_loader)

    evaluate_evaluator = Evaluator(
        model=model, 
        generator=evaluate_loader)

    train_bgn_time = time.time()
    
    # Train on mini batches
    for batch_data_dict in train_loader:
        
        # Evaluate
        if iteration % 1000 == 0:
            if resume_iteration > 0 and iteration == resume_iteration:
                pass
            else:
                logging.info('------------------------------------')
                logging.info('Iteration: {}'.format(iteration))

                train_fin_time = time.time()

                for (data_type, evaluator, reference_csv_path) in [
                    ('test', test_evaluator, test_reference_csv_path), 
                    ('evaluate', evaluate_evaluator, evaluate_reference_csv_path)]:

                    logging.info('{} statistics:'.format(data_type))

                    (statistics, predictions) = evaluator.evaluate(
                        reference_csv_path, tmp_submission_path)

                    statistics_container.append(data_type, iteration, statistics)

                    prediction_path = os.path.join(predictions_dir, 
                        '{}_iterations.prediction.{}.h5'.format(iteration, data_type))

                    write_out_prediction(predictions, prediction_path)
                
                statistics_container.dump()

                train_time = train_fin_time - train_bgn_time
                validate_time = time.time() - train_fin_time

                logging.info(
                    'Train time: {:.3f} s, validate time: {:.3f} s'
                    ''.format(train_time, validate_time))

                train_bgn_time = time.time()

        # Save model 
        if iteration % 10000 == 0 and iteration > 49999:
            checkpoint = {
                'iteration': iteration, 
                'model': model.module.state_dict(), 
                'optimizer': optimizer.state_dict()}

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))
                
            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))
        
        if 'mixup' in augmentation:
            batch_data_dict['mixup_lambda'] = mixup_augmenter.get_lambda(len(batch_data_dict['waveform']))

        # Move data to GPU
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key], device)
        
        # Train
        model.train()

        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'], batch_data_dict['mixup_lambda'])
            batch_target_dict = {'target': do_mixup(batch_data_dict['target'], batch_data_dict['mixup_lambda'])}
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            batch_target_dict = {'target': batch_data_dict['target']}

        # loss
        loss = loss_func(batch_output_dict, batch_target_dict)
        print(iteration, loss)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Stop learning
        if iteration == stop_iteration:
            break 
            
        iteration += 1
Ejemplo n.º 3
0
def train(args):
    """Train AudioSet tagging model. 

    Args:
      dataset_dir: str
      workspace: str
      data_type: 'balanced_train' | 'full_train'
      window_size: int
      hop_size: int
      mel_bins: int
      model_type: str
      loss_type: 'clip_bce'
      balanced: 'none' | 'balanced' | 'alternate'
      augmentation: 'none' | 'mixup'
      batch_size: int
      learning_rate: float
      resume_iteration: int
      early_stop: int
      accumulation_steps: int
      cuda: bool
    """

    # Arugments & parameters
    workspace = args.workspace
    data_type = args.data_type
    sample_rate = args.sample_rate
    window_size = args.window_size
    hop_size = args.hop_size
    mel_bins = args.mel_bins
    fmin = args.fmin
    fmax = args.fmax
    model_type = args.model_type
    loss_type = args.loss_type
    balanced = args.balanced
    augmentation = args.augmentation
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    resume_iteration = args.resume_iteration
    early_stop = args.early_stop
    device = torch.device('cuda') if args.cuda and torch.cuda.is_available(
    ) else torch.device('cpu')
    filename = args.filename

    num_workers = 128
    prefetch_factor = 4

    #os.environ["MASTER_ADDR"] = "localhost"
    #os.environ["MASTER_PORT"] = "12355"
    #dist.init_process_group("nccl", rank=rank, world_size=args.world_size)

    clip_samples = config.clip_samples
    classes_num = config.classes_num
    loss_func = get_loss_func(loss_type)

    # Paths
    black_list_csv = None

    train_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                           '{}.h5'.format(data_type))

    eval_bal_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                              'balanced_train.h5')

    eval_test_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                               'eval.h5')

    checkpoints_dir = os.path.join(
        workspace, 'checkpoints', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size),
        datetime.datetime.now().strftime("%d%m%Y_%H%M%S"))

    #if rank == 0:
    create_folder(checkpoints_dir)

    statistics_path = os.path.join(
        workspace, 'statistics', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size), 'statistics.pkl')

    #if rank == 0:
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(
        workspace, 'logs', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))

    create_logging(logs_dir, filemode='w')
    logging.info(args)

    if 'cuda' in str(device):
        logging.info('Using GPU.')
        device = 'cuda'
    else:
        logging.info('Using CPU. Set --cuda flag to use GPU.')
        device = 'cpu'

    # Model
    Model = eval(model_type)
    model = Model(sample_rate=sample_rate,
                  window_size=window_size,
                  hop_size=hop_size,
                  mel_bins=mel_bins,
                  fmin=fmin,
                  fmax=fmax,
                  classes_num=classes_num)

    params_num = count_parameters(model)
    # flops_num = count_flops(model, clip_samples)
    logging.info('Parameters num: {}'.format(params_num))
    # logging.info('Flops num: {:.3f} G'.format(flops_num / 1e9))

    # Dataset will be used by DataLoader later. Dataset takes a meta as input
    # and return a waveform and a target.
    dataset = AudioSetDataset(sample_rate=sample_rate)

    # Train sampler
    if balanced == 'none':
        Sampler = TrainSampler
    elif balanced == 'balanced':
        Sampler = BalancedTrainSampler
    elif balanced == 'alternate':
        Sampler = AlternateTrainSampler

    train_sampler = Sampler(indexes_hdf5_path=train_indexes_hdf5_path,
                            batch_size=batch_size *
                            2 if 'mixup' in augmentation else batch_size,
                            black_list_csv=black_list_csv)

    # Evaluate sampler
    eval_bal_sampler = EvaluateSampler(
        indexes_hdf5_path=eval_bal_indexes_hdf5_path,
        batch_size=2 * batch_size)

    eval_test_sampler = EvaluateSampler(
        indexes_hdf5_path=eval_test_indexes_hdf5_path,
        batch_size=2 * batch_size)

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_sampler=train_sampler,
                                               collate_fn=collate_fn,
                                               num_workers=num_workers,
                                               pin_memory=True,
                                               prefetch_factor=prefetch_factor)

    eval_bal_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_bal_sampler,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=True,
        prefetch_factor=prefetch_factor)

    eval_test_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_test_sampler,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=True,
        prefetch_factor=prefetch_factor)

    if 'mixup' in augmentation:
        mixup_augmenter = Mixup(mixup_alpha=1.)

    # Evaluator
    evaluator = Evaluator(model=model)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)

    train_bgn_time = time.time()

    # Resume training
    if resume_iteration > 0:
        resume_checkpoint_path = os.path.join(
            workspace, 'checkpoints', filename,
            'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
            .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                    fmax), 'data_type={}'.format(data_type), model_type,
            'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
            'augmentation={}'.format(augmentation),
            'batch_size={}'.format(batch_size),
            '{}_iterations.pth'.format(resume_iteration))

        logging.info('Loading checkpoint {}'.format(resume_checkpoint_path))
        checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(checkpoint['model'])
        train_sampler.load_state_dict(checkpoint['sampler'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = checkpoint['iteration']

    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in str(device):
        model.to(device)
        #model = model.cuda(rank)

    #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    #print([(s[0], s[1].is_cuda) for s in model.named_parameters()])

    time1 = time.time()

    prev_bal_map = 0.0
    prev_test_map = 0.0
    save_bal_model = 0
    save_test_model = 0

    for batch_data_dict in train_loader:
        """batch_data_dict: {
            'audio_name': (batch_size [*2 if mixup],), 
            'waveform': (batch_size [*2 if mixup], clip_samples), 
            'target': (batch_size [*2 if mixup], classes_num), 
            (ifexist) 'mixup_lambda': (batch_size * 2,)}
        """
        #print(batch_data_dict)
        # Evaluate
        if (iteration % 2000 == 0
                and iteration > resume_iteration) or (iteration == -1):
            train_fin_time = time.time()

            bal_statistics = evaluator.evaluate(eval_bal_loader)
            test_statistics = evaluator.evaluate(eval_test_loader)

            logging.info('Validate bal mAP: {:.3f}'.format(
                np.mean(bal_statistics['average_precision'])))

            logging.info('Validate test mAP: {:.3f}'.format(
                np.mean(test_statistics['average_precision'])))

            save_bal_model = 1 if np.mean(
                bal_statistics['average_precision']) > prev_bal_map else 0
            save_test_model = 1 if np.mean(
                test_statistics['average_precision']) > prev_test_map else 0

            statistics_container.append(iteration,
                                        bal_statistics,
                                        data_type='bal')
            statistics_container.append(iteration,
                                        test_statistics,
                                        data_type='test')
            statistics_container.dump()

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info(
                'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s'
                ''.format(iteration, train_time, validate_time))

            logging.info('------------------------------------')

            train_bgn_time = time.time()

        # Save model
        if iteration % 100000 == 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        if save_bal_model:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations_bal.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))
            save_bal_model = 0

        if save_test_model:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations_test.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))
            save_test_model = 0

        # Mixup lambda
        if 'mixup' in augmentation:
            batch_data_dict['mixup_lambda'] = mixup_augmenter.get_lambda(
                batch_size=len(batch_data_dict['waveform']))

        # Move data to device
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key],
                                                       device)

        # Forward
        model.train()

        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'],
                                      batch_data_dict['mixup_lambda'])
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {
                'target':
                do_mixup(batch_data_dict['target'],
                         batch_data_dict['mixup_lambda'])
            }
            """{'target': (batch_size, classes_num)}"""
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {'target': batch_data_dict['target']}
            """{'target': (batch_size, classes_num)}"""

        # Loss
        loss = loss_func(batch_output_dict, batch_target_dict)

        # Backward
        loss.backward()
        print(loss)

        optimizer.step()
        optimizer.zero_grad()

        if iteration % 10 == 0:
            print('--- Iteration: {}, train time: {:.3f} s / 10 iterations ---'\
                .format(iteration, time.time() - time1))
            time1 = time.time()

        # Stop learning
        if iteration == early_stop:
            break

        iteration += 1
Ejemplo n.º 4
0
def train(args):
    """Train a piano transcription system.

    Args:
      workspace: str, directory of your workspace
      model_type: str, e.g. 'Regressonset_regressoffset_frame_velocity_CRNN'
      loss_type: str, e.g. 'regress_onset_offset_frame_velocity_bce'
      augmentation: str, e.g. 'none'
      batch_size: int
      learning_rate: float
      reduce_iteration: int
      resume_iteration: int
      early_stop: int
      device: 'cuda' | 'cpu'
      mini_data: bool
    """

    # Arugments & parameters
    workspace = args.workspace
    model_type = args.model_type
    loss_type = args.loss_type
    augmentation = args.augmentation
    max_note_shift = args.max_note_shift
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    reduce_iteration = args.reduce_iteration
    resume_iteration = args.resume_iteration
    early_stop = args.early_stop
    device = torch.device('cuda') if args.cuda and torch.cuda.is_available(
    ) else torch.device('cpu')
    mini_data = args.mini_data
    filename = args.filename

    sample_rate = config.sample_rate
    segment_seconds = config.segment_seconds
    hop_seconds = config.hop_seconds
    segment_samples = int(segment_seconds * sample_rate)
    frames_per_second = config.frames_per_second
    classes_num = config.classes_num
    num_workers = 8

    # Loss function
    loss_func = get_loss_func(loss_type)

    # Paths
    hdf5s_dir = os.path.join(workspace, 'hdf5s', 'maestro')

    checkpoints_dir = os.path.join(workspace, 'checkpoints', filename,
                                   model_type,
                                   'loss_type={}'.format(loss_type),
                                   'augmentation={}'.format(augmentation),
                                   'max_note_shift={}'.format(max_note_shift),
                                   'batch_size={}'.format(batch_size))
    create_folder(checkpoints_dir)

    statistics_path = os.path.join(workspace, 'statistics', filename,
                                   model_type,
                                   'loss_type={}'.format(loss_type),
                                   'augmentation={}'.format(augmentation),
                                   'max_note_shift={}'.format(max_note_shift),
                                   'batch_size={}'.format(batch_size),
                                   'statistics.pkl')
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(workspace, 'logs', filename, model_type,
                            'loss_type={}'.format(loss_type),
                            'augmentation={}'.format(augmentation),
                            'max_note_shift={}'.format(max_note_shift),
                            'batch_size={}'.format(batch_size))
    create_folder(logs_dir)

    create_logging(logs_dir, filemode='w')
    logging.info(args)

    if 'cuda' in str(device):
        logging.info('Using GPU.')
        device = 'cuda'
    else:
        logging.info('Using CPU.')
        device = 'cpu'

    # Model
    Model = eval(model_type)
    model = Model(frames_per_second=frames_per_second, classes_num=classes_num)

    if augmentation == 'none':
        augmentor = None
    elif augmentation == 'aug':
        augmentor = Augmentor()
    else:
        raise Exception('Incorrect argumentation!')

    # Dataset
    train_dataset = MaestroDataset(hdf5s_dir=hdf5s_dir,
                                   segment_seconds=segment_seconds,
                                   frames_per_second=frames_per_second,
                                   max_note_shift=max_note_shift,
                                   augmentor=augmentor)

    evaluate_dataset = MaestroDataset(hdf5s_dir=hdf5s_dir,
                                      segment_seconds=segment_seconds,
                                      frames_per_second=frames_per_second,
                                      max_note_shift=0)

    # Sampler for training
    train_sampler = Sampler(hdf5s_dir=hdf5s_dir,
                            split='train',
                            segment_seconds=segment_seconds,
                            hop_seconds=hop_seconds,
                            batch_size=batch_size,
                            mini_data=mini_data)

    # Sampler for evaluation
    evaluate_train_sampler = TestSampler(hdf5s_dir=hdf5s_dir,
                                         split='train',
                                         segment_seconds=segment_seconds,
                                         hop_seconds=hop_seconds,
                                         batch_size=batch_size,
                                         mini_data=mini_data)

    evaluate_validate_sampler = TestSampler(hdf5s_dir=hdf5s_dir,
                                            split='validation',
                                            segment_seconds=segment_seconds,
                                            hop_seconds=hop_seconds,
                                            batch_size=batch_size,
                                            mini_data=mini_data)

    evaluate_test_sampler = TestSampler(hdf5s_dir=hdf5s_dir,
                                        split='test',
                                        segment_seconds=segment_seconds,
                                        hop_seconds=hop_seconds,
                                        batch_size=batch_size,
                                        mini_data=mini_data)

    # Dataloader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_sampler=train_sampler,
                                               collate_fn=collate_fn,
                                               num_workers=num_workers,
                                               pin_memory=True)

    evaluate_train_loader = torch.utils.data.DataLoader(
        dataset=evaluate_dataset,
        batch_sampler=evaluate_train_sampler,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=True)

    validate_loader = torch.utils.data.DataLoader(
        dataset=evaluate_dataset,
        batch_sampler=evaluate_validate_sampler,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
        dataset=evaluate_dataset,
        batch_sampler=evaluate_test_sampler,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=True)

    # Evaluator
    evaluator = SegmentEvaluator(model, batch_size)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)

    # Resume training
    if resume_iteration > 0:
        resume_checkpoint_path = os.path.join(
            workspace, 'checkpoints', filename, model_type,
            'loss_type={}'.format(loss_type),
            'augmentation={}'.format(augmentation),
            'batch_size={}'.format(batch_size),
            '{}_iterations.pth'.format(resume_iteration))

        logging.info('Loading checkpoint {}'.format(resume_checkpoint_path))
        checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(checkpoint['model'])
        train_sampler.load_state_dict(checkpoint['sampler'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = checkpoint['iteration']

    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in str(device):
        model.to(device)

    train_bgn_time = time.time()

    for batch_data_dict in train_loader:

        # Evaluation
        if iteration % 5000 == 0:  # and iteration > 0:
            logging.info('------------------------------------')
            logging.info('Iteration: {}'.format(iteration))

            train_fin_time = time.time()

            evaluate_train_statistics = evaluator.evaluate(
                evaluate_train_loader)
            validate_statistics = evaluator.evaluate(validate_loader)
            test_statistics = evaluator.evaluate(test_loader)

            logging.info(
                '    Train statistics: {}'.format(evaluate_train_statistics))
            logging.info(
                '    Validation statistics: {}'.format(validate_statistics))
            logging.info('    Test statistics: {}'.format(test_statistics))

            statistics_container.append(iteration,
                                        evaluate_train_statistics,
                                        data_type='train')
            statistics_container.append(iteration,
                                        validate_statistics,
                                        data_type='validation')
            statistics_container.append(iteration,
                                        test_statistics,
                                        data_type='test')
            statistics_container.dump()

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info('Train time: {:.3f} s, validate time: {:.3f} s'
                         ''.format(train_time, validate_time))

            train_bgn_time = time.time()

        # Save model
        if iteration % 20000 == 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        # Reduce learning rate
        if iteration % reduce_iteration == 0 and iteration > 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.9

        # Move data to device
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key],
                                                       device)

        model.train()
        batch_output_dict = model(batch_data_dict['waveform'])

        loss = loss_func(model, batch_output_dict, batch_data_dict)

        print(iteration, loss)

        # Backward
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        # Stop learning
        if iteration == early_stop:
            break

        iteration += 1
Ejemplo n.º 5
0
def train(args):
    """Train AudioSet tagging model. 

    Args:
      dataset_dir: str
      workspace: str
      data_type: 'balanced_train' | 'unbalanced_train'
      frames_per_second: int
      mel_bins: int
      model_type: str
      loss_type: 'bce'
      balanced: bool
      augmentation: str
      batch_size: int
      learning_rate: float
      resume_iteration: int
      early_stop: int
      accumulation_steps: int
      cuda: bool
    """

    # Arugments & parameters
    workspace = args.workspace
    data_type = args.data_type
    window_size = args.window_size
    hop_size = args.hop_size
    mel_bins = args.mel_bins
    fmin = args.fmin
    fmax = args.fmax
    model_type = args.model_type
    loss_type = args.loss_type
    balanced = args.balanced
    augmentation = args.augmentation
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    resume_iteration = args.resume_iteration
    early_stop = args.early_stop
    device = torch.device('cuda') if args.cuda and torch.cuda.is_available(
    ) else torch.device('cpu')
    filename = args.filename

    num_workers = 8
    sample_rate = config.sample_rate
    clip_samples = config.clip_samples
    classes_num = config.classes_num
    loss_func = get_loss_func(loss_type)

    # Paths
    black_list_csv = os.path.join(workspace, 'black_list',
                                  'dcase2017task4.csv')

    train_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                           '{}.h5'.format(data_type))

    eval_bal_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                              'balanced_train.h5')

    eval_test_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                               'eval.h5')

    checkpoints_dir = os.path.join(
        workspace, 'checkpoints', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))
    create_folder(checkpoints_dir)

    statistics_path = os.path.join(
        workspace, 'statistics', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size), 'statistics.pkl')
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(
        workspace, 'logs', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))

    create_logging(logs_dir, filemode='w')
    logging.info(args)

    if 'cuda' in str(device):
        logging.info('Using GPU.')
        device = 'cuda'
    else:
        logging.info('Using CPU.')
        device = 'cpu'

    # Model
    Model = eval(model_type)
    model = Model(sample_rate=sample_rate,
                  window_size=window_size,
                  hop_size=hop_size,
                  mel_bins=mel_bins,
                  fmin=fmin,
                  fmax=fmax,
                  classes_num=classes_num)

    params_num = count_parameters(model)
    # flops_num = count_flops(model, clip_samples)
    logging.info('Parameters num: {}'.format(params_num))
    # logging.info('Flops num: {:.3f} G'.format(flops_num / 1e9))

    # Dataset will be used by DataLoader later. Dataset takes a meta as input
    # and return a waveform and a target.
    dataset = AudioSetDataset(clip_samples=clip_samples,
                              classes_num=classes_num)

    # Train sampler
    (train_sampler,
     train_collector) = get_train_sampler(balanced, augmentation,
                                          train_indexes_hdf5_path,
                                          black_list_csv, batch_size)

    # Evaluate sampler
    eval_bal_sampler = EvaluateSampler(
        indexes_hdf5_path=eval_bal_indexes_hdf5_path, batch_size=batch_size)

    eval_test_sampler = EvaluateSampler(
        indexes_hdf5_path=eval_test_indexes_hdf5_path, batch_size=batch_size)

    eval_collector = Collator(mixup_alpha=None)

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_sampler=train_sampler,
                                               collate_fn=train_collector,
                                               num_workers=num_workers,
                                               pin_memory=True)

    eval_bal_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_bal_sampler,
        collate_fn=eval_collector,
        num_workers=num_workers,
        pin_memory=True)

    eval_test_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_test_sampler,
        collate_fn=eval_collector,
        num_workers=num_workers,
        pin_memory=True)

    # Evaluator
    bal_evaluator = Evaluator(model=model, generator=eval_bal_loader)
    test_evaluator = Evaluator(model=model, generator=eval_test_loader)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)

    train_bgn_time = time.time()

    # Resume training
    if resume_iteration > 0:
        resume_checkpoint_path = os.path.join(
            workspace, 'checkpoints', filename,
            'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
            .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                    fmax), 'data_type={}'.format(data_type), model_type,
            'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
            'augmentation={}'.format(augmentation),
            'batch_size={}'.format(batch_size),
            '{}_iterations.pth'.format(resume_iteration))

        logging.info('Loading checkpoint {}'.format(resume_checkpoint_path))
        checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(checkpoint['model'])
        train_sampler.load_state_dict(checkpoint['sampler'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = checkpoint['iteration']

    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in str(device):
        model.to(device)

    time1 = time.time()

    for batch_data_dict in train_loader:
        """batch_data_dict: {
            'audio_name': (batch_size [*2 if mixup],), 
            'waveform': (batch_size [*2 if mixup], clip_samples), 
            'target': (batch_size [*2 if mixup], classes_num), 
            (ifexist) 'mixup_lambda': (batch_size * 2,)}
        """

        # Evaluate
        if (iteration % 2000 == 0
                and iteration > resume_iteration) or (iteration == 0):
            train_fin_time = time.time()

            bal_statistics = bal_evaluator.evaluate()
            test_statistics = test_evaluator.evaluate()

            logging.info('Validate bal mAP: {:.3f}'.format(
                np.mean(bal_statistics['average_precision'])))

            logging.info('Validate test mAP: {:.3f}'.format(
                np.mean(test_statistics['average_precision'])))

            statistics_container.append(iteration,
                                        bal_statistics,
                                        data_type='bal')
            statistics_container.append(iteration,
                                        test_statistics,
                                        data_type='test')
            statistics_container.dump()

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info(
                'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s'
                ''.format(iteration, train_time, validate_time))

            logging.info('------------------------------------')

            train_bgn_time = time.time()

        # Save model
        if iteration % 20000 == 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        # Move data to device
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key],
                                                       device)

        # Forward
        model.train()

        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'],
                                      batch_data_dict['mixup_lambda'])
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {
                'target':
                do_mixup(batch_data_dict['target'],
                         batch_data_dict['mixup_lambda'])
            }
            """{'target': (batch_size, classes_num)}"""
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {'target': batch_data_dict['target']}
            """{'target': (batch_size, classes_num)}"""

        # Loss
        loss = loss_func(batch_output_dict, batch_target_dict)

        # Backward
        loss.backward()
        print(loss)

        optimizer.step()
        optimizer.zero_grad()

        if iteration % 10 == 0:
            print('--- Iteration: {}, train time: {:.3f} s / 10 iterations ---'\
                .format(iteration, time.time() - time1))
            time1 = time.time()

        iteration += 1

        # Stop learning
        if iteration == early_stop:
            break
Ejemplo n.º 6
0
                          audio_length=220500,
                          classes_num=7)
val_loader = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=16,
                                         num_workers=0,
                                         shuffle=True)

# model = Cnn14(sample_rate=44100, window_size=400, hop_size=160, mel_bins=64, fmin=50, fmax=14000, classes_num=7)
J = 6
Q = 8

model = Cnn14_scatter(classes_num=7, J=J, Q=Q, audio_length=110250)

model.cuda()

loss_func = get_loss_func('clip_bce')
optimizer = optim.Adam(model.parameters(),
                       lr=1e-3,
                       betas=(0.9, 0.999),
                       eps=1e-08,
                       weight_decay=0.,
                       amsgrad=True)

max_epoch = 10
best_acc = 0

logs = {}
logs['J'] = J
logs['Q'] = Q

num_try = '1_scatter_fixed'
def train(args):
    """Train and evaluate.

    Args:
      dataset_dir: str
      workspace: str
      holdout_fold: '1'
      model_type: str, e.g., 'Cnn_9layers_Gru_FrameAtt'
      loss_type: str, e.g., 'clip_bce'
      augmentation: str, e.g., 'mixup'
      learning_rate, float
      batch_size: int
      resume_iteration: int
      stop_iteration: int
      device: 'cuda' | 'cpu'
      mini_data: bool
    """

    # Arugments & parameters
    dataset_dir = args.dataset_dir
    workspace = args.workspace
    holdout_fold = args.holdout_fold
    model_type = args.model_type
    loss_type = args.loss_type
    augmentation = args.augmentation
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    resume_iteration = args.resume_iteration
    stop_iteration = args.stop_iteration
    device = 'cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu'
    mini_data = args.mini_data
    filename = args.filename

    loss_func = get_loss_func(loss_type)
    num_workers = 8

    # Paths
    if mini_data:
        prefix = 'minidata_'
    else:
        prefix = ''

    train_hdf5_path = os.path.join(workspace, 'hdf5s',
                                   '{}training.h5'.format(prefix))

    test_hdf5_path = os.path.join(workspace, 'hdf5s',
                                  '{}testing.h5'.format(prefix))

    evaluate_hdf5_path = os.path.join(workspace, 'hdf5s',
                                      'evaluation.h5'.format(prefix))

    test_reference_csv_path = os.path.join(
        dataset_dir, 'metadata', 'groundtruth_strong_label_testing_set.csv')

    evaluate_reference_csv_path = os.path.join(
        dataset_dir, 'metadata', 'groundtruth_strong_label_evaluation_set.csv')

    checkpoints_dir = os.path.join(workspace, 'checkpoints', '{}{}'.format(
        prefix, filename), 'holdout_fold={}'.format(holdout_fold),
                                   'model_type={}'.format(model_type),
                                   'loss_type={}'.format(loss_type),
                                   'augmentation={}'.format(augmentation),
                                   'batch_size={}'.format(batch_size))
    create_folder(checkpoints_dir)

    tmp_submission_path = os.path.join(
        workspace, '_tmp_submission', '{}{}'.format(prefix, filename),
        'holdout_fold={}'.format(holdout_fold),
        'model_type={}'.format(model_type), 'loss_type={}'.format(loss_type),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size), '_submission.csv')
    create_folder(os.path.dirname(tmp_submission_path))

    statistics_path = os.path.join(workspace, 'statistics', '{}{}'.format(
        prefix, filename), 'holdout_fold={}'.format(holdout_fold),
                                   'model_type={}'.format(model_type),
                                   'loss_type={}'.format(loss_type),
                                   'augmentation={}'.format(augmentation),
                                   'batch_size={}'.format(batch_size),
                                   'statistics.pickle')
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(workspace, 'logs', '{}{}'.format(prefix, filename),
                            'holdout_fold={}'.format(holdout_fold),
                            'model_type={}'.format(model_type),
                            'loss_type={}'.format(loss_type),
                            'augmentation={}'.format(augmentation),
                            'batch_size={}'.format(batch_size))
    create_logging(logs_dir, 'w')
    logging.info(args)

    if 'cuda' in device:
        logging.info('Using GPU.')
    else:
        logging.info('Using CPU. Set --cuda flag to use GPU.')

    # Model
    assert model_type, 'Please specify model_type!'
    Model = eval(model_type)
    model = Model(sample_rate, window_size, hop_size, mel_bins, fmin, fmax,
                  classes_num)

    if resume_iteration:
        resume_checkpoint_path = os.path.join(
            checkpoints_dir, '{}_iterations.pth'.format(resume_iteration))
        logging.info(
            'Load resume model from {}'.format(resume_checkpoint_path))
        resume_checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(resume_checkpoint['model'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = resume_checkpoint['iteration']
    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in device:
        model.to(device)

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)

    # Dataset
    dataset = DCASE2017Task4Dataset()

    # Sampler
    train_sampler = TrainSampler(hdf5_path=train_hdf5_path,
                                 batch_size=batch_size *
                                 2 if 'mixup' in augmentation else batch_size)

    test_sampler = TestSampler(hdf5_path=test_hdf5_path, batch_size=batch_size)

    evaluate_sampler = TestSampler(hdf5_path=evaluate_hdf5_path,
                                   batch_size=batch_size)

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_sampler=train_sampler,
                                               collate_fn=collate_fn,
                                               num_workers=num_workers,
                                               pin_memory=True)

    test_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_sampler=test_sampler,
                                              collate_fn=collate_fn,
                                              num_workers=num_workers,
                                              pin_memory=True)

    evaluate_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=evaluate_sampler,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=True)

    if 'mixup' in augmentation:
        mixup_augmenter = Mixup(mixup_alpha=1.)

    # Evaluator
    evaluator = Evaluator(model=model)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    train_bgn_time = time.time()

    # Train on mini batches
    for batch_data_dict in train_loader:

        # Evaluate
        if (iteration % 1000 == 0
                and iteration > resume_iteration):  # or (iteration == 0):

            logging.info('------------------------------------')
            logging.info('Iteration: {}'.format(iteration))

            train_fin_time = time.time()

            for (data_type, data_loader, reference_csv_path) in [
                ('test', test_loader, test_reference_csv_path),
                ('evaluate', evaluate_loader, evaluate_reference_csv_path)
            ]:

                # Calculate tatistics
                (statistics, _) = evaluator.evaluate(data_loader,
                                                     reference_csv_path,
                                                     tmp_submission_path)

                logging.info('{} statistics:'.format(data_type))
                logging.info('    Clipwise mAP: {:.3f}'.format(
                    np.mean(statistics['clipwise_ap'])))
                logging.info('    Framewise mAP: {:.3f}'.format(
                    np.mean(statistics['framewise_ap'])))
                logging.info('    {}'.format(
                    statistics['sed_metrics']['overall']['error_rate']))

                statistics_container.append(data_type, iteration, statistics)

            statistics_container.dump()

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info('Train time: {:.3f} s, validate time: {:.3f} s'
                         ''.format(train_time, validate_time))

            train_bgn_time = time.time()

        # Save model
        if iteration % 10000 == 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        if 'mixup' in augmentation:
            batch_data_dict['mixup_lambda'] = mixup_augmenter.get_lambda(
                batch_size=len(batch_data_dict['waveform']))

        # Move data to GPU
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key],
                                                       device)

        # Train
        model.train()

        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'],
                                      batch_data_dict['mixup_lambda'])
            batch_target_dict = {
                'target':
                do_mixup(batch_data_dict['target'],
                         batch_data_dict['mixup_lambda'])
            }
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            batch_target_dict = {'target': batch_data_dict['target']}

        # loss
        loss = loss_func(batch_output_dict, batch_target_dict)
        print(iteration, loss)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Stop learning
        if iteration == stop_iteration:
            break

        iteration += 1
Ejemplo n.º 8
0
def train(args):

    # Arugments & parameters
    dataset_dir = args.dataset_dir
    workspace = args.workspace
    holdout_fold = args.holdout_fold
    model_type = args.model_type
    pretrained_checkpoint_path = args.pretrained_checkpoint_path
    freeze_base = args.freeze_base
    loss_type = args.loss_type
    augmentation = args.augmentation
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    resume_iteration = args.resume_iteration
    stop_iteration = args.stop_iteration
    device = 'cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu'
    filename = args.filename
    num_workers = 8

    loss_func = get_loss_func(loss_type)
    pretrain = True if pretrained_checkpoint_path else False

    #TODO вернуть путь до полного набора обработанных данных

    hdf5_path = os.path.join(workspace, 'features_ramas',
                             'waveform_meta_test.h5')
    # hdf5_path = os.path.join(workspace, 'features', 'waveform.h5')

    checkpoints_dir = os.path.join(workspace, 'checkpoints')
    create_folder(checkpoints_dir)

    statistics_path = os.path.join(workspace, 'statistics', filename,
                                   'holdout_fold={}'.format(holdout_fold),
                                   model_type, 'pretrain={}'.format(pretrain),
                                   'loss_type={}'.format(loss_type),
                                   'augmentation={}'.format(augmentation),
                                   'batch_size={}'.format(batch_size),
                                   'freeze_base={}'.format(freeze_base),
                                   'statistics.pickle')
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(workspace, 'logs', filename,
                            'holdout_fold={}'.format(holdout_fold), model_type,
                            'pretrain={}'.format(pretrain),
                            'loss_type={}'.format(loss_type),
                            'augmentation={}'.format(augmentation),
                            'batch_size={}'.format(batch_size),
                            'freeze_base={}'.format(freeze_base))
    create_logging(logs_dir, 'w')
    logging.info(args)

    if 'cuda' in device:
        logging.info('Using GPU.')
    else:
        logging.info('Using CPU. Set --cuda flag to use GPU.')

    # Model
    Model = eval(model_type)

    #TODO захардкодил classes num- это нехорошо
    model = Model(sample_rate, window_size, hop_size, mel_bins, fmin, fmax, 4,
                  freeze_base)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    if pretrain:
        logging.info(
            'Load pretrained model from {}'.format(pretrained_checkpoint_path))
        model.load_from_pretrain(pretrained_checkpoint_path)

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    dataset = GtzanDataset()

    validate_sampler = EvaluateSampler(hdf5_path=hdf5_path,
                                       holdout_fold=holdout_fold,
                                       batch_size=1)

    validate_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=validate_sampler,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=True)

    if 'cuda' in device:
        model.to(device)

    # Evaluator
    evaluator = Evaluator(model=model)

    torch.manual_seed(729720439)

    statistics, output_dict = evaluator.evaluate(validate_loader)
    logging.info('Validate precision: {:.3f}'.format(statistics['precision']))
    logging.info('Validate recall: {:.3f}'.format(statistics['recall']))
    logging.info('Validate f_score: {:.3f}'.format(statistics['f_score']))
    logging.info('\n' + str(statistics['cm']))

    df_audio = pd.read_csv(
        '/home/den/DATASETS/AUDIO/preprocessed/ramas/meta_test.csv')
    df_audio = df_audio[df_audio['cur_label'].isin(
        ['ang', 'hap', 'sad', 'neu'])]

    temp_df = pd.DataFrame(columns=['cur_name', 'hap', 'ang', 'sad', 'neu'])
    temp_df['cur_name'] = output_dict['audio_name']
    temp_df.loc[:, ['hap', 'ang', 'sad', 'neu']] = np.vstack(
        output_dict['clipwise_output2'])

    merge_df = pd.merge(df_audio, temp_df, on='cur_name', how='inner')
    merge_df.to_csv(
        '/home/den/Documents/diploma/panns/panns_ramas_inference.csv',
        index=False)
def train(args):

    # Arugments & parameters
    window_size = args.window_size
    hop_size = args.hop_size
    mel_bins = args.mel_bins
    fmin = args.fmin
    fmax = args.fmax
    model_type = args.model_type
    pretrained_checkpoint_path = args.pretrained_checkpoint_path
    freeze_base = args.freeze_base
    freeze_base = True
    device = 'cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu'
    sample_rate = config.sample_rate
    classes_num = config.classes_num
    pretrain = True if pretrained_checkpoint_path else False

    # Model
    Model = eval(model_type)
    model = Model(sample_rate, window_size, hop_size, mel_bins, fmin, fmax,
                  classes_num, freeze_base)

    # Load pretrained model
    if pretrain:
        logging.info(
            'Load pretrained model from {}'.format(pretrained_checkpoint_path))
        model.load_from_pretrain(pretrained_checkpoint_path)

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in device:
        model.to(device)

    print('Load pretrained model successfully!')
    ###############Copying main.py####################
    workspace_input = args.workspace_input
    workspace_output = args.workspace_output
    data_type = 'balanced_train'
    loss_type = 'clip_bce'
    balanced = 'balanced'
    augmentation = 'none'
    batch_size = 1
    learning_rate = 1e-3
    resume_iteration = 0
    early_stop = 100000
    device = torch.device('cuda') if args.cuda and torch.cuda.is_available(
    ) else torch.device('cpu')
    filename = args.filename
    num_workers = 8
    clip_samples = config.clip_samples
    loss_func = get_loss_func(loss_type)
    black_list_csv = 'metadata/black_list/groundtruth_weak_label_evaluation_set.csv'
    previous_loss = None

    train_indexes_hdf5_path = os.path.join(workspace_input, 'hdf5s', 'indexes',
                                           '{}.h5'.format(data_type))

    eval_bal_indexes_hdf5_path = os.path.join(workspace_input, 'hdf5s',
                                              'indexes', 'balanced_train.h5')

    eval_test_indexes_hdf5_path = os.path.join(workspace_input, 'hdf5s',
                                               'indexes', 'eval.h5')

    checkpoints_dir = os.path.join(
        workspace_output, 'checkpoints', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))
    create_folder(checkpoints_dir)

    statistics_path = os.path.join(
        workspace_output, 'statistics', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size), 'statistics.pkl')
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(
        workspace_output, 'logs', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))

    create_logging(logs_dir, filemode='w')
    logging.info(args)

    if 'cuda' in str(device):
        logging.info('Using GPU.')
        device = 'cuda'
    else:
        logging.info('Using CPU.')
        device = 'cpu'

    # Model
    Model = eval(model_type)
    model = Model(sample_rate=sample_rate,
                  window_size=window_size,
                  hop_size=hop_size,
                  mel_bins=mel_bins,
                  fmin=fmin,
                  fmax=fmax,
                  classes_num=classes_num,
                  freeze_base=freeze_base)
    params_num = count_parameters(model)
    # flops_num = count_flops(model, clip_samples)
    logging.info('Parameters num: {}'.format(params_num))
    # logging.info('Flops num: {:.3f} G'.format(flops_num / 1e9))

    # Dataset will be used by DataLoader later. Dataset takes a meta as input
    # and return a waveform and a target.
    dataset = AudioSetDataset(clip_samples=clip_samples,
                              classes_num=classes_num)

    # Train sampler
    (train_sampler, train_collector) = get_train_sampler(
        balanced, augmentation,
        workspace_input + 'hdf5s/indexes/balanced_train.h5', black_list_csv,
        batch_size)

    # Evaluate sampler
    eval_bal_sampler = EvaluateSampler(indexes_hdf5_path=workspace_input +
                                       'hdf5s/indexes/balanced_train.h5',
                                       batch_size=batch_size)

    eval_test_sampler = EvaluateSampler(indexes_hdf5_path=workspace_input +
                                        'hdf5s/indexes/eval.h5',
                                        batch_size=batch_size)

    eval_collector = Collator(mixup_alpha=None)

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_sampler=train_sampler,
                                               collate_fn=train_collector,
                                               num_workers=num_workers,
                                               pin_memory=True)

    eval_bal_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_bal_sampler,
        collate_fn=eval_collector,
        num_workers=num_workers,
        pin_memory=True)

    eval_test_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_test_sampler,
        collate_fn=eval_collector,
        num_workers=num_workers,
        pin_memory=True)

    # Evaluator
    bal_evaluator = Evaluator(model=model, generator=eval_bal_loader)
    test_evaluator = Evaluator(model=model, generator=eval_test_loader)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)

    train_bgn_time = time.time()
    if resume_iteration > 0:
        resume_checkpoint_path = os.path.join(
            workspace_input, 'checkpoints', filename,
            'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
            .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                    fmax), 'data_type={}'.format(data_type), model_type,
            'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
            'augmentation={}'.format(augmentation),
            'batch_size={}'.format(batch_size),
            '{}_iterations.pth'.format(resume_iteration))

        logging.info('Loading checkpoint {}'.format(resume_checkpoint_path))
        if torch.cuda.is_available():
            checkpoint = torch.load(resume_checkpoint_path)
        else:
            checkpoint = torch.load(resume_checkpoint_path, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        train_sampler.load_state_dict(checkpoint['sampler'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = checkpoint['iteration']

    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in str(device):
        model.to(device)

    time1 = time.time()

    for iterate_n, batch_data_dict in enumerate(train_loader):
        """batch_data_dict: {
            'audio_name': (batch_size [*2 if mixup],), 
            'waveform': (batch_size [*2 if mixup], clip_samples), 
            'target': (batch_size [*2 if mixup], classes_num), 
            (ifexist) 'mixup_lambda': (batch_size * 2,)}
        """

        # Evaluate
        if (iteration % 2000 == 0
                and iteration > resume_iteration) or (iteration == 0):
            train_fin_time = time.time()

            bal_statistics = bal_evaluator.evaluate()
            test_statistics = test_evaluator.evaluate()

            logging.info('Validate bal mAP: {:.3f}'.format(
                np.mean(bal_statistics['average_precision'])))

            logging.info('Validate test mAP: {:.3f}'.format(
                np.mean(test_statistics['average_precision'])))

            statistics_container.append(iteration,
                                        bal_statistics,
                                        data_type='bal')
            statistics_container.append(iteration,
                                        test_statistics,
                                        data_type='test')
            statistics_container.dump()

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info(
                'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s'
                ''.format(iteration, train_time, validate_time))

            logging.info('------------------------------------')

            train_bgn_time = time.time()

        # Save model
        if iteration % 20000 == 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        # Move data to device
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key],
                                                       device)

        # Forward
        model.train()
        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'],
                                      batch_data_dict['mixup_lambda'])
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {
                'target':
                do_mixup(batch_data_dict['target'],
                         batch_data_dict['mixup_lambda'])
            }
            """{'target': (batch_size, classes_num)}"""
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {'target': batch_data_dict['target']}
            """{'target': (batch_size, classes_num)}"""
        loss = loss_func(batch_output_dict, batch_target_dict)
        # Loss
        # try:
        #     loss = loss_func(batch_output_dict, batch_target_dict)
        # except:
        #     tensor = batch_output_dict['clipwise_output'].detach().numpy()
        #     arr = -1. * np.where(tensor > 0,0.,tensor)
        #     batch_output_dict['clipwise_output'] = torch.tensor(np.where(arr > 1,1.,arr),requires_grad=True)
        #     loss = loss_func(batch_output_dict, batch_target_dict)
        # Backward
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if iteration % 10 == 0:
            print('--- Iteration: {}, train time: {:.3f} s / 10 iterations ---'\
                .format(iteration, time.time() - time1))
            time1 = time.time()

        iteration += 1

        # Stop learning
        if iteration == early_stop:
            break
Ejemplo n.º 10
0
def train(args):

    # Arugments & parameters
    dataset_dir = args.dataset_dir
    workspace = args.workspace
    holdout_fold = args.holdout_fold
    model_type = args.model_type
    pretrained_checkpoint_path = args.pretrained_checkpoint_path
    freeze_base = args.freeze_base
    loss_type = args.loss_type
    augmentation = args.augmentation
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    resume_iteration = args.resume_iteration
    stop_iteration = args.stop_iteration
    device = 'cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu'
    filename = args.filename
    num_workers = 8

    loss_func = get_loss_func(loss_type)
    pretrain = True if pretrained_checkpoint_path else False

    hdf5_path = os.path.join(workspace, 'features', 'waveform.h5')

    checkpoints_dir = os.path.join(workspace, 'checkpoints', filename,
                                   'holdout_fold={}'.format(holdout_fold),
                                   model_type, 'pretrain={}'.format(pretrain),
                                   'loss_type={}'.format(loss_type),
                                   'augmentation={}'.format(augmentation),
                                   'batch_size={}'.format(batch_size),
                                   'freeze_base={}'.format(freeze_base))
    create_folder(checkpoints_dir)

    statistics_path = os.path.join(workspace, 'statistics', filename,
                                   'holdout_fold={}'.format(holdout_fold),
                                   model_type, 'pretrain={}'.format(pretrain),
                                   'loss_type={}'.format(loss_type),
                                   'augmentation={}'.format(augmentation),
                                   'batch_size={}'.format(batch_size),
                                   'freeze_base={}'.format(freeze_base),
                                   'statistics.pickle')
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(workspace, 'logs', filename,
                            'holdout_fold={}'.format(holdout_fold), model_type,
                            'pretrain={}'.format(pretrain),
                            'loss_type={}'.format(loss_type),
                            'augmentation={}'.format(augmentation),
                            'batch_size={}'.format(batch_size),
                            'freeze_base={}'.format(freeze_base))
    create_logging(logs_dir, 'w')
    logging.info(args)

    if 'cuda' in device:
        logging.info('Using GPU.')
    else:
        logging.info('Using CPU. Set --cuda flag to use GPU.')

    # Model
    Model = eval(model_type)
    model = Model(sample_rate, window_size, hop_size, mel_bins, fmin, fmax,
                  classes_num, freeze_base)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    if pretrain:
        logging.info(
            'Load pretrained model from {}'.format(pretrained_checkpoint_path))
        model.load_from_pretrain(pretrained_checkpoint_path)

    if resume_iteration:
        resume_checkpoint_path = os.path.join(
            checkpoints_dir, '{}_iterations.pth'.format(resume_iteration))
        logging.info(
            'Load resume model from {}'.format(resume_checkpoint_path))
        resume_checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(resume_checkpoint['model'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = resume_checkpoint['iteration']
    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    dataset = GtzanDataset()

    # Data generator
    train_sampler = TrainSampler(hdf5_path=hdf5_path,
                                 holdout_fold=holdout_fold,
                                 batch_size=batch_size *
                                 2 if 'mixup' in augmentation else batch_size)

    validate_sampler = EvaluateSampler(hdf5_path=hdf5_path,
                                       holdout_fold=holdout_fold,
                                       batch_size=batch_size)

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_sampler=train_sampler,
                                               collate_fn=collate_fn,
                                               num_workers=num_workers,
                                               pin_memory=True)

    validate_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=validate_sampler,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=True)

    if 'cuda' in device:
        model.to(device)

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)

    if 'mixup' in augmentation:
        mixup_augmenter = Mixup(mixup_alpha=1.)

    # Evaluator
    evaluator = Evaluator(model=model)

    train_bgn_time = time.time()

    # Train on mini batches
    for batch_data_dict in train_loader:

        # import crash
        # asdf

        # Evaluate
        if iteration % 200 == 0 and iteration > 0:
            if resume_iteration > 0 and iteration == resume_iteration:
                pass
            else:
                logging.info('------------------------------------')
                logging.info('Iteration: {}'.format(iteration))

                train_fin_time = time.time()

                statistics = evaluator.evaluate(validate_loader)
                logging.info('Validate accuracy: {:.3f}'.format(
                    statistics['accuracy']))

                statistics_container.append(iteration, statistics, 'validate')
                statistics_container.dump()

                train_time = train_fin_time - train_bgn_time
                validate_time = time.time() - train_fin_time

                logging.info('Train time: {:.3f} s, validate time: {:.3f} s'
                             ''.format(train_time, validate_time))

                train_bgn_time = time.time()

        # Save model
        if iteration % 2000 == 0 and iteration > 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        if 'mixup' in augmentation:
            batch_data_dict['mixup_lambda'] = mixup_augmenter.get_lambda(
                len(batch_data_dict['waveform']))

        # Move data to GPU
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key],
                                                       device)

        # Train
        model.train()

        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'],
                                      batch_data_dict['mixup_lambda'])
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {
                'target':
                do_mixup(batch_data_dict['target'],
                         batch_data_dict['mixup_lambda'])
            }
            """{'target': (batch_size, classes_num)}"""
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {'target': batch_data_dict['target']}
            """{'target': (batch_size, classes_num)}"""

        # loss
        loss = loss_func(batch_output_dict, batch_target_dict)
        print(iteration, loss)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Stop learning
        if iteration == stop_iteration:
            break

        iteration += 1