Пример #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--out")
    parser.add_argument("model_paths", nargs='+')
    options = parser.parse_args()
    model_paths = options.model_paths

    if options.out is not None:
      import matplotlib
      matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    print 'generating names...'
    model_names = [model_path.replace('.pkl', '!') for model_path in
            model_paths]
    model_names = unique_substrings(model_names, min_size=10)
    model_names = [model_name.replace('!','') for model_name in
            model_names]
    print '...done'

    for i, arg in enumerate(model_paths):
        try:
            model = serial.load(arg)
        except:
            if arg.endswith('.yaml'):
                print >> sys.stderr, arg + " is a yaml config file," + \
                "you need to load a trained model."
                quit(-1)
            raise
        this_model_channels = model.monitor.channels

        if len(sys.argv) > 2:
            postfix = ":" + model_names[i]
        else:
            postfix = ""

        for channel in this_model_channels:
            channels[channel+postfix] = this_model_channels[channel]


    while True:
        # Make a list of short codes for each channel so user can specify them
        # easily
        tag_generator = _TagGenerator()
        codebook = {}
        sorted_codes = []
        for channel_name in sorted(channels,
                key = number_aware_alphabetical_key):
            code = tag_generator.get_tag()
            codebook[code] = channel_name
            codebook['<'+channel_name+'>'] = channel_name
            sorted_codes.append(code)

        x_axis = 'example'
        print 'set x_axis to example'

        if len(channels.values()) == 0:
            print "there are no channels to plot"
            break

        # If there is more than one channel in the monitor ask which ones to
        # plot
        prompt = len(channels.values()) > 1

        if prompt:

            # Display the codebook
            for code in sorted_codes:
                print code + '. ' + codebook[code]

            print

            print "Put e, b, s or h in the list somewhere to plot " + \
                    "epochs, batches, seconds, or hours, respectively."
            response = raw_input('Enter a list of channels to plot ' + \
                    '(example: A, C,F-G, h, <test_err>) or q to quit' + \
                    'or o for options: ')

            if response == 'o':
                print '1: smooth all channels'
                print 'any other response: do nothing, go back to plotting'
                response = raw_input('Enter your choice: ')
                if response == '1':
                    for channel in channels.values():
                        k = 5
                        new_val_record = []
                        for i in xrange(len(channel.val_record)):
                            new_val = 0.
                            count = 0.
                            for j in xrange(max(0, i-k), i+1):
                                new_val += channel.val_record[j]
                                count += 1.
                            new_val_record.append(new_val / count)
                        channel.val_record = new_val_record
                continue

            if response == 'q':
                break

            #Remove spaces
            response = response.replace(' ','')

            #Split into list
            codes = response.split(',')

            final_codes = set([])

            for code in codes:
                if code == 'e':
                    x_axis = 'epoch'
                    continue
                elif code == 'b':
                    x_axis = 'batche'
                elif code == 's':
                    x_axis = 'second'
                elif code == 'h':
                    x_axis = 'hour'
                elif code.startswith('<'):
                    assert code.endswith('>')
                    final_codes.add(code)
                elif code.find('-') != -1:
                    #The current list element is a range of codes

                    rng = code.split('-')

                    if len(rng) != 2:
                        print "Input not understood: "+code
                        quit(-1)

                    found = False
                    for i in xrange(len(sorted_codes)):
                        if sorted_codes[i] == rng[0]:
                            found = True
                            break

                    if not found:
                        print "Invalid code: "+rng[0]
                        quit(-1)

                    found = False
                    for j in xrange(i,len(sorted_codes)):
                        if sorted_codes[j] == rng[1]:
                            found = True
                            break

                    if not found:
                        print "Invalid code: "+rng[1]
                        quit(-1)

                    final_codes = final_codes.union(set(sorted_codes[i:j+1]))
                else:
                    #The current list element is just a single code
                    final_codes = final_codes.union(set([code]))
            # end for code in codes
        else:
            final_codes ,= set(codebook.keys())

        colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
        styles = list(colors)
        styles += [color+'--' for color in colors]
        styles += [color+':' for color in colors]

        fig = plt.figure()
        ax = plt.subplot(1,1,1)

        # plot the requested channels
        for idx, code in enumerate(sorted(final_codes)):

            channel_name= codebook[code]
            channel = channels[channel_name]

            y = np.asarray(channel.val_record)

            if np.any(np.isnan(y)):
                print channel_name + ' contains NaNs'

            if np.any(np.isinf(y)):
                print channel_name + 'contains infinite values'

            if x_axis == 'example':
                x = np.asarray(channel.example_record)
            elif x_axis == 'batche':
                x = np.asarray(channel.batch_record)
            elif x_axis == 'epoch':
                try:
                    x = np.asarray(channel.epoch_record)
                except AttributeError:
                    # older saved monitors won't have epoch_record
                    x = np.arange(len(channel.batch_record))
            elif x_axis == 'second':
                x = np.asarray(channel.time_record)
            elif x_axis == 'hour':
                x = np.asarray(channel.time_record) / 3600.
            else:
                assert False


            ax.plot( x,
                      y,
                      styles[idx % len(styles)],
                      marker = '.', # add point margers to lines
                      label = channel_name)

        plt.xlabel('# '+x_axis+'s')
        ax.ticklabel_format( scilimits = (-3,3), axis = 'both')

        handles, labels = ax.get_legend_handles_labels()
        lgd = ax.legend(handles, labels, loc='upper center',
                bbox_to_anchor=(0.5,-0.1))
        # 0.046 is the size of 1 legend box
        fig.subplots_adjust(bottom=0.11 + 0.046 * len(final_codes))

          plt.show()
          plt.savefig(options.out)

        if not prompt:
            break
