コード例 #1
0
def vgg_loss(image_a, image_b):
    vgg_a, vgg_b = Vgg19('vgg19.npy'), Vgg19('vgg19.npy')
    vgg_a.build(image_a)
    vgg_b.build(image_b)
    VGG_loss = tf.reduce_mean(
        tf.losses.absolute_difference(vgg_a.conv4_4, vgg_b.conv4_4))
    h = tf.cast(tf.shape(vgg_a.conv4_4)[1], tf.float32)
    w = tf.cast(tf.shape(vgg_a.conv4_4)[2], tf.float32)
    c = tf.cast(tf.shape(vgg_a.conv4_4)[3], tf.float32)
    VGG_loss = VGG_loss / (h * w * c)
    return VGG_loss
コード例 #2
0
ファイル: baseline.py プロジェクト: cjcchen/ML
def predict():

    config = img_config()

    w2d, d2w = data.get_word_to_id()

    #img = utils.load_image("/home/tusimple/junechen/ml_data/data/train2014/COCO_train2014_000000160629.jpg")
    img = utils.load_image("/home/tusimple/junechen/ml_data/data/train2014/COCO_train2014_000000318556.jpg")
    img = img.reshape((1, 224, 224, 3))

    images = tf.placeholder("float", [None, 224, 224, 3], name="image")

    with tf.name_scope("content_vgg"):
            cnn_net = Vgg19()
            cnn_net.build(images)

    with tf.device("/gpu:1"):
        image_feature = cnn_net.conv5_3 #[-1,14*14,512]

        cg = CaptionGenerator(w2d,n_time_step=config.seq_len)
        alphas, betas, sampled_captions =  cg.build_sampler()

    sv = tf.train.Supervisor(logdir=FLAGS.save_path)
    config_proto = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)
    with sv.managed_session(config=config_proto) as sess:
        [feature] = sess.run( [image_feature], feed_dict={ images: img } )

        [image_label] = sess.run( [sampled_captions], feed_dict={
                                cg.features : feature.reshape(-1,196,512) } )

        print image_label
        print [ d2w[p] for p in image_label[0]]
コード例 #3
0
    def _reverse_map_z(self, phi_z):
        """Reverse map z from phi(z)"""
        device = '/gpu:0' if self.FLAGS.gpu else '/cpu:0'
        # Open second session with
        with tf.device(device):
            self._graph_var = tf.Graph()
            with self._graph_var.as_default():
                self._nn = Vgg19(model=self._model,
                                 input_placeholder=False,
                                 data_dir=self.FLAGS.data_dir,
                                 random_start=self.FLAGS.random_start,
                                 start_image_path=self.FLAGS.person_image)

                with tf.Session(graph=self._graph_var) as self._sess:
                    self._sess.run(tf.initialize_all_variables())

                    try:

                        self._conv_layer_tensors = [
                            self._graph_var.get_tensor_by_name(l)
                            for l in self._conv_layer_tensor_names
                        ]
                    except Exception as e:
                        raise Exception(
                            'Invalid layer names. Check out valid layer names '
                            'as defined in vgg19._init_empty_model() and make '
                            'sure this block has a relu layer. Exception: {}'.
                            format(e))

                    # Set z_tensor reference
                    self._z_tensor = self._nn.inputRGB

                    return self._optimize_z_tf(phi_z)
コード例 #4
0
def build_model(first_img_t,
                mid_img_t,
                end_img_t,
                vgg_data_dict=None,
                reuse_all=False):

    first_img_t = tf.cast(first_img_t, tf.float32) / 255.0
    mid_img_t = tf.cast(mid_img_t, tf.float32) / 255.0
    end_img_t = tf.cast(end_img_t, tf.float32) / 255.0

    assert vgg_data_dict is not None, 'Invalid vgg data dict'
    vgg = Vgg19(vgg_data_dict)
    pred_mid_img = model_interpolation(first_img_t,
                                       end_img_t,
                                       ctx_net=vgg,
                                       reuse=reuse_all)

    if loss_type is 'feature_reconstruct':
        pred_feat = vgg.relu4_4(pred_mid_img)
        gt_feat = vgg.relu4_4(mid_img_t)
        l1_loss = tf.reduce_mean(tf.square(gt_feat - pred_feat),
                                 axis=[0, 1, 2, 3])
    else:
        l1_loss = tf.losses.absolute_difference(pred_mid_img, mid_img_t)

    summary = [
        tf.summary.image('first_img', first_img_t),
        tf.summary.image('end_img', end_img_t),
        tf.summary.image('gt_mid_img', mid_img_t),
        tf.summary.image('pred_mid_img',
                         tf.clip_by_value(pred_mid_img, 0.0, 1.0)),
        tf.summary.scalar('l1_loss', l1_loss)
    ]

    return l1_loss, summary
コード例 #5
0
def create_vgg_net(inputs, vgg19_npy_path, output_classes=1000):
    vgg = Vgg19(
        vgg19_npy_path=vgg19_npy_path)  # Read model from pretrained vgg
    train_mode = tf.constant(False, dtype=tf.bool, name='train_mode')
    vgg.build(inputs, output_classes, train_mode=train_mode)
    vgg_19_net = vgg.net()
    return vgg_19_net
コード例 #6
0
ファイル: model.py プロジェクト: vicyor/bysj
 def __init__(self,target_layer,content_path,style_path,alpha,pretrained_vgg,output_path,decoder_weights) :
     self.target_layer = target_layer
     self.content_path = content_path
     self.style_path = style_path
     self.output_path = output_path
     self.alpha = alpha
     self.encoder = Vgg19(pretrained_vgg)
     self.decoder = Decoder()  
     self.decoder_weights = decoder_weights
