Beispiel #1
0
def get_train_sampler(balanced, augmentation, train_indexes_hdf5_path,
                      black_list_csv, batch_size):
    """Get train sampler.

    Args:
      balanced: str
      augmentation: str
      train_indexes_hdf5_path: str
      black_list_csv: str
      batch_size: int

    Returns:
      train_sampler: object
      train_collector: object
    """
    if balanced == 'none':
        train_sampler = Sampler(indexes_hdf5_path=train_indexes_hdf5_path,
                                black_list_csv=black_list_csv,
                                batch_size=batch_size)
        train_collector = Collator(mixup_alpha=None)

    elif balanced == 'balanced':
        if augmentation == 'none':
            print('sampling balanced, non augmented dataset')
            train_sampler = BalancedSampler(
                indexes_hdf5_path=train_indexes_hdf5_path,
                black_list_csv=black_list_csv,
                batch_size=batch_size)
            train_collector = Collator(mixup_alpha=None)

        elif 'mixup' in augmentation:
            if augmentation == 'mixup_from_0_epoch':
                start_mix_epoch = 0
            elif augmentation == 'mixup':
                start_mix_epoch = 1
            else:
                raise Exception('Incorrect augmentation!')

            assert batch_size % torch.cuda.device_count() == 0, \
                'To let mixup working properly this must be satisfied.'

            train_sampler = BalancedMixupSampler(
                indexes_hdf5_path=train_indexes_hdf5_path,
                black_list_csv=black_list_csv,
                batch_size=batch_size * 2,
                start_mix_epoch=start_mix_epoch)
            train_collector = Collator(mixup_alpha=1.)

        else:
            raise Exception('Incorrect augmentation!')

    else:
        raise Exception('Incorrect balanced type!')

    return train_sampler, train_collector
