コード例 #1
0
    def __init__(self, opt, phase, transform=None, quiet=False):
        if not quiet:
            print('DataLoader loading h5 question file: ' +
                  opt['h5_ques_file'])
        h5_file = h5py.File(opt['h5_ques_file'], 'r')
        if phase is 'train':
            if not quiet:
                print('DataLoader loading h5 image train file: ' +
                      opt['h5_img_file_train'])
            self.image = np.array(
                h5py.File(opt['h5_img_file_train'], 'r')['/images_train'])
            self.ques = np.array(h5_file['/ques_train'])
            self.ques_len = np.array(h5_file['/ques_len_train'])
            self.img_pos = np.array(h5_file['/img_pos_train'])
            self.ques_id = np.array(h5_file['/ques_id_train'])
            self.ans = np.array(h5_file['/answers'])
            self.split = np.array(h5_file['/split_train'])
        else:  # valid or test
            if not quiet:
                print('DataLoader loading h5 image test file: ' +
                      opt['h5_img_file_test'])
            self.image = np.array(
                h5py.File(opt['h5_img_file_test'], 'r')['/images_test'])
            self.ques = np.array(h5_file['/ques_test'])
            self.ques_len = np.array(h5_file['/ques_len_test'])
            self.img_pos = np.array(h5_file['/img_pos_test'])
            self.ques_id = np.array(h5_file['/ques_id_test'])
            self.ans = np.array(h5_file['/ans_test'])
            self.split = np.array(h5_file['/split_test'])

        self.feature_type = opt['feature_type']
        self.phase = phase
        self.transform = transform

        if not quiet:
            print('DataLoader loading json file: %s' % opt['json_file'])
        json_file = utils.read_json(opt['json_file'])
        self.ix_to_word = json_file['ix_to_word']
        self.ix_to_ans = json_file['ix_to_ans']

        self.vocab_size = utils.count_key(self.ix_to_word)
        self.seq_length = self.ques.shape[1]
コード例 #2
0
def config_tasks_envs(eparams):
    '''
        Configure tasks parameters.
        Envs params and task parameters based on pearl paper:
        args like followings will be added:
        n_train_tasks   2
        n_eval_tasks    2
        n_tasks 2
        randomize_tasks true
        low_gear    false
        forward_backward    true
        num_evals   4
        num_steps_per_task  400
        num_steps_per_eval  400
        num_train_steps_per_itr 4000
    '''
    configs = read_json(eparams.env_configs)[eparams.env_name]
    temp_params = vars(eparams)
    for k, v in configs.items():
            temp_params[k] = v