예제 #1
0
    def deserialize(file: IO):
        cnt, *points = file.readline().strip().split(",")
        cnt = int(cnt)
        points = list(map(int, points))
        assert len(
            points
        ) == cnt * 4, "Can't deserialize Fern. count = {}, coords = {}"

        return Fern(cnt, list(grouper(grouper(points, 2), 2)))
예제 #2
0
def compute_checksum(heigth, width, input_string):

    rows = grouper(input_string, width)
    layers = grouper(rows, heigth)

    layers = list(layers)
    LOG.debug(layers)
    idx_zeros_and_checksum = ((idx, number_of_zeros_in_layer(layer), compute_checksu(layer)) for idx, layer in enumerate(layers))
    idx_zeros_and_checksum = list(idx_zeros_and_checksum)
    LOG.debug(idx_zeros_and_checksum)
    idx, num_zeros, checksum = min(idx_zeros_and_checksum, key=lambda s: s[1])
    LOG.debug(f"Layer {idx=} with {num_zeros} zeros has checksum {checksum}")
    return checksum
예제 #3
0
    def parse_gzip(self, filename):
        """
        Translate input in file at path defined by `filename` from FASTQ format after gzip decompression.

        Returns:
            List of dictionaries with the following keys:
                * id (str): Identifier for the sequence.
                * desc (str, optional): If the original identifier contains at least one space, desc will be the string after the last one.
                * dna (str): The sequence DNA string.
        """
        gzfile = open(filename, 'rb')
        d = zlib.decompressobj(16 + zlib.MAX_WBITS)
        eof = False
        data = ''
        while True:
            while True:
                cdata = gzfile.read(2000000)
                if not cdata:
                    eof = True
                    break
                data = ''.join([data, d.decompress(cdata)])
                lines = data.split('\n')
                num = len(lines) - 1
                if num >= 4:
                    slines = lines[0:num - (num % 4)]
                    data = '\n'.join(lines[num - (num % 4)::])
                    break
            if eof:
                break
            for ls in grouper(slines, 4):
                yield self._parse_lines(ls)
예제 #4
0
    def scrape_sections_html(self, html):

        def join(x):
            return map(lambda y: ''.join(y), x)

        rows = join(grouper(2, re.split(
            '(<TR>\n<TH CLASS="ddtitle")', html)[1:]))

        title_re = re.compile(
            '<TH CLASS="ddtitle".*><A .*>' +
            '.* - ([0-9]+) - .* (' + course_number_re_fragment + ')' +
            ' - (' + section_name_re_fragment + ')</A></TH>')

        def iter_sections():
            for row in rows:
                section = {}
                match = title_re.search(row.replace('\n', ''))
                if match is None: continue
                yield {
                    'crn': match.group(1),
                    'course': match.group(2),
                    'name': match.group(3),
                }

        return list(iter_sections())
예제 #5
0
    def scrape_courses_html(self, html):
        """
        Don't try to parse HTML with regular expressions, right?
        Unfortunately, since Oscar doesn't believe in closing <tr>
        tags, BeautifulSoup does not parse this table intelligibly.

        This page consists of one giant single-column table.
        Each course is represented by two consecutive rows.
        The first row contains the course number and title.
        The second row contains the description, plus some other things.
        """

        def join(x):
            return map(lambda y: ''.join(y), x)

        rows = join(grouper(2, re.split('(<TR)', html)[1:]))

        title_re = re.compile('CLASS="nttitle".*>[A-Z]+ ('
                              + course_number_re_fragment + ') - (.*)</A></TD>')

        def iter_courses():
            for (row1, row2) in grouper(2, rows):
                course = {}
                match = title_re.search(row1.replace('\n', ''))
                if match is None: continue
                course['number'] = match.group(1)
                course['name'] = match.group(2)
                soup = BeautifulSoup(row2)
                td = soup.find('td')
                d = td.contents[0].strip().replace('\n', ' ') or None
                course['description'] = d
                yield course

        return list(iter_courses())
예제 #6
0
파일: game.py 프로젝트: ericyd/2048
 def group(self, grouping, array=None):
     array = array if array is not None else self.board
     # group into nested list of lists, 4 items each
     nested_array = grouper(array, 4)
     if grouping == 'columns':
         nested_array = transpose(*nested_array)
     return nested_array
예제 #7
0
def main(mode, inputs_dir, output_filename, n_jobs, chunk_size):
    from glob import glob
    if os.path.exists(output_filename):
        user_in = ""
        while len(user_in) != 1 or user_in not in "aoe":
            user_in = raw_input(
                "output filename `%s' already exists, append/overwrite/exit (aoe)?"
                % output_filename)
            if user_in == "e":
                print "Exiting."
                sys.exit(1)
            elif user_in == "o":
                print "Overwriting."
                output_file = open(output_filename, "w")
                with open("done_files.txt", "w") as f:
                    f.write("dummy")
                    pass
            elif user_in == "a":
                print "Appending."
                output_file = open(output_filename, "a")
    else:
        output_file = open(output_filename, "w")
    input_filenames = sorted(glob(os.path.join(inputs_dir, "*.json.bz2")))

    done_files = []
    if os.path.exists("done_files.txt"):
        with open("done_files.txt") as done_fd:
            done_files = [s.strip() for s in done_fd.readlines()]

    input_filenames = [f for f in input_filenames if f not in done_files]
    n_files = len(input_filenames)
    n_processed_files = 0

    widgets = [progressbar.ETA(), progressbar.Percentage()]
    pbar = progressbar.ProgressBar(widgets=widgets, maxval=n_files).start()
    logging.info("Processing %d / %d files", chunk_size, n_files)
    with Parallel(n_jobs=n_jobs) as parallel:
        for chunk in grouper(input_filenames, chunk_size):
            chunk = [f for f in chunk if f is not None]
            if len(chunk) == 0:
                continue
            try:
                sentences = parallel(
                    delayed(cleanup_bz2_file)(filename, mode)
                    for filename in chunk if filename is not None)
            except EOFError:
                continue
            sentences = [l for l in chain(*sentences)]
            logging.info("Found %d sentences in [%s]", len(sentences),
                         ", ".join(chunk))
            for sent in sentences:
                output_file.write(sent)
                output_file.write("\n")
            for f in chunk:
                done_files.append(f)
            with open("done_files.txt", "w") as f:
                f.write("\n".join(done_files))
            n_processed_files += len(chunk)
            pbar.update(n_processed_files)
    pbar.finish()
def parse_xml_file(filename_glob, output_filename, mode, n_jobs):
    data = {}

    from glob import glob
    filenames = sorted(glob(filename_glob))
    n_parsed = 0
    sents = []
    #if os.path.exists(output_filename):
        #print("ERROR: output filename already exists")
        #sys.exit(1)
    outfile = open(output_filename, "a")
    n_ok = 0
    print("Number of files: ", len(filenames))
    with Parallel(n_jobs=n_jobs, verbose=4) as parallel:
        for chunk in grouper(filenames, 10*n_jobs):
            try:
                sents = parallel(delayed(parse_files)([fn], mode) for fn in chunk)
        
                for sent in chain(*sents):
                    if sent.ok:
                            sent.write_yaml(outfile)
                            n_ok += 1
                    n_parsed += 1
                del sents

            except: 
                print("ERROR")

    print("Total OK sentences:", n_ok)
    try:
    	print("Frac of OK sentences:", n_ok / float(n_parsed))
    except:
    	print("No candidates found.")

    outfile.close()
