示例#1
0
 def save(self, save_path=None, items=None):
     path_ = "./config/" if save_path is None else save_path
     ensure_path(path_)
     self.model.save(path_, "model") # save model
     self.trainer.save(path_, "trainer")
     self.optimizer.save(path_, "optimizer") # save optmizer
     self.data_loader.save(path_, "data_loader") # save data_loader
示例#2
0
    def __init__(self,
                 run_label: str,
                 path_to_results: str,
                 pg_model: str = '') -> None:
        self.path = path_to_results
        self.run_label: str = run_label
        utils.ensure_path(f'{path_to_results}/results.csv')

        utils.ensure_path(f'{path_to_results}/{run_label}/results.csv')

        self.tensorboard_writer = SummaryWriter(
            f'{path_to_results}/tboard_logs', comment=f'pg_{pg_model}')

        self.columns = [
            'env', 'seed', 'run_label', 'run_nr_epochs', 'nr_episodes',
            'run_model', 'timestamp', 'epoch', 'step_size', 'epoch_duration',
            'nr_steps', 'perf', 'hardware'
        ]

        self.results = pd.DataFrame(columns=self.columns)

        if (not os.path.exists(f'{path_to_results}/results.csv')
                or os.path.getsize(f'{path_to_results}/results.csv') == 0):
            with open(f'{self.path}/results.csv', 'w') as f:
                writer = csv.writer(f)
                writer.writerow(self.columns)
示例#3
0
 def save(self, save_path, save_name):
     ensure_path(save_path)
     #self.update_before_save()
     with open(save_path + save_name, 'wb') as f:
         self.to(torch.device('cpu'))
         torch.save(self.dict, f)
         self.to(self.device)
示例#4
0
    def plot_weight(self,
                    ax=None,
                    save=True,
                    save_path='./',
                    save_name='RNN_Navi_weight_plot.png',
                    cmap='jet'):
        if ax is None:
            plt.close('all')
            row_num, col_num = 2, 2
            fig, axes = plt.subplots(nrows=row_num,
                                     ncols=col_num,
                                     figsize=(5 * col_num, 5 *
                                              row_num))  # figsize unit: inches

        fig.suptitle('Weight Visualization of 1-layer RNN')

        # plot recurrent weight
        ax = axes[0, 0]  # raises error is row_num==col_num==1

        self.plot_recurrent_weight(ax, cmap)

        # plot input_weight
        if self.init_method in ['linear']:
            ax = axes[0, 1]
        elif self.init_method in ['mlp']:
            pass
        else:
            pass

        plt.tight_layout()
        if save:
            ensure_path(save_path)
            plt.savefig(save_path + save_name)
示例#5
0
文件: Trainers.py 项目: wwf194/EINet
    def __init__(self, dict_, load=False, options=None):
        '''
        if options is not None:
            self.receive_options(options)
        else:
            raise Exception('Trainer: options is none.')
        '''
        self.dict = dict_
        
        '''
        self.epoch_now = get_from_dict(self.dict, 'epoch_now', default=self.epoch_start, write_default=True)
        self.epoch_start = get_from_dict(self.dict, 'epoch_start', default=1, write_default=True)
        self.epoch_end = get_from_dict(self.dict, 'epoch_end', default=self.epoch_um, write_default=True)
        '''        
        self.epoch_now = 0
        #print(self.dict.keys())
        self.epoch_num = self.dict['epoch_num']
        self.epoch_end = self.epoch_num - 1

        # save directory setting
        self.save_model_path = search_dict(self.dict, ['save_model_path', 'save_dir_model', 'save_path_model'], 
            default='./SavedModels/', write_default=True, write_default_key='save_model_path')
        #print(self.save_model_path)
        ensure_path(self.save_model_path)

        self.save_model = get_from_dict(self.dict, 'save_model', default=True, write_default=True)
        self.save_after_train = get_from_dict(self.dict, 'save_after_train', default=True, write_default=True)
        self.save_before_train = get_from_dict(self.dict, 'save_before_train', default=True, write_default=True)

        if self.save_model:
            self.save_interval = get_from_dict(self.dict, 'save_model_interval', default=True, write_default=True)

        self.anal_path = search_dict(self.dict, ['anal_path'], default='./', write_default=True)
        #print(self.anal_path)
        ensure_path(self.anal_path)
def main(
    original_path: str,
    executable_path: str = EXECUTABLE_PATH,
    destination_path: str = None,
    overwrite: bool = False,
):
    """ Run batch conversion process.

    Args:
        original_path: path to the dataset to convert.
        executable_path: path to the converter executable.
        destination_path: destination path.
        overwrite: if True, any previous .edf
        with the same name will be overwrited.
    """
    print('1 - List the files to convert...')
    files = [
        os.path.abspath(file_) for file_ in list_files(original_path)
        if os.path.basename(file_).lower().endswith('.eeg')
    ]

    n_files = len(files)
    print('{0} file(s) will be converted.'.format(n_files))

    print('2 - Convert files')
    for index, file_ in enumerate(sorted(files), start=1):
        # Destination file path
        if destination_path is None:
            file_destination_path = file_[:-4] + '.EDF'
        else:
            file_destination_path = os.path.join(
                destination_path, os.path.relpath(file_, original_path))

        print(
            '({0}/{1}) Convert "{2}" to "{3}"'.format(
                index,
                n_files,
                file_,
                file_destination_path,
            ), )

        ensure_path(path=os.path.dirname(file_destination_path))

        if os.path.isfile(file_destination_path) and not overwrite:
            print('File has already been converted.')
        else:
            if os.path.isfile(file_destination_path):
                print('File has already been converted (will be overwrited).')
            convert_coh3_to_edf(
                executable_path=executable_path,
                eeg_path=file_,
                edf_path=file_destination_path,
            )

    if n_files:
        print('3 - Kill the converter process(es).')
        os.system(
            'taskkill /f /im {0}'.format(
                os.path.basename(executable_path), ), )
示例#7
0
 def anal(self, title=None, save_path=None, verbose=True):
     if save_path is None:
         if title is None:
             save_path = self.anal_path + 'epoch=%d/' % (self.epoch_index)
         else:
             save_path = self.anal_path + 'epoch=%s/' % (title)
     ensure_path(save_path)
     self.agent.anal(save_path=save_path, trainer=self)
示例#8
0
    def __init__(self, dict_, load=False):
        if options is not None:
            self.receive_options(options)

        self.dict = dict_
        #set_instance_variable(self, self.dict)
        self.epoch_num = self.dict['epoch_num']
        self.batch_num = self.dict['batch_num']
        self.batch_size = self.dict['batch_size']

        if not hasattr(self, 'anal_path'):
            self.anal_path = self.dict.setdefault('anal_path', './anal/')
        '''
        self.epoch_index = get_from_dict(self.dict, 'epoch_index', default=self.epoch_start, write_default=True)
        self.epoch_start = get_from_dict(self.dict, 'epoch_start', default=1, write_default=True)
        self.epoch_end = get_from_dict(self.dict, 'epoch_end', default=self.epoch_um, write_default=True)
        '''
        self.epoch_index = 0
        self.epoch_end = self.epoch_num - 1

        # save directory setting
        self.save_path = search_dict(
            self.dict, ['save_path', 'save_model_path', 'save_dir_model'],
            default='./saved_models/',
            write_default=True,
            write_default_key='save_path')
        ensure_path(self.save_path)

        self.save = search_dict(self.dict, ['save', 'save_model'],
                                default=True,
                                write_default=True)
        self.save_after_train = get_from_dict(self.dict,
                                              'save_after_train',
                                              default=True,
                                              write_default=True)
        self.save_before_train = get_from_dict(self.dict,
                                               'save_before_train',
                                               default=True,
                                               write_default=True)
        self.anal_before_train = get_from_dict(self.dict,
                                               'anal_before_train',
                                               default=True,
                                               write_default=True)

        if self.save:
            self.save_interval = search_dict(
                self.dict, ['save_interval', 'save_model_interval'],
                default=int(self.epoch_num / 10),
                write_default=True)
        '''
        if options is not None:
            self.options = options
            self.set_options()
        '''
        self.test_performs = self.dict['test_performs'] = {}
        self.train_performs = self.dict['train_performs'] = {}

        self.anal_model = self.dict.setdefault('anal_model', True)
