示例#1
0
    def __init__(self, config, rng):
        self.config = config
        self.rng = rng
        self.task = config.task
        self.model_dir = config.model_dir
        self.gpu_memory_fraction = config.gpu_memory_fraction
        self.checkpoint_secs = config.checkpoint_secs
        self.log_step = config.log_step
        self.num_epoch = config.num_epochs

        ## import data Loader ##
        data_dir = config.data_dir
        dataset_name = config.task
        batch_size = config.batch_size
        num_time_steps = config.num_time_steps
        num_node = config.num_node
        self.data_loader = BatchLoader(data_dir, dataset_name, batch_size,
                                       num_time_steps, num_node)

        ## Need to think about how we construct adj matrix(W)
        W = self.data_loader.adj
        laplacian = W / W.max()
        laplacian = scipy.sparse.csr_matrix(laplacian, dtype=np.float32)
        lmax = graph.lmax(laplacian)

        #idx2char = batchLoader_.idx2char
        #char2idx = batchLoader_.char2idx
        #batch_x, batch_y = batchLoader_.next_batch(0) 0:train 1:valid 2:test
        #batchLoader_.reset_batch_pointer(0)

        ## define model ##
        self.model = Model(config, laplacian, lmax)

        ## model saver / summary writer ##
        self.saver = tf.train.Saver()
        self.model_saver = tf.train.Saver(self.model.model_vars)
        self.summary_train_writer = tf.summary.FileWriter(self.model_dir +
                                                          '/train')
        self.summary_test_writer = tf.summary.FileWriter(self.model_dir +
                                                         '/test')

        sv = tf.train.Supervisor(logdir=self.model_dir,
                                 is_chief=True,
                                 saver=self.saver,
                                 summary_op=None,
                                 summary_writer=self.summary_train_writer,
                                 save_summaries_secs=300,
                                 save_model_secs=self.checkpoint_secs,
                                 global_step=self.model.model_step)

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=self.gpu_memory_fraction,
            allow_growth=True)  # seems to be not working
        sess_config = tf.ConfigProto(allow_soft_placement=True,
                                     gpu_options=gpu_options)

        self.sess = sv.prepare_or_wait_for_session(config=sess_config)
示例#2
0
    def __init__(self, config, rng):
        self.config = config
        self.rng = rng
        self.task = config.task
        self.model_dir = config.model_dir
        self.gpu_memory_fraction = config.gpu_memory_fraction
        self.checkpoint_secs = config.checkpoint_secs
        self.log_step = config.log_step
        self.num_epoch = config.num_epochs

        ## import data Loader ##
        data_dir = config.data_dir
        dataset_name = config.task
        batch_size = config.batch_size
        num_time_steps = config.num_time_steps
        self.data_loader = BatchLoader(data_dir, dataset_name, batch_size,
                                       num_time_steps)

        ## Need to think about how we construct adj matrix(W)
        # Oh no. Are you kidding me??
        W = self.data_loader.adj
        laplacian = W / W.max()  # 作了归一化
        laplacian = scipy.sparse.csr_matrix(laplacian,
                                            dtype=np.float32)  # 将矩阵用CSR的方式存储
        lmax = graph.lmax(laplacian)  # Q:作用未知

        ## define model ##
        self.model = Model(config, laplacian, lmax)

        ## model saver / summary writer ##
        self.saver = tf.train.Saver()
        self.model_saver = tf.train.Saver(self.model.model_vars)
        self.summary_writer = tf.summary.FileWriter(self.model_dir)

        sv = tf.train.Supervisor(logdir=self.model_dir,
                                 is_chief=True,
                                 saver=self.saver,
                                 summary_op=None,
                                 summary_writer=self.summary_writer,
                                 save_summaries_secs=300,
                                 save_model_secs=self.checkpoint_secs,
                                 global_step=self.model.model_step)

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=self.gpu_memory_fraction,
            allow_growth=True)  # seems to be not working
        sess_config = tf.ConfigProto(allow_soft_placement=True,
                                     gpu_options=gpu_options)

        self.sess = sv.prepare_or_wait_for_session(config=sess_config)