예제 #9
0
def do_evaluate(args):
    """
    Evaluate an existing model.
    """
    logging.info("Evaluating the model.")
    model = get_model_factory(args.model).load(args.model_path)

    data = list(process_snli_data(args.eval_data))
    X1, X2, Y = vectorize_data(data, args.input_length)

    emb = WordEmbeddings()
    cm = ConfusionMatrix(LABELS)
    writer = csv.writer(args.output, delimiter="\t")
    writer.writerow(["sentence1", "sentence2", "gold_label", "guess_label", "neutral", "contradiction", "entailment"])
    for batch in tqdm(grouper(args.batch_size, zip(data, X1, X2, Y)), total=int(len(data)/args.batch_size)):
        objs, X1_batch, X2_batch, y_batch = zip(*batch)
        X1_batch = array([emb.weights[x,:] for x in X1_batch])
        X2_batch = array([emb.weights[x,:] for x in X2_batch])
        y_batch = array(y_batch)

        y_batch_ = model.predict_on_batch([X1_batch, X2_batch])

        for obj, y, y_ in zip(objs, y_batch, y_batch_):
            label = np.argmax(y)
            label_ = np.argmax(y_)
            writer.writerow([
                obj.sentence1,
                obj.sentence2,
                LABELS[label],
                LABELS[label_],
                ] + list(y_))
            cm.update(label, label_)
    cm.print_table()
    cm.summary()
    logging.info("Done.")
예제 #10
0
def update_commutes(session, dests, modes, chunksize=50, delay=5, verbose=True, **kwargs):
    """
    Look up commute distances and times from Google Maps API for posts in the
    database that are missing commute information.
    """

    query = (
        session
        .query(ApartmentPost)
        .outerjoin(ApartmentPost.commutes)
        .filter(not_(ApartmentPost.commutes.any()))
        .filter(not_(ApartmentPost.latitude == None)))

    num_updates = query.count()
    num_processed = 0

    for posts in grouper(chunksize, query):
        _process_batch(session, posts, dests, modes, **kwargs)  
        num_processed += len(posts)
        print "{}/{} commutes processed".format(num_processed, num_updates)
        _random_pause(delay)

    try:
        session.commit()
    except:
        session.rollback()
        raise
    finally:
        session.close()
예제 #11
0
def do_train(args):
    """
    Train the model using the provided arguments.
    """

    # Assumption: it is cheap to store all the data in text form in
    # memory (it's only about 144mb)
    _, X, y = load_data_raw(args.input)
    X_train, y_train, X_val, y_val = split_data(X, y, args.dev_split)

    # Assumption: word vector model will also easily fit in memory.
    wvecs = WordVectorModel.from_file(args.wvecs, False, '*UNKNOWN*')

    # Typical values are 50, 50
    input_shape = (1,args.n_words, wvecs.dim)
    output_shape = len(LABELS)

    # Build model
    model = build_model(args, input_shape=input_shape, output_shape=output_shape, output_type=args.output_type)

    # Training data on the other hand will not. Each input instance is
    # 50x50 matrix with 8bytes per value: that's about 20kb.
    # Assuming we want to store only about 500mb in memory at a time,
    # that means we want at most 25k items in a batch.
    # Typically minibatches of 32-128 are probably ok. Let's keep it
    # that way?
    for epoch in range(args.n_epochs):
        log("== Training model, epoch {}", epoch)

        scorer = Scorer(model)
        for xy in tqdm(grouper(args.batch_size, zip(X_train, y_train))):
            X_batch, y_batch = zip(*xy)
            X_batch, y_batch = wvecs.embed_sentences(X_batch), array(make_one_hot(y_batch, len(LABELS)))
            score = model.train_on_batch(X_batch, y_batch)
            scorer.update(score, len(X_batch))
        log("=== train error: {}", scorer)

        scorer = Scorer(model)
        for xy in tqdm(grouper(args.batch_size, zip(X_val, y_val))):
            X_batch, y_batch = zip(*xy)
            X_batch, y_batch = wvecs.embed_sentences(X_batch), array(make_one_hot(y_batch, len(LABELS)))
            score = model.test_on_batch(X_batch, y_batch)
            scorer.update(score, len(X_batch))
        log("=== val error: {}", scorer)

    ## Save the model
    save_model(model, args.model, args.weights)
	def exhaust_generator(self, sess):
		self.print('Starting exhaust_generator')
		self.generator = self.generator_fn(*self.generator_args)
		if self.enqueue_many is not None:
			self.generator = (tuple(zip(*group))
							  for group in grouper(self.generator, self.enqueue_many, self.fill_value))
		(self.multi_loop if isinstance(self.placeholders, list) else self.mono_loop)(sess)
		self.print('Loop ended')
예제 #13
0
 def iter_courses():
     for (row1, row2) in grouper(2, rows):
         course = {}
         match = title_re.search(row1.replace('\n', ''))
         if match is None: continue
         course['number'] = match.group(1)
         course['name'] = match.group(2)
         soup = BeautifulSoup(row2)
         td = soup.find('td')
         d = td.contents[0].strip().replace('\n', ' ') or None
         course['description'] = d
         yield course
예제 #14
0
    async def get_summoner_names_by_ids(self, summoner_ids):
        """Get summoner names by their ids

        Keyword arguments:
        summoner_ids -- list of summoner ids to query
        """
        results = []
        for subset in util.grouper(summoner_ids, 40):
            url = self.base_summoner_url + ','.join(str(summoner_id) for summoner_id in subset if summoner_id) + '/name'
            results.append(await self.get(url))

        return util.dict_merge(results)
예제 #15
0
    async def get_summoners_info_by_names(self, summoner_names):
        """Get info about summoners by summoner names

        Keyword arguments:
        summoner_names -- list of summoner names to query
        """
        results = []
        for subset in util.grouper(summoner_names, 40):
            url = self.base_summoner_url + 'by-name/' + ','.join(name for name in subset if name)
            results.append(await self.get(url))

        return util.dict_merge(results)
예제 #16
0
def evaluate(args, emb, model, X1X2Y, total=None):
    cm = ConfusionMatrix(LABELS)
    for batch in tqdm(grouper(args.batch_size, X1X2Y), total=int(total/args.batch_size)):
        X1_batch, X2_batch, y_batch = zip(*batch)
        X1_batch = array([emb.weights[x,:] for x in X1_batch])
        X2_batch = array([emb.weights[x,:] for x in X2_batch])
        y_batch = array(y_batch)

        y_batch_ = model.predict_on_batch([X1_batch, X2_batch])
        for y, y_ in zip(y_batch, y_batch_): cm.update(np.argmax(y), np.argmax(y_))
    cm.print_table()
    cm.summary()
    return cm
예제 #17
0
    def get_summoners_info_by_names(self, summoner_names):
        """Get info about summoners by summoner names

        Keyword arguments:
        summoner_names -- list of summoner names to query
        """

        results = []
        for subset in util.grouper(summoner_names, 40):
            url = self.base_url + "by-name/" + ",".join(name for name in subset if name)
            results.append(LeagueRequest.get(url))

        return util.dict_merge(results)
예제 #18
0
    def get_summoner_names_by_ids(self, summoner_ids):
        """Get summoner names by their ids

        Keyword arguments:
        summoner_ids -- list of summoner ids to query
        """

        results = []
        for subset in util.grouper(summoner_ids, 40):
            url = self.base_url + ",".join(str(summoner_id) for summoner_id in subset if summoner_id) + "/name"
            results.append(LeagueRequest.get(url))

        return util.dict_merge(results)
예제 #19
0
파일: scraper.py 프로젝트: alecgorge/grouch
 def iter_courses():
     for (row1, row2) in grouper(2, rows):
         course = {}
         match = title_re.search(row1.replace("\n", ""))
         if match is None:
             continue
         course["number"] = match.group(1)
         course["name"] = match.group(2)
         soup = BeautifulSoup(row2)
         td = soup.find("td")
         d = td.contents[0].strip().replace("\n", " ") or None
         course["description"] = d
         yield course
