Ejemplo n.º 1
0
 def create_local(args):
     try:
         if args.name and len(args.name) > 0:
             filename = './tasks/' + args.report + '/runs/' + args.name + '/' + args.name + '.csv'
             if not os.path.exists(filename):
                 headers = args.params_report_local
                 with open(filename, 'w') as outfile:
                     FileUtils.lock_file(outfile)
                     outfile.write(','.join(headers) + '\n')
                     outfile.flush()
                     os.fsync(outfile)
                     FileUtils.unlock_file(outfile)
     except Exception as e:
         logging.error(str(e))
         exc_type, exc_value, exc_tb = sys.exc_info()
         logging.error(
             traceback.format_exception(exc_type, exc_value, exc_tb))
Ejemplo n.º 2
0
 def create(args):
     try:
         if args.report and len(args.report) > 0:
             filename = os.path.join('reports', args.report) + '.csv'
             if not os.path.exists(filename):
                 headers = args.params_report
                 if not args.params_grid is None:
                     headers += args.params_grid
                 with open(filename, 'w') as outfile:
                     FileUtils.lock_file(outfile)
                     outfile.write(','.join(headers) + '\n')
                     outfile.flush()
                     os.fsync(outfile)
                     FileUtils.unlock_file(outfile)
     except Exception as e:
         logging.error(str(e))
         exc_type, exc_value, exc_tb = sys.exc_info()
         logging.error(
             traceback.format_exception(exc_type, exc_value, exc_tb))
Ejemplo n.º 3
0
    def add_hparams(path_sequence, run_name, args_dict, metrics_dict, global_step):
        try:
            path_local_csv = f'{path_sequence}/{run_name}.csv'
            path_global_csv = f'{path_sequence}/sequence-{os.path.basename(path_sequence)}.csv'

            args_dict = copy.copy(args_dict)
            metrics_dict = copy.copy(metrics_dict)
            for each_dict in [args_dict, metrics_dict]:
                for key in list(each_dict.keys()):
                    if not isinstance(each_dict[key], float) and \
                        not isinstance(each_dict[key], int) and \
                        not isinstance(each_dict[key], str) and \
                        not isinstance(each_dict[key], np.float) and \
                        not isinstance(each_dict[key], np.int) and \
                        not isinstance(each_dict[key], np.float32):
                        del each_dict[key]

            for path_csv in [path_local_csv, path_global_csv]:

                if os.path.exists(path_csv):
                    with open(path_csv, 'r+') as outfile:
                        FileUtils.lock_file(outfile)
                        lines_all = outfile.readlines()
                        lines_all = [it.replace('\n', '').split(',') for it in lines_all if ',' in it]
                        if len(lines_all) == 0 or len(lines_all[0]) < 2:
                            headers = ['step'] + list(args_dict.keys()) + list(metrics_dict.keys())
                            headers = [str(it).replace(',', '_') for it in headers]
                            lines_all.append(headers)

                        values = [global_step] + list(args_dict.values()) + list(metrics_dict.values())
                        values = [str(it).replace(',', '_') for it in values]
                        if path_csv == path_local_csv:
                            lines_all.append(values)
                        else:
                            # global
                            existing_line_idx = -1
                            args_values = list(args_dict.values())
                            args_values = [str(it).replace(',', '_') for it in args_values]
                            for idx_line, line in enumerate(lines_all):
                                if len(line) > 1:
                                    is_match = True
                                    for idx_arg in range(len(args_values)):
                                        if line[idx_arg + 1] != args_values[idx_arg]:
                                            is_match = False
                                            break
                                    if is_match:
                                        existing_line_idx = idx_line
                                        break
                            if existing_line_idx >= 0:
                                lines_all[existing_line_idx] = values
                            else:
                                lines_all.append(values)

                        outfile.truncate(0)
                        outfile.seek(0)
                        outfile.flush()
                        rows = [','.join(it) for it in lines_all]
                        rows = [it for it in rows if len(it.replace('\n', '').strip()) > 0]
                        outfile.write('\n'.join(rows).strip())
                        outfile.flush()
                        os.fsync(outfile)
                        FileUtils.unlock_file(outfile)

        except Exception as e:
            logging.exception(e)
