Exemplo n.º 1
0
    def __init__(self, ARCH, DATA, datadir, logdir, path=None):
        # parameters
        self.ARCH = ARCH
        self.DATA = DATA
        self.datadir = datadir
        self.log = logdir
        self.path = path

        # put logger where it belongs
        self.tb_logger = Logger(self.log + "/tb")
        self.info = {
            "train_update": 0,
            "train_loss": 0,
            "train_acc": 0,
            "train_iou": 0,
            "valid_loss": 0,
            "valid_acc": 0,
            "valid_iou": 0,
            "backbone_lr": 0,
            "decoder_lr": 0,
            "head_lr": 0,
            "post_lr": 0
        }

        # get the data
        parserModule = imp.load_source(
            "parserModule", booger.TRAIN_PATH + '/tasks/semantic/dataset/' +
            self.DATA["name"] + '/parser.py')
        self.parser = parserModule.Parser(
            root=self.datadir,
            train_sequences=self.DATA["split"]["train"],
            valid_sequences=self.DATA["split"]["valid"],
            test_sequences=None,
            labels=self.DATA["labels"],
            color_map=self.DATA["color_map"],
            learning_map=self.DATA["learning_map"],
            learning_map_inv=self.DATA["learning_map_inv"],
            sensor=self.ARCH["dataset"]["sensor"],
            max_points=self.ARCH["dataset"]["max_points"],
            batch_size=self.ARCH["train"]["batch_size"],
            workers=self.ARCH["train"]["workers"],
            gt=True,
            shuffle_train=True)

        # weights for loss (and bias)
        # weights for loss (and bias)
        epsilon_w = self.ARCH["train"]["epsilon_w"]
        content = torch.zeros(self.parser.get_n_classes(), dtype=torch.float)
        for cl, freq in DATA["content"].items():
            x_cl = self.parser.to_xentropy(
                cl)  # map actual class to xentropy class
            content[x_cl] += freq
        self.loss_w = 1 / (content + epsilon_w)  # get weights
        for x_cl, w in enumerate(
                self.loss_w):  # ignore the ones necessary to ignore
            if DATA["learning_ignore"][x_cl]:
                # don't weigh
                self.loss_w[x_cl] = 0
        print("Loss weights from content: ", self.loss_w.data)

        # concatenate the encoder and the head
        with torch.no_grad():
            self.model = Segmentator(self.ARCH, self.parser.get_n_classes(),
                                     self.path)
            print(self.model)

        # GPU?
        self.gpu = False
        self.multi_gpu = False
        self.n_gpus = 0
        self.model_single = self.model
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        print("Training in device: ", self.device)
        if torch.cuda.is_available() and torch.cuda.device_count() > 0:
            cudnn.benchmark = True
            cudnn.fastest = True
            self.gpu = True
            self.n_gpus = 1
            self.model.cuda()
        if torch.cuda.is_available() and torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            self.model = nn.DataParallel(self.model)  # spread in gpus
            self.model = convert_model(self.model).cuda()  # sync batchnorm
            self.model_single = self.model.module  # single model to get weight names
            self.multi_gpu = True
            self.n_gpus = torch.cuda.device_count()

        # loss
        if "loss" in self.ARCH["train"].keys(
        ) and self.ARCH["train"]["loss"] == "xentropy":
            self.criterion = nn.NLLLoss(weight=self.loss_w).to(self.device)
        else:
            raise Exception('Loss not defined in config file')
        # loss as dataparallel too (more images in batch)
        if self.n_gpus > 1:
            self.criterion = nn.DataParallel(
                self.criterion).cuda()  # spread in gpus

        # optimizer
        if self.ARCH["post"]["CRF"]["use"] and self.ARCH["post"]["CRF"][
                "train"]:
            self.lr_group_names = ["post_lr"]
            self.train_dicts = [{'params': self.model_single.CRF.parameters()}]
        else:
            self.lr_group_names = []
            self.train_dicts = []
        if self.ARCH["backbone"]["train"]:
            self.lr_group_names.append("backbone_lr")
            self.train_dicts.append(
                {'params': self.model_single.backbone.parameters()})
        if self.ARCH["decoder"]["train"]:
            self.lr_group_names.append("decoder_lr")
            self.train_dicts.append(
                {'params': self.model_single.decoder.parameters()})
        if self.ARCH["head"]["train"]:
            self.lr_group_names.append("head_lr")
            self.train_dicts.append(
                {'params': self.model_single.head.parameters()})

        # Use SGD optimizer to train
        self.optimizer = optim.SGD(self.train_dicts,
                                   lr=self.ARCH["train"]["lr"],
                                   momentum=self.ARCH["train"]["momentum"],
                                   weight_decay=self.ARCH["train"]["w_decay"])

        # Use warmup learning rate
        # post decay and step sizes come in epochs and we want it in steps
        steps_per_epoch = self.parser.get_train_size()
        up_steps = int(self.ARCH["train"]["wup_epochs"] * steps_per_epoch)
        final_decay = self.ARCH["train"]["lr_decay"]**(1 / steps_per_epoch)
        self.scheduler = warmupLR(optimizer=self.optimizer,
                                  lr=self.ARCH["train"]["lr"],
                                  warmup_steps=up_steps,
                                  momentum=self.ARCH["train"]["momentum"],
                                  decay=final_decay)
