示例#1
0
    def load(self, data, training=True):
        print('Loading Audio')
        self.training = training
        m = Manager()
        audio = set()
        for i in data:
            _, _, path = i.strip().split('@')
            if path not in audio:
                audio.add(path + self.domain + '.npy')
        self.feats = m.dict()
        if training:
            self.mean, self.std = m.list(), m.list()
        with Pool(processes=100) as p:
            with tqdm.tqdm(total=len(audio)) as pbar:
                for i, _ in tqdm.tqdm(
                        enumerate(p.imap_unordered(self._load, audio))):
                    pbar.update()
        examples, lengths = {}, {}
        for i in data:
            key, _, path = i.strip().split('@')
            path += self.domain + '.npy'
            if path in self.feats:
                examples[key] = path
                lengths[key] = self.feats[path]
        #$import pdb; pdb.set_trace()

        if self.training:
            self.mean = (sum(self.mean) / (len(self.mean) + 1))
            self.std = sum(self.std) / (len(self.std) + 1)
            self.dim = self.mean.shape[0]

        del self.feats

        return examples, lengths
示例#2
0
def img_rescaler(dir_in, extension_in, threads=1):
    """ 
    Import an image, rescale it to normal UBYTE (0-255, 8 bit) range, and re-save it.
    
    """

    dir_out = os.path.join(dir_in, "rescaled")
    
    total_files = 0
    for path, folder, filename in os.walk(dir_in):
        if dir_out not in path:
            for f in filename:
                if f.endswith(extension_in):
                    total_files += 1
    print("\nYou have {} images to analyze".format(total_files))
    
    for path, folder, filename in os.walk(dir_in):
        if dir_out not in path:   # Don't run in the output directory.

            # Make directory for saving objects
            subpath = path[len(dir_in)+1:]
            if not os.path.exists(os.path.join(dir_out, subpath)):
                os.mkdir(os.path.join(dir_out, subpath))

            # What we'll do:
            global _core_fn  # bad form for Pool.map() compatibility
            def _core_fn(filename):
                if filename.endswith(extension_in):
                    # count progress.

                    path_in = os.path.join(path, filename)
                    subpath_in = os.path.join(subpath, filename) # for printing purposes
                    path_out = os.path.join(dir_out, subpath, filename)

                    if os.path.exists(path_out): #skip
                        print("\nALREADY ANALYZED: {}. Skipping...\n".format(subpath_in))

                    else: #(try to) do it
                        try:
                            img = io.imread(path_in)  # load image
                            img = img_as_ubyte(img / np.max(img))
                            io.imsave(path_out, img)
                        except:
                            print("Couldn't analyze {}".format(subpath_in))
                return()
            
            # run it
            sleep(1)  # to give everything time to  load
            thread_pool = Pool(threads)
            # Work on _core_fn (and give progressbar)
            tqdm.tqdm(thread_pool.imap_unordered(_core_fn,
                                                 filename,
                                                 chunksize=1),
                      total=total_files)
            # finish
            thread_pool.close()
            thread_pool.join()
    return()
 def fit(use_onecycle=False, model=model):
     print("Epoch\tTrn_loss\tVal_loss\tTrn_acc\t\tVal_acc")
     for j in range(epoch):
         t = tqdm.tqdm(train_loader, leave=False, total=len(train_loader))
         train(t, j, use_onecycle, model)
         t = tqdm.tqdm(train_loader, leave=False, total=len(train_loader))
         test(t, model)
         #pdb.set_trace()
         print(j + 1, trn_losses[-1], val_losses[-1],
               sum(trn_accs) / len(trn_accs),
               sum(val_accs) / len(val_accs))
示例#4
0
def sn7_convert_geojsons_to_csv(json_dirs, population='proposal'):
    '''
    Convert jsons to csv
    Population is either "ground" or "proposal" 
    '''

    first_file = True  # switch that will be turned off once we process the first file
    for json_dir in tqdm.tqdm(json_dirs):
        json_files = sorted(glob.glob(os.path.join(json_dir, '*.geojson')))
        for json_file in tqdm.tqdm(json_files):
            try:
                df = gpd.read_file(json_file)
            except (fiona.errors.DriverError):
                message = '! Invalid dataframe for %s' % json_file
                print(message)
                continue
                #raise Exception(message)
            if population == 'ground':
                file_name_col = df.image_fname.apply(
                    lambda x: os.path.splitext(x)[0])
            elif population == 'proposal':
                file_name_col = os.path.splitext(
                    os.path.basename(json_file))[0]
            else:
                raise Exception('! Invalid population')

            all_geom = []
            for g in df.geometry.scale(xfact=1 / scale,
                                       yfact=1 / scale,
                                       origin=(0, 0)):
                g0 = g.simplify(0.25)
                g0 = loads(dumps(g0, rounding_precision=2))
                all_geom.append(g0)
            df = gpd.GeoDataFrame({
                'filename': file_name_col,
                'id': df.Id.astype(int),
                'geometry': all_geom,
            })
            if len(df) == 0:
                message = '! Empty dataframe for %s' % json_file
                print(message)
                #raise Exception(message)

            if first_file:
                net_df = df
                first_file = False
            else:
                net_df = net_df.append(df)

    return net_df
示例#5
0
    def get_features(self, image_list, model=None):
        self.image_list = image_list.files

        if Path('features/{}_features.npy'.format(self.dataset_name)).exists():
            logging.info('Feature files are found.')
            self.load_features()
        else:
            logging.info('Feature Extraction, It may take a while...')
            if model is not None:
                self.model = model

            embeddings = np.zeros((len(self.image_list), 512))

            for idx in tqdm.tqdm(
                    range(0, len(self.image_list), self.batch_size)):

                batch_list = self.image_list[idx:idx + self.batch_size]
                batch_data = self.get_batch_img(batch_list)
                batch_data = torch.FloatTensor(batch_data).to(self.device)

                embed = self.model(batch_data)
                embeddings[idx:idx +
                           self.batch_size] = embed.detach().cpu().numpy()

                embed = None
            self.features = embeddings
            self.save_feature()

        labels = image_list.Class_ID
        le = preprocessing.LabelEncoder()
        self.labels = le.fit_transform(labels)