示例#3
0
def train(data, test=False):
    batch_loader = BatchLoader(data)
    with batch_loader as batch:
        print('Preparing data...')

        if test:
            X_train, X_test, y_train, y_test = train_test_split(batch.features,
                                                                batch.labels,
                                                                test_size=0.2,
                                                                random_state=0)
        else:
            X_train = batch.features
            y_train = batch.labels

        print('Train...')
        text_classifier = RandomForestClassifier(n_estimators=10,
                                                 random_state=0)
        text_classifier.fit(X_train, y_train)

        print('Training finished!')

        if test:
            predictions = text_classifier.predict(X_test)

            print(confusion_matrix(y_test, predictions))
            print(classification_report(y_test, predictions))
            print(accuracy_score(y_test, predictions))

        return text_classifier, batch_loader.vectorizer
示例#4
0
def predict_to_csv(model, vectorizer, data, dest='prediction.csv'):
    batch_loader = BatchLoader(data, vectorizer, has_labels=False)
    with batch_loader as batch:
        predictions = model.predict(batch.features)

        df = pandas.DataFrame(predictions,
                              columns=['Category'],
                              index=batch.ids)
        df.index.name = 'Id'

        df.to_csv(dest, index=True, header=True)

    print("Predictions exported to {}.".format(dest))
示例#5
0
class Trainer(object):
    def __init__(self, config, rng):
        self.config = config
        self.rng = rng
        self.task = config.task
        self.model_dir = config.model_dir
        self.gpu_memory_fraction = config.gpu_memory_fraction
        self.checkpoint_secs = config.checkpoint_secs
        self.log_step = config.log_step
        self.num_epoch = config.num_epochs

        ## import data Loader ##
        data_dir = config.data_dir
        dataset_name = config.task
        batch_size = config.batch_size
        num_time_steps = config.num_time_steps
        self.data_loader = BatchLoader(data_dir, dataset_name, batch_size,
                                       num_time_steps)

        ## Need to think about how we construct adj matrix(W)
        W = self.data_loader.adj
        laplacian = W / W.max()
        laplacian = scipy.sparse.csr_matrix(laplacian, dtype=np.float32)
        lmax = graph.lmax(laplacian)

        #idx2char = batchLoader_.idx2char
        #char2idx = batchLoader_.char2idx
        #batch_x, batch_y = batchLoader_.next_batch(0) 0:train 1:valid 2:test
        #batchLoader_.reset_batch_pointer(0)

        ## define model ##
        self.model = Model(config, laplacian, lmax)

        ## model saver / summary writer ##
        self.saver = tf.train.Saver()
        self.model_saver = tf.train.Saver(self.model.model_vars)
        self.summary_writer = tf.summary.FileWriter(self.model_dir)

        sv = tf.train.Supervisor(logdir=self.model_dir,
                                 is_chief=True,
                                 saver=self.saver,
                                 summary_op=None,
                                 summary_writer=self.summary_writer,
                                 save_summaries_secs=300,
                                 save_model_secs=self.checkpoint_secs,
                                 global_step=self.model.model_step)

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=self.gpu_memory_fraction,
            allow_growth=True)  # seems to be not working
        sess_config = tf.ConfigProto(allow_soft_placement=True,
                                     gpu_options=gpu_options)

        self.sess = sv.prepare_or_wait_for_session(config=sess_config)

    def train(self):
        print("[*] Checking if previous run exists in {}"
              "".format(self.model_dir))
        latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
        if tf.train.latest_checkpoint(self.model_dir) is not None:
            print("[*] Saved result exists! loading...")
            self.saver.restore(self.sess, latest_checkpoint)
            print("[*] Loaded previously trained weights")
            self.b_pretrain_loaded = True
        else:
            print("[*] No previous result")
            self.b_pretrain_loaded = False

        print("[*] Training starts...")
        self.model_summary_writer = None

        ##Training
        for n_epoch in trange(self.num_epoch, desc="Training[epoch]"):
            self.data_loader.reset_batch_pointer(0)
            for k in trange(self.data_loader.sizes[0], desc="[per_batch]"):
                # Fetch training data
                batch_x, batch_y = self.data_loader.next_batch(0)
                batch_x_onehot = convert_to_one_hot(batch_x,
                                                    self.config.num_node)
                if self.config.model_type == 'lstm':
                    reshaped = batch_x_onehot.reshape([
                        self.config.batch_size, self.config.num_node,
                        self.config.num_time_steps
                    ])
                    batch_x = reshaped
                elif self.config.model_type == 'glstm':
                    reshaped = batch_x_onehot.reshape([
                        self.config.batch_size, self.config.num_time_steps, 1,
                        self.config.num_node
                    ])
                    batch_x = np.transpose(reshaped, (0, 3, 2, 1))

                feed_dict = {
                    self.model.rnn_input: batch_x,
                    self.model.rnn_output: batch_y
                }
                res = self.model.train(self.sess,
                                       feed_dict,
                                       self.model_summary_writer,
                                       with_output=True)
                self.model_summary_writer = self._get_summary_writer(res)

            if n_epoch % 10 == 0:
                self.saver.save(self.sess, self.model_dir)
                print(batch_x, batch_y)

    def test(self):
        self.model_summary_writer = None

        #Testing
        for n_sample in trange(self.data_loader.sizes[2], desc="Testing"):
            batch_x, batch_y = self.data_loader.next_batch(2)
            batch_x_onehot = convert_to_one_hot(batch_x, self.config.num_node)
            reshaped = batch_x_onehot.reshape([
                self.config.batch_size, self.config.num_time_steps, 1,
                self.config.num_node
            ])
            batch_x = np.transpose(reshaped, (0, 3, 2, 1))

            feed_dict = {
                self.model.rnn_input: batch_x,
                self.model.rnn_output: batch_y
            }
            res = self.model.test(self.sess,
                                  feed_dict,
                                  self.model_summary_writer,
                                  with_output=True)
            self.model_summary_writer = self._get_summary_writer(res)

    def _get_summary_writer(self, result):
        if result['step'] % self.log_step == 0:
            return self.summary_writer
        else:
            return None
