예제 #1
0
def train(args):
    logging.info('======= user config ======')
    logging.info(pprint(opt))
    logging.info(pprint(args))
    logging.info('======= end ======')

    train_data, valid_data = get_data_provider(opt)

    net = getattr(network, opt.network.name)(classes=opt.dataset.num_classes)
    optimizer = getattr(torch.optim,
                        opt.train.optimizer)(net.parameters(),
                                             lr=opt.train.lr,
                                             weight_decay=opt.train.wd,
                                             momentum=opt.train.momentum)
    ce_loss = nn.CrossEntropyLoss()
    lr_scheduler = LRScheduler(base_lr=opt.train.lr,
                               step=opt.train.step,
                               factor=opt.train.factor,
                               warmup_epoch=opt.train.warmup_epoch,
                               warmup_begin_lr=opt.train.warmup_begin_lr)
    net = nn.DataParallel(net)
    net = net.cuda()
    mod = Solver(opt, net)
    mod.fit(train_data=train_data,
            test_data=valid_data,
            optimizer=optimizer,
            criterion=ce_loss,
            lr_scheduler=lr_scheduler)
예제 #2
0
    def __init__(self, args):
        self.args = args
        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        # dataset and dataloader
        data_kwargs = {'transform': input_transform,
                       'base_size': args.base_size,
                       'crop_size': args.crop_size}

        train_dataset = get_segmentation_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs)
        val_dataset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs)
        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                            drop_last=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_size=1,
                                          shuffle=False)

        # create network
        self.model = get_fast_scnn(dataset=args.dataset, aux=args.aux)
        if torch.cuda.device_count() > 1:
            self.model = torch.nn.DataParallel(self.model, device_ids=[0, 1, 2])
        self.model.to(args.device)

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))
                self.model.load_state_dict(torch.load(args.resume, map_location=lambda storage, loc: storage))

        # create criterion
        self.criterion = MixSoftmaxCrossEntropyOHEMLoss(aux=args.aux, aux_weight=args.aux_weight,
                                                        ignore_index=-1).to(args.device)

        # optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        # lr scheduling
        self.lr_scheduler = LRScheduler(mode='poly', base_lr=args.lr, nepochs=args.epochs,
                                        iters_per_epoch=len(self.train_loader),
                                        power=0.9)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)

        self.best_pred = 0.0
예제 #3
0
def train(args):
    logging.info('======= user config ======')
    logging.info(pprint(opt))
    logging.info(pprint(args))
    logging.info('======= end ======')

    train_data, test_data, num_query = get_data_provider(opt)

    net = getattr(network, opt.network.name)(opt.dataset.num_classes,
                                             opt.network.last_stride)
    net = nn.DataParallel(net).cuda()

    optimizer = getattr(torch.optim,
                        opt.train.optimizer)(net.parameters(),
                                             lr=opt.train.lr,
                                             weight_decay=opt.train.wd)
    ce_loss = nn.CrossEntropyLoss()
    triplet_loss = TripletLoss(margin=opt.train.margin)

    def ce_loss_func(scores, feat, labels):
        ce = ce_loss(scores, labels)
        return ce

    def tri_loss_func(scores, feat, labels):
        tri = triplet_loss(feat, labels)[0]
        return tri

    def ce_tri_loss_func(scores, feat, labels):
        ce = ce_loss(scores, labels)
        triplet = triplet_loss(feat, labels)[0]
        return ce + triplet

    if opt.train.loss_fn == 'softmax':
        loss_fn = ce_loss_func
    elif opt.train.loss_fn == 'triplet':
        loss_fn = tri_loss_func
    elif opt.train.loss_fn == 'softmax_triplet':
        loss_fn = ce_tri_loss_func
    else:
        raise ValueError('Unknown loss func {}'.format(opt.train.loss_fn))

    lr_scheduler = LRScheduler(base_lr=opt.train.lr,
                               step=opt.train.step,
                               factor=opt.train.factor,
                               warmup_epoch=opt.train.warmup_epoch,
                               warmup_begin_lr=opt.train.warmup_begin_lr)

    mod = Solver(opt, net)
    mod.fit(train_data=train_data,
            test_data=test_data,
            num_query=num_query,
            optimizer=optimizer,
            criterion=loss_fn,
            lr_scheduler=lr_scheduler)
