コード例 #1
0
def create_ema_model(model):
    ema_model = Res_Deeplab(num_classes=num_classes)

    for param in ema_model.parameters():
        param.detach_()
    mp = list(model.parameters())
    mcp = list(ema_model.parameters())
    n = len(mp)
    for i in range(0, n):
        mcp[i].data[:] = mp[i].data[:].clone()
    if len(gpus) > 1:
        if use_sync_batchnorm:
            ema_model = convert_model(ema_model)
            ema_model = DataParallelWithCallback(ema_model, device_ids=gpus)
        else:
            ema_model = torch.nn.DataParallel(ema_model, device_ids=gpus)
    return ema_model
コード例 #2
0
def create_ema_model(model):
    #ema_model = getattr(models, config['arch']['type'])(self.train_loader.dataset.num_classes, **config['arch']['args']).to(self.device)
    ema_model = Res_Deeplab(num_classes=num_classes)

    for param in ema_model.parameters():
        param.detach_()
    mp = list(model.parameters())
    mcp = list(ema_model.parameters())
    n = len(mp)
    for i in range(0, n):
        mcp[i].data[:] = mp[i].data[:].clone()
    #_, availble_gpus = self._get_available_devices(self.config['n_gpu'])
    #ema_model = torch.nn.DataParallel(ema_model, device_ids=availble_gpus)
    if len(gpus) > 1:
        #return torch.nn.DataParallel(ema_model, device_ids=gpus)
        if use_sync_batchnorm:
            ema_model = convert_model(ema_model)
            ema_model = DataParallelWithCallback(ema_model, device_ids=gpus)
        else:
            ema_model = torch.nn.DataParallel(ema_model, device_ids=gpus)
    return ema_model
コード例 #3
0
    def __init__(self,
                 model,
                 loss,
                 resume,
                 config,
                 train_loader,
                 val_loader=None,
                 train_logger=None):
        self.model = model
        self.loss = loss
        self.config = config
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_logger = train_logger
        self.logger = logging.getLogger(self.__class__.__name__)
        self.do_validation = self.config['trainer']['val']
        self.start_epoch = 1
        self.improved = False

        # SETTING THE DEVICE
        self.device, availble_gpus = self._get_available_devices(
            self.config['n_gpu'])
        self.model.loss = loss
        if config["use_synch_bn"]:
            self.model = convert_model(self.model)
            self.model = DataParallelWithCallback(self.model,
                                                  device_ids=availble_gpus)
        else:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=availble_gpus)
        self.model.cuda()

        # CONFIGS
        cfg_trainer = self.config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']

        # OPTIMIZER
        if self.config['optimizer']['differential_lr']:
            if isinstance(self.model, torch.nn.DataParallel):
                trainable_params = [{
                    'params':
                    filter(lambda p: p.requires_grad,
                           self.model.module.get_decoder_params())
                }, {
                    'params':
                    filter(lambda p: p.requires_grad,
                           self.model.module.get_backbone_params()),
                    'lr':
                    config['optimizer']['args']['lr'] / 10
                }]
            else:
                trainable_params = [{
                    'params':
                    filter(lambda p: p.requires_grad,
                           self.model.get_decoder_params())
                }, {
                    'params':
                    filter(lambda p: p.requires_grad,
                           self.model.get_backbone_params()),
                    'lr':
                    config['optimizer']['args']['lr'] / 10
                }]
        else:
            trainable_params = filter(lambda p: p.requires_grad,
                                      self.model.parameters())
        self.optimizer = get_instance(torch.optim, 'optimizer', config,
                                      trainable_params)
        self.lr_scheduler = getattr(utils.lr_scheduler,
                                    config['lr_scheduler']['type'])(
                                        self.optimizer, self.epochs,
                                        len(train_loader))
        #self.lr_scheduler = getattr(torch.optim.lr_scheduler, config['lr_scheduler']['type'])(self.optimizer, **config['lr_scheduler']['args'])

        # MONITORING
        self.monitor = cfg_trainer.get('monitor', 'off')
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']
            self.mnt_best = -math.inf if self.mnt_mode == 'max' else math.inf
            self.early_stoping = cfg_trainer.get('early_stop', math.inf)

        # CHECKPOINTS & TENSOBOARD
        start_time = datetime.datetime.now().strftime('%m-%d_%H-%M')
        self.checkpoint_dir = os.path.join(cfg_trainer['save_dir'],
                                           self.config['name'], start_time)
        helpers.dir_exists(self.checkpoint_dir)
        config_save_path = os.path.join(self.checkpoint_dir, 'config.json')
        with open(config_save_path, 'w') as handle:
            json.dump(self.config, handle, indent=4, sort_keys=True)

        writer_dir = os.path.join(cfg_trainer['log_dir'], self.config['name'],
                                  start_time)
        self.writer = tensorboard.SummaryWriter(writer_dir)

        if resume: self._resume_checkpoint(resume)
