コード例 #1
0
def evaluate(model, input, target, stats_dir, probs_dir, iteration):
    """Evaluate a model.

    Args:
      model: object
      output: 2d array, (samples_num, classes_num)
      target: 2d array, (samples_num, classes_num)
      stats_dir: str, directory to write out statistics.
      probs_dir: str, directory to write out output (samples_num, classes_num)
      iteration: int

    Returns:
      None
    """
    # Check if cuda
    cuda = next(model.parameters()).is_cuda

    utilities.create_folder(stats_dir)
    utilities.create_folder(probs_dir)

    # Predict presence probabilittarget
    callback_time = time.time()
    (clips_num, time_steps, freq_bins) = input.shape

    (input, target) = utilities.transform_data(input, target)

    output = forward_in_batch(model, input, batch_size=500, cuda=cuda)
    output = output.data.cpu().numpy()  # (clips_num, classes_num)

    # Write out presence probabilities
    prob_path = os.path.join(probs_dir, "prob_{}_iters.p".format(iteration))
    cPickle.dump(output, open(prob_path, 'wb'))

    # Calculate statistics
    stats = utilities.calculate_stats(output, target)

    # Write out statistics
    stat_path = os.path.join(stats_dir, "stat_{}_iters.p".format(iteration))
    cPickle.dump(stats, open(stat_path, 'wb'))

    mAP = np.mean([stat['AP'] for stat in stats])
    mAUC = np.mean([stat['auc'] for stat in stats])
    logging.info("mAP: {:.6f}, AUC: {:.6f}, Callback time: {:.3f} s".format(
        mAP, mAUC,
        time.time() - callback_time))

    if False:
        logging.info("Saveing prob to {}".format(prob_path))
        logging.info("Saveing stat to {}".format(stat_path))
コード例 #2
0
def test(args):

    data_dir = args.data_dir
    workspace = args.workspace
    #mini_data = args.mini_data
    balance_type = args.balance_type
    #learning_rate = args.learning_rate
    filename = args.filename
    model_type = args.model_type
    model = args.model
    #batch_size = args.batch_size

    # Test data
    test_hdf5_path = os.path.join(data_dir, "eval1.h5")
    (test_x, test_y, test_id_list) = utilities.load_data(test_hdf5_path)

    # Output directories
    sub_dir = os.path.join(filename, 'balance_type={}'.format(balance_type),
                           'model_type={}'.format(model_type))

    models_dir = os.path.join(workspace, "models", sub_dir)
    utilities.create_folder(models_dir)

    stats_dir = os.path.join(workspace, "stats", sub_dir)
    utilities.create_folder(stats_dir)

    probs_dir = os.path.join(workspace, "probs", sub_dir)
    utilities.create_folder(probs_dir)

    iteration = 1200

    # Optimization method
    optimizer = optim.Adam(model.parameters(),
                           lr=1e-3,
                           betas=(0.9, 0.999),
                           eps=1e-07)

    #Loading model..

    PATH = os.path.join(models_dir,
                        "md_{}_iters_300batchsize.tar".format(iteration))
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    logging.info("Training data shape: {}".format(test_x.shape))
    logging.info("Training data shape: {}".format(test_y.shape))

    logging.info("Test statistics:")
    evaluate(model=model,
             input=test_x,
             target=test_y,
             stats_dir=os.path.join(stats_dir, "test"),
             probs_dir=os.path.join(probs_dir, "test"),
             iteration=iteration)

    print('ready')
コード例 #3
0
    subparsers = parser.add_subparsers(dest='mode')
    parser_train = subparsers.add_parser('train')
    parser_get_avg_stats = subparsers.add_parser('get_avg_stats')

    args = parser.parse_args()

    args.filename = utilities.get_filename(__file__)

    # Logs
    sub_dir = os.path.join(args.filename,
                           'balance_type={}'.format(args.balance_type),
                           'model_type={}'.format(args.model_type))

    logs_dir = os.path.join(args.workspace, 'logs', sub_dir)
    utilities.create_folder(logs_dir)
    logging = utilities.create_logging(logs_dir, filemode='w')

    logging.info(os.path.abspath(__file__))
    logging.info(args)
    
    totest = 0 
    
    if totest == 0 and (not (args.mode == 'get_avg_stats')): 
        test(args)
        
    
    else: 
        if args.mode == "train":
            train(args)
