Ejemplo n.º 1
0
def convert_data(src_collection: Type[Documents.FileRaw], dst_collection: Type[Documents.File]):
    """
    Convert entries of :class:`model_generation.data.documents.Documents.FileRaw` to :class:`model_generation.data.documents.Documents.File`
    Each record of :class:`model_generation.data.documents.Documents.File` or its subclasses has unique :attribute: path and
    :attribute licenses of :type list containing all the license findings found for given path in
    the original collection of :class:`model_generation.data.documents.Documents.FileRaw`
    """
    if src_collection == dst_collection:
        raise DocumentConversionError('Source and destination collections must be different')

    with db.connect():
        logger.info(f'Connected to: {db.get_connection_info()}')
        logger.info(f'Converting: {src_collection.__name__} ({db.get_collection_size(src_collection)})'
                    f' -> {dst_collection.__name__} ({db.get_collection_size(dst_collection)})')

    docs = src_collection.objects()
    total_count = docs.count()
    for current_count, src_doc in enumerate(docs):
        log_progress(current_count, total_count)

        try:
            mapped_doc = map_document(src_doc)
        except (DocumentConversionError, DocumentConstructionError) as e:
            logger.warning(f'Skipping: {src_doc} because of: {e}')
            continue

        mapped_doc.create_or_update()

    with db.connect():
        logger.info(f'Total {dst_collection.__name__} count: ({db.get_collection_size(dst_collection)})')
        logger.info(f'Documents.Conclusion count: ({db.get_collection_size(Documents.Conclusion)})')
        logger.info(f'Documents.License count: ({db.get_collection_size(Documents.License)})')
Ejemplo n.º 2
0
def update(update_function, n_cores, batch_size, collection, **query):
    collection_size = db.get_collection_size(collection, **query)
    batch_size = min(batch_size, collection_size // n_cores)
    if batch_size:
        skips = range(0, collection_size, batch_size)
        db.reset_connection_references(collection)
        logger.info(
            'Running updates on collection: %s (%s) - (no_cores: %d, batch_size: %d, collection_size: %d)',
            collection.__name__, query, n_cores, batch_size, collection_size)

        with ProcessPoolExecutor(max_workers=n_cores) as executor:
            future_to_chunk = [
                executor.submit(process_cursor, update_function, skip_n,
                                batch_size, collection, **query)
                for skip_n in skips
            ]

            current = 0
            for future in as_completed(future_to_chunk):
                current += future.result()[0]
                logger.debug('Thread: %s %s', future, future.result())
                log_progress(current,
                             collection_size,
                             '',
                             num_log_outputs=collection_size)
    else:
        logger.info(
            f'Query {query} for collection {collection} found {collection_size} documents'
            f'\nRunning sequentially ...')
        process_cursor(update_function, 0, collection_size, collection,
                       **query)
Ejemplo n.º 3
0
    def analyze(self):
        docs = self.collection.objects(**self.query)
        doc_count = docs.count()
        logger.info('Running Collection analysis for %s, query: %s (%d)',
                    self.collection_name, self.query, doc_count)

        for doc_no, doc in enumerate(docs):
            log_progress(doc_no, doc_count)
            self.analyze_document(doc)
Ejemplo n.º 4
0
    def retrieve_updated(self):
        """Build list of issues/PRs from the given repositories.

        If this is the first update, than retrieve all of the
        opened issues. If update is subsequent, than only issues
        (opened and closed), which were updated since the last
        update, will be processed.

        Returns:
            dict:
                Issues index in format:
                {issue.html_url: github.Issue.Issue}
        """
        updated_issues = {}

        for repo_name in self._repo_names:
            repo = self._repos.setdefault(
                repo_name, self._gh_client.get_repo(repo_name)
            )
            self.prs_index.index_closed_prs(repo, self._repo_names)

            is_first_update = self._is_first_update(repo_name)

            logging.info("{repo}: processing issues".format(repo=repo.full_name))
            issues = repo.get_issues(**self._build_filter(repo_name))

            for ind, issue in enumerate(issues):
                # "since" filter returns the issue, which was
                # the last updated in previous filling - skip it
                if (
                    issue.updated_at == self._last_issue_updates[repo_name][0]
                    and issue.html_url == self._last_issue_updates[repo_name][1]
                ):
                    continue

                self._process_issue(issue, updated_issues)

                if issue.updated_at > self._last_issue_updates[repo_name][0]:
                    self._last_issue_updates[repo_name] = (
                        issue.updated_at,
                        issue.html_url,
                    )

                log_progress(is_first_update, issues.totalCount, ind, "issues")

            logging.info("{repo}: issues processed".format(repo=repo.full_name))

        save_update_stamps(
            "last_issue_updates", self._sheet_name, self._last_issue_updates
        )
        self.prs_index.save_updates()

        self._issues_index.update(updated_issues)
        return updated_issues
Ejemplo n.º 5
0
def process_cursor_remove_tag(tag: str, _collection):
    logger.info('Started removing tag %s from %s', tag, _collection)
    with db.connect():
        query = {'tags': tag}
        docs = _collection.objects(**query).timeout(False).only('tags')
        docs_count = docs.count(with_limit_and_skip=True)
        try:
            for doc_no, doc in enumerate(docs):
                log_progress(doc_no, docs_count)
                doc.modify(pull__tags=tag.value)

        finally:
            docs._cursor.close()
            logger.info('Completed removing tag %s from %s', tag, _collection)
Ejemplo n.º 6
0
Archivo: sgan.py Proyecto: fayiz7/sgan
    def generate_images(self, zs, truncation_psi, class_idx=None):
        Gs_kwargs = dnnlib.EasyDict()
        Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8,
                                          nchw_to_nhwc=True)
        Gs_kwargs.randomize_noise = False
        if not isinstance(truncation_psi, list):
            truncation_psi = [truncation_psi] * len(zs)

        imgs = []
        label = np.zeros([1] + self.Gs.input_shapes[1][1:])
        if class_idx is not None:
            label[:, class_idx] = 1
        else:
            label = None
        for z_idx, z in log_progress(enumerate(zs),
                                     size=len(zs),
                                     name="Generating images"):
            Gs_kwargs.truncation_psi = truncation_psi[z_idx]
            noise_rnd = np.random.RandomState(1)  # fix noise
            tflib.set_vars({
                var: noise_rnd.randn(*var.shape.as_list())
                for var in self.noise_vars
            })  # [height, width]
            images = self.Gs.run(
                z, label, **Gs_kwargs)  # [minibatch, height, width, channel]
            imgs.append(PIL.Image.fromarray(images[0], 'RGB'))
        return imgs
