def __init__(self, setting): super(ResNetForClassificationModel, self).__init__(setting) ch = SettingHandler(setting) n_trim = int(setting['models']['resNet']['pretrain']['nTrim']) self.net = ch.get_res_net() if setting['models']['resNet']['pretrain']['trimFullConnectedLayer']: self.net = nn.Sequential(*list(self.net.children())[:-n_trim])
def __init__(self, setting_path="settings"): self.settings = settings = EasyDict(load_setting(setting_path)) self.sh = SettingHandler(settings) self.controller = self.sh.get_controller() self.dataset = DatasetFactory(settings).create( settings.data.base.datasetName) self.valid_dataset = None if settings.data.base.isValid: self.valid_dataset = DatasetFactory(settings).create( settings['data']['base']['valid']['datasetName']) self.data_loader = self.dataset.getDataLoader() if settings.data.base.isValid: self.valid_data_loader = self.valid_dataset.getDataLoader() self.sheckpoint_dir = self.sh.get_check_points_dir() self.viz = viewer.TensorBoardXViewer(settings) self.train_monitor = viewer.TrainMonitor(settings) self.n_update_graphs = self.sh.get_update_interval_of_graphs( self.dataset) self.n_update_images = self.sh.get_update_interval_of_images( self.dataset) if settings.data.base.isValid: self.validator = Validator(settings) self.idx_dic = EasyDict({'train': 0, 'test': 1, 'valid': 2})
def __init__(self, setting): super(ResNetForGeneratorModel, self).__init__(setting) self.setting = setting ch = SettingHandler(setting) input_nc = int(setting['base']['numberOfInputImageChannels']) output_nc = int(setting['base']['numberOfOutputImageChannels']) ngf = int(setting['models']['resNet']['generator']['filterSize']) use_bias = setting['models']['resNet']['generator']['useBias'] self.gpu_ids = ch.get_GPU_ID() self.pad_factory = PadFactory(setting) self.norm_layer = NormalizeFactory(setting).create( setting['models']['resNet']['generator']['normalizeLayer']) model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), self.norm_layer(ngf), nn.ReLU(True) ] n_downsampling = 2 for i in range(n_downsampling): mult = 2 ** i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), self.norm_layer(ngf * mult * 2), nn.ReLU(True) ] mult = 2**n_downsampling for i in range(int(setting['models']['resNet']['generator']['nBlocks'])): model += [ResnetBlock(ngf * mult, setting)] for i in range(n_downsampling): mult = 2**(n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), self.norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] model += [self.pad_factory.create(setting['models']['resNet']['generator']['paddingType'])(3)] model += [nn.Conv2d(ngf, output_nc, kernel_size=(7, 7), padding=0)] model += [nn.Tanh()] self.model = nn.Sequential(*model)
def __init__(self, settings): super(UnetForGeneratorModel, self).__init__(settings) self.settings = settings ch = SettingHandler(settings) self.gpu_ids = ch.get_GPU_ID() #self.pad_factory = PadFactory(settings) self.model = UnetBlock(0, settings) for i in range( 1, int(settings['models']['unet']['generator']['numHierarchy'])): self.model = UnetBlock(i, settings, [self.model])
def __init__(self, setting): self.ch = SettingHandler(setting) self.setting = setting self.total_dataset_length = 0 self.mode = "train" self.n_iter = 0 self.writer_dic = { "train": SummaryWriter('runs/train'), "test": SummaryWriter('runs/test'), "valid": SummaryWriter('runs/valid') } self.graph_taglst = self.ch.get_visdom_graph_tags() self.image_taglst = self.ch.get_visdom_image_tags()
def __init__(self, setting=None): super(BaseDataset, self).__init__() self.setting = setting self.ch = SettingHandler(setting) self.mode = 'train' self.nom_transform = self.ch.get_normalize_transform() self.to_tensor = transforms.ToTensor() self.is_shuffle = self.setting['data']['base']['isShuffle'] if setting is not None else True #self.port = setting['utils']['connection']['port'] self.train_test_ratio = float(setting['data']['base']['trainTestRatio']) if setting is not None else 1 self.dsf = dsf.DataSourceFactory(train_test_ratio=self.train_test_ratio) self._setDataSource()
def __init__(self, setting): self.ch = SettingHandler(setting) portNumber = self.ch.get_visdom_port_number() self.setting = setting self.viz = Visdom(port=portNumber) self.viz_image_dic = {} self.viz_graph_dic = {} self.title_dic = {} self.graph_taglst = self.ch.get_visdom_graph_tags() self.graph_xlabel_dic = self.ch.get_visdom_graph_x_labels() self.graph_ylabel_dic = self.ch.get_visdom_graph_y_labels() self.graph_title_dic = self.ch.get_visdom_graph_titles() self.image_taglst = self.ch.get_visdom_image_tags() self.image_title_dic = self.ch.get_visdom_image_titles()
def __init__(self, settings): super(BaseModel, self).__init__() if settings is not None: self.setSetting(settings) ch = self.ch = SettingHandler(settings) self.gpu_ids = ch.get_GPU_ID() else: self.gpu_ids = [-1] self.feature_image_dic = {}
def __init__(self, settings=None, **kwargs): if settings is not None: self.is_data_parallel = settings['base']['isDataParallel'] self.is_showmode_info = settings['ui']['base']['isShowModelInfo'] else: self.is_data_parallel = False self.is_showmode_info = False self.settings = settings ch = SettingHandler(settings) self.gpu_ids = ch.get_GPU_ID() self.checkpoints_dir = ch.get_check_points_dir() self.loss_factory = lf.LossFactory(settings) self.optimizer_factory = OptimizerFactory(settings) self.show_model_lst = [] self.model_factory = ModelFactory(settings)
def __init__(self, i, settings, medium_layers=[]): super(UnetBlock, self).__init__() self.ch = SettingHandler(settings) self.n_input = int(settings['models']['unet']['generator'] ['numberOfInputImageChannels']) self.n_output = int(settings['models']['unet']['generator'] ['numberOfOutputImageChannels']) self.n_conv = int(settings['models']['unet']['generator'] ['numComvolutionEachHierarchy']) num_hierarchy = int( settings['models']['unet']['generator']['numHierarchy']) self.scale_ratio = int( settings['models']['unet']['generator']['scaleRatio']) self.feature_size = int( settings['models']['unet']['generator']['featureSize']) self.input_nc = int(2**(num_hierarchy - i - 2) * self.feature_size * self.scale_ratio) # use_bias = settings['models']['unet']['generator']['useBias'] self.norm_layer = NormalizeFactory(settings).create( settings['models']['unet']['generator']['normalizeLayer']) self.inner_activation = ActivationFactory(settings).create( settings['models']['unet']['generator']['innerActivation']) self.output_activation = ActivationFactory(settings).create( settings['models']['unet']['generator']['outputActivation']) self.is_outermost = self.is_innermost = False if i == (num_hierarchy - 1): self.is_outermost = True elif i == 0: self.is_innermost = True if self.is_innermost: self.model = nn.Sequential(*self._genMediumSideLayers()) elif self.is_outermost: input_outermost_layers, output_outermost_layers = self._genOutermostLayers( ) self.model = nn.Sequential(*(input_outermost_layers + medium_layers + output_outermost_layers)) else: input_layers = self._genInputSideLayers() output_layers = self._genOutputSideLayers() self.input_model = nn.Sequential(*input_layers) self.input_medium_model = nn.Sequential(*(input_layers + medium_layers)) self.output_model = nn.Sequential(*output_layers)
class VisdomViewer: def __init__(self, setting): self.ch = SettingHandler(setting) portNumber = self.ch.get_visdom_port_number() self.setting = setting self.viz = Visdom(port=portNumber) self.viz_image_dic = {} self.viz_graph_dic = {} self.title_dic = {} self.graph_taglst = self.ch.get_visdom_graph_tags() self.graph_xlabel_dic = self.ch.get_visdom_graph_x_labels() self.graph_ylabel_dic = self.ch.get_visdom_graph_y_labels() self.graph_title_dic = self.ch.get_visdom_graph_titles() self.image_taglst = self.ch.get_visdom_image_tags() self.image_title_dic = self.ch.get_visdom_image_titles() def initGraph(self, tag, xlabel=None, ylabel=None, title=None): self.title_dic[tag] = title self.viz_graph_dic[tag] = self.viz.line( np.array([[0, 0, 0]]), np.array([[np.nan, np.nan, np.nan]]), opts=dict(xlabel=xlabel, ylabel=ylabel)) def initGraphs(self): for tag in self.graph_taglst: self.initGraph(tag, self.graph_xlabel_dic[tag], tag, tag) def updateGraph(self, tag, x_value, y_value, opts=None, idx=0): y_arr = np.array([[np.nan, np.nan, np.nan]]) y_arr[0, idx] = y_value if y_value is not None: self.viz.line(X=np.array([[x_value] * 3]), Y=y_arr, win=self.viz_graph_dic[tag], update='append', opts=opts) def updateGraphs(self, x_value, value_dic, opts=None, idx=0): for tag in self.graph_taglst: self.updateGraph(tag, x_value, value_dic[tag], opts=opts, idx=idx) def initImage(self, tag, title, dummy=torch.Tensor(3, 100, 100)): self.title_dic[tag] = title self.viz_image_dic[tag] = self.viz.image(dummy, opts=dict(title=title)) def initImages(self, dummy=torch.Tensor(3, 100, 100)): for tag in self.image_taglst: self.initImage(tag, tag, dummy=dummy) #self.initImage(tag, self.image_title_dic[tag], dummy=dummy) def updateImage(self, tag, image, title, n_iter): self.title_dic[tag] = title if self.is_cuda(image): image = image.cpu() if image is not None: #if 'normalize' in self.setting['dataset settings']['targetTransform']: if 'normalize' in self.setting['data']['base']['inputTransform']: image = (image + 1) / 2.0 * 255 if image.dim() == 3: self.viz.image(image, opts=dict(title=title), win=self.viz_image_dic[tag]) else: self.viz.images(image, opts=dict(title=title), win=self.viz_image_dic[tag]) def updateImages(self, image_dic, n_iter): for tag in self.image_taglst: self.updateImage(tag, image_dic[tag], tag, n_iter) def is_cuda(self, tensor): return "cuda" in str(type(tensor)) def destructVisdom(self): pass
def __init__(self, setting): self.ch = SettingHandler(setting) self.loss_tags = self.ch.get_visdom_graph_tags() self.loss_dic = self._genLossDic() self.th = TimeHandler() self._initialize()
class TensorBoardXViewer: def __init__(self, setting): self.ch = SettingHandler(setting) self.setting = setting self.total_dataset_length = 0 self.mode = "train" self.n_iter = 0 self.writer_dic = { "train": SummaryWriter('runs/train'), "test": SummaryWriter('runs/test'), "valid": SummaryWriter('runs/valid') } self.graph_taglst = self.ch.get_visdom_graph_tags() self.image_taglst = self.ch.get_visdom_image_tags() def initGraph(self, tag, xlabel=None, ylabel=None, title=None): self.n_iter = 0 def initGraphs(self): self.n_iter = 0 def setMode(self, mode): self.mode = mode def setTotalDataLoaderLength(self, total_dataset_length): self.total_dataset_length = total_dataset_length def updateGraph(self, tag, x_value, y_value, opts=None, idx=0): self.writer_dic[self.mode].add_scalar(tag, y_value, x_value) #self.writer.add_scalar(tag, y_value, x_value / self.total_dataset_length) def updateGraphs(self, x_value, value_dic, opts=None, idx=0): for tag in self.graph_taglst: self.updateGraph(tag, x_value, value_dic[tag], opts=opts, idx=idx) def initImage(self, tag, title, dummy=torch.Tensor(3, 100, 100)): pass def initImages(self, dummy=torch.Tensor(3, 100, 100)): pass def updateImage(self, tag, image, title, n_iter): if len(image.shape) != 3: image = image[None] if 'normalize' in self.setting['data']['base']['inputTransform']: image = ((image + 1) / 2.0 * 255).cpu().detach().numpy().astype( np.uint8) self.writer_dic[self.mode].add_image(tag, image, n_iter) def updateImages(self, image_dic, n_iter): for tag in self.image_taglst: #if 'normalize' in self.setting['data']['base']['inputTransform']: # image_dic[tag] = ((image_dic[tag] + 1) / 2.0 * 255).cpu().detach().numpy().astype(np.uint8) self.writer_dic[self.mode].add_image( tag, vutils.make_grid(image_dic[tag], normalize=True, scale_each=True), n_iter) #self.writer.add_images(tag, vutils.make_grid(image_dic[tag], normalize=True, scale_each=True), n_iter) def is_cuda(self, tensor): return "cuda" in str(type(tensor)) def destructVisdom(self): pass
class TrainMonitor: def __init__(self, setting): self.ch = SettingHandler(setting) self.loss_tags = self.ch.get_visdom_graph_tags() self.loss_dic = self._genLossDic() self.th = TimeHandler() self._initialize() def _initialize(self): self.progress_bar = '' self.progress_value = 0 self.thresh = 0 for tag in self.loss_tags: self.loss_dic[tag] = [] def flash(self): self._initialize() def _genLossDic(self): ret = {} for tag in self.loss_tags: ret[tag] = [] return ret def setLosses(self, loss_dic): for tag in self.loss_tags: self.setLoss(tag, loss_dic[tag]) def setLoss(self, tag, value): self.loss_dic[tag] += [value] def _updateProgress(self, now_progress_value, now_progress_bar, current_n_iter, n_iter): if now_progress_value > self.thresh: now_progress_bar += r'8' self.thresh += 10 progress_value = (1 + current_n_iter) / n_iter * 100 return progress_value, now_progress_bar def dumpCurrentProgress(self, n_epoch, current_n_iter, n_iter): self.progress_value, self.progress_bar = \ self._updateProgress(self.progress_value, self.progress_bar, current_n_iter, n_iter) losses = 'losses: ' for k, v in self.loss_dic.items(): if v[-1] is None: line = "{}:{}, " else: line = "{}:{:6.3f}, " losses += line.format(k, v[-1]) sys.stdout.write( "\repoch {:03d}: [{}] [{:10s}] {:4.0f} % {} {:3d}/{:3d}".format( n_epoch, self.th.getElapsedTime(is_now=False), self.progress_bar, self.progress_value, losses, current_n_iter, n_iter)) def dumpAverageLossOnEpoch(self, n_epoch): ave_losses = '' for k, v in self.loss_dic.items(): if v[-1] is not None: ave_losses += \ "{} ave. loss: {:6.3f}, ".format(k, np.array(v).mean()) else: ave_losses += "{} ave. loss: {}, ".format(k, None) sys.stdout.write("\repoch {:03d}: [{}] [{:10s}] {:4.0f} % {}\n".format( n_epoch, self.th.getElapsedTime(is_now=False), self.progress_bar, self.progress_value, ave_losses))
def __init__(self, settings, gpu_ids=[]): super(PatchGANModel, self).__init__(settings) # get global self.settings = settings ch = SettingHandler(settings) self.gpu_ids = ch.get_GPU_ID() norm_layer = ch.get_norm_layer() n = 1 if settings['base']['controller'] == 'pix2pix' or settings['base']['controller'] == 'pix2pixMulti': n = 2 # for image pooling input_nc = settings['models']['patchGANDiscriminator']['input_n'] * n use_bias = True # if type(norm_layer) == functools.partial: # use_bias = norm_layer.func == nn.InstanceNorm2d # else: # use_bias = norm_layer == nn.InstanceNorm2d print(settings['models']) kw = int(settings['models']['patchGANDiscriminator']['kernelSize']) padw = int(settings['models']['patchGANDiscriminator']['paddingSize']) ndf = int(settings['models']['patchGANDiscriminator']['numberOfDiscriminatorFilters']) n_layers = int(settings['models']['patchGANDiscriminator']['nLayers']) is_sigmoid = settings['models']['patchGANDiscriminator']['useSigmoid'] sequence = [ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True) ] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): nf_mult_prev = nf_mult nf_mult = min(2**n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2**n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True), nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] if is_sigmoid: sequence += [nn.Sigmoid()] self.model = nn.Sequential(*sequence)
def __init__(self, settings): import aimaker.models.model_factory as mf import aimaker.loss.loss_factory as lf import aimaker.optimizers.optimizer_factory as of self.settings = settings ch = SettingHandler(settings) self.gpu_ids = ch.get_GPU_ID() self.checkpoints_dir = ch.get_check_points_dir() model_factory = mf.ModelFactory(settings) loss_factory = lf.LossFactory(settings) optimizer_factory = of.OptimizerFactory(settings) # for discriminator regularization self.pool_fake_A = ImagePool( int(settings['controllers']['cycleGAN']['imagePoolSize'])) self.pool_fake_B = ImagePool( int(settings['controllers']['cycleGAN']['imagePoolSize'])) name = settings['controllers']['cycleGAN']['generatorModel'] self.netG_A = model_factory.create(name) self.netG_B = model_factory.create(name) if len(self.gpu_ids): self.netG_A = self.netG_A.cuda(self.gpu_ids[0]) self.netG_B = self.netG_B.cuda(self.gpu_ids[0]) name = settings['controllers']['cycleGAN']['discriminatorModel'] self.netD_A = model_factory.create(name) self.netD_B = model_factory.create(name) if len(self.gpu_ids): self.netD_A = self.netD_A.cuda(self.gpu_ids[0]) self.netD_B = self.netD_B.cuda(self.gpu_ids[0]) self.loadModels() self.criterionGAN = loss_factory.create("GANLoss") self.criterionCycle = loss_factory.create( settings['controllers']['cycleGAN']['cycleLoss']) self.criterionIdt = loss_factory.create( settings['controllers']['cycleGAN']['idtLoss']) if len(self.gpu_ids): self.criterionGAN = self.criterionGAN.cuda(self.gpu_ids[0]) self.criterionCycle = self.criterionCycle.cuda(self.gpu_ids[0]) self.criterionIdt = self.criterionIdt.cuda(self.gpu_ids[0]) # initialize optimizers self.optimizer_G = optimizer_factory.create( settings['controllers']['cycleGAN']['generatorOptimizer'])( it.chain(self.netG_A.parameters(), self.netG_B.parameters()), settings) if settings['data']['base']['isTrain']: self.optimizer_D_A = optimizer_factory.create( settings['controllers']['cycleGAN']['D_AOptimizer'])( self.netD_A.parameters(), settings) self.optimizer_D_B = optimizer_factory.create( settings['controllers']['cycleGAN']['D_BOptimizer'])( self.netD_B.parameters(), settings) if settings['ui']['base']['isShowModelInfo']: self.showModel()
class Trainer: def __init__(self, setting_path="settings"): self.settings = settings = EasyDict(load_setting(setting_path)) self.sh = SettingHandler(settings) self.controller = self.sh.get_controller() self.dataset = DatasetFactory(settings).create( settings.data.base.datasetName) self.valid_dataset = None if settings.data.base.isValid: self.valid_dataset = DatasetFactory(settings).create( settings['data']['base']['valid']['datasetName']) self.data_loader = self.dataset.getDataLoader() if settings.data.base.isValid: self.valid_data_loader = self.valid_dataset.getDataLoader() self.sheckpoint_dir = self.sh.get_check_points_dir() self.viz = viewer.TensorBoardXViewer(settings) self.train_monitor = viewer.TrainMonitor(settings) self.n_update_graphs = self.sh.get_update_interval_of_graphs( self.dataset) self.n_update_images = self.sh.get_update_interval_of_images( self.dataset) if settings.data.base.isValid: self.validator = Validator(settings) self.idx_dic = EasyDict({'train': 0, 'test': 1, 'valid': 2}) def _getInfo(self): info = EasyDict() info.current_epoch = 0 info.train = EasyDict({"v_iter": 0, "current_n_iter": 0}) info.test = EasyDict({"v_iter": 0, "current_n_iter": 0}) info.valid = EasyDict({"v_iter": 0, "current_n_iter": 0}) return info def train(self): n_epoch = self.settings['base']['nEpoch'] if self.settings['base']['isView']: self.viz.initGraphs() self.viz.initImages() train_n_iter = len(self.data_loader) if self.settings.data.base.isValid: valid_n_iter = len(self.valid_data_loader) if os.path.exists(self.settings.base.infoPath): info = EasyDict(json.load(open(self.settings.base.infoPath))) else: info = self._getInfo() try: model_save_interval = self.sh.get_model_save_interval() if self.settings.data.base.isValid: model_save_interval_valid = self.sh.get_model_save_interval_for_valid( ) for current_epoch in range(info.current_epoch, n_epoch): info.current_epoch = current_epoch if self.settings['data']['base']['isTrain']: info.mode = "train" print('{}:'.format('train')) self.dataset.set_mode('train') self.dataset.getTransforms() self.viz.setMode('train') self.data_loader = self.dataset.getDataLoader() self.controller.set_mode('train') info = self._learning('train', info, current_epoch, self.data_loader, train_n_iter) if self.settings['data']['base']['isTest']: info.mode = "test" print('{}:'.format('test')) self.dataset.set_mode('test') self.dataset.getTransforms() self.viz.setMode('test') self.data_loader = self.dataset.getDataLoader() self.controller.set_mode('test') info = self._learning('test', info, current_epoch, self.data_loader, train_n_iter) if self.valid_dataset is not None: if self.settings['data']['base']['isValid']: info.mode = "valid" print('{}:'.format('valid')) self.valid_dataset.set_mode('valid') self.valid_dataset.getTransforms() self.viz.setMode('valid') self.valid_data_loader = self.valid_dataset.getDataLoader( ) self.controller.set_mode('valid') info = self._learning('valid', info, current_epoch, self.valid_data_loader, valid_n_iter) if current_epoch != 0 and not current_epoch % model_save_interval_valid: self.controller._save_model( self.controller.get_model(), self.settings['validator']['base'] ['modelPath'], is_fcnt=False) self.validator.upload( ) #self.settings['valid_data']['data']['base']['datasetName']) if not current_epoch % model_save_interval: self.controller.save_models() except: import traceback traceback.print_exc() self.controller.save_models() if self.settings['base']['isView']: self.viz.destructVisdom() def _saveInfo(self, info): with open(self.settings.base.infoPath, 'w') as fp: json.dump(info, fp) def _learning(self, mode, info, current_epoch, data_loader, train_n_iter): n_iter = len(data_loader) ratio = train_n_iter / n_iter v_iter = info[mode].v_iter self.viz.setTotalDataLoaderLength(len(self.data_loader)) is_volatile = False if mode == 'train' else True for current_n_iter, data in enumerate(data_loader): if current_n_iter < info[mode].current_n_iter: continue info[mode].current_n_iter = current_n_iter info[mode].v_iter = v_iter try: self.controller.set_input(data, is_volatile) self.controller.forward() if mode == 'train': self.controller.backward() loss_dic = self.controller.get_losses() if mode == 'valid': self.validator.setLosses(loss_dic) self.train_monitor.setLosses(loss_dic) self.train_monitor.dumpCurrentProgress(current_epoch, current_n_iter, n_iter) if not current_n_iter % self.n_update_graphs: if self.settings['base']['isView']: self.viz.updateGraphs(ratio * v_iter, loss_dic, idx=self.idx_dic[mode]) if not current_n_iter % self.n_update_images: if self.settings['base']['isView']: self.viz.updateImages(self.controller.get_images(), current_n_iter) except KeyboardInterrupt: self._saveInfo(info) sys.exit() break except FileNotFoundError: import traceback traceback.print_exc() except: import traceback traceback.print_exc() break v_iter += 1 #if mode == 'valid': # self.validator.uploadModelIfSOTA(current_epoch) info[mode].v_iter = v_iter info[mode].current_n_iter = 0 # reset self.train_monitor.dumpAverageLossOnEpoch(current_epoch) self.train_monitor.flash() return info
class BaseDataset(data.Dataset): def __init__(self, setting=None): super(BaseDataset, self).__init__() self.setting = setting self.ch = SettingHandler(setting) self.mode = 'train' self.nom_transform = self.ch.get_normalize_transform() self.to_tensor = transforms.ToTensor() self.is_shuffle = self.setting['data']['base']['isShuffle'] if setting is not None else True #self.port = setting['utils']['connection']['port'] self.train_test_ratio = float(setting['data']['base']['trainTestRatio']) if setting is not None else 1 self.dsf = dsf.DataSourceFactory(train_test_ratio=self.train_test_ratio) self._setDataSource() def getDataLoader(self): return dataloader.DataLoader(self, batch_size=self.ch.get_batch_size(self.mode), shuffle=self.is_shuffle, ) def getTransforms(self): self.input_transform = self.ch.get_input_transform(self.mode) self.target_transform = self.ch.get_target_transform(self.mode) self.common_transform = self.ch.get_common_transform(self.mode) def _setDataSource(self): self.ds = "dammy data Source" pass def _getInput(self, index): pass def _getTarget(self, index): pass def _inputTransform(self, input): if isinstance(input, list): input = torch.cat([self.input_transform(x) for x in input]) else: input = self.input_transform(input) return input def _targetTransform(self, target): if isinstance(target, list): target = torch.cat([self.target_transform(x) for x in target]) else: target = self.target_transform(target) return target def _commonTransform(self, input, target): seed = random.randint(0, 2147483647) if isinstance(input, list): lst = [] for i in input: random.seed(seed) lst += [self.common_transform(i)] input = lst else: random.seed(seed) input = self.common_transform(input) if isinstance(target, list): lst = [] for i in target: random.seed(seed) lst += [self.common_transform(i)] target = lst else: random.seed(seed) target = self.common_transform(target) return input, target def setMode(self, mode): if mode == 'valid': mode == 'train' self.ds.setMode(mode) def __getitem__(self, index): input = self._getInput(index) target = self._getTarget(index) if self.common_transform: input, target = self._commonTransform(input, target) if self.input_transform: input = self._inputTransform(input) if self.target_transform: target = self._targetTransform(target) return input, target def __len__(self): return len(self.ds)