def train():

    Dt1_train_dir = "/media/common-ns/New Volume/reseach/Dataset/OU-ISIR_by_Setoguchi/Gallery/signed/128_3ch/CV01/Dt1/CV01_Dt1_(Gallery&Probe)"

    train1 = load_GEI(path_dir=Dt1_train_dir, mode=True)

    Dt2_train_dir = "/media/common-ns/New Volume/reseach/Dataset/OU-ISIR_by_Setoguchi/Gallery/signed/128_3ch/CV01/Dt1/CV01_Dt2_(Gallery&Probe)"
    train2 = load_GEI(path_dir=Dt2_train_dir, mode=True)

    model = Multi_modal_GEINet()

    model.to_gpu()

    Dt1_train_iter = iterators.SerialIterator(train1,
                                              batch_size=2,
                                              shuffle=False)
    Dt2_train_iter = iterators.SerialIterator(train2,
                                              batch_size=2,
                                              shuffle=False)

    optimizer = chainer.optimizers.MomentumSGD(lr=0.02, momentum=0.9)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(0.01))

    # updater = training.ParallelUpdater(train_iter, optimizer, devices={'main': 0, 'second': 1})
    updater = Multi_modal_Updater(model,
                                  Dt1_train_iter,
                                  Dt2_train_iter,
                                  optimizer,
                                  device=0)
    epoch = 6250

    trainer = training.Trainer(
        updater, (epoch, 'epoch'),
        out='/home/wutong/Setoguchi/chainer_files/result')

    # trainer.extend(extensions.Evaluator(test_iter, model, device=0))
    trainer.extend(extensions.ExponentialShift(attr='lr', rate=0.56234),
                   trigger=(1250, 'epoch'))
    trainer.extend(
        extensions.LogReport(log_name='SFDEI_log', trigger=(50, "epoch")))
    trainer.extend(extensions.snapshot(), trigger=(1250, 'epoch'))
    trainer.extend(extensions.snapshot_object(
        target=model, filename='model_snapshot_{.updater.epoch}'),
                   trigger=(1250, 'epoch'))
    trainer.extend(extensions.PrintReport(['epoch', 'accuracy', 'loss']))
    # 'validation/main/accuracy']),
    # trigger=(1, "epoch"))
    trainer.extend(
        extensions.dump_graph(root_name="loss", out_name="multi_modal.dot"))
    trainer.extend(extensions.PlotReport(["loss"]), trigger=(20, 'epoch'))
    trainer.extend(extensions.ProgressBar())

    # Run the trainer
    trainer.run()
def recognition(model_name):
    # 識別用のデータをダウンロード
    train1, train_labels1 = load_GEI(
        "/home/common/setoguchi/chainer_files/dataset/SFDEI/CV02(Gallery)_2nd",
        mode=False)
    train2, train_labels2 = load_GEI(
        "/home/common/setoguchi/chainer_files/dataset/SFDEI/CV02_Dt2(Gallery)",
        mode=False)
    test1, test_labels1 = load_GEI(
        "/home/common/setoguchi/chainer_files/dataset/SFDEI/CV02(Probe)_2nd",
        mode=False)
    test2, test_labels2 = load_GEI(
        "/home/common/setoguchi/chainer_files/dataset/SFDEI/CV02_Dt2(Probe)",
        mode=False)



    # extract features
    model = Multi_modal_GEINet()
    serializers.load_npz(model_name, obj=model)

    train_features = extract_features(model, train1, train2)
    test_features = extract_features(model, test1, test2)

    neigh = KNeighborsClassifier(n_neighbors=1)
    neigh.fit(train_features, train_labels1)

    correct = 0.0
    for i, item in enumerate(test_features):
        predit = neigh.predict([item])[0]
        print "label:%d, predict:%d" % (test_labels1[i], predit)
        if predit == test_labels1[i]:
            correct = correct + 1
    # else:
    #         f.write("label:%d, predict:%d" %(true_label[i],predit)+"\n")

    acc = correct / len(test_features)
    print acc