예제 #20
0
 def pretty(self):
     """Returns a human-readable representation."""
     horizontal_line = ("+",) + (("-",) * 7 + ("+",)) * 3 + ("\n",)
     r = []
     r.extend(horizontal_line)
     for square_row in range(3):
         for row_index in range(square_row * 3, square_row * 3 + 3):
             for group in util.grouper(self.row(row_index), 3):
                 r.append("| ")
                 for cell in group:
                     r.append(str(cell) if cell is not None else "x")
                     r.append(" ")
             r.append("|\n")
         r.extend(horizontal_line)
     return "".join(r[:-1])
예제 #21
0
    async def get_prods_by_query(self, **kwargs):
        prod_details, total_pages = await self._get_products_by_query_for_page(
            page=1, return_page_num=True, **kwargs)

        if total_pages > 1:
            for sub_pages in util.grouper(range(total_pages), 5):
                details = await asyncio.gather(*[
                    self._get_products_by_query_for_page(page=page, **kwargs)
                    for page in sub_pages if page
                ])

                for detail in details:
                    prod_details.extend(detail)

        return prod_details
예제 #22
0
def solve_part_one():
    p = get_program('../inputs/day13.txt')
    c = Computer(p, Queue(), Queue())
    c.run_until_stop()
    all_outputs = queue_to_list(c.output_queue)
    all_outputs = all_outputs[:
                              -1]  # skip the final "program terminated"-output

    screen = defaultdict(int)
    for x, y, tile_id in grouper(all_outputs, 3):

        assert tile_id in tile_ids, f"invalid code: {tile_id}"
        screen[(x, y)] = tile_id

    return sum(1 for tid in screen.values() if tid == BLOCK)
예제 #23
0
파일: plugin.py 프로젝트: anthonyv5/webgdb
 def compute_data_view(self, view):
     unit = view['unit']
     count = view['count']
     addr = ez.eval_location(view['location'])
     data = ez.read(addr, unit * count)
     words = []
     for x in util.grouper(data, unit):
         val = util.unpack_le(x)
         words.append({
             'address': addr,
             'value': val,
             'smart': ez.make_smart(val),
         })
         addr += unit
     assert len(words) == count
     return words
예제 #24
0
def measure_dataset(detector,
                    video,
                    frame_flags: TextIO,
                    gt_homography: TextIO,
                    gt_points: TextIO,
                    sample=None,
                    explore=False):
    if explore:
        logger.debug("Explore enabled")

    h, w = np.shape(sample)[:2]
    sample_bounds = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
    filter_bounds = sample_bounds.copy()
    filter_vel = np.float32([[0, 0], [0, 0], [0, 0], [0, 0]])
    alpha = 0.5
    beta = 0.2

    logger.debug("Start iterating over frames")
    result = []
    for idx, (frame, flag, Hline, Pline) in \
            enumerate(zip(get_frames(video), frame_flags, gt_homography, gt_points)):
        logger.debug("Evaluating frame {}".format(idx))
        truth = list(grouper(map(float, Pline.strip().split()), 2))
        flag = int(flag.strip())

        if idx % 2 == 0 or flag > 0:
            logger.debug("Frame {} dropped".format(idx))
            continue

        points, H = detector.detect(frame, orig_bounds=sample_bounds)

        filter_bounds += filter_vel
        if len(points) > 0:
            filter_r = points - filter_bounds
            filter_bounds += alpha * filter_r
            filter_vel += beta * filter_r

        metric = calc_metric(truth, points)

        if explore:
            examine_detection(detector, sample, frame, truth, points)
            examine_detection(detector, sample, frame, truth, filter_bounds)

        logger.debug("Metric value for frame {} = {}".format(idx, metric))
        result.append(metric)

    return result
예제 #25
0
def train_detector(video, gt_points: TextIO):
    logger.info("Start detector training")
    frame = next(util.get_frames(video))

    gt_points = np.array(
        list(util.grouper(map(float,
                              next(gt_points).strip().split()), 2)))
    sample_corners = np.array([[0, 0], [640, 0], [640, 480], [0, 480]],
                              dtype=np.float32)

    H, _ = cv2.findHomography(gt_points, sample_corners, cv2.RANSAC, 5.0)
    sample = cv2.warpPerspective(frame, H, (640, 480))

    detector = fern.FernDetector.train(sample,
                                       max_train_corners=20,
                                       max_match_corners=500)
    logger.info("Detector trained")
    return detector
예제 #26
0
def do_run(args):
    """
    Run the neural net to predict on new data.
    """
    # Load the model and weights
    model = load_model(args.model, args.weights)
    wvecs = WordVectorModel.from_file(args.wvecs, False, '*UNKNOWN*')

    data = ((tweet.id, tokenize(to_ascii(tweet.text))) for tweet in RowObjectFactory.from_stream(csv.reader(args.input, delimiter="\t")))
    writer = csv.writer(args.output, delimiter='\t')
    writer.writerow(['id',] + LABELS)

    for ix in tqdm(grouper(args.batch_size, data)):
        ids_batch, X_batch = zip(*ix)
        X_batch = wvecs.embed_sentences(X_batch)
        labels = model.predict_on_batch(X_batch)
        for id, label in zip(ids_batch, labels):
            writer.writerow([id,] + [float(l) for l in label])
예제 #27
0
	def process_data(self, conf):
		super(Schedule, self).process_data(conf)
		data = conf['data']
		triples = grouper(3, data)
		
		labels, begin_dates, end_dates = zip(*triples)
		
		begin_dates = map(self.parse_date, begin_dates)
		end_dates = map(self.parse_date, end_dates)

		# reconstruct the triples in a new order
		reordered_triples = zip(begin_dates, end_dates, labels)
		
		# because of the reordering, this will sort by begin_date
		#  then end_date, then label.
		reordered_triples.sort()
		
		conf['data'] = reordered_triples
예제 #28
0
def get_mfccs(track, dc=False, n_fft=2048, average=None, normalize=False,
              n_mfcc=20, fmin=20, fmax=None, hop_length=512, n_mels=128, **kwargs):
    audio, sr = librosa.load(track['file_path'])
    mfcc = librosa.feature.mfcc(audio, sr=sr, n_fft=n_fft, n_mfcc=n_mfcc,
                                fmin=fmin, fmax=fmax, hop_length=hop_length,
                                n_mels=n_mels)
    if not dc:
        mfcc = mfcc[1:]
    if normalize:
        # Normalize each feature vector between 0 and 1
        mfcc = mfcc - mfcc.min(axis=0)
        mfcc = mfcc / mfcc.max(axis=0)
    if average and average > 0:
        samples = sr * average / n_fft
        chunks = util.grouper(mfcc.T, samples)
        averaged_chunk = [np.mean(group, axis=0) for group in chunks]
        mfcc = np.array(averaged_chunk).T
    return mfcc
예제 #29
0
    def deserialize(file: IO):
        module_logger.info("Deserialiazing FernDetector from {}".format(
            file.name))
        version = int(file.readline().strip())

        if version != 1:
            msg = "Can't deserialize FernDetector from {}. Incorrect version of model. Expected 1, found {}"\
                .format(file.name, version)
            module_logger.error(msg)
            raise AssertionError(msg)

        num_ferns = int(file.readline().strip())
        ph, pw = map(int, file.readline().strip().split(","))

        with util.Timer("Deserializing ferns"):
            ferns = [Fern.deserialize(file) for _ in range(num_ferns)]

        fern_bits, max_train, max_match = map(
            int,
            file.readline().strip().split(","))

        with util.Timer("Deserializing fern_p"):
            F, C, K = map(int, file.readline().strip().split(","))
            fern_p = np.zeros((F, C, K), dtype=float)
            for fern_idx in range(F):
                for class_idx in range(C):
                    line = list(map(float, file.readline().strip().split(",")))
                    fern_p[fern_idx, class_idx, :] = line

        line = file.readline().strip().split(",")
        key_points = list(util.grouper(map(int, line), 2))

        module_logger.info("Creating FernDetector")
        detector = FernDetector(patch_size=(ph, pw),
                                max_train_corners=max_train,
                                max_match_corners=max_match,
                                ferns=ferns,
                                ferns_p=fern_p,
                                classes_cnt=C,
                                key_points=key_points,
                                fern_bits=fern_bits)
        module_logger.info("Deserialization complete.")
        return detector
