import numpy as np import pandas as pd import PIL.Image import matplotlib.pyplot as plt from catalyst.utils import get_one_hot from imageio import imread import pretrainedmodels import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from visdom import Visdom jupytools.syspath.add(join(dirname(os.getcwd()), 'protein_project')) jupytools.syspath.add('rxrx1-utils') if jupytools.is_notebook(): from tqdm import tqdm_notebook as tqdm else: from tqdm import tqdm as tdqm from basedir import ROOT, NUM_CLASSES from dataset import build_stats_index torch.set_default_tensor_type(torch.FloatTensor) from augmentation import JoinChannels, SwapChannels, Resize, ToFloat, Rescale from augmentation import VerticalFlip, HorizontalFlip, PixelStatsNorm, composer from augmentation import AugmentedImages, bernoulli default_open_fn = imread # PIL.Image.open
bar.update(1) vis.line(X=[iteration], Y=[avg_loss], win='loss', name='avg_loss', update='append') val_dl = dataset['valid'] n = len(val_dl) model.eval() with torch.no_grad(): matches = [] with tqdm(total=n) as bar: for batch in val_dl: x = batch['features'].to(device) y = batch['targets'].to(device) out = model(x) y_pred = out.softmax(dim=1).argmax(dim=1) matched = (y == y_pred).detach().cpu().numpy().tolist() matches.extend(matched) bar.update(1) acc = np.mean(matches) vis.line(X=[epoch], Y=[acc], win='acc', name='val_acc', update='append') print(f'validation accuracy: {acc:2.2%}') acc_str = str(int(round(acc * 10_000, 0))) path = os.path.join(logdir, f'train.{epoch}.{acc_str}.pth') torch.save(model.state_dict(), path) if __name__ == '__main__': if not is_notebook(): ancli.make_cli(train)