示例#6
0
#!/bin/python


from utils import BatchLoader

loader = BatchLoader('data/training_eeg.csv', 5, 2)
loader.next_batch()

示例#7
0
class Trainer(object):
    def __init__(self, config, rng):
        self.config = config
        self.rng = rng
        self.task = config.task
        self.model_dir = config.model_dir
        self.gpu_memory_fraction = config.gpu_memory_fraction
        self.checkpoint_secs = config.checkpoint_secs
        self.log_step = config.log_step
        self.num_epoch = config.num_epochs

        ## import data Loader ##
        data_dir = config.data_dir
        dataset_name = config.task
        batch_size = config.batch_size
        num_time_steps = config.num_time_steps
        self.data_loader = BatchLoader(data_dir, dataset_name, batch_size,
                                       num_time_steps)

        ## Need to think about how we construct adj matrix(W)
        # Oh no. Are you kidding me??
        W = self.data_loader.adj
        laplacian = W / W.max()  # 作了归一化
        laplacian = scipy.sparse.csr_matrix(laplacian,
                                            dtype=np.float32)  # 将矩阵用CSR的方式存储
        lmax = graph.lmax(laplacian)  # Q:作用未知

        ## define model ##
        self.model = Model(config, laplacian, lmax)

        ## model saver / summary writer ##
        self.saver = tf.train.Saver()
        self.model_saver = tf.train.Saver(self.model.model_vars)
        self.summary_writer = tf.summary.FileWriter(self.model_dir)

        sv = tf.train.Supervisor(logdir=self.model_dir,
                                 is_chief=True,
                                 saver=self.saver,
                                 summary_op=None,
                                 summary_writer=self.summary_writer,
                                 save_summaries_secs=300,
                                 save_model_secs=self.checkpoint_secs,
                                 global_step=self.model.model_step)

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=self.gpu_memory_fraction,
            allow_growth=True)  # seems to be not working
        sess_config = tf.ConfigProto(allow_soft_placement=True,
                                     gpu_options=gpu_options)

        self.sess = sv.prepare_or_wait_for_session(config=sess_config)

    def train(self):
        # print("[*] Checking if previous run exists in {}"
        #       "".format(self.model_dir))
        # latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
        # if tf.train.latest_checkpoint(self.model_dir) is not None:
        #     print("[*] Saved result exists! loading...")
        #     self.saver.restore(
        #         self.sess,
        #         latest_checkpoint
        #     )
        #     print("[*] Loaded previously trained weights")
        #     self.b_pretrain_loaded = True
        # else:
        #     print("[*] No previous result")
        #     self.b_pretrain_loaded = False
        #
        # print("[*] Training starts...")
        # self.model_summary_writer = None

        ##Training

        for n_epoch in trange(self.num_epoch, desc="Training[epoch]"):
            self.data_loader.reset_batch_pointer(0)
            #sizes[0] traindata
            for k in trange(self.data_loader.sizes[0], desc="[per_batch]"):
                # Fetch training data
                batch_x, batch_y = self.data_loader.next_batch(0)
                #得到每一行的数据
                # batch_x_onehot = convert_to_one_hot(batch_x, self.config.num_node)
                #对进行coo编码转换
                if self.config.model_type == 'lstm':
                    reshaped = batch_x.reshape([
                        self.config.batch_size, self.config.num_node,
                        self.config.num_time_steps
                    ])

                    batch_x = reshaped
                elif self.config.model_type == 'glstm':
                    reshaped = batch_x.reshape([
                        self.config.batch_size, self.config.num_time_steps,
                        self.config.feat_in, self.config.num_node
                    ])
                    #[20,50,1,50]->[20,50,1,50] batchsize,num_node,1,numtime_steps
                    batch_x = np.transpose(reshaped, (0, 3, 2, 1))

                batch_y = np.transpose(batch_y, (0, 3, 2, 1))
                feed_dict = {
                    self.model.rnn_input: batch_x,
                    self.model.rnn_output: batch_y
                }
                # res = self.model.train(self.sess, feed_dict, self.model_summary_writer,
                #                        with_output=True)
                res = self.model.train(self.sess, feed_dict, with_output=True)

                # self.model_summary_writer = self._get_summary_writer(res)
            res_output = res['output']
            res_shape = res_output.shape
            threshold = 1.0
            total = self.config.batch_size * self.config.num_node * self.config.num_time_steps

            # res_output[:,:,0,:] - res_output[:,:,1,:]
            for i in range(res_shape[0]):
                for j in range(res_shape[1]):
                    for k in range(res_shape[3]):
                        x1, x2 = res_output[i, j, 0, k], res_output[i, j, 1, k]
                        if x1 - x2 > threshold:
                            res_output[i, j, 0, k] = 1
                            res_output[i, j, 1, k] = 0
                        elif x1 - x2 < -1 * threshold:
                            res_output[i, j, 0, k] = 0
                            res_output[i, j, 1, k] = 1
                        else:
                            res_output[i, j, 0, k] = 0
                            res_output[i, j, 1, k] = 0
            res_output = np.swapaxes(res_output, 2, 3).reshape([-1, 2])
            batch_y_bp = np.swapaxes(batch_y, 2, 3).reshape([-1, 2])
            acc = roc_auc_score(res_output, batch_y_bp)

            print('acc:  ', acc)
            # print('res:',res)
            if n_epoch % 10 == 0:
                self.saver.save(self.sess, self.model_dir)
        with DeepExplain(session=self.sess) as de:
            logits = self.model.rnn_output
            xi, yi = self.data_loader.next_batch(0)
            reshaped = xi.reshape([
                self.config.batch_size, self.config.num_time_steps,
                self.config.feat_in, self.config.num_node
            ])
            # [20,50,1,50]->[20,50,1,50] batchsize,num_node,1,numtime_steps
            xi = np.transpose(reshaped, (0, 3, 2, 1))
            # x = np.reshape(xi[0, :, :, 0], [110, 2])
            yreshaped = yi.reshape([
                self.config.batch_size, self.config.num_time_steps,
                self.config.feat_in, self.config.num_node
            ])
            yi = np.transpose(yreshaped, (0, 3, 2, 1))
            # y = np.reshape(yi[0, :, :, 0], [110, 2])
            # applyy = logits*yi
            # print("[*] Checking if previous run exi
            print("[*] Saved result exists! loading...")

            print("asdfghjklweeeeeerty")
            print(logits)
            # data_loader=BatchLoader(data_dir, dataset_name,
            #             batch_size, num_time_steps)
            print("ok")
            # X = tf.placeholder("float", [self.config.batch_size,self.config.num_node,self.config.feat_in,self.config.num_time_steps])
            attributions = de.explain('grad*input', xi, self.model.rnn_input,
                                      xi)
            np.savetxt('0_features.csv', attributions[0], delimiter=', ')
            print('Done')

    def test(self):
        print("[*] Checking if previous run exists in {}"
              "".format(self.model_dir))
        latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
        if tf.train.latest_checkpoint(self.model_dir) is not None:
            print("[*] Saved result exists! loading...")
            self.saver.restore(self.sess, latest_checkpoint)
            print("[*] Loaded previously trained weights")
            self.b_pretrain_loaded = True
        else:
            print("[*] No previous result")
            self.b_pretrain_loaded = False

        print("[*] Testing starts...")
        self.model_summary_writer = None
        ##Testing
        for n_sample in trange(self.data_loader.sizes[2], desc="Testing"):
            batch_x, batch_y = self.data_loader.next_batch(2)
            # batch_x_onehot = convert_to_one_hot(batch_x, self.config.num_node)
            reshaped = batch_x.reshape([
                self.config.batch_size, self.config.num_time_steps,
                self.config.feat_in, self.config.num_node
            ])
            batch_x = np.transpose(reshaped, (0, 3, 2, 1))

            feed_dict = {
                self.model.rnn_input: batch_x,
                self.model.rnn_output: batch_y
            }
            res = self.model.test(self.sess,
                                  feed_dict,
                                  self.model_summary_writer,
                                  with_output=True)
            self.model_summary_writer = self._get_summary_writer(res)

    def _get_summary_writer(self, result):
        if result['step'] % self.log_step == 0:
            return self.summary_writer
        else:
            return None
