예제 #1
0
파일: Train.py 프로젝트: ihatasi/Learning
def main():
    parser = argparse.ArgumentParser(description="Vanilla_AE")
    parser.add_argument("--batchsize", "-b", type=int, default=128)
    parser.add_argument("--epoch", "-e", type=int, default=100)
    parser.add_argument("--gpu", "-g", type=int, default=0)
    parser.add_argument("--snapshot", "-s", type=int, default=10)
    parser.add_argument("--n_dimz", "-z", type=int, default=64)

    args = parser.parse_args()

    #print settings
    print("GPU:{}".format(args.gpu))
    print("epoch:{}".format(args.epoch))
    print("Minibatch_size:{}".format(args.batchsize))
    print('')

    batchsize = args.batchsize
    gpu_id = args.gpu
    max_epoch = args.epoch
    train_val, test = mnist.get_mnist(withlabel=False, ndim=1)
    train, valid = split_dataset_random(train_val, 50000, seed=0)
    model = Network.AE(n_dimz=args.n_dimz, n_out=784)

    #set iterator
    train_iter = iterators.SerialIterator(train, batchsize)
    valid_iter = iterators.SerialIterator(valid,
                                          batchsize,
                                          repeat=False,
                                          shuffle=False)

    #optimizer
    def make_optimizer(model, alpha=0.0002, beta1=0.5):
        optimizer = optimizers.Adam(alpha=alpha, beta1=beta1)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.WeightDecay(0.0001))
        return optimizer

    opt = make_optimizer(model)
    #trainer
    updater = Updater.AEUpdater(model=model,
                                iterator=train_iter,
                                optimizer=opt,
                                device=args.gpu)

    trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='result')
    #trainer.extend(extensions.ExponentialShift('lr', 0.5),trigger=(30, 'epoch'))
    trainer.extend(extensions.LogReport(log_name='log'))
    trainer.extend(
        Evaluator.AEEvaluator(iterator=valid_iter,
                              target=model,
                              device=args.gpu))
    trainer.extend(extensions.snapshot_object(
        model, filename='model_snapshot_epoch_{.updater.epoch}.npz'),
                   trigger=(args.snapshot, 'epoch'))
    #trainer.extend(extensions.snapshot_object(optimizer, filename='optimizer_snapshot_epoch_{.updater.epoch}'), trigger=(args.snapshot, 'epoch'))
    trainer.extend(
        extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss']))
    trainer.extend(extensions.ProgressBar())
    trainer.run()
    del trainer
예제 #2
0
def main():
    parser = argparse.ArgumentParser(description="Vanilla_AE")
    parser.add_argument("--batchsize", "-b", type=int, default=128)
    parser.add_argument("--gpu", "-g", type=int, default=0)
    parser.add_argument("--snapshot", "-s", type=int, default=100)
    parser.add_argument("--n_dimz", "-z", type=int, default=64)

    args = parser.parse_args()
    os.makedirs('pict', exist_ok=True)

    def plot_mnist_data(samples):
        for index, (data, label) in enumerate(samples):  #(配列番号,要素)
            plt.subplot(5, 5, index + 1)  #(行数, 列数, 何番目のプロットか)
            plt.axis('off')  #軸はoff
            plt.imshow(data.reshape(28, 28), cmap="gray")  #nearestで補完
            plt.title(int(label), color='red')
        plt.savefig("pict/epoch_{}.png".format(args.snapshot))
        plt.show()

    batchsize = args.batchsize
    gpu_id = args.gpu
    _, test = mnist.get_mnist(withlabel=True, ndim=1)
    model = Network.AE(n_dimz=args.n_dimz, n_out=784)
    model.to_cpu()
    load_path = 'result/model_snapshot_epoch_{}.npz'.format(args.snapshot)
    chainer.serializers.load_npz(load_path, model)
    test = test[:25]
    pred_list = []
    for (data, label) in test:
        pred_data = model(np.array([data]).astype(np.float32)).data
        pred_list.append((pred_data, label))  #2つで1セットをたくさん作ってる

    plot_mnist_data(pred_list)