示例#1
0
    def __init__(self, args):
        self.args = args
        self.env = make_env(args)
        self.env_test = make_env(args)

        self.info = []
        if args.save_acc:
            make_dir('log/accs', clear=False)
            self.test_rollouts = 100

            self.env_List = []
            self.env_test_List = []
            for _ in range(self.test_rollouts):
                self.env_List.append(make_env(args))
                self.env_test_List.append(make_env(args))

            self.acc_record = {}
            self.acc_record[self.args.goal] = []
            for key in self.acc_record.keys():
                self.info.append('Success/' + key + '@blue')
示例#2
0
	def __init__(self, args, test_rollouts=100, after_train_test=False):
		self.args = args
		self.env = make_env(args)
		self.env_test = make_env(args)

		self.info = []
		self.calls = 0
		self.after_train_test = after_train_test
		if args.save_acc:
			make_dir('log/accs', clear=False)
			self.test_rollouts = test_rollouts

			self.env_List = []
			for _ in range(self.test_rollouts):
				self.env_List.append(make_env(args))

			self.acc_record = {}
			self.acc_record[self.args.goal] = []
			for key in self.acc_record.keys():
				self.info.append('Success'+'@blue')
				self.info.append('MaxDistance')
				self.info.append('MinDistance')

		self.coll_tol = 0 #this attribute is just used for tests after training
示例#3
0
文件: test.py 项目: MouseHu/gem
    def __init__(self, args):
        self.args = args
        self.env = make_env(args)
        self.info = []

        if args.save_rews:
            make_dir('log/rews', clear=False)
            self.rews_record = {}
            self.rews_record[args.env] = []

        if args.save_Q:
            make_dir('log/Q_std', clear=False)
            make_dir('log/Q_net', clear=False)
            make_dir('log/Q_ground', clear=False)
            self.Q_std_record, self.Q_net_record, self.Q_ground_record = {}, {}, {}
            self.Q_std_record[args.env], self.Q_net_record[
                args.env], self.Q_ground_record[args.env] = [], [], []
            self.info += ['Q_error/mean', 'Q_error/std']
示例#4
0
        [ppairs[inds1][:, 0, :], ppairs_around_obstacle[inds2][:, 0, :]],
        axis=1)
    return new_ppairs


def flat_entries(bboxes_list, ppair):
    return np.concatenate([bboxes_list.ravel(), ppair.ravel()])


if __name__ == "__main__":
    args = get_args()
    # create data folder if it does not exist, corresponding folders, and files where to store data
    this_file_dir = os.path.dirname(os.path.abspath(__file__)) + '/'
    base_data_dir = this_file_dir + 'data/'
    env_data_dir = base_data_dir + args.env + '/'
    make_dir(env_data_dir, clear=False)

    if args.vae_dist_help:
        load_vaes(args)
    load_field_parameters(args)
    env = make_temp_env(args)

    field_names = ['ppair', 'bbox', 'distance']
    csv_file_path = env_data_dir + 'distances.csv'
    csv_file_path_val = env_data_dir + 'distances_val.csv'
    csv_file_path_test = env_data_dir + 'distances_test.csv'
    for csv_path in [csv_file_path, csv_file_path_val, csv_file_path_test]:
        if os.path.exists(csv_path):
            os.remove(csv_path)
        with open(csv_path, 'w') as csv_file:
            writer = csv.DictWriter(csv_file, fieldnames=field_names)
示例#5
0
                        default=6)
    parser.add_argument('--beta',
                        help='beta val for the reconstruction loss',
                        type=np.float,
                        default=8.)  #5#8
    parser.add_argument('--gamma',
                        help='gamma val for the mask loss',
                        type=np.float,
                        default=5.)  #2.)#5
    parser.add_argument('--bg_sigma', help='', type=np.float, default=0.09)
    parser.add_argument('--fg_sigma', help='', type=np.float, default=0.11)

    args = parser.parse_args()

    # get names corresponding folders, and files where to store data
    make_dir(this_file_dir + 'results/', clear=False)
    base_data_dir = this_file_dir + '../data/'
    data_dir = base_data_dir + args.env + '/'

    train_file = data_dir + 'all_set.npy'
    weights_path = data_dir + 'all_sb_model'

    if args.task == 'train':
        train_Vae(epochs=args.train_epochs,
                  batch_size=args.batch_size,
                  img_size=args.img_size,
                  latent_size=args.latent_size,
                  train_file=train_file,
                  vae_weights_path=weights_path,
                  beta=args.beta,
                  gamma=args.gamma,
示例#6
0
    parser.add_argument('--task', help='the task for the generation of data', type=str,
                        default='generate', choices=['generate', 'mix', 'show'], required=True)
    #CURRENTLY using just FetchGenerativeEnv-v1; argument is left since it is used in different parts of code
    parser.add_argument('--env', help='gym env id', type=str, default='FetchGenerativeEnv-v1', choices=Robotics_envs_id)
    args, _ = parser.parse_known_args()
    if args.task == 'mix':
        parser.add_argument('--file_1', help='first file to mix', type=str)
        parser.add_argument('--file_2', help='second file to mix', type=str)
        parser.add_argument('--output_file', help='name of output file for mixed dataset', type=str)
        args = parser.parse_args()

        this_file_dir = os.path.dirname(os.path.abspath(__file__)) + '/'
        base_data_dir = this_file_dir + '../data/'
        env_data_dir = base_data_dir + args.env + '/'
        make_dir(env_data_dir, clear=False)
        data_file_1 = env_data_dir + args.file_1
        data_file_2 = env_data_dir + args.file_2

        data_1 = np.load(data_file_1)
        data_2 = np.load(data_file_2)
        mixed_data = np.concatenate([data_1, data_2], axis=0)
        np.random.shuffle(mixed_data)

        output_file = env_data_dir + args.output_file
        np.save(output_file, mixed_data)

    else:
        if args.env == 'HandReach-v0':
            parser.add_argument('--goal', help='method of goal generation', type=str, default='reach',
                                choices=['vanilla', 'reach'])