コード例 #1
0
def get_point_cloud(B_matrix,
                    parallax_map,
                    color_map,
                    poses,
                    mask_lower_bound=2):

    point_cloud = []
    point_cloud_colors = []

    bar = FillingSquaresBar('Generating Frame Point Cloud',
                            max=len(parallax_map))
    for i in range(len(parallax_map)):

        p_map = parallax_map[i]
        mask = (p_map[:, 2] > mask_lower_bound)
        p_map = p_map[mask, :]
        c_map = color_map[i]
        c_map = (c_map[mask, :] / 255.0).astype('float64')

        point_cloud_colors.append(c_map)

        point_cloud.append(B_matrix @ p_map.T)

        point_cloud[i] = point_cloud[i] / point_cloud[i][3]

        point_cloud[i] = poses[i] @ point_cloud[i]
        point_cloud[i] = point_cloud[i].T

        bar.next()
    bar.finish()

    registered_point_cloud = point_cloud[0]
    registered_point_cloud_colors = point_cloud_colors[0]

    bar = FillingSquaresBar('Registering Global Point Cloud',
                            max=len(point_cloud) - 1)
    for i in range(1, len(point_cloud)):
        registered_point_cloud = np.concatenate(
            (registered_point_cloud, point_cloud[i]), axis=0)
        registered_point_cloud_colors = np.concatenate(
            (registered_point_cloud_colors, point_cloud_colors[i]), axis=0)
        bar.next()
    bar.finish()

    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(registered_point_cloud)
    pcd.colors = o3d.utility.Vector3dVector(registered_point_cloud_colors)

    o3d.io.write_point_cloud(
        OP1_DIR + "/point_cloud_{}.ply".format(
            datetime.datetime.now().strftime("%d-%b-%Y %H:%M:%S.%f")), pcd)

    return registered_point_cloud, registered_point_cloud_colors
コード例 #2
0
def createModeFolder(assetsDir, modeDir):
    print('Begin editing furniture description files. Please wait!')
    os.chdir(assetsDir)
    bar = FillingSquaresBar('] Creating mode files', max=countTaskLen())

    for root, dirs, files in os.walk("."):
        for dir in dirs:
            if not os.path.exists(modeDir + os.path.join(root, dir)[1:]):
                os.mkdir(modeDir + os.path.join(root, dir)[1:])

        for file in files:
            if file.endswith(".object"):
                bar.next()
                with open(os.path.join(root, file), encoding='utf-8') as of:
                    str = jsmin.jsmin(of.read())
                    data = json.loads(str)

                    try:
                        if data['category'] not in ['furniture', 'decorative']:
                            continue

                        tags = data['colonyTags']
                        desc = data['description']

                        wasUsed = max(desc.find('TAGS:'), desc.find('ТЕГИ:'))
                        if wasUsed != -1:
                            desc = desc[:(wasUsed) - 1]

                        if enableTranslating:
                            # sometimes connection is interrupted
                            # this check for avoid repeated translations on script restarting
                            if os.path.isfile(modeDir +
                                              os.path.join(root, file)[1:] +
                                              '.patch'):
                                continue

                            desc = ts.translate(text=desc,
                                                src='en',
                                                dest=[enableTranslating])

                    except KeyError:
                        continue

                    generate = json.loads(
                        '[{"value": "", "path": "/description", "op": "replace"}]'
                    )
                    generate[0]['value'] = desc + ' TAGS: ' + ', '.join(tags)

                    with open(modeDir + os.path.join(root, file)[1:] +
                              '.patch',
                              'w',
                              encoding='utf-8') as outfile:
                        json.dump(generate,
                                  outfile,
                                  sort_keys=True,
                                  indent=2,
                                  ensure_ascii=False)

    print('\nCleaning mode folder...')
    delEmpryDirs(modeDir)
