示例#1
0
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'

    if args.dataset == 'greenhouse':
        print(args.use_depth)
        from data_loader.segmentation.greenhouse import GreenhouseRGBDSegCls, GREENHOUSE_CLASS_LIST
        train_dataset = GreenhouseRGBDSegCls(
            root=args.data_path,
            list_name='train_greenhouse_mult.txt',
            train=True,
            size=crop_size,
            scale=args.scale,
            use_depth=args.use_depth)
        val_dataset = GreenhouseRGBDSegCls(root=args.data_path,
                                           list_name='val_greenhouse_mult.txt',
                                           train=False,
                                           size=crop_size,
                                           scale=args.scale,
                                           use_depth=args.use_depth)
        class_weights = np.load('class_weights.npy')[:4]
        print(class_weights)
        class_wts = torch.from_numpy(class_weights).float().to(device)

        seg_classes = len(GREENHOUSE_CLASS_LIST)
        color_encoding = OrderedDict([('end_of_plant', (0, 255, 0)),
                                      ('other_part_of_plant', (0, 255, 255)),
                                      ('artificial_objects', (255, 0, 0)),
                                      ('ground', (255, 255, 0)),
                                      ('background', (0, 0, 0))])
    else:
        print_error_message('Dataset: {} not yet supported'.format(
            args.dataset))
        exit(-1)

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    if args.model == 'espdnet':
        from model.segmentation.espdnet_mult import espdnet_mult
        args.classes = seg_classes
        args.cls_classes = 5
        model = espdnet_mult(args)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    if args.finetune:
        if os.path.isfile(args.finetune):
            print_info_message('Loading weights for finetuning from {}'.format(
                args.finetune))
            weight_dict = torch.load(args.finetune,
                                     map_location=torch.device(device='cpu'))
            model.load_state_dict(weight_dict)
            print_info_message('Done')
        else:
            print_warning_message('No file for finetuning. Please check.')

    if args.freeze_bn:
        print_info_message('Freezing batch normalization layers')
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False

    if args.use_depth:
        train_params = [{
            'params': model.get_basenet_params(),
            'lr': args.lr
        }, {
            'params': model.get_segment_params(),
            'lr': args.lr * args.lr_mult
        }, {
            'params': model.get_depth_encoder_params(),
            'lr': args.lr
        }]
    else:
        train_params = [{
            'params': model.get_basenet_params(),
            'lr': args.lr
        }, {
            'params': model.get_segment_params(),
            'lr': args.lr * args.lr_mult
        }]

    optimizer = optim.SGD(train_params,
                          lr=args.lr * args.lr_mult,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 3, crop_size[0], crop_size[1]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            crop_size[0], crop_size[1], flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')
    try:
        writer.add_graph(model,
                         input_to_model=torch.Tensor(1, 3, crop_size[0],
                                                     crop_size[1]))
    except:
        print_log_message(
            "Not able to generate the graph. Likely because your model is not supported by ONNX"
        )

    start_epoch = 0
    best_miou = 0.0
    if args.resume:
        if os.path.isfile(args.resume):
            print_info_message("=> loading checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            start_epoch = checkpoint['epoch']
            best_miou = checkpoint['best_miou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_info_message("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print_warning_message("=> no checkpoint found at '{}'".format(
                args.resume))

    print('device : ' + device)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    cls_class_weight = calc_cls_class_weight(train_loader, 5)
    print(cls_class_weight)

    #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
    criterion_seg = SegmentationLoss(n_classes=seg_classes,
                                     loss_type=args.loss_type,
                                     device=device,
                                     ignore_idx=args.ignore_idx,
                                     class_wts=class_wts.to(device))

    criterion_cls = nn.CrossEntropyLoss(
        weight=torch.from_numpy(cls_class_weight).float().to(device))

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion_seg = criterion_seg.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion_seg = DataParallelCriteria(criterion_seg)
            criterion_seg = criterion_seg.cuda()

        criterion_cls = criterion_cls.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    if args.scheduler == 'fixed':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import FixedMultiStepLR
        lr_scheduler = FixedMultiStepLR(base_lr=args.lr,
                                        steps=step_sizes,
                                        gamma=args.lr_decay)
    elif args.scheduler == 'clr':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import CyclicLR
        lr_scheduler = CyclicLR(min_lr=args.lr,
                                cycle_len=5,
                                steps=step_sizes,
                                gamma=args.lr_decay)
    elif args.scheduler == 'poly':
        from utilities.lr_scheduler import PolyLR
        lr_scheduler = PolyLR(base_lr=args.lr,
                              max_epochs=args.epochs,
                              power=args.power)
    elif args.scheduler == 'hybrid':
        from utilities.lr_scheduler import HybirdLR
        lr_scheduler = HybirdLR(base_lr=args.lr,
                                max_epochs=args.epochs,
                                clr_max=args.clr_max,
                                cycle_len=args.cycle_len)
    elif args.scheduler == 'linear':
        from utilities.lr_scheduler import LinearLR
        lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs)
    else:
        print_error_message('{} scheduler Not supported'.format(
            args.scheduler))
        exit()

    print_info_message(lr_scheduler)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_base
        optimizer.param_groups[1]['lr'] = lr_seg
        if args.use_depth:
            optimizer.param_groups[2]['lr'] = lr_base

        print_info_message(
            'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
            .format(epoch, lr_base, lr_seg))
        miou_train, train_loss, train_seg_loss, train_cls_loss = train(
            model,
            train_loader,
            optimizer,
            criterion_seg,
            seg_classes,
            epoch,
            criterion_cls,
            device=device,
            use_depth=args.use_depth)
        miou_val, val_loss, val_seg_loss, val_cls_loss = val(
            model,
            val_loader,
            criterion_seg,
            criterion_cls,
            seg_classes,
            device=device,
            use_depth=args.use_depth)

        batch = iter(val_loader).next()
        if args.use_depth:
            in_training_visualization_2(model,
                                        images=batch[0].to(device=device),
                                        depths=batch[2].to(device=device),
                                        labels=batch[1].to(device=device),
                                        class_encoding=color_encoding,
                                        writer=writer,
                                        epoch=epoch,
                                        data='Segmentation',
                                        device=device)
        else:
            in_training_visualization_2(model,
                                        images=batch[0].to(device=device),
                                        labels=batch[1].to(device=device),
                                        class_encoding=color_encoding,
                                        writer=writer,
                                        epoch=epoch,
                                        data='Segmentation',
                                        device=device)


#            image_grid = torchvision.utils.make_grid(outputs.data.cpu()).numpy()
#            writer.add_image('Segmentation/results/val', image_grid, epoch)

# remember best miou and save checkpoint
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)

        writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch)
        writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch)
        writer.add_scalar('Segmentation/Loss/train', train_loss, epoch)
        writer.add_scalar('Segmentation/SegLoss/train', train_seg_loss, epoch)
        writer.add_scalar('Segmentation/ClsLoss/train', train_cls_loss, epoch)
        writer.add_scalar('Segmentation/Loss/val', val_loss, epoch)
        writer.add_scalar('Segmentation/SegLoss/val', val_seg_loss, epoch)
        writer.add_scalar('Segmentation/ClsLoss/val', val_cls_loss, epoch)
        writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch)
        writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch)
        writer.add_scalar('Segmentation/Complexity/Flops', best_miou,
                          math.ceil(flops))
        writer.add_scalar('Segmentation/Complexity/Params', best_miou,
                          math.ceil(num_params))

    writer.close()
示例#2
0
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    if args.dataset == 'pascal':
        from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST
        train_dataset = VOCSegmentation(root=args.data_path,
                                        train=True,
                                        crop_size=crop_size,
                                        scale=args.scale,
                                        coco_root_dir=args.coco_path)
        val_dataset = VOCSegmentation(root=args.data_path,
                                      train=False,
                                      crop_size=crop_size,
                                      scale=args.scale)
        seg_classes = len(VOC_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'city':
        from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST
        train_dataset = CityscapesSegmentation(root=args.data_path,
                                               train=True,
                                               size=crop_size,
                                               scale=args.scale,
                                               coarse=args.coarse)
        val_dataset = CityscapesSegmentation(root=args.data_path,
                                             train=False,
                                             size=crop_size,
                                             scale=args.scale,
                                             coarse=False)
        seg_classes = len(CITYSCAPE_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
        class_wts[0] = 2.8149201869965
        class_wts[1] = 6.9850029945374
        class_wts[2] = 3.7890393733978
        class_wts[3] = 9.9428062438965
        class_wts[4] = 9.7702074050903
        class_wts[5] = 9.5110931396484
        class_wts[6] = 10.311357498169
        class_wts[7] = 10.026463508606
        class_wts[8] = 4.6323022842407
        class_wts[9] = 9.5608062744141
        class_wts[10] = 7.8698215484619
        class_wts[11] = 9.5168733596802
        class_wts[12] = 10.373730659485
        class_wts[13] = 6.6616044044495
        class_wts[14] = 10.260489463806
        class_wts[15] = 10.287888526917
        class_wts[16] = 10.289801597595
        class_wts[17] = 10.405355453491
        class_wts[18] = 10.138095855713
        class_wts[19] = 0.0
    else:
        print_error_message('Dataset: {} not yet supported'.format(
            args.dataset))
        exit(-1)

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    if args.model == 'espnetv2':
        from model.segmentation.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'dicenet':
        from model.segmentation.dicenet import dicenet_seg
        model = dicenet_seg(args, classes=seg_classes)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    if args.finetune:
        if os.path.isfile(args.finetune):
            print_info_message('Loading weights for finetuning from {}'.format(
                args.finetune))
            weight_dict = torch.load(args.finetune,
                                     map_location=torch.device(device='cpu'))
            model.load_state_dict(weight_dict)
            print_info_message('Done')
        else:
            print_warning_message('No file for finetuning. Please check.')

    if args.freeze_bn:
        print_info_message('Freezing batch normalization layers')
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'

    train_params = [{
        'params': model.get_basenet_params(),
        'lr': args.lr
    }, {
        'params': model.get_segment_params(),
        'lr': args.lr * args.lr_mult
    }]

    optimizer = optim.SGD(train_params,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 3, crop_size[0], crop_size[1]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            crop_size[0], crop_size[1], flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')
    try:
        writer.add_graph(model,
                         input_to_model=torch.Tensor(1, 3, crop_size[0],
                                                     crop_size[1]))
    except:
        print_log_message(
            "Not able to generate the graph. Likely because your model is not supported by ONNX"
        )

    start_epoch = 0
    best_miou = 0.0
    if args.resume:
        if os.path.isfile(args.resume):
            print_info_message("=> loading checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            start_epoch = checkpoint['epoch']
            best_miou = checkpoint['best_miou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_info_message("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print_warning_message("=> no checkpoint found at '{}'".format(
                args.resume))

    #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
    criterion = SegmentationLoss(n_classes=seg_classes,
                                 loss_type=args.loss_type,
                                 device=device,
                                 ignore_idx=args.ignore_idx,
                                 class_wts=class_wts.to(device))

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    if args.scheduler == 'fixed':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import FixedMultiStepLR
        lr_scheduler = FixedMultiStepLR(base_lr=args.lr,
                                        steps=step_sizes,
                                        gamma=args.lr_decay)
    elif args.scheduler == 'clr':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import CyclicLR
        lr_scheduler = CyclicLR(min_lr=args.lr,
                                cycle_len=5,
                                steps=step_sizes,
                                gamma=args.lr_decay)
    elif args.scheduler == 'poly':
        from utilities.lr_scheduler import PolyLR
        lr_scheduler = PolyLR(base_lr=args.lr,
                              max_epochs=args.epochs,
                              power=args.power)
    elif args.scheduler == 'hybrid':
        from utilities.lr_scheduler import HybirdLR
        lr_scheduler = HybirdLR(base_lr=args.lr,
                                max_epochs=args.epochs,
                                clr_max=args.clr_max,
                                cycle_len=args.cycle_len)
    elif args.scheduler == 'linear':
        from utilities.lr_scheduler import LinearLR
        lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs)
    else:
        print_error_message('{} scheduler Not supported'.format(
            args.scheduler))
        exit()

    print_info_message(lr_scheduler)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_base
        optimizer.param_groups[1]['lr'] = lr_seg

        print_info_message(
            'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
            .format(epoch, lr_base, lr_seg))
        miou_train, train_loss = train(model,
                                       train_loader,
                                       optimizer,
                                       criterion,
                                       seg_classes,
                                       epoch,
                                       device=device)
        miou_val, val_loss = val(model,
                                 val_loader,
                                 criterion,
                                 seg_classes,
                                 device=device)

        # remember best miou and save checkpoint
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)

        writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch)
        writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch)
        writer.add_scalar('Segmentation/Loss/train', train_loss, epoch)
        writer.add_scalar('Segmentation/Loss/val', val_loss, epoch)
        writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch)
        writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch)
        writer.add_scalar('Segmentation/Complexity/Flops', best_miou,
                          math.ceil(flops))
        writer.add_scalar('Segmentation/Complexity/Params', best_miou,
                          math.ceil(num_params))

    writer.close()