コード例 #4
0
def train(args):
    """Train a model.
    """

    data_dir = args.data_dir
    workspace = args.workspace
    mini_data = args.mini_data
    balance_type = args.balance_type
    learning_rate = args.learning_rate
    filename = args.filename
    model_type = args.model_type
    model = args.model
    batch_size = args.batch_size

    # Path of hdf5 data
    bal_train_hdf5_path = os.path.join(data_dir, "bal_train.h5")
    unbal_train_hdf5_path = os.path.join(data_dir, "unbal_train.h5")
    test_hdf5_path = os.path.join(data_dir, "eval.h5")

    # Load data
    load_time = time.time()

    if mini_data:
        # Only load balanced data
        (bal_train_x, bal_train_y,
         bal_train_id_list) = utilities.load_data(bal_train_hdf5_path)

        train_x = bal_train_x
        train_y = bal_train_y
        train_id_list = bal_train_id_list

    else:
        # Load both balanced and unbalanced data
        (bal_train_x, bal_train_y,
         bal_train_id_list) = utilities.load_data(bal_train_hdf5_path)

        (unbal_train_x, unbal_train_y,
         unbal_train_id_list) = utilities.load_data(unbal_train_hdf5_path)

        train_x = np.concatenate((bal_train_x, unbal_train_x))
        train_y = np.concatenate((bal_train_y, unbal_train_y))
        train_id_list = bal_train_id_list + unbal_train_id_list

    # Test data
    (test_x, test_y, test_id_list) = utilities.load_data(test_hdf5_path)

    logging.info("Loading data time: {:.3f} s".format(time.time() - load_time))
    logging.info("Training data shape: {}".format(train_x.shape))

    # Optimization method
    optimizer = Adam(lr=learning_rate)
    model.compile(loss='binary_crossentropy', optimizer=optimizer)

    # Output directories
    sub_dir = os.path.join(filename, 'balance_type={}'.format(balance_type),
                           'model_type={}'.format(model_type))

    models_dir = os.path.join(workspace, "models", sub_dir)
    utilities.create_folder(models_dir)

    stats_dir = os.path.join(workspace, "stats", sub_dir)
    utilities.create_folder(stats_dir)

    probs_dir = os.path.join(workspace, "probs", sub_dir)
    utilities.create_folder(probs_dir)

    # Data generator
    if balance_type == 'no_balance':
        DataGenerator = data_generator.VanillaDataGenerator

    elif balance_type == 'balance_in_batch':
        DataGenerator = data_generator.BalancedDataGenerator

    else:
        raise Exception("Incorrect balance_type!")

    train_gen = DataGenerator(x=train_x,
                              y=train_y,
                              batch_size=batch_size,
                              shuffle=True,
                              seed=1234)

    iteration = 0
    call_freq = 1000
    train_time = time.time()

    for (batch_x, batch_y) in train_gen.generate():

        # Compute stats every several interations
        if iteration % call_freq == 0:

            logging.info("------------------")

            logging.info("Iteration: {}, train time: {:.3f} s".format(
                iteration,
                time.time() - train_time))

            logging.info("Balance train statistics:")
            evaluate(model=model,
                     input=bal_train_x,
                     target=bal_train_y,
                     stats_dir=os.path.join(stats_dir, 'bal_train'),
                     probs_dir=os.path.join(probs_dir, 'bal_train'),
                     iteration=iteration)

            logging.info("Test statistics:")
            evaluate(model=model,
                     input=test_x,
                     target=test_y,
                     stats_dir=os.path.join(stats_dir, "test"),
                     probs_dir=os.path.join(probs_dir, "test"),
                     iteration=iteration)

            train_time = time.time()

        # Update params
        (batch_x, batch_y) = utilities.transform_data(batch_x, batch_y)
        model.train_on_batch(x=batch_x, y=batch_y)

        iteration += 1

        # Save model
        save_out_path = os.path.join(models_dir,
                                     "md_{}_iters.h5".format(iteration))
        model.save(save_out_path)

        # Stop training when maximum iteration achieves
        if iteration == 50001:
            break