Ejemplo n.º 7
0
 def predict(self, bugs):
     bugs_count = bugs.shape[0]
     self.logger.info('Start bugs prediction, bugs count: {}', format(bugs_count))
     iteration_num = 0
     last_log_percentage = 0
     errors_count = 0
     avg_error_proba = 0
     time_start = datetime.now()
     result = bugs.copy()
     result[config.STORAGE_COLUMN_SUGGESTED_COMPONENT] = ''
     result[config.STORAGE_COLUMN_CONFIDENCE] = 0.0
     for idx, bug in bugs.iterrows():
         suggested_component, confidence = self.predict_bug(bug=bug)
         result.loc[idx, config.STORAGE_COLUMN_SUGGESTED_COMPONENT] = suggested_component
         result.loc[idx, config.STORAGE_COLUMN_CONFIDENCE] = confidence
         if suggested_component != bug[config.COMPONENT]:
             errors_count += 1
             avg_error_proba += confidence
         iteration_num += 1
         last_log_percentage = utils.log_progress(self.logger, time_start, iteration_num, bugs_count,
                                                  last_log_percentage, delta_percentage=10.0)
     self.logger.info('Accuracy: {0:.3f}, avg. error probability: {1:.3f}'.format(1.0 - (errors_count/bugs_count),
                                                                                  avg_error_proba/errors_count))
     self.logger.info('Finish bugs prediction')
     return result
def heard():
    files = glob.glob('train/*')

    cat_files = [fn for fn in files if 'cat' in fn]
    dog_files = [fn for fn in files if 'dog' in fn]
    len(cat_files), len(dog_files)

    cat_train = np.random.choice(cat_files, size=1500, replace=False)
    dog_train = np.random.choice(dog_files, size=1500, replace=False)
    cat_files = list(set(cat_files) - set(cat_train))
    dog_files = list(set(dog_files) - set(dog_train))

    cat_val = np.random.choice(cat_files, size=500, replace=False)
    dog_val = np.random.choice(dog_files, size=500, replace=False)
    cat_files = list(set(cat_files) - set(cat_val))
    dog_files = list(set(dog_files) - set(dog_val))

    cat_test = np.random.choice(cat_files, size=500, replace=False)
    dog_test = np.random.choice(dog_files, size=500, replace=False)

    print('Cat datasets:', cat_train.shape, cat_val.shape, cat_test.shape)
    print('Dog datasets:', dog_train.shape, dog_val.shape, dog_test.shape)

    train_dir = 'training_data'
    val_dir = 'validation_data'
    test_dir = 'test_data'

    train_files = np.concatenate([cat_train, dog_train])
    validate_files = np.concatenate([cat_val, dog_val])
    test_files = np.concatenate([cat_test, dog_test])

    os.mkdir(train_dir) if not os.path.isdir(train_dir) else None
    os.mkdir(val_dir) if not os.path.isdir(val_dir) else None
    os.mkdir(test_dir) if not os.path.isdir(test_dir) else None

    for fn in log_progress(train_files, name='Training Images'):
        shutil.copy(fn, train_dir)

    for fn in log_progress(validate_files, name='Validation Images'):
        shutil.copy(fn, val_dir)

    for fn in log_progress(test_files, name='Test Images'):
        shutil.copy(fn, test_dir)
Ejemplo n.º 9
0
Archivo: rasm.py Proyecto: wajdiz/rasm
    def generate_animation(self, size = 9, steps = 10, trunc_psi = 0.5):
      seeds = list(np.random.randint((2**32) - 1, size=size))
      seeds = seeds + [seeds[0]]
      zs = self.generate_zs_from_seeds(seeds)

      imgs = self.generate_images(self.interpolate(zs, steps = steps), trunc_psi)
      movie_name = 'animation.mp4'
      with imageio.get_writer(movie_name, mode='I') as writer:
        for image in log_progress(list(imgs), name = "Creating animation"):
            writer.append_data(np.array(image))
      return show_animation(movie_name)
Ejemplo n.º 10
0
    def index_closed_prs(self, repo, repo_names):
        """Add closed pull requests into index.

        Method remembers last PR's update time and doesn't
        indexate PRs which weren't updated since the last
        spreadsheet update.

        Args:
            repo (github.Repository.Repository): Repository object.
            repo_names (tuple): All tracked on this sheet repos names.
        """
        pulls = repo.get_pulls(state="closed",
                               sort="updated",
                               direction="desc")

        if pulls.totalCount:
            is_first_update = False
            logging.info(
                "{repo}: indexing pull requests".format(repo=repo.full_name))

            for index, pull in enumerate(pulls):
                if repo.full_name not in self._last_pr_updates.keys():
                    self._last_pr_updates[repo.full_name] = datetime.datetime(
                        1, 1, 1)
                    is_first_update = True

                if pull.updated_at < self._last_pr_updates[repo.full_name]:
                    break

                for key_phrase in try_match_keywords(pull.body, repo_names):
                    self.add(repo.html_url, pull, key_phrase)

                log_progress(is_first_update, pulls.totalCount, index,
                             "pull requests")

            self._last_pr_updates[repo.full_name] = pulls[0].updated_at
            logging.info("{repo}: all pull requests indexed".format(
                repo=repo.full_name))
Ejemplo n.º 11
0
Archivo: rasm.py Proyecto: wajdiz/rasm
 def generate_from_zs(self, zs, truncation_psi = 0.5):
     Gs_kwargs = dnnlib.EasyDict()
     Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
     Gs_kwargs.randomize_noise = False
     if not isinstance(truncation_psi, list):
         truncation_psi = [truncation_psi] * len(zs)
         
     for z_idx, z in log_progress(enumerate(zs), size = len(zs), name = "Generating images"):
         Gs_kwargs.truncation_psi = truncation_psi[z_idx]
         noise_rnd = np.random.RandomState(1) # fix noise
         tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in self.noise_vars}) # [height, width]
         images = self.Gs.run(z, None, **Gs_kwargs) # [minibatch, height, width, channel]
         img = PIL.Image.fromarray(images[0], 'RGB')
         imshow(img)
Ejemplo n.º 12
0
Archivo: rasm.py Proyecto: wajdiz/rasm
    def generate_images_in_w_space(self, dlatents, truncation_psi):
        Gs_kwargs = dnnlib.EasyDict()
        Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
        Gs_kwargs.randomize_noise = False
        Gs_kwargs.truncation_psi = truncation_psi
        # dlatent_avg = self.Gs.get_var('dlatent_avg') # [component]

        imgs = []
        for _, dlatent in log_progress(enumerate(dlatents), name = "Generating images"):
            #row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(truncation_psi, [-1, 1, 1]) + dlatent_avg
            # dl = (dlatent-dlatent_avg)*truncation_psi   + dlatent_avg
            row_images = self.Gs.components.synthesis.run(dlatent,  **Gs_kwargs)
            imgs.append(PIL.Image.fromarray(row_images[0], 'RGB'))
        return imgs       
