Exemplo n.º 1
0
def train(args):
    # 訓練とテストデータを読み込みます
    # Load train and test datas
    train_gen, test_gen = load_dataset(train_rate=args.trainrate)
    trainx = train_gen.images_original
    trainy = train_gen.images_segmented
    testx = test_gen.images_original
    testy = test_gen.images_segmented
    print(trainx.shape)
    print(testx.shape)
    # Create Reporter Object
    reporter = rp.Reporter(parser=parser)
    accuracy_fig = reporter.create_figure("Accuracy", ("epoch", "accuracy"),
                                          ["train", "test"])
    loss_fig = reporter.create_figure("Loss", ("epoch", "loss"),
                                      ["train", "test"])

    epochs = args.epoch
    batch_size = args.batchsize
    is_augment = args.augmentation

    model = build_model(output_class_num=ld.DataSet.length_category(),
                        l2_reg=args.l2reg)
    model.compile(
        loss=original_loss,
        #loss='binary_crossentropy',
        optimizer=Adam(lr=0.01),
        metrics=['accuracy'])

    #train_sequence = SequenceGenerator(train_gen, batch_size=batch_size, is_augment=is_augment)
    #test_sequence = SequenceGenerator(test_gen, batch_size=batch_size, is_augment=False)
    callbacks_list = [
        SavePredictionCallback(train_gen, test_gen, reporter),
        ReduceLROnPlateau(monitor='loss',
                          factor=0.5,
                          patience=5,
                          min_lr=1e-15,
                          verbose=1,
                          mode='auto',
                          cooldown=0),
        ModelCheckpoint(filepath='./model_{epoch:02d}_{val_loss:.2f}.h5',
                        monitor='loss',
                        save_best_only=True,
                        verbose=1,
                        mode='auto')
    ]

    # model.fit_generator(
    #     generator=train_sequence,
    #     validation_data=test_sequence,
    #     epochs=epochs,
    #     callbacks=callbacks_list
    # )
    model.summary()
    history = model.fit(x=[trainx],
                        y=[trainy],
                        batch_size=batch_size,
                        epochs=epochs,
                        validation_data=([testx], [testy]),
                        callbacks=callbacks_list)