def train(args):
    """ Arguments & parameters"""
    # from main.py
    workspace = args.workspace  # store experiments results in the workspace
    sample_rate = args.sample_rate
    window_size = args.window_size
    hop_size = args.hop_size
    mel_bins = args.mel_bins
    fmin = args.fmin
    fmax = args.fmax
    model_type = args.model_type
    loss_type = args.loss_type
    balanced = args.balanced
    augmentation = args.augmentation
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    # resume_iteration = args.resume_iteration
    early_stop = args.early_stop
    filename = args.filename
    # for fine-tune models
    pretrained_checkpoint_path = args.pretrained_checkpoint_path
    freeze_base_num = args.freeze_base_num

    pretrain = True if pretrained_checkpoint_path else False

    # Define Saving Paths
    best_model_path = os.path.join(workspace, 'best_model', filename,
                                   'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
                                   .format(sample_rate, window_size, hop_size, mel_bins, fmin, fmax), model_type,
                                   'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
                                   'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size),
                                   )
    create_folder(os.path.dirname(best_model_path))

    statistics_path = os.path.join(workspace, 'statistics', filename,
                                   'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
                                   .format(sample_rate, window_size, hop_size, mel_bins, fmin, fmax), model_type,
                                   'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
                                   'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size),
                                   'statistics.pkl')
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(workspace, 'logs', filename,
                            'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
                            .format(sample_rate, window_size, hop_size, mel_bins, fmin, fmax), model_type,
                            'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
                            'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size))
    create_logging(logs_dir, filemode='w')

    # Dataset
    # return a waveform and a one-hot encoded target.
    # The training csv file filtered minor classes by a Dropping_threshold (10)
    train_csv = pd.read_csv("German_Birdcall_Dataset_Preparation/Germany_Birdcall_resampled_filtered.csv")
    classes_num = len(train_csv["gen"].unique())
    audio_path = "German_Birdcall_Dataset_Preparation/Germany_Birdcall_resampled"
    # Split csv file training and test
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.4, random_state=42)
    for train_idx, test_idx in splitter.split(X=train_csv, y=train_csv["gen"]):
        train_df = train_csv.loc[train_idx, :].reset_index(drop=True)
        test_df = train_csv.loc[test_idx, :].reset_index(drop=True)
    # dataset = WaveformDataset(df: pd.DataFrame, datadir: str)
    train_dataset = WaveformDataset(df=train_df, datadir=audio_path)
    test_dataset = WaveformDataset(df=test_df, datadir=audio_path)

    # Train sampler and Train loader
    num_workers = 10
    if balanced == 'balanced':
        train_sampler = BalancedSampler(
            df=train_df,
            batch_size=batch_size * 2 if 'mixup' in augmentation else batch_size)
        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            batch_sampler=train_sampler,
            collate_fn=collate_fn,
            num_workers=num_workers,
            pin_memory=True)
    else:
        train_sampler = RandomSampler(
            df=train_df,
            batch_size=batch_size * 2 if 'mixup' in augmentation else batch_size)
        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            batch_sampler=train_sampler,
            collate_fn=collate_fn,
            num_workers=num_workers,
            pin_memory=True)

    eval_test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        num_workers=num_workers)

    if 'mixup' in augmentation:
        mixup_augmenter = Mixup(mixup_alpha=1.)

    # Model Initialization
    transfer_model = eval(model_type)  # model_type = "Transfer_Cnn14"
    model = transfer_model(sample_rate, window_size, hop_size, mel_bins, fmin, fmax,
                  classes_num, freeze_base_num)

    logging.info(args)

    # Load pretrained model
    # CHECKPOINT_PATH="Cnn14_mAP=0.431.pth"/"Cnn10_mAP=0.380.pth"/"Cnn6_mAP=0.343.pth"
    if pretrain:
        logging.info('Load pretrained model from {}'.format(pretrained_checkpoint_path))
        model.load_from_pretrain(pretrained_checkpoint_path)
        print('Load pretrained model successfully!')
    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if 'cuda' in device:
        model.to(device)
        logging.info('Using GPU.')
    else:
        logging.info('Using CPU. Set --cuda flag to use GPU.')

    # Loss
    loss_func = get_loss_func(loss_type)

    # Evaluator : return mAP and Auc value
    evaluator = Evaluator(model=model)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate,
                           betas=(0.9, 0.999), eps=1e-08, weight_decay=0., amsgrad=True)

    # Training Loop
    time_initial = time.time()
    train_bgn_time = time.time()
    time1 = time.time()
    iteration = 0
    loss_sum = 0
    loss_average = 0
    best_mAP = 0
    # store validation results with pd.dataframe
    validation_results = pd.DataFrame(columns=["iteration","mAP","Auc"])
    # store training losses with pd.dataframe
    training_results = pd.DataFrame(columns=["iteration","average loss"])
    i = 0
    j = 0
    for batch_data_dict in train_loader:
        """batch_data_dict: {
            'audio_name': (batch_size [*2 if mixup],), 
            'waveform': (batch_size [*2 if mixup], clip_samples), 
            'target': (batch_size [*2 if mixup], classes_num), 
            (ifexist) 'mixup_lambda': (batch_size * 2,)}
        """

        # Evaluate
        if iteration % 200 == 0 or (iteration == 0):
            train_fin_time = time.time()

            test_statistics = evaluator.evaluate(eval_test_loader)
            current_mAP = np.mean(test_statistics['average_precision'])
            current_auc = np.mean(test_statistics['auc'])
            logging.info('Validate test mAP: {:.3f}'.format(current_mAP))
            logging.info('Validate test Auc: {:.3f}'.format(current_auc))
            validation_results.loc[i] = [iteration, current_mAP, current_auc]
            i += 1

            statistics_container.append(iteration, test_statistics, data_type='test')
            statistics_container.dump()
            
            # copy best model
            if current_mAP > best_mAP:
                best_mAP = current_mAP
                best_model = copy.deepcopy(model.state_dict())

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info(
                'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s'
                ''.format(iteration, train_time, validate_time))
            logging.info('------------------------------------')

            train_bgn_time = time.time()  # reset after evaluation

        # Mixup lambda
        if 'mixup' in augmentation:
            batch_data_dict['mixup_lambda'] = mixup_augmenter.get_lambda(
                batch_size=len(batch_data_dict['waveform']))

        # Move data to device
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key], device)

        # Forward
        model.train()
        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'],
                                      batch_data_dict['mixup_lambda'])
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {'target': do_mixup(batch_data_dict['target'],
                                                    batch_data_dict['mixup_lambda'])}
            """{'target': (batch_size, classes_num)}"""
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {'target': batch_data_dict['target']}
            """{'target': (batch_size, classes_num)}"""

        # Loss
        loss = loss_func(batch_output_dict, batch_target_dict)

        # Backward
        loss.backward()
