示例#1
0
    def init_environment(self):
        # gpu
        if self.arg.use_gpu:
            gpus = torchlight.visible_gpu(self.arg.device)
            torchlight.occupy_gpu(gpus)
            self.gpus = gpus
            self.dev = "cuda:0"
        else:
            self.dev = "cpu"

        # random seed
        if self.arg.fix_random:
            SEED = 0
            #torch.backends.cudnn.deterministic = True
            torch.manual_seed(SEED)
            torch.cuda.manual_seed_all(SEED)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            np.random.seed(SEED)
            random.seed(SEED)

        # dir
        self.work_dir = get_dir(self.arg.work_dir, None)
        self.config_dir = get_dir(self.arg.work_dir, 'config')
        self.checkpoint_dir = get_dir(self.arg.work_dir, 'checkpoints')
        self.log_dir = get_dir(self.arg.work_dir, 'log')
        self.tflog_dir = get_dir(self.arg.work_dir, 'tflog')
        self.test_dir = get_dir(self.arg.work_dir, 'test_results')
示例#2
0
    def init_environment(self):
        # super().init_environment()  # 先执行一次父类的该方法,这里实际上就是在父类方法的基础上进行扩展
        #父类方法,直接抄过来方便查阅
        self.io = torchlight.IO(  #class 'torchlight.io.IO',这里的torchlight.IO就是一个容器
            self.arg.work_dir,  #存储结果的路径,默认为'./work_dir/tmp'
            save_log=self.arg.save_log,  #是否保存日志,默认true
            print_log=self.arg.print_log)  #是否打印日志,默认true
        self.io.save_arg(self.arg)

        # gpu
        if self.arg.use_gpu:  #如果指定了使用GPU,就默认使用0号GPU,否则使用CPU
            gpus = torchlight.visible_gpu(
                self.arg.device)  #返回的是list(range(len(gpus)))
            torchlight.occupy_gpu(gpus)

            self.gpus = gpus
            print('现在使用的gpu为:', self.gpus)
            self.dev = "cuda:0"
        else:
            self.dev = "cpu"
        #子类方法
        #定义了几个字典对象
        self.result = dict()
        self.iter_info = dict()
        self.epoch_info = dict()
        self.meta_info = dict(epoch=0, iter=0)
示例#3
0
文件: io.py 项目: zgsxwsdxg/st-gcn
    def init_environment(self):
        self.io = torchlight.IO(self.arg.work_dir,
                                save_log=self.arg.save_log,
                                print_log=self.arg.print_log)
        self.io.save_arg(self.arg)

        # gpu
        gpus = torchlight.visible_gpu(self.arg.device)
        torchlight.occupy_gpu(gpus)
        self.gpus = gpus
        self.dev = "cuda:0"
示例#4
0
 def init_environment(self):
     self.io = torchlight.IO(self.arg.work_dir,
                             save_log=self.arg.save_log,
                             print_log=self.arg.print_log)
     self.io.save_arg(self.arg)
     print('Mid of init enviroment')
     # gpu
     if self.arg.use_gpu:
         gpus = torchlight.visible_gpu(self.arg.device)
         torchlight.occupy_gpu(gpus)
         self.gpus = gpus
         self.dev = "cuda:0"
     else:
         self.dev = "cpu"
示例#5
0
    def init_environment(self):
        self.io = torchlight.IO(  #class 'torchlight.io.IO',这里的torchlight.IO就是一个容器
            self.arg.work_dir,  #存储结果的路径,默认为'./work_dir/tmp'
            save_log=self.arg.save_log,  #是否保存日志,默认true
            print_log=self.arg.print_log)  #是否打印日志,默认true
        self.io.save_arg(self.arg)

        # gpu
        if self.arg.use_gpu:  #如果指定了使用GPU,就默认使用0号GPU,否则使用CPU
            gpus = torchlight.visible_gpu(self.arg.device)
            torchlight.occupy_gpu(gpus)
            self.gpus = gpus
            self.dev = "cuda:0"
        else:
            self.dev = "cpu"
示例#6
0
    def init_environment(self):
        self.io = torchlight.IO(
            self.arg.work_dir,
            save_log=self.arg.save_log,
            print_log=self.arg.print_log)
        self.io.save_arg(self.arg)

        # gpu
        if self.arg.use_gpu:
            gpus = torchlight.visible_gpu(self.arg.device)
            torchlight.occupy_gpu(gpus)
            self.gpus = gpus
            self.dev = "cuda:0"
        else:
            self.dev = "cpu"
示例#7
0
文件: io.py 项目: zzzzlalala/AS-GCN
    def init_environment(self):
        self.save_dir = os.path.join(self.arg.work_dir, self.arg.max_hop_dir,
                                     self.arg.lamda_act_dir)
        self.io = torchlight.IO(self.save_dir,
                                save_log=self.arg.save_log,
                                print_log=self.arg.print_log)
        self.io.save_arg(self.arg)

        # gpu
        if self.arg.use_gpu:
            gpus = torchlight.visible_gpu(self.arg.device)
            torchlight.occupy_gpu(gpus)
            self.gpus = gpus
            self.dev = "cuda:0"
        else:
            self.dev = "cpu"
示例#8
0
文件: io.py 项目: zbliu98/DMGNN
    def init_environment(self):
        self.save_dir = os.path.join(self.arg.work_dir,
                                     self.arg.fusion_layer_dir,
                                     self.arg.learning_rate_dir,
                                     self.arg.lamda_dir,
                                     self.arg.crossw_dir,
                                     self.arg.note)
        self.io = torchlight.IO(self.save_dir, save_log=self.arg.save_log, print_log=self.arg.print_log)
        self.io.save_arg(self.arg)

        if self.arg.use_gpu:
            gpus = torchlight.visible_gpu(self.arg.device)
            torchlight.occupy_gpu(gpus)
            self.gpus = gpus
            self.dev = "cuda:0"
        else:
            self.dev = "cpu"
示例#9
0
    def load_model(self):
        # pdb.set_trace()
        self.arg.device = torchlight.visible_gpu(self.arg.device)
        output_device = self.arg.device[0] if type(self.arg.device) is list else self.arg.device
        self.output_device = output_device
        Model = import_class(self.arg.model)
        shutil.copy2(inspect.getfile(Model), self.arg.work_dir)
        print(Model)
        self.model = Model(**self.arg.model_args)
        # pdb.set_trace()
        print(self.model)
        self.loss = nn.CrossEntropyLoss().cuda(output_device)

        if self.arg.weights:
            self.global_step = int(arg.weights[:-3].split('-')[-1])
            self.print_log('Load weights from {}.'.format(self.arg.weights))
            if '.pkl' in self.arg.weights:
                with open(self.arg.weights, 'r') as f:
                    weights = pickle.load(f)
            else:
                weights = torch.load(self.arg.weights)

            weights = OrderedDict(
                [[k.split('module.')[-1],
                  v.cuda(output_device)] for k, v in weights.items()])

            keys = list(weights.keys())
            for w in self.arg.ignore_weights:
                for key in keys:
                    if w in key:
                        if weights.pop(key, None) is not None:
                            self.print_log('Sucessfully Remove Weights: {}.'.format(key))
                        else:
                            self.print_log('Can Not Remove Weights: {}.'.format(key))

            try:
                self.model.load_state_dict(weights)
            except:
                state = self.model.state_dict()
                diff = list(set(state.keys()).difference(set(weights.keys())))
                print('Can not find these weights:')
                for d in diff:
                    print('  ' + d)
                state.update(weights)
                self.model.load_state_dict(state)