Exemplo n.º 1
0
def test_set_matplotlib_close():
    cfg = _get_inline_config()
    cfg.close_figures = False
    display.set_matplotlib_close()
    assert cfg.close_figures
    display.set_matplotlib_close(False)
    assert not cfg.close_figures
Exemplo n.º 2
0
def test_set_matplotlib_close():
    cfg = _get_inline_config()
    cfg.close_figures = False
    with pytest.deprecated_call():
        display.set_matplotlib_close()
    assert cfg.close_figures
    with pytest.deprecated_call():
        display.set_matplotlib_close(False)
    assert not cfg.close_figures
Exemplo n.º 3
0
 def __init__(self, data_dict, thresh, nrows, ncols, fig, axs, chans):
     # figure parameters
     self.thresh = thresh
     self.nrows = nrows
     self.ncols = ncols
     self.fig = fig
     self.axs = axs
     self.chans = chans
     # lists for colors and markers
     self.cl = [
         'tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple',
         'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan'
     ]
     self.ml = ['o', '^', 's', 'p', 'P', '*', 'X', 'd']
     self.ic = 0  # index counter
     if self.nrows == 1 and self.ncols == 1:
         self.axs.cla()
         self.axs.plot(data_dict['tdcSamplesChan_%d' % self.chans[self.ic]],
                       data_dict['adcSamplesChan_%d' % self.chans[self.ic]],
                       color=self.cl[self.ic % len(self.cl)],
                       marker=self.ml[self.ic % len(self.ml)],
                       ls='',
                       label='Channel %d' % self.chans[self.ic])
         hit_loc = np.where(data_dict['adcSamplesChan_%d' % self.chans[self.ic]] > self.thresh)[0] + \
                   np.min(data_dict['tdcSamplesChan_%d' % self.chans[self.ic]])
         if len(hit_loc) != 0:
             self.axs.set_xlim(np.min(hit_loc) - 10, np.max(hit_loc) + 20)
         self.axs.set_ylim(0, 1024)
         self.axs.set_ylabel('ADC Value')
         self.axs.set_xlabel('TDC Sample Number')
         self.axs.legend(loc='best',
                         markerscale=0,
                         handletextpad=0,
                         handlelength=0)
         self.ic += 1
     if (self.nrows == 1 and self.ncols == 2) or (self.nrows == 2
                                                  and self.ncols == 1):
         for index in range(2):
             self.axs[index].cla()
             self.axs[index].plot(
                 data_dict['tdcSamplesChan_%d' % self.chans[self.ic]],
                 data_dict['adcSamplesChan_%d' % self.chans[self.ic]],
                 color=self.cl[self.ic % len(self.cl)],
                 marker=self.ml[self.ic % len(self.ml)],
                 ls='',
                 label='Channel %d' % self.chans[self.ic])
             hit_loc = np.where(data_dict['adcSamplesChan_%d' % self.chans[self.ic]] > self.thresh)[0] + \
                       np.min(data_dict['tdcSamplesChan_%d' % self.chans[self.ic]])
             if len(hit_loc) != 0:
                 self.axs[index].set_xlim(
                     np.min(hit_loc) - 10,
                     np.max(hit_loc) + 20)
             self.axs[index].set_ylim(0, 1024)
             self.axs[index].set_ylabel('ADC Value')
             self.axs[index].set_xlabel('TDC Sample Number')
             self.axs[index].legend(loc='best',
                                    markerscale=0,
                                    handletextpad=0,
                                    handlelength=0)
             self.ic += 1
     if self.nrows >= 2 and self.ncols >= 2:
         for row in range(self.nrows):
             for column in range(self.ncols):
                 self.axs[row, column].cla()
                 self.axs[row, column].plot(
                     data_dict['tdcSamplesChan_%d' % self.chans[self.ic]],
                     data_dict['adcSamplesChan_%d' % self.chans[self.ic]],
                     color=self.cl[self.ic % len(self.cl)],
                     marker=self.ml[self.ic % len(self.ml)],
                     ls='',
                     label='Channel %d' % self.chans[self.ic])
                 hit_loc = np.where(data_dict['adcSamplesChan_%d' % self.chans[self.ic]] > self.thresh)[0] + \
                          np.min(data_dict['tdcSamplesChan_%d' % self.chans[self.ic]])
                 if len(hit_loc) != 0:
                     self.axs[row, column].set_xlim(
                         np.min(hit_loc) - 10,
                         np.max(hit_loc) + 20)
                 self.axs[row, column].set_ylim(0, 1024)
                 if column == 0:
                     self.axs[row, column].set_ylabel('ADC Value')
                 if row == self.nrows - 1:
                     self.axs[row, column].set_xlabel('TDC Sample Number')
                 self.axs[row, column].legend(loc='best',
                                              markerscale=0,
                                              handletextpad=0,
                                              handlelength=0)
                 self.ic += 1
     plt.tight_layout()
     # plt.savefig('plots/event_%d.png' % ec)
     # plt.pause(0.05)
     # display(self.fig)
     set_matplotlib_close(False)
     clear_output(wait=True)
     plt.pause(0.005)
     # remove event data after being published
     for chan in range(1, numChans + 1):
         data_dict.pop('adcSamplesChan_%s' % str(chan), None)
         data_dict.pop('tdcSamplesChan_%s' % str(chan), None)
