Example #1
0
def dataset_to_clustering_error(results_paths, dataset_name, root, output_path, debug, n_samples,
                                param_file=None, coil20_unprocessed=False):

    dataset = Datasets(dataset=dataset_name, root_folder=root, flatten=True,
                       debug=debug, n_samples=n_samples, coil20_unprocessed=coil20_unprocessed)

    test_loader = DataLoader(dataset.test_data, shuffle=False)

    true_labels = np.array([])
    for batch_idx, (sample, target) in enumerate(test_loader):
        true_labels = np.concatenate((true_labels, target.numpy()))

    repeats = len([f for f in listdir(results_paths)
                   if isfile(join(results_paths, f)) and not f.startswith('.') and f.endswith(".results")])

    ces = []
    num_nodes = []

    for i in range(repeats):
        results_file = join(results_paths, "{0}_{1}.results".format(dataset_name.split(".")[0], i))
        print(results_file)

        data_n_winners, found_clusters, _ = utils.read_results(results_file)
        predict_labels = pd.DataFrame(data_n_winners, dtype=np.float64).iloc[:, -1].values

        if debug:
            predict_labels = predict_labels[:n_samples]

        num_nodes.append(found_clusters)

        ce = cluster.predict_to_clustering_error(true_labels, predict_labels)
        ces.append(ce)

    output_file = open(output_path + '.csv', 'w+')

    line = "max_value," + str(np.nanmax(ces)) + "\n"

    max_value_index = np.nanargmax(ces)
    line += "num_nodes," + str(num_nodes[max_value_index]) + "\n"
    line += "index_set," + str(max_value_index) + "\n"

    write_csv_body([ces], [dataset_name], line, [np.nanmean(ces)], output_file, param_file, [np.nanstd(ces)])
def main(args):
    config_yaml = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
    if not os.path.exists(args.config):
        raise FileNotFoundError('provided config file does not exist: %s' % args.config)

    if 'restart_log_dir_path' not in config_yaml['simclr']['train'].keys():
        config_yaml['simclr']['train']['restart_log_dir_path'] = None

    if args.data_dir_path is not None:
        config_yaml['simclr']['train']['data_dir_path'] = args.data_dir_path
        print('yo!: ', args.data_dir_path)

    config_yaml['logger_name'] = 'logreg'
    config = SimCLRConfig(config_yaml)

    if not os.path.exists(config.base.output_dir_path):
        os.mkdir(config.base.output_dir_path)

    if not os.path.exists(config.base.log_dir_path):
        os.makedirs(config.base.log_dir_path)

    logger = setup_logger(config.base.logger_name, config.base.log_file_path)
    logger.info('using config: %s' % config)

    config_copy_file_path = os.path.join(config.base.log_dir_path, 'config.yaml')
    shutil.copy(args.config, config_copy_file_path)

    writer = SummaryWriter(log_dir=config.base.log_dir_path)

    if not os.path.exists(args.model):
        raise FileNotFoundError('provided model directory does not exist: %s' % args.model)
    else:
        logger.info('using model directory: %s' % args.model)

    config.logistic_regression.model_path = args.model
    logger.info('using model_path: {}'.format(config.logistic_regression.model_path))

    config.logistic_regression.epoch_num = args.epoch_num
    logger.info('using epoch_num: {}'.format(config.logistic_regression.epoch_num))

    model_file_path = Path(config.logistic_regression.model_path).joinpath(
        'checkpoint_' + config.logistic_regression.epoch_num + '.pth')
    if not os.path.exists(model_file_path):
        raise FileNotFoundError('model file does not exist: %s' % model_file_path)
    else:
        logger.info('using model file: %s' % model_file_path)

    train_dataset, val_dataset, test_dataset, classes = Datasets.get_datasets(config,
                                                                              img_size=config.logistic_regression.img_size)
    num_classes = len(classes)

    train_loader, val_loader, test_loader = Datasets.get_loaders(config, train_dataset, val_dataset, test_dataset)

    simclr_model = load_simclr_model(config)
    simclr_model = simclr_model.to(config.base.device)
    simclr_model.eval()

    model = LogisticRegression(simclr_model.num_features, num_classes)
    model = model.to(config.base.device)

    learning_rate = config.logistic_regression.learning_rate
    momentum = config.logistic_regression.momentum
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, nesterov=True)
    criterion = torch.nn.CrossEntropyLoss()

    logger.info("creating features from pre-trained context model")
    (train_x, train_y, test_x, test_y) = get_features(
        config, simclr_model, train_loader, test_loader
    )

    feature_train_loader, feature_test_loader = get_data_loaders(
        config, train_x, train_y, test_x, test_y
    )

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_epoch = 0
    best_loss = 0

    for epoch in range(config.logistic_regression.epochs):
        loss_epoch, accuracy_epoch = train(
            config, feature_train_loader, model, criterion, optimizer
        )

        loss = loss_epoch / len(train_loader)
        accuracy = accuracy_epoch / len(train_loader)

        writer.add_scalar("Loss/train_epoch", loss, epoch)
        writer.add_scalar("Accuracy/train_epoch", accuracy, epoch)
        logger.info(
            "epoch [%3.i|%i] -> train loss: %f, accuracy: %f" % (
                epoch + 1, config.logistic_regression.epochs, loss, accuracy)
        )

        if accuracy > best_acc:
            best_loss = loss
            best_epoch = epoch + 1
            best_acc = accuracy
            best_model_wts = copy.deepcopy(model.state_dict())

    model.load_state_dict(best_model_wts)
    logger.info(
        "train dataset performance -> best epoch: {}, loss: {}, accuracy: {}".format(best_epoch, best_loss, best_acc, )
    )

    loss_epoch, accuracy_epoch = test(
        config, feature_test_loader, model, criterion
    )

    loss = loss_epoch / len(test_loader)
    accuracy = accuracy_epoch / len(test_loader)
    logger.info(
        "test dataset performance -> best epoch: {}, loss: {}, accuracy: {}".format(best_epoch, loss, accuracy)
    )