#         print(loss)
        loss_sum += loss.item()

        optimizer.step()
        optimizer.zero_grad()

        if iteration % 200 == 0:
            print('--- Iteration: {}, train time: {:.3f} s / 200 iterations ---' \
                  .format(iteration, time.time() - time1))
            time1 = time.time()
            loss_average = loss_sum / 200
            print("average loss of recent 200 batches {:.5f}".format(loss_average))
            loss_sum = 0
            training_results.loc[j] = [iteration, loss_average]
            j += 1
        
        # Stop learning
        if iteration == early_stop:
            break

        iteration += 1

    # Save model
    best_model_path = "best_"+model_type+balanced+augmentation+"freeze"\
                      + str(freeze_base_num)+"_mAP={:.3f}".format(best_mAP)
    torch.save(best_model, best_model_path+".pth")

    # Save validation results
    validation_results_path = "validation_results"+model_type+balanced\
                              + augmentation+"freeze"+str(freeze_base_num)\
                              + "_mAP={:.3f}".format(best_mAP)
    validation_results.to_csv(validation_results_path+'.csv', index=False)
    
    # Save training results
    training_results_path = "training_results"+model_type+balanced\
                              + augmentation+"freeze"+str(freeze_base_num)\
                              + "_mAP={:.3f}".format(best_mAP)
    training_results.to_csv(training_results_path+'.csv', index=False)    

    time_end = time.time()
    time_cost = time_end - time_initial
    print("The whole training process takes: {:.3f} s".format(time_cost))