Exemplo n.º 4
0
def savePlotToFile(matplotlib_figure, filepath):
    abs_filepath = abspath(filepath)

    # turn off printing (savefig is annoying)
    # nullwrite = NullWriter()
    # oldstdout = sys.stdout
    # sys.stdout = nullwrite
    # save plot as image
    success = False
    try:
        # matplotlib.use('Agg')
        # figure = matplotlib_pyplot.gcf()
        # matplotlib_pyplot.show()
        # matplotlib_pyplot.draw()
        # matplotlib_pyplot.savefig(abs_filepath, format='png')
        # matplotlib_pyplot.plot()

        # figure = matplotlib_pyplot.gcf()
        # figure = matplotlib.pyplot.gcf()
        matplotlib_figure.savefig(abs_filepath,
                                  format="png",
                                  bbox_inches="tight",
                                  pad_inches=0)

        # display image
        img = mpimg.imread(abs_filepath)
        imgplot = plt.imshow(img)
        plt.axis("off")
        plt.title(None)
        set_matplotlib_close(False)
        plt.show()
        # plt.draw()

        # matplotlib_pyplot.show();
        # matplotlib_pyplot.close()
        success = True
    except Exception as err:
        error = err
    # turn back on printing
    finally:
        # sys.stdout = oldstdout
        pass

    # # turn off printing (savefig is annoying)
    # nullwrite = NullWriter()
    # oldstdout = sys.stdout
    # sys.stdout = nullwrite
    # # save plot as image
    # success = False
    # try:
    #     matplotlib.use('Agg')
    #     matplotlib_pyplot.savefig(abs_filepath)
    #     # matplotlib_pyplot.close()
    #     success = True
    # except Exception as err:
    #     pass
    # # turn back on printing
    # finally:
    #     sys.stdout = oldstdout

    # if success, return filepath
    if success:
        return abs_filepath
    # otherwise, throw error
    # (this control flow is a little funky to ensure printing is turned back on
    #  prior to throwing error)
    else:
        raise error
Exemplo n.º 5
0
from ayx.helpers import convertObjToStr, fileExists, isDictMappingStrToStr
from ayx.Datafiles import FileFormat
from ayx.Compiled import pyxdb
from ayx import Settings
from IPython.display import set_matplotlib_close
from IPython import get_ipython


# temp replacement for sys.stdout (write method does nothing)
class NullWriter(object):
    def write(self, arg):
        pass


# setup matplotlib for export to datastream
set_matplotlib_close(False)
try:
    get_ipython().run_line_magic("matplotlib", "inline")
