コード例 #1
0
def predict():
    data = tf.placeholder(tf.float32, [None, 331, 331, 3])
    label = tf.placeholder(tf.float32, [None, None])
    dropout = tf.placeholder(tf.float32)
    phase = tf.placeholder(tf.bool)
    net = MyInceptionV4(data, label, dropout, phase)

    # batch data
    transformer = transforms.Sequential([
        transforms.Resize([331, 331]),
        transforms.Preprocess(),
        # transforms.RandomHorizontalFlip(),
    ])

    mt_loader = MTCSVLoader(
        root='E:/fashion-dataset/rank',
        csv_path='E:\\fashion-dataset\\rank\\Tests\\question.csv',
        batch_size=64,
        transformer_fn=transformer,
        shuffle=False,
        num_epochs=1,
        allow_smaller_final_batch=True)
    sess = session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(
        sess,
        '{}/model/MyInceptionV4/MultiTask/fashion-ai.ckpt'.format(hp.pro_path))
    queue = multiprocessing.Queue(maxsize=30)
    writer_process = multiprocessing.Process(
        target=csv_writer,
        args=['{}/result/sub.csv'.format(hp.pro_path), queue, 'stop'])
    writer_process.start()
    for attr_key, n_class in di.num_classes_v2.items():
        flat_logit = net.layers['cls_prob_{}'.format(attr_key)]
        y_pred = tf.argmax(flat_logit, axis=1)
        print('writing predictions...')
        try:
            while not mt_loader.coord.should_stop():
                img_batch, label_batch, name_batch = mt_loader.batch(
                    attr_key=attr_key)
                names = list(map(lambda v: bytes.decode(v), name_batch))
                probs, preds = sess.run([flat_logit, y_pred],
                                        feed_dict={
                                            net.data: img_batch,
                                            net.is_training: False,
                                            net.keep_prob: 1
                                        })
                queue.put(('continue', names, attr_key, probs))
                print(probs.shape)

        except tf.errors.OutOfRangeError:
            print('Predict {} Done.'.format(attr_key))
    queue.put(('stop', None, None, None))
    writer_process.join()
    sess.close()
コード例 #2
0
    def __init__(self):
        self.config = config
        self.session = session()
        self.purchase_num = 1
        self.sku_id = self.config.get_config('config', 'sku_id')
        self.order_data = dict()
        self.init_info = dict()
        self.timers = Timer()

        log.info('正在同步系统时间')
コード例 #3
0
    def _bulid(self, dataset, sess=None):
        self._dataset = dataset

        if self._is_eager:
            self._eager_iterator = tfe.Iterator(dataset)
        else:
            self._iterator = dataset.make_initializable_iterator(
                shared_name='CelebA')
            self.img_iter_init = self._iterator.make_initializer(
                self._dataset)  # add this..works well
            self._batch_op = self._iterator.get_next()
            if sess:
                self._sess = sess
            else:
                self._sess = session()

        try:
            self.reset()
        except:
            pass
コード例 #4
0
    with tf.control_dependencies([d_step_]):
        d_step = tf.group(*(tf.assign(var, tf.clip_by_value(var, -clip, clip))
                            for var in d_var))
    g_step = tf.train.RMSPropOptimizer(learning_rate=lr).minimize(
        g_loss, var_list=g_var)

    # summaries
    d_summary = utils.summary({wd: 'wd'})
    g_summary = utils.summary({g_loss: 'g_loss'})

    # sample
    f_sample = generator(z, training=False)
""" train """
''' init '''
# session
sess = utils.session()
# iteration counter
it_cnt, update_cnt = utils.counter()
# saver
saver = tf.train.Saver(max_to_keep=5)
# summary writer
summary_writer = tf.summary.FileWriter('./summaries/celeba_wgan', sess.graph)
''' initialization '''
ckpt_dir = './checkpoints/celeba_wgan'
utils.mkdir(ckpt_dir + '/')
if not utils.load_checkpoint(ckpt_dir, sess):
    sess.run(tf.global_variables_initializer())