Пример #2
0
def main():
    """
    .. todo::

        WRITEME
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--out")
    parser.add_argument("model_paths", nargs='+')
    parser.add_argument("--yrange",
                        help='The y-range to be used for plotting, e.g.  0:1')

    options = parser.parse_args()
    model_paths = options.model_paths

    if options.out is not None:
        import matplotlib
        matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    print('generating names...')
    model_names = [
        model_path.replace('.pkl', '!') for model_path in model_paths
    ]
    model_names = unique_substrings(model_names, min_size=10)
    model_names = [model_name.replace('!', '') for model_name in model_names]
    print('...done')

    for i, arg in enumerate(model_paths):
        try:
            model = serial.load(arg)
        except Exception:
            if arg.endswith('.yaml'):
                print(sys.stderr,
                      arg + " is a yaml config file," +
                      "you need to load a trained model.",
                      file=sys.stderr)
                quit(-1)
            raise
        this_model_channels = model.monitor.channels

        if len(sys.argv) > 2:
            postfix = ":" + model_names[i]
        else:
            postfix = ""

        for channel in this_model_channels:
            channels[channel + postfix] = this_model_channels[channel]
        del model
        gc.collect()

    while True:
        # Make a list of short codes for each channel so user can specify them
        # easily
        tag_generator = _TagGenerator()
        codebook = {}
        sorted_codes = []
        for channel_name in sorted(channels,
                                   key=number_aware_alphabetical_key):
            code = tag_generator.get_tag()
            codebook[code] = channel_name
            codebook['<' + channel_name + '>'] = channel_name
            sorted_codes.append(code)

        x_axis = 'example'
        print('set x_axis to example')

        if len(channels.values()) == 0:
            print("there are no channels to plot")
            break

        # If there is more than one channel in the monitor ask which ones to
        # plot
        prompt = len(channels.values()) > 1

        if prompt:

            # Display the codebook
            for code in sorted_codes:
                print(code + '. ' + codebook[code])

            print()

            print("Put e, b, s or h in the list somewhere to plot " +
                  "epochs, batches, seconds, or hours, respectively.")
            response = input('Enter a list of channels to plot ' + \
                    '(example: A, C,F-G, h, <test_err>) or q to quit' + \
                    ' or o for options: ')

            if response == 'o':
                print('1: smooth all channels')
                print('any other response: do nothing, go back to plotting')
                response = input('Enter your choice: ')
                if response == '1':
                    for channel in channels.values():
                        k = 5
                        new_val_record = []
                        for i in xrange(len(channel.val_record)):
                            new_val = 0.
                            count = 0.
                            for j in xrange(max(0, i - k), i + 1):
                                new_val += channel.val_record[j]
                                count += 1.
                            new_val_record.append(new_val / count)
                        channel.val_record = new_val_record
                continue

            if response == 'q':
                break

            #Remove spaces
            response = response.replace(' ', '')

            #Split into list
            codes = response.split(',')

            final_codes = set([])

            for code in codes:
                if code == 'e':
                    x_axis = 'epoch'
                    continue
                elif code == 'b':
                    x_axis = 'batche'
                elif code == 's':
                    x_axis = 'second'
                elif code == 'h':
                    x_axis = 'hour'
                elif code.startswith('<'):
                    assert code.endswith('>')
                    final_codes.add(code)
                elif code.find('-') != -1:
                    #The current list element is a range of codes

                    rng = code.split('-')

                    if len(rng) != 2:
                        print("Input not understood: " + code)
                        quit(-1)

                    found = False
                    for i in xrange(len(sorted_codes)):
                        if sorted_codes[i] == rng[0]:
                            found = True
                            break

                    if not found:
                        print("Invalid code: " + rng[0])
                        quit(-1)

                    found = False
                    for j in xrange(i, len(sorted_codes)):
                        if sorted_codes[j] == rng[1]:
                            found = True
                            break

                    if not found:
                        print("Invalid code: " + rng[1])
                        quit(-1)

                    final_codes = final_codes.union(set(sorted_codes[i:j + 1]))
                else:
                    #The current list element is just a single code
                    final_codes = final_codes.union(set([code]))
            # end for code in codes
        else:
            final_codes, = set(codebook.keys())

        colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
        styles = list(colors)
        styles += [color + '--' for color in colors]
        styles += [color + ':' for color in colors]

        fig = plt.figure()
        ax = plt.subplot(1, 1, 1)

        # plot the requested channels
        for idx, code in enumerate(sorted(final_codes)):

            channel_name = codebook[code]
            channel = channels[channel_name]

            y = np.asarray(channel.val_record)

            if contains_nan(y):
                print(channel_name + ' contains NaNs')

            if contains_inf(y):
                print(channel_name + 'contains infinite values')

            if x_axis == 'example':
                x = np.asarray(channel.example_record)
            elif x_axis == 'batche':
                x = np.asarray(channel.batch_record)
            elif x_axis == 'epoch':
                try:
                    x = np.asarray(channel.epoch_record)
                except AttributeError:
                    # older saved monitors won't have epoch_record
                    x = np.arange(len(channel.batch_record))
            elif x_axis == 'second':
                x = np.asarray(channel.time_record)
            elif x_axis == 'hour':
                x = np.asarray(channel.time_record) / 3600.
            else:
                assert False

            ax.plot(
                x,
                y,
                styles[idx % len(styles)],
                marker='.',  # add point margers to lines
                label=channel_name)

        plt.xlabel('# ' + x_axis + 's')
        ax.ticklabel_format(scilimits=(-3, 3), axis='both')

        handles, labels = ax.get_legend_handles_labels()
        lgd = ax.legend(handles,
                        labels,
                        loc='upper center',
                        bbox_to_anchor=(0.5, -0.1))
        # 0.046 is the size of 1 legend box
        fig.subplots_adjust(bottom=0.11 + 0.046 * len(final_codes))

        if (options.yrange is not None):
            ymin, ymax = map(float, options.yrange.split(':'))
            plt.ylim(ymin, ymax)

        if options.out is None:
            plt.show()
        else:
            plt.savefig(options.out)

        if not prompt:
            break
Пример #3
0
for i, arg in enumerate(sys.argv[1:]):
    model = serial.load(arg)
    this_model_channels = model.monitor.channels

    if len(sys.argv) > 2:
        postfix = ":model%d" % i
    else:
        postfix = ""

    for channel in this_model_channels:
        channels[channel+postfix] = this_model_channels[channel]


while True:
#Make a list of short codes for each channel so user can specify them easily
    tag_generator = _TagGenerator()
    codebook = {}
    sorted_codes = []
    for channel_name in sorted(channels, key = number_aware_alphabetical_key):
        code = tag_generator.get_tag()
        codebook[code] = channel_name
        sorted_codes.append(code)

    x_axis = 'example'

    if len(channels.values()) == 0:
        print "there are no channels to plot"
        break

    #if there is more than one channel in the monitor ask which ones to plot
    prompt = len(channels.values()) > 1
Пример #4
0
for i, arg in enumerate(sys.argv[1:]):
    model = serial.load(arg)
    this_model_channels = model.monitor.channels

    if len(sys.argv) > 2:
        postfix = ":model%d" % i
    else:
        postfix = ""

    for channel in this_model_channels:
        channels[channel+postfix] = this_model_channels[channel]


while True:
#Make a list of short codes for each channel so user can specify them easily
    tag_generator = _TagGenerator()
    codebook = {}
    sorted_codes = []
    for channel_name in sorted(channels, key = number_aware_alphabetical_key):
        code = tag_generator.get_tag()
        codebook[code] = channel_name
        codebook['<'+channel_name+'>'] = channel_name
        sorted_codes.append(code)

    x_axis = 'example'
    print 'set x_axis to example'

    if len(channels.values()) == 0:
        print "there are no channels to plot"
        break
Пример #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--out")
    parser.add_argument("model_paths", nargs='+')
    options = parser.parse_args()
    model_paths = options.model_paths

    if options.out is not None:
        import matplotlib
        matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    print 'generating names...'
    model_names = [
        model_path.replace('.pkl', '!') for model_path in model_paths
    ]
    model_names = unique_substrings(model_names, min_size=10)
    model_names = [model_name.replace('!', '') for model_name in model_names]
    print '...done'

    for i, arg in enumerate(model_paths):
        try:
            model = serial.load(arg)
        except:
            if arg.endswith('.yaml'):
                print >> sys.stderr, arg + " is a yaml config file, you need to load a trained model."
                quit(-1)
            raise
        this_model_channels = model.monitor.channels

        if len(sys.argv) > 2:
            postfix = ":" + model_names[i]
        else:
            postfix = ""

        for channel in this_model_channels:
            channels[channel + postfix] = this_model_channels[channel]

    while True:
        #Make a list of short codes for each channel so user can specify them easily
        tag_generator = _TagGenerator()
        codebook = {}
        sorted_codes = []
        for channel_name in sorted(channels,
                                   key=number_aware_alphabetical_key):
            code = tag_generator.get_tag()
            codebook[code] = channel_name
            codebook['<' + channel_name + '>'] = channel_name
            sorted_codes.append(code)

        x_axis = 'example'
        print 'set x_axis to example'

        if len(channels.values()) == 0:
            print "there are no channels to plot"
            break

        #if there is more than one channel in the monitor ask which ones to plot
        prompt = len(channels.values()) > 1

        if prompt:

            #Display the codebook
            for code in sorted_codes:
                print code + '. ' + codebook[code]

            print

            print "Put e, b, s or h in the list somewhere to plot epochs, batches, seconds, or hours, respectively."
            response = raw_input(
                'Enter a list of channels to plot (example: A, C,F-G, h, <test_err>) or q to quit or o for options: '
            )

            if response == 'o':
                print '1: smooth all channels'
                print 'any other response: do nothing, go back to plotting'
                response = raw_input('Enter your choice: ')
                if response == '1':
                    for channel in channels.values():
                        k = 5
                        new_val_record = []
                        for i in xrange(len(channel.val_record)):
                            new_val = 0.
                            count = 0.
                            for j in xrange(max(0, i - k), i + 1):
                                new_val += channel.val_record[j]
                                count += 1.
                            new_val_record.append(new_val / count)
                        channel.val_record = new_val_record
                continue

            if response == 'q':
                break

            #Remove spaces
            response = response.replace(' ', '')

            #Split into list
            codes = response.split(',')

            final_codes = set([])

            for code in codes:
                if code == 'e':
                    x_axis = 'epoch'
                    continue
                elif code == 'b':
                    x_axis = 'batche'
                elif code == 's':
                    x_axis = 'second'
                elif code == 'h':
                    x_axis = 'hour'
                elif code.startswith('<'):
                    assert code.endswith('>')
                    final_codes.add(code)
                elif code.find('-') != -1:
                    #The current list element is a range of codes

                    rng = code.split('-')

                    if len(rng) != 2:
                        print "Input not understood: " + code
                        quit(-1)

                    found = False
                    for i in xrange(len(sorted_codes)):
                        if sorted_codes[i] == rng[0]:
                            found = True
                            break

                    if not found:
                        print "Invalid code: " + rng[0]
                        quit(-1)

                    found = False
                    for j in xrange(i, len(sorted_codes)):
                        if sorted_codes[j] == rng[1]:
                            found = True
                            break

                    if not found:
                        print "Invalid code: " + rng[1]
                        quit(-1)

                    final_codes = final_codes.union(set(sorted_codes[i:j + 1]))
                else:
                    #The current list element is just a single code
                    final_codes = final_codes.union(set([code]))
            # end for code in codes
        else:
            final_codes, = set(codebook.keys())

        plt.figure()
        #Make 2 subplots so the legend gets a plot to itself and won't cover up the plot
        ax = plt.subplot(1, 2, 1)

        # Grow current axis' width by 30%
        box = ax.get_position()

        try:
            x0 = box.x0
            y0 = box.y0
            width = box.width
            height = box.height
        except:
            x0, width, y0, height = box

        ax.set_position([x0, y0, width * 1.3, height])

        ax.ticklabel_format(scilimits=(-3, 3), axis='both')

        plt.xlabel('# ' + x_axis + 's')

        colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
        styles = list(colors)
        styles += [color + '--' for color in colors]
        styles += [color + ':' for color in colors]

        #plot the requested channels
        for idx, code in enumerate(sorted(final_codes)):

            channel_name = codebook[code]

            channel = channels[channel_name]

            y = np.asarray(channel.val_record)

            if np.any(np.isnan(y)):
                print channel_name + ' contains NaNs'

            if np.any(np.isinf(y)):
                print channel_name + 'contains infinite values'

            if x_axis == 'example':
                x = np.asarray(channel.example_record)
            elif x_axis == 'batche':
                x = np.asarray(channel.batch_record)
            elif x_axis == 'epoch':
                try:
                    x = np.asarray(channel.epoch_record)
                except AttributeError:
                    # older saved monitors won't have epoch_record
                    x = np.arange(len(channel.batch_record))
            elif x_axis == 'second':
                x = np.asarray(channel.time_record)
            elif x_axis == 'hour':
                x = np.asarray(channel.time_record) / 3600.
            else:
                assert False

            plt.plot(x, y, styles[idx % len(styles)], label=channel_name)

        plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

        if options.out is None:
            plt.show()
        else:
            plt.savefig(options.out)

        if not prompt:
            break
Пример #6
0
def main():
    """
    .. todo::

        WRITEME
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--out")
    parser.add_argument("model_paths", nargs="+")
    options = parser.parse_args()
    model_paths = options.model_paths

    if options.out is not None:
        import matplotlib

        matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    print "generating names..."
    model_names = [model_path.replace(".pkl", "!") for model_path in model_paths]
    model_names = unique_substrings(model_names, min_size=10)
    model_names = [model_name.replace("!", "") for model_name in model_names]
    print "...done"

    for i, arg in enumerate(model_paths):
        try:
            model = serial.load(arg)
        except:
            if arg.endswith(".yaml"):
                print >> sys.stderr, arg + " is a yaml config file," + "you need to load a trained model."
                quit(-1)
            raise
        this_model_channels = model.monitor.channels

        if len(sys.argv) > 2:
            postfix = ":" + model_names[i]
        else:
            postfix = ""

        for channel in this_model_channels:
            channels[channel + postfix] = this_model_channels[channel]
        del model
        gc.collect()

    while True:
        # Make a list of short codes for each channel so user can specify them
        # easily
        tag_generator = _TagGenerator()
        codebook = {}
        sorted_codes = []
        for channel_name in sorted(channels, key=number_aware_alphabetical_key):
            code = tag_generator.get_tag()
            codebook[code] = channel_name
            codebook["<" + channel_name + ">"] = channel_name
            sorted_codes.append(code)

        x_axis = "example"
        print "set x_axis to example"

        if len(channels.values()) == 0:
            print "there are no channels to plot"
            break

        # If there is more than one channel in the monitor ask which ones to
        # plot
        prompt = len(channels.values()) > 1

        if prompt:

            # Display the codebook
            for code in sorted_codes:
                print code + ". " + codebook[code]

            print

            print "Put e, b, s or h in the list somewhere to plot " + "epochs, batches, seconds, or hours, respectively."
            response = raw_input(
                "Enter a list of channels to plot "
                + "(example: A, C,F-G, h, <test_err>) or q to quit"
                + " or o for options: "
            )

            if response == "o":
                print "1: smooth all channels"
                print "any other response: do nothing, go back to plotting"
                response = raw_input("Enter your choice: ")
                if response == "1":
                    for channel in channels.values():
                        k = 5
                        new_val_record = []
                        for i in xrange(len(channel.val_record)):
                            new_val = 0.0
                            count = 0.0
                            for j in xrange(max(0, i - k), i + 1):
                                new_val += channel.val_record[j]
                                count += 1.0
                            new_val_record.append(new_val / count)
                        channel.val_record = new_val_record
                continue

            if response == "q":
                break

            # Remove spaces
            response = response.replace(" ", "")

            # Split into list
            codes = response.split(",")

            final_codes = set([])

            for code in codes:
                if code == "e":
                    x_axis = "epoch"
                    continue
                elif code == "b":
                    x_axis = "batche"
                elif code == "s":
                    x_axis = "second"
                elif code == "h":
                    x_axis = "hour"
                elif code.startswith("<"):
                    assert code.endswith(">")
                    final_codes.add(code)
                elif code.find("-") != -1:
                    # The current list element is a range of codes

                    rng = code.split("-")

                    if len(rng) != 2:
                        print "Input not understood: " + code
                        quit(-1)

                    found = False
                    for i in xrange(len(sorted_codes)):
                        if sorted_codes[i] == rng[0]:
                            found = True
                            break

                    if not found:
                        print "Invalid code: " + rng[0]
                        quit(-1)

                    found = False
                    for j in xrange(i, len(sorted_codes)):
                        if sorted_codes[j] == rng[1]:
                            found = True
                            break

                    if not found:
                        print "Invalid code: " + rng[1]
                        quit(-1)

                    final_codes = final_codes.union(set(sorted_codes[i : j + 1]))
                else:
                    # The current list element is just a single code
                    final_codes = final_codes.union(set([code]))
            # end for code in codes
        else:
            final_codes, = set(codebook.keys())

        colors = ["b", "g", "r", "c", "m", "y", "k"]
        styles = list(colors)
        styles += [color + "--" for color in colors]
        styles += [color + ":" for color in colors]

        fig = plt.figure()
        ax = plt.subplot(1, 1, 1)

        # plot the requested channels
        for idx, code in enumerate(sorted(final_codes)):

            channel_name = codebook[code]
            channel = channels[channel_name]

            y = np.asarray(channel.val_record)

            if np.any(np.isnan(y)):
                print channel_name + " contains NaNs"

            if np.any(np.isinf(y)):
                print channel_name + "contains infinite values"

            if x_axis == "example":
                x = np.asarray(channel.example_record)
            elif x_axis == "batche":
                x = np.asarray(channel.batch_record)
            elif x_axis == "epoch":
                try:
                    x = np.asarray(channel.epoch_record)
                except AttributeError:
                    # older saved monitors won't have epoch_record
                    x = np.arange(len(channel.batch_record))
            elif x_axis == "second":
                x = np.asarray(channel.time_record)
            elif x_axis == "hour":
                x = np.asarray(channel.time_record) / 3600.0
            else:
                assert False

            ax.plot(x, y, styles[idx % len(styles)], marker=".", label=channel_name)  # add point margers to lines

        plt.xlabel("# " + x_axis + "s")
        ax.ticklabel_format(scilimits=(-3, 3), axis="both")

        handles, labels = ax.get_legend_handles_labels()
        lgd = ax.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, -0.1))
        # 0.046 is the size of 1 legend box
        fig.subplots_adjust(bottom=0.11 + 0.046 * len(final_codes))

        if options.out is None:
            # plt.show()
            plt.savefig("plot_value_over_all_epochs.png")
        else:
            plt.savefig(options.out)

        if not prompt:
            break