コード例 #7
0
ファイル: baseline.py プロジェクト: cjcchen/ML
def train():

    config = img_config()
    #config.batch_size=1
    f, image,label,word, target, w2d,d2w = data.get_data(FLAGS.caption_path, FLAGS.image_path, max_len=config.seq_len+1, batch_size=config.batch_size)

    with tf.name_scope("content_vgg"):
            cnn_net = Vgg19()
            cnn_net.build(image)

    image_feature = cnn_net.conv5_3 #[-1,14*14,512]

    cg = CaptionGenerator(w2d,n_time_step=config.seq_len)
    loss = cg.build_model()

    tvars = tf.trainable_variables()
    grads, _= tf.clip_by_global_norm(tf.gradients(loss, tvars), config.max_grad_norm)
    lr = tf.Variable(0.0, trainable=False)
    new_lr = tf.Variable(0.0, trainable=False)

    lr_op=tf.assign(lr,new_lr)

    optimizer = tf.train.GradientDescentOptimizer(lr)
    train_op = optimizer.apply_gradients(
            zip(grads, tvars),
            global_step=tf.train.get_or_create_global_step())


    sv = tf.train.Supervisor(logdir=FLAGS.save_path)
    config_proto = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)
    with sv.managed_session(config=config_proto) as sess:
        threads = tf.train.start_queue_runners(sess)

        summary_writer = tf.summary.FileWriter(FLAGS.log_path,sess.graph)

        for i in range(config.max_max_epoch):
            x_lr_decay = config.lr_decay ** max(i + 1 - config.max_epoch, 0.0)
            print ("lr:",x_lr_decay)
            _,c_lr=sess.run([lr_op,lr], feed_dict={new_lr:config.learning_rate * x_lr_decay})

            for j in range(414113/config.batch_size):
                [img_f, feature,img_label] = sess.run( [f, image_feature, label] )

                _, c_loss = sess.run( [ train_op, loss ], feed_dict={
                                cg.features : feature.reshape(-1,196,512),
                                cg.captions : img_label})
                if j % 10 == 0:
                    summary = tf.Summary()
                    summary.value.add(tag='loss', simple_value=c_loss)
                    i_global=sess.run(tf.train.get_or_create_global_step())
                    print ("cost %f global step %d" % (c_loss, i_global))
                    summary_writer.add_summary(summary, i_global)#write eval to tensorboard
                if j % 100 == 0:
                    print img_f
                    print img_label
                    print [ d2w[p] for p in img_label[0]]
コード例 #8
0
ファイル: model.py プロジェクト: vicyor/bysj
 def __init__(self,content_path,style_path,alpha,pretrained_vgg,output_path) :
     self.content_path = content_path
     self.style_path = style_path
     self.output_path = output_path
     self.alpha = alpha
     self.encoder = Vgg19(pretrained_vgg)
     self.decoder = Decoder()
     #添加第五个权重
     #add ——>models/decoder_5.ckpt
     self.decoder_weights = ['models/decoder_1.ckpt','models/decoder_2.ckpt','models/decoder_3.ckpt','models/decoder_4.ckpt','models/model.ckpt-15004']
コード例 #9
0
    def __init__(self, vgg19_npy_path):
        self.vgg19 = Vgg19(vgg19_npy_path)
        self.gpyr = self.getGaussianPyramid(5)

        self.source_inputs = tf.placeholder(dtype=tf.float32)
        self.source_outputs = self.gpyr(self.source_inputs)

        self.source_encoded = []
        for output in self.source_outputs:
            self.source_encoded.extend(self.vgg19.build(output))
コード例 #10
0
ファイル: model.py プロジェクト: vicyor/bysj
 def __init__(self, target_layer=None, pretrained_path=None, max_iterator=None, checkpoint_path=None, tfrecord_path=None, batch_size=None):
     self.pretrained_path = pretrained_path
     self.target_layer = target_layer
     #vgg编码器
     self.encoder = Vgg19(self.pretrained_path)
     #最大迭代次数
     self.max_iterator = max_iterator
     self.checkpoint_path = checkpoint_path
     self.tfrecord_path = tfrecord_path
     self.batch_size = batch_size
コード例 #11
0
ファイル: loss.py プロジェクト: ahxc/CartoonGAN
def vgg_loss(real, fake):
    vgg = Vgg19('vgg19.npy')

    vgg.build(real)
    real_feature_map = vgg.conv4_4_no_activation

    vgg.build(fake)
    fake_feature_map = vgg.conv4_4_no_activation

    return L1_loss(real_feature_map, fake_feature_map)
コード例 #12
0
ファイル: DeblurGAN.py プロジェクト: JYongSmile/DeblurGAN
    def build_graph(self):

        # if self.in_memory:
        self.blur = tf.placeholder(name="blur", shape=[None, None, None, self.channel], dtype=tf.float32)
        self.sharp = tf.placeholder(name="sharp", shape=[None, None, None, self.channel], dtype=tf.float32)

        x = self.blur
        label = self.sharp

        self.epoch = tf.placeholder(name='train_step', shape=None, dtype=tf.int32)

        x = (2.0 * x / 255.0) - 1.0
        label = (2.0 * label / 255.0) - 1.0

        self.gene_img = self.generator(x, reuse=False)
        self.real_prob = self.discriminator(label, reuse=False)
        self.fake_prob = self.discriminator(self.gene_img, reuse=True)

        epsilon = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0.0, maxval=1.0)

        interpolated_input = epsilon * label + (1 - epsilon) * self.gene_img
        gradient = tf.gradients(self.discriminator(interpolated_input, reuse=True), [interpolated_input])[0]
        GP_loss = tf.reduce_mean(tf.square(tf.sqrt(tf.reduce_mean(tf.square(gradient), axis=[1, 2, 3])) - 1))

        d_loss_real = - tf.reduce_mean(self.real_prob)
        d_loss_fake = tf.reduce_mean(self.fake_prob)

        self.vgg_net = Vgg19(self.vgg_path)
        self.vgg_net.build(tf.concat([label, self.gene_img], axis=0))
        self.content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(
            self.vgg_net.relu3_3[self.batch_size:] - self.vgg_net.relu3_3[:self.batch_size]), axis=3))

        self.D_loss = d_loss_real + d_loss_fake + 10.0 * GP_loss
        self.G_loss = - d_loss_fake + 100.0 * self.content_loss

        t_vars = tf.trainable_variables()
        G_vars = [var for var in t_vars if 'generator' in var.name]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]

        lr = tf.minimum(self.learning_rate, tf.abs(2 * self.learning_rate - (
                self.learning_rate * tf.cast(self.epoch, tf.float32) / self.decay_step)))
        self.D_train = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.D_loss, var_list=D_vars)
        self.G_train = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.G_loss, var_list=G_vars)

        self.PSNR = tf.reduce_mean(tf.image.psnr(((self.gene_img + 1.0) / 2.0), ((label + 1.0) / 2.0), max_val=1.0))
        self.ssim = tf.reduce_mean(tf.image.ssim(((self.gene_img + 1.0) / 2.0), ((label + 1.0) / 2.0), max_val=1.0))

        logging_D_loss = tf.summary.scalar(name='D_loss', tensor=self.D_loss)
        logging_G_loss = tf.summary.scalar(name='G_loss', tensor=self.G_loss)
        logging_PSNR = tf.summary.scalar(name='PSNR', tensor=self.PSNR)
        logging_ssim = tf.summary.scalar(name='ssim', tensor=self.ssim)

        self.output = (self.gene_img + 1.0) * 255.0 / 2.0
        self.output = tf.round(self.output)
        self.output = tf.cast(self.output, tf.uint8)