def cross_val(input_path, output_dir, num_folds, dataset_name="data"):
    kf = KFold(n_splits=num_folds)
    data = np.array(utils.load_text_as_list(input_path))
    fold = 1
    for train_index, test_index in kf.split(data):
        fold_dir = os.path.join(output_dir, "fold_{}".format(fold))
        print("Creating fold {} at {}".format(fold, output_dir))
        data_train, data_test = data[train_index], data[test_index]
        utils.save_list_as_text(data_train, utils.ensure_path(os.path.join(fold_dir, "{}.train".format(dataset_name))))
        utils.save_list_as_text(data_test, utils.ensure_path(os.path.join(fold_dir, "{}.val".format(dataset_name))))
        fold += 1
示例#10
0
def ensure_path(params):
    """
    Ensure a certain path
    """
    params = utils.format_params(params)
    
    if not 'path' in params:
        abort('No path set')
    
    utils.ensure_path(path=params['path'])
    
    print(green("Ensure path `%s`." % (params['path']))) 
示例#11
0
def fetch_dict(url, dest):
    assert url.endswith('.tar.bz2'), url
    filename = os.path.basename(url)
    utils.fetch(url, os.path.join(DICT_DIR, filename))
    utils.run(["tar", "xjvf", filename], cwd=DICT_DIR)
    name = filename[:-len('.tar.bz2')]
    path = os.path.join(DICT_DIR, name)
    utils.run(["./configure", "--vars", "DESTDIR=tmp"], cwd=path)
    utils.run(["make"], cwd=path)
    utils.run(["make", "install"], cwd=path)
    result_dir = os.path.join(path, 'tmp/usr/lib/aspell')
    utils.ensure_path(dest)
    for dict_file in os.listdir(result_dir):
        shutil.copy2(os.path.join(result_dir, dict_file), os.path.join(dest, dict_file))
def fetch_dict(url, dest):
    assert url.endswith(".tar.bz2"), url
    filename = os.path.basename(url)
    utils.fetch(url, os.path.join(DICT_DIR, filename))
    utils.run(["tar", "xjvf", filename], cwd=DICT_DIR)
    name = filename[: -len(".tar.bz2")]
    path = os.path.join(DICT_DIR, name)
    utils.run(["./configure", "--vars", "DESTDIR=tmp"], cwd=path)
    utils.run(["make"], cwd=path)
    utils.run(["make", "install"], cwd=path)
    result_dir = os.path.join(path, "tmp/usr/lib/aspell")
    utils.ensure_path(dest)
    for dict_file in os.listdir(result_dir):
        shutil.copy2(os.path.join(result_dir, dict_file), os.path.join(dest, dict_file))
示例#13
0
 def plot_place_cells_coords(self, ax=None, arena=None, save=True, save_path='./', save_name='place_cells_coords.png'):
     arena = self.arenas.get_current_arena() if arena is None else arena
     if ax is None:
         plt.close('all')
         fig, ax = plt.subplots()
     arena.plot_arena(ax, save=False)
     ax.scatter(self.coords_np[:,0], self.coords_np[:,1], marker='d', color=(0.0, 1.0, 0.0), edgecolors=(0.0,0.0,0.0), label='Start Positions') # marker='d' for diamond
     ax.set_title('Place Cells Positions')
     if save:
         ensure_path(save_path)
         #cv.imwrite(save_path + save_name, imgs) # so that origin is in left-bottom corner.
         ensure_path(save_path)
         plt.savefig(save_path + save_name)
         plt.close()
    def directory(self):
        """Gets the workflow cache directory.

        .. note::
           The directory is calculated based off of the workflow environment variable. If such variable
           is not defined, then it defaults to:

           Alfred 3:
           ~/Library/Caches/com.runningwithcrayons.Alfred-3/Workflow Data/

           Alfred 2:
           ~/Library/Caches/com.runningwithcrayons.Alfred-2/Workflow Data/'

        :return: the workflow cache directory.
        :rtype: ``str``.
        """

        if not self._directory:
            if self.workflow.environment('workflow_cache'):
                self._directory = self.workflow.environment('workflow_cache')
            elif self.workflow.environment('version_build') >= 652:
                self._directory = os.path.join(
                    os.path.expanduser(
                        '~/Library/Caches/com.runningwithcrayons.Alfred-3/Workflow Data/'
                    ), self.workflow.bundle)
            elif self.workflow.environment('version_build') < 652:
                self._directory = os.path.join(
                    os.path.expanduser(
                        '~/Library/Caches/com.runningwithcrayons.Alfred-2/Workflow Data/'
                    ), self.workflow.bundle)

        return ensure_path(self._directory)
    def directory(self):
        """Gets the workflow cache directory.

        .. note::
           The directory is calculated based off of the workflow environment variable. If such variable
           is not defined, then it defaults to:

           Alfred 3:
           ~/Library/Caches/com.runningwithcrayons.Alfred-3/Workflow Data/

           Alfred 2:
           ~/Library/Caches/com.runningwithcrayons.Alfred-2/Workflow Data/'

        :return: the workflow cache directory.
        :rtype: ``str``.
        """

        if not self._directory:
            if self.workflow.environment('workflow_cache'):
                self._directory = self.workflow.environment('workflow_cache')
            elif self.workflow.environment('version_build') >= 652:
                self._directory = os.path.join(
                    os.path.expanduser('~/Library/Caches/com.runningwithcrayons.Alfred-3/Workflow Data/'),
                    self.workflow.bundle
                )
            elif self.workflow.environment('version_build') < 652:
                self._directory = os.path.join(
                    os.path.expanduser('~/Library/Caches/com.runningwithcrayons.Alfred-2/Workflow Data/'),
                    self.workflow.bundle
                )

        return ensure_path(self._directory)
