예제 #1
0
파일: n2d.py 프로젝트: snazari/n2d
    parser.add_argument('--visualize', default=False, action='store_true')
    args = parser.parse_args()
    print(args)

    optimizer = 'adam'
    from datasets import load_mnist, load_mnist_test, load_usps, load_pendigits, load_fashion, load_har

    label_names = None
    if args.dataset == 'mnist':
        x, y = load_mnist()
    elif args.dataset == 'mnist-test':
        x, y = load_mnist_test()
    elif args.dataset == 'usps':
        x, y = load_usps()
    elif args.dataset == 'pendigits':
        x, y = load_pendigits()
    elif args.dataset == 'fashion':
        x, y, label_names = load_fashion()
    elif args.dataset == 'har':
        x, y, label_names = load_har()

    shape = [x.shape[-1], 500, 500, 2000, args.n_clusters]
    autoencoder = autoencoder(shape)

    hidden = autoencoder.get_layer(name='encoder_%d' % (len(shape) - 2)).output
    encoder = Model(inputs=autoencoder.input, outputs=hidden)

    pretrain_time = time()

    # Pretrain autoencoders before clustering
    if args.ae_weights is None:
예제 #2
0
파일: IDEC.py 프로젝트: epideep/source
    parser.add_argument('--update_interval', default=1, type=int)
    parser.add_argument('--tol', default=0.001, type=float)
    parser.add_argument('--ae_weights', default=None)
    parser.add_argument('--save_dir', default='results/idec')
    args = parser.parse_args()
    print(args)

    # load dataset
    optimizer = 'adam'  # SGD(lr=0.01, momentum=0.99)
    from datasets import load_mnist, load_reuters, load_usps, load_pendigits, load_mydata
    if args.dataset == 'mnist':  # recommends: n_clusters=10, update_interval=140
        x, y = load_mnist()
    elif args.dataset == 'usps':  # recommends: n_clusters=10, update_interval=30
        x, y = load_usps('data/usps')
    elif args.dataset == 'pendigits':
        x, y = load_pendigits('data/pendigits')
    elif args.dataset == 'reutersidf10k':  # recommends: n_clusters=4, update_interval=20
        x, y = load_reuters('data/reuters')
    elif args.dataset == 'mydata':  # recommends: n_clusters=4, update_interval=20
        x, y = load_mydata(path='./data/mydata')

    if args.update_interval == 0:  # one epoch
        args.update_interval = int(x.shape[0] / args.batch_size)

    # Define IDEC model
    idec = IDEC(dims=[x.shape[-1], 500, 500, 2000, 18],
                n_clusters=args.n_clusters)
    #plot_model(idec.model, to_file='idec_model.png', show_shapes=True)
    idec.model.summary()

    t0 = time()