コード例 #13
0
ファイル: train.py プロジェクト: sunyasheng/iter_vstab
def build_model(first_img_t,
                mid_img_t,
                end_img_t,
                s_img_t,
                vgg_data_dict=None,
                reuse_all=False):

    first_img_t = tf.cast(first_img_t, tf.float32) / 255.0
    mid_img_t = tf.cast(mid_img_t, tf.float32) / 255.0
    end_img_t = tf.cast(end_img_t, tf.float32) / 255.0
    s_img_t = tf.cast(s_img_t, tf.float32) / 255.0

    assert vgg_data_dict is not None, 'Invalid vgg data dict'
    vgg = Vgg19(vgg_data_dict)
    img_int, img_out, [warped_first, warped_end, warped_mid
                       ] = training_stab_model(first_img_t, s_img_t, end_img_t,
                                               mid_img_t)

    int_feat = vgg.relu4_4(img_int)
    out_feat = vgg.relu4_4(img_out)
    s_feat = vgg.relu4_4(s_img_t)

    vgg_int_loss = tf.reduce_mean(tf.square(s_feat - int_feat),
                                  axis=[0, 1, 2, 3])
    vgg_out_loss = tf.reduce_mean(tf.square(s_feat - out_feat),
                                  axis=[0, 1, 2, 3])

    l1_int_loss = tf.losses.absolute_difference(img_int, s_img_t)
    l1_out_loss = tf.losses.absolute_difference(img_out, s_img_t)

    summary = [
        tf.summary.image('first_img', first_img_t),
        tf.summary.image('end_img', end_img_t),
        tf.summary.image('mid_img', mid_img_t),
        tf.summary.image('s_img', s_img_t),
        tf.summary.image('warped_first', tf.clip_by_value(warped_first, 0, 1)),
        tf.summary.image('warped_end', tf.clip_by_value(warped_end, 0, 1)),
        tf.summary.image('warped_mid', tf.clip_by_value(warped_mid, 0, 1)),
        tf.summary.image('int_img', tf.clip_by_value(img_int, 0, 1)),
        tf.summary.image('out_img', tf.clip_by_value(img_out, 0, 1)),
        tf.summary.scalar('vgg_int_loss', vgg_int_loss),
        tf.summary.scalar('vgg_out_loss', vgg_out_loss),
        tf.summary.scalar('l1_int_loss', l1_int_loss),
        tf.summary.scalar('l1_out_loss', l1_out_loss)
    ]
    tot_loss = vgg_int_loss + vgg_out_loss + \
                l1_int_loss + l1_out_loss
    # tot_loss = vgg_int_loss + l1_int_loss
    # tot_loss = l1_int_loss + l1_out_loss
    return tot_loss, summary
コード例 #14
0
    def build_model(self):
        img_in, img_gt = self.input_producer(self.batch_size)

        #tf.summary.image('img_in', im2uint8(img_in))
        #tf.summary.image('img_gt', im2uint8(img_gt))
        print('img_in, img_gt', img_in.get_shape(), img_gt.get_shape())

        # generator
        x_unwrap = self.generator(img_in, reuse=False, scope='g_net')
        # calculate multi-scale loss
        self.loss_total = 0
        for i in xrange(self.n_levels):
            _, hi, wi, _ = x_unwrap[i].get_shape().as_list()
            gt_i = tf.image.resize_images(img_gt, [hi, wi], method=0)
            loss = tf.reduce_mean((gt_i - x_unwrap[i])**2)  #MSE loss
            #perceptual_loss
            vgg_net = Vgg19(self.vgg_path)
            vgg_net.build(tf.concat([gt_i, x_unwrap[i]], axis=0))
            perceptual_loss = tf.reduce_mean(
                tf.reduce_sum(tf.square(vgg_net.relu3_3[self.batch_size:] -
                                        vgg_net.relu3_3[:self.batch_size]),
                              axis=3))
            self.loss_total += (loss + perceptual_loss * 0.01)

            tf.summary.image('out_' + str(i), im2uint8(x_unwrap[i]))
            tf.summary.scalar('loss_' + str(i), loss)

        # losses
        tf.summary.scalar('loss_total', self.loss_total)

        self.PSNR = tf.reduce_mean(
            tf.image.psnr(((x_unwrap[0] + 1.0) / 2.0), ((img_gt + 1.0) / 2.0),
                          max_val=1.0))
        self.ssim = tf.reduce_mean(
            tf.image.ssim(((x_unwrap[0] + 1.0) / 2.0), ((img_gt + 1.0) / 2.0),
                          max_val=1.0))
        tf.summary.scalar(name='PSNR', tensor=self.PSNR)
        tf.summary.scalar(name='SSIM', tensor=self.ssim)

        self.output = (x_unwrap[0] + 1.0) * 255.0 / 2.0
        self.output = tf.round(self.output)
        self.output = tf.cast(self.output, tf.uint8)

        # training vars
        all_vars = tf.trainable_variables()
        self.all_vars = all_vars
        self.g_vars = [var for var in all_vars if 'g_net' in var.name]
        self.lstm_vars = [var for var in all_vars if 'LSTM' in var.name]
        for var in all_vars:
            print(var.name)
コード例 #15
0
ファイル: model.py プロジェクト: nihaotiancai/style_transform
 def __init__(self, content_path, style_path, alpha_wct, alpha_swap,
              pretrained_vgg, output_path):
     self.content_path = content_path  # 内容图片路径
     self.style_path = style_path  # 风格图片路径
     self.output_path = output_path  # 融合后图片输出路径
     self.alpha_wct = alpha_wct  # wct方法内容与风格图片比重
     self.alpha_swap = alpha_swap  #wct_swap方法内容与风格图片比重
     self.encoder = Vgg19(pretrained_vgg)  # 导入VGG19 模型
     self.decoder = Decoder()  # 导入Decoder 反卷积
     # 导入Decoder模型参数
     self.decoder_weights = [
         'models/decoder_1.ckpt', 'models/decoder_2.ckpt',
         'models/decoder_3.ckpt', 'models/decoder_4.ckpt'
     ]
コード例 #16
0
def build_test_model(first_img_p,
                     end_img_p,
                     vgg_data_dict=None,
                     reuse=False,
                     training=False):
    first_img_p = tf.cast(first_img_p, tf.float32) / 255.0
    end_img_p = tf.cast(end_img_p, tf.float32) / 255.0
    assert vgg_data_dict is not None, 'Invalid vgg data dict'
    vgg = Vgg19(vgg_data_dict)
    pred_mid_img = model_interpolation(first_img_p,
                                       end_img_p,
                                       ctx_net=vgg,
                                       reuse=reuse,
                                       training=training)
    pred_mid_img = tf.clip_by_value(pred_mid_img, 0., 1.)
    return pred_mid_img
コード例 #17
0
ファイル: validate_image.py プロジェクト: hjyai94/VGG
def test_image(path_image, num_class):
    img_string = tf.read_file(path_image)
    img_decoded = tf.image.decode_png(img_string, channels=3)
    img_resized = tf.image.resize_images(img_decoded, [224, 224])
    img_resized = tf.reshape(img_resized, shape=[1, 224, 224, 3])
    model = Vgg19(bgr_image=img_resized, num_class=num_class, vgg19_npy_path='./vgg19.npy')
    score = model.fc8
    prediction = tf.argmax(score, 1)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, "./tmp/checkpoints/model_epoch100.ckpt")

        plt.imshow(img_decoded.eval())
        plt.title("Class:" + class_name[sess.run(prediction)[0]])
        plt.show()