Exemplo n.º 2
0
    def __init__(self,
                 ARCH,
                 DATA,
                 datadir,
                 logdir,
                 path=None,
                 model_mode='salsanext'):
        # parameters
        self.ARCH = ARCH
        self.DATA = DATA
        self.datadir = datadir
        self.log = logdir
        self.path = path
        self.model_mode = model_mode

        self.batch_time_t = AverageMeter()
        self.data_time_t = AverageMeter()
        self.batch_time_e = AverageMeter()
        self.epoch = 0

        # put logger where it belongs

        self.info = {
            "train_update": 0,
            "train_loss": 0,
            "train_acc": 0,
            "train_iou": 0,
            "valid_loss": 0,
            "valid_acc": 0,
            "valid_iou": 0,
            "best_train_iou": 0,
            "best_val_iou": 0
        }

        # get the data
        parserModule = imp.load_source(
            "parserModule", booger.TRAIN_PATH + '/tasks/semantic/dataset/' +
            self.DATA["name"] + '/parser.py')
        self.parser = parserModule.Parser(
            root=self.datadir,
            train_sequences=self.DATA["split"]["train"],
            valid_sequences=self.DATA["split"]["valid"],
            test_sequences=None,
            labels=self.DATA["labels"],
            color_map=self.DATA["color_map"],
            learning_map=self.DATA["learning_map"],
            learning_map_inv=self.DATA["learning_map_inv"],
            sensor=self.ARCH["dataset"]["sensor"],
            max_points=self.ARCH["dataset"]["max_points"],
            batch_size=self.ARCH["train"]["batch_size"],
            workers=self.ARCH["train"]["workers"],
            gt=True,
            shuffle_train=True)

        # weights for loss (and bias)
        # weights for loss (and bias)
        epsilon_w = self.ARCH["train"]["epsilon_w"]
        content = torch.zeros(self.parser.get_n_classes(), dtype=torch.float)
        for cl, freq in DATA["content"].items():
            x_cl = self.parser.to_xentropy(
                cl)  # map actual class to xentropy class
            content[x_cl] += freq
        self.loss_w = 1 / (content + epsilon_w)  # get weights
        for x_cl, w in enumerate(
                self.loss_w):  # ignore the ones necessary to ignore
            if DATA["learning_ignore"][x_cl]:
                # don't weigh
                self.loss_w[x_cl] = 0
        print("Loss weights from content: ", self.loss_w.data)
        # concatenate the encoder and the head
        with torch.no_grad():
            self.model = SalsaNet(self.ARCH, self.parser.get_n_classes(),
                                  self.path)

        self.tb_logger = Logger(self.log + "/tb", self.model)

        # GPU?
        self.gpu = False
        self.multi_gpu = False
        self.n_gpus = 0
        self.model_single = self.model
        pytorch_total_params = sum(p.numel() for p in self.model.parameters()
                                   if p.requires_grad)
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                print("{}: {:,}".format(name, param.numel()))
        print(
            "Total of Trainable Parameters: {:,}".format(pytorch_total_params))
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        print("Training in device: ", self.device)
        if torch.cuda.is_available() and torch.cuda.device_count() > 0:
            cudnn.benchmark = True
            cudnn.fastest = True
            self.gpu = True
            self.n_gpus = 1
            self.model.cuda()
        if torch.cuda.is_available() and torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            self.model = nn.DataParallel(self.model)  # spread in gpus
            self.model = convert_model(self.model).cuda()  # sync batchnorm
            self.model_single = self.model.module  # single model to get weight names
            self.multi_gpu = True
            self.n_gpus = torch.cuda.device_count()

        self.criterion = nn.NLLLoss(weight=self.loss_w).to(self.device)
        self.ls = Lovasz_softmax(ignore=0).to(self.device)
        # loss as dataparallel too (more images in batch)
        if self.n_gpus > 1:
            self.criterion = nn.DataParallel(
                self.criterion).cuda()  # spread in gpus
            self.ls = nn.DataParallel(self.ls).cuda()
        self.optimizer = optim.SGD([{
            'params': self.model.parameters()
        }],
                                   lr=self.ARCH["train"]["lr"],
                                   momentum=self.ARCH["train"]["momentum"],
                                   weight_decay=self.ARCH["train"]["w_decay"])

        # Use warmup learning rate
        # post decay and step sizes come in epochs and we want it in steps
        steps_per_epoch = self.parser.get_train_size()
        up_steps = int(self.ARCH["train"]["wup_epochs"] * steps_per_epoch)
        final_decay = self.ARCH["train"]["lr_decay"]**(1 / steps_per_epoch)
        self.scheduler = warmupLR(optimizer=self.optimizer,
                                  lr=self.ARCH["train"]["lr"],
                                  warmup_steps=up_steps,
                                  momentum=self.ARCH["train"]["momentum"],
                                  decay=final_decay)

        if self.path is not None:
            torch.nn.Module.dump_patches = True
            w_dict = torch.load(path + "/SalsaNet",
                                map_location=lambda storage, loc: storage)
            self.model.load_state_dict(w_dict['state_dict'], strict=True)
            self.optimizer.load_state_dict(w_dict['optimizer'])
            self.epoch = w_dict['epoch'] + 1
            self.scheduler.load_state_dict(w_dict['scheduler'])
            print("dict epoch:", w_dict['epoch'])
            self.info = w_dict['info']
            print("info", w_dict['info'])
