Esempio n. 1
0
 def test_basic(self):
     """
     Basic test to check that the calculation is sensible and conforms to the formula.
     """
     test_tensor = torch.Tensor([[0.5, 0.5], [0.0, 1.0]])
     output = target_distribution(test_tensor)
     self.assertAlmostEqual(tuple(output[0]), (0.75, 0.25))
     self.assertAlmostEqual(tuple(output[1]), (0.0, 1.0))
Esempio n. 2
0
def train(
        dataset: torch.utils.data.Dataset,
        model: torch.nn.Module,
        epochs: int,
        batch_size: int,
        optimizer: torch.optim.Optimizer,
        stopping_delta: Optional[float] = None,
        cuda: bool = True,
        sampler: Optional[torch.utils.data.sampler.Sampler] = None,
        silent: bool = True,
        update_freq: int = 10,
        evaluate_batch_size: int = 1024,
        update_callback: Optional[Callable[[float, float], None]] = None,
        epoch_callback: Optional[Callable[[int, torch.nn.Module], None]] = None) -> None:
    """
    Train the DEC model given a dataset, a model instance and various configuration parameters.

    :param dataset: instance of Dataset to use for training
    :param model: instance of DEC model to train
    :param epochs: number of training epochs
    :param batch_size: size of the batch to train with
    :param optimizer: instance of optimizer to use
    :param stopping_delta: label delta as a proportion to use for stopping, None to disable, default None
    :param cuda: whether to use CUDA, defaults to True
    :param sampler: optional sampler to use in the DataLoader, defaults to None
    :param silent: set to True to prevent printing out summary statistics, defaults to False
    :param update_freq: frequency of batches with which to update counter, None disables, default 10
    :param evaluate_batch_size: batch size for evaluation stage, default 1024
    :param update_callback: optional function of accuracy and loss to update, default None
    :param epoch_callback: optional function of epoch and model, default None
    :return: None
    """
    static_dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=False,
        sampler=sampler,
        shuffle=False
    )
    train_dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=True
    )
    data_iterator = tqdm(
        static_dataloader,
        leave=True,
        unit='batch',
        postfix={
            'epo': -1,
            'acc': '%.4f' % 0.0,
            'lss': '%.8f' % 0.0,
            'dlb': '%.4f' % -1,
        },
        disable=silent
    )
    kmeans = KMeans(n_clusters=model.cluster_number, n_init=20)
    model.train()
    features = []
    actual = []
    # form initial cluster centres
    for index, batch in enumerate(data_iterator):
        if (isinstance(batch, tuple) or isinstance(batch, list)) and len(batch) == 2:
            batch, value = batch  # if we have a prediction label, separate it to actual
            actual.append(value)
        if cuda:
            batch = batch.cuda(non_blocking=True)
        features.append(model.encoder(batch).detach().cpu())
        # features.append(model.encoder(batch))
    actual = torch.cat(actual).long()
    predicted = kmeans.fit_predict(torch.cat(features).numpy())
    predicted_previous = torch.tensor(np.copy(predicted), dtype=torch.long)
    _, accuracy = cluster_accuracy(predicted, actual.cpu().numpy())
    cluster_centers = torch.tensor(kmeans.cluster_centers_, dtype=torch.float)
    if cuda:
        cluster_centers = cluster_centers.cuda(non_blocking=True)
    model.assignment.cluster_centers = torch.nn.Parameter(cluster_centers)
    loss_function = nn.KLDivLoss(size_average=False)
    delta_label = None
    for epoch in range(epochs):
        features = []
        data_iterator = tqdm(
            train_dataloader,
            leave=True,
            unit='batch',
            postfix={
                'epo': epoch,
                'acc': '%.4f' % (accuracy or 0.0),
                'lss': '%.8f' % 0.0,
                'dlb': '%.4f' % (delta_label or 0.0),
            },
            disable=silent,
        )
        model.train()
        for index, batch in enumerate(data_iterator):
            if (isinstance(batch, tuple) or isinstance(batch, list)) and len(batch) == 2:
                batch, _ = batch  # if we have a prediction label, strip it away
            if cuda:
                batch = batch.cuda(non_blocking=True)
            output = model(batch)
            target = target_distribution(output).detach()
            loss = loss_function(output.log(), target) / output.shape[0]
            data_iterator.set_postfix(
                epo=epoch,
                acc='%.4f' % (accuracy or 0.0),
                lss='%.8f' % float(loss.item()),
                dlb='%.4f' % (delta_label or 0.0),
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step(closure=None)
            features.append(model.encoder(batch).detach().cpu())
            # features.append(model.encoder(batch))
            if update_freq is not None and index % update_freq == 0:
                loss_value = float(loss.item())
                data_iterator.set_postfix(
                    epo=epoch,
                    acc='%.4f' % (accuracy or 0.0),
                    lss='%.8f' % loss_value,
                    dlb='%.4f' % (delta_label or 0.0),
                )
                if update_callback is not None:
                    update_callback(accuracy, loss_value, delta_label)
        predicted, actual = predict(dataset, model, evaluate_batch_size, silent=True, return_actual=True, cuda=cuda)
        delta_label = float((predicted != predicted_previous).float().sum().item()) / predicted_previous.shape[0]
        # if stopping_delta is not None and delta_label < stopping_delta:
        #     print('Early stopping as label delta "%1.5f" less than "%1.5f".' % (delta_label, stopping_delta))
        #     break
        predicted_previous = predicted
        _, accuracy = cluster_accuracy(predicted.cpu().numpy(), actual.cpu().numpy())
        data_iterator.set_postfix(
            epo=epoch,
            acc='%.4f' % (accuracy or 0.0),
            lss='%.8f' % 0.0,
            dlb='%.4f' % (delta_label or 0.0),
        )
        if epoch_callback is not None:
            epoch_callback(epoch, model)
Esempio n. 3
0
def train(dataset: torch.utils.data.Dataset,
          wdec: torch.nn.Module,
          epochs: int,
          batch_size: int,
          optimizer: torch.optim.Optimizer,
          reinitKMeans: bool = True,
          scheduler = None, ###
          positive_ratio: float = 0.6, ###
          stopping_delta: Optional[float] = None,
          collate_fn = default_collate,
          cuda: bool = True,
          sampler: Optional[torch.utils.data.sampler.Sampler] = None,
          silent: bool = False,
          update_freq: int = 10,
          evaluate_batch_size: int = 1024,
          update_callback: Optional[Callable[[float, float], None]] = None,
          epoch_callback: Optional[Callable[[int, torch.nn.Module], None]] = None,
          start_time: Optional[float] = None,           
          ) -> None:
    """
    Train the DEC model given a dataset, a model instance and various configuration parameters.

    :param dataset: instance of Dataset to use for training
    :param model: instance of DEC model to train
    :param epochs: number of training epochs
    :param batch_size: size of the batch to train with
    :param reinitKMeans: if true, the clusters will be initialized.
    :param optimizer: instance of optimizer to use
    :param scheduler: instance of lr_scheduler to use
    :param stopping_delta: label delta as a proportion to use for stopping, None to disable, default None
    :param collate_fn: function to merge a list of samples into mini-batch
    :param cuda: whether to use CUDA, defaults to True
    :param sampler: optional sampler to use in the DataLoader, defaults to None
    :param silent: set to True to prevent printing out summary statistics, defaults to False
    :param update_freq: frequency of batches with which to update counter, None disables, default 10
    :param evaluate_batch_size: batch size for evaluation stage, default 1024
    :param update_callback:sample_weight optional function of accuracy and loss to update, default None
    :param epoch_callback: optional function of epoch and model, default None
    :param start_time: optional starting time of training process, default None
    :return: None
    """
    static_dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        pin_memory=False,
        sampler=sampler,
        shuffle=False
    )
    train_dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        sampler=sampler,
        shuffle=True
    )
    data_iterator = tqdm(
        static_dataloader,
        leave=True,
        unit='batch',
        postfix={
            'epo': -1,
            'acc': '%.4f' % 0.0,
            'lss': '%.8f' % 0.0,
            'dlb': '%.4f' % -1,
        },
        disable=silent
    )
    wdec.train()
    
    test_dataset(dataset)
    
    
    if reinitKMeans:
        # get all data needed for KMeans.
        if start_time is not None:
            print('\nLinearizing data')
            print(f'@ {time.time() - start_time}\n')
        features, actual, idxs, boxs, videos, frames = DataSetExtract(dataset, wdec)
               
        # KMeans.
        if start_time is not None:
            print('\nPerforming KMeans')
            print(f'@ {time.time() - start_time}\n')
        predicted, kmeans = SSKMeans(
            wdec, features, actual, idxs, boxs, videos, frames
        )
        # Computing the positive ration scores and the positive ratio clusters
        cpr = PositiveRatioClusters(
            predicted, actual, wdec.assignment.cluster_number,
        )
        predicted_previous = torch.tensor(np.copy(predicted), dtype=torch.long)
        _, accuracy        = cluster_accuracy(predicted, actual.cpu().numpy())
        cluster_centers    = torch.tensor(
            kmeans.cluster_centers_,
            dtype=torch.float, requires_grad=True
        )
        predicted_idxed    = torch.cat(
            [idxs.reshape(-1,1), torch.tensor(predicted).reshape(-1,1).long()],
            dim = -1
        )
        del features, actual, idxs, boxs, videos, frames
        if cuda:
            wdec.cuda()
            cluster_centers = cluster_centers.cuda(non_blocking=True)
        with torch.no_grad():
            # initialise the cluster centers
            wdec.state_dict()['assignment.cluster_centers'].copy_(cluster_centers)
            # wdec.state_dict()['assignment.cluster_predicted'].copy_(predicted_idxed)
            # wdec.state_dict()['assignment.cluster_positive_ratio'].copy_(cpr)
            wdec.assignment.cluster_predicted = predicted_idxed.clone()
            wdec.assignment.cluster_positive_ratio = cpr.clone()
    else:
        predicted, actual = predict(
              dataset,
              wdec,
              batch_size=evaluate_batch_size,
              collate_fn=collate_fn,
              silent=True,
              return_actual=True,
              cuda=cuda
        )
        predicted_previous = torch.tensor(np.copy(predicted), dtype=torch.long)
        _, accuracy = cluster_accuracy(predicted.cpu().numpy(), actual.cpu().numpy())
        
    if start_time is not None:
        print('\ntrainint DEC')
        print(f'@ {time.time() - start_time}\n')

    loss_function = nn.KLDivLoss(size_average=False)
    delta_label = None
    for epoch in range(epochs):
        # features = [] ### I see no use for this
        data_iterator = tqdm(
            train_dataloader,
            leave=True,
            unit='batch',
            postfix={
                'epo': epoch,
                'acc': '%.4f' % (accuracy or 0.0),
                'lss': '%.8f' % 0.0,
                'dlb': '%.4f' % (delta_label or 0.0),
            },
            disable=silent,
        )
        wdec.train()
        for index, batch in enumerate(data_iterator):
            if (isinstance(batch, tuple) or isinstance(batch, list)) and len(batch) == 6:
                batch, actual, idxs, _, _, _ = batch  # if we have a prediction label, strip it away
            if cuda:
                batch  = batch.cuda(non_blocking=True)
                actual = actual.cuda()
                idxs   = idxs.cuda()
            output = wdec(batch, actual, idxs,)
            target = target_distribution(output).detach()
            loss   = loss_function(output.log(), target) / output.shape[0]
            data_iterator.set_postfix(
                epo = epoch,
                acc = '%.4f' % (accuracy or 0.0),
                lss = '%.8f' % float(loss.item()),
                dlb = '%.4f' % (delta_label or 0.0),
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step(closure=None)
            if scheduler is not None: scheduler.step()
            # features.append(model.encoder(batch).detach().cpu()) ### I see no use for this
            if update_freq is not None and index % update_freq == 0:
                loss_value = float(loss.item())
                data_iterator.set_postfix(
                    epo=epoch,
                    acc='%.4f' % (accuracy or 0.0),
                    lss='%.8f' % loss_value,
                    dlb='%.4f' % (delta_label or 0.0),
                )
                if update_callback is not None:
                    update_callback(accuracy, loss_value, delta_label)
        predicted, actual = predict(
            dataset,
            wdec,
            batch_size=evaluate_batch_size,
            collate_fn=collate_fn,
            silent=True,
            return_actual=True,
            cuda=cuda
        )
        delta_label = float((predicted != predicted_previous).float().sum().item()) / predicted_previous.shape[0]
        if stopping_delta is not None and delta_label < stopping_delta:
            print('Early stopping as label delta "%1.5f" less than "%1.5f".' % (delta_label, stopping_delta))
            break
        predicted_previous = predicted
        _, accuracy = cluster_accuracy(predicted.cpu().numpy(), actual.cpu().numpy())
        data_iterator.set_postfix(
            epo=epoch,
            acc='%.4f' % (accuracy or 0.0),
            lss='%.8f' % 0.0,
            dlb='%.4f' % (delta_label or 0.0),
        )
        if epoch_callback is not None:
            epoch_callback(epoch, wdec)
    wdec.cpu()