예제 #4
0
    def __init__(self, settings: dict, settings_to_log: list):
        self.settings = settings
        self.settings_to_log = settings_to_log

        self.threshold = self.settings['threshold']
        self.start_epoch = self.settings['start_epoch']
        self.dataset = self.settings['dataset']
        self.batch_size = self.settings['batch_size']
        self.workers = self.settings['workers']
        self.cuda = self.settings['cuda']
        self.fp16 = self.settings['fp16']
        self.epochs = self.settings['epochs']
        self.ignore_index = self.settings['ignore_index']
        self.loss_reduction = self.settings['loss_reduction']

        # -------------------- Define Data loader ------------------------------
        self.loaders, self.nclass, self.plotter = make_data_loader(settings)
        self.train_loader, self.val_loader, self.test_loader = [self.loaders[key] for key in ['train', 'val', 'test']]

        # -------------------- Define model ------------------------------------
        self.model = get_model(self.settings)

        # -------------------- Define optimizer and its options ----------------
        self.optimizer = define_optimizer(self.model, self.settings['optimizer'], self.settings['optimizer_params'])
        if self.settings['lr_scheduler']:
            self.lr_scheduler = LRScheduler(self.settings['lr_scheduler'], self.optimizer, self.batch_size)

        # -------------------- Define loss -------------------------------------
        input_size = (self.batch_size, self.nclass, *self.settings['target_size'])
        self.criterion = CustomLoss(input_size=input_size, ignore_index=self.ignore_index, reduction=self.loss_reduction)

        self.evaluator = Evaluator(metrics=self.settings['metrics'], num_class=self.nclass, threshold=self.settings['threshold'])

        self.logger = MainLogger(loggers=self.settings['loggers'], settings=settings, settings_to_log=settings_to_log)
        if self.settings['resume']:
            self.resume_checkpoint(self.settings['resume'])

        self.metric_to_watch = 0.0
예제 #5
0
    def __init__(self, args):
        self.args = args

        self.img, self.target = VOCSegmentation().get()

        self.model = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone,
                                            aux=False, norm_layer=nn.BatchNorm2d).to(args.device)

        self.criterion = MixSoftmaxCrossEntropyLoss(False, 0., ignore_label=-1).to(args.device)

        # for EncNet
        # self.criterion = EncNetLoss(nclass=21, ignore_label=-1).to(args.device)

        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=args.lr,
                                          weight_decay=args.weight_decay)
        self.lr_scheduler = LRScheduler(mode='poly', base_lr=args.lr, nepochs=args.epochs,
                                        iters_per_epoch=1, power=0.9)
    def __init__(self, config, logger):
        super().__init__(config, logger)

        # Initialize Model
        self.model = MNIST(config)
        self.logger.log_torch_model(self.model)

        # Initialize Dataloader
        self.data_loader = MNISTDataLoader(config)

        # Define Loss Function
        self.loss = define_loss(config.agent.trainer.loss)

        # Define Metrics List
        self.metrics = Metrics(config)

        # Initialize Optimizer
        self.optimizer = Optim(self.model.parameters(), config)
        self.lr_scheduler = LRScheduler(self.optimizer, config)
