Exemple #1
0
def m_embeding():
    M = np.loadtxt(mid_result_path + "M.txt", delimiter=',')
    EPOCH = 10
    BATCH_SIZE = 64
    LR = 0.005
    autoencoder = AutoEncoder(M.shape[1])
    M = torch.tensor(M, dtype=torch.float32)
    M_train = Data.DataLoader(dataset=M, batch_size=BATCH_SIZE, shuffle=True)

    optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
    loss_func = nn.MSELoss()
    for epoch in range(EPOCH):
        for step, x in enumerate(M_train):
            # print(x)
            b_x = Variable(x)#
            b_y = Variable(x)
            encoded, decoded = autoencoder(b_x)
            loss = loss_func(decoded, b_y)
            print(loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    encoded_data, _ = autoencoder(Variable(M))
    np.savetxt(mid_result_path + "m_coder.txt", encoded_data.detach().numpy(), delimiter=',', fmt='%.4f')
    print(encoded_data)
Exemple #2
0
    def buildModel(self):
        """
        Builds autoencoder models based on init parameters
        """

        if self.model_type == "VGG":
            autoencoder = AutoEncoder(model_type="VGG",
                                      latent_vector=self.latent_vector,
                                      latent_dim=self.latent_dim)
            # construct the convolutional autoencoder
            self.auto_encoder, self.encoder, self.decoder = autoencoder.VGG()

        elif self.model_type == "COAPNET":
            autoencoder = AutoEncoder(model_type="COAPNET",
                                      latent_vector=self.latent_vector,
                                      latent_dim=self.latent_dim)
            # construct the convolutional autoencoder
            self.auto_encoder, self.encoder, self.decoder = autoencoder.COAPNET(
            )

        elif self.model_type == "FULLYCONNECTED":
            autoencoder = AutoEncoder(model_type="FULLYCONNECTED",
                                      latent_vector=self.latent_vector,
                                      latent_dim=self.latent_dim)
            # construct the convolutional autoencoder
            self.auto_encoder, self.encoder, self.decoder = autoencoder.FULLYCONNECTED(
            )

        elif self.model_type == "vae":
            from test import vae_test

            self.auto_encoder, self.encoder, self.decoder = vae_test()
            print(self.encoder.summary())
            print(self.decoder.summary())
            print(self.auto_encoder.summary())
            # construct the convolutional autoencoder
            #self.auto_encoder, self.encoder, self.decoder  = autoencoder.sparse()
            return

        else:
            print("No model with that name exists")

        adam = tf.keras.optimizers.Adam(learning_rate=0.0001,
                                        beta_1=0.9,
                                        beta_2=0.999,
                                        amsgrad=False)
        sgd = tf.keras.optimizers.SGD(learning_rate=0.001,
                                      momentum=0.01,
                                      nesterov=False)
        self.auto_encoder.compile(loss="mse", optimizer=adam)  #SGD(lr=1e-3)
        print(self.encoder.summary())
        print(self.decoder.summary())
        print(self.auto_encoder.summary())
    def __init__(self, is_training, network, em, dataset):
        _shape = dataset['shape']

        # AE vars
        _input_size = _shape[0] * _shape[1] * _shape[2]
        self._is_training = is_training
        self.ae_input_data = tf.placeholder(tf.float32, [None, _input_size])
        self.targets = tf.placeholder(tf.float32, [None, _input_size])

        # model
        self._ae = AutoEncoder(network['hidden_size'], _input_size, network['e_activation'],
                               network['d_activation'], tied=network['tied'])

        self.outputs, self.cost, self.lr, self.n_vars, self.train_op = 5*[None]
        self.build_network_training_model(dataset['binary'])

        # RC vars
        self.rc_input_data = tf.placeholder(tf.float32, [1, None] + _shape + [1])
        self.pi = tf.placeholder(tf.float32, [1, None, 1, 1, 1, em['k']])
        self.gamma = tf.placeholder(tf.float32, [1, None] + _shape[:2] + [1, em['k']])

        # RC model
        _distribution = 'binomial' if dataset['binary'] else 'gaussian'
        self._rc = ReconstructionClustering(self._ae, _shape, em['k'], em['e_step'], _distribution, dataset['binary'])

        self.new_gamma, self.new_pi, self.likelihood_post_m, self.likelihood_post_e = 4*[None]
        self.build_reconstruction_clustering_model()
def train(train_x, train_y, word_dict, args):
    with tf.Session() as sess:
        if args.model == "auto_encoder":
            model = AutoEncoder(word_dict, MAX_DOCUMENT_LEN)
        else:
            raise ValueError("Not found model: {}.".format(args.model))

        # Define training procedure
        global_step = tf.Variable(0, trainable=False)
        params = tf.trainable_variables()
        gradients = tf.gradients(model.loss, params)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
        optimizer = tf.train.AdamOptimizer(0.001)
        train_op = optimizer.apply_gradients(zip(clipped_gradients, params),
                                             global_step=global_step)

        # Summary
        # loss_summary = tf.summary.scalar("loss", model.loss)
        # summary_op = tf.summary.merge_all()
        # summary_writer = tf.summary.FileWriter(args.save, sess.graph)

        # Checkpoint
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)

        # Initialize all variables
        sess.run(tf.global_variables_initializer())

        def train_step(batch_x):
            feed_dict = {model.x: batch_x}
            _, step, loss = sess.run([train_op, global_step, model.loss],
                                     feed_dict=feed_dict)
            # summary_writer.add_summary(summaries, step)

            if step % 100 == 0:
                print("step {0} : loss = {1}".format(step, loss))
                with open("pre-train-loss-all-" + args.save + ".txt",
                          "a") as f:
                    print("step {0} : loss = {1}".format(step, loss), file=f)

        # Training loop
        batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS)

        st = time.time()
        for batch_x, _ in batches:
            train_step(batch_x)
            step = tf.train.global_step(sess, global_step)

            steps_per_epoch = int(num_train / BATCH_SIZE)
            if step % steps_per_epoch == 0:
                print("epoch: {}, step: {}, steps_per_epoch: {}".format(
                    int(step / steps_per_epoch), step, steps_per_epoch))
                saver.save(sess,
                           os.path.join(args.save, "model", "model.ckpt"),
                           global_step=step)
                print("save to {}, time of one epoch: {}".format(
                    args.save,
                    time.time() - st))
                st = time.time()
    def __init__(self):
        imgs_dir = "./imgs_comp_box"
        imgs_mask_dir = "./imgs_mask_box"

        self.str_imgs_fns = []
        self.str_mask_fns = []

        dirs = []

        for parent, dirnames, filenames in os.walk(imgs_mask_dir):
            for dirname in dirnames:
                dirs.append(dirname)

        for str_dir in dirs:
            str_dir_path = imgs_mask_dir + "/" + str_dir
            for parent, dirnames, filenames in os.walk(str_dir_path):
                for filename in filenames:
                    str_path = str_dir_path + "/" + filename
                    self.str_mask_fns.append(str_path)
                    idx = filename.find(".png")
                    str_img_path = imgs_dir + "/" + str_dir + "/" + filename[:
                                                                             idx] + ".png"
                    self.str_imgs_fns.append(str_img_path)

        self.autoencoder = AutoEncoder()

        self.win = Gtk.Window()
        self.win.connect("delete-event", self.win_quit)
        self.win.set_default_size(1000, 600)
        self.win.set_title("show imgs")

        self.sw = Gtk.ScrolledWindow()
        self.win.add(self.sw)
        self.sw.set_border_width(2)

        fig = Figure(figsize=(8, 8), dpi=80)
        self.canvas = FigureCanvas(fig)
        self.canvas.set_size_request(1000, 600)
        self.sw.add(self.canvas)
        self.win.show_all()

        self.torch_lock = threading.Lock()
        self.torch_show_data = {}
        self.n_test_imgs = 5
        self.torch_show_data["mess_quit"] = False

        thread_torch = Encoder_Thread(self.update_torch_data,
                                      self.torch_lock,
                                      self.autoencoder,
                                      self.str_imgs_fns,
                                      self.str_mask_fns,
                                      self.torch_show_data,
                                      wh=97,
                                      max_n_loop=1,
                                      n_loop=0,
                                      idx_segment=0)

        thread_torch.start()