コード例 #4
0
class BaseTrainer:
    def __init__(self,
                 model,
                 loss,
                 resume,
                 config,
                 train_loader,
                 val_loader=None,
                 train_logger=None):
        self.model = model
        self.loss = loss
        self.config = config
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_logger = train_logger
        self.logger = logging.getLogger(self.__class__.__name__)
        self.do_validation = self.config['trainer']['val']
        self.start_epoch = 1
        self.improved = False

        # SETTING THE DEVICE
        self.device, availble_gpus = self._get_available_devices(
            self.config['n_gpu'])
        self.model.loss = loss
        if config["use_synch_bn"]:
            self.model = convert_model(self.model)
            self.model = DataParallelWithCallback(self.model,
                                                  device_ids=availble_gpus)
        else:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=availble_gpus)
        self.model.cuda()

        # CONFIGS
        cfg_trainer = self.config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']

        # OPTIMIZER
        if self.config['optimizer']['differential_lr']:
            if isinstance(self.model, torch.nn.DataParallel):
                trainable_params = [{
                    'params':
                    filter(lambda p: p.requires_grad,
                           self.model.module.get_decoder_params())
                }, {
                    'params':
                    filter(lambda p: p.requires_grad,
                           self.model.module.get_backbone_params()),
                    'lr':
                    config['optimizer']['args']['lr'] / 10
                }]
            else:
                trainable_params = [{
                    'params':
                    filter(lambda p: p.requires_grad,
                           self.model.get_decoder_params())
                }, {
                    'params':
                    filter(lambda p: p.requires_grad,
                           self.model.get_backbone_params()),
                    'lr':
                    config['optimizer']['args']['lr'] / 10
                }]
        else:
            trainable_params = filter(lambda p: p.requires_grad,
                                      self.model.parameters())
        self.optimizer = get_instance(torch.optim, 'optimizer', config,
                                      trainable_params)
        self.lr_scheduler = getattr(utils.lr_scheduler,
                                    config['lr_scheduler']['type'])(
                                        self.optimizer, self.epochs,
                                        len(train_loader))
        #self.lr_scheduler = getattr(torch.optim.lr_scheduler, config['lr_scheduler']['type'])(self.optimizer, **config['lr_scheduler']['args'])

        # MONITORING
        self.monitor = cfg_trainer.get('monitor', 'off')
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']
            self.mnt_best = -math.inf if self.mnt_mode == 'max' else math.inf
            self.early_stoping = cfg_trainer.get('early_stop', math.inf)

        # CHECKPOINTS & TENSOBOARD
        start_time = datetime.datetime.now().strftime('%m-%d_%H-%M')
        self.checkpoint_dir = os.path.join(cfg_trainer['save_dir'],
                                           self.config['name'], start_time)
        helpers.dir_exists(self.checkpoint_dir)
        config_save_path = os.path.join(self.checkpoint_dir, 'config.json')
        with open(config_save_path, 'w') as handle:
            json.dump(self.config, handle, indent=4, sort_keys=True)

        writer_dir = os.path.join(cfg_trainer['log_dir'], self.config['name'],
                                  start_time)
        self.writer = tensorboard.SummaryWriter(writer_dir)

        if resume: self._resume_checkpoint(resume)

    def _get_available_devices(self, n_gpu):
        sys_gpu = torch.cuda.device_count()
        if sys_gpu == 0:
            self.logger.warning('No GPUs detected, using the CPU')
            n_gpu = 0
        elif n_gpu > sys_gpu:
            self.logger.warning(
                f'Nbr of GPU requested is {n_gpu} but only {sys_gpu} are available'
            )
            n_gpu = sys_gpu

        device = torch.device('cuda' if n_gpu > 0 else 'cpu')
        self.logger.info(f'Detected GPUs: {sys_gpu} Requested: {n_gpu}')
        available_gpus = list(range(n_gpu))
        return device, available_gpus

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            # RUN TRAIN (AND VAL)
            results = self._train_epoch(epoch)
            self.lr_scheduler.step()
            if self.do_validation and epoch % self.config['trainer'][
                    'val_per_epochs'] == 0:
                results = self._valid_epoch(epoch)

                # LOGGING INFO
                self.logger.info(f'\n         ## Info for epoch {epoch} ## ')
                for k, v in results.items():
                    self.logger.info(f'         {str(k):15s}: {v}')

            if self.train_logger is not None:
                log = {'epoch': epoch, **results}
                self.train_logger.add_entry(log)

            # CHECKING IF THIS IS THE BEST MODEL (ONLY FOR VAL)
            if self.mnt_mode != 'off' and epoch % self.config['trainer'][
                    'val_per_epochs'] == 0:
                try:
                    if self.mnt_mode == 'min':
                        self.improved = (log[self.mnt_metric] < self.mnt_best)
                    else:
                        self.improved = (log[self.mnt_metric] > self.mnt_best)
                except KeyError:
                    self.logger.warning(
                        f'The metrics being tracked ({self.mnt_metric}) has not been calculated. Training stops.'
                    )
                    break

                if self.improved:
                    self.mnt_best = log[self.mnt_metric]
                    self.not_improved_count = 0
                else:
                    self.not_improved_count += 1

                if self.not_improved_count > self.early_stoping:
                    self.logger.info(
                        f'\nPerformance didn\'t improve for {self.early_stoping} epochs'
                    )
                    self.logger.warning('Training Stoped')
                    break

            # SAVE CHECKPOINT
            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=self.improved)

    def _save_checkpoint(self, epoch, save_best=False):
        state = {
            'arch': type(self.model).__name__,
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'monitor_best': self.mnt_best,
            'config': self.config
        }
        filename = os.path.join(self.checkpoint_dir,
                                f'checkpoint-epoch{epoch}.pth')
        self.logger.info(f'\nSaving a checkpoint: {filename} ...')
        torch.save(state, filename)

        if save_best:
            filename = os.path.join(self.checkpoint_dir, f'best_model.pth')
            torch.save(state, filename)
            self.logger.info("Saving current best: best_model.pth")

    def _resume_checkpoint(self, resume_path):
        self.logger.info(f'Loading checkpoint : {resume_path}')
        checkpoint = torch.load(resume_path)

        # Load last run info, the model params, the optimizer and the loggers
        self.start_epoch = checkpoint['epoch'] + 1
        self.mnt_best = checkpoint['monitor_best']
        self.not_improved_count = 0

        if checkpoint['config']['arch'] != self.config['arch']:
            self.logger.warning({
                'Warning! Current model is not the same as the one in the checkpoint'
            })
        self.model.load_state_dict(checkpoint['state_dict'], strict=False)

        if checkpoint['config']['optimizer']['type'] != self.config[
                'optimizer']['type']:
            self.logger.warning({
                'Warning! Current optimizer is not the same as the one in the checkpoint'
            })
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        # if self.lr_scheduler:
        #     self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

        #self.train_logger = checkpoint['logger']
        #self.logger.info(f'Checkpoint <{resume_path}> (epoch {self.start_epoch}) was loaded')

    def _train_epoch(self, epoch):
        raise NotImplementedError

    def _valid_epoch(self, epoch):
        raise NotImplementedError

    def _eval_metrics(self, output, target):
        raise NotImplementedError
