예제 #1
0
파일: models.py 프로젝트: preduct0r/diploma
    def forward(self, input, mixup_lambda=None):
        """
        Input: (batch_size, data_length)"""

        x = self.spectrogram_extractor(
            input)  # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)  # (batch_size, 1, time_steps, mel_bins)

        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)

        if self.training:
            x = self.spec_augmenter(x)

        # Mixup on spectrogram
        if self.training and mixup_lambda is not None:
            x = do_mixup(x, mixup_lambda)

        x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = torch.mean(x, dim=3)

        (x1, _) = torch.max(x, dim=2)
        x2 = torch.mean(x, dim=2)
        x = x1 + x2
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu_(self.fc1(x))
        embedding = F.dropout(x, p=0.5, training=self.training)
        clipwise_output = torch.sigmoid(self.fc_audioset(x))

        output_dict = {
            'clipwise_output': clipwise_output,
            'embedding': embedding
        }

        return output_dict
예제 #2
0
    def forward(self, input, mixup_lambda=None):
        """Input: (batch_size, times_steps, freq_bins)"""

        interpolate_ratio = 8

        x = self.spectrogram_extractor(
            input)  # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)  # (batch_size, 1, time_steps, mel_bins)

        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)

        if self.training:
            x = self.spec_augmenter(x)

        # Mixup on spectrogram
        if self.training and mixup_lambda is not None:
            x = do_mixup(x, mixup_lambda)

        x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
        x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
        x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
        x = self.conv_block4(x, pool_size=(1, 1), pool_type='avg')

        x = torch.mean(x, dim=3)
        x = x.transpose(1, 2)  # (batch_size, time_steps, channels)
        (x, _) = self.gru(x)
        x = x.transpose(1, 2)
        embedding = x  # (batch_size, feature_maps, time_steps)

        # Framewise output
        x = x.transpose(1, 2)
        framewise_output = torch.sigmoid(self.fc(x))
        framewise_output = interpolate(framewise_output, interpolate_ratio)

        # Clipwise output
        clipwise_output = torch.mean(framewise_output, dim=1)

        output_dict = {
            'framewise_output': framewise_output,
            'clipwise_output': clipwise_output,
            'embedding': embedding
        }

        return output_dict
예제 #3
0
    def forward(self, input, mixup_lambda=None):
        """Input: (batch_size, times_steps, freq_bins)"""

        interpolate_ratio = 8

        x = self.spectrogram_extractor(
            input)  # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)  # (batch_size, 1, time_steps, mel_bins)

        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)

        if self.training:
            x = self.spec_augmenter(x)

        # Mixup on spectrogram
        if self.training and mixup_lambda is not None:
            x = do_mixup(x, mixup_lambda)

        x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
        x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
        x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
        x = self.conv_block4(x, pool_size=(1, 1), pool_type='avg')

        x = torch.mean(x, dim=3)
        x = x.transpose(1, 2)  # (batch_size, time_steps, channels)
        x = self.multihead(x, x, x)
        x = x.transpose(1, 2)
        embedding = x  # (batch_size, feature_maps, time_steps)

        (clipwise_output, norm_att, cla) = self.att_block(x)
        """cla: (batch_size, classes_num, time_stpes)"""

        # Framewise output
        framewise_output = cla.transpose(1, 2)
        framewise_output = interpolate(framewise_output, interpolate_ratio)

        output_dict = {
            'framewise_output': framewise_output,
            'clipwise_output': clipwise_output,
            'embedding': embedding
        }

        return output_dict