예제 #30
0
    async def add_items(self, items):
        temp_id = await self._redis.incr(self._temp_id_key)
        temp_key = self._keyname(f'temp_{temp_id}')

        # Load all items we're adding into a temp table
        for some_items in grouper(1000, items):
            await self._redis.sadd(temp_key, *some_items)

        # Find the new items (i.e. items not already in all_items)
        await self._redis.sdiffstore(temp_key, temp_key, self._items_key)

        # Save new items into all_items and unexplored
        await self._redis.sunionstore(self._items_key, self._items_key,
                                      temp_key)
        await self._redis.sunionstore(self._unexplored_key,
                                      self._unexplored_key, temp_key)

        # And clean up after ourselves
        await self._redis.delete(temp_key)
예제 #31
0
파일: storage.py 프로젝트: 0NtgO/Pyrit
 def __getattr__(self, name):
     if name == "passwords":
         result = zlib.decompress(self.pwbuffer).split('\n')
         assert len(result) == self.numElems
         md = hashlib.md5()
         md.update(self.essid)
         md.update(self.pmkbuffer)
         md.update(self.pwbuffer)
         if md.digest() != self.digest:
             raise DigestError("Digest check failed")
         self.passwords = result
         del self.pwbuffer
     elif name == "pmks":
         result = util.grouper(self.pmkbuffer, 32)
         assert len(result) == self.numElems
         self.pmks = result
         del self.pmkbuffer
     else:
         raise AttributeError
     return result
예제 #32
0
def train(args, emb, model, X1X2Y, total=None):
    """
    Train the model using the embeddings @emb and input data batch X1X2Y.
    """
    cm = ConfusionMatrix(LABELS)
    scorer = Scorer(model.metrics_names)
    for batch in tqdm(grouper(args.batch_size, X1X2Y), total=int(total/args.batch_size)):
        X1_batch, X2_batch, y_batch = zip(*batch)
        X1_batch = array([emb.weights[x,:] for x in X1_batch])
        X2_batch = array([emb.weights[x,:] for x in X2_batch])
        y_batch = array(y_batch)

        score = model.train_on_batch([X1_batch, X2_batch], y_batch)
        scorer.update(score, len(y_batch))
        y_batch_ = model.predict_on_batch([X1_batch, X2_batch])
        for y, y_ in zip(y_batch, y_batch_): cm.update(np.argmax(y), np.argmax(y_))
    logging.info("train error: %s", scorer)
    cm.print_table()
    cm.summary()
    return cm
예제 #33
0
 def __getattr__(self, name):
     if name == "passwords":
         result = zlib.decompress(self.pwbuffer).split('\n')
         assert len(result) == self.numElems
         md = hashlib.md5()
         md.update(self.essid)
         md.update(self.pmkbuffer)
         md.update(self.pwbuffer)
         if md.digest() != self.digest:
             raise DigestError("Digest check failed")
         self.passwords = result
         del self.pwbuffer
     elif name == "pmks":
         result = util.grouper(self.pmkbuffer, 32)
         assert len(result) == self.numElems
         self.pmks = result
         del self.pmkbuffer
     else:
         raise AttributeError
     return result
예제 #34
0
파일: scraper.py 프로젝트: alecgorge/grouch
    def scrape_sections_html(self, html):
        def join(x):
            return map(lambda y: "".join(y), x)

        rows = join(grouper(2, re.split('(<TR>\n<TH CLASS="ddtitle")', html)[1:]))

        title_re = re.compile(
            '<TH CLASS="ddtitle".*><A .*>'
            ".* - ([0-9]+) - .* (" + course_number_re_fragment + ")"
            " - (" + section_name_re_fragment + ")</A></TH>"
        )

        def iter_sections():
            for row in rows:
                section = {}
                match = title_re.search(row.replace("\n", ""))
                if match is None:
                    continue
                yield {"crn": match.group(1), "course": match.group(2), "name": match.group(3)}

        return list(iter_sections())
예제 #35
0
def train_detector(video, gt_points: TextIO):
    assert video.isOpened()
    frame = next(get_frames(video))

    gt_points = np.array(list(grouper(map(float, next(gt_points).strip().split()), 2)))
    lx, rx = (gt_points[0, 0] + gt_points[3, 0]) / 2, (gt_points[1, 0] + gt_points[2, 0]) / 2
    ty, by = (gt_points[0, 1] + gt_points[1, 1]) / 2, (gt_points[2, 1] + gt_points[3, 1]) / 2

    w = np.int32(rx - lx)
    h = np.int32(by - ty)

    sample_corners = np.array([[0, 0], [w, 0], [w, h], [0, h]], dtype=np.float32)

    H, _ = cv2.findHomography(gt_points, sample_corners, cv2.RANSAC, 5.0)
    sample = cv2.warpPerspective(frame, H, (w, h))

    detector = fern.FernDetector.train(sample,
                                       deform_param_gen=smart_deformations_gen(sample, 20, 20),
                                       max_train_corners=250,
                                       max_match_corners=500)
    return sample, detector
예제 #36
0
 def __create_thread_objects(self, follow, stream_with):
     '''Split the specified follow list into groups of CONNECTION_LIMIT
     or smaller and then create ListenThread objects for those groups.
     
     >>> follow = range(1,1001)
     >>> monitor = ListenThreadMonitor(follow, consumer, token)
     >>> len(monitor.threads) >= len(follow)/FOLLOW_LIMIT
     True
     
     '''
     threads = []
     chunks = list(grouper(FOLLOW_LIMIT, follow))
     
     for follow in chunks:
         stream = SiteStream(follow, self.consumer, self.token,
                             stream_with, self.parser)
         thread = ListenThread(stream)
         thread.daemon = True
         threads.append(thread)
     
     logger.debug("Created %s new thread objects." % len(threads))
     return threads
예제 #37
0
    def parse_fastq(self, filename):
        """
        Translate input in file at path defined by `filename` from FASTQ format.

        Returns:
            List of dictionaries with the following keys:
                * id (str): Identifier for the sequence.
                * desc (str, optional): If the original identifier contains at least one space, desc will be the string after the last one.
                * dna (str): The sequence DNA string.

        Raises:
            PathError
        """
        id = ''
        desc = ''
        tempseq = []
        try:
            seqfile = open(filename, 'r')
            for lines in grouper(seqfile, 4):
                yield self._parse_lines(lines)
        except OSError:
            raise PathError(''.join(['ERROR: cannot open', refseqpath]))
예제 #38
0
파일: storage.py 프로젝트: slothg/pyrit
 def _unpack(self):
     with self._unpackLock:
         if hasattr(self, '_pwbuffer'):
             pwbuffer = zlib.decompress(self._pwbuffer)
             pwbuffer = pwbuffer.split(self._delimiter)
             assert len(pwbuffer) == self._numElems
             md = hashlib.md5()
             md.update(self.essid)
             if self._magic == 'PYR2':
                 md.update(self._pmkbuffer)
                 md.update(self._pwbuffer)
             else:
                 md.update(self._pmkbuffer)
                 md.update(''.join(pwbuffer))
             if md.digest() != self._digest:
                 raise DigestError("Digest check failed")
             self.results = zip(pwbuffer, util.grouper(self._pmkbuffer, 32))
             assert len(self.results) == self._numElems
             del self._pwbuffer
             del self._digest
             del self._magic
             del self._delimiter