예제 #7
0
    net = net.cuda()
    loss = loss.cuda()
    cudnn.benchmark = True
    #print(net.named_parameter())
    net = DataParallel(net)

    #for p in module.parameters():
    #p.requires_grad = True
    optimizer = torch.optim.SGD(net.parameters(),
                                lr,
                                momentum=0.9,
                                weight_decay=1e-4)
    lrs = LRScheduler(lr,
                      patience=patience,
                      factor=0.5,
                      min_lr=0.5 * 0.5 * 0.5 * lr,
                      best_loss=best_val_loss)
    #with torch.no_grad():
    #    val_metrics, val_time, val_iou, val_acc, val_mean_acc = validate(val_loader, net, loss, 0, num_class)
    #print (val_iou, val_acc, val_mean_acc)
    for epoch in range(start_epoch, epochs + 1):
        train_metrics, train_time = train(train_loader, net, loss, optimizer,
                                          lr)
        with torch.no_grad():
            val_metrics, val_time, val_iou, val_acc, val_mean_acc = validate(
                val_loader, net, loss, epoch, num_class)

        print_log(epoch,
                  lr,
                  train_metrics,
예제 #8
0
    def __init__(self, args, cfg=None):
        # train_dataset = [build_dataset(cfg.data.train)]
        # self.dataset= train_dataset
        # val_dataset = [build_dataset(cfg.data.test)]
        # if len(cfg.workflow) == 2:
        #     train_dataset.append(build_dataset(cfg.data.val))
        # train_data_loaders = [
        #     build_dataloader(
        #         ds,
        #         cfg.data.imgs_per_gpu,
        #         cfg.data.workers_per_gpu,
        #         # cfg.gpus,
        #         dist=False) for ds in train_dataset
        # ]
        # val_data_loader = build_dataloader(
        #     val_dataset,
        #     imgs_per_gpu=1,
        #     workers_per_gpu=cfg.data.workers_per_gpu,
        #     dist=False,
        #     shuffle=False)
        # self.train_loader = train_data_loaders[0]
        # self.val_loader = val_data_loader

        self.args = args
        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[123.675, 116.28, 103.53],
                                 std=[58.395, 57.12, 57.375]),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        train_dataset = get_segmentation_dataset(args.dataset,
                                                 split=args.train_split,
                                                 mode='train',
                                                 **data_kwargs)
        val_dataset = get_segmentation_dataset(args.dataset,
                                               split='val',
                                               mode='val',
                                               **data_kwargs)
        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                            drop_last=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_size=1,
                                          shuffle=False)

        # create network
        self.model = get_fast_scnn(dataset=args.dataset, aux=args.aux)
        if torch.cuda.device_count() > 1:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=[0, 1, 2])
        self.model.to(args.device)

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))
                self.model.load_state_dict(
                    torch.load(args.resume,
                               map_location=lambda storage, loc: storage))

        # create criterion
        self.criterion = MixSoftmaxCrossEntropyOHEMLoss(
            aux=args.aux, aux_weight=args.aux_weight,
            ignore_index=-1).to(args.device)

        # optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        # lr scheduling
        self.lr_scheduler = LRScheduler(mode='poly',
                                        base_lr=args.lr,
                                        nepochs=args.epochs,
                                        iters_per_epoch=len(self.train_loader),
                                        power=0.9)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)

        self.best_pred = 0.0