コード例 #18
0
    def build_loss(self):
        # Compute losses.
        self.mse = tf.losses.mean_squared_error(labels=self.bg_img,
                                                predictions=self.output)

        perceptron = Vgg19(cfg.vgg_dir)
        perceptron.build(tf.concat([self.bg_img, self.output], axis=0))
        self.content_loss = tf.losses.mean_squared_error(
            perceptron.conv3_4[:cfg.batch_size],
            perceptron.conv3_4[cfg.batch_size:])

        self.ssim = tf.reduce_mean(
            tf.image.ssim(self.bg_img, self.output, max_val=255.0))
        self.psnr = tf.reduce_mean(
            tf.image.psnr(self.bg_img, self.output, max_val=255.0))
        self.total_loss = self.mse + 1e-3 * self.content_loss
コード例 #19
0
ファイル: baseline_train.py プロジェクト: cjcchen/ML
def main():
    train_caption_path = "/home/tusimple/junechen/ml_data/data/annotations/captions_train2014.json"
    train_image_path = "/home/tusimple/junechen/ml_data/data/train2014/"

    f, image, label, word, target, w2d, d2w = data.get_data(train_caption_path,
                                                            train_image_path,
                                                            max_len=16 + 1,
                                                            batch_size=50,
                                                            mode='train')

    cnn_net = Vgg19()
    cnn_net.build(image)
    image_feature = cnn_net.conv5_3  #[-1,14*14,512]

    print("cnn build done")
    model = CaptionGenerator(image_feature,
                             label,
                             w2d,
                             dim_feature=[196, 512],
                             dim_embed=512,
                             dim_hidden=1024,
                             n_time_step=16,
                             prev2out=True,
                             ctx2out=True,
                             alpha_c=1.0,
                             selector=True,
                             dropout=True)

    print("build done")
    solver = CaptioningSolver(model,
                              data,
                              None,
                              n_epochs=20,
                              batch_size=128,
                              update_rule='adam',
                              learning_rate=0.0005,
                              print_every=1000,
                              save_every=1,
                              image_path='./image/')

    solver.train()
コード例 #20
0
ファイル: model.py プロジェクト: nihaotiancai/style_transform
 def __init__(self,
              target_layer=None,
              pretrained_path=None,
              max_iterator=None,
              checkpoint_path=None,
              tfrecord_path=None,
              batch_size=None):
     self.pretrained_path = pretrained_path  #预处理模型存放路径
     self.target_layer = target_layer  #目标层名称
     # 加载vgg19 模型,调用Vgg19
     self.encoder = Vgg19(self.pretrained_path)
     self.max_iterator = max_iterator  #训练最大次数
     self.checkpoint_path = checkpoint_path  #decoder存放路径
     self.tfrecord_path = tfrecord_path  #tfrecord存放路径
     self.batch_size = batch_size  #训练batch大小
     # 预设训练的字典, {名称:ckpt文件名}
     self.save_model_dir = {
         'relu1': 'decoder_1.ckpt',
         'relu2': 'decoder_2.ckpt',
         'relu3': 'decoder_3.ckpt',
         'relu4': 'decoder_4.ckpt',
         'relu5': 'decoder_5.ckpt',
     }
コード例 #21
0
def model(content_path, style_path, alpha, beta, lambdas, num_iterations,
          initial_learning_rate, learning_rate_decay):

    # Code to make Tensorflow release GPU Memory after it's done
    tf.reset_default_graph()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    # Read the content and style images, normalize them and also save the normalizing values
    content_means, content_variances, content_image = readAndNormalize(
        content_path)
    style_means, style_variances, style_image = readAndNormalize(style_path)

    # Compute the weighted average of the means and variances
    means = (alpha * content_means + beta * style_means) / (alpha + beta)
    variances = (alpha * content_variances + beta * style_variances) / (alpha +
                                                                        beta)

    # Add another dimension to the images because that's what the model needs
    content_image = content_image[np.newaxis, ...].astype(dtype=np.float32)
    style_image = style_image[np.newaxis, ...].astype(dtype=np.float32)
    target_means, target_variances, target_image = addGausianNoise(
        content_image[0, :, :, :])

    # Create the 3 model instances we need
    content_model = Vgg19()
    style_model = Vgg19()
    target_model = Vgg19()

    # Create the dictionaries we'll use to store the values computed
    style_activations = dict()
    target_style_activations = dict()

    # Define the target image to be a Tensorflow variable and initiate it to the noisy image
    target_image_variable = tf.get_variable(name='target_activation',
                                            dtype=tf.float32,
                                            initializer=target_image)

    # Build the instance for the content and calculate the activation for a layer
    content_model.build(content_image)
    content_activation = content_model.conv3_1

    # Build the instance for the style and calculate the activations for the layers we want
    style_model.build(style_image)
    style_activations['conv2_1'] = style_model.conv2_1
    style_activations['conv2_2'] = style_model.conv2_2
    style_activations['conv3_1'] = style_model.conv3_1
    style_activations['conv3_2'] = style_model.conv3_2

    # Build the instance for the target image and calculate the activations for the layers we used for CONTENT and STYLE
    # Make sure that you use the same layers as you did for the content and style activations
    target_model.build(target_image_variable)
    target_content_activation = target_model.conv3_1
    target_style_activations['conv2_1'] = target_model.conv2_1
    target_style_activations['conv2_2'] = target_model.conv2_2
    target_style_activations['conv3_1'] = target_model.conv3_1
    target_style_activations['conv3_2'] = target_model.conv3_2

    # Calculate the costs using the functions defined in the utils file
    content_cost = contentCost(content_activation, target_content_activation)
    style_cost = styleCost(style_activations, target_style_activations,
                           lambdas)
    total_cost = totalCost(content_cost, style_cost, alpha, beta)

    # Define the optimizer
    learning_rate = tf.placeholder(name='learning_rate', dtype=tf.float32)
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_step = optimizer.minimize(total_cost)

    # Start the session using the config settings defined at the beginning of the file
    with tf.Session(config=config) as session:

        # Initialize the variable
        session.run(tf.global_variables_initializer())

        # Run the loop for the specified number of times
        for iteration in range(num_iterations):

            # Run one iteration of the training stept
            session.run(train_step,
                        feed_dict={
                            learning_rate:
                            initial_learning_rate /
                            (1 + learning_rate_decay * iteration)
                        })

            # Print the cost every X iterations
            if iteration % 10 == 0:
                current_content_cost, current_style_cost = session.run(
                    [content_cost, style_cost])
                print("Iteration number {}: {} ~ 10^{}  <--->  {} ~ 10^{} ".
                      format(iteration, current_content_cost,
                             np.floor(np.log10(current_content_cost)),
                             current_style_cost,
                             np.floor(np.log10(current_style_cost))))

        # Save the final output
        output, final_style_cost, final_content_cost = session.run(
            [target_image_variable, style_cost, content_cost])

        # Revert the normalization
        output = output[0, :, :, :] * content_variances + content_means

        # Save the image
        content_path = content_path[:-4]
        content_path = content_path[17:]
        style_path = style_path[:-4]
        style_path = style_path[15:]
        name = 'generated images/{}+{}.jpg'.format(content_path, style_path)
        cv2.imwrite(name, output)
        print(
            name +
            " was created. Final content cost was {}; Final style cost was {}."
            .format(final_content_cost, final_style_cost))

        # Close the session and free the unused memory
        session.close()
        gc.collect()