예제 #4
0
파일: models.py 프로젝트: thunderock/kaggle
    def forward(self, input, spec_aug=False, mixup_lambda=None):
        #print(input.type())
        x = self.spectrogram_extractor(
            input.float())  # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)  # (batch_size, 1, time_steps, mel_bins)

        #if spec_aug:
        #    x = self.spec_augmenter(x)
        if self.training:
            x = self.spec_augmenter(x)

        # Mixup on spectrogram
        if mixup_lambda is not None:
            x = do_mixup(x, mixup_lambda)
            #pass

        x = self.encoder.forward_features(x)
        x = self.avg_pool(x).flatten(1)
        x = self.dropout(x)
        x = self.fc(x)

        return x
    def forward_orig(self, x, mixup_lambda=None):
        x = self.spectrogram_extractor(x)   # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)    # (batch_size, 1, time_steps, mel_bins)

        # frames_num = x.shape[2]

        # print( 'x shape: ', x.shape ) # x shape:  torch.Size([1, 1, 701, 64])
        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)
        
        if self.training:
            x = self.spec_augmenter(x)

        # Mixup on spectrogram
        if self.training and mixup_lambda is not None:
            x = do_mixup(x, mixup_lambda)
        
        x = self.features(x)

        
        x = torch.mean(x, dim=3)

        x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
        x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
        x = x1 + x2
        x = F.dropout(x, p=0.5, training=self.training)
        x = x.transpose(1, 2)
        x = F.relu_(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        segmentwise_output = torch.sigmoid(self.fc_audioset(x))
        (clipwise_output, _) = torch.max(segmentwise_output, dim=1)

        # Get framewise output
        framewise_output = interpolate(segmentwise_output, self.interpolate_ratio)
        # TEMP DISABLE framewise_output = pad_framewise_output(framewise_output, frames_num)

        return framewise_output
예제 #6
0
def train(args):

    # Arugments & parameters
    dataset_dir = args.dataset_dir
    workspace = args.workspace
    holdout_fold = args.holdout_fold
    model_type = args.model_type
    pretrained_checkpoint_path = args.pretrained_checkpoint_path
    freeze_base = args.freeze_base
    loss_type = args.loss_type
    augmentation = args.augmentation
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    few_shots = args.few_shots
    random_seed = args.random_seed
    resume_iteration = args.resume_iteration
    stop_iteration = args.stop_iteration
    device = 'cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu'
    mini_data = args.mini_data
    filename = args.filename

    loss_func = get_loss_func(loss_type)
    pretrain = True if pretrained_checkpoint_path else False
    num_workers = 16
    
    # Paths
    if mini_data:
        prefix = 'minidata_'
    else:
        prefix = ''

    train_hdf5_path = os.path.join(workspace, 'features', 
        '{}training.waveform.h5'.format(prefix))

    test_hdf5_path = os.path.join(workspace, 'features', 
        'testing.waveform.h5'.format(prefix))

    evaluate_hdf5_path = os.path.join(workspace, 'features', 
        'evaluation.waveform.h5'.format(prefix))

    test_reference_csv_path = os.path.join(dataset_dir, 'metadata', 
        'groundtruth_strong_label_testing_set.csv')
        
    evaluate_reference_csv_path = os.path.join(dataset_dir, 'metadata', 
        'groundtruth_strong_label_evaluation_set.csv')

    checkpoints_dir = os.path.join(workspace, 'checkpoints', filename, 
        'holdout_fold={}'.format(holdout_fold), model_type, 
        'pretrain={}'.format(pretrain), 'loss_type={}'.format(loss_type), 
        'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size), 
        'few_shots={}'.format(few_shots), 'random_seed={}'.format(random_seed), 
        'freeze_base={}'.format(freeze_base))
    create_folder(checkpoints_dir)

    tmp_submission_path = os.path.join(workspace, '_tmp_submission', 
        '{}{}'.format(prefix, filename), 'holdout_fold={}'.format(holdout_fold), 
        model_type, 'pretrain={}'.format(pretrain), 'loss_type={}'.format(loss_type), 
        'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size), 
        'few_shots={}'.format(few_shots), 'random_seed={}'.format(random_seed), 
        'freeze_base={}'.format(freeze_base), '_submission.csv')
    create_folder(os.path.dirname(tmp_submission_path))

    statistics_path = os.path.join(workspace, 'statistics', 
        '{}{}'.format(prefix, filename), 'holdout_fold={}'.format(holdout_fold), 
        model_type, 'pretrain={}'.format(pretrain), 'loss_type={}'.format(loss_type), 
        'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size), 
        'few_shots={}'.format(few_shots), 'random_seed={}'.format(random_seed), 
        'freeze_base={}'.format(freeze_base), 'statistics.pickle')
    create_folder(os.path.dirname(statistics_path))

    predictions_dir = os.path.join(workspace, 'predictions', 
        '{}{}'.format(prefix, filename), 'holdout_fold={}'.format(holdout_fold), 
        model_type, 'pretrain={}'.format(pretrain), 
        'loss_type={}'.format(loss_type), 'augmentation={}'.format(augmentation), 
        'few_shots={}'.format(few_shots), 'random_seed={}'.format(random_seed), 
        'freeze_base={}'.format(freeze_base), 'batch_size={}'.format(batch_size))
    create_folder(predictions_dir)

    logs_dir = os.path.join(workspace, 'logs', '{}{}'.format(prefix, filename), 
        'holdout_fold={}'.format(holdout_fold), model_type, 
        'pretrain={}'.format(pretrain), 'loss_type={}'.format(loss_type), 
        'augmentation={}'.format(augmentation), 'few_shots={}'.format(few_shots), 
        'random_seed={}'.format(random_seed), 'freeze_base={}'.format(freeze_base), 
        'batch_size={}'.format(batch_size))
    create_logging(logs_dir, 'w')
    logging.info(args)

    if 'cuda' in device:
        logging.info('Using GPU.')
    else:
        logging.info('Using CPU. Set --cuda flag to use GPU.')
    
    # Model
    Model = eval(model_type)
    model = Model(sample_rate, window_size, hop_size, mel_bins, fmin, fmax, 
        classes_num)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    if pretrain:
        logging.info('Load pretrained model from {}'.format(pretrained_checkpoint_path))
        model.load_from_pretrain(pretrained_checkpoint_path)

    if resume_iteration:
        resume_checkpoint_path = os.path.join(checkpoints_dir, '{}_iterations.pth'.format(resume_iteration))
        logging.info('Load resume model from {}'.format(resume_checkpoint_path))
        resume_checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(resume_checkpoint['model'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = resume_checkpoint['iteration']
    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in device:
        model.to(device)

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, 
        betas=(0.9, 0.999), eps=1e-08, weight_decay=0., amsgrad=True)

    train_dataset = DCASE2017Task4Dataset(hdf5_path=train_hdf5_path)
    test_dataset = DCASE2017Task4Dataset(hdf5_path=test_hdf5_path)
    evaluate_dataset = DCASE2017Task4Dataset(hdf5_path=evaluate_hdf5_path)

    train_sampler = TrainSampler(
        hdf5_path=train_hdf5_path, 
        batch_size=batch_size * 2 if 'mixup' in augmentation else batch_size, 
        few_shots=few_shots, 
        random_seed=random_seed)

    test_sampler = EvaluateSampler(dataset_size=len(test_dataset), batch_size=batch_size)
    evaluate_sampler = EvaluateSampler(dataset_size=len(evaluate_dataset), batch_size=batch_size)

    collector = Collator()

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
        batch_sampler=train_sampler, collate_fn=collector, 
        num_workers=num_workers, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
        batch_sampler=test_sampler, collate_fn=collector, 
        num_workers=num_workers, pin_memory=True)

    evaluate_loader = torch.utils.data.DataLoader(dataset=evaluate_dataset, 
        batch_sampler=evaluate_sampler, collate_fn=collector, 
        num_workers=num_workers, pin_memory=True)

    if 'mixup' in augmentation:
        mixup_augmenter = Mixup(mixup_alpha=1.)
        
    # Evaluator
    test_evaluator = Evaluator(
        model=model, 
        generator=test_loader)

    evaluate_evaluator = Evaluator(
        model=model, 
        generator=evaluate_loader)

    train_bgn_time = time.time()
    
    # Train on mini batches
    for batch_data_dict in train_loader:
        
        # Evaluate
        if iteration % 1000 == 0:
            if resume_iteration > 0 and iteration == resume_iteration:
                pass
            else:
                logging.info('------------------------------------')
                logging.info('Iteration: {}'.format(iteration))

                train_fin_time = time.time()

                for (data_type, evaluator, reference_csv_path) in [
                    ('test', test_evaluator, test_reference_csv_path), 
                    ('evaluate', evaluate_evaluator, evaluate_reference_csv_path)]:

                    logging.info('{} statistics:'.format(data_type))

                    (statistics, predictions) = evaluator.evaluate(
                        reference_csv_path, tmp_submission_path)

                    statistics_container.append(data_type, iteration, statistics)

                    prediction_path = os.path.join(predictions_dir, 
                        '{}_iterations.prediction.{}.h5'.format(iteration, data_type))

                    write_out_prediction(predictions, prediction_path)
                
                statistics_container.dump()

                train_time = train_fin_time - train_bgn_time
                validate_time = time.time() - train_fin_time

                logging.info(
                    'Train time: {:.3f} s, validate time: {:.3f} s'
                    ''.format(train_time, validate_time))

                train_bgn_time = time.time()

        # Save model 
        if iteration % 10000 == 0 and iteration > 49999:
            checkpoint = {
                'iteration': iteration, 
                'model': model.module.state_dict(), 
                'optimizer': optimizer.state_dict()}

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))
                
            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))
        
        if 'mixup' in augmentation:
            batch_data_dict['mixup_lambda'] = mixup_augmenter.get_lambda(len(batch_data_dict['waveform']))

        # Move data to GPU
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key], device)
        
        # Train
        model.train()

        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'], batch_data_dict['mixup_lambda'])
            batch_target_dict = {'target': do_mixup(batch_data_dict['target'], batch_data_dict['mixup_lambda'])}
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            batch_target_dict = {'target': batch_data_dict['target']}

        # loss
        loss = loss_func(batch_output_dict, batch_target_dict)
        print(iteration, loss)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Stop learning
        if iteration == stop_iteration:
            break 
            
        iteration += 1