Ejemplo n.º 13
0
def add_tag(collection_analysis, tag, min_samples, factor):
    """
    :param collection_analysis: An object that compounds distribution of license occurrences within the database
    :param tag: Tag you want to distribute
    :param min_samples: Minimal number of license occurences in the whole database, in order to consider worth tagging it
    :param factor: Multiplication Factor used to compute how many tags shall be given for each license computed from original distribution
    """
    license_counts = Counter(
        collection_analysis.mapped_single_license) + Counter(
            collection_analysis.mapped_multiple_license)
    logger.info(
        f'Tagging with {tag} with factor {factor} on {len(license_counts)} unique licenses with min_samples {min_samples} '
    )

    logger.info(f'Removing old tags...')
    with db.connect():
        Documents.Conclusion.objects().update(pull__tags=tag)
        db.reset_connection_references(Documents.Conclusion)

    with ProcessPoolExecutor(max_workers=n_cores) as executor:
        futures = [
            executor.submit(add_tag_for_max_samples, tag, license, count,
                            factor)
            for license, count in license_counts.items()
            if count >= min_samples
        ]

        current = 0
        for future in as_completed(futures):
            current += 1
            log_progress(current, len(license_counts))

    with db.connect():
        tagged_docs = Documents.Conclusion.objects(tags=tag)
        db.reset_connection_references(Documents.Conclusion)

    logger.info(f'\nTagged {tagged_docs.count()} documents with {tag}...')
Ejemplo n.º 14
0
def aggregate_counters(vocabulary_txt, source_bytes, counters):
    overall_counter = Counter()
    progress_indicator = log_progress(total=source_bytes, format='bytes')
    while True:
        counter_and_read_bytes = counters.get()
        if counter_and_read_bytes == STOP_TOKEN:
            with open(vocabulary_txt, 'w') as vocabulary_file:
                vocabulary_file.write('\n'.join(str(word) for word, count in overall_counter.most_common(ARGS.vocabulary_size)))
            progress_indicator.end()
            return
        counter, read_bytes = counter_and_read_bytes
        overall_counter += counter
        progress_indicator.increment(value_difference=read_bytes)
        if len(overall_counter.keys()) > ARGS.keep_factor * ARGS.vocabulary_size:
            overall_counter = Counter(overall_counter.most_common(ARGS.vocabulary_size))
Ejemplo n.º 15
0
 def get_iou_score_v0(self, mode='simple'):
     iou_scores = []
     for i in log_progress(range(len(self.df)), name='Samples to Test'):
         self.selected_row = self.df.iloc[[i]]
         self.load_data()
         self.predict()
         iou_scores.append(
             self.__iou_score(
                 self.prediction.reshape(1, *self.input_shape[:2], 1),
                 self.msk.reshape(1, *self.input_shape[:2], 1)))
     if mode == 'simple':
         return min(iou_scores), max(
             iou_scores), sum(iou_scores) / len(iou_scores)
     elif mode == 'raw':
         return iou_scores
