Beispiel #1
0
def run(args, train, sparse_evidences, claims_dict):
    BATCH_SIZE = args.batch_size
    LEARNING_RATE = args.learning_rate
    DATA_SAMPLING = args.data_sampling
    NUM_EPOCHS = args.epochs
    MODEL = args.model
    RANDOMIZE = args.no_randomize
    PRINT = args.print

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")

    logger = Logger('./logs/{}'.format(time.localtime()))

    if MODEL:
        print("Loading pretrained model...")
        model = torch.load(MODEL)
        model.load_state_dict(torch.load(MODEL).state_dict())
    else:
        model = cdssm.CDSSM()
        model = model.cuda()
        model = model.to(device)

    # model = cdssm.CDSSM()
    # model = model.cuda()
    # model = model.to(device)

    if torch.cuda.device_count() > 0:
        print("Let's use", torch.cuda.device_count(), "GPU(s)!")
        model = nn.DataParallel(model)

    print("Created model with {:,} parameters.".format(
        putils.count_parameters(model)))

    # if MODEL:
    # print("TEMPORARY change to loading!")
    # model.load_state_dict(torch.load(MODEL).state_dict())

    print("Created dataset...")

    # use an 80/20 train/validate split!
    train_size = int(len(train) * 0.80)
    #test = int(len(train) * 0.5)
    train_dataset = pytorch_data_loader.WikiDataset(
        train[:train_size],
        claims_dict,
        data_sampling=DATA_SAMPLING,
        sparse_evidences=sparse_evidences,
        randomize=RANDOMIZE)
    val_dataset = pytorch_data_loader.WikiDataset(
        train[train_size:],
        claims_dict,
        data_sampling=DATA_SAMPLING,
        sparse_evidences=sparse_evidences,
        randomize=RANDOMIZE)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=BATCH_SIZE,
                                  num_workers=0,
                                  shuffle=True,
                                  collate_fn=pytorch_data_loader.PadCollate())
    val_dataloader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                num_workers=0,
                                shuffle=True,
                                collate_fn=pytorch_data_loader.PadCollate())

    # Loss and optimizer
    criterion = torch.nn.NLLLoss()
    # criterion = torch.nn.SoftMarginLoss()
    # if torch.cuda.device_count() > 0:
    # print("Let's parallelize the backward pass...")
    # criterion = DataParallelCriterion(criterion)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=LEARNING_RATE,
                                 weight_decay=1e-3)

    OUTPUT_FREQ = max(int((len(train_dataset) / BATCH_SIZE) * 0.02), 20)
    parameters = {
        "batch size": BATCH_SIZE,
        "epochs": NUM_EPOCHS,
        "learning rate": LEARNING_RATE,
        "optimizer": optimizer.__class__.__name__,
        "loss": criterion.__class__.__name__,
        "training size": train_size,
        "data sampling rate": DATA_SAMPLING,
        "data": args.data,
        "sparse_evidences": args.sparse_evidences,
        "randomize": RANDOMIZE,
        "model": MODEL
    }
    experiment = Experiment(api_key="YLsW4AvRTYGxzdDqlWRGCOhee",
                            project_name="clsm",
                            workspace="moinnadeem")
    experiment.add_tag("train")
    experiment.log_asset("cdssm.py")
    experiment.log_dataset_info(name=args.data)
    experiment.log_parameters(parameters)

    model_checkpoint_dir = "models/saved_model"
    for key, value in parameters.items():
        if type(value) == str:
            value = value.replace("/", "-")
        if key != "model":
            model_checkpoint_dir += "_{}-{}".format(key.replace(" ", "_"),
                                                    value)

    print("Training...")
    beginning_time = time.time()
    best_loss = torch.tensor(float("inf"),
                             dtype=torch.float)  # begin loss at infinity

    for epoch in range(NUM_EPOCHS):
        beginning_time = time.time()
        mean_train_acc = 0.0
        train_running_loss = 0.0
        train_running_accuracy = 0.0
        model.train()
        experiment.log_current_epoch(epoch)

        with experiment.train():
            for train_batch_num, inputs in enumerate(train_dataloader):
                claims_tensors, claims_text, evidences_tensors, evidences_text, labels = inputs

                claims_tensors = claims_tensors.cuda()
                evidences_tensors = evidences_tensors.cuda()
                labels = labels.cuda()
                #claims = claims.to(device).float()
                #evidences = evidences.to(device).float()
                #labels = labels.to(device)

                y_pred = model(claims_tensors, evidences_tensors)

                y = (labels)
                # y = y.unsqueeze(0)
                # y = y.unsqueeze(0)
                # y_pred = parallel.gather(y_pred, 0)

                y_pred = y_pred.squeeze()
                # y = y.squeeze()

                loss = criterion(y_pred, torch.max(y, 1)[1])
                # loss = criterion(y_pred, y)

                y = y.float()
                binary_y = torch.max(y, 1)[1]
                binary_pred = torch.max(y_pred, 1)[1]
                accuracy = (binary_y == binary_pred).to("cuda")
                accuracy = accuracy.float()
                accuracy = accuracy.mean()
                train_running_accuracy += accuracy.item()
                mean_train_acc += accuracy.item()
                train_running_loss += loss.item()

                if PRINT:
                    for idx in range(len(y)):
                        print(
                            "Claim: {}, Evidence: {}, Prediction: {}, Label: {}"
                            .format(claims_text[0], evidences_text[idx],
                                    torch.exp(y_pred[idx]), y[idx]))

                if (train_batch_num %
                        OUTPUT_FREQ) == 0 and train_batch_num > 0:
                    elapsed_time = time.time() - beginning_time
                    binary_y = torch.max(y, 1)[1]
                    binary_pred = torch.max(y_pred, 1)[1]
                    print(
                        "[{}:{}:{:3f}s] training loss: {}, training accuracy: {}, training recall: {}"
                        .format(
                            epoch, train_batch_num /
                            (len(train_dataset) / BATCH_SIZE), elapsed_time,
                            train_running_loss / OUTPUT_FREQ,
                            train_running_accuracy / OUTPUT_FREQ,
                            recall_score(binary_y.cpu().detach().numpy(),
                                         binary_pred.cpu().detach().numpy())))

                    # 1. Log scalar values (scalar summary)
                    info = {
                        'train_loss': train_running_loss / OUTPUT_FREQ,
                        'train_accuracy': train_running_accuracy / OUTPUT_FREQ
                    }

                    for tag, value in info.items():
                        experiment.log_metric(tag,
                                              value,
                                              step=train_batch_num *
                                              (epoch + 1))
                        logger.scalar_summary(tag, value, train_batch_num + 1)

                    ## 2. Log values and gradients of the parameters (histogram summary)
                    for tag, value in model.named_parameters():
                        tag = tag.replace('.', '/')
                        logger.histo_summary(tag,
                                             value.detach().cpu().numpy(),
                                             train_batch_num + 1)
                        logger.histo_summary(tag + '/grad',
                                             value.grad.detach().cpu().numpy(),
                                             train_batch_num + 1)

                    train_running_loss = 0.0
                    beginning_time = time.time()
                    train_running_accuracy = 0.0
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # del loss
        # del accuracy
        # del claims_tensors
        # del claims_text
        # del evidences_tensors
        # del evidences_text
        # del labels
        # del y
        # del y_pred
        # torch.cuda.empty_cache()

        print("Running validation...")
        model.eval()
        pred = []
        true = []
        avg_loss = 0.0
        val_running_accuracy = 0.0
        val_running_loss = 0.0
        beginning_time = time.time()
        with experiment.validate():
            for val_batch_num, val_inputs in enumerate(val_dataloader):
                claims_tensors, claims_text, evidences_tensors, evidences_text, labels = val_inputs

                claims_tensors = claims_tensors.cuda()
                evidences_tensors = evidences_tensors.cuda()
                labels = labels.cuda()

                y_pred = model(claims_tensors, evidences_tensors)

                y = (labels)
                # y_pred = parallel.gather(y_pred, 0)

                y_pred = y_pred.squeeze()

                loss = criterion(y_pred, torch.max(y, 1)[1])

                y = y.float()

                binary_y = torch.max(y, 1)[1]
                binary_pred = torch.max(y_pred, 1)[1]
                true.extend(binary_y.tolist())
                pred.extend(binary_pred.tolist())

                accuracy = (binary_y == binary_pred).to("cuda")

                accuracy = accuracy.float().mean()
                val_running_accuracy += accuracy.item()
                val_running_loss += loss.item()
                avg_loss += loss.item()

                if (val_batch_num % OUTPUT_FREQ) == 0 and val_batch_num > 0:
                    elapsed_time = time.time() - beginning_time
                    print(
                        "[{}:{}:{:3f}s] validation loss: {}, accuracy: {}, recall: {}"
                        .format(
                            epoch,
                            val_batch_num / (len(val_dataset) / BATCH_SIZE),
                            elapsed_time, val_running_loss / OUTPUT_FREQ,
                            val_running_accuracy / OUTPUT_FREQ,
                            recall_score(binary_y.cpu().detach().numpy(),
                                         binary_pred.cpu().detach().numpy())))

                    # 1. Log scalar values (scalar summary)
                    info = {'val_accuracy': val_running_accuracy / OUTPUT_FREQ}

                    for tag, value in info.items():
                        experiment.log_metric(tag,
                                              value,
                                              step=val_batch_num * (epoch + 1))
                        logger.scalar_summary(tag, value, val_batch_num + 1)

                    ## 2. Log values and gradients of the parameters (histogram summary)
                    for tag, value in model.named_parameters():
                        tag = tag.replace('.', '/')
                        logger.histo_summary(tag,
                                             value.detach().cpu().numpy(),
                                             val_batch_num + 1)
                        logger.histo_summary(tag + '/grad',
                                             value.grad.detach().cpu().numpy(),
                                             val_batch_num + 1)

                    val_running_accuracy = 0.0
                    val_running_loss = 0.0
                    beginning_time = time.time()

        # del loss
        # del accuracy
        # del claims_tensors
        # del claims_text
        # del evidences_tensors
        # del evidences_text
        # del labels
        # del y
        # del y_pred
        # torch.cuda.empty_cache()

        accuracy = accuracy_score(true, pred)
        print("[{}] mean accuracy: {}, mean loss: {}".format(
            epoch, accuracy, avg_loss / len(val_dataloader)))

        true = np.array(true).astype("int")
        pred = np.array(pred).astype("int")
        print(classification_report(true, pred))

        best_loss = torch.tensor(
            min(avg_loss / len(val_dataloader),
                best_loss.cpu().numpy()))
        is_best = bool((avg_loss / len(val_dataloader)) <= best_loss)

        putils.save_checkpoint(
            {
                "epoch": epoch,
                "model": model,
                "best_loss": best_loss
            },
            is_best,
            filename="{}_loss_{}".format(model_checkpoint_dir,
                                         best_loss.cpu().numpy()))