示例#6
0
def parallel_apply(df_column, function, number_of_workers, loading_bars,
                   **props):
    """
    This function will run pandas.apply in parallel depending on the number of CPUS the user specifies.
    """

    steps = len(df_column) / number_of_workers
    mid_dfs = []
    for x in range(number_of_workers):
        if x == number_of_workers - 1:
            mid_dfs.append(df_column.iloc[int(steps * x):])
        else:
            mid_dfs.append(df_column.iloc[int(steps * x):int(steps * (x + 1))])

    main_df = None
    with cf.ProcessPoolExecutor(max_workers=number_of_workers) as executor:

        results = []
        for mid_df in mid_dfs:
            results.append(executor.submit(__aw__, mid_df, function, **props))

        if loading_bars:
            for f in tqdm.tqdm(cf.as_completed(results),
                               total=number_of_workers):
                if main_df is None:
                    main_df = f.result()
                else:
                    main_df = main_df.append(f.result())
        else:
            for f in cf.as_completed(results):
                if main_df is None:
                    main_df = f.result()
                else:
                    main_df = main_df.append(f.result())
    return main_df
示例#7
0
文件: SFFS_run.py 项目: BigDaMa/DFS
def execute_feature_combo1(feature_combo, feature_combo_id=0, params=pARAMS):
    mask = np.zeros(mp_globalsfs.data_per_fold[0][0].shape[1], dtype=bool)
    for fc in feature_combo:
        mask[fc] = True

    mp_globalsfs.mask = mask

    hyperparameter_search_scores = []
    for c in params:
        pipeline = Pipeline([('imputation', SimpleImputer()),
                             ('selection', MaskSelection(mask)),
                             get_model(c)])

        mp_globalsfs.parameter = c

        cv_scores = []
        with Pool(processes=multiprocessing.cpu_count()) as p:
            cv_scores = list(
                tqdm.tqdm(p.imap(run_fold,
                                 range(len(mp_globalsfs.data_per_fold))),
                          total=len(mp_globalsfs.data_per_fold)))

        hyperparameter_search_scores.append(np.mean(cv_scores))

    return (feature_combo_id, np.max(hyperparameter_search_scores),
            params[np.argmax(hyperparameter_search_scores)])
示例#8
0
    def pred(self, sample_size=1):
        """
        Do prediction on train data.
        """
        predictions = []
        with torch.no_grad():
            self.model.eval()
            total = len(self.train_dialogue_dataloader)
            for i, data in tqdm.tqdm(enumerate(self.train_dialogue_dataloader),
                                     total=total):
                d_data = data

                post = d_data['post'].to(self.device)
                bs = len(post)

                enc_contexts = list()
                enc_contexts.append(self.model.encode(post))

                styles = torch.ones(bs).long().to(self.device)
                top_p = self.config.annealing_topp
                prediction, lens = self.model.top_k_top_p_search(
                    enc_contexts,
                    top_p=top_p,
                    styles=styles,
                    sample_size=sample_size)

                for j in range(bs):
                    post_str = self.ids2string(post[j])
                    for k in range(sample_size):
                        pred_str = self.ids2string(
                            prediction[j * sample_size + k],
                            lens[j * sample_size + k] - 1)
                        predictions.append((post_str, pred_str))

        return predictions
示例#9
0
def extract_object_feature(flags):
    is_cuda = t.cuda.is_available()  # 是否有GPU资源
    cnn = factory(flags, is_cuda, is_cuda)

    games = []
    for set in ["train", "valid", "test"]:
        games.extend(get_games(flags.data_dir, set))

    objectset = ObjectSet(games, flags.image_dir)
    batch_size = flags.batch_size
    objectloader = DataLoader(objectset, batch_size=batch_size, shuffle=False, collate_fn=collate)

    fea_dir = os.path.join(flags.fea_dir, flags.arch)
    if not os.path.exists(fea_dir):
        os.mkdir(fea_dir)
    fea_dir = os.path.join(fea_dir, flags.feature)
    f = h5py.File(os.path.join(fea_dir, "crop.hdf5"), "w")
    shape = tuple([len(objectset)] + list(flags.shape))
    fea_set = f.create_dataset("feature", shape, chunks=True)
    index_set = f.create_dataset("index", (len(objectset),), dtype='i')

    index = 0
    for batch_input in tqdm.tqdm(objectloader):
        # retrieve id images
        fea, ids = batch_input
        fea = cnn(fea).detach().cpu().numpy()
        size = ids.shape[0]
        fea_set[index:index + size] = fea
        index_set[index:index + size] = ids
        index += size

    f.close()
示例#10
0
def install_zip(zip_file, end_dir, rpath, arg):
    print('unzipping...', file=sys.stderr)

    import tqdm
    from zipfile import ZipFile
    with ZipFile(zip_file, 'r') as zipObj:
        # Get a list of all archived file names from the zip
        files = zipObj.namelist()
        for i in arg:
            if arg in files:
                pass
            else:
                files.append(i)
        import tqdm
        print('installing...')
        Bar = tqdm.tqdm(files)

        # Iterate over the file names
        first = None
        for fileName in Bar:
            # Extract a single file from zip
            if first == None:
                first = fileName
            try:
                zipObj.extract(fileName, end_dir)
            except:
                pass
        Bar.close()
    def subsample(self, down_to=1, new_path=None, verbose=True):
        """Pick a given number of sequences from the file pseudo-randomly."""
        # Pick the destination path #
        if new_path is None:
            subsampled = self.__class__(new_temp_path())
        elif isinstance(new_path, FASTA):
            subsampled = new_path
        else:
            subsampled = self.__class__(new_path)
        # Check size #
        if down_to > len(self):
            message = "Can't subsample %s down to %i. Only down to %i."
            print(Color.ylw + message % (self, down_to, len(self)) + Color.end)
            self.copy(new_path)
            return
        # Select verbosity #
        import tqdm
        if verbose: wrapper = lambda x: tqdm.tqdm(x, total=self.count)
        else: wrapper = lambda x: x

        # Generator #
        def iterator():
            for read in wrapper(isubsample(self, down_to)):
                yield read

        # Do it #
        subsampled.write(iterator())
        # Did it work #
        assert len(subsampled) == down_to
        # Return #
        return subsampled
示例#12
0
文件: fid.py 项目: vdt/DeepPrivacy
def preprocess_images(images, use_multiprocessing):
    """Resizes and shifts the dynamic range of image to 0-1
    Args:
        images: np.array, shape: (N, H, W, 3), dtype: float32 between 0-1 or np.uint8
        use_multiprocessing: If multiprocessing should be used to pre-process the images
    Return:
        final_images: torch.tensor, shape: (N, 3, 299, 299), dtype: torch.float32 between 0-1
    """
    if use_multiprocessing:
        with multiprocessing.Pool(multiprocessing.cpu_count()) as pool:
            jobs = []
            for im in tqdm.tqdm(images, desc="Starting FID jobs"):
                job = pool.apply_async(preprocess_image, (im, ))
                jobs.append(job)
            final_images = torch.zeros(images.shape[0], 3, 299, 299)
            for idx, job in enumerate(tqdm(jobs, desc="finishing jobs")):
                im = job.get()
                final_images[idx] = im  #job.get()
    else:
        final_images = torch.zeros((len(images), 3, 299, 299),
                                   dtype=torch.float32)
        for idx in range(len(images)):
            im = preprocess_image(images[idx])
            final_images[idx] = im
    assert final_images.shape == (images.shape[0], 3, 299, 299)
    assert final_images.max() <= 1.0
    assert final_images.min() >= 0.0
    assert final_images.dtype == torch.float32
    return final_images