示例#16
0
def main(args):
    utils.delete_path(args.log_dir)
    utils.delete_path(args.save_dir)
    utils.ensure_path(args.save_dir)
    utils.ensure_path(args.log_dir)
    utils.write_dict(vars(args), os.path.join(args.save_dir, 'arguments.csv'))

    torch.manual_seed(args.seed)
    cudnn.benchmark = True 
    torch.cuda.manual_seed_all(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus

    if args.mode == 'train':
        train(args)
    elif args.mode == 'test':
        test(args)
示例#17
0
    def plot_arena_plt(self, ax=None, save=True, save_path='./', save_name='arena_random_xy.png', line_width=2, color=(0,0,0)):
        if ax is None:
            figure, ax = plt.subplots()

        ax.set_xlim(self.x0, self.x1)
        ax.set_ylim(self.y0, self.y1)
        ax.set_xticks(np.linspace(self.x0, self.x1, 5))
        ax.set_yticks(np.linspace(self.y0, self.y1, 5))
        ax.set_aspect(1)

        circle = plt.Circle(self.center_coord, self.radius, color=(0.0, 0.0, 0.0), fill=False)        
        ax.add_patch(circle)
        if save:
            ensure_path(save_path)
            plt.savefig(save_path + save_name)

        return ax
示例#18
0
    def plot_random_xy_plt(self, ax=None, save=True, save_path='./', save_name='arena_random_xy.png', num=100, color=(0.0,1.0,0.0), plot_arena=True, **kw):
        if ax is None:
            figure, ax = plt.subplots()
            if plot_arena:
                self.plot_arena_plt(ax, save=False)
            else:
                ax.set_xlim(self.x0, self.x1)
                ax.set_ylim(self.y0, self.y1)
                ax.set_aspect(1) # so that x and y has same unit length in image.
        points = self.get_random_xy(num=100)
        ax.scatter(points[:,0], points[:,1], marker='o', color=color, edgecolors=(0.0,0.0,0.0), label='Start Positions')
        plt.legend()

        if save:
            ensure_path(save_path)
            plt.savefig(save_path + save_name)
        return ax
示例#19
0
 def plot_random_xy_cv(self, img=None, save=True, save_path='./', save_name='arena_random_xy.png', num=100, color=(0,255,0), plot_arena=True, **kw):
     if img is None:
         res = search_dict(kw, ['res, resolution'], default=100)
         res_x, res_y = get_res_xy(res, self.width, self.height)
         if plot_arena:
             img = self.plot_arena(save=False, res=res)
         else:
             img = np.zeros([res_x, res_y, 3], dtype=np.uint8)
             img[:,:,:] = (255, 255, 255)
     else:
         res_x, res_y = img.shape[0], img.shape[1]
     points = self.get_random_xy(num=100)
     points_int = get_int_coords_np(points, self.xy_range, res_x, res_y)
     for point in points_int:
         cv.circle(img, (point[0], point[1]), radius=0, color=color, thickness=4)
     if save:
         ensure_path(save_path)
         cv.imwrite(save_path + save_name, img[:, ::-1, :])
     return img
示例#20
0
def copy_project_files(args):
    path = args.path
    if path is None:
        raise Exception(
            'copy_project_files: args.path must not be none. please give path to copy files to'
        )
    ensure_path(args.path)
    if args.param_path is None:
        param_path = './params/'
    print(path)
    if not path.endswith('/'):
        path += '/'
    file_list = [
        #'cmd.py',
        'Models',
        'Optimizers',
        'Trainers.py',
        'DataLoaders.py',
        #'Analyzer.py',
        'utils.py',
        'utils_anal.py',
        'utils_model.py',
        'config_sys.py',
    ]
    copy_files(file_list, path_from='./src/', path_to=path + 'src/')
    file_list = ['main.py', 'params/__init__.py']
    copy_files(file_list, path_from='./', path_to=path)
    param_files = get_param_files(args)
    #param_files = list(map(lambda file:param_path + file, param_files))
    model_file = param_files['model_file']
    optimizer_file = param_files['optimizer_file']
    trainer_file = param_files['trainer_file']
    data_loader_file = param_files['data_loader_file']
    component_files = [
        model_file, optimizer_file, trainer_file, data_loader_file
    ]
    if param_files.get('config_file') is not None:
        component_files.append(param_files['config_file'])
    #print(component_files)
    copy_files(component_files,
               path_from=param_path,
               path_to=path + param_path)
示例#21
0
文件: video.py 项目: kunyilu/eta
    def run(self, inpath, outpath):
        '''Run the ffmpeg binary with the specified input/outpath paths.

        Args:
            inpath: the input path. If inpath is "-", input streaming mode is
                activated and data can be passed via the stream() method
            outpath: the output path. Existing files are overwritten, and the
                directory is created if needed. If outpath is "-", output
                streaming mode is activated and data can be read via the
                read() method

        Raises:
            ExecutableNotFoundError: if the ffmpeg binary cannot be found
            ExecutableRuntimeError: if the ffmpeg binary raises an error during
                execution
        '''
        self.is_input_streaming = (inpath == "-")
        self.is_output_streaming = (outpath == "-")

        self._args = (
            [self._executable] +
            self._global_opts +
            self._in_opts + ["-i", inpath] +
            self._out_opts + [outpath]
        )

        if not self.is_output_streaming:
            utils.ensure_path(outpath)

        try:
            self._p = Popen(self._args, stdin=PIPE, stdout=PIPE, stderr=PIPE)
        except OSError as e:
            if e.errno == errno.ENOENT:
                raise utils.ExecutableNotFoundError(self._executable)
            else:
                raise

        # Run non-streaming jobs immediately
        if not (self.is_input_streaming or self.is_output_streaming):
            err = self._p.communicate()[1]
            if self._p.returncode != 0:
                raise utils.ExecutableRuntimeError(self.cmd, err)
示例#22
0
文件: video.py 项目: kunyilu/eta
    def __init__(self, outpath, fps, size):
        '''Constructs a VideoWriter with OpenCV backend.

        Args:
            outpath: the output video path, e.g., "/path/to/video.mp4". Existing
                files are overwritten, and the directory is created if needed
            fps: the frame rate
            size: the (width, height) of each frame

        Raises:
            VideoWriterError: if the writer failed to open
        '''
        self.outpath = outpath
        self.fps = fps
        self.size = size
        self._writer = cv2.VideoWriter()

        utils.ensure_path(self.outpath)
        self._writer.open(self.outpath, -1, self.fps, self.size, True)
        if not self._writer.isOpened():
            raise VideoWriterError("Unable to open '%s'" % self.outpath)
示例#23
0
文件: Trainers.py 项目: wwf194/EINet
    def plot_perform(self, save_path='./', save_name='perform.png', col_num=3):
        # plot test_performs
        epochs = self.test_performs.keys()
        epochs = np.array(epochs, )
        epochs = np.sort(epochs)
        items = self.test_performs[epochs[0]].keys()
        item_num = len(items)

        row_num = item_num // col_num
        if item_num % col_num > 0:
            col_num += 1
        
        fig, axes = plt.subplots(nrow=row_num, ncol=col_num)

        for item, index in enumerate(items):
            row_index = index // row_num
            col_index = index % row_num
            ax = axes[row_index, col_index]
            ax.set_title
        plt.suptitle('%s Test Performance'%self.model.dict['name'])
        ensure_path(save_path)
        plt.savefig(save_path + save_name)
示例#24
0
def clone(params):
    """
    The clone command can be used to clone a new repo.
    If it's allready an existing path it will prompt for overwrite
    """
    print yellow("Warning git.clone is deprecated from version 1.0")

    if 'repo_path' in params:
        abort(red("repo_path is deprecated, use git_repo_path"))
        
    if 'repo_url' in params:
        abort(red("repo_url is deprecated, use git_repo_url"))
      
    # Try to get global params
    params = utils.get_global_params(params,
                                     'git_repo_path', 
                                     'git_repo_url')
    
    if 'git_repo_path' not in params:
        abort(red("git_repo_path can't be empty?"))
    
    if 'git_repo_url' not in params:
        abort(red("git_repo_url can't be empty?"))
    
    params = utils.format_params(params)
    
    if exists(params['git_repo_path']):
        if confirm("Repo path `%s` found, do you want to reinstall?" % params['git_repo_path']):
            print(yellow("Repo path `%s` will be deleted" % params['git_repo_path']))
            run('rm -Rf %s' % params['git_repo_path'])
        else:
            abort("Aborted...")
        
    utils.ensure_path(params['git_repo_path'])
    
    run('git clone --recursive %s %s' % (params['git_repo_url'], params['git_repo_path']))
         
    print(green("Repo `%s` successfully cloned" % params['git_repo_url']))
    def test_ensure_path(self):
        """ Test the function ensure_path. """
        with tempfile.TemporaryDirectory() as tmpdirname:
            subfolder_path = os.path.join(tmpdirname, 'folder', 'subfolder')
            folder_path = os.path.join(tmpdirname, 'folder')

            # No folder and subfolder should exist
            self.assertEqual(os.path.exists(subfolder_path), False)
            self.assertEqual(os.path.exists(folder_path), False)

            # Create folder and subfolder
            ensure_path(subfolder_path)

            # Folder and subfolder should exist
            self.assertEqual(os.path.exists(subfolder_path), True)
            self.assertEqual(os.path.exists(folder_path), True)

            # Should not change anything
            ensure_path(folder_path)

            # Folder and subfolder should exist
            self.assertEqual(os.path.exists(subfolder_path), True)
            self.assertEqual(os.path.exists(folder_path), True)
示例#26
0
def backup_db(params):
    """" 
    This command backups the database based on a backup folder
    The output dump will be a iso date formatted filename
    """
    print yellow("Warning mysql.backup_db is deprecated from version 1.0")
    params = utils.format_params(params)
    
    command = """
    mysqldump -h %(host)s -u %(user)s --password='******' %(database)s > %(backup_file)s
    """
    
    backup_path = os.path.dirname(params['backup_file'])
    utils.ensure_path(backup_path)
     
    # Make params
    command_params = {'user': params['user'],
                      'password': params['password'],
                      'database': params['database'],
                      'host': params['host'],
                      'backup_file':params['backup_file']}
    
    with hide('running'):
        run(command % command_params)      
    
    with cd(backup_path):
        filename = os.path.basename(params['backup_file'])
        clean_filename = os.path.splitext(filename)[0]
        tarfilename = "%s.tar.gz" % clean_filename
        run("tar czvf %s %s" % (tarfilename, filename))
        run("rm %s" % filename)
        
    full_tar_file_path = "%s/%s" % (backup_path ,tarfilename)
    print(green("Mysql backup `%s` successfully stored." % full_tar_file_path)) 
    
    if 'download_tar_to_local_file' in params:
        get(remote_path=full_tar_file_path, local_path=params['download_tar_to_local_file'])
示例#27
0
    def plot_arena_plt(self, ax=None, save=True, save_path='./', save_name='arena_random_xy.png', line_width=2, color=(0,0,0)):
        if ax is None:
            plt.close('all')
            fig, ax = plt.subplots()

        ax.set_xlim(self.x0, self.x1)
        ax.set_ylim(self.y0, self.y1)
        ax.set_xticks(np.linspace(self.x0, self.x1, 5))
        ax.set_yticks(np.linspace(self.y0, self.y1, 5))
        ax.set_aspect(1)
        
        vertices = self.vertices
        vertex_num = self.vertices.shape[0]

        for i in range(vertex_num):
            vertex_0 = vertices[i]
            vertex_1 = vertices[(i+1)%vertex_num]
            ax.add_line(Line2D([vertex_0[0], vertex_1[0]], [vertex_0[1], vertex_1[1]], linewidth=line_width, color=color))

        if save:
            ensure_path(save_path)
            plt.savefig(save_path + save_name)

        return ax
示例#28
0
    def plot_arena_cv(self, img=None, line_color=(0,0,0), line_width=1, line_type=4, save=False, save_path='./', save_name='arena_plot.png', **kw):# line_color: (b, g, r)
        if img is None:
            res = search_dict(kw, ['res, resolution'], default=100)
            res_x, res_y = get_res_xy(res, self.width, self.height)
            img = np.zeros([res_x, res_y, 3], dtype=np.uint8)
            img[:,:,:] = (255, 255, 255)

        vertices, width, height = self.dict['vertices'], self.width, self.height
        vertex_num = vertices.shape[0]
        
        #print('xy_range:%s'%str(xy_range))
        
        for i in range(vertex_num):
            vertex_0 = get_int_coords(vertices[i, 0], vertices[i, 1], self.xy_range, res_x, res_y)
            vertex_1 = get_int_coords(vertices[(i+1)%vertex_num, 0], vertices[(i+1)%vertex_num, 1], self.xy_range, res_x, res_y)
            #print('plot line: (%d, %d) to (%d, %d)'%(vertex_0[0], vertex_0[1], vertex_1[0], vertex_1[1]))
            cv.line(img, vertex_0, vertex_1, line_color, line_width, line_type) # line_width seems to be in unit of pixel.
        
        if save:
            ensure_path(save_path)
            #if save_name is None:
            #    save_name = 'arena_plot.png'
            cv.imwrite(save_path + save_name, img[:, ::-1, :]) # so that origin is in left-bottom corner.
        return img
示例#29
0
def main():
    # Basic configuration
    data_dir = utils.ensure_path('~/data_analysis/fcb/')
#     group_id = 153748404666241
    group_id = 597682743580084
    
    # Load basic arguments
    log("Parsing basic arguments")
    missing_arg = utils.check_required_arg(utils.login_opt, utils.password_opt, utils.ch_driver_opt)
    if missing_arg is not None:
        utils.exit_program('Missing required argument ' + missing_arg)
    login = utils.get_arg(utils.login_opt)
    password = utils.get_arg(utils.password_opt)
    driver_path = utils.get_arg(utils.ch_driver_opt)


    log('Starting browser')
    browser = webdriver.Chrome(executable_path=driver_path)

    scrapper = FcbBrowserScrapper(browser, login, password)
    scrapper.log_in()
    
    # Scrap
    users = scrapper.scrap_group_members(group_id)

#     Temporary code for loading user from file rather from the web
#     to speed up the development
#     users_file = utils.ensure_path('~/data_analysis/mlyny_group.txt')
#     utils.save_data(users, users_file)
#     users = []
#     with open(users_file, encoding='utf-8') as f:
#         for line in f:
#             if line == '':
#                 break
#             s = line.split(',')
#             users.append((s[0], s[1]))

    # Process
    scrapper.process_users(users, data_dir)
    

    
 
    log('Closing the browser')
    browser.close()
    log('The end')
示例#30
0
            str(args.shot),
            str(args.query),
            str(args.way),
            str(args.validation_way),
            str(args.step_size),
            str(args.gamma),
            str(args.lr),
            str(args.temperature),
            str(args.hyperbolic),
            str(args.dim),
            str(args.c)[:5],
            str(args.train_c),
            str(args.train_x)
        ])
        args.save_path = save_path1 + '_' + save_path2 + '_' + args.exp_addendum
        ensure_path(args.save_path)
    else:
        ensure_path(args.save_path)

    if args.dataset == 'MiniImageNet':
        # Handle MiniImageNet
        from dataloader.mini_imagenet import MiniImageNet as Dataset
    elif args.dataset == 'CUB':
        from dataloader.cub import CUB as Dataset
    else:
        raise ValueError('Non-supported Dataset.')

    # train n_batch is 100 by default, val n_batch is 500 by default
    trainset = Dataset('train', args)
    train_sampler = CategoriesSampler(trainset.label, 100, args.way,
                                      args.shot + args.query)