コード例 #5
0
def main():
    torch.cuda.empty_cache()
    print(config)

    best_mIoU = 0

    if consistency_loss == 'CE':
        if len(gpus) > 1:
            unlabeled_loss = torch.nn.DataParallel(
                CrossEntropyLoss2dPixelWiseWeighted(ignore_index=ignore_label),
                device_ids=gpus).cuda()
        else:
            unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted().cuda()
    elif consistency_loss == 'MSE':
        if len(gpus) > 1:
            unlabeled_loss = torch.nn.DataParallel(MSELoss2d(),
                                                   device_ids=gpus).cuda()
        else:
            unlabeled_loss = MSELoss2d().cuda()

    cudnn.enabled = True

    # create network
    model = Res_Deeplab(num_classes=num_classes)

    # load pretrained parameters
    if restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(restore_from)
    else:
        saved_state_dict = torch.load(restore_from)

    # Copy loaded parameters to model
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
    model.load_state_dict(new_params)

    # Initiate ema-model
    if train_unlabeled:
        ema_model = create_ema_model(model)
        ema_model.train()
        ema_model = ema_model.cuda()
    else:
        ema_model = None

    if len(gpus) > 1:
        if use_sync_batchnorm:
            model = convert_model(model)
            model = DataParallelWithCallback(model, device_ids=gpus)
        else:
            model = torch.nn.DataParallel(model, device_ids=gpus)
    model.train()
    model.cuda()

    cudnn.benchmark = True

    if dataset == 'pascal_voc':
        data_loader = get_loader(dataset)
        data_path = get_data_path(dataset)
        train_dataset = data_loader(data_path,
                                    crop_size=input_size,
                                    scale=random_scale,
                                    mirror=random_flip)

    elif dataset == 'cityscapes':
        data_loader = get_loader('cityscapes')
        data_path = get_data_path('cityscapes')
        if random_crop:
            data_aug = Compose([RandomCrop_city(input_size)])
        else:
            data_aug = None

        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    augmentations=data_aug,
                                    img_size=input_size)

    train_dataset_size = len(train_dataset)
    print('dataset size: ', train_dataset_size)

    partial_size = labeled_samples
    print('Training on number of samples:', partial_size)
    if split_id is not None:
        train_ids = pickle.load(open(split_id, 'rb'))
        print('loading train ids from {}'.format(split_id))
    else:
        np.random.seed(random_seed)
        train_ids = np.arange(train_dataset_size)
        np.random.shuffle(train_ids)

    train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])
    trainloader = data.DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  sampler=train_sampler,
                                  num_workers=num_workers,
                                  pin_memory=True)
    trainloader_iter = iter(trainloader)

    if train_unlabeled:
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=1,
                                             pin_memory=True)
        trainloader_remain_iter = iter(trainloader_remain)

    # Optimizer for segmentation network
    learning_rate_object = Learning_Rate_Object(
        config['training']['learning_rate'])

    if optimizer_type == 'SGD':
        if len(gpus) > 1:
            optimizer = optim.SGD(
                model.module.optim_parameters(learning_rate_object),
                lr=learning_rate,
                momentum=momentum,
                weight_decay=weight_decay)
        else:
            optimizer = optim.SGD(
                model.optim_parameters(
                    learning_rate_object),  ## DOES THIS CAUSE THE USERWARNING?
                lr=learning_rate,
                momentum=momentum,
                weight_decay=weight_decay)

    optimizer.zero_grad()

    interp = nn.Upsample(size=(input_size[0], input_size[1]),
                         mode='bilinear',
                         align_corners=True)

    start_iteration = 0

    if args.resume:
        start_iteration, model, optimizer, ema_model = _resume_checkpoint(
            args.resume, model, optimizer, ema_model)

    accumulated_loss_l = []
    if train_unlabeled:
        accumulated_loss_u = []

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    with open(checkpoint_dir + '/config.json', 'w') as handle:
        json.dump(config, handle, indent=4, sort_keys=False)
    pickle.dump(train_ids,
                open(os.path.join(checkpoint_dir, 'train_split.pkl'), 'wb'))

    epochs_since_start = 0
    for i_iter in range(start_iteration, num_iterations):
        model.train()

        loss_l_value = 0
        if train_unlabeled:
            loss_u_value = 0

        optimizer.zero_grad()

        if lr_schedule:
            adjust_learning_rate(optimizer, i_iter)

        # Training loss for labeled data only
        try:
            batch = next(trainloader_iter)
            if batch[0].shape[0] != batch_size:
                batch = next(trainloader_iter)
        except:
            epochs_since_start = epochs_since_start + 1
            print('Epochs since start: ', epochs_since_start)
            trainloader_iter = iter(trainloader)
            batch = next(trainloader_iter)

        weak_parameters = {"flip": 0}

        images, labels, _, _, _ = batch
        images = images.cuda()
        labels = labels.cuda()

        images, labels = weakTransform(weak_parameters,
                                       data=images,
                                       target=labels)
        intermediary_var = model(images)
        pred = interp(intermediary_var)

        L_l = loss_calc(pred, labels)

        if train_unlabeled:
            try:
                batch_remain = next(trainloader_remain_iter)
                if batch_remain[0].shape[0] != batch_size:
                    batch_remain = next(trainloader_remain_iter)
            except:
                trainloader_remain_iter = iter(trainloader_remain)
                batch_remain = next(trainloader_remain_iter)

            images_remain, _, _, _, _ = batch_remain
            images_remain = images_remain.cuda()
            inputs_u_w, _ = weakTransform(weak_parameters, data=images_remain)
            logits_u_w = interp(ema_model(inputs_u_w))
            logits_u_w, _ = weakTransform(
                getWeakInverseTransformParameters(weak_parameters),
                data=logits_u_w.detach())

            softmax_u_w = torch.softmax(logits_u_w.detach(), dim=1)
            max_probs, argmax_u_w = torch.max(softmax_u_w, dim=1)

            if mix_mask == "class":

                for image_i in range(batch_size):
                    classes = torch.unique(argmax_u_w[image_i])
                    classes = classes[classes != ignore_label]
                    nclasses = classes.shape[0]
                    classes = (classes[torch.Tensor(
                        np.random.choice(nclasses,
                                         int((nclasses - nclasses % 2) / 2),
                                         replace=False)).long()]).cuda()
                    if image_i == 0:
                        MixMask = transformmasks.generate_class_mask(
                            argmax_u_w[image_i], classes).unsqueeze(0).cuda()
                    else:
                        MixMask = torch.cat(
                            (MixMask,
                             transformmasks.generate_class_mask(
                                 argmax_u_w[image_i],
                                 classes).unsqueeze(0).cuda()))

            elif mix_mask == 'cut':
                img_size = inputs_u_w.shape[2:4]
                for image_i in range(batch_size):
                    if image_i == 0:
                        MixMask = torch.from_numpy(
                            transformmasks.generate_cutout_mask(
                                img_size)).unsqueeze(0).cuda().float()
                    else:
                        MixMask = torch.cat(
                            (MixMask,
                             torch.from_numpy(
                                 transformmasks.generate_cutout_mask(
                                     img_size)).unsqueeze(0).cuda().float()))

            elif mix_mask == "cow":
                img_size = inputs_u_w.shape[2:4]
                sigma_min = 8
                sigma_max = 32
                p_min = 0.5
                p_max = 0.5
                for image_i in range(batch_size):
                    sigma = np.exp(
                        np.random.uniform(np.log(sigma_min),
                                          np.log(sigma_max)))  # Random sigma
                    p = np.random.uniform(p_min, p_max)  # Random p
                    if image_i == 0:
                        MixMask = torch.from_numpy(
                            transformmasks.generate_cow_mask(
                                img_size, sigma, p,
                                seed=None)).unsqueeze(0).cuda().float()
                    else:
                        MixMask = torch.cat(
                            (MixMask,
                             torch.from_numpy(
                                 transformmasks.generate_cow_mask(
                                     img_size, sigma, p,
                                     seed=None)).unsqueeze(0).cuda().float()))

            elif mix_mask == None:
                MixMask = torch.ones((inputs_u_w.shape)).cuda()

            strong_parameters = {"Mix": MixMask}
            if random_flip:
                strong_parameters["flip"] = random.randint(0, 1)
            else:
                strong_parameters["flip"] = 0
            if color_jitter:
                strong_parameters["ColorJitter"] = random.uniform(0, 1)
            else:
                strong_parameters["ColorJitter"] = 0
            if gaussian_blur:
                strong_parameters["GaussianBlur"] = random.uniform(0, 1)
            else:
                strong_parameters["GaussianBlur"] = 0

            inputs_u_s, _ = strongTransform(strong_parameters,
                                            data=images_remain)
            logits_u_s = interp(model(inputs_u_s))

            softmax_u_w_mixed, _ = strongTransform(strong_parameters,
                                                   data=softmax_u_w)
            max_probs, pseudo_label = torch.max(softmax_u_w_mixed, dim=1)

            if pixel_weight == "threshold_uniform":
                unlabeled_weight = torch.sum(
                    max_probs.ge(0.968).long() == 1).item() / np.size(
                        np.array(pseudo_label.cpu()))
                pixelWiseWeight = unlabeled_weight * torch.ones(
                    max_probs.shape).cuda()
            elif pixel_weight == "threshold":
                pixelWiseWeight = max_probs.ge(0.968).long().cuda()
            elif pixel_weight == 'sigmoid':
                max_iter = 10000
                pixelWiseWeight = sigmoid_ramp_up(
                    i_iter, max_iter) * torch.ones(max_probs.shape).cuda()
            elif pixel_weight == False:
                pixelWiseWeight = torch.ones(max_probs.shape).cuda()

            if consistency_loss == 'CE':
                L_u = consistency_weight * unlabeled_loss(
                    logits_u_s, pseudo_label, pixelWiseWeight)
            elif consistency_loss == 'MSE':
                unlabeled_weight = torch.sum(
                    max_probs.ge(0.968).long() == 1).item() / np.size(
                        np.array(pseudo_label.cpu()))
                #softmax_u_w_mixed = torch.cat((softmax_u_w_mixed[1].unsqueeze(0),softmax_u_w_mixed[0].unsqueeze(0)))
                L_u = consistency_weight * unlabeled_weight * unlabeled_loss(
                    logits_u_s, softmax_u_w_mixed)

            loss = L_l + L_u

        else:
            loss = L_l

        if len(gpus) > 1:
            loss = loss.mean()
            loss_l_value += L_l.mean().item()
            if train_unlabeled:
                loss_u_value += L_u.mean().item()
        else:
            loss_l_value += L_l.item()
            if train_unlabeled:
                loss_u_value += L_u.item()

        loss.backward()
        optimizer.step()

        # update Mean teacher network
        if ema_model is not None:
            alpha_teacher = 0.99
            ema_model = update_ema_variables(ema_model=ema_model,
                                             model=model,
                                             alpha_teacher=alpha_teacher,
                                             iteration=i_iter)

        if train_unlabeled:
            print('iter = {0:6d}/{1:6d}, loss_l = {2:.3f}, loss_u = {3:.3f}'.
                  format(i_iter, num_iterations, loss_l_value, loss_u_value))
        else:
            print('iter = {0:6d}/{1:6d}, loss_l = {2:.3f}'.format(
                i_iter, num_iterations, loss_l_value))

        if i_iter % save_checkpoint_every == 0 and i_iter != 0:
            _save_checkpoint(i_iter, model, optimizer, config, ema_model)

        if use_tensorboard:
            if 'tensorboard_writer' not in locals():
                tensorboard_writer = tensorboard.SummaryWriter(log_dir,
                                                               flush_secs=30)

            accumulated_loss_l.append(loss_l_value)
            if train_unlabeled:
                accumulated_loss_u.append(loss_u_value)
            if i_iter % log_per_iter == 0 and i_iter != 0:

                tensorboard_writer.add_scalar('Training/Supervised loss',
                                              np.mean(accumulated_loss_l),
                                              i_iter)
                accumulated_loss_l = []

                if train_unlabeled:
                    tensorboard_writer.add_scalar('Training/Unsupervised loss',
                                                  np.mean(accumulated_loss_u),
                                                  i_iter)
                    accumulated_loss_u = []

        if i_iter % val_per_iter == 0 and i_iter != 0:
            model.eval()
            mIoU, eval_loss = evaluate(model,
                                       dataset,
                                       ignore_label=ignore_label,
                                       input_size=(512, 1024),
                                       save_dir=checkpoint_dir)

            model.train()

            if mIoU > best_mIoU and save_best_model:
                best_mIoU = mIoU
                _save_checkpoint(i_iter,
                                 model,
                                 optimizer,
                                 config,
                                 ema_model,
                                 save_best=True)

            if use_tensorboard:
                tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter)
                tensorboard_writer.add_scalar('Validation/Loss', eval_loss,
                                              i_iter)

        if save_unlabeled_images and train_unlabeled and i_iter % save_checkpoint_every == 0:
            # Saves two mixed images and the corresponding prediction
            save_image(inputs_u_s[0].cpu(), i_iter, 'input1',
                       palette.CityScpates_palette)
            save_image(inputs_u_s[1].cpu(), i_iter, 'input2',
                       palette.CityScpates_palette)
            _, pred_u_s = torch.max(logits_u_s, dim=1)
            save_image(pred_u_s[0].cpu(), i_iter, 'pred1',
                       palette.CityScpates_palette)
            save_image(pred_u_s[1].cpu(), i_iter, 'pred2',
                       palette.CityScpates_palette)

    _save_checkpoint(num_iterations, model, optimizer, config, ema_model)

    model.eval()
    mIoU, val_loss = evaluate(model,
                              dataset,
                              ignore_label=ignore_label,
                              input_size=(512, 1024),
                              save_dir=checkpoint_dir)

    model.train()
    if mIoU > best_mIoU and save_best_model:
        best_mIoU = mIoU
        _save_checkpoint(i_iter,
                         model,
                         optimizer,
                         config,
                         ema_model,
                         save_best=True)

    if use_tensorboard:
        tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter)
        tensorboard_writer.add_scalar('Validation/Loss', val_loss, i_iter)

    end = timeit.default_timer()
    print('Total time: ' + str(end - start) + ' seconds')