示例#3
0
class SSLOnlineEvaluator(Callback):  # pragma: no cover
    """Attaches a MLP for fine-tuning using the standard self-supervised protocol.

    Example::

        # your datamodule must have 2 attributes
        dm = DataModule()
        dm.num_classes = ... # the num of classes in the datamodule
        dm.name = ... # name of the datamodule (e.g. ImageNet, STL10, CIFAR10)

        # your model must have 1 attribute
        model = Model()
        model.z_dim = ... # the representation dim

        online_eval = SSLOnlineEvaluator(
            z_dim=model.z_dim
        )
    """
    def __init__(
        self,
        z_dim: int,
        drop_p: float = 0.2,
        hidden_dim: Optional[int] = None,
        num_classes: Optional[int] = None,
        dataset: Optional[str] = None,
    ):
        """
        Args:
            z_dim: Representation dimension
            drop_p: Dropout probability
            hidden_dim: Hidden dimension for the fine-tune MLP
        """
        super().__init__()

        self.z_dim = z_dim
        self.hidden_dim = hidden_dim
        self.drop_p = drop_p

        self.optimizer: Optional[Optimizer] = None
        self.online_evaluator: Optional[SSLEvaluator] = None
        self.num_classes: Optional[int] = None
        self.dataset: Optional[str] = None
        self.num_classes: Optional[int] = num_classes
        self.dataset: Optional[str] = dataset

        self._recovered_callback_state: Optional[Dict[str, Any]] = None

    def setup(self,
              trainer: Trainer,
              pl_module: LightningModule,
              stage: Optional[str] = None) -> None:
        if self.num_classes is None:
            self.num_classes = trainer.datamodule.num_classes
        if self.dataset is None:
            self.dataset = trainer.datamodule.name

    def on_pretrain_routine_start(self, trainer: Trainer,
                                  pl_module: LightningModule) -> None:
        # must move to device after setup, as during setup, pl_module is still on cpu
        self.online_evaluator = SSLEvaluator(
            n_input=self.z_dim,
            n_classes=self.num_classes,
            p=self.drop_p,
            n_hidden=self.hidden_dim,
        ).to(pl_module.device)

        # switch fo PL compatibility reasons
        accel = (trainer.accelerator_connector if hasattr(
            trainer, "accelerator_connector") else
                 trainer._accelerator_connector)
        if accel.is_distributed:
            if accel.use_ddp:
                from torch.nn.parallel import DistributedDataParallel as DDP

                self.online_evaluator = DDP(self.online_evaluator,
                                            device_ids=[pl_module.device])
            elif accel.use_dp:
                from torch.nn.parallel import DataParallel as DP

                self.online_evaluator = DP(self.online_evaluator,
                                           device_ids=[pl_module.device])
            else:
                rank_zero_warn(
                    "Does not support this type of distributed accelerator. The online evaluator will not sync."
                )

        self.optimizer = torch.optim.Adam(self.online_evaluator.parameters(),
                                          lr=1e-4)

        if self._recovered_callback_state is not None:
            self.online_evaluator.load_state_dict(
                self._recovered_callback_state["state_dict"])
            self.optimizer.load_state_dict(
                self._recovered_callback_state["optimizer_state"])

    def to_device(self, batch: Sequence,
                  device: Union[str, torch.device]) -> Tuple[Tensor, Tensor]:
        # get the labeled batch
        if self.dataset == "stl10":
            labeled_batch = batch[1]
            batch = labeled_batch

        inputs, y = batch

        # last input is for online eval
        x = inputs[-1]
        x = x.to(device)
        y = y.to(device)

        return x, y

    def shared_step(
        self,
        pl_module: LightningModule,
        batch: Sequence,
    ):
        with torch.no_grad():
            with set_training(pl_module, False):
                x, y = self.to_device(batch, pl_module.device)
                representations = pl_module(x).flatten(start_dim=1)

        # forward pass
        mlp_logits = self.online_evaluator(
            representations)  # type: ignore[operator]
        mlp_loss = F.cross_entropy(mlp_logits, y)

        acc = accuracy(mlp_logits.softmax(-1), y)

        return acc, mlp_loss

    def on_train_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        train_acc, mlp_loss = self.shared_step(pl_module, batch)

        # update finetune weights
        mlp_loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        pl_module.log("online_train_acc",
                      train_acc,
                      on_step=True,
                      on_epoch=False)
        pl_module.log("online_train_loss",
                      mlp_loss,
                      on_step=True,
                      on_epoch=False)

    def on_validation_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        val_acc, mlp_loss = self.shared_step(pl_module, batch)
        pl_module.log("online_val_acc",
                      val_acc,
                      on_step=False,
                      on_epoch=True,
                      sync_dist=True)
        pl_module.log("online_val_loss",
                      mlp_loss,
                      on_step=False,
                      on_epoch=True,
                      sync_dist=True)

    def on_save_checkpoint(self, trainer: Trainer, pl_module: LightningModule,
                           checkpoint: Dict[str, Any]) -> dict:
        return {
            "state_dict": self.online_evaluator.state_dict(),
            "optimizer_state": self.optimizer.state_dict()
        }

    def on_load_checkpoint(self, trainer: Trainer, pl_module: LightningModule,
                           callback_state: Dict[str, Any]) -> None:
        self._recovered_callback_state = callback_state
示例#4
0
class BaseTrainer(object):
    def __init__(self,
                 epochs,
                 model,
                 train_dataloader,
                 train_loss_func,
                 train_metrics_func,
                 optimizer,
                 log_dir,
                 checkpoint_dir,
                 checkpoint_frequency,
                 checkpoint_restore=None,
                 val_dataloader=None,
                 val_metrics_func=None,
                 lr_scheduler=None,
                 lr_reduce_metric=None,
                 use_gpu=False,
                 gpu_ids=None):
        # train settings
        self.epochs = epochs
        self.model = model
        self.train_dataloader = train_dataloader
        self.train_loss_func = train_loss_func
        self.train_metrics_func = train_metrics_func
        self.optimizer = optimizer
        self.checkpoint_dir = checkpoint_dir
        self.checkpoint_frequency = checkpoint_frequency
        self.writer = SummaryWriter(logdir=log_dir)

        # validation settings
        if val_dataloader is not None:
            self.validation = True
            self.val_dataloader = val_dataloader
            self.val_metrics_func = val_metrics_func
        else:
            self.validation = False

        # lr scheduler settings
        if lr_scheduler is not None:
            self.lr_schedule = True
            self.lr_scheduler = lr_scheduler
            if isinstance(lr_scheduler,
                          torch.optim.lr_scheduler.ReduceLROnPlateau):
                self.lr_reduce_metric = lr_reduce_metric
        else:
            self.lr_schedule = False

        # multi-gpu settings
        self.use_gpu = use_gpu
        gpu_visible = list()
        for index in range(len(gpu_ids)):
            gpu_visible.append(str(gpu_ids[index]))
        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpu_visible)

        if use_gpu and torch.cuda.device_count() > 0:
            self.model.cuda()
            if gpu_ids is not None:
                if len(gpu_ids) > 1:
                    self.multi_gpu = True
                    self.model = DataParallel(model, gpu_ids)
                else:
                    self.multi_gpu = False
            else:
                if torch.cuda.device_count() > 1:
                    self.multi_gpu = True
                    self.model = DataParallel(model)
                else:
                    self.multi_gpu = False
        else:
            self.multi_gpu = False
            self.device = torch.device('cpu')
            self.model = self.model.cpu()

        # checkpoint settings
        if checkpoint_restore is not None:
            self.model.load_state_dict(torch.load(checkpoint_restore))

    def train(self):

        for epoch in range(1, self.epochs + 1):
            logging.info('*' * 80)
            logging.info('start epoch %d training loop' % epoch)
            # train
            self.model.train()
            loss, metrics = self.train_epochs(epoch)

            self.writer.add_scalar('train_loss', loss, epoch)
            for key in metrics.keys():
                self.writer.add_scalar(key, metrics[key], epoch)
            if self.lr_schedule:
                if isinstance(self.lr_scheduler,
                              torch.optim.lr_scheduler.ReduceLROnPlateau):
                    self.lr_scheduler.step(loss[self.lr_reduce_metric])
                else:
                    self.lr_scheduler.step()
            logging.info('train loss result: %s' % str(loss))
            logging.info('train metrics result: %s' % str(metrics))

            # validation
            if self.validation:
                logging.info('validation start ... ')
                self.model.eval()
                loss, metrics = self.val_epochs(epoch)
                self.writer.add_scalar('val_loss', loss, epoch)
                for key in metrics.keys():
                    self.writer.add_scalar(key, metrics[key], epoch)
                logging.info('validation loss result: %s' % str(loss))
                logging.info('validation metrics result: %s' % str(metrics))

            # model checkpoint
            if epoch % self.checkpoint_frequency == 0:
                logging.info('saving model...')
                checkpoint_name = 'checkpoint_%d.pth' % epoch
                if self.multi_gpu:
                    torch.save(
                        self.model.module.state_dict(),
                        os.path.join(self.checkpoint_dir, checkpoint_name))
                else:
                    torch.save(
                        self.model.state_dict(),
                        os.path.join(self.checkpoint_dir, checkpoint_name))
                logging.info('model have saved for epoch_%d ' % epoch)
            else:
                logging.info('saving model skipped.')

    def train_epochs(self, epoch) -> (dict, dict):
        """

        :rtype: loss -> dict , metrics -> dict
        """
        pass

    def val_epochs(self, epoch) -> (dict, dict):
        """

        :rtype: loss -> dict , metrics -> dict
        """
        pass
示例#5
0
        else:
            log_str = add_to_log("=> no checkpoint found at '{}'".format(args.load_from)) 
    else:
        start = 0
    
    for epoch in range(start, args.epochs):

        train(epoch, losstype=losstype)
        val_loss = validate(losstype=losstype).cpu()
        scheduler.step(val_loss)
		
        if vloss > val_loss:
            vloss = val_loss
            is_best = epoch+1
       
        for param_group in optimizer.param_groups:
            lr = param_group['lr']

        saver_fn({
            'epoch': epoch + 1,
            'state_dict': net.state_dict(),
            'best_val_loss': vloss,
            'optimizer' : optimizer.state_dict(),
			'learning_rate': lr
        }, is_best) 
        