示例#13
0
    def train_one_epoch(self):
        """
        Return:
            total_loss: the total loss during training
            accuracy: the mAP
        """
        pred_bboxes, pred_labels, pred_scores = list(), list(), list()
        gt_bboxes, gt_labels, gt_difficults = list(), list(), list()
        self.trainer.reset_meters()
        for ii, (img, sizes, bbox_, label_, scale, gt_difficults_) in \
                tqdm.tqdm(enumerate(self.dataloader)):
            scale = at.scalar(scale)
            img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda()
            self.trainer.train_step(img, bbox, label, scale)
            if (ii + 1) % self.opt.plot_every == 0:
                sizes = [sizes[0][0].item(), sizes[1][0].item()]
                pred_bboxes_, pred_labels_, pred_scores_ = \
                    self.faster_rcnn.predict(img, [sizes])
                pred_bboxes += pred_bboxes_
                pred_labels += pred_labels_
                pred_scores += pred_scores_
                gt_bboxes += list(bbox_.numpy())
                gt_labels += list(label_.numpy())
                gt_difficults += list(gt_difficults_.numpy())

        return self.trainer.get_meter_data()['total_loss']
示例#14
0
    def test(self):
        self.netG.eval()
        self.netD.eval()
        for inputs, label in tqdm.tqdm(self.test_dataloader, desc='test'):
            # convert tensor to variables
            real_img = inputs.cuda().float()
            real_cond = ops.group_to_one_hot(
                label, age_group=self.age_group).cuda().float()
            real_ordinal_cond = ops.group_to_binary(
                label, age_group=self.age_group).cuda().float()
            desired_cond = ops.desired_group_to_one_hot(
                label, self.age_group).cuda().float()

            # test D
            loss_d_prob, loss_d_cond, fake_imgs_masked = self.forward_D(
                real_img, real_cond, real_ordinal_cond, desired_cond)
            loss_d_gp = self.gradient_penalty_D(real_img, fake_imgs_masked)
            # save losses
            self.loss_d_prob.append(loss_d_prob.data.cpu().numpy())
            self.loss_d_cond.append(loss_d_cond.data.cpu().numpy())
            self.loss_d_gp.append(loss_d_gp.data.cpu().numpy())

            # test G
            loss_g_masked_fake, loss_g_masked_cond, loss_g_mask, loss_g_mask_smooth = self.forward_G(
                real_img, real_cond, desired_cond, desired_ordinal_cond)
            loss_G = loss_g_masked_fake + loss_g_masked_cond + loss_g_mask + loss_g_mask_smooth
            # save losses
            self.loss_g_masked_fake.append(
                loss_g_masked_fake.data.cpu().numpy())
            self.loss_g_masked_cond.append(
                loss_g_masked_cond.data.cpu().numpy())
            self.loss_g_mask.append(loss_g_mask.data.cpu().numpy())
            self.loss_g_mask_smooth.append(
                loss_g_mask_smooth.data.cpu().numpy())
示例#15
0
def copytree(src, dst, symlinks=False, ignore=None):
    import tqdm
    items = tqdm.tqdm(os.listdir(src))
    fi = False
    for item in items:
        if fi == False:
            fi == item
        s = os.path.join(src, item)
        d = os.path.join(dst, item)
        s = s.replace('/', '\\')
        d = d.replace('/', '\\')
        if os.path.isdir(s):
            try:
                shutil.rmtree(d)
            except:
                pass
            try:
                shutil.copytree(s, d, symlinks, ignore)
            except Exception as ex:
                print('ERROR:', ex.__class__, str(ex))
        else:
            try:
                shutil.copy2(s, d)
            except Exception as ex:
                print('ERROR:', ex.__class__, str(ex))
    items.close()
    return fi
示例#16
0
def greedy_find_clique_number2(graph, progress=False, nodes=None, \
                               ordering_func=None, node_deletion=True):
    graph_order = graph.order()
    if graph_order == 0:
        return 0

    nodes = list(graph.nodes) if ordering_func is None else ordering_func(graph)
        
    output=0
    node_index_iterator = tqdm.tqdm(range(graph_order)) if progress else\
        range(graph_order)
    for i in node_index_iterator:
        node = nodes[i]
        neighbours = set(graph.neighbors(node))
        if node_deletion:
            output = max(output,
                1 + greedy_find_clique_number2(graph.subgraph(neighbours).copy(),
                progress=False, node_deletion=node_deletion))
            graph.remove_node(node)
        else:
            output = max(output,
                1 + greedy_find_clique_number2(graph.subgraph(neighbours),
                progress=False, node_deletion=node_deletion))
        
    return output
示例#17
0
def greedy_find_clique(graph, progress=False, nodes=None, ordering_func=None,
                       node_deletion=True):
    """Finds the graph's clique (largest complete subgraph)
    """
    graph_order = graph.order()
    if graph_order == 0:
        return []

    if nodes is None:
        nodes = list(graph.nodes) if ordering_func is None else \ordering_func(graph)
        
    output=[]
    node_index_iterator = tqdm.tqdm(range(graph_order)) if progress else\
        range(graph_order)
    for i in node_index_iterator:
        node = nodes[i]
        neighbours = set(graph.neighbors(node))
        if node_deletion:
            subgraph_output = [node] + \
                greedy_find_clique(graph.subgraph(neighbours).copy(),
                progress=False, nodes=[u for u in nodes[i+1:] if u in neighbours],
                node_deletion=node_deletion)
            
            graph.remove_node(node)
        else:
            subgraph_output = [node] + \
            greedy_find_clique(graph.subgraph(neighbours),
                progress=False, nodes=[u for u in nodes if u in neighbours],
                node_deletion=node_deletion)
            
        output = output if len(output) >= len(subgraph_output) else subgraph_output
        
    return output
示例#18
0
 def tiftopng(self, path):
     list = os.listdir(path)
     for x in tqdm.tqdm(list):
         one = path + x
         a = skimage.io.imread(one)
         new = path + x.replace('tif', 'png')
         skimage.io.imsave(new, a)