Exemplo n.º 3
0
    def __init__(self,
                 ARCH,
                 DATA,
                 datadir,
                 logdir,
                 path=None,
                 uncertainty=False):
        # parameters
        self.ARCH = ARCH
        self.DATA = DATA
        self.datadir = datadir
        self.log = logdir
        self.path = path
        self.uncertainty = uncertainty

        self.batch_time_t = AverageMeter()
        self.data_time_t = AverageMeter()
        self.batch_time_e = AverageMeter()
        self.epoch = 0

        # put logger where it belongs

        self.info = {
            "train_update": 0,
            "train_loss": 0,
            "train_acc": 0,
            "train_iou": 0,
            "valid_loss": 0,
            "valid_acc": 0,
            "valid_iou": 0,
            "best_train_iou": 0,
            "best_val_iou": 0
        }

        # get the data
        parserModule = imp.load_source(
            "parserModule", booger.TRAIN_PATH + '/tasks/semantic/dataset/' +
            self.DATA["name"] + '/parser.py')
        self.parser = parserModule.Parser(
            root=self.datadir,
            train_sequences=self.DATA["split"]["train"],
            valid_sequences=self.DATA["split"]["valid"],
            test_sequences=None,
            labels=self.DATA["labels"],
            color_map=self.DATA["color_map"],
            learning_map=self.DATA["learning_map"],
            learning_map_inv=self.DATA["learning_map_inv"],
            sensor=self.ARCH["dataset"]["sensor"],
            max_points=self.ARCH["dataset"]["max_points"],
            batch_size=self.ARCH["train"]["batch_size"],
            workers=self.ARCH["train"]["workers"],
            gt=True,
            # 想要在 show_scan=True 時看到連續畫面,就把這邊改False即可
            shuffle_train=True)

        # weights for loss (and bias)

        epsilon_w = self.ARCH["train"]["epsilon_w"]
        content = torch.zeros(self.parser.get_n_classes(), dtype=torch.float)
        for cl, freq in DATA["content"].items():
            x_cl = self.parser.to_xentropy(
                cl)  # map actual class to xentropy class
            content[x_cl] += freq
        self.loss_w = 1 / (content + epsilon_w)  # get weights
        for x_cl, w in enumerate(
                self.loss_w):  # ignore the ones necessary to ignore
            if DATA["learning_ignore"][x_cl]:
                # don't weigh
                self.loss_w[x_cl] = 0
        print("Loss weights from content: ", self.loss_w.data)

        with torch.no_grad():
            self.model = SalsaNext(self.parser.get_n_classes())
            self.discriminator = Discriminator()
        self.tb_logger = Logger(self.log + "/tb")

        # GPU?
        self.gpu = False
        self.multi_gpu = False
        self.n_gpus = 0
        self.model_single = self.model
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        print("Training in device: ", self.device)
        if torch.cuda.is_available() and torch.cuda.device_count() > 0:
            cudnn.benchmark = True
            cudnn.fastest = True
            self.gpu = True
            self.n_gpus = 1
            self.model.cuda()
            self.discriminator.cuda()

        # loss function

        # optimizer_C = optim.Adam(label_predictor.parameters())
        # optimizer_D = optim.Adam(domain_classifier.parameters())

        self.criterion = nn.NLLLoss(weight=self.loss_w).to(self.device)
        self.ls = Lovasz_softmax(ignore=0).to(self.device)
        self.SoftmaxHeteroscedasticLoss = SoftmaxHeteroscedasticLoss().to(
            self.device)
        self.criterion_pixelwise = torch.nn.SmoothL1Loss().to(self.device)
        self.criterion_GAN = torch.nn.BCEWithLogitsLoss().to(self.device)
        self.optimizer = optim.SGD([{
            'params': self.model.parameters()
        }],
                                   lr=self.ARCH["train"]["lr"],
                                   momentum=self.ARCH["train"]["momentum"],
                                   weight_decay=self.ARCH["train"]["w_decay"])
        self.optimizer_D = optim.SGD(
            [{
                'params': self.discriminator.parameters()
            }],
            lr=self.ARCH["train"]["lr"],
            momentum=self.ARCH["train"]["momentum"],
            weight_decay=self.ARCH["train"]["w_decay"])

        # Use warmup learning rate
        # post decay and step sizes come in epochs and we want it in steps
        steps_per_epoch = self.parser.get_train_size()
        up_steps = int(self.ARCH["train"]["wup_epochs"] * steps_per_epoch)
        final_decay = self.ARCH["train"]["lr_decay"]**(1 / steps_per_epoch)
        self.scheduler = warmupLR(optimizer=self.optimizer,
                                  lr=self.ARCH["train"]["lr"],
                                  warmup_steps=up_steps,
                                  momentum=self.ARCH["train"]["momentum"],
                                  decay=final_decay)
        self.scheduler_D = warmupLR(optimizer_D=self.optimizer_D,
                                    lr=self.ARCH["train"]["lr"],
                                    warmup_steps=up_steps,
                                    momentum=self.ARCH["train"]["momentum"],
                                    decay=final_decay)

        if self.path is not None:
            # generator
            torch.nn.Module.dump_patches = True
            w_dict = torch.load(path + "/SalsaNext",
                                map_location=lambda storage, loc: storage)
            self.model.load_state_dict(w_dict['state_dict'], strict=True)
            self.optimizer.load_state_dict(w_dict['optimizer'])
            self.epoch = w_dict['epoch'] + 1
            self.scheduler.load_state_dict(w_dict['scheduler'])
            print("dict epoch:", w_dict['epoch'])
            self.info = w_dict['info']
            print("info", w_dict['info'])
            # discriminator
            w_dict_D = torch.load(path + "/SalsaNext_D",
                                  map_location=lambda storage, loc: storage)
            self.discriminator.load_state_dict(w_dict_D['state_dict'],
                                               strict=True)
            self.optimizer_D.load_state_dict(w_dict_D['optimizer_D'])
            self.epoch = w_dict_D['epoch'] + 1
            self.scheduler_D.load_state_dict(w_dict_D['scheduler_D'])
            print("dict epoch:", w_dict_D['epoch'])
            self.info = w_dict_D['info']
            print("info", w_dict_D['info'])