Beispiel #2
0
def train(args):
    """Train AudioSet tagging model. 

    Args:
      dataset_dir: str
      workspace: str
      data_type: 'balanced_train' | 'unbalanced_train'
      frames_per_second: int
      mel_bins: int
      model_type: str
      loss_type: 'bce'
      balanced: bool
      augmentation: str
      batch_size: int
      learning_rate: float
      resume_iteration: int
      early_stop: int
      accumulation_steps: int
      cuda: bool
    """

    # Arugments & parameters
    workspace = args.workspace
    data_type = args.data_type
    window_size = args.window_size
    hop_size = args.hop_size
    mel_bins = args.mel_bins
    fmin = args.fmin
    fmax = args.fmax
    model_type = args.model_type
    loss_type = args.loss_type
    balanced = args.balanced
    augmentation = args.augmentation
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    resume_iteration = args.resume_iteration
    early_stop = args.early_stop
    device = torch.device('cuda') if args.cuda and torch.cuda.is_available(
    ) else torch.device('cpu')
    filename = args.filename

    num_workers = 8
    sample_rate = config.sample_rate
    clip_samples = config.clip_samples
    classes_num = config.classes_num
    loss_func = get_loss_func(loss_type)

    # Paths
    black_list_csv = os.path.join(workspace, 'black_list',
                                  'dcase2017task4.csv')

    train_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                           '{}.h5'.format(data_type))

    eval_bal_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                              'balanced_train.h5')

    eval_test_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                               'eval.h5')

    checkpoints_dir = os.path.join(
        workspace, 'checkpoints', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))
    create_folder(checkpoints_dir)

    statistics_path = os.path.join(
        workspace, 'statistics', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size), 'statistics.pkl')
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(
        workspace, 'logs', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))

    create_logging(logs_dir, filemode='w')
    logging.info(args)

    if 'cuda' in str(device):
        logging.info('Using GPU.')
        device = 'cuda'
    else:
        logging.info('Using CPU.')
        device = 'cpu'

    # Model
    Model = eval(model_type)
    model = Model(sample_rate=sample_rate,
                  window_size=window_size,
                  hop_size=hop_size,
                  mel_bins=mel_bins,
                  fmin=fmin,
                  fmax=fmax,
                  classes_num=classes_num)

    params_num = count_parameters(model)
    # flops_num = count_flops(model, clip_samples)
    logging.info('Parameters num: {}'.format(params_num))
    # logging.info('Flops num: {:.3f} G'.format(flops_num / 1e9))

    # Dataset will be used by DataLoader later. Dataset takes a meta as input
    # and return a waveform and a target.
    dataset = AudioSetDataset(clip_samples=clip_samples,
                              classes_num=classes_num)

    # Train sampler
    (train_sampler,
     train_collector) = get_train_sampler(balanced, augmentation,
                                          train_indexes_hdf5_path,
                                          black_list_csv, batch_size)

    # Evaluate sampler
    eval_bal_sampler = EvaluateSampler(
        indexes_hdf5_path=eval_bal_indexes_hdf5_path, batch_size=batch_size)

    eval_test_sampler = EvaluateSampler(
        indexes_hdf5_path=eval_test_indexes_hdf5_path, batch_size=batch_size)

    eval_collector = Collator(mixup_alpha=None)

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_sampler=train_sampler,
                                               collate_fn=train_collector,
                                               num_workers=num_workers,
                                               pin_memory=True)

    eval_bal_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_bal_sampler,
        collate_fn=eval_collector,
        num_workers=num_workers,
        pin_memory=True)

    eval_test_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_test_sampler,
        collate_fn=eval_collector,
        num_workers=num_workers,
        pin_memory=True)

    # Evaluator
    bal_evaluator = Evaluator(model=model, generator=eval_bal_loader)
    test_evaluator = Evaluator(model=model, generator=eval_test_loader)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)

    train_bgn_time = time.time()

    # Resume training
    if resume_iteration > 0:
        resume_checkpoint_path = os.path.join(
            workspace, 'checkpoints', filename,
            'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
            .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                    fmax), 'data_type={}'.format(data_type), model_type,
            'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
            'augmentation={}'.format(augmentation),
            'batch_size={}'.format(batch_size),
            '{}_iterations.pth'.format(resume_iteration))

        logging.info('Loading checkpoint {}'.format(resume_checkpoint_path))
        checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(checkpoint['model'])
        train_sampler.load_state_dict(checkpoint['sampler'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = checkpoint['iteration']

    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in str(device):
        model.to(device)

    time1 = time.time()

    for batch_data_dict in train_loader:
        """batch_data_dict: {
            'audio_name': (batch_size [*2 if mixup],), 
            'waveform': (batch_size [*2 if mixup], clip_samples), 
            'target': (batch_size [*2 if mixup], classes_num), 
            (ifexist) 'mixup_lambda': (batch_size * 2,)}
        """

        # Evaluate
        if (iteration % 2000 == 0
                and iteration > resume_iteration) or (iteration == 0):
            train_fin_time = time.time()

            bal_statistics = bal_evaluator.evaluate()
            test_statistics = test_evaluator.evaluate()

            logging.info('Validate bal mAP: {:.3f}'.format(
                np.mean(bal_statistics['average_precision'])))

            logging.info('Validate test mAP: {:.3f}'.format(
                np.mean(test_statistics['average_precision'])))

            statistics_container.append(iteration,
                                        bal_statistics,
                                        data_type='bal')
            statistics_container.append(iteration,
                                        test_statistics,
                                        data_type='test')
            statistics_container.dump()

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info(
                'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s'
                ''.format(iteration, train_time, validate_time))

            logging.info('------------------------------------')

            train_bgn_time = time.time()

        # Save model
        if iteration % 20000 == 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        # Move data to device
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key],
                                                       device)

        # Forward
        model.train()

        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'],
                                      batch_data_dict['mixup_lambda'])
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {
                'target':
                do_mixup(batch_data_dict['target'],
                         batch_data_dict['mixup_lambda'])
            }
            """{'target': (batch_size, classes_num)}"""
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {'target': batch_data_dict['target']}
            """{'target': (batch_size, classes_num)}"""

        # Loss
        loss = loss_func(batch_output_dict, batch_target_dict)

        # Backward
        loss.backward()
        print(loss)

        optimizer.step()
        optimizer.zero_grad()

        if iteration % 10 == 0:
            print('--- Iteration: {}, train time: {:.3f} s / 10 iterations ---'\
                .format(iteration, time.time() - time1))
            time1 = time.time()

        iteration += 1

        # Stop learning
        if iteration == early_stop:
            break