Пример #3
0
def train(train_dir):

    train1 = load_GEI(path_dir=train_dir, mode=True)

    model = L.Classifier(GEINet())

    model.to_gpu()

    Dt1_train_iter = iterators.SerialIterator(train1,
                                              batch_size=239,
                                              shuffle=False)

    optimizer = chainer.optimizers.MomentumSGD(lr=0.02, momentum=0.9)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(0.01))

    # updater = training.ParallelUpdater(train_iter, optimizer, devices={'main': 0, 'second': 1})
    updater = training.StandardUpdater(model,
                                       Dt1_train_iter,
                                       optimizer,
                                       device=0)
    epoch = 6250

    trainer = training.Trainer(
        updater, (epoch, 'epoch'),
        out='/home/wutong/Setoguchi/chainer_files/result')

    # trainer.extend(extensions.Evaluator(test_iter, model, device=0))
    trainer.extend(extensions.ExponentialShift(attr='lr', rate=0.56234),
                   trigger=(1250, 'epoch'))
    trainer.extend(
        extensions.LogReport(log_name='SFDEI_log', trigger=(50, "epoch")))
    trainer.extend(extensions.snapshot(), trigger=(1250, 'epoch'))
    trainer.extend(extensions.snapshot_object(
        target=model, filename='model_snapshot_{.updater.epoch}'),
                   trigger=(1250, 'epoch'))
    trainer.extend(extensions.PrintReport(['epoch', 'accuracy', 'loss']))
    # 'validation/main/accuracy']),
    # trigger=(1, "epoch"))
    trainer.extend(
        extensions.dump_graph(root_name="loss", out_name="multi_modal.dot"))
    trainer.extend(extensions.PlotReport(["loss"]), trigger=(20, 'epoch'))
    trainer.extend(extensions.ProgressBar())

    # Run the trainer
    trainer.run()