else:
    test()


示例#6
0
                pic = (torch.cat([
                    a_real_test, b_fake_test, a_rec_test, b_real_test,
                    a_fake_test, b_rec_test
                ],
                                 dim=0).data + 1) / 2.0

                save_dir = './sample_images_while_training/espgan_m2d_lam5/'
                utils.mkdir(save_dir)
                torchvision.utils.save_image(
                    pic,
                    '%sEpoch_(%d)_(%dof%d).jpg' %
                    (save_dir, epoch, i + 1, min(len(a_loader),
                                                 len(b_loader))),
                    nrow=batch_size)
    utils.save_checkpoint(
        {
            'epoch': epoch + 1,
            'Da': Da.state_dict(),
            'Db': Db.state_dict(),
            'Ga': Ga.state_dict(),
            'Gb': Gb.state_dict(),
            'IDE': IDE.state_dict(),
            'da_optimizer': da_optimizer.state_dict(),
            'db_optimizer': db_optimizer.state_dict(),
            'ga_optimizer': ga_optimizer.state_dict(),
            'gb_optimizer': gb_optimizer.state_dict(),
            'IDE_optimizer': IDE_optimizer.state_dict()
        },
        '%sEpoch_(%d).ckpt' % (ckpt_dir, epoch + 1),
        max_keep=6)
示例#7
0
def main(train_path, val_path, labels_path, embedding_vectors_path,
         embedding_word2idx_path, categories_def_path, uncertainty_output_path,
         batch_size, model_snapshot_prefix, pretrained_model_path,
         model_snapshot_interval):
    embedding_vectors = bcolz.open(embedding_vectors_path)[:]
    embedding_dim = len(embedding_vectors[0])
    embedding_word2idx = pickle.load(open(embedding_word2idx_path, 'rb'))
    # Maps words to embedding vectors. These are all embeddings available to us
    embeddings = {
        w: embedding_vectors[embedding_word2idx[w]]
        for w in embedding_word2idx
    }

    # Build vocabulary using training set. Maps words to indices
    vocab = create_vocabulary(train_path)
    vocab_size = len(vocab)
    print(f'Vocabulary size: {vocab_size}\nBatch size: {batch_size}')
    # TODO: take advantage of the multiple annotations
    labels = load_existing_annotations(labels_path,
                                       load_first_annotation_only=True)

    if model_snapshot_interval:
        print(f'Taking model snapshot every {model_snapshot_interval} epochs')
    else:
        print(f'Taking model snapshot ONLY at the end of training')

    humor_types = load_sentences_or_categories(categories_def_path)
    # Map label IDs to indices so that when computing cross entropy we don't operate on raw label IDs
    label_id_to_idx = {
        label_id: idx
        for idx, label_id in enumerate(humor_types)
    }
    word_weight_matrix = create_weight_matrix(vocab, embeddings, device)

    # Stores indexes of sentences provided in the original dataset
    train_labeled_idx, train_labeled_data_unpadded, train_labels, train_unlabeled_idx, train_unlabeled_data_unpadded,\
        longest_sentence_length = load_unpadded_train_val_data(train_path, vocab, labels, label_id_to_idx)
    val_labeled_idx, val_labeled_data_unpadded, val_labels, val_unlabeled_idx, val_unlabeled_data_unpadded,\
        _ = load_unpadded_train_val_data(val_path, vocab, labels, label_id_to_idx)

    # Create padded train and val dataset
    # TODO: Do not use longest length to pad input. Find mean and std
    train_labeled_data = create_padded_data(train_labeled_data_unpadded,
                                            longest_sentence_length)
    val_labeled_data = create_padded_data(val_labeled_data_unpadded,
                                          longest_sentence_length)

    print(
        f'Num of labeled training data: {train_labeled_data.shape[0]}, labeled val: {val_labeled_data.shape[0]}'
    )

    num_iterations = train_labeled_data.shape[0] // batch_size

    textCNN = DataParallel(
        TextCNN(word_weight_matrix, NUM_FILTERS, WINDOW_SIZES,
                len(humor_types))).to(device)
    if pretrained_model_path:
        textCNN.module.initialize_from_pretrained(pretrained_model_path)
    optimizer = torch.optim.Adam(textCNN.parameters(), lr=LR, eps=OPTIM_EPS)

    for i in range(NUM_EPOCHS):
        print(f'Epoch {i}')
        train_one_epoch(
            textCNN,
            create_batch_iterable(train_labeled_data, train_labels, batch_size,
                                  device), optimizer, val_labeled_data,
            val_labels, num_iterations)
        if model_snapshot_prefix:
            if (not model_snapshot_interval and i + 1 == NUM_EPOCHS) or \
                    (model_snapshot_interval and (i + 1) % model_snapshot_interval == 0):
                print('\nSaving model snapshot...')
                torch.save(textCNN.state_dict(),
                           f'{model_snapshot_prefix}_epoch{i}.mdl')
                print('Saved\n')

    if uncertainty_output_path:
        train_unlabeled_data = create_padded_data(
            train_unlabeled_data_unpadded, longest_sentence_length)
        rank_unlabeled_train(
            textCNN,
            torch.tensor(train_unlabeled_data, dtype=torch.long,
                         device=device), train_unlabeled_idx,
            uncertainty_output_path)
def main():

    global args, best_prec1
    args = parser.parse_args()

    # Read list of training and validation data
    listfiles_train, labels_train = read_lists(TRAIN_OUT)
    listfiles_val, labels_val = read_lists(VAL_OUT)
    listfiles_test, labels_test = read_lists(TEST_OUT)
    dataset_train = Dataset(listfiles_train,
                            labels_train,
                            subtract_mean=False,
                            V=12)
    dataset_val = Dataset(listfiles_val, labels_val, subtract_mean=False, V=12)
    dataset_test = Dataset(listfiles_test,
                           labels_test,
                           subtract_mean=False,
                           V=12)

    # shuffle data
    dataset_train.shuffle()
    dataset_val.shuffle()
    dataset_test.shuffle()
    tra_data_size, val_data_size, test_data_size = dataset_train.size(
    ), dataset_val.size(), dataset_test.size()
    print 'training size:', tra_data_size
    print 'validation size:', val_data_size
    print 'testing size:', test_data_size

    batch_size = args.b
    print("batch_size is :" + str(batch_size))
    learning_rate = args.lr
    print("learning_rate is :" + str(learning_rate))
    num_cuda = cuda.device_count()
    print("number of GPUs have been detected:" + str(num_cuda))

    # creat model
    print("model building...")
    mvcnn = DataParallel(modelnet40_Alex(num_cuda, batch_size))
    #mvcnn = modelnet40(num_cuda, batch_size, multi_gpu = False)
    mvcnn.cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint'{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            mvcnn.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    #print(mvcnn)

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.Adadelta(mvcnn.parameters(), weight_decay=1e-4)
    # evaluate performance only
    if args.evaluate:
        print 'testing mode ------------------'
        validate(dataset_test, mvcnn, criterion, optimizer, batch_size)
        return

    print 'training mode ------------------'
    for epoch in xrange(args.start_epoch, args.epochs):
        print('epoch:', epoch)

        #adjust_learning_rate(optimizer, epoch)
        # train for one epoch
        train(dataset_train, mvcnn, criterion, optimizer, epoch, batch_size)

        # evaluate on validation set
        prec1 = validate(dataset_val, mvcnn, criterion, optimizer, batch_size)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        if is_best:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': mvcnn.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best, epoch)
        elif epoch % 5 is 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': mvcnn.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best, epoch)
示例#9
0
class BaseEngine(object):
    def __init__(self, args):
        self._make_dataset(args)
        self._make_model(args)
        tc.manual_seed(args.seed)
        if args.cuda and tc.cuda.is_available():
            tc.cuda.manual_seed_all(args.seed)
            if tc.cuda.device_count() > 1:
                self.batch_size = args.batch_size * tc.cuda.device_count()
                self.model = DataParallel(self.model)
            else:
                self.batch_size = args.batch_size
                self.model = self.model.cuda()
        else:
            self.batch_size = args.batch_size
        self._make_optimizer(args)
        self._make_loss(args)
        self._make_metric(args)
        self.num_training_samples = args.num_training_samples
        self.tag = args.tag or 'default'
        self.dump_dir = get_dir(args.dump_dir)
        self.train_logger = get_logger('train.{}.{}'.format(
            self.__class__.__name__, self.tag))

    def _make_dataset(self, args):
        raise NotImplementedError

    def _make_model(self, args):
        raise NotImplementedError

    def _make_optimizer(self, args):
        raise NotImplementedError

    def _make_loss(self, args):
        raise NotImplementedError

    def _make_metric(self, args):
        raise NotImplementedError

    def dump(self, epoch, model=True, optimizer=True, decayer=True):
        state = {'epoch': epoch}
        if model:
            state['model'] = self.model.state_dict()
        if optimizer:
            state['optimizer'] = self.optimizer.state_dict()
        if decayer and (getattr(self, 'decayer', None) is not None):
            state['decayer'] = self.decayer.state_dict()
        tc.save(state,
                os.path.join(self.dump_dir, 'state_{}.pkl'.format(self.tag)))
        self.train_logger.info('Checkpoint {} dumped'.format(self.tag))

    def load(self, model=True, optimizer=True, decayer=True):
        try:
            state = tc.load(
                os.path.join(self.dump_dir, 'state_{}.pkl'.format(self.tag)))
        except FileNotFoundError:
            return 0
        if model and (state.get('model') is not None):
            self.model.load_state_dict(state['model'])
        if optimizer and (state.get('optimizer') is not None):
            self.optimizer.load_state_dict(state['optimizer'])
        if decayer and (state.get('decayer') is not None) and (getattr(
                self, 'decayer', None) is not None):
            self.decayer.load_state_dict(state['decayer'])
        return state['epoch']

    def eval(self):
        raise NotImplementedError

    def test(self):
        raise NotImplementedError

    def train(self, num_epochs, resume=False):
        raise NotImplementedError