コード例 #22
0

def preprocess_image(image_path):

    img = load_image(image_path)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    img = transform(img)
    return img


vgg1 = Vgg19(requires_grad=False)

vgg1.cuda()

image_path = 'result.jpg'

style_path = 'style.jpg'

photo_path = 'photo.jpg'

style_img = preprocess_image(style_path)

content_img = preprocess_image(photo_path)

# img = preprocess_image(image_path)
コード例 #23
0
    def build_graph(self):
        
        self.out = tf.placeholder(name = "out", shape = [None, None, None, self.channel], dtype = tf.float32)
        self.inp = tf.placeholder(name = "inp", shape = [None, None, None, self.channel+self.focus_channel], dtype = tf.float32)
        self.delta = tf.placeholder(tf.float32, name="refocus_parameter", shape = [None,self.z_dim])

        x = self.inp
        label = self.out
        delta = self.delta
        
        self.epoch = tf.placeholder(name = 'train_step', shape = None, dtype = tf.int32)
        
        x = (2.0 * x / 255.0) - 1.0
        label = (2.0 * label / 255.0) - 1.0
        
        self.gene_img = self.generator(x, delta, reuse = False)

        #Input to the discriminator is the real/label image
        self.real_prob = self.discriminator(label, reuse = False)
        
        #Input to the discriminator is the image generated from the generator
        self.fake_prob = self.discriminator(self.gene_img, reuse = True)
        
        epsilon = tf.random_uniform(shape = [self.batch_size, 1, 1, 1], minval = 0.0, maxval = 1.0)
        interpolated_input = epsilon * label + (1 - epsilon) * self.gene_img
        gradient = tf.gradients(self.discriminator(interpolated_input, reuse = True), [interpolated_input])[0]
        GP_loss = tf.reduce_mean(tf.square(tf.sqrt(tf.reduce_mean(tf.square(gradient), axis = [1, 2, 3])) - 1))
        
        d_loss_real = - tf.reduce_mean(self.real_prob)
        d_loss_fake = tf.reduce_mean(self.fake_prob)
        
        """
            LOSS FUNCTIONS
            1.Perceptual Loss or Content loss
            3.Adversarial loss
            4.GP_loss

            Generator loss     = Adversarial loss + 100*Content loss 
        """ 
        if self.mode == 'train':
            self.vgg_net = Vgg19(self.vgg_path)
            self.vgg_net.build(tf.concat([label, self.gene_img], axis = 0))
            self.content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(self.vgg_net.relu3_3[self.batch_size:] - self.vgg_net.relu3_3[:self.batch_size]), axis = 3))

            self.D_loss = d_loss_real + d_loss_fake + 10.0 * GP_loss
            self.G_loss = - d_loss_fake + 100.0 * self.content_loss

            t_vars = tf.trainable_variables()
            G_vars = [var for var in t_vars if 'generator' in var.name]
            D_vars = [var for var in t_vars if 'discriminator' in var.name]

            lr = tf.minimum(self.learning_rate, tf.abs(2 * self.learning_rate - (self.learning_rate * tf.cast(self.epoch, tf.float32) / self.decay_step)))
            self.D_train = tf.train.AdamOptimizer(learning_rate = lr).minimize(self.D_loss, var_list = D_vars)
            self.G_train = tf.train.AdamOptimizer(learning_rate = lr).minimize(self.G_loss, var_list = G_vars)
            
            logging_D_loss = tf.summary.scalar(name = 'D_loss', tensor = self.D_loss)   
            logging_G_loss = tf.summary.scalar(name = 'G_loss', tensor = self.G_loss)
        
        self.output = (self.gene_img + 1.0) * 255.0 / 2.0
        self.output = tf.round(self.output)
        self.output = tf.cast(self.output, tf.uint8)
コード例 #24
0
def train():
    args = Train_Args()
    train_start = time.time()
    time_start = time.time()

    checkpoint_dir = args.model_save_path
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    srcfile = args.model_filename
    dstfile = args.model_save_path + srcfile
    shutil.copyfile(srcfile, dstfile)
    shutil.copyfile(args.args_filename,
                    args.model_save_path + args.args_filename)

    train_x, train_y, test_x, test_y = load_train_005_6_shuffle()
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # the second GPU
    best_train_accuracy = 0
    best_test_accuracy = 0

    with tf.Session() as sess:
        deepem = Vgg19(args)
        # deepem = dec14(args)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
        sess.run(tf.global_variables_initializer())
        print("train size is %d " % len(train_x), flush=True)
        for e in range(args.num_epochs):
            print('\n=============== Epoch %d/%d ===============' %
                  (e + 1, args.num_epochs),
                  flush=True)
            num_batch = len(train_x) // args.batch_size
            print("num_batch is %d" % num_batch, flush=True)
            acc_train = []
            for i in range(num_batch):
                batch_x = train_x[args.batch_size * i:args.batch_size *
                                  (i + 1)]
                batch_y = train_y[args.batch_size * i:args.batch_size *
                                  (i + 1)]
                loss, accuracy, lr, _ = sess.run([
                    deepem.loss, deepem.accuracy, deepem.lr, deepem.optimizer
                ], {
                    deepem.X: batch_x,
                    deepem.Y: batch_y
                })
                acc_train.append(accuracy)
                if i % 10 == 0:
                    print('lr: %.8f loss: %.6f  acc: %.6f' %
                          (lr, loss, accuracy),
                          flush=True)

            train_acc = np.mean(acc_train)
            if train_acc > best_train_accuracy:
                best_train_accuracy = train_acc

            print("avg acc: %.6f" % train_acc)
            print("best_train_accuracy: %.6f" % best_train_accuracy,
                  flush=True)

            if e % 10 == 0 or e == args.num_epochs - 1:
                print("\ntesting start.", flush=True)
                num_batch = len(test_x) // args.batch_size
                print("num_batch is %d" % num_batch, flush=True)
                acc_test = []
                for i in range(num_batch):
                    batch_x = test_x[args.batch_size * i:args.batch_size *
                                     (i + 1)]
                    batch_y = test_y[args.batch_size * i:args.batch_size *
                                     (i + 1)]
                    accuracy = sess.run(deepem.accuracy,
                                        feed_dict={
                                            deepem.X: batch_x,
                                            deepem.Y: batch_y
                                        })
                    print('acc: %.6f' % (accuracy), flush=True)
                    acc_test.append(accuracy)

                acc = np.mean(acc_test)
                if acc > best_test_accuracy:
                    best_test_accuracy = acc
                    ckpt_path = os.path.join(checkpoint_dir, 'model.ckpt')
                    saver.save(sess, ckpt_path, global_step=e)
                    print("model saved!")
                print("avg acc: %.6f" % np.mean(acc))
                print("best_test_accuracy: %.6f" % best_test_accuracy,
                      flush=True)

    output_name = args.model_save_path + "accuracy.txt"
    output = open(output_name, 'w')
    output.write("train accuracy: " + str(best_train_accuracy) + '\n')
    output.write("test accuracy: " + str(best_test_accuracy) + '\n')
    output.close
    train_end = time.time()
    print("\ntrain done! totally cost: %.5f \n" % (train_end - train_start),
          flush=True)
    print("best_test_accuracy: %.6f" % best_test_accuracy, flush=True)