def train(parser):

    train, test = load_dataset(train_rate=parser.trainrate)
    valid = test.devide(0, int(NUM*0.1))
    test = test.devide(int(NUM*0.1), int(NUM*0.3))

    #保存ファイル
    reporter = rp.Reporter(parser=parser)
    accuracy_fig = reporter.create_figure("Accuracy", ("epoch", "accuracy"), ["train", "test"])
    loss_fig = reporter.create_figure("Loss", ("epoch", "loss"), ["train", "test"])

    #GPU
    gpu = parser.gpu

    #model
    model_unet = model.UNet(size=(128, 128), l2_reg=parser.l2reg).model

    #誤差関数
    cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=model_unet.teacher,
                                                                           logits=model_unet.outputs))
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)

    #精度
    correct_prediction = tf.equal(tf.argmax(model_unet.outputs, 3), tf.argmax(model_unet.teacher, 3))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    #gpu config
    gpu_config = tf.ConfigProto(gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.7), device_count={'GPU': 1},
                                log_device_placement=False, allow_soft_placement=True)
    sess = tf.InteractiveSession(config=gpu_config) if gpu else tf.InteractiveSession()
    tf.global_variables_initializer().run()

    #parameter
    epochs = parser.epoch
    batch_size = parser.batchsize
    is_augment = parser.augmentation


    v_images_original = valid.images_original
    v_images_original = v_images_original[:,:,:, np.newaxis]
    t_images_original = test.images_original
    t_images_original = t_images_original[:,:,:, np.newaxis]

    train_dict = {model_unet.inputs: v_images_original, model_unet.teacher: valid.images_segmented,
                  model_unet.is_training: False}
    test_dict = {model_unet.inputs: t_images_original, model_unet.teacher: test.images_segmented,
                 model_unet.is_training: False}


    saver = tf.train.Saver()
    if  not os.path.exists("./checkpoint"):
        os.makedirs("./checkpoint")

    if CONTINUE:
        if os.path.exists("./checkpoint/"+ RESTORE_MODEL):
            saver.restore(sess, "./checkpoint/"+ RESTORE_MODEL)

    for epoch in range(epochs):
        for batch in train(batch_size=batch_size, augment=is_augment):#ここでtrainがシャッフルされる
            # バッチデータ
            images_original = batch.images_original
            if not is_augment:
                images_original = images_original[:, :, :, np.newaxis]
            inputs = images_original
            teacher = batch.images_segmented

            sess.run(train_step, feed_dict={model_unet.inputs: inputs, model_unet.teacher: teacher,
                                            model_unet.is_training: True})

        #入力データはグレースケールなのでチャネルの分の次元を追加
        train_images_original = train.images_original
        train_images_original = train_images_original[:,:,:, np.newaxis]

        # 評価
        if epoch % 1 == 0:
            loss_train = sess.run(cross_entropy, feed_dict=train_dict)
            loss_test = sess.run(cross_entropy, feed_dict=test_dict)
            accuracy_train = sess.run(accuracy, feed_dict=train_dict)
            accuracy_test = sess.run(accuracy, feed_dict=test_dict)
            print("Epoch:", epoch)
            print("[Train] Loss:", loss_train, " Accuracy:", accuracy_train)
            print("[Test]  Loss:", loss_test, "Accuracy:", accuracy_test)
            accuracy_fig.add([accuracy_train, accuracy_test], is_update=True)
            loss_fig.add([loss_train, loss_test], is_update=True)
            if epoch % 1 == 0:
                idx_train = random.randrange(NUM*0.7)#trainサイズ
                idx_test = random.randrange(NUM*0.2)#validationとtestが24個しかない
                outputs_train = sess.run(model_unet.outputs,
                                         feed_dict={model_unet.inputs: [train_images_original[idx_train]],
                                                    model_unet.is_training: False})
                outputs_test = sess.run(model_unet.outputs,
                                        feed_dict={model_unet.inputs: [t_images_original[idx_test]],
                                                   model_unet.is_training: False})
                train_set = [train_images_original[idx_train], outputs_train[0], train.images_segmented[idx_train]] #なぜかtrain.images_segmentedがシャッフルされるTODO
                test_set = [t_images_original[idx_test], outputs_test[0], test.images_segmented[idx_test]]
                reporter.save_image_from_ndarray(train_set, test_set, train.palette, epoch,
                                                 index_void=len(ld.DataSet.CATEGORY)-1)
        if epoch % 10 == 0:
            if SAVE:
                save_path = saver.save(sess, "./checkpoint/save_model_epoch_"+str(epoch)+"_.ckpt")
    save_path = saver.save(sess, "./checkpoint/save_model_done_plus.ckpt")

    #modelの評価
    loss_test = sess.run(cross_entropy, feed_dict=test_dict)
    accuracy_test = sess.run(accuracy, feed_dict=test_dict)
    print("Result")
    print("[Test]  Loss:", loss_test, "Accuracy:", accuracy_test)

    sess.close()