示例#8
0
batch_size = 100
display_step = 10

# Network Parameters
num_features = 10 # Number of dimensions in tangent space produced by pyriemann
timesteps = 6 # Number of eeg epochs per sequence
num_hidden = 2048 # hidden layer num of neurons
num_classes = 2 # distracted or concentrated
num_layers = 1 # number of hidden layers
input_keep_prob = 1 # portion of incoming connections to keep
output_keep_prob = 0.5 # portion of outgoing connections to keep

logging.info("LR = " + str(learning_rate) + " Epochs = " + str(epochs))

# Initialize data feed
train_loader = BatchLoader('data/training_eeg.csv', batch_size, timesteps, num_features, num_classes, train=True)
valid_loader = BatchLoader('data/valid_eeg.csv', batch_size, timesteps, num_features, num_classes, train=False)

# tf Graph input
X = tf.placeholder("float", [batch_size, timesteps, num_features])
Y = tf.placeholder("int64", [batch_size])

# Define weights
weights = {
    'out': tf.Variable(tf.random_normal([num_hidden, num_classes]))
}
biases = {
    'out': tf.Variable(tf.random_normal([num_classes]))
}

示例#9
0
    def __init__(self, config, rng):
        self.config = config
        self.rng = rng
        self.model_dir = config.model_dir
        self.gpu_memory_fraction = config.gpu_memory_fraction
        self.checkpoint_secs = config.checkpoint_secs
        self.log_step = config.log_step
        self.num_epoch = config.num_epochs
        self.stop_win_size = config.stop_win_size
        self.stop_early = config.stop_early

        ## import data Loader ##ir
        batch_size = config.batch_size
        server_name = config.server_name
        mode = config.mode
        target = config.target
        sample_rate = config.sample_rate
        win_size = config.win_size
        hist_range = config.hist_range
        s_month = config.s_month
        e_month = config.e_month
        e_date = config.e_date
        s_date = config.s_date
        data_rm = config.data_rm
        coarsening_level = config.coarsening_level
        cnn_mode = config.conv
        is_coarsen = config.is_coarsen

        self.data_loader = BatchLoader(server_name, mode, target, sample_rate,
                                       win_size, hist_range, s_month, s_date,
                                       e_month, e_date, data_rm, batch_size,
                                       coarsening_level, cnn_mode, is_coarsen)

        actual_node = self.data_loader.adj.shape[0]
        if config.conv == 'gcnn':
            graphs = self.data_loader.graphs
            if config.is_coarsen:
                L = [
                    graph.laplacian(A, normalized=config.normalized)
                    for A in graphs
                ]
            else:
                L = [
                    graph.laplacian(self.data_loader.adj,
                                    normalized=config.normalized)
                ] * len(graphs)
        elif config.conv == 'cnn':
            L = [actual_node]
            tmp_node = actual_node
            while tmp_node > 0:
                tmp_node = int(tmp_node / 2)
                L.append(tmp_node)
        else:
            raise ValueError("Unsupported config.conv {}".format(config.conv))

        tf.reset_default_graph()
        ## define model ##
        self.model = Model(config, L, actual_node)

        ## model saver / summary writer ##
        self.saver = tf.train.Saver()
        self.model_saver = tf.train.Saver(self.model.model_vars)
        self.summary_writer = tf.summary.FileWriter(self.model_dir)
        # Checkpoint
        # meta file: describes the saved graph structure, includes
        # GraphDef, SaverDef, and so on; then apply
        # tf.train.import_meta_graph('/tmp/model.ckpt.meta'),
        # will restore Saver and Graph.

        # index file: it is a string-string immutable
        # table(tensorflow::table::Table). Each key is a name of a tensor
        # and its value is a serialized BundleEntryProto.
        # Each BundleEntryProto describes the metadata of a
        # tensor: which of the "data" files contains the content of a tensor,
        # the offset into that file, checksum, some auxiliary data, etc.
        #
        # data file: it is TensorBundle collection, save the values of all variables.
        sv = tf.train.Supervisor(logdir=self.model_dir,
                                 is_chief=True,
                                 saver=self.saver,
                                 summary_op=None,
                                 summary_writer=self.summary_writer,
                                 save_summaries_secs=300,
                                 save_model_secs=self.checkpoint_secs,
                                 global_step=self.model.model_step)

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=self.gpu_memory_fraction,
            allow_growth=True)  # seems to be not working
        sess_config = tf.ConfigProto(allow_soft_placement=True,
                                     gpu_options=gpu_options)
        #
        self.sess = sv.prepare_or_wait_for_session(config=sess_config)
