Пример #1
0
def load_data(mode, train=True):
    if mode == 's2m':
        return get_mnist(train), get_usps(train)
    elif mode == 'u2m':
        return get_mnist(train), get_usps(train)
    elif mode == 'm2u':
        return get_mnist(train), get_usps(train)
Пример #2
0
def run(dtrain, dtest, epochs=10, verbose=False):
    args = AttrDict(**dict(
        save_dir="_models",
        iters=epochs,
        epochs=epochs,
        bootstrap_epochs=1,
        ncams=1,
        verbose=verbose
    ))

    mnist = Collection('multiinput_edge_dropout_mpcc_{}'.format(args.ncams), args.save_dir, nepochs=args.epochs, verbose=args.verbose)
    ncams = args.ncams
    mnist.set_model_family(MultiInputEdgeDropoutFamily, ninputs=ncams, resume=False,
                            merge_function="max_pool_concat", drop_comm_train=dtrain,
                            drop_comm_test=dtest, input_dims=1, output_dims=10)
    train, test = get_mnist()
    mnist.add_trainset(train)
    mnist.add_testset(test)
    mnist.set_searchspace(
        nfilters_embeded=[3],
        nlayers_embeded=[2],
        nfilters_cloud=[3],
        nlayers_cloud=[2],
        lr=[1e-3],
        branchweight=[.1],
        ent_T=[100]
    )

    # currently optimize based on the validation accuracy of the main model
    traces = mnist.train(niters=args.iters, bootstrap_nepochs=args.bootstrap_epochs)
    return traces[-1]['y']
Пример #3
0
def get_data_loader(name):
    '''Get data loader by name'''
    if name == 'mnist':
        return get_mnist()
    elif name == 'celeba':
        return get_celeba()
    else:
        assert False, '[*] dataset not implement!'
Пример #4
0
def get_data_loader(name, path, train=True):
    """Get data loader by name."""
    if name == "MNIST":
        return get_mnist(path, train)
    elif name == "USPS":
        return get_usps(path, train)
    elif name == "SVHN":
        return get_svhn(path, train)
Пример #5
0
def get_loader(name, split, batch_size=50):
    if name == "mnist":
        return get_mnist(split,batch_size)

    if name == "usps":
        return get_usps(split,batch_size)

    if name == "mnistBig":
        return mnistBig.get_mnist(split,batch_size)


    if name == "uspsBig":
        return uspsBig.get_usps(split,batch_size)
Пример #6
0
def get_data( name, start_date, end_date):
    entitys = None
    if name == 'mnist' or name == 'default':
        nb_classes, input_shape, x_train, \
            x_test, y_train, y_test = get_mnist()
        entitys = Entitys(input_shape, nb_classes, x_train, y_train, x_test, y_test)
    elif name == 'cifar10':
        nb_classes, input_shape, x_train, \
            x_test, y_train, y_test = get_cifar10()
        entitys = Entitys(input_shape, nb_classes, x_train, y_train, x_test, y_test)
    elif name == 'stock':
        nb_classes, input_shape, x_train, \
            x_test, y_train, y_test = get_stock(start_date, end_date)
        entitys = Entitys(input_shape, nb_classes, x_train, y_train, x_test, y_test)
    
    return entitys
Пример #7
0
def get_loader(name, split, batch_size=50):
    if name == "mnist":
        return get_mnist(split, batch_size)

    elif name == "usps":
        return get_usps(split, batch_size)

    elif name == "mnistBig":
        return mnistBig.get_mnist(split, batch_size)

    elif name == "uspsBig":
        return uspsBig.get_usps(split, batch_size)

    elif name == "coxs2v":
        return coxs2v.get_coxs2v(split, batch_size)

    else:
        raise Exception("Dataset name {} not supported".format(name))
Пример #8
0
def get_data_loader(name, train=True):
    """Get data loader by name."""
    if name == "MNIST":
        return get_mnist(train)
    elif name == "USPS":
        return get_usps(train)
Пример #9
0
def get_data_iter(name, train):
    if name == 'MNIST':
        return get_inf_iterator(mnist.get_mnist(train=True))
    else:
        return get_inf_iterator(usps.get_usps(train=True))
Пример #10
0
if __name__ == '__main__':
    data_itr_tgt = get_data_iter("USPS", train=True)
    image_tgt, label_tgt = next(data_itr_tgt)
    image_tgt = image_tgt[0:2]
    print(image_tgt.shape)
    print(image_tgt)
    new_tgt = make_larger_size(image_tgt)
    print(new_tgt.shape)
    print(new_tgt)
    plt.imshow(new_tgt[0].numpy().reshape(36, 36), cmap="gray")
    plt.show()
    print(label_tgt[0])
    exit(0)
    #options = { 'dir' : 'data' , 'name' : 'MNIST' , 'batch_size' : 64 , 'dataset_mean' : (0.5,0.5,0.5) , 'dataset_std' : (0.5,0.5,0.5)}
    mnist_loader = mnist.get_mnist(train=True)
    usps_loader = usps.get_usps(train=True)
    im = Image.open("snapshots/Figure_2.png")
    im = im.convert("L")
    #im = im.resize((image_width, image_height))
    # im.show()
    data = im.getdata()
    data = np.matrix(data)
    print(data.shape)
    for tgt_img, tgt_label in usps_loader:
        #print(type(tgt_img))
        #print(tgt_img[0])
        plt.imshow(tgt_img[0].numpy().reshape(28, 28), cmap="gray")

        plt.show()
        print(tgt_img[0].numpy())