예제 #1
0
def main():
    for LOG_DIR in FLAGS.log_dir:
        CKPT_DIR = os.path.join(LOG_DIR, 'trained_models')
        if not os.path.exists(LOG_DIR):
            print('ERROR: log_dir %s does not exist! Please Check!' % LOG_DIR)
            exit(1)
        VALID_DIR = os.path.join(LOG_DIR, FLAGS.valid_dir)
        if not os.path.exists(VALID_DIR):
            print('ERROR: valid_dir %s does not exist! Run valid.py first!' % VALID_DIR)
            exit(1)
        LOG_DIR = os.path.join(LOG_DIR, FLAGS.eval_dir)
        check_mkdir(LOG_DIR)
        PRED_DIR = os.path.join(LOG_DIR, FLAGS.pred_dir)
        force_mkdir(PRED_DIR)
        if FLAGS.visu_dir is not None:
            VISU_DIR = os.path.join(LOG_DIR, FLAGS.visu_dir)
            force_mkdir(VISU_DIR)

        os.system('cp %s %s' % (MODEL_FILE, LOG_DIR)) # bkp of model def
        os.system('cp %s %s' % (__file__, LOG_DIR)) # bkp of train procedure
        LOG_FOUT = open(os.path.join(LOG_DIR, 'log_eval.txt'), 'w')
        LOG_FOUT.write(str(FLAGS)+'\n')

        data_in_dir = '../../data/ins_seg_h5_for_sgpn/%s-%d/' % (FLAGS.category, FLAGS.level_id)
        test_h5_fn_list = []
        for item in os.listdir(data_in_dir):
            if item.endswith('.h5') and item.startswith('test-'):
                test_h5_fn_list.append(os.path.join(data_in_dir, item))

        NUM_CLASSES = 1
        print('force Semantic Labels: ', NUM_CLASSES)
        NUM_INS = FLAGS.num_ins
        print('Number of Instances: ', NUM_INS)

        # load validiation hyper-parameters
        pw_sim_thres = np.loadtxt(os.path.join(VALID_DIR, 'per_category_pointwise_similarity_threshold.txt')).reshape(NUM_CLASSES)
        avg_group_size = np.loadtxt(os.path.join(VALID_DIR, 'per_category_average_group_size.txt')).reshape(NUM_CLASSES)
        min_group_size = 0.25 * avg_group_size

        for i in range(NUM_CLASSES):
            print('%d %f %d' % (i, pw_sim_thres[i], min_group_size[i]))

        # main
        log_string('pid: %s'%(str(os.getpid())), LOG_FOUT)
        eval(NUM_CLASSES=NUM_CLASSES, NUM_POINT=NUM_POINT, NUM_INS=NUM_INS, CKPT_DIR=CKPT_DIR,
             test_h5_fn_list=test_h5_fn_list, min_group_size=min_group_size, pw_sim_thres=pw_sim_thres,
             PRED_DIR=PRED_DIR, VISU_DIR=VISU_DIR, LOG_FOUT=LOG_FOUT)
        LOG_FOUT.close()
예제 #2
0
from PIL import Image
import scipy.misc as misc

from detect_adj import compute_adj
from detect_ref_sym import compute_ref_sym, atob_ref_sym
from detect_trans_sym import compute_trans_sym, atob_trans_sym
from detect_rot_sym import compute_rot_sym, atob_rot_sym

anno_id = sys.argv[1]
in_dir = os.path.join('../../data', anno_id)
render_dir = os.path.join(in_dir, 'parts_render_after_merging')

out_dir = os.path.join('results', anno_id)
check_mkdir(out_dir)
visu_dir = os.path.join(out_dir, 'visu')
force_mkdir(visu_dir)

parent_dir = os.path.join(visu_dir, 'parent')
os.mkdir(parent_dir)
info_dir = os.path.join(visu_dir, 'info')
os.mkdir(info_dir)
child_dir = os.path.join(visu_dir, 'child')
os.mkdir(child_dir)

json_fn = os.path.join(in_dir, 'result_after_merging.json')
with open(json_fn, 'r') as fin:
    data = json.load(fin)[0]

found_edges = dict()

from commons import check_mkdir, force_mkdir

parser = argparse.ArgumentParser()
parser.add_argument('--log_dir', type=str, default='log', help='Log dir [default: log]')
parser.add_argument('--eval_dir', type=str, default='eval', help='Eval dir [default: eval]')
parser.add_argument('--visu_dir', type=str, default=None, help='Visu dir [default: None, meaning no visu]')
FLAGS = parser.parse_args()

LOG_DIR = FLAGS.log_dir
if not os.path.exists(LOG_DIR):
    print('ERROR: log_dir %s does not exist! Please Check!' % LOG_DIR)
    exit(1)