def gen_vocab_lang8(dirs, vocab_out='', current_vocab=''):
    thu1 = thulac.thulac(T2S=True, seg_only=True)  # 默认模式
    dic = {}
    if current_vocab:
        with open(current_vocab) as fv:
            for line in fv:
                dic[line.split()[0]] = int(line.split()[1])
            fv.close()
    cnt = 0
    for dir in dirs:
        with open(dir + '.src.txt') as fs, open(dir + '.trg.txt') as ft:
            doc_num = fs.readlines()
            for line_src in tqdm.tqdm(doc_num):
                # cnt += 1
                # if cnt > 10:
                #     break
                line_trg = ft.readline()
                texts = [k[0] for k in thu1.cut(line_src, text=False)]
                texts.extend([k[0] for k in thu1.cut(line_trg, text=False)])
                for t in set(texts):
                    if t in dic:
                        dic[t] += 1
                    else:
                        dic[t] = 1
    with open(vocab_out, 'w') as fvw:
        for k in dic:
            if k.strip():
                fvw.write(k + ' ' + str(dic[k]) + '\n')
def generate_pianoroll(
    args,
    input_seq_ln,
    model,
    generated,
    samples_dir,
    min_note,
    max_note,
):
    temperature = float(args.temp)

    nr_samples = int(args.nr)

    for i in tqdm.tqdm(list(range(nr_samples))):

        for timestep in range(input_seq_ln, len(generated)):
            start_index = timestep - (input_seq_ln)
            sequence_for_prediction = generated[start_index:timestep]
            #         next_step, att = sample(model, sequence_for_prediction, temperature, withatt=True)
            next_step, _ = sample(model,
                                  sequence_for_prediction,
                                  temperature,
                                  withatt=args.att)
            #         print(att.argsort()[-10:][::-1])
            generated[timestep] = next_step

        generated_noseed = generated[input_seq_ln:]

        new_path = os.path.join(samples_dir,
                                "temp_%s_%s.mid" % (temperature, i))
        save_trim_pianoroll_seq(generated_noseed, min_note, max_note, new_path)
        plot_midifile(new_path, samples_dir,
                      "temp_%s_%s.png" % (temperature, i))
示例#21
0
    def __init__(self,
                 tokenizer: Type[transformers.PreTrainedTokenizer],
                 file_path: str,
                 block_size=512):

        file_path: pathlib.Path = pathlib.Path(file_path)
        assert file_path.is_file(), f'{file_path} is not file'

        block_size = block_size - (tokenizer.max_len -
                                   tokenizer.max_len_single_sentence)

        logger.info("Creating tokens from dataset file at %s", file_path)

        self.examples = []
        with open(str(file_path), encoding="utf-8") as f:
            text = f.read()

            tokenized_text = tokenizer.convert_tokens_to_ids(
                tokenizer.tokenize(text))

            # Truncate in block of block_size
            for i in tqdm.tqdm(
                    range(0,
                          len(tokenized_text) - block_size + 1, block_size)):
                self.examples.append(
                    tokenizer.build_inputs_with_special_tokens(
                        tokenized_text[i:i + block_size]))
示例#22
0
    def eval(self, dataloader, faster_rcnn, test_num=10000):
        pred_bboxes, pred_labels, pred_scores = list(), list(), list()
        gt_bboxes, gt_labels, gt_difficults = list(), list(), list()
        total_losses = list()
        for ii, (imgs, sizes, gt_bboxes_, gt_labels_, scale, gt_difficults_) \
                in tqdm.tqdm(enumerate(dataloader)):
            img = imgs.cuda().float()
            bbox = gt_bboxes_.cuda()
            label = gt_labels_.cuda()
            sizes = [sizes[0][0].item(), sizes[1][0].item()]
            pred_bboxes_, pred_labels_, pred_scores_ = \
                faster_rcnn.predict(imgs, [sizes])
            losses = self.trainer.forward(img, bbox, label, float(scale))
            total_losses.append(losses.total_loss.item())
            gt_bboxes += list(gt_bboxes_.numpy())
            gt_labels += list(gt_labels_.numpy())
            gt_difficults += list(gt_difficults_.numpy())
            pred_bboxes += pred_bboxes_
            pred_labels += pred_labels_
            pred_scores += pred_scores_
            if ii == test_num: break

        result = eval_detection_voc(
            pred_bboxes, pred_labels, pred_scores,
            gt_bboxes, gt_labels, gt_difficults,
            use_07_metric=True)
        total_loss = sum(total_losses) / len(total_losses)
        return total_loss, result
示例#23
0
    def eval(self, dataloader, yolo, test_num=10000):
        labels = []
        sample_metrics = []  # List of tuples (TP, confs, pred)
        total_losses = list()
        Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
        for batch_i, (_, imgs, targets) in enumerate(tqdm.tqdm(dataloader, desc="Detecting objects")):
            # Extract labels
            labels += targets[:, 1].tolist()
            # Rescale target
            targets = Variable(targets.to(self.device), requires_grad=False)

            imgs = Variable(imgs.type(Tensor), requires_grad=False)
            with torch.no_grad():
                loss, outputs = yolo(imgs, targets)
                outputs = non_max_suppression(outputs, conf_thres=0.5, nms_thres=0.5)
                total_losses.append(loss.item())
            targets = targets.to("cpu")
            targets[:, 2:] = xywh2xyxy(targets[:, 2:])
            targets[:, 2:] *= int(self.model_config['img_size'])
            sample_metrics += get_batch_statistics(outputs, targets, iou_threshold=0.5)
        if len(sample_metrics) > 0:
            true_positives, pred_scores, pred_labels = [np.concatenate(x, 0) for x in list(zip(*sample_metrics))]
            precision, recall, AP, f1, ap_class = ap_per_class(true_positives, pred_scores, pred_labels, labels)
        else:
            return 0.0, 0.0, 0.0
        total_loss = sum(total_losses) / len(total_losses)
        return total_loss, AP.mean(), recall.mean()
示例#24
0
 def build_images_folder(data_root, X, labels, dest_folder):
     images = data_folder/"images"
     for i, (x, y) in tqdm.tqdm(enumerate(zip(X, labels))):
         folder = images/dest_folder/f"{y}"
         ensure_folder(folder)
         x = x.numpy()
         image = np.stack([x for ch in range(3)], axis=-1)
         PIL.Image.fromarray(image).save(folder/f"img{y}_{i:06d}.png")
示例#25
0
def window_powers_max_any(width, height, serial):

    previous = dict()

    powers = dict(((i, j, w), window_powers(i, j, serial, w, previous))
                  for w in tqdm.tqdm(range(1, 300)) for i in range(1, 300)
                  for j in range(1, 300))

    return max(previous, key=previous.get)
