예제 #1
0
T = 6
timer = Timer(T)
id_timer = Timer(3 * T)
while True:
    random.shuffle(ks)
    for k in ks:
        hdf5_filename = keys[k]
        solver_inputs = solver_inputs_dic[hdf5_filename]
        caffe_net.solver.net.blobs['ZED_data_pool2'].data[:] = solver_inputs[
            k]['ZED_data_pool2'][:] / 255. - 0.5
        caffe_net.solver.net.blobs['metadata'].data[:] = solver_inputs[k][
            'metadata'][:]
        caffe_net.solver.net.blobs[
            'steer_motor_target_data'].data[:] = solver_inputs[k][
                'steer_motor_target_data'][:]
        caffe_net.train_step()
        steer.append([
            caffe_net.solver.net.blobs['steer_motor_target_data'].data[0, 9],
            caffe_net.solver.net.blobs['ip2'].data[0, 9]
        ])
        motor.append([
            caffe_net.solver.net.blobs['steer_motor_target_data'].data[0, 19],
            caffe_net.solver.net.blobs['ip2'].data[0, 19]
        ])
        ctr += 1
        if timer.check():
            plot_performance(steer, motor, caffe_net.loss1000)
            timer.reset()
        if id_timer.check():
            cprint(solver_file_path, 'blue', 'on_yellow')
            id_timer.reset()
예제 #2
0
            wait_delay = 60
            cprint(
                d2s('Waiting', wait_delay,
                    'seconds to let data thread load a lot of data.'))
            time.sleep(wait_delay)
        try:
            #print 'here 2'
            data = data_list[-1]
        except Exception as e:
            cprint("********** Exception ***********************", 'red')
            print(e.message, e.args)
        if data != None:
            #print data['path']
            #time.sleep(1)
            #print 'here 3'
            caffe_net.train_step(data)
            #print 'here 4'
        else:
            print "data == None"


def plot_loss1000(paths='/home/karlzipser/Desktop/loss1000.pkl',
                  max_num_points=100000,
                  style='ro-'):
    if type(paths) != list:
        paths = [paths]
    l = []
    for path in paths:
        l = l + list(load_obj(path))
    if len(l) > max_num_points:
        l = l[:max_num_points]
예제 #3
0
        plt.clf()
        plot(caffe_net.loss1000)
        pause(0.00001)
        counts_timer.reset()
        for c in bag_viewed_counter_dic.keys():
            counts.append(bag_viewed_counter_dic[c])
        counts = sorted(counts)
        #figure('counts')
        #hist(counts,bins=25)
        count_median = np.median(array(counts))
        #plt.title(count_median)
        #pause(0.0001)
    #cprint(count_median,'red','on_yellow')
    #get_data(BagFolder_dic,bag_img_dic,group_binned_timestamps,NUM_STATE_ONE_STEPS)

    caffe_net.train_step(solver_ready_queue, solver_waiting_queue)

    if False:
        visualize_data(data, 5, True)
    if False:
        bf = data['path']
        if bf not in bag_viewed_counter_dic:
            bag_viewed_counter_dic[bf] = 0
        bag_viewed_counter_dic[bf] += 1
    #print(d2s(bf,bag_viewed_counter_dic[bf]))
t3 = time.time()
train_time = t3 - t2
print(d2s('train_time =', train_time))
print(d2s(ctr, 100 * ctr / (ctr2 * 1.0), '%'))
timing_data.append([load_time, train_time])
exit()