Beispiel #2
0
def train(args):

    # Arugments & parameters
    dataset_dir = args.dataset_dir
    workspace = args.workspace
    holdout_fold = args.holdout_fold
    model_type = args.model_type
    pretrained_checkpoint_path = args.pretrained_checkpoint_path
    freeze_base = args.freeze_base
    loss_type = args.loss_type
    augmentation = args.augmentation
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    few_shots = args.few_shots
    random_seed = args.random_seed
    resume_iteration = args.resume_iteration
    stop_iteration = args.stop_iteration
    device = 'cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu'
    mini_data = args.mini_data
    filename = args.filename

    loss_func = get_loss_func(loss_type)
    pretrain = True if pretrained_checkpoint_path else False
    num_workers = 16
    
    # Paths
    if mini_data:
        prefix = 'minidata_'
    else:
        prefix = ''

    train_hdf5_path = os.path.join(workspace, 'features', 
        '{}training.waveform.h5'.format(prefix))

    test_hdf5_path = os.path.join(workspace, 'features', 
        'testing.waveform.h5'.format(prefix))

    evaluate_hdf5_path = os.path.join(workspace, 'features', 
        'evaluation.waveform.h5'.format(prefix))

    test_reference_csv_path = os.path.join(dataset_dir, 'metadata', 
        'groundtruth_strong_label_testing_set.csv')
        
    evaluate_reference_csv_path = os.path.join(dataset_dir, 'metadata', 
        'groundtruth_strong_label_evaluation_set.csv')

    checkpoints_dir = os.path.join(workspace, 'checkpoints', filename, 
        'holdout_fold={}'.format(holdout_fold), model_type, 
        'pretrain={}'.format(pretrain), 'loss_type={}'.format(loss_type), 
        'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size), 
        'few_shots={}'.format(few_shots), 'random_seed={}'.format(random_seed), 
        'freeze_base={}'.format(freeze_base))
    create_folder(checkpoints_dir)

    tmp_submission_path = os.path.join(workspace, '_tmp_submission', 
        '{}{}'.format(prefix, filename), 'holdout_fold={}'.format(holdout_fold), 
        model_type, 'pretrain={}'.format(pretrain), 'loss_type={}'.format(loss_type), 
        'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size), 
        'few_shots={}'.format(few_shots), 'random_seed={}'.format(random_seed), 
        'freeze_base={}'.format(freeze_base), '_submission.csv')
    create_folder(os.path.dirname(tmp_submission_path))

    statistics_path = os.path.join(workspace, 'statistics', 
        '{}{}'.format(prefix, filename), 'holdout_fold={}'.format(holdout_fold), 
        model_type, 'pretrain={}'.format(pretrain), 'loss_type={}'.format(loss_type), 
        'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size), 
        'few_shots={}'.format(few_shots), 'random_seed={}'.format(random_seed), 
        'freeze_base={}'.format(freeze_base), 'statistics.pickle')
    create_folder(os.path.dirname(statistics_path))

    predictions_dir = os.path.join(workspace, 'predictions', 
        '{}{}'.format(prefix, filename), 'holdout_fold={}'.format(holdout_fold), 
        model_type, 'pretrain={}'.format(pretrain), 
        'loss_type={}'.format(loss_type), 'augmentation={}'.format(augmentation), 
        'few_shots={}'.format(few_shots), 'random_seed={}'.format(random_seed), 
        'freeze_base={}'.format(freeze_base), 'batch_size={}'.format(batch_size))
    create_folder(predictions_dir)

    logs_dir = os.path.join(workspace, 'logs', '{}{}'.format(prefix, filename), 
        'holdout_fold={}'.format(holdout_fold), model_type, 
        'pretrain={}'.format(pretrain), 'loss_type={}'.format(loss_type), 
        'augmentation={}'.format(augmentation), 'few_shots={}'.format(few_shots), 
        'random_seed={}'.format(random_seed), 'freeze_base={}'.format(freeze_base), 
        'batch_size={}'.format(batch_size))
    create_logging(logs_dir, 'w')
    logging.info(args)

    if 'cuda' in device:
        logging.info('Using GPU.')
    else:
        logging.info('Using CPU. Set --cuda flag to use GPU.')
    
    # Model
    Model = eval(model_type)
    model = Model(sample_rate, window_size, hop_size, mel_bins, fmin, fmax, 
        classes_num)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    if pretrain:
        logging.info('Load pretrained model from {}'.format(pretrained_checkpoint_path))
        model.load_from_pretrain(pretrained_checkpoint_path)

    if resume_iteration:
        resume_checkpoint_path = os.path.join(checkpoints_dir, '{}_iterations.pth'.format(resume_iteration))
        logging.info('Load resume model from {}'.format(resume_checkpoint_path))
        resume_checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(resume_checkpoint['model'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = resume_checkpoint['iteration']
    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in device:
        model.to(device)

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, 
        betas=(0.9, 0.999), eps=1e-08, weight_decay=0., amsgrad=True)

    train_dataset = DCASE2017Task4Dataset(hdf5_path=train_hdf5_path)
    test_dataset = DCASE2017Task4Dataset(hdf5_path=test_hdf5_path)
    evaluate_dataset = DCASE2017Task4Dataset(hdf5_path=evaluate_hdf5_path)

    train_sampler = TrainSampler(
        hdf5_path=train_hdf5_path, 
        batch_size=batch_size * 2 if 'mixup' in augmentation else batch_size, 
        few_shots=few_shots, 
        random_seed=random_seed)

    test_sampler = EvaluateSampler(dataset_size=len(test_dataset), batch_size=batch_size)
    evaluate_sampler = EvaluateSampler(dataset_size=len(evaluate_dataset), batch_size=batch_size)

    collector = Collator()

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
        batch_sampler=train_sampler, collate_fn=collector, 
        num_workers=num_workers, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
        batch_sampler=test_sampler, collate_fn=collector, 
        num_workers=num_workers, pin_memory=True)

    evaluate_loader = torch.utils.data.DataLoader(dataset=evaluate_dataset, 
        batch_sampler=evaluate_sampler, collate_fn=collector, 
        num_workers=num_workers, pin_memory=True)

    if 'mixup' in augmentation:
        mixup_augmenter = Mixup(mixup_alpha=1.)
        
    # Evaluator
    test_evaluator = Evaluator(
        model=model, 
        generator=test_loader)

    evaluate_evaluator = Evaluator(
        model=model, 
        generator=evaluate_loader)

    train_bgn_time = time.time()
    
    # Train on mini batches
    for batch_data_dict in train_loader:
        
        # Evaluate
        if iteration % 1000 == 0:
            if resume_iteration > 0 and iteration == resume_iteration:
                pass
            else:
                logging.info('------------------------------------')
                logging.info('Iteration: {}'.format(iteration))

                train_fin_time = time.time()

                for (data_type, evaluator, reference_csv_path) in [
                    ('test', test_evaluator, test_reference_csv_path), 
                    ('evaluate', evaluate_evaluator, evaluate_reference_csv_path)]:

                    logging.info('{} statistics:'.format(data_type))

                    (statistics, predictions) = evaluator.evaluate(
                        reference_csv_path, tmp_submission_path)

                    statistics_container.append(data_type, iteration, statistics)

                    prediction_path = os.path.join(predictions_dir, 
                        '{}_iterations.prediction.{}.h5'.format(iteration, data_type))

                    write_out_prediction(predictions, prediction_path)
                
                statistics_container.dump()

                train_time = train_fin_time - train_bgn_time
                validate_time = time.time() - train_fin_time

                logging.info(
                    'Train time: {:.3f} s, validate time: {:.3f} s'
                    ''.format(train_time, validate_time))

                train_bgn_time = time.time()

        # Save model 
        if iteration % 10000 == 0 and iteration > 49999:
            checkpoint = {
                'iteration': iteration, 
                'model': model.module.state_dict(), 
                'optimizer': optimizer.state_dict()}

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))
                
            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))
        
        if 'mixup' in augmentation:
            batch_data_dict['mixup_lambda'] = mixup_augmenter.get_lambda(len(batch_data_dict['waveform']))

        # Move data to GPU
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key], device)
        
        # Train
        model.train()

        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'], batch_data_dict['mixup_lambda'])
            batch_target_dict = {'target': do_mixup(batch_data_dict['target'], batch_data_dict['mixup_lambda'])}
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            batch_target_dict = {'target': batch_data_dict['target']}

        # loss
        loss = loss_func(batch_output_dict, batch_target_dict)
        print(iteration, loss)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Stop learning
        if iteration == stop_iteration:
            break 
            
        iteration += 1