コード例 #25
0
imgs_path = [
    "./img-airplane-224x224.jpg", "./img-guitar-224x224.jpg",
    "./img-puzzle-224x224.jpg", "./img-tatoo-plane-224x224.jpg",
    "./img-dog-224x224.jpg", "./img-paper-plane-224x224.jpg",
    "./img-pyramid-224x224.jpg", "./img-tiger-224x224.jpg"
]
imgs = utils.load_images(*imgs_path)
print("The input image(s) is loaded.")
for i, img in enumerate(imgs_path):
    print("%d %s" % (i, img))
print("")

# Design the graph.
graph = tf.Graph()
with graph.as_default():
    nn = Vgg19(model=model)

# Run the graph in the session.
with tf.Session(graph=graph) as sess:
    tf.initialize_all_variables().run()
    print("Tensorflow initialized all variables.")

    # The OP to write logs to Tensorboard.
    summary_writer = tf.train.SummaryWriter(test_lib.tf_log_path,
                                            graph=sess.graph)

    preds = sess.run(nn.preds, feed_dict={nn.inputRGB: imgs})
    print("There you go the predictions.")

    for i, pred in enumerate(preds):
        utils.print_prediction(pred, test_lib.label_vgg, img_path=imgs_path[i])
コード例 #26
0
    def _get_phi_z_const(self):
        """
        Calculates the constant phi(z) = phi(x) + alpha * w
        :return: phi(z) = phi(x) + alpha * w
        """

        device = '/gpu:0' if self.FLAGS.gpu else '/cpu:0'
        if self.FLAGS.rebuild_cache or not os.path.isfile('cache.ch.npy'):
            print('Using device: {}'.format(device))
            with tf.device(device):
                self._graph_ph = tf.Graph()
                with self._graph_ph.as_default():
                    self._nn = Vgg19(model=self._model,
                                     input_placeholder=True,
                                     start_image_path=self.FLAGS.person_image)

                    with tf.name_scope('DFI-Graph') as scope:
                        # Run the graph in the session.
                        with tf.Session(graph=self._graph_ph) as self._sess:
                            self._sess.run(tf.initialize_all_variables())

                            self._conv_layer_tensors = [
                                self._graph_ph.get_tensor_by_name(l)
                                for l in self._conv_layer_tensor_names
                            ]

                            if self.FLAGS.discrete_knn:
                                atts = load_discrete_lfw_attributes(
                                    self.FLAGS.data_dir)
                            else:
                                atts = load_lfw_attributes(self.FLAGS.data_dir)

                            imgs_path = atts['path'].values

                            if self.FLAGS.person_image:
                                start_img_path = self.FLAGS.person_image
                            else:
                                start_img_path = imgs_path[
                                    self.FLAGS.person_index]

                            person_index = get_person_idx_by_path(
                                atts, start_img_path)

                            start_img = \
                                reduce_img_size(load_images(*[start_img_path]))[
                                    0]
                            plt.imsave(fname='start_img.png', arr=start_img)

                            # Get image paths
                            pos_paths, neg_paths = self._get_sets(
                                atts, self.FLAGS.feature, person_index)

                            # Reduce image sizes
                            pos_imgs = reduce_img_size(load_images(*pos_paths))
                            neg_imgs = reduce_img_size(load_images(*neg_paths))

                            # Get pos/neg deep features
                            pos_deep_features = self._phi(pos_imgs)
                            neg_deep_features = self._phi(neg_imgs)

                            # Calc W
                            w = np.mean(pos_deep_features, axis=0) - np.mean(
                                neg_deep_features, axis=0)
                            w /= np.linalg.norm(w)

                            inv = -1 if self.FLAGS.invert else 1

                            # Calc phi(z)
                            phi = self._phi(start_img)
                            phi_z = phi + self.FLAGS.alpha * w * inv
                            np.save('cache.ch', phi_z)
        else:
            print('Loading cached phi_z')
            phi_z = np.load('cache.ch.npy')

        return phi_z