Ejemplo n.º 4
0
    def __init__(self, args, is_test_data):
        super().__init__()

        self.args = args
        self.is_test_data = is_test_data

        path_data = f'{self.args.path_data}/{self.args.datasource_type}'
        FileUtils.createDir(path_data)

        if not os.path.exists(
                f'{self.args.path_data}/{self.args.datasource_type}/lock'):
            with open(
                    f'{self.args.path_data}/{self.args.datasource_type}/lock',
                    'w') as fp_download_lock:
                fp_download_lock.write('')
            time.sleep(1.0)

        with open(f'{self.args.path_data}/{self.args.datasource_type}/lock',
                  'r+') as fp_download_lock:
            FileUtils.lock_file(fp_download_lock)

            transform_colors = torchvision.transforms.ToTensor()
            if self.args.datasource_is_grayscale:
                transform_colors = torchvision.transforms.Compose([
                    torchvision.transforms.Grayscale(),
                    torchvision.transforms.ToTensor()
                ])

            if self.args.datasource_type == 'fassion_mnist':
                self.dataset = torchvision.datasets.FashionMNIST(
                    path_data,
                    download=True,
                    train=not is_test_data,
                    transform=torchvision.transforms.ToTensor())
            elif self.args.datasource_type == 'mnist':
                self.dataset = torchvision.datasets.MNIST(
                    path_data,
                    download=True,
                    train=not is_test_data,
                    transform=torchvision.transforms.ToTensor())
            elif self.args.datasource_type == 'cifar_10':

                self.dataset = torchvision.datasets.CIFAR10(
                    path_data,
                    download=True,
                    train=not is_test_data,
                    transform=transform_colors)
            elif self.args.datasource_type == 'cifar_100':
                self.dataset = torchvision.datasets.CIFAR100(
                    path_data,
                    download=True,
                    train=not is_test_data,
                    transform=transform_colors)
            elif self.args.datasource_type == 'emnist':  # extended mnist https://arxiv.org/pdf/1702.05373.pdf
                self.dataset = torchvision.datasets.EMNIST(
                    path_data,
                    download=True,
                    split='balanced',
                    train=not is_test_data,
                    transform=torchvision.transforms.Compose([
                        lambda img: torchvision.transforms.functional.rotate(
                            img, -90), lambda img: torchvision.transforms.
                        functional.hflip(img),
                        torchvision.transforms.ToTensor()
                    ]))

            FileUtils.unlock_file(fp_download_lock)

        self.classes = np.arange(np.array(self.dataset.targets).max() +
                                 1).tolist()
        groups = [{'samples': [], 'counter': 0} for _ in self.classes]

        for img, label_idx in self.dataset:
            groups[int(label_idx)]['samples'].append(img)

        args.input_size = img.size(1)  # channels, w, h
        args.input_features = img.size(0)

        if not is_test_data:
            ids = [
                int(it) for it in self.args.datasource_exclude_train_class_ids
            ]
            ids = sorted(ids, reverse=True)
            for remove_id in ids:
                del self.classes[remove_id]
                del groups[remove_id]
        else:
            if len(self.args.datasource_include_test_class_ids):
                ids = set(self.classes) - set([
                    int(it)
                    for it in self.args.datasource_include_test_class_ids
                ])
                ids = list(ids)
                ids = sorted(ids, reverse=True)
                for remove_id in ids:
                    del self.classes[remove_id]
                    del groups[remove_id]

        self.classes = np.array(self.classes, dtype=np.int)
        self.size_samples = 0
        for idx, group in enumerate(groups):
            samples = group['samples']
            self.size_samples += len(samples)
        self.groups = groups

        # for debugging purposes
        # DEBUGGING
        if self.args.datasource_size_samples > 0:
            logging.info(
                f'debugging: reduced data size {self.args.datasource_size_samples}'
            )
            self.size_samples = self.args.datasource_size_samples

        logging.info(
            f'{self.args.datasource_type} {"test" if is_test_data else "train"}: classes: {len(groups)} total triplets: {self.size_samples}'
        )

        if not is_test_data:
            self.args.datasource_classes_train = len(
                groups)  # override class count

        if self.args.batch_size % self.args.triplet_positives != 0 or self.args.batch_size <= self.args.triplet_positives:
            logging.error(
                f'batch does not accommodate triplet_positives {self.args.batch_size} {self.args.triplet_positives}'
            )
            exit()
        self.reshuffle()
Ejemplo n.º 5
0
    def add_results(args, state):
        try:
            if args.report and len(args.report) > 0:
                filename = os.path.join('reports', args.report) + '.csv'

                if not os.path.exists(filename):
                    if not os.path.exists('./reports'):
                        os.mkdir('./reports')
                    with open(filename, 'w') as outfile:
                        FileUtils.lock_file(outfile)
                        outfile.write(','.join(args.params_report) + '\n')
                        outfile.flush()
                        os.fsync(outfile)
                        FileUtils.unlock_file(outfile)

                lines_all = []
                with open(filename, 'r+') as outfile:
                    FileUtils.lock_file(outfile)
                    raw_lines = outfile.readlines()
                    if len(raw_lines) > 0:
                        header_line = raw_lines[0].strip()
                        headers = header_line.split(',')
                    else:
                        headers = args.params_report
                        lines_all.append(headers)

                    for line in raw_lines:
                        line = line.strip()
                        if len(line) > 0 and ',' in line:
                            parts = line.split(',')
                            lines_all.append(parts)

                    line_new = []
                    for key in headers:
                        #! gather from state
                        if key in state:
                            line_new.append(str(state[key]))
                        # ! gather also from args
                        elif key in vars(args):
                            line_new.append(str(getattr(args, key)))
                        # ! if not found empty
                        else:
                            line_new.append('')

                    # look for existing line to override
                    part_idx_id = headers.index('id')
                    is_exist = False
                    try:
                        for idx_line in range(1, len(lines_all)):
                            parts = lines_all[idx_line]
                            part_id = parts[part_idx_id]
                            if str(args.id) == part_id.strip():
                                lines_all[idx_line] = line_new
                                is_exist = True
                                break
                    except Exception as e:
                        logging.error(str(e))
                        exc_type, exc_value, exc_tb = sys.exc_info()
                        logging.error(
                            traceback.format_exception(exc_type, exc_value,
                                                       exc_tb))

                    if not is_exist:
                        lines_all.append(line_new)

                    outfile.truncate(0)
                    outfile.seek(0)
                    outfile.flush()
                    rows = [','.join(it) for it in lines_all]
                    outfile.write('\n'.join(rows))
                    outfile.flush()
                    os.fsync(outfile)
                    FileUtils.unlock_file(outfile)
        except Exception as e:
            logging.error(str(e))
            exc_type, exc_value, exc_tb = sys.exc_info()
            logging.error(
                traceback.format_exception(exc_type, exc_value, exc_tb))