示例#1
0
 def __init__(self,
              network,
              start=0,
              end=1,
              dt=0.01,
              interval=100,
              figsize=(10, 10)):
     plt.style.use('classic')
     self.network = network
     self.start = start
     self.end = end
     self.dt = dt
     self.interval = interval
     self.fig, self.ax = plt.subplots(figsize=figsize)
     if is_notebook():
         plt.close()
     self.fig.set_facecolor((0.8, 0.8, 0.8))
     self.ax.set_facecolor((0.2, 0.2, 0.2))
     self.ax.set_title(f'time = {start}')
     self.Lx, self.Ly = network.connectivity.length_x, network.connectivity.length_y
     self.ax.axis(
         [-.2 * self.Lx, 1.2 * self.Lx, -.2 * self.Lx, 1.2 * self.Lx])
     self.ax.get_xaxis().set_ticks([])
     self.ax.get_yaxis().set_ticks([])
     self.compile()
示例#2
0
 def save(self, filename=None):
     if filename == None:
         from time import gmtime, strftime
         curr = strftime("%Y%m%d_%H%M", gmtime())
         filename = r'movie/random_movie_' + curr + '.mp4'
     else:
         filename = r'movie/' + filename
     Writer = animation.writers['ffmpeg']
     writer = Writer(fps=1, metadata=dict(artist='Me'), bitrate=1800)
     self.movie.save(filename, writer=writer)
     if is_notebook():
         plt.close()
示例#3
0
 def __init__(self,
              network,
              start=0,
              end=1,
              dt=0.01,
              interval=100,
              figsize=(10, 10),
              edge_mode='current'):
     plt.style.use('classic')
     self.network = network
     self.start = start
     self.end = end
     self.dt = dt
     self.interval = interval
     self.edge_mode = edge_mode
     self.fig, self.ax = plt.subplots(figsize=figsize)
     if is_notebook():
         plt.close()
     self.fig.set_facecolor((0.8, 0.8, 0.8))
     self.ax.set_facecolor((0.8, 0.8, 0.8))
     self.ax.set_title(f'Network Current Flow at t = {start}')
     self.ax.get_xaxis().set_ticks([])
     self.ax.get_yaxis().set_ticks([])
     self.compile()
示例#4
0
    IN_COLAB = True

    drive.mount('/content/drive')
    path = "/content/drive/My Drive/Colab Notebooks/"

    # for python imports from google drive
    sys.path.append(path)
except:
    IN_COLAB = False
    path = "./"

from utils.datasets import Characters as TextDataset
from utils import _Trainer, print_cuda_info, is_notebook
from hmlstm import HMLSTMNetwork, SlopeScheduler

if is_notebook():
    from tqdm.notebook import tqdm_notebook as tqdm
    print("running in notebook and/or colab")
else:
    from tqdm import tqdm

# %%


def train(architecture, data_loader: DataLoader, parameters: Dict[str, float],
          device: torch.device) -> nn.Module:

    epochs = parameters.get("epochs", 1)
    lr = parameters.get("lr", 0.01)

    samples = len(data_loader.dataset)