示例#10
0
class Im2latex(BaseAgent):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.device = get_device()
        cfg.device = self.device
        self.cfg = cfg

        # dataset
        train_dataset = Im2LatexDataset(cfg, mode="train")
        self.id2token = train_dataset.id2token
        self.token2id = train_dataset.token2id

        collate = custom_collate(self.token2id, cfg.max_len)

        self.train_loader = DataLoader(train_dataset,
                                       batch_size=cfg.bs,
                                       shuffle=cfg.data_shuffle,
                                       num_workers=cfg.num_w,
                                       collate_fn=collate,
                                       drop_last=True)
        if cfg.valid_img_path != "":
            valid_dataset = Im2LatexDataset(cfg,
                                            mode="valid",
                                            vocab={
                                                'id2token': self.id2token,
                                                'token2id': self.token2id
                                            })
            self.valid_loader = DataLoader(valid_dataset,
                                           batch_size=cfg.bs //
                                           cfg.beam_search_k,
                                           shuffle=cfg.data_shuffle,
                                           num_workers=cfg.num_w,
                                           collate_fn=collate,
                                           drop_last=True)

        # define models
        self.model = Im2LatexModel(cfg)  # fill the parameters
        # weight initialization setting
        for name, param in self.model.named_parameters():
            if 'localization_fc2' in name:
                print(f'Skip {name} as it is already initialized')
                continue
            try:
                if 'bias' in name:
                    torch.nn.init.constant_(param, 0.0)
                elif 'weight' in name:
                    torch.nn.init.kaiming_normal_(param)
            except Exception as e:  # for batchnorm.
                if 'weight' in name:
                    param.data.fill_(1)
                continue

        self.model = DataParallel(self.model)
        # define criterion
        self.criterion = cal_loss

        self.optimizer = torch.optim.Adam(params=self.model.parameters(),
                                          lr=cfg.lr,
                                          betas=(cfg.adam_beta_1,
                                                 cfg.adam_beta_2))

        milestones = cfg.milestones
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer,
                                                              milestones,
                                                              gamma=cfg.gamma,
                                                              verbose=True)

        # initialize counter
        self.current_epoch = 1
        self.current_iteration = 1
        self.best_metric = 100
        self.best_info = ''

        # set the manual seed for torch
        torch.cuda.manual_seed_all(self.cfg.seed)
        if self.cfg.cuda:
            self.model = self.model.to(self.device)
            self.logger.info("Program will run on *****GPU-CUDA***** ")
        else:
            self.logger.info("Program will run on *****CPU*****\n")

        # Model Loading from cfg if not found start from scratch.
        self.exp_dir = os.path.join('./experiments', cfg.exp_name)
        self.load_checkpoint(cfg.checkpoint_filename)
        # Summary Writer
        self.summary_writer = SummaryWriter(
            log_dir=os.path.join(self.exp_dir, 'summaries'))

    def load_checkpoint(self, file_name):
        """
        Latest checkpoint loader
        :param file_name: name of the checkpoint file
        :return:
        """
        try:
            self.logger.info("Loading checkpoint '{}'".format(file_name))
            checkpoint = torch.load(file_name, map_location=self.device)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.model.load_state_dict(checkpoint['model'], strict=False)
            self.optimizer.load_state_dict(checkpoint['optimizer'])

            info = "Checkpoint loaded successfully from "
            self.logger.info(
                info + "'{}' at (epoch {}) at (iteration {})\n".format(
                    file_name, checkpoint['epoch'], checkpoint['iteration']))

        except OSError as e:
            self.logger.info("Checkpoint not found in '{}'.".format(file_name))
            self.logger.info("**First time to train**")

    def save_checkpoint(self, file_name="checkpoint.pth", is_best=False):
        """
        Checkpoint saver
        :param file_name: name of the checkpoint file
        :param is_best: boolean flag to indicate whether current
                        checkpoint's accuracy is the best so far
        :return:
        """
        state = {
            'epoch': self.current_epoch,
            'iteration': self.current_iteration,
            'model': self.model.state_dict(),
            'vocab': self.id2token,
            'optimizer': self.optimizer.state_dict()
        }

        # save the state
        checkpoint_dir = os.path.join(self.exp_dir, 'checkpoints')
        if is_best:
            torch.save(state, os.path.join(checkpoint_dir, 'best.pt'))
            self.best_info = 'best: e{}_i{}'.format(self.current_epoch,
                                                    self.current_iteration)
        else:
            file_name = "e{}-i{}.pt".format(self.current_epoch,
                                            self.current_iteration)
            torch.save(state, os.path.join(checkpoint_dir, file_name))

    def run(self):
        """
        The main operator
        :return:
        """
        try:
            if self.cfg.mode == 'train':
                self.train()
            elif self.cfg.mode == 'predict':
                self.predict()

        except KeyboardInterrupt:
            self.logger.info("You have entered CTRL+C.. Wait to finalize")

    def train(self):
        """
        Main training loop
        :return:
        """
        prev_perplexity = 0
        for e in range(self.current_epoch, self.cfg.epochs + 1):
            this_perplexity = self.train_one_epoch()
            self.save_checkpoint()
            self.scheduler.step()
            self.current_epoch += 1

        if self.cfg.valid_img_path:
            self.validate()

    def train_one_epoch(self):
        """
        One epoch of training
        :return:
        """
        tqdm_bar = tqdm(enumerate(self.train_loader, 1),
                        total=len(self.train_loader))

        self.model.train()
        last_avg_perplexity, avg_perplexity = 0, 0
        for i, (imgs, tgt) in tqdm_bar:
            imgs = imgs.float().to(self.device)
            tgt = tgt.long().to(self.device)

            # [B, MAXLEN, VOCABSIZE]
            logits = self.model(imgs, tgt, is_train=True)

            loss = self.criterion(logits, tgt)
            avg_perplexity += loss.item()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.cfg.grad_clip)
            self.optimizer.step()
            self.current_iteration += 1

            # logging
            if i % self.cfg.log_freq == 0:
                avg_perplexity = avg_perplexity / self.cfg.log_freq
                self.summary_writer.add_scalar(
                    'perplexity/train',
                    avg_perplexity,
                    global_step=self.current_iteration)
                self.summary_writer.add_scalar(
                    'lr',
                    self.scheduler.get_last_lr(),
                    global_step=self.current_iteration)
                tqdm_bar.set_description("e{} | avg_perplexity: {:.3f}".format(
                    self.current_epoch, avg_perplexity))

                # save if best
                if avg_perplexity < self.best_metric:
                    self.save_checkpoint(is_best=True)

                    self.best_metric = avg_perplexity
                last_avg_perplexity = avg_perplexity
                avg_perplexity = 0

        mask = (tgt[0] != 2)
        pred = str(logits[0].argmax(1)[mask].cpu().detach().tolist())
        gt = str(tgt[0][mask].cpu().tolist())
        self.summary_writer.add_text('example/train',
                                     pred + '  \n' + gt,
                                     global_step=self.current_iteration)

        return last_avg_perplexity

    def validate(self):
        """
        One cycle of model validation
        :return:
        """
        tqdm_bar = tqdm(enumerate(self.valid_loader, 1),
                        total=len(self.valid_loader))
        self.model.eval()
        acc = 0
        with torch.no_grad():
            for i, (imgs, tgt) in tqdm_bar:
                imgs = imgs.to(self.device).float()
                tgt = tgt.to(self.device).long()

                logits = self.model(
                    imgs, is_train=False).long()  # [B, MAXLEN, VOCABSIZE]
                # mask = (tgt == 2)
                # tgt[mask] = 1
                # logits[mask] = 1
                acc += torch.all(tgt == logits, dim=1).sum() / imgs.size(0)
                # print('t', tgt)
                # print('l', logits)
                tqdm_bar.set_description('acc {:.4f}'.format(acc / i))
                if i % self.cfg.log_freq == 0:
                    self.summary_writer.add_scalar(
                        'accuracy/valid',
                        acc.item() / i,
                        global_step=self.current_iteration)

    def predict(self):
        """
        get predict results
        :return:
        """
        from torchvision import transforms
        from pathlib import Path
        from PIL import Image
        from time import time

        self.model.eval()
        transform = transforms.ToTensor()
        image_path = Path(self.cfg.test_img_path)

        t = time()
        with torch.no_grad():
            images = []
            imgPath = list(image_path.glob('*.jpg')) + list(
                image_path.glob('*.png'))
            for i, img in enumerate(imgPath):
                print(i, ':', img)
                img = Image.open(img)
                img = transform(img)
                images.append(img)
            images = torch.stack(images, dim=0)
            out = self.model(images)  # [B, max_len, vocab_size]
            # out = out.argmax(2)

        for i, output in enumerate(out):
            print(
                i, ' '.join([
                    self.id2token[out.item()] for out in output
                    if out.item() != 1
                ]))

        print(time() - t)

    def finalize(self):
        """
        Finalizes all the operations of the 2 Main classes of the process,
        the operator and the data loader
        :return:
        """
        print(self.best_info)
        pass