Example #3
0
def main(args):
    config_yaml = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
    if not os.path.exists(args.config):
        raise FileNotFoundError('provided config file does not exist: %s' %
                                args.config)

    config_yaml['logger_name'] = 'onnx'
    config = SimCLRConfig(config_yaml)

    if not os.path.exists(config.base.output_dir_path):
        os.mkdir(config.base.output_dir_path)

    if not os.path.exists(config.base.log_dir_path):
        os.makedirs(config.base.log_dir_path)

    logger = setup_logger(config.base.logger_name, config.base.log_file_path)
    logger.info('using config: %s' % config)

    if not os.path.exists(args.model):
        raise FileNotFoundError('provided model directory does not exist: %s' %
                                args.model)
    else:
        logger.info('using model directory: %s' % args.model)

    config.onnx.model_path = args.model
    logger.info('using model_path: {}'.format(config.onnx.model_path))

    config.onnx.epoch_num = args.epoch_num
    logger.info('using epoch_num: {}'.format(config.onnx.epoch_num))

    model_file_path = Path(
        config.onnx.model_path).joinpath('checkpoint_' +
                                         config.onnx.epoch_num + '.pth')
    if not os.path.exists(model_file_path):
        raise FileNotFoundError('model file does not exist: %s' %
                                model_file_path)
    else:
        logger.info('using model file: %s' % model_file_path)

    train_dataset, val_dataset, test_dataset, classes = Datasets.get_datasets(
        config)
    num_classes = len(classes)

    train_loader, val_loader, test_loader = Datasets.get_loaders(
        config, train_dataset, val_dataset, test_dataset)

    torch_model = load_torch_model(config, num_classes)

    val_acc, test_acc = test_pt_model(config, torch_model, val_dataset,
                                      test_dataset, val_loader, test_loader)
    logger.info('torch model performance -> val_acc: {}, test_acc: {}'.format(
        val_acc, test_acc))

    torch_model = torch_model.to(torch.device('cpu'))
    onnx_model_file_path = save_onnx_model(torch_model,
                                           num_classes=num_classes,
                                           config=config,
                                           current_epoch=config.onnx.epoch_num)

    onnx_model = load_onnx_model(config, onnx_model_file_path)
    if onnx_model:
        logger.info('loaded onnx_model: {}'.format(onnx_model_file_path))

    val_acc, test_acc = test_onnx_model(config, onnx_model_file_path,
                                        val_dataset, test_dataset, val_loader,
                                        test_loader)
    logger.info('onnx model performance -> val_acc: {}, test_acc: {}'.format(
        val_acc, test_acc))