Exemple #6
0
def load_model(path):
    with tf.Graph().as_default():
        sess = tf.Session()
        with sess.as_default():
            auto_encoder_instance = AutoEncoder(
                embedding_size=config['embedding_size'],
                num_hidden_layer=config['num_hidden_layer'],
                hidden_layers=config['hidden_layers'])
            auto_encoder_instance.saver.restore(sess, path)
    return sess, auto_encoder_instance
        os.mkdir(results_path)

    dataset_path = '..' + os.sep + args.data_path

    # dataset = Cifar10Dataset(configuration.batch_size, dataset_path,
    #                           configuration.shuffle_dataset)  # Create an instance of CIFAR10 dataset

    print(f"batch_size: {configuration.batch_size}")

    dataset = ArnoDataset(batch_size=configuration.batch_size,
                          path="../../../datasets/arno_v1",
                          shuffle_dataset=configuration.shuffle_dataset,
                          num_workers=6,
                          im_size=128)

    auto_encoder = AutoEncoder(device, configuration).to(
        device)  # Create an AutoEncoder model using our GPU device

    optimizer = optim.Adam(auto_encoder.parameters(),
                           lr=configuration.learning_rate,
                           amsgrad=True)  # Create an Adam optimizer instance
    trainer = Trainer(device, auto_encoder, optimizer,
                      dataset)  # Create a trainer instance
    trainer.train(configuration.num_training_updates, results_path,
                  args)  # Train our model on the CIFAR10 dataset
    auto_encoder.save(results_path + os.sep +
                      args.model_name)  # Save our trained model
    trainer.save_loss_plot(results_path + os.sep +
                           args.loss_plot_name)  # Save the loss plot

    evaluator = Evaluator(
        device, auto_encoder,
Exemple #8
0
def trainAndTest(dataset, enable_data_augmentation = False, percentage_similarity_loss = 0, LSTM = False, EPOCHS = 500, enable_same_noise = False, save_output = True, NlogN = True):
    X_train, y_train, X_test, y_test, info = py_ts_data.load_data(dataset, variables_as_channels=True)

    print("Dataset shape: Train: {}, Test: {}".format(X_train.shape, X_test.shape))
    print(np.shape(y_train))

    if enable_data_augmentation or len(X_train) >= 1000:
        # LSTM will greatly extend the training time, so disable it if we have large data
        LSTM = False

    title = "{}-DA:{}-CoefSimilar:{}-LSTM:{}".format(dataset, enable_data_augmentation, percentage_similarity_loss, LSTM)

    ##### Preprocess Data ####
    num_train = len(X_train)
    if num_train < 1000 and enable_data_augmentation:
        X_train= augment_data(X_train, enable_same_noise = enable_same_noise)
        num_train = len(X_train)

    # randomly generate N pairs:
    # NlogN woule be int(num_train * math.log2(num_train))
    if NlogN:
        num_of_pairs = num_train * int(math.log2(num_train))
    else:
        num_of_pairs = num_train
    X, Y = generateRandomPairs(num_of_pairs, X_train)
    # NlogN is too large, for N = 1000, NlogN would be 10K

    normalized_X, normalized_Y, distance = calculatePreSBD(X, Y)

    ###### Training Stage #####
    kwargs = {
        "input_shape": (X_train.shape[1], X_train.shape[2]),
        "filters": [32, 64, 128],
        "kernel_sizes": [5, 5, 5],
        "code_size": 16,
    }

    ae = AutoEncoder(**kwargs)

    # # Training
    loss_history = []
    t1 = time.time()
    for epoch in range(EPOCHS):

        if epoch % 100 == 50:
            print("Epoch {}/{}".format(epoch, EPOCHS))
        total_loss = train_step(normalized_X, normalized_Y, distance, ae, alpha = percentage_similarity_loss, LSTM = LSTM)
        loss_history.append(total_loss)
        # print("Epoch {}: {}".format(epoch, total_loss), end="\r")
        
    print("The training time for dataset {} is: {}".format(dataset, (time.time() - t1) / 60))


    #%%
    plt.clf()
    plt.xlabel("epoch starting from 5")
    plt.ylabel("loss")
    plt.title("Loss vs epoch")
    plt.plot(loss_history[5:])
    # plt.show()
    if save_output:
        if not os.path.isdir(ouput_dir_name + dataset):
            os.mkdir(ouput_dir_name + dataset)
            with open(ouput_dir_name + dataset + "/record.txt", "a") as f:
                f.write("Dataset, Data Augmentation, Coefficient of Similarity Loss, LSTM, EPOCHS, Distance Measure, L2 Distance, 10-nn score\n")
        
        plt.savefig(ouput_dir_name + dataset + "/" + title + "-loss.png")

    #%%
    X_test = normalize(X_test)
    code_test = ae.encode(X_test, LSTM = LSTM)
    decoded_test = ae.decode(code_test)
    plt.clf()
    plt.plot(X_test[0], label = "Original TS")
    plt.plot(decoded_test[0], label = "reconstructed TS")
    if save_output:
        plt.savefig(ouput_dir_name + dataset + "/" + title + "-reconstruction.png")
    # plt.show()

    losses = []
    for ground, predict in zip(X_test, decoded_test):
        losses.append(np.linalg.norm(ground - predict))

    L2_distance = np.array(losses).mean()
    print("Mean L2 distance: {}".format(L2_distance))


    #%%
    from sklearn.neighbors import NearestNeighbors

    nn_x_test = np.squeeze(X_test)
    baseline_nn = NearestNeighbors(n_neighbors=10, metric = SBD).fit(nn_x_test)
    code_nn = NearestNeighbors(n_neighbors=10).fit(code_test)# the default metric is euclidean distance

    # For each item in the test data, find its 11 nearest neighbors in that dataset (the nn is itself)
    baseline_11nn = baseline_nn.kneighbors(nn_x_test, 11, return_distance=False)
    code_11nn     = code_nn.kneighbors(code_test, 11, return_distance=False)

    # On average, how many common items are in the 10nn?
    result = []
    for b, c in zip(baseline_11nn, code_11nn):
        # remove the first nn (itself)
        b = set(b[1:])
        c = set(c[1:])
        result.append(len(b.intersection(c)))

    ten_nn_score = np.array(result).mean()
    print("10-nn score is:", ten_nn_score)
    if save_output:
        with open(ouput_dir_name + dataset + "/record.txt", "a") as f:
            f.write(",".join([dataset, str(enable_data_augmentation), str(percentage_similarity_loss), str(LSTM), str(EPOCHS), distance_measure, str(round(L2_distance,2)), str(round(ten_nn_score,2)), str(NlogN)]) + "\n")
Exemple #9
0
def get_preds(args=None):
    parser = argparse.ArgumentParser(description='Simple testing script.')
    parser.add_argument('--cls_id', help='class id', type=int)
    parser.add_argument('--version', help='model version', type=float)
    parser.add_argument('--resume_epoch',
                        help='trained model for resume',
                        type=int)
    parser.add_argument('--set_name',
                        help='imply attack goal',
                        type=str,
                        default='test_digi_ifgsm_hiding')
    parser.add_argument('--gamma',
                        help='gamma for the SoftL1Loss',
                        type=float,
                        default=9.0)
    parser.add_argument('--checkpoints',
                        help='checkpoints path',
                        type=str,
                        default='voc_checkpoints')
    parser.add_argument('--saves_dir',
                        help='the save path for tested reconstruction error',
                        type=str,
                        default='voc_reconstruction_error')
    parser.add_argument('--batch_size',
                        help='batch size for optimization',
                        type=int,
                        default=1)
    parser = parser.parse_args(args)

    batch_size = parser.batch_size
    if not os.path.isdir(parser.saves_dir):
        os.mkdir(parser.saves_dir)

    cls_name = classes[parser.cls_id]
    parser.checkpoints = '_'.join([parser.checkpoints, cls_name])

    checkpoint_name = os.path.join(
        parser.checkpoints,
        'model_{:1.1f}_epoch{:d}.pt'.format(parser.version,
                                            parser.resume_epoch))
    if not os.path.isfile(checkpoint_name):
        raise ValueError('No checkpoint file {:s}'.format(checkpoint_name))
    assert batch_size == 1

    print('[data prepare]....')
    cls_dir = "../context_profile/voc_detection_{:s}_p10/"\
     .format(cls_name)
    dataloader_test = DataLoader(Fetch(parser.set_name, root_dir=cls_dir),
                                 batch_size=batch_size,
                                 num_workers=1,
                                 shuffle=False)

    print('[model prepare]....')
    use_gpu = torch.cuda.device_count() > 0
    model = AutoEncoder(parser.gamma)
    if use_gpu:
        model = torch.nn.DataParallel(model).cuda()
    model.load_state_dict(torch.load(checkpoint_name))
    print('model loaded from {:s}'.format(checkpoint_name))

    print('[model testing]...')
    model.eval()
    preds = []
    with torch.no_grad():
        for sample in iter(dataloader_test):
            if use_gpu:
                data = sample['data'].cuda().float()

            else:
                data = sample['data'].float()
            loss = model(data)
            preds.append(float(loss))
    preds_name = '_model{:1.1f}_' + parser.set_name
    save_name = os.path.join(parser.saves_dir,
                             cls_name + preds_name.format(parser.version))
    np.save(save_name, preds)
    print('save preds in {:s}'.format(save_name))
Exemple #10
0
def main(args=None):
	parser = argparse.ArgumentParser(description='Simple training script.')
	parser.add_argument('--cls_id', help='class id', type=int)
	parser.add_argument('--version', help='model version', type=float)
	parser.add_argument('--gamma', help='gamma for the SoftL1Loss', type=float, default=9.0)
	parser.add_argument('--lr', help='lr for optimization', type=float, default=1e-4)
	parser.add_argument('--epoches', help='num of epoches for optimization', type=int, default=4)
	parser.add_argument('--resume_epoch', help='trained model for resume', type=int, default=0)
	parser.add_argument('--batch_size', help='batch size for optimization', type=int, default=10)
	parser.add_argument('--checkpoints', help='checkpoints path', type=str, default='voc_checkpoints')
	parser = parser.parse_args(args)

	cls_name = classes[parser.cls_id]
	parser.checkpoints = '_'.join([parser.checkpoints,cls_name])
	if not os.path.isdir(parser.checkpoints):
		os.mkdir(parser.checkpoints)
	print('will save checkpoints in '+parser.checkpoints)
	cls_dir = "../context_profile/voc_detection_{:s}_p10/"\
		.format(cls_name)
	batch_size = parser.batch_size
	print('[data prepare]....')
	dataloader_train = DataLoader(Fetch('train_benign', root_dir=cls_dir), batch_size=batch_size, num_workers=2, shuffle=True)

	print('[model prepare]....')
	use_gpu = torch.cuda.device_count()>0

	model = AutoEncoder(parser.gamma)
	if use_gpu:
		model = torch.nn.DataParallel(model).cuda()
	optimizer = torch.optim.Adam(model.parameters(), lr=parser.lr)
	scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, verbose=True)
	if parser.resume_epoch > 0 :
		checkpoint_name = os.path.join(parser.checkpoints, 'model_{:1.1f}_epoch{:d}.pt'.format(parser.version, parser.resume_epoch))
		if not os.path.isfile(checkpoint_name):
			raise ValueError('No checkpoint file {:s}'.format(checkpoint_name))
		model.load_state_dict(torch.load(checkpoint_name))
		print('model loaded from {:s}'.format(checkpoint_name))

	print('[model training]...')
	loss_hist = []
	epoch_loss = []
	num_iter = len(dataloader_train)
	for epoch_num in range(parser.resume_epoch, parser.epoches):
		model.train()
		for iter_num, sample in enumerate(dataloader_train):
			if True:#try:
				optimizer.zero_grad()
				if use_gpu:
					data = sample['data'].cuda().float()
				else:
					data = sample['data'].float()
					
				loss = model(data).mean()
				if bool(loss==0):
					continue 
				loss.backward()
				torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
				optimizer.step()
				epoch_loss.append(float(loss))
				loss_hist.append(float(loss))
				if iter_num % 30 == 0:
					print('Epoch {:d}/{:d} | Iteration: {:d}/{:d} | loss: {:1.5f}'.format(
						epoch_num+1, parser.epoches, iter_num+1, num_iter, float(loss)))
				if iter_num % 3000 == 0:
					scheduler.step(np.mean(epoch_loss))
					epoch_loss = []
		if epoch_num < 1:
			continue
		checkpoint_name = os.path.join(parser.checkpoints, 'model_{:1.1f}_epoch{:d}.pt'.format(parser.version, epoch_num+1))
		torch.save(model.state_dict(), checkpoint_name)
		print('Model saved as {:s}'.format(checkpoint_name))

	np.save('loss_hist.npy', loss_hist)