''' train '''
try:
    z_ipt_sample = np.random.normal(size=[100, z_dim])
コード例 #5
0
def train():
    alpha_span = 800000
    batch_size = 32
    ckpt_dir = './checkpoints/wgp'
    n_gen = 1
    n_critic = 1
    it_start = 0
    #epoch = 20*(alpha_span * 2 // (2*4936)) # 4936 is number of images
    
    def preprocess_fn(img):
        img = tf.image.resize_images(img, [target_size, target_size], method=tf.image.ResizeMethod.AREA) / 127.5 -1
        return img

    def preprocess_fn_dummy(img):
        img = tf.image.resize_images(img, [final_size, final_size], method=tf.image.ResizeMethod.AREA) / 127.5 -1
        return img
    
    # dataset
    img_paths = glob.glob('./imgs/faces/*.png')
    data_pool = utils.DiskImageData(5, img_paths, batch_size//2, shape=[640, 640, 3], preprocess_fn=preprocess_fn)
    data_pool_dummy = utils.DiskImageData(7, img_paths, 1, shape=[640, 640, 3], preprocess_fn=preprocess_fn_dummy)    
    batch_epoch = len(data_pool) // (batch_size * 1)#n_critic

    # build graph
    print('Building a graph ...')
    nodes = build(batch_size)
    # session
    sess = utils.session()
    saver = tf.train.Saver()
    # summary
    summary_writer = tf.summary.FileWriter('./summaries/wgp/', sess.graph)
    utils.mkdir(ckpt_dir + '/')

    print('Initializing all variables ...')
    sess.run(tf.global_variables_initializer())
    
    # run final size session for storing all variables to be used into the optimizer
    print('Running final size dummy session ...')
    #if target_size == initial_size and len(sys.argv) <= 3:
    #    _ = sess.run([nodes['dummy']['d']], feed_dict=get_ipt(2, final_size, 1.0, data_pool_dummy ,z_dim, nodes['dummy']['input'] ))
    #    _ = sess.run([nodes['dummy']['g']], feed_dict=get_ipt(2, final_size, 1.0, data_pool_dummy ,z_dim, nodes['dummy']['input'] ))
        
    # load checkpoint
    if len(sys.argv)>3 and sys.argv[2]=='resume':
        print ('Loading the checkpoint ...')
        saver.restore(sess, ckpt_dir+'/model.ckpt')
        it_start = 1 + int(sys.argv[3])
    last_saved_iter = it_start - 1

    ''' train '''
    for it in range(it_start, 9999999999):
        # fade alpha
        alpha_ipt = it / (alpha_span / batch_size)
        if alpha_ipt > 1 or target_size == initial_size:
            alpha_ipt = 1.0
        print('Alpha : %f' % alpha_ipt)
        alpha_ipt = 1.0
        
        # train D
        for i in range(n_critic):
            d_summary_opt, _ = sess.run([nodes['summaries']['d'], nodes['product']['d']],\
                feed_dict=get_ipt(batch_size, target_size, alpha_ipt, data_pool, z_dim, nodes['product']['input']))
        summary_writer.add_summary(d_summary_opt, it)

        # train G
        for i in range(n_gen):
            g_summary_opt, _ = sess.run([nodes['summaries']['g'], nodes['product']['g']],\
                feed_dict=get_ipt(batch_size, target_size, alpha_ipt, data_pool, z_dim, nodes['product']['input']))
        summary_writer.add_summary(g_summary_opt, it)
        
        # display
        epoch = it // batch_epoch
        it_epoch = it % batch_epoch + 1
        if it % 1 == 0:
            print("iter : %8d, epoch : (%3d) (%5d/%5d) _ resume point : %d" % (it, epoch, it_epoch, batch_epoch,last_saved_iter))

        # sample
        if (it + 1) % batch_epoch == 0:
            f_sample_opt = sess.run(nodes['sample'], feed_dict=get_ipt_for_sample(batch_size, z_dim, nodes['product']['input']))
            f_sample_opt = np.clip(f_sample_opt, -1, 1)
            save_dir = './sample_images_while_training/wgp/'
            utils.mkdir(save_dir + '/')
            osz = int(math.sqrt(batch_size))+1
            utils.imwrite(utils.immerge(f_sample_opt, osz, osz), '%s/iter_(%d).png' % (save_dir, it))
            
        # save
        if (it + 1) % batch_epoch == 0:
            last_saved_iter = it
            save_path = saver.save(sess, '%s/model.ckpt' % (ckpt_dir))
            print('Model saved in file: %s' % save_path)
コード例 #6
0
ファイル: train.py プロジェクト: shosangit/Colorization
                        self._eval(
                            iteration,
                            self.train_config["validsize"],
                            v_list,
                        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="SCFT")
    parser.add_argument('--session',
                        type=str,
                        default='scft',
                        help="session name")
    parser.add_argument('--data_path',
                        type=Path,
                        help="path containing color images")
    parser.add_argument('--sketch_path',
                        type=Path,
                        help="path containing sketch images")
    args = parser.parse_args()

    outdir, modeldir = session(args.session)

    with open("param.yaml", "r") as f:
        config = yaml.safe_load(f)
        pprint.pprint(config)

    trainer = Trainer(config, outdir, modeldir, args.data_path,
                      args.sketch_path)
    trainer()
コード例 #7
0
ファイル: runtrain.py プロジェクト: HansonSun/FaceHeadpose_TF
def run_training():

    #1.create log and model saved dir according to the datetime
    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    models_dir = os.path.join("saved_models", subdir, "models")
    if not os.path.isdir(models_dir):  # Create the model directory if it doesn't exist
        os.makedirs(models_dir)
    logs_dir = os.path.join("saved_models", subdir, "logs")
    if not os.path.isdir(logs_dir):  # Create the log directory if it doesn't exist
        os.makedirs(logs_dir)
    topn_models_dir = os.path.join("saved_models", subdir, "topn")#topn dir used for save top accuracy model
    if not os.path.isdir(topn_models_dir):  # Create the topn model directory if it doesn't exist
        os.makedirs(topn_models_dir)
    topn_file=open(os.path.join(topn_models_dir,"topn_acc.txt"),"a+")
    topn_file.close()


    #2.load dataset and define placeholder
    demo=TFRecordDataset(  )
    train_iterator,train_next_element=demo.generateDataset(tfrecord_path='tfrecord_dataset/train.tfrecords',batch_size=512)


    phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')
    images_placeholder = tf.placeholder(name='input', shape=[None, 96, 96, 3], dtype=tf.float32)
    binned_pose_placeholder = tf.placeholder(name='binned_pose', shape=[None,3 ], dtype=tf.int64)
    cont_labels_placeholder = tf.placeholder(name='cont_labels', shape=[None,3 ], dtype=tf.float32)

    yaw,pitch,roll = vgg.inference(images_placeholder,phase_train=phase_train_placeholder)

    yaw_logit   = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=yaw,labels=binned_pose_placeholder[:,0])
    pitch_logit = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=pitch,labels=binned_pose_placeholder[:,1])
    roll_logit  = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=roll,labels=binned_pose_placeholder[:,2])


    loss_yaw   = tf.reduce_mean(yaw_logit)
    loss_pitch = tf.reduce_mean(pitch_logit)
    loss_roll  = tf.reduce_mean(roll_logit)


    softmax_yaw=tf.nn.softmax(yaw)
    softmax_pitch=tf.nn.softmax(pitch)
    softmax_roll=tf.nn.softmax(roll)

    yaw_predicted   =  tf.math.reduce_sum( (softmax_yaw * tf.linspace(0.0,66.0,67) ), 1 )* 3 - 99
    pitch_predicted =  tf.math.reduce_sum( (softmax_pitch * tf.linspace(0.0,66.0,67) ), 1 )* 3 - 99
    roll_predicted  =  tf.math.reduce_sum( (softmax_roll * tf.linspace(0.0,66.0,67) ), 1 )* 3 - 99



    yaw_mse_loss = tf.losses.mean_squared_error(labels=cont_labels_placeholder[:,0], predictions=yaw_predicted)
    pitch_mse_loss = tf.losses.mean_squared_error(labels=cont_labels_placeholder[:,1], predictions=pitch_predicted)
    roll_mse_loss = tf.losses.mean_squared_error(labels=cont_labels_placeholder[:,2], predictions=roll_predicted)

    alpha = 0.001

    total_loss_softmax=(loss_yaw+loss_pitch+loss_roll)
    total_loss_mse = alpha*(yaw_mse_loss+pitch_mse_loss+roll_mse_loss)
    total_loss = total_loss_softmax+total_loss_mse

    

    yaw_correct_prediction = tf.equal(tf.argmax(yaw,1),binned_pose_placeholder[:,0] )
    pitch_correct_prediction = tf.equal(tf.argmax(pitch,1),binned_pose_placeholder[:,1] )
    roll_correct_prediction = tf.equal(tf.argmax(roll,1),binned_pose_placeholder[:,2] )

    yaw_accuracy = tf.reduce_mean(tf.cast(yaw_correct_prediction, tf.float32))
    pitch_accuracy = tf.reduce_mean(tf.cast(pitch_correct_prediction, tf.float32))
    roll_accuracy = tf.reduce_mean(tf.cast(roll_correct_prediction, tf.float32))

    #adjust learning rate
    global_step = tf.Variable(0, trainable=False)
    #learning_rate = tf.train.exponential_decay(0.001,global_step,100000,0.98,staircase=True)
    learning_rate = tf.train.piecewise_constant(global_step, boundaries=[8000, 16000, 24000, 32000], values=[0.001, 0.0001, 0.0001, 0.00001, 0.000001],name='lr_schedule')



    #optimize loss and update
    #optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
    optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9, use_nesterov=True)
    #optimizer = tf.train.RMSPropOptimizer(learning_rate, decay=0.9, momentum=0.9, epsilon=1.0)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(total_loss,global_step=global_step)


    saver=tf.train.Saver(tf.trainable_variables(),max_to_keep=5)

    sess=utils.session()
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    saver.restore(sess, "/home/hanson/work/FaceHeadpose_TF/saved_models/20190801-135403/models/0.563564.ckpt")

    minimum_loss_value=999.0
    total_loss_value = 0.0
    for epoch in range(1000):
        sess.run(train_iterator.initializer)
        while True:
            use_time=0
            try:
                images_train, binned_pose,cont_labels = sess.run(train_next_element)
                start_time=time.time()
                input_dict={phase_train_placeholder:True,images_placeholder:images_train,binned_pose_placeholder:binned_pose,cont_labels_placeholder:cont_labels}
                
                total_loss_mse_value,total_loss_softmax_value,yaw_acc,pitch_acc,roll_acc,step,lr,train_loss,_ = sess.run([
                                        total_loss_mse,
                                        total_loss_softmax,
                                        yaw_accuracy,
                                        pitch_accuracy,
                                        roll_accuracy,
                                        global_step,
                                        learning_rate,
                                        total_loss,
                                        train_op],
                                        feed_dict=input_dict)

                total_loss_value+=train_loss
                end_time=time.time()
                use_time+=(end_time-start_time)

                # display train result
                if(step%100==0):
                    use_time=0
                    average_loss_value = total_loss_value/100.0
                    total_loss_value=0
                    print ("step:%d lr:%f sloss:%f mloss%f average_loss:%f YAW_ACC:%.2f PITCH_ACC:%.2f ROLL_ACC:%.2f epoch:%d"%(step,
                                                                                                           lr,
                                                                                                           total_loss_softmax_value,
                                                                                                           total_loss_mse_value,
                                                                                                           float(average_loss_value),
                                                                                                           yaw_acc,
                                                                                                           pitch_acc,
                                                                                                           roll_acc, 
                                                                                                           epoch) )
                    if average_loss_value<minimum_loss_value:
                        print("save ckpt")
                        filename_cpkt = os.path.join(models_dir,"%f.ckpt"%average_loss_value)
                        saver.save(sess, filename_cpkt)
                        minimum_loss_value=average_loss_value

            except tf.errors.OutOfRangeError:
                print("End of epoch ")
                break
