Пример #1
0
def train_missing_models(cfg, lbl_name, archi_fun):
    """Figure out which models are missing and train them."""
    uid = cfg['uniqueid']
    label_data = np.load('data/%s/%s.npz' % (uid, lbl_name))
    data = np.load(cfg['file_path'])
    x, y = data[cfg['x_name']], data[cfg['y_name']]
    n_data = x.shape[0]

    exist_mdls = util.get_clus_reg_by_dir('models/%s/%s' % (uid, lbl_name))
    exist_mdl_key = exist_mdls.keys()

    keys = label_data.keys()
    keys.sort(key=int)

    for key in keys:
        n_cluster = int(key)
        model_directory = 'models/%s/%s/%s' % (uid, lbl_name, key)
        files = glob.glob(os.path.join(model_directory, '*'))
        # create an array of missing flag
        miss_flag = np.ones(n_cluster, dtype=int)
        for file_ in files:
            if 'classifier_' in file_:
                continue
            mdl_idx = int(re.findall('\/model_(\d+)_', file_)[-1])
            miss_flag[mdl_idx] = 0  # this is not missing
        miss_idx = np.where(miss_flag == 1)[0]
        if miss_idx.shape[0] == 0:
            print('No missing model for key ', key)
        else:
            print('missing models are ', miss_idx)
            label = label_data[key]
            x_arr, y_arr, label_arr = pp.getSharedNumpy(x, y, label)
            for idx in miss_idx:
                train_one_model([idx, idx + 1], model_directory, archi_fun, x_arr, y_arr, label_arr)
Пример #2
0
def eval_on_valid(cfg, lbl_name, args):
    """Just perform evaluation on validation set, save the outputs into some file."""
    # load the validation set
    vdata = np.load(cfg['valid_path'])
    if 'valid_x_name' in cfg:
        x = vdata[cfg['valid_x_name']]
    else:
        x = vdata[cfg['x_name']]
    uid = cfg['uniqueid']
    if args.snn:
        mdlfun = modelLoader(cfg['snn_path'])
        predy = mdlfun(x)
        np.save('data/%s/snn_validation_predict.npy' % uid, predy)
        return
    # load MoE models from desired directory
    result = util.get_clus_reg_by_dir('models/%s/%s' % (uid, lbl_name))
    keys = result.keys()
    print('existing keys ', keys)
    out_dict = {}
    for key in keys:
        print('For key ', key)
        cls, regs = result[key]
        net = MoMNet(cls, regs)
        predy = net.getPredY(x)
        out_dict[str(key)] = predy
    np.savez('data/%s/%s_validation_predict.npz' % (uid, lbl_name), **out_dict)
Пример #3
0
def eval_valid_error(cfg, lbl_name, args):
    """Evaluation of trained models on validation set."""
    # load the validation set
    vdata = np.load(cfg['valid_path'])
    if 'valid_x_name' in cfg:
        x = vdata[cfg['valid_x_name']]
    else:
        x = vdata[cfg['x_name']]
    if 'valid_y_name' in cfg:
        y = vdata[cfg['valid_y_name']]
    else:
        y = vdata[cfg['y_name']]
    if 'flag' in vdata.keys():
        mask = np.where(vdata['flag'] == 1)
        x = x[mask]
        y = y[mask]
    print('validation set size ', x.shape, y.shape)
    uid = cfg['uniqueid']
    if args.snn:
        mdlfun = modelLoader(cfg['snn_path'])
        predy = mdlfun(x)
        error = np.mean(l1loss(y, predy),
                        axis=1)  # get error for each instance
        print(np.mean(error))
        fig, ax = plt.subplots()
        ax.hist(error, bins=20)
        ax.set_yscale('log', nonposy='clip')
        ax.set_ylim(1, ax.get_ylim()[1])
        ax.legend()
        plt.show()
        return
    # load MoE models from desired directory
    result = util.get_clus_reg_by_dir('models/%s/%s' % (uid, lbl_name))
    keys = result.keys()
    print('existing keys ', keys)
    v_error = []
    for key in keys:
        print('For key ', key)
        cls, regs = result[key]
        net = MoMNet(cls, regs)
        predy = net.getPredY(x)
        error = np.mean(l1loss(y, predy),
                        axis=1)  # get error for each instance
        v_error.append(error)
    v_mean_error = [np.mean(error) for error in v_error]
    print('mean error is ', v_mean_error)
    # show histogram
    fig, ax = plt.subplots()
    ax.hist(v_error, bins=20, label=keys)
    ax.set_yscale('log', nonposy='clip')
    ax.set_ylim(1, ax.get_ylim()[1])
    ax.legend()
    plt.show()