Exemplo n.º 3
0
def train(parser):
    # 訓練とテストデータを読み込みます
    # Load train and test datas
    train, test = load_dataset(train_rate=parser.trainrate)
    valid = train.perm(0, 30)
    test = test.perm(0, 150)

    # 結果保存用のインスタンスを作成します
    # Create Reporter Object
    reporter = rp.Reporter(parser=parser)
    accuracy_fig = reporter.create_figure("Accuracy", ("epoch", "accuracy"),
                                          ["train", "test"])
    loss_fig = reporter.create_figure("Loss", ("epoch", "loss"),
                                      ["train", "test"])

    # GPUを使用するか
    # Whether or not using a GPU
    gpu = parser.gpu

    # モデルの生成
    # Create a model
    model_unet = model.UNet(l2_reg=parser.l2reg).model

    # 誤差関数とオプティマイザの設定をします
    # Set a loss function and an optimizer
    cross_entropy = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(labels=model_unet.teacher,
                                                logits=model_unet.outputs))
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)

    # 精度の算出をします
    # Calculate accuracy
    correct_prediction = tf.equal(tf.argmax(model_unet.outputs, 3),
                                  tf.argmax(model_unet.teacher, 3))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    # セッションの初期化をします
    # Initialize session
    gpu_config = tf.ConfigProto(
        gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.7),
        device_count={'GPU': 1},
        log_device_placement=False,
        allow_soft_placement=True)
    sess = tf.InteractiveSession(
        config=gpu_config) if gpu else tf.InteractiveSession()
    tf.global_variables_initializer().run()

    # モデルの訓練
    # Train the model
    epochs = parser.epoch
    batch_size = parser.batchsize
    is_augment = parser.augmentation
    train_dict = {
        model_unet.inputs: valid.images_original,
        model_unet.teacher: valid.images_segmented,
        model_unet.is_training: False
    }
    test_dict = {
        model_unet.inputs: test.images_original,
        model_unet.teacher: test.images_segmented,
        model_unet.is_training: False
    }

    for epoch in range(epochs):
        for batch in train(batch_size=batch_size, augment=is_augment):
            # バッチデータの展開
            inputs = batch.images_original
            teacher = batch.images_segmented
            # Training
            sess.run(train_step,
                     feed_dict={
                         model_unet.inputs: inputs,
                         model_unet.teacher: teacher,
                         model_unet.is_training: True
                     })

        # 評価
        # Evaluation
        if epoch % 1 == 0:
            loss_train = sess.run(cross_entropy, feed_dict=train_dict)
            loss_test = sess.run(cross_entropy, feed_dict=test_dict)
            accuracy_train = sess.run(accuracy, feed_dict=train_dict)
            accuracy_test = sess.run(accuracy, feed_dict=test_dict)
            print("Epoch:", epoch)
            print("[Train] Loss:", loss_train, " Accuracy:", accuracy_train)
            print("[Test]  Loss:", loss_test, "Accuracy:", accuracy_test)
            accuracy_fig.add([accuracy_train, accuracy_test], is_update=True)
            loss_fig.add([loss_train, loss_test], is_update=True)
            if epoch % 3 == 0:
                idx_train = random.randrange(10)
                idx_test = random.randrange(100)
                outputs_train = sess.run(
                    model_unet.outputs,
                    feed_dict={
                        model_unet.inputs: [train.images_original[idx_train]],
                        model_unet.is_training: False
                    })
                outputs_test = sess.run(model_unet.outputs,
                                        feed_dict={
                                            model_unet.inputs:
                                            [test.images_original[idx_test]],
                                            model_unet.is_training:
                                            False
                                        })
                train_set = [
                    train.images_original[idx_train], outputs_train[0],
                    train.images_segmented[idx_train]
                ]
                test_set = [
                    test.images_original[idx_test], outputs_test[0],
                    test.images_segmented[idx_test]
                ]
                reporter.save_image_from_ndarray(
                    train_set,
                    test_set,
                    train.palette,
                    epoch,
                    index_void=len(ld.DataSet.CATEGORY) - 1)

    # 訓練済みモデルの評価
    # Test the trained model
    loss_test = sess.run(cross_entropy, feed_dict=test_dict)
    accuracy_test = sess.run(accuracy, feed_dict=test_dict)
    print("Result")
    print("[Test]  Loss:", loss_test, "Accuracy:", accuracy_test)