Ejemplo n.º 16
0
def analyze_in_parallel(db, n_cores, batch_size, collection, **query):
    collection_size = db.get_collection_size(collection, **query)
    batch_size = min(batch_size, collection_size // n_cores)

    final_analysis = CollectionAnalysis(collection)
    if batch_size:
        skips = range(0, collection_size, batch_size)
        db.reset_connection_references(collection)
        logger.info(
            'Running analysis on collection: %s (%s) - (no_cores: %d, batch_size: %d, collection_size: %d)',
            collection.__name__, query, n_cores, batch_size, collection_size)

        with ProcessPoolExecutor(max_workers=n_cores) as executor:
            future_to_chunk = [
                executor.submit(analyze_batch, db, skip_n, batch_size,
                                collection, **query) for skip_n in skips
            ]

            current = 0
            for future in as_completed(future_to_chunk):
                ca = future.result()[0]
                current += ca.doc_count
                final_analysis.append(ca)

                logger.debug('Thread: %s %s', future, future.result())
                log_progress(current,
                             collection_size,
                             '',
                             num_log_outputs=collection_size)
    else:
        logger.info(
            f'Query {query} for collection {collection} found {collection_size} documents'
            f'\nRunning sequentially ...')
        analyze_batch(0, collection_size, collection, **query)

    return final_analysis
Ejemplo n.º 17
0
 def get_seg_eval_metrics(self,
                          prediction_threshold=0.7,
                          dc_threshold=0.7,
                          print_output=False):
     DCs = []
     TPs = []
     FPs = []
     FNs = []
     names = []
     for i in log_progress(range(len(self.df)), name="Samples to Test"):
         self.selected_row = self.df.iloc[[i]]
         self.load_data()
         self.predict()
         pred = self.prediction > prediction_threshold
         msk = self.msk > 0.5
         DC = self.__dice_score(msk, pred)
         pred = pred.flatten()
         msk = msk.flatten()
         TP = np.sum(pred == msk) / len(msk)
         FP = 0
         for gt, p in zip(msk, pred):
             if p == 1 and gt == 0:
                 FP += 1
         FP /= len(msk)
         FN = 0
         FN = 0 if DC > dc_threshold else 1
         #for gt,p in zip(msk,pred):
         #    if p == 0 and gt == 1:
         #        FN += 1
         #FN /= len(msk)
         name = self.df.iloc[[i]].name
         DCs.append(DC)
         TPs.append(TP)
         FPs.append(FP)
         FNs.append(FN)
         names.append(name)
         if print_output:
             print(
                 str(DC) + " | " + str(TP) + " | " + str(FP) + " | " +
                 str(FN) + " | " + str(name))
     return DCs, TPs, FPs, FNs, names
Ejemplo n.º 18
0
    def get_dice_coeff_score(self, mode='simple'):
        assert mode == 'simple' or mode == 'raw', 'Mode must be "simple" or "raw"'
        dice_coeffs = []
        prediction_times = []  # Just for stats
        for i in log_progress(range(len(self.df)), name='Samples to Test'):
            self.selected_row = self.df.iloc[[i]]
            self.load_data()
            t = time.time()
            self.predict()
            prediction_times.append(time.time() - t)
            dice_coeffs.append(
                self.__dice_score(
                    self.prediction.reshape(1, *self.input_shape[:2], 1),
                    self.msk.reshape(1, *self.input_shape[:2], 1)))

        print("Average prediction time: %.2f s" %
              (sum(prediction_times) / len(prediction_times)))
        if mode == 'simple':
            return min(dice_coeffs), max(
                dice_coeffs), sum(dice_coeffs) / len(dice_coeffs)
        elif mode == 'raw':
            return dice_coeffs
Ejemplo n.º 19
0
    def show_matrix(self, index, mode, rows=4):
        """
        Show a rows x 2 Matrix of images

        :param List of int or str: List of indexes to show, or "random"
        :param str "mode": 
                image : shows only image
                mask : shows only mask
                image_mask : shows image with overlayed mask
                image_prediction : shows image with overlayed prediction
                image_prediction_roots : shows image with GT mask and predicted roots
                image_prediction_contour : shows image with predicted segmentation and GT contours
        :param int "row": how much rows should be displayd
        :return: No return Value
        """
        self.mode = mode
        # Create empty header:
        selected_rows = pd.DataFrame().reindex_like(self.df).head(0)
        if index == 'random':
            n = rows * 2
            selected_rows = selected_rows.append(self.df.sample(n))
        else:
            n = len(index)
            rows = int(n / 2)
            if n <= 2:
                raise ValueError('Index length must be greater then 2')
            if n % 2 != 0:
                raise ValueError('Index length must be eval')
            for i in index:
                selected_rows = selected_rows.append(
                    self.df[self.df['name'].str.contains(str(i))],
                    ignore_index=True)

        _, ax = plt.subplots(int(n / 2), 2, figsize=(15, 3 * n))

        for i in log_progress(range(rows), every=1, name='Rows'):
            rows = selected_rows[2 * i:2 * i + 2]
            self.__make_image_row(rows, ax[i])
        plt.subplots_adjust(wspace=0.01, hspace=0)
Ejemplo n.º 20
0
 def log_root_precision_values(self, tolerance=30, print_output=False):
     tPs = []
     fPs = []
     fNs = []
     precicions = []
     recalls = []
     test_log = []
     for roots_per_image in range(1, 7):
         if print_output:
             print("Max roots_per_image: " + str(roots_per_image))
             print("TP\tFP\tFN\tPrecicion\tRecall\tImageName")
         for i, row in log_progress(self.df.iterrows(),
                                    every=1,
                                    size=len(self.df),
                                    name=str(roots_per_image)):
             if len(row["roots"]) <= roots_per_image:
                 #print(len(row["roots"]))
                 tP, fP, fN, precicion, recall = self.get_root_precicion(
                     row["name"],
                     tolerance=tolerance,
                     print_distance_matrix=False)
                 if print_output:
                     print("{}\t{}\t{}\t{:1.2f}\t{:1.2f}\t{}".format(
                         tP, fP, fN, precicion, recall, row['name']))
                 tPs.append(tP)
                 fPs.append(fP)
                 fNs.append(fN)
                 precicions.append(precicion)
                 recalls.append(recall)
         test_log.append((tPs, fPs, fNs, precicions, recalls))
         tPs = []
         fPs = []
         fNs = []
         precicions = []
         recalls = []
     return test_log
Ejemplo n.º 21
0
 def run(self, steps, stages=("fw", "train")):
     ret = []
     for stage, target in [("fw", self._loss), ("train", self._train)]:
         if stage not in stages:
             continue
         result = ExperimentResult(
             workers=self._workers,
             base_model=self._model,
             batch_size=self._batch_size,
             ordering_algorithm=self._ordering_algorithm,
             stage=stage,
             steps=steps)
         with tf.train.MonitoredTrainingSession(
                 master=self._master) as sess:
             # Warm up run
             sess.run(target)
             for _ in log_progress(range(steps)):
                 with Timer() as timer:
                     with Timeline() as timeline:
                         sess.run(target, **timeline.kwargs())
                 result.times.append(timer.elapsed())
                 result.metadata.append(timeline.run_metadata)
         ret.append(result)
     return ret
def train(model: Hidden, device: torch.device,
          hidden_config: HiDDenConfiguration, train_options: TrainingOptions,
          this_run_folder: str, tb_logger):
    """
    Trains the HiDDeN model
    :param model: The model
    :param device: torch.device object, usually this is GPU (if avaliable), otherwise CPU.
    :param hidden_config: The network configuration
    :param train_options: The training settings
    :param this_run_folder: The parent folder for the current training run to store training artifacts/results/logs.
    :param tb_logger: TensorBoardLogger object which is a thin wrapper for TensorboardX logger.
                Pass None to disable TensorboardX logging
    :return:
    """

    train_data, val_data = utils.get_data_loaders(hidden_config, train_options)
    file_count = len(train_data.dataset)
    if file_count % train_options.batch_size == 0:
        steps_in_epoch = file_count // train_options.batch_size
    else:
        steps_in_epoch = file_count // train_options.batch_size + 1

    print_each = 10
    images_to_save = 8
    saved_images_size = (512, 512)

    for epoch in range(train_options.start_epoch,
                       train_options.number_of_epochs + 1):
        logging.info('\nStarting epoch {}/{}'.format(
            epoch, train_options.number_of_epochs))
        logging.info('Batch size = {}\nSteps in epoch = {}'.format(
            train_options.batch_size, steps_in_epoch))
        training_losses = defaultdict(AverageMeter)
        epoch_start = time.time()
        step = 1
        #train
        for image, _ in train_data:
            image = image.to(device)
            """
            message = torch.Tensor(np.random.choice([0, 1], (image.shape[0], hidden_config.message_length))).to(device)
            losses, _ = model.train_on_batch([image, message])
            print(losses)
            """
            #crop imgs
            imgs = cropImg(32, image)
            #iterate img
            bitwise_arr = []
            main_losses = None
            encoded_imgs = []
            for img in imgs:
                img = img.to(device)
                message = torch.Tensor(
                    np.random.choice(
                        [0, 1], (img.shape[0],
                                 hidden_config.message_length))).to(device)
                losses, (encoded_images, noised_images,
                         decoded_messages) = model.train_on_batch(
                             [img, message])
                encoded_imgs.append(
                    encoded_images[0][0].cpu().detach().numpy())
                main_losses = losses
                for name, loss in losses.items():
                    if (name == 'bitwise-error  '):
                        bitwise_arr.append(loss)
            Total = 0
            Vcount = 0
            V_average = 0
            H_average = 0
            for i in range(0, len(encoded_imgs) - 1):
                if ((i + 1) % 4 != 0):
                    img = encoded_imgs[i]
                    img_next = encoded_imgs[i + 1]
                    average_img = 0
                    average_img_next = 0
                    for j in range(0, 32):
                        for k in range(0, 10):
                            average_img = average_img + img[j][31 - k]
                            average_img_next = average_img_next + img_next[j][k]
                    average_blocking = np.abs(average_img -
                                              average_img_next) / 320
                    V_average = V_average + average_blocking
                    for j in range(0, 32):
                        distinct = np.abs(img[j][31] - img_next[j][0])
                        Total = Total + 1
                        if (distinct > 0.5):
                            Vcount = Vcount + 1
            V_average = V_average / 12
            Hcount = 0
            for i in range(0, len(encoded_imgs) - 4):
                img = encoded_imgs[i]
                img_next = encoded_imgs[i + 4]
                average_img = 0
                average_img_next = 0
                for j in range(0, 32):
                    for k in range(0, 10):
                        average_img = average_img + img[31 - k][j]
                        average_img_next = average_img_next + img_next[k][j]
                average_blocking = np.abs(average_img - average_img_next) / 320
                H_average = H_average + average_blocking
                for j in range(0, 32):
                    distinct = np.abs(img[31][j] - img_next[0][j])
                    Total = Total + 1
                    if (distinct > 0.5):
                        Hcount = Hcount + 1
            H_average = H_average / 12

            bitwise_arr = np.array(bitwise_arr)
            bitwise_avg = np.average(bitwise_arr)
            #blocking_loss = (Vcount+Hcount)/Total
            blocking_loss = (H_average + V_average) / 2

            for name, loss in main_losses.items():
                if (name == 'bitwise-error  '):
                    training_losses[name].update(bitwise_avg)
                else:
                    if (name == 'blocking_effect'):
                        training_losses[name].update(blocking_loss)
                    else:
                        training_losses[name].update(loss)

            if step % print_each == 0 or step == steps_in_epoch:
                logging.info('Epoch: {}/{} Step: {}/{}'.format(
                    epoch, train_options.number_of_epochs, step,
                    steps_in_epoch))
                utils.log_progress(training_losses)
                logging.info('-' * 40)
            step += 1

        train_duration = time.time() - epoch_start
        logging.info('Epoch {} training duration {:.2f} sec'.format(
            epoch, train_duration))
        logging.info('-' * 40)
        utils.write_losses(os.path.join(this_run_folder, 'train.csv'),
                           training_losses, epoch, train_duration)
        if tb_logger is not None:
            tb_logger.save_losses(training_losses, epoch)
            tb_logger.save_grads(epoch)
            tb_logger.save_tensors(epoch)

        first_iteration = True
        validation_losses = defaultdict(AverageMeter)
        logging.info('Running validation for epoch {}/{}'.format(
            epoch, train_options.number_of_epochs))

        #val
        for image, _ in val_data:
            image = image.to(device)
            #crop imgs
            imgs = cropImg(32, image)
            #iterate img
            bitwise_arr = []
            main_losses = None
            encoded_imgs = []
            blocking_imgs = []
            for img in imgs:
                img = img.to(device)
                message = torch.Tensor(
                    np.random.choice(
                        [0, 1], (img.shape[0],
                                 hidden_config.message_length))).to(device)
                losses, (encoded_images, noised_images,
                         decoded_messages) = model.validate_on_batch(
                             [img, message])
                encoded_imgs.append(encoded_images)
                blocking_imgs.append(
                    encoded_images[0][0].cpu().detach().numpy())
                main_losses = losses
                for name, loss in losses.items():
                    if (name == 'bitwise-error  '):
                        bitwise_arr.append(loss)

            Total = 0
            Vcount = 0
            V_average = 0
            H_average = 0
            for i in range(0, len(blocking_imgs) - 1):
                if ((i + 1) % 4 != 0):
                    img = blocking_imgs[i]
                    img_next = blocking_imgs[i + 1]
                    average_img = 0
                    average_img_next = 0
                    for j in range(0, 32):
                        for k in range(0, 10):
                            average_img = average_img + img[j][31 - k]
                            average_img_next = average_img_next + img_next[j][k]
                    average_blocking = np.abs(average_img -
                                              average_img_next) / 320
                    V_average = V_average + average_blocking
                    for j in range(0, 32):
                        distinct = np.abs(img[j][31] - img_next[j][0])
                        Total = Total + 1
                        if (distinct > 0.5):
                            Vcount = Vcount + 1
            V_average = V_average / 12
            Hcount = 0
            for i in range(0, len(blocking_imgs) - 4):
                img = blocking_imgs[i]
                img_next = blocking_imgs[i + 4]
                for j in range(0, 32):
                    for k in range(0, 10):
                        average_img = average_img + img[31 - k][j]
                        average_img_next = average_img_next + img_next[k][j]
                average_blocking = np.abs(average_img - average_img_next) / 320
                H_average = H_average + average_blocking
                for j in range(0, 32):
                    distinct = np.abs(img[31][j] - img_next[0][j])
                    Total = Total + 1
                    if (distinct > 0.5):
                        Hcount = Hcount + 1
            H_average = H_average / 12

            bitwise_arr = np.array(bitwise_arr)
            bitwise_avg = np.average(bitwise_arr)
            #blocking_loss = (Vcount+Hcount)/Total
            blocking_loss = (H_average + V_average) / 2
            for name, loss in main_losses.items():
                if (name == 'bitwise-error  '):
                    validation_losses[name].update(bitwise_avg)
                else:
                    if (name == 'blocking_effect'):
                        validation_losses[name].update(blocking_loss)
                    else:
                        validation_losses[name].update(loss)
            #concat image
            encoded_images = concatImgs(encoded_imgs)

            if first_iteration:
                if hidden_config.enable_fp16:
                    image = image.float()
                    encoded_images = encoded_images.float()
                utils.save_images(
                    image.cpu()[:images_to_save, :, :, :],
                    encoded_images[:images_to_save, :, :, :].cpu(),
                    epoch,
                    os.path.join(this_run_folder, 'images'),
                    resize_to=saved_images_size)
                first_iteration = False

        utils.log_progress(validation_losses)
        logging.info('-' * 40)
        utils.save_checkpoint(model, train_options.experiment_name, epoch,
                              os.path.join(this_run_folder, 'checkpoints'))
        utils.write_losses(os.path.join(this_run_folder, 'validation.csv'),
                           validation_losses, epoch,
                           time.time() - epoch_start)
Ejemplo n.º 23
0
def progress(it=None, desc=None, total=None):
    if desc is not None:
        logging.info(desc)
    return it if CLI_ARGS.no_progress else log_progress(
        it, interval=CLI_ARGS.progress_interval, total=total)
def train(model: Hidden,
          device: torch.device,
          hidden_config: HiDDenConfiguration,
          train_options: TrainingOptions,
          this_run_folder: str,
          tb_logger):
    """
    Trains the HiDDeN model
    :param model: The model
    :param device: torch.device object, usually this is GPU (if avaliable), otherwise CPU.
    :param hidden_config: The network configuration
    :param train_options: The training settings
    :param this_run_folder: The parent folder for the current training run to store training artifacts/results/logs.
    :param tb_logger: TensorBoardLogger object which is a thin wrapper for TensorboardX logger.
                Pass None to disable TensorboardX logging
    :return:
    """

    train_data, val_data = utils.get_data_loaders(hidden_config, train_options)
    block_size = hidden_config.block_size
    block_number = int(hidden_config.H/hidden_config.block_size)
    val_folder = train_options.validation_folder
    loss_type = train_options.loss_mode
    m_length = hidden_config.message_length
    alpha = train_options.alpha
    img_names = listdir(val_folder+"/valid_class")
    img_names.sort()
    out_folder = train_options.output_folder
    default = train_options.default
    beta = train_options.beta
    crop_width = int(beta*block_size)
    file_count = len(train_data.dataset)
    if file_count % train_options.batch_size == 0:
        steps_in_epoch = file_count // train_options.batch_size
    else:
        steps_in_epoch = file_count // train_options.batch_size + 1

    print_each = 10
    images_to_save = 8
    saved_images_size = (512, 512)
    icount = 0
    plot_block = []

    for epoch in range(train_options.start_epoch, train_options.number_of_epochs + 1):
        logging.info('\nStarting epoch {}/{}'.format(epoch, train_options.number_of_epochs))
        logging.info('Batch size = {}\nSteps in epoch = {}'.format(train_options.batch_size, steps_in_epoch))
        training_losses = defaultdict(AverageMeter)
        epoch_start = time.time()
        step = 1
        #train
        for image, _ in train_data:
            image = image.to(device)
            #crop imgs into blocks
            imgs, modified_imgs, entropies = cropImg(block_size,image,crop_width,alpha)
            bitwise_arr=[]
            main_losses = None
            encoded_imgs = []
            batch = 0 
            for img, modified_img, entropy in zip(imgs,modified_imgs, entropies):
                img=img.to(device)
                modified_img = modified_img.to(device)
                entropy = entropy.to(device)
                
                message = torch.Tensor(np.random.choice([0, 1], (img.shape[0], m_length))).to(device)
                losses, (encoded_images, noised_images, decoded_messages) = \
                    model.train_on_batch([img, message, modified_img, entropy,loss_type])
                encoded_imgs.append(encoded_images)
                batch = encoded_images.shape[0]
                #get loss in the last block
                if main_losses is None:
                    main_losses = losses
                    for k in losses:
                        main_losses[k] = losses[k]/len(imgs)
                else:
                    for k in main_losses:
                        main_losses[k] += losses[k]/len(imgs)

            #blocking effect loss calculation
            blocking_loss = blocking_value(encoded_imgs,batch,block_size,block_number)
          
            #update bitwise training loss
            for name, loss in main_losses.items():
                if(default == False  and name == 'blocking_effect'):
                    training_losses[name].update(blocking_loss)
                else:
                    training_losses[name].update(loss) 
            #statistic
            if step % print_each == 0 or step == steps_in_epoch:
                logging.info(
                    'Epoch: {}/{} Step: {}/{}'.format(epoch, train_options.number_of_epochs, step, steps_in_epoch))
                utils.log_progress(training_losses)
                logging.info('-' * 40)
            step += 1

        train_duration = time.time() - epoch_start
        logging.info('Epoch {} training duration {:.2f} sec'.format(epoch, train_duration))
        logging.info('-' * 40)
        utils.write_losses(os.path.join(this_run_folder, 'train.csv'), training_losses, epoch, train_duration)
        if tb_logger is not None:
            tb_logger.save_losses(training_losses, epoch)
            tb_logger.save_grads(epoch)
            tb_logger.save_tensors(epoch)

        first_iteration = True
        validation_losses = defaultdict(AverageMeter)
        logging.info('Running validation for epoch {}/{}'.format(epoch, train_options.number_of_epochs))

        #validation
        ep_blocking = 0
        ep_total = 0
     
        for image, _ in val_data:
            image = image.to(device)
            #crop imgs
            imgs, modified_imgs, entropies = cropImg(block_size,image,crop_width,alpha)
            bitwise_arr=[]
            main_losses = None
            encoded_imgs = []
            batch = 0
          
            for img, modified_img, entropy in zip(imgs,modified_imgs, entropies):
                img=img.to(device)
                modified_img = modified_img.to(device)
                entropy = entropy.to(device)
                
                message = torch.Tensor(np.random.choice([0, 1], (img.shape[0], m_length))).to(device)
                losses, (encoded_images, noised_images, decoded_messages) = \
                    model.train_on_batch([img, message, modified_img, entropy,loss_type])
                encoded_imgs.append(encoded_images)
                batch = encoded_images.shape[0]
                #get loss in the last block
                if main_losses is None:
                    main_losses = losses
                    for k in losses:
                        main_losses[k] = losses[k]/len(imgs)
                else:
                    for k in main_losses:
                        main_losses[k] += losses[k]/len(imgs)
                
            #blocking value for plotting
            blocking_loss = blocking_value(encoded_imgs,batch,block_size,block_number)
            ep_blocking = ep_blocking+ blocking_loss
            ep_total = ep_total+1

            for name, loss in main_losses.items():
                if(default == False  and name == 'blocking_effect'):
                    validation_losses[name].update(blocking_loss)
                else:
                    validation_losses[name].update(loss) 
            #concat image
            encoded_images = concatImgs(encoded_imgs,block_number)
            #save_image(encoded_images,"enc_img"+str(epoch)+".png")
            #save_image(image,"original_img"+str(epoch)+".png")
            if first_iteration:
                if hidden_config.enable_fp16:
                    image = image.float()
                    encoded_images = encoded_images.float()
                utils.save_images(image.cpu()[:images_to_save, :, :, :],
                                  encoded_images[:images_to_save, :, :, :].cpu(),
                                  epoch,
                                  os.path.join(this_run_folder, 'images'), resize_to=saved_images_size)
                first_iteration = False
            #save validation in the last epoch
            if(epoch == train_options.number_of_epochs):
                if(train_options.ats):
                    for i in range(0,batch):
                        image = encoded_images[i].cpu()
                        image = (image + 1) / 2
                        f_dst = out_folder+"/"+img_names[icount]
                        save_image(image,f_dst)
                        icount = icount+1
        #append block effect for plotting
        plot_block.append(ep_blocking/ep_total)
    
        utils.log_progress(validation_losses)
        logging.info('-' * 40)
        utils.save_checkpoint(model, train_options.experiment_name, epoch, os.path.join(this_run_folder, 'checkpoints'))
        utils.write_losses(os.path.join(this_run_folder, 'validation.csv'), validation_losses, epoch,
                           time.time() - epoch_start)
Ejemplo n.º 25
0
def train_own_noise(model: Hidden, device: torch.device,
                    hidden_config: HiDDenConfiguration,
                    train_options: TrainingOptions, this_run_folder: str,
                    tb_logger, noise):
    """
    Trains the HiDDeN model
    :param model: The model
    :param device: torch.device object, usually this is GPU (if avaliable), otherwise CPU.
    :param hidden_config: The network configuration
    :param train_options: The training settings
    :param this_run_folder: The parent folder for the current training run to store training artifacts/results/logs.
    :param tb_logger: TensorBoardLogger object which is a thin wrapper for TensorboardX logger.
                Pass None to disable TensorboardX logging
    :return:
    """

    train_data, val_data = utils.get_data_loaders(hidden_config, train_options)
    file_count = len(train_data.dataset)
    if file_count % train_options.batch_size == 0:
        steps_in_epoch = file_count // train_options.batch_size
    else:
        steps_in_epoch = file_count // train_options.batch_size + 1
    steps_in_epoch = 313

    print_each = 10
    images_to_save = 8
    saved_images_size = (
        512, 512)  # for qualitative check purpose to use a larger size

    for epoch in range(train_options.start_epoch,
                       train_options.number_of_epochs + 1):
        logging.info('\nStarting epoch {}/{}'.format(
            epoch, train_options.number_of_epochs))
        logging.info('Batch size = {}\nSteps in epoch = {}'.format(
            train_options.batch_size, steps_in_epoch))
        training_losses = defaultdict(AverageMeter)

        if train_options.video_dataset:
            random.shuffle(train_data.dataset)

        epoch_start = time.time()
        step = 1
        for image, _ in train_data:
            image = image.to(device)
            message = torch.Tensor(
                np.random.choice(
                    [0, 1],
                    (image.shape[0], hidden_config.message_length))).to(device)
            losses, _ = model.train_on_batch([image, message])

            for name, loss in losses.items():
                training_losses[name].update(loss)
            if step % print_each == 0 or step == steps_in_epoch:
                #import pdb; pdb.set_trace()
                logging.info('Epoch: {}/{} Step: {}/{}'.format(
                    epoch, train_options.number_of_epochs, step,
                    steps_in_epoch))
                utils.log_progress(training_losses)
                logging.info('-' * 40)
            step += 1
            if step == steps_in_epoch:
                break

        train_duration = time.time() - epoch_start
        logging.info('Epoch {} training duration {:.2f} sec'.format(
            epoch, train_duration))
        logging.info('-' * 40)
        utils.write_losses(os.path.join(this_run_folder, 'train.csv'),
                           training_losses, epoch, train_duration)
        if tb_logger is not None:
            tb_logger.save_losses(training_losses, epoch)
            tb_logger.save_grads(epoch)
            tb_logger.save_tensors(epoch)

        first_iteration = True
        validation_losses = defaultdict(AverageMeter)
        logging.info('Running validation for epoch {}/{} for noise {}'.format(
            epoch, train_options.number_of_epochs, noise))
        step = 1
        for image, _ in val_data:
            image = image.to(device)
            message = torch.Tensor(
                np.random.choice(
                    [0, 1],
                    (image.shape[0], hidden_config.message_length))).to(device)
            losses, (
                encoded_images, noised_images,
                decoded_messages) = model.validate_on_batch_specific_noise(
                    [image, message], noise=noise)
            for name, loss in losses.items():
                validation_losses[name].update(loss)
            if first_iteration:
                if hidden_config.enable_fp16:
                    image = image.float()
                    encoded_images = encoded_images.float()
                utils.save_images(
                    image.cpu()[:images_to_save, :, :, :],
                    encoded_images[:images_to_save, :, :, :].cpu(),
                    epoch,
                    os.path.join(this_run_folder, 'images'),
                    resize_to=saved_images_size)
                first_iteration = False
            step += 1
            if step == steps_in_epoch // 10:
                break

        utils.log_progress(validation_losses)
        logging.info('-' * 40)
        utils.save_checkpoint(model, train_options.experiment_name, epoch,
                              os.path.join(this_run_folder, 'checkpoints'))
        utils.write_losses(
            os.path.join(this_run_folder, 'validation_' + noise + '.csv'),
            validation_losses, epoch,
            time.time() - epoch_start)
Ejemplo n.º 26
0
def train(model: Hidden, device: torch.device,
          hidden_config: HiDDenConfiguration, train_options: TrainingOptions,
          this_run_folder: str, tb_logger, vocab):
    """
    Trains the HiDDeN model
    :param model: The model
    :param device: torch.device object, usually this is GPU (if avaliable), otherwise CPU.
    :param hidden_config: The network configuration
    :param train_options: The training settings
    :param this_run_folder: The parent folder for the current training run to store training artifacts/results/logs.
    :param tb_logger: TensorBoardLogger object which is a thin wrapper for TensorboardX logger.
                Pass None to disable TensorboardX logging
    :return:
    """

    train_data, val_data = utils.get_data_loaders(hidden_config, train_options,
                                                  vocab)
    file_count = len(train_data.dataset)
    if file_count % train_options.batch_size == 0:
        steps_in_epoch = file_count // train_options.batch_size
    else:
        steps_in_epoch = file_count // train_options.batch_size + 1

    print_each = 10
    images_to_save = 8
    saved_images_size = (512, 512)

    for epoch in range(train_options.start_epoch,
                       train_options.number_of_epochs + 1):
        logging.info('\nStarting epoch {}/{}'.format(
            epoch, train_options.number_of_epochs))
        logging.info('Batch size = {}\nSteps in epoch = {}'.format(
            train_options.batch_size, steps_in_epoch))
        training_losses = defaultdict(AverageMeter)
        epoch_start = time.time()
        step = 1
        for image, ekeys, dkeys, caption, length in train_data:
            image, caption, ekeys, dkeys = image.to(device), caption.to(
                device), ekeys.to(device), dkeys.to(device)

            losses, _ = model.train_on_batch(
                [image, ekeys, dkeys, caption, length])

            for name, loss in losses.items():
                training_losses[name].update(loss)
            if step % print_each == 0 or step == steps_in_epoch:
                logging.info('Epoch: {}/{} Step: {}/{}'.format(
                    epoch, train_options.number_of_epochs, step,
                    steps_in_epoch))
                utils.log_progress(training_losses)
                logging.info('-' * 40)
            step += 1

        train_duration = time.time() - epoch_start
        logging.info('Epoch {} training duration {:.2f} sec'.format(
            epoch, train_duration))
        logging.info('-' * 40)
        utils.write_losses(os.path.join(this_run_folder, 'train.csv'),
                           training_losses, epoch, train_duration)
        if tb_logger is not None:
            tb_logger.save_losses(training_losses, epoch)
            tb_logger.save_grads(epoch)
            tb_logger.save_tensors(epoch)

        first_iteration = True
        validation_losses = defaultdict(AverageMeter)
        logging.info('Running validation for epoch {}/{}'.format(
            epoch, train_options.number_of_epochs))
        for image, ekeys, dkeys, caption, length in val_data:
            image, caption, ekeys, dkeys = image.to(device), caption.to(
                device), ekeys.to(device), dkeys.to(device)

            losses, (encoded_images, noised_images, decoded_messages, predicted_sents) = \
                model.validate_on_batch([image, ekeys, dkeys, caption, length])

            #print(predicted)
            #exit()
            predicted_sents = predicted_sents.cpu().numpy()
            for i in range(train_options.batch_size):
                try:
                    #print(''.join([vocab.idx2word[int(w)] + ' ' for w in predicted.cpu().numpy()[i::train_options.batch_size]][1:length[i]-1]))
                    print("".join([
                        vocab.idx2word[int(idx)] + ' '
                        for idx in predicted_sents[i]
                    ]))
                    break
                except IndexError:
                    print(f'{i}th batch does not have enough length.')

            for name, loss in losses.items():
                validation_losses[name].update(loss)
            if first_iteration:
                if hidden_config.enable_fp16:
                    image = image.float()
                    encoded_images = encoded_images.float()
                utils.save_images(
                    image.cpu()[:images_to_save, :, :, :],
                    encoded_images[:images_to_save, :, :, :].cpu(),
                    epoch,
                    os.path.join(this_run_folder, 'images'),
                    resize_to=saved_images_size)
                first_iteration = False

        utils.log_progress(validation_losses)
        logging.info('-' * 40)
        utils.save_checkpoint(model, train_options.experiment_name, epoch,
                              os.path.join(this_run_folder, 'checkpoints'))
        utils.write_losses(os.path.join(this_run_folder, 'validation.csv'),
                           validation_losses, epoch,
                           time.time() - epoch_start)
Ejemplo n.º 27
0
def progress(it=None, desc='Processing', total=None):
    print(desc, file=sys.stderr, flush=True)
    return it if CLI_ARGS.no_progress else log_progress(
        it, interval=CLI_ARGS.progress_interval, total=total)
Ejemplo n.º 28
0
 def progress(it=None, desc='Processing', total=None):
     print(desc)
     return it if args.no_progress else log_progress(it, interval=args.progress_interval, total=total)
Ejemplo n.º 29
0
 def progress(it=None, desc="Processing", total=None):
     logging.info(desc)
     return (it if args.no_progress else log_progress(
         it, interval=args.progress_interval, total=total))
Ejemplo n.º 30
0
def train(model: Hidden, device: torch.device,
          hidden_config: HiDDenConfiguration, train_options: TrainingOptions,
          this_run_folder: str, tb_logger):
    """
    Trains the HiDDeN model
    :param model: The model
    :param device: torch.device object, usually this is GPU (if avaliable), otherwise CPU.
    :param hidden_config: The network configuration
    :param train_options: The training settings
    :param this_run_folder: The parent folder for the current training run to store training artifacts/results/logs.
    :param tb_logger: TensorBoardLogger object which is a thin wrapper for TensorboardX logger.
                Pass None to disable TensorboardX logging
    :return:
    """

    train_data, val_data = utils.get_data_loaders(hidden_config, train_options)

    images_to_save = 8
    saved_images_size = (512, 512)

    best_epoch = train_options.best_epoch
    best_cond = train_options.best_cond
    for epoch in range(train_options.start_epoch,
                       train_options.number_of_epochs + 1):
        logging.info(
            f'\nStarting epoch {epoch}/{train_options.number_of_epochs} [{best_epoch}]'
        )
        training_losses = defaultdict(functions.AverageMeter)
        epoch_start = time.time()
        for image, _ in tqdm(train_data, ncols=80):
            image = image.to(device)  #.squeeze(0)
            message = torch.Tensor(
                np.random.choice(
                    [0, 1],
                    (image.shape[0], hidden_config.message_length))).to(device)
            losses, _ = model.train_on_batch([image, message])

            for name, loss in losses.items():
                training_losses[name].update(loss)

        train_duration = time.time() - epoch_start
        logging.info('Epoch {} training duration {:.2f} sec'.format(
            epoch, train_duration))
        logging.info('-' * 40)
        utils.write_losses(os.path.join(this_run_folder, 'train.csv'),
                           training_losses, epoch, train_duration)
        if tb_logger is not None:
            tb_logger.save_losses('train_loss', training_losses, epoch)
            tb_logger.save_grads(epoch)
            tb_logger.save_tensors(epoch)
            tb_logger.writer.flush()

        validation_losses = defaultdict(functions.AverageMeter)
        logging.info('Running validation for epoch {}/{}'.format(
            epoch, train_options.number_of_epochs))
        val_image_patches = ()
        val_encoded_patches = ()
        val_noised_patches = ()
        for image, _ in tqdm(val_data, ncols=80):
            image = image.to(device)  #.squeeze(0)
            message = torch.Tensor(
                np.random.choice(
                    [0, 1],
                    (image.shape[0], hidden_config.message_length))).to(device)
            losses, (encoded_images, noised_images,
                     decoded_messages) = model.validate_on_batch(
                         [image, message])
            for name, loss in losses.items():
                validation_losses[name].update(loss)

            if hidden_config.enable_fp16:
                image = image.float()
                encoded_images = encoded_images.float()
            pick = np.random.randint(0, image.shape[0])
            val_image_patches += (F.interpolate(
                image[pick:pick + 1, :, :, :].cpu(),
                size=(hidden_config.W, hidden_config.H)), )
            val_encoded_patches += (F.interpolate(
                encoded_images[pick:pick + 1, :, :, :].cpu(),
                size=(hidden_config.W, hidden_config.H)), )
            val_noised_patches += (F.interpolate(
                noised_images[pick:pick + 1, :, :, :].cpu(),
                size=(hidden_config.W, hidden_config.H)), )

        if tb_logger is not None:
            tb_logger.save_losses('val_loss', validation_losses, epoch)
            tb_logger.writer.flush()

        val_image_patches = torch.stack(val_image_patches).squeeze(1)
        val_encoded_patches = torch.stack(val_encoded_patches).squeeze(1)
        val_noised_patches = torch.stack(val_noised_patches).squeeze(1)
        utils.save_images(val_image_patches[:images_to_save, :, :, :],
                          val_encoded_patches[:images_to_save, :, :, :],
                          val_noised_patches[:images_to_save, :, :, :],
                          epoch,
                          os.path.join(this_run_folder, 'images'),
                          resize_to=saved_images_size)

        curr_cond = validation_losses['encoder_mse'].avg + validation_losses[
            'bitwise-error'].avg
        if best_cond is None or curr_cond < best_cond:
            best_cond = curr_cond
            best_epoch = epoch

        utils.log_progress(validation_losses)
        logging.info('-' * 40)
        utils.save_checkpoint(model, train_options.experiment_name, epoch,
                              best_epoch, best_cond,
                              os.path.join(this_run_folder, 'checkpoints'))
        logging.info(
            f'Current best epoch = {best_epoch}, loss = {best_cond:.6f}')
        utils.write_losses(os.path.join(this_run_folder, 'validation.csv'),
                           validation_losses, epoch,
                           time.time() - epoch_start)
Ejemplo n.º 31
0

def clean(txt, stem=False, spell=False):
    # print txt
    txt = txt.lower().strip()
    txt = txt.replace('^p', '')
    txt = re.sub('[%s]' % re.escape(string.punctuation + u'“”«»–—―◦℅™•№▪'), ' ', txt)
    txt = u' '.join([convert_word(i, stem, spell) for i in tokenizer.tokenize(txt) if i not in russtop])

    # print txt
    return txt


if __name__ == '__main__':
    if len(sys.argv) != 3:
        print "Usage " + sys.argv[0] + " file cleanfile"
        sys.exit(1)

    train_file = sys.argv[1]
    clean_filename = sys.argv[2]

    prev_l = ''

    utils.reset_progress()
    with codecs.open(clean_filename, 'w', 'utf-8') as fw:
        for parts in utils.read_train(train_file):
            parts[3] = clean(parts[3], False, False)
            parts[4] = clean(parts[4], False, False)
            fw.write('\t'.join(parts))
            utils.log_progress()