Exemplo n.º 4
0
    def __init__(self, ARCH, DATA, datadir, logdir, path=None):
        # parameters
        self.ARCH = ARCH
        self.DATA = DATA
        self.datadir = datadir
        self.log = logdir
        self.path = path

        # put logger where it belongs
        self.tb_logger = Logger(self.log + "/tb")
        self.info = {
            "train_update": 0,
            "train_loss": 0,
            "train_acc": 0,
            "train_iou": 0,
            "valid_loss": 0,
            "valid_acc": 0,
            "valid_iou": 0,
            "backbone_lr": 0,
            "decoder_lr": 0,
            "head_lr": 0,
            "post_lr": 0
        }

        # get the data
        parserModule = imp.load_source(
            "parserModule", booger.TRAIN_PATH + '/tasks/semantic/dataset/' +
            self.DATA["name"] + '/parser.py')
        self.parser = parserModule.Parser(
            root=self.datadir,
            train_sequences=self.DATA["split"]["train"],
            valid_sequences=self.DATA["split"]["valid"],
            test_sequences=None,
            labels=self.DATA["labels"],
            color_map=self.DATA["color_map"],
            learning_map=self.DATA["learning_map"],
            learning_map_inv=self.DATA["learning_map_inv"],
            sensor=self.ARCH["dataset"]["sensor"],
            max_points=self.ARCH["dataset"]["max_points"],
            batch_size=self.ARCH["train"]["batch_size"],
            workers=self.ARCH["train"]["workers"],
            gt=True,
            shuffle_train=True)

        # weights for loss (and bias)
        # weights for loss (and bias)
        epsilon_w = self.ARCH["train"]["epsilon_w"]
        content = torch.zeros(self.parser.get_n_classes(), dtype=torch.float)
        for cl, freq in DATA["content"].items():
            x_cl = self.parser.to_xentropy(
                cl)  # map actual class to xentropy class
            content[x_cl] += freq
        self.loss_w = 1 / (content + epsilon_w)  # get weights
        # self.loss_w = np.power(self.loss_w, 0.50)
        for x_cl, w in enumerate(
                self.loss_w):  # ignore the ones necessary to ignore
            if DATA["learning_ignore"][x_cl]:
                # don't weigh
                self.loss_w[x_cl] = 0
        print("Loss weights from content: ", self.loss_w.data)

        # concatenate the encoder and the head
        with torch.no_grad():
            self.model = Segmentator(self.ARCH, self.parser.get_n_classes(),
                                     self.path)

        # GPU?
        self.gpu = False
        self.multi_gpu = False
        self.n_gpus = 0
        self.model_single = self.model
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        print("Training in device: ", self.device)
        if torch.cuda.is_available() and torch.cuda.device_count() > 0:
            cudnn.benchmark = True
            cudnn.fastest = True
            self.gpu = True
            self.n_gpus = 1
            self.model.cuda()

        # if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        #   print("Let's use", torch.cuda.device_count(), "GPUs!")
        #   self.model = nn.DataParallel(self.model)   # spread in gpus
        #   self.model = convert_model(self.model).cuda()  # sync batchnorm
        #   self.model_single = self.model.module  # single model to get weight names
        #   self.multi_gpu = True
        #   self.n_gpus = torch.cuda.device_count()

        # loss
        if "loss" in self.ARCH["train"].keys(
        ) and self.ARCH["train"]["loss"] == "xentropy":
            self.criterion = nn.NLLLoss(weight=self.loss_w).to(self.device)
        else:
            raise Exception('Loss not defined in config file')
        # loss as dataparallel too (more images in batch)
        if self.n_gpus > 1:
            self.criterion = nn.DataParallel(
                self.criterion).cuda()  # spread in gpus

        # optimizer
        if self.ARCH["post"]["CRF"]["use"] and self.ARCH["post"]["CRF"][
                "train"]:
            self.lr_group_names = ["post_lr"]
            self.train_dicts = [{'params': self.model_single.CRF.parameters()}]
        else:
            self.lr_group_names = []
            self.train_dicts = []
        if self.ARCH["backbone"]["train"]:
            self.lr_group_names.append("backbone_lr")
            self.train_dicts.append(
                {'params': self.model_single.backbone.parameters()})
        if self.ARCH["decoder"]["train"]:
            self.lr_group_names.append("decoder_lr")
            self.train_dicts.append(
                {'params': self.model_single.decoder.parameters()})
        if self.ARCH["head"]["train"]:
            self.lr_group_names.append("head_lr")
            self.train_dicts.append(
                {'params': self.model_single.head.parameters()})

        # Use SGD optimizer to train
        self.optimizer = optim.SGD(self.train_dicts,
                                   lr=self.ARCH["train"]["lr"],
                                   momentum=self.ARCH["train"]["momentum"],
                                   weight_decay=self.ARCH["train"]["w_decay"])

        # Use warmup learning rate
        # post decay and step sizes come in epochs and we want it in steps
        steps_per_epoch = self.parser.get_train_size()
        up_steps = int(self.ARCH["train"]["wup_epochs"] * steps_per_epoch)
        final_decay = self.ARCH["train"]["lr_decay"]**(1 / steps_per_epoch)
        self.scheduler = warmupLR(optimizer=self.optimizer,
                                  lr=self.ARCH["train"]["lr"],
                                  warmup_steps=up_steps,
                                  momentum=self.ARCH["train"]["momentum"],
                                  decay=final_decay)
        from thop import profile

        inputs = torch.randn(1, 10, 64, 2048).cuda()

        representations = {}
        representations['image'] = []
        representations['points'] = []
        representations['points'].append(torch.randn(1, 10, 9, 131072).cuda())
        representations['points'].append(torch.randn(1, 10, 9, 32768).cuda())
        representations['points'].append(torch.randn(1, 10, 9, 8192).cuda())
        representations['points'].append(torch.randn(1, 10, 9, 2048).cuda())
        representations['image'].append(torch.randn(1, 10, 64, 2048).cuda())
        representations['image'].append(torch.randn(1, 10, 32, 1024).cuda())
        representations['image'].append(torch.randn(1, 10, 16, 512).cuda())
        representations['image'].append(torch.randn(1, 10, 8, 256).cuda())

        flops, params = profile(self.model,
                                inputs=([inputs, representations], ),
                                verbose=False)
        time_train = []
        outputs = self.model([inputs, representations])
        outputs = self.model([inputs, representations])

        for i in range(20):
            inputs = torch.randn(1, 10, 64, 2048).cuda()
            inputs_points = torch.randn(1, 10, 64, 8192).cuda()

            with torch.no_grad():
                start_time = time.time()
                outputs = self.model([inputs, representations])

            torch.cuda.synchronize(
            )  # wait for cuda to finish (cuda is asynchronous!)
            fwt = time.time() - start_time
            time_train.append(fwt)
            print("Forward time per img: %.3f (Mean: %.3f)" %
                  (fwt / 1, sum(time_train) / len(time_train) / 1))
            print("Total number of flops (G): ", flops / 1000000000.)
            time.sleep(0.3)