コード例 #3
0
def load_images(img_path):
    '''
        Function to load images into the main memory

        img_path    : Relative path to the image directory
        return      : numpy array of images present in that directory
                      sorted in the numerical order
    '''
    image_files_names = [name for name in os.listdir(img_path)]
    image_files_names = [name.split('.')[0] for name in os.listdir(img_path)]
    image_files_names.sort()
    image_files_names = [
        img_path + name + '.png' for name in image_files_names
    ]

    images = []
    bar = FillingSquaresBar('Loading Images from {}'.format(img_path),
                            max=len(image_files_names))
    for i in range(len(image_files_names)):
        image = cv2.imread(image_files_names[i])
        images.append(image)
        bar.next()
    bar.finish()

    images = np.array(images)
    return images
コード例 #4
0
 def extract_data(self):
     if not self.views:
         extract_data_views = [self.view]
     else:
         extract_data_views = self.views
     for extract_data_view in extract_data_views:
         folders = getattr(self.db, extract_data_view)
         if self.range:
             folders = folders[int(self.range.split(':')[0]):int(self.range.split(':')[1])]
         if self.limit:
             folders = folders.head(self.limit)
         if self.licence_id:
             folders = folders[folders.REFERENCE == self.licence_id]
         if self.id:
             folders = folders[folders.id == int(self.id)]
         bar = FillingSquaresBar('Processing licences for {}'.format(str(extract_data_view)), max=folders.shape[0])
         for id, licence in folders.iterrows():
             self.get_licence(id, licence)
             bar.next()
         bar.finish()
         export_error_csv([self.parcel_errors, self.street_errors])
         if self.iterate is True:
             try:
                 self.validate_data(self.data, 'GenericLicence')
             except Exception:
                 raise IterationError('Schema change during iterative process')
コード例 #5
0
ファイル: bars.py プロジェクト: xtrmbuster/kubestriker
 def wrapper(*args):
     d = (stylize(item, fg("dodger_blue_1")))
     with FillingSquaresBar(d) as bar:
         for _ in range(100):
             sleep(0.005)
             bar.next()
     return fun(*args)
コード例 #6
0
 def _mine(self, progress=True):
     if progress:
         bar = FillingSquaresBar('Mining %s:' % self.grid.name,
                                 max=self.grid.dim)
         for i in range(self.grid.dim):
             p = {
                 'lat': self.grid.points[i][0],
                 'lng': self.grid.points[i][1]
             }
             query_result = self.searcher(lat_lng=p,
                                          radius=self.r,
                                          types=self.place_type)
             for place in self.get_places(query_result):
                 yield (place)
             bar.next()
         bar.finish()
     else:
         for i in range(dim):
             p = {
                 'lat': self.grid.points[i][0],
                 'lng': self.grid.points[i][1]
             }
             query_result = self.searcher(lat_lng=p,
                                          radius=self.r,
                                          types=self.place_type)
             for place in self.get_places(query_result):
                 yield (place)
コード例 #7
0
 def validate_data(self, data, type):
     bar = FillingSquaresBar('Validating licences with : {}'.format(type),
                             max=len(data))
     for licence in data:
         self.validate_schema(licence, type)
         bar.next()
     bar.finish()
コード例 #8
0
def run_bar():
    mylist = [1, 2, 3, 4, 5]
    bar = FillingSquaresBar('Bar', max=len(mylist))
    for item in mylist:
        bar.next()
        time.sleep(0.5)
    bar.finish()
コード例 #9
0
ファイル: fillet.py プロジェクト: DustinTheGreat/phish-fillet
def main():
    textfile = args.file

    if filConfig.exclusions:
        textfile = 'urls-with-exclusions.txt'
    if filConfig.download:
        target.download = filConfig.download

    # Retrieve number of url lines
    content = []
    with open(textfile, 'r') as f:
        content = f.readlines()