Exemplo n.º 4
0
def train(parser):
    assert parser.num != None, 'please input the number of images. i.e. python3 main.py -n xxx'
    #imageの枚数
    NUM = parser.num
    #画像の学習サイズ
    size = tuple(parser.size)
    #trainrate
    trainrate = parser.trainrate

    train, test = load_dataset(train_rate=trainrate, size=size)

    valid = train.devide(int(NUM * (trainrate - ((1 - trainrate)))),
                         int(NUM * trainrate))
    test = test.devide(0, int(NUM * (1 - trainrate)))

    #保存ファイル
    reporter = rp.Reporter(parser=parser)
    accuracy_fig = reporter.create_figure("Accuracy", ("epoch", "accuracy"),
                                          ["train", "test"])
    loss_fig = reporter.create_figure("Loss", ("epoch", "loss"),
                                      ["train", "test"])

    #restoreするモデル名
    CONTINUE = parser.restore

    #GPU
    gpu = parser.gpu
    print(gpu)

    #model
    model_unet = model.UNet(size=size, l2_reg=parser.l2reg).model

    #誤差関数
    cross_entropy = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(labels=model_unet.teacher,
                                                logits=model_unet.outputs))
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)

    #精度
    correct_prediction = tf.equal(tf.argmax(model_unet.outputs, 3),
                                  tf.argmax(model_unet.teacher, 3))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    #gpu config
    gpu_config = tf.ConfigProto(
        gpu_options=tf.GPUOptions(
            per_process_gpu_memory_fraction=0.7,
            visible_device_list="",
            allow_growth=True),  #device_count={'GPU': 0},
        log_device_placement=False,
        allow_soft_placement=True)
    if gpu:
        sess = tf.InteractiveSession(config=gpu_config)
        print("gpu mode")
    else:
        sess = tf.InteractiveSession()
        print("cpu mode")

    tf.global_variables_initializer().run()
    #parameter
    epochs = parser.epoch
    batch_size = parser.batchsize

    t_images_original = test.images_original

    saver = tf.train.Saver(max_to_keep=100)
    if not os.path.exists("./checkpoint"):
        os.makedirs("./checkpoint")

    if CONTINUE is not None:
        saver.restore(sess, "./checkpoint/" + CONTINUE)
        print("restored")

    for epoch in range(epochs):
        for batch in train(batch_size=batch_size):
            # バッチデータ
            images_original = batch.images_original
            inputs = images_original
            teacher = batch.images_segmented

            sess.run(train_step,
                     feed_dict={
                         model_unet.inputs: inputs,
                         model_unet.teacher: teacher,
                         model_unet.is_training: True
                     })

        train_images_original = train.images_original

        # 評価
        accuracy_train = 0
        loss_train = 0
        accuracy_test = 0
        loss_test = 0

        if epoch % 1 == 0:
            num_batch = 0
            for batchs in valid(batch_size=batch_size):
                num_batch += 1
                images = batchs.images_original
                segmented = batchs.images_segmented
                accuracy_train += sess.run(accuracy,
                                           feed_dict={
                                               model_unet.inputs: images,
                                               model_unet.teacher: segmented,
                                               model_unet.is_training: False
                                           })
                loss_train += sess.run(cross_entropy,
                                       feed_dict={
                                           model_unet.inputs: images,
                                           model_unet.teacher: segmented,
                                           model_unet.is_training: False
                                       })
            accuracy_train /= num_batch
            loss_train /= num_batch

            num_batch = 0
            for batchs in test(batch_size=batch_size):
                num_batch += 1
                images = batchs.images_original
                segmented = batchs.images_segmented
                accuracy_test += sess.run(accuracy,
                                          feed_dict={
                                              model_unet.inputs: images,
                                              model_unet.teacher: segmented,
                                              model_unet.is_training: False
                                          })
                loss_test += sess.run(cross_entropy,
                                      feed_dict={
                                          model_unet.inputs: images,
                                          model_unet.teacher: segmented,
                                          model_unet.is_training: False
                                      })
            accuracy_test /= num_batch
            loss_test /= num_batch

            print("Epoch:", epoch)
            print("[Train] Loss:", loss_train, " Accuracy:", accuracy_train)
            print("[Test]  Loss:", loss_test, "Accuracy:", accuracy_test)
            accuracy_fig.add([accuracy_train, accuracy_test], is_update=True)
            loss_fig.add([loss_train, loss_test], is_update=True)
            if epoch % 1 == 0:
                idx_train = random.randrange(int(NUM * trainrate))
                idx_test = random.randrange(int(NUM * (1 - trainrate)))
                outputs_train = sess.run(
                    model_unet.outputs,
                    feed_dict={
                        model_unet.inputs: [train_images_original[idx_train]],
                        model_unet.is_training: False
                    })
                outputs_test = sess.run(model_unet.outputs,
                                        feed_dict={
                                            model_unet.inputs:
                                            [t_images_original[idx_test]],
                                            model_unet.is_training:
                                            False
                                        })
                train_set = [
                    train_images_original[idx_train], outputs_train[0],
                    train.images_segmented[idx_train]
                ]
                test_set = [
                    t_images_original[idx_test], outputs_test[0],
                    test.images_segmented[idx_test]
                ]
                reporter.save_image_from_ndarray(
                    train_set, test_set, train.palette, epoch,
                    index_void=0)  #index_void = background
        if epoch % 10 == 0:
            if SAVE:
                save_path = saver.save(
                    sess,
                    "./checkpoint/save_model_epoch_" + str(epoch) + "_.ckpt")
    save_path = saver.save(sess, "./checkpoint/save_model_done.ckpt")

    #modelの評価

    for batchs in test(batch_size=batch_size):
        images = batchs.images_original
        segmented = batchs.image_segmented
        accuracy_test += sess.run(accuracy,
                                  feed_dict={
                                      model_unet.inputs: images,
                                      model_unet.teacher: segmented,
                                      model_unet.is_training: False
                                  })
        loss_test += sess.run(cross_entropy,
                              feed_dict={
                                  model_unet.inputs: images,
                                  model_unet.teacher: segmented,
                                  model_unet.is_training: False
                              })
    accuracy_test /= num_batch
    loss_test /= num_batch

    print("Result")
    print("[Test]  Loss:", loss_test, "Accuracy:", accuracy_test)

    sess.close()