Пример #4
0
def train_model_warm(args):
    """Train that model warmly by loading some pre-trained models.

    We say warm we mean we pretrain gate or experts. Not sure which I should use instead.
    Maybe I really need a function to modify weights of GaoNet to consider xScale and yScale
    """
    cfg, lbl_name = util.get_label_cfg_by_args(args)
    uid = cfg['uniqueid']
    result = util.get_clus_reg_by_dir('models/%s/%s' % (uid, lbl_name))
    cls, regs = result[5]  # we know this is bad but could possibly work
    cls_data = torch.load(cls)
    print(cls_data.keys())
    cls_net = cls_data['model']
    xmean, xstd = cls_data['xScale']
    print('xmean', xmean, 'xstd', xstd)
    # cls_net.extendXYScale((xmean, xstd))
    expert = Experts([[2, 60, 75]] * 5)
    run_the_training(args, cls_net, expert)
Пример #5
0
def show_picky_states(cfg, lbl_name, args):
    """Show a few states and their predictions"""
    # play with initial states
    use_obs = np.array([0, 4, 4, 3])
    use_x0 = np.array([[0, 8, 8], [0, 8, 9]])
    n_x0 = use_x0.shape[0]

    uid = cfg['uniqueid']
    lbl_name = 'pca_kmean_label'
    # load functions
    snn_fun = modelLoader(cfg['snn_path'])
    result = util.get_clus_reg_by_dir('models/%s/%s' % (uid, lbl_name))
    cls, regs = result[20]  # let me try this one
    net = MoMNet(cls, regs)

    # load optimal solutions
    sol = np.load('data/droneone/the_two_sol.npy')

    # create figure
    fig, ax = pld.get3dAxis()
    pld.addSphere(ax, *use_obs)
    legend_sets = []
    for i in range(n_x0):
        tmp_set = []
        x0 = np.concatenate((use_x0[i], use_obs))
        y0 = snn_fun(x0)
        predy = net.getPredY(x0)
        tX, _, _ = parseX(y0)

        hi, = ax.plot(tX[:, 0],
                      tX[:, 1],
                      tX[:, 2],
                      ls=':',
                      color='C%d' % i,
                      label='SNN %d' % (i + 1))
        tXmoe, _, _ = parseX(predy)
        hj, = ax.plot(tXmoe[:, 0],
                      tXmoe[:, 1],
                      tXmoe[:, 2],
                      ls='--',
                      color='C%d' % i,
                      label='MoE %d' % (i + 1))

        true_sol = sol[i]
        tXt, _, _ = parseX(true_sol)
        hk, = ax.plot(tXt[:, 0],
                      tXt[:, 1],
                      tXt[:, 2],
                      color='C%d' % i,
                      label='Opt. %d' % (i + 1))
        legend_sets.append([hi, hj, hk])

        ax.scatter(*use_x0[i], marker='*',
                   color='C%d' % i)  #, label='Start %d' % i)
        ax.text(use_x0[i][0], use_x0[i][1], use_x0[i][2], 'Start%d' % (i + 1))
    ax.scatter(0, 0, 0, marker='o', color='r')  # label='Goal'
    ax.text(0, 0, 0, "Goal")
    lgd = ax.legend(handles=legend_sets[0], loc=4, bbox_to_anchor=(0.8, 0.1))
    ax.add_artist(lgd)
    ax.legend(handles=legend_sets[1], loc=2, bbox_to_anchor=(0.15, 0.85))
    # ax.legend()
    ax.set_xlabel(r'$x$(m)')
    ax.set_ylabel(r'$y$(m)')
    ax.set_zlabel(r'$z$(m)')
    ax.view_init(elev=11, azim=10)
    fig.tight_layout()
    fig.savefig('gallery/droneone/moe_vs_snn_example.pdf')
    plt.show()