# you may also want to remove whitespace characters like `\n` at the end of each line
        content = [x.strip() for x in content]
        # Return lines to config object
        filConfig.numOfUrls = len(content)

    # Open textfile again to actually run.


    if filConfig.quiet:
        print("\n[ Searching {} urls ]".format(filConfig.numOfUrls))
        bar = FillingSquaresBar('Filleting Phish', max = filConfig.numOfUrls)
        print("\n")
    
    if filConfig.verbose:
        filConfig.show()

    for line in content:
        

        try:
            if filConfig.quiet:
                bar.next()
            # Pass URL to target class object
                
            # Class object passed to urlConstruct to create attributes 

            # Launch GeoIP Function
            if filConfig.geoIpEnabled:
                fil_getGeoIP(target, filConfig) # Can this be placed inside connector?
            
            # If output selected collect return and place in output function
            if filConfig.output:
                index = fil_connector(target, filConfig)
                fil_output(index, filConfig.output)
            else:
                fil_connector(filConfig, content)
                print("done")
            
        except (KeyboardInterrupt, SystemExit):

                print("\n\n\
Goodbye!       ,-,\n\
             ,/.(     __\n\
          ,-'    `!._/ /\n\
         > X )<|    _ <\n\
          `-....,,;' \_\n")
                exit()
コード例 #10
0
ファイル: test.py プロジェクト: sk0g/plant-classifier-pytorch
def test_model(saved_model_name):
    model_to_test = torch.load(saved_model_name)

    # batch files are of format "batch-{batch_num}-{model_name}.pth", retreive batch_num
    batch_number = saved_model_name.split("-")[1]
    testing_loader = get_testing_loader(batch_num=batch_number)

    # no_grad call allows processing bigger batches at once
    with torch.no_grad():
        model_to_test.eval()

        truth_list, predictions_list = [], []
        top_1_accuracy, top_5_accuracy = 0, 0
        testing_bar = FillingSquaresBar(message='Testing',
                                        max=len(testing_loader))

        for inputs, labels in testing_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # predict inputs, and reverse the LogSoftMax
            real_predictions = torch.exp(model_to_test(inputs))

            # Get top class of outputs
            _, top_1_class = real_predictions.topk(k=1)
            _, top_5_classes = real_predictions.topk(k=5)

            # Run predictions
            top_1_equals = top_1_class == labels.view(*top_1_class.shape)
            top_5_equals = top_5_classes == labels.view(*top_1_class.shape)

            # Count all the accurate guesses
            top_1_accuracy += top_1_equals.sum().item()
            top_5_accuracy += top_5_equals.sum().item()

            # append to confusion matrix lists
            for truth, prediction in zip(labels.view(-1),
                                         top_1_class.view(-1)):
                predictions_list.append(prediction.item())
                truth_list.append(truth.item())

            testing_bar.next()

    top_1_testing_accuracy = top_1_accuracy / len(testing_loader.dataset)
    top_5_testing_accuracy = top_5_accuracy / len(testing_loader.dataset)
    print(f'''\nAccuracy
        top-1: {helper.to_percentage(top_1_testing_accuracy)}
        top-5: {helper.to_percentage(top_5_testing_accuracy)}''')

    print("Calculating and printing per-class accuracy...")
    print_per_class_accuracy(truth_list, predictions_list)

    print("Displaying confusion matrix...")
    confusionMatrixPrettyPrint.plot_confusion_matrix_from_data(
        y_test=truth_list,
        predictions=predictions_list,
        columns=class_names,
        figsize=[15, 15],
        cmap='twilight')
コード例 #11
0
ファイル: bars.py プロジェクト: xtrmbuster/kubestriker
def scan_status(item):
    '''
    This function decorates items being scanned with bars
    '''
    d = (stylize(item, fg("dodger_blue_1")))
    with FillingSquaresBar(d) as bar:
        for _ in range(100):
            sleep(0.005)
            bar.next()
コード例 #12
0
ファイル: bars.py プロジェクト: xtrmbuster/kubestriker
 def wrapper(*args):
     print('\n')
     d = (stylize(item, fg("green_1")))
     with FillingSquaresBar(d) as bar:
         for _ in range(100):
             sleep(0.01)
             bar.next()
         print('\n')
     return fun(*args)
コード例 #13
0
ファイル: uploader.py プロジェクト: vasivaas/personalWallet
 def save_data(self, data: List[object]) -> int:
     """
        Save prepared data to file
     """
     file_manager.check_db_file()
     bar = FillingSquaresBar('Save data', suffix='%(percent)d%%', max=1)
     with self._file_path.open(mode='w') as f:
         bar.next()
         f.write(ObjectSerializer().encode_object(constants.json_format,
                                                  data))
     return 0