WINE_RN_WIN_DIR = os.path.join(WINE_RN_DIR, 'win')
PYINSTALLER = os.path.join(DRIVE_C, 'PyInstaller-2.1', 'pyinstaller.py')
SPEC = os.path.join(BASE_DIR, 'win', 'rednotebook.spec')
WINE_SPEC = os.path.join(WINE_RN_WIN_DIR, 'rednotebook.spec')
WINE_BUILD = os.path.join(DRIVE_C, 'build')
WINE_DIST = os.path.join(DRIVE_C, 'dist')
LOCALE_DIR = os.path.join(WINE_DIST, 'share', 'locale')
WINE_RN_EXE = os.path.join(WINE_DIST, 'rednotebook.exe')
WINE_PYTHON = os.path.join(DRIVE_C, 'Python27', 'python.exe')

if os.path.exists(WINE_DIR):
    answer = raw_input('The build dir exists. Overwrite it? (Y/n): ').strip()
    if answer and answer.lower() != 'y':
        sys.exit('Aborting')
    shutil.rmtree(WINE_DIR)
os.environ['WINEPREFIX'] = WINE_DIR
os.mkdir(WINE_DIR)
run(['tar', '-xzf', WINE_TARBALL, '--directory', WINE_DIR])

archive = '/tmp/rednotebook-archive.tar'
run(['git', 'archive', 'HEAD', '-o', archive], cwd=BASE_DIR)
utils.ensure_path(WINE_RN_DIR)
run(['tar', '-xf', archive], cwd=WINE_RN_DIR)
shutil.copy2(SPEC, WINE_SPEC)

run(['wine', WINE_PYTHON, PYINSTALLER, '--workpath', WINE_BUILD,
     '--distpath', DRIVE_C, WINE_SPEC])  # will be built at ...DRIVE_C/dist
run(['./build-translations.py', LOCALE_DIR], cwd=DIR)

#run(['wine', WINE_RN_EXE])
示例#32
0
BASE_DIR = os.path.dirname(DIR)
DIST_DIR = os.path.abspath(args.dist_dir)
DRIVE_C = os.path.join(DIST_DIR, 'drive_c')
BUILD_DIR = os.path.abspath(args.build_dir)
assert os.path.exists(BUILD_DIR), BUILD_DIR
WINE_RN_DIR = os.path.join(DRIVE_C, 'rednotebook')
SPEC = os.path.join(BASE_DIR, 'win', 'rednotebook.spec')
WINE_BUILD = os.path.join(DRIVE_C, 'build')
WINE_DIST = os.path.join(DRIVE_C, 'dist')
LOCALE_DIR = os.path.join(WINE_DIST, 'share', 'locale')
WINE_RN_EXE = os.path.join(WINE_DIST, 'rednotebook.exe')
WINE_PYTHON = os.path.join(DRIVE_C, 'Python34', 'python.exe')

