def testOnlyClusterInitialization(dataset_name, arch, epochs):
    '''
    Train an autoencoder defined by architecture arch and trains it with the dataset defined
    :param dataset_name: Name of the dataset with which the network will be trained [MNIST, COIL20]
    :param arch: Architecture of the network as a dictionary. Specification for architecture can be found in readme.md
    :param epochs: Number of train epochs
    :return: None - (side effect) saves the latent space and params of trained network in an appropriate location in saved_params folder
    '''
    rootLogger.info("Loading dataset")
    dataset = DatasetHelper(dataset_name)
    dataset.loadDataset()
    rootLogger.info("Done loading dataset")
    rootLogger.info("Creating network")
    dcjc = DCJC(arch)
    rootLogger.info("Done creating network")
    rootLogger.info("Starting training")
    dcjc.pretrainWithData(dataset, epochs, False)
def testOnlyClusterImprovement(dataset_name, arch, epochs, method):
    '''
    Use an initialized autoencoder and train it along with clustering loss. Assumed that pretrained autoencoder params
    are available, i.e. testOnlyClusterInitialization has been run already with the given params
    :param dataset_name: Name of the dataset with which the network will be trained [MNIST, COIL20]
    :param arch: Architecture of the network as a dictionary. Specification for architecture can be found in readme.md
    :param epochs: Number of train epochs
    :param method: Can be KM or KLD - depending on whether the clustering loss is KLDivergence loss between the current KMeans distribution(Q) and a more desired one(Q^2), or if the clustering loss is just the Kmeans loss
    :return: None - (side effect) saves latent space and params of the trained network
    '''
    rootLogger.info("Loading dataset")
    dataset = DatasetHelper(dataset_name)
    dataset.loadDataset()
    rootLogger.info("Done loading dataset")
    rootLogger.info("Creating network")
    dcjc = DCJC(arch)
    rootLogger.info("Starting cluster improvement")
    if method == 'KM':
        dcjc.doClusteringWithKMeansLoss(dataset, epochs)
    elif method == 'KLD':
        dcjc.doClusteringWithKLdivLoss(dataset, True, epochs)