コード例 #8
0
def train_net(attr_key):
    with tf.Graph().as_default():
        # placeholder
        data = tf.placeholder(tf.float32, [None, 331, 331, 3])
        label = tf.placeholder(tf.float32, [None, None])
        is_training = tf.placeholder(tf.bool)
        keep_prob = tf.placeholder(tf.float32)

        # network
        net = MyInceptionV4(data=data, label=label, keep_prob=keep_prob, is_training=is_training)

        # batch data
        transformer = transforms.Sequential(
            [
                transforms.Resize([331, 331]),
                transforms.Preprocess(),
                transforms.RandomHorizontalFlip(),
            ]
        )

        loader = STCSVLoader(root='E:/fashion-dataset/base', attr_key=attr_key,
                             csv_path='{}/dataset/labels/s1_label.csv'.format(hp.pro_path),
                             batch_size=hp.batch_size, transformer_fn=transformer,
                             shuffle=hp.shuffle, min_after_dequeue=hp.min_after_dequeue,
                             num_threads=hp.num_threads, allow_smaller_final_batch=hp.allow_smaller_final_batch,
                             seed=hp.seed, num_epochs=None)

        num_batch = loader.n_sample // hp.batch_size
        hp.display = 1
        hp.snapshot_iter = num_batch
        hp.stepsize = num_batch * 5

        global_step = tf.Variable(0, name='global_step', trainable=False)
        lr = tf.train.exponential_decay(hp.learning_rate, global_step, hp.stepsize, hp.lr_decay, staircase=True)
        opt1 = tf.train.MomentumOptimizer(lr, hp.momentum)
        opt2 = tf.train.MomentumOptimizer(hp.times * lr, hp.momentum)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        update_op = tf.group(*update_ops)

        var1 = tf.trainable_variables()[0:-16]
        var2 = tf.trainable_variables()[-16:]

        cls_score = 'cls_score_{}'.format(attr_key)
        flat_logit = net.layers[cls_score]
        flat_label = net.layers['label']

        # acc
        y_pred = tf.argmax(flat_logit, axis=1)
        y_true = tf.argmax(flat_label, axis=1)
        accuracy = tf.reduce_mean(tf.cast(tf.equal(y_pred, y_true), tf.float32))

        # loss
        loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(
                logits=flat_logit,
                labels=flat_label
            )
        )

        with tf.control_dependencies([update_op]):
            cost = tf.identity(loss)
        # optimizer
        train_op1 = opt1.minimize(cost, global_step=global_step, var_list=var1)
        train_op2 = opt2.minimize(cost, var_list=var2)
        train_op = tf.group(train_op1, train_op2)

        # session
        sess = session()
        saver = tf.train.Saver(max_to_keep=50)
        sess.run(tf.global_variables_initializer())
        net.load(sess, '{}/model/Pretrain/{}'.format(hp.pro_path, 'inception_v4.ckpt'))
        step = -1
        last_snapshot_iter = -1
        max_step = hp.num_epoch * num_batch
        print('num_batch', num_batch)
        now_ime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        print('%s Start Training...iter: %d(%d) / %d(%d) lr: %s' % (
        now_ime, (step + 1), 0, max_step, hp.num_epoch, sess.run(lr)))

        for epoch in range(hp.num_epoch):
            for step in range((epoch * num_batch), ((epoch + 1) * num_batch)):
                data_batch, label_batch, name_batch = loader.batch()
                _ = sess.run(
                    train_op,
                    feed_dict={
                        net.data: data_batch,
                        net.label: label_batch,
                        net.is_training: True,
                        net.keep_prob: hp.keep_prob
                    }
                )
                if (step + 1) % hp.display == 0:
                    loss_value, acc, y1, y2 = sess.run(
                        [loss, accuracy, y_true, y_pred],
                        feed_dict={
                            net.data: data_batch,
                            net.label: label_batch,
                            net.is_training: False,
                            net.keep_prob: 1.0
                        }
                    )
                    print(y1)
                    print(y2)
                    now_ime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
                    print('%s iter: %d(%d) / %d(%d), total loss: %.4f, acc: %.2f, lr: %s'
                          % (now_ime, (step + 1), (epoch + 1), max_step, hp.num_epoch, loss_value, acc, sess.run(lr)))
                if (step + 1) % hp.snapshot_iter == 0:
                    last_snapshot_iter = step
                    net.save(sess, saver, step)
        if last_snapshot_iter != step:
            net.save(sess, saver, step)

        sess.close()