def implement(parser):

    #image data
    train = load_dataset()

    #reporter
    reporter = rp.Reporter(parser=parser)

    # GPU
    gpu = parser.gpu

    #model
    model_unet = model.UNet(size=(1024, 1024), l2_reg=parser.l2reg).model

    # Initialize session
    gpu_config = tf.ConfigProto(gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.7), device_count={'GPU': 1},
                                log_device_placement=False, allow_soft_placement=True)
    sess = tf.InteractiveSession(config=gpu_config) if gpu else tf.InteractiveSession()
    tf.global_variables_initializer().run()

    is_augment = parser.augmentation

    #Sarver
    saver = tf.train.Saver()
    if  not os.path.exists("./checkpoint"):
        os.makedirs("./checkpoint")

    if not os.path.exists("./output"):
        os.makedirs("./output")

    if not os.path.exists("./input"):
        os.makedirs("./input")

    saver.restore(sess, "./checkpoint/save_model_512new_epoch_70.ckpt")


    train_images_original = train.images_original
    train_images_original = train_images_original[:,:,:, np.newaxis]

    for k in range(NUM):

        print("number: ", k)
        idx_train = k
        outputs_train = sess.run(model_unet.outputs,
                                         feed_dict={model_unet.inputs: [train_images_original[idx_train]],
                                                    model_unet.is_training: False})

        images_original_size = train.images_original_size
        print(images_original_size[idx_train])

        pred_image = reporter.cast_to_out_image(outputs_train[0], images_original_size[idx_train])

        #インプット画像の確認
        image_in_np = np.squeeze(train_images_original[idx_train])
        image_in_pil = reporter.cast_to_out_image_in(image_in_np, images_original_size[idx_train])
        if k >= 40:
            #reporter.save_simple_image(pred_image, k-40, "./output", "hv_")
            #reporter.save_simple_image(image_in_pil, k-40, "./input", "hv_")
            pass
        else:
            reporter.save_simple_image(pred_image, k, "./output", "hh_")
            reporter.save_simple_image(image_in_pil, k, "./input", "hh_")

    print("Result")

    sess.close()