Exemple #11
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("-a",
                        "--auto",
                        action="store_true",
                        help="autoencoder")
    parser.add_argument("-e",
                        "--encauto",
                        action="store_true",
                        help="encoder + autoencoder")
    parser.add_argument("-s",
                        "--seqencauto",
                        action="store_true",
                        help="Encoder(sim) + autoencoder(rec)")
    parser.add_argument("l")
    parser.add_argument("filter1")
    parser.add_argument("filter2")
    parser.add_argument("filter3")
    parser.add_argument("epoch")
    parser.add_argument("batch")
    args = parser.parse_args()

    m_type = None
    if args.auto:
        m_type = "autoencoder"
    elif args.encauto:
        m_type = "encoder_autoencoder"
    elif args.seqencauto:
        m_type = "Encoder_sim_autoencoder_rec"
    else:
        raise Exception("model type flag not set")

    model_type_log = "{m_type} lambda={l} filter=[{filter1}, {filter2}, {filter3}] epoch={epoch} batch={batch}".format(
        m_type=m_type,
        l=args.l,
        filter1=args.filter1,
        filter2=args.filter2,
        filter3=args.filter3,
        epoch=args.epoch,
        batch=args.batch)

    filters = [int(args.filter1), int(args.filter2), int(args.filter3)]
    BATCH = int(args.batch)
    EPOCHS = int(args.epoch)
    lam = float(args.l)

    hyperparams["model_type"] = model_type_log
    hyperparams["epochs"] = EPOCHS
    hyperparams["batch_size"] = BATCH

    experiment = Experiment(log_code=False)
    experiment.log_parameters(LAMBDA)
    experiment.log_parameters(hyperparams)

    dataset_name = "GunPoint"

    X_train, y_train, X_test, y_test, info = py_ts_data.load_data(
        dataset_name, variables_as_channels=True)
    print("Dataset shape: Train: {}, Test: {}".format(X_train.shape,
                                                      X_test.shape))

    print(X_train.shape, y_train.shape)
    X_train, y_train = augmentation(X_train, y_train)
    # X_test, y_test = augmentation(X_test, y_test)
    print(X_train.shape, y_train.shape)
    # fig, axs = plt.subplots(1, 2, figsize=(10, 3))
    # axs[0].plot(X_train[200])
    X_train = min_max(X_train, feature_range=(-1, 1))
    # axs[1].plot(X_train[200])
    X_test = min_max(X_test, feature_range=(-1, 1))
    # plt.show()

    kwargs = {
        "input_shape": (X_train.shape[1], X_train.shape[2]),
        # "filters": [32, 64, 128],
        # "filters": [128, 64, 32],
        "filters": filters,
        # "filters": [32, 32, 32],
        # "filters": [32, 32, 16],
        "kernel_sizes": [5, 5, 5],
        "code_size": 16,
    }

    # lambda_to_test = [0.9, ]
    # for l in range(1, 10):
    #     lam = l / 10

    # lam = 0.99
    ae = AutoEncoder(**kwargs)

    input_shape = kwargs["input_shape"]
    code_size = kwargs["code_size"]
    filters = kwargs["filters"]
    kernel_sizes = kwargs["kernel_sizes"]
    encoder = Encoder(input_shape, code_size, filters, kernel_sizes)
    # training

    SHUFFLE_BUFFER = 100
    K = len(set(y_train))

    train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER).batch(BATCH)

    suffix = "lam={lam}".format(lam=lam)
    train(ae, encoder, EPOCHS, train_dataset, suffix, experiment, lam, args)

    code_test = recon_eval(ae, X_test, suffix, experiment)
    sim_eval(X_test, code_test, suffix, experiment)

    cwd = os.path.abspath(os.getcwd())
    metadata = "lambda_{l}_filter_{filter1}{filter2}{filter3}_epoch_{epoch}_batch_{batch}".format(
        l=args.l,
        filter1=args.filter1,
        filter2=args.filter2,
        filter3=args.filter3,
        epoch=args.epoch,
        batch=args.batch)
    encoder_path = os.path.join(cwd, m_type, dataset_name, metadata, "encoder")
    ae_encoder_path = os.path.join(cwd, m_type, dataset_name, metadata,
                                   "auto_encoder")
    ae_decoder_path = os.path.join(cwd, m_type, dataset_name, metadata,
                                   "decoder")

    if not args.auto:
        encoder.save(encoder_path)
    ae.encode.save(ae_encoder_path)
    ae.decode.save(ae_decoder_path)
    sample_evaluation(ae.encode,
                      ae.encode,
                      ae.decode,
                      experiment,
                      suffix,
                      DATA=dataset_name)
        return dat
    else:
        return fetch_mldata("MNIST original")

