class TorchRoadSignsModel(ZMLModel):
    IMAGE_SIZE = 64  # 128

    ##
    # Constructor
    def __init__(self,
                 dataset_id,
                 epochs=0,
                 mainv=None,
                 ipaddress="127.0.0.1",
                 port=7777):
        super(TorchRoadSignsModel, self).__init__(0, mainv)

        #self.view   = mainv
        self._start(self.__init__.__name__)

        self.write("dataset_id:{}, ephochs:{}, mainv:{}".format(
            dataset_id, epochs, mainv))
        self.ipaddress = ipaddress
        self.port = port
        self.model = None  #
        self.dataset_id = dataset_id  # ROADSIGNS
        self.dataset = None
        self.epochs = epochs
        self.set_dataset_id(dataset_id)

        notifier = self.__class__.__name__ + str("-") + str(self.dataset_id)

        self.callbacks = [
            ZTorchEpochChangeNotifier(ipaddress, port, notifier,
                                      int(self.epochs) + 10)
        ]

        self._end(self.__init__.__name__)

    def set_dataset_id(self, dataset_id):
        self._start(self.set_dataset_id.__name__)
        self.dataset_id = dataset_id

        self.model_filename = self.__class__.__name__ + "_" + str(
            self.dataset_id) + ".pt"

        self.nclasses = 0
        self.write("model_filename  " + self.model_filename)

        self._end(self.set_dataset_id.__name__)

    def build(self):
        self.write("====================================")
        self._start(self.build.__name__)

        if self.is_trained() != True:
            try:
                self.load_dataset()
                self.create()

                self.train()
                self.save()

            except:
                traceback.print_exc()

        self._end(self.build.__name__)

    #
    def load_dataset(self,
                     data_root="./dataset/",
                     batch_size_train=64,
                     batch_size_test=16,
                     try_augmentation=False):

        self._start(self.load_dataset.__name__)

        if try_augmentation == True:
            self.train_tranformer = transforms.Compose([
                transforms.Resize((self.IMAGE_SIZE, self.IMAGE_SIZE)),
                transforms.RandomCrop(
                    (self.IMAGE_SIZE - 4, self.IMAGE_SIZE - 4)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
        else:
            self.train_transformer = transforms.Compose([
                transforms.Resize((self.IMAGE_SIZE, self.IMAGE_SIZE)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

        self.valid_transformer = transforms.Compose([
            transforms.Resize((self.IMAGE_SIZE, self.IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        # Load ROADSIGNS
        if self.dataset_id == ROADSIGNS:
            # create train/valid datasets
            train_root = data_root + "/train/"
            self.train_dataset = TorchRoadSignsDataset(
                root=train_root, transform=self.train_transformer)
            valid_root = data_root + "/valid/"
            self.valid_dataset = TorchRoadSignsDataset(
                root=valid_root, transform=self.valid_transformer)
            self.classes = self.train_dataset.classes
            self.nclasses = self.train_dataset.nclasses
            # create train/val loaders

            self.train_loader = DataLoader(dataset=self.train_dataset,
                                           batch_size=batch_size_train,
                                           shuffle=True,
                                           num_workers=2)
            self.valid_loader = DataLoader(dataset=self.valid_dataset,
                                           batch_size=batch_size_test,
                                           shuffle=False,
                                           num_workers=2)

        self._end(self.load_dataset.__name__)

    # Create a sequential model
    def create(self):
        self._start(self.create.__name__)
        self.image_size = (3, self.IMAGE_SIZE, self.IMAGE_SIZE)

        print("classes {}".format(self.nclasses))
        self.model = ZTorchSimpleModel(self.image_size, self.nclasses,
                                       self.model_filename)
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = self.model.to(device)

        self._end(self.create.__name__)

    def train(self):
        self._start(self.train.__name__)
        start = time.time()
        criterion = nn.CrossEntropyLoss()

        optimizer = optim.SGD(self.model.parameters(),
                              lr=0.01,
                              momentum=0.9,
                              weight_decay=5e-4)

        self.model.fit(self.train_loader, self.valid_loader, self.callbacks,
                       self.epochs, criterion, optimizer)

        elapsed_time = time.time() - start
        elapsed = str("Train elapsed_time:{0}".format(elapsed_time) + "[sec]")
        self.write(elapsed)
        self.model.summary()
        self._end(self.train.__name__)

    def predict(self, input):
        #image_tensor = self.valid_transformer(image).float()
        #image_tensor = image_tensor.unsqueeze_(0)
        #input = Variable(image_tensor)

        prediction = self.model.predict(input)
        return prediction

    def save(self):
        self._start(self.save.__name__)
        self.model.save()
        self._end(self.save.__name__)

    def load(self):
        self._start(self.load.__name__)

        try:
            self.model.load_model()
            #self.write("Loaded a weight file:{}".format(self.model_file))

        except:
            self.write(formatted_traceback())

        self._end(self.load.__name__)

    def get_model(self):
        return self.model

    def is_trained(self):
        rc = False

        if os.path.isfile(self.model_filename) == True:
            self.write("Found model_filename:'{}'".format(self.model_filename))
            rc = True
        return rc
Exemplo n.º 2
0
class TorchCIFARModel(ZMLModel):
    ##
    # Constructor
    def __init__(self,
                 dataset_id,
                 epochs=0,
                 mainv=None,
                 ipaddress="127.0.0.1",
                 port=7777):
        super(TorchCIFARModel, self).__init__(0, mainv)

        #self.view   = mainv
        self._start(self.__init__.__name__)

        self.write("dataset_id:{}, ephochs:{}, mainv:{}".format(
            dataset_id, epochs, mainv))
        self.ipaddress = ipaddress
        self.port = port
        self.model = None  #
        self.dataset_id = dataset_id  # CIFAR10 or CIFAR100
        self.dataset = None
        self.epochs = epochs
        self.set_dataset_id(dataset_id)

        notifier = self.__class__.__name__ + str("-") + str(self.dataset_id)

        self.callbacks = [
            ZTorchEpochChangeNotifier(ipaddress, port, notifier,
                                      int(self.epochs) + 10)
        ]

        self._end(self.__init__.__name__)

    def set_dataset_id(self, dataset_id):
        self._start(self.set_dataset_id.__name__)
        self.dataset_id = dataset_id

        self.model_filename = self.__class__.__name__ + "_" + str(
            self.dataset_id) + ".pt"

        self.nclasses = 0
        self.write("model_filename  " + self.model_filename)

        self._end(self.set_dataset_id.__name__)

    def build(self):
        self.write("====================================")
        self._start(self.build.__name__)

        if self.is_trained() != True:
            try:
                self.load_dataset()
                self.create()

                self.train()
                #self.evaluate()
                self.save()

            except:
                traceback.print_exc()

        self._end(self.build.__name__)

    #
    def load_dataset(self,
                     data_root="./data",
                     batch_size_train=128,
                     batch_size_test=64):

        self._start(self.load_dataset.__name__)

        self.train_transformer = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.test_transformer = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        # Load CIFAR10
        if self.dataset_id == CIFAR10:
            self.trainset = torchvision.datasets.CIFAR10(
                root=data_root,
                train=True,
                download=True,
                transform=self.train_transformer)

            self.train_loader = torch.utils.data.DataLoader(
                self.trainset,
                batch_size=batch_size_train,
                shuffle=True,
                num_workers=2)

            self.testset = torchvision.datasets.CIFAR10(
                root=data_root,
                train=False,
                download=True,
                transform=self.test_transformer)

            self.test_loader = torch.utils.data.DataLoader(
                self.testset,
                batch_size=batch_size_test,
                shuffle=False,
                num_workers=2)

            self.nclasses = 10

        # Load CIFAR100
        if self.dataset_id == CIFAR100:
            self.trainset = torchvision.datasets.CIFAR100(
                root=data_root,
                train=True,
                download=True,
                transform=self.train_transformer)

            self.train_loader = torch.utils.data.DataLoader(
                self.trainset,
                batch_size=batch_size_train,
                shuffle=True,
                num_workers=2)

            self.testset = torchvision.datasets.CIFAR100(
                root=data_root,
                train=False,
                download=True,
                transform=self.test_transformer)

            self.test_loader = torch.utils.data.DataLoader(
                self.testset,
                batch_size=batch_size_test,
                shuffle=False,
                num_workers=2)

            self.nclasses = 100

        self._end(self.load_dataset.__name__)

    # Create a sequential model
    def create(self):
        self._start(self.create.__name__)
        self.image_size = (3, 32, 32)

        print("classes {}".format(self.nclasses))
        self.model = ZTorchSimpleModel(self.image_size, self.nclasses,
                                       self.model_filename)
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = self.model.to(device)

        self._end(self.create.__name__)

    def train(self):
        self._start(self.train.__name__)
        start = time.time()
        criterion = nn.CrossEntropyLoss()

        optimizer = optim.SGD(self.model.parameters(),
                              lr=0.01,
                              momentum=0.9,
                              weight_decay=5e-4)

        self.model.fit(self.train_loader, self.test_loader, self.callbacks,
                       self.epochs, criterion, optimizer)

        elapsed_time = time.time() - start
        elapsed = str("Train elapsed_time:{0}".format(elapsed_time) + "[sec]")
        self.write(elapsed)
        self.model.summary()
        self._end(self.train.__name__)

    def predict(self, input):
        #image_tensor = self.test_transformer(image).float()
        #image_tensor = image_tensor.unsqueeze_(0)
        #input = Variable(image_tensor)

        prediction = self.model.predict(input)

        return prediction

    def save(self):
        self._start(self.save.__name__)
        self.model.save()
        self._end(self.save.__name__)

    def load(self):
        self._start(self.load.__name__)

        try:
            self.model.load_model()
            #self.write("Loaded a weight file:{}".format(self.model_file))

        except:
            self.write(formatted_traceback())

        self._end(self.load.__name__)

    def get_model(self):
        return self.model

    def is_trained(self):
        rc = False

        if os.path.isfile(self.model_filename) == True:
            self.write("Found model_filename:'{}'".format(self.model_filename))
            rc = True
        return rc

    def evaluate(self):
        self._start(self.evaluate.__name__)
        try:
            score = 0  # self.model.evaluate(self.X_test, self.y_test, verbose=0)
            #self.write("Test loss    :{}".format(score[0]))
            #self.write("Test accuracy:{}".format(score[1]))
            pass
        except:
            self.write(formatted_traceback())

        self._end(self.evaluate.__name__)