Exemplo n.º 6
0
def debug(parser):
    # load test dataset
    _, test = load_dataset(train_rate=parser.trainrate)
    test = test.perm(0, 150)

    saver = tf.train.import_meta_graph(
        os.path.join(parser.saverpath, 'model.meta'))

    # Create Reporter Object
    reporter = rp.Reporter(result_dir=parser.saverpath, parser=parser)
    # Whether or not using a GPU
    gpu = parser.gpu

    # Create a model
    model_unet = model.UNet(l2_reg=parser.l2reg).model

    # Set a loss function
    cross_entropy = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(labels=model_unet.teacher,
                                                logits=model_unet.outputs))

    # Calculate accuracy
    correct_prediction = tf.equal(tf.argmax(model_unet.outputs, 3),
                                  tf.argmax(model_unet.teacher, 3))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    # Calculate mean iou
    labels = tf.reshape(model_unet.teacher,
                        [tf.shape(model_unet.teacher)[0], -1])
    preds = tf.reshape(tf.clip_by_value(model_unet.outputs, 0, 10000),
                       [tf.shape(model_unet.outputs)[0], -1])
    weights = tf.cast(
        tf.less_equal(preds,
                      len(ld.DataSet.CATEGORY) - 1),
        tf.int32)  # Ignoring all labels greater than or equal to n_classes.
    miou, update_op_miou = tf.metrics.mean_iou(labels=labels,
                                               predictions=preds,
                                               num_classes=len(
                                                   ld.DataSet.CATEGORY),
                                               weights=weights)
    # shape of each
    shape_teacher = tf.shape(model_unet.teacher)
    shape_output = tf.shape(model_unet.outputs)

    # Initialize session

    gpu_config = tf.ConfigProto(
        gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.7),
        device_count={'GPU': 1},
        log_device_placement=False,
        allow_soft_placement=True)
    sess = tf.InteractiveSession(
        config=gpu_config) if gpu else tf.InteractiveSession()
    tf.global_variables_initializer().run()
    tf.local_variables_initializer().run()
    saver.restore(sess, tf.train.latest_checkpoint(parser.saverpath))
    print("Saver restore model variables from ",
          tf.train.latest_checkpoint(parser.saverpath))

    # Set up the test dataset
    test_dict = {
        model_unet.inputs: test.images_original[0:2],
        model_unet.teacher: test.images_segmented[0:2],
        model_unet.is_training: False
    }

    # Test the trained model
    loss_test = sess.run(cross_entropy, feed_dict=test_dict)
    accuracy_test = sess.run(accuracy, feed_dict=test_dict)
    sess.run(update_op_miou, feed_dict=test_dict)
    miou_test = sess.run(miou, feed_dict=test_dict)
    shapes_test1 = sess.run(shape_teacher, feed_dict=test_dict)
    shapes_test2 = sess.run(shape_output, feed_dict=test_dict)
    print("TEST Result")
    print("[Test]  Loss:", loss_test, "Accuracy:", accuracy_test, " Mean IOU:",
          miou_test, " Shape", shapes_test1, shapes_test2,
          len(ld.DataSet.CATEGORY))
