Ejemplo n.º 1
0
 def __init__(self, setting):
     super(ResNetForClassificationModel, self).__init__(setting)
     ch = SettingHandler(setting)
     n_trim = int(setting['models']['resNet']['pretrain']['nTrim'])
     self.net = ch.get_res_net()
     if setting['models']['resNet']['pretrain']['trimFullConnectedLayer']:
         self.net = nn.Sequential(*list(self.net.children())[:-n_trim])
Ejemplo n.º 2
0
    def __init__(self, setting_path="settings"):
        self.settings = settings = EasyDict(load_setting(setting_path))
        self.sh = SettingHandler(settings)
        self.controller = self.sh.get_controller()
        self.dataset = DatasetFactory(settings).create(
            settings.data.base.datasetName)
        self.valid_dataset = None
        if settings.data.base.isValid:
            self.valid_dataset = DatasetFactory(settings).create(
                settings['data']['base']['valid']['datasetName'])
        self.data_loader = self.dataset.getDataLoader()
        if settings.data.base.isValid:
            self.valid_data_loader = self.valid_dataset.getDataLoader()

        self.sheckpoint_dir = self.sh.get_check_points_dir()
        self.viz = viewer.TensorBoardXViewer(settings)
        self.train_monitor = viewer.TrainMonitor(settings)

        self.n_update_graphs = self.sh.get_update_interval_of_graphs(
            self.dataset)
        self.n_update_images = self.sh.get_update_interval_of_images(
            self.dataset)
        if settings.data.base.isValid:
            self.validator = Validator(settings)
        self.idx_dic = EasyDict({'train': 0, 'test': 1, 'valid': 2})