예제 #7
0
def train(args):
    """Train AudioSet tagging model. 

    Args:
      dataset_dir: str
      workspace: str
      data_type: 'balanced_train' | 'full_train'
      window_size: int
      hop_size: int
      mel_bins: int
      model_type: str
      loss_type: 'clip_bce'
      balanced: 'none' | 'balanced' | 'alternate'
      augmentation: 'none' | 'mixup'
      batch_size: int
      learning_rate: float
      resume_iteration: int
      early_stop: int
      accumulation_steps: int
      cuda: bool
    """

    # Arugments & parameters
    workspace = args.workspace
    data_type = args.data_type
    sample_rate = args.sample_rate
    window_size = args.window_size
    hop_size = args.hop_size
    mel_bins = args.mel_bins
    fmin = args.fmin
    fmax = args.fmax
    model_type = args.model_type
    loss_type = args.loss_type
    balanced = args.balanced
    augmentation = args.augmentation
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    resume_iteration = args.resume_iteration
    early_stop = args.early_stop
    device = torch.device('cuda') if args.cuda and torch.cuda.is_available(
    ) else torch.device('cpu')
    filename = args.filename

    num_workers = 128
    prefetch_factor = 4

    #os.environ["MASTER_ADDR"] = "localhost"
    #os.environ["MASTER_PORT"] = "12355"
    #dist.init_process_group("nccl", rank=rank, world_size=args.world_size)

    clip_samples = config.clip_samples
    classes_num = config.classes_num
    loss_func = get_loss_func(loss_type)

    # Paths
    black_list_csv = None

    train_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                           '{}.h5'.format(data_type))

    eval_bal_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                              'balanced_train.h5')

    eval_test_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                               'eval.h5')

    checkpoints_dir = os.path.join(
        workspace, 'checkpoints', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size),
        datetime.datetime.now().strftime("%d%m%Y_%H%M%S"))

    #if rank == 0:
    create_folder(checkpoints_dir)

    statistics_path = os.path.join(
        workspace, 'statistics', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size), 'statistics.pkl')

    #if rank == 0:
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(
        workspace, 'logs', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))

    create_logging(logs_dir, filemode='w')
    logging.info(args)

    if 'cuda' in str(device):
        logging.info('Using GPU.')
        device = 'cuda'
    else:
        logging.info('Using CPU. Set --cuda flag to use GPU.')
        device = 'cpu'

    # Model
    Model = eval(model_type)
    model = Model(sample_rate=sample_rate,
                  window_size=window_size,
                  hop_size=hop_size,
                  mel_bins=mel_bins,
                  fmin=fmin,
                  fmax=fmax,
                  classes_num=classes_num)

    params_num = count_parameters(model)
    # flops_num = count_flops(model, clip_samples)
    logging.info('Parameters num: {}'.format(params_num))
    # logging.info('Flops num: {:.3f} G'.format(flops_num / 1e9))

    # Dataset will be used by DataLoader later. Dataset takes a meta as input
    # and return a waveform and a target.
    dataset = AudioSetDataset(sample_rate=sample_rate)

    # Train sampler
    if balanced == 'none':
        Sampler = TrainSampler
    elif balanced == 'balanced':
        Sampler = BalancedTrainSampler
    elif balanced == 'alternate':
        Sampler = AlternateTrainSampler

    train_sampler = Sampler(indexes_hdf5_path=train_indexes_hdf5_path,
                            batch_size=batch_size *
                            2 if 'mixup' in augmentation else batch_size,
                            black_list_csv=black_list_csv)

    # Evaluate sampler
    eval_bal_sampler = EvaluateSampler(
        indexes_hdf5_path=eval_bal_indexes_hdf5_path,
        batch_size=2 * batch_size)

    eval_test_sampler = EvaluateSampler(
        indexes_hdf5_path=eval_test_indexes_hdf5_path,
        batch_size=2 * batch_size)

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_sampler=train_sampler,
                                               collate_fn=collate_fn,
                                               num_workers=num_workers,
                                               pin_memory=True,
                                               prefetch_factor=prefetch_factor)

    eval_bal_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_bal_sampler,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=True,
        prefetch_factor=prefetch_factor)

    eval_test_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_test_sampler,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=True,
        prefetch_factor=prefetch_factor)

    if 'mixup' in augmentation:
        mixup_augmenter = Mixup(mixup_alpha=1.)

    # Evaluator
    evaluator = Evaluator(model=model)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)

    train_bgn_time = time.time()

    # Resume training
    if resume_iteration > 0:
        resume_checkpoint_path = os.path.join(
            workspace, 'checkpoints', filename,
            'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
            .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                    fmax), 'data_type={}'.format(data_type), model_type,
            'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
            'augmentation={}'.format(augmentation),
            'batch_size={}'.format(batch_size),
            '{}_iterations.pth'.format(resume_iteration))

        logging.info('Loading checkpoint {}'.format(resume_checkpoint_path))
        checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(checkpoint['model'])
        train_sampler.load_state_dict(checkpoint['sampler'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = checkpoint['iteration']

    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in str(device):
        model.to(device)
        #model = model.cuda(rank)

    #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    #print([(s[0], s[1].is_cuda) for s in model.named_parameters()])

    time1 = time.time()

    prev_bal_map = 0.0
    prev_test_map = 0.0
    save_bal_model = 0
    save_test_model = 0

    for batch_data_dict in train_loader:
        """batch_data_dict: {
            'audio_name': (batch_size [*2 if mixup],), 
            'waveform': (batch_size [*2 if mixup], clip_samples), 
            'target': (batch_size [*2 if mixup], classes_num), 
            (ifexist) 'mixup_lambda': (batch_size * 2,)}
        """
        #print(batch_data_dict)
        # Evaluate
        if (iteration % 2000 == 0
                and iteration > resume_iteration) or (iteration == -1):
            train_fin_time = time.time()

            bal_statistics = evaluator.evaluate(eval_bal_loader)
            test_statistics = evaluator.evaluate(eval_test_loader)

            logging.info('Validate bal mAP: {:.3f}'.format(
                np.mean(bal_statistics['average_precision'])))

            logging.info('Validate test mAP: {:.3f}'.format(
                np.mean(test_statistics['average_precision'])))

            save_bal_model = 1 if np.mean(
                bal_statistics['average_precision']) > prev_bal_map else 0
            save_test_model = 1 if np.mean(
                test_statistics['average_precision']) > prev_test_map else 0

            statistics_container.append(iteration,
                                        bal_statistics,
                                        data_type='bal')
            statistics_container.append(iteration,
                                        test_statistics,
                                        data_type='test')
            statistics_container.dump()

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info(
                'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s'
                ''.format(iteration, train_time, validate_time))

            logging.info('------------------------------------')

            train_bgn_time = time.time()

        # Save model
        if iteration % 100000 == 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        if save_bal_model:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations_bal.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))
            save_bal_model = 0

        if save_test_model:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations_test.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))
            save_test_model = 0

        # Mixup lambda
        if 'mixup' in augmentation:
            batch_data_dict['mixup_lambda'] = mixup_augmenter.get_lambda(
                batch_size=len(batch_data_dict['waveform']))

        # Move data to device
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key],
                                                       device)

        # Forward
        model.train()

        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'],
                                      batch_data_dict['mixup_lambda'])
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {
                'target':
                do_mixup(batch_data_dict['target'],
                         batch_data_dict['mixup_lambda'])
            }
            """{'target': (batch_size, classes_num)}"""
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {'target': batch_data_dict['target']}
            """{'target': (batch_size, classes_num)}"""

        # Loss
        loss = loss_func(batch_output_dict, batch_target_dict)

        # Backward
        loss.backward()
        print(loss)

        optimizer.step()
        optimizer.zero_grad()

        if iteration % 10 == 0:
            print('--- Iteration: {}, train time: {:.3f} s / 10 iterations ---'\
                .format(iteration, time.time() - time1))
            time1 = time.time()

        # Stop learning
        if iteration == early_stop:
            break

        iteration += 1
