예제 #1
0
    def get_network_visualization(self, **kwargs):
        """
        return visualization of network
        """
        desc = kwargs['desc']
        # save network description to temporary file
        temp_network_handle, temp_network_path = tempfile.mkstemp(suffix='.lua')
        os.write(temp_network_handle, desc)
        os.close(temp_network_handle)

        try:  # do this in a try..finally clause to make sure we delete the temp file
            # build command line
            torch_bin = config_value('torch')['executable']

            args = [torch_bin,
                    os.path.join(os.path.dirname(digits.__file__), 'tools', 'torch', 'main.lua'),
                    '--network=%s' % os.path.splitext(os.path.basename(temp_network_path))[0],
                    '--networkDirectory=%s' % os.path.dirname(temp_network_path),
                    '--subtractMean=none',  # we are not providing a mean image
                    '--visualizeModel=yes',
                    '--type=float'
                    ]

            # execute command
            p = subprocess.Popen(args,
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.STDOUT,
                                 close_fds=True,
                                 )

            # TODO: need to include regular expression for MAC color codes
            regex = re.compile('\x1b\[[0-9;]*m', re.UNICODE)

            # the network description will be accumulated from the command output
            # when collecting_net_definition==True
            collecting_net_definition = False
            desc = []
            unrecognized_output = []
            while p.poll() is None:
                for line in utils.nonblocking_readlines(p.stdout):
                    if line is not None:
                        # Remove whitespace and color codes.
                        # Color codes are appended to beginning and end of line by torch binary
                        # i.e., 'th'. Check the below link for more information
                        # https://groups.google.com/forum/#!searchin/torch7/color$20codes/torch7/8O_0lSgSzuA/Ih6wYg9fgcwJ  # noqa
                        line = regex.sub('', line)
                        timestamp, level, message = TorchTrainTask.preprocess_output_torch(line.strip())
                        if message:
                            if message.startswith('Network definition'):
                                collecting_net_definition = not collecting_net_definition
                        else:
                            if collecting_net_definition:
                                desc.append(line)
                            elif len(line):
                                unrecognized_output.append(line)
                    else:
                        time.sleep(0.05)

            if not len(desc):
                # we did not find a network description
                raise NetworkVisualizationError(''.join(unrecognized_output))
            else:
                output = flask.Markup('<pre align="left">')
                for line in desc:
                    output += flask.Markup.escape(line)
                output += flask.Markup('</pre>')
                return output
        finally:
            os.remove(temp_network_path)
예제 #2
0
 def create_train_task(self, **kwargs):
     """
     create train task
     """
     return TorchTrainTask(framework_id=self.framework_id, **kwargs)
예제 #3
0
    def get_network_visualization(self, desc):
        """
        return visualization of network
        """
        # save network description to temporary file
        temp_network_handle, temp_network_path = tempfile.mkstemp(suffix='.lua')
        os.write(temp_network_handle, desc)
        os.close(temp_network_handle)

        try: # do this in a try..finally clause to make sure we delete the temp file
            # build command line
            if config_value('torch_root') == '<PATHS>':
                torch_bin = 'th'
            else:
                torch_bin = os.path.join(config_value('torch_root'), 'bin', 'th')

            args = [torch_bin,
                    os.path.join(os.path.dirname(os.path.dirname(digits.__file__)),'tools','torch','main.lua'),
                    '--network=%s' % os.path.splitext(os.path.basename(temp_network_path))[0],
                    '--networkDirectory=%s' % os.path.dirname(temp_network_path),
                    '--subtractMean=none', # we are not providing a mean image
                    '--visualizeModel=yes',
                    '--type=float'
                    ]

            # execute command
            p = subprocess.Popen(args,
                        stdout=subprocess.PIPE,
                        stderr=subprocess.STDOUT,
                        close_fds=True,
                        )

            regex = re.compile('\x1b\[[0-9;]*m', re.UNICODE)   #TODO: need to include regular expression for MAC color codes

            # the network description will be accumulated from the command output
            # when collecting_net_definition==True
            collecting_net_definition = False
            desc = []
            unrecognized_output = []
            while p.poll() is None:
                for line in utils.nonblocking_readlines(p.stdout):
                    if line is not None:
                        # Remove whitespace and color codes. color codes are appended to beginning and end of line by torch binary i.e., 'th'. Check the below link for more information
                        # https://groups.google.com/forum/#!searchin/torch7/color$20codes/torch7/8O_0lSgSzuA/Ih6wYg9fgcwJ
                        line = regex.sub('', line)
                        timestamp, level, message = TorchTrainTask.preprocess_output_torch(line.strip())
                        if message:
                            if message.startswith('Network definition'):
                                collecting_net_definition = not collecting_net_definition
                        else:
                            if collecting_net_definition:
                                desc.append(line)
                            elif len(line):
                                unrecognized_output.append(line)
                    else:
                        time.sleep(0.05)

            if not len(desc):
                # we did not find a network description
                raise NetworkVisualizationError(''.join(unrecognized_output))
            else:
                output = flask.Markup('<pre>')
                for line in desc:
                    output += flask.Markup.escape(line)
                output += flask.Markup('</pre>')
                return output
        finally:
            os.remove(temp_network_path)