示例#26
0
    def _eval_train(self, epoch, after_step_funcs=[]):
        self.model.train()

        loss, acc, step_count = 0, 0, 0
        # self.logger.info('epoch %d, rank %d, before loop' % (epoch, self.rank))
        total = len(self.train_dataloader)
        for i, data in tqdm.tqdm(enumerate(self.train_dataloader),
                                 total=total):
            d_data = data

            text, style = d_data['text'].to(self.device), d_data['style'].to(
                self.device)
            text_len = d_data['text_len'].to(self.device)

            outputs = self.model(text[:, 1:], text_len - 1)
            batch_loss = self.criterion(outputs, style)
            batch_acc = (torch.argmax(outputs, dim=1) == style).float().mean()

            full_loss = batch_loss / self.config.batch_split
            full_loss.backward()

            loss += batch_loss.item()
            acc += batch_acc.item()
            step_count += 1

            # self.logger.info('epoch %d, batch %d' % (epoch, i))
            if (i + 1) % self.config.batch_split == 0:
                if self.config.clip_grad is not None:
                    for group in self.optimizer.param_groups:
                        nn.utils.clip_grad_norm_(group['params'],
                                                 self.config.clip_grad)
                # update weights
                self.optimizer.step()
                self.optimizer.zero_grad()

                if self.optimizer.curr_step() % self.config.save_interval == 0:
                    for func in after_step_funcs:
                        func(self.optimizer.curr_step(), self.device)

                # shit log if you are node 0 in every step
                if self.rank == -1 or self.rank == 0:
                    loss /= step_count
                    acc /= step_count

                    self.train_writer.add_scalar('loss/loss', loss,
                                                 self.optimizer.curr_step())
                    self.train_writer.add_scalar('acc/acc', acc,
                                                 self.optimizer.curr_step())
                    self.train_writer.add_scalar('lr/lr',
                                                 self.optimizer.rate(),
                                                 self.optimizer.curr_step())
                    loss, acc, step_count = 0, 0, 0

                # only valid on dev and sample on dev data at every eval_steps
                if self.optimizer.curr_step() % self.config.eval_steps == 0:
                    self._eval_test(epoch, self.optimizer.curr_step(),
                                    self.optimizer.rate())
示例#27
0
def multiprocess_gdf(fxn, gdf, *args, num_cores=None, **kwargs):
    from joblib import Parallel, delayed
    num_cores = num_cores if num_cores else multiprocessing.cpu_count() - 2
    split_dfs = [gdf.iloc[[i]] for i in range(len(gdf))]
    # Run fxn in counts
    results = Parallel(n_jobs=num_cores)(delayed(fxn)(i, *args, **kwargs) for i in tqdm.tqdm(split_dfs))
    # Combine individual gdfs back into one
    output = pd.concat(results)

    return output
示例#28
0
def download_kgml_files(kegg_pathway_ids, path=KEGG_FILES):
    """Download KEGG KGML files by querying the KEGG API.

    :param list kegg_pathway_ids: list of kegg ids
    """
    for kegg_id in tqdm.tqdm(kegg_pathway_ids, desc='Downloading KEGG files'):
        request = requests.get(KEGG_KGML_URL.format(kegg_id))
        with open(os.path.join(path, '{}.xml'.format(kegg_id)), 'w+') as file:
            file.write(request.text)
            file.close()
示例#29
0
def validation(model_dibert, criterion_cls,criterion_mlm,criterion_pp, valid_iter, epoch):
    model_dibert.eval()
    losses_cls = []
    losses_mlm = []
    losses_pp = []

    l_cls = []
    p_cls = []
    l_pp = []
    p_pp = []
    l_mlm = []
    p_mlm = []
    softmax = nn.Softmax(dim=-1)
    print('\nValid_Epoch:', epoch)

    with torch.no_grad():
        for batch in tqdm.tqdm(valid_iter):
            input_ids = batch['input_ids'].cuda()
            attention_mask = batch['attention_mask'].cuda()
            token_type_ids = batch['token_type_ids'].cuda()
            #truelabel_pp = batch['indexes'].cuda()
            truelabel_pp = batch['parent_ids'].cuda()
            truelabel_mlm = batch['mask_ids'].cuda()
            truelabel_cls = batch['cls_label'].cuda()

            logits_cls, logits_mlm, logits_pp = model_dibert(input_ids, attention_mask, token_type_ids)
            ## if out dim is (bs x seqlen x numclass) -> (total_words_batch x numclass)
            ## if true label is (bs x seqlen) -> (total_words_batch)

            loss_mlm = criterion_mlm(logits_mlm.view(-1, model.Config.vocab_size), truelabel_mlm.view(-1))
            loss_cls = criterion_cls(logits_cls.view(-1, 2), truelabel_cls.view(-1))

            if (Config.is_dibert == True):
                #loss_pp = criterion_pp(logits_pp.view(-1, model.Config.max_len), truelabel_pp.view(-1))
                loss_pp = criterion_pp(logits_pp.view(-1, model.Config.vocab_size), truelabel_pp.view(-1))
                losses_pp.append(loss_pp.item())
                pred_pp = softmax(logits_pp).argmax(2)
                nptrue_pp, nppreds_pp = utils.prune_preds(truelabel_pp.view(-1), pred_pp.view(-1))
                l_pp.extend(nptrue_pp)
                p_pp.extend(nppreds_pp)

            losses_cls.append(loss_cls.item())
            losses_mlm.append(loss_mlm.item())

            # for now we are only interested in accuracy and f1 of the classification task
            l_cls.extend(truelabel_cls.cpu().detach().numpy())
            preds_cls = softmax(logits_cls).argmax(1)
            p_cls.extend(preds_cls.view(-1).cpu().detach().numpy())

            pred_mlm = softmax(logits_mlm).argmax(2)
            nptrue_mlm, nppreds_mlm = utils.prune_preds(truelabel_mlm.view(-1), pred_mlm.view(-1))
            l_mlm.extend(nptrue_mlm)
            p_mlm.extend(nppreds_mlm)

    return losses_cls, losses_mlm, losses_pp, l_cls, p_cls, l_mlm, p_mlm, l_pp, p_pp