Beispiel #3
0
def train(args):
    """Train AudioSet tagging model. 

    Args:
      dataset_dir: str
      workspace: str
      data_type: 'balanced_train' | 'unbalanced_train'
      frames_per_second: int
      mel_bins: int
      model_type: str
      loss_type: 'bce'
      balanced: bool
      augmentation: str
      batch_size: int
      learning_rate: float
      resume_iteration: int
      early_stop: int
      accumulation_steps: int
      cuda: bool
    """

    # Arugments & parameters
    # dataset_dir = args.dataset_dir
    workspace = args.workspace
    data_type = args.data_type
    window_size = args.window_size
    hop_size = args.hop_size
    mel_bins = args.mel_bins
    fmin = args.fmin
    fmax = args.fmax
    model_type = args.model_type
    loss_type = args.loss_type
    balanced = args.balanced
    augmentation = args.augmentation
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    resume_iteration = args.resume_iteration
    early_stop = args.early_stop
    device = torch.device('cuda') if args.cuda and torch.cuda.is_available(
    ) else torch.device('cpu')
    filename = args.filename

    num_workers = 0
    sample_rate = config.sample_rate
    audio_length = config.audio_length
    classes_num = config.classes_num
    assert loss_type == 'clip_bce'

    # Paths
    black_list_csv = os.path.join(workspace, 'black_list',
                                  'dcase2017task4.csv')

    waveform_hdf5s_dir = os.path.join(workspace, 'hdf5s', 'waveforms')

    # Target hdf5 path
    eval_train_targets_hdf5_path = os.path.join(workspace, 'hdf5s', 'targets',
                                                'balanced_train.h5')

    eval_test_targets_hdf5_path = os.path.join(workspace, 'hdf5s', 'targets',
                                               'eval.h5')

    if data_type == 'balanced_train':
        train_targets_hdf5_path = os.path.join(workspace, 'hdf5s', 'targets',
                                               'balanced_train.h5')
    elif data_type == 'full_train':
        train_targets_hdf5_path = os.path.join(workspace, 'hdf5s', 'targets',
                                               'full_train.h5')

    checkpoints_dir = os.path.join(
        workspace, 'checkpoints', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))
    create_folder(checkpoints_dir)

    statistics_path = os.path.join(
        workspace, 'statistics', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size), 'statistics.pkl')
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(
        workspace, 'logs', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))

    create_logging(logs_dir, filemode='w')
    logging.info(args)

    if 'cuda' in str(device):
        logging.info('Using GPU.')
        device = 'cuda'
    else:
        logging.info('Using CPU.')
        device = 'cpu'

    # Model
    model = Cnn13(audio_length, sample_rate, window_size, hop_size, mel_bins,
                  fmin, fmax, classes_num)
    model.summary()
    logging.info('Parameters number: {}'.format(model.count_params()))

    # Optimizer
    optimizer = keras.optimizers.Adam(lr=learning_rate,
                                      beta_1=0.9,
                                      beta_2=0.999,
                                      amsgrad=True)

    # Loss
    loss = keras.losses.binary_crossentropy

    model.compile(loss=loss, optimizer=optimizer)

    # Dataset will be used by DataLoader later. Provide an index and return
    # waveform and target of audio
    train_dataset = AudioSetDataset(target_hdf5_path=train_targets_hdf5_path,
                                    waveform_hdf5s_dir=waveform_hdf5s_dir,
                                    audio_length=audio_length,
                                    classes_num=classes_num)

    bal_dataset = AudioSetDataset(
        target_hdf5_path=eval_train_targets_hdf5_path,
        waveform_hdf5s_dir=waveform_hdf5s_dir,
        audio_length=audio_length,
        classes_num=classes_num)

    test_dataset = AudioSetDataset(
        target_hdf5_path=eval_test_targets_hdf5_path,
        waveform_hdf5s_dir=waveform_hdf5s_dir,
        audio_length=audio_length,
        classes_num=classes_num)

    # Sampler
    if balanced == 'balanced':
        if 'mixup' in augmentation:
            train_sampler = BalancedSamplerMixup(
                target_hdf5_path=train_targets_hdf5_path,
                black_list_csv=black_list_csv,
                batch_size=batch_size,
                start_mix_epoch=1)
            train_collector = Collator(mixup_alpha=1.)
            assert batch_size % torch.cuda.device_count(
            ) == 0, 'To let mixup working properly this must be satisfied.'
        else:
            train_sampler = BalancedSampler(
                target_hdf5_path=train_targets_hdf5_path,
                black_list_csv=black_list_csv,
                batch_size=batch_size)
            train_collector = Collator(mixup_alpha=None)

    bal_sampler = EvaluateSampler(dataset_size=len(bal_dataset),
                                  batch_size=batch_size)

    test_sampler = EvaluateSampler(dataset_size=len(test_dataset),
                                   batch_size=batch_size)

    eval_collector = Collator(mixup_alpha=None)

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_sampler=train_sampler,
                                               collate_fn=train_collector,
                                               num_workers=num_workers,
                                               pin_memory=True)

    bal_loader = torch.utils.data.DataLoader(dataset=bal_dataset,
                                             batch_sampler=bal_sampler,
                                             collate_fn=eval_collector,
                                             num_workers=num_workers,
                                             pin_memory=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_sampler=test_sampler,
                                              collate_fn=eval_collector,
                                              num_workers=num_workers,
                                              pin_memory=True)

    # Evaluator
    bal_evaluator = Evaluator(model=model, generator=bal_loader)

    test_evaluator = Evaluator(model=model, generator=test_loader)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    train_bgn_time = time.time()

    # Resume training
    if resume_iteration > 0:
        resume_weights_path = os.path.join(
            checkpoints_dir,
            '{}_iterations.weights.h5'.format(resume_iteration))
        resume_sampler_path = os.path.join(
            checkpoints_dir,
            '{}_iterations.sampler.h5'.format(resume_iteration))
        iteration = resume_iteration

        model.load_weights(resume_weights_path)
        sampler_state_dict = cPickle.load(open(resume_sampler_path, 'rb'))
        train_sampler.load_state_dict(sampler_state_dict)
        statistics_container.load_state_dict(resume_iteration)

    else:
        iteration = 0

    t_ = time.time()

    for batch_data_dict in train_loader:

        # Evaluate
        if (iteration % 2000 == 0
                and iteration > resume_iteration) or (iteration == 0):
            train_fin_time = time.time()

            bal_statistics = bal_evaluator.evaluate()
            test_statistics = test_evaluator.evaluate()

            logging.info('Validate bal mAP: {:.3f}'.format(
                np.mean(bal_statistics['average_precision'])))

            logging.info('Validate test mAP: {:.3f}'.format(
                np.mean(test_statistics['average_precision'])))

            statistics_container.append(iteration,
                                        bal_statistics,
                                        data_type='bal')
            statistics_container.append(iteration,
                                        test_statistics,
                                        data_type='test')
            statistics_container.dump()

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info(
                'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s'
                ''.format(iteration, train_time, validate_time))

            logging.info('------------------------------------')

            train_bgn_time = time.time()

        # Save model
        # if iteration % 20000 == 0 and iteration > resume_iteration:
        if iteration == 10:
            weights_path = os.path.join(
                checkpoints_dir, '{}_iterations.weights.h5'.format(iteration))

            sampler_path = os.path.join(
                checkpoints_dir, '{}_iterations.sampler.h5'.format(iteration))

            model.save_weights(weights_path)
            cPickle.dump(train_sampler.state_dict(), open(sampler_path, 'wb'))

            logging.info('Model weights saved to {}'.format(weights_path))
            logging.info('Sampler saved to {}'.format(sampler_path))
        '''
        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'], batch_data_dict['mixup_lambda'])
            batch_target_dict = {'target': do_mixup(batch_data_dict['target'], batch_data_dict['mixup_lambda'])}
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            batch_target_dict = {'target': batch_data_dict['target']}
        '''

        loss = model.train_on_batch(x=batch_data_dict['waveform'],
                                    y=batch_data_dict['target'])
        print(iteration, loss)

        iteration += 1

        # Stop learning
        if iteration == early_stop:
            break
