コード例 #1
0
    def __init__(self, config, args):
        self.args = args
        self.config = config
        self.visdom = args.visdom
        if args.visdom:
            self.vis = visdom.Visdom(env=os.getcwd().split('/')[-1], port=8888)
        # Define Dataloader
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            config)
        self.target_train_loader, self.target_val_loader, self.target_test_loader, _ = make_target_data_loader(
            config)

        # Define network
        self.model = DeepLab(num_classes=self.nclass,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=config.freeze_bn)

        self.D = Discriminator(num_classes=self.nclass, ndf=16)

        train_params = [{
            'params': self.model.get_1x_lr_params(),
            'lr': config.lr
        }, {
            'params': self.model.get_10x_lr_params(),
            'lr': config.lr * config.lr_ratio
        }]

        # Define Optimizer
        self.optimizer = torch.optim.SGD(train_params,
                                         momentum=config.momentum,
                                         weight_decay=config.weight_decay)
        self.D_optimizer = torch.optim.Adam(self.D.parameters(),
                                            lr=config.lr,
                                            betas=(0.9, 0.99))

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=config.loss)
        self.entropy_mini_loss = MinimizeEntropyLoss()
        self.bottleneck_loss = BottleneckLoss()
        self.instance_loss = InstanceLoss()
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(config.lr_scheduler,
                                      config.lr, config.epochs,
                                      len(self.train_loader), config.lr_step,
                                      config.warmup_epochs)
        self.summary = TensorboardSummary('./train_log')
        # labels for adversarial training
        self.source_label = 0
        self.target_label = 1

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            # cudnn.benchmark = True
            self.model = self.model.cuda()

            self.D = torch.nn.DataParallel(self.D)
            patch_replication_callback(self.D)
            self.D = self.D.cuda()

        self.best_pred_source = 0.0
        self.best_pred_target = 0.0
        # Resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint,
                                           map_location=torch.device('cpu'))
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch))
コード例 #2
0
    def __init__(self, config, args):
        self.args = args
        self.config = config
        # Define Dataloader
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(config)
        self.target_train_loader, self.target_val_loader, self.target_test_loader, _ = make_target_data_loader(config)
        # Define network
        self.model = DeepLab(num_classes=self.nclass,
                        backbone=config.backbone,
                        output_stride=config.out_stride,
                        sync_bn=config.sync_bn,
                        freeze_bn=config.freeze_bn)


        train_params = [{'params': self.model.get_1x_lr_params(), 'lr': config.lr},
                        {'params': self.model.get_10x_lr_params(), 'lr': config.lr * config.lr_ratio}]

        # Define Optimizer
        self.optimizer = torch.optim.SGD(train_params, momentum=config.momentum,
                                    weight_decay=config.weight_decay)

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode=config.loss)
        self.consistency = ConsistencyLoss(cuda=args.cuda)
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(config.lr_scheduler, config.lr,
                                      config.epochs, len(self.train_loader),
                                      config.lr_step, config.warmup_epochs)
        self.summary = TensorboardSummary('./train_log')

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            # cudnn.benchmark = True
            self.model = self.model.cuda()

        self.best_pred_source = 0.0
        # Resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint, map_location=torch.device('cpu'))
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, args.start_epoch))