def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    args = parser.parse_args()
    cfg.merge_from_file(args.config_file)
    cfg.freeze()

    viewer = Visualizer(cfg.OUTPUT_DIR)
    #Model
    model = build_model(cfg)
    model = DataParallel(model).cuda()
    if cfg.MODEL.WEIGHT !="":
        model.module.backbone.load_state_dict(torch.load(cfg.MODEL.WEIGHT))
        #freeze backbone
        # for key,val in model.module.backbone.named_parameters():
        #     val.requires_grad = False


    #model lr method
    # params_list = []
    # params_list = group_weight(params_list, model.module.backbone,
    #                            nn.BatchNorm2d, cfg.SOLVER.BASE_LR/10)
    # for module in model.module.business:
    #     params_list = group_weight(params_list, module, nn.BatchNorm2d,
    #                                cfg.SOLVER.BASE_LR)


    batch_time = AverageMeter()
    data_time = AverageMeter()
    #optimizer
    optimizer = getattr(torch.optim,cfg.SOLVER.OPTIM)(model.parameters(),lr = cfg.SOLVER.BASE_LR,weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    lr_sche = torch.optim.lr_scheduler.MultiStepLR(optimizer,cfg.SOLVER.STEPS,gamma= cfg.SOLVER.GAMMA)
    #dataset
    datasets  = make_dataset(cfg)
    dataloaders = make_dataloaders(cfg,datasets,True)
    iter_epoch = (cfg.SOLVER.MAX_ITER)//len(dataloaders[0])+1
    if not os.path.exists(cfg.OUTPUT_DIR):
        os.mkdir(cfg.OUTPUT_DIR)
    ite = 0
    batch_it = [i *cfg.SOLVER.IMS_PER_BATCH for i in range(1,4)]


    # start time
    model.train()
    start = time.time()
    for epoch in tqdm.tqdm(range(iter_epoch),desc="epoch"):
        for dataloader in dataloaders:
            for imgs,labels,types in tqdm.tqdm(dataloader,desc="dataloader:"):
                lr_sche.step()
                data_time.update(time.time() - start)

                inputs = torch.cat([imgs[0].cuda(),imgs[1].cuda(),imgs[2].cuda()],dim=0)
                features = model(inputs)
                acc,loss = loss_opts.batch_triple_loss(features,labels,types,size_average=True)
                optimizer.zero_grad()
                loss.backward()

                optimizer.step()
                ite+=1
                # viewer.line("train/loss",loss.item()*100,ite)
                print(acc,loss)
                batch_time.update(time.time() - start)
                start = time.time()

                print('Epoch: [{0}][{1}/{2}]\n'
                      'Time: {data_time.avg:.4f} ({batch_time.avg:.4f})\n'.format(
                    epoch,ite, len(dataloader),
                    data_time=data_time, batch_time=batch_time),
                    flush=True)

        torch.save(model.state_dict(),os.path.join(cfg.OUTPUT_DIR,"{}_{}.pth".format(cfg.MODEL.META_ARCHITECTURE,epoch)))
示例#12
0
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'
    print('device : ' + device)

    # Get a summary writer for tensorboard
    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')

    #
    # Training the model with 13 classes of CamVid dataset
    # TODO: This process should be done only if specified
    #
    if not args.finetune:
        train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset(
            label_conversion=False)  # 13 classes
        args.use_depth = False  # 'use_depth' is always false for camvid

        print_info_message('Training samples: {}'.format(len(train_dataset)))
        print_info_message('Validation samples: {}'.format(len(val_dataset)))

        # Import model
        if args.model == 'espnetv2':
            from model.segmentation.espnetv2 import espnetv2_seg
            args.classes = seg_classes
            model = espnetv2_seg(args)
        elif args.model == 'espdnet':
            from model.segmentation.espdnet import espdnet_seg
            args.classes = seg_classes
            print("Trainable fusion : {}".format(args.trainable_fusion))
            print("Segmentation classes : {}".format(seg_classes))
            model = espdnet_seg(args)
        elif args.model == 'espdnetue':
            from model.segmentation.espdnet_ue import espdnetue_seg2
            args.classes = seg_classes
            print("Trainable fusion : {}".format(args.trainable_fusion))
            ("Segmentation classes : {}".format(seg_classes))
            print(args.weights)
            model = espdnetue_seg2(args, False, fix_pyr_plane_proj=True)
        else:
            print_error_message('Arch: {} not yet supported'.format(
                args.model))
            exit(-1)

        # Freeze batch normalization layers?
        if args.freeze_bn:
            freeze_bn_layer(model)

        # Set learning rates
        train_params = [{
            'params': model.get_basenet_params(),
            'lr': args.lr
        }, {
            'params': model.get_segment_params(),
            'lr': args.lr * args.lr_mult
        }]

        # Define an optimizer
        optimizer = optim.SGD(train_params,
                              lr=args.lr * args.lr_mult,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

        # Compute the FLOPs and the number of parameters, and display it
        num_params, flops = show_network_stats(model, crop_size)

        try:
            writer.add_graph(model,
                             input_to_model=torch.Tensor(
                                 1, 3, crop_size[0], crop_size[1]))
        except:
            print_log_message(
                "Not able to generate the graph. Likely because your model is not supported by ONNX"
            )

        #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
        criterion = SegmentationLoss(n_classes=seg_classes,
                                     loss_type=args.loss_type,
                                     device=device,
                                     ignore_idx=args.ignore_idx,
                                     class_wts=class_wts.to(device))
        nid_loss = NIDLoss(image_bin=32,
                           label_bin=seg_classes) if args.use_nid else None

        if num_gpus >= 1:
            if num_gpus == 1:
                # for a single GPU, we do not need DataParallel wrapper for Criteria.
                # So, falling back to its internal wrapper
                from torch.nn.parallel import DataParallel
                model = DataParallel(model)
                model = model.cuda()
                criterion = criterion.cuda()
                if args.use_nid:
                    nid_loss.cuda()
            else:
                from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
                model = DataParallelModel(model)
                model = model.cuda()
                criterion = DataParallelCriteria(criterion)
                criterion = criterion.cuda()
                if args.use_nid:
                    nid_loss = DataParallelCriteria(nid_loss)
                    nid_loss = nid_loss.cuda()

            if torch.backends.cudnn.is_available():
                import torch.backends.cudnn as cudnn
                cudnn.benchmark = True
                cudnn.deterministic = True

        # Get data loaders for training and validation data
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   num_workers=args.workers)
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=20,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.workers)

        # Get a learning rate scheduler
        lr_scheduler = get_lr_scheduler(args.scheduler)

        write_stats_to_json(num_params, flops)

        extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
        #
        # Main training loop of 13 classes
        #
        start_epoch = 0
        best_miou = 0.0
        for epoch in range(start_epoch, args.epochs):
            lr_base = lr_scheduler.step(epoch)
            # set the optimizer with the learning rate
            # This can be done inside the MyLRScheduler
            lr_seg = lr_base * args.lr_mult
            optimizer.param_groups[0]['lr'] = lr_base
            optimizer.param_groups[1]['lr'] = lr_seg

            print_info_message(
                'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
                .format(epoch, lr_base, lr_seg))

            # Use different training functions for espdnetue
            if args.model == 'espdnetue':
                from utilities.train_eval_seg import train_seg_ue as train
                from utilities.train_eval_seg import val_seg_ue as val
            else:
                from utilities.train_eval_seg import train_seg as train
                from utilities.train_eval_seg import val_seg as val

            miou_train, train_loss = train(model,
                                           train_loader,
                                           optimizer,
                                           criterion,
                                           seg_classes,
                                           epoch,
                                           device=device,
                                           use_depth=args.use_depth,
                                           add_criterion=nid_loss)
            miou_val, val_loss = val(model,
                                     val_loader,
                                     criterion,
                                     seg_classes,
                                     device=device,
                                     use_depth=args.use_depth,
                                     add_criterion=nid_loss)

            batch_train = iter(train_loader).next()
            batch = iter(val_loader).next()
            in_training_visualization_img(
                model,
                images=batch_train[0].to(device=device),
                labels=batch_train[1].to(device=device),
                class_encoding=color_encoding,
                writer=writer,
                epoch=epoch,
                data='Segmentation/train',
                device=device)
            in_training_visualization_img(model,
                                          images=batch[0].to(device=device),
                                          labels=batch[1].to(device=device),
                                          class_encoding=color_encoding,
                                          writer=writer,
                                          epoch=epoch,
                                          data='Segmentation/val',
                                          device=device)

            # remember best miou and save checkpoint
            is_best = miou_val > best_miou
            best_miou = max(miou_val, best_miou)

            weights_dict = model.module.state_dict(
            ) if device == 'cuda' else model.state_dict()
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.model,
                    'state_dict': weights_dict,
                    'best_miou': best_miou,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.savedir, extra_info_ckpt)

            writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch)
            writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch)
            writer.add_scalar('Segmentation/Loss/train', train_loss, epoch)
            writer.add_scalar('Segmentation/Loss/val', val_loss, epoch)
            writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch)
            writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch)
            writer.add_scalar('Segmentation/Complexity/Flops', best_miou,
                              math.ceil(flops))
            writer.add_scalar('Segmentation/Complexity/Params', best_miou,
                              math.ceil(num_params))

        # Save the pretrained weights
        model_dict = copy.deepcopy(model.state_dict())
        del model
        torch.cuda.empty_cache()

    #
    # Finetuning with 4 classes
    #
    args.ignore_idx = 4
    train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset(
        label_conversion=True)  # 5 classes

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    #set_parameters_for_finetuning()

    # Import model
    if args.model == 'espnetv2':
        from model.segmentation.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'espdnet':
        from model.segmentation.espdnet import espdnet_seg
        args.classes = seg_classes
        print("Trainable fusion : {}".format(args.trainable_fusion))
        print("Segmentation classes : {}".format(seg_classes))
        model = espdnet_seg(args)
    elif args.model == 'espdnetue':
        from model.segmentation.espdnet_ue import espdnetue_seg2
        args.classes = seg_classes
        print("Trainable fusion : {}".format(args.trainable_fusion))
        print("Segmentation classes : {}".format(seg_classes))
        print(args.weights)
        model = espdnetue_seg2(args, args.finetune, fix_pyr_plane_proj=True)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    if not args.finetune:
        new_model_dict = model.state_dict()
        #        for k, v in model_dict.items():
        #            if k.lstrip('module.') in new_model_dict:
        #                print('In:{}'.format(k.lstrip('module.')))
        #            else:
        #                print('Not In:{}'.format(k.lstrip('module.')))
        overlap_dict = {
            k.replace('module.', ''): v
            for k, v in model_dict.items()
            if k.replace('module.', '') in new_model_dict
            and new_model_dict[k.replace('module.', '')].size() == v.size()
        }
        no_overlap_dict = {
            k.replace('module.', ''): v
            for k, v in new_model_dict.items()
            if k.replace('module.', '') not in new_model_dict
            or new_model_dict[k.replace('module.', '')].size() != v.size()
        }
        print(no_overlap_dict.keys())

        new_model_dict.update(overlap_dict)
        model.load_state_dict(new_model_dict)

    output = model(torch.ones(1, 3, 288, 480))
    print(output[0].size())

    print(seg_classes)
    print(class_wts.size())
    #print(model_dict.keys())
    #print(new_model_dict.keys())
    criterion = SegmentationLoss(n_classes=seg_classes,
                                 loss_type=args.loss_type,
                                 device=device,
                                 ignore_idx=args.ignore_idx,
                                 class_wts=class_wts.to(device))
    nid_loss = NIDLoss(image_bin=32,
                       label_bin=seg_classes) if args.use_nid else None

    # Set learning rates
    args.lr /= 100
    train_params = [{
        'params': model.get_basenet_params(),
        'lr': args.lr
    }, {
        'params': model.get_segment_params(),
        'lr': args.lr * args.lr_mult
    }]
    # Define an optimizer
    optimizer = optim.SGD(train_params,
                          lr=args.lr * args.lr_mult,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
            if args.use_nid:
                nid_loss.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()
            if args.use_nid:
                nid_loss = DataParallelCriteria(nid_loss)
                nid_loss = nid_loss.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    # Get data loaders for training and validation data
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=20,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    # Get a learning rate scheduler
    args.epochs = 50
    lr_scheduler = get_lr_scheduler(args.scheduler)

    # Compute the FLOPs and the number of parameters, and display it
    num_params, flops = show_network_stats(model, crop_size)
    write_stats_to_json(num_params, flops)

    extra_info_ckpt = '{}_{}_{}_{}'.format(args.model, seg_classes, args.s,
                                           crop_size[0])
    #
    # Main training loop of 13 classes
    #
    start_epoch = 0
    best_miou = 0.0
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_base
        optimizer.param_groups[1]['lr'] = lr_seg

        print_info_message(
            'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
            .format(epoch, lr_base, lr_seg))

        # Use different training functions for espdnetue
        if args.model == 'espdnetue':
            from utilities.train_eval_seg import train_seg_ue as train
            from utilities.train_eval_seg import val_seg_ue as val
        else:
            from utilities.train_eval_seg import train_seg as train
            from utilities.train_eval_seg import val_seg as val

        miou_train, train_loss = train(model,
                                       train_loader,
                                       optimizer,
                                       criterion,
                                       seg_classes,
                                       epoch,
                                       device=device,
                                       use_depth=args.use_depth,
                                       add_criterion=nid_loss)
        miou_val, val_loss = val(model,
                                 val_loader,
                                 criterion,
                                 seg_classes,
                                 device=device,
                                 use_depth=args.use_depth,
                                 add_criterion=nid_loss)

        batch_train = iter(train_loader).next()
        batch = iter(val_loader).next()
        in_training_visualization_img(model,
                                      images=batch_train[0].to(device=device),
                                      labels=batch_train[1].to(device=device),
                                      class_encoding=color_encoding,
                                      writer=writer,
                                      epoch=epoch,
                                      data='SegmentationConv/train',
                                      device=device)
        in_training_visualization_img(model,
                                      images=batch[0].to(device=device),
                                      labels=batch[1].to(device=device),
                                      class_encoding=color_encoding,
                                      writer=writer,
                                      epoch=epoch,
                                      data='SegmentationConv/val',
                                      device=device)

        # remember best miou and save checkpoint
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)

        writer.add_scalar('SegmentationConv/LR/base', round(lr_base, 6), epoch)
        writer.add_scalar('SegmentationConv/LR/seg', round(lr_seg, 6), epoch)
        writer.add_scalar('SegmentationConv/Loss/train', train_loss, epoch)
        writer.add_scalar('SegmentationConv/Loss/val', val_loss, epoch)
        writer.add_scalar('SegmentationConv/mIOU/train', miou_train, epoch)
        writer.add_scalar('SegmentationConv/mIOU/val', miou_val, epoch)
        writer.add_scalar('SegmentationConv/Complexity/Flops', best_miou,
                          math.ceil(flops))
        writer.add_scalar('SegmentationConv/Complexity/Params', best_miou,
                          math.ceil(num_params))

    writer.close()