コード例 #14
0
def Pb5():
    from progress.bar import FillingSquaresBar
    import time

    bar = FillingSquaresBar('进度条5', max=100)  #max的值100,可调节

    for i in range(100):  #这个也需要适当调节
        bar.next()
        time.sleep(0.1)  #延迟时间,可调节,0.1~1之间最佳

    bar.finish()
コード例 #15
0
def download_from_eoddata(start_date, end_date, market, driver):
    """Provide datetime.date arguments `start_date` and `end_date`, a string
    `market`, and Selenium driver `driver`.  The function will then download
    the EOD data for the appropriate market and dates from the eoddata
    """

    # navigate to the downloads page
    driver.get('http://www.eoddata.com/download.aspx')

    # get a list of the all of the hyperlink tags in the pagen
    bs_obj = BeautifulSoup(driver.page_source, "lxml")
    url_list = bs_obj.find_all('a')

    # each iteration steps through the list of hyperlink tags in the page until
    # it finds the list of example downloads, and then extracts the `k` field
    k = ''
    for url in url_list:

        if not url.has_attr('href'):
            continue

        # looks for a link of the form
        # /data/filedownload.aspx?e=INDEX&sd=20180606&ed=20180606&d=4&k=ph72h4ynw2&o=d&ea=1&p=0
        # Once we find one, we need to extract the `k` field so that we can use
        # it when constructing our own HTML request.
        url_string = url.attrs['href']
        if re.match('/data/filedownload.aspx', url_string):
            k = re.search('k=([^&]*)', url_string).group(1)
            break
    if not k:
        raise Exception

    # construct the URL according to the dates and market that we want to
    # download
    url_template = '{url_base}?e={e}&sd={sd}&ed={ed}&d={d}&k={k}&o={o}&ea={ea}&p={p}'
    url_download = url_template.format(
        url_base='http://www.eoddata.com/data/filedownload.aspx',
        e=market,
        sd=start_date.strftime('%Y%m%d'),
        ed=end_date.strftime('%Y%m%d'),
        d='4',
        k=k,
        o='d',
        ea='1',
        p='0')
    # submit the download request
    driver.get(url_download)

    # wait for 10 seconds to ensure that the file has time to download
    bar = FillingSquaresBar('Downloading data ', max=100)
    for i in range(100):
        bar.next()
        time.sleep(0.1)
    bar.finish()
コード例 #16
0
 def _generate(self):
     self._initialize_content()
     self._parse_structure()
     start_time = default_timer()
     progress_bar = FillingSquaresBar(
         "Generate content", max=self._structure["number"]
     )
     loop = asyncio.get_event_loop()
     loop.run_until_complete(self._generate_async(progress_bar))
     progress_bar.finish()
     elapsed = default_timer() - start_time
     print("{:5.2f}s elapsed".format(elapsed))
コード例 #17
0
 def _mine(self, progress=True):
     if progress:
         self.bar = FillingSquaresBar('Mining:', max=self.dim)
         for p in self.points:
             for pla in self.get_places(p):
                 yield pla
             self.bar.next()
         self.bar.finish()
     else:
         for p in self.points:
             for pla in self.get_places(p):
                 yield pla
コード例 #18
0
def download_video(link):
    yt=pytube.YouTube(link)
    stream=yt.streams.first()
    video_length=get_time(yt.length)
    video_size=get_size(stream.filesize)
    print("Downloading \""+yt.title+"\" Length : "+video_length)
    print("\tFile Size : "+video_size)
    bar=FillingSquaresBar("Progress : ", suffix="%(percent)d%%")
    for i in range(100):
        stream=yt.streams.first()
        stream.download(SAVE_PATH)
        bar.next()
    bar.finish()
コード例 #19
0
def image_point_cloud(point_cloud, point_cloud_colors, poses, image_width,
                      image_height):
    bar = FillingSquaresBar('Imaging the Point Cloud with the given poses',
                            max=len(poses))
    for i in range(len(poses)):
        P = poses[i]
        R = P[:, :3]
        T = -1 * (R.T @ P[:, 3])
        R = R.T
        im = get_image(i, point_cloud, point_cloud_colors, R, T, K,
                       image_width, image_height)
        bar.next()
    bar.finish()