mnist = fetch_mnist()
mnist.data = mnist.data.astype(np.float32)
mnist.data /= 255
mnist.target = mnist.target.astype(np.int32)

N = 60000
x_train, x_test = np.split(mnist.data, [N])
y_train, y_test = np.split(mnist.target, [N])
N_test = y_test.size

#First layer
ae1 = AutoEncoder(784, n_units)

optimizer1 = optimizers.Adam()
optimizer1.setup(ae1.model.collect_parameters())

for epoch in xrange(1, n_epoch+1):
    print 'epoch', epoch
    perm = np.random.permutation(N)
    for i in xrange(0, N, batchsize):
        sum_loss = 0
        x_batch = x_train[perm[i:i+batchsize]]
        optimizer1.zero_grads()
        loss = ae1.train_once(x_batch)
        loss.backward()
        optimizer1.update()
        sum_loss += float(loss.data) * batchsize
Exemple #13
0
# run the basic auto-encoder

if RUN_AE or RUN_AE_DEFAULT:
    print('auto-encoder')

    create_dir(encoder_dir)
    create_dir(encoder_dir_to_user)

    if RUN_AE:
        # train + test sets
        train_X, test_X, tmp1, tmp2 = train_test_split(vecs_data,
                                                       vecs_data,
                                                       test_size=0.20)

        # train  auto-encoder model
        ae = AutoEncoder(np.array(train_X), root, encoding_dim=ENCODING_DIM)
        ae.encoder_decoder()
        train_results = ae.fit_autoencoder(np.array(train_X),
                                           batch_size=50,
                                           epochs=EPOCHS)
        ae.save_ae()

        # create_dir(encoder_dir)
        # create_dir(encoder_dir_to_user)

        # plot loss and accuracy as a function of time
        plot_loss(train_results.history['loss'],
                  train_results.history['val_loss'], encoder_dir)

        # load trained model
        encoder = load_model(r'./weights_' + root + '/encoder_weights.h5')