示例#30
0
def ap_per_class(tp, conf, pred_cls, target_cls):
    """ Compute the average precision, given the recall and precision curves.
    Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
    # Arguments
        tp:    True positives (list).
        conf:  Objectness value from 0-1 (list).
        pred_cls: Predicted object classes (list).
        target_cls: True object classes (list).
    # Returns
        The average precision as computed in py-faster-rcnn.
    """

    # Sort by objectness
    i = np.argsort(-conf)
    tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]

    # Find unique classes
    unique_classes = np.unique(target_cls)

    # Create Precision-Recall curve and compute AP for each class
    ap, p, r = [], [], []
    for c in tqdm.tqdm(unique_classes, desc="Computing AP"):
        i = pred_cls == c
        n_gt = (target_cls == c).sum()  # Number of ground truth objects
        n_p = i.sum()  # Number of predicted objects

        if n_p == 0 and n_gt == 0:
            continue
        elif n_p == 0 or n_gt == 0:
            ap.append(0)
            r.append(0)
            p.append(0)
        else:
            # Accumulate FPs and TPs
            fpc = (1 - tp[i]).cumsum()
            tpc = (tp[i]).cumsum()

            # Recall
            recall_curve = tpc / (n_gt + 1e-16)
            r.append(recall_curve[-1])

            # Precision
            precision_curve = tpc / (tpc + fpc)
            p.append(precision_curve[-1])

            # AP from recall-precision curve
            ap.append(compute_ap(recall_curve, precision_curve))

    # Compute F1 score (harmonic mean of precision and recall)
    p, r, ap = np.array(p), np.array(r), np.array(ap)
    f1 = 2 * p * r / (p + r + 1e-16)

    return p, r, ap, f1, unique_classes.astype("int32")
示例#31
0
def select(save_path):

    train_word = np.load('./data/train_word.npy')
    train_pos1 = np.load('./data/train_pos1.npy')
    train_pos2 = np.load('./data/train_pos2.npy')
    train_entitypair = np.load('./data/train_entitypair.npy')
    y_train = np.load('data/train_y.npy')

    all_sentence_ebd = np.load('./data/all_sentence_ebd.npy')
    all_reward = np.load('./data/all_reward.npy')
    entity_ebd = np.load('origin_data/entity_ebd.npy')

    selected_word = []
    selected_pos1 = []
    selected_pos2 = []
    selected_y = []

    g_rl = tf.Graph()
    sess2 = tf.Session(graph=g_rl)
    env = environment(230)


    with g_rl.as_default():
        with sess2.as_default():

            myAgent = agent(0.02, entity_ebd, 460)
            init = tf.global_variables_initializer()
            sess2.run(init)
            saver = tf.train.Saver()
            saver.restore(sess2, save_path=save_path)
            g_rl.finalize()


            for epoch in range(1):

                total_reward = []
                num_chosen = 0

                all_list = list(range(len(all_sentence_ebd)))

                for batch in tqdm.tqdm(all_list):

                    batch_en1 = train_entitypair[batch][0]
                    batch_en2 = train_entitypair[batch][1]
                    batch_sentence_ebd = all_sentence_ebd[batch]
                    batch_reward = all_reward[batch]
                    batch_len = len(batch_sentence_ebd)

                    batch_word = train_word[batch]
                    batch_pos1 = train_pos1[batch]
                    batch_pos2 = train_pos2[batch]
                    batch_y = [y_train[batch] for x in range(len(batch_word))]

                    # reset environment
                    state = env.reset(batch_en1, batch_en2, batch_sentence_ebd, batch_reward)
                    old_prob = []

                    # get action
                    # start = time.time()
                    for i in range(batch_len):
                        state_in = np.append(state[0], state[1])
                        feed_dict = {}
                        feed_dict[myAgent.entity1] = [state[2]]
                        feed_dict[myAgent.entity2] = [state[3]]
                        feed_dict[myAgent.state_in] = [state_in]
                        prob = sess2.run(myAgent.prob, feed_dict=feed_dict)
                        old_prob.append(prob[0])
                        action = decide_action(prob)
                        # produce data for training cnn model
                        state = env.step(action)
                        if action == 1:
                            num_chosen+=1
                    #print (old_prob)
                    chosen_reward = [batch_reward[x] for x in env.list_selected]
                    total_reward += chosen_reward

                    selected_word += [batch_word[x] for x in env.list_selected]
                    selected_pos1 += [batch_pos1[x] for x in env.list_selected]
                    selected_pos2 += [batch_pos2[x] for x in env.list_selected]
                    selected_y += [batch_y[x] for x in env.list_selected]
                print(num_chosen)
    selected_word = np.array(selected_word)
    selected_pos1 = np.array(selected_pos1)
    selected_pos2 = np.array(selected_pos2)
    selected_y = np.array(selected_y)

    np.save('cnndata/selected_word.npy',selected_word)
    np.save('cnndata/selected_pos1.npy', selected_pos1)
    np.save('cnndata/selected_pos2.npy', selected_pos2)
    np.save('cnndata/selected_y.npy', selected_y)