示例#10
0
class Trainer(object):
    def __init__(self, config, rng):
        self.config = config
        self.rng = rng
        self.model_dir = config.model_dir
        self.gpu_memory_fraction = config.gpu_memory_fraction
        self.checkpoint_secs = config.checkpoint_secs
        self.log_step = config.log_step
        self.num_epoch = config.num_epochs
        self.stop_win_size = config.stop_win_size
        self.stop_early = config.stop_early

        ## import data Loader ##ir
        batch_size = config.batch_size
        server_name = config.server_name
        mode = config.mode
        target = config.target
        sample_rate = config.sample_rate
        win_size = config.win_size
        hist_range = config.hist_range
        s_month = config.s_month
        e_month = config.e_month
        e_date = config.e_date
        s_date = config.s_date
        data_rm = config.data_rm
        coarsening_level = config.coarsening_level
        cnn_mode = config.conv
        is_coarsen = config.is_coarsen

        self.data_loader = BatchLoader(server_name, mode, target, sample_rate,
                                       win_size, hist_range, s_month, s_date,
                                       e_month, e_date, data_rm, batch_size,
                                       coarsening_level, cnn_mode, is_coarsen)

        actual_node = self.data_loader.adj.shape[0]
        if config.conv == 'gcnn':
            graphs = self.data_loader.graphs
            if config.is_coarsen:
                L = [
                    graph.laplacian(A, normalized=config.normalized)
                    for A in graphs
                ]
            else:
                L = [
                    graph.laplacian(self.data_loader.adj,
                                    normalized=config.normalized)
                ] * len(graphs)
        elif config.conv == 'cnn':
            L = [actual_node]
            tmp_node = actual_node
            while tmp_node > 0:
                tmp_node = int(tmp_node / 2)
                L.append(tmp_node)
        else:
            raise ValueError("Unsupported config.conv {}".format(config.conv))

        tf.reset_default_graph()
        ## define model ##
        self.model = Model(config, L, actual_node)

        ## model saver / summary writer ##
        self.saver = tf.train.Saver()
        self.model_saver = tf.train.Saver(self.model.model_vars)
        self.summary_writer = tf.summary.FileWriter(self.model_dir)
        # Checkpoint
        # meta file: describes the saved graph structure, includes
        # GraphDef, SaverDef, and so on; then apply
        # tf.train.import_meta_graph('/tmp/model.ckpt.meta'),
        # will restore Saver and Graph.

        # index file: it is a string-string immutable
        # table(tensorflow::table::Table). Each key is a name of a tensor
        # and its value is a serialized BundleEntryProto.
        # Each BundleEntryProto describes the metadata of a
        # tensor: which of the "data" files contains the content of a tensor,
        # the offset into that file, checksum, some auxiliary data, etc.
        #
        # data file: it is TensorBundle collection, save the values of all variables.
        sv = tf.train.Supervisor(logdir=self.model_dir,
                                 is_chief=True,
                                 saver=self.saver,
                                 summary_op=None,
                                 summary_writer=self.summary_writer,
                                 save_summaries_secs=300,
                                 save_model_secs=self.checkpoint_secs,
                                 global_step=self.model.model_step)

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=self.gpu_memory_fraction,
            allow_growth=True)  # seems to be not working
        sess_config = tf.ConfigProto(allow_soft_placement=True,
                                     gpu_options=gpu_options)
        #
        self.sess = sv.prepare_or_wait_for_session(config=sess_config)

        # init = tf.global_variables_initializer()
        # self.sess = tf.Session(config=sess_config)
        # self.sess.run(init)

    def train(self, val_best_score=10, save=False, index=1, best_model=None):
        print("[*] Checking if previous run exists in {}"
              "".format(self.model_dir))
        latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
        if tf.train.latest_checkpoint(self.model_dir) is not None:
            print("[*] Saved result exists! loading...")
            self.saver.restore(self.sess, latest_checkpoint)
            print("[*] Loaded previously trained weights")
            self.b_pretrain_loaded = True
        else:
            print("[*] No previous result")
            self.b_pretrain_loaded = False

        print("[*] Training starts...")
        self.model_summary_writer = None

        val_loss = 0
        lr = 0
        tmp_best_loss = float('+inf')
        validation_loss_window = np.zeros(self.stop_win_size)
        validation_loss_window[:] = float('+inf')
        ##Training
        for n_epoch in trange(self.num_epoch, desc="Training[epoch]"):
            self.data_loader.reset_batch_pointer(0)
            loss_epoch = []
            for k in trange(self.data_loader.sizes[0], desc="[per_batch]"):
                # Fetch training data
                batch_x, batch_y, weight_y,\
                count_y, _ = self.data_loader.next_batch(0)

                feed_dict = {
                    self.model.cnn_input: batch_x,
                    self.model.output_label: batch_y,
                    self.model.ph_labels_weight: weight_y,
                    self.model.is_training: True
                }
                res = self.model.train(self.sess,
                                       feed_dict,
                                       self.model_summary_writer,
                                       with_output=True)
                loss_epoch.append(res['loss'])
                lr = res['lr']
                self.model_summary_writer = self._get_summary_writer(res)

            val_loss = self.validate()
            train_loss = np.mean(loss_epoch)

            validation_loss_window[n_epoch % self.stop_win_size] = val_loss

            if self.stop_early:
                if np.abs(validation_loss_window.mean() - val_loss) < 1e-4:
                    print('Validation loss did not decrease. Stopping early.')
                    break

            if n_epoch % 10 == 0:
                if save:
                    self.saver.save(self.sess, self.model_dir)
                if val_loss < val_best_score:
                    val_best_score = val_loss
                    best_model = self.model_dir
                if val_loss < tmp_best_loss:
                    tmp_best_loss = val_loss
                print("Searching {}...".format(index))
                print("Epoch {}: ".format(n_epoch))
                print("LR: ", lr)
                print("  Train Loss: ", train_loss)
                print("  Validate Loss: ", val_loss)
                print("  Current Best Loss: ", val_best_score)
                print("  Current Model Dir: ", best_model)

        return tmp_best_loss

    def validate(self):

        loss = []
        for n_sample in trange(self.data_loader.sizes[1], desc="Validating"):
            batch_x, batch_y, weight_y, count_y,\
            _ = self.data_loader.next_batch(1)

            feed_dict = {
                self.model.cnn_input: batch_x,
                self.model.output_label: batch_y,
                self.model.ph_labels_weight: weight_y,
                self.model.is_training: False
            }
            res = self.model.test(self.sess,
                                  feed_dict,
                                  self.summary_writer,
                                  with_output=True)
            loss.append(res['loss'])

        return np.nanmean(loss)

    def test(self):

        loss = []
        gt_y = []
        pred_y = []
        w_y = []
        counts_y = []
        vel_list_y = []
        for n_sample in trange(self.data_loader.sizes[2], desc="Testing"):
            batch_x, batch_y, weight_y, \
            count_y, vel_list = self.data_loader.next_batch(2)

            feed_dict = {
                self.model.cnn_input: batch_x,
                self.model.output_label: batch_y,
                self.model.ph_labels_weight: weight_y,
                self.model.is_training: False
            }
            res = self.model.test(self.sess,
                                  feed_dict,
                                  self.summary_writer,
                                  with_output=True)
            loss.append(res['loss'])
            gt_y.append(batch_y)
            w_y.append(weight_y)
            counts_y.append(count_y)
            vel_list_y.append(vel_list)
            pred_y.append(res['pred'])

        final_gt = np.concatenate(gt_y, axis=0)
        final_pred = np.concatenate(pred_y, axis=0)
        final_weight = np.concatenate(w_y, axis=0)
        final_count = np.concatenate(counts_y, axis=0)
        final_vel_list = np.concatenate(vel_list_y, axis=0)

        result_dict = {
            'ground_truth': final_gt,
            'prediction': final_pred,
            'weight': final_weight,
            'count': final_count,
            'vel_list': final_vel_list
        }

        test_loss = np.mean(loss)
        print("Test Loss: ", test_loss)

        return result_dict
        # self.model_summary_writer = self._get_summary_writer(res)

    def _get_summary_writer(self, result):
        if result['step'] % self.log_step == 0:
            return self.summary_writer
        else:
            return None