Exemple #14
0
    parser.add_argument('--model_name', nargs='?', default='model.pth', type=str, help='The file name of trained model')
    parser.add_argument('--original_images_name', nargs='?', default='original_images.png', type=str, help='The file name of the original images used in evaluation')
    parser.add_argument('--validation_images_name', nargs='?', default='validation_images.png', type=str, help='The file name of the reconstructed images used in evaluation')
    args = parser.parse_args()

    # Dataset and model hyperparameters
    configuration = Configuration.build_from_args(args)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Use GPU if cuda is available

    # Set the result path and create the directory if it doesn't exist
    results_path = '..' + os.sep + args.results_path
    if not os.path.isdir(results_path):
        os.mkdir(results_path)
    
    dataset_path = '..' + os.sep + args.data_path

    dataset = Cifar10Dataset(configuration.batch_size, dataset_path, configuration.shuffle_dataset) # Create an instance of CIFAR10 dataset
    auto_encoder = AutoEncoder(device, configuration).to(device) # Create an AutoEncoder model using our GPU device

    optimizer = optim.Adam(auto_encoder.parameters(), lr=configuration.learning_rate, amsgrad=True) # Create an Adam optimizer instance
    trainer = Trainer(device, auto_encoder, optimizer, dataset) # Create a trainer instance
    trainer.train(configuration.num_training_updates) # Train our model on the CIFAR10 dataset
    auto_encoder.save(results_path + os.sep + args.model_name) # Save our trained model
    trainer.save_loss_plot(results_path + os.sep + args.loss_plot_name) # Save the loss plot

    evaluator = Evaluator(device, auto_encoder, dataset) # Create en Evaluator instance to evaluate our trained model
    evaluator.reconstruct() # Reconstruct our images from the embedded space
    evaluator.save_original_images_plot(results_path + os.sep + args.original_images_name) # Save the original images for comparaison purpose
    evaluator.save_validation_reconstructions_plot(results_path + os.sep + args.validation_images_name) # Reconstruct the decoded images and save them
Exemple #15
0
from Tools.scripts import import_diagnostics
from keras.datasets import cifar10, mnist
from auto_encoder import AutoEncoder

salt_ratio = .05
pepper_ratio = .15

encoder = AutoEncoder.AutoEncoder(data_set=cifar10)
encoder.train()
ENC_UNITS = Config['encoder_units']
DEC_UNITS = Config['decoder_units']
SEQUENCE_PRED = Config['recurrent_layer_output_sequence']

# use_seq_regen_acc = True
# if Config['accuracy_metric'] != 'sequence_regeneration_accuracy':
#   use_seq_regen_acc = False

NUM_EPOCHS = Config['num_epochs']
INIT_EPOCH = Config['starting_epoch']

""" Build Model """
auto_encoder = AutoEncoder(
  embedding_matrix,
  ENC_UNITS,
  DEC_UNITS,
  tokenizer,
  rnn_type=RNN_TYPE,
  enable_eager_execution=False
)

ckpt_path = Config['pre_trained_ckpt']

if ckpt_path != None and len(glob.glob(f'{ckpt_path}.*')) == 2:
  auto_encoder.load_weights(ckpt_path)
  print(f"Loaded checkpoint: {ckpt_path}")

# if use_seq_regen_acc:
#   metric = SequenceRegenerationAccuracy()
# else:
#   metric = tf.keras.metrics.SparseCategoricalAccuracy()
import tensorflow as tf
from tokenizer import get_tokenizer
from embedding_layer import get_embeddings_matrix
from auto_encoder import AutoEncoder
from auto_encoder_config import Config
""" Tokenizer """
tokenizer = get_tokenizer()
""" Embedding Layer """
embedding_matrix = get_embeddings_matrix(tokenizer.get_vocabulary())
""" Model Config """
ENC_UNITS = Config['encoder_units']
DEC_UNITS = Config['decoder_units']
""" Build Model """
auto_encoder = AutoEncoder(embedding_matrix,
                           ENC_UNITS,
                           DEC_UNITS,
                           tokenizer,
                           enable_eager_execution=False)

