for obj in args.objects: #obj = 'hammer' #idxes_train = [4,1,7,8] #idxes_test = [0,2,3,5,6,9] #idxes_train = [1] #idxes_test = [1] #idxes_train = range(10) idxes_test = [] #idx_param = 9 idx_param = obj2paramntrain[obj][0] idxes_train = idxes_set[obj2paramntrain[obj][1]-1][obj2sets['sims'][obj][0]][0] from setups_sim import get_setups_tools_sim setups = get_setups_tools_sim([obj],idxes_train,idxes_test,idx_param) setups['target_infos'][0]['masses'] = gts[obj]['masses'] engine = DiffPhysics(is_dbg=False, fp_video='') engine.setup(setups, n_group=-1) engine.mass_minmax = [2.5 - (10-args.seed)*0.25, 2.5 + (10-args.seed)*0.25] engine.fric_minmax = [0.5 - (args.seed-1 )*0.05, 0.5 + (args.seed-1 )*0.05] print(engine.fric_minmax) history = engine.infer(infer_type='cell_mass',nsimslimit=500, cellinfos_init=None, freq_update_fric=1) print('[Test Result] cell err: %.3f (m) pose err: %.3f (m) %.3f (deg)' % (history[-1]['test_error_cell'], history[-1]['test_error_xy'], history[-1]['test_error_yaw']) ) fp_save = '%s_seed%d.pkl' % (obj,args.seed) with open(os.path.join(args.outdir,fp_save),'wb') as f:
for idx in args.sets: print('param:%d obj:%s n_train:%d set:%d' % (idx_param,obj,n_trains,idx)) if args.skip: fp_save = os.path.join(args.outdir, '%s_%s_param%d_train%d_set%d_c%d.pkl'%\ (args.method,obj,idx_param,n_trains,idx,obj2ngroups[obj])) if os.path.isfile(fp_save): print('[SKIP] %s'%fp_save) continue idxes_train, idxes_test = idxes_sim[n_trains-1][idx] engine = get_engine(args.method) if args.method.endswith('_cell'): if args.method.startswith('adiff') or args.method.startswith('asdiff'): setups = get_setups_tools_sim([obj],None,None,idx_param) engine.setup(setups) history = engine.infer(n_trains, idxes_train[0]) else: setups = get_setups_tools_sim([obj],idxes_train,None,idx_param) engine.setup(setups, n_group=n_group) history = engine.infer(infer_type='cell_mass',nsimslimit=500) n_group = len(engine.meta_targets[obj]['cells']['cells']) else: engine.setup(setups, n_group=n_group) history = engine.infer(infer_type='group_mass') print('Test Result: %.3f (m) %.3f (deg) %.3f (cell)' % (history[-1]['test_error_xy'], history[-1]['test_error_yaw'],