コード例 #9
0
    def __init__(self,
                 root,
                 csv_path,
                 batch_size,
                 height=331,
                 width=331,
                 transformer_fn=None,
                 num_epochs=None,
                 shuffle=True,
                 min_after_dequeue=25,
                 allow_smaller_final_batch=False,
                 num_threads=2,
                 seed=None):
        root = root.replace('\\', '/')
        root = root if root[-1] == '/' else '{}/'.format(root)
        self.root = root
        self.csv_path = csv_path.replace('\\', '/')
        self.batch_size = batch_size
        self.height = height
        self.width = width
        self.transformer_fn = transformer_fn
        self.num_epochs = num_epochs
        self.shuffle = shuffle
        self.min_after_dequeue = min_after_dequeue
        self.allow_smaller_final_batch = allow_smaller_final_batch
        self.num_threads = num_threads
        self.seed = seed

        df = pd.read_csv(self.csv_path, header=None)

        if shuffle:
            df = df.sample(frac=1., random_state=seed)

        try:
            df[2] = df.apply(lambda line: line[2].index('y') + len(line[2])
                             if line[3] else line[2].index('y'),
                             axis=1)
        except Exception as e:
            print(e)

        print('{}: create session!'.format(self.__class__.__name__))
        self.batch_ops = {}
        self.n_sample = {}
        self.graph = tf.Graph()
        with self.graph.as_default():
            with tf.device('/cpu:0'):
                for attr_key, n_class in di.num_classes_v2.items():
                    data_path = df[df[1].isin([attr_key])][0].values
                    label = df[df[1].isin([attr_key])][2].values

                    batch_ops, n_sample = csv_batch(
                        root=self.root,
                        data_path=data_path,
                        label=label,
                        n_class=n_class,
                        batch_size=batch_size,
                        height=height,
                        width=width,
                        transformer_fn=transformer_fn,
                        num_epochs=num_epochs,
                        shuffle=shuffle,
                        min_after_dequeue=min_after_dequeue,
                        allow_smaller_final_batch=allow_smaller_final_batch,
                        num_threads=num_threads,
                        seed=seed)
                    self.batch_ops[attr_key] = batch_ops
                    self.n_sample[attr_key] = n_sample
                if num_epochs is not None:
                    self.init = tf.local_variables_initializer()
        self.sess = session(graph=self.graph)
        if num_epochs is not None:
            self.sess.run(self.init)
        self.coord = tf.train.Coordinator()
        self.threads = tf.train.start_queue_runners(sess=self.sess,
                                                    coord=self.coord)