ckpt_path = Config['pre_trained_ckpt']

if ckpt_path != None and len(glob.glob(f'{ckpt_path}.*')) == 2:
    auto_encoder.load_weights(ckpt_path)

auto_encoder.compile(
    optimizer=tf.optimizers.Adam(),
    # loss=MaskedLoss(sequence=SEQUENCE_PRED),
    loss=tf.keras.losses.MeanSquaredError(
        reduction=tf.keras.losses.Reduction.SUM),
    # metrics=[metric],
    metrics=[tf.keras.metrics.CosineSimilarity()],
Exemple #18
0
class BpcvMain:
    def __init__(self):
        imgs_dir = "./imgs_comp_box"
        imgs_mask_dir = "./imgs_mask_box"

        self.str_imgs_fns = []
        self.str_mask_fns = []

        dirs = []

        for parent, dirnames, filenames in os.walk(imgs_mask_dir):
            for dirname in dirnames:
                dirs.append(dirname)

        for str_dir in dirs:
            str_dir_path = imgs_mask_dir + "/" + str_dir
            for parent, dirnames, filenames in os.walk(str_dir_path):
                for filename in filenames:
                    str_path = str_dir_path + "/" + filename
                    self.str_mask_fns.append(str_path)
                    idx = filename.find(".png")
                    str_img_path = imgs_dir + "/" + str_dir + "/" + filename[:
                                                                             idx] + ".png"
                    self.str_imgs_fns.append(str_img_path)

        #str_pth_fn = "./models/bpcv_encoder_06000.pth"
        str_pth_fn = "./models/bpcv_encoder_12000.pth"

        self.autoencoder = AutoEncoder()

        bpcv_dict = torch.load(str_pth_fn)
        self.autoencoder.load_state_dict(bpcv_dict["net_state"])

        print("continue: ...  n_loop: {0:0>5d}  idx_loop: {1:0>5d}".format(
            bpcv_dict["n_loop"], bpcv_dict["idx_loop"]))
        print(
            ".............................................................................."
        )

        self.win = Gtk.Window()
        self.win.connect("delete-event", self.win_quit)
        self.win.set_default_size(1000, 600)
        self.win.set_title("show imgs")

        self.sw = Gtk.ScrolledWindow()
        self.win.add(self.sw)
        self.sw.set_border_width(2)

        fig = Figure(figsize=(8, 8), dpi=80)
        self.canvas = FigureCanvas(fig)
        self.canvas.set_size_request(1000, 600)
        self.sw.add(self.canvas)
        self.win.show_all()

        self.torch_lock = threading.Lock()
        self.torch_show_data = {}
        self.n_test_imgs = 5
        self.torch_show_data["mess_quit"] = False

        thread_torch = Encoder_Thread(self.update_torch_data,
                                      self.torch_lock,
                                      self.autoencoder,
                                      self.str_imgs_fns,
                                      self.str_mask_fns,
                                      self.torch_show_data,
                                      wh=97,
                                      max_n_loop=3,
                                      n_loop=bpcv_dict["n_loop"],
                                      idx_segment=bpcv_dict["idx_loop"])

        thread_torch.start()

    def update_torch_data(self, str_txt):

        self.torch_lock.acquire()
        np_imgs = self.torch_show_data["np_imgs"]
        np_mask_imgs = self.torch_show_data["np_mask_imgs"]
        np_decoded = self.torch_show_data["np_decoded"]
        self.torch_lock.release()

        np_imgs = np_imgs.transpose((0, 2, 3, 1))

        self.sw.remove(self.canvas)

        axs = [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]

        fig = Figure(figsize=(8, 8), dpi=80)

        for n in range(3):
            for i in range(self.n_test_imgs):
                axs[n][i] = fig.add_subplot(3, self.n_test_imgs,
                                            n * self.n_test_imgs + i + 1)

        for i in range(self.n_test_imgs):
            axs[0][i].imshow(np_imgs[i] * 0.5 + 0.5, cmap='gray')
            axs[1][i].imshow(np_mask_imgs[i][0], cmap='gray')
            axs[2][i].imshow(np_decoded[i][0], cmap='gray')

        self.canvas = FigureCanvas(fig)
        self.canvas.set_size_request(1000, 600)
        self.sw.add(self.canvas)
        self.sw.show_all()

    def win_quit(self, a, b):
        self.torch_lock.acquire()
        self.torch_show_data["mess_quit"] = True
        self.torch_lock.release()
        Gtk.main_quit()
Exemple #19
0
def train_AE(X, y, k, i):
    '''
    :param X: Training Set, n * d
    :param y: Training Labels, n * 1
    :param k: Amount of clusters
    :param i: Used to generate the name of result file, see line 100
    :return: (purity, NMI)
    '''
    lam2 = 10 ** -2
    delta = 0.1
    rate = 0.001
    activation = TANH
    print('-----------------start pre-training 1-----------------')
    pre_1 = AutoEncoder(X, y, k, [1024, 512, 1024], max_iter=5, delta=delta,
                        lam2=lam2, file_name=PARAMS_NAMES[JAFFE], read_params=False, rate=rate,
                        activation=activation)
    pre_1.train()
    print('-----------------start pre-training 2-----------------')
    pre_2 = AutoEncoder(pre_1.H, y, k, [512, 300, 512], max_iter=5, delta=delta,
                        lam2=lam2, file_name=PARAMS_NAMES[JAFFE], read_params=False, rate=rate,
                        activation=activation)
    pre_2.train()
    print('-----------------start training-----------------')
    ae = AutoEncoder(X, y, k, [1024, 512, 300, 512, 1024], max_iter=35, delta=delta,
                     lam2=lam2, file_name=PARAMS_NAMES[JAFFE], read_params=False, rate=rate,
                     decay_threshold=50, activation=activation)
    ae.W[1] = pre_1.W[1]
    ae.W[4] = pre_1.W[2]
    ae.b[1] = pre_1.b[1]
    ae.b[4] = pre_1.b[2]

    ae.W[2] = pre_2.W[1]
    ae.W[3] = pre_2.W[2]
    ae.b[2] = pre_2.b[1]
    ae.b[3] = pre_2.b[2]
    ae.train()
    name = DIR_NAME + 'ae_' + str(i) + '.mat'
    # return train_baseline(ae.H, y, k, 10, KM)
    km = KMeans(k)
    y_pred = km.fit_predict(ae.H, y)
    p, mi = cal_metric2(y_pred, y, k)
    scio.savemat(name, {'y_predicted': y_pred, 'y': y})
    return p, mi
Exemple #20
0
def train(is_debug=False):
    model = get_model()
    sess = tf.Session()
    if is_debug:
        sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
    sess.run(tf.global_variables_initializer())

    if (use_neg_data):
        train_data = open("dataset/atis-2.train.w-intent_with_neg.iob",
                          "r").readlines()
    else:
        train_data = open("dataset/atis-2.train.w-intent.iob", "r").readlines()
    test_data = open("dataset/atis-2.dev.w-intent.iob", "r").readlines()
    train_data_ed = data_pipeline(train_data)
    test_data_ed = data_pipeline(test_data)
    word2index, index2word, intent2index, index2intent = get_info_from_training_data(
        train_data_ed)
    index_train = to_index(train_data_ed, word2index, intent2index)
    index_test = to_index(test_data_ed, word2index, intent2index)
    print("%20s%20s%20s" % ("Epoch#", "Train Loss", "Intent Accuracy"))

    def add_to_vocab_file(fh, data):
        all = set()
        bsize = 16
        for i, batch in enumerate(getBatch(bsize, data)):
            for index in range(len(batch)):
                sen_len = batch[index][1]
                current_vocabs = index_seq2word(batch[index][0],
                                                index2word)[:sen_len]
                for w in current_vocabs:
                    if (w in all):
                        continue
                    f_vocab_list.write(w + "\n")
                    all.add(w)

    def add_to_intent_file(fh, data):
        all = set()
        bsize = 16
        for i, batch in enumerate(getBatch(bsize, data)):
            for index in range(len(batch)):
                sen_len = batch[index][1]
                w = index2intent[batch[index][2]]
                if (w in all):
                    continue
                f_vocab_list.write(w + "\n")
                all.add(w)

    f_vocab_list = open("vocab_list.in", "w")
    add_to_vocab_file(f_vocab_list, index_train)
    add_to_vocab_file(f_vocab_list, index_test)
    f_vocab_list.close()

    f_vocab_list = open("intent_list.in", "w")
    add_to_intent_file(f_vocab_list, index_train)
    add_to_intent_file(f_vocab_list, index_test)
    f_vocab_list.close()

    # saver = tf.train.Saver()

    for epoch in range(epoch_num):
        mean_loss = 0.0
        train_loss = 0.0
        for i, batch in enumerate(getBatch(batch_size, index_train)):
            _, loss, intent, _ = model.step(sess, "train", batch)
            train_loss += loss
        train_loss /= (i + 1)

        intent_accs = []
        for j, batch in enumerate(getBatch(batch_size, index_test)):
            intent, _ = model.step(sess, "test", batch)
            intent_acc = accuracy_score(list(zip(*batch))[2], intent)
            intent_accs.append(intent_acc)
        print("%20d%20f%20f" % (epoch, train_loss, np.average(intent_accs)))

    print("Training auto-encoder...")
    print("%20s%20s%20s%20s%20s" %
          ("Epoch#", "Train Loss", "Neg Data Loss", "Good Data Loss", "Ratio"))
    ae_model = AutoEncoder(model)
    ae_model.tf_init(sess)

    if (train_ae):
        for epoch in range(epoch_num_ae):
            mean_loss = 0.0
            train_loss = 0.0
            for i, batch in enumerate(getBatch(batch_size, index_train)):
                intent, _, _, output_true, output_layer, loss, _ = ae_model.step(
                    sess, "train", batch)
                train_loss += loss
            train_loss /= (i + 1)
            result1, result2 = run_batch_test(ae_model, sess, word2index,
                                              index2intent, index_test, epoch)
            r = (result1 - result2) / result1 * 100
            print("%20d%20f%20f%20f%20f" %
                  (epoch, train_loss, result1, result2, r))
    else:
        run_batch_test(ae_model, sess, word2index, index2intent, index_test, 0)
Exemple #21
0
    if os.path.isfile(path):
        raise RuntimeError('the save path should be a dir')
    if not os.path.isdir(path):
        os.makedirs(path)
    checkpoint_path = os.path.join(path, "autoencoder.ckpt")
    model.saver.save(sess, checkpoint_path)


data = load_corpus(r'D:\标注数据集\新闻标题数据集\train_label0.txt')
save_config(data, 'config.json')

with tf.Graph().as_default():
    sess = tf.Session()
    with sess.as_default():
        auto_encoder_instance = AutoEncoder(
            embedding_size=len(data[0]),
            num_hidden_layer=FLAGS.num_hidden_layer,
            hidden_layers=FLAGS.hidden_layers)
        optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
        grads_and_vars = optimizer.compute_gradients(
            auto_encoder_instance.loss)
        train_op = optimizer.apply_gradients(
            grads_and_vars, global_step=auto_encoder_instance.global_step)

        train_writer = tf.summary.FileWriter('./', sess.graph)
        sess.run(tf.global_variables_initializer())
        batch_manager = BatchManager(data, FLAGS.batch_size)
        best_loss = sys.maxsize
        for epoch in range(FLAGS.epochs):
            step = 0
            for train_batch in batch_manager.iter_batch():
                _, global_step, loss, merge_summary = sess.run(
        out = F.relu(self.lin0(x))
        h = out.unsqueeze(0)

        for i in range(num_step_message_passing):
            prev = h
            m = F.relu(self.conv(out, edge_index, edge_attr))
            h = F.relu(self.lin_h(h))
            out = F.relu(self.lin_h_m(torch.cat(
                (h.squeeze(0), m), dim=1))) + prev
            out = out.squeeze(0)
        out = F.relu(out)

        return out


model = AutoEncoder(node_hidden_dim, Encoder()).to(dev)

data.train_mask = data.val_mask = data.test_mask = None
tr_data, val_data, ts_data = model.split_edges(data,
                                               val_ratio=val_ratio,
                                               test_ratio=test_ratio)
tr_loader = NeighborSampler(tr_data,
                            size=[5] * num_step_message_passing,
                            num_hops=num_step_message_passing,
                            batch_size=batch_size,
                            bipartite=False,
                            shuffle=True)
val_loader = NeighborSampler(val_data,
                             size=[5] * num_step_message_passing,
                             num_hops=num_step_message_passing,
                             batch_size=batch_size,
Exemple #23
0
    def __init__(self):
        imgs_dir = "./imgs_comp_box"
        imgs_mask_dir = "./imgs_mask_box"

        self.str_imgs_fns = []
        self.str_mask_fns = []

        dirs = []

        for parent, dirnames, filenames in os.walk(imgs_mask_dir):
            for dirname in dirnames:
                dirs.append(dirname)

        for str_dir in dirs:
            str_dir_path = imgs_mask_dir + "/" + str_dir
            for parent, dirnames, filenames in os.walk(str_dir_path):
                for filename in filenames:
                    str_path = str_dir_path + "/" + filename
                    self.str_mask_fns.append(str_path)
                    idx = filename.find(".png")
                    str_img_path = imgs_dir + "/" + str_dir + "/" + filename[:
                                                                             idx] + ".png"
                    self.str_imgs_fns.append(str_img_path)

        #str_pth_fn = "./models/bpcv_encoder_06000.pth"
        str_pth_fn = "./models/bpcv_encoder_12000.pth"

        self.autoencoder = AutoEncoder()

        bpcv_dict = torch.load(str_pth_fn)
        self.autoencoder.load_state_dict(bpcv_dict["net_state"])

        print("continue: ...  n_loop: {0:0>5d}  idx_loop: {1:0>5d}".format(
            bpcv_dict["n_loop"], bpcv_dict["idx_loop"]))
        print(
            ".............................................................................."
        )

        self.win = Gtk.Window()
        self.win.connect("delete-event", self.win_quit)
        self.win.set_default_size(1000, 600)
        self.win.set_title("show imgs")

        self.sw = Gtk.ScrolledWindow()
        self.win.add(self.sw)
        self.sw.set_border_width(2)

        fig = Figure(figsize=(8, 8), dpi=80)
        self.canvas = FigureCanvas(fig)
        self.canvas.set_size_request(1000, 600)
        self.sw.add(self.canvas)
        self.win.show_all()

        self.torch_lock = threading.Lock()
        self.torch_show_data = {}
        self.n_test_imgs = 5
        self.torch_show_data["mess_quit"] = False

        thread_torch = Encoder_Thread(self.update_torch_data,
                                      self.torch_lock,
                                      self.autoencoder,
                                      self.str_imgs_fns,
                                      self.str_mask_fns,
                                      self.torch_show_data,
                                      wh=97,
                                      max_n_loop=3,
                                      n_loop=bpcv_dict["n_loop"],
                                      idx_segment=bpcv_dict["idx_loop"])

        thread_torch.start()
epochs = args.epochs
save_name = args.save_name
sample_percent = args.sample_percent
feature_output_path = args.feature_output_path
layers = list(map(int, args.layers.split(',')))

X, labels, file_names = load_data(image_dir,
                                  sample_percent=sample_percent,
                                  return_names=True)
length, width = X[0].shape
input_size = length * width
X = torch.Tensor(X).view(-1, input_size).type(torch.float32)
X /= 255

if load_path is None:
    model = AutoEncoder([input_size] + layers)
    model.train(X=X, batch_size=batch_size, epochs=epochs, verbose=True)
else:
    model = AutoEncoder.load(load_path)

if feature_output_path is not None:
    print('Saving learned features...')
    new_features = model(X)
    new_features = new_features.detach().numpy()
    root_dir = os.getcwd()
    try:
        os.mkdir(root_dir + '\\' + feature_output_path)
    except FileExistsError:  # if directory already exists thats ok
        pass
    for label in np.unique(labels):
        try:
Exemple #25
0
def main():

    experiment = Experiment(log_code=False)
    experiment.log_parameters(LAMBDA)
    experiment.log_parameters(hyperparams)

    dataset_name = "GunPoint"

    X_train, y_train, X_test, y_test, info = py_ts_data.load_data(
        dataset_name, variables_as_channels=True)
    print("Dataset shape: Train: {}, Test: {}".format(X_train.shape,
                                                      X_test.shape))

    print(X_train.shape, y_train.shape)
    X_train, y_train = augmentation(X_train, y_train)
    # X_test, y_test = augmentation(X_test, y_test)
    print(X_train.shape, y_train.shape)
    # fig, axs = plt.subplots(1, 2, figsize=(10, 3))
    # axs[0].plot(X_train[200])
    X_train = min_max(X_train, feature_range=(-1, 1))
    # axs[1].plot(X_train[200])
    X_test = min_max(X_test, feature_range=(-1, 1))
    # plt.show()

    kwargs = {
        "input_shape": (X_train.shape[1], X_train.shape[2]),
        # "filters": [32, 64, 128],
        # "filters": [128, 64, 32],
        "filters": [64, 32, 16],
        # "filters": [32, 32, 32],
        # "filters": [32, 32, 16],
        "kernel_sizes": [5, 5, 5],
        "code_size": 16,
    }

    # lambda_to_test = [0.9, ]
    # for l in range(1, 10):
    #     lam = l / 10

    lam = 0.99
    ae = AutoEncoder(**kwargs)

    input_shape = kwargs["input_shape"]
    code_size = kwargs["code_size"]
    filters = kwargs["filters"]
    kernel_sizes = kwargs["kernel_sizes"]
    encoder = Encoder(input_shape, code_size, filters, kernel_sizes)
    # training

    SHUFFLE_BUFFER = 100
    K = len(set(y_train))

    train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER).batch(BATCH)

    suffix = "lam={lam}".format(lam=lam)
    train(ae, encoder, EPOCHS, train_dataset, suffix, experiment, lam)

    code_test = recon_eval(ae, X_test, suffix, experiment)
    sim_eval(X_test, code_test, suffix, experiment)

    encoder.save(
        r"C:\Users\jiang\Desktop\2270\cs227_final_project\enc_auto_643216_50_50\GunPoint\encoder"
    )
    ae.encode.save(
        r"C:\Users\jiang\Desktop\2270\cs227_final_project\enc_auto_643216_50_50\GunPoint\auto_encoder"
    )
    ae.decode.save(
        r"C:\Users\jiang\Desktop\2270\cs227_final_project\enc_auto_643216_50_50\GunPoint\decoder"
    )
    sample_evaluation(ae.encode,
                      ae.encode,
                      ae.decode,
                      experiment,
                      suffix,
                      DATA=dataset_name)