コード例 #6
0
def main():
    print(config)

    best_mIoU = 0

    if consistency_loss == 'MSE':
        if len(gpus) > 1:
            unlabeled_loss = torch.nn.DataParallel(MSELoss2d(),
                                                   device_ids=gpus).cuda()
        else:
            unlabeled_loss = MSELoss2d().cuda()
    elif consistency_loss == 'CE':
        if len(gpus) > 1:
            unlabeled_loss = torch.nn.DataParallel(
                CrossEntropyLoss2dPixelWiseWeighted(ignore_index=ignore_label),
                device_ids=gpus).cuda()
        else:
            unlabeled_loss = CrossEntropyLoss2dPixelWiseWeighted(
                ignore_index=ignore_label).cuda()

    cudnn.enabled = True

    # create network
    model = Res_Deeplab(num_classes=num_classes)

    # load pretrained parameters
    #saved_state_dict = torch.load(args.restore_from)
    # load pretrained parameters
    if restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(restore_from)
    else:
        saved_state_dict = torch.load(restore_from)

    # Copy loaded parameters to model
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
    model.load_state_dict(new_params)

    # init ema-model
    if train_unlabeled:
        ema_model = create_ema_model(model)
        ema_model.train()
        ema_model = ema_model.cuda()
    else:
        ema_model = None

    if len(gpus) > 1:
        if use_sync_batchnorm:
            model = convert_model(model)
            model = DataParallelWithCallback(model, device_ids=gpus)
        else:
            model = torch.nn.DataParallel(model, device_ids=gpus)
    model.train()
    model.cuda()

    cudnn.benchmark = True
    data_loader = get_loader(config['dataset'])
    # data_path = get_data_path(config['dataset'])
    # if random_crop:
    # data_aug = Compose([RandomCrop_city(input_size)])
    # else:
    # data_aug = None
    data_aug = Compose([RandomHorizontallyFlip()])
    if dataset == 'cityscapes':
        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    augmentations=data_aug,
                                    img_size=input_size,
                                    img_mean=IMG_MEAN)
    elif dataset == 'multiview':
        # adaption data
        data_path = '/tmp/tcn_data/texture_multibot_push_left10050/videos/train_adaptation'
        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    view_idx=0,
                                    number_views=1,
                                    load_seg_mask=False,
                                    augmentations=data_aug,
                                    img_size=input_size,
                                    img_mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)
    print('dataset size: ', train_dataset_size)

    if labeled_samples is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=num_workers,
                                      pin_memory=True)

        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=num_workers,
                                             pin_memory=True)
        trainloader_remain_iter = iter(trainloader_remain)

    else:
        partial_size = labeled_samples
        print('Training on number of samples:', partial_size)
        np.random.seed(random_seed)
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=4,
                                             pin_memory=True)

        trainloader_remain_iter = iter(trainloader_remain)

    #New loader for Domain transfer
    # if random_crop:
    # data_aug = Compose([RandomCrop_gta(input_size)])
    # else:
    # data_aug = None
    # SUPERVSIED DATA
    data_path = '/tmp/tcn_data/texture_multibot_push_left10050/videos/train_adaptation'
    data_aug = Compose([RandomHorizontallyFlip()])
    if dataset == 'multiview':
        train_dataset = data_loader(data_path,
                                    is_transform=True,
                                    view_idx=0,
                                    number_views=1,
                                    load_seg_mask=True,
                                    augmentations=data_aug,
                                    img_size=input_size,
                                    img_mean=IMG_MEAN)
    else:
        data_loader = get_loader('gta')
        data_path = get_data_path('gta')
        train_dataset = data_loader(data_path,
                                    list_path='./data/gta5_list/train.txt',
                                    augmentations=data_aug,
                                    img_size=(1280, 720),
                                    mean=IMG_MEAN)

    trainloader = data.DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers,
                                  pin_memory=True)

    # training loss for labeled data only
    trainloader_iter = iter(trainloader)
    print('gta size:', len(trainloader))

    #Load new data for domain_transfer

    # optimizer for segmentation network
    learning_rate_object = Learning_Rate_Object(
        config['training']['learning_rate'])

    if optimizer_type == 'SGD':
        if len(gpus) > 1:
            optimizer = optim.SGD(
                model.module.optim_parameters(learning_rate_object),
                lr=learning_rate,
                momentum=momentum,
                weight_decay=weight_decay)
        else:
            optimizer = optim.SGD(model.optim_parameters(learning_rate_object),
                                  lr=learning_rate,
                                  momentum=momentum,
                                  weight_decay=weight_decay)
    elif optimizer_type == 'Adam':
        if len(gpus) > 1:
            optimizer = optim.Adam(
                model.module.optim_parameters(learning_rate_object),
                lr=learning_rate,
                momentum=momentum,
                weight_decay=weight_decay)
        else:
            optimizer = optim.Adam(
                model.optim_parameters(learning_rate_object),
                lr=learning_rate,
                weight_decay=weight_decay)

    optimizer.zero_grad()

    interp = nn.Upsample(size=(input_size[0], input_size[1]),
                         mode='bilinear',
                         align_corners=True)
    start_iteration = 0

    if args.resume:
        start_iteration, model, optimizer, ema_model = _resume_checkpoint(
            args.resume, model, optimizer, ema_model)

    accumulated_loss_l = []
    accumulated_loss_u = []

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    with open(checkpoint_dir + '/config.json', 'w') as handle:
        json.dump(config, handle, indent=4, sort_keys=True)

    epochs_since_start = 0
    for i_iter in range(start_iteration, num_iterations):
        model.train()

        loss_u_value = 0
        loss_l_value = 0

        optimizer.zero_grad()

        if lr_schedule:
            adjust_learning_rate(optimizer, i_iter)

        # training loss for labeled data only
        try:
            batch = next(trainloader_iter)
            if batch[0].shape[0] != batch_size:
                batch = next(trainloader_iter)
        except:
            epochs_since_start = epochs_since_start + 1
            print('Epochs since start: ', epochs_since_start)
            trainloader_iter = iter(trainloader)
            batch = next(trainloader_iter)

        #if random_flip:
        #    weak_parameters={"flip":random.randint(0,1)}
        #else:
        weak_parameters = {"flip": 0}

        images, labels, _, _ = batch
        images = images.cuda()
        labels = labels.cuda().long()

        #images, labels = weakTransform(weak_parameters, data = images, target = labels)

        pred = interp(model(images))
        L_l = loss_calc(pred, labels)  # Cross entropy loss for labeled data
        #L_l = torch.Tensor([0.0]).cuda()

        if train_unlabeled:
            try:
                batch_remain = next(trainloader_remain_iter)
                if batch_remain[0].shape[0] != batch_size:
                    batch_remain = next(trainloader_remain_iter)
            except:
                trainloader_remain_iter = iter(trainloader_remain)
                batch_remain = next(trainloader_remain_iter)

            images_remain, *_ = batch_remain
            images_remain = images_remain.cuda()
            inputs_u_w, _ = weakTransform(weak_parameters, data=images_remain)
            #inputs_u_w = inputs_u_w.clone()
            logits_u_w = interp(ema_model(inputs_u_w))
            logits_u_w, _ = weakTransform(
                getWeakInverseTransformParameters(weak_parameters),
                data=logits_u_w.detach())

            pseudo_label = torch.softmax(logits_u_w.detach(), dim=1)
            max_probs, targets_u_w = torch.max(pseudo_label, dim=1)

            if mix_mask == "class":
                for image_i in range(batch_size):
                    classes = torch.unique(labels[image_i])
                    #classes=classes[classes!=ignore_label]
                    nclasses = classes.shape[0]
                    #if nclasses > 0:
                    classes = (classes[torch.Tensor(
                        np.random.choice(nclasses,
                                         int((nclasses + nclasses % 2) / 2),
                                         replace=False)).long()]).cuda()

                    if image_i == 0:
                        MixMask0 = transformmasks.generate_class_mask(
                            labels[image_i], classes).unsqueeze(0).cuda()
                    else:
                        MixMask1 = transformmasks.generate_class_mask(
                            labels[image_i], classes).unsqueeze(0).cuda()

            elif mix_mask == None:
                MixMask = torch.ones((inputs_u_w.shape))

            strong_parameters = {"Mix": MixMask0}
            if random_flip:
                strong_parameters["flip"] = random.randint(0, 1)
            else:
                strong_parameters["flip"] = 0
            if color_jitter:
                strong_parameters["ColorJitter"] = random.uniform(0, 1)
            else:
                strong_parameters["ColorJitter"] = 0
            if gaussian_blur:
                strong_parameters["GaussianBlur"] = random.uniform(0, 1)
            else:
                strong_parameters["GaussianBlur"] = 0

            inputs_u_s0, _ = strongTransform(
                strong_parameters,
                data=torch.cat(
                    (images[0].unsqueeze(0), images_remain[0].unsqueeze(0))))
            strong_parameters["Mix"] = MixMask1
            inputs_u_s1, _ = strongTransform(
                strong_parameters,
                data=torch.cat(
                    (images[1].unsqueeze(0), images_remain[1].unsqueeze(0))))
            inputs_u_s = torch.cat((inputs_u_s0, inputs_u_s1))
            logits_u_s = interp(model(inputs_u_s))

            strong_parameters["Mix"] = MixMask0
            _, targets_u0 = strongTransform(strong_parameters,
                                            target=torch.cat(
                                                (labels[0].unsqueeze(0),
                                                 targets_u_w[0].unsqueeze(0))))
            strong_parameters["Mix"] = MixMask1
            _, targets_u1 = strongTransform(strong_parameters,
                                            target=torch.cat(
                                                (labels[1].unsqueeze(0),
                                                 targets_u_w[1].unsqueeze(0))))
            targets_u = torch.cat((targets_u0, targets_u1)).long()

            if pixel_weight == "threshold_uniform":
                unlabeled_weight = torch.sum(
                    max_probs.ge(0.968).long() == 1).item() / np.size(
                        np.array(targets_u.cpu()))
                pixelWiseWeight = unlabeled_weight * torch.ones(
                    max_probs.shape).cuda()
            elif pixel_weight == "threshold":
                pixelWiseWeight = max_probs.ge(0.968).float().cuda()
            elif pixel_weight == False:
                pixelWiseWeight = torch.ones(max_probs.shape).cuda()

            onesWeights = torch.ones((pixelWiseWeight.shape)).cuda()
            strong_parameters["Mix"] = MixMask0
            _, pixelWiseWeight0 = strongTransform(
                strong_parameters,
                target=torch.cat((onesWeights[0].unsqueeze(0),
                                  pixelWiseWeight[0].unsqueeze(0))))
            strong_parameters["Mix"] = MixMask1
            _, pixelWiseWeight1 = strongTransform(
                strong_parameters,
                target=torch.cat((onesWeights[1].unsqueeze(0),
                                  pixelWiseWeight[1].unsqueeze(0))))
            pixelWiseWeight = torch.cat(
                (pixelWiseWeight0, pixelWiseWeight1)).cuda()

            if consistency_loss == 'MSE':
                unlabeled_weight = torch.sum(
                    max_probs.ge(0.968).long() == 1).item() / np.size(
                        np.array(targets_u.cpu()))
                #pseudo_label = torch.cat((pseudo_label[1].unsqueeze(0),pseudo_label[0].unsqueeze(0)))
                L_u = consistency_weight * unlabeled_weight * unlabeled_loss(
                    logits_u_s, pseudo_label)
            elif consistency_loss == 'CE':
                L_u = consistency_weight * unlabeled_loss(
                    logits_u_s, targets_u, pixelWiseWeight)

            loss = L_l + L_u

        else:
            loss = L_l

        if len(gpus) > 1:
            #print('before mean = ',loss)
            loss = loss.mean()
            #print('after mean = ',loss)
            loss_l_value += L_l.mean().item()
            if train_unlabeled:
                loss_u_value += L_u.mean().item()
        else:
            loss_l_value += L_l.item()
            if train_unlabeled:
                loss_u_value += L_u.item()

        loss.backward()
        optimizer.step()

        # update Mean teacher network
        if ema_model is not None:
            alpha_teacher = 0.99
            ema_model = update_ema_variables(ema_model=ema_model,
                                             model=model,
                                             alpha_teacher=alpha_teacher,
                                             iteration=i_iter)

        print(
            'iter = {0:6d}/{1:6d}, loss_l = {2:.3f}, loss_u = {3:.3f}'.format(
                i_iter, num_iterations, loss_l_value, loss_u_value))

        if i_iter % save_checkpoint_every == 0 and i_iter != 0:
            if epochs_since_start * len(trainloader) < save_checkpoint_every:
                _save_checkpoint(i_iter,
                                 model,
                                 optimizer,
                                 config,
                                 ema_model,
                                 overwrite=False)
            else:
                _save_checkpoint(i_iter, model, optimizer, config, ema_model)

        if config['utils']['tensorboard']:
            if 'tensorboard_writer' not in locals():
                tensorboard_writer = tensorboard.SummaryWriter(log_dir,
                                                               flush_secs=30)

            accumulated_loss_l.append(loss_l_value)
            if train_unlabeled:
                accumulated_loss_u.append(loss_u_value)
            if i_iter % log_per_iter == 0 and i_iter != 0:

                tensorboard_writer.add_scalar('Training/Supervised loss',
                                              np.mean(accumulated_loss_l),
                                              i_iter)
                accumulated_loss_l = []

                if train_unlabeled:
                    tensorboard_writer.add_scalar('Training/Unsupervised loss',
                                                  np.mean(accumulated_loss_u),
                                                  i_iter)
                    accumulated_loss_u = []

        if i_iter % val_per_iter == 0 and i_iter != 0:
            model.eval()
            if dataset == 'cityscapes':
                mIoU, eval_loss = evaluate(model,
                                           dataset,
                                           ignore_label=250,
                                           input_size=(512, 1024),
                                           save_dir=checkpoint_dir)
            elif dataset == 'multiview':
                mIoU, eval_loss = evaluate(model,
                                           dataset,
                                           ignore_label=255,
                                           input_size=(300, 300),
                                           save_dir=checkpoint_dir)
            else:
                print('erro dataset: {}'.format(dataset))
            model.train()

            if mIoU > best_mIoU and save_best_model:
                best_mIoU = mIoU
                _save_checkpoint(i_iter,
                                 model,
                                 optimizer,
                                 config,
                                 ema_model,
                                 save_best=True)

            if config['utils']['tensorboard']:
                tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter)
                tensorboard_writer.add_scalar('Validation/Loss', eval_loss,
                                              i_iter)
                print('iter {}, mIoU: {}'.format(mIoU, i_iter))

        if save_unlabeled_images and train_unlabeled and i_iter % save_checkpoint_every == 0:
            # Saves two mixed images and the corresponding prediction
            save_image(inputs_u_s[0].cpu(), i_iter, 'input1',
                       palette.CityScpates_palette)
            save_image(inputs_u_s[1].cpu(), i_iter, 'input2',
                       palette.CityScpates_palette)
            _, pred_u_s = torch.max(logits_u_s, dim=1)
            save_image(pred_u_s[0].cpu(), i_iter, 'pred1',
                       palette.CityScpates_palette)
            save_image(pred_u_s[1].cpu(), i_iter, 'pred2',
                       palette.CityScpates_palette)

    _save_checkpoint(num_iterations, model, optimizer, config, ema_model)

    model.eval()
    if dataset == 'cityscapes':
        mIoU, val_loss = evaluate(model,
                                  dataset,
                                  ignore_label=250,
                                  input_size=(512, 1024),
                                  save_dir=checkpoint_dir)
    elif dataset == 'multiview':
        mIoU, val_loss = evaluate(model,
                                  dataset,
                                  ignore_label=255,
                                  input_size=(300, 300),
                                  save_dir=checkpoint_dir)
    else:
        print('erro dataset: {}'.format(dataset))
    model.train()
    if mIoU > best_mIoU and save_best_model:
        best_mIoU = mIoU
        _save_checkpoint(i_iter,
                         model,
                         optimizer,
                         config,
                         ema_model,
                         save_best=True)

    if config['utils']['tensorboard']:
        tensorboard_writer.add_scalar('Validation/mIoU', mIoU, i_iter)
        tensorboard_writer.add_scalar('Validation/Loss', val_loss, i_iter)

    end = timeit.default_timer()
    print('Total time: ' + str(end - start) + 'seconds')