Пример #7
0
def main():
    """
    .. todo::

        WRITEME
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--out")
    parser.add_argument("model_paths", nargs='+')
    parser.add_argument("--yrange", help='The y-range to be used for plotting, e.g.  0:1')
    
    options = parser.parse_args()
    model_paths = options.model_paths

    if options.out is not None:
      import matplotlib
      matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    print('generating names...')
    model_names = [model_path.replace('.pkl', '!') for model_path in
            model_paths]
    model_names = unique_substrings(model_names, min_size=10)
    model_names = [model_name.replace('!','') for model_name in
            model_names]
    print('...done')

    for i, arg in enumerate(model_paths):
        try:
            model = serial.load(arg)
        except Exception:
            if arg.endswith('.yaml'):
                print(sys.stderr, arg + " is a yaml config file," + 
                      "you need to load a trained model.", file=sys.stderr)
                quit(-1)
            raise
        this_model_channels = model.monitor.channels

        if len(sys.argv) > 2:
            postfix = ":" + model_names[i]
        else:
            postfix = ""

        for channel in this_model_channels:
            channels[channel+postfix] = this_model_channels[channel]
        del model
        gc.collect()


    while True:
        # Make a list of short codes for each channel so user can specify them
        # easily
        tag_generator = _TagGenerator()
        codebook = {}
        sorted_codes = []
        for channel_name in sorted(channels,
                key = number_aware_alphabetical_key):
            code = tag_generator.get_tag()
            codebook[code] = channel_name
            codebook['<'+channel_name+'>'] = channel_name
            sorted_codes.append(code)

        x_axis = 'example'
        print('set x_axis to example')

        if len(channels.values()) == 0:
            print("there are no channels to plot")
            break

        # If there is more than one channel in the monitor ask which ones to
        # plot
        prompt = len(channels.values()) > 1

        if prompt:

            # Display the codebook
            for code in sorted_codes:
                print(code + '. ' + codebook[code])

            print()

            print("Put e, b, s or h in the list somewhere to plot " + 
                    "epochs, batches, seconds, or hours, respectively.")
            response = input('Enter a list of channels to plot ' + \
                    '(example: A, C,F-G, h, <test_err>) or q to quit' + \
                    ' or o for options: ')

            if response == 'o':
                print('1: smooth all channels')
                print('any other response: do nothing, go back to plotting')
                response = input('Enter your choice: ')
                if response == '1':
                    for channel in channels.values():
                        k = 5
                        new_val_record = []
                        for i in xrange(len(channel.val_record)):
                            new_val = 0.
                            count = 0.
                            for j in xrange(max(0, i-k), i+1):
                                new_val += channel.val_record[j]
                                count += 1.
                            new_val_record.append(new_val / count)
                        channel.val_record = new_val_record
                continue

            if response == 'q':
                break

            #Remove spaces
            response = response.replace(' ','')

            #Split into list
            codes = response.split(',')

            final_codes = set([])

            for code in codes:
                if code == 'e':
                    x_axis = 'epoch'
                    continue
                elif code == 'b':
                    x_axis = 'batche'
                elif code == 's':
                    x_axis = 'second'
                elif code == 'h':
                    x_axis = 'hour'
                elif code.startswith('<'):
                    assert code.endswith('>')
                    final_codes.add(code)
                elif code.find('-') != -1:
                    #The current list element is a range of codes

                    rng = code.split('-')

                    if len(rng) != 2:
                        print("Input not understood: "+code)
                        quit(-1)

                    found = False
                    for i in xrange(len(sorted_codes)):
                        if sorted_codes[i] == rng[0]:
                            found = True
                            break

                    if not found:
                        print("Invalid code: "+rng[0])
                        quit(-1)

                    found = False
                    for j in xrange(i,len(sorted_codes)):
                        if sorted_codes[j] == rng[1]:
                            found = True
                            break

                    if not found:
                        print("Invalid code: "+rng[1])
                        quit(-1)

                    final_codes = final_codes.union(set(sorted_codes[i:j+1]))
                else:
                    #The current list element is just a single code
                    final_codes = final_codes.union(set([code]))
            # end for code in codes
        else:
            final_codes ,= set(codebook.keys())

        colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
        styles = list(colors)
        styles += [color+'--' for color in colors]
        styles += [color+':' for color in colors]

        fig = plt.figure()
        ax = plt.subplot(1,1,1)

        # plot the requested channels
        for idx, code in enumerate(sorted(final_codes)):

            channel_name= codebook[code]
            channel = channels[channel_name]

            y = np.asarray(channel.val_record)

            if contains_nan(y):
                print(channel_name + ' contains NaNs')

            if contains_inf(y):
                print(channel_name + 'contains infinite values')

            if x_axis == 'example':
                x = np.asarray(channel.example_record)
            elif x_axis == 'batche':
                x = np.asarray(channel.batch_record)
            elif x_axis == 'epoch':
                try:
                    x = np.asarray(channel.epoch_record)
                except AttributeError:
                    # older saved monitors won't have epoch_record
                    x = np.arange(len(channel.batch_record))
            elif x_axis == 'second':
                x = np.asarray(channel.time_record)
            elif x_axis == 'hour':
                x = np.asarray(channel.time_record) / 3600.
            else:
                assert False


            ax.plot( x,
                      y,
                      styles[idx % len(styles)],
                      marker = '.', # add point margers to lines
                      label = channel_name)

        plt.xlabel('# '+x_axis+'s')
        ax.ticklabel_format( scilimits = (-3,3), axis = 'both')

        handles, labels = ax.get_legend_handles_labels()
        lgd = ax.legend(handles, labels, loc = 'upper left',
               bbox_to_anchor = (1.05, 1.02))

        # Get the axis positions and the height and width of the legend

        plt.draw()       
        ax_pos = ax.get_position()
        pad_width = ax_pos.x0 * fig.get_size_inches()[0]
        pad_height = ax_pos.y0 * fig.get_size_inches()[1]
        dpi = fig.get_dpi()
        lgd_width = ax.get_legend().get_frame().get_width() / dpi 
        lgd_height = ax.get_legend().get_frame().get_height() / dpi 

        # Adjust the bounding box to encompass both legend and axis.  Axis should be 3x3 inches.
        # I had trouble getting everything to align vertically.

        ax_width = 3
        ax_height = 3
        total_width = 2*pad_width + ax_width + lgd_width
        total_height = 2*pad_height + np.maximum(ax_height, lgd_height)

        fig.set_size_inches(total_width, total_height)
        ax.set_position([pad_width/total_width, 1-6*pad_height/total_height, ax_width/total_width, ax_height/total_height])

        if(options.yrange is not None):
            ymin, ymax = map(float, options.yrange.split(':'))
            plt.ylim(ymin, ymax)
        
        if options.out is None:
          plt.show()
        else:
          plt.savefig(options.out)

        if not prompt:
            break
Пример #8
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--out")
    parser.add_argument("model_paths", nargs='+')
    options = parser.parse_args()
    model_paths = options.model_paths

    if options.out is not None:
      import matplotlib
      matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    print 'generating names...'
    model_names = [model_path.replace('.pkl', '!') for model_path in model_paths]
    model_names = unique_substrings(model_names, min_size=10)
    model_names = [model_name.replace('!','') for model_name in model_names]
    print '...done'

    for i, arg in enumerate(model_paths):
        try:
            model = serial.load(arg)
        except:
            if arg.endswith('.yaml'):
                print >> sys.stderr, arg+" is a yaml config file, you need to load a trained model."
                quit(-1)
            raise
        this_model_channels = model.monitor.channels

        if len(sys.argv) > 2:
            postfix = ":" + model_names[i]
        else:
            postfix = ""

        for channel in this_model_channels:
            channels[channel+postfix] = this_model_channels[channel]


    while True:
        #XD
        maxval = 0  
        baseval = 0  
        xlen = 0   
        start = 10
        stop = 0
#Make a list of short codes for each channel so user can specify them easily
        tag_generator = _TagGenerator()
        codebook = {}
        sorted_codes = []
        for channel_name in sorted(channels, key = number_aware_alphabetical_key):
            code = tag_generator.get_tag()
            codebook[code] = channel_name
            codebook['<'+channel_name+'>'] = channel_name
            sorted_codes.append(code)

        #x_axis = 'example'
        x_axis = 'epoch'  #XD add
        #print 'set x_axis to example'

        if len(channels.values()) == 0:
            print "there are no channels to plot"
            break

        #if there is more than one channel in the monitor ask which ones to plot
        prompt = len(channels.values()) > 1

        if prompt:

            #Display the codebook
            for code in sorted_codes:
                print code + '. ' + codebook[code]

            print

            print "Put e, b, s or h in the list somewhere to plot epochs, batches, seconds, or hours, respectively."
            print "Put max=xx or base=xx  in the list somewhere to set bounds or baseline for plotting."  #XD add
            response = raw_input('Enter a list of channels to plot (example: A, C,F-G, h, <test_err>) or q to quit or o for options: ')

            if response == 'o':
                print '1: smooth all channels'
                print 'any other response: do nothing, go back to plotting'
                response = raw_input('Enter your choice: ')
                if response == '1':
                    for channel in channels.values():
                        k = 5
                        new_val_record = []
                        for i in xrange(len(channel.val_record)):
                            new_val = 0.
                            count = 0.
                            for j in xrange(max(0, i-k), i+1):
                                new_val += channel.val_record[j]
                                count += 1.
                            new_val_record.append(new_val / count)
                        channel.val_record = new_val_record
                continue

            if response == 'q':
                break

            #Remove spaces
            response = response.replace(' ','')

            #Split into list
            codes = response.split(',')

            final_codes = set([])

            for code in codes:
                """
                if code == 'e':
                    x_axis = 'epoch'
                    continue
                elif code == 'b':
                    x_axis = 'batche'
                elif code == 's':
                    x_axis = 'second'
                elif code == 'h':
                    x_axis = 'hour'
                elif code.startswith('<'):
                    assert code.endswith('>')
                    final_codes.add(code)
                """
                #XD add
                if code.startswith('max='):
                    maxval = float(code.split('=')[-1])
                    print 'maxval=', maxval  #XD debug
                elif code.startswith('base='):
                    baseval = float(code.split('=')[-1])
                elif code.startswith('xlen='):
                    xlen = float(code.split('=')[-1])
                elif code.startswith('start='):
                    start = float(code.split('=')[-1])
                elif code.startswith('stop='):
                    stop = float(code.split('=')[-1])

                elif code.find('-') != -1:
                    #The current list element is a range of codes

                    rng = code.split('-')

                    if len(rng) != 2:
                        print "Input not understood: "+code
                        quit(-1)

                    found = False
                    for i in xrange(len(sorted_codes)):
                        if sorted_codes[i] == rng[0]:
                            found = True
                            break

                    if not found:
                        print "Invalid code: "+rng[0]
                        quit(-1)

                    found = False
                    for j in xrange(i,len(sorted_codes)):
                        if sorted_codes[j] == rng[1]:
                            found = True
                            break

                    if not found:
                        print "Invalid code: "+rng[1]
                        quit(-1)

                    final_codes = final_codes.union(set(sorted_codes[i:j+1]))
                else:
                    #The current list element is just a single code
                    final_codes = final_codes.union(set([code]))
            # end for code in codes
        else:
            final_codes ,= set(codebook.keys())

        plt.figure()
        """XD
        #Make 2 subplots so the legend gets a plot to itself and won't cover up the plot
        ax = plt.subplot(1,2,1)

        # Grow current axis' width by 30%
        box = ax.get_position()

        try:
            x0 = box.x0
            y0 = box.y0
            width = box.width
            height = box.height
        except:
            x0, width, y0, height = box


        ax.set_position([x0, y0, width * 1.3, height])

        ax.ticklabel_format( scilimits = (-3,3), axis = 'both')
        """
        plt.xlabel('# '+x_axis+'s')


        colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
        styles = list(colors)
        styles += [color+'--' for color in colors]
        styles += [color+':' for color in colors]

        #plot the requested channels
        for idx, code in enumerate(sorted(final_codes)):
            #code = code.upper()  #XD

            #channel_name= codebook[code]
            channel_name= codebook[code.upper()]   #XD
            print 'channel_name =', channel_name  #XD

            channel = channels[channel_name]

            y = np.asarray(channel.val_record)

            if np.any(np.isnan(y)):
                print channel_name + ' contains NaNs'

            if np.any(np.isinf(y)):
                print channel_name + 'contains infinite values'

            #XD add
            if maxval != 0:
                y = y * (y < maxval) + maxval * (y > maxval) 

            if x_axis == 'example':
                x = np.asarray(channel.example_record)
            elif x_axis == 'batche':
                x = np.asarray(channel.batch_record)
            elif x_axis == 'epoch':
                try:
                    x = np.asarray(channel.epoch_record)
                except AttributeError:
                    # older saved monitors won't have epoch_record
                    x = np.arange(len(channel.batch_record))
            elif x_axis == 'second':
                x = np.asarray(channel.time_record)
            elif x_axis == 'hour':
                x = np.asarray(channel.time_record) / 3600.
            else:
                assert False

            #XD add
            if xlen != 0:
                y = y[:xlen]
                x = x[:xlen]
            if start != 0 and stop !=0:
                y = y[start:stop]
                x = x[start:stop]
            elif start != 0:
                y = y[start:]
                x = x[start:]
            if  stop !=0:
                y = y[:stop]
                x = x[:stop]
                    

            """
            if channel_name == 'train_incprobs.5-1-1_nll':
                y /= train_incprobs555_nll
            if channel_name == 'train_incprobs.5-1-1_misclass':
                y /= train_incprobs555_misclass                
            if channel_name == 'train_incprobs.7-1-1_nll':
                y /= train_incprobs755_nll
            if channel_name == 'train_incprobs.7-1-1_misclass':
                y /= train_incprobs755_misclass      
            
            if channel_name == 'train_y_nll':
                y = y /train_objective *5
                #y = y * 20
            
            
            if channel_name == 'train_iptincprob.5_entropy':
                y = y + 0.05
            if channel_name == 'train_iptincprob.6_entropy':
                y = y + 0.055
            if channel_name == 'train_incprob.5_nll':
                y = y + 0.009
            if channel_name == 'train_y_entropy':
                y = y * 5 + 0.023
            """
            #print 'channel_name before plot =', channel_name  #XD
            plt.plot( x,
                      y,
                      #styles[idx % len(styles)],
                      styles[(idx+1) % len(styles)],     # XD add
                      label = channel_name)

            if baseval !=0 and idx == 0:
                y = y*0.0+baseval
                plt.plot(x, y, styles[idx % len(styles)], label = 'baseline=' + str(baseval))   # XD add      


        #plt.legend(bbox_to_anchor=(1.05, 1),  loc=2, borderaxespad=0.)
        plt.legend(loc=2)  #XD


        if options.out is None:
          plt.show()
        else:
          plt.savefig(options.out)

        if not prompt:
            break