コード例 #27
0
def training(args):

    dtype = torch.float64
    if args.gpu:
        use_cuda = True
        print("Current device: %d" % torch.cuda.current_device())
        dtype = torch.cuda.FloatTensor

    print('content = {}'.format(args.content))
    print('style = {}'.format(args.style))

    img_transform = transforms.Compose([
        transforms.Grayscale(3),
        transforms.Resize(
            args.image_size,
            interpolation=Image.NEAREST),  # scale shortest side to image_size
        transforms.CenterCrop(args.image_size),  # crop center image_size out
        transforms.ToTensor(),  # turn image from [0-255] to [0-1]
        utils.normalize_imagenet(norm)  # normalize with ImageNet values
    ])

    content = Image.open(args.content)
    content = img_transform(content)  # Loaded already cropped
    content = Variable(content.repeat(1, 1, 1, 1),
                       requires_grad=False).type(dtype)

    # define network
    image_transformer = ImageTransformNet().type(dtype)
    optimizer = Adam(image_transformer.parameters(), 1e-5)

    loss_mse = torch.nn.MSELoss()

    # load vgg network
    vgg = Vgg19().type(dtype)

    # get training dataset
    dataset_transform = transforms.Compose([
        transforms.Grayscale(3),
        transforms.Resize(
            args.image_size,
            interpolation=Image.NEAREST),  # scale shortest side to image_size
        transforms.CenterCrop(args.image_size),  # crop center image_size out
        transforms.ToTensor(),  # turn image from [0-255] to [0-1]
        utils.normalize_imagenet(norm)  # normalize with ImageNet values
    ])
    train_dataset = datasets.ImageFolder(args.dataset, dataset_transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)

    # style image
    style_transform = transforms.Compose([
        transforms.Grayscale(3),
        transforms.ToTensor(),  # turn image from [0-255] to [0-1]
        utils.normalize_imagenet(norm)  # normalize with ImageNet values
    ])

    style = Image.open(args.style)
    if "clinical" in args.style:
        style = style.crop(
            (20, 0, style.size[0],
             style.size[1]))  # Remove left bar from the style image
    style = style_transform(style)
    style = Variable(style.repeat(args.batch_size, 1, 1, 1)).type(dtype)

    # calculate gram matrices for target style layers
    style_features = vgg(style)
    style_gram = [utils.gram(feature) for feature in style_features]

    if args.loss == 1:
        print("Using average style on features")
        if "clinical" in args.style:
            with open('models/perceptual/us_clinical_ft_dict.pickle',
                      'rb') as handle:
                style_features = pickle.load(handle)
        else:
            with open('models/perceptual/us_hq_ft_dict.pickle',
                      'rb') as handle:
                style_features = pickle.load(handle)
        style_features = [
            style_features[label].type(dtype)
            for label in style_features.keys()
        ]
        style_gram = [utils.gram(feature) for feature in style_features]

    style_loss_list, content_loss_list, total_loss_list = [], [], []

    for e in range(args.epochs):
        count = 0
        img_count = 0

        # train network
        image_transformer.train()
        for batch_num, (x, label) in enumerate(train_loader):
            img_batch_read = len(x)
            img_count += img_batch_read

            # zero out gradients
            optimizer.zero_grad()

            # input batch to transformer network
            x = Variable(x).type(dtype)
            y_hat = image_transformer(x)

            # get vgg features
            y_c_features = vgg(x)
            y_hat_features = vgg(y_hat)

            # calculate style loss
            y_hat_gram = [utils.gram(feature) for feature in y_hat_features]
            style_loss = 0.0
            for j in range(5):
                style_loss += loss_mse(y_hat_gram[j],
                                       style_gram[j][:img_batch_read])
            style_loss = args.weights[0] * style_loss

            # calculate content loss (block5_conv2)
            recon = y_c_features[5]
            recon_hat = y_hat_features[5]
            content_loss = args.weights[1] * loss_mse(recon_hat, recon)

            # total loss
            total_loss = style_loss + content_loss

            # backprop
            total_loss.backward()
            optimizer.step()

            # print out status message
            if ((batch_num + 1) % 100 == 0):
                count = count + 1
                total_loss_list.append(total_loss.item())
                content_loss_list.append(content_loss.item())
                style_loss_list.append(style_loss.item())
                print(
                    "Epoch {}:\t [{}/{}]\t\t Batch:[{}]\t total: {:.6f}\t style: {:.6f}\t content: {:.6f}\t"
                    .format(e, img_count, len(train_dataset), batch_num + 1,
                            total_loss.item(), style_loss.item(),
                            content_loss.item()))

        image_transformer.eval()

        stylized = image_transformer(content).cpu()
        out_path = args.save_dir + "/opt/perc%d_%d.png" % (e, batch_num + 1)
        utils.save_image(out_path, stylized.data[0], norm)

        image_transformer.train()

    # save model
    image_transformer.eval()

    filename = 'models/perceptual/' + str(args.model_name)
    if not '.model' in filename:
        filename = filename + '.model'
    torch.save(image_transformer.state_dict(), filename)

    total_loss = np.array(total_loss_list)
    style_loss = np.array(style_loss_list)
    content_loss = np.array(content_loss_list)
    x = np.arange(0, np.size(total_loss)) / (count + 1)

    fig = plt.figure('Perceptual Loss')
    plt.plot(x, total_loss)
    plt.plot(x, content_loss)
    plt.plot(x, style_loss)
    plt.legend(['Total', 'Content', 'Style'])
    plt.title('Perceptual Loss')
    plt.savefig(args.save_dir + '/perc_loss.png')
コード例 #28
0
def predict():
    time_start = time.time()
    args = Predict_Args()
    deepem = Vgg19(args)
    checkpoint_dir = args.model_save_path
    with tf.Session() as sess:
        saver = tf.train.Saver(tf.global_variables())
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print('Restore model failed!', flush=True)
        if not os.path.exists(args.result_path):
            os.mkdir(args.result_path)

        print("start predicting....\n", flush=True)
        for num in range(args.start_mic_num, args.end_mic_num + 1):
            mrc_name = args.data_path + args.name_prefix + str(num).zfill(
                args.name_length) + ".mrc"
            if not os.path.exists(mrc_name):
                print("%s is not exist!" % mrc_name)
                continue

            print("\nprocessing mrc %s..." % mrc_name, flush=True)
            output_name = args.result_path + args.name_prefix + str(num).zfill(
                args.name_length) + '.box'
            output = open(output_name, 'w')
            test_x, test_index = load_predict(args, mrc_name)
            test_x = np.asarray(test_x).reshape(len(test_x), args.boxsize,
                                                args.boxsize, 1)

            num_batch = len(test_x) // args.batch_size
            print("num_of_box is %d" % len(test_x), flush=True)
            particle = []
            scores = []
            for i in range(num_batch):
                batch_x = test_x[args.batch_size * i:args.batch_size * (i + 1)]
                batch_test_index = test_index[args.batch_size *
                                              i:args.batch_size * (i + 1)]
                pred = sess.run(deepem.pred, feed_dict={deepem.X: batch_x})
                # print("pred: avg = %.10f, max = %.10f " % ( pred.mean(), pred.max()))
                for i in range(len(batch_test_index)):
                    sys.stdout.flush()
                    if pred[i] > args.accuracy:
                        #print("%d.mrc %d %d %.10f"%(num,batch_test_index[i][0],batch_test_index[i][1],pred[i]),flush=True)
                        particle.append(
                            [batch_test_index[i][0], batch_test_index[i][1]])
                        scores.append(pred[i])
            print("%d particles detected in %s!" % (len(particle), mrc_name))
            if len(particle) == 0:
                output.close
                continue
            particle = np.asarray(particle)
            scores = np.asarray(scores)
            # remove overlapping particles
            result = non_max_suppression(particle, scores, args.boxsize,
                                         args.threhold)
            for i in range(len(result)):
                #print("%d.mrc %d %d "%(num,result[i][0],result[i][1]),flush=True)
                output.write(
                    str(result[i][0]) + '\t' + str(result[i][1]) + '\t' +
                    str(args.boxsize) + '\t' + str(args.boxsize) + '\n')
                output.flush()
            print("%d particles left in %s!" % (len(result), mrc_name))
            output.close

    time_end = time.time()
    print("\ntotal %d mrc pictures." %
          (args.end_mic_num - args.start_mic_num + 1))
    print("predicting done! totally cost: %.5f \n" % (time_end - time_start),
          flush=True)
    print("cost of every single mrc file: %.5f \n" %
          ((time_end - time_start) /
           (args.end_mic_num - args.start_mic_num + 1)),
          flush=True)