Beispiel #3
0
def train(args):
    """Train AudioSet tagging model. 

    Args:
      dataset_dir: str
      workspace: str
      data_type: 'balanced_train' | 'full_train'
      window_size: int
      hop_size: int
      mel_bins: int
      model_type: str
      loss_type: 'clip_bce'
      balanced: 'none' | 'balanced' | 'alternate'
      augmentation: 'none' | 'mixup'
      batch_size: int
      learning_rate: float
      resume_iteration: int
      early_stop: int
      accumulation_steps: int
      cuda: bool
    """

    # Arugments & parameters
    workspace = args.workspace
    data_type = args.data_type
    sample_rate = args.sample_rate
    window_size = args.window_size
    hop_size = args.hop_size
    mel_bins = args.mel_bins
    fmin = args.fmin
    fmax = args.fmax
    model_type = args.model_type
    loss_type = args.loss_type
    balanced = args.balanced
    augmentation = args.augmentation
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    resume_iteration = args.resume_iteration
    early_stop = args.early_stop
    device = torch.device('cuda') if args.cuda and torch.cuda.is_available(
    ) else torch.device('cpu')
    filename = args.filename

    num_workers = 128
    prefetch_factor = 4

    #os.environ["MASTER_ADDR"] = "localhost"
    #os.environ["MASTER_PORT"] = "12355"
    #dist.init_process_group("nccl", rank=rank, world_size=args.world_size)

    clip_samples = config.clip_samples
    classes_num = config.classes_num
    loss_func = get_loss_func(loss_type)

    # Paths
    black_list_csv = None

    train_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                           '{}.h5'.format(data_type))

    eval_bal_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                              'balanced_train.h5')

    eval_test_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
                                               'eval.h5')

    checkpoints_dir = os.path.join(
        workspace, 'checkpoints', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size),
        datetime.datetime.now().strftime("%d%m%Y_%H%M%S"))

    #if rank == 0:
    create_folder(checkpoints_dir)

    statistics_path = os.path.join(
        workspace, 'statistics', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size), 'statistics.pkl')

    #if rank == 0:
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(
        workspace, 'logs', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))

    create_logging(logs_dir, filemode='w')
    logging.info(args)

    if 'cuda' in str(device):
        logging.info('Using GPU.')
        device = 'cuda'
    else:
        logging.info('Using CPU. Set --cuda flag to use GPU.')
        device = 'cpu'

    # Model
    Model = eval(model_type)
    model = Model(sample_rate=sample_rate,
                  window_size=window_size,
                  hop_size=hop_size,
                  mel_bins=mel_bins,
                  fmin=fmin,
                  fmax=fmax,
                  classes_num=classes_num)

    params_num = count_parameters(model)
    # flops_num = count_flops(model, clip_samples)
    logging.info('Parameters num: {}'.format(params_num))
    # logging.info('Flops num: {:.3f} G'.format(flops_num / 1e9))

    # Dataset will be used by DataLoader later. Dataset takes a meta as input
    # and return a waveform and a target.
    dataset = AudioSetDataset(sample_rate=sample_rate)

    # Train sampler
    if balanced == 'none':
        Sampler = TrainSampler
    elif balanced == 'balanced':
        Sampler = BalancedTrainSampler
    elif balanced == 'alternate':
        Sampler = AlternateTrainSampler

    train_sampler = Sampler(indexes_hdf5_path=train_indexes_hdf5_path,
                            batch_size=batch_size *
                            2 if 'mixup' in augmentation else batch_size,
                            black_list_csv=black_list_csv)

    # Evaluate sampler
    eval_bal_sampler = EvaluateSampler(
        indexes_hdf5_path=eval_bal_indexes_hdf5_path,
        batch_size=2 * batch_size)

    eval_test_sampler = EvaluateSampler(
        indexes_hdf5_path=eval_test_indexes_hdf5_path,
        batch_size=2 * batch_size)

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_sampler=train_sampler,
                                               collate_fn=collate_fn,
                                               num_workers=num_workers,
                                               pin_memory=True,
                                               prefetch_factor=prefetch_factor)

    eval_bal_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_bal_sampler,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=True,
        prefetch_factor=prefetch_factor)

    eval_test_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_test_sampler,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=True,
        prefetch_factor=prefetch_factor)

    if 'mixup' in augmentation:
        mixup_augmenter = Mixup(mixup_alpha=1.)

    # Evaluator
    evaluator = Evaluator(model=model)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)

    train_bgn_time = time.time()

    # Resume training
    if resume_iteration > 0:
        resume_checkpoint_path = os.path.join(
            workspace, 'checkpoints', filename,
            'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
            .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                    fmax), 'data_type={}'.format(data_type), model_type,
            'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
            'augmentation={}'.format(augmentation),
            'batch_size={}'.format(batch_size),
            '{}_iterations.pth'.format(resume_iteration))

        logging.info('Loading checkpoint {}'.format(resume_checkpoint_path))
        checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(checkpoint['model'])
        train_sampler.load_state_dict(checkpoint['sampler'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = checkpoint['iteration']

    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in str(device):
        model.to(device)
        #model = model.cuda(rank)

    #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    #print([(s[0], s[1].is_cuda) for s in model.named_parameters()])

    time1 = time.time()

    prev_bal_map = 0.0
    prev_test_map = 0.0
    save_bal_model = 0
    save_test_model = 0

    for batch_data_dict in train_loader:
        """batch_data_dict: {
            'audio_name': (batch_size [*2 if mixup],), 
            'waveform': (batch_size [*2 if mixup], clip_samples), 
            'target': (batch_size [*2 if mixup], classes_num), 
            (ifexist) 'mixup_lambda': (batch_size * 2,)}
        """
        #print(batch_data_dict)
        # Evaluate
        if (iteration % 2000 == 0
                and iteration > resume_iteration) or (iteration == -1):
            train_fin_time = time.time()

            bal_statistics = evaluator.evaluate(eval_bal_loader)
            test_statistics = evaluator.evaluate(eval_test_loader)

            logging.info('Validate bal mAP: {:.3f}'.format(
                np.mean(bal_statistics['average_precision'])))

            logging.info('Validate test mAP: {:.3f}'.format(
                np.mean(test_statistics['average_precision'])))

            save_bal_model = 1 if np.mean(
                bal_statistics['average_precision']) > prev_bal_map else 0
            save_test_model = 1 if np.mean(
                test_statistics['average_precision']) > prev_test_map else 0

            statistics_container.append(iteration,
                                        bal_statistics,
                                        data_type='bal')
            statistics_container.append(iteration,
                                        test_statistics,
                                        data_type='test')
            statistics_container.dump()

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info(
                'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s'
                ''.format(iteration, train_time, validate_time))

            logging.info('------------------------------------')

            train_bgn_time = time.time()

        # Save model
        if iteration % 100000 == 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        if save_bal_model:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations_bal.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))
            save_bal_model = 0

        if save_test_model:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations_test.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))
            save_test_model = 0

        # Mixup lambda
        if 'mixup' in augmentation:
            batch_data_dict['mixup_lambda'] = mixup_augmenter.get_lambda(
                batch_size=len(batch_data_dict['waveform']))

        # Move data to device
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key],
                                                       device)

        # Forward
        model.train()

        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'],
                                      batch_data_dict['mixup_lambda'])
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {
                'target':
                do_mixup(batch_data_dict['target'],
                         batch_data_dict['mixup_lambda'])
            }
            """{'target': (batch_size, classes_num)}"""
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {'target': batch_data_dict['target']}
            """{'target': (batch_size, classes_num)}"""

        # Loss
        loss = loss_func(batch_output_dict, batch_target_dict)

        # Backward
        loss.backward()
        print(loss)

        optimizer.step()
        optimizer.zero_grad()

        if iteration % 10 == 0:
            print('--- Iteration: {}, train time: {:.3f} s / 10 iterations ---'\
                .format(iteration, time.time() - time1))
            time1 = time.time()

        # Stop learning
        if iteration == early_stop:
            break

        iteration += 1