예제 #8
0
def train(args):
    """Train AudioSet tagging model. 

    Args:
      dataset_dir: str
      workspace: str
      data_type: 'balanced_train' | 'unbalanced_train'
      frames_per_second: int
      mel_bins: int
      model_type: str
      loss_type: 'bce'
      balanced: bool
      augmentation: str
      batch_size: int
      learning_rate: float
      resume_iteration: int
      early_stop: int
      accumulation_steps: int
      cuda: bool
    """

    # Arugments & parameters
    workspace = args.workspace
    data_type = args.data_type
    window_size = args.window_size
    hop_size = args.hop_size
    mel_bins = args.mel_bins
    fmin = args.fmin
    fmax = args.fmax
    model_type = args.model_type
    loss_type = args.loss_type
    balanced = args.balanced
    augmentation = args.augmentation
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    resume_iteration = args.resume_iteration
    early_stop = args.early_stop
    device = torch.device('cuda') if args.cuda and torch.cuda.is_available(
    ) else torch.device('cpu')
    filename = args.filename

    num_workers = 8
    sample_rate = config.sample_rate
    clip_samples = config.clip_samples
    classes_num = config.classes_num
    loss_func = get_loss_func(loss_type)

    # Paths
    black_list_csv = os.path.join(workspace, 'black_list',
                                  'dcase2017task4.csv')

    train_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                           '{}.h5'.format(data_type))

    eval_bal_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                              'balanced_train.h5')

    eval_test_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                               'eval.h5')

    checkpoints_dir = os.path.join(
        workspace, 'checkpoints', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))
    create_folder(checkpoints_dir)

    statistics_path = os.path.join(
        workspace, 'statistics', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size), 'statistics.pkl')
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(
        workspace, 'logs', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))

    create_logging(logs_dir, filemode='w')
    logging.info(args)

    if 'cuda' in str(device):
        logging.info('Using GPU.')
        device = 'cuda'
    else:
        logging.info('Using CPU.')
        device = 'cpu'

    # Model
    Model = eval(model_type)
    model = Model(sample_rate=sample_rate,
                  window_size=window_size,
                  hop_size=hop_size,
                  mel_bins=mel_bins,
                  fmin=fmin,
                  fmax=fmax,
                  classes_num=classes_num)

    params_num = count_parameters(model)
    # flops_num = count_flops(model, clip_samples)
    logging.info('Parameters num: {}'.format(params_num))
    # logging.info('Flops num: {:.3f} G'.format(flops_num / 1e9))

    # Dataset will be used by DataLoader later. Dataset takes a meta as input
    # and return a waveform and a target.
    dataset = AudioSetDataset(clip_samples=clip_samples,
                              classes_num=classes_num)

    # Train sampler
    (train_sampler,
     train_collector) = get_train_sampler(balanced, augmentation,
                                          train_indexes_hdf5_path,
                                          black_list_csv, batch_size)

    # Evaluate sampler
    eval_bal_sampler = EvaluateSampler(
        indexes_hdf5_path=eval_bal_indexes_hdf5_path, batch_size=batch_size)

    eval_test_sampler = EvaluateSampler(
        indexes_hdf5_path=eval_test_indexes_hdf5_path, batch_size=batch_size)

    eval_collector = Collator(mixup_alpha=None)

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_sampler=train_sampler,
                                               collate_fn=train_collector,
                                               num_workers=num_workers,
                                               pin_memory=True)

    eval_bal_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_bal_sampler,
        collate_fn=eval_collector,
        num_workers=num_workers,
        pin_memory=True)

    eval_test_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_test_sampler,
        collate_fn=eval_collector,
        num_workers=num_workers,
        pin_memory=True)

    # Evaluator
    bal_evaluator = Evaluator(model=model, generator=eval_bal_loader)
    test_evaluator = Evaluator(model=model, generator=eval_test_loader)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)

    train_bgn_time = time.time()

    # Resume training
    if resume_iteration > 0:
        resume_checkpoint_path = os.path.join(
            workspace, 'checkpoints', filename,
            'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
            .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                    fmax), 'data_type={}'.format(data_type), model_type,
            'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
            'augmentation={}'.format(augmentation),
            'batch_size={}'.format(batch_size),
            '{}_iterations.pth'.format(resume_iteration))

        logging.info('Loading checkpoint {}'.format(resume_checkpoint_path))
        checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(checkpoint['model'])
        train_sampler.load_state_dict(checkpoint['sampler'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = checkpoint['iteration']

    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in str(device):
        model.to(device)

    time1 = time.time()

    for batch_data_dict in train_loader:
        """batch_data_dict: {
            'audio_name': (batch_size [*2 if mixup],), 
            'waveform': (batch_size [*2 if mixup], clip_samples), 
            'target': (batch_size [*2 if mixup], classes_num), 
            (ifexist) 'mixup_lambda': (batch_size * 2,)}
        """

        # Evaluate
        if (iteration % 2000 == 0
                and iteration > resume_iteration) or (iteration == 0):
            train_fin_time = time.time()

            bal_statistics = bal_evaluator.evaluate()
            test_statistics = test_evaluator.evaluate()

            logging.info('Validate bal mAP: {:.3f}'.format(
                np.mean(bal_statistics['average_precision'])))

            logging.info('Validate test mAP: {:.3f}'.format(
                np.mean(test_statistics['average_precision'])))

            statistics_container.append(iteration,
                                        bal_statistics,
                                        data_type='bal')
            statistics_container.append(iteration,
                                        test_statistics,
                                        data_type='test')
            statistics_container.dump()

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info(
                'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s'
                ''.format(iteration, train_time, validate_time))

            logging.info('------------------------------------')

            train_bgn_time = time.time()

        # Save model
        if iteration % 20000 == 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        # Move data to device
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key],
                                                       device)

        # Forward
        model.train()

        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'],
                                      batch_data_dict['mixup_lambda'])
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {
                'target':
                do_mixup(batch_data_dict['target'],
                         batch_data_dict['mixup_lambda'])
            }
            """{'target': (batch_size, classes_num)}"""
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {'target': batch_data_dict['target']}
            """{'target': (batch_size, classes_num)}"""

        # Loss
        loss = loss_func(batch_output_dict, batch_target_dict)

        # Backward
        loss.backward()
        print(loss)

        optimizer.step()
        optimizer.zero_grad()

        if iteration % 10 == 0:
            print('--- Iteration: {}, train time: {:.3f} s / 10 iterations ---'\
                .format(iteration, time.time() - time1))
            time1 = time.time()

        iteration += 1

        # Stop learning
        if iteration == early_stop:
            break