utils.confirm_overwrite(DIST_DIR)
os.environ['WINEPREFIX'] = DIST_DIR
utils.ensure_path(os.path.dirname(DIST_DIR))
print('Start copying {} to {}'.format(BUILD_DIR, DIST_DIR))
utils.fast_copytree(BUILD_DIR, DIST_DIR)
print('Finished copying')

archive = '/tmp/rednotebook-archive.tar'
stash_name = utils.get_output(['git', 'stash', 'create'], cwd=BASE_DIR)
stash_name = stash_name or 'HEAD'
run(['git', 'archive', stash_name, '-o', archive], cwd=BASE_DIR)
utils.ensure_path(WINE_RN_DIR)
run(['tar', '-xf', archive], cwd=WINE_RN_DIR)

os.mkdir(os.path.join(DRIVE_C, 'Python34/share'))
shutil.copytree(os.path.join(DRIVE_C, 'Python34/Lib/site-packages/gnome/share/gir-1.0/'), os.path.join(DRIVE_C, 'Python34/share/gir-1.0/'))

run(['wine', WINE_PYTHON, '-m', 'PyInstaller', '--workpath', WINE_BUILD,
def main(config):
    svname = args.name
    if svname is None:
        svname = 'classifier_{}'.format(config['train_dataset'])
        svname += '_' + config['model_args']['encoder']
        clsfr = config['model_args']['classifier']
        if clsfr != 'linear-classifier':
            svname += '-' + clsfr
    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    # train
    train_dataset = datasets.make(config['train_dataset'],
                                  **config['train_dataset_args'])
    augmentations = [
        transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomResizedCrop(size=(80, 80),
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.3333)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomRotation(35),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomResizedCrop(size=(80, 80),
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.3333)),
            transforms.RandomRotation(35),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomRotation(35),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomResizedCrop(size=(80, 80),
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.3333)),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomRotation(35),
            transforms.RandomResizedCrop(size=(80, 80),
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.3333)),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    ]
    train_dataset.transform = augmentations[int(config['_a'])]
    print(train_dataset.transform)
    print("_a", config['_a'])
    input("Continue with these augmentations?")

    train_loader = DataLoader(train_dataset,
                              config['batch_size'],
                              shuffle=True,
                              num_workers=0,
                              pin_memory=True)
    utils.log('train dataset: {} (x{}), {}'.format(train_dataset[0][0].shape,
                                                   len(train_dataset),
                                                   train_dataset.n_classes))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(train_dataset, 'train_dataset', writer)

    # val
    if config.get('val_dataset'):
        eval_val = True
        val_dataset = datasets.make(config['val_dataset'],
                                    **config['val_dataset_args'])
        val_loader = DataLoader(val_dataset,
                                config['batch_size'],
                                num_workers=0,
                                pin_memory=True)
        utils.log('val dataset: {} (x{}), {}'.format(val_dataset[0][0].shape,
                                                     len(val_dataset),
                                                     val_dataset.n_classes))
        if config.get('visualize_datasets'):
            utils.visualize_dataset(val_dataset, 'val_dataset', writer)
    else:
        eval_val = False

    # few-shot eval
    if config.get('fs_dataset'):
        ef_epoch = config.get('eval_fs_epoch')
        if ef_epoch is None:
            ef_epoch = 5
        eval_fs = True

        fs_dataset = datasets.make(config['fs_dataset'],
                                   **config['fs_dataset_args'])
        utils.log('fs dataset: {} (x{}), {}'.format(fs_dataset[0][0].shape,
                                                    len(fs_dataset),
                                                    fs_dataset.n_classes))
        if config.get('visualize_datasets'):
            utils.visualize_dataset(fs_dataset, 'fs_dataset', writer)

        n_way = 5
        n_query = 15
        n_shots = [1, 5]
        fs_loaders = []
        for n_shot in n_shots:
            fs_sampler = CategoriesSampler(fs_dataset.label,
                                           200,
                                           n_way,
                                           n_shot + n_query,
                                           ep_per_batch=4)
            fs_loader = DataLoader(fs_dataset,
                                   batch_sampler=fs_sampler,
                                   num_workers=0,
                                   pin_memory=True)
            fs_loaders.append(fs_loader)
    else:
        eval_fs = False

    ########

    #### Model and Optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])

    if eval_fs:
        fs_model = models.make('meta-baseline', encoder=None)
        fs_model.encoder = model.encoder

    if config.get('_parallel'):
        model = nn.DataParallel(model)
        if eval_fs:
            fs_model = nn.DataParallel(fs_model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(model.parameters(),
                                                   config['optimizer'],
                                                   **config['optimizer_args'])

    ########

    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    for epoch in range(1, max_epoch + 1 + 1):
        if epoch == max_epoch + 1:
            if not config.get('epoch_ex'):
                break
            train_dataset.transform = train_dataset.default_transform
            print(train_dataset.transform)
            train_loader = DataLoader(train_dataset,
                                      config['batch_size'],
                                      shuffle=True,
                                      num_workers=0,
                                      pin_memory=True)

        timer_epoch.s()
        aves_keys = ['tl', 'ta', 'vl', 'va']
        if eval_fs:
            for n_shot in n_shots:
                aves_keys += ['fsa-' + str(n_shot)]
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        for data, label in tqdm(train_loader, desc='train', leave=False):
            # for data, label in train_loader:
            data, label = data.cuda(), label.cuda()
            logits = model(data)
            loss = F.cross_entropy(logits, label)
            acc = utils.compute_acc(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None
            loss = None

        # eval
        if eval_val:
            model.eval()
            for data, label in tqdm(val_loader, desc='val', leave=False):
                data, label = data.cuda(), label.cuda()
                with torch.no_grad():
                    logits = model(data)
                    loss = F.cross_entropy(logits, label)
                    acc = utils.compute_acc(logits, label)

                aves['vl'].add(loss.item())
                aves['va'].add(acc)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            fs_model.eval()
            for i, n_shot in enumerate(n_shots):
                np.random.seed(0)
                for data, _ in tqdm(fs_loaders[i],
                                    desc='fs-' + str(n_shot),
                                    leave=False):
                    x_shot, x_query = fs.split_shot_query(data.cuda(),
                                                          n_way,
                                                          n_shot,
                                                          n_query,
                                                          ep_per_batch=4)
                    label = fs.make_nk_label(n_way, n_query,
                                             ep_per_batch=4).cuda()
                    with torch.no_grad():
                        logits = fs_model(x_shot, x_query).view(-1, n_way)
                        acc = utils.compute_acc(logits, label)
                    aves['fsa-' + str(n_shot)].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)

        if epoch <= max_epoch:
            epoch_str = str(epoch)
        else:
            epoch_str = 'ex'
        log_str = 'epoch {}, train {:.4f}|{:.4f}'.format(
            epoch_str, aves['tl'], aves['ta'])
        writer.add_scalars('loss', {'train': aves['tl']}, epoch)
        writer.add_scalars('acc', {'train': aves['ta']}, epoch)

        if eval_val:
            log_str += ', val {:.4f}|{:.4f}'.format(aves['vl'], aves['va'])
            writer.add_scalars('loss', {'val': aves['vl']}, epoch)
            writer.add_scalars('acc', {'val': aves['va']}, epoch)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            log_str += ', fs'
            for n_shot in n_shots:
                key = 'fsa-' + str(n_shot)
                log_str += ' {}: {:.4f}'.format(n_shot, aves[key])
                writer.add_scalars('acc', {key: aves[key]}, epoch)

        if epoch <= max_epoch:
            log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate)
        else:
            log_str += ', {}'.format(t_epoch)
        utils.log(log_str)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,
            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),
            'training': training,
        }
        if epoch <= max_epoch:
            torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))

            if (save_epoch is not None) and epoch % save_epoch == 0:
                torch.save(
                    save_obj,
                    os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

            if aves['va'] > max_va:
                max_va = aves['va']
                torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))
        else:
            torch.save(save_obj, os.path.join(save_path, 'epoch-ex.pth'))

        writer.flush()