コード例 #20
0
def train(data):
    # model training
    train_data = data
    for e in range(epochs):
        with FillingSquaresBar('Processing epoch {}/{}'.format(e + 1, epochs),
                               max=len(train_data)) as bar:
            for elm in train_data:
                expected_output = [0.01] * 2
                expected_output[int(elm[0])] = 0.99
                n_net.train(elm[1:], expected_output)
                bar.next()
            bar.finish()
        acc_test, acc_train = test()
        print('Acc_train: {:.4f} Acc_test: {:.4f}\n'.format(
            acc_train, acc_test))
コード例 #21
0
def create_parallax_map(images_left, images_right):
    '''
        Return a parallax map given two stereo rectified images

        images_left: np array of the left stereo images
        images_left: np array of the right stereo images
        return:
    '''
    if len(images_left) != len(images_right):
        print("Error: #images_left must be equal to #images_right")
        return False

    window_size = 5
    minDisparity = -39
    numDisparities = 144
    stereo = cv2.StereoSGBM_create(minDisparity=-39,
                                   numDisparities=144,
                                   blockSize=5,
                                   P1=8 * 3 * window_size**2,
                                   P2=64 * 3 * window_size**2,
                                   disp12MaxDiff=1,
                                   uniquenessRatio=10,
                                   speckleWindowSize=100,
                                   speckleRange=32,
                                   preFilterCap=63,
                                   mode=3)

    disparity = []
    parallax_map = []

    bar = FillingSquaresBar('Extracting Disparity Map', max=len(images_left))
    for k in range(len(images_left)):
        im_right = cv2.cvtColor(images_right[k], cv2.COLOR_BGR2GRAY)
        im_left = cv2.cvtColor(images_left[k], cv2.COLOR_BGR2GRAY)
        disparity = stereo.compute(im_right, im_left).astype('float64')
        disparity = (disparity - minDisparity) / numDisparities

        parallax_map.append([])
        for y in range(disparity.shape[0]):
            for x in range(disparity.shape[1]):
                parallax_map[k].append([x, y, disparity[y, x], 1])

        bar.next()

    parallax_map = np.array(parallax_map)

    bar.finish()
    return parallax_map, disparity
コード例 #22
0
def insert_data_to_db():

    for uf in estados.estados:
        dir = str(path.csv_path) + '/' + uf
        os.chdir(dir)
        city_list_lenght = 0
        for file in glob.glob('*.*'):
            city_list_lenght += 1

        bar = FillingSquaresBar(uf, max=city_list_lenght)
        for file in glob.glob('*.*'):
            try:
                bar.next()
                DB().insert_many(file, 'empresas')

            except:
                pass
コード例 #23
0
def train(logbook, net, device, loss_fn, opt, train_l):
    """Run one epoch of the training experiment."""
    logbook.meter.reset()
    bar = FillingSquaresBar('Training \t', max=len(train_l))
    controllers = indiv.Controller.getControllers(net)
        
    for i_batch, data in enumerate(train_l):
        
        # load data onto device
        inputs, gt_labels = data
        inputs            = inputs.to(device)
        gt_labels         = gt_labels.to(device)
        
        # forprop
        pr_outs           = net(inputs)
        loss              = loss_fn(pr_outs, gt_labels)
        
        # update statistics
        logbook.meter.update(pr_outs, gt_labels, loss.item(), track_metric=logbook.track_metric)
        bar.suffix = 'Total: {total:} | ETA: {eta:} | Epoch: {epoch:4d} | ({batch:5d}/{num_batches:5d})'.format(
                total=bar.elapsed_td,
                eta=bar.eta_td,
                epoch=logbook.i_epoch,
                batch=i_batch + 1,
                num_batches=len(train_l))
        bar.suffix = bar.suffix + logbook.meter.bar()
        bar.next()
        
        # backprop
        opt.zero_grad()
        loss.backward()
        opt.step()
        for ctrl in controllers: 
            ctrl.step_postOptimStep()
        
    bar.finish()
    stats = {
        'train_loss':   logbook.meter.avg_loss,
        'train_metric': logbook.meter.avg_metric
    }
    for k, v in stats.items():
        if v:
            logbook.writer.add_scalar(k, v, global_step=logbook.i_epoch)
    logbook.writer.add_scalar('learning_rate', opt.param_groups[0]['lr'], global_step=logbook.i_epoch)
    return stats