示例#13
0
class GANLoss(nn.Module):
    def __init__(self, args):
        super(GANLoss, self).__init__()
        # By default, discriminator will be updated every step

        if args.discriminator == 'discriminator_vgg_128':
            self.netD = Discriminator_VGG_128(in_nc=3, nf=64)
        else:
            raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(args.discriminator))

        if args.pretrained_netD is not None:
            self.netD.load_state_dict(torch.load(args.pretrained_netD))
        if not args.cpu:
            self.netD = self.netD.to('cuda')
        if args.n_GPUs > 1:
            self.netD = DataParallel(self.netD)


        self.save_D_every = args.save_D_every
        if args.save_D_path == '...':
            self.save_D_path = "../experiments/{}/model/".format(args.name)
        else:
            self.save_D_path = args.save_D_path

        # Loss Type
        self.gan_type = args.gan_type
        if self.gan_type == 'gan' or self.gan_type == 'ragan':
            self.loss = nn.BCEWithLogitsLoss()
        elif self.gan_type == 'lsgan':
            self.loss = nn.MSELoss()
        elif self.gan_type == 'wgan-gp':
            def wgan_loss(input, target):
                # target is boolean
                return -1 * input.mean() if target else input.mean()
            self.loss = wgan_loss
        else:
            raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))


        # Optimizer 
        self.optimizer_D = torch.optim.Adam(
            params = self.netD.parameters(),
            lr = args.lr_D,
            betas = (args.beta1_D, args.beta2_D),
            weight_decay = args.weight_decay_D
        )

        

    def get_target_label(self, input, target_is_real):
        if target_is_real:
            return torch.empty_like(input).fill_(1.0) # real_label_val
        else:
            return torch.empty_like(input).fill_(0.0) # fake_label_val

    def update_D(self, loss, step):
        # 根据计算的loss来更新一下D
        
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        loss.backward(retain_graph=True)
        self.optimizer_D.step()
        
        if step % self.save_D_every == 0:
            torch.save(self.netD.state_dict(), "{}/{}.pth".format(self.save_D_path, step))
            print("Discriminatoro saved.")

        for p in self.netD.parameters():
            p.requires_grad = False

    def forward(self, fake, real, step=0, is_train=False):
        # 计算loss值,GAN的LOSS计算是根据当前的输入判断真假
        self.netD.train() if is_train else self.netD.eval()
        # print("REAL SIZE {}".format(real.size())) # torch.Size([32, 3, 112, 96])
        pred_d_real = self.netD(real)
        pred_d_fake = self.netD(fake)

        if self.gan_type == 'gan':
            target_real = self.get_target_label(pred_d_fake, True)
            target_fake = self.get_target_label(pred_d_real, False)
            loss = (self.loss(pred_d_fake, target_real) + self.loss(pred_d_real, target_fake)) / 2
        elif self.gan_type == 'ragan':
            target_real = self.get_target_label(pred_d_fake, True)
            target_fake = self.get_target_label(pred_d_real, False)
            loss = (
                self.loss(pred_d_fake - torch.mean(pred_d_real), target_real) + 
                self.loss(pred_d_real - torch.mean(pred_d_fake), target_fake)
            ) / 2
        loss_copy = torch.tensor(loss.item())
        if is_train:
            self.update_D(loss, step)
        return loss_copy
示例#14
0
def main(args):
    logdir = args.savedir + '/logs/'
    if not os.path.isdir(logdir):
        os.makedirs(logdir)

    my_logger = Logger(60066, logdir)

    if args.dataset == 'pascal':
        crop_size = (512, 512)
        args.scale = (0.5, 2.0)
    elif args.dataset == 'city':
        crop_size = (768, 768)
        args.scale = (0.5, 2.0)

    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[1], crop_size[0], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    if args.dataset == 'pascal':
        from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST
        train_dataset = VOCSegmentation(root=args.data_path,
                                        train=True,
                                        crop_size=crop_size,
                                        scale=args.scale,
                                        coco_root_dir=args.coco_path)
        val_dataset = VOCSegmentation(root=args.data_path,
                                      train=False,
                                      crop_size=crop_size,
                                      scale=args.scale)
        seg_classes = len(VOC_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'city':
        from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST
        train_dataset = CityscapesSegmentation(root=args.data_path,
                                               train=True,
                                               size=crop_size,
                                               scale=args.scale,
                                               coarse=args.coarse)
        val_dataset = CityscapesSegmentation(root=args.data_path,
                                             train=False,
                                             size=crop_size,
                                             scale=args.scale,
                                             coarse=False)
        seg_classes = len(CITYSCAPE_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
        class_wts[0] = 2.8149201869965
        class_wts[1] = 6.9850029945374
        class_wts[2] = 3.7890393733978
        class_wts[3] = 9.9428062438965
        class_wts[4] = 9.7702074050903
        class_wts[5] = 9.5110931396484
        class_wts[6] = 10.311357498169
        class_wts[7] = 10.026463508606
        class_wts[8] = 4.6323022842407
        class_wts[9] = 9.5608062744141
        class_wts[10] = 7.8698215484619
        class_wts[11] = 9.5168733596802
        class_wts[12] = 10.373730659485
        class_wts[13] = 6.6616044044495
        class_wts[14] = 10.260489463806
        class_wts[15] = 10.287888526917
        class_wts[16] = 10.289801597595
        class_wts[17] = 10.405355453491
        class_wts[18] = 10.138095855713
        class_wts[19] = 0.0
    else:
        print_error_message('Dataset: {} not yet supported'.format(
            args.dataset))
        exit(-1)

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    if args.model == 'espnetv2':
        from model.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'espnet':
        from model.espnet import espnet_seg
        args.classes = seg_classes
        model = espnet_seg(args)
    elif args.model == 'mobilenetv2_1_0':
        from model.mobilenetv2 import get_mobilenet_v2_1_0_seg
        args.classes = seg_classes
        model = get_mobilenet_v2_1_0_seg(args)
    elif args.model == 'mobilenetv2_0_35':
        from model.mobilenetv2 import get_mobilenet_v2_0_35_seg
        args.classes = seg_classes
        model = get_mobilenet_v2_0_35_seg(args)
    elif args.model == 'mobilenetv2_0_5':
        from model.mobilenetv2 import get_mobilenet_v2_0_5_seg
        args.classes = seg_classes
        model = get_mobilenet_v2_0_5_seg(args)
    elif args.model == 'mobilenetv3_small':
        from model.mobilenetv3 import get_mobilenet_v3_small_seg
        args.classes = seg_classes
        model = get_mobilenet_v3_small_seg(args)
    elif args.model == 'mobilenetv3_large':
        from model.mobilenetv3 import get_mobilenet_v3_large_seg
        args.classes = seg_classes
        model = get_mobilenet_v3_large_seg(args)
    elif args.model == 'mobilenetv3_RE_small':
        from model.mobilenetv3 import get_mobilenet_v3_RE_small_seg
        args.classes = seg_classes
        model = get_mobilenet_v3_RE_small_seg(args)
    elif args.model == 'mobilenetv3_RE_large':
        from model.mobilenetv3 import get_mobilenet_v3_RE_large_seg
        args.classes = seg_classes
        model = get_mobilenet_v3_RE_large_seg(args)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'

    train_params = []
    params_dict = dict(model.named_parameters())
    others = args.weight_decay * 0.01
    for key, value in params_dict.items():
        if len(value.data.shape) == 4:
            if value.data.shape[1] == 1:
                train_params += [{
                    'params': [value],
                    'lr': args.lr,
                    'weight_decay': 0.0
                }]
            else:
                train_params += [{
                    'params': [value],
                    'lr': args.lr,
                    'weight_decay': args.weight_decay
                }]
        else:
            train_params += [{
                'params': [value],
                'lr': args.lr,
                'weight_decay': others
            }]

    args.learning_rate = args.lr
    optimizer = get_optimizer(args.optimizer, train_params, args)
    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 3, crop_size[1], crop_size[0]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            crop_size[1], crop_size[0], flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    start_epoch = 0
    epochs_len = args.epochs
    best_miou = 0.0

    #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
    criterion = SegmentationLoss(n_classes=seg_classes,
                                 loss_type=args.loss_type,
                                 device=device,
                                 ignore_idx=args.ignore_idx,
                                 class_wts=class_wts.to(device))

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers,
                                               drop_last=True)
    if args.dataset == 'city':
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.workers,
                                                 drop_last=True)
    else:
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.workers,
                                                 drop_last=True)

    lr_scheduler = get_lr_scheduler(args)

    print_info_message(lr_scheduler)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])

    if args.fp_epochs > 0:
        print_info_message("========== MODEL FP WARMUP ===========")

        for epoch in range(args.fp_epochs):
            lr = lr_scheduler.step(epoch)

            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            print_info_message(
                'Running epoch {} with learning rates: {:.6f}'.format(
                    epoch, lr))
            start_t = time.time()
            miou_train, train_loss = train(model,
                                           train_loader,
                                           optimizer,
                                           criterion,
                                           seg_classes,
                                           epoch,
                                           device=device)
    if args.optimizer.startswith('Q'):
        optimizer.is_warmup = False
        print('exp_sensitivity calibration fin.')

    if not args.fp_train:
        model.module.quantized.fuse_model()
        model.module.quantized.qconfig = torch.quantization.get_default_qat_qconfig(
            'qnnpack')
        torch.quantization.prepare_qat(model.module.quantized, inplace=True)

    if args.resume:
        start_epoch = args.start_epoch
        if os.path.isfile(args.resume):
            print_info_message('Loading weights from {}'.format(args.resume))
            weight_dict = torch.load(args.resume, device)
            model.module.load_state_dict(weight_dict)
            print_info_message('Done')
        else:
            print_warning_message('No file for resume. Please check.')

    for epoch in range(start_epoch, args.epochs):
        lr = lr_scheduler.step(epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        print_info_message(
            'Running epoch {} with learning rates: {:.6f}'.format(epoch, lr))
        miou_train, train_loss = train(model,
                                       train_loader,
                                       optimizer,
                                       criterion,
                                       seg_classes,
                                       epoch,
                                       device=device)
        miou_val, val_loss = val(model,
                                 val_loader,
                                 criterion,
                                 seg_classes,
                                 device=device)

        # remember best miou and save checkpoint
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)
        if is_best:
            model_file_name = args.savedir + '/model_' + str(epoch +
                                                             1) + '.pth'
            torch.save(weights_dict, model_file_name)
            print('weights saved in {}'.format(model_file_name))
        info = {
            'Segmentation/LR': round(lr, 6),
            'Segmentation/Loss/train': train_loss,
            'Segmentation/Loss/val': val_loss,
            'Segmentation/mIOU/train': miou_train,
            'Segmentation/mIOU/val': miou_val,
            'Segmentation/Complexity/Flops': best_miou,
            'Segmentation/Complexity/Params': best_miou,
        }

        for tag, value in info.items():
            if tag == 'Segmentation/Complexity/Flops':
                my_logger.scalar_summary(tag, value, math.ceil(flops))
            elif tag == 'Segmentation/Complexity/Params':
                my_logger.scalar_summary(tag, value, math.ceil(num_params))
            else:
                my_logger.scalar_summary(tag, value, epoch + 1)

    print_info_message("========== TRAINING FINISHED ===========")
