Пример #1
0
    def save(self, info):
        dst = info['path']
        App.instance().make_save_dir(dst)

        file_names = os.listdir(dst)
        previous_file_name = None
        for name in file_names:
            _, file_extension = os.path.splitext(name)
            if file_extension == 'model':
                epoch = name.split('_')[-2]
                if int(epoch) == info['controller'].get_current_main_module():
                    previous_file_name = name

        saved_data = {
            'data': self.state_dict(),
            'config': Config.extraction_dictionary(self.config)
        }
        saved_path = os.path.join(dst,
                                  self.get_name_format(info['controller']))
        torch.save(saved_data, saved_path)
        if previous_file_name is not None:
            os.remove(os.path.join(dst, previous_file_name))

        App.instance().set_variables('$latest_{}'.format(info['module_name']),
                                     saved_path)
Пример #2
0
 def save(self, info):
     state = info['state']
     contents = App.instance().current_time_format(
     ) + " {} epoch, {} step\n".format(state.epoch, state.step)
     contents += "|{}: {}\t".format('Total', str(self.total_loss))
     for key in self.loss_dict.keys():
         contents += "|{}: {}\t".format(key, str(self.loss_dict[key]))
     App.instance().smart_write(contents + '\n', info['path'], 'a+')
Пример #3
0
 def get_name_format(self, controller):
     f = MainStateBasedFormatter(
         controller, {
             'model_name': App.instance().name,
             'time': App.instance().time_format()
         },
         '[$time]_[$model_name]_[$main:step:03]_[$main:total_step:08].model'
     )
     return f.Formatting()
    def __init__(self):
        self.config = App.instance().config.MODEL_CONTROLLER
        self.dataloader_controller = DataLoaderController.instance()

        self.MODEL: torch.nn.Module = None
        self.OPTIMIZER: torch.nn.Module = None
        self.LOSSES: torch.nn.Module = App.instance().set_gpu_device(LossContainer(self.config.LOSSES))

#        self.COMMAND_CONTROLLER = CommandController(self.config.COMMAND_CONTROLLER.command_path, self)
        self.COMMAND_CONTROLLER = RunnableModuleController(self.config.COMMAND_CONTROLLER, self)
        self.sample = None

        self.all_callable = [method_name for method_name in dir(self) if callable(getattr(self, method_name))]
Пример #5
0
    def run(self, controller):
        for name in self.config.required:
            if name not in dir(controller):
                write_log('[ERROR] save fail: {} is not Modules.'.format(name),
                          controller)
                continue

            args = copy.deepcopy(self.config.args[name])
            args['module_name'] = name
            args['controller'] = controller
            args['path'] = dir_path_parser(
                args.get('path',
                         '$base/ckpt_{}'.format(name),
                         possible_none=False))

            module = controller.__getattribute__(name)
            if not isinstance(module, torch.nn.Module):
                write_log('[ERROR] save fail: {} is not Modules'.format(name),
                          controller)
                continue
            if not callable(getattr(module, 'save')):
                write_log(
                    '[ERROR] save fail: {} must to have save method.'.format(
                        name), controller)
                continue

            module.save(args)
            write_log(
                '[INFO] {} Save scucesses [{}] '.format(
                    self.config.required,
                    App.instance().get_variables('$latest_{}'.format(name))),
                controller)
    def controller_factory(cls, config=None):
        if config is None:
            config = App.instance().config.MODEL_CONTROLLER

        controller_module = config.MODULE_NAME
        controller_class = config.CLASS_NAME
        controller_module: ModelController = get_class_object_from_name(controller_module, controller_class)
        return controller_module.instance()
Пример #7
0
def main(configs):
    config = None
    if isinstance(configs, list):
        config = App.make_from_config_list(configs).config
    else:
        config = App.instance(configs).config
        App.instance().update()
    dataloader_controller = DataLoaderController.instance()

    trainer: ModelController = ModelController.controller_factory()

    try:
        # trainer.run()
        trainer.COMMAND_CONTROLLER.run()
    except Exception as e:
        print(e)
    finally:
        MultipleProcessorController.instance().remove_all_process()
Пример #8
0
    def __init__(self, config):
        super().__init__()

        self.loss_dict = nn.ModuleDict()
        for loss_name in config.keys():
            self.loss_dict[loss_name] = LossFactory.instantiate_loss(
                loss_name, config[loss_name])
        self.total_loss = LossEmpty('TotalLoss')
        self.device = App.instance().get_device()
Пример #9
0
    def run(self, controller):
        args = []
        dir_path = dir_path_parser(
            self.config.args.get('path',
                                 '$base/visual/img',
                                 possible_none=False))
        fm = self.config.args.get('format', 'png', possible_none=False)

        App.instance().make_save_dir(dir_path)
        formatter = MainStateBasedFormatter(
            controller, {
                'content': '',
                'format': fm,
                'batch': 0
            },
            format=
            '[$main:epoch:03]e_[$main:step:08]s_[$content]_[$batch].[$format]')
        for name in self.config.required:
            formatter.contents['content'] = name
            imgs = self.numpy_trasnsform(
                {name: controller.sample[name].clone()})

            b, _, _, _ = imgs[name].shape
            for i in range(b):
                formatter.contents['batch'] = str(i).zfill(4)
                path = os.path.join(dir_path, formatter.Formatting())
                args.append((path, imgs[name][i]))

        import time

        def batched_image_save(queue):
            while True:
                sample = queue.get()
                if sample is None: break
                path, img = sample
                misc.imsave(path, img)
                time.sleep(0.001)

        MultipleProcessorController.instance().push_data(
            self.__class__.__name__, batched_image_save, args, num_worker=1)
Пример #10
0
    def __init__(self):
        super(ExampleContoller, self).__init__()

        self.MODEL = BaseModel.model_factory(
            self.config.MODEL.MODEL_CONFIG_PATH)
        self.MODEL = App.instance().set_gpu_device(self.MODEL)

        print(self.MODEL.description())
        if self.MODEL is not None:
            self.OPTIMIZER = make_optimizer(self.config.OPTIMIZER, self.MODEL)

        self.all_callable = [
            method_name for method_name in dir(self)
            if callable(getattr(self, method_name))
        ]
        self.sample = None
Пример #11
0
 def get_save_name(self, state):
     return App.instance().name_format(
         App.instance().name) + "_{}_{}.opt".format(
             state.epoch, state.step)
Пример #12
0
 def forward(self, samples):
     for name in self.inputs_name:
         samples[name] = samples[name].to(
             device=App.instance().get_device()).float()
     return samples
Пример #13
0
 def description(self):
     return str(self) + "=> IN:{}, OUT:{} ".format(
         self.inputs_name, self.outputs_name) + "device: {}:{}\n".format(
             App.instance().get_device(),
             App.instance().get_gpu_ids())
Пример #14
0
def dir_path_parser(path):
    return os.path.join(*App.instance().variable_parsing(path, '/'))