Exemplo n.º 7
0
def train(parser):
    # Load train and test datas
    print("Start training")
    train, test = load_dataset(train_rate=parser.trainrate)
    valid = train.perm(0, 30)
    test = test.perm(0, 150)

    # Create Reporter Object
    reporter = rp.Reporter(parser=parser)
    #accuracy_fig = reporter.create_figure("Accuracy", ("epoch", "accuracy"), ["train", "test"])
    #loss_fig = reporter.create_figure("Loss", ("epoch", "loss"), ["train", "test"])

    # Whether or not using a GPU
    gpu = parser.gpu

    # Create a model
    model_unet = model.UNet(l2_reg=parser.l2reg).model

    # Set a loss function and an optimizer
    cross_entropy = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(labels=model_unet.teacher,
                                                logits=model_unet.outputs))
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)

    # Calculate accuracy
    correct_prediction = tf.equal(tf.argmax(model_unet.outputs, 3),
                                  tf.argmax(model_unet.teacher, 3))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    # Calculate mean iou
    labels = tf.reshape(model_unet.teacher,
                        [tf.shape(model_unet.teacher)[0], -1])
    preds = tf.reshape(tf.clip_by_value(model_unet.outputs, 0, 10000),
                       [tf.shape(model_unet.outputs)[0], -1])
    weights = tf.cast(
        tf.less_equal(preds,
                      len(ld.DataSet.CATEGORY) - 1),
        tf.int32)  # Ignoring all labels greater than or equal to n_classes.
    #miou, update_op_miou = tf.metrics.mean_iou(labels = labels,
    #                                           predictions = preds,
    #                                           num_classes = len(ld.DataSet.CATEGORY),
    #                                           weights=weights)

    # Initialize session
    saver = tf.train.Saver()
    gpu_config = tf.ConfigProto(
        gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.7),
        device_count={'GPU': 1},
        log_device_placement=False,
        allow_soft_placement=True)
    sess = tf.InteractiveSession(
        config=gpu_config) if gpu else tf.InteractiveSession()
    tf.global_variables_initializer().run()
    tf.local_variables_initializer().run()

    # Train the model
    epochs = parser.epoch
    batch_size = parser.batchsize
    is_augment = parser.augmentation
    train_dict = {
        model_unet.inputs: valid.images_original,
        model_unet.teacher: valid.images_segmented,
        model_unet.is_training: False
    }
    test_dict = {
        model_unet.inputs: test.images_original,
        model_unet.teacher: test.images_segmented,
        model_unet.is_training: False
    }

    for epoch in range(epochs):
        for batch in train(batch_size=batch_size, augment=is_augment):
            # input images
            inputs = batch.images_original
            teacher = batch.images_segmented
            # Training
            sess.run(train_step,
                     feed_dict={
                         model_unet.inputs: inputs,
                         model_unet.teacher: teacher,
                         model_unet.is_training: True
                     })

        # Evaluation
        if epoch % 1 == 0:
            loss_train = sess.run(cross_entropy, feed_dict=train_dict)
            loss_test = sess.run(cross_entropy, feed_dict=test_dict)
            accuracy_train = sess.run(accuracy, feed_dict=train_dict)
            accuracy_test = sess.run(accuracy, feed_dict=test_dict)
            #sess.run(update_op_miou, feed_dict=train_dict)
            #sess.run(update_op_miou, feed_dict=test_dict)
            #miou_train = sess.run(miou, feed_dict=train_dict)
            #miou_test = sess.run(miou, feed_dict=test_dict)
            print("Epoch:", epoch)
            print("[Train] Loss:", loss_train, " Accuracy:", accuracy_train)
            print("[Test]  Loss:", loss_test, "Accuracy:", accuracy_test)
            #accuracy_fig.add([accuracy_train, accuracy_test], is_update=True)
            #loss_fig.add([loss_train, loss_test], is_update=True)
            if epoch % 3 == 0:
                saver.save(sess, os.path.join(reporter._result_dir, 'model'))
                idx_train = random.randrange(10)
                idx_test = random.randrange(100)
                outputs_train = sess.run(
                    model_unet.outputs,
                    feed_dict={
                        model_unet.inputs: [train.images_original[idx_train]],
                        model_unet.is_training: False
                    })
                outputs_test = sess.run(model_unet.outputs,
                                        feed_dict={
                                            model_unet.inputs:
                                            [test.images_original[idx_test]],
                                            model_unet.is_training:
                                            False
                                        })
                train_set = [
                    train.images_original[idx_train], outputs_train[0],
                    train.images_segmented[idx_train]
                ]
                test_set = [
                    test.images_original[idx_test], outputs_test[0],
                    test.images_segmented[idx_test]
                ]
                reporter.save_image_from_ndarray(
                    train_set,
                    test_set,
                    train.palette,
                    epoch,
                    index_void=len(ld.DataSet.CATEGORY) - 1)

    # Test the trained model
    loss_test = sess.run(cross_entropy, feed_dict=test_dict)
    accuracy_test = sess.run(accuracy, feed_dict=test_dict)
    # sess.run(update_op_miou, feed_dict=test_dict)
    # miou_test = sess.run(miou, feed_dict=test_dict)
    print("Result")
    print("[Test]  Loss:", loss_test, "Accuracy:", accuracy_test)
    save_path = saver.save(sess, os.path.join(reporter._result_dir, 'model'))
    print("Model saved in file: ", save_path)
    for ii in range(100):
        outputs_test = sess.run(model_unet.outputs,
                                feed_dict={
                                    model_unet.inputs:
                                    [test.images_original[ii]],
                                    model_unet.is_training: False
                                })
        test_set = [
            test.images_original[ii], outputs_test[0],
            test.images_segmented[ii], test.filenames[ii]
        ]
        reporter.save_image_from_ndarray([],
                                         test_set,
                                         test.palette,
                                         1000000,
                                         index_void=len(ld.DataSet.CATEGORY) -
                                         1,
                                         fnames=test.filenames[ii])