示例#15
0
class VideoBaseModel(BaseModel):
    def __init__(self, opt):
        super(VideoBaseModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        # self.print_network()
        self.load()

        self.log_dict = OrderedDict()

        #### loss
        loss_type = train_opt['pixel_criterion']
        if loss_type == 'l1':
            self.cri_pix = nn.L1Loss(reduction='mean').to(
                self.device)  # Change from sum to mean
        elif loss_type == 'l2':
            self.cri_pix = nn.MSELoss(reduction='mean').to(
                self.device)  # Change from sum to mean
        elif loss_type == 'cb':
            self.cri_pix = CharbonnierLoss().to(self.device)
        elif loss_type == 'huber':
            self.cri_pix = HuberLoss().to(self.device)
        else:
            raise NotImplementedError(
                'Loss type [{:s}] is not recognized.'.format(loss_type))
        self.l_pix_w = train_opt['pixel_weight']

        if self.is_train:
            self.netG.train()

            #### optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            if train_opt['ft_tsa_only']:
                normal_params = []
                tsa_fusion_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if 'tsa_fusion' in k:
                            tsa_fusion_params.append(v)
                        else:
                            normal_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                'Params [{:s}] will not optimize.'.format(k))
                optim_params = [
                    {  # add normal params first
                        'params': normal_params,
                        'lr': train_opt['lr_G']
                    },
                    {
                        'params': tsa_fusion_params,
                        'lr': train_opt['lr_G']
                    },
                ]
            if opt['train']['freeze_front']:
                normal_params = []
                freeze_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if 'module.conv3d_1' in k or 'module.dense_block_1' in k or 'module.dense_block_2' in k or 'module.dense_block_3' in k:
                            freeze_params.append(v)
                        else:
                            normal_params.append(v)
                optim_params = [
                    {  # add normal params first
                        'params': normal_params,
                        'lr': train_opt['lr_G']
                    },
                    {
                        'params': freeze_params,
                        'lr': 0
                    },
                ]
            elif train_opt['small_offset_lr']:
                normal_params = []
                conv_offset_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if 'pcd_align' in k or 'fea_L' in k or 'feature_extraction' in k or 'conv_first' in k:
                            conv_offset_params.append(v)
                        else:
                            normal_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                'Params [{:s}] will not optimize.'.format(k))
                optim_params = [
                    {  # add normal params first
                        'params': normal_params,
                        'lr': train_opt['lr_G']
                    },
                    {
                        'params': conv_offset_params,
                        'lr': train_opt['lr_G'] * 0.1
                    },
                ]
            else:
                optim_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        optim_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                'Params [{:s}] will not optimize.'.format(k))

            if train_opt['optim'] == 'SGD':
                self.optimizer_G = torch.optim.SGD(optim_params,
                                                   lr=train_opt['lr_G'],
                                                   weight_decay=wd_G)
            else:
                self.optimizer_G = torch.optim.Adam(optim_params,
                                                    lr=train_opt['lr_G'],
                                                    weight_decay=wd_G,
                                                    betas=(train_opt['beta1'],
                                                           train_opt['beta2']))

            self.optimizers.append(self.optimizer_G)

            #### schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError()

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQs'].to(self.device)
        if need_GT:
            self.real_H = data['GT'].to(self.device)

    def set_params_lr_zero(self):
        # fix normal module
        self.optimizers[0].param_groups[0]['lr'] = 0

    def optimize_parameters(self, step):
        if self.opt['train'][
                'ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']:
            self.set_params_lr_zero()

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L)

        l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
        l_pix.backward()
        self.optimizer_G.step()

        # set log
        self.log_dict['l_pix'] = l_pix.item()

    def optimize_by_loss(self, loss):
        if self.opt['train']['ft_tsa_only']:
            self.set_params_lr_zero()

        self.optimizer_G.zero_grad()
        loss.backward()
        self.optimizer_G.step()

        # set log
        self.log_dict['l_pix'] = loss.item()

    def calculate_loss(self):
        self.fake_H = self.netG(self.var_L)
        l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
        self.log_dict['l_pix'] = l_pix.item()
        return l_pix

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)

    def load(self, verbose=True):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            if verbose:
                logger.info(
                    'Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])

    def load_for_test(self):
        load_path_G = os.path.join(self.opt['path']['models'], 'latest_G.pth')
        if load_path_G is not None:
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])
        else:
            print('No models are saved!')

    def save(self, iter_label):
        self.save_network(self.netG, 'G', iter_label)

    def save_for_test(self):
        save_path = os.path.join(self.opt['path']['models'], 'latest_G.pth')
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            network = self.netG.module
            state_dict = network.state_dict()
        else:
            state_dict = self.netG.state_dict()
        for key, param in state_dict.items():
            state_dict[key] = param.cpu()
        torch.save(state_dict, save_path)
示例#16
0
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'

    if args.dataset == 'pascal':
        from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST
        train_dataset = VOCSegmentation(root=args.data_path,
                                        train=True,
                                        crop_size=crop_size,
                                        scale=args.scale,
                                        coco_root_dir=args.coco_path)
        val_dataset = VOCSegmentation(root=args.data_path,
                                      train=False,
                                      crop_size=crop_size,
                                      scale=args.scale)
        seg_classes = len(VOC_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'city':
        from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST
        train_dataset = CityscapesSegmentation(root=args.data_path,
                                               train=True,
                                               size=crop_size,
                                               scale=args.scale,
                                               coarse=args.coarse)
        val_dataset = CityscapesSegmentation(root=args.data_path,
                                             train=False,
                                             size=crop_size,
                                             scale=args.scale,
                                             coarse=False)
        seg_classes = len(CITYSCAPE_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
        class_wts[0] = 2.8149201869965
        class_wts[1] = 6.9850029945374
        class_wts[2] = 3.7890393733978
        class_wts[3] = 9.9428062438965
        class_wts[4] = 9.7702074050903
        class_wts[5] = 9.5110931396484
        class_wts[6] = 10.311357498169
        class_wts[7] = 10.026463508606
        class_wts[8] = 4.6323022842407
        class_wts[9] = 9.5608062744141
        class_wts[10] = 7.8698215484619
        class_wts[11] = 9.5168733596802
        class_wts[12] = 10.373730659485
        class_wts[13] = 6.6616044044495
        class_wts[14] = 10.260489463806
        class_wts[15] = 10.287888526917
        class_wts[16] = 10.289801597595
        class_wts[17] = 10.405355453491
        class_wts[18] = 10.138095855713
        class_wts[19] = 0.0

    elif args.dataset == 'greenhouse':
        print(args.use_depth)
        from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentation, GreenhouseDepth, GREENHOUSE_CLASS_LIST
        train_dataset = GreenhouseDepth(root=args.data_path,
                                        list_name='train_depth_ae.txt',
                                        train=True,
                                        size=crop_size,
                                        scale=args.scale,
                                        use_filter=True)
        val_dataset = GreenhouseRGBDSegmentation(root=args.data_path,
                                                 list_name='val_depth_ae.txt',
                                                 train=False,
                                                 size=crop_size,
                                                 scale=args.scale,
                                                 use_depth=True)
        class_weights = np.load('class_weights.npy')[:4]
        print(class_weights)
        class_wts = torch.from_numpy(class_weights).float().to(device)

        seg_classes = len(GREENHOUSE_CLASS_LIST)
    else:
        print_error_message('Dataset: {} not yet supported'.format(
            args.dataset))
        exit(-1)

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    from model.autoencoder.depth_autoencoder import espnetv2_autoenc
    args.classes = 3
    model = espnetv2_autoenc(args)

    train_params = [{
        'params': model.get_basenet_params(),
        'lr': args.lr * args.lr_mult
    }]

    optimizer = optim.SGD(train_params,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 1, crop_size[0], crop_size[1]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            crop_size[0], crop_size[1], flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')
    try:
        writer.add_graph(model,
                         input_to_model=torch.Tensor(1, 3, crop_size[0],
                                                     crop_size[1]))
    except:
        print_log_message(
            "Not able to generate the graph. Likely because your model is not supported by ONNX"
        )

    start_epoch = 0

    print('device : ' + device)

    #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
    #criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type,
    #                             device=device, ignore_idx=args.ignore_idx,
    #                             class_wts=class_wts.to(device))
    criterion = nn.MSELoss()
    # criterion = nn.L1Loss()

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    if args.scheduler == 'fixed':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import FixedMultiStepLR
        lr_scheduler = FixedMultiStepLR(base_lr=args.lr,
                                        steps=step_sizes,
                                        gamma=args.lr_decay)
    elif args.scheduler == 'clr':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import CyclicLR
        lr_scheduler = CyclicLR(min_lr=args.lr,
                                cycle_len=5,
                                steps=step_sizes,
                                gamma=args.lr_decay)
    elif args.scheduler == 'poly':
        from utilities.lr_scheduler import PolyLR
        lr_scheduler = PolyLR(base_lr=args.lr,
                              max_epochs=args.epochs,
                              power=args.power)
    elif args.scheduler == 'hybrid':
        from utilities.lr_scheduler import HybirdLR
        lr_scheduler = HybirdLR(base_lr=args.lr,
                                max_epochs=args.epochs,
                                clr_max=args.clr_max,
                                cycle_len=args.cycle_len)
    elif args.scheduler == 'linear':
        from utilities.lr_scheduler import LinearLR
        lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs)
    else:
        print_error_message('{} scheduler Not supported'.format(
            args.scheduler))
        exit()

    print_info_message(lr_scheduler)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
    best_loss = 0.0
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_seg
        # optimizer.param_groups[1]['lr'] = lr_seg

        # Train
        model.train()
        losses = AverageMeter()
        for i, batch in enumerate(train_loader):
            inputs = batch[1].to(device=device)  # Depth
            target = batch[0].to(device=device)  # RGB

            outputs = model(inputs)

            if device == 'cuda':
                loss = criterion(outputs, target).mean()
                if isinstance(outputs, (list, tuple)):
                    target_dev = outputs[0].device
                    outputs = gather(outputs, target_device=target_dev)
            else:
                loss = criterion(outputs, target)

            losses.update(loss.item(), inputs.size(0))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #             if not (i % 10):
            #                 print("Step {}, write images".format(i))
            #                 image_grid = torchvision.utils.make_grid(outputs.data.cpu()).numpy()
            #                 writer.add_image('Autoencoder/results/train', image_grid, len(train_loader) * epoch + i)

            writer.add_scalar('Autoencoder/Loss/train', loss.item(),
                              len(train_loader) * epoch + i)

            print_info_message('Running batch {}/{} of epoch {}'.format(
                i + 1, len(train_loader), epoch + 1))

        train_loss = losses.avg

        writer.add_scalar('Autoencoder/LR/seg', round(lr_seg, 6), epoch)

        # Val
        if epoch % 5 == 0:
            losses = AverageMeter()
            with torch.no_grad():
                for i, batch in enumerate(val_loader):
                    inputs = batch[2].to(device=device)  # Depth
                    target = batch[0].to(device=device)  # RGB

                    outputs = model(inputs)

                    if device == 'cuda':
                        loss = criterion(outputs, target)  # .mean()
                        if isinstance(outputs, (list, tuple)):
                            target_dev = outputs[0].device
                            outputs = gather(outputs, target_device=target_dev)
                    else:
                        loss = criterion(outputs, target)

                    losses.update(loss.item(), inputs.size(0))

                    image_grid = torchvision.utils.make_grid(
                        outputs.data.cpu()).numpy()
                    writer.add_image('Autoencoder/results/val', image_grid,
                                     epoch)
                    image_grid = torchvision.utils.make_grid(
                        inputs.data.cpu()).numpy()
                    writer.add_image('Autoencoder/inputs/val', image_grid,
                                     epoch)
                    image_grid = torchvision.utils.make_grid(
                        target.data.cpu()).numpy()
                    writer.add_image('Autoencoder/target/val', image_grid,
                                     epoch)

            val_loss = losses.avg

            print_info_message(
                'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
                .format(epoch, lr_base, lr_seg))

            # remember best miou and save checkpoint
            is_best = val_loss < best_loss
            best_loss = min(val_loss, best_loss)

            weights_dict = model.module.state_dict(
            ) if device == 'cuda' else model.state_dict()
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.model,
                    'state_dict': weights_dict,
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.savedir, extra_info_ckpt)

            writer.add_scalar('Autoencoder/Loss/val', val_loss, epoch)

    writer.close()
示例#17
0
    )

    if torch.cuda.device_count() > 1:
        model = DP(model)
        # model = DDP(model)

    model.to(device=device)
    model.__DEBUG__ = False

    try:
        train(
            model=model,
            model_config=model_config,
            config=train_config,
            device=device,
            logger=logger,
            debug=train_config.debug,
        )
    except KeyboardInterrupt:
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "model_config": model_config,
                "train_config": config,
            }, os.path.join(config.checkpoints, "INTERRUPTED.pth.tar"))
        logger.info("Saved interrupt")
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
示例#18
0
def create_and_test_triplet_network(batch_triplet_indices_loader,
                                    experiment_name,
                                    path_to_emb_net,
                                    unseen_triplets,
                                    dataset_name,
                                    model_name,
                                    logger,
                                    test_n,
                                    n,
                                    dim,
                                    layers,
                                    learning_rate=5e-2,
                                    epochs=20,
                                    hl_size=100):
    """
    Description: Constructs the OENN network, defines an optimizer and trains the network on the data w.r.t triplet loss.
    :param model_name:
    :param dataset_name:
    :param test_n:
    :param path_to_emb_net: Data loader object. Gives triplet indices in batches.
    :param n: # points
    :param dim: # features/ dimensions
    :param layers: # layers
    :param learning_rate: learning rate of optimizer.
    :param epochs: # epochs
    :param hl_size: # width of the hidden layer
    :param unseen_triplets: #TODO
    :param logger: # for logging
    :return:
    """

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    digits = int(math.ceil(math.log2(n)))

    #  Define train model
    emb_net_train = define_model(model_name=model_name,
                                 digits=digits,
                                 hl_size=hl_size,
                                 dim=dim,
                                 layers=layers)
    emb_net_train = emb_net_train.to(device)

    for param in emb_net_train.parameters():
        param.requires_grad = False

    if torch.cuda.device_count() > 1:
        emb_net_train = DataParallel(emb_net_train)
        print('multi-gpu')

    checkpoint = torch.load(path_to_emb_net)['model_state_dict']
    key_word = list(checkpoint.keys())[0].split('.')[0]
    if key_word == 'module':
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in checkpoint.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        emb_net_train.load_state_dict(new_state_dict)
    else:
        emb_net_train.load_state_dict(checkpoint)

    emb_net_train.eval()

    #  Define test model
    emb_net_test = define_model(model_name=model_name,
                                digits=digits,
                                hl_size=hl_size,
                                dim=dim,
                                layers=layers)
    emb_net_test = emb_net_test.to(device)

    if torch.cuda.device_count() > 1:
        emb_net_test = DataParallel(emb_net_test)
        print('multi-gpu')

    # Optimizer
    optimizer = torch.optim.Adam(emb_net_test.parameters(), lr=learning_rate)
    criterion = nn.TripletMarginLoss(margin=1, p=2)
    criterion = criterion.to(device)

    logger.info('#### Dataset Selection #### \n')
    logger.info('dataset:', dataset_name)
    logger.info('#### Network and learning parameters #### \n')
    logger.info('------------------------------------------ \n')
    logger.info('Model Name: ' + model_name + '\n')
    logger.info('Number of hidden layers: ' + str(layers) + '\n')
    logger.info('Hidden layer width: ' + str(hl_size) + '\n')
    logger.info('Embedding dimension: ' + str(dim) + '\n')
    logger.info('Learning rate: ' + str(learning_rate) + '\n')
    logger.info('Number of epochs: ' + str(epochs) + '\n')

    logger.info(' #### Training begins #### \n')
    logger.info('---------------------------\n')

    digits = int(math.ceil(math.log2(n)))
    bin_array = data_utils.get_binary_array(n, digits)

    trip_data = torch.tensor(bin_array[unseen_triplets])
    trip = trip_data.squeeze().to(device).float()

    # Training begins
    train_time = 0
    for ep in range(epochs):
        # Epoch is one pass over the dataset
        epoch_loss = 0

        for batch_ind, trips in enumerate(batch_triplet_indices_loader):
            sys.stdout.flush()
            trip = trips.squeeze().to(device).float()

            # Training time
            begin_train_time = time.time()
            # Forward pass
            embedded_a = emb_net_test(trip[:, :digits])
            embedded_p = emb_net_train(trip[:, digits:2 * digits])
            embedded_n = emb_net_train(trip[:, 2 * digits:])
            # Compute loss
            loss = criterion(embedded_a, embedded_p, embedded_n).to(device)
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # End of training
            end_train_time = time.time()
            if batch_ind % 50 == 0:
                logger.info('Epoch: ' + str(ep) + ' Mini batch: ' +
                            str(batch_ind) + '/' +
                            str(len(batch_triplet_indices_loader)) +
                            ' Loss: ' + str(loss.item()))
                sys.stdout.flush()  # Prints faster to the out file
            epoch_loss += loss.item()
            train_time = train_time + end_train_time - begin_train_time

        # Log
        logger.info('Epoch ' + str(ep) + ' - Average Epoch Loss:  ' +
                    str(epoch_loss / len(batch_triplet_indices_loader)) +
                    ' Training time ' + str(train_time))
        sys.stdout.flush()  # Prints faster to the out file

        # Saving the results
        logger.info('Saving the models and the results')
        sys.stdout.flush()  # Prints faster to the out file

        os.makedirs('test_checkpoints', mode=0o777, exist_ok=True)
        model_path = 'test_checkpoints/' + \
                     experiment_name + \
                     '.pt'
        torch.save(
            {
                'epochs': ep,
                'model_state_dict': emb_net_test.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss:': epoch_loss,
            }, model_path)

    # Compute the embedding of the data points.
    bin_array_test = data_utils.get_binary_array(test_n, digits)
    test_embeddings = emb_net_test(
        torch.Tensor(bin_array_test).cuda().float()).cpu().detach().numpy()
    train_embeddings = emb_net_train(
        torch.Tensor(bin_array).cuda().float()).cpu().detach().numpy()
    unseen_triplet_error, _ = data_utils.triplet_error_unseen(
        test_embeddings, train_embeddings, unseen_triplets)

    logger.info('Unseen triplet error is ' + str(unseen_triplet_error))
    return unseen_triplet_error