def train(args):

    # Arugments & parameters
    window_size = args.window_size
    hop_size = args.hop_size
    mel_bins = args.mel_bins
    fmin = args.fmin
    fmax = args.fmax
    model_type = args.model_type
    pretrained_checkpoint_path = args.pretrained_checkpoint_path
    freeze_base = args.freeze_base
    freeze_base = True
    device = 'cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu'
    sample_rate = config.sample_rate
    classes_num = config.classes_num
    pretrain = True if pretrained_checkpoint_path else False

    # Model
    Model = eval(model_type)
    model = Model(sample_rate, window_size, hop_size, mel_bins, fmin, fmax,
                  classes_num, freeze_base)

    # Load pretrained model
    if pretrain:
        logging.info(
            'Load pretrained model from {}'.format(pretrained_checkpoint_path))
        model.load_from_pretrain(pretrained_checkpoint_path)

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in device:
        model.to(device)

    print('Load pretrained model successfully!')
    ###############Copying main.py####################
    workspace_input = args.workspace_input
    workspace_output = args.workspace_output
    data_type = 'balanced_train'
    loss_type = 'clip_bce'
    balanced = 'balanced'
    augmentation = 'none'
    batch_size = 1
    learning_rate = 1e-3
    resume_iteration = 0
    early_stop = 100000
    device = torch.device('cuda') if args.cuda and torch.cuda.is_available(
    ) else torch.device('cpu')
    filename = args.filename
    num_workers = 8
    clip_samples = config.clip_samples
    loss_func = get_loss_func(loss_type)
    black_list_csv = 'metadata/black_list/groundtruth_weak_label_evaluation_set.csv'
    previous_loss = None

    train_indexes_hdf5_path = os.path.join(workspace_input, 'hdf5s', 'indexes',
                                           '{}.h5'.format(data_type))

    eval_bal_indexes_hdf5_path = os.path.join(workspace_input, 'hdf5s',
                                              'indexes', 'balanced_train.h5')

    eval_test_indexes_hdf5_path = os.path.join(workspace_input, 'hdf5s',
                                               'indexes', 'eval.h5')

    checkpoints_dir = os.path.join(
        workspace_output, 'checkpoints', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))
    create_folder(checkpoints_dir)

    statistics_path = os.path.join(
        workspace_output, 'statistics', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size), 'statistics.pkl')
    create_folder(os.path.dirname(statistics_path))

    logs_dir = os.path.join(
        workspace_output, 'logs', filename,
        'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
        .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                fmax), 'data_type={}'.format(data_type), model_type,
        'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
        'augmentation={}'.format(augmentation),
        'batch_size={}'.format(batch_size))

    create_logging(logs_dir, filemode='w')
    logging.info(args)

    if 'cuda' in str(device):
        logging.info('Using GPU.')
        device = 'cuda'
    else:
        logging.info('Using CPU.')
        device = 'cpu'

    # Model
    Model = eval(model_type)
    model = Model(sample_rate=sample_rate,
                  window_size=window_size,
                  hop_size=hop_size,
                  mel_bins=mel_bins,
                  fmin=fmin,
                  fmax=fmax,
                  classes_num=classes_num,
                  freeze_base=freeze_base)
    params_num = count_parameters(model)
    # flops_num = count_flops(model, clip_samples)
    logging.info('Parameters num: {}'.format(params_num))
    # logging.info('Flops num: {:.3f} G'.format(flops_num / 1e9))

    # Dataset will be used by DataLoader later. Dataset takes a meta as input
    # and return a waveform and a target.
    dataset = AudioSetDataset(clip_samples=clip_samples,
                              classes_num=classes_num)

    # Train sampler
    (train_sampler, train_collector) = get_train_sampler(
        balanced, augmentation,
        workspace_input + 'hdf5s/indexes/balanced_train.h5', black_list_csv,
        batch_size)

    # Evaluate sampler
    eval_bal_sampler = EvaluateSampler(indexes_hdf5_path=workspace_input +
                                       'hdf5s/indexes/balanced_train.h5',
                                       batch_size=batch_size)

    eval_test_sampler = EvaluateSampler(indexes_hdf5_path=workspace_input +
                                        'hdf5s/indexes/eval.h5',
                                        batch_size=batch_size)

    eval_collector = Collator(mixup_alpha=None)

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_sampler=train_sampler,
                                               collate_fn=train_collector,
                                               num_workers=num_workers,
                                               pin_memory=True)

    eval_bal_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_bal_sampler,
        collate_fn=eval_collector,
        num_workers=num_workers,
        pin_memory=True)

    eval_test_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_sampler=eval_test_sampler,
        collate_fn=eval_collector,
        num_workers=num_workers,
        pin_memory=True)

    # Evaluator
    bal_evaluator = Evaluator(model=model, generator=eval_bal_loader)
    test_evaluator = Evaluator(model=model, generator=eval_test_loader)

    # Statistics
    statistics_container = StatisticsContainer(statistics_path)

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.,
                           amsgrad=True)

    train_bgn_time = time.time()
    if resume_iteration > 0:
        resume_checkpoint_path = os.path.join(
            workspace_input, 'checkpoints', filename,
            'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'
            .format(sample_rate, window_size, hop_size, mel_bins, fmin,
                    fmax), 'data_type={}'.format(data_type), model_type,
            'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
            'augmentation={}'.format(augmentation),
            'batch_size={}'.format(batch_size),
            '{}_iterations.pth'.format(resume_iteration))

        logging.info('Loading checkpoint {}'.format(resume_checkpoint_path))
        if torch.cuda.is_available():
            checkpoint = torch.load(resume_checkpoint_path)
        else:
            checkpoint = torch.load(resume_checkpoint_path, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        train_sampler.load_state_dict(checkpoint['sampler'])
        statistics_container.load_state_dict(resume_iteration)
        iteration = checkpoint['iteration']

    else:
        iteration = 0

    # Parallel
    print('GPU number: {}'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if 'cuda' in str(device):
        model.to(device)

    time1 = time.time()

    for iterate_n, batch_data_dict in enumerate(train_loader):
        """batch_data_dict: {
            'audio_name': (batch_size [*2 if mixup],), 
            'waveform': (batch_size [*2 if mixup], clip_samples), 
            'target': (batch_size [*2 if mixup], classes_num), 
            (ifexist) 'mixup_lambda': (batch_size * 2,)}
        """

        # Evaluate
        if (iteration % 2000 == 0
                and iteration > resume_iteration) or (iteration == 0):
            train_fin_time = time.time()

            bal_statistics = bal_evaluator.evaluate()
            test_statistics = test_evaluator.evaluate()

            logging.info('Validate bal mAP: {:.3f}'.format(
                np.mean(bal_statistics['average_precision'])))

            logging.info('Validate test mAP: {:.3f}'.format(
                np.mean(test_statistics['average_precision'])))

            statistics_container.append(iteration,
                                        bal_statistics,
                                        data_type='bal')
            statistics_container.append(iteration,
                                        test_statistics,
                                        data_type='test')
            statistics_container.dump()

            train_time = train_fin_time - train_bgn_time
            validate_time = time.time() - train_fin_time

            logging.info(
                'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s'
                ''.format(iteration, train_time, validate_time))

            logging.info('------------------------------------')

            train_bgn_time = time.time()

        # Save model
        if iteration % 20000 == 0:
            checkpoint = {
                'iteration': iteration,
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'sampler': train_sampler.state_dict()
            }

            checkpoint_path = os.path.join(
                checkpoints_dir, '{}_iterations.pth'.format(iteration))

            torch.save(checkpoint, checkpoint_path)
            logging.info('Model saved to {}'.format(checkpoint_path))

        # Move data to device
        for key in batch_data_dict.keys():
            batch_data_dict[key] = move_data_to_device(batch_data_dict[key],
                                                       device)

        # Forward
        model.train()
        if 'mixup' in augmentation:
            batch_output_dict = model(batch_data_dict['waveform'],
                                      batch_data_dict['mixup_lambda'])
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {
                'target':
                do_mixup(batch_data_dict['target'],
                         batch_data_dict['mixup_lambda'])
            }
            """{'target': (batch_size, classes_num)}"""
        else:
            batch_output_dict = model(batch_data_dict['waveform'], None)
            """{'clipwise_output': (batch_size, classes_num), ...}"""

            batch_target_dict = {'target': batch_data_dict['target']}
            """{'target': (batch_size, classes_num)}"""
        loss = loss_func(batch_output_dict, batch_target_dict)
        # Loss
        # try:
        #     loss = loss_func(batch_output_dict, batch_target_dict)
        # except:
        #     tensor = batch_output_dict['clipwise_output'].detach().numpy()
        #     arr = -1. * np.where(tensor > 0,0.,tensor)
        #     batch_output_dict['clipwise_output'] = torch.tensor(np.where(arr > 1,1.,arr),requires_grad=True)
        #     loss = loss_func(batch_output_dict, batch_target_dict)
        # Backward
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if iteration % 10 == 0:
            print('--- Iteration: {}, train time: {:.3f} s / 10 iterations ---'\
                .format(iteration, time.time() - time1))
            time1 = time.time()

        iteration += 1

        # Stop learning
        if iteration == early_stop:
            break