Ejemplo n.º 3
0
    def __init__(self,
                 setting):
        super(ResNetForGeneratorModel, self).__init__(setting)

        self.setting = setting
        ch = SettingHandler(setting)
        input_nc         = int(setting['base']['numberOfInputImageChannels'])
        output_nc        = int(setting['base']['numberOfOutputImageChannels'])
        ngf              = int(setting['models']['resNet']['generator']['filterSize'])
        use_bias         = setting['models']['resNet']['generator']['useBias']
        self.gpu_ids     = ch.get_GPU_ID()
        self.pad_factory = PadFactory(setting)
        self.norm_layer  = NormalizeFactory(setting).create(
                                setting['models']['resNet']['generator']['normalizeLayer'])

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                 self.norm_layer(ngf),
                 nn.ReLU(True)
                ]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult   = 2 ** i
            model += [nn.Conv2d(ngf * mult,
                                ngf * mult * 2,
                                kernel_size=3,
                                stride=2,
                                padding=1,
                                bias=use_bias),
                      self.norm_layer(ngf * mult * 2),
                      nn.ReLU(True)
                     ]
        mult = 2**n_downsampling

        for i in range(int(setting['models']['resNet']['generator']['nBlocks'])):
            model += [ResnetBlock(ngf * mult, setting)] 

        for i in range(n_downsampling):
            mult   = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, 
                                         int(ngf * mult / 2),
                                         kernel_size=3, 
                                         stride=2,
                                         padding=1,
                                         output_padding=1,
                                         bias=use_bias),
                      self.norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]

        model += [self.pad_factory.create(setting['models']['resNet']['generator']['paddingType'])(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=(7, 7), padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)
Ejemplo n.º 4
0
    def __init__(self, settings):

        super(UnetForGeneratorModel, self).__init__(settings)
        self.settings = settings
        ch = SettingHandler(settings)
        self.gpu_ids = ch.get_GPU_ID()
        #self.pad_factory = PadFactory(settings)

        self.model = UnetBlock(0, settings)
        for i in range(
                1,
                int(settings['models']['unet']['generator']['numHierarchy'])):
            self.model = UnetBlock(i, settings, [self.model])
Ejemplo n.º 5
0
    def __init__(self, setting):
        self.ch = SettingHandler(setting)
        self.setting = setting
        self.total_dataset_length = 0
        self.mode = "train"

        self.n_iter = 0
        self.writer_dic = {
            "train": SummaryWriter('runs/train'),
            "test": SummaryWriter('runs/test'),
            "valid": SummaryWriter('runs/valid')
        }
        self.graph_taglst = self.ch.get_visdom_graph_tags()
        self.image_taglst = self.ch.get_visdom_image_tags()
Ejemplo n.º 6
0
    def __init__(self, setting=None):
        super(BaseDataset, self).__init__()

        self.setting = setting
        self.ch = SettingHandler(setting)
        self.mode = 'train'
        self.nom_transform    = self.ch.get_normalize_transform()
        self.to_tensor        = transforms.ToTensor()
        self.is_shuffle = self.setting['data']['base']['isShuffle'] if setting is not None else True

        #self.port = setting['utils']['connection']['port']
        self.train_test_ratio = float(setting['data']['base']['trainTestRatio']) if setting is not None else 1

        self.dsf = dsf.DataSourceFactory(train_test_ratio=self.train_test_ratio)
        self._setDataSource()
Ejemplo n.º 7
0
    def __init__(self, setting):
        self.ch = SettingHandler(setting)
        portNumber = self.ch.get_visdom_port_number()
        self.setting = setting
        self.viz = Visdom(port=portNumber)
        self.viz_image_dic = {}
        self.viz_graph_dic = {}
        self.title_dic = {}

        self.graph_taglst = self.ch.get_visdom_graph_tags()
        self.graph_xlabel_dic = self.ch.get_visdom_graph_x_labels()
        self.graph_ylabel_dic = self.ch.get_visdom_graph_y_labels()
        self.graph_title_dic = self.ch.get_visdom_graph_titles()
        self.image_taglst = self.ch.get_visdom_image_tags()
        self.image_title_dic = self.ch.get_visdom_image_titles()
Ejemplo n.º 8
0
    def __init__(self, settings):
        super(BaseModel, self).__init__()
        if settings is not None:
            self.setSetting(settings)
            ch = self.ch = SettingHandler(settings)
            self.gpu_ids = ch.get_GPU_ID()
        else:
            self.gpu_ids = [-1]

        self.feature_image_dic = {}
Ejemplo n.º 9
0
    def __init__(self, settings=None, **kwargs):

        if settings is not None:
            self.is_data_parallel = settings['base']['isDataParallel']
            self.is_showmode_info = settings['ui']['base']['isShowModelInfo']
        else:
            self.is_data_parallel = False
            self.is_showmode_info = False

        self.settings = settings
        ch = SettingHandler(settings)
        self.gpu_ids = ch.get_GPU_ID()
        self.checkpoints_dir = ch.get_check_points_dir()

        self.loss_factory = lf.LossFactory(settings)
        self.optimizer_factory = OptimizerFactory(settings)
        self.show_model_lst = []

        self.model_factory = ModelFactory(settings)
Ejemplo n.º 10
0
    def __init__(self, i, settings, medium_layers=[]):
        super(UnetBlock, self).__init__()
        self.ch = SettingHandler(settings)
        self.n_input = int(settings['models']['unet']['generator']
                           ['numberOfInputImageChannels'])
        self.n_output = int(settings['models']['unet']['generator']
                            ['numberOfOutputImageChannels'])
        self.n_conv = int(settings['models']['unet']['generator']
                          ['numComvolutionEachHierarchy'])
        num_hierarchy = int(
            settings['models']['unet']['generator']['numHierarchy'])
        self.scale_ratio = int(
            settings['models']['unet']['generator']['scaleRatio'])
        self.feature_size = int(
            settings['models']['unet']['generator']['featureSize'])
        self.input_nc = int(2**(num_hierarchy - i - 2) * self.feature_size *
                            self.scale_ratio)
        #        use_bias               = settings['models']['unet']['generator']['useBias']
        self.norm_layer = NormalizeFactory(settings).create(
            settings['models']['unet']['generator']['normalizeLayer'])
        self.inner_activation = ActivationFactory(settings).create(
            settings['models']['unet']['generator']['innerActivation'])
        self.output_activation = ActivationFactory(settings).create(
            settings['models']['unet']['generator']['outputActivation'])

        self.is_outermost = self.is_innermost = False
        if i == (num_hierarchy - 1):
            self.is_outermost = True
        elif i == 0:
            self.is_innermost = True

        if self.is_innermost:
            self.model = nn.Sequential(*self._genMediumSideLayers())
        elif self.is_outermost:
            input_outermost_layers, output_outermost_layers = self._genOutermostLayers(
            )
            self.model = nn.Sequential(*(input_outermost_layers +
                                         medium_layers +
                                         output_outermost_layers))
        else:
            input_layers = self._genInputSideLayers()
            output_layers = self._genOutputSideLayers()
            self.input_model = nn.Sequential(*input_layers)
            self.input_medium_model = nn.Sequential(*(input_layers +
                                                      medium_layers))
            self.output_model = nn.Sequential(*output_layers)
Ejemplo n.º 11
0
class VisdomViewer:
    def __init__(self, setting):
        self.ch = SettingHandler(setting)
        portNumber = self.ch.get_visdom_port_number()
        self.setting = setting
        self.viz = Visdom(port=portNumber)
        self.viz_image_dic = {}
        self.viz_graph_dic = {}
        self.title_dic = {}

        self.graph_taglst = self.ch.get_visdom_graph_tags()
        self.graph_xlabel_dic = self.ch.get_visdom_graph_x_labels()
        self.graph_ylabel_dic = self.ch.get_visdom_graph_y_labels()
        self.graph_title_dic = self.ch.get_visdom_graph_titles()
        self.image_taglst = self.ch.get_visdom_image_tags()
        self.image_title_dic = self.ch.get_visdom_image_titles()

    def initGraph(self, tag, xlabel=None, ylabel=None, title=None):
        self.title_dic[tag] = title
        self.viz_graph_dic[tag] = self.viz.line(
            np.array([[0, 0, 0]]),
            np.array([[np.nan, np.nan, np.nan]]),
            opts=dict(xlabel=xlabel, ylabel=ylabel))

    def initGraphs(self):
        for tag in self.graph_taglst:
            self.initGraph(tag, self.graph_xlabel_dic[tag], tag, tag)

    def updateGraph(self, tag, x_value, y_value, opts=None, idx=0):
        y_arr = np.array([[np.nan, np.nan, np.nan]])
        y_arr[0, idx] = y_value
        if y_value is not None:
            self.viz.line(X=np.array([[x_value] * 3]),
                          Y=y_arr,
                          win=self.viz_graph_dic[tag],
                          update='append',
                          opts=opts)

    def updateGraphs(self, x_value, value_dic, opts=None, idx=0):
        for tag in self.graph_taglst:
            self.updateGraph(tag, x_value, value_dic[tag], opts=opts, idx=idx)

    def initImage(self, tag, title, dummy=torch.Tensor(3, 100, 100)):
        self.title_dic[tag] = title
        self.viz_image_dic[tag] = self.viz.image(dummy, opts=dict(title=title))

    def initImages(self, dummy=torch.Tensor(3, 100, 100)):
        for tag in self.image_taglst:
            self.initImage(tag, tag, dummy=dummy)
            #self.initImage(tag, self.image_title_dic[tag], dummy=dummy)

    def updateImage(self, tag, image, title, n_iter):
        self.title_dic[tag] = title

        if self.is_cuda(image):
            image = image.cpu()

        if image is not None:
            #if 'normalize' in self.setting['dataset settings']['targetTransform']:
            if 'normalize' in self.setting['data']['base']['inputTransform']:
                image = (image + 1) / 2.0 * 255
            if image.dim() == 3:
                self.viz.image(image,
                               opts=dict(title=title),
                               win=self.viz_image_dic[tag])
            else:
                self.viz.images(image,
                                opts=dict(title=title),
                                win=self.viz_image_dic[tag])

    def updateImages(self, image_dic, n_iter):
        for tag in self.image_taglst:
            self.updateImage(tag, image_dic[tag], tag, n_iter)

    def is_cuda(self, tensor):
        return "cuda" in str(type(tensor))

    def destructVisdom(self):
        pass
Ejemplo n.º 12
0
 def __init__(self, setting):
     self.ch = SettingHandler(setting)
     self.loss_tags = self.ch.get_visdom_graph_tags()
     self.loss_dic = self._genLossDic()
     self.th = TimeHandler()
     self._initialize()
Ejemplo n.º 13
0
class TensorBoardXViewer:
    def __init__(self, setting):
        self.ch = SettingHandler(setting)
        self.setting = setting
        self.total_dataset_length = 0
        self.mode = "train"

        self.n_iter = 0
        self.writer_dic = {
            "train": SummaryWriter('runs/train'),
            "test": SummaryWriter('runs/test'),
            "valid": SummaryWriter('runs/valid')
        }
        self.graph_taglst = self.ch.get_visdom_graph_tags()
        self.image_taglst = self.ch.get_visdom_image_tags()

    def initGraph(self, tag, xlabel=None, ylabel=None, title=None):
        self.n_iter = 0

    def initGraphs(self):
        self.n_iter = 0

    def setMode(self, mode):
        self.mode = mode

    def setTotalDataLoaderLength(self, total_dataset_length):
        self.total_dataset_length = total_dataset_length

    def updateGraph(self, tag, x_value, y_value, opts=None, idx=0):
        self.writer_dic[self.mode].add_scalar(tag, y_value, x_value)
        #self.writer.add_scalar(tag, y_value, x_value / self.total_dataset_length)

    def updateGraphs(self, x_value, value_dic, opts=None, idx=0):
        for tag in self.graph_taglst:
            self.updateGraph(tag, x_value, value_dic[tag], opts=opts, idx=idx)

    def initImage(self, tag, title, dummy=torch.Tensor(3, 100, 100)):
        pass

    def initImages(self, dummy=torch.Tensor(3, 100, 100)):
        pass

    def updateImage(self, tag, image, title, n_iter):
        if len(image.shape) != 3:
            image = image[None]
        if 'normalize' in self.setting['data']['base']['inputTransform']:
            image = ((image + 1) / 2.0 * 255).cpu().detach().numpy().astype(
                np.uint8)
        self.writer_dic[self.mode].add_image(tag, image, n_iter)

    def updateImages(self, image_dic, n_iter):
        for tag in self.image_taglst:

            #if 'normalize' in self.setting['data']['base']['inputTransform']:
            #    image_dic[tag] = ((image_dic[tag] + 1) / 2.0 * 255).cpu().detach().numpy().astype(np.uint8)
            self.writer_dic[self.mode].add_image(
                tag,
                vutils.make_grid(image_dic[tag],
                                 normalize=True,
                                 scale_each=True), n_iter)
            #self.writer.add_images(tag, vutils.make_grid(image_dic[tag], normalize=True, scale_each=True), n_iter)

    def is_cuda(self, tensor):
        return "cuda" in str(type(tensor))

    def destructVisdom(self):
        pass
Ejemplo n.º 14
0
class TrainMonitor:
    def __init__(self, setting):
        self.ch = SettingHandler(setting)
        self.loss_tags = self.ch.get_visdom_graph_tags()
        self.loss_dic = self._genLossDic()
        self.th = TimeHandler()
        self._initialize()

    def _initialize(self):
        self.progress_bar = ''
        self.progress_value = 0
        self.thresh = 0
        for tag in self.loss_tags:
            self.loss_dic[tag] = []

    def flash(self):
        self._initialize()

    def _genLossDic(self):
        ret = {}
        for tag in self.loss_tags:
            ret[tag] = []
        return ret

    def setLosses(self, loss_dic):
        for tag in self.loss_tags:
            self.setLoss(tag, loss_dic[tag])

    def setLoss(self, tag, value):
        self.loss_dic[tag] += [value]

    def _updateProgress(self, now_progress_value, now_progress_bar,
                        current_n_iter, n_iter):

        if now_progress_value > self.thresh:
            now_progress_bar += r'8'
            self.thresh += 10
        progress_value = (1 + current_n_iter) / n_iter * 100
        return progress_value, now_progress_bar

    def dumpCurrentProgress(self, n_epoch, current_n_iter, n_iter):
        self.progress_value, self.progress_bar = \
            self._updateProgress(self.progress_value,
                                 self.progress_bar,
                                 current_n_iter,
                                 n_iter)

        losses = 'losses: '
        for k, v in self.loss_dic.items():
            if v[-1] is None:
                line = "{}:{}, "
            else:
                line = "{}:{:6.3f}, "

            losses += line.format(k, v[-1])
        sys.stdout.write(
            "\repoch {:03d}: [{}] [{:10s}] {:4.0f} % {} {:3d}/{:3d}".format(
                n_epoch, self.th.getElapsedTime(is_now=False),
                self.progress_bar, self.progress_value, losses, current_n_iter,
                n_iter))

    def dumpAverageLossOnEpoch(self, n_epoch):
        ave_losses = ''
        for k, v in self.loss_dic.items():
            if v[-1] is not None:
                ave_losses += \
                    "{} ave. loss: {:6.3f}, ".format(k, np.array(v).mean())
            else:
                ave_losses += "{} ave. loss: {}, ".format(k, None)

        sys.stdout.write("\repoch {:03d}: [{}] [{:10s}] {:4.0f} % {}\n".format(
            n_epoch, self.th.getElapsedTime(is_now=False), self.progress_bar,
            self.progress_value, ave_losses))
Ejemplo n.º 15
0
    def __init__(self, 
                 settings,
                 gpu_ids=[]):

        super(PatchGANModel, self).__init__(settings)


        # get global
        self.settings  = settings
        ch           = SettingHandler(settings)
        self.gpu_ids = ch.get_GPU_ID()
        norm_layer   = ch.get_norm_layer()

        n = 1
        if settings['base']['controller'] == 'pix2pix' or settings['base']['controller'] == 'pix2pixMulti':
            n = 2 # for image pooling
        input_nc = settings['models']['patchGANDiscriminator']['input_n'] * n

        use_bias = True
#        if type(norm_layer) == functools.partial:
#            use_bias = norm_layer.func == nn.InstanceNorm2d
#        else:
#            use_bias = norm_layer == nn.InstanceNorm2d

        
        print(settings['models'])
        kw         = int(settings['models']['patchGANDiscriminator']['kernelSize'])
        padw       = int(settings['models']['patchGANDiscriminator']['paddingSize'])
        ndf        = int(settings['models']['patchGANDiscriminator']['numberOfDiscriminatorFilters'])
        n_layers   = int(settings['models']['patchGANDiscriminator']['nLayers'])
        is_sigmoid = settings['models']['patchGANDiscriminator']['useSigmoid']


        sequence = [
                nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
                nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult      = min(2**n, 8)
            sequence += [
                    nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                              kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                    norm_layer(ndf * nf_mult),
                    nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult      = min(2**n_layers, 8)
        sequence += [
                nn.Conv2d(ndf * nf_mult_prev, 
                          ndf * nf_mult,
                          kernel_size=kw, 
                          stride=1,
                          padding=padw,
                          bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True),
                nn.Conv2d(ndf * nf_mult, 
                     1, 
                     kernel_size=kw, 
                     stride=1, 
                     padding=padw)]

        if is_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)
Ejemplo n.º 16
0
    def __init__(self, settings):
        import aimaker.models.model_factory as mf
        import aimaker.loss.loss_factory as lf
        import aimaker.optimizers.optimizer_factory as of

        self.settings = settings
        ch = SettingHandler(settings)

        self.gpu_ids = ch.get_GPU_ID()
        self.checkpoints_dir = ch.get_check_points_dir()
        model_factory = mf.ModelFactory(settings)
        loss_factory = lf.LossFactory(settings)
        optimizer_factory = of.OptimizerFactory(settings)

        # for discriminator regularization
        self.pool_fake_A = ImagePool(
            int(settings['controllers']['cycleGAN']['imagePoolSize']))
        self.pool_fake_B = ImagePool(
            int(settings['controllers']['cycleGAN']['imagePoolSize']))

        name = settings['controllers']['cycleGAN']['generatorModel']
        self.netG_A = model_factory.create(name)
        self.netG_B = model_factory.create(name)
        if len(self.gpu_ids):
            self.netG_A = self.netG_A.cuda(self.gpu_ids[0])
            self.netG_B = self.netG_B.cuda(self.gpu_ids[0])

            name = settings['controllers']['cycleGAN']['discriminatorModel']
            self.netD_A = model_factory.create(name)
            self.netD_B = model_factory.create(name)
            if len(self.gpu_ids):
                self.netD_A = self.netD_A.cuda(self.gpu_ids[0])
                self.netD_B = self.netD_B.cuda(self.gpu_ids[0])

        self.loadModels()

        self.criterionGAN = loss_factory.create("GANLoss")
        self.criterionCycle = loss_factory.create(
            settings['controllers']['cycleGAN']['cycleLoss'])
        self.criterionIdt = loss_factory.create(
            settings['controllers']['cycleGAN']['idtLoss'])
        if len(self.gpu_ids):
            self.criterionGAN = self.criterionGAN.cuda(self.gpu_ids[0])
            self.criterionCycle = self.criterionCycle.cuda(self.gpu_ids[0])
            self.criterionIdt = self.criterionIdt.cuda(self.gpu_ids[0])

            # initialize optimizers
        self.optimizer_G = optimizer_factory.create(
            settings['controllers']['cycleGAN']['generatorOptimizer'])(
                it.chain(self.netG_A.parameters(),
                         self.netG_B.parameters()), settings)

        if settings['data']['base']['isTrain']:
            self.optimizer_D_A = optimizer_factory.create(
                settings['controllers']['cycleGAN']['D_AOptimizer'])(
                    self.netD_A.parameters(), settings)
            self.optimizer_D_B = optimizer_factory.create(
                settings['controllers']['cycleGAN']['D_BOptimizer'])(
                    self.netD_B.parameters(), settings)

        if settings['ui']['base']['isShowModelInfo']:
            self.showModel()
Ejemplo n.º 17
0
class Trainer:
    def __init__(self, setting_path="settings"):
        self.settings = settings = EasyDict(load_setting(setting_path))
        self.sh = SettingHandler(settings)
        self.controller = self.sh.get_controller()
        self.dataset = DatasetFactory(settings).create(
            settings.data.base.datasetName)
        self.valid_dataset = None
        if settings.data.base.isValid:
            self.valid_dataset = DatasetFactory(settings).create(
                settings['data']['base']['valid']['datasetName'])
        self.data_loader = self.dataset.getDataLoader()
        if settings.data.base.isValid:
            self.valid_data_loader = self.valid_dataset.getDataLoader()

        self.sheckpoint_dir = self.sh.get_check_points_dir()
        self.viz = viewer.TensorBoardXViewer(settings)
        self.train_monitor = viewer.TrainMonitor(settings)

        self.n_update_graphs = self.sh.get_update_interval_of_graphs(
            self.dataset)
        self.n_update_images = self.sh.get_update_interval_of_images(
            self.dataset)
        if settings.data.base.isValid:
            self.validator = Validator(settings)
        self.idx_dic = EasyDict({'train': 0, 'test': 1, 'valid': 2})

    def _getInfo(self):
        info = EasyDict()
        info.current_epoch = 0
        info.train = EasyDict({"v_iter": 0, "current_n_iter": 0})
        info.test = EasyDict({"v_iter": 0, "current_n_iter": 0})
        info.valid = EasyDict({"v_iter": 0, "current_n_iter": 0})
        return info

    def train(self):
        n_epoch = self.settings['base']['nEpoch']
        if self.settings['base']['isView']:
            self.viz.initGraphs()
            self.viz.initImages()

        train_n_iter = len(self.data_loader)
        if self.settings.data.base.isValid:
            valid_n_iter = len(self.valid_data_loader)

        if os.path.exists(self.settings.base.infoPath):
            info = EasyDict(json.load(open(self.settings.base.infoPath)))
        else:
            info = self._getInfo()

        try:
            model_save_interval = self.sh.get_model_save_interval()
            if self.settings.data.base.isValid:
                model_save_interval_valid = self.sh.get_model_save_interval_for_valid(
                )
            for current_epoch in range(info.current_epoch, n_epoch):
                info.current_epoch = current_epoch
                if self.settings['data']['base']['isTrain']:
                    info.mode = "train"
                    print('{}:'.format('train'))
                    self.dataset.set_mode('train')
                    self.dataset.getTransforms()
                    self.viz.setMode('train')
                    self.data_loader = self.dataset.getDataLoader()
                    self.controller.set_mode('train')
                    info = self._learning('train', info, current_epoch,
                                          self.data_loader, train_n_iter)
                if self.settings['data']['base']['isTest']:
                    info.mode = "test"
                    print('{}:'.format('test'))
                    self.dataset.set_mode('test')
                    self.dataset.getTransforms()
                    self.viz.setMode('test')
                    self.data_loader = self.dataset.getDataLoader()
                    self.controller.set_mode('test')
                    info = self._learning('test', info, current_epoch,
                                          self.data_loader, train_n_iter)
                if self.valid_dataset is not None:
                    if self.settings['data']['base']['isValid']:
                        info.mode = "valid"
                        print('{}:'.format('valid'))
                        self.valid_dataset.set_mode('valid')
                        self.valid_dataset.getTransforms()
                        self.viz.setMode('valid')
                        self.valid_data_loader = self.valid_dataset.getDataLoader(
                        )
                        self.controller.set_mode('valid')
                        info = self._learning('valid', info, current_epoch,
                                              self.valid_data_loader,
                                              valid_n_iter)
                        if current_epoch != 0 and not current_epoch % model_save_interval_valid:
                            self.controller._save_model(
                                self.controller.get_model(),
                                self.settings['validator']['base']
                                ['modelPath'],
                                is_fcnt=False)
                            self.validator.upload(
                            )  #self.settings['valid_data']['data']['base']['datasetName'])

                if not current_epoch % model_save_interval:
                    self.controller.save_models()
        except:
            import traceback
            traceback.print_exc()
        self.controller.save_models()
        if self.settings['base']['isView']:
            self.viz.destructVisdom()

    def _saveInfo(self, info):
        with open(self.settings.base.infoPath, 'w') as fp:
            json.dump(info, fp)

    def _learning(self, mode, info, current_epoch, data_loader, train_n_iter):
        n_iter = len(data_loader)
        ratio = train_n_iter / n_iter
        v_iter = info[mode].v_iter
        self.viz.setTotalDataLoaderLength(len(self.data_loader))
        is_volatile = False if mode == 'train' else True
        for current_n_iter, data in enumerate(data_loader):
            if current_n_iter < info[mode].current_n_iter:
                continue
            info[mode].current_n_iter = current_n_iter
            info[mode].v_iter = v_iter
            try:
                self.controller.set_input(data, is_volatile)
                self.controller.forward()
                if mode == 'train':
                    self.controller.backward()
                loss_dic = self.controller.get_losses()
                if mode == 'valid':
                    self.validator.setLosses(loss_dic)
                self.train_monitor.setLosses(loss_dic)
                self.train_monitor.dumpCurrentProgress(current_epoch,
                                                       current_n_iter, n_iter)

                if not current_n_iter % self.n_update_graphs:
                    if self.settings['base']['isView']:
                        self.viz.updateGraphs(ratio * v_iter,
                                              loss_dic,
                                              idx=self.idx_dic[mode])
                if not current_n_iter % self.n_update_images:
                    if self.settings['base']['isView']:
                        self.viz.updateImages(self.controller.get_images(),
                                              current_n_iter)
            except KeyboardInterrupt:
                self._saveInfo(info)
                sys.exit()
                break
            except FileNotFoundError:
                import traceback
                traceback.print_exc()
            except:
                import traceback
                traceback.print_exc()
                break
            v_iter += 1
        #if mode == 'valid':
        #    self.validator.uploadModelIfSOTA(current_epoch)
        info[mode].v_iter = v_iter
        info[mode].current_n_iter = 0  # reset
        self.train_monitor.dumpAverageLossOnEpoch(current_epoch)
        self.train_monitor.flash()
        return info
Ejemplo n.º 18
0
class BaseDataset(data.Dataset):
    def __init__(self, setting=None):
        super(BaseDataset, self).__init__()

        self.setting = setting
        self.ch = SettingHandler(setting)
        self.mode = 'train'
        self.nom_transform    = self.ch.get_normalize_transform()
        self.to_tensor        = transforms.ToTensor()
        self.is_shuffle = self.setting['data']['base']['isShuffle'] if setting is not None else True

        #self.port = setting['utils']['connection']['port']
        self.train_test_ratio = float(setting['data']['base']['trainTestRatio']) if setting is not None else 1

        self.dsf = dsf.DataSourceFactory(train_test_ratio=self.train_test_ratio)
        self._setDataSource()

    def getDataLoader(self):
        return dataloader.DataLoader(self,
                                     batch_size=self.ch.get_batch_size(self.mode),
                                     shuffle=self.is_shuffle,
                                     )

    def getTransforms(self):
        self.input_transform = self.ch.get_input_transform(self.mode)
        self.target_transform = self.ch.get_target_transform(self.mode)
        self.common_transform = self.ch.get_common_transform(self.mode)

    def _setDataSource(self):
        self.ds = "dammy data Source"
        pass

    def _getInput(self, index):
        pass

    def _getTarget(self, index):
        pass

    def _inputTransform(self, input):
        if isinstance(input, list):
            input = torch.cat([self.input_transform(x) for x in input])
        else:
            input = self.input_transform(input)
        return input

    def _targetTransform(self, target):
        if isinstance(target, list):
            target = torch.cat([self.target_transform(x) for x in target])
        else:
            target = self.target_transform(target)
        return target

    def _commonTransform(self, input, target):
        seed = random.randint(0, 2147483647)

        if isinstance(input, list):
            lst = []
            for i in input:
                random.seed(seed)
                lst += [self.common_transform(i)]
            input = lst
        else:
            random.seed(seed)
            input = self.common_transform(input)

        if isinstance(target, list):
            lst = []
            for i in target:
                random.seed(seed)
                lst += [self.common_transform(i)]
            target = lst
        else:
            random.seed(seed)
            target = self.common_transform(target)

        return input, target

    def setMode(self, mode):
        if mode == 'valid':
            mode == 'train'
        self.ds.setMode(mode)

    def __getitem__(self, index):
        input   = self._getInput(index)
        target  = self._getTarget(index)

        if self.common_transform:
            input, target = self._commonTransform(input, target)

        if self.input_transform:
            input  = self._inputTransform(input)
        if self.target_transform:
            target = self._targetTransform(target)

        return input, target

    def __len__(self):
        return len(self.ds)