Пример #6
0
def eval_valid_constr_vio(cfg, lbl_name, args):
    """Evaluate models by violation of constraints."""
    # load violation evaluation function
    vio_fun = util.get_xy_vio_fun(cfg)
    # get validation dataset
    vdata = np.load(cfg['valid_path'])
    if 'valid_x_name' in cfg:
        x = vdata[cfg['valid_x_name']]
    else:
        x = vdata[cfg['x_name']]
    if 'valid_y_name' in cfg:
        y = vdata[cfg['valid_y_name']]
    else:
        y = vdata[cfg['y_name']]
    if 'flag' in vdata.keys():
        mask = np.where(vdata['flag'] == 1)
        x = x[mask]
        y = y[mask]
    uid = cfg['uniqueid']
    # first we try the snn case
    if args.snn:
        mdlfun = modelLoader(cfg['snn_path'])
        predy = mdlfun(x)
        n_data = x.shape[0]
        error = np.zeros(n_data)
        for i in range(n_data):
            error[i] = vio_fun(x[i], predy[i])
        print('average is %f' % (np.sum(error[error < 0]) / n_data))
        print('max error is ', np.amin(error))
        fig, ax = plt.subplots()
        ax.hist(error)
        plt.show()
        return
    # load MoE models from desired directory
    result = util.get_clus_reg_by_dir('models/%s/%s' % (uid, lbl_name))
    keys = result.keys()
    print('existing keys ', keys)
    v_error = []
    fig, ax = plt.subplots()
    for key in keys:
        print('For key ', key)
        cls, regs = result[key]
        net = MoMNet(cls, regs)
        predy = net.getPredY(x)
        n_data = x.shape[0]
        error = np.zeros(n_data)
        for i in range(n_data):
            error[i] = vio_fun(x[i], predy[i])
        v_error.append(error)
        merror = get_moving_average(error)
        ax.plot(merror)
    v_mean_error = [np.mean(error) for error in v_error]
    print('mean error is ', v_mean_error)
    print('mean neg error ',
          [np.sum(error[error < 0]) / error.shape[0] for error in v_error])
    print('max error is ', [np.amin(error) for error in v_error])
    # show histogram
    fig, ax = plt.subplots()
    ax.hist(v_error, bins=20, label=keys)
    ax.set_yscale('log', nonposy='clip')
    ax.set_ylim(1, ax.get_ylim()[1])
    ax.legend()
    ax.set_xlabel('Constraint Violation')
    ax.set_ylabel('Count')
    fig.savefig('gallery/%s/%s_valid_constr_vio_hist.pdf' % (uid, lbl_name))
    plt.show()