Beispiel #4
0
def train(args):
    """Train AudioSet tagging model. 

    Args:
      dataset_dir: str
      workspace: str
      data_type: 'balanced_train' | 'unbalanced_train'
      frames_per_second: int
      mel_bins: int
      model_type: str
      loss_type: 'bce'
      balanced: bool
      augmentation: str
      batch_size: int
      learning_rate: float
      resume_iteration: int
      early_stop: int
      accumulation_steps: int
      cuda: bool
    """

    # Arugments & parameters
    workspace = args.workspace
    data_type = args.data_type
    window_size = args.window_size
    hop_size = args.hop_size
    mel_bins = args.mel_bins
    fmin = args.fmin
    fmax = args.fmax
    model_type = args.model_type
    loss_type = args.loss_type
    balanced = args.balanced
    augmentation = args.augmentation
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    resume_iteration = args.resume_iteration
    early_stop = args.early_stop
    device = torch.device('cuda') if args.cuda and torch.cuda.is_available(
    ) else torch.device('cpu')
    filename = args.filename

    num_workers = 8
    sample_rate = config.sample_rate
    clip_samples = config.clip_samples
    classes_num = config.classes_num
    loss_func = get_loss_func(loss_type)

    # Paths
    black_list_csv = os.path.join(workspace, 'black_list',
                                  'dcase2017task4.csv')

    train_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                           '{}.h5'.format(data_type))

    eval_bal_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                              'balanced_train.h5')

    eval_test_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                               'eval.h5')

    checkpoints_dir = os.path.join(
        workspace, 'checkpoints', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))
    create_folder(checkpoints_dir)

    statistics_path = os.path.join(
        workspace, 'statistics', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size), 'statistics.pkl')
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(
        workspace, 'logs', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))

    create_logging(logs_dir, filemode='w')
    logging.info(args)

    if 'cuda' in str(device):
        logging.info('Using GPU.')
        device = 'cuda'
    else:
        logging.info('Using CPU.')
        device = 'cpu'

    # Model
    Model = eval(model_type)
    model = Model(sample_rate=sample_rate,
                  window_size=window_size,
                  hop_size=hop_size,
                  mel_bins=mel_bins,
                  fmin=fmin,
                  fmax=fmax,
                  classes_num=classes_num)

    params_num = count_parameters(model)
    # flops_num = count_flops(model, clip_samples)
    logging.info('Parameters num: {}'.format(params_num))
    # logging.info('Flops num: {:.3f} G'.format(flops_num / 1e9))

    # Dataset will be used by DataLoader later. Dataset takes a meta as input
    # and return a waveform and a target.
    dataset = AudioSetDataset(clip_samples=clip_samples,
                              classes_num=classes_num)

    # Train sampler
    (train_sampler,
     train_collector) = get_train_sampler(balanced, augmentation,
                                          train_indexes_hdf5_path,
                                          black_list_csv, batch_size)

    # Evaluate sampler
    eval_bal_sampler = EvaluateSampler(
        indexes_hdf5_path=eval_bal_indexes_hdf5_path, batch_size=batch_size)

    eval_test_sampler = EvaluateSampler(
        indexes_hdf5_path=eval_test_indexes_hdf5_path, batch_size=batch_size)

    eval_collector = Collator(mixup_alpha=None)

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_sampler=train_sampler,
                                               collate_fn=train_collector,
                                               num_workers=num_workers,
                                               pin_memory=True)

    eval_bal_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_bal_sampler,
        collate_fn=eval_collector,
        num_workers=num_workers,
        pin_memory=True)

    eval_test_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_test_sampler,
        collate_fn=eval_collector,
        num_workers=num_workers,
        pin_memory=True)

    # Evaluator
    bal_evaluator = Evaluator(model=model, generator=eval_bal_loader)
    test_evaluator = Evaluator(model=model, generator=eval_test_loader)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)

    train_bgn_time = time.time()

    # Resume training
    if resume_iteration > 0:
        resume_checkpoint_path = os.path.join(
            workspace, 'checkpoints', filename,
            'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
            .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                    fmax), 'data_type={}'.format(data_type), model_type,
            'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
            'augmentation={}'.format(augmentation),
            'batch_size={}'.format(batch_size),
            '{}_iterations.pth'.format(resume_iteration))

        logging.info('Loading checkpoint {}'.format(resume_checkpoint_path))
        checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(checkpoint['model'])
        train_sampler.load_state_dict(checkpoint['sampler'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = checkpoint['iteration']

    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in str(device):
        model.to(device)

    time1 = time.time()

    for batch_data_dict in train_loader:
        """batch_data_dict: {
            'audio_name': (batch_size [*2 if mixup],), 
            'waveform': (batch_size [*2 if mixup], clip_samples), 
            'target': (batch_size [*2 if mixup], classes_num), 
            (ifexist) 'mixup_lambda': (batch_size * 2,)}
        """

        # Evaluate
        if (iteration % 2000 == 0
                and iteration > resume_iteration) or (iteration == 0):
            train_fin_time = time.time()

            bal_statistics = bal_evaluator.evaluate()
            test_statistics = test_evaluator.evaluate()

            logging.info('Validate bal mAP: {:.3f}'.format(
                np.mean(bal_statistics['average_precision'])))

            logging.info('Validate test mAP: {:.3f}'.format(
                np.mean(test_statistics['average_precision'])))

            statistics_container.append(iteration,
                                        bal_statistics,
                                        data_type='bal')
            statistics_container.append(iteration,
                                        test_statistics,
                                        data_type='test')
            statistics_container.dump()

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info(
                'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s'
                ''.format(iteration, train_time, validate_time))

            logging.info('------------------------------------')

            train_bgn_time = time.time()

        # Save model
        if iteration % 20000 == 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        # Move data to device
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key],
                                                       device)

        # Forward
        model.train()

        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'],
                                      batch_data_dict['mixup_lambda'])
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {
                'target':
                do_mixup(batch_data_dict['target'],
                         batch_data_dict['mixup_lambda'])
            }
            """{'target': (batch_size, classes_num)}"""
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {'target': batch_data_dict['target']}
            """{'target': (batch_size, classes_num)}"""

        # Loss
        loss = loss_func(batch_output_dict, batch_target_dict)

        # Backward
        loss.backward()
        print(loss)

        optimizer.step()
        optimizer.zero_grad()

        if iteration % 10 == 0:
            print('--- Iteration: {}, train time: {:.3f} s / 10 iterations ---'\
                .format(iteration, time.time() - time1))
            time1 = time.time()

        iteration += 1

        # Stop learning
        if iteration == early_stop:
            break