示例#32
0
def train():

    train_word = np.load('./data/train_word.npy')
    train_pos1 = np.load('./data/train_pos1.npy')
    train_pos2 = np.load('./data/train_pos2.npy')
    train_entitypair = np.load('./data/train_entitypair.npy')
    y_train = np.load('data/train_y.npy')

    all_sentence_ebd = np.load('./data/all_sentence_ebd.npy')
    all_reward= np.load('./data/all_reward.npy')
    average_reward = np.load('data/average_reward.npy')
    entity_ebd = np.load('origin_data/entity_ebd.npy')


    g_cnn = tf.Graph()
    g_rl = tf.Graph()
    sess1 = tf.Session(graph=g_cnn)
    sess2 = tf.Session(graph=g_rl)


    with g_cnn.as_default():
        with sess1.as_default():
            interact = cnnmodel.interaction(sess1,save_path='model/origin_cnn_model.ckpt')
            tvars_best_cnn = interact.tvars()
            for index, var in enumerate(tvars_best_cnn):
                tvars_best_cnn[index] = var * 0

    g_cnn.finalize()
    env = environment(230)
    best_score = -100000



    with g_rl.as_default():
        with sess2.as_default():


            myAgent = agent(0.02,entity_ebd,460)
            updaterate = 0.01
            num_epoch = 25
            sampletimes = 3
            best_reward = -100000

            init = tf.global_variables_initializer()
            sess2.run(init)
            saver = tf.train.Saver()
            saver.restore(sess2, save_path='rlmodel/origin_rl_model.ckpt')

            tvars_best_rl = sess2.run(myAgent.tvars)
            for index, var in enumerate(tvars_best_rl):
                tvars_best_rl[index] = var * 0

            tvars_old = sess2.run(myAgent.tvars)


            gradBuffer = sess2.run(myAgent.tvars)
            for index, grad in enumerate(gradBuffer):
                gradBuffer[index] = grad * 0

            g_rl.finalize()


            for epoch in range(num_epoch):

                update_word = []
                update_pos1 = []
                update_pos2 = []
                update_y    = []

                all_list = list(range(len(all_sentence_ebd)))
                total_reward = []

                # shuffle bags
                random.shuffle(all_list)

                print ('update the rlmodel')
                for batch in tqdm.tqdm(all_list):
                #for batch in tqdm.tqdm(range(10000)):

                    batch_en1 = train_entitypair[batch][0]
                    batch_en2 = train_entitypair[batch][1]
                    batch_sentence_ebd = all_sentence_ebd[batch]
                    batch_reward = all_reward[batch]
                    batch_len = len(batch_sentence_ebd)

                    batch_word = train_word[batch]
                    batch_pos1 = train_pos1[batch]
                    batch_pos2 = train_pos2[batch]
                    batch_y = [y_train[batch] for x in range(len(batch_word))]


                    list_list_state = []
                    list_list_action = []
                    list_list_reward = []
                    avg_reward  = 0


                    # add sample times
                    for j in range(sampletimes):
                        #reset environment
                        state = env.reset( batch_en1, batch_en2,batch_sentence_ebd,batch_reward)
                        list_action = []
                        list_state = []
                        old_prob = []


                        #get action
                        #start = time.time()
                        for i in range(batch_len):

                            state_in = np.append(state[0],state[1])
                            feed_dict = {}
                            feed_dict[myAgent.entity1] = [state[2]]
                            feed_dict[myAgent.entity2] = [state[3]]
                            feed_dict[myAgent.state_in] = [state_in]
                            prob = sess2.run(myAgent.prob,feed_dict = feed_dict)
                            old_prob.append(prob[0])
                            action = get_action(prob)
                            '''
                            if action == None:
                                print (123)
                            action = 1
                            '''
                            #add produce data for training cnn model
                            list_action.append(action)
                            list_state.append(state)
                            state = env.step(action)
                        #end = time.time()
                        #print ('get action:',end - start)

                        if env.num_selected == 0:
                            tmp_reward = average_reward
                        else:
                            tmp_reward = env.reward()

                        avg_reward += tmp_reward
                        list_list_state.append(list_state)
                        list_list_action.append(list_action)
                        list_list_reward.append(tmp_reward)


                    avg_reward = avg_reward / sampletimes
                    # add sample times
                    for j in range(sampletimes):

                        list_state = list_list_state[j]
                        list_action = list_list_action[j]
                        reward = list_list_reward[j]

                        # compute gradient
                        # start = time.time()
                        list_reward = [reward - avg_reward for x in range(batch_len)]
                        list_state_in = [np.append(state[0],state[1]) for state in list_state]
                        list_entity1 = [state[2] for state in list_state]
                        list_entity2 = [state[3] for state in list_state ]

                        feed_dict = {}
                        feed_dict[myAgent.state_in] = list_state_in
                        feed_dict[myAgent.entity1] = list_entity1
                        feed_dict[myAgent.entity2] = list_entity2
                        feed_dict[myAgent.reward_holder] = list_reward
                        feed_dict[myAgent.action_holder] = list_action

                        grads = sess2.run(myAgent.gradients, feed_dict=feed_dict)
                        for index, grad in enumerate(grads):
                            gradBuffer[index] += grad
                        #end = time.time()
                        #print('get loss and update:', end - start)

                    #decide action and compute reward
                    state = env.reset(batch_en1, batch_en2, batch_sentence_ebd, batch_reward)
                    old_prob = []
                    for i in range(batch_len):
                        state_in = np.append(state[0], state[1])
                        feed_dict = {}
                        feed_dict[myAgent.entity1] = [state[2]]
                        feed_dict[myAgent.entity2] = [state[3]]
                        feed_dict[myAgent.state_in] = [state_in]
                        prob = sess2.run(myAgent.prob, feed_dict=feed_dict)
                        old_prob.append(prob[0])
                        action = decide_action(prob)
                        state = env.step(action)
                    chosen_reward = [batch_reward[x] for x in env.list_selected]
                    total_reward += chosen_reward

                    update_word += [batch_word[x] for x in env.list_selected]
                    update_pos1 += [batch_pos1[x] for x in env.list_selected]
                    update_pos2 += [batch_pos2[x] for x in env.list_selected]
                    update_y += [batch_y[x] for x in env.list_selected]
                print ('finished')

                #print (len(update_word),len(update_pos1),len(update_pos2),len(update_y),updaterate)

                #train and update cnnmodel
                print('update the cnnmodel')
                interact.update_cnn(update_word,update_pos1,update_pos2,update_y,updaterate)
                print('finished')

                #produce new embedding
                print ('produce new embedding')
                average_reward, all_sentence_ebd, all_reward = interact.produce_new_embedding()
                average_score = average_reward
                print ('finished')

                #update the rlmodel
                #apply gradient
                feed_dict = dictionary = dict(zip(myAgent.gradient_holders, gradBuffer))
                sess2.run(myAgent.update_batch, feed_dict=feed_dict)
                for index, grad in enumerate(gradBuffer):
                    gradBuffer[index] = grad * 0

                #get tvars_new
                tvars_new = sess2.run(myAgent.tvars)

                # update old variables of the target network
                tvars_update = sess2.run(myAgent.tvars)
                for index, var in enumerate(tvars_update):
                    tvars_update[index] = updaterate * tvars_new[index] + (1-updaterate) * tvars_old[index]

                feed_dict = dictionary = dict(zip(myAgent.tvars_holders, tvars_update))
                sess2.run(myAgent.update_tvar_holder, feed_dict)
                tvars_old = sess2.run(myAgent.tvars)
                #break


                #find the best parameters
                chosen_size = len(total_reward)
                total_reward = np.mean(np.array(total_reward))


                if (total_reward > best_reward):
                    best_reward = total_reward
                    tvars_best_rl = tvars_old

                if  average_score > best_score:
                    best_score = average_score
                    #tvars_best_rl = tvars_old
                print ('epoch:',epoch)
                print ('chosen sentence size:',chosen_size)
                print ('total_reward:',total_reward)
                print ('best_reward',best_reward)
                print ('average score',average_score)
                print ('best score',best_score)


            #set parameters = best_tvars
            feed_dict = dictionary = dict(zip(myAgent.tvars_holders, tvars_best_rl))
            sess2.run(myAgent.update_tvar_holder, feed_dict)
            #save model
            saver.save(sess2, save_path='rlmodel/union_rl_model.ckpt')

    #interact.update_tvars(tvars_best_cnn)
    interact.save_cnnmodel(save_path='model/union_cnn_model.ckpt')