except AttributeError:
    pass  # running outside of jupyter (eg, cmd line tests)


def savePlotToFile(matplotlib_figure, filepath):
    abs_filepath = abspath(filepath)

    # turn off printing (savefig is annoying)
    # nullwrite = NullWriter()
    # oldstdout = sys.stdout
    # sys.stdout = nullwrite
    # save plot as image
    success = False
Exemplo n.º 6
0
Arquivo: fmtest.py Projeto: erl-j/FMPE
def trainNet(net, batch_size, n_epochs, learning_rate):

    # Print all of the hyperparameters of the trning iteration:
    print("===== HYPERPARAMETERS =====")
    print("batch_size=", batch_size)
    print("epochs=", n_epochs)
    print("learning_rate=", learning_rate)
    print("=" * 30)

    # Get training data
    trn_loader = get_trn_loader(batch_size)
    n_batches = len(trn_loader)

    # Create our loss and optimizer functions
    loss, optimizer = createLossAndOptimizer(net, learning_rate)

    # Time for printing
    training_start_time = time.time()

    # Loop for n_epochs
    loss_progress = np.zeros((n_epochs, 1))
    attribute_progress = np.zeros((n_epochs, N_OPS * N_PARAMS))

    for epoch in range(n_epochs):

        running_loss = 0.0
        print_every = n_batches // 10
        start_time = time.time()
        total_train_loss = 0

        trn_delta = 0

        for i, data in enumerate(trn_loader, 0):
            # Get inputs
            inputs, labels = data

            # Wrap them in a Variable object
            inputs, labels = inputs.cuda(), labels.cuda()

            # Set the parameter gradients to zero
            optimizer.zero_grad()

            # Forward pass, backward pass, optimize
            outputs = net(inputs)

            loss_size = loss(outputs, labels)

            if type(trn_delta) == "Int":
                trn_delta = torch.abs(trn_outputs -
                                      labels).cpu().detach().numpy()
            else:
                trn_delta = trn_delta + \
                    torch.abs(outputs - labels).cpu().detach().numpy()+trn_delta

            # computes gradients
            loss_size.backward()

            # performs update step
            optimizer.step()

            # Print statistics
            running_loss += loss_size.data.item()
            total_train_loss += loss_size.data.item()

            # Print every 10th batch of an epoch
            if (i + 1) % (print_every + 1) == 0:

                print("Epoch {}, {:d}% \t train_loss: {:.5f} took: {:.2f}s".
                      format(epoch + 1, int(100 * (i + 1) / n_batches),
                             running_loss / print_every,
                             time.time() - start_time))
                # Reset running loss and time
                running_loss = 0.0
                start_time = time.time()

        per_attribute_delta_sum = np.sum(trn_delta, axis=0)

        attribute_progress[epoch, :] = per_attribute_delta_sum

        plt.plot(attribute_progress)
        fig_size = plt.gcf().get_size_inches()  # Get current size
        sizefactor = 1.5  # Set a zoom factor
        # Modify the current size by the factor
        plt.gcf().set_size_inches(sizefactor * fig_size)
        plt.show()

        loss_progress[epoch, :] = total_train_loss

        plt.plot(loss_progress)
        fig_size = plt.gcf().get_size_inches()  # Get current size
        sizefactor = 1.5  # Set a zoom factor
        # Modify the current size by the factor
        plt.gcf().set_size_inches(sizefactor * fig_size)
        plt.show()

        # print([val_outputs[0],labels[0]])

        try:
            play_audio(fm.generate(outputs[0].cpu().detach().numpy()),
                       SAMPLE_RATE)
        except:
            print("failed producing clip")

        time.sleep(2)

        play_audio(fm.generate(labels[0].cpu().detach().numpy()), SAMPLE_RATE)

        ipd.set_matplotlib_close(close=True)

        print("saving model")
        torch.save(net.state_dict(), MODEL_PATH)

    print("Training finished, took {:.2f}s".format(time.time() -
                                                   training_start_time))