def train(args, network, train_data, valid_data, optimizer, criterion, device,
          log_path, label2name):
    lr_scheduler = LRScheduler(base_lr=0.01, step=(30, 60), factor=0.1)
    network = network.to(device)
    best_test_acc = -np.inf
    losses = AverageValueMeter()
    acces = AverageValueMeter()
    for epoch in range(120):
        losses.reset()
        acces.reset()
        network.train()

        lr = lr_scheduler.update(epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        # print_str = 'Epoch [%d] learning rate update to %.3e' % (epoch, lr)
        # print(print_str)
        # with open(log_path, 'a') as f: f.write(print_str + '\n')
        tic = time.time()
        for i, data in enumerate(train_data):
            imgs, labels = data
            imgs = imgs.to(device)
            labels = labels.to(device)
            scores = network(imgs)
            loss = criterion(scores, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.add(loss.item())
            acc = (scores.max(1)[1] == labels.long()).float().mean()
            acces.add(acc.item())

            if (i + 1) % args.log_interval == 0:
                loss_mean = losses.value()[0]
                acc_mean = acces.value()[0]
                print_str = 'Epoch[%d] Batch [%d]\tloss=%f\tacc=%f' % (
                    epoch, i + 1, loss_mean, acc_mean)
                print(print_str)
                with open(log_path, 'a') as f:
                    f.write(print_str + '\n')
                btic = time.time()

        loss_mean = losses.value()[0]
        acc_mean = acces.value()[0]

        print_str = '[Epoch %d] Training: loss=%f\tacc=%f\ttime cost: %.3f' % (
            epoch, loss_mean, acc_mean, time.time() - tic)
        print(print_str)
        with open(log_path, 'a') as f:
            f.write(print_str + '\n')

        is_best = False
        if valid_data is not None:
            test_acc = test(network, valid_data, device)
            print_str = '[Epoch %d] test acc: %f' % (epoch, test_acc)
            print(print_str)
            with open(log_path, 'a') as f:
                f.write(print_str + '\n')
            is_best = test_acc > best_test_acc
            if is_best:
                best_test_acc = test_acc
        state_dict = network.state_dict()
        if (epoch + 1) % args.save_step == 0:
            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'epoch': epoch + 1,
                    'label2name': label2name,
                },
                is_best=is_best,
                save_dir=os.path.join(args.save_dir, 'models'),
                filename='model' + '.pth')
예제 #10
0
    loss = loss.cuda()
    cudnn.benchmark = True
    #print(net.named_parameter())
    net = DataParallel(net)
    if os.path.exists(args.resume):
        print('loading checkpoint %s' % (args.resume))
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch'] + 1
        lr = checkpoint['lr']
        best_val_loss = checkpoint['best_val_loss']
        net.load_state_dict(checkpoint['state_dict'])
        log_mode = 'a'

    optimizer = torch.optim.SGD(net.parameters(),
                                lr,
                                momentum=0.9,
                                weight_decay=5e-4)
    lrs = LRScheduler(lr,
                      patience=patience,
                      factor=0.5,
                      min_lr=0.5 * 0.5 * 0.5 * lr,
                      best_loss=best_val_loss)
    with torch.no_grad():
        #val_time, val_iou, val_acc, val_mean_acc = validate(val_loader, net, loss, 0, num_class)
        #val_time, val_iou_480, val_acc, val_mean_acc = validate(val_loader_480, net, loss, 0, num_class)
        #val_time, val_iou_840, val_acc, val_mean_acc = validate(val_loader_840, net, loss, 0, num_class)
        val_time, val_iou, val_acc, val_mean_acc = validate(
            val_loader, val_loader_2048, val_loader_480, val_loader_840, net,
            loss, 0, num_class)
    print(val_iou, val_acc, val_mean_acc)
