コード例 #1
0
ファイル: processor_lstm.py プロジェクト: xiang526/STEP
    def __init__(self, args, ftype, data_loader, data_max, data_min, C, T, V, F, num_classes, n_z=1024, device='cuda:0'):

        self.args = args
        self.ftype = ftype
        self.data_loader = data_loader
        self.data_max = data_max
        self.data_min = data_min
        self.num_classes = num_classes
        self.result = dict()
        self.iter_info = dict()
        self.epoch_info = dict()
        self.meta_info = dict(epoch=0, iter=0)
        self.device = device
        self.io = torchlight.IO(
            self.args.work_dir,
            save_log=self.args.save_log,
            print_log=self.args.print_log)

        # model
        self.C = C
        self.T = T
        self.V = V
        self.F = F
        self.n_z = n_z
        if not os.path.isdir(self.args.work_dir):
            os.mkdir(self.args.work_dir)
        self.model = CVAE.CVAE(F, T, self.n_z, num_classes)
        self.model.cuda('cuda:0')
        self.model.apply(weights_init)
        self.loss = vae_loss
        self.best_loss = np.inf
        self.loss_updated = False
        self.step_epochs = [np.ceil(float(self.args.num_epoch * x)) for x in self.args.step]
        self.best_epoch = None
        self.mean = 0.
        self.lsig = 1.

        # optimizer
        if self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(
                self.model.parameters(),
                lr=self.args.base_lr,
                momentum=0.9,
                nesterov=self.args.nesterov,
                weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'Adam':
            self.optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.args.base_lr,
                weight_decay=self.args.weight_decay)
        else:
            raise ValueError()
        self.lr = self.args.base_lr
コード例 #2
0
ファイル: processor.py プロジェクト: xiang526/STEP
    def __init__(self,
                 args,
                 data_loader,
                 C,
                 num_classes,
                 graph_dict,
                 device='cuda:0',
                 verbose=True):

        self.args = args
        self.data_loader = data_loader
        self.num_classes = num_classes
        self.result = dict()
        self.iter_info = dict()
        self.epoch_info = dict()
        self.meta_info = dict(epoch=0, iter=0)
        self.device = device
        self.verbose = verbose
        self.io = torchlight.IO(self.args.work_dir,
                                save_log=self.args.save_log,
                                print_log=self.args.print_log)

        # model
        if not os.path.isdir(self.args.work_dir):
            os.mkdir(self.args.work_dir)
        self.model = classifier.Classifier(C, num_classes, graph_dict)
        self.model.cuda('cuda:0')
        self.model.apply(weights_init)
        self.loss = nn.CrossEntropyLoss()
        self.best_loss = math.inf
        self.step_epochs = [
            math.ceil(float(self.args.num_epoch * x)) for x in self.args.step
        ]
        self.best_epoch = None
        self.best_accuracy = np.zeros((1, np.max(self.args.topk)))
        self.accuracy_updated = False

        # optimizer
        if self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.args.base_lr,
                                       momentum=0.9,
                                       nesterov=self.args.nesterov,
                                       weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'Adam':
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=self.args.base_lr,
                                        weight_decay=self.args.weight_decay)
        else:
            raise ValueError()
        self.lr = self.args.base_lr
コード例 #3
0
ファイル: io.py プロジェクト: haryuu/lighttrack-
    def init_environment(self):
        self.io = torchlight.IO(self.arg.work_dir,
                                save_log=self.arg.save_log,
                                print_log=self.arg.print_log)
        self.io.save_arg(self.arg)

        # gpu
        if self.arg.use_gpu:
            gpus = torchlight.visible_gpu(self.arg.device)
            torchlight.occupy_gpu(gpus)
            self.gpus = gpus
            self.dev = "cuda:0"
        else:
            self.dev = "cpu"