Exemplo n.º 5
0
    def __init__(self, ARCH, DATA, datadir, logdir, path=None):
        # parameters
        self.ARCH = ARCH
        self.DATA = DATA
        self.datadir = datadir
        self.log = logdir
        self.path = path

        # put logger where it belongs
        self.tb_logger = Logger(self.log + "/tb")
        self.info = {
            "train_update": 0,
            "train_loss": 0,
            "train_acc": 0,
            "train_iou": 0,
            "valid_loss": 0,
            "valid_acc": 0,
            "valid_iou": 0,
            "backbone_lr": 0,
            "decoder_lr": 0,
            "head_lr": 0,
            "post_lr": 0
        }

        # get the data
        parserModule = imp.load_source(
            "parserModule",
            booger.TRAIN_PATH + '/tasks/mask_regression/dataset/' +
            self.DATA["name"] + '/parser.py')
        self.parser = parserModule.Parser(
            root=self.datadir,
            train_sequences=self.DATA["split"]["train"],
            valid_sequences=self.DATA["split"]["valid"],
            test_sequences=None,
            labels=self.DATA["labels"],
            color_map=self.DATA["color_map"],
            learning_map=self.DATA["learning_map"],
            learning_map_inv=self.DATA["learning_map_inv"],
            sensor=self.ARCH["dataset"]["sensor"],
            max_points=self.ARCH["dataset"]["max_points"],
            batch_size=self.ARCH["train"]["batch_size"],
            workers=self.ARCH["train"]["workers"],
            gt=True,
            shuffle_train=True)

        # concatenate the encoder and the head
        with torch.no_grad():
            self.model = MaskRegressor(self.ARCH, self.path)

        # GPU?
        self.gpu = False
        self.multi_gpu = False
        self.n_gpus = 0
        self.model_single = self.model
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        print("Training in device: ", self.device)
        if torch.cuda.is_available() and torch.cuda.device_count() > 0:
            cudnn.benchmark = True
            cudnn.fastest = True
            self.gpu = True
            self.n_gpus = 1
            self.model.cuda()
        if torch.cuda.is_available() and torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            self.model = nn.DataParallel(self.model)  # spread in gpus
            self.model = convert_model(self.model).cuda()  # sync batchnorm
            self.model_single = self.model.module  # single model to get weight names
            self.multi_gpu = True
            self.n_gpus = torch.cuda.device_count()

        # loss
        if "loss" in self.ARCH["train"].keys(
        ) and self.ARCH["train"]["loss"] == "xentropy":
            w = torch.zeros(1, dtype=torch.float)
            w[0] = 14.0
            self.criterion = nn.BCEWithLogitsLoss(pos_weight=w).to(self.device)
        else:
            raise Exception('Loss not defined in config file')
        # loss as dataparallel too (more images in batch)
        if self.n_gpus > 1:
            self.criterion = nn.DataParallel(
                self.criterion).cuda()  # spread in gpus

        # optimizer
        self.lr_group_names = []
        self.train_dicts = []
        if self.ARCH["backbone"]["train"]:
            self.lr_group_names.append("backbone_lr")
            self.train_dicts.append(
                {'params': self.model_single.backbone.parameters()})
        if self.ARCH["decoder"]["train"]:
            self.lr_group_names.append("decoder_lr")
            self.train_dicts.append(
                {'params': self.model_single.decoder.parameters()})
        if self.ARCH["head"]["train"]:
            self.lr_group_names.append("head_lr")
            self.train_dicts.append(
                {'params': self.model_single.head.parameters()})

        # Use SGD optimizer to train
        self.optimizer = optim.SGD(self.train_dicts,
                                   lr=self.ARCH["train"]["lr"],
                                   momentum=self.ARCH["train"]["momentum"],
                                   weight_decay=self.ARCH["train"]["w_decay"])

        # Use warmup learning rate
        # post decay and step sizes come in epochs and we want it in steps
        steps_per_epoch = self.parser.get_train_size()
        up_steps = int(self.ARCH["train"]["wup_epochs"] * steps_per_epoch)
        final_decay = self.ARCH["train"]["lr_decay"]**(1 / steps_per_epoch)
        self.scheduler = warmupLR(optimizer=self.optimizer,
                                  lr=self.ARCH["train"]["lr"],
                                  warmup_steps=up_steps,
                                  momentum=self.ARCH["train"]["momentum"],
                                  decay=final_decay)