示例#33
0
def train():


    train_entitypair = np.load('./data/train_entitypair.npy')
    all_sentence_ebd = np.load('./data/all_sentence_ebd.npy')
    all_reward= np.load('./data/all_reward.npy')
    average_reward = np.load('data/average_reward.npy')
    entity_ebd = np.load('origin_data/entity_ebd.npy')

    g_rl = tf.Graph()
    sess2 = tf.Session(graph=g_rl)
    env = environment(230)


    with g_rl.as_default():
        with sess2.as_default():

            myAgent = agent(0.03,entity_ebd,460)
            updaterate = 1
            num_epoch = 25
            sampletimes = 3
            best_reward = -100000

            init = tf.global_variables_initializer()
            sess2.run(init)
            saver = tf.train.Saver()
            #saver.restore(sess2, save_path='rlmodel/rl.ckpt')

            tvars_best = sess2.run(myAgent.tvars)
            for index, var in enumerate(tvars_best):
                tvars_best[index] = var * 0

            tvars_old = sess2.run(myAgent.tvars)


            gradBuffer = sess2.run(myAgent.tvars)
            for index, grad in enumerate(gradBuffer):
                gradBuffer[index] = grad * 0

            g_rl.finalize()

            for epoch in range(num_epoch):

                all_list = list(range(len(all_sentence_ebd)))
                total_reward = []

                # shuffle bags
                random.shuffle(all_list)

                for batch in tqdm.tqdm(all_list):
                #for batch in tqdm.tqdm(range(10000)):

                    batch_en1 = train_entitypair[batch][0]
                    batch_en2 = train_entitypair[batch][1]
                    batch_sentence_ebd = all_sentence_ebd[batch]
                    batch_reward = all_reward[batch]
                    batch_len = len(batch_sentence_ebd)

                    list_list_state = []
                    list_list_action = []
                    list_list_reward = []
                    avg_reward  = 0


                    # add sample times
                    for j in range(sampletimes):
                        #reset environment
                        state = env.reset( batch_en1, batch_en2,batch_sentence_ebd,batch_reward)
                        list_action = []
                        list_state = []
                        old_prob = []


                        #get action
                        #start = time.time()
                        for i in range(batch_len):

                            state_in = np.append(state[0],state[1])
                            feed_dict = {}
                            feed_dict[myAgent.entity1] = [state[2]]
                            feed_dict[myAgent.entity2] = [state[3]]
                            feed_dict[myAgent.state_in] = [state_in]
                            prob = sess2.run(myAgent.prob,feed_dict = feed_dict)
                            old_prob.append(prob[0])
                            action = get_action(prob)
                            #add produce data for training cnn model
                            list_action.append(action)
                            list_state.append(state)
                            state = env.step(action)
                        #end = time.time()
                        #print ('get action:',end - start)

                        if env.num_selected == 0:
                            tmp_reward = average_reward
                        else:
                            tmp_reward = env.reward()

                        avg_reward += tmp_reward
                        list_list_state.append(list_state)
                        list_list_action.append(list_action)
                        list_list_reward.append(tmp_reward)


                    avg_reward = avg_reward / sampletimes
                    # add sample times
                    for j in range(sampletimes):

                        list_state = list_list_state[j]
                        list_action = list_list_action[j]
                        reward = list_list_reward[j]

                        # compute gradient
                        # start = time.time()
                        list_reward = [reward - avg_reward for x in range(batch_len)]
                        list_state_in = [np.append(state[0],state[1]) for state in list_state]
                        list_entity1 = [state[2] for state in list_state]
                        list_entity2 = [state[3] for state in list_state ]

                        feed_dict = {}
                        feed_dict[myAgent.state_in] = list_state_in
                        feed_dict[myAgent.entity1] = list_entity1
                        feed_dict[myAgent.entity2] = list_entity2
                        feed_dict[myAgent.reward_holder] = list_reward
                        feed_dict[myAgent.action_holder] = list_action
                        '''
                        loss =sess2.run(myAgent.loss, feed_dict=feed_dict)
                        if loss == float("-inf"):
                            probs,pis = sess2.run([myAgent.prob,myAgent.pi], feed_dict=feed_dict)
                            print(' ')
                            print ('batch:',batch)
                            print (old_prob)
                            print (list_action)
                            print(probs)
                            print (pis)
                            print('error!')
                            return 0
                        '''
                        grads = sess2.run(myAgent.gradients, feed_dict=feed_dict)
                        for index, grad in enumerate(grads):
                            gradBuffer[index] += grad
                        #end = time.time()
                        #print('get loss and update:', end - start)
                        '''
                        print (len(list_state),len(list_action),len(list_reward),len(list_entity1),len(list_entity2))
                        print (list_action)
                        print (list_reward)
                        print (list_entity1)
                        print (list_entity2)
                        break
                        '''
                    #decide action and compute reward
                    state = env.reset(batch_en1, batch_en2, batch_sentence_ebd, batch_reward)
                    old_prob = []
                    for i in range(batch_len):
                        state_in = np.append(state[0], state[1])
                        feed_dict = {}
                        feed_dict[myAgent.entity1] = [state[2]]
                        feed_dict[myAgent.entity2] = [state[3]]
                        feed_dict[myAgent.state_in] = [state_in]
                        prob = sess2.run(myAgent.prob, feed_dict=feed_dict)
                        old_prob.append(prob[0])
                        action = decide_action(prob)
                        state = env.step(action)
                    chosen_reward = [batch_reward[x] for x in env.list_selected]
                    total_reward += chosen_reward


                #apply gradient
                feed_dict = dictionary = dict(zip(myAgent.gradient_holders, gradBuffer))
                sess2.run(myAgent.update_batch, feed_dict=feed_dict)
                for index, grad in enumerate(gradBuffer):
                    gradBuffer[index] = grad * 0

                #get tvars_new
                tvars_new = sess2.run(myAgent.tvars)

                # update old variables of the target network
                tvars_update = sess2.run(myAgent.tvars)
                for index, var in enumerate(tvars_update):
                    tvars_update[index] = updaterate * tvars_new[index] + (1-updaterate) * tvars_old[index]

                feed_dict = dictionary = dict(zip(myAgent.tvars_holders, tvars_update))
                sess2.run(myAgent.update_tvar_holder, feed_dict)
                tvars_old = sess2.run(myAgent.tvars)
                #break


                #find the best parameters
                chosen_size = len(total_reward)
                total_reward = np.mean(np.array(total_reward))

                if (total_reward > best_reward):
                    best_reward = total_reward
                    tvars_best = tvars_old
                print ('chosen sentence size:',chosen_size)
                print ('total_reward:',total_reward)
                print ('best_reward',best_reward)


            #set parameters = best_tvars
            feed_dict = dictionary = dict(zip(myAgent.tvars_holders, tvars_best))
            sess2.run(myAgent.update_tvar_holder, feed_dict)
            #save model
            saver.save(sess2, save_path='rlmodel/origin_rl_model.ckpt')