コード例 #7
0
    def __init__(
        self,
        model,
        loss,
        resume,
        config,
        train_loader,
        val_loader=None,
        train_logger=None,
    ):
        self.model = model
        self.loss = loss
        self.config = config
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_logger = train_logger
        self.logger = logging.getLogger(self.__class__.__name__)
        self.do_validation = self.config["trainer"]["val"]
        self.start_epoch = 1
        self.improved = False

        # SETTING THE DEVICE
        self.device, availble_gpus = self._get_available_devices(
            self.config["n_gpu"])
        if config["use_synch_bn"]:
            self.model = convert_model(self.model)
            self.model = DataParallelWithCallback(self.model,
                                                  device_ids=availble_gpus)
        else:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=availble_gpus)
        self.model.to(self.device)

        # CONFIGS
        cfg_trainer = self.config["trainer"]
        self.epochs = cfg_trainer["epochs"]
        self.save_period = cfg_trainer["save_period"]

        # OPTIMIZER
        if self.config["optimizer"]["differential_lr"]:
            if isinstance(self.model, torch.nn.DataParallel):
                trainable_params = [
                    {
                        "params":
                        filter(
                            lambda p: p.requires_grad,
                            self.model.module.get_decoder_params(),
                        )
                    },
                    {
                        "params":
                        filter(
                            lambda p: p.requires_grad,
                            self.model.module.get_backbone_params(),
                        ),
                        "lr":
                        config["optimizer"]["args"]["lr"] / 10,
                    },
                ]
            else:
                trainable_params = [
                    {
                        "params":
                        filter(lambda p: p.requires_grad,
                               self.model.get_decoder_params())
                    },
                    {
                        "params":
                        filter(lambda p: p.requires_grad,
                               self.model.get_backbone_params()),
                        "lr":
                        config["optimizer"]["args"]["lr"] / 10,
                    },
                ]
        else:
            trainable_params = filter(lambda p: p.requires_grad,
                                      self.model.parameters())
        self.optimizer = get_instance(torch.optim, "optimizer", config,
                                      trainable_params)
        self.lr_scheduler = getattr(utils.lr_scheduler,
                                    config["lr_scheduler"]["type"])(
                                        self.optimizer, self.epochs,
                                        len(train_loader))

        # MONITORING
        self.monitor = cfg_trainer.get("monitor", "off")
        if self.monitor == "off":
            self.mnt_mode = "off"
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ["min", "max"]
            self.mnt_best = -math.inf if self.mnt_mode == "max" else math.inf
            self.early_stoping = cfg_trainer.get("early_stop", math.inf)

        # CHECKPOINTS & TENSOBOARD
        start_time = datetime.datetime.now().strftime("%m-%d_%H-%M")
        self.checkpoint_dir = os.path.join(cfg_trainer["save_dir"],
                                           self.config["name"], start_time)
        helpers.dir_exists(self.checkpoint_dir)
        config_save_path = os.path.join(self.checkpoint_dir, "config.json")
        with open(config_save_path, "w") as handle:
            json.dump(self.config, handle, indent=4, sort_keys=True)

        writer_dir = os.path.join(cfg_trainer["log_dir"], self.config["name"],
                                  start_time)
        self.writer = tensorboard.SummaryWriter(writer_dir)

        if resume:
            self._resume_checkpoint(resume)