Пример #4
0
def train(mode):

    Dt1_train_dir = "/media/wutong/New Volume/reseach/Dataset/OU-ISIR_by_Setoguchi/Gallery/signed/128_3ch/CV01_(Gallery&Probe)_2nd"
    train1 = load_GEI(path_dir=Dt1_train_dir, mode=True)

    Dt2_train_dir = "/media/wutong/New Volume/reseach/Dataset/OU-ISIR_by_Setoguchi/Gallery/signed/128_3ch/CV01_Dt2_(Gallery&Probe)"
    train2 = load_GEI(path_dir=Dt2_train_dir, mode=True)

    Dt3_train_dir = "/media/wutong/New Volume/reseach/Dataset/OU-ISIR_by_Setoguchi/Gallery/signed/128_3ch/CV01_Dt3_(Gallery&Probe)"
    train3 = load_GEI(path_dir=Dt3_train_dir, mode=True)

    Dt4_train_dir = "/media/wutong/New Volume/reseach/Dataset/OU-ISIR_by_Setoguchi/Gallery/signed/128_3ch/CV01_Dt4_(Gallery&Probe)"
    train4 = load_GEI(path_dir=Dt4_train_dir, mode=True)

    Dt5_train_dir = "/media/wutong/New Volume/reseach/Dataset/OU-ISIR_by_Setoguchi/Gallery/signed/128_3ch/CV01_Dt5_(Gallery&Probe)"
    train5 = load_GEI(path_dir=Dt5_train_dir, mode=True)

    Dt6_train_dir = "/media/wutong/New Volume/reseach/Dataset/OU-ISIR_by_Setoguchi/Gallery/signed/128_3ch/CV01_Dt6_(Gallery&Probe)"
    train6 = load_GEI(path_dir=Dt6_train_dir, mode=True)

    Dt7_train_dir = "/media/wutong/New Volume/reseach/Dataset/OU-ISIR_by_Setoguchi/Gallery/signed/128_3ch/CV01_Dt7_(Gallery&Probe)"
    train7 = load_GEI(path_dir=Dt7_train_dir, mode=True)

    Dt8_train_dir = "/media/wutong/New Volume/reseach/Dataset/OU-ISIR_by_Setoguchi/Gallery/signed/128_3ch/CV01_Dt8_(Gallery&Probe)"
    train8 = load_GEI(path_dir=Dt8_train_dir, mode=True)

    model = Multi_modal_GEINet()

    model.to_gpu()

    # train_iter = iterators.MultiprocessIterator(train, batch_size=239)
    Dt1_train_iter = iterators.SerialIterator(train1,
                                              batch_size=239,
                                              shuffle=False)
    Dt2_train_iter = iterators.SerialIterator(train2,
                                              batch_size=239,
                                              shuffle=False)
    Dt3_train_iter = iterators.SerialIterator(train3,
                                              batch_size=239,
                                              shuffle=False)
    Dt4_train_iter = iterators.SerialIterator(train4,
                                              batch_size=239,
                                              shuffle=False)
    Dt5_train_iter = iterators.SerialIterator(train5,
                                              batch_size=239,
                                              shuffle=False)
    Dt6_train_iter = iterators.SerialIterator(train6,
                                              batch_size=239,
                                              shuffle=False)
    Dt7_train_iter = iterators.SerialIterator(train7,
                                              batch_size=239,
                                              shuffle=False)
    Dt8_train_iter = iterators.SerialIterator(train8,
                                              batch_size=239,
                                              shuffle=False)

    # optimizer = chainer.optimizers.SGD(lr=0.02)
    optimizer = chainer.optimizers.MomentumSGD(lr=0.02, momentum=0.9)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(0.01))

    # updater = training.ParallelUpdater(train_iter, optimizer, devices={'main': 0, 'second': 1})
    updater = Multi_modal_Updater(model,
                                  Dt1_train_iter,
                                  Dt2_train_iter,
                                  Dt3_train_iter,
                                  Dt4_train_iter,
                                  Dt5_train_iter,
                                  Dt6_train_iter,
                                  Dt7_train_iter,
                                  Dt8_train_iter,
                                  optimizer,
                                  device=0)
    epoch = 6250

    trainer = training.Trainer(
        updater, (epoch, 'epoch'),
        out='/home/wutong/Setoguchi/chainer_files/result')

    # trainer.extend(extensions.Evaluator(test_iter, model, device=0))
    trainer.extend(extensions.ExponentialShift(attr='lr', rate=0.56234),
                   trigger=(1250, 'epoch'))
    trainer.extend(
        extensions.LogReport(log_name='SFDEI_log', trigger=(20, "epoch")))
    trainer.extend((extensions.snapshot_object(
        model, filename='model_shapshot_{.update.epoch}')),
                   trigger=(1250, 'epoch'))
    trainer.extend(extensions.snapshot(), trigger=(1250, 'epoch'))
    trainer.extend(extensions.PrintReport(['epoch', 'accuracy', 'loss']))
    trainer.extend(
        extensions.dump_graph(root_name="loss", out_name="multi_modal_3.dot"))
    trainer.extend(extensions.PlotReport(["loss"]), trigger=(50, 'epoch'))
    trainer.extend(extensions.ProgressBar())

    if mode == True:
        # Run the trainer
        trainer.run()
    else:
        serializers.load_npz(
            "/home/wutong/Setoguchi/chainer_files/SFDEINet_multi_modal/SFDEINet_multi_modal_model",
            trainer)
        trainer.run()
        serializers.save_npz(
            "/home/wutong/Setoguchi/chainer_files/SFDEINet_multi_modal/SFDEINet_multi_modal_model",
            trainer)

    serializers.save_npz(
        "/home/wutong/Setoguchi/chainer_files/SFDEINet_multi_modal/SFDEINet_multi_modal_model",
        model)