def predict(model_file, data_file, image_prefix):
    generator = make_generator_model(input_width=16,
                                     input_height=22,
                                     output_width=200,
                                     output_height=128)
    generator.load_weights(model_file)
    input_array = np.load(data_file, allow_pickle=True)
    input_array = np.concatenate([
        np.expand_dims(i, axis=0) for i in input_array if i.shape == (22, 16)
    ],
                                 axis=0)
    input_array = input_array + 0.01
    input_array = (input_array - np.expand_dims(
        input_array.mean(axis=1), axis=1)) / np.expand_dims(
            input_array.std(axis=1), axis=1)
    input_tensor = tf.constant(input_array)
    n_sample, w, h = input_tensor.shape
    input_tensor = tf.reshape(input_tensor, shape=(n_sample, w * h))
    generated_pic = generator(input_tensor, training=False)
    generated_pic = np.array(generated_pic)
    generated_pic = generated_pic * 127.5 + 127.5
    generated_pic = np.round(generated_pic, 0).astype(np.int32)

    for idx, pic in enumerate(generated_pic):
        plt.imshow(np.flipud(np.rot90(pic)))
        plt.savefig(image_prefix + str(idx) + ".png")
示例#2
0
def run(train_data, valid_data, len_size, scale, EPOCHS, root_path='./', load_model_dir=None, saved_model_dir=None, log_dir=None, summary=False):
    if log_dir is None:
        log_dir = os.path.join(root_path, 'our_model', 'logs', 'model')
    logging.info(train_data)
    logging.info(valid_data)
    # get generator model and discriminator model
    Gen = model.make_generator_model(len_high_size=len_size, scale=scale)
    Dis = model.make_discriminator_model(len_high_size=len_size, scale=scale)
    if load_model_dir is not None:
    #load_model_dir = os.path.join(root_path, 'our_model', 'saved_model')
        file_path = os.path.join(load_model_dir, 'gen_model_'+str(len_size), 'gen_weights')
        if os.path.exists(file_path):
            Gen.load_weights(file_path)
        else:
            logging.info("generator doesn't exist. create a new one.")
        file_path = os.path.join(load_model_dir, 'dis_model_'+str(len_size), 'dis_weights')
        if os.path.exists(file_path):
            Dis.load_weights(file_path)
        else:
            logging.info("discriminator model doesn't exist. create a new one")

    if summary:
        logging.info(Gen.summary())
        tf.keras.utils.plot_model(Gen, to_file='G.png', show_shapes=True)
        logging.info(Dis.summary())
        tf.keras.utils.plot_model(Dis, to_file='D.png', show_shapes=True)

    if saved_model_dir is None:
        saved_model_dir = os.path.join(root_path, 'our_model', 'saved_model')

    model.fit(Gen, Dis, train_data, EPOCHS, len_size, scale,
                valid_data, log_dir=log_dir, saved_model_dir=saved_model_dir)

    file_path = os.path.join(
        saved_model_dir, 'gen_model_'+str(len_size), 'gen_weights')
    Gen.save_weights(file_path)

    file_path = os.path.join(
        saved_model_dir, 'dis_model_'+str(len_size), 'dis_weights')
    Dis.save_weights(file_path)
def main(input_array, input_width=16, input_height=22, output_width=100, output_height=64):
    input_array = (input_array - 127.5) / 127.5
    train_dataset = tf.data.Dataset.from_tensor_slices(input_array).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

    generator = make_generator_model(input_width, input_height, output_width, output_height)
    discriminator = make_discriminator_model(output_width, output_height)

    generator_optimizer = tf.keras.optimizers.Adam(1e-4)
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

    checkpoint_dir = './training_checkpoints'
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                     discriminator_optimizer=discriminator_optimizer,
                                     generator=generator,
                                     discriminator=discriminator)

    seed = tf.random.normal([num_examples_to_generate, noise_dim])

    @tf.function
    def train_step(images):
        noise = tf.random.normal([BATCH_SIZE, noise_dim])

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            generated_images = generator(noise, training=True)

            real_output = discriminator(images, training=True)
            fake_output = discriminator(generated_images, training=True)

            gen_loss = generator_loss(fake_output)
            disc_loss = discriminator_loss(real_output, fake_output)

        gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

        generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    def generate_and_save_images(model, epoch, test_input):
        # Notice `training` is set to False.
        # This is so all layers run in inference mode (batchnorm).
        predictions = model(test_input, training=False)

        fig = plt.figure(figsize=(4, 4))

        for i in range(predictions.shape[0]):
            plt.subplot(4, 4, i + 1)
            plt.imshow(predictions[i, :, :, :3] * 127.5 + 127.5)
            plt.axis('off')

        plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
        plt.show()

    def train(dataset, epochs):
        for epoch in range(epochs):
            start = time.time()

            for image_batch in dataset:
                train_step(image_batch)

            # Produce images for the GIF as we go
            generate_and_save_images(generator,
                                     epoch + 1,
                                     seed)

            # Save the model every 15 epochs
            if (epoch + 1) % 20 == 0:
                checkpoint.save(file_prefix=checkpoint_prefix)
                generator.save_weights("model_epoch_"+ str(epoch) + ".h5")

            print('Time for epoch {} is {} sec'.format(epoch + 1, time.time() - start))

        # Generate after the final epoch
        generate_and_save_images(generator,
                                 epochs,
                                 seed)

    train(train_dataset, EPOCHS)