コード例 #29
0
    # create an reinitializable iterator given the dataset structure
    iterator = Iterator.from_structure(tr_data.data.output_types,
                                       tr_data.data.output_shapes)
    next_batch = iterator.get_next()

# Ops for initializing the two different iterators
training_init_op = iterator.make_initializer(tr_data.data)
validation_init_op = iterator.make_initializer(val_data.data)

# TF placeholder for graph input and output
x = tf.placeholder(tf.float32, [batch_size, 224, 224, 3])
y = tf.placeholder(tf.float32, [batch_size, num_classes])
keep_prob = tf.placeholder(tf.float32)

# Initialize model
model = Vgg19(x, keep_prob, num_classes, train_layers)

# Link variable to model output
score = model.fc8

# List of trainable variables of the layers we want to train
var_list = [
    v for v in tf.trainable_variables() if v.name.split('/')[0] in train_layers
]

# Op for calculating the loss
with tf.name_scope("cross_ent"):
    loss = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(logits=score, labels=y))

# Train op
コード例 #30
0
ファイル: train_KLH.py プロジェクト: lsy8899/deepEM
def train():
    args = Train_Args()
    train_start = time.time()
    time_start = time.time()

    checkpoint_dir = args.model_save_path
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)
        
    output_name = args.model_save_path + "training_messages.txt"
    output = open(output_name, 'w')
    print("read data start.",flush=True)
    output.write("read data start.\n")

    train_x, train_y, test_x, test_y = load_train(args)
    print("shape of train_x: " , train_x.shape)
    time_end = time.time()
    print("\nread done! totally cost: %.5f \n" %(time_end - time_start),flush=True)
    output.write("read done! totally cost: " + str(time_end - time_start) +'\n')
    output.flush()
    print("training start.",flush=True)
    # copy argument file
    srcfile = args.model_filename
    dstfile = args.model_save_path + srcfile
    shutil.copyfile(srcfile,dstfile)
    shutil.copyfile(args.args_filename,args.model_save_path + args.args_filename)

    time_start = time.time()

    tot_cost = []
    plot = []
    plot_train = []
    best_test_accuracy = 0
    best_train_accuracy = 0
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"

    with tf.Session() as sess:
        deepem = Vgg19(args)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep = 100)
        sess.run(tf.global_variables_initializer())
        print("train size is %d " % len(train_x), flush=True)

        for e in range(args.num_epochs):
            print('\n=============== Epoch %d/%d ==============='% (e + 1,args.num_epochs),flush=True)
            output.write("\n=============== Epoch " + str(e + 1) + "/" + str(args.num_epochs) + " ===============\n")
            cost = []
            num_batch = len(train_x) // args.batch_size
            print("num_batch is %d" % num_batch,flush=True)
            test_train_pred = []
            test_train_y = []
            for i in range(num_batch):
                batch_x = train_x[args.batch_size*i:args.batch_size*(i+1)]
                batch_y = train_y[args.batch_size*i:args.batch_size*(i+1)]
                # batch_x = batch_x.reshape((args.batch_size, args.resize, args.resize, 1))
                loss,pred_train,lr,_= sess.run([deepem.loss, deepem.pred,deepem.lr, deepem.optimizer], {deepem.X:batch_x, deepem.Y: batch_y})
                test_train_pred[args.batch_size*i:args.batch_size*(i+1)] = pred_train[:]
                test_train_y[args.batch_size*i:args.batch_size*(i+1)] = batch_y[:]
                cost.append(loss)
                if i % 10 == 0:
                    print('lr: %.8f loss: %.6f' % (lr, np.mean(cost)),flush=True)

            tot_cost.append([np.mean(cost)])
            output.write("average loss: " + str(np.mean(cost)) + '\n')
            output.flush()
            # print("test_train_pred",test_train_pred)
            test_train_pred = np.asarray(test_train_pred)
            print("avg = %.6f , min = %.6f, max = %.6f "% (test_train_pred.mean(),test_train_pred.min(),test_train_pred.max()))
            threhold = 0.5
            test_train_pred[test_train_pred<=threhold] = 0
            test_train_pred[test_train_pred>threhold] = 1
            accuracy_train = np.sum(np.equal(test_train_pred,test_train_y))/len(test_train_y)

            if accuracy_train > best_train_accuracy:
                best_train_accuracy = accuracy_train

            plot_train.append(1- accuracy_train)
            print("train accuracy: %.6f" % accuracy_train,flush=True)
            print("best_train_accuracy: %.6f" % best_train_accuracy,flush=True)
            output.write("train accuracy: " + str(accuracy_train) + '\n')
            output.flush()

            # start testing
            if e % 5 == 0 or e = args.num_epochs -1 :
                print("\ntesting start.",flush=True)
                num_batch = len(test_x) // args.batch_size
                print("num_batch is %d" % num_batch,flush=True)
                test_pred = []
                for i in range(num_batch):
                    batch_x = test_x[args.batch_size*i:args.batch_size*(i+1)]
                    # batch_x = batch_x.reshape((args.batch_size, args.resize, args.resize, 1))
                    batch_y = test_y[args.batch_size*i:args.batch_size*(i+1)]
                    pred = sess.run(deepem.pred,feed_dict={deepem.X: batch_x})
                    test_pred[args.batch_size*i:args.batch_size*(i+1)] = pred[:]
                test_pred = np.asarray(test_pred)
                print("avg = %.6f , min = %.6f, max = %.6f "% (test_pred.mean(),test_pred.min(),test_pred.max()))
                threhold = 0.5
                test_pred[test_pred<=threhold] = 0
                test_pred[test_pred>threhold] = 1
                accuracy = np.sum(np.equal(test_pred,test_y))/len(test_y)
                plot.append(1- accuracy)
                print("testing set accuracy: %.6f" % accuracy,flush=True)
                output.write("testing set accuracy: " + str(accuracy) + '\n')
                output.write("best_test_accuracy: " + str(best_test_accuracy) + '\n')
                output.flush()
                if accuracy > best_test_accuracy:
                    best_test_accuracy = accuracy
                    ckpt_path = os.path.join(checkpoint_dir, 'model.ckpt')
                    saver.save(sess, ckpt_path, global_step = e)
                    print("model saved!")
                print("best_test_accuracy: %.6f" % best_test_accuracy,flush=True)
        time_end = time.time()
        print("\ntraining done! totally cost: %.5f \n" %(time_end - time_start),flush=True)
        output.write("training done! totally cost: " + str(time_end - time_start) + '\n')
        output.flush()