Example #4
0
def train_som(root,
              dataset_path,
              parameters,
              device,
              use_cuda,
              workers,
              out_folder,
              batch_size,
              n_max=None,
              evaluate=False,
              summ_writer=None,
              coil20_unprocessed=False):
    dataset = Datasets(dataset=dataset_path,
                       root_folder=root,
                       flatten=True,
                       coil20_unprocessed=coil20_unprocessed)

    plots = HParams()
    clustering_errors = []
    for param_set in parameters.itertuples():
        n_max_som = param_set.n_max if n_max is None else n_max

        som = SOM(input_dim=dataset.dim_flatten,
                  n_max=n_max_som,
                  at=param_set.at,
                  ds_beta=param_set.ds_beta,
                  eb=param_set.eb,
                  eps_ds=param_set.eps_ds,
                  ld=param_set.ld,
                  device=device)
        som_epochs = param_set.epochs

        manual_seed = param_set.seed
        random.seed(manual_seed)
        torch.manual_seed(manual_seed)

        if use_cuda:
            torch.cuda.manual_seed_all(manual_seed)
            som.cuda()
            cudnn.benchmark = True

        train_loader = DataLoader(dataset.train_data,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=workers)
        test_loader = DataLoader(dataset.test_data, shuffle=False)

        for epoch in range(som_epochs):
            print('{} [epoch: {}]'.format(dataset_path, epoch))

            for batch_idx, (sample, target) in enumerate(train_loader):
                sample, target = sample.to(device), target.to(device)

                som(sample)

        cluster_result, predict_labels, true_labels = som.cluster(test_loader)
        filename = dataset_path.split(".arff")[0] + "_" + str(
            param_set.Index) + ".results"
        som.write_output(join(out_folder, filename), cluster_result)

        if evaluate:
            ce = metrics.cluster.predict_to_clustering_error(
                true_labels, predict_labels)
            clustering_errors.append(ce)
            print('{} \t exp_id {} \tCE: {:.3f}'.format(
                dataset_path, param_set.Index, ce))

    if evaluate and summ_writer is not None:
        clustering_errors = np.array(clustering_errors)
        plots.plot_tensorboard_x_y(parameters, 'CE', clustering_errors,
                                   summ_writer,
                                   dataset_path.split(".arff")[0])