示例#34
0
def main(config):
    # Environment setup
    save_dir = config['save_dir']
    utils.ensure_path(save_dir)
    with open(osp.join(save_dir, 'config.yaml'), 'w') as f:
        yaml.dump(config, f, sort_keys=False)
    global log, writer
    logger = set_logger(osp.join(save_dir, 'log.txt'))
    log = logger.info
    writer = SummaryWriter(osp.join(save_dir, 'tensorboard'))

    os.environ['WANDB_NAME'] = config['exp_name']
    os.environ['WANDB_DIR'] = config['save_dir']
    if not config.get('wandb_upload', False):
        os.environ['WANDB_MODE'] = 'dryrun'
    t = config['wandb']
    os.environ['WANDB_API_KEY'] = t['api_key']
    wandb.init(project=t['project'], entity=t['entity'], config=config)

    log('logging init done.')
    log(f'wandb id: {wandb.run.id}')

    # Dataset, model and optimizer
    train_dataset = datasets.make((config['train_dataset']))
    test_dataset = datasets.make((config['test_dataset']))

    model = models.make(config['model'], args=None).cuda()
    log(f'model #params: {utils.compute_num_params(model)}')

    n_gpus = len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
    if n_gpus > 1:
        model = nn.DataParallel(model)

    optimizer = utils.make_optimizer(model.parameters(), config['optimizer'])

    train_loader = DataLoader(train_dataset, config['batch_size'], shuffle=True,
                              num_workers=8, pin_memory=True)
    test_loader = DataLoader(test_dataset, config['batch_size'],
                             num_workers=8, pin_memory=True)

    # Ready for training
    max_epoch = config['max_epoch']
    n_milestones = config.get('n_milestones', 1)
    milestone_epoch = max_epoch // n_milestones
    min_test_loss = 1e18

    sample_batch_train = sample_data_batch(train_dataset).cuda()
    sample_batch_test = sample_data_batch(test_dataset).cuda()

    epoch_timer = utils.EpochTimer(max_epoch)
    for epoch in range(1, max_epoch + 1):
        log_text = f'epoch {epoch}'

        # Train
        model.train()

        adjust_lr(optimizer, epoch, max_epoch, config)
        log_temp_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        ave_scalars = {k: utils.Averager() for k in ['loss']}

        pbar = tqdm(train_loader, desc='train', leave=False)
        for data in pbar:
            data = data.cuda()
            t = train_step(model, data, data, optimizer)
            for k, v in t.items():
                ave_scalars[k].add(v, len(data))
            pbar.set_description(desc=f"train loss:{t['loss']:.4f}")

        log_text += ', train:'
        for k, v in ave_scalars.items():
            v = v.item()
            log_text += f' {k}={v:.4f}'
            log_temp_scalar('train/' + k, v, epoch)

        # Test
        model.eval()

        ave_scalars = {k: utils.Averager() for k in ['loss']}

        pbar = tqdm(test_loader, desc='test', leave=False)
        for data in pbar:
            data = data.cuda()
            t = eval_step(model, data, data)
            for k, v in t.items():
                ave_scalars[k].add(v, len(data))
            pbar.set_description(desc=f"test loss:{t['loss']:.4f}")

        log_text += ', test:'
        for k, v in ave_scalars.items():
            v = v.item()
            log_text += f' {k}={v:.4f}'
            log_temp_scalar('test/' + k, v, epoch)

        test_loss = ave_scalars['loss'].item()

        if epoch % milestone_epoch == 0:
            with torch.no_grad():
                pred = model(sample_batch_train).clamp(0, 1)
                video_batch = torch.cat([sample_batch_train, pred], dim=0)
                log_temp_videos('train/videos', video_batch, epoch)
                img_batch = video_batch[:, :, 3, :, :]
                log_temp_images('train/images', img_batch, epoch)

                pred = model(sample_batch_test).clamp(0, 1)
                video_batch = torch.cat([sample_batch_test, pred], dim=0)
                log_temp_videos('test/videos', video_batch, epoch)
                img_batch = video_batch[:, :, 3, :, :]
                log_temp_images('test/images', img_batch, epoch)

        # Summary and save
        log_text += ', {} {}/{}'.format(*epoch_timer.step())
        log(log_text)

        model_ = model.module if n_gpus > 1 else model
        model_spec = config['model']
        model_spec['sd'] = model_.state_dict()
        optimizer_spec = config['optimizer']
        optimizer_spec['sd'] = optimizer.state_dict()
        pth_file = {
            'model': model_spec,
            'optimizer': optimizer_spec,
            'epoch': epoch,
        }

        if test_loss < min_test_loss:
            min_test_loss = test_loss
            wandb.run.summary['min_test_loss'] = min_test_loss
            torch.save(pth_file, osp.join(save_dir, 'min-test-loss.pth'))

        torch.save(pth_file, osp.join(save_dir, 'epoch-last.pth'))

        writer.flush()
示例#35
0
 def write_info(self, save_path='./', save_name='Arena Dict.txt'):
     ensure_path(save_path)
     utils.write_dict()
     with open(save_path + save_name, 'w') as f:
         return # to be implemented
示例#36
0
    def plot_place_cells(self, act_map=None, arena=None, res=50, plot_num=100, col_num=15, save=True, save_path='./', save_name='place_cells_plot.png', cmap='jet'):
        arena = self.arena if arena is None else arena
        
        if act_map is None:
            act_map, arena_mask = self.get_act_map(arena=arena, res=res) # [N_num, res_x, res_y]
        else:
            res_x, res_y = act_map.shape[1], act_map.shape[2]
            arena_mask = arena.get_mask(res_x=res_x, res_y=res_y)

        act_map = act_map[:, ::-1, :] # when plotting image, default origin is on top-left corner.
        act_max = np.max(act_map)
        act_min = np.min(act_map)
        #print('PlaceCells.plot_place_cells: act_min:%.2e act_max:%.2e'%(act_min, act_max))

        if plot_num < self.N_num:
            plot_index = np.sort(random.sample(range(self.N_num), plot_num)) # default order: ascending.
        else:
            plot_num = self.N_num
            plot_index = range(self.N_num)

        img_num = plot_num + 1
        row_num = ( plot_num + 1 ) // col_num
        if img_num % col_num > 0:
            row_num += 1

        #print('row_num:%d col_num:%d'%(row_num, col_num))
        fig, axes = plt.subplots(nrows=row_num, ncols=col_num, figsize=(5*col_num, 5*row_num))
        
        act_map_norm = ( act_map - act_min ) / (act_max - act_min) # normalize to [0, 1]

        cmap_func = plt.cm.get_cmap(cmap)
        act_map_mapped = cmap_func(act_map_norm) # [N_num, res_x, res_y, (r,g,b,a)]
        for i in range(act_map_mapped.shape[0]):
            act_map_mapped[i,:,:,:] = cv.GaussianBlur(act_map_mapped[i,:,:,:], ksize=(3,3), sigmaX=1, sigmaY=1)
        
        arena_mask_white = (~arena_mask).astype(np.int)[:, :, np.newaxis] * np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float)
        #print(np.sum(arena_mask.astype(np.int)))
        #print(~arena_mask)
        #print(arena_mask_white.shape)
        act_map_mapped = act_map_mapped * arena_mask.astype(np.int)[np.newaxis, :, :, np.newaxis] + arena_mask_white[np.newaxis, :, :, :]
        #arena_mask = arena_mask[np.newaxis, :, :]
        #for i in range(act_map_mapped.shape[0]):
        #    act_map_mapped[i, arena_mask] = (1.0, 1.0, 1.0, 1.0)

        for i in range(plot_num):
            row_index = (i+1) // col_num
            col_index = (i+1) % col_num
            N_index = plot_index[i]
            ax = axes[row_index, col_index]
            im = ax.imshow(act_map_mapped[N_index], extent=(arena.x0, arena.x1, arena.y0, arena.y1)) # extent: rescaling axis to arena size.
            ax.set_xticks(np.linspace(arena.x0, arena.x1, 5))
            ax.set_yticks(np.linspace(arena.y0, arena.y1, 5))
            ax.set_aspect(1)
            ax.set_title('Place Cells No.%d @ (%.2f, %.2f)'%(N_index, self.xy[N_index][0], self.xy[N_index][1]))
            #self.arena.plot_arena(ax, save=False)
            
        for i in range(plot_num + 1, row_num * col_num):
            row_index = i // col_num
            col_index = i % col_num
            ax = axes[row_index, col_index]
            ax.axis('off')
        
        ax = axes[0, 0]
        ax.axis('off')
        norm = mpl.colors.Normalize(vmin=act_min, vmax=act_max)
        ax_ = ax.inset_axes([0.0, 0.4, 1.0, 0.2]) # left, bottom, width, height. all are ratios to sub-canvas of ax.
        cbar = ax.figure.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), 
            cax=ax_, 
            #pad=.05, # ?Fraction of original axes between colorbar and new image axes
            fraction=1.0, 
            ticks=np.linspace(act_min, act_max, num=5),
            aspect = 5, # ratio of colorbar height to width
            anchor=(0.5, 0.5), # coord of anchor point of colorbar
            panchor=(0.5, 0.5), # coord of colorbar's anchor point in parent ax.
            orientation='horizontal')
        cbar.set_label('Average fire rate', loc='center')

        ax.axis('off')
        
        if save:
            ensure_path(save_path)
            #cv.imwrite(save_path + save_name, imgs) # so that origin is in left-bottom corner.
            plt.savefig(save_path + save_name)
            plt.close()
