Пример #1
0
def parse_summary(summary):
    content = ''
    headers = ['Module Name', 'Input Size', 'Weight Size', 'Output Size', 'Parameters', 'FLOPs']
    records = []
    for key in summary['module']:
        if not summary['module'][key]['params']:
            continue
        module_name = summary['module'][key]['module_name']
        input_size = str(summary['module'][key]['input_size'])
        weight_size = str(summary['module'][key]['params']['weight']['size']) if (
                'weight' in summary['module'][key]['params']) else 'N/A'
        output_size = str(summary['module'][key]['output_size'])
        num_params = 0
        for name in summary['module'][key]['params']:
            num_params += (summary['module'][key]['params'][name]['mask'] > 0).sum().item()
        num_flops = divide_by_unit(summary['module'][key]['flops'])
        records.append([module_name, input_size, weight_size, output_size, num_params, num_flops])
    total_num_param = '{} ({})'.format(summary['total_num_params'], divide_by_unit(summary['total_num_params']))
    total_num_flops = '{} ({})'.format(summary['total_num_flops'], divide_by_unit(summary['total_num_flops']))
    total_space = summary['total_space']
    total = {'num_params': summary['total_num_params'], 'num_flops': summary['total_num_flops'],
             'space': summary['total_space']}
    table = tabulate(records, headers=headers, tablefmt='github')
    content += table + '\n'
    content += '================================================================\n'
    content += 'Total Number of Parameters: {}\n'.format(total_num_param)
    content += 'Total Number of FLOPs: {}\n'.format(total_num_flops)
    content += 'Total Space (MB): {:.2f}\n'.format(total_space)
    makedir_exist_ok('./output')
    content_file = open('./output/summary.md', 'w')
    content_file.write(content)
    content_file.close()
    return content, total
Пример #2
0
 def download(self):
     if self._check_exists():
         return
     makedir_exist_ok(os.path.join(self.raw_folder))
     filename = os.path.basename(self.url)
     file_path = os.path.join(self.raw_folder, filename)
     download_url(self.url,
                  root=self.raw_folder,
                  filename=filename,
                  md5=None)
     extract_file(file_path)
     vocab = Vocab()
     for split in self.data_file:
         token_path = os.path.join(self.raw_folder, 'wikitext-103',
                                   self.data_file[split])
         num_tokens = 0
         with open(token_path, 'r', encoding='utf-8') as f:
             for line in f:
                 line = line.split() + [u'<eos>']
                 num_tokens += len(line)
                 for symbol in line:
                     vocab.add(symbol)
         with open(token_path, 'r', encoding='utf-8') as f:
             data = torch.LongTensor(num_tokens)
             i = 0
             for line in f:
                 line = line.split() + [u'<eos>']
                 for symbol in line:
                     data[i] = vocab.symbol_to_index[symbol]
                     i += 1
         save(data,
              os.path.join(self.processed_folder, '{}.pt'.format(split)))
     save(vocab, os.path.join(self.processed_folder, 'meta.pt'))
     return
def download_url(url, root, filename, md5):
    from six.moves import urllib

    root = os.path.expanduser(root)
    fpath = os.path.join(root, filename)

    makedir_exist_ok(root)

    # downloads file
    if os.path.isfile(fpath) and check_integrity(fpath, md5):
        print('Using downloaded and verified file: ' + fpath)
    else:
        try:
            print('Downloading ' + url + ' to ' + fpath)
            urllib.request.urlretrieve(url,
                                       fpath,
                                       reporthook=gen_bar_updater(
                                           tqdm(unit='B', unit_scale=True)))
        except OSError:
            if url[:5] == 'https':
                url = url.replace('https:', 'http:')
                print('Failed download. Trying https -> http instead.'
                      ' Downloading ' + url + ' to ' + fpath)
                urllib.request.urlretrieve(url,
                                           fpath,
                                           reporthook=gen_bar_updater(
                                               tqdm(unit='B',
                                                    unit_scale=True)))