Example #5
0
def main(args):
    config_yaml = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
    if not os.path.exists(args.config):
        raise FileNotFoundError(
            'provided config file does not exist: {}'.format(args.config))

    if 'restart_log_dir_path' not in config_yaml['simclr']['train'].keys():
        config_yaml['simclr']['train']['restart_log_dir_path'] = None

    config_yaml['logger_name'] = 'simclr'
    config = SimCLRConfig(config_yaml)

    if not config.simclr.train.start_epoch == 0 and config.simclr.train.restart_log_dir_path is None:
        raise ValueError(
            'provided config file is invalid. no restart_log_dir_path provided and start_epoch is not 0'
        )

    if not os.path.exists(config.base.output_dir_path):
        os.mkdir(config.base.output_dir_path)

    if not os.path.exists(config.base.log_dir_path):
        os.makedirs(config.base.log_dir_path)

    logger = setup_logger(config.base.logger_name, config.base.log_file_path)
    logger.info('using config: {}'.format(config))

    config_copy_file_path = os.path.join(config.base.log_dir_path,
                                         'config.yaml')
    shutil.copy(args.config, config_copy_file_path)

    writer = SummaryWriter(log_dir=config.base.log_dir_path)

    model = load_model(config)
    logger.info('loaded model')

    train_dataset, _ = Datasets.get_simclr_dataset(config)
    logger.info('using train_dataset. length: {}'.format(len(train_dataset)))

    train_loader = Datasets.get_simclr_loader(config, train_dataset)
    logger.info('created train_loader. length {}'.format(len(train_loader)))

    scheduler = None
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    logger.info('created optimizer')

    criterion = NTXent(config.simclr.train.batch_size,
                       config.simclr.train.temperature, config.base.device)
    logger.info('created criterion')

    config.simclr.train.current_epoch = config.simclr.train.start_epoch
    for epoch in range(config.simclr.train.start_epoch,
                       config.simclr.train.epochs):
        lr = optimizer.param_groups[0]['lr']
        loss_epoch = train(config, train_loader, model, criterion, optimizer,
                           writer)

        if scheduler:
            scheduler.step()

        if epoch % config.simclr.train.save_num_epochs == 0:
            save_model(config, model)

        writer.add_scalar("Loss/train_epoch", loss_epoch / len(train_loader),
                          epoch)
        writer.add_scalar("Learning rate", lr, epoch)
        logger.info("epoch [%5.i|%5.i] -> loss: %.15f, lr: %f" %
                    (epoch + 1, config.simclr.train.epochs,
                     loss_epoch / len(train_loader), round(lr, 5)))
        config.simclr.train.current_epoch += 1

    save_model(config, model)
def main(args):
    config_yaml = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
    if not os.path.exists(args.config):
        raise FileNotFoundError(
            'provided config file does not exist: {}'.format(args.config))

    if 'restart_log_dir_path' not in config_yaml['simclr']['train'].keys():
        config_yaml['simclr']['train']['restart_log_dir_path'] = None

    config_yaml['logger_name'] = 'classification'
    config = SimCLRConfig(config_yaml)

    if not os.path.exists(config.base.output_dir_path):
        os.mkdir(config.base.output_dir_path)

    if not os.path.exists(config.base.log_dir_path):
        os.makedirs(config.base.log_dir_path)

    logger = setup_logger(config.base.logger_name, config.base.log_file_path)
    logger.info('using config: {}'.format(config))

    config_copy_file_path = os.path.join(config.base.log_dir_path,
                                         'config.yaml')
    shutil.copy(args.config, config_copy_file_path)

    writer = SummaryWriter(log_dir=config.base.log_dir_path)

    if not os.path.exists(args.model):
        raise FileNotFoundError('provided model directory does not exist: %s' %
                                args.model)
    else:
        logger.info('using model directory: {}'.format(args.model))

    config.fine_tuning.model_path = args.model
    logger.info('using model_path: {}'.format(config.fine_tuning.model_path))

    config.fine_tuning.epoch_num = args.epoch_num
    logger.info('using epoch_num: {}'.format(config.fine_tuning.epoch_num))

    model_file_path = Path(
        config.fine_tuning.model_path).joinpath('checkpoint_' +
                                                config.fine_tuning.epoch_num +
                                                '.pth')
    if not os.path.exists(model_file_path):
        raise FileNotFoundError(
            'model file does not exist: {}'.format(model_file_path))
    else:
        logger.info('using model file: {}'.format(model_file_path))

    train_dataset, val_dataset, test_dataset, classes = Datasets.get_datasets(
        config)
    num_classes = len(classes)

    train_loader, val_loader, test_loader = Datasets.get_loaders(
        config, train_dataset, val_dataset, test_dataset)

    dataloaders = {
        'train': train_loader,
        'val': val_loader,
    }

    dataset_sizes = {
        'train': len(train_loader.sampler),
        'val': len(val_loader.sampler)
    }

    simclr_model = load_model(config)
    logger.info('loaded simclr_model: {}'.format(
        config.fine_tuning.model_path))

    classification_model = to_classification_model(simclr_model, num_classes,
                                                   config)
    classification_model = classification_model.to(config.base.device)
    logger.info('created classification model from simclr model')

    criterion = torch.nn.CrossEntropyLoss()
    logger.info('created criterion')

    lr = config.fine_tuning.learning_rate
    momentum = config.fine_tuning.momentum
    optimizer_ft = torch.optim.SGD(classification_model.parameters(),
                                   lr=lr,
                                   momentum=momentum,
                                   nesterov=True)
    logger.info('created optimizer')

    step_size = config.fine_tuning.step_size
    gamma = config.fine_tuning.gamma
    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft,
                                                       step_size=step_size,
                                                       gamma=gamma)
    logger.info('created learning rate scheduler')

    epochs = config.fine_tuning.epochs
    classification_model = train_model(classification_model, criterion,
                                       optimizer_ft, exp_lr_scheduler,
                                       dataloaders, dataset_sizes, config,
                                       epochs, writer)
    logger.info('completed model training')

    test_model(config, classification_model, test_loader)
    logger.info('completed model testing')

    trained_model_file_path = save_model(config, classification_model, epochs)
    logger.info('saved trained model: {}'.format(trained_model_file_path))