示例#37
0
    def process_users(self, users, directory):
        
        utils.ensure_path(directory + '/pics/')

        c = 0
        stale_count = 0
        no_such_el_count = 0
        unprocessed_ids = []
        bad_page_ids = []
        html_file = open(directory + '/index.html', 'w', encoding='utf-8')
        html_file.write('<html><body>\n')

        for user in users:
            try:
                c += 1
                log(user)
                user_id = user[1].strip()
                log('Processing ' + user_id + " ({}/{})".format(c, len(users)))

                url = FcbBrowserScrapper.root_url + user_id
                self.browser.get(url)

                a_el = self.browser.find_element_by_class_name('profilePicThumb')
                a_el.click()

                while True:
                    try:
                        img_el = self.browser.find_element_by_class_name('spotlight')
                        break
                    except NoSuchElementException:
                            try:
                                img_el = self.browser.find_element_by_css_selector('img._4-od')
                                break
                            except NoSuchElementException:
                                pass
                    # The ajax call has not finished yet so lets wait 1 second
                    # and repeat it again
                    time.sleep(1)
                
                try:
                    pic_url = img_el.get_attribute('src')
                except StaleElementReferenceException:
                    stale_count += 1
                    utils.log('Stale element encountered ({})'.format(stale_count))
                    unprocessed_ids.append(user_id)
                    continue

                # Retrieve and... 
                response = requests.get(pic_url)
                # ...save the photo
                photo_file = open(directory + '/pics/{}.jpg'.format(user_id), 'wb')
                photo_file.write(response.content)

            except NoSuchElementException:
                utils.log("Failed scrapping page for user {} - no such element".format(user_id))
                unprocessed_ids.append(user_id)
                bad_page_ids.append(user_id)
                no_such_el_count += 1

            
            
            #
            page_row = '<a href="{}"><img src="pics/{}" title="{}"></a>'.format(url, user_id + '.jpg', user[0])
            html_file.write(page_row + '\n')

        utils.log('Total staled: {}'.format(stale_count))
        utils.log('Total not found: {}'.format(no_such_el_count))
        utils.log('Unprocessed ids: {}'.format(unprocessed_ids))
        html_file.write('</body></html>')