示例#4
0
import tensorflow as tf
import model
import matplotlib.pyplot as plt
import datetime
import os
import time
import numpy as np

from define import *

# build model
generator = model.make_generator_model()

# checkpoint
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator=generator)

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

# make seed
seed = np.load("./cherry_pick.npy")

current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
test_summary_writer = tf.summary.create_file_writer(
    f'logs/{current_time}/vector_arithmetic/')

for i in range(10):
    # seed1 = tf.random.normal([noise_dim])
    # seed2 = tf.random.normal([noise_dim])
    # seed3 = tf.random.normal([noise_dim])
示例#5
0
def predict(path='./data',
            raw_path='raw',
            raw_file='Rao2014-GM12878-DpnII-allreps-filtered.10kb.cool',
            model_path=None,
            sr_path='output',
            chromosome='22',
            scale=4,
            len_size=200,
            genomic_distance=2000000,
            start=None,
            end=None,
            draw_out=False):
    sr_file = raw_file.split('-')[0] + '_' + raw_file.split(
        '-')[1] + '_' + raw_file.split('-')[2] + '_' + raw_file.split('.')[1]
    directory_sr = os.path.join(path, sr_path, sr_file, 'SR',
                                'chr' + chromosome)
    if not os.path.exists(directory_sr):
        os.makedirs(directory_sr)

    # get generator model
    if model_path is None:
        gan_model_weights_path = './our_model/saved_model/gen_model_' + \
            str(len_size)+'/gen_weights'
    else:
        gan_model_weights_path = model_path
    Generator = model.make_generator_model(len_high_size=len_size, scale=4)
    Generator.load_weights(gan_model_weights_path)
    print(Generator)

    name = os.path.join(path, raw_path, raw_file)
    c = cooler.Cooler(name)
    resolution = c.binsize
    if 'chr' + chromosome not in c.chromnames:
        return
    mat = c.matrix(balance=True).fetch('chr' + chromosome)

    [Mh, idx] = operations.remove_zeros(mat)
    print('Shape HR: {}'.format(Mh.shape), end='\t')

    if start is None:
        start = 0
    if end is None:
        end = Mh.shape[0]

    Mh = Mh[start:end, start:end]
    print('MH: {}'.format(Mh.shape), end='\t')

    Ml = operations.sampling_hic(Mh, scale**2, fix_seed=True)
    print('ML: {}'.format(Ml.shape))

    # Normalization
    # the input should not be type of np.matrix!
    Ml = np.asarray(Ml)
    Mh = np.asarray(Mh)
    Ml, Dl = operations.scn_normalization(Ml, max_iter=3000)
    print('Dl shape:{}'.format(Dl.shape))
    Mh, Dh = operations.scn_normalization(Mh, max_iter=3000)
    print('Dh shape:{}'.format(Dh.shape))
    #Ml = np.divide((Ml-Ml.min()), (Ml.max()-Ml.min()), dtype=float, out=np.zeros_like(Ml), where=(Ml.max()-Ml.min()) != 0)
    #Mh = np.divide((Mh-Mh.min()), (Mh.max()-Mh.min()), dtype=float, out=np.zeros_like(Mh), where=(Mh.max()-Mh.min()) != 0)

    if genomic_distance is None:
        max_boundary = None
    else:
        max_boundary = np.ceil(genomic_distance / (resolution))
    residual = Mh.shape[0] % int(len_size / 2)
    print('residual: {}'.format(residual))

    hic_hr_front, index_1d_2d_front, index_2d_1d_front = operations.divide_pieces_hic(
        Mh[0:-residual, 0:-residual],
        block_size=len_size,
        max_distance=max_boundary,
        save_file=False)
    hic_hr_front = np.asarray(hic_hr_front, dtype=np.float32)
    print('shape hic_hr front: ', hic_hr_front.shape)
    true_hic_hr_front = hic_hr_front
    print('shape true hic_hr: ', true_hic_hr_front.shape)

    hic_hr_offset, index_1d_2d_offset, index_2d_1d_offset = operations.divide_pieces_hic(
        Mh[residual:, residual:],
        block_size=len_size,
        max_distance=max_boundary,
        save_file=False)
    hic_hr_offset = np.asarray(hic_hr_offset, dtype=np.float32)
    print('shape hic_hr offset: ', hic_hr_offset.shape)
    true_hic_hr_offset = hic_hr_offset
    print('shape true hic_hr: ', true_hic_hr_offset.shape)

    Ml_front = Ml[0:-residual, 0:-residual]
    hic_lr_front, _, _ = operations.divide_pieces_hic(
        Ml_front,
        block_size=len_size,
        max_distance=max_boundary,
        save_file=False)
    hic_lr_front = np.asarray(hic_lr_front, dtype=np.float32)
    print('shape hic_lr: ', hic_lr_front.shape)
    hic_lr_ds = tf.data.Dataset.from_tensor_slices(
        hic_lr_front[..., np.newaxis]).batch(9)
    predict_hic_hr_front = None
    for i, input_data in enumerate(hic_lr_ds):
        [_, _, tmp, _, _] = Generator(input_data, training=False)
        if predict_hic_hr_front is None:
            predict_hic_hr_front = tmp.numpy()
        else:
            predict_hic_hr_front = np.concatenate(
                (predict_hic_hr_front, tmp.numpy()), axis=0)

    predict_hic_hr_front = np.squeeze(predict_hic_hr_front, axis=3)
    print('Shape of prediction front: ', predict_hic_hr_front.shape)

    file_path = os.path.join(directory_sr, sr_file + '_chr' + chromosome)
    np.savez_compressed(file_path + '_front.npz',
                        predict_hic=predict_hic_hr_front,
                        true_hic=true_hic_hr_front,
                        index_1D_2D=index_1d_2d_front,
                        index_2D_1D=index_2d_1d_front,
                        start_id=start,
                        end_id=end,
                        residual=0)

    predict_hic_hr_merge_front = operations.merge_hic(
        predict_hic_hr_front,
        index_1D_2D=index_1d_2d_front,
        max_distance=max_boundary)
    print('Shape of merge predict hic HR front',
          predict_hic_hr_merge_front.shape)

    Ml_offset = Ml[residual:, residual:]
    hic_lr_offset, _, _ = operations.divide_pieces_hic(
        Ml_offset,
        block_size=len_size,
        max_distance=max_boundary,
        save_file=False)
    hic_lr_offset = np.asarray(hic_lr_offset, dtype=np.float32)
    print('Shape hic_lr_offset: ', hic_lr_offset.shape)
    hic_lr_ds = tf.data.Dataset.from_tensor_slices(
        hic_lr_offset[..., np.newaxis]).batch(9)
    predict_hic_hr_offset = None
    for i, input_data in enumerate(hic_lr_ds):
        [_, _, tmp, _, _] = Generator(input_data, training=False)
        if predict_hic_hr_offset is None:
            predict_hic_hr_offset = tmp.numpy()
        else:
            predict_hic_hr_offset = np.concatenate(
                (predict_hic_hr_offset, tmp.numpy()), axis=0)

    predict_hic_hr_offset = np.squeeze(predict_hic_hr_offset, axis=3)
    print('Shape of prediction offset: ', predict_hic_hr_offset.shape)

    file_path = os.path.join(directory_sr, sr_file + '_chr' + chromosome)
    np.savez_compressed(file_path + '_offset.npz',
                        predict_hic=predict_hic_hr_offset,
                        true_hic=true_hic_hr_offset,
                        index_1D_2D=index_1d_2d_offset,
                        index_2D_1D=index_2d_1d_offset,
                        start_id=start,
                        end_id=end,
                        residual=residual)
    predict_hic_hr_merge_offset = operations.merge_hic(
        predict_hic_hr_offset,
        index_1D_2D=index_1d_2d_offset,
        max_distance=max_boundary)
    print('Shape of merge predict hic hr offset: ',
          predict_hic_hr_merge_offset.shape)

    predict_hic_hr_merge = np.zeros(Mh.shape)
    predict_hic_hr_merge = addAtPos(predict_hic_hr_merge,
                                    predict_hic_hr_merge_front, (0, 0))
    predict_hic_hr_merge = addAtPos(predict_hic_hr_merge,
                                    predict_hic_hr_merge_offset,
                                    (residual, residual))

    ave = np.ones_like(predict_hic_hr_merge)
    twice = np.ones(shape=(Mh.shape[0] - 2 * residual,
                           Mh.shape[1] - 2 * residual))
    ave = addAtPos(ave, twice, (residual, residual))
    predict_hic_hr_merge = predict_hic_hr_merge / ave

    true_hic_hr_merge_front = operations.merge_hic(
        true_hic_hr_front,
        index_1D_2D=index_1d_2d_front,
        max_distance=max_boundary)
    true_hic_hr_merge_offset = operations.merge_hic(
        true_hic_hr_offset,
        index_1D_2D=index_1d_2d_offset,
        max_distance=max_boundary)
    true_hic_hr_merge = np.zeros(Mh.shape)
    true_hic_hr_merge = addAtPos(true_hic_hr_merge, true_hic_hr_merge_front,
                                 (0, 0))
    true_hic_hr_merge = addAtPos(true_hic_hr_merge, true_hic_hr_merge_offset,
                                 (residual, residual))
    true_hic_hr_merge = true_hic_hr_merge / ave
    print('Shape of true merge hic hr: {}'.format(true_hic_hr_merge.shape))
    '''# chrop Mh
    residual = Mh.shape[0] % int(len_size/2)
    print('residual: {}'.format(residual))
    if residual > 0:
        Mh = Mh[0:-residual, 0:-residual]
        # true_hic_hr_merge = true_hic_hr_merge[0:-residual, 0:-residual]
        Dh = Dh[0:-residual]
        Dl = Dl[0:-residual]'''

    # recover M from scn to origin
    Mh = operations.scn_recover(Mh, Dh)
    true_hic_hr_merge = operations.scn_recover(true_hic_hr_merge, Dh)
    predict_hic_hr_merge = operations.scn_recover(predict_hic_hr_merge, Dh)

    # remove diag and off diag
    k = max_boundary.astype(int)
    Mh = operations.filter_diag_boundary(Mh, diag_k=0, boundary_k=k)
    true_hic_hr_merge = operations.filter_diag_boundary(true_hic_hr_merge,
                                                        diag_k=0,
                                                        boundary_k=k)
    predict_hic_hr_merge = operations.filter_diag_boundary(
        predict_hic_hr_merge, diag_k=0, boundary_k=k)

    print('sum Mh:', np.sum(np.abs(Mh)))
    print('sum true merge:', np.sum(np.abs(true_hic_hr_merge)))
    print('sum pred merge:', np.sum(np.abs(predict_hic_hr_merge)))
    diff = np.abs(Mh - predict_hic_hr_merge)
    print('sum Mh - pred square error: {:.5}'.format(np.sum(diff**2)))
    diff = np.abs(true_hic_hr_merge - predict_hic_hr_merge)
    print('sum true merge - pred square error: {:.5}'.format(np.sum(diff**2)))
    diff = np.abs(Mh - true_hic_hr_merge)
    print('sum Mh - true merge square error: {:.5}'.format(np.sum(diff**2)))

    directory_sr = os.path.join(path, sr_path, sr_file, 'SR')
    compact = idx[0:-residual]
    file = 'predict_chr{}_{}.npz'.format(chromosome, resolution)
    np.savez_compressed(os.path.join(directory_sr, file),
                        hic=predict_hic_hr_merge,
                        compact=compact)
    print('Saving file: {}, at {}'.format(file, directory_sr))
    directory_sr = os.path.join(path, sr_path, sr_file, 'SR',
                                'chr' + chromosome)
    file = 'true_chr{}_{}.npz'.format(chromosome, resolution)
    np.savez_compressed(os.path.join(directory_sr, file),
                        hic=Mh,
                        compact=compact)
    print('Saving file: {}, at {}'.format(file, directory_sr))
    '''file = 'truemerge_chr{}_{}.npz'.format(chromosome, resolution)
    np.savez_compressed(os.path.join(directory_sr, file), hic=true_hic_hr_merge, compact=compact)
    print('Saving file:', file)'''

    if draw_out:
        predict_hic_hr_merge = predict_hic_hr_merge[::10, ::10]
        Mh = Mh[::10, ::10]
        fig, axs = plt.subplots(1, 2, figsize=(8, 15))
        # , cmap='RdBu_r'
        ax = axs[0].imshow(np.log1p(predict_hic_hr_merge), cmap='RdBu')
        axs[0].set_title('predict')
        ax = axs[1].imshow(np.log1p(Mh), cmap='RdBu')  # , cmap='RdBu_r'
        axs[1].set_title('true')
        plt.tight_layout()
        plt.show()