Example #7
0
def train_full_model(root, dataset_path, parameters, device, use_cuda, out_folder, debug, n_samples,
                     lr_cnn, batch_size, summ_writer, print_debug, coil20_unprocessed=False):
    dataset = Datasets(dataset=dataset_path, root_folder=root, debug=debug,
                       n_samples=n_samples, coil20_unprocessed=coil20_unprocessed)

    som_plotter = Plotter()
    tsne_plotter = Plotter()

    # Initialize all meters
    data_timer = utils.Timer()
    batch_timer = utils.Timer()
    batch_time = utils.AverageMeter()
    data_time = utils.AverageMeter()

    for param_set in parameters.itertuples():

        model = Net(d_in=dataset.d_in,
                    n_conv_layers=param_set.n_conv,
                    max_pool=True if param_set.max_pool else False,
                    hw_in=dataset.hw_in,
                    som_input=param_set.som_in,
                    filters_list=param_set.filters_pow,
                    kernel_size_list=param_set.n_conv * [param_set.kernel_size],
                    stride_size_list=param_set.n_conv * [1],
                    padding_size_list=param_set.n_conv * [0],
                    max_pool2d_size=param_set.max_pool2d_size,
                    n_max=param_set.n_max,
                    at=param_set.at,
                    eb=param_set.eb,
                    ds_beta=param_set.ds_beta,
                    eps_ds=param_set.eps_ds,
                    ld=param_set.ld,
                    device=device)

        manual_seed = param_set.seed
        random.seed(manual_seed)
        torch.manual_seed(manual_seed)

        if use_cuda:
            torch.cuda.manual_seed_all(manual_seed)
            model.cuda()
            cudnn.benchmark = True

        train_loader = DataLoader(dataset.train_data, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(dataset.test_data, shuffle=False)

        #  optimizer = optim.SGD(model.parameters(), lr=lr_cnn, momentum=0.5)
        optimizer = optim.Adam(model.parameters(), lr=param_set.lr_cnn)
        loss = nn.MSELoss(reduction='sum')

        model.train()
        for epoch in range(param_set.epochs):

            # Self-Organize and Backpropagate
            avg_loss = 0
            s = 0
            data_timer.tic()
            batch_timer.tic()
            for batch_idx, (sample, target) in enumerate(train_loader):

                data_time.update(data_timer.toc())
                sample, target = sample.to(device), target.to(device)
                optimizer.zero_grad()
                samples_high_at, weights_unique_nodes_high_at, relevances = model(sample)

                #  if only new nodes were created, the loss is zero, no need to backprobagate it
                if len(samples_high_at) > 0:
                    weights_unique_nodes_high_at = weights_unique_nodes_high_at.view(-1, model.som_input_size)
                    out = weightedMSELoss(samples_high_at, weights_unique_nodes_high_at, relevances)
                    out.backward()
                    optimizer.step()
                else:
                    out = 0.0

                avg_loss += out
                s += len(sample)

                batch_time.update(batch_timer.toc())
                data_timer.toc()

                if debug:
                    cluster_result, predict_labels, true_labels = model.cluster(test_loader)
                    print("Homogeneity: %0.3f" % metrics.cluster.homogeneity_score(true_labels, predict_labels))
                    print("Completeness: %0.3f" % metrics.cluster.completeness_score(true_labels, predict_labels))
                    print("V-measure: %0.3f" % metrics.cluster.v_measure_score(true_labels, predict_labels))
                    nmi = metrics.cluster.nmi(true_labels, predict_labels)
                    print("Normalized Mutual Information (NMI): %0.3f" % nmi)
                    ari = metrics.cluster.ari(true_labels, predict_labels)
                    print("Adjusted Rand Index (ARI): %0.3f" % ari)
                    clus_acc = metrics.cluster.acc(true_labels, predict_labels)
                    print("Clustering Accuracy (ACC): %0.3f" % clus_acc)
                    print('{0} \tCE: {1:.3f}'.format(dataset_path,
                                                     metrics.cluster.predict_to_clustering_error(true_labels,
                                                                                                 predict_labels)))

                    if summ_writer is not None:
                        summ_writer.add_scalar('/NMI', nmi, epoch)
                        summ_writer.add_scalar('/ARI', ari, epoch)
                        summ_writer.add_scalar('/Acc', clus_acc, epoch)

                if print_debug:
                    print('[{0:6d}/{1:6d}]\t'
                          '{batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                          '{data_time.val:.4f} ({data_time.avg:.4f})\t'.format(
                        batch_idx, len(train_loader), batch_time=batch_time,
                        data_time=data_time))

            samples = None
            t = None
            #  Calculate metrics or plot without change SOM map
            if debug:
                for batch_idx, (inputs, targets) in enumerate(train_loader):
                    inputs, targets = inputs.to(device), targets.to(device)
                    outputs = model.cnn_extract_features(inputs)

                    if samples is None:
                        samples = outputs.cpu().detach().numpy()
                        t = targets.cpu().detach().numpy()
                    else:
                        samples = np.append(samples, outputs.cpu().detach().numpy(), axis=0)
                        t = np.append(t, targets.cpu().detach().numpy(), axis=0)

                centers, relevances, ma = model.som.get_prototypes()
                som_plotter.plot_data(samples, t, centers.cpu(), relevances.cpu() * 0.1)

                if summ_writer is not None:
                    summ_writer.add_scalar('Nodes', len(centers), epoch)

                # for center in centers:
                #     t = np.append(t, [10], axis=0)
                # samples = np.append(samples, centers.cpu().detach().numpy(), axis=0)
                # tsne = cumlTSNE(n_components=2, method='barnes_hut')
                # embedding = tsne.fit_transform(samples)
                # tsne_plotter.plot_data(embedding, t, None, None)

            print("Epoch: %d avg_loss: %.6f\n" % (epoch, avg_loss / s))
            if summ_writer is not None:
                summ_writer.add_scalar('Loss/train', avg_loss / s, epoch)

        #  Need to change train loader to test loader...
        model.eval()

        print("Train Finished", flush=True)

        cluster_result, predict_labels, true_labels = model.cluster(test_loader)

        if not os.path.exists(join(out_folder, dataset_path.split(".arff")[0])):
            os.makedirs(join(out_folder, dataset_path.split(".arff")[0]))

        print("Homogeneity: %0.3f" % metrics.cluster.homogeneity_score(true_labels, predict_labels))
        print("Completeness: %0.3f" % metrics.cluster.completeness_score(true_labels, predict_labels))
        print("V-measure: %0.3f" % metrics.cluster.v_measure_score(true_labels, predict_labels))
        print("Normalized Mutual Information (NMI): %0.3f" % metrics.cluster.nmi(true_labels, predict_labels))
        print("Adjusted Rand Index (ARI): %0.3f" % metrics.cluster.ari(true_labels, predict_labels))
        print("Clustering Accuracy (ACC): %0.3f" % metrics.cluster.acc(true_labels, predict_labels))

        filename = dataset_path.split(".arff")[0] + "_" + str(param_set.Index) + ".results"
        model.write_output(join(out_folder, filename), cluster_result)

        print('{0} \tCE: {1:.3f}'.format(dataset_path,
                                         metrics.cluster.predict_to_clustering_error(true_labels,
                                                                                     predict_labels)))

        if debug:
            som_plotter.plot_hold()