コード例 #10
0
                    if iteration % self.train_config["snapshot_interval"] == 1:
                        self._eval(iteration, self.train_config["validsize"],
                                   v_list, self.outdir)
                        self._eval(iteration, self.train_config["validsize"],
                                   v_fix_list, self.outdir_fix)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="BicycleGAN")
    parser.add_argument('--session',
                        type=str,
                        default='bicyclegan',
                        help="session name")
    parser.add_argument('--data_path',
                        type=Path,
                        help="path containing color images")
    parser.add_argument('--sketch_path',
                        type=Path,
                        help="path containing color images")
    args = parser.parse_args()

    outdir, outdir_fix, modeldir = session(args.session)

    with open("param.yaml", "r") as f:
        config = yaml.safe_load(f)
        pprint.pprint(config)

    trainer = Trainer(config, outdir, outdir_fix, modeldir, args.data_path,
                      args.sketch_path)
    trainer()
コード例 #11
0
ファイル: data_loader.py プロジェクト: yanghedada/FlowS-Unet
    def __init__(self, data_root, mask_root, batch_size, height=256, width=256, transformer_fn=None, num_epochs=None,
                 data_suffix='npy', mask_suffix='tif', shuffle=True,
                 min_after_dequeue=25, allow_smaller_final_batch=False, num_threads=2, seed=None):
        self.data_root = data_root
        self.mask_root = mask_root
        self.batch_size = batch_size
        self.height = height
        self.width = width
        self.transformer_fn = transformer_fn
        self.num_epochs = num_epochs
        self.data_suffix = data_suffix
        self.mask_suffix = mask_suffix
        self.shuffle = shuffle
        self.min_after_dequeue = min_after_dequeue
        self.allow_smaller_final_batch = allow_smaller_final_batch
        self.num_threads = num_threads
        self.seed = seed

        data_root = data_root.replace('\\', '/')
        data_root = data_root if data_root[-1] == '/' else '{0}/'.format(data_root)

        mask_root = mask_root.replace('\\', '/')
        mask_root = mask_root if mask_root[-1] == '/' else '{0}/'.format(mask_root)

        data_paths = glob.glob('{0}*.{1}'.format(data_root, data_suffix))
        data_paths = list(map(lambda x: x.replace('\\', '/'), data_paths))
        data_names = list(map(lambda x: x.split('/')[-1].split('.')[0], data_paths))
        mask_paths = list(map(lambda x: '{}{}.{}'.format(mask_root, x, mask_suffix), data_names))

        if shuffle:
            random.seed(seed)
            randnum = random.randint(0, 2018)
            random.seed(randnum)
            random.shuffle(data_paths)
            random.seed(randnum)
            random.shuffle(data_names)
            random.seed(randnum)
            random.shuffle(mask_paths)

        # print(data_paths)
        # print(data_names)
        # print(mask_paths)

        print('{}: create session!'.format(self.__class__.__name__))
        self.graph = tf.Graph()
        with self.graph.as_default():
            with tf.device('/cpu:0'):
                self.batch_ops, self.n_sample = self.folder_batch(data_paths=data_paths, mask_paths=mask_paths,
                                                                  data_names=data_names, batch_size=batch_size,
                                                                  height=height, width=width,
                                                                  transformer_fn=transformer_fn, num_epochs=num_epochs,
                                                                  shuffle=shuffle, min_after_dequeue=min_after_dequeue,
                                                                  allow_smaller_final_batch=allow_smaller_final_batch,
                                                                  num_threads=num_threads, seed=seed)
                if num_epochs is not None:
                    self.init = tf.local_variables_initializer()
        self.sess = session(graph=self.graph)
        if num_epochs is not None:
            self.sess.run(self.init)
        self.coord = tf.train.Coordinator()
        self.threads = tf.train.start_queue_runners(sess=self.sess, coord=self.coord)