def main():
    parser = argparse.ArgumentParser(description='model training')
    parser.add_argument('--save_dir',
                        type=str,
                        default='logs/tmp',
                        help='save model directory')
    # transforms
    parser.add_argument('--train_size',
                        type=int,
                        default=[224],
                        nargs='+',
                        help='train image size')
    parser.add_argument('--test_size',
                        type=int,
                        default=[224, 224],
                        nargs='+',
                        help='test image size')
    parser.add_argument('--h_filp',
                        action='store_true',
                        help='do horizontal flip')
    # dataset
    parser.add_argument('--dataset_dir',
                        type=str,
                        default='datasets',
                        help='datasets path')
    parser.add_argument('--valid_pect',
                        type=float,
                        default=0.2,
                        help='validation percent split from train')
    parser.add_argument('--train_bs',
                        type=int,
                        default=64,
                        help='train images per batch')
    parser.add_argument('--test_bs',
                        type=int,
                        default=128,
                        help='test images per batch')
    # training
    parser.add_argument('--no_gpu',
                        action='store_true',
                        help='whether use gpu')
    parser.add_argument('--gpus',
                        type=str,
                        default='0',
                        help='gpus to use in training')
    parser.add_argument('--opt_func',
                        type=str,
                        default='Adam',
                        help='optimizer function')
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        help='base learning rate')
    parser.add_argument('--steps',
                        type=int,
                        default=(60, 90),
                        nargs='+',
                        help='learning rate decay strategy')
    parser.add_argument('--factor',
                        type=float,
                        default=0.1,
                        help='learning rate decay factor')
    parser.add_argument('--wd', type=float, default=5e-4, help='weight decay')
    parser.add_argument('--momentum', default=0.9, help='training momentum')
    parser.add_argument('--max_epoch',
                        type=int,
                        default=120,
                        help='number of training epochs')
    parser.add_argument('--log_interval',
                        type=int,
                        default=50,
                        help='intermediate printing')
    parser.add_argument('--save_step',
                        type=int,
                        default=20,
                        help='save model every save_step')

    args = parser.parse_args()

    mkdir_if_missing(args.save_dir)
    log_path = os.path.join(args.save_dir, 'log.txt')
    with open(log_path, 'w') as f:
        f.write('{}'.format(args))

    device = "cuda:{}".format(args.gpus) if not args.no_gpu else "cpu"
    if not args.no_gpu:
        cudnn.benchmark = True

    # define train transforms and test transforms
    totensor = T.ToTensor()
    normalize = T.Normalize(mean=[0.491, 0.482, 0.446],
                            std=[0.202, 0.199, 0.201])
    train_tfms = list()
    train_size = args.train_size[0] if len(
        args.train_size) == 1 else args.train_size
    train_tfms.append(T.RandomResizedCrop(train_size))
    if args.h_filp:
        train_tfms.append(T.RandomHorizontalFlip())
    train_tfms.append(totensor)
    train_tfms.append(normalize)
    train_tfms = T.Compose(train_tfms)

    test_tfms = list()
    test_size = (args.test_size[0], args.test_size[0]) if len(
        args.test_size) == 1 else args.test_size
    test_tfms.append(T.Resize(test_size))
    test_tfms.append(totensor)
    test_tfms.append(normalize)
    test_tfms = T.Compose(test_tfms)

    # get dataloader
    train_list, valid_list, label2name = split_dataset(args.dataset_dir,
                                                       args.valid_pect)
    trainset = ImageDataset(train_list, train_tfms)
    validset = ImageDataset(valid_list, test_tfms)

    train_loader = DataLoader(trainset,
                              batch_size=args.train_bs,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    valid_loader = DataLoader(validset,
                              batch_size=args.test_bs,
                              shuffle=False,
                              num_workers=8,
                              pin_memory=True)

    # define network
    net = get_resnet50(len(label2name), pretrain=True)
    # layer_groups = [nn.Sequential(*flatten_model(net))]

    # define loss
    ce_loss = nn.CrossEntropyLoss()

    # define optimizer and lr scheduler
    if args.opt_func == 'Adam':
        optimizer = getattr(torch.optim, args.opt_func)(net.parameters(),
                                                        weight_decay=args.wd)
    else:
        optimizer = getattr(torch.optim, args.opt_func)(net.parameters(),
                                                        weight_decay=args.wd,
                                                        momentum=args.momentum)
    lr_scheduler = LRScheduler(base_lr=args.lr,
                               step=args.steps,
                               factor=args.factor)

    train(
        args=args,
        network=net,
        train_data=train_loader,
        valid_data=valid_loader,
        optimizer=optimizer,
        criterion=ce_loss,
        lr_scheduler=lr_scheduler,
        device=device,
        log_path=log_path,
        label2name=label2name,
    )
예제 #12
0
class Trainer(object):
    def __init__(self, settings: dict, settings_to_log: list):
        self.settings = settings
        self.settings_to_log = settings_to_log

        self.threshold = self.settings['threshold']
        self.start_epoch = self.settings['start_epoch']
        self.dataset = self.settings['dataset']
        self.batch_size = self.settings['batch_size']
        self.workers = self.settings['workers']
        self.cuda = self.settings['cuda']
        self.fp16 = self.settings['fp16']
        self.epochs = self.settings['epochs']
        self.ignore_index = self.settings['ignore_index']
        self.loss_reduction = self.settings['loss_reduction']

        # -------------------- Define Data loader ------------------------------
        self.loaders, self.nclass, self.plotter = make_data_loader(settings)
        self.train_loader, self.val_loader, self.test_loader = [self.loaders[key] for key in ['train', 'val', 'test']]

        # -------------------- Define model ------------------------------------
        self.model = get_model(self.settings)

        # -------------------- Define optimizer and its options ----------------
        self.optimizer = define_optimizer(self.model, self.settings['optimizer'], self.settings['optimizer_params'])
        if self.settings['lr_scheduler']:
            self.lr_scheduler = LRScheduler(self.settings['lr_scheduler'], self.optimizer, self.batch_size)

        # -------------------- Define loss -------------------------------------
        input_size = (self.batch_size, self.nclass, *self.settings['target_size'])
        self.criterion = CustomLoss(input_size=input_size, ignore_index=self.ignore_index, reduction=self.loss_reduction)

        self.evaluator = Evaluator(metrics=self.settings['metrics'], num_class=self.nclass, threshold=self.settings['threshold'])

        self.logger = MainLogger(loggers=self.settings['loggers'], settings=settings, settings_to_log=settings_to_log)
        if self.settings['resume']:
            self.resume_checkpoint(self.settings['resume'])

        self.metric_to_watch = 0.0

    def activation(self, output):
        if self.nclass == 1:
            output = torch.sigmoid(output)
        else:
            output = torch.softmax(output, dim=1)
        return output

    def prepare_inputs(self, *inputs):
        if self.settings['cuda']:
            inputs = [i.cuda() for i in inputs]
        if self.settings['fp16']:
            inputs = [i.half() for i in inputs]
        return inputs

    def training(self, epoch: int):
        """
        Training loop for a certain epoch
        :param epoch: epoch id
        :return:
        """
        self.evaluator.reset()
        self.model.train()
        tbar = tqdm(self.train_loader, desc='train', file=sys.stdout)
        train_loss = 0.0
        output = {}
        for i, sample in enumerate(tbar):
            img, target = self.prepare_inputs(sample['image'], sample['label'])
            img, target, perm_target, gamma = random_joint_mix(img, target, self.settings['CutMix'], self.settings['MixUp'], p=self.settings['MixP'])

            self.optimizer.zero_grad()
            output['pred'], output['pred8'], output['pred16'] = self.model(img)

            if self.settings['MixUp'] or self.settings['CutMix']:
                loss = mix_criterion(self.criterion.train_loss, output, tgt_a=target, tgt_b=perm_target, gamma=gamma)
            else:
                loss = self.criterion.train_loss(**output, target=target)
            loss.backward()

            self.optimizer.step()
            train_loss += loss.item()

            if self.settings['lr_scheduler']:
                self.lr_scheduler(i, epoch, self.metric_to_watch)

            out = self.activation(output['pred'])
            self.evaluator.add_batch(out, target)
            tbar.set_description('Train loss: %.4f, Epoch: %d' % (train_loss / float(i + 1), epoch))

            self.logger.log_metric(metric_tuple=('TRAIN_LOSS', (train_loss / float(i + 1))))
        _ = self.evaluator.eval_metrics(reduction=self.settings['evaluator_reduction'], show=True)

    def validation(self, epoch: int):
        """
        Validation loop for a certain epoch
        :param epoch: epoch id
        :return:
        """
        self.evaluator.reset()
        self.model.eval()
        if self.settings['validation_only']:
            loader = self.loaders[self.settings['validation_only']]
        else:
            loader = self.val_loader
        tbar = tqdm(loader, desc='valid', file=sys.stdout)
        test_loss = 0.0
        with torch.no_grad():
            for i, sample in enumerate(tbar):
                img, target = self.prepare_inputs(sample['image'], sample['label'])

                output = self.model(img)

                loss = self.criterion.val_loss(pred=output, target=target)
                test_loss += loss.item()

                output = self.activation(output)
                self.evaluator.add_batch(output, target)
                tbar.set_description('Validation loss: %.3f, Epoch: %d' % (test_loss / (i + 1), epoch))

                if self.settings['log_artifacts']:
                    self.log_artifacts(epoch=epoch, sample=sample, output=output)

                self.logger.log_metric(metric_tuple=('VAL_LOSS', test_loss / (i + 1)))
        metrics_dict = self.evaluator.eval_metrics(reduction=self.settings['evaluator_reduction'], show=True)
        metrics_dict['val_loss'] = test_loss / (i + 1)
        self.metric_to_watch = metrics_dict[self.settings['metric_to_watch']].mean()
        if not self.settings['validation_only']:
            self.save_checkpoint(epoch=epoch, metrics_dict=metrics_dict)

    def save_checkpoint(self, epoch, metrics_dict):
        state = {
            'epoch': epoch + 1,
            'state_dict': self.model.module.state_dict() if self.cuda else self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'metrics': metrics_dict,
            'scheduler': self.lr_scheduler.state_dict() if self.settings['lr_scheduler'] else None,
        }
        self.logger.log_metrics(self.settings['metrics'], metrics_dict, epoch=epoch)
        self.logger.log_checkpoint(state, key_metric=self.metric_to_watch, filename=self.settings['check_suffix'])

    def log_artifacts(self, sample, output, epoch):
        last_epoch = epoch == (self.settings['epochs'] - 1)
        if epoch % self.settings['log_dilate'] == 0 or last_epoch:
            sample['image'] = denormalize_image(sample['image'], **self.settings['normalize_params'])
            image, target, output = tensors_to_numpy(sample['image'], sample['label'], output)
            for ind, value in enumerate(sample['id']):
                if value in self.settings['inputs_to_watch']:
                    fig = self.plotter(image[ind], output[ind], target[ind],
                                       alpha=0.4, threshold=self.threshold, show=self.settings['show_results'])
                    self.logger.log_artifact(artifact=fig, epoch=epoch, name=value.replace('_leftImg8bit', ''))
                    plt.close()

    def resume_checkpoint(self, resume):
        if not os.path.isfile(resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(resume))
        checkpoint = torch.load(resume)
        self.start_epoch = checkpoint['epoch']
        if self.cuda:
            self.model.module.load_state_dict(checkpoint['state_dict'], strict=True)
        else:
            self.model.load_state_dict(checkpoint['state_dict'], strict=True)
        if not self.settings['fine_tuning']:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            if checkpoint['scheduler']:
                self.lr_scheduler.load_state_dict(checkpoint['scheduler'])
        self.metric_to_watch = checkpoint['best_pred']
        print("=> loaded checkpoint '{}' (epoch: {}, best_metric: {:.4f})"
              .format(resume, checkpoint['epoch'], self.metric_to_watch))

    def close(self):
        fig = plot_confusion_matrix(self.evaluator.confusion_matrix, normalize=True, title=None, cmap=plt.cm.Blues, show=False)
        self.logger.log_artifact(fig, epoch=-1, name='confusion_matrix.png')
        self.logger.close()
예제 #13
0
    def __init__(self, args):
        self.args = args

        # define saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # define tensorboard summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.nclass, self.inchannel = make_data_loader(
            args, **kwargs)

        # define network
        model, parameters = build_baseline(args.baseline,
                                           n_channels=self.inchannel,
                                           n_classes=self.nclass,
                                           base_lr=args.lr)
        #stat(model, (1, 512, 512))
        #sys.exit()

        # define Optimizer
        optimizer = torch.optim.SGD(parameters,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # define Criterion
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weights_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        print(model)

        # define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # define lr scheduler
        self.scheduler = LRScheduler(args.lr_scheduler, args.lr, args.epochs,
                                     len(self.train_loader))

        # using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            self.model = self.model.cuda()

        # resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0