Example #8
0
def som_weights_visualization(root, dataset_path, parameters, grid_rows=10, grid_cols=10, lhs_samples=250):
    dataset = Datasets(dataset=dataset_path, root_folder=root, flatten=True)

    for param_set in parameters.itertuples():

        som = SOM(input_dim=dataset.dim_flatten,
                  n_max=grid_rows * grid_cols,
                  at=param_set.at,
                  ds_beta=param_set.ds_beta,
                  eb=param_set.eb,
                  eps_ds=param_set.eps_ds)

        manual_seed = param_set.seed
        random.seed(manual_seed)
        torch.manual_seed(manual_seed)

        train_loader = DataLoader(dataset.train_data, batch_size=1, shuffle=True, num_workers=1)

        fig = plt.figure(figsize=(30, 30), constrained_layout=False)
        gs = fig.add_gridspec(1, 3, hspace=0.1, wspace=0.1)
        gs02 = gs[0, 2].subgridspec(grid_rows, grid_cols, wspace=0.1, hspace=0.0)
        ax1 = fig.add_subplot(gs[0, 0], xticks=np.array([]), yticks=np.array([]))
        ax2 = fig.add_subplot(gs[0, 1], xticks=np.array([]), yticks=np.array([]))

        fig.suptitle('Self-Organizing Map (SOM) - SOM Clustering', fontsize=16)
        for epoch in range(param_set.epochs):
            print('Experiment {} of {} [epoch: {} of {}]'.format(param_set.Index, lhs_samples, epoch, param_set.epochs))
            for batch_idx, (sample, target) in enumerate(train_loader):

                _, bmu_weights, _ = som(sample)
                _, bmu_indexes = som.get_winners(sample)
                ind_max = bmu_indexes.item()
                weights = som.weights[bmu_indexes]

                if dataset.d_in == 1:
                    ax1.imshow(sample.view(dataset.hw_in, dataset.hw_in), cmap='gray')
                    ax2.imshow(weights.view(dataset.hw_in, dataset.hw_in), cmap='gray')
                    images = [image.reshape(dataset.hw_in, dataset.hw_in) for image in som.weights]
                else:
                    ax1.imshow(sample.view(dataset.d_in, dataset.hw_in, dataset.hw_in).numpy().transpose((1, 2, 0)))
                    ax2.imshow(weights.view(dataset.d_in, dataset.hw_in, dataset.hw_in).numpy().transpose((1, 2, 0)))
                    images = [image.reshape(dataset.hw_in, dataset.hw_in, dataset.d_in) for image in som.weights]

                for x in range(grid_rows):
                    for y in range(grid_cols):
                        if ind_max == (10 * y + x):
                            ax3 = fig.add_subplot(gs02[y, x])
                            if dataset.d_in == 1:
                                ax3.imshow(images[10 * y + x], cmap='gray')
                            else:
                                ax3.imshow(images[10 * y + x].view(dataset.d_in,
                                                                   dataset.hw_in,
                                                                   dataset.hw_in).numpy().transpose((1, 2, 0)))
                            ax3.set_xlabel('{label}'.format(label=ind_max))
                            plt.xticks(np.array([]))
                            plt.yticks(np.array([]))
                ax1.set_xlabel("Sample {} of {}".format(batch_idx, len(train_loader)))
                ax2.set_xlabel("Epoch {} of {}".format(epoch, param_set.epochs))
                ax1.set_title('Input Label: {label}'.format(label=target.item()))
                ax2.set_title('SOM BMU: {label}'.format(label=ind_max))
                plt.pause(0.001)