コード例 #24
0
def dwnldfile():
    filename = input(colored(" [system]: filename: ", "white"))
    time.sleep(1)
    ################################

    suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]'
    bar = FillingSquaresBar(" [system]: downloading ", suffix=suffix)
    for i in bar.iter(range(100)):
        sleep()

    ##################################
    localfile = open(filename, 'wb')
    ftp.retrbinary('RETR ' + filename, localfile.write, 1024)
    print(" ")
    time.sleep(0.9)
    print(colored(" [system]: file downloaded", "white"))
    ftp.quit()
    localfile.close()
コード例 #25
0
def download_audio(link):
    yt=pytube.YouTube(link)
    bar=FillingSquaresBar("Downloading Audio : ", suffix="%(percent)d%%")
    for i in range(100):
        stream=yt.streams.filter(only_audio=True).first()
        bad_chars=[";", ":", "!", "*", ' ', "$", "@", "(", ")", "[", "]", "|", ".", "\"", "\'", ","]
        _filename=yt.title
        for i in bad_chars:
            _filename=_filename.replace(i, "_")
        mp4_name="download/%s.mp4"%_filename
        mp3_name="download/%s.mp3"%_filename
        stream.download(SAVE_PATH, _filename)
        bar.next()

    print("\nPerforming required conversions...")
    ffmpeg=('ffmpeg -loglevel panic -i %s ' % mp4_name + mp3_name)
    subprocess.call(ffmpeg, shell=True)
    os.remove(mp4_name)
    bar.finish()
コード例 #26
0
def download_page(output, url):
    logger = logging.getLogger()
    logger.info('Start program!')
    page_name = prepare.get_valid_name(url)
    path_page, dir_files = prepare.prepare_directory(output, page_name)
    logger.info('The directory creation process is complete')
    page = getter.get_content(url)
    html_page, resources = replace_tags(page, dir_files, url)
    getter.get_file(html_page, path_page)
    for key, value in FillingSquaresBar(
     'Process download').iter(resources.items()):
        content = getter.get_content(key)
        getter.get_file(content.content, value)
    logger.info('The download page and the data is complete')
    logger.info('Process download is complete!')
    logger.info(f'The download is complete. Data is saved in {output}')
    print('')
    print(f'The download is complete. Data is saved in {output}')
    print('')
コード例 #27
0
ファイル: raster_tools.py プロジェクト: selkind/LaSeRo
def make_stacked_chunks(data_manager, scene_id, chunk_width=512, chunk_height=512):
    label_path = os.path.join(data_manager.label_dir, i + "_label.TIF")
    with rio.open(label_path) as label_file:
        label = label_file.read(1)

    band_rasters = []
    for b in bands:
        band_path = os.path.join(data_manager.raw_image_dir, i, "{}_B{}.TIF".format(i, b))
        with rio.open(band_path) as band_file:
            band_rasters.append(band_file.read(1))

    band_stack = np.stack(band_rasters, axis=0).transpose(1,2,0)
    band_rasters.clear()

    scene_height = band_stack.shape[0]
    scene_width = band_stack.shape[1]

    vertical_chunks = scene_height // chunk_height
    horizontal_chunks = scene_width // chunk_width

    scene_chunk_dir = os.path.join(data_manager.stack_dir, i)
    if not os.path.exists(scene_chunk_dir):
        os.mkdir(scene_chunk_dir)

    with FillingSquaresBar('Processing', max=vertical_chunks * horizontal_chunks) as bar:
        for j in range(vertical_chunks):
            for k in range(horizontal_chunks):
                row_index = j * chunk_height
                col_index = k * chunk_width

                band_chunk = band_stack[row_index: row_index + chunk_height, col_index: col_index + chunk_width, :]
                label_chunk = label[row_index: row_index + chunk_height, col_index: col_index + chunk_width]

                data_pixels = np.where(band_chunk > 0)
                data_pixel_count = np.sum(data_pixels)

                band_chunk_path = os.path.join(scene_chunk_dir, "chunk_{}_{}.npy".format(j, k))
                label_chunk_path = os.path.join(scene_chunk_dir, "chunk_{}_{}_label.npy".format(j, k))

                if data_pixel_count > 0:
                    np.save(band_chunk_path, band_chunk, allow_pickle=True)
                    np.save(label_chunk_path, label_chunk, allow_pickle=True)
                bar.next()