def train(args):
    """Train and evaluate.

    Args:
      dataset_dir: str
      workspace: str
      holdout_fold: '1'
      model_type: str, e.g., 'Cnn_9layers_Gru_FrameAtt'
      loss_type: str, e.g., 'clip_bce'
      augmentation: str, e.g., 'mixup'
      learning_rate, float
      batch_size: int
      resume_iteration: int
      stop_iteration: int
      device: 'cuda' | 'cpu'
      mini_data: bool
    """

    # Arugments & parameters
    dataset_dir = args.dataset_dir
    workspace = args.workspace
    holdout_fold = args.holdout_fold
    model_type = args.model_type
    loss_type = args.loss_type
    augmentation = args.augmentation
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    resume_iteration = args.resume_iteration
    stop_iteration = args.stop_iteration
    device = 'cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu'
    mini_data = args.mini_data
    filename = args.filename

    loss_func = get_loss_func(loss_type)
    num_workers = 8

    # Paths
    if mini_data:
        prefix = 'minidata_'
    else:
        prefix = ''

    train_hdf5_path = os.path.join(workspace, 'hdf5s',
                                   '{}training.h5'.format(prefix))

    test_hdf5_path = os.path.join(workspace, 'hdf5s',
                                  '{}testing.h5'.format(prefix))

    evaluate_hdf5_path = os.path.join(workspace, 'hdf5s',
                                      'evaluation.h5'.format(prefix))

    test_reference_csv_path = os.path.join(
        dataset_dir, 'metadata', 'groundtruth_strong_label_testing_set.csv')

    evaluate_reference_csv_path = os.path.join(
        dataset_dir, 'metadata', 'groundtruth_strong_label_evaluation_set.csv')

    checkpoints_dir = os.path.join(workspace, 'checkpoints', '{}{}'.format(
        prefix, filename), 'holdout_fold={}'.format(holdout_fold),
                                   'model_type={}'.format(model_type),
                                   'loss_type={}'.format(loss_type),
                                   'augmentation={}'.format(augmentation),
                                   'batch_size={}'.format(batch_size))
    create_folder(checkpoints_dir)

    tmp_submission_path = os.path.join(
        workspace, '_tmp_submission', '{}{}'.format(prefix, filename),
        'holdout_fold={}'.format(holdout_fold),
        'model_type={}'.format(model_type), 'loss_type={}'.format(loss_type),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size), '_submission.csv')
    create_folder(os.path.dirname(tmp_submission_path))

    statistics_path = os.path.join(workspace, 'statistics', '{}{}'.format(
        prefix, filename), 'holdout_fold={}'.format(holdout_fold),
                                   'model_type={}'.format(model_type),
                                   'loss_type={}'.format(loss_type),
                                   'augmentation={}'.format(augmentation),
                                   'batch_size={}'.format(batch_size),
                                   'statistics.pickle')
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(workspace, 'logs', '{}{}'.format(prefix, filename),
                            'holdout_fold={}'.format(holdout_fold),
                            'model_type={}'.format(model_type),
                            'loss_type={}'.format(loss_type),
                            'augmentation={}'.format(augmentation),
                            'batch_size={}'.format(batch_size))
    create_logging(logs_dir, 'w')
    logging.info(args)

    if 'cuda' in device:
        logging.info('Using GPU.')
    else:
        logging.info('Using CPU. Set --cuda flag to use GPU.')

    # Model
    assert model_type, 'Please specify model_type!'
    Model = eval(model_type)
    model = Model(sample_rate, window_size, hop_size, mel_bins, fmin, fmax,
                  classes_num)

    if resume_iteration:
        resume_checkpoint_path = os.path.join(
            checkpoints_dir, '{}_iterations.pth'.format(resume_iteration))
        logging.info(
            'Load resume model from {}'.format(resume_checkpoint_path))
        resume_checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(resume_checkpoint['model'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = resume_checkpoint['iteration']
    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in device:
        model.to(device)

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)

    # Dataset
    dataset = DCASE2017Task4Dataset()

    # Sampler
    train_sampler = TrainSampler(hdf5_path=train_hdf5_path,
                                 batch_size=batch_size *
                                 2 if 'mixup' in augmentation else batch_size)

    test_sampler = TestSampler(hdf5_path=test_hdf5_path, batch_size=batch_size)

    evaluate_sampler = TestSampler(hdf5_path=evaluate_hdf5_path,
                                   batch_size=batch_size)

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_sampler=train_sampler,
                                               collate_fn=collate_fn,
                                               num_workers=num_workers,
                                               pin_memory=True)

    test_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_sampler=test_sampler,
                                              collate_fn=collate_fn,
                                              num_workers=num_workers,
                                              pin_memory=True)

    evaluate_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=evaluate_sampler,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=True)

    if 'mixup' in augmentation:
        mixup_augmenter = Mixup(mixup_alpha=1.)

    # Evaluator
    evaluator = Evaluator(model=model)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    train_bgn_time = time.time()

    # Train on mini batches
    for batch_data_dict in train_loader:

        # Evaluate
        if (iteration % 1000 == 0
                and iteration > resume_iteration):  # or (iteration == 0):

            logging.info('------------------------------------')
            logging.info('Iteration: {}'.format(iteration))

            train_fin_time = time.time()

            for (data_type, data_loader, reference_csv_path) in [
                ('test', test_loader, test_reference_csv_path),
                ('evaluate', evaluate_loader, evaluate_reference_csv_path)
            ]:

                # Calculate tatistics
                (statistics, _) = evaluator.evaluate(data_loader,
                                                     reference_csv_path,
                                                     tmp_submission_path)

                logging.info('{} statistics:'.format(data_type))
                logging.info('    Clipwise mAP: {:.3f}'.format(
                    np.mean(statistics['clipwise_ap'])))
                logging.info('    Framewise mAP: {:.3f}'.format(
                    np.mean(statistics['framewise_ap'])))
                logging.info('    {}'.format(
                    statistics['sed_metrics']['overall']['error_rate']))

                statistics_container.append(data_type, iteration, statistics)

            statistics_container.dump()

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info('Train time: {:.3f} s, validate time: {:.3f} s'
                         ''.format(train_time, validate_time))

            train_bgn_time = time.time()

        # Save model
        if iteration % 10000 == 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        if 'mixup' in augmentation:
            batch_data_dict['mixup_lambda'] = mixup_augmenter.get_lambda(
                batch_size=len(batch_data_dict['waveform']))

        # Move data to GPU
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key],
                                                       device)

        # Train
        model.train()

        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'],
                                      batch_data_dict['mixup_lambda'])
            batch_target_dict = {
                'target':
                do_mixup(batch_data_dict['target'],
                         batch_data_dict['mixup_lambda'])
            }
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            batch_target_dict = {'target': batch_data_dict['target']}

        # loss
        loss = loss_func(batch_output_dict, batch_target_dict)
        print(iteration, loss)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Stop learning
        if iteration == stop_iteration:
            break

        iteration += 1