コード例 #6
0
def train(args):
    """Train a model.
    """

    data_dir = args.data_dir
    workspace = args.workspace
    mini_data = args.mini_data
    balance_type = args.balance_type
    learning_rate = args.learning_rate
    filename = args.filename
    model_type = args.model_type
    model = args.model
    batch_size = args.batch_size
    cuda = True

    # Move model to gpu
    if cuda:
        model.cuda()

    # Path of hdf5 data
    bal_train_hdf5_path = os.path.join(data_dir, "bal_train.h5")
    unbal_train_hdf5_path = os.path.join(data_dir, "unbal_train.h5")
    test_hdf5_path = os.path.join(data_dir, "eval.h5")

    # Load data
    load_time = time.time()

    if mini_data:
        # Only load balanced data
        (bal_train_x, bal_train_y,
         bal_train_id_list) = utilities.load_data(bal_train_hdf5_path)

        train_x = bal_train_x
        train_y = bal_train_y
        train_id_list = bal_train_id_list

    else:
        # Load both balanced and unbalanced data
        (bal_train_x, bal_train_y,
         bal_train_id_list) = utilities.load_data(bal_train_hdf5_path)

        (unbal_train_x, unbal_train_y,
         unbal_train_id_list) = utilities.load_data(unbal_train_hdf5_path)

        train_x = np.concatenate((bal_train_x, unbal_train_x))
        train_y = np.concatenate((bal_train_y, unbal_train_y))
        train_id_list = bal_train_id_list + unbal_train_id_list

    # Test data
    (test_x, test_y, test_id_list) = utilities.load_data(test_hdf5_path)

    logging.info("Loading data time: {:.3f} s".format(time.time() - load_time))
    logging.info("Training data shape: {}".format(train_x.shape))

    # Optimization method
    optimizer = optim.Adam(model.parameters(),
                           lr=1e-3,
                           betas=(0.9, 0.999),
                           eps=1e-07)

    # Output directories
    sub_dir = os.path.join(filename, 'balance_type={}'.format(balance_type),
                           'model_type={}'.format(model_type))

    models_dir = os.path.join(workspace, "models", sub_dir)
    utilities.create_folder(models_dir)

    stats_dir = os.path.join(workspace, "stats", sub_dir)
    utilities.create_folder(stats_dir)

    probs_dir = os.path.join(workspace, "probs", sub_dir)
    utilities.create_folder(probs_dir)

    # Data generator
    if balance_type == 'no_balance':
        DataGenerator = data_generator.VanillaDataGenerator

    elif balance_type == 'balance_in_batch':
        DataGenerator = data_generator.BalancedDataGenerator

    else:
        raise Exception("Incorrect balance_type!")

    train_gen = DataGenerator(x=train_x,
                              y=train_y,
                              batch_size=batch_size,
                              shuffle=True,
                              seed=1234)

    iteration = 0
    call_freq = 1000
    train_time = time.time()

    for (batch_x, batch_y) in train_gen.generate():

        # Compute stats every several interations
        if iteration % call_freq == 0 and iteration > 1:

            logging.info("------------------")

            logging.info("Iteration: {}, train time: {:.3f} s".format(
                iteration,
                time.time() - train_time))

            logging.info("Balance train statistics:")
            evaluate(model=model,
                     input=bal_train_x,
                     target=bal_train_y,
                     stats_dir=os.path.join(stats_dir, 'bal_train'),
                     probs_dir=os.path.join(probs_dir, 'bal_train'),
                     iteration=iteration)

            logging.info("Test statistics:")
            evaluate(model=model,
                     input=test_x,
                     target=test_y,
                     stats_dir=os.path.join(stats_dir, "test"),
                     probs_dir=os.path.join(probs_dir, "test"),
                     iteration=iteration)

            train_time = time.time()

        (batch_x, batch_y) = utilities.transform_data(batch_x, batch_y)

        batch_x = move_data_to_gpu(batch_x, cuda)
        batch_y = move_data_to_gpu(batch_y, cuda)

        # Forward.
        model.train()
        output = model(batch_x)

        # Loss.
        loss = F.binary_cross_entropy(output, batch_y)

        # Backward.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        iteration += 1

        # Save model.
        if iteration % 5000 == 0:
            save_out_dict = {
                'iteration': iteration,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            save_out_path = os.path.join(models_dir,
                                         "md_{}_iters.tar".format(iteration))
            torch.save(save_out_dict, save_out_path)
            logging.info("Save model to {}".format(save_out_path))

        # Stop training when maximum iteration achieves
        if iteration == 20001:
            break
コード例 #7
0
def evaluate(model, input, target, stats_dir, probs_dir, iteration):
    """Evaluate a model.
    Args:
      model: object
      output: 2d array, (samples_num, classes_num)
      target: 2d array, (samples_num, classes_num)
      stats_dir: str, directory to write out statistics.
      probs_dir: str, directory to write out output (samples_num, classes_num)
      iteration: int
    Returns:
      None
    """
    # Check if cuda
    cuda = True
    #cuda = next(model.parameters()).is_cuda
    utilities.create_folder(stats_dir)
    utilities.create_folder(probs_dir)

    # Predict presence probabilittarget
    callback_time = time.time()
    (clips_num, time_steps, freq_bins) = input.shape

    (input, target) = utilities.transform_data(input, target)

    output, cla, norm_att, mult = forward_in_batch(model,
                                                   input,
                                                   batch_size=350,
                                                   cuda=cuda)

    output = output.data.cpu().numpy()  # (clips_num, classes_num)

    single = 1
    if single == 1:
        print("output_all cat: ", output.shape)
        print("cla_all cat: ", cla.shape)
        print("cla_all cat: ", cla)
        print("mult_all cat: ", mult)
        print("norm_att_all cat: ", norm_att)

        cla = cla.data.cpu().numpy()
        norm_att = norm_att.data.cpu().numpy()
        mult = mult.data.cpu().numpy()

    #for multy
    multy = 0
    if multy == 1:
        cla = cla.data.cpu().numpy()
        norm_att = norm_att.data.cpu().numpy()
        mult = mult.data.cpu().numpy()
        cla2 = cla2.data.cpu().numpy()
        norm_att2 = norm_att2.data.cpu().numpy()
        mult2 = mult2.data.cpu().numpy()
        print("cla_all cat: ", cla)
        print("mult_all cat: ", mult)
        print("norm_att_all cat: ", norm_att)
        print("cla_all cat: ", cla2)
        print("mult_all cat: ", mult2)
        print("norm_att_all cat: ", norm_att2)

    avg = 0
    if avg == 1:
        print("output_all cat: ", output.shape)
        print("b2: ", b2.shape)
        b2 = b2.data.cpu().numpy()
    '''
    # Write out presence probabilities
    prob_path = os.path.join(probs_dir, "prob_{}_iters.p".format(iteration))
    cPickle.dump(output, open(prob_path, 'wb'))

    # Calculate statistics
    stats = utilities.calculate_stats(output, target)

    # Write out statistics
    stat_path = os.path.join(stats_dir, "stat_{}_iters.p".format(iteration))
    cPickle.dump(stats, open(stat_path, 'wb'))

    mAP = np.mean([stat['AP'] for stat in stats])
    mAUC = np.mean([stat['auc'] for stat in stats])
    logging.info(
        "mAP: {:.6f}, AUC: {:.6f}, Callback time: {:.3f} s".format(
            mAP, mAUC, time.time() - callback_time))

    if False:
        logging.info("Saveing prob to {}".format(prob_path))
        logging.info("Saveing stat to {}".format(stat_path))
        
    '''

    #Save
    totest = 0

    if totest == 0:
        #SAVE MODEL

        dataset = {}
        dataset['output'] = output

        if single == 1:
            dataset['cla'] = cla
            dataset['norm_att'] = norm_att
            dataset['mult'] = mult

        if multy == 1:
            dataset['cla'] = cla
            dataset['norm_att'] = norm_att
            dataset['mult'] = mult
            dataset['cla2'] = cla2
            dataset['norm_att2'] = norm_att2
            dataset['mult2'] = mult2

        if avg == 1:
            dataset['b2'] = b2

        path = r'C:\Users\AdexGomez\Downloads\Master\third_semester\Deep_Learning\Project\02456_project_audioset_attention\data'
        file_name = '\multyt.h5'
        file = path + file_name
        print(file)
        f = h5py.File(file, 'w')
        for k in dataset.keys():
            print('\n  ' + k + ' type=', type(dataset[k][0]))
            if k != 'event_label' and k != 'filename':  # @@@@temporal, need to be fixed!!!
                print('    ... saving ' + k + '...')
                f.create_dataset(k, data=dataset[k])