コード例 #28
0
def test(logbook, net, device, loss_fn, test_l, valid=False, prefix=None):
    """Run a validation epoch."""
    logbook.meter.reset()
    bar_title = 'Validation \t' if valid else 'Test \t'
    bar       = FillingSquaresBar(bar_title, max=len(test_l))
    with torch.no_grad():
        for i_batch, data in enumerate(test_l):
            
            # load data onto device
            inputs, gt_labels     = data
            inputs                = inputs.to(device)
            gt_labels             = gt_labels.to(device)
            
            # forprop
            tensor_stats, pr_outs = net.forward_with_tensor_stats(inputs)
            loss                  = loss_fn(pr_outs, gt_labels)
            
            # update statistics
            logbook.meter.update(pr_outs, gt_labels, loss.item(), track_metric=True)
            bar.suffix = 'Total: {total:} | ETA: {eta:} | Epoch: {epoch:4d} | ({batch:5d}/{num_batches:5d})'.format(
                total=bar.elapsed_td,
                eta=bar.eta_td,
                epoch=logbook.i_epoch,
                batch=i_batch + 1,
                num_batches=len(test_l))
            bar.suffix = bar.suffix + logbook.meter.bar()
            bar.next()
    bar.finish()

    if prefix == None: 
        prefix = 'valid' if valid else 'test'
    stats = {
        prefix+'_loss':   logbook.meter.avg_loss,
        prefix+'_metric': logbook.meter.avg_metric
    }
    if valid:
        for k, v in stats.items():
            if v:
                logbook.writer.add_scalar(k, v, global_step=logbook.i_epoch)
        for name, tensor in tensor_stats:
            logbook.writer.add_histogram(name, tensor, global_step=logbook.i_epoch)
    return stats
コード例 #29
0
ファイル: modules.py プロジェクト: gcasale82/file_order
    def generate_test_files(self):
        if not os.path.exists(self.test_directory):
            os.makedirs(self.test_directory)
        else:
            shutil.rmtree(self.test_directory)
            os.makedirs(self.test_directory)
        os.chdir(self.test_directory)

        bar = FillingSquaresBar('Processing', max=self.files_number)
        for i in range(self.files_number):
            filename = "file" + str(
                i
            ) + f".{self.file_extention[rand(0,len(self.file_extention) - 1 )]}"
            with open(filename, 'wb') as new_random_file:
                d1 = rand(1, 1000)
                d2 = rand(1, 999)
                dimension = d1 * self.file_size + d2
                new_random_file.write(os.urandom(dimension))
            bar.next()
        bar.finish()
コード例 #30
0
 def get_trips(self) -> Tuple[Trip]:
     bar = FillingSquaresBar("Processing", max=len(self._waypoints))
     for waypoint in self._waypoints:
         bar.next()
         if not self.last_valid_waypoint:
             self.last_valid_waypoint = waypoint
             continue
         distance = calculate_distance(self.last_valid_waypoint, waypoint)
         if distance < 15:
             continue
         time_difference = calculate_minute_difference(
             self.last_valid_waypoint.timestamp, waypoint.timestamp)
         if time_difference <= 3:
             continue
         trip = Trip(start=self.last_valid_waypoint,
                     end=waypoint,
                     distance=distance)
         self.last_valid_waypoint = waypoint
         self.trips.append(trip)
     return tuple(self.trips)