def train(args):

    # Arugments & parameters
    window_size = args.window_size
    hop_size = args.hop_size
    mel_bins = args.mel_bins
    fmin = args.fmin
    fmax = args.fmax
    model_type = args.model_type
    pretrained_checkpoint_path = args.pretrained_checkpoint_path
    freeze_base = args.freeze_base
    freeze_base = True
    device = 'cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu'
    sample_rate = config.sample_rate
    classes_num = config.classes_num
    pretrain = True if pretrained_checkpoint_path else False

    # Model
    Model = eval(model_type)
    model = Model(sample_rate, window_size, hop_size, mel_bins, fmin, fmax,
                  classes_num, freeze_base)

    # Load pretrained model
    if pretrain:
        logging.info(
            'Load pretrained model from {}'.format(pretrained_checkpoint_path))
        model.load_from_pretrain(pretrained_checkpoint_path)

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in device:
        model.to(device)

    print('Load pretrained model successfully!')
    ###############Copying main.py####################
    workspace_input = args.workspace_input
    workspace_output = args.workspace_output
    data_type = 'balanced_train'
    loss_type = 'clip_bce'
    balanced = 'balanced'
    augmentation = 'none'
    batch_size = 1
    learning_rate = 1e-3
    resume_iteration = 0
    early_stop = 100000
    device = torch.device('cuda') if args.cuda and torch.cuda.is_available(
    ) else torch.device('cpu')
    filename = args.filename
    num_workers = 8
    clip_samples = config.clip_samples
    loss_func = get_loss_func(loss_type)
    black_list_csv = 'metadata/black_list/groundtruth_weak_label_evaluation_set.csv'
    previous_loss = None

    train_indexes_hdf5_path = os.path.join(workspace_input, 'hdf5s', 'indexes',
                                           '{}.h5'.format(data_type))

    eval_bal_indexes_hdf5_path = os.path.join(workspace_input, 'hdf5s',
                                              'indexes', 'balanced_train.h5')

    eval_test_indexes_hdf5_path = os.path.join(workspace_input, 'hdf5s',
                                               'indexes', 'eval.h5')

    checkpoints_dir = os.path.join(
        workspace_output, 'checkpoints', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))
    create_folder(checkpoints_dir)

    statistics_path = os.path.join(
        workspace_output, 'statistics', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size), 'statistics.pkl')
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(
        workspace_output, 'logs', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))

    create_logging(logs_dir, filemode='w')
    logging.info(args)

    if 'cuda' in str(device):
        logging.info('Using GPU.')
        device = 'cuda'
    else:
        logging.info('Using CPU.')
        device = 'cpu'

    # Model
    Model = eval(model_type)
    model = Model(sample_rate=sample_rate,
                  window_size=window_size,
                  hop_size=hop_size,
                  mel_bins=mel_bins,
                  fmin=fmin,
                  fmax=fmax,
                  classes_num=classes_num,
                  freeze_base=freeze_base)
    params_num = count_parameters(model)
    # flops_num = count_flops(model, clip_samples)
    logging.info('Parameters num: {}'.format(params_num))
    # logging.info('Flops num: {:.3f} G'.format(flops_num / 1e9))

    # Dataset will be used by DataLoader later. Dataset takes a meta as input
    # and return a waveform and a target.
    dataset = AudioSetDataset(clip_samples=clip_samples,
                              classes_num=classes_num)

    # Train sampler
    (train_sampler, train_collector) = get_train_sampler(
        balanced, augmentation,
        workspace_input + 'hdf5s/indexes/balanced_train.h5', black_list_csv,
        batch_size)

    # Evaluate sampler
    eval_bal_sampler = EvaluateSampler(indexes_hdf5_path=workspace_input +
                                       'hdf5s/indexes/balanced_train.h5',
                                       batch_size=batch_size)

    eval_test_sampler = EvaluateSampler(indexes_hdf5_path=workspace_input +
                                        'hdf5s/indexes/eval.h5',
                                        batch_size=batch_size)

    eval_collector = Collator(mixup_alpha=None)

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_sampler=train_sampler,
                                               collate_fn=train_collector,
                                               num_workers=num_workers,
                                               pin_memory=True)

    eval_bal_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_bal_sampler,
        collate_fn=eval_collector,
        num_workers=num_workers,
        pin_memory=True)

    eval_test_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_test_sampler,
        collate_fn=eval_collector,
        num_workers=num_workers,
        pin_memory=True)

    # Evaluator
    bal_evaluator = Evaluator(model=model, generator=eval_bal_loader)
    test_evaluator = Evaluator(model=model, generator=eval_test_loader)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)

    train_bgn_time = time.time()
    if resume_iteration > 0:
        resume_checkpoint_path = os.path.join(
            workspace_input, 'checkpoints', filename,
            'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
            .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                    fmax), 'data_type={}'.format(data_type), model_type,
            'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
            'augmentation={}'.format(augmentation),
            'batch_size={}'.format(batch_size),
            '{}_iterations.pth'.format(resume_iteration))

        logging.info('Loading checkpoint {}'.format(resume_checkpoint_path))
        if torch.cuda.is_available():
            checkpoint = torch.load(resume_checkpoint_path)
        else:
            checkpoint = torch.load(resume_checkpoint_path, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        train_sampler.load_state_dict(checkpoint['sampler'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = checkpoint['iteration']

    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in str(device):
        model.to(device)

    time1 = time.time()

    for iterate_n, batch_data_dict in enumerate(train_loader):
        """batch_data_dict: {
            'audio_name': (batch_size [*2 if mixup],), 
            'waveform': (batch_size [*2 if mixup], clip_samples), 
            'target': (batch_size [*2 if mixup], classes_num), 
            (ifexist) 'mixup_lambda': (batch_size * 2,)}
        """

        # Evaluate
        if (iteration % 2000 == 0
                and iteration > resume_iteration) or (iteration == 0):
            train_fin_time = time.time()

            bal_statistics = bal_evaluator.evaluate()
            test_statistics = test_evaluator.evaluate()

            logging.info('Validate bal mAP: {:.3f}'.format(
                np.mean(bal_statistics['average_precision'])))

            logging.info('Validate test mAP: {:.3f}'.format(
                np.mean(test_statistics['average_precision'])))

            statistics_container.append(iteration,
                                        bal_statistics,
                                        data_type='bal')
            statistics_container.append(iteration,
                                        test_statistics,
                                        data_type='test')
            statistics_container.dump()

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info(
                'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s'
                ''.format(iteration, train_time, validate_time))

            logging.info('------------------------------------')

            train_bgn_time = time.time()

        # Save model
        if iteration % 20000 == 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        # Move data to device
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key],
                                                       device)

        # Forward
        model.train()
        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'],
                                      batch_data_dict['mixup_lambda'])
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {
                'target':
                do_mixup(batch_data_dict['target'],
                         batch_data_dict['mixup_lambda'])
            }
            """{'target': (batch_size, classes_num)}"""
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {'target': batch_data_dict['target']}
            """{'target': (batch_size, classes_num)}"""
        loss = loss_func(batch_output_dict, batch_target_dict)
        # Loss
        # try:
        #     loss = loss_func(batch_output_dict, batch_target_dict)
        # except:
        #     tensor = batch_output_dict['clipwise_output'].detach().numpy()
        #     arr = -1. * np.where(tensor > 0,0.,tensor)
        #     batch_output_dict['clipwise_output'] = torch.tensor(np.where(arr > 1,1.,arr),requires_grad=True)
        #     loss = loss_func(batch_output_dict, batch_target_dict)
        # Backward
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if iteration % 10 == 0:
            print('--- Iteration: {}, train time: {:.3f} s / 10 iterations ---'\
                .format(iteration, time.time() - time1))
            time1 = time.time()

        iteration += 1

        # Stop learning
        if iteration == early_stop:
            break