示例#38
0
def main(config):
    svname = args.name
    if svname is None:
        svname = 'meta_{}-{}shot'.format(config['train_dataset'],
                                         config['n_shot'])
        svname += '_' + config['model']
        if config['model_args'].get('encoder'):
            svname += '-' + config['model_args']['encoder']
        if config['model_args'].get('prog_synthesis'):
            svname += '-' + config['model_args']['prog_synthesis']
    svname += '-seed' + str(args.seed)
    if args.tag is not None:
        svname += '_' + args.tag

    save_path = os.path.join(args.save_dir, svname)
    utils.ensure_path(save_path, remove=False)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    logger = utils.Logger(file_name=os.path.join(save_path, "log_sdout.txt"),
                          file_mode="a+",
                          should_flush=True)

    #### Dataset ####

    n_way, n_shot = config['n_way'], config['n_shot']
    n_query = config['n_query']

    if config.get('n_train_way') is not None:
        n_train_way = config['n_train_way']
    else:
        n_train_way = n_way
    if config.get('n_train_shot') is not None:
        n_train_shot = config['n_train_shot']
    else:
        n_train_shot = n_shot
    if config.get('ep_per_batch') is not None:
        ep_per_batch = config['ep_per_batch']
    else:
        ep_per_batch = 1

    random_state = np.random.RandomState(args.seed)
    print('seed:', args.seed)

    # train
    train_dataset = datasets.make(config['train_dataset'],
                                  **config['train_dataset_args'])
    utils.log('train dataset: {} (x{})'.format(train_dataset[0][0].shape,
                                               len(train_dataset)))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(train_dataset, 'train_dataset', writer)
    train_sampler = BongardSampler(train_dataset.n_tasks,
                                   config['train_batches'], ep_per_batch,
                                   random_state.randint(2**31))
    train_loader = DataLoader(train_dataset,
                              batch_sampler=train_sampler,
                              num_workers=8,
                              pin_memory=True)

    # tvals
    tval_loaders = {}
    tval_name_ntasks_dict = {
        'tval': 2000,
        'tval_ff': 600,
        'tval_bd': 480,
        'tval_hd_comb': 400,
        'tval_hd_novel': 320
    }  # numbers depend on dataset
    for tval_type in tval_name_ntasks_dict.keys():
        if config.get('{}_dataset'.format(tval_type)):
            tval_dataset = datasets.make(
                config['{}_dataset'.format(tval_type)],
                **config['{}_dataset_args'.format(tval_type)])
            utils.log('{} dataset: {} (x{})'.format(tval_type,
                                                    tval_dataset[0][0].shape,
                                                    len(tval_dataset)))
            if config.get('visualize_datasets'):
                utils.visualize_dataset(tval_dataset, 'tval_ff_dataset',
                                        writer)
            tval_sampler = BongardSampler(
                tval_dataset.n_tasks,
                n_batch=tval_name_ntasks_dict[tval_type] // ep_per_batch,
                ep_per_batch=ep_per_batch,
                seed=random_state.randint(2**31))
            tval_loader = DataLoader(tval_dataset,
                                     batch_sampler=tval_sampler,
                                     num_workers=8,
                                     pin_memory=True)
            tval_loaders.update({tval_type: tval_loader})
        else:
            tval_loaders.update({tval_type: None})

    # val
    val_dataset = datasets.make(config['val_dataset'],
                                **config['val_dataset_args'])
    utils.log('val dataset: {} (x{})'.format(val_dataset[0][0].shape,
                                             len(val_dataset)))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(val_dataset, 'val_dataset', writer)
    val_sampler = BongardSampler(val_dataset.n_tasks,
                                 n_batch=900 // ep_per_batch,
                                 ep_per_batch=ep_per_batch,
                                 seed=random_state.randint(2**31))
    val_loader = DataLoader(val_dataset,
                            batch_sampler=val_sampler,
                            num_workers=8,
                            pin_memory=True)

    ########

    #### Model and optimizer ####

    if config.get('load'):
        print('loading pretrained model: ', config['load'])
        model = models.load(torch.load(config['load']))
    else:
        model = models.make(config['model'], **config['model_args'])

        if config.get('load_encoder'):
            print('loading pretrained encoder: ', config['load_encoder'])
            encoder = models.load(torch.load(config['load_encoder'])).encoder
            model.encoder.load_state_dict(encoder.state_dict())

        if config.get('load_prog_synthesis'):
            print('loading pretrained program synthesis model: ',
                  config['load_prog_synthesis'])
            prog_synthesis = models.load(
                torch.load(config['load_prog_synthesis']))
            model.prog_synthesis.load_state_dict(prog_synthesis.state_dict())

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(model.parameters(),
                                                   config['optimizer'],
                                                   **config['optimizer_args'])

    ########

    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    aves_keys = ['tl', 'ta', 'vl', 'va']
    tval_tuple_lst = []
    for k, v in tval_loaders.items():
        if v is not None:
            loss_key = 'tvl' + k.split('tval')[-1]
            acc_key = ' tva' + k.split('tval')[-1]
            aves_keys.append(loss_key)
            aves_keys.append(acc_key)
            tval_tuple_lst.append((k, v, loss_key, acc_key))

    trlog = dict()
    for k in aves_keys:
        trlog[k] = []

    for epoch in range(1, max_epoch + 1):
        timer_epoch.s()
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        if config.get('freeze_bn'):
            utils.freeze_bn(model)
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        for data, label in tqdm(train_loader, desc='train', leave=False):

            x_shot, x_query = fs.split_shot_query(data.cuda(),
                                                  n_train_way,
                                                  n_train_shot,
                                                  n_query,
                                                  ep_per_batch=ep_per_batch)
            label_query = fs.make_nk_label(n_train_way,
                                           n_query,
                                           ep_per_batch=ep_per_batch).cuda()

            if config['model'] == 'snail':  # only use one selected label_query
                query_dix = random_state.randint(n_train_way * n_query)
                label_query = label_query.view(ep_per_batch, -1)[:, query_dix]
                x_query = x_query[:, query_dix:query_dix + 1]

            if config['model'] == 'maml':  # need grad in maml
                model.zero_grad()

            logits = model(x_shot, x_query).view(-1, n_train_way)
            loss = F.cross_entropy(logits, label_query)
            acc = utils.compute_acc(logits, label_query)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None
            loss = None

        # eval
        model.eval()

        for name, loader, name_l, name_a in [('val', val_loader, 'vl', 'va')
                                             ] + tval_tuple_lst:

            if config.get('{}_dataset'.format(name)) is None:
                continue

            np.random.seed(0)
            for data, _ in tqdm(loader, desc=name, leave=False):
                x_shot, x_query = fs.split_shot_query(
                    data.cuda(),
                    n_way,
                    n_shot,
                    n_query,
                    ep_per_batch=ep_per_batch)
                label_query = fs.make_nk_label(
                    n_way, n_query, ep_per_batch=ep_per_batch).cuda()

                if config[
                        'model'] == 'snail':  # only use one randomly selected label_query
                    query_dix = random_state.randint(n_train_way)
                    label_query = label_query.view(ep_per_batch, -1)[:,
                                                                     query_dix]
                    x_query = x_query[:, query_dix:query_dix + 1]

                if config['model'] == 'maml':  # need grad in maml
                    model.zero_grad()
                    logits = model(x_shot, x_query, eval=True).view(-1, n_way)
                    loss = F.cross_entropy(logits, label_query)
                    acc = utils.compute_acc(logits, label_query)
                else:
                    with torch.no_grad():
                        logits = model(x_shot, x_query,
                                       eval=True).view(-1, n_way)
                        loss = F.cross_entropy(logits, label_query)
                        acc = utils.compute_acc(logits, label_query)

                aves[name_l].add(loss.item())
                aves[name_a].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()
            trlog[k].append(aves[k])

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)
        log_str = 'epoch {}, train {:.4f}|{:.4f}, val {:.4f}|{:.4f}'.format(
            epoch, aves['tl'], aves['ta'], aves['vl'], aves['va'])
        for tval_name, _, loss_key, acc_key in tval_tuple_lst:
            log_str += ', {} {:.4f}|{:.4f}'.format(tval_name, aves[loss_key],
                                                   aves[acc_key])
            writer.add_scalars('loss', {tval_name: aves[loss_key]}, epoch)
            writer.add_scalars('acc', {tval_name: aves[acc_key]}, epoch)
        log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate)
        utils.log(log_str)

        writer.add_scalars('loss', {
            'train': aves['tl'],
            'val': aves['vl'],
        }, epoch)
        writer.add_scalars('acc', {
            'train': aves['ta'],
            'val': aves['va'],
        }, epoch)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,
            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),
            'training': training,
        }
        torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))
        torch.save(trlog, os.path.join(save_path, 'trlog.pth'))

        if (save_epoch is not None) and epoch % save_epoch == 0:
            torch.save(save_obj,
                       os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

        if aves['va'] > max_va:
            max_va = aves['va']
            torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))

        writer.flush()

    print('finished training!')
    logger.close()
示例#39
0
from diabeticretinopathy import DiabeticRetinopathy
from samplers import CategoriesSampler
from convnet import Convnet
from utils import pprint, set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric

if __name__ == '__main__':
    save_epoch = 10
    train_way = 4
    test_way = 4
    shot = 1
    query = 3
    gpu = 0
    save_path = './save/proto-1'
    max_epoch = 60

    ensure_path(save_path)

    trainset = DiabeticRetinopathy('train')
    train_sampler_ = PrototypicalBatchSampler(trainset.label, 3, train_way,
                                              shot + query)
    train_loader = DataLoader(dataset=trainset,
                              batch_sampler=train_sampler_,
                              num_workers=8)

    valset = DiabeticRetinopathy('val')
    val_sampler = PrototypicalBatchSampler(valset.label, 4, test_way,
                                           shot + query)
    val_loader = DataLoader(dataset=valset,
                            batch_sampler=val_sampler,
                            num_workers=8,
                            pin_memory=True)
示例#40
0
def deploy(params):
    """
    This command will update a repo and copy the contents to another folder.
    It can be used to create versioned deployments of the codebase
    When executed it will prompt for the tag to deploy if it's not known
    """
    print yellow("Warning git.deploy is deprecated from version 1.0")
    
    # Try to get global params
    params = utils.get_global_params(params,
                                     'git_repo_path', 
                                     'git_repo_url',
                                     'git_branch',
                                     'git_source_path')
    
    # Old params
    if 'tag_path' in params:
        abort(red("Warning tag_path is deprecated, Use git_source_path !"))
    
    if 'target_path' in params:
        abort(red("Warning target_path is deprecated, Use git_source_path !"))
    
    if 'tag' in params:
        abort(red("Warning tag is deprecated, Use git_branch !"))
        
    if 'branch' in params:
        abort(red("Warning branch is deprecated, Use git_branch !"))
       
    if 'repo_path' in params:
        abort(red("repo_path is deprecated, use git_repo_path !")) 
        
    
    # Check required params
    if 'git_source_path' not in params:
        abort(red("git_source_path is required !")) 
    
    if 'git_repo_url' not in params:
        abort(red("git_repo_url is required !")) 
    
    if 'git_repo_path' not in params:
        abort(red("git_repo_path is required !")) 
        
    if 'git_branch' not in params or len(params['git_branch']) == 0:
        abort(red("`git_branch` is required !"))
    
    params = utils.format_params(params)
    
    if not exists(params['git_repo_path']):
        abort(red("Repo path not existing... is the project installed?"))
    
    # If exist remove full source
    if exists(params['git_source_path']):
        print(yellow("Deploy target path `%s` allready existed... source data be removed and reset." % params['git_source_path']))
        run('rm -Rf %s' % params['git_source_path'])
    
    utils.ensure_path(params['git_source_path'])
    
    with cd(params['git_repo_path']):
        run('git fetch')
        
        # Update local repo with latest code
        run('git checkout %s' % (params['git_branch']))
        
        run('git pull origin %s' % (params['git_branch']))
        
        run('git submodule update')
        
        # Copy source code to version
        run('cp -Rf %s/* %s/' % (params['git_repo_path'], params['git_source_path']))