def train(args):

    # Arugments & parameters
    window_size = args.window_size
    hop_size = args.hop_size
    mel_bins = args.mel_bins
    fmin = args.fmin
    fmax = args.fmax
    model_type = args.model_type
    pretrained_checkpoint_path = args.pretrained_checkpoint_path
    freeze_base = args.freeze_base
    freeze_base = True
    device = 'cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu'
    sample_rate = config.sample_rate
    classes_num = config.classes_num
    pretrain = True if pretrained_checkpoint_path else False

    # Model
    Model = eval(model_type)
    model = Model(sample_rate, window_size, hop_size, mel_bins, fmin, fmax,
                  classes_num, freeze_base)

    # Load pretrained model
    if pretrain:
        logging.info(
            'Load pretrained model from {}'.format(pretrained_checkpoint_path))
        model.load_from_pretrain(pretrained_checkpoint_path)

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in device:
        model.to(device)

    print('Load pretrained model successfully!')
    ###############Copying main.py####################
    workspace_input = args.workspace_input
    workspace_output = args.workspace_output
    data_type = 'balanced_train'
    loss_type = 'clip_bce'
    balanced = 'balanced'
    augmentation = 'none'
    batch_size = 1
    learning_rate = 1e-3
    resume_iteration = 0
    early_stop = 100000
    device = torch.device('cuda') if args.cuda and torch.cuda.is_available(
    ) else torch.device('cpu')
    filename = args.filename
    num_workers = 8
    clip_samples = config.clip_samples
    loss_func = get_loss_func(loss_type)
    black_list_csv = 'metadata/black_list/groundtruth_weak_label_evaluation_set.csv'
    previous_loss = None

    train_indexes_hdf5_path = os.path.join(workspace_input, 'hdf5s', 'indexes',
                                           '{}.h5'.format(data_type))

    eval_bal_indexes_hdf5_path = os.path.join(workspace_input, 'hdf5s',
                                              'indexes', 'balanced_train.h5')

    eval_test_indexes_hdf5_path = os.path.join(workspace_input, 'hdf5s',
                                               'indexes', 'eval.h5')

    checkpoints_dir = os.path.join(
        workspace_output, 'checkpoints', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))
    create_folder(checkpoints_dir)

    statistics_path = os.path.join(
        workspace_output, 'statistics', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size), 'statistics.pkl')
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(
        workspace_output, 'logs', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))

    create_logging(logs_dir, filemode='w')
    logging.info(args)

    if 'cuda' in str(device):
        logging.info('Using GPU.')
        device = 'cuda'
    else:
        logging.info('Using CPU.')
        device = 'cpu'

    # Model
    Model = eval(model_type)
    model = Model(sample_rate=sample_rate,
                  window_size=window_size,
                  hop_size=hop_size,
                  mel_bins=mel_bins,
                  fmin=fmin,
                  fmax=fmax,
                  classes_num=classes_num,
                  freeze_base=freeze_base)
    params_num = count_parameters(model)
    # flops_num = count_flops(model, clip_samples)
    logging.info('Parameters num: {}'.format(params_num))
    # logging.info('Flops num: {:.3f} G'.format(flops_num / 1e9))

    # Dataset will be used by DataLoader later. Dataset takes a meta as input
    # and return a waveform and a target.
    dataset = AudioSetDataset(clip_samples=clip_samples,
                              classes_num=classes_num)

    # Train sampler
    (train_sampler, train_collector) = get_train_sampler(
        balanced, augmentation,
        workspace_input + 'hdf5s/indexes/balanced_train.h5', black_list_csv,
        batch_size)

    # Evaluate sampler
    eval_bal_sampler = EvaluateSampler(indexes_hdf5_path=workspace_input +
                                       'hdf5s/indexes/balanced_train.h5',
                                       batch_size=batch_size)

    eval_test_sampler = EvaluateSampler(indexes_hdf5_path=workspace_input +
                                        'hdf5s/indexes/eval.h5',
                                        batch_size=batch_size)

    eval_collector = Collator(mixup_alpha=None)

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_sampler=train_sampler,
                                               collate_fn=train_collector,
                                               num_workers=num_workers,
                                               pin_memory=True)

    eval_bal_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_bal_sampler,
        collate_fn=eval_collector,
        num_workers=num_workers,
        pin_memory=True)

    eval_test_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_test_sampler,
        collate_fn=eval_collector,
        num_workers=num_workers,
        pin_memory=True)

    # Evaluator
    bal_evaluator = Evaluator(model=model, generator=eval_bal_loader)
    test_evaluator = Evaluator(model=model, generator=eval_test_loader)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)

    train_bgn_time = time.time()
    if resume_iteration > 0:
        resume_checkpoint_path = os.path.join(
            workspace_input, 'checkpoints', filename,
            'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
            .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                    fmax), 'data_type={}'.format(data_type), model_type,
            'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
            'augmentation={}'.format(augmentation),
            'batch_size={}'.format(batch_size),
            '{}_iterations.pth'.format(resume_iteration))

        logging.info('Loading checkpoint {}'.format(resume_checkpoint_path))
        if torch.cuda.is_available():
            checkpoint = torch.load(resume_checkpoint_path)
        else:
            checkpoint = torch.load(resume_checkpoint_path, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        train_sampler.load_state_dict(checkpoint['sampler'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = checkpoint['iteration']

    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in str(device):
        model.to(device)

    time1 = time.time()

    for iterate_n, batch_data_dict in enumerate(train_loader):
        """batch_data_dict: {
            'audio_name': (batch_size [*2 if mixup],), 
            'waveform': (batch_size [*2 if mixup], clip_samples), 
            'target': (batch_size [*2 if mixup], classes_num), 
            (ifexist) 'mixup_lambda': (batch_size * 2,)}
        """

        # Evaluate
        if (iteration % 2000 == 0
                and iteration > resume_iteration) or (iteration == 0):
            train_fin_time = time.time()

            bal_statistics = bal_evaluator.evaluate()
            test_statistics = test_evaluator.evaluate()

            logging.info('Validate bal mAP: {:.3f}'.format(
                np.mean(bal_statistics['average_precision'])))

            logging.info('Validate test mAP: {:.3f}'.format(
                np.mean(test_statistics['average_precision'])))

            statistics_container.append(iteration,
                                        bal_statistics,
                                        data_type='bal')
            statistics_container.append(iteration,
                                        test_statistics,
                                        data_type='test')
            statistics_container.dump()

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info(
                'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s'
                ''.format(iteration, train_time, validate_time))

            logging.info('------------------------------------')

            train_bgn_time = time.time()

        # Save model
        if iteration % 20000 == 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        # Move data to device
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key],
                                                       device)

        # Forward
        model.train()
        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'],
                                      batch_data_dict['mixup_lambda'])
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {
                'target':
                do_mixup(batch_data_dict['target'],
                         batch_data_dict['mixup_lambda'])
            }
            """{'target': (batch_size, classes_num)}"""
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {'target': batch_data_dict['target']}
            """{'target': (batch_size, classes_num)}"""
        loss = loss_func(batch_output_dict, batch_target_dict)
        # Loss
        # try:
        #     loss = loss_func(batch_output_dict, batch_target_dict)
        # except:
        #     tensor = batch_output_dict['clipwise_output'].detach().numpy()
        #     arr = -1. * np.where(tensor > 0,0.,tensor)
        #     batch_output_dict['clipwise_output'] = torch.tensor(np.where(arr > 1,1.,arr),requires_grad=True)
        #     loss = loss_func(batch_output_dict, batch_target_dict)
        # Backward
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if iteration % 10 == 0:
            print('--- Iteration: {}, train time: {:.3f} s / 10 iterations ---'\
                .format(iteration, time.time() - time1))
            time1 = time.time()

        iteration += 1

        # Stop learning
        if iteration == early_stop:
            break