def show_picky_states(cfg, lbl_name, args):
    """Select a few states and draw predictions."""
    uid = cfg['uniqueid']
    lbl_name = 'pca_kmean_label'
    # load all training data and validation data
    data = npload(cfg['file_path'], uid)
    xname, yname = cfg['x_name'], cfg['y_name']
    datax, datay = data[xname], data[yname]
    # create a query instance
    query = Query(datax, scale=True)
    vdata = np.load(cfg['valid_path'])
    vx, vy = vdata[xname], vdata[yname]
    # snn model
    snn_fun = modelLoader(cfg['snn_path'])
    # moe model
    result = util.get_clus_reg_by_dir('models/%s/%s' % (uid, lbl_name))
    cls, regs = result[10]  # let me try this one
    net = MoMNet(cls, regs)
    # load cluster labels
    lbl_data_dct = np.load('data/%s/%s.npz' % (uid, lbl_name))
    label = lbl_data_dct['10']

    # eval snn on validation set and extract the one with largest prediction error
    pred_vy = snn_fun(vx)
    diff_vy = pred_vy - vy
    error_y = np.linalg.norm(diff_vy, axis=1)
    error_order = np.argsort(error_y)
    for i in range(7, 20):
        vx_idx = error_order[-1 - i]
        bad_x0 = vx[vx_idx]
        bad_sol = vy[vx_idx]
        snn_pred = pred_vy[vx_idx]
        moe_pred = net.getPredY(bad_x0)

        predX, _, _ = parseX(snn_pred)
        realX, _, _ = parseX(bad_sol)
        predXMoE, _, _ = parseX(moe_pred)

        # get neighbors
        index = query.getIndex(bad_x0)
        print('index ', index, 'label ', label[index])
        # draw them
        fig, axes = plt.subplots(1, 2)
        shown_cluster = []
        for ind in index:
            nnX, _, _ = parseX(datay[ind])
            if label[ind] not in shown_cluster:
                axes[1].plot(nnX[:, 0], nnX[:, 1], color='C%d' % label[ind], label='Cluster %d' % label[ind])
                shown_cluster.append(label[ind])
            else:
                axes[1].plot(nnX[:, 0], nnX[:, 1], color='C%d' % label[ind])
        axes[0].plot(predX[:, 0], predX[:, 1], color='#ff7f0e', linewidth=2, ls='--', label='SNN')
        axes[0].plot(predXMoE[:, 0], predXMoE[:, 1], color='g', linewidth=2, ls='--', label='MoE')
        axes[0].plot(realX[:, 0], realX[:, 1], color='k', linewidth=2, label='Opt.')
        finalAgl = predX[-1, 2]
        direc = [1*np.sin(finalAgl), 1*np.cos(finalAgl)]
        xf = predX[-1]
        for i in range(2):
            ax = axes[i]
            if i == 0:
                ax.arrow(xf[0], xf[1], direc[0], direc[1], color='#ff7f0e', linewidth=2, width=0.1)
            finalAgl = predXMoE[-1, 2]
            direc = [1*np.sin(finalAgl), 1*np.cos(finalAgl)]
            xf = predXMoE[-1]
            ax.arrow(xf[0], xf[1], direc[0], direc[1], color='g', linewidth=2, width=0.1)
            ax.scatter(0, 0, s=50, color='r')
            ax.annotate('Goal', (0, 0), xytext=(0.2, 0.2), textcoords='data')
            ax.scatter(bad_x0[0], bad_x0[1], s=50, color='k', marker='*')
            if i == 0:
                ax.annotate('Start', (bad_x0[0], bad_x0[1]), xytext=(-1 + bad_x0[0], -0.8 + bad_x0[1]), textcoords='data')
            else:
                ax.annotate('Start', (bad_x0[0], bad_x0[1]), xytext=(bad_x0[0], 0.3 + bad_x0[1]), textcoords='data')
            ax.set_xlabel(r'$x$')
            ax.axis('equal')
            if i == 0:
                xlim = ax.get_xlim()
                ax.set_ylabel(r'$y$')
            if i == 0:
                ax.legend()
            else:
                ax.legend(loc=4)
            if i == 0:
                ax.set_xlim(-2.5, xlim[1] + 1)
            else:
                xlim = ax.get_xlim()
                ax.set_xlim(xlim[0] - 1, xlim[1] + 1.5)
        fig.tight_layout()
        fig.savefig('gallery/car/car_snn_vs_moe_traj.pdf')
        plt.show()