示例#6
0
def extract_features(path='./data',
            raw_path='raw',
            raw_file='Rao2014-GM12878-DpnII-allreps-filtered.10kb.cool',
            model_path=None,
            sr_path='output',
            chromosome='22',
            scale=4,
            len_size=200,
            genomic_distance=2000000,
            start=None, end=None):
    sr_file = raw_file.split('-')[0] + '_' + raw_file.split('-')[1] + '_' + raw_file.split('-')[2] + '_' + raw_file.split('.')[1]
    directory_sr = os.path.join(path, sr_path, sr_file, 'extract_features')
    if not os.path.exists(directory_sr):
        os.makedirs(directory_sr)

    # get generator model
    if model_path is None:
        gan_model_weights_path = './our_model/saved_model/gen_model_' + \
            str(len_size)+'/gen_weights'
    else:
        gan_model_weights_path = model_path
    Generator = model.make_generator_model(len_high_size=len_size, scale=4)
    Generator.load_weights(gan_model_weights_path)
    print(Generator)

    name = os.path.join(path, raw_path, raw_file)
    c = cooler.Cooler(name)
    resolution = c.binsize
    mat = c.matrix(balance=True).fetch('chr'+chromosome)

    [Mh, idx] = operations.remove_zeros(mat)

    nonzero_idx = np.array(np.where(idx)).flatten()
    print(idx.shape)
    print(nonzero_idx.shape)
    print(nonzero_idx)
    print(start, end, nonzero_idx[start:end])
    print('Shape HR: {}'.format(Mh.shape), end='\t')

    if start is None:
        start = 0
    if end is None:
        end = Mh.shape[0]

    Mh = Mh[start:end, start:end]
    print('MH: {}'.format(Mh.shape), end='\t')

    Ml = operations.sampling_hic(Mh, scale**2, fix_seed=True)
    print('ML: {}'.format(Ml.shape))

    # Normalization
    # the input should not be type of np.matrix!
    Ml = np.asarray(Ml)
    Mh = np.asarray(Mh)
    Ml, Dl = operations.scn_normalization(Ml, max_iter=3000)
    print('Dl shape:{}'.format(Dl.shape))
    Mh, Dh = operations.scn_normalization(Mh, max_iter=3000)
    print('Dh shape:{}'.format(Dh.shape))

    if genomic_distance is None:
        max_boundary = None
    else:
        max_boundary = np.ceil(genomic_distance/(resolution))
    # residual = Mh.shape[0] % int(len_size/2)
    # print('residual: {}'.format(residual))

    hic_hr, index_1d_2d, index_2d_1d = operations.divide_pieces_hic( Mh, block_size=len_size, max_distance=max_boundary, save_file=False)
    hic_hr = np.asarray(hic_hr, dtype=np.float32)
    print('shape hic_hr front: ', hic_hr.shape)
    true_hic_hr = hic_hr
    print('shape true hic_hr: ', true_hic_hr.shape)

    hic_lr, _, _ = operations.divide_pieces_hic( Ml, block_size=len_size, max_distance=max_boundary, save_file=False)
    hic_lr = np.asarray(hic_lr, dtype=np.float32)
    print('shape hic_lr: ', hic_lr.shape)
    hic_lr_ds = tf.data.Dataset.from_tensor_slices( hic_lr[..., np.newaxis]).batch(9)
    predict_hic_hr = None
    for i, input_data in enumerate(hic_lr_ds):
        [out_low_x2, out_low_x4, tmp, low_x2, low_x4] = Generator(input_data, training=False)
        if predict_hic_hr is None:
            predict_hic_hr = tmp.numpy()
        else:
            predict_hic_hr = np.concatenate( (predict_hic_hr, tmp.numpy()), axis=0)

    layer_name = 'dsd_x2'
    for i, data in enumerate(hic_lr_ds):
        intermediate_layer_model = keras.Model(inputs=Generator.get_layer(layer_name).input,
                                        outputs=Generator.get_layer(layer_name).output)
        intermediate_x2 = intermediate_layer_model(data)

    layer_name = 'dsd_x4'
    for i, data in enumerate(hic_lr_ds):
        intermediate_layer_model = keras.Model(inputs=Generator.get_layer(layer_name).input,
                                    outputs=Generator.get_layer(layer_name).output)
        intermediate_x4 = intermediate_layer_model(data)

    predict_hic_hr = np.squeeze(predict_hic_hr, axis=3)
    print('Shape of prediction front: ', predict_hic_hr.shape)

    file_path = os.path.join(directory_sr, sr_file+'_chr'+chromosome)
    np.savez_compressed(file_path+'.npz', predict_hic=predict_hic_hr, true_hic=true_hic_hr,
                        index_1D_2D=index_1d_2d, index_2D_1D=index_2d_1d,
                        start_id=start, end_id=end, residual=0)

    predict_hic_hr_merge = operations.merge_hic(predict_hic_hr, index_1D_2D=index_1d_2d, max_distance=max_boundary)
    print('Shape of merge predict hic HR', predict_hic_hr_merge.shape)

    true_hic_hr_merge = operations.merge_hic( true_hic_hr, index_1D_2D=index_1d_2d, max_distance=max_boundary)
    print('Shape of merge true hic HR: {}'.format(true_hic_hr_merge.shape))

    # recover M from scn to origin
    Mh = operations.scn_recover(Mh, Dh)
    true_hic_hr_merge = operations.scn_recover(true_hic_hr_merge, Dh)
    predict_hic_hr_merge = operations.scn_recover(predict_hic_hr_merge, Dh)

    # remove diag and off diag
    k = max_boundary.astype(int)
    Mh = operations.filter_diag_boundary(Mh, diag_k=1, boundary_k=None)
    true_hic_hr_merge = operations.filter_diag_boundary(true_hic_hr_merge, diag_k=1, boundary_k=None)
    predict_hic_hr_merge = operations.filter_diag_boundary(predict_hic_hr_merge, diag_k=1, boundary_k=None)

    print('sum Mh:', np.sum(np.abs(Mh)))
    print('sum true merge:', np.sum(np.abs(true_hic_hr_merge)))
    print('sum pred merge:', np.sum(np.abs(predict_hic_hr_merge)))
    diff = np.abs(Mh-predict_hic_hr_merge)
    print('sum Mh - pred square error: {:.5}'.format(np.sum(diff**2)))
    diff = np.abs(true_hic_hr_merge-predict_hic_hr_merge)
    print('sum true merge - pred square error: {:.5}'.format(np.sum(diff**2)))
    diff = np.abs(Mh-true_hic_hr_merge)
    print('sum Mh - true merge square error: {:.5}'.format(np.sum(diff**2)))

    compact = idx
    file = 'predict_chr{}_{}.npz'.format(chromosome, resolution)
    np.savez_compressed(os.path.join(directory_sr, file), hic=predict_hic_hr_merge, compact=compact)
    print('Saving file: {}, at {}'.format(file, directory_sr))
    file = 'true_chr{}_{}.npz'.format(chromosome, resolution)
    np.savez_compressed(os.path.join(directory_sr, file), hic=Mh, compact=compact)
    print('Saving file: {}, at {}'.format(file, directory_sr))


    # predict_hic_hr_merge = predict_hic_hr_merge[::10, ::10]
    # Mh = Mh[::10, ::10]
    """fig, axs = plt.subplots(1, 2, figsize=(15, 8))
    # , cmap='RdBu_r'
    ax = axs[0].imshow(np.log1p(predict_hic_hr_merge), cmap='OrRd')
    axs[0].set_title('predict')
    ax = axs[1].imshow(np.log1p(Mh), cmap='OrRd')  # , cmap='RdBu_r'
    axs[1].set_title('true')
    plt.tight_layout()
    fig.colorbar(ax, ax=axs, shrink=0.3)

    '''cmap = mpl.cm.OrRd
    norm = mpl.colors.Normalize(vmin=0, vmax=0.7)
    fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=axs, shrink=0.3)'''

    output = os.path.join(directory_sr, 'prediction_chr{}_{}_{}.jpg'.format(chromosome, start, end))
    plt.savefig(output, format='jpg')"""

    nr,nc = 6,8
    fig, axs = plt.subplots(nrows=nr, ncols=nc, figsize=(25, 20))
    interm = intermediate_x2.numpy()
    interm = np.squeeze(interm, axis=0)
    interm = (interm-interm.min())/(interm.max()-interm.min())
    sum_interm = np.sum(interm, axis=(0,1))
    interm = interm[:,:, sum_interm.argsort()]
    interm = interm[:,:,::-1]
    print(interm.shape)
    for i in np.arange(0, nr):
        for j in np.arange(0, nc):
            idx = 40 + (i*nc+j)*2
            if idx > interm.shape[2]:
                continue
            m = interm[:,:, idx]
            m = np.squeeze(m)
            pcm = axs[i, j].imshow(np.log1p(m), cmap='OrRd')
    plt.tight_layout()
    fig.colorbar(pcm, ax=axs, shrink=0.3)
    output = os.path.join(directory_sr, 'features_x2_chr{}_{}_{}.jpg'.format(chromosome, start, end))
    plt.savefig(output, format='jpg')

    nr,nc = 5,7
    fig, axs = plt.subplots(nrows=nr, ncols=nc, figsize=(25, 20))
    interm = intermediate_x4.numpy()
    interm = np.squeeze(interm, axis=0)
    interm = (interm-interm.min())/(interm.max()-interm.min())
    sum_interm = np.sum(interm, axis=(0,1))
    interm = interm[:,:, sum_interm.argsort()]
    interm = interm[:,:,::-1]
    for i in np.arange(0, nr):
        for j in np.arange(0, nc):
            idx = 10 + (i*nc+j)
            if idx > interm.shape[2]:
                continue
            m = interm[:,:, idx]
            m = np.squeeze(m)
            pcm = axs[i, j].imshow(np.log1p(m), cmap='OrRd')

    plt.tight_layout()
    """cmap = mpl.cm.seismic
    norm = mpl.colors.Normalize(vmin=0, vmax=0.7)
    fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=axs, shrink=0.3)"""
    fig.colorbar(pcm, ax=axs, shrink=0.3)
    output = os.path.join(directory_sr, 'features_x4_chr{}_{}_{}.jpg'.format(chromosome, start, end))
    plt.savefig(output, format='jpg')

    name = 'prediection_chr{}_{}_{}'.format(chromosome, start, end)
    plot_hic_matrix(predict_hic_hr_merge, directory_sr, name, title='Prediction Hi-C (10kb)')

    name = 'prediction_x2_chr{}_{}_{}'.format(chromosome, start, end)
    plot_hic_matrix(out_low_x2, directory_sr, name, 'Prediction Hi-C (20kb)')

    name = 'prediction_x4_chr{}_{}_{}'.format(chromosome, start, end)
    plot_hic_matrix(out_low_x4, directory_sr, name, 'Prediction Hi-C (40kb)')

    name = 'HR_chr{}_{}_{}'.format(chromosome, start, end)
    plot_hic_matrix(Mh, directory_sr, name, 'True Hi-C (10kb)')

    name = 'true_x2_chr{}_{}_{}'.format(chromosome, start, end)
    plot_hic_matrix(low_x2, directory_sr, name, 'True Hi-C (40kb, by average pooling)')

    name = 'true_x4_chr{}_{}_{}'.format(chromosome, start, end)
    plot_hic_matrix(low_x4, directory_sr, name, 'True Hi-C (40kb, by average pooling)')

    name = 'LR_chr{}_{}_{}'.format(chromosome, start, end)
    plot_hic_matrix(Ml, directory_sr, name, title = 'True Hi-C (40kb, x16 downsampling)')