示例#19
0
class TrainValProcess():
    def __init__(self):
        self.net = ET_Net()
        if (ARGS['weight']):
            self.net.load_state_dict(torch.load(ARGS['weight']))
        else:
            self.net.load_encoder_weight()
        if (ARGS['gpu']):
            self.net = DataParallel(module=self.net.cuda())

        self.train_dataset = get_dataset(dataset_name=ARGS['dataset'],
                                         part='train')
        self.val_dataset = get_dataset(dataset_name=ARGS['dataset'],
                                       part='val')

        self.optimizer = Adam(self.net.parameters(), lr=ARGS['lr'])
        # Use / to get an approximate result, // to get an accurate result
        total_iters = len(
            self.train_dataset) // ARGS['batch_size'] * ARGS['num_epochs']
        self.lr_scheduler = LambdaLR(
            self.optimizer,
            lr_lambda=lambda iter:
            (1 - iter / total_iters)**ARGS['scheduler_power'])
        self.writer = SummaryWriter()

    def train(self, epoch):

        start = time.time()
        self.net.train()
        train_dataloader = DataLoader(self.train_dataset,
                                      batch_size=ARGS['batch_size'],
                                      shuffle=False)
        epoch_loss = 0.
        for batch_index, items in enumerate(train_dataloader):
            images, labels, edges = items['image'], items['label'], items[
                'edge']
            images = images.float()
            labels = labels.long()
            edges = edges.long()

            if ARGS['gpu']:
                labels = labels.cuda()
                images = images.cuda()
                edges = edges.cuda()

            self.optimizer.zero_grad()
            outputs_edge, outputs = self.net(images)
            # print('output edge min:', outputs_edge[0, 1].min(), ' max: ', outputs_edge[0, 1].max())
            # plt.imshow(outputs_edge[0, 1].detach().cpu().numpy() * 255, cmap='gray')
            # plt.show()
            loss_edge = lovasz_softmax(outputs_edge,
                                       edges)  # Lovasz-Softmax loss
            loss_seg = lovasz_softmax(outputs, labels)  #
            loss = ARGS['combine_alpha'] * loss_seg + (
                1 - ARGS['combine_alpha']) * loss_edge
            loss.backward()
            self.optimizer.step()
            self.lr_scheduler.step()

            n_iter = (epoch - 1) * len(train_dataloader) + batch_index + 1

            pred = torch.max(outputs, dim=1)[1]
            iou = torch.sum(pred & labels) / (torch.sum(pred | labels) + 1e-6)

            # print('edge min:', edges.min(), ' max: ', edges.max())
            # print('output edge min:', outputs_edge.min(), ' max: ', outputs_edge.max())

            print(
                'Training Epoch: {epoch} [{trained_samples}/{total_samples}]\tL_edge: {:0.4f}\tL_seg: {:0.4f}\tL_all: {:0.4f}\tIoU: {:0.4f}\tLR: {:0.4f}'
                .format(loss_edge.item(),
                        loss_seg.item(),
                        loss.item(),
                        iou.item(),
                        self.optimizer.param_groups[0]['lr'],
                        epoch=epoch,
                        trained_samples=batch_index * ARGS['batch_size'],
                        total_samples=len(train_dataloader.dataset)))

            epoch_loss += loss.item()

            # update training loss for each iteration
            # self.writer.add_scalar('Train/loss', loss.item(), n_iter)

        for name, param in self.net.named_parameters():
            layer, attr = os.path.splitext(name)
            attr = attr[1:]
            self.writer.add_histogram("{}/{}".format(layer, attr), param,
                                      epoch)

        epoch_loss /= len(train_dataloader)
        self.writer.add_scalar('Train/loss', epoch_loss, epoch)
        finish = time.time()

        print('epoch {} training time consumed: {:.2f}s'.format(
            epoch, finish - start))

    def validate(self, epoch):

        start = time.time()
        self.net.eval()
        val_batch_size = min(ARGS['batch_size'], len(self.val_dataset))
        val_dataloader = DataLoader(self.val_dataset,
                                    batch_size=val_batch_size)
        epoch_loss = 0.
        for batch_index, items in enumerate(val_dataloader):
            images, labels, edges = items['image'], items['label'], items[
                'edge']
            # print('label min:', labels[0].min(), ' max: ', labels[0].max())
            # print('edge min:', labels[0].min(), ' max: ', labels[0].max())

            if ARGS['gpu']:
                labels = labels.cuda()
                images = images.cuda()
                edges = edges.cuda()

            print('image shape:', images.size())

            with torch.no_grad():
                outputs_edge, outputs = self.net(images)
                loss_edge = lovasz_softmax(outputs_edge,
                                           edges)  # Lovasz-Softmax loss
                loss_seg = lovasz_softmax(outputs, labels)  #
                loss = ARGS['combine_alpha'] * loss_seg + (
                    1 - ARGS['combine_alpha']) * loss_edge

            pred = torch.max(outputs, dim=1)[1]
            iou = torch.sum(pred & labels) / (torch.sum(pred | labels) + 1e-6)

            print(
                'Validating Epoch: {epoch} [{val_samples}/{total_samples}]\tLoss: {:0.4f}\tIoU: {:0.4f}'
                .format(loss.item(),
                        iou.item(),
                        epoch=epoch,
                        val_samples=batch_index * val_batch_size,
                        total_samples=len(val_dataloader.dataset)))

            epoch_loss += loss

            # update training loss for each iteration
            # self.writer.add_scalar('Train/loss', loss.item(), n_iter)

        epoch_loss /= len(val_dataloader)
        self.writer.add_scalar('Val/loss', epoch_loss, epoch)

        finish = time.time()

        print('epoch {} training time consumed: {:.2f}s'.format(
            epoch, finish - start))

    def train_val(self):
        print('Begin training and validating:')
        for epoch in range(ARGS['num_epochs']):
            self.train(epoch)
            self.validate(epoch)
            self.net.state_dict()
            print(f'Finish training and validating epoch #{epoch+1}')
            if (epoch + 1) % ARGS['epoch_save'] == 0:
                os.makedirs(ARGS['weight_save_folder'], exist_ok=True)
                torch.save(
                    self.net.state_dict(),
                    os.path.join(ARGS['weight_save_folder'],
                                 f'epoch_{epoch+1}.pth'))
                print(f'Model saved for epoch #{epoch+1}.')
        print('Finish training and validating.')