Пример #8
0
def show_picky_examples(cfg):
    """I will use manual labels based on final angle"""
    SIZE = 24
    font = pld.getSizeFont(SIZE)
    N, dimx, dimu = 25, 2, 1

    data = np.load(cfg['file_path'])
    x0 = data['x0']
    Sol = data['Sol']
    nData = x0.shape[0]
    # find labels
    thetaf = (Sol[:, 48] + np.pi) / 2 / np.pi
    label = thetaf.astype(int)
    print([np.sum(label == i) for i in range(3)])
    # we use state (1.0, 0.6)
    # first we plot all-black the trajectories
    # find the data point
    onex0 = np.array([2.4, 1.6])
    diff = np.linalg.norm(Sol[:, :dimx] - onex0, axis=1)
    thetaf = np.reshape(Sol[:, 48], (61, 21)).T
    theind = np.argmin(diff)
    thex0 = Sol[theind, :dimx]
    print('x0 is {}'.format(thex0))

    blackName = 'gallery/pen/blackPendulum.pdf'
    if True:
        fig, ax = plt.subplots()
        for i in range(nData):
            row, col = np.unravel_index(i, (21, 61), order='F')
            todo = False
            if row == 0 or row == 20:
                todo = True
            elif col == 0 or col == 60:
                todo = True
            if not todo:
                neighborthetaf = [thetaf[row-1,col], thetaf[row+1, col], thetaf[row, col-1], thetaf[row, col+1]]
                if not np.allclose(neighborthetaf, thetaf[row, col]):
                    todo = True
            if todo:
                tX, _, _ = parseX(Sol[i], N, dimx, dimu)
                ax.plot(tX[:, 0], tX[:, 1], color='k')
        # draw the target
        ax.plot(np.pi, 0, color='r', marker='o', markersize=5)
        ax.plot(-np.pi, 0, color='r', marker='o', markersize=5)
        ax.plot(3*np.pi, 0, color='r', marker='o', markersize=5)
        pld.setTickSize(ax, SIZE)
        ax.set_xlabel(r'$\theta$', fontproperties=font)
        ax.set_ylabel(r'$\omega$', fontproperties=font)
        pld.savefig(fig, blackName)

    colorName = 'gallery/pen/colorPendulum.pdf'
    if True:
        ncluster = 3
        colors = pld.getColorCycle()
        fig, ax = plt.subplots()
        for k in range(ncluster):
            mask =  label ==  k
            x0_ = x0[mask]
            Sol_ = Sol[mask]
            ndata = len(x0_)
            for i in range(ndata):
                # check if it is at boundary
                todo = False
                theta, omega = x0_[i]
                # find with the same theta, if omega is at boundary
                vind = np.where(np.abs(x0_[:, 0] - theta) < 1e-2)[0]
                omegas = x0_[vind, 1]
                if omega < omegas.min() + 1e-4 or omega > omegas.max() - 1e-4:
                    todo = True
                if not todo:
                    # find same omega, check theta
                    vind = np.where(np.abs(x0_[:, 1] - omega) < 1e-2)[0]
                    thetas = x0_[vind, 0]
                    if theta < thetas.min() + 1e-4 or theta > thetas.max() - 1e-4:
                        todo = True
                if todo:
                    tX, _, _ = parseX(Sol_[i], N, dimx, dimu)
                    ax.plot(tX[:, 0], tX[:, 1], color=colors[k])
        # draw the target
        ax.plot(np.pi, 0, color='r', marker='o', markersize=5)
        ax.plot(-np.pi, 0, color='r', marker='o', markersize=5)
        ax.plot(3*np.pi, 0, color='r', marker='o', markersize=5)
        ax.set_xlabel(r'$\theta$', fontproperties=font)
        ax.set_ylabel(r'$\omega$', fontproperties=font)
        pld.setTickSize(ax, SIZE)
        pld.savefig(fig, colorName)

    snnbadName = 'gallery/pen/penSNNbadPred.pdf'
    if True:
        mdl = modelLoader(cfg['snn_path'])
        # find the one closest to one
        fig, ax = plt.subplots()
        tX, _, _ = parseX(Sol[theind], N, dimx, dimu)
        ax.plot(tX[:, 0], tX[:, 1], color='k', label='Optimal')
        # make prediction
        predy = mdl(thex0)
        tX, _, _ = parseX(predy, N, dimx, dimu)
        ax.plot(tX[:, 0], tX[:, 1], color='k', linestyle='--', label='SNN Pred.')
        ax.plot(3*np.pi, 0, color='r', marker='o', markersize=5)
        ax.set_xlabel(r'$\theta$', fontproperties=font)
        ax.set_ylabel(r'$\omega$', fontproperties=font)
        ax.legend(fontsize=SIZE)
        pld.setTickSize(ax, SIZE)
        pld.savefig(fig, snnbadName)

    momgoodName = 'gallery/pen/penMoMgoodPred.pdf'
    lbl_name = 'pca_kmean_label'
    result = util.get_clus_reg_by_dir('models/pen/pca_kmean_label')
    cls, regs = result[5]
    net = MoMNet(cls, regs)
    if True:
        fig, ax = plt.subplots()
        tX, _, _ = parseX(Sol[theind], N, dimx, dimu)
        ax.plot(tX[:, 0], tX[:, 1], color='k', label='Optimal')
        # make prediction
        predy = net.getPredY(thex0)
        tX, _, _ = parseX(predy, N, dimx, dimu)
        ax.plot(tX[:, 0], tX[:, 1], color='k', linestyle='--', label='MoE Pred.')
        ax.plot(3*np.pi, 0, color='r', marker='o', markersize=5)
        ax.set_xlabel(r'$\theta$', fontproperties=font)
        ax.set_ylabel(r'$\omega$', fontproperties=font)
        ax.legend(fontsize=SIZE)
        pld.setTickSize(ax, SIZE)
        pld.savefig(fig, momgoodName)

    plt.show()