コード例 #8
0
class BaseTrainer:
    def __init__(
        self,
        model,
        loss,
        resume,
        config,
        train_loader,
        val_loader=None,
        train_logger=None,
    ):
        self.model = model
        self.loss = loss
        self.config = config
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_logger = train_logger
        self.logger = logging.getLogger(self.__class__.__name__)
        self.do_validation = self.config["trainer"]["val"]
        self.start_epoch = 1
        self.improved = False

        # SETTING THE DEVICE
        self.device, availble_gpus = self._get_available_devices(
            self.config["n_gpu"])
        if config["use_synch_bn"]:
            self.model = convert_model(self.model)
            self.model = DataParallelWithCallback(self.model,
                                                  device_ids=availble_gpus)
        else:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=availble_gpus)
        self.model.to(self.device)

        # CONFIGS
        cfg_trainer = self.config["trainer"]
        self.epochs = cfg_trainer["epochs"]
        self.save_period = cfg_trainer["save_period"]

        # OPTIMIZER
        if self.config["optimizer"]["differential_lr"]:
            if isinstance(self.model, torch.nn.DataParallel):
                trainable_params = [
                    {
                        "params":
                        filter(
                            lambda p: p.requires_grad,
                            self.model.module.get_decoder_params(),
                        )
                    },
                    {
                        "params":
                        filter(
                            lambda p: p.requires_grad,
                            self.model.module.get_backbone_params(),
                        ),
                        "lr":
                        config["optimizer"]["args"]["lr"] / 10,
                    },
                ]
            else:
                trainable_params = [
                    {
                        "params":
                        filter(lambda p: p.requires_grad,
                               self.model.get_decoder_params())
                    },
                    {
                        "params":
                        filter(lambda p: p.requires_grad,
                               self.model.get_backbone_params()),
                        "lr":
                        config["optimizer"]["args"]["lr"] / 10,
                    },
                ]
        else:
            trainable_params = filter(lambda p: p.requires_grad,
                                      self.model.parameters())
        self.optimizer = get_instance(torch.optim, "optimizer", config,
                                      trainable_params)
        self.lr_scheduler = getattr(utils.lr_scheduler,
                                    config["lr_scheduler"]["type"])(
                                        self.optimizer, self.epochs,
                                        len(train_loader))

        # MONITORING
        self.monitor = cfg_trainer.get("monitor", "off")
        if self.monitor == "off":
            self.mnt_mode = "off"
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ["min", "max"]
            self.mnt_best = -math.inf if self.mnt_mode == "max" else math.inf
            self.early_stoping = cfg_trainer.get("early_stop", math.inf)

        # CHECKPOINTS & TENSOBOARD
        start_time = datetime.datetime.now().strftime("%m-%d_%H-%M")
        self.checkpoint_dir = os.path.join(cfg_trainer["save_dir"],
                                           self.config["name"], start_time)
        helpers.dir_exists(self.checkpoint_dir)
        config_save_path = os.path.join(self.checkpoint_dir, "config.json")
        with open(config_save_path, "w") as handle:
            json.dump(self.config, handle, indent=4, sort_keys=True)

        writer_dir = os.path.join(cfg_trainer["log_dir"], self.config["name"],
                                  start_time)
        self.writer = tensorboard.SummaryWriter(writer_dir)

        if resume:
            self._resume_checkpoint(resume)

    def _get_available_devices(self, n_gpu):
        sys_gpu = torch.cuda.device_count()
        if sys_gpu == 0:
            self.logger.warning("No GPUs detected, using the CPU")
            n_gpu = 0
        elif n_gpu > sys_gpu:
            self.logger.warning(
                f"Nbr of GPU requested is {n_gpu} but only {sys_gpu} are available"
            )
            n_gpu = sys_gpu

        device = torch.device("cuda:0" if n_gpu > 0 else "cpu")
        self.logger.info(f"Detected GPUs: {sys_gpu} Requested: {n_gpu}")
        available_gpus = list(range(n_gpu))
        return device, available_gpus

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            # RUN TRAIN (AND VAL)
            results = self._train_epoch(epoch)
            if (self.do_validation
                    and epoch % self.config["trainer"]["val_per_epochs"] == 0):
                results = self._valid_epoch(epoch)

                # LOGGING INFO
                self.logger.info(f"\n         ## Info for epoch {epoch} ## ")
                for k, v in results.items():
                    self.logger.info(f"         {str(k):15s}: {v}")

            if self.train_logger is not None:
                log = {"epoch": epoch, **results}
                self.train_logger.add_entry(log)

            # CHECKING IF THIS IS THE BEST MODEL (ONLY FOR VAL)
            if (self.mnt_mode != "off"
                    and epoch % self.config["trainer"]["val_per_epochs"] == 0):
                try:
                    if self.mnt_mode == "min":
                        self.improved = log[self.mnt_metric] < self.mnt_best
                    else:
                        self.improved = log[self.mnt_metric] > self.mnt_best
                except KeyError:
                    self.logger.warning(
                        f"The metrics being tracked ({self.mnt_metric}) has not been calculated. Training stops."
                    )
                    break

                if self.improved:
                    self.mnt_best = log[self.mnt_metric]
                    self.not_improved_count = 0
                else:
                    self.not_improved_count += 1

                if self.not_improved_count > self.early_stoping:
                    self.logger.info(
                        f"\nPerformance didn't improve for {self.early_stoping} epochs"
                    )
                    self.logger.warning("Training Stoped")
                    break

            # SAVE CHECKPOINT
            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=self.improved)

    def _save_checkpoint(self, epoch, save_best=False):
        state = {
            "arch": type(self.model).__name__,
            "epoch": epoch,
            "state_dict": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "monitor_best": self.mnt_best,
            "config": self.config,
        }
        filename = os.path.join(self.checkpoint_dir,
                                f"checkpoint-epoch{epoch}.pth")
        self.logger.info(f"\nSaving a checkpoint: {filename} ...")
        torch.save(state, filename)

        if save_best:
            filename = os.path.join(self.checkpoint_dir, "best_model.pth")
            torch.save(state, filename)
            self.logger.info("Saving current best: best_model.pth")

    def _resume_checkpoint(self, resume_path):
        self.logger.info(f"Loading checkpoint : {resume_path}")
        checkpoint = torch.load(resume_path)

        # Load last run info, the model params, the optimizer and the loggers
        self.start_epoch = checkpoint["epoch"] + 1
        self.mnt_best = checkpoint["monitor_best"]
        self.not_improved_count = 0

        if checkpoint["config"]["arch"] != self.config["arch"]:
            self.logger.warning({
                "Warning! Current model is not the same as the one in the checkpoint"
            })
        self.model.load_state_dict(checkpoint["state_dict"])

        if (checkpoint["config"]["optimizer"]["type"] !=
                self.config["optimizer"]["type"]):
            self.logger.warning({
                "Warning! Current optimizer is not the same as the one in the checkpoint"
            })
        self.optimizer.load_state_dict(checkpoint["optimizer"])

        self.logger.info(
            f"Checkpoint <{resume_path}> (epoch {self.start_epoch}) was loaded"
        )

    def _train_epoch(self, epoch):
        raise NotImplementedError

    def _valid_epoch(self, epoch):
        raise NotImplementedError

    def _eval_metrics(self, output, target):
        raise NotImplementedError