LOG_DIR = os.path.join(LOG_DIR, FLAGS.eval_dir)
if FLAGS.visu_dir is not None:
    VISU_DIR = os.path.join(LOG_DIR, FLAGS.visu_dir)
    force_mkdir(VISU_DIR)

def get_palette(num_cls):
    """ Returns the color map for visualizing the segmentation mask.
    Args:
        num_cls: Number of classes
    Returns:
        The color map
    """

    n = num_cls
    palette = [0] * (n * 3)
    for j in range(0, n):
        lab = j
        palette[j * 3 + 0] = 0
        palette[j * 3 + 1] = 0
    call(cmd, shell=True)

    # save h5
    pts = load_h5(in_h5_fn)
    print 'pts: ', pts.shape

    # get the first NUM_POINT points
    pts = pts[:, :NUM_POINT, :]

    save_h5(out_h5_fn, pts)


# main
data_in_dir = '../../data/ins_seg_h5/%s/' % args.category
data_out_dir = '../../data/ins_seg_h5_for_detection'
force_mkdir(data_out_dir)
data_out_dir = os.path.join(data_out_dir,
                            '%s-%d' % (args.category, args.level_id))
force_mkdir(data_out_dir)

print args.category, args.level_id, 'test'

h5_fn_list = []
for item in os.listdir(data_in_dir):
    if item.endswith('.h5') and item.startswith('test-'):
        h5_fn_list.append(item)

for item in h5_fn_list:
    in_h5_fn = os.path.join(data_in_dir, item)
    out_h5_fn = os.path.join(data_out_dir, item)
    reformat_data(in_h5_fn, out_h5_fn)
예제 #5
0
BATCH_SIZE = FLAGS.batch_size
NUM_POINT = FLAGS.num_point
GPU_INDEX = FLAGS.gpu

MODEL = importlib.import_module(FLAGS.model)  # import network module
MODEL_FILE = os.path.join(ROOT_DIR, 'models', FLAGS.model + '.py')
LOG_DIR = FLAGS.log_dir
CKPT_DIR = os.path.join(LOG_DIR, 'trained_models')
if not os.path.exists(LOG_DIR):
    print('ERROR: log_dir %s does not exist! Please Check!' % LOG_DIR)
    exit(1)
LOG_DIR = os.path.join(LOG_DIR, FLAGS.eval_dir)
check_mkdir(LOG_DIR)
PRED_DIR = os.path.join(LOG_DIR, FLAGS.pred_dir)
force_mkdir(PRED_DIR)
if FLAGS.visu_dir is not None:
    VISU_DIR = os.path.join(LOG_DIR, FLAGS.visu_dir)
    force_mkdir(VISU_DIR)

os.system('cp %s %s' % (MODEL_FILE, LOG_DIR))  # bkp of model def
os.system('cp %s %s' % (__file__, LOG_DIR))  # bkp of train procedure
LOG_FOUT = open(os.path.join(LOG_DIR, 'log_eval.txt'), 'w')
LOG_FOUT.write(str(FLAGS) + '\n')

# load meta data files
stat_in_fn = '../../stats/after_merging_label_ids/%s-level-%d.txt' % (
    FLAGS.category, FLAGS.level_id)
print('Reading from ', stat_in_fn)
with open(stat_in_fn, 'r') as fin:
    part_name_list = [item.rstrip().split()[1] for item in fin.readlines()]
예제 #6
0
FLAGS = parser.parse_args()

BATCH_SIZE = FLAGS.batch_size
NUM_POINT = FLAGS.num_point
MAX_EPOCH = FLAGS.max_epoch
BASE_LEARNING_RATE = FLAGS.learning_rate
GPU_INDEX = FLAGS.gpu
MOMENTUM = FLAGS.momentum
OPTIMIZER = FLAGS.optimizer

MODEL = importlib.import_module(FLAGS.model) # import network module
MODEL_FILE = os.path.join(ROOT_DIR, 'models', FLAGS.model+'.py')
LOG_DIR = FLAGS.log_dir
check_mkdir(LOG_DIR)
CKPT_DIR = os.path.join(LOG_DIR, 'trained_models')
force_mkdir(CKPT_DIR)

if FLAGS.visu_dir is not None:
    VISU_DIR = os.path.join(LOG_DIR, FLAGS.visu_dir)
    force_mkdir(VISU_DIR)

os.system('cp %s %s' % (MODEL_FILE, LOG_DIR)) # bkp of model def
os.system('cp %s %s' % (__file__, LOG_DIR)) # bkp of train procedure
LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w')
LOG_FOUT.write(str(FLAGS)+'\n')

def log_string(out_str):
    LOG_FOUT.write(out_str+'\n')
    LOG_FOUT.flush()
    print(out_str)