def create_local(args): try: if args.name and len(args.name) > 0: filename = './tasks/' + args.report + '/runs/' + args.name + '/' + args.name + '.csv' if not os.path.exists(filename): headers = args.params_report_local with open(filename, 'w') as outfile: FileUtils.lock_file(outfile) outfile.write(','.join(headers) + '\n') outfile.flush() os.fsync(outfile) FileUtils.unlock_file(outfile) except Exception as e: logging.error(str(e)) exc_type, exc_value, exc_tb = sys.exc_info() logging.error( traceback.format_exception(exc_type, exc_value, exc_tb))
def create(args): try: if args.report and len(args.report) > 0: filename = os.path.join('reports', args.report) + '.csv' if not os.path.exists(filename): headers = args.params_report if not args.params_grid is None: headers += args.params_grid with open(filename, 'w') as outfile: FileUtils.lock_file(outfile) outfile.write(','.join(headers) + '\n') outfile.flush() os.fsync(outfile) FileUtils.unlock_file(outfile) except Exception as e: logging.error(str(e)) exc_type, exc_value, exc_tb = sys.exc_info() logging.error( traceback.format_exception(exc_type, exc_value, exc_tb))
def add_hparams(path_sequence, run_name, args_dict, metrics_dict, global_step): try: path_local_csv = f'{path_sequence}/{run_name}.csv' path_global_csv = f'{path_sequence}/sequence-{os.path.basename(path_sequence)}.csv' args_dict = copy.copy(args_dict) metrics_dict = copy.copy(metrics_dict) for each_dict in [args_dict, metrics_dict]: for key in list(each_dict.keys()): if not isinstance(each_dict[key], float) and \ not isinstance(each_dict[key], int) and \ not isinstance(each_dict[key], str) and \ not isinstance(each_dict[key], np.float) and \ not isinstance(each_dict[key], np.int) and \ not isinstance(each_dict[key], np.float32): del each_dict[key] for path_csv in [path_local_csv, path_global_csv]: if os.path.exists(path_csv): with open(path_csv, 'r+') as outfile: FileUtils.lock_file(outfile) lines_all = outfile.readlines() lines_all = [it.replace('\n', '').split(',') for it in lines_all if ',' in it] if len(lines_all) == 0 or len(lines_all[0]) < 2: headers = ['step'] + list(args_dict.keys()) + list(metrics_dict.keys()) headers = [str(it).replace(',', '_') for it in headers] lines_all.append(headers) values = [global_step] + list(args_dict.values()) + list(metrics_dict.values()) values = [str(it).replace(',', '_') for it in values] if path_csv == path_local_csv: lines_all.append(values) else: # global existing_line_idx = -1 args_values = list(args_dict.values()) args_values = [str(it).replace(',', '_') for it in args_values] for idx_line, line in enumerate(lines_all): if len(line) > 1: is_match = True for idx_arg in range(len(args_values)): if line[idx_arg + 1] != args_values[idx_arg]: is_match = False break if is_match: existing_line_idx = idx_line break if existing_line_idx >= 0: lines_all[existing_line_idx] = values else: lines_all.append(values) outfile.truncate(0) outfile.seek(0) outfile.flush() rows = [','.join(it) for it in lines_all] rows = [it for it in rows if len(it.replace('\n', '').strip()) > 0] outfile.write('\n'.join(rows).strip()) outfile.flush() os.fsync(outfile) FileUtils.unlock_file(outfile) except Exception as e: logging.exception(e)
def __init__(self, args, is_test_data): super().__init__() self.args = args self.is_test_data = is_test_data path_data = f'{self.args.path_data}/{self.args.datasource_type}' FileUtils.createDir(path_data) if not os.path.exists( f'{self.args.path_data}/{self.args.datasource_type}/lock'): with open( f'{self.args.path_data}/{self.args.datasource_type}/lock', 'w') as fp_download_lock: fp_download_lock.write('') time.sleep(1.0) with open(f'{self.args.path_data}/{self.args.datasource_type}/lock', 'r+') as fp_download_lock: FileUtils.lock_file(fp_download_lock) transform_colors = torchvision.transforms.ToTensor() if self.args.datasource_is_grayscale: transform_colors = torchvision.transforms.Compose([ torchvision.transforms.Grayscale(), torchvision.transforms.ToTensor() ]) if self.args.datasource_type == 'fassion_mnist': self.dataset = torchvision.datasets.FashionMNIST( path_data, download=True, train=not is_test_data, transform=torchvision.transforms.ToTensor()) elif self.args.datasource_type == 'mnist': self.dataset = torchvision.datasets.MNIST( path_data, download=True, train=not is_test_data, transform=torchvision.transforms.ToTensor()) elif self.args.datasource_type == 'cifar_10': self.dataset = torchvision.datasets.CIFAR10( path_data, download=True, train=not is_test_data, transform=transform_colors) elif self.args.datasource_type == 'cifar_100': self.dataset = torchvision.datasets.CIFAR100( path_data, download=True, train=not is_test_data, transform=transform_colors) elif self.args.datasource_type == 'emnist': # extended mnist https://arxiv.org/pdf/1702.05373.pdf self.dataset = torchvision.datasets.EMNIST( path_data, download=True, split='balanced', train=not is_test_data, transform=torchvision.transforms.Compose([ lambda img: torchvision.transforms.functional.rotate( img, -90), lambda img: torchvision.transforms. functional.hflip(img), torchvision.transforms.ToTensor() ])) FileUtils.unlock_file(fp_download_lock) self.classes = np.arange(np.array(self.dataset.targets).max() + 1).tolist() groups = [{'samples': [], 'counter': 0} for _ in self.classes] for img, label_idx in self.dataset: groups[int(label_idx)]['samples'].append(img) args.input_size = img.size(1) # channels, w, h args.input_features = img.size(0) if not is_test_data: ids = [ int(it) for it in self.args.datasource_exclude_train_class_ids ] ids = sorted(ids, reverse=True) for remove_id in ids: del self.classes[remove_id] del groups[remove_id] else: if len(self.args.datasource_include_test_class_ids): ids = set(self.classes) - set([ int(it) for it in self.args.datasource_include_test_class_ids ]) ids = list(ids) ids = sorted(ids, reverse=True) for remove_id in ids: del self.classes[remove_id] del groups[remove_id] self.classes = np.array(self.classes, dtype=np.int) self.size_samples = 0 for idx, group in enumerate(groups): samples = group['samples'] self.size_samples += len(samples) self.groups = groups # for debugging purposes # DEBUGGING if self.args.datasource_size_samples > 0: logging.info( f'debugging: reduced data size {self.args.datasource_size_samples}' ) self.size_samples = self.args.datasource_size_samples logging.info( f'{self.args.datasource_type} {"test" if is_test_data else "train"}: classes: {len(groups)} total triplets: {self.size_samples}' ) if not is_test_data: self.args.datasource_classes_train = len( groups) # override class count if self.args.batch_size % self.args.triplet_positives != 0 or self.args.batch_size <= self.args.triplet_positives: logging.error( f'batch does not accommodate triplet_positives {self.args.batch_size} {self.args.triplet_positives}' ) exit() self.reshuffle()
def add_results(args, state): try: if args.report and len(args.report) > 0: filename = os.path.join('reports', args.report) + '.csv' if not os.path.exists(filename): if not os.path.exists('./reports'): os.mkdir('./reports') with open(filename, 'w') as outfile: FileUtils.lock_file(outfile) outfile.write(','.join(args.params_report) + '\n') outfile.flush() os.fsync(outfile) FileUtils.unlock_file(outfile) lines_all = [] with open(filename, 'r+') as outfile: FileUtils.lock_file(outfile) raw_lines = outfile.readlines() if len(raw_lines) > 0: header_line = raw_lines[0].strip() headers = header_line.split(',') else: headers = args.params_report lines_all.append(headers) for line in raw_lines: line = line.strip() if len(line) > 0 and ',' in line: parts = line.split(',') lines_all.append(parts) line_new = [] for key in headers: #! gather from state if key in state: line_new.append(str(state[key])) # ! gather also from args elif key in vars(args): line_new.append(str(getattr(args, key))) # ! if not found empty else: line_new.append('') # look for existing line to override part_idx_id = headers.index('id') is_exist = False try: for idx_line in range(1, len(lines_all)): parts = lines_all[idx_line] part_id = parts[part_idx_id] if str(args.id) == part_id.strip(): lines_all[idx_line] = line_new is_exist = True break except Exception as e: logging.error(str(e)) exc_type, exc_value, exc_tb = sys.exc_info() logging.error( traceback.format_exception(exc_type, exc_value, exc_tb)) if not is_exist: lines_all.append(line_new) outfile.truncate(0) outfile.seek(0) outfile.flush() rows = [','.join(it) for it in lines_all] outfile.write('\n'.join(rows)) outfile.flush() os.fsync(outfile) FileUtils.unlock_file(outfile) except Exception as e: logging.error(str(e)) exc_type, exc_value, exc_tb = sys.exc_info() logging.error( traceback.format_exception(exc_type, exc_value, exc_tb))