예제 #39
0
    def __create_thread_objects(self, follow, stream_with):
        '''Split the specified follow list into groups of CONNECTION_LIMIT
        or smaller and then create ListenThread objects for those groups.
        
        >>> follow = range(1,1001)
        >>> monitor = ListenThreadMonitor(follow, consumer, token)
        >>> len(monitor.threads) >= len(follow)/FOLLOW_LIMIT
        True
        
        '''
        threads = []
        chunks = list(grouper(FOLLOW_LIMIT, follow))

        for follow in chunks:
            stream = SiteStream(follow, self.consumer, self.token, stream_with,
                                self.parser)
            thread = ListenThread(stream)
            thread.daemon = True
            threads.append(thread)

        logger.debug("Created %s new thread objects." % len(threads))
        return threads
예제 #40
0
 def _unpack(self):
     with self._unpackLock:
         if hasattr(self, '_pwbuffer'):
             pwbuffer = zlib.decompress(self._pwbuffer)
             pwbuffer = pwbuffer.split(self._delimiter)
             assert len(pwbuffer) == self._numElems
             md = hashlib.md5()
             md.update(self.essid)
             if self._magic == 'PYR2':
                 md.update(self._pmkbuffer)
                 md.update(self._pwbuffer)
             else:
                 md.update(self._pmkbuffer)
                 md.update(''.join(pwbuffer))
             if md.digest() != self._digest:
                 raise DigestError("Digest check failed")
             self.results = zip(pwbuffer, util.grouper(self._pmkbuffer, 32))
             assert len(self.results) == self._numElems
             del self._pwbuffer
             del self._digest
             del self._magic
             del self._delimiter
예제 #41
0
 async def mark_explored(self, items):
     for some_items in grouper(1000, items):
         await self._redis.srem(self._unexplored_key, *some_items)
예제 #42
0
def render_image(height, width, input_string):
    flat_layers = grouper(input_string, width*height)
    redering = reduce(render_from_two_layers, flat_layers)
    render = list(map(list,grouper(map(int,redering),width)))
    return render
예제 #43
0
    def run(self):
        for dirname, dirnames, filenames in os.walk(os.path.join(self.project_location,"src")):
            for filename in filenames:
                full_file_path = os.path.join(dirname, filename)
                ext = util.get_file_extension_no_period(full_file_path)
                if ext in apex_extensions_to_check:
                    self.apex_files_to_check.append(full_file_path)
                elif ext in vf_extensions_to_check:
                    self.vf_files_to_check.append(full_file_path)

        apex_parser_threads = []
        vf_parser_threads   = []

        apex_file_chunks    = list(util.grouper(8, self.apex_files_to_check))
        vf_file_chunks      = list(util.grouper(8, self.vf_files_to_check))

        for files in apex_file_chunks:                    
            apex_parser_thread = ApexParser(files)
            apex_parser_threads.append(apex_parser_thread)
            apex_parser_thread.start()

        for thread in apex_parser_threads:
            thread.join()
            if thread.complete:
                self.apex_parser_results.update(thread.result)

        for files in vf_file_chunks:                    
            vf_parser_thread = VfParser(files)
            vf_parser_threads.append(vf_parser_thread)
            vf_parser_thread.start()

        for thread in vf_parser_threads:
            thread.join()
            if thread.complete:
                self.vf_parser_results.update(thread.result)
        
        #pp = pprint.PrettyPrinter(indent=2)
        #pp.pprint(self.parser_results)

        for file_name in self.vf_files_to_check:
            parser_result = self.vf_parser_results[file_name]

            base_name = os.path.basename(file_name)
            self.vf_result[base_name] = {}
            file_body = util.get_file_as_string(file_name)

            ### ACTION POLLERS
            if "actionPollers" not in parser_result:
                #print file_name
                continue

            action_pollers = parser_result["actionPollers"]
            action_poller_matches = []
            if len(action_pollers) > 0:
                for p in action_pollers:
                     
                    line_contents = ""
                    for lnum in range(p["location"]["row"], p["location"]["row"]+1):
                        line_contents += print_file_line(file_name, lnum)
                    
                    p["lineNumber"]       = p["location"]["row"]
                    p["line_contents"]    = line_contents

                    action_poller_matches.append(p)  


                self.result["visualforce_statistics"]["action_poller"]["results"].append(
                    {
                        "file_name" : base_name,
                        "flagged"   : len(action_poller_matches) > 0,
                        "matches"   : action_poller_matches   
                    }
                )
                self.action_poller_count += len(action_poller_matches)

            ### HARDCODED URLS
            output_links = parser_result["outputLinks"]
            link_matches = []
            if len(output_links) > 0:
                for p in output_links:
                     
                    line_contents = ""
                    for lnum in range(p["location"]["row"], p["location"]["row"]+1):
                        line_contents += print_file_line(file_name, lnum)
                    
                    p["lineNumber"]       = p["location"]["row"]
                    p["line_contents"]    = line_contents

                    link_matches.append(p)  


                self.result["visualforce_statistics"]["hardcoded_url"]["results"].append(
                    {
                        "file_name" : base_name,
                        "flagged"   : len(link_matches) > 0,
                        "matches"   : link_matches   
                    }
                )
                self.hardcoded_link_count += len(link_matches)

            ## REFRESHERS
            refreshers = re.finditer(js_refresh_pattern, file_body)          
            js_matches      = []
            meta_matches    = []
            for match in refreshers:
                if match != None and "meta" in match.group(0):
                    match_string = match.group(0).replace("<", "")    
                    meta_matches.append(match_string)
                else:
                    match_string = match.group(0)
                    js_matches.append(match_string)
            if len(js_matches) > 0:
                self.result["visualforce_statistics"]["javascript_refresh"]["results"].append(
                    {
                        "file_name" : base_name,
                        "flagged"   : len(js_matches) > 0,
                        "matches"   : js_matches   
                    }
                )
                self.javascript_refresh_count += len(js_matches)

            if len(meta_matches) > 0:
                self.result["visualforce_statistics"]["meta_refresh"]["results"].append(
                    {
                        "file_name" : base_name,
                        "flagged"   : len(meta_matches) > 0,
                        "matches"   : meta_matches   
                    }
                )
                self.meta_refresh_count += len(meta_matches)



        for file_name in self.apex_files_to_check:
            parser_result = self.apex_parser_results[file_name]

            base_name = os.path.basename(file_name)
            self.apex_result[base_name] = {}
            file_body = util.get_file_as_string(file_name)

            ### WITHOUT SHARING
            without_sharings = re.finditer(without_sharing_pattern, file_body)          
            matches = []
            for match in without_sharings:
                matches.append(match.group(0))
            if len(matches) > 0:
                self.result["apex_statistics"]["without_sharing"]["results"].append(
                    {
                        "file_name" : base_name,
                        "flagged"   : len(matches) > 0,
                        "matches"   : matches   
                    }
                )
                self.without_sharing_count += len(matches)
            
            #print parser_result
            if "forLoops" not in parser_result:
                #print file_name
                continue

            for_loops       = parser_result["forLoops"]
            dml_statements  = parser_result["dmlStatements"]
            queries         = parser_result["queries"]
            methods         = parser_result["methods"]
            classes         = parser_result["classes"]

            #seealldata
            see_all_data_matches     = []
            for m in methods:
                if "annotations" in m and len(m["annotations"]) > 0:
                    for a in m["annotations"]:
                        if "pairs" in a:
                            for p in a["pairs"]:
                                if p["name"].lower() == "seealldata" and p["value"]["value"] == True:
                                    
                                    line_contents = ""
                                    for lnum in range(p["beginLine"], p["beginLine"]+2):
                                        line_contents += print_file_line(file_name, lnum)
                                    
                                    m["lineNumber"]       = p["beginLine"]
                                    m["line_contents"]    = line_contents

                                    see_all_data_matches.append(m)  

            for c in classes:
                if "annotations" in c and len(c["annotations"]) > 0:
                    for a in c["annotations"]:
                        if "pairs" in a:
                            for p in a["pairs"]:
                                if p["name"].lower() == "seealldata" and p["value"]["value"] == True:
                                    
                                    line_contents = ""
                                    for lnum in range(p["beginLine"], p["beginLine"]+2):
                                        line_contents += print_file_line(file_name, lnum)
                                    
                                    c["lineNumber"]       = p["beginLine"]
                                    c["line_contents"]    = line_contents

                                    see_all_data_matches.append(c)  
                        

            if len(see_all_data_matches) > 0:
                    self.result["apex_statistics"]["see_all_data_annotations"]["results"].append(
                        {
                            "file_name" : base_name,
                            "flagged"   : len(see_all_data_matches) > 0,
                            "matches"   : see_all_data_matches   
                        }
                    )
                    self.see_all_data_count += len(see_all_data_matches)

            #SOQL WITHOUT WHERE CLAUSES
            no_where_clause_matches     = []
            negative_operator_matches   = []
            for query in queries:
                line_number = query["lineNumber"]
                lower_query = query["statement"].lower()
                #if ' where ' not in lower_query:
                if where_pattern.search(lower_query) == None:
                    no_where_clause_matches.append(query)
                #if ' not like ' in lower_query or '!=' in lower_query:
                if not_like_pattern.search(lower_query) != None or "!=" in lower_query:
                    negative_operator_matches.append(query)

            if len(no_where_clause_matches) > 0:
                    self.result["apex_statistics"]["soql_no_where_clause"]["results"].append(
                        {
                            "file_name" : base_name,
                            "flagged"   : len(no_where_clause_matches) > 0,
                            "matches"   : no_where_clause_matches   
                        }
                    )
                    self.no_where_clause_count += len(no_where_clause_matches)

            if len(negative_operator_matches) > 0:
                    self.result["apex_statistics"]["soql_negative_operators"]["results"].append(
                        {
                            "file_name" : base_name,
                            "flagged"   : len(negative_operator_matches) > 0,
                            "matches"   : negative_operator_matches   
                        }
                    )
                    self.negative_soql_count += len(negative_operator_matches)


            ### DML INSIDE ITERATORS
            dml_matches     = []
            query_matches   = []

            if len(for_loops) > 0:
                for dml in dml_statements:
                    line_number = dml["statement"]["beginLine"]
                    for loop in for_loops:
                        if loop[0] < line_number < loop[1]:
                            #this is a dml statement inside an iterator
                            line_contents = ""
                            for lnum in range(loop[0], loop[1]+1):
                                line_contents += print_file_line(file_name, lnum)
                            
                            dml["lineNumber"]       = loop[0]
                            dml["line_contents"]    = line_contents
                            dml_matches.append(dml)

                for query in queries:
                    line_number = query["lineNumber"]
                    for loop in for_loops:
                        if loop[0] < line_number < loop[1]:
                            #this is a soql statement inside an iterator
                            query["line_contents"] = print_file_line(file_name, line_number)
                            query_matches.append(query)

                if len(dml_matches) > 0:
                    self.result["apex_statistics"]["dml_for_loop"]["results"].append(
                        {
                            "file_name" : base_name,
                            "flagged"   : len(dml_matches) > 0,
                            "matches"   : dml_matches   
                        }
                    )
                    self.dml_for_loop_count += len(dml_matches)

                if len(query_matches) > 0:
                    self.result["apex_statistics"]["soql_for_loop"]["results"].append(
                        {
                            "file_name" : base_name,
                            "flagged"   : len(query_matches) > 0,
                            "matches"   : query_matches   
                        }
                    )
                    self.soql_for_loop_count += len(query_matches)


        self.result["apex_statistics"]["without_sharing"]["count"]          = self.without_sharing_count
        self.result["apex_statistics"]["dml_for_loop"]["count"]             = self.dml_for_loop_count
        self.result["apex_statistics"]["soql_for_loop"]["count"]            = self.soql_for_loop_count
        self.result["apex_statistics"]["soql_negative_operators"]["count"]  = self.negative_soql_count
        self.result["apex_statistics"]["soql_no_where_clause"]["count"]     = self.no_where_clause_count
        self.result["apex_statistics"]["see_all_data_annotations"]["count"] = self.see_all_data_count

        self.result["visualforce_statistics"]["action_poller"]["count"]      = self.action_poller_count
        self.result["visualforce_statistics"]["javascript_refresh"]["count"] = self.javascript_refresh_count
        self.result["visualforce_statistics"]["meta_refresh"]["count"]       = self.meta_refresh_count
        self.result["visualforce_statistics"]["hardcoded_url"]["count"]      = self.hardcoded_link_count
        return self.result