예제 #11
0
def train(args):

    # Arugments & parameters
    dataset_dir = args.dataset_dir
    workspace = args.workspace
    holdout_fold = args.holdout_fold
    model_type = args.model_type
    pretrained_checkpoint_path = args.pretrained_checkpoint_path
    freeze_base = args.freeze_base
    loss_type = args.loss_type
    augmentation = args.augmentation
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    resume_iteration = args.resume_iteration
    stop_iteration = args.stop_iteration
    device = 'cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu'
    filename = args.filename
    num_workers = 8

    loss_func = get_loss_func(loss_type)
    pretrain = True if pretrained_checkpoint_path else False

    hdf5_path = os.path.join(workspace, 'features', 'waveform.h5')

    checkpoints_dir = os.path.join(workspace, 'checkpoints', filename,
                                   'holdout_fold={}'.format(holdout_fold),
                                   model_type, 'pretrain={}'.format(pretrain),
                                   'loss_type={}'.format(loss_type),
                                   'augmentation={}'.format(augmentation),
                                   'batch_size={}'.format(batch_size),
                                   'freeze_base={}'.format(freeze_base))
    create_folder(checkpoints_dir)

    statistics_path = os.path.join(workspace, 'statistics', filename,
                                   'holdout_fold={}'.format(holdout_fold),
                                   model_type, 'pretrain={}'.format(pretrain),
                                   'loss_type={}'.format(loss_type),
                                   'augmentation={}'.format(augmentation),
                                   'batch_size={}'.format(batch_size),
                                   'freeze_base={}'.format(freeze_base),
                                   'statistics.pickle')
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(workspace, 'logs', filename,
                            'holdout_fold={}'.format(holdout_fold), model_type,
                            'pretrain={}'.format(pretrain),
                            'loss_type={}'.format(loss_type),
                            'augmentation={}'.format(augmentation),
                            'batch_size={}'.format(batch_size),
                            'freeze_base={}'.format(freeze_base))
    create_logging(logs_dir, 'w')
    logging.info(args)

    if 'cuda' in device:
        logging.info('Using GPU.')
    else:
        logging.info('Using CPU. Set --cuda flag to use GPU.')

    # Model
    Model = eval(model_type)
    model = Model(sample_rate, window_size, hop_size, mel_bins, fmin, fmax,
                  classes_num, freeze_base)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    if pretrain:
        logging.info(
            'Load pretrained model from {}'.format(pretrained_checkpoint_path))
        model.load_from_pretrain(pretrained_checkpoint_path)

    if resume_iteration:
        resume_checkpoint_path = os.path.join(
            checkpoints_dir, '{}_iterations.pth'.format(resume_iteration))
        logging.info(
            'Load resume model from {}'.format(resume_checkpoint_path))
        resume_checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(resume_checkpoint['model'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = resume_checkpoint['iteration']
    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    dataset = GtzanDataset()

    # Data generator
    train_sampler = TrainSampler(hdf5_path=hdf5_path,
                                 holdout_fold=holdout_fold,
                                 batch_size=batch_size *
                                 2 if 'mixup' in augmentation else batch_size)

    validate_sampler = EvaluateSampler(hdf5_path=hdf5_path,
                                       holdout_fold=holdout_fold,
                                       batch_size=batch_size)

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_sampler=train_sampler,
                                               collate_fn=collate_fn,
                                               num_workers=num_workers,
                                               pin_memory=True)

    validate_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=validate_sampler,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=True)

    if 'cuda' in device:
        model.to(device)

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)

    if 'mixup' in augmentation:
        mixup_augmenter = Mixup(mixup_alpha=1.)

    # Evaluator
    evaluator = Evaluator(model=model)

    train_bgn_time = time.time()

    # Train on mini batches
    for batch_data_dict in train_loader:

        # import crash
        # asdf

        # Evaluate
        if iteration % 200 == 0 and iteration > 0:
            if resume_iteration > 0 and iteration == resume_iteration:
                pass
            else:
                logging.info('------------------------------------')
                logging.info('Iteration: {}'.format(iteration))

                train_fin_time = time.time()

                statistics = evaluator.evaluate(validate_loader)
                logging.info('Validate accuracy: {:.3f}'.format(
                    statistics['accuracy']))

                statistics_container.append(iteration, statistics, 'validate')
                statistics_container.dump()

                train_time = train_fin_time - train_bgn_time
                validate_time = time.time() - train_fin_time

                logging.info('Train time: {:.3f} s, validate time: {:.3f} s'
                             ''.format(train_time, validate_time))

                train_bgn_time = time.time()

        # Save model
        if iteration % 2000 == 0 and iteration > 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        if 'mixup' in augmentation:
            batch_data_dict['mixup_lambda'] = mixup_augmenter.get_lambda(
                len(batch_data_dict['waveform']))

        # Move data to GPU
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key],
                                                       device)

        # Train
        model.train()

        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'],
                                      batch_data_dict['mixup_lambda'])
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {
                'target':
                do_mixup(batch_data_dict['target'],
                         batch_data_dict['mixup_lambda'])
            }
            """{'target': (batch_size, classes_num)}"""
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {'target': batch_data_dict['target']}
            """{'target': (batch_size, classes_num)}"""

        # loss
        loss = loss_func(batch_output_dict, batch_target_dict)
        print(iteration, loss)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Stop learning
        if iteration == stop_iteration:
            break

        iteration += 1