示例#5
0
def train_model(model,
                arch,
                dataloaders,
                criterion,
                optimizer,
                scheduler=None,
                num_epochs=25,
                output_path='result/',
                start_epoch=0):

    best_acc = 0.0
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    dataset_sizes = {x: len(dataloaders[x].dataset) for x in ['train', 'val']}
    class_names = dataloaders['train'].dataset.classes

    os.makedirs(output_path, exist_ok=True)
    os.makedirs(os.path.join(output_path, "tb"), exist_ok=True)

    writer = SummaryWriter(os.path.join(output_path, "tb"))
    with open(os.path.join(output_path, "log.txt"), 'w') as f:
        metrics = "phase, epoch, loss, accuracy"
        f.write(metrics + '\n')

    if not os.path.exists(
            os.path.join(output_path.split("/")[0], "summary.txt")):
        open(os.path.join(output_path.split("/")[0], "summary.txt"),
             'w').close()

    ncols = 600 if is_notebook() else 80
    for epoch in trange(start_epoch, num_epochs, desc="epoch ", ncols=ncols):
        trainval_loss = {'train': 0.0, 'val': 0.0}
        trainval_acc = {'train': 0.0, 'val': 0.0}
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in tqdm(dataloaders[phase],
                                       total=len(dataloaders[phase]),
                                       desc=" {}".format(phase),
                                       ncols=ncols,
                                       leave=False):

                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backpropagation
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                loss_data = loss.detach().cpu().item() * inputs.size(0)
                running_loss += loss_data
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train' and scheduler is not None:
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double().item() / dataset_sizes[phase]
            trainval_loss[phase] = epoch_loss
            trainval_acc[phase] = epoch_acc

            writer.add_scalar("loss/{}".format(phase), epoch_loss, epoch)
            writer.add_scalar("accuracy/{}".format(phase), epoch_acc, epoch)

            with open(os.path.join(output_path, "log.txt"), 'a') as f:
                metrics = "{}, {}, {:.10f}, {:.10f}".format(
                    phase, epoch, epoch_loss, epoch_acc)
                f.write(metrics + '\n')

            # deep copy the model
            if phase == 'val':
                torch.save(
                    {
                        'arch': arch,
                        'epoch': epoch,
                        'best_acc': best_acc,
                        'class_names': class_names,
                        'num_classes': len(
                            dataloaders['train'].dataset.classes),
                        'optim_state_dict': optimizer.state_dict(),
                        'model_state_dict': model.state_dict(),
                    }, os.path.join(output_path, 'checkpoint.pth'))

                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(
                        {
                            'arch':
                            arch,
                            'epoch':
                            epoch,
                            'best_acc':
                            best_acc,
                            'class_names':
                            class_names,
                            'num_classes':
                            len(dataloaders['train'].dataset.classes),
                            'optim_state_dict':
                            optimizer.state_dict(),
                            'model_state_dict':
                            model.state_dict(),
                        }, os.path.join(output_path,
                                        '{}_best.pth'.format(arch)))

        writer.add_scalars('combined_loss', trainval_loss, epoch)
        writer.add_scalars('combined_accuracy', trainval_acc, epoch)

    print("Best Acc: {:4f}".format(best_acc))
    print("best trained model is saved to: {}".format(
        os.path.join(output_path, arch + '_best.pth')))
    with open(os.path.join(output_path.split("/")[0], "summary.txt"),
              'a') as f:
        f.write("{}, {}\n".format(output_path, best_acc))
示例#6
0
def main(arch="resnet18",
         data_path="dataset/",
         resume="",
         epochs=25,
         batch_size=4,
         img_size=224,
         use_scheduler=False,
         **kwargs):

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')

    ## dataset
    dataloaders = prepare_dataloaders(data_path, img_size, batch_size)
    class_names = dataloaders['train'].dataset.classes
    n_class = len(class_names)

    print("preparing '{}' model with {} class: {}".format(
        arch, n_class, class_names))

    ## models
    model, criterion, optimizer = prepare_model(arch, n_class)

    start_epoch = 0
    if resume != '':
        checkpoint = torch.load(resume)
        if checkpoint["arch"] != arch:
            raise ValueError
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model_state_dict'], strict=True)
        optimizer.load_state_dict(checkpoint['optim_state_dict'])
    else:
        tb_path = os.path.join("result", arch, "tb")
        if os.path.exists(tb_path) and len(os.listdir(tb_path)) > 0:
            import shutil
            for f in os.listdir(tb_path):
                p = os.path.join(tb_path, f)
                if os.path.isdir(p):
                    shutil.rmtree(p)
                else:
                    os.remove(os.path.join(tb_path, f))

    scheduler = None
    if use_scheduler:
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=10,
                                              gamma=0.1)

    print("Training {} on {}".format(arch, device))
    if is_notebook():
        print(
            "you can also check progress on tensorboard, execute in terminal:")
        print("  > tensorboard --logdir result/<model_name>/tb/")

    train_model(model,
                arch,
                dataloaders,
                criterion,
                optimizer,
                scheduler=scheduler,
                num_epochs=epochs,
                output_path=os.path.join("result", arch),
                start_epoch=start_epoch)