예제 #44
0
def main():
    # Make directories if they don't already exist
    util.make_directories()
    # Load model options
    model_options = constants.MAIN_MODEL_OPTIONS

    ########## DATA ##########
    if constants.PRINT_MODEL_STATUS: print("Loading data")

    dataset_map = util.load_dataset_map()
    train_captions, val_captions, test_captions = util.load_text_vec(
        'Data', constants.VEC_OUTPUT_FILE_NAME, dataset_map)
    train_image_dict, val_image_dict, test_image_dict = util.get_images(
        'Data', constants.DIRECTORY_PATH, constants.FLOWERS_DICTS_PATH)

    ########## MODEL ##########
    generator = CondBeganGenerator(model_options)
    discriminator = CondBeganDiscriminator(model_options)

    # Put G and D on cuda if GPU available
    if torch.cuda.is_available():
        if constants.PRINT_MODEL_STATUS: print("CUDA is available")
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        if constants.PRINT_MODEL_STATUS: print("Moved models to GPU")

    # Initialize weights
    generator.apply(util.weights_init)
    discriminator.apply(util.weights_init)

    ########## SAVED VARIABLES #########
    new_epoch = 0
    began_k = 0
    train_losses = {"generator": [], "discriminator": [], "converge": []}
    val_losses = {"generator": [], "discriminator": [], "converge": []}
    losses = {'train': train_losses, 'val': val_losses}

    ########## OPTIMIZER ##########
    g_optimizer = optim.Adam(generator.parameters(),
                             lr=constants.LR,
                             betas=constants.BETAS)
    # Changes the optimizer to SGD if declared in constants
    if constants.D_OPTIMIZER_SGD:
        d_optimizer = optim.SGD(discriminator.parameters(), lr=constants.LR)
    else:
        d_optimizer = optim.Adam(discriminator.parameters(),
                                 lr=constants.LR,
                                 betas=constants.BETAS)
    if constants.PRINT_MODEL_STATUS: print("Added optimizers")

    ########## RESUME OPTION ##########
    if args.resume:
        print("Resuming from epoch " + args.resume)
        checkpoint = torch.load(constants.SAVE_PATH + 'weights/epoch' +
                                str(args.resume))
        new_epoch = checkpoint['epoch'] + 1
        generator.load_state_dict(checkpoint['g_dict'])
        discriminator.load_state_dict(checkpoint['d_dict'])
        began_k = checkpoint['began_k']
        g_optimizer.load_state_dict(checkpoint['g_optimizer'])
        d_optimizer.load_state_dict(checkpoint['d_optimizer'])
        losses = torch.load(constants.SAVE_PATH + 'losses')

    ########## VARIABLES ##########
    noise_vec = torch.FloatTensor(constants.BATCH_SIZE, model_options['z_dim'])
    text_vec = torch.FloatTensor(constants.BATCH_SIZE,
                                 model_options['caption_vec_len'])
    real_img = torch.FloatTensor(constants.BATCH_SIZE,
                                 model_options['image_channels'],
                                 constants.IMAGE_SIZE, constants.IMAGE_SIZE)
    real_caption = torch.FloatTensor(constants.BATCH_SIZE,
                                     model_options['caption_vec_len'])
    if constants.USE_CLS:
        wrong_img = torch.FloatTensor(constants.BATCH_SIZE,
                                      model_options['image_channels'],
                                      constants.IMAGE_SIZE,
                                      constants.IMAGE_SIZE)
        wrong_caption = torch.FloatTensor(constants.BATCH_SIZE,
                                          model_options['caption_vec_len'])

    # Add cuda GPU option
    if torch.cuda.is_available():
        noise_vec = noise_vec.cuda()
        text_vec = text_vec.cuda()
        real_img = real_img.cuda()
        real_caption = real_caption.cuda()
        if constants.USE_CLS: wrong_img = wrong_img.cuda()

    ########## Training ##########
    num_iterations = 0
    for epoch in range(new_epoch, constants.NUM_EPOCHS):
        print("Epoch %d" % (epoch))
        st = time.time()

        for i, batch_iter in enumerate(
                util.grouper(train_captions.keys(), constants.BATCH_SIZE)):
            batch_keys = [x for x in batch_iter if x is not None]
            curr_batch_size = len(batch_keys)

            discriminator.train()
            generator.train()
            discriminator.zero_grad()  # Zero out gradient
            # Save computations for gradient calculations
            for p in discriminator.parameters():
                p.requires_grad = True  # Need this to be true to update generator as well

            ########## BATCH DATA #########
            noise_batch = torch.randn(curr_batch_size, model_options['z_dim'])
            text_vec_batch = torch.Tensor(
                util.get_text_description(train_captions, batch_keys))
            real_caption_batch = torch.Tensor(
                util.get_text_description(train_captions, batch_keys))
            real_img_batch = torch.Tensor(
                util.choose_real_image(train_image_dict, batch_keys))
            if constants.USE_CLS:
                wrong_img_batch = torch.Tensor(
                    util.choose_wrong_image(train_image_dict, batch_keys))
            if torch.cuda.is_available():
                noise_batch = noise_batch.cuda()
                text_vec_batch = text_vec_batch.cuda()
                real_caption_batch = real_caption_batch.cuda()
                real_img_batch = real_img_batch.cuda()
                if constants.USE_CLS: wrong_img_batch = wrong_img_batch.cuda()

            # Fill in tensors with batch data
            noise_vec.resize_as_(noise_batch).copy_(noise_batch)
            text_vec.resize_as_(text_vec_batch).copy_(text_vec_batch)
            real_caption.resize_as_(text_vec_batch).copy_(text_vec_batch)
            real_img.resize_as_(real_img_batch).copy_(real_img_batch)
            if constants.USE_CLS:
                wrong_img.resize_as_(wrong_img_batch).copy_(wrong_img_batch)

            ########## RUN THROUGH GAN ##########
            gen_image = generator.forward(Variable(text_vec),
                                          Variable(noise_vec))

            real_img_passed = discriminator.forward(Variable(real_img),
                                                    Variable(real_caption))
            fake_img_passed = discriminator.forward(gen_image.detach(),
                                                    Variable(real_caption))
            if constants.USE_CLS:
                wrong_img_passed = discriminator.forward(
                    Variable(wrong_img), Variable(real_caption))

            ########## TRAIN DISCRIMINATOR ##########
            if constants.USE_REAL_LS:
                # Real loss sensitivity
                # L_D = L(y_r) - k * (L(y_f) + L(y_f, r))
                # L_G = L(y_f) +  L(y_f, r)
                # k = k + lambda_k * (gamma * L(y_r) + L(y_f) +  L(y_f, r))
                d_real_loss = torch.mean(
                    torch.abs(real_img_passed - Variable(real_img)))
                d_fake_loss = torch.mean(torch.abs(fake_img_passed -
                                                   gen_image))
                d_real_sensitivity_loss = torch.mean(
                    torch.abs(fake_img_passed - Variable(real_img)))
                d_loss = d_real_loss - began_k * (
                    0.5 * d_fake_loss + 0.5 * d_real_sensitivity_loss)

                # Update began k value
                balance = (model_options['began_gamma'] * d_real_loss -
                           0.5 * d_fake_loss -
                           0.5 * d_real_sensitivity_loss).data[0]
                began_k = min(
                    max(began_k + model_options['began_lambda_k'] * balance,
                        0), 1)
            elif constants.USE_CLS:
                # Cond BEGAN Discrminator Loss with CLS
                # L(y_w) is the caption loss sensitivity CLS (makes sure that captions match the image)
                # L_D = L(y_r) + L(y_f, w) - k * L(y_f)
                # L_G = L(y_f)
                # k = k + lambda_k * (gamma * (L(y_r) + L(y_f, w)) - L(y_f))
                d_real_loss = torch.mean(
                    torch.abs(real_img_passed - Variable(real_img)))
                d_wrong_loss = torch.mean(
                    torch.abs(fake_img_passed - Variable(wrong_img)))
                d_fake_loss = torch.mean(torch.abs(fake_img_passed -
                                                   gen_image))
                d_loss = 0.5 * d_real_loss + 0.5 * d_wrong_loss - began_k * d_fake_loss

                # Update began k value
                balance = (model_options['began_gamma'] *
                           (0.5 * d_real_loss + 0.5 * d_wrong_loss) -
                           d_fake_loss).data[0]
                began_k = min(
                    max(began_k + model_options['began_lambda_k'] * balance,
                        0), 1)
    # No CLS option
            else:
                # Cond BEGAN Discriminator Loss
                # L_D = L(y_r) - k * L(y_f)
                # k = k + lambda_k * (gamma * L(y_r) + L(y_f))
                d_real_loss = torch.mean(
                    torch.abs(real_img_passed - Variable(real_img)))
                d_fake_loss = torch.mean(torch.abs(fake_img_passed -
                                                   gen_image))
                d_loss = d_real_loss - began_k * d_fake_loss

                # Update began k value
                balance = (model_options['began_gamma'] * d_real_loss -
                           d_fake_loss).data[0]
                began_k = min(
                    max(began_k + model_options['began_lambda_k'] * balance,
                        0), 1)

            d_loss.backward()
            d_optimizer.step()

            ########## TRAIN GENERATOR ##########
            generator.zero_grad()
            for p in discriminator.parameters():
                p.requires_grad = False

            # Generate image again if you want to
            if constants.REGEN_IMAGE:
                noise_batch = torch.randn(curr_batch_size,
                                          model_options['z_dim'])
                if torch.cuda.is_available():
                    noise_batch = noise_batch.cuda()
                noise_vec.resize_as_(noise_batch).copy_(noise_batch)
                gen_image = generator.forward(Variable(text_vec),
                                              Variable(noise_vec))

            new_fake_img_passed = discriminator.forward(
                gen_image, Variable(real_caption))

            # Generator Loss
            # L_G = L(y_f)
            g_loss = torch.mean(torch.abs(new_fake_img_passed - gen_image))
            if constants.USE_REAL_LS:
                g_loss += torch.mean(
                    torch.abs(new_fake_img_passed - Variable(real_img)))
            elif constants.USE_CLS:
                g_loss -= torch.mean(
                    torch.abs(new_fake_img_passed - Variable(wrong_img)))

            g_loss.backward()
            g_optimizer.step()

            # M = L(y_r) + |gamma * L(y_r) - L(y_f)|
            convergence_val = d_real_loss + abs(balance)

            # learning rate decay
            g_optimizer = util.adjust_learning_rate(g_optimizer,
                                                    num_iterations)
            d_optimizer = util.adjust_learning_rate(d_optimizer,
                                                    num_iterations)

            if i % constants.LOSS_SAVE_IDX == 0:
                losses['train']['generator'].append((g_loss.data[0], epoch, i))
                losses['train']['discriminator'].append(
                    (d_loss.data[0], epoch, i))
                losses['train']['converge'].append(
                    (convergence_val.data[0], epoch, i))

            num_iterations += 1

        print('Total number of iterations: ', num_iterations)
        print('Training G Loss: ', g_loss.data[0])
        print('Training D Loss: ', d_loss.data[0])
        print('Training Convergence: ', convergence_val.data[0])
        print('K value: ', began_k)
        epoch_time = time.time() - st
        print("Time: ", epoch_time)

        if epoch == constants.REPORT_EPOCH:
            with open(constants.SAVE_PATH + 'report.txt', 'w') as f:
                f.write(constants.EXP_REPORT)
                f.write("Time per epoch: " + str(epoch_time))
            print("Saved report")

        ########## DEV SET #########
        # Calculate dev set loss
        # Volatile is true because we are running in inference mode (no need to calculate gradients)
        generator.eval()
        discriminator.eval()
        for i, batch_iter in enumerate(
                util.grouper(val_captions.keys(), constants.BATCH_SIZE)):
            batch_keys = [x for x in batch_iter if x is not None]
            curr_batch_size = len(batch_keys)

            # Gather batch data
            noise_batch = torch.randn(curr_batch_size, model_options['z_dim'])
            text_vec_batch = torch.Tensor(
                util.get_text_description(val_captions, batch_keys))
            real_caption_batch = torch.Tensor(
                util.get_text_description(val_captions, batch_keys))
            real_img_batch = torch.Tensor(
                util.choose_real_image(val_image_dict, batch_keys))
            if constants.USE_CLS:
                wrong_img_batch = torch.Tensor(
                    util.choose_wrong_image(val_image_dict, batch_keys))
            if torch.cuda.is_available():
                noise_batch = noise_batch.cuda()
                text_vec_batch = text_vec_batch.cuda()
                real_caption_batch = real_caption_batch.cuda()
                real_img_batch = real_img_batch.cuda()
                if constants.USE_CLS:
                    wrong_img_batch = wrong_img_batch.cuda()

            # Fill in tensors with batch data
            noise_vec.resize_as_(noise_batch).copy_(noise_batch)
            text_vec.resize_as_(text_vec_batch).copy_(text_vec_batch)
            real_caption.resize_as_(text_vec_batch).copy_(text_vec_batch)
            real_img.resize_as_(real_img_batch).copy_(real_img_batch)
            if constants.USE_CLS:
                wrong_img.resize_as_(wrong_img_batch).copy_(wrong_img_batch)

            # Run through generator
            gen_image = generator.forward(Variable(
                text_vec, volatile=True), Variable(
                    noise_vec,
                    volatile=True))  # Returns tensor variable holding image

            # Run through discriminator
            real_img_passed = discriminator.forward(
                Variable(real_img, volatile=True),
                Variable(real_caption, volatile=True))
            fake_img_passed = discriminator.forward(
                gen_image.detach(), Variable(real_caption, volatile=True))
            if constants.USE_CLS:
                wrong_img_passed = discriminator.forward(
                    Variable(wrong_img, volatile=True),
                    Variable(real_caption, volatile=True))

            # Calculate D loss
            # D LOSS
            if constants.USE_REAL_LS:
                d_real_loss = torch.mean(
                    torch.abs(real_img_passed - Variable(real_img)))
                d_fake_loss = torch.mean(torch.abs(fake_img_passed -
                                                   gen_image))
                d_real_sensitivity_loss = torch.mean(
                    torch.abs(fake_img_passed - Variable(real_img)))
                d_loss = d_real_loss - began_k * (
                    0.5 * d_fake_loss + 0.5 * d_real_sensitivity_loss)

                balance = (model_options['began_gamma'] * d_real_loss -
                           0.5 * d_fake_loss -
                           0.5 * d_real_sensitivity_loss).data[0]
            elif constants.USE_CLS:
                d_real_loss = torch.mean(
                    torch.abs(real_img_passed - Variable(real_img)))
                d_wrong_loss = torch.mean(
                    torch.abs(fake_img_passed - Variable(wrong_img)))
                d_fake_loss = torch.mean(torch.abs(fake_img_passed -
                                                   gen_image))
                d_loss = 0.5 * d_real_loss + 0.5 * d_wrong_loss - began_k * d_fake_loss

                balance = (model_options['began_gamma'] *
                           (0.5 * d_real_loss + 0.5 * d_wrong_loss) -
                           d_fake_loss).data[0]
    # No CLS option
            else:
                d_real_loss = torch.mean(
                    torch.abs(real_img_passed - Variable(real_img)))
                d_fake_loss = torch.mean(torch.abs(fake_img_passed -
                                                   gen_image))
                d_loss = d_real_loss - began_k * d_fake_loss

                # Update began k value
                balance = (model_options['began_gamma'] * d_real_loss -
                           d_fake_loss).data[0]

            # Calculate G loss
            if constants.USE_REAL_LS:
                g_loss = 0.5 * torch.mean(
                    torch.abs(fake_img_passed - gen_image))
                g_loss += 0.5 * torch.mean(
                    torch.abs(fake_img_passed - Variable(real_img)))
            elif constants.USE_CLS:
                g_loss = torch.mean(torch.abs(fake_img_passed - gen_image))
                g_loss -= 0.5 * torch.mean(
                    torch.abs(fake_img_passed - Variable(wrong_img)))
            else:
                # L_G = L(y_f)
                g_loss = torch.mean(torch.abs(fake_img_passed - gen_image))

            # M = L(y_r) + |gamma * L(y_r) - L(y_f)|
            convergence_val = d_real_loss + abs(balance)

            if i % constants.LOSS_SAVE_IDX == 0:
                losses['val']['generator'].append((g_loss.data[0], epoch, i))
                losses['val']['discriminator'].append(
                    (d_loss.data[0], epoch, i))
                losses['val']['converge'].append(
                    (convergence_val.data[0], epoch, i))

        print('Val G Loss: ', g_loss.data[0])
        print('Val D Loss: ', d_loss.data[0])
        print('Val Convergence: ', convergence_val.data[0])

        # Save losses
        torch.save(losses, constants.SAVE_PATH + 'losses')

        # Save images
        vutils.save_image(gen_image[0].data.cpu(),
                          constants.SAVE_PATH + 'images/gen0_epoch' +
                          str(epoch) + '.png',
                          normalize=True)
        vutils.save_image(gen_image[1].data.cpu(),
                          constants.SAVE_PATH + 'images/gen1_epoch' +
                          str(epoch) + '.png',
                          normalize=True)
        vutils.save_image(fake_img_passed[0].data.cpu(),
                          constants.SAVE_PATH + 'images/gen_recon0_epoch' +
                          str(epoch) + '.png',
                          normalize=True)
        vutils.save_image(fake_img_passed[1].data.cpu(),
                          constants.SAVE_PATH + 'images/gen_recon1_epoch' +
                          str(epoch) + '.png',
                          normalize=True)
        # vutils.save_image(real_img_passed[0].data.cpu(),
        #             constants.SAVE_PATH + 'images/real_recon0_epoch' + str(epoch) + '.png',
        #             normalize=True)
        # vutils.save_image(real_img_passed[1].data.cpu(),
        #             constants.SAVE_PATH + 'images/real_recon1_epoch' + str(epoch) + '.png',
        #             normalize=True)

        # Save model
        if epoch % constants.CHECKPOINT_FREQUENCY == 0 and epoch != 0 or epoch == constants.NUM_EPOCHS - 1:
            save_checkpoint = {
                'epoch': epoch,
                'g_dict': generator.state_dict(),
                'd_dict': discriminator.state_dict(),
                'g_optimizer': g_optimizer.state_dict(),
                'd_optimizer': d_optimizer.state_dict(),
                'began_k': began_k
            }
            torch.save(save_checkpoint,
                       constants.SAVE_PATH + 'weights/epoch' + str(epoch))