Пример #4
0
def parse_summary(summary):
    content = ''
    headers = ['Module Name', 'Input Size', 'Weight Size', 'Output Size', 'Number of Parameters']
    records = []
    for key in summary['module']:
        if 'weight' not in summary['module'][key]['params']:
            continue
        module_name = summary['module'][key]['module_name']
        input_size = str(summary['module'][key]['input_size'])
        weight_size = str(summary['module'][key]['params']['weight']['size']) if (
                'weight' in summary['module'][key]['params']) else 'N/A'
        output_size = str(summary['module'][key]['output_size'])
        num_params = 0
        for name in summary['module'][key]['params']:
            num_params += (summary['module'][key]['params'][name]['mask'] > 0).sum().item()
        records.append([module_name, input_size, weight_size, output_size, num_params])
    total_num_param = summary['total_num_param']
    total_space_param = summary['total_space_param']

    table = tabulate(records, headers=headers, tablefmt='github')
    content += table + '\n'
    content += '================================================================\n'
    content += 'Total Number of Parameters: {}\n'.format(total_num_param)
    content += 'Total Space of Parameters (MB): {:.2f}\n'.format(total_space_param)
    makedir_exist_ok('./output')
    content_file = open('./output/summary.md', 'w')
    content_file.write(content)
    content_file.close()
    return content
def download_url(url, root, filename, md5):
    from six.moves import urllib
    path = os.path.join(root, filename)
    makedir_exist_ok(root)
    if os.path.isfile(path) and check_integrity(path, md5):
        print('Using downloaded and verified file: ' + path)
    else:
        try:
            print('Downloading ' + url + ' to ' + path)
            urllib.request.urlretrieve(url,
                                       path,
                                       reporthook=make_bar_updater(
                                           tqdm(unit='B', unit_scale=True)))
        except OSError:
            if url[:5] == 'https':
                url = url.replace('https:', 'http:')
                print('Failed download. Trying https -> http instead.'
                      ' Downloading ' + url + ' to ' + path)
                urllib.request.urlretrieve(url,
                                           path,
                                           reporthook=make_bar_updater(
                                               tqdm(unit='B',
                                                    unit_scale=True)))
        if not check_integrity(path, md5):
            raise RuntimeError('Not valid downloaded file')
    return
 def download(self):
     makedir_exist_ok(self.raw_folder)
     for (url, md5) in self.file:
         filename = os.path.basename(url)
         download_url(url, self.raw_folder, filename, md5)
         extract_file(os.path.join(self.raw_folder, filename))
     return
def download_google(id, root, filename, md5):
    google_url = "https://docs.google.com/uc?export=download"
    path = os.path.join(root, filename)
    makedir_exist_ok(root)
    if os.path.isfile(path) and check_integrity(path, md5):
        print('Using downloaded and verified file: ' + path)
    else:
        session = requests.Session()
        response = session.get(google_url, params={'id': id}, stream=True)
        token = None
        for key, value in response.cookies.items():
            if key.startswith('download_warning'):
                token = value
                break
        if token:
            params = {'id': id, 'confirm': token}
            response = session.get(google_url, params=params, stream=True)
        with open(path, "wb") as f:
            pbar = tqdm(total=None)
            progress = 0
            for chunk in response.iter_content(32768):
                if chunk:
                    f.write(chunk)
                    progress += len(chunk)
                    pbar.update(progress - pbar.n)
            pbar.close()
        if not check_integrity(path, md5):
            raise RuntimeError('Not valid downloaded file')
    return
 def download(self):
     if self._check_exists():
         return
     makedir_exist_ok(self.raw_folder)
     makedir_exist_ok(self.processed_folder)
     for url in self.urls:
         filename = url.rpartition('/')[2]
         file_path = os.path.join(self.raw_folder, filename)
         download_url(url,
                      root=self.raw_folder,
                      filename=filename,
                      md5=None)
         self.extract_gzip(gzip_path=file_path, remove_finished=True)
     print('Processing...')
     training_set = (read_image_file(
         os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
                     read_label_file(
                         os.path.join(self.raw_folder,
                                      'train-labels-idx1-ubyte')))
     test_set = (read_image_file(
         os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
                 read_label_file(
                     os.path.join(self.raw_folder,
                                  't10k-labels-idx1-ubyte')))
     with open(os.path.join(self.processed_folder, self.training_file),
               'wb') as f:
         torch.save(training_set, f)
     with open(os.path.join(self.processed_folder, self.test_file),
               'wb') as f:
         torch.save(test_set, f)
     print('Done!')
    def download(self):
        import shutil
        import zipfile

        if self._check_exists():
            return

        makedir_exist_ok(self.raw_folder)
        makedir_exist_ok(self.processed_folder)

        filename = self.url.rpartition('/')[2]
        file_path = os.path.join(self.raw_folder, filename)
        download_url(self.url,
                     root=self.raw_folder,
                     filename=filename,
                     md5=None)

        print('Extracting zip archive')
        with zipfile.ZipFile(file_path) as zip_f:
            zip_f.extractall(self.raw_folder)
        os.unlink(file_path)
        gzip_folder = os.path.join(self.raw_folder, 'gzip')
        for gzip_file in os.listdir(gzip_folder):
            if gzip_file.endswith('.gz'):
                self.extract_gzip(
                    gzip_path=os.path.join(gzip_folder, gzip_file))

        for split in self.splits:
            print('Processing ' + split)
            training_set = (
                read_image_file(
                    os.path.join(
                        gzip_folder,
                        'emnist-{}-train-images-idx3-ubyte'.format(split))),
                read_label_file(
                    os.path.join(
                        gzip_folder,
                        'emnist-{}-train-labels-idx1-ubyte'.format(split))))
            test_set = (
                read_image_file(
                    os.path.join(
                        gzip_folder,
                        'emnist-{}-test-images-idx3-ubyte'.format(split))),
                read_label_file(
                    os.path.join(
                        gzip_folder,
                        'emnist-{}-test-labels-idx1-ubyte'.format(split))))
            with open(
                    os.path.join(self.processed_folder,
                                 self._training_file(split)), 'wb') as f:
                torch.save(training_set, f)
            with open(
                    os.path.join(self.processed_folder,
                                 self._test_file(split)), 'wb') as f:
                torch.save(test_set, f)
        shutil.rmtree(gzip_folder)

        print('Done!')
def make_learning_curve(processed_result):
    ylim_dict = {'iid': {'global': {'MNIST': [95, 100], 'CIFAR10': [50, 100], 'WikiText2': [0, 20]}},
                 'non-iid-2': {'global': {'MNIST': [50, 100], 'CIFAR10': [0, 70]},
                               'local': {'MNIST': [95, 100], 'CIFAR10': [50, 100]}}}
    fig = {}
    for exp_name in processed_result:
        control = exp_name.split('_')
        data_name = control[0]
        metric_name = metric_name_dict[data_name]
        control_name = control[-4]
        if control_name in ['a5-b5', 'a5-c5', 'a5-d5', 'a5-e5', 'a1-b1', 'a1-c1', 'a1-d1', 'a1-e1']:
            if 'non-iid-2' in exp_name:
                y = processed_result[exp_name]['Local-{}_mean'.format(metric_name)]
                x = np.arange(len(y))
                label_name = '-'.join(['{}'.format(x[0]) for x in list(control_name.split('-'))])
                fig_name = '{}_lc_local'.format('_'.join(control[:-4] + control[-3:]))
                fig[fig_name] = plt.figure(fig_name)
                plt.plot(x, y, linestyle='-', label=label_name)
                plt.legend(loc=loc_dict[data_name], fontsize=fontsize)
                plt.xlabel('Communication rounds', fontsize=fontsize)
                plt.ylabel('Test {}'.format(metric_name), fontsize=fontsize)
                plt.ylim(ylim_dict['non-iid-2']['local'][data_name])
                plt.xticks(fontsize=fontsize)
                plt.yticks(fontsize=fontsize)
                y = processed_result[exp_name]['Global-{}_mean'.format(metric_name)]
                x = np.arange(len(y))
                label_name = '-'.join(['{}'.format(x[0]) for x in list(control_name.split('-'))])
                fig_name = '{}_lc_global'.format('_'.join(control[:-4] + control[-3:]))
                fig[fig_name] = plt.figure(fig_name)
                plt.plot(x, y, linestyle='-', label=label_name)
                plt.legend(loc=loc_dict[data_name], fontsize=fontsize)
                plt.xlabel('Communication rounds', fontsize=fontsize)
                plt.ylabel('Test {}'.format(metric_name), fontsize=fontsize)
                plt.ylim(ylim_dict['non-iid-2']['global'][data_name])
                plt.xticks(fontsize=fontsize)
                plt.yticks(fontsize=fontsize)
            else:
                y = processed_result[exp_name]['Global-{}_mean'.format(metric_name)]
                x = np.arange(len(y))
                label_name = '-'.join(['{}'.format(x[0]) for x in list(control_name.split('-'))])
                fig_name = '{}_lc_global'.format('_'.join(control[:-4] + control[-3:]))
                fig[fig_name] = plt.figure(fig_name)
                plt.plot(x, y, linestyle='-', label=label_name)
                plt.legend(loc=loc_dict[data_name], fontsize=fontsize)
                plt.xlabel('Communication rounds', fontsize=fontsize)
                plt.ylabel('Test {}'.format(metric_name), fontsize=fontsize)
                plt.ylim(ylim_dict['iid']['global'][data_name])
                plt.xticks(fontsize=fontsize)
                plt.yticks(fontsize=fontsize)
    for fig_name in fig:
        fig[fig_name] = plt.figure(fig_name)
        plt.grid()
        fig_path = '{}/{}.{}'.format(vis_path, fig_name, cfg['save_format'])
        makedir_exist_ok(vis_path)
        plt.savefig(fig_path, dpi=500, bbox_inches='tight', pad_inches=0)
        plt.close(fig_name)
    return
def make_vis(df):
    fig = {}
    for df_name in df:
        if 'fix' in df_name and 'none' not in df_name:
            control = df_name.split('_')
            data_name = control[0]
            metric_name = metric_name_dict[data_name]
            label_name = control[-1]
            x = df[df_name]['Params_mean']
            if 'non-iid-2' in df_name:
                fig_name = '{}_{}_local'.format('_'.join(control[:-1]), label_name[0])
                fig[fig_name] = plt.figure(fig_name)
                y = df[df_name]['Local-{}_mean'.format(metric_name)]
                plt.plot(x, y, linestyle='-', marker=marker_dict[label_name], label=label_name)
                plt.legend(loc=loc_dict[data_name], fontsize=fontsize)
                plt.xlabel('Number of Model Parameters', fontsize=fontsize)
                plt.ylabel(metric_name, fontsize=fontsize)
                plt.xticks(fontsize=fontsize)
                plt.yticks(fontsize=fontsize)
                plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
                fig_name = '{}_{}_global'.format('_'.join(control[:-1]), label_name[0])
                fig[fig_name] = plt.figure(fig_name)
                y = df[df_name]['Global-{}_mean'.format(metric_name)]
                plt.plot(x, y, linestyle='-', marker=marker_dict[label_name], label=label_name)
                plt.legend(loc=loc_dict[data_name], fontsize=fontsize)
                plt.xlabel('Number of Model Parameters', fontsize=fontsize)
                plt.ylabel(metric_name, fontsize=fontsize)
                plt.xticks(fontsize=fontsize)
                plt.yticks(fontsize=fontsize)
                plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
            elif 'iid' in df_name:
                fig_name = '{}_{}'.format('_'.join(control[:-1]), label_name[0])
                fig[fig_name] = plt.figure(fig_name)
                y = df[df_name]['Global-{}_mean'.format(metric_name)]
                plt.plot(x, y, linestyle='-', marker=marker_dict[label_name], label=label_name)
                plt.legend(loc=loc_dict[data_name], fontsize=fontsize)
                plt.xlabel('Number of Model Parameters', fontsize=fontsize)
                plt.ylabel(metric_name, fontsize=fontsize)
                plt.xticks(fontsize=fontsize)
                plt.yticks(fontsize=fontsize)
                plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
            else:
                raise ValueError('Not valid df name')
    for fig_name in fig:
        fig[fig_name] = plt.figure(fig_name)
        plt.grid()
        fig_path = '{}/{}.{}'.format(vis_path, fig_name, cfg['save_format'])
        makedir_exist_ok(vis_path)
        plt.savefig(fig_path, dpi=500, bbox_inches='tight', pad_inches=0)
        plt.close(fig_name)
    return
Пример #12
0
    def split_train_test(self, dataframe):

        if self._check_exists():
            return
        if self.balance:
            dataframe = self.balance_dataframe(dataframe)
        makedir_exist_ok(self.processed_folder)

        train_df, test_df = train_test_split(dataframe, test_size=0.2)
        with open(os.path.join(self.processed_folder, self.training_file),
                  'wb') as f:
            torch.save(train_df, f)
        with open(os.path.join(self.processed_folder, self.test_file),
                  'wb') as f:
            torch.save(test_df, f)
        print("Split the training and test file successfully! {}".format(
            "It's the balanced data" if self.balance else ""))
Пример #13
0
    def download(self):
        """Download the MNIST data if it doesn't exist in processed_folder already."""

        if self._check_exists():
            return

        makedir_exist_ok(self.raw_folder)
        makedir_exist_ok(self.processed_folder)

        # download files
        for url in self.urls:
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.raw_folder, filename)
            download_url(url,
                         root=self.raw_folder,
                         filename=filename,
                         md5=None)
            self.extract_gzip(gzip_path=file_path, remove_finished=True)
def vis(s, control, processed_result):
    makedir_exist_ok(model_path)
    if 'exp' in processed_result:
        data_name = control[0]
        model_name = control[2]
        metric = control[-1]
        save_per_mode = 10
        filenames = ['generate', 'transit', 'create']
        pivot = 'is'
        if metric != pivot:
            return
        best_seed = exp[processed_result['argmax']]
        for filename in filenames:
            if model_name == 'vqvae' or (filename == 'transit'
                                         and 'pixelcnn' in model_name):
                continue
            label = model_name[2:] if 'mc' in model_name else model_name[1:]
            model_tag = '_'.join([best_seed] + control[:-1] + ['best'])
            shutil.copy(
                os.path.join(backup_path, label, 'model',
                             '{}.pt'.format(model_tag)),
                os.path.join(model_path, '{}.pt'.format(model_tag)))
            if 'pixelcnn' in model_name:
                ae_tag = '_'.join([
                    best_seed, control[0], control[1], cfg['ae_name'], 'best'
                ])
                shutil.copy(
                    os.path.join(backup_path, cfg['ae_name'], 'model',
                                 '{}.pt'.format(ae_tag)),
                    os.path.join(model_path, '{}.pt'.format(ae_tag)))
            script_name = '{}.py'.format(filename)
            control_name = '0.5' if 'mc' in model_name else None
            controls = [
                best_seed, data_name, model_name, control_name, save_per_mode
            ]
            s.extend([
                'CUDA_VISIBLE_DEVICES="0" python {} --init_seed {} --data_name {} --model_name {} '
                '--control_name {} --save_per_mode {}\n'.format(
                    script_name, *controls)
            ])
    else:
        for k, v in processed_result.items():
            vis(s, control + [k], v)
    return
Пример #15
0
    def re_select(self):
        """
        Re-select the image to build the pairs for Siamese networks.
        There are four data set which are going to be built (training):
            - 250 images with highest virality score, 250 images with the lowest virality score.
            - random image pairs:
                - 250 highest images with 250 images randomly selected from the lower part
                - 250 images randomly selected from the highest part and 250 least viral images
                - randomly select 250 images from higher part and lower part respectively.
        This is the main function of the data set.
        """
        if self._check_exists():
            return
        makedir_exist_ok(self.processed_folder)

        # select the images pairs randomly
        if self.pair_mode == 1:
            img_pairs, labels = self._build_img_pair_1()
        elif self.pair_mode == 2:
            img_pairs, labels = self._build_img_pair_2()
        elif self.pair_mode == 3:
            img_pairs, labels = self._build_img_pair_3()
        elif self.pair_mode == 4:
            img_pairs, labels = self._build_img_pair_4()

        img_pairs_test, labels_test = self._build_test_pairs()
        # use the tuple to store the image pairs and labels, then save into the processed folder.
        print('Processing... building image pairs')
        data_pair_labels = (img_pairs, labels)
        data_pair_labels_test = (img_pairs_test, labels_test)

        with open(os.path.join(self.processed_folder, self.training_file),
                  'wb') as f:
            torch.save(data_pair_labels, f)
        with open(os.path.join(self.processed_folder, self.test_file),
                  'wb') as f:
            torch.save(data_pair_labels_test, f)

        print('Done!')
Пример #16
0
def make_vis(processed_result):
    fontsize = 16
    fig = {}
    vis(fig, [], processed_result)
    for k, v in fig.items():
        metric_name = k.split('_')[-1].split('/')
        save_fig_name = '_'.join(k.split('_')[:-1] + ['_'.join(metric_name)])
        fig[k] = plt.figure(k)
        plt.xlabel('Epoch', fontsize=fontsize)
        plt.ylabel(metric_name[-1], fontsize=fontsize)
        plt.xticks(fontsize=fontsize)
        plt.yticks(fontsize=fontsize)
        plt.grid()
        if metric_name[-1] == 'FID':
            plt.legend(loc='upper right', fontsize=fontsize)
        else:
            plt.legend(loc='lower right', fontsize=fontsize)
        fig_path = '{}/{}.{}'.format(vis_path, save_fig_name,
                                     cfg['save_format'])
        makedir_exist_ok(vis_path)
        fig[k].savefig(fig_path, dpi=300, bbox_inches='tight', pad_inches=0)
        plt.close(k)
    return
Пример #17
0
 def __init__(self):
     super(Model, self).__init__()
     self.model = make_model(config.PARAM['model'])
     makedir_exist_ok('./output/tmp')
Пример #18
0
    def download(self):
        """Download the MNIST data if it doesn't exist in processed_folder already."""

        if self._check_exists() and self._split_exists():
            print('Files already processed')
            return

        makedir_exist_ok(self.raw_folder)
        makedir_exist_ok(self.processed_folder)

        # download files
        for url in self.urls:
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.raw_folder, filename)
            download_url(url, root=self.raw_folder, filename=filename, md5=None)
            self.extract_gzip(gzip_path=file_path, remove_finished=True)

        # process and save as torch files
        print('Processing...')

        train_set = (
            read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
            read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
        )
        test_set = (
            read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
            read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
        )
        with open(os.path.join(self.processed_folder, self.train_file), 'wb') as f:
            torch.save(train_set, f)
        with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
            torch.save(test_set, f)

        # splitting full train dataset into train/val splits
        V = 10000 # val size
        C = len(self.classes) # 10 classes
        VC = int(V/C+0.5)
        data, targets = train_set

        vImages = torch.tensor([], dtype=torch.uint8)
        vLabels = torch.tensor([], dtype=torch.long)
        tImages = torch.tensor([], dtype=torch.uint8)
        tLabels = torch.tensor([], dtype=torch.long)
        for c in range(C):
            mask = targets.eq(c)
            index = mask.nonzero().squeeze().tolist()
            vIdx = random.sample(index, VC) # select randomly sunbset
            tIdx = [item for item in index if item not in vIdx]
            print(c, ':', len(index),'=', len(vIdx)+len(tIdx))
            vImages = torch.cat([vImages, torch.index_select(data,    0, torch.tensor(vIdx, dtype=torch.long))])
            vLabels = torch.cat([vLabels, torch.index_select(targets, 0, torch.tensor(vIdx, dtype=torch.long))])
            tImages = torch.cat([tImages, torch.index_select(data,    0, torch.tensor(tIdx, dtype=torch.long))])
            tLabels = torch.cat([tLabels, torch.index_select(targets, 0, torch.tensor(tIdx, dtype=torch.long))])

        with open(os.path.join(self.processed_folder, self.val_split_file), 'wb') as f:
            torch.save((vImages, vLabels), f)

        with open(os.path.join(self.processed_folder, self.train_split_file), 'wb') as f:
            torch.save((tImages, tLabels), f)

        print('Done!')