Пример #1
0
class trainNew(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        cell_path = os.path.join(args.saved_arch_path, 'genotype.npy')
        network_path_space = os.path.join(args.saved_arch_path,
                                          'network_path_space.npy')

        new_cell_arch = np.load(cell_path)
        new_network_arch = np.load(network_path_space)

        # Define network
        model = newModel(network_arch=new_network_arch,
                         cell_arch=new_cell_arch,
                         num_classes=self.nclass,
                         num_layers=12)
        #                        output_stride=args.out_stride,
        #                        sync_bn=args.sync_bn,
        #                        freeze_bn=args.freeze_bn)
        self.decoder = Decoder(self.nclass, 'autodeeplab', args, False)
        # TODO: look into these
        # TODO: ALSO look into different param groups as done int deeplab below
        #        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
        #                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
        #
        train_params = [{'params': model.parameters(), 'lr': args.lr}]
        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(
            args.lr_scheduler, args.lr, args.epochs,
            len(self.train_loader))  #TODO: use min_lr ?

        # TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well
        # Using cuda
        if args.cuda:
            if (torch.cuda.device_count() > 1 or args.load_parallel):
                self.model = torch.nn.DataParallel(self.model.cuda())
                patch_replication_callback(self.model)
            self.model = self.model.cuda()
            print('cuda finished')

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                self.model.load_state_dict(new_state_dict)

            else:
                if (torch.cuda.device_count() > 1 or args.load_parallel):
                    self.model.module.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint['state_dict'])

            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            encoder_output, low_level_feature = self.model(image)
            output = self.decoder(encoder_output, low_level_feature)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                encoder_output, low_level_feature = self.model(image)
                output = self.decoder(encoder_output, low_level_feature)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Пример #2
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.train_dir = './data_list/train_lite.csv'
        self.train_list = pd.read_csv(self.train_dir)
        self.val_dir = './data_list/val_lite.csv'
        self.val_list = pd.read_csv(self.val_dir)
        self.train_length = len(self.train_list)
        self.val_length = len(self.val_list)
        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # 方式2
        self.train_gen, self.val_gen, self.test_gen, self.nclass = make_data_loader2(args)
        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)
        # optimizer = torch.optim.Adam(train_params, weight_decay=args.weight_decay)

        # Define Criterion
        # self.criterion = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.criterion1 = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode='ce')
        self.criterion2= SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode='dice')

        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, self.train_length)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                # self.model.module.load_state_dict(checkpoint['state_dict'])
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        prev_time = time.time()
        self.model.train()
        self.evaluator.reset()

        num_img_tr = self.train_length / self.args.batch_size

        for iteration in range(int(num_img_tr)):
            samples = next(self.train_gen)
            image, target = samples['image'], samples['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, iteration, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss1 = self.criterion1(output, target)
            loss2 = self.criterion2(output, make_one_hot(target.long(), num_classes=self.nclass))
            loss = loss1 + loss2
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            self.writer.add_scalar('train/total_loss_iter', loss.item(), iteration + num_img_tr * epoch)


            # print log  默认log_iters = 4
            if iteration % 4 == 0:
                end_time = time.time()
                print("Iter - %d: train loss: %.3f, celoss: %.4f, diceloss: %.4f, time cost: %.3f s" \
                      % (iteration, loss.item(), loss1.item(), loss2.item(), end_time - prev_time))
                prev_time = time.time()

            # Show 10 * 3 inference results each epoch
            if iteration % (num_img_tr // 10) == 0:
                global_step = iteration + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        print("input image shape/iter:", image.shape)

        # train evaluate
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        IoU = self.evaluator.Mean_Intersection_over_Union()
        mIoU = np.nanmean(IoU)
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print("Acc_tr:{}, Acc_class_tr:{}, IoU_tr:{}, mIoU_tr:{}, fwIoU_tr: {}".format(Acc, Acc_class, IoU, mIoU, FWIoU))

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, iteration * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)





    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        val_loss = 0.0
        prev_time = time.time()
        num_img_val = self.val_length / self.args.batch_size
        print("Validation:","epoch ", epoch)
        print(num_img_val)
        for iteration in range(int(num_img_val)):
            samples = next(self.val_gen)
            image, target = samples['image'], samples['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():  #
                output = self.model(image)
            loss1 = self.criterion1(output, target)
            loss2 = self.criterion2(output, make_one_hot(target.long(), num_classes=self.nclass))
            loss = loss1 + loss2
            val_loss += loss.item()
            self.writer.add_scalar('val/total_loss_iter', loss.item(), iteration + num_img_val * epoch)
            val_loss += loss.item()

            # print log  默认log_iters = 4
            if iteration % 4 == 0:
                end_time = time.time()
                print("Iter - %d: validation loss: %.3f, celoss: %.4f, diceloss: %.4f, time cost: %.3f s" \
                      % (iteration, loss.item(), loss1.item(), loss2.item(), end_time - prev_time))
                prev_time = time.time()


            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        print(image.shape)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        IoU = self.evaluator.Mean_Intersection_over_Union()
        mIoU = np.nanmean(IoU)
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', val_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, iteration * self.args.batch_size + image.data.shape[0]))
        print("Acc_val:{}, Acc_class_val:{}, IoU:val:{}, mIoU_val:{}, fwIoU_val: {}".format(Acc, Acc_class, IoU, mIoU, FWIoU))
        print('Loss: %.3f' % val_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        if args.dataset == 'CamVid':
            size = 512
            train_file = os.path.join(os.getcwd() + "\\data\\CamVid", "train.csv")
            val_file = os.path.join(os.getcwd() + "\\data\\CamVid", "val.csv")
            print('=>loading datasets')
            train_data = CamVidDataset(csv_file=train_file, phase='train')
            self.train_loader = torch.utils.data.DataLoader(train_data,
                                                     batch_size=args.batch_size,
                                                     shuffle=True,
                                                     num_workers=args.num_workers)
            val_data = CamVidDataset(csv_file=val_file, phase='val', flip_rate=0)
            self.val_loader = torch.utils.data.DataLoader(val_data,
                                                     batch_size=args.batch_size,
                                                     shuffle=True,
                                                     num_workers=args.num_workers)
            self.num_class = 32
        elif args.dataset == 'Cityscapes':
            kwargs = {'num_workers': args.num_workers, 'pin_memory': True}
            self.train_loader, self.val_loader, self.test_loader, self.num_class = make_data_loader(args, **kwargs)

        # Define network
        if args.net == 'resnet101':
            blocks = [2,4,23,3]
            fpn = FPN(blocks, self.num_class, back_bone=args.net)

        # Define Optimizer
        self.lr = self.args.lr
        if args.optimizer == 'adam':
            self.lr = self.lr * 0.1
            optimizer = torch.optim.Adam(fpn.parameters(), lr=args.lr, momentum=0, weight_decay=args.weight_decay)
        elif args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(fpn.parameters(), lr=args.lr, momentum=0, weight_decay=args.weight_decay)

        # Define Criterion
        if args.dataset == 'CamVid':
            self.criterion = nn.CrossEntropyLoss()
        elif args.dataset == 'Cityscapes':
            weight = None
            self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode='ce')

        self.model = fpn
        self.optimizer = optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.num_class)

        # multiple mGPUs
        if args.mGPUs:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume:
            output_dir = os.path.join(args.save_dir, args.dataset, args.checkname)
            runs = sorted(glob.glob(os.path.join(output_dir, 'experiment_*')))
            run_id = int(runs[-1].split('_')[-1]) - 1 if runs else 0
            experiment_dir = os.path.join(output_dir, 'experiment_{}'.format(str(run_id)))
            load_name = os.path.join(experiment_dir,
                                 'checkpoint.pth.tar')
            if not os.path.isfile(load_name):
                raise RuntimeError("=> no checkpoint found at '{}'".format(load_name))
            checkpoint = torch.load(load_name)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            self.lr = checkpoint['optimizer']['param_groups'][0]['lr']
            print("=> loaded checkpoint '{}'(epoch {})".format(load_name, checkpoint['epoch']))

        self.lr_stage = [68, 93]
        self.lr_staget_ind = 0


    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        # tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        if self.lr_staget_ind > 1 and epoch % (self.lr_stage[self.lr_staget_ind]) == 0:
            adjust_learning_rate(self.optimizer, self.args.lr_decay_gamma)
            self.lr *= self.args.lr_decay_gamma
            self.lr_staget_ind += 1
        for iteration, batch in enumerate(self.train_loader):
            if self.args.dataset == 'CamVid':
                image, target = batch['X'], batch['l']
            elif self.args.dataset == 'Cityscapes':
                image, target = batch['image'], batch['label']
            else:
                raise NotImplementedError
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.optimizer.zero_grad()
            inputs = Variable(image)
            labels = Variable(target)

            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels.long())
            loss_val = loss.item()
            loss.backward(torch.ones_like(loss))
            # loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            # tbar.set_description('\rTrain loss:%.3f' % (train_loss / (iteration + 1)))

            if iteration % 10 == 0:
                print("Epoch[{}]({}/{}):Loss:{:.4f}, learning rate={}".format(epoch, iteration, len(self.train_loader), loss.data, self.lr))

            self.writer.add_scalar('train/total_loss_iter', loss.item(), iteration + num_img_tr * epoch)

            #if iteration % (num_img_tr // 10) == 0:
            #    global_step = iteration + num_img_tr * epoch
            #    self.summary.visualize_image(self.witer, self.args.dataset, image, target, outputs, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, iteration * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
                }, is_best)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for iter, batch in enumerate(self.val_loader):
            if self.args.dataset == 'CamVid':
                image, target = batch['X'], batch['l']
            elif self.args.dataset == 'Cityscapes':
                image, target = batch['image'], batch['label']
            else:
                raise NotImplementedError
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f ' % (test_loss / (iter + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/FWIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, iter * self.args.batch_size + image.shape[0]))
        print("Acc:{:.5f}, Acc_class:{:.5f}, mIoU:{:.5f}, fwIoU:{:.5f}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Пример #4
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.use_amp = True if (APEX_AVAILABLE and args.use_amp) else False
        self.opt_level = args.opt_level

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True,
            'drop_last': True
        }
        self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                raise NotImplementedError
                #if so, which trainloader to use?
                # weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define network
        model = AutoDeeplab(self.nclass, 12, self.criterion,
                            self.args.filter_multiplier,
                            self.args.block_multiplier, self.args.step)
        optimizer = torch.optim.SGD(model.weight_parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        self.model, self.optimizer = model, optimizer

        self.architect_optimizer = torch.optim.Adam(
            self.model.arch_parameters(),
            lr=args.arch_lr,
            betas=(0.9, 0.999),
            weight_decay=args.arch_weight_decay)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler,
                                      args.lr,
                                      args.epochs,
                                      len(self.train_loaderA),
                                      min_lr=args.min_lr)
        # TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well
        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # mixed precision
        if self.use_amp and args.cuda:
            keep_batchnorm_fp32 = True if (self.opt_level == 'O2'
                                           or self.opt_level == 'O3') else None

            # fix for current pytorch version with opt_level 'O1'
            if self.opt_level == 'O1' and torch.__version__ < '1.3':
                for module in self.model.modules():
                    if isinstance(module,
                                  torch.nn.modules.batchnorm._BatchNorm):
                        # Hack to fix BN fprop without affine transformation
                        if module.weight is None:
                            module.weight = torch.nn.Parameter(
                                torch.ones(module.running_var.shape,
                                           dtype=module.running_var.dtype,
                                           device=module.running_var.device),
                                requires_grad=False)
                        if module.bias is None:
                            module.bias = torch.nn.Parameter(
                                torch.zeros(module.running_var.shape,
                                            dtype=module.running_var.dtype,
                                            device=module.running_var.device),
                                requires_grad=False)

            # print(keep_batchnorm_fp32)
            self.model, [self.optimizer,
                         self.architect_optimizer] = amp.initialize(
                             self.model,
                             [self.optimizer, self.architect_optimizer],
                             opt_level=self.opt_level,
                             keep_batchnorm_fp32=keep_batchnorm_fp32,
                             loss_scale="dynamic")

            print('cuda finished')

        # Using data parallel
        if args.cuda and len(self.args.gpu_ids) > 1:
            if self.opt_level == 'O2' or self.opt_level == 'O3':
                print(
                    'currently cannot run with nn.DataParallel and optimization level',
                    self.opt_level)
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            print('training on multiple-GPUs')

        #checkpoint = torch.load(args.resume)
        #print('about to load state_dict')
        #self.model.load_state_dict(checkpoint['state_dict'])
        #print('model loaded')
        #sys.exit()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                # self.model.load_state_dict(new_state_dict)
                copy_state_dict(self.model.state_dict(), new_state_dict)

            else:
                if torch.cuda.device_count() > 1 or args.load_parallel:
                    # self.model.module.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.module.state_dict(),
                                    checkpoint['state_dict'])
                else:
                    # self.model.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.state_dict(),
                                    checkpoint['state_dict'])

            if not args.ft:
                # self.optimizer.load_state_dict(checkpoint['optimizer'])
                copy_state_dict(self.optimizer.state_dict(),
                                checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loaderA)
        num_img_tr = len(self.train_loaderA)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            if self.use_amp:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            self.optimizer.step()

            if epoch >= self.args.alpha_epoch:
                search = next(iter(self.train_loaderB))
                image_search, target_search = search['image'], search['label']
                if self.args.cuda:
                    image_search, target_search = image_search.cuda(
                    ), target_search.cuda()

                self.architect_optimizer.zero_grad()
                output_search = self.model(image_search)
                arch_loss = self.criterion(output_search, target_search)
                if self.use_amp:
                    with amp.scale_loss(
                            arch_loss,
                            self.architect_optimizer) as arch_scaled_loss:
                        arch_scaled_loss.backward()
                else:
                    arch_loss.backward()
                self.architect_optimizer.step()

            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            #self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

            #torch.cuda.empty_cache()
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            if torch.cuda.device_count() > 1:
                state_dict = self.model.module.state_dict()
            else:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': state_dict,
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)
        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            if torch.cuda.device_count() > 1:
                state_dict = self.model.module.state_dict()
            else:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': state_dict,
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Пример #5
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()

        # Define TensorBoard summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        if not args.test:
            self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        norm = bn

        # Define Network
        model = Model(args, self.nclass)

        train_params = [{'params': model.parameters(), 'lr': args.lr}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        criterion = nn.CrossEntropyLoss()

        self.model, self.optimizer, self.criterion = model, optimizer, criterion

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using CUDA
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("{}: No such checkpoint exists".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            if args.cuda:
                pretrained_dict = checkpoint['state_dict']
                model_dict = {}
                state_dict = self.model.module.state_dict()
                for k, v in pretrained_dict.items():
                    if k in state_dict:
                        model_dict[k] = v
                state_dict.update(model_dict)
                self.model.module.load_state_dict(state_dict)
            else:
                print("Please use CUDA")
                raise NotImplementedError

            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("Loading {} (epoch {}) successfully done".format(
                args.resume, checkpoint['epoch']))

        if args.ft:
            args.start_epoch = 0

        # layer wise freezing
        self.histories = []
        self.history = {}
        self.isTrained = False
        self.freeze_count = 0
        self.total_count = 0
        for i in model.parameters():
            self.total_count += 1

    def train(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        if self.isTrained: return
        for i, sample in enumerate(tbar):
            image, target = sample['image'].cuda(), sample['label'].cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print("[epoch: %d, loss: %.3f]" % (epoch, train_loss))

        M = 0

        if self.args.freeze or self.args.time:
            for n, i in self.model.named_parameters():
                self.history[str(epoch) + n] = i.cpu().detach()
            if epoch >= 1:
                for n, i in self.model.named_parameters():
                    if not 'conv' in n or not 'layer' in n: continue
                    dif = self.history[str(epoch) +
                                       n] - self.history[str(epoch - 1) + n]
                    m = np.abs(dif).mean()
                    if m < 1e-04:
                        M += 1
                        if i.requires_grad and self.args.freeze:
                            i.requires_grad = False
                            self.freeze_count += 1
                            if not self.args.decomp: continue
                            name = n.split('.')
                            if name[0] == 'layer1': layer = self.model.layer1
                            elif name[0] == 'layer2': layer = self.model.layer2
                            elif name[0] == 'layer3': layer = self.model.layer3
                            elif name[0] == 'layer4': layer = self.model.layer4
                            else: continue
                            conv = layer._modules[int(name[1])]
                            dec = tucker_decomposition_conv_layer(conv)
                            layer._modules[int(name[1])] = dec
                if self.freeze_count == self.total_count: self.isTrained = True
            if not self.args.freeze and self.args.time:
                self.histories.append((epoch, M))
            else:
                self.histories.append((epoch, self.freeze_count))

        if self.args.no_val and not self.args.time:
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def val(self, epoch):
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')

        self.model.eval()

        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)

            #
            test_loss += loss.item()

            # top accuracy record
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))

            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            #self.evaluator.add_batch(target, pred)

        # Fast test during the training
        _top1 = top1.avg
        _top5 = top5.avg
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/top1', _top1, epoch)
        self.writer.add_scalar('val/top5', _top5, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Top-1: %.3f, Top-5: %.3f" % (_top1, _top5))
        print('Loss: %.3f' % test_loss)

        new_pred = _top1
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = float(new_pred)
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Пример #6
0
class Trainer(object):
    def __init__(self,
                 batch_size=32,
                 optimizer_name="Adam",
                 lr=1e-3,
                 weight_decay=1e-5,
                 epochs=200,
                 model_name="model01",
                 gpu_ids=None,
                 resume=None,
                 tqdm=None,
                 is_develop=False):
        """
        args:
            batch_size = (int) batch_size of training and validation
            lr = (float) learning rate of optimization
            weight_decay = (float) weight decay of optimization
            epochs = (int) The number of epochs of training
            model_name = (string) The name of training model. Will be folder name.
            gpu_ids = (List) List of gpu_ids. (e.g. gpu_ids = [0, 1]). Use CPU, if it is None. 
            resume = (Dict) Dict of some settings. (resume = {"checkpoint_path":PATH_of_checkpoint, "fine_tuning":True or False}). 
                     Learn from scratch, if it is None.
            tqdm = (tqdm Object) progress bar object. Set your tqdm please.
                   Don't view progress bar, if it is None.
        """
        # Set params
        self.batch_size = batch_size
        self.epochs = epochs
        self.start_epoch = 0
        self.use_cuda = (gpu_ids is not None) and torch.cuda.is_available
        self.tqdm = tqdm
        self.use_tqdm = tqdm is not None
        # Define Utils. (No need to Change.)
        """
        These are Project Modules.
        You may not have to change these.
        
        Saver: Save model weight. / <utils.saver.Saver()>
        TensorboardSummary: Write tensorboard file. / <utils.summaries.TensorboardSummary()>
        Evaluator: Calculate some metrics (e.g. Accuracy). / <utils.metrics.Evaluator()>
        """
        ## ***Define Saver***
        self.saver = Saver(model_name, lr, epochs)
        self.saver.save_experiment_config()

        ## ***Define Tensorboard Summary***
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # ------------------------- #
        # Define Training components. (You have to Change!)
        """
        These are important setting for training.
        You have to change these.
        
        make_data_loader: This creates some <Dataloader>s. / <dataloader.__init__>
        Modeling: You have to define your Model. / <modeling.modeling.Modeling()>
        Evaluator: You have to define Evaluator. / <utils.metrics.Evaluator()>
        Optimizer: You have to define Optimizer. / <utils.optimizer.Optimizer()>
        Loss: You have to define Loss function. / <utils.loss.Loss()>
        """
        ## ***Define Dataloader***
        self.train_loader, self.val_loader, self.test_loader, self.num_classes = make_data_loader(
            batch_size, is_develop=is_develop)

        ## ***Define Your Model***
        self.model = Modeling(self.num_classes)

        ## ***Define Evaluator***
        self.evaluator = Evaluator(self.num_classes)

        ## ***Define Optimizer***
        self.optimizer = Optimizer(self.model.parameters(),
                                   optimizer_name=optimizer_name,
                                   lr=lr,
                                   weight_decay=weight_decay)

        ## ***Define Loss***
        self.criterion = SegmentationLosses(
            weight=torch.tensor([1.0, 1594.0]).cuda()).build_loss('ce')
        # self.criterion = SegmentationLosses().build_loss('focal')
        #  self.criterion = BCEDiceLoss()
        # ------------------------- #
        # Some settings
        """
        You don't have to touch bellow code.
        
        Using cuda: Enable to use cuda if you want.
        Resuming checkpoint: You can resume training if you want.
        """
        ## ***Using cuda***
        if self.use_cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=gpu_ids).cuda()

        ## ***Resuming checkpoint***
        """You can ignore bellow code."""
        self.best_pred = 0.0
        if resume is not None:
            if not os.path.isfile(resume["checkpoint_path"]):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    resume["checkpoint_path"]))
            checkpoint = torch.load(resume["checkpoint_path"])
            self.start_epoch = checkpoint['epoch']
            if self.use_cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if resume["fine_tuning"]:
                # resume params of optimizer, if run fine tuning.
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                self.start_epoch = 0
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume["checkpoint_path"], checkpoint['epoch']))

    def _run_epoch(self,
                   epoch,
                   mode="train",
                   leave_progress=True,
                   use_optuna=False):
        """
        run training or validation 1 epoch.
        You don't have to change almost of this method.
        
        args:
            epoch = (int) How many epochs this time.
            mode = {"train" or "val"}
            leave_progress = {True or False} Can choose whether leave progress bar or not.
            use_optuna = {True or False} Can choose whether use optuna or not.
        
        Change point (if you need):
        - Evaluation: You can change metrics of monitoring.
        - writer.add_scalar: You can change metrics to be saved in tensorboard.
        """
        # ------------------------- #
        leave_progress = leave_progress and not use_optuna
        # Initializing
        epoch_loss = 0.0
        ## Set model mode & tqdm (progress bar; it wrap dataloader)
        assert (mode == "train") or (
            mode == "val"
        ), "argument 'mode' can be 'train' or 'val.' Not {}.".format(mode)
        if mode == "train":
            data_loader = self.tqdm(
                self.train_loader,
                leave=leave_progress) if self.use_tqdm else self.train_loader
            self.model.train()
            num_dataset = len(self.train_loader)
        elif mode == "val":
            data_loader = self.tqdm(
                self.val_loader,
                leave=leave_progress) if self.use_tqdm else self.val_loader
            self.model.eval()
            num_dataset = len(self.val_loader)
        ## Reset confusion matrix of evaluator
        self.evaluator.reset()

        # ------------------------- #
        # Run 1 epoch
        for i, sample in enumerate(data_loader):
            ## ***Get Input data***
            inputs, target = sample["input"], sample["label"]
            if self.use_cuda:
                inputs, target = inputs.cuda(), target.cuda()

            ## ***Calculate Loss <Train>***
            if mode == "train":
                self.optimizer.zero_grad()
                output = self.model(inputs)
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()
            ## ***Calculate Loss <Validation>***
            elif mode == "val":
                with torch.no_grad():
                    output = self.model(inputs)
                loss = self.criterion(output, target)
            epoch_loss += loss.item()
            ## ***Report results***
            if self.use_tqdm:
                data_loader.set_description('{} loss: {:.3f}'.format(
                    mode, epoch_loss / (i + 1)))
            ## ***Add batch results into evaluator***
            target = target.cpu().numpy()
            output = torch.argmax(output, axis=1).data.cpu().numpy()
            self.evaluator.add_batch(target, output)

        ## **********Evaluate Score**********
        """You can add new metrics! <utils.metrics.Evaluator()>"""
        # Acc = self.evaluator.Accuracy()
        miou = self.evaluator.Mean_Intersection_over_Union()

        if not use_optuna:
            ## ***Save eval into Tensorboard***
            self.writer.add_scalar('{}/loss_epoch'.format(mode),
                                   epoch_loss / (i + 1), epoch)
            # self.writer.add_scalar('{}/Acc'.format(mode), Acc, epoch)
            self.writer.add_scalar('{}/miou'.format(mode), miou, epoch)
            print('Total {} loss: {:.3f}'.format(mode,
                                                 epoch_loss / num_dataset))
            print("{0} mIoU:{1:.2f}".format(mode, miou))

        # Return score to watch. (update checkpoint or optuna's objective)
        return miou

    def run(self, leave_progress=True, use_optuna=False):
        """
        Run all epochs of training and validation.
        """
        for epoch in tqdm(range(self.start_epoch, self.epochs)):
            print(pycolor.GREEN + "[Epoch: {}]".format(epoch) + pycolor.END)

            ## ***Train***
            print(pycolor.YELLOW + "Training:" + pycolor.END)
            self._run_epoch(epoch,
                            mode="train",
                            leave_progress=leave_progress,
                            use_optuna=use_optuna)
            ## ***Validation***
            print(pycolor.YELLOW + "Validation:" + pycolor.END)
            score = self._run_epoch(epoch,
                                    mode="val",
                                    leave_progress=leave_progress,
                                    use_optuna=use_optuna)
            print("---------------------")
            if score > self.best_pred:
                print("model improve best score from {:.4f} to {:.4f}.".format(
                    self.best_pred, score))
                self.best_pred = score
                self.saver.save_checkpoint({
                    'epoch':
                    epoch + 1,
                    'state_dict':
                    self.model.state_dict(),
                    'optimizer':
                    self.optimizer.state_dict(),
                    'best_pred':
                    self.best_pred,
                })
        self.writer.close()
        return self.best_pred
Пример #7
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary

        # Define Dataloader
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        if DEBUG:
            print("get device: ", self.device)
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args)
        # Define network

        modelDeeplab = DeepLab3d(num_classes=self.nclass,
                                 backbone=args.backbone,
                                 output_stride=args.out_stride,
                                 sync_bn=args.sync_bn,
                                 freeze_bn=args.freeze_bn).cuda()
        Bilstm = BiLSTM(cube_D * cube_D * cube_D * 3,
                        cube_D * cube_D * cube_D * 3, 1).cuda()
        train_params = [{
            'params': modelDeeplab.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': modelDeeplab.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizerSGD = torch.optim.SGD(train_params,
                                       momentum=args.momentum,
                                       weight_decay=args.weight_decay,
                                       nesterov=args.nesterov)
        optimizerADAM = torch.optim.Adam(Bilstm.parameters())
        # Define Criterion
        # whether to use class balanced weights

        #if args.use_balanced_weights:
        #    classes_weights_path = os.path.join(ROOT_PATH, args.dataset+'_classes_weights.npy')

        #    if os.path.isfile(classes_weights_path):
        #        weight = np.load(classes_weights_path)
        #    else:
        #        weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
        #    weight = torch.from_numpy(weight.astype(np.float32)) ##########weight not cuda

        #else:
        #    weight = None

        self.deeplabCriterion = DiceCELoss().cuda()
        self.lstmCost = torch.nn.BCELoss().cuda()
        self.deeplab, self.Bilstm, self.optimizerSGD, self.optimizerADAM = modelDeeplab, Bilstm, optimizerSGD, optimizerADAM

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler

        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda

        #if args.cuda: --------注释掉不会有问题的吧
        #    self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
        #    patch_replication_callback(self.model)
        #    self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.deeplab.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.deeplab.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):

        train_loss = 0.0
        dice_loss_count = 0.0
        ce_loss_count = 0.0
        num_count = 0

        self.deeplab.eval()

        tbar = tqdm(self.train_loader)

        num_img_tr = len(self.train_loader)

        for i, sample in enumerate(tbar):
            image, target = sample
            target_sque = target.squeeze(
            )  #期望得到没有 batch, channel的譬如50*384*384图像
            img_sque = image.squeeze()
            if DEBUG:
                print("image, target ,sque size feed in model,", image.size(),
                      target.size(), target_sque.size())
            image, target = image.cuda(), target.cuda()

            self.scheduler(self.optimizerSGD, i, epoch, self.best_pred)
            self.optimizerSGD.zero_grad()

            output = self.deeplab(image)
            if DEBUG:
                print(output.size())
            n, c, d, w, h = output.shape
            output2 = torch.tensor((np.zeros(
                (n, c, d, w, h))).astype(np.float32))
            if (output.is_cuda == True):
                output2 = output2.to(self.device)
            for mk1 in range(0, n):
                for mk2 in range(0, c):  #对于每个n, c进行正则化
                    output2[mk1, mk2, :, :, :] = (
                        output[mk1, mk2, :, :, :] -
                        torch.min(output[mk1, mk2, :, :, :])) / (
                            torch.max(output[mk1, mk2, :, :, :]) -
                            torch.min(output[mk1, mk2, :, :, :]))

            loss, dice_loss, ce_loss = self.deeplabCriterion(
                output, output2, target, self.device)

            loss.backward()

            self.optimizerSGD.step()
            #####---------------------------------lstm part---------------------
            aro = output2[0][0]
            aro = aro.detach().cpu().numpy()
            gro = output2[0][1]
            gro = gro.detach().cpu().numpy()  #要求batch必须是1
            orig_vol_dim, bbx_loc = get_bounding_box_loc(img=target_sque,
                                                         bbx_ext=10)
            aux_grid_list = load_nii2grid(grid_D,
                                          grid_ita,
                                          bbx_loc=bbx_loc,
                                          img=target_sque)  #读取label,
            aux_grid_list_c0 = load_nii2grid(grid_D,
                                             grid_ita,
                                             img=gro,
                                             bbx_loc=bbx_loc)  #ground
            aux_grid_list_c1 = load_nii2grid(grid_D,
                                             grid_ita,
                                             img=aro,
                                             bbx_loc=bbx_loc)  #arotia

            us_grid_list = load_nii2grid(grid_D,
                                         grid_ita,
                                         img=img_sque,
                                         bbx_loc=bbx_loc)  #rawimage
            label_grid_list = []
            for g in range(len(us_grid_list)):
                us_grid_vol = us_grid_list[g]  #rawimage

                aux_grid_vol = aux_grid_list[g]  #label
                aux_grid_vol_c0 = aux_grid_list_c0[g]  #ground
                aux_grid_vol_c1 = aux_grid_list_c1[g]  #arotia
                # serialization grid to sequence
                us_mat = partition_vol2grid2seq(
                    us_grid_vol, cube_D, cube_ita,
                    norm_fact=255.0)  #正则化rawimage并切分

                aux_mat = partition_vol2grid2seq(aux_grid_vol,
                                                 cube_D,
                                                 cube_ita,
                                                 norm_fact=1.0)  #  label切分
                aux_mat_c0 = partition_vol2grid2seq(aux_grid_vol_c0,
                                                    cube_D,
                                                    cube_ita,
                                                    norm_fact=1.0)  #found切分
                aux_mat_c1 = partition_vol2grid2seq(aux_grid_vol_c1,
                                                    cube_D,
                                                    cube_ita,
                                                    norm_fact=1.0)  # arotia切分
                feat_mat = np.concatenate((us_mat, aux_mat_c0, aux_mat_c1),
                                          axis=1)  #串联rawinage,ground,arotia
                #print(feat_mat.shape)
                feat_mat = torch.from_numpy(feat_mat)  #转换为torchtensor
                #feat_map=feat_mat.float()   #转换为float类型
                feat_mat = feat_mat.unsqueeze(0)  #增加维度匹配LSTM的轮子
                feat_mat = Variable(
                    feat_mat).float().cuda()  #切换为float类型,也许可以试试double?
                #feat_mat.unsqueeze(0)
                y_label_seq = self.Bilstm(feat_mat)  #喂进网络
                #print(y_label_seq.shape)
                self.optimizerADAM.zero_grad()
                aux_mat = torch.from_numpy(aux_mat)  #讲label换为tensor
                aux_mat = aux_mat.float().cuda()  #label换为浮点型
                lstmloss = self.lstmCost(y_label_seq, aux_mat)  #计算损失
                lstmloss.backward()
                self.optimizerADAM.step()
            #######------------------------------------------------------------------------------
            train_loss += loss.item() + lstmloss
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            dice_loss_count = dice_loss_count + dice_loss.item()
            ce_loss_count = ce_loss_count + ce_loss.item()
            num_count = num_count + 1

            # Show 10 * 3 inference results each epoch

            if i % (num_img_tr // 5) == 0:

                global_step = i + num_img_tr * epoch

        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))

        print('Loss: %.3f, dice loss: %.3f, ce loss: %.3f' %
              (train_loss, dice_loss_count / num_count,
               ce_loss_count / num_count))  #maybe here is something wrong

        if self.args.no_val:

            # save checkpoint every epoch

            is_best = False

            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.deeplab.module.state_dict(),
                    'optimizerSGD': self.optimizerSGD.state_dict(),
                    'optimizerADAM': self.optimizerADAM.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):

        self.deeplab.eval()

        self.evaluator.reset()

        tbar = tqdm(self.val_loader, desc='\r')

        test_loss = 0.0
        dice_loss = 0.0
        ce_loss = 0.0
        num_count = 0
        for i, sample in enumerate(tbar):
            image, target = sample
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output = self.deeplab(image)
            n, c, d, w, h = output.shape
            output2 = torch.tensor((np.zeros(
                (n, c, d, w, h))).astype(np.float32))
            if (output.is_cuda == True):
                output2 = output2.to(self.device)
            for mk1 in range(0, n):
                for mk2 in range(0, c):  #对于每个n, c进行正则化
                    output2[mk1, mk2, :, :, :] = (
                        output[mk1, mk2, :, :, :] -
                        torch.min(output[mk1, mk2, :, :, :])) / (
                            torch.max(output[mk1, mk2, :, :, :]) -
                            torch.min(output[mk1, mk2, :, :, :]))

            loss, dice, ce = self.criterion(output, ioutput2, target,
                                            self.device)
            test_loss += loss.item()
            dice_loss += dice.item()
            ce_loss += ce.item()
            num_count += 1
            tbar.set_description(
                'Test loss: %.3f, dice loss: %.3f, ce loss: %.3f' %
                (test_loss /
                 (i + 1), dice_loss / num_count, ce_loss / num_count))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()

            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            #            if self.args.cuda:
            #                target, pred = torch.from_numpy(target).cuda(), torch.from_numpy(pred).cuda()
            self.evaluator.add_batch(np.squeeze(target), pred)

        # Fast test during the training

        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()

        print('Validation:')

        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))

        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))

        print('Loss: %.3f' % test_loss, dice_loss, ce_loss)

        new_pred = mIoU

        if new_pred > self.best_pred:

            is_best = True

            self.best_pred = new_pred

            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.deeplab.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
            print("ltt save ckpt!")
Пример #8
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last':True}

        self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader = make_data_loader(args, **kwargs)

        # Define network
        model = AutoStereo(maxdisp = self.args.max_disp, 
                           Fea_Layers=self.args.fea_num_layers, Fea_Filter=self.args.fea_filter_multiplier, 
                           Fea_Block=self.args.fea_block_multiplier, Fea_Step=self.args.fea_step, 
                           Mat_Layers=self.args.mat_num_layers, Mat_Filter=self.args.mat_filter_multiplier, 
                           Mat_Block=self.args.mat_block_multiplier, Mat_Step=self.args.mat_step)

        optimizer_F = torch.optim.SGD(
                model.feature.weight_parameters(), 
                args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay
            )        
        optimizer_M = torch.optim.SGD(
                model.matching.weight_parameters(), 
                args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay
            )

 
        self.model, self.optimizer_F, self.optimizer_M = model, optimizer_F, optimizer_M       
        self.architect_optimizer_F = torch.optim.Adam(self.model.feature.arch_parameters(),
                                                    lr=args.arch_lr, betas=(0.9, 0.999),
                                                    weight_decay=args.arch_weight_decay)

        self.architect_optimizer_M = torch.optim.Adam(self.model.matching.arch_parameters(),
                                                    lr=args.arch_lr, betas=(0.9, 0.999),
                                                    weight_decay=args.arch_weight_decay)

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loaderA), min_lr=args.min_lr)
        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model).cuda()

        # Resuming checkpoint
        self.best_pred = 100.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    if k.find('module') != -1:
                        print(1)
                        pdb.set_trace()
                        name = k[7:]  # remove 'module.' of dataparallel
                        new_state_dict[name] = v
                # self.model.load_state_dict(new_state_dict)
                pdb.set_trace()
                copy_state_dict(self.model.state_dict(), new_state_dict)

            else:
                if torch.cuda.device_count() > 1:#or args.load_parallel:
                    # self.model.module.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict'])
                else:
                    # self.model.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict'])


            if not args.ft:
                # self.optimizer.load_state_dict(checkpoint['optimizer'])
                copy_state_dict(self.optimizer_M.state_dict(), checkpoint['optimizer_M'])
                copy_state_dict(self.optimizer_F.state_dict(), checkpoint['optimizer_F'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

        print('Total number of model parameters : {}'.format(sum([p.data.nelement() for p in self.model.parameters()])))
        print('Number of Feature Net parameters: {}'.format(sum([p.data.nelement() for p in self.model.module.feature.parameters()])))
        print('Number of Matching Net parameters: {}'.format(sum([p.data.nelement() for p in self.model.module.matching.parameters()])))


    def training(self, epoch):
        train_loss = 0.0
        valid_iteration = 0
        self.model.train()
        tbar = tqdm(self.train_loaderA)
        num_img_tr = len(self.train_loaderA)

        for i, batch in enumerate(tbar):
            input1, input2, target = Variable(batch[0],requires_grad=True), Variable(batch[1], requires_grad=True), (batch[2])
            if self.args.cuda:
                input1 = input1.cuda()
                input2 = input2.cuda()
                target = target.cuda()

            target=torch.squeeze(target,1)
            mask = target < self.args.max_disp
            mask.detach_()
            valid = target[mask].size()[0]
            if valid > 0:
                self.scheduler(self.optimizer_F, i, epoch, self.best_pred)
                self.scheduler(self.optimizer_M, i, epoch, self.best_pred)
                self.optimizer_F.zero_grad()
                self.optimizer_M.zero_grad()
            
                output = self.model(input1, input2) 
                loss = F.smooth_l1_loss(output[mask], target[mask], reduction='mean')
                loss.backward()            
                self.optimizer_F.step()     
                self.optimizer_M.step()   

                if epoch >= self.args.alpha_epoch:
                    print("Start searching architecture!...........")
                    search = next(iter(self.train_loaderB))
                    input1_search, input2_search, target_search = Variable(search[0],requires_grad=True), Variable(search[1], requires_grad=True), (search[2])
                    if self.args.cuda:
                        input1_search = input1_search.cuda()
                        input2_search = input2_search.cuda()
                        target_search = target_search.cuda()

                    target_search=torch.squeeze(target_search,1)
                    mask_search = target_search < self.args.max_disp
                    mask_search.detach_()

                    self.architect_optimizer_F.zero_grad()
                    self.architect_optimizer_M.zero_grad()
                    output_search = self.model(input1_search, input2_search)
                    arch_loss = F.smooth_l1_loss(output_search[mask_search], target_search[mask_search], reduction='mean')

                    arch_loss.backward()            
                    self.architect_optimizer_F.step() 
                    self.architect_optimizer_M.step()   

                train_loss += loss.item()
                valid_iteration += 1
                tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
                self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            #Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image_stereo(self.writer, input1, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print("=== Train ===> Epoch :{} Error: {:.4f}".format(epoch, train_loss/valid_iteration))
        print(self.model.module.feature.alphas)

        #save checkpoint every epoch
        is_best = False
        if torch.cuda.device_count() > 1:
           state_dict = self.model.module.state_dict()
        else:
           state_dict = self.model.state_dict()
        self.saver.save_checkpoint({
               'epoch': epoch + 1,
               'state_dict': state_dict,
               'optimizer_F': self.optimizer_F.state_dict(),
               'optimizer_M': self.optimizer_M.state_dict(),
               'best_pred': self.best_pred,
        }, is_best, filename='checkpoint_{}.pth.tar'.format(epoch))

    def validation(self, epoch):
        self.model.eval()
        
        epoch_error = 0
        three_px_acc_all = 0
        valid_iteration = 0

        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0

        for i, batch in enumerate(tbar):
            input1, input2, target = Variable(batch[0],requires_grad=False), Variable(batch[1], requires_grad=False), Variable(batch[2], requires_grad=False)
            if self.args.cuda:
                input1 = input1.cuda()
                input2 = input2.cuda()
                target = target.cuda()

            target=torch.squeeze(target,1)
            mask = target < self.args.max_disp
            mask.detach_()
            valid = target[mask].size()[0]

            if valid>0:
                with torch.no_grad():
                    output = self.model(input1, input2)

                    error = torch.mean(torch.abs(output[mask] - target[mask]))
                    epoch_error += error.item()

                    valid_iteration += 1

                    #computing 3-px error#                
                    pred_disp = output.cpu().detach()                                                                                                                          
                    true_disp = target.cpu().detach()
                    disp_true = true_disp
                    index = np.argwhere(true_disp<opt.max_disp)
                    disp_true[index[0][:], index[1][:], index[2][:]] = np.abs(true_disp[index[0][:], index[1][:], index[2][:]]-pred_disp[index[0][:], index[1][:], index[2][:]])
                    correct = (disp_true[index[0][:], index[1][:], index[2][:]] < 1)|(disp_true[index[0][:], index[1][:], index[2][:]] < true_disp[index[0][:], index[1][:], index[2][:]]*0.05)      
                    three_px_acc = 1-(float(torch.sum(correct))/float(len(index[0])))

                    three_px_acc_all += three_px_acc
                    print("===> Test({}/{}): Error(EPE): ({:.4f} {:.4f})".format(i, len(self.val_loader), error.item(),three_px_acc))

        self.writer.add_scalar('val/EPE', epoch_error/valid_iteration, epoch)
        self.writer.add_scalar('val/D1_all', three_px_acc_all/valid_iteration, epoch)

        print("===> Test: Avg. Error: ({:.4f} {:.4f})".format(epoch_error/valid_iteration, three_px_acc_all/valid_iteration))


        # save model
        new_pred = epoch_error/valid_iteration # three_px_acc_all/valid_iteration
        if new_pred < self.best_pred: 
            is_best = True
            self.best_pred = new_pred
            if torch.cuda.device_count() > 1:
                state_dict = self.model.module.state_dict()
            else:
                state_dict = self.model.state_dict()
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': state_dict,
                'optimizer_F': self.optimizer_F.state_dict(),
                'optimizer_M': self.optimizer_M.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        # self.saver = Saver(args)
        # Recoder the running processing
        self.saver = Saver(args)
        sys.stdout = Logger(
            os.path.join(
                self.saver.experiment_dir,
                'log_train-%s.txt' % time.strftime("%Y-%m-%d-%H-%M-%S")))
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)
        if args.dataset == 'pairwise_lits':
            proxy_nclasses = self.nclass = 3
        elif args.dataset == 'pairwise_chaos':
            proxy_nclasses = 2 * self.nclass
        else:
            raise NotImplementedError

        # Define network
        model = ConsistentDeepLab(in_channels=3,
                                  num_classes=proxy_nclasses,
                                  pretrained=args.pretrained,
                                  backbone=args.backbone,
                                  output_stride=args.out_stride,
                                  sync_bn=args.sync_bn,
                                  freeze_bn=args.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        # optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
        #                            weight_decay=args.weight_decay, nesterov=args.nesterov)
        optimizer = torch.optim.Adam(train_params,
                                     weight_decay=args.weight_decay)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            weights = calculate_weigths_labels(args.dataset, self.train_loader,
                                               proxy_nclasses)
        else:
            weights = None

        # Initializing loss
        print("Initializing loss: {}".format(args.loss_type))
        self.criterion = losses.init_loss(args.loss_type, weights=weights)

        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, (sample1, sample2, proxy_label,
                sample_indices) in enumerate(tbar):
            image1, target1 = sample1['image'], sample1['label']
            image2, target2 = sample2['image'], sample2['label']
            if self.args.cuda:
                image1, target1 = image1.cuda(), target1.cuda()
                image2, target2 = image2.cuda(), target2.cuda()
                proxy_label = proxy_label.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image1, image2)
            loss = self.criterion(output, proxy_label)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                image = torch.cat((image1, image2), dim=-2)
                if len(proxy_label.shape) > 3:
                    output = output[:, 0:self.nclass]
                    proxy_label = torch.argmax(proxy_label[:, 0:self.nclass],
                                               dim=1)
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, proxy_label, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image1.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        val_time = 0
        for i, (sample1, sample2, proxy_label,
                sample_indices) in enumerate(tbar):
            image1, target1 = sample1['image'], sample1['label']
            image2, target2 = sample2['image'], sample2['label']
            if self.args.cuda:
                image1, target1 = image1.cuda(), target1.cuda()
                image2, target2 = image2.cuda(), target2.cuda()
                proxy_label = proxy_label.cuda()

            with torch.no_grad():
                start = time.time()
                output = self.model(image1, image2, is_val=True)
                end = time.time()
            val_time += end - start
            loss = self.criterion(output, proxy_label)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            proxy_label = proxy_label.cpu().numpy()

            # Add batch sample into evaluator
            if len(proxy_label.shape) > 3:
                pred = np.argmax(pred[:, 0:self.nclass], axis=1)
                proxy_label = np.argmax(proxy_label[:, 0:self.nclass], axis=1)
            else:
                pred = np.argmax(pred, axis=1)
            self.evaluator.add_batch(proxy_label, pred)

            if self.args.save_predict:
                self.saver.save_predict_mask(
                    pred, sample_indices, self.val_loader.dataset.data1_files)

        print("Val time: {}".format(val_time))
        print("Total paramerters: {}".format(
            sum(x.numel() for x in self.model.parameters())))
        if self.args.save_predict:
            namelist = []
            for fname in self.val_loader.dataset.data1_files:
                # namelist.append(fname.split('/')[-1].split('.')[0])
                _, name = os.path.split(fname)
                name = name.split('.')[0]
                namelist.append(name)
            file = gzip.open(
                os.path.join(self.saver.save_dir, 'namelist.pkl.gz'), 'wb')
            pickle.dump(namelist, file, protocol=-1)
            file.close()

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image1.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        dice = self.evaluator.Dice()
        # self.writer.add_scalar('val/Dice_1', dice[1], epoch)
        self.writer.add_scalar('val/Dice_2', dice[2], epoch)
        print("Dice:{}".format(dice))

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Пример #10
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        if args.loss_type == 'depth_loss_two_distributions':
            self.nclass = args.num_class + args.num_class2 + 1
        if args.loss_type == 'depth_avg_sigmoid_class':
            self.nclass = args.num_class + args.num_class2
        # Define network

        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)


        print("\nDefine models...\n")
        self.model_aprox_depth = DeepLab(num_classes=1,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=args.sync_bn,
                             freeze_bn=args.freeze_bn)

        self.input_conv = nn.Conv2d(4, 3, 3, padding=1)
        # Using cuda
        if args.cuda:
            self.model_aprox_depth = torch.nn.DataParallel(self.model_aprox_depth, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model_aprox_depth)
            self.model_aprox_depth = self.model_aprox_depth.cuda()
            self.input_conv = self.input_conv.cuda()


        print("\nLoad checkpoints...\n")
        if not args.cuda:
            ckpt_aprox_depth = torch.load(args.ckpt, map_location='cpu')
            self.model_aprox_depth.load_state_dict(ckpt_aprox_depth['state_dict'])
        else:
            ckpt_aprox_depth = torch.load(args.ckpt)
            self.model_aprox_depth.module.load_state_dict(ckpt_aprox_depth['state_dict'])


        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        # set optimizer
        optimizer = torch.optim.Adam(train_params, args.lr)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        if 'depth' in args.loss_type:
            self.criterion = DepthLosses(weight=weight,
                                         cuda=args.cuda,
                                         min_depth=args.min_depth,
                                         max_depth=args.max_depth,
                                         num_class=args.num_class,
                                         cut_point=args.cut_point,
                                         num_class2=args.num_class2).build_loss(mode=args.loss_type)
            self.infer = DepthLosses(weight=weight,
                                     cuda=args.cuda,
                                     min_depth=args.min_depth,
                                     max_depth=args.max_depth,
                                     num_class=args.num_class,
                                     cut_point=args.cut_point,
                                     num_class2=args.num_class2)
            self.evaluator_depth = EvaluatorDepth(args.batch_size)
        else:
            self.criterion = SegmentationLosses(cuda=args.cuda, weight=weight).build_loss(mode=args.loss_type)
            self.evaluator = Evaluator(self.nclass)

        self.model, self.optimizer = model, optimizer

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        if 'depth' in args.loss_type:
            self.best_pred = 1e6
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            if not args.cuda:
                checkpoint = torch.load(args.resume, map_location='cpu')
            else:
                checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                state_dict = checkpoint['state_dict']
                state_dict.popitem(last=True)
                state_dict.popitem(last=True)
                self.model.module.load_state_dict(state_dict, strict=False)
            else:
                state_dict = checkpoint['state_dict']
                state_dict.popitem(last=True)
                state_dict.popitem(last=True)
                self.model.load_state_dict(state_dict, strict=False)
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            if 'depth' in args.loss_type:
                self.best_pred = 1e6
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

        # add input layer to the model
        self.model = nn.Sequential(
            self.input_conv,
            self.model
        )
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        self.model_aprox_depth.eval()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.dataset == 'apollo_seg' or self.args.dataset == 'farsight_seg':
                target[target <= self.args.cut_point] = 0
                target[target > self.args.cut_point] = 1
            if image.shape[0] == 1:
                target = torch.cat([target, target], dim=0)
                image = torch.cat([image, image], dim=0)
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            aprox_depth = self.model_aprox_depth(image)
            aprox_depth = self.infer.sigmoid(aprox_depth)
            input = torch.cat([image, aprox_depth], dim=1)
            output = self.model(input)
            if self.args.loss_type == 'depth_sigmoid_loss_inverse':
                loss = self.criterion(output, target, inverse=True)
            else:
                loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                target[
                    torch.isnan(target)] = 0  # change nan values to zero for display (handle warning from tensorboard)
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step,
                                             n_class=self.args.num_class)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.model_aprox_depth.eval()
        if 'depth' in self.args.loss_type:
            self.evaluator_depth.reset()
        else:
            softmax = nn.Softmax(1)
            self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                aprox_depth = self.model_aprox_depth(image)
                aprox_depth = self.infer.sigmoid(aprox_depth)
                input = torch.cat([image, aprox_depth], dim=1)
                output = self.model(input)
                loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Val loss: %.3f' % (test_loss / (i + 1)))
            if 'depth' in self.args.loss_type:
                if self.args.loss_type == 'depth_loss':
                    pred = self.infer.pred_to_continous_depth(output)
                elif self.args.loss_type == 'depth_avg_sigmoid_class':
                    pred = self.infer.pred_to_continous_depth_avg(output)
                elif self.args.loss_type == 'depth_loss_combination':
                    pred = self.infer.pred_to_continous_combination(output)
                elif self.args.loss_type == 'depth_loss_two_distributions':
                    pred = self.infer.pred_to_continous_depth_two_distributions(output)
                elif 'depth_sigmoid_loss' in self.args.loss_type:
                    output = self.infer.sigmoid(output.squeeze(1))
                    pred = self.infer.depth01_to_depth(output)
                # Add batch sample into evaluator
                self.evaluator_depth.evaluateError(pred, target)
            else:
                output = softmax(output)
                pred = output.data.cpu().numpy()
                target = target.cpu().numpy()
                pred = np.argmax(pred, axis=1)
                # Add batch sample into evaluator
                self.evaluator.add_batch(target, pred)
        if 'depth' in self.args.loss_type:
            # Fast test during the training
            MSE = self.evaluator_depth.averageError['MSE']
            RMSE = self.evaluator_depth.averageError['RMSE']
            ABS_REL = self.evaluator_depth.averageError['ABS_REL']
            LG10 = self.evaluator_depth.averageError['LG10']
            MAE = self.evaluator_depth.averageError['MAE']
            DELTA1 = self.evaluator_depth.averageError['DELTA1']
            DELTA2 = self.evaluator_depth.averageError['DELTA2']
            DELTA3 = self.evaluator_depth.averageError['DELTA3']

            self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
            self.writer.add_scalar('val/MSE', MSE, epoch)
            self.writer.add_scalar('val/RMSE', RMSE, epoch)
            self.writer.add_scalar('val/ABS_REL', ABS_REL, epoch)
            self.writer.add_scalar('val/LG10', LG10, epoch)

            self.writer.add_scalar('val/MAE', MAE, epoch)
            self.writer.add_scalar('val/DELTA1', DELTA1, epoch)
            self.writer.add_scalar('val/DELTA2', DELTA2, epoch)
            self.writer.add_scalar('val/DELTA3', DELTA3, epoch)

            print('Validation:')
            print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
            print(
                "MSE:{}, RMSE:{}, ABS_REL:{}, LG10: {}\nMAE:{}, DELTA1:{}, DELTA2:{}, DELTA3: {}".format(MSE, RMSE,
                                                                                                         ABS_REL,
                                                                                                         LG10, MAE,
                                                                                                         DELTA1,
                                                                                                         DELTA2,
                                                                                                         DELTA3))
            new_pred = RMSE
            if new_pred < self.best_pred:
                is_best = True
                self.best_pred = new_pred
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
        else:
            # Fast test during the training
            Acc = self.evaluator.Pixel_Accuracy()
            Acc_class = self.evaluator.Pixel_Accuracy_Class()
            mIoU = self.evaluator.Mean_Intersection_over_Union()
            FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
            self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
            self.writer.add_scalar('val/mIoU', mIoU, epoch)
            self.writer.add_scalar('val/Acc', Acc, epoch)
            self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
            self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
            print('Validation:')
            print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
            print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
            print('Loss: %.3f' % test_loss)

            new_pred = mIoU
            if new_pred > self.best_pred:
                is_best = True
                self.best_pred = new_pred
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

        print('Loss: %.3f' % test_loss)
Пример #11
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}

        parameters.set_saved_parafile_path(args.para)
        patch_w = parameters.get_digit_parameters("", "train_patch_width", None, 'int')
        patch_h = parameters.get_digit_parameters("", "train_patch_height", None, 'int')
        overlay_x = parameters.get_digit_parameters("", "train_pixel_overlay_x", None, 'int')
        overlay_y = parameters.get_digit_parameters("", "train_pixel_overlay_y", None, 'int')
        crop_height = parameters.get_digit_parameters("", "crop_height", None, 'int')
        crop_width = parameters.get_digit_parameters("", "crop_width", None, 'int')

        dataset = RemoteSensingImg(args.dataroot, args.list, patch_w, patch_h, overlay_x, overlay_y)

        #train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
        #                                           num_workers=args.workers, shuffle=True)
        train_length = int(len(dataset) * 0.9)
        validation_length = len(dataset) - train_length
	#print ("totol data len is %d , train_length is %d"%(len(train_loader),train_length))	
        [self.train_dataset, self.val_dataset] = torch.utils.data.random_split(dataset, (train_length, validation_length))
        print("len of train dataset is %d and val dataset is %d and total datalen is %d"%(len(self.train_dataset),len(self.val_dataset),len(dataset)))
        self.train_loader=torch.utils.data.DataLoader(self.train_dataset, batch_size=args.batch_size,num_workers=args.workers, shuffle=True,drop_last=True)
        self.val_loader=torch.utils.data.DataLoader(self.val_dataset, batch_size=args.batch_size,num_workers=args.workers, shuffle=True,drop_last=True)
        print("len of train loader is %d and val loader is %d"%(len(self.train_loader),len(self.val_loader)))
	#self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
	
        # Define network
        model = DeepLab(num_classes=1,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)




        # whether to use class balanced weights
        # if args.use_balanced_weights:
        #     classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
        #     if os.path.isfile(classes_weights_path):
        #         weight = np.load(classes_weights_path)
        #     else:
        #         weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
        #     weight = torch.from_numpy(weight.astype(np.float32))
        # else:
        #     weight = None





        # Define Criterion
        #self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)


        self.criterion=nn.BCELoss()

        if args.cuda:
            self.criterion=self.criterion.cuda()


        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(2)
        # Define lr scheduler
        print("lenght of train_loader is %d"%(len(self.train_loader)))
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            #self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {}) with best mIoU {}"
                  .format(args.resume, checkpoint['epoch'], checkpoint['best_pred']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
            self.best_pred=0

    def training(self, epoch):
        train_start_time=time.time()
        train_loss = 0.0
        self.model.train()
        #tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        print("start training at epoch %d, with the training length of %d"%(epoch,num_img_tr))
        for i, (x, y) in enumerate(self.train_loader):
            start_time=time.time()
            image, target = x, y
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            end_time=time.time()
            #tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)
            print('[The loss for iteration %d is %.3f and the time used is %.3f]'%(i+num_img_tr*epoch,loss.item(),end_time-start_time))
            # Show 10 * 3 inference results each epoch
            # if i % (num_img_tr // 10) == 0:
            #     global_step = i + num_img_tr * epoch
            #     self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        train_end_time=time.time()
        print('[Epoch: %d, numImages: %5d, time used : %.3f hour]' % (epoch, i * self.args.batch_size + image.data.shape[0],(train_end_time-train_start_time)/3600))
        print('Loss: %.3f' % (train_loss/len(self.train_loader)))
	
        with open(self.args.checkname+".train_out.txt", 'a') as log:
            out_massage='[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])
            log.writelines(out_massage+'\n')
            out_massage='Loss: %.3f' % (train_loss/len(self.train_loader))
            log.writelines(out_massage+'\n')
        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)


    def validation(self, epoch):
        time_val_start=time.time()
        self.model.eval()
        self.evaluator.reset()
        #tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, (x, y) in enumerate(self.val_loader):
            image, target = x,y
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            #tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            #pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            print("validate on the %d patch of total %d patch"%(i,len(self.val_loader)))
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        time_val_end=time.time()
        print('Validation:')
        print('[Epoch: %d, numImages: %5d, time used: %.3f hour]' % (epoch, len(self.val_loader), (time_val_end-time_val_start)/3600))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Validation Loss: %.3f' % (test_loss/len((self.val_loader))))

        with open(self.args.checkname+".train_out.txt", 'a') as log:
            out_message='Validation:'
            log.writelines(out_message+'\n')
            out_message="Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)
            log.writelines(out_message+'\n')
            out_message='Validation Loss: %.3f' % (test_loss/len((self.val_loader)))
            log.writelines(out_message+'\n')
        new_pred = mIoU

        if new_pred > self.best_pred:
            print("saveing model")
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best,self.args.checkname)
            return False
        else:
            return True
Пример #12
0
class MyTrainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}


        if (args.dataset == "fashion_person"):

            train_set = fashion.FashionDataset(args, Path.db_root_dir("fashion_person"), mode='train',type = 'person')
            val_set = fashion.FashionDataset(args, Path.db_root_dir("fashion_person"), mode='test', type='person')
            self.nclass = train_set.nclass



            print("Train size {}, val size {}".format(len(train_set), len(val_set)))


            self.train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True,
                                   **kwargs)
            self.val_loader = DataLoader(dataset=val_set, batch_size=args.batch_size, shuffle=False,
                                   **kwargs)
            self.test_loader = None

            assert self.nclass == 2

        elif (args.dataset == "fashion_clothes"):
            train_set = fashion.FashionDataset(args, Path.db_root_dir("fashion_clothes"), mode='train', type='clothes')
            val_set = fashion.FashionDataset(args, Path.db_root_dir("fashion_clothes"), mode='test', type='clothes')
            self.nclass = train_set.nclass

            print("Train size {}, val size {}".format(len(train_set), len(val_set)))

            self.train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True,
                                           **kwargs)
            self.val_loader = DataLoader(dataset=val_set, batch_size=args.batch_size, shuffle=False,
                                         **kwargs)
            self.test_loader = None

            assert self.nclass == 7



        #self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        # model = DeepLab(num_classes=self.nclass,
        #                 backbone=args.backbone,
        #                 output_stride=args.out_stride,
        #                 sync_bn=args.sync_bn,
        #                 freeze_bn=args.freeze_bn)
        # Using original network to load pretrained and do fine tuning


        self.best_pred = 0.0

        if args.model == 'deeplabv3+':
            model = DeepLab(backbone=args.backbone,
                            output_stride=args.out_stride,
                            sync_bn=args.sync_bn,
                            freeze_bn=args.freeze_bn)

            # Loading pretrained VOC model
            if args.resume is not None:
                if not os.path.isfile(args.resume):
                    raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
                if args.cuda:
                    checkpoint = torch.load(args.resume)
                else:
                    checkpoint = torch.load(args.resume,map_location='cpu')
                args.start_epoch = checkpoint['epoch']

                model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(args.resume, checkpoint['epoch']))

            #Freez the backbone
            if args.freeze_backbone:
                set_parameter_requires_grad(model.backbone, False)

            ######NEW DECODER######
            #Different type of FT
            if args.ft_type == 'decoder':
                set_parameter_requires_grad(model, False)
                model.decoder = build_decoder(self.nclass, 'resnet', nn.BatchNorm2d)
                train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                                {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

            elif args.ft_type == 'last_layer':
                set_parameter_requires_grad(model, False)
                model.decoder.last_conv[8] = nn.Conv2d(in_channels=256, out_channels=self.nclass, kernel_size=1)
                model.decoder.last_conv[8].reset_parameters()
                train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                                {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
            if args.ft_type == 'all':
                #Reset last layer, to generate output we want
                model.decoder.last_conv[8] = nn.Conv2d(in_channels=256, out_channels=self.nclass, kernel_size=1)
                model.decoder.last_conv[8].reset_parameters()

                train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                            {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]



            # Define Optimizer
            optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                        weight_decay=args.weight_decay, nesterov=args.nesterov)


        elif args.model == "unet":
            model = UNet(num_categories=self.nclass, num_filters=args.num_filters)

            optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
                                        weight_decay=args.weight_decay)

        elif args.model == 'mydeeplab':

            model = My_DeepLab(num_classes=self.nclass, in_channels=3)
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,
                                        weight_decay=args.weight_decay, nesterov=args.nesterov)


        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
            print("weight is {}".format(weight))
        else:
            weight = None

        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        self.model, self.optimizer = model, optimizer




        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            # TODO, ADD PARALLEL SUPPORT (NEED SYNC BATCH)
            # self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            # patch_replication_callback(self.model)
            self.model = self.model.cuda()

        args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def visulize_validation(self):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            #current_index_val_set
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)

            #we have image, target, output on GPU
            #j, index of image in batch

            self.summary.visualize_pregt(self.writer, self.args.dataset, image, target, output, i)

            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Visualizing:')
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print('Final Validation:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

    def output_validation(self):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            #current_index_val_set
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)

            #we have image, target, output on GPU
            #j, index of image in batch

            #image save
            self.summary.save_pred(self.args.dataset, output, i)

            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Visualizing:')
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print('Final Validation:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)


    def _load_model(self, path):
        if self.args.cuda:
            checkpoint = torch.load(path)
        else:
            checkpoint = torch.load(path, map_location='cpu')

        self.model.load_state_dict(checkpoint['state_dict'])


    def train_loop(self):
        try:
            for epoch in range(self.args.start_epoch, self.args.epochs):
                self.training(epoch)
                if not self.args.no_val and epoch % self.args.eval_interval == (self.args.eval_interval - 1):
                    self.validation(epoch)
        except KeyboardInterrupt:
            print('Early Stopping')
        finally:
            self.visulize_validation()
            self.writer.close()
Пример #13
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': False}
        if args.dataset == 'click':
            extract_hard_example(args, batch_size=32, recal=False)
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        sbox = DeepLabX(pretrain=False)
        sbox.load_state_dict(
            torch.load('run/sbox_513_8925.pth.tar',
                       map_location=torch.device('cuda:0'))['state_dict'])
        click = ClickNet()
        model = FusionNet(sbox=sbox, click=click, pos_limit=2, neg_limit=2)
        model.sbox_net.eval()
        for para in model.sbox_net.parameters():
            para.requires_grad = False

        train_params = [
            {
                'params': model.click_net.parameters(),
                'lr': args.lr
            },
            # {'params': model.sbox_net.get_1x_lr_params(), 'lr': args.lr*0.001}
            # {'params': model.sbox_net.get_train_click_params(), 'lr': args.lr*0.001}
        ]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            # self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            # patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        self.model.sbox_net.eval()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, gt = sample['crop_image'], sample['crop_gt']
            if self.args.cuda:
                image, gt = image.cuda(), gt.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            sbox_pred, click_pred, sum_pred = self.model(image, crop_gt=gt)
            sum_pred = F.interpolate(sum_pred,
                                     size=gt.size()[-2:],
                                     align_corners=True,
                                     mode='bilinear')
            sbox_pred = F.interpolate(sbox_pred,
                                      size=gt.size()[-2:],
                                      align_corners=True,
                                      mode='bilinear')
            loss1 = self.criterion(sum_pred, gt) \
                # + self.criterion(sbox_pred, gt)

            loss1.backward()
            self.optimizer.step()
            total_loss = loss1.item()
            train_loss += total_loss
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_steps', total_loss,
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                grid_image = make_grid(decode_seg_map_sequence(
                    torch.max(sbox_pred[:3], 1)[1].detach().cpu().numpy(),
                    dataset=self.args.dataset),
                                       3,
                                       normalize=False,
                                       range=(0, 255))
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, sample['crop_gt'],
                                             sum_pred, global_step)
                self.writer.add_image('sbox_pred', grid_image, global_step)

        self.writer.add_scalar('train/total_epochs', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        total_clicks = 0
        for i, sample in enumerate(tbar):
            image, gt = sample['crop_image'], sample['crop_gt']
            if self.args.cuda:
                image, gt = image.cuda(), gt.cuda()
            with torch.no_grad():
                sbox_pred, click_pred, sum_pred = self.model(image, crop_gt=gt)
                # sum_pred, clicks = self.model.click_eval(image, gt)
            # total_clicks += clicks
            sum_pred = F.interpolate(sum_pred,
                                     size=gt.size()[-2:],
                                     align_corners=True,
                                     mode='bilinear')
            loss1 = self.criterion(sum_pred, gt)
            total_loss = loss1.item()
            test_loss += total_loss
            pred = sum_pred.data.cpu().numpy()
            target = gt.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_epochs', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)
        # print('total clicks:' , total_clicks)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                },
                is_best,
                prefix='click')
Пример #14
0
class operater(object):
    def __init__(self, args, student_model, teacher_model, src_loader,
                 trg_loader, val_loader, optimizer, teacher_optimizer):

        self.args = args
        self.student_model = student_model
        self.teacher_model = teacher_model
        self.src_loader = src_loader
        self.trg_loader = trg_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.teacher_optimizer = teacher_optimizer
        # Define Evaluator
        self.evaluator = Evaluator(args.nclass)
        # Define lr scheduler
        # self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
        #                          args.epochs, len(trn_loader))
        #self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[3, 6, 9, 12], gamma=0.5)
        #ft
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer,
                                                              milestones=[20],
                                                              gamma=0.5)
        self.best_pred = 0
        self.init_weight = 0.98
        # Define Saver
        self.saver = Saver(self.args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.evaluator = Evaluator(self.args.nclass)

    def training(self, epoch, args):
        train_loss = 0.0
        self.student_model.train()
        self.teacher_model.train()
        num_src = len(self.src_loader)
        num_trg = len(self.trg_loader)
        num_itr = np.maximum(num_src, num_trg)
        tbar = tqdm(range(1, num_itr + 1))
        #w1 = 0.2 + 0.5 * (self.init_weight - 0.2) * (1 + np.cos(epoch * np.pi / args.epochs))
        print('Learning rate:', self.optimizer.param_groups[0]['lr'])
        iter_src = iter(self.src_loader)
        iter_trg = iter(self.trg_loader)
        for i in tbar:

            src_x1, src_x2, src_y, src_idx = iter_src.next()
            trg_x1, trg_x2, trg_y, trg_idx = iter_trg.next()

            if i % num_src == 0:
                iter_src = iter(self.src_loader)

            if self.args.cuda:
                src_x1, src_x2 = src_x1.cuda(), src_x2.cuda()
                trg_x1, trg_x2 = trg_x1.cuda(), trg_x2.cuda()

            self.optimizer.zero_grad()

            # train with source

            _, _, src_output = self.student_model(src_x1, src_x2)

            src_output = F.softmax(src_output, dim=1)

            # CE loss of supervised data

            #loss_ce=CELossLayer(src_output,src_y)

            # #print('ce loss', loss_ce)
            # Focal loss of supervised data
            loss_focal = FocalLossLayer(src_output, src_y)
            #print('focal loss', loss_focal)

            loss_val_lovasz = LovaszLossLayer(src_output, src_y)
            #print('lovasz loss', loss_lovasz)

            if epoch > 3:
                loss_su = loss_val_lovasz + loss_focal
            else:
                loss_su = loss_val_lovasz + loss_focal

            # train with target

            trg_x1_s = trg_x1 + torch.randn(
                trg_x1.size()).cuda() * self.args.noise
            trg_x1_t = trg_x1 + torch.randn(
                trg_x1.size()).cuda() * self.args.noise

            trg_x2_s = trg_x2 + torch.randn(
                trg_x2.size()).cuda() * self.args.noise
            trg_x2_t = trg_x2 + torch.randn(
                trg_x2.size()).cuda() * self.args.noise

            _, _, trg_predict_s = self.student_model(trg_x1_s, trg_x2_s)

            _, spatial_mask_prob, trg_predict_t = self.teacher_model(
                trg_x1_t, trg_x2_t)

            trg_predict_s = F.softmax(trg_predict_s, dim=1)
            trg_predict_t = F.softmax(trg_predict_t, dim=1)

            loss_tes_lovasz = LovaszLossLayer(trg_predict_s, trg_y)

            # spatial mask

            #channel_mask = channel_mask_prob > args.attention_threshold
            spatial_mask = spatial_mask_prob > args.attention_threshold

            spatial_mask = spatial_mask.float()

            #spatial_mask = spatial_mask.permute(0,2,3,1)# N,H,W,C

            #channel_mask = channel_mask.float()
            #spatial_mask = spatial_mask.view(-1)

            num_pixel = spatial_mask.shape[0] * spatial_mask.shape[
                -2] * spatial_mask.shape[-1]

            mask_num_rate = torch.sum(spatial_mask).float() / num_pixel

            # trg_output_s = trg_output_s.permute(0, 2, 3, 1)#N,H,W,C
            # trg_output_t = trg_output_t.permute(0, 2, 3, 1)

            #trg_output_s = trg_output_s * channel_mask
            trg_predict_s = trg_predict_s * spatial_mask

            #trg_output_t = trg_output_t * channel_mask
            trg_predict_t = trg_predict_t * spatial_mask

            # trg_output_s = trg_output_s.contiguous().view(-1, self.args.nclass)
            # trg_output_s = trg_output_s[spatial_mask]
            #
            # trg_output_t = trg_output_t.contiguous().view(-1, self.args.nclass)
            # trg_output_t = trg_output_t[spatial_mask]

            # consistency loss

            loss_con = ConsistencyLossLayer(trg_predict_s, trg_predict_t)

            if mask_num_rate == 0.:

                loss_con = torch.tensor(0.).float().cuda()

            loss = loss_su + self.args.con_weight * loss_con + self.args.teslab_weight * loss_tes_lovasz

            #self.writer.add_scalar('train/ce_loss_iter', loss_ce.item(), i + num_itr * epoch)
            self.writer.add_scalar('train/focal_loss_iter', loss_focal.item(),
                                   i + num_itr * epoch)
            self.writer.add_scalar('train/supervised_loss_iter',
                                   loss_su.item(), i + num_itr * epoch)
            self.writer.add_scalar('train/consistency_loss_iter',
                                   loss_con.item(), i + num_itr * epoch)
            self.writer.add_scalar('train/teslab_loss_iter',
                                   loss_tes_lovasz.item(), i + num_itr * epoch)
            #loss = w1 * loss_ce + (0.5 - 0.5 * w1) * loss_focal + (0.5 - 0.5 * w1) * loss_lovasz

            loss.backward()
            self.optimizer.step()
            self.teacher_optimizer.step()

            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_itr * epoch)

            #Show 10 * 3 inference results each epoch
            if i % 10 == 0:
                global_step = i + num_itr * epoch
                if self.args.oly_s1 and not self.args.oly_s2:
                    self.summary.visualize_image(
                        self.writer, self.args.dataset, src_x1[:, [0], :, :],
                        trg_x1[:, [0], :, :], src_y, src_output, trg_predict_s,
                        trg_predict_t, trg_y, global_step)
                elif not self.args.oly_s1:
                    if self.args.rgb:
                        self.summary.visualize_image(self.writer,
                                                     self.args.dataset, src_x2,
                                                     trg_x2, src_y, src_output,
                                                     trg_predict_s,
                                                     trg_predict_t, trg_y,
                                                     global_step)
                    else:
                        self.summary.visualize_image(
                            self.writer, self.args.dataset,
                            src_x2[:, [2, 1, 0], :, :],
                            trg_x2[:, [2, 1, 0], :, :], src_y, src_output,
                            trg_predict_s, trg_predict_t, trg_y, global_step)
                else:
                    raise NotImplementedError

        self.scheduler.step()
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + src_y.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'student_state_dict':
                    self.student_model.module.state_dict(),
                    'teacher_state_dict':
                    self.teacher_model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                }, is_best)

    def validation(self, epoch, args):
        self.teacher_model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        for i, (x1, x2, y, index) in enumerate(tbar):

            if self.args.cuda:
                x1, x2 = x1.cuda(), x2.cuda()
            with torch.no_grad():
                _, _, output = self.teacher_model(x1, x2)

            output = F.softmax(output, dim=1)
            pred = output.data.cpu().numpy()
            #pred[:,[2,7],:,:]=0
            target = y[:, 0, :, :].cpu().numpy()  # batch_size * 256 * 256
            pred = np.argmax(pred, axis=1)  # batch_size * 256 * 256
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        OA = self.evaluator.Pixel_Accuracy()
        AA = self.evaluator.val_Pixel_Accuracy_Class()
        self.writer.add_scalar('val/OA', OA, epoch)
        self.writer.add_scalar('val/AA', AA, epoch)

        print('AVERAGE ACCURACY:', AA)

        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + y.data.shape[0]))

        new_pred = AA
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'student_state_dict':
                    self.student_model.module.state_dict(),
                    'teacher_state_dict':
                    self.teacher_model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Пример #15
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Пример #16
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Generate .npy file for dataloader
        self.img_process(args)

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = getattr(modeling, args.model_name)(pretrained=args.pretrained)

        # Define Optimizer
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)
        # train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
        #                 {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Criterion
        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    # 将大图按unit_size的大小,每次stride的移动量进行裁剪。将分好的训练集和验证机以np数组形式存储在save_dir中,
    # 方便下次使用,并减少内存的占用。请将路径修改为自己的。
    def img_process(self, args):
        unit_size = args.base_size
        stride = unit_size  # int(unit_size/2)
        save_dir = os.path.join(
            '/data/dingyifeng/pytorch-jingwei-master/npy_process',
            str(unit_size))
        # npy_process
        if not os.path.exists(save_dir):

            Image.MAX_IMAGE_PIXELS = 100000000000
            # load train image 1
            img = Image.open(
                '/data/dingyifeng/jingwei/jingwei_round1_train_20190619/image_1.png'
            )
            img = np.asarray(img)  #(50141, 47161, 4)
            anno_map = Image.open(
                '/data/dingyifeng/jingwei/jingwei_round1_train_20190619/image_1_label.png'
            )
            anno_map = np.asarray(anno_map)  #(50141, 47161)

            length, width = img.shape[0], img.shape[1]
            x1, x2, y1, y2 = 0, unit_size, 0, unit_size
            Img1 = []  # 保存小图的数组
            Label1 = []  # 保存label的数组
            while (x1 < length):
                #判断横向是否越界
                if x2 > length:
                    x2, x1 = length, length - unit_size

                while (y1 < width):
                    if y2 > width:
                        y2, y1 = width, width - unit_size
                    im = img[x1:x2, y1:y2, :]
                    if 255 in im[:, :, -1]:  # 判断裁剪出来的小图中是否存在有像素点
                        Img1.append(im[:, :, 0:3])  # 添加小图
                        Label1.append(anno_map[x1:x2, y1:y2])  # 添加label

                    if y2 == width: break

                    y1 += stride
                    y2 += stride

                if x2 == length: break

                y1, y2 = 0, unit_size
                x1 += stride
                x2 += stride
            Img1 = np.array(Img1)  #(4123, 448, 448, 3)
            Label1 = np.array(Label1)  #(4123, 448, 448)

            # load train image 2
            img = Image.open(
                '/data/dingyifeng/jingwei/jingwei_round1_train_20190619/image_2.png'
            )
            img = np.asarray(img)  #(50141, 47161, 4)
            anno_map = Image.open(
                '/data/dingyifeng/jingwei/jingwei_round1_train_20190619/image_2_label.png'
            )
            anno_map = np.asarray(anno_map)  #(50141, 47161)

            length, width = img.shape[0], img.shape[1]
            x1, x2, y1, y2 = 0, unit_size, 0, unit_size
            Img2 = []  # 保存小图的数组
            Label2 = []  # 保存label的数组
            while (x1 < length):
                #判断横向是否越界
                if x2 > length:
                    x2, x1 = length, length - unit_size

                while (y1 < width):
                    if y2 > width:
                        y2, y1 = width, width - unit_size
                    im = img[x1:x2, y1:y2, :]
                    if 255 in im[:, :, -1]:  # 判断裁剪出来的小图中是否存在有像素点
                        Img2.append(im[:, :, 0:3])  # 添加小图
                        Label2.append(anno_map[x1:x2, y1:y2])  # 添加label

                    if y2 == width: break

                    y1 += stride
                    y2 += stride

                if x2 == length: break

                y1, y2 = 0, unit_size
                x1 += stride
                x2 += stride
            Img2 = np.array(Img2)  #(5072, 448, 448, 3)
            Label2 = np.array(Label2)  #(5072, 448, 448)

            Img = np.concatenate((Img1, Img2), axis=0)
            cat = np.concatenate((Label1, Label2), axis=0)

            # shuffle
            np.random.seed(1)
            assert (Img.shape[0] == cat.shape[0])
            shuffle_id = np.arange(Img.shape[0])
            np.random.shuffle(shuffle_id)
            Img = Img[shuffle_id]
            cat = cat[shuffle_id]

            os.mkdir(save_dir)
            print("=> generate {}".format(unit_size))
            # split train dataset
            images_train = Img  #[:int(Img.shape[0]*0.8)]
            categories_train = cat  #[:int(cat.shape[0]*0.8)]
            assert (len(images_train) == len(categories_train))
            np.save(os.path.join(save_dir, 'train_img.npy'), images_train)
            np.save(os.path.join(save_dir, 'train_label.npy'),
                    categories_train)
            # split val dataset
            images_val = Img[int(Img.shape[0] * 0.8):]
            categories_val = cat[int(cat.shape[0] * 0.8):]
            assert (len(images_val) == len(categories_val))
            np.save(os.path.join(save_dir, 'val_img.npy'), images_val)
            np.save(os.path.join(save_dir, 'val_label.npy'), categories_val)

            print("=> img_process finished!")
        else:
            print("{} file already exists!".format(unit_size))
        for x in locals().keys():
            del locals()[x]
        # 释放内存
        import gc
        gc.collect()

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Пример #17
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        #self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.printer = args.printer

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass, self.class_names = make_data_loader(
            args, **kwargs)

        # Define network
        self.model = self.get_net()
        if args.net in {
                'deeplabv3p', 'wdeeplabv3p', 'wsegnet', 'segnet', 'unet'
        }:
            train_params = [{
                'params': self.model.get_1x_lr_params(),
                'lr': args.lr
            }, {
                'params': self.model.get_10x_lr_params(),
                'lr': args.lr * 10
            }]
        elif args.net in {'segnet', 'waveunet', 'unet', 'waveunet_v2'}:
            weight_p, bias_p = [], []
            for name, p in self.model.named_parameters():
                if 'bias' in name:
                    bias_p.append(p)
                else:
                    weight_p.append(p)
            train_params = [{
                'params': weight_p,
                'weight_decay': args.weight_decay,
                'lr': args.lr
            }, {
                'params': bias_p,
                'weight_decay': 0,
                'lr': args.lr
            }]
        else:
            train_params = None
            assert args.net in {
                'deeplabv3p', 'wdeeplabv3p', 'wsegnet', 'segnet', 'unet'
            }

        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    nesterov=args.nesterov)
        self.optimizer = optimizer
        # Define Optimizer

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        #self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.criterion = SegmentationLosses(
            weight=weight,
            cuda=args.cuda,
            batch_average=self.args.batch_average).build_loss(
                mode=args.loss_type)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.args.printer.pprint(
            'Using {} LR Scheduler!, initialization lr = {}'.format(
                args.lr_scheduler, args.lr))
        if self.args.net.startswith('deeplab'):
            self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                          args.epochs, len(self.train_loader))
        else:
            self.scheduler = LR_Scheduler(args.lr_scheduler,
                                          args.lr,
                                          args.epochs,
                                          len(self.train_loader),
                                          net=self.args.net)

        for key, value in self.args.__dict__.items():
            if not key.startswith('_'):
                self.printer.pprint('{} ==> {}'.format(key.rjust(24), value))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu,
                                               output_device=args.out_gpu)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if args.dataset in ['pascal', 'cityscapes']:
                #self.load_pretrained_model()
                #elif args.dataset == 'cityscapes':
                self.load_pretrained_model_cityscape()

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def get_net(self):
        model = None
        if self.args.net == 'deeplabv3p':
            model = DeepLabV3P(num_classes=self.nclass,
                               backbone=self.args.backbone,
                               output_stride=self.args.out_stride,
                               sync_bn=self.args.sync_bn,
                               freeze_bn=self.args.freeze_bn,
                               p_dropout=self.args.p_dropout)
        elif self.args.net == 'wdeeplabv3p':
            model = WDeepLabV3P(num_classes=self.nclass,
                                backbone=self.args.backbone,
                                output_stride=self.args.out_stride,
                                sync_bn=self.args.sync_bn,
                                freeze_bn=self.args.freeze_bn,
                                wavename=self.args.wn,
                                p_dropout=self.args.p_dropout)
        elif self.args.net == 'segnet':
            model = SegNet(num_classes=self.nclass, wavename=self.args.wn)
        elif self.args.net == 'unet':
            model = UNet(num_classes=self.nclass, wavename=self.args.wn)
        elif self.args.net == 'wsegnet':
            model = WSegNet(num_classes=self.nclass, wavename=self.args.wn)
        return model

    def load_pretrained_model(self):
        if not os.path.isfile(self.args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                self.args.resume))
        checkpoint = torch.load(self.args.resume,
                                map_location=self.args.gpu_map)
        try:
            self.args.start_epoch = checkpoint['epoch']
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            pre_model = checkpoint['state_dict']
        except:
            self.printer.pprint('What happened ?!')
            self.args.start_epoch = 0
            if self.args.net in {
                    'deeplabv3p', 'wdeeplabv3p', 'wsegnet', 'segnet', 'unet'
            }:
                pre_model = checkpoint
            elif self.args.net in {'waveunet', 'waveunet_v2'}:
                pre_model = checkpoint['state_dict']
        model_dict = self.model.state_dict()
        self.printer.pprint("=> loaded checkpoint '{}' (epoch {})".format(
            self.args.resume, self.args.start_epoch))
        for key in model_dict:
            self.printer.pprint('AAAA - key in model --> {}'.format(key))
        for key in pre_model:
            self.printer.pprint('BBBB - key in pre_model --> {}'.format(key))
        if self.args.net in {'deeplabv3p', 'wdeeplabv3p'}:
            pre_layers = [('module.' + k, v) for k, v in pre_model.items()
                          if 'module.' + k in model_dict]
            for key in pre_layers:
                self.printer.pprint('CCCC - key in pre_model --> {}'.format(
                    key[0]))
            model_dict.update(pre_layers)
            self.model.load_state_dict(model_dict)
        elif self.args.net in {'segnet', 'unet'}:
            pre_layers = [('module.' + k, v) for k, v in pre_model.items()
                          if 'module.' + k in model_dict]
            for key in pre_layers:
                self.printer.pprint('CCCC - key in pre_model --> {}'.format(
                    key[0]))
            model_dict.update(pre_layers)
            self.model.load_state_dict(model_dict)
        elif self.args.net in {'wsegnet'}:
            pre_layers = [('module.features.' + k[16:], v)
                          for k, v in pre_model.items()
                          if 'module.features.' + k[16:] in model_dict]
            for key in pre_layers:
                self.printer.pprint('CCCC - key in pre_model --> {}'.format(
                    key[0]))
            model_dict.update(pre_layers)
            self.model.load_state_dict(model_dict)
        elif self.args.net == 'wdeeplabv3p':
            pre_layers = [
                ('module.backbone.' + k[7:], v) for k, v in pre_model.items()
                if 'module.backbone.' + k[7:] in model_dict and (
                    v.shape == model_dict['module.backbone.' + k[7:]].shape)
            ]
            for key in pre_layers:
                self.printer.pprint('CCCC - key in pre_model --> {}'.format(
                    key[0]))
            model_dict.update(pre_layers)
            self.model.load_state_dict(model_dict)

    def load_pretrained_model_cityscape(self):
        if not os.path.isfile(self.args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'".format(
                self.args.resume))
        checkpoint = torch.load(self.args.resume,
                                map_location=self.args.gpu_map)
        try:
            self.args.start_epoch = 0
            pre_model = checkpoint['state_dict']
        except:
            self.printer.pprint('What happened ?!')
            self.args.start_epoch = 0
            pre_model = checkpoint
        self.printer.pprint("=> loaded checkpoint '{}' (epoch {})".format(
            self.args.resume, self.args.start_epoch))
        if self.args.net == 'deeplabv3p' or 'wdeeplabv3p_per':
            model_dict = self.model.state_dict()
            for key in model_dict:
                self.printer.pprint('AAAA - key in model --> {}'.format(key))
            for key in pre_model:
                self.printer.pprint(
                    'BBBB - key in pre_model --> {}'.format(key))
            pre_layers = [
                ('module.backbone.' + k, v) for k, v in pre_model.items()
                if 'module.backbone.' +
                k in model_dict and model_dict['module.backbone.' +
                                               k].shape == pre_model[k].shape
            ]
            model_dict.update(pre_layers)
            self.model.load_state_dict(model_dict)

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        num_img_tr = len(self.train_loader)
        time_epoch_begin = datetime.now()
        for i, sample in enumerate(self.train_loader):
            time_iter_begin = datetime.now()
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            time_iter_end = datetime.now()
            time_iter = time_iter_end - time_iter_begin
            time_iter_during = time_iter_end - self.args.time_begin
            if i % 10 == 0:
                self.printer.pprint('train: epoch = {:3d} / {:3d}, '
                                    'iter = {:4d} / {:5d}, '
                                    'loss = {:.3f} / {:.3f}, '
                                    'time = {} / {}, '
                                    'lr = {:.6f}'.format(
                                        epoch, self.args.epochs, i, num_img_tr,
                                        loss.item(), train_loss / (i + 1),
                                        time_iter, time_iter_during,
                                        self.optimizer.param_groups[0]['lr']))
        self.printer.pprint(
            '------------ Train_total_loss = {}, epoch = {}, Time = {}'.format(
                train_loss, epoch,
                datetime.now() - time_epoch_begin))
        self.printer.pprint(' ')
        if epoch % 10 == 0:
            filename = os.path.join(self.args.weight_root,
                                    'epoch_{}'.format(epoch) + '.pth.tar')
            torch.save(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, filename)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        num_img_val = len(self.val_loader)
        test_loss = 0.0
        time_epoch_begin = datetime.now()
        for i, sample in enumerate(self.val_loader):
            time_iter_begin = datetime.now()
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            #test_loss += loss.item()
            test_loss += loss
            _, pred = output.topk(1, dim=1)
            pred = pred.squeeze(dim=1)
            pred = pred.cpu().numpy()
            target = target.cpu().numpy()
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)
            time_iter_end = datetime.now()
            time_iter = time_iter_end - time_iter_begin
            time_iter_during = time_iter_end - self.args.time_begin
            self.printer.pprint('validation: epoch = {:3d} / {:3d}, '
                                'iter = {:4d} / {:5d}, '
                                'loss = {:.3f} / {:.3f}, '
                                'time = {} / {}'.format(
                                    epoch, self.args.epochs, i, num_img_val,
                                    loss.item(), test_loss / (i + 1),
                                    time_iter, time_iter_during))

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.printer.pprint('Validation, epoch = {}, Time = {}'.format(
            epoch,
            datetime.now() - time_epoch_begin))
        self.printer.pprint('------------ Total_loss = {}'.format(test_loss))
        self.printer.pprint(
            "------------ Acc: {:.4f}, mIoU: {:.4f}, fwIoU: {:.4f}".format(
                Acc, mIoU, FWIoU))
        self.printer.pprint('------------ Acc_class = {}'.format(Acc_class))
        Object_names = '\t'.join(self.class_names)
        Object_IoU = '\t'.join(
            ['{:0.3f}'.format(IoU * 100) for IoU in self.evaluator.IoU_class])
        self.printer.pprint('------------ ' + Object_names)
        self.printer.pprint('------------ ' + Object_IoU)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        if args.densecrfloss >0:
            self.densecrflosslayer = DenseCRFLoss(weight=args.densecrfloss, sigma_rgb=args.sigma_rgb, sigma_xy=args.sigma_xy, scale_factor=args.rloss_scale)
            print(self.densecrflosslayer)
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        train_celoss = 0.0
        train_crfloss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        softmax = nn.Softmax(dim=1)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            croppings = (target!=254).float()
            target[target==254]=255
            # Pixels labeled 255 are those unlabeled pixels. Padded region are labeled 254.
            # see function RandomScaleCrop in dataloaders/custom_transforms.py for the detail in data preprocessing
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            
            celoss = self.criterion(output, target)
            
            if self.args.densecrfloss ==0:
                loss = celoss
            else:
                max_output = (max(torch.abs(torch.max(output)), 
                                  torch.abs(torch.min(output))))
                mean_output = torch.mean(torch.abs(output)).item()
                # std_output = torch.std(output).item()
                probs = softmax(output) # /max_output*4
                denormalized_image = denormalizeimage(sample['image'], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
                densecrfloss = self.densecrflosslayer(denormalized_image,probs,croppings)
                if self.args.cuda:
                    densecrfloss = densecrfloss.cuda()
                loss = celoss + densecrfloss
                train_crfloss += densecrfloss.item()

                logits_copy = output.detach().clone().requires_grad_(True)
                max_output_copy = (max(torch.abs(torch.max(logits_copy)), 
                                  torch.abs(torch.min(logits_copy))))
                probs_copy = softmax(logits_copy) # /max_output_copy*4
                denormalized_image_copy = denormalized_image.detach().clone()
                croppings_copy = croppings.detach().clone()
                densecrfloss_copy = self.densecrflosslayer(denormalized_image_copy, probs_copy, croppings)

                @torch.no_grad()
                def add_grad_map(grad, plot_name):
                  if i % (num_img_tr // 10) == 0:
                    global_step = i + num_img_tr * epoch
                    batch_grads = torch.max(torch.abs(grad), dim=1)[0].detach().cpu().numpy()
                    color_imgs = []
                    for grad_img in batch_grads:
                        grad_img[0,0]=0
                        img = colorize(grad_img)[:,:,:3]
                        color_imgs.append(img)
                    color_imgs = torch.from_numpy(np.array(color_imgs).transpose([0, 3, 1, 2]))
                    grid_image = make_grid(color_imgs[:3], 3, normalize=False, range=(0, 255))
                    self.writer.add_image(plot_name, grid_image, global_step)

                output.register_hook(lambda grad: add_grad_map(grad, 'Grad Logits')) 
                probs.register_hook(lambda grad: add_grad_map(grad, 'Grad Probs')) 
                
                logits_copy.register_hook(lambda grad: add_grad_map(grad, 'Grad Logits Rloss')) 
                densecrfloss_copy.backward()

                if i % (num_img_tr // 10) == 0:
                  global_step = i + num_img_tr * epoch
                  img_entropy = torch.sum(-probs*torch.log(probs+1e-9), dim=1).detach().cpu().numpy()
                  color_imgs = []
                  for e in img_entropy:
                      e[0,0] = 0
                      img = colorize(e)[:,:,:3]
                      color_imgs.append(img)
                  color_imgs = torch.from_numpy(np.array(color_imgs).transpose([0, 3, 1, 2]))
                  grid_image = make_grid(color_imgs[:3], 3, normalize=False, range=(0, 255))
                  self.writer.add_image('Entropy', grid_image, global_step)

                  self.writer.add_histogram('train/total_loss_iter/logit_histogram', output, i + num_img_tr * epoch)
                  self.writer.add_histogram('train/total_loss_iter/probs_histogram', probs, i + num_img_tr * epoch)

                self.writer.add_scalar('train/total_loss_iter/rloss', densecrfloss.item(), i + num_img_tr * epoch)
                self.writer.add_scalar('train/total_loss_iter/max_output', max_output.item(), i + num_img_tr * epoch)
                self.writer.add_scalar('train/total_loss_iter/mean_output', mean_output, i + num_img_tr * epoch)


            loss.backward()
        
            self.optimizer.step()
            train_loss += loss.item()
            train_celoss += celoss.item()
            
            tbar.set_description('Train loss: %.3f = CE loss %.3f + CRF loss: %.3f' 
                             % (train_loss / (i + 1),train_celoss / (i + 1),train_crfloss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)
            self.writer.add_scalar('train/total_loss_iter/ce', celoss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        #if self.args.no_val:
        if self.args.save_interval:
            # save checkpoint every interval epoch
            is_best = False
            if (epoch + 1) % self.args.save_interval == 0:
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best, filename='checkpoint_epoch_{}.pth.tar'.format(str(epoch+1)))


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            target[target==254]=255
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Пример #19
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        if args.distributed:
            if args.local_rank ==0:
                self.saver = Saver(args)
                self.saver.save_experiment_config()
                # Define Tensorboard Summary
                self.summary = TensorboardSummary(self.saver.experiment_dir)
                self.writer = self.summary.create_summary()
        else:
            self.saver = Saver(args)
            self.saver.save_experiment_config()
            # Define Tensorboard Summary
            self.summary = TensorboardSummary(self.saver.experiment_dir)
            self.writer = self.summary.create_summary()

        # PATH = args.path
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        # self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = SCNN(nclass=self.nclass,backbone=args.backbone,output_stride=args.out_stride,cuda = args.cuda,extension=args.ext)


        # Define Optimizer
        # optimizer = torch.optim.SGD(model.parameters(),args.lr, momentum=args.momentum,
        #                             weight_decay=args.weight_decay, nesterov=args.nesterov)
        optimizer = torch.optim.Adam(model.parameters(), args.lr,weight_decay=args.weight_decay)

        # model, optimizer = amp.initialize(model,optimizer,opt_level="O1")

        # Define Criterion
        weight = None
        # criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        # self.criterion = SegmentationCELosses(weight=weight, cuda=args.cuda)
        # self.criterion = SegmentationCELosses(weight=weight, cuda=args.cuda)
        # self.criterion = FocalLoss(gamma=0, alpha=[0.2, 0.98], img_size=512*512)
        self.criterion1 = FocalLoss(gamma=5, alpha=[0.2, 0.98], img_size=512 * 512)
        self.criterion2 = disc_loss(delta_v=0.5, delta_d=3.0, param_var=1.0, param_dist=1.0,
                                    param_reg=0.001, EMBEDDING_FEATS_DIMS=21,image_shape=[512,512])

        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.val_loader),local_rank=args.local_rank)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()
            if args.distributed:
                self.model = DistributedDataParallel(self.model)
            # self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            # patch_replication_callback(self.model)


        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            filename = 'checkpoint.pth.tar'
            args.resume = os.path.join(self.saver.experiment_dir, filename)
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            # if args.cuda:
            #     self.model.module.load_state_dict(checkpoint['state_dict'])
            # else:
            self.model.load_state_dict(checkpoint['state_dict'])
            # if not args.ft:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))


    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        max_instances = 1
        for i, sample in enumerate(tbar):
            # image, target = sample['image'], sample['label']
            image, target, ins_target = sample['image'], sample['bin_label'], sample['label']
            # _target = target.cpu().numpy()
            # if np.max(_target) > max_instances:
            #     max_instances = np.max(_target)
            #     print(max_instances)

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)

            # if i % 10==0:
            #     misc.imsave('/mfc/user/1623600/.temp6/train_{:s}_epoch:{}_i:{}.png'.format(str(self.args.distributed),epoch,i),np.transpose(image[0].cpu().numpy(),(1,2,0)))
            #     os.chmod('/mfc/user/1623600/.temp6/train_{:s}_epoch:{}_i:{}.png'.format(str(self.args.distributed),epoch,i),0o777)


            # self.criterion = DataParallelCriterion(self.criterion)
            loss1 = self.criterion1(output, target)
            loss2 = self.criterion2(output, ins_target)

            reg_lambda = 0.01


            loss = loss1 + 10*loss2
            # loss = loss1
            output=output[1]
            # loss.back
            # with amp.scale_loss(loss, self.optimizer) as scaled_loss:
            #     scaled_loss.backward()

            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

            if self.args.distributed:
                if self.args.local_rank == 0:
                    self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)
            else:
                self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)


            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr / 10) == 0:
                global_step = i + num_img_tr * epoch
                if self.args.distributed:
                    if self.args.local_rank == 0:
                        self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)
                else:
                    self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        if self.args.distributed:
            if self.args.local_rank == 0:
                self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        else:
            self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)

        if self.args.local_rank == 0:
            print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
            print('Loss: %.3f' % train_loss)

        # if self.args.distributed:
        #     if self.args.local_rank == 0:
        #         if self.args.no_val:
        #             # save checkpoint every epoch
        #             is_best = False
        #             self.saver.save_checkpoint({
        #                 'epoch': epoch + 1,
        #                 'state_dict': self.model.module.state_dict(),
        #                 'optimizer': self.optimizer.state_dict(),
        #                 'best_pred': self.best_pred,
        #             }, is_best)
        #     else:
        #         if self.args.no_val:
        #             # save checkpoint every epoch
        #             is_best = False
        #             self.saver.save_checkpoint({
        #                 'epoch': epoch + 1,
        #                 'state_dict': self.model.module.state_dict(),
        #                 'optimizer': self.optimizer.state_dict(),
        #                 'best_pred': self.best_pred,
        #             }, is_best)



    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            # image, target = sample['image'], sample['label']
            image, target = sample['image'], sample['bin_label']
            a= target.numpy()
            aa_max = np.max(a)
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion1(output, target)
            test_loss += loss.item()
            instance_seg = output[0].data.cpu().numpy()
            instance_seg = np.squeeze(instance_seg[0])
            instance_seg = np.transpose(instance_seg, (1, 2, 0))
            output = output[1]
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)
            if i % 30==0:
                misc.imsave('/mfc/user/1623600/.temp6/{:s}_val_epoch:{}_i:{}.png'
                            .format(str(self.args.distributed),epoch,i),
                            np.transpose(image[0].cpu().numpy(),(1,2,0))+3*np.asarray(np.stack((pred[0],pred[0],pred[0]),axis=-1),dtype=np.uint8))
                os.chmod('/mfc/user/1623600/.temp6/{:s}_val_epoch:{}_i:{}.png'.format(str(self.args.distributed),epoch,i),0o777)
                temp_instance_seg = np.zeros_like(np.transpose(image[0].cpu().numpy(),(1,2,0)))
                for j in range(21):
                    if j<7:
                        temp_instance_seg[:, :, 0] += instance_seg[:, :, j]
                    elif j<14:
                        temp_instance_seg[:, :, 1] += instance_seg[:, :, j]
                    else:
                        temp_instance_seg[:, :, 2] += instance_seg[:, :, j]

                for k in range(3):
                    temp_instance_seg[:, :, k] = self.minmax_scale(temp_instance_seg[:, :, k])

                instance_seg = np.array(temp_instance_seg, np.uint8)


                misc.imsave('/mfc/user/1623600/.temp6/emb_{:s}_val_epoch:{}_i:{}.png'
                            .format(str(self.args.distributed), epoch, i),instance_seg[...,:3])
                os.chmod(
                    '/mfc/user/1623600/.temp6/emb_{:s}_val_epoch:{}_i:{}.png'.format(str(self.args.distributed), epoch, i),
                    0o777)



        if self.args.distributed:
            if self.args.local_rank == 0:
                # Fast test during the training
                Acc = self.evaluator.Pixel_Accuracy()
                Acc_class = self.evaluator.Pixel_Accuracy_Class()
                mIoU = self.evaluator.Mean_Intersection_over_Union()
                F0 = self.evaluator.F0()
                self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
                self.writer.add_scalar('val/mIoU', mIoU, epoch)
                self.writer.add_scalar('val/Acc', Acc, epoch)
                self.writer.add_scalar('val/Acc_class', Acc_class, epoch)

                if self.args.local_rank == 0:
                    print('Validation:')
                    print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
                    print("Acc:{}, Acc_class:{}, mIoU:{}".format(Acc, Acc_class, mIoU))
                    print('Loss: %.3f' % test_loss)

                new_pred = F0
                if new_pred > self.best_pred:
                    is_best = True
                    self.best_pred = new_pred
                    self.saver.save_checkpoint({
                        'epoch': epoch + 1,
                        'state_dict': self.model.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'best_pred': self.best_pred,
                    }, is_best)
        else:
            # Fast test during the training
            Acc = self.evaluator.Pixel_Accuracy()
            Acc_class = self.evaluator.Pixel_Accuracy_Class()
            mIoU = self.evaluator.Mean_Intersection_over_Union()
            F0 = self.evaluator.F0()
            self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
            self.writer.add_scalar('val/mIoU', mIoU, epoch)
            self.writer.add_scalar('val/Acc', Acc, epoch)
            self.writer.add_scalar('val/Acc_class', Acc_class, epoch)

            if self.args.local_rank == 0:
                print('Validation:')
                print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
                print("Acc:{}, Acc_class:{}, mIoU:{}".format(Acc, Acc_class, mIoU))
                print('Loss: %.3f' % test_loss)

            new_pred = F0
            if new_pred > self.best_pred:
                is_best = True
                self.best_pred = new_pred
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def minmax_scale(self,input_arr):
        """

        :param input_arr:
        :return:
        """
        min_val = np.min(input_arr)
        max_val = np.max(input_arr)

        output_arr = (input_arr - min_val) * 255.0 / (max_val - min_val)

        return output_arr
Пример #20
0
class Trainer(object):
    def __init__(self, weight_path, resume, gpu_id):
        init_seeds(1)
        init_dirs("result")

        self.device = gpu.select_device(gpu_id)
        self.start_epoch = 0
        self.best_mIoU = 0.
        self.epochs = cfg.TRAIN["EPOCHS"]
        self.weight_path = weight_path

        self.train_loader, self.val_loader, _, self.num_class = make_data_loader(
        )

        self.model = DeepLab(num_classes=self.num_class,
                             backbone="resnet",
                             output_stride=16,
                             sync_bn=False,
                             freeze_bn=False).to(self.device)

        train_params = [{
            'params': self.model.get_1x_lr_params(),
            'lr': cfg.TRAIN["LR_INIT"]
        }, {
            'params': self.model.get_10x_lr_params(),
            'lr': cfg.TRAIN["LR_INIT"] * 10
        }]

        self.optimizer = optim.SGD(train_params,
                                   momentum=cfg.TRAIN["MOMENTUM"],
                                   weight_decay=cfg.TRAIN["WEIGHT_DECAY"])

        self.criterion = SegmentationLosses().build_loss(
            mode=cfg.TRAIN["LOSS_TYPE"])

        self.scheduler = LR_Scheduler(mode=cfg.TRAIN["LR_SCHEDULER"],
                                      base_lr=cfg.TRAIN["LR_INIT"],
                                      num_epochs=self.epochs,
                                      iters_per_epoch=len(self.train_loader))
        self.evaluator = Evaluator(self.num_class)
        self.saver = Saver()
        self.summary = TensorboardSummary(os.path.join("result", "run"))

        if resume:
            self.__resume_model_weights()

    def __resume_model_weights(self):
        last_weight = os.path.join("result", "weights", "last.pt")
        chkpt = torch.load(last_weight, map_location=self.device)
        self.model.load_state_dict(chkpt['model'])

        self.start_epoch = chkpt['epoch'] + 1
        if chkpt['optimizer'] is not None:
            self.optimizer.load_state_dict(chkpt['optimizer'])
            self.best_mIoU = chkpt['best_mIoU']
        del chkpt

        print("resume model weights from : {}".format(last_weight))

    def __training(self, epoch):
        self.model.train()
        train_loss = 0.0
        for i, sample in enumerate(self.train_loader):
            image, target = sample["image"], sample["label"]
            image = image.to(self.device)
            target = target.to(self.device)

            self.scheduler(self.optimizer, i, epoch, self.best_mIoU)
            self.optimizer.zero_grad()

            out = self.model(image)
            loss = self.criterion(logit=out, target=target)
            loss.backward()
            self.optimizer.step()

            # Update running mean of tracked metrics
            train_loss = (train_loss * i + loss.item()) / (i + 1)
            # Print or log
            if i % 20 == 0:
                s = 'Epoch:[ {:d} | {:d} ]    Batch:[ {:d} | {:d} ]    loss: {:.4f}    lr: {:.6f}'.format(
                    epoch, self.epochs - 1, i,
                    len(self.train_loader) - 1, train_loss,
                    self.optimizer.param_groups[0]['lr'])
                # self.logger.info(s)
                print(s)

            # Write
            global_step = i + len(self.train_loader) * epoch
            self.summary.writer.add_scalar('train/total_loss_iter',
                                           loss.item(), global_step)
            self.summary.writer.add_scalar('train/total_loss_epoch',
                                           train_loss, epoch)
            if i % (len(self.train_loader) // 10) == 0:
                self.summary.visualize_image(cfg.DATA["TYPE"], image, target,
                                             out, global_step)

        # Save last.pt
        if epoch <= 20:
            self.saver.save_checkpoint(state={
                'epoch': epoch,
                'best_mAP': self.best_mIoU,
                'model': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict()
            },
                                       is_best=False)

    def __validating(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        test_loss = 0.0
        for i, sample in enumerate(self.val_loader):
            image, target = sample["image"], sample["label"]
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                out = self.model(image)
            loss = self.criterion(logit=out, target=target)
            test_loss += loss.item()

            pred = out.data.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            target = target.cpu().numpy()
            self.evaluator.add_batch(target, pred)

        #  calculate the index
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()

        # Write
        self.summary.writer.add_scalar('val/total_loss_epoch', test_loss,
                                       epoch)
        self.summary.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.summary.writer.add_scalar('val/Acc', Acc, epoch)
        self.summary.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.summary.writer.add_scalar('val/fwIoU', FWIoU, epoch)

        # Save
        is_best = False
        if mIoU > self.best_mIoU:
            self.best_mIoU = mIoU
            is_best = True
        self.saver.save_checkpoint(state={
            'epoch': epoch,
            'best_mAP': self.best_mIoU,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict()
        },
                                   is_best=is_best)

        # print
        print('*' * 20 + "Validate" + '*' * 20)
        print(
            "Acc: {}\nAcc_class: {}\nmIoU: {}\nfwIoU: {}\nLoss: {:.3f}\nbest_mIoU: {}"
            .format(Acc, Acc_class, mIoU, FWIoU, test_loss, self.best_mIoU))

    def train(self):
        print(self.model)
        for epoch in range(self.start_epoch, self.epochs):
            self.__training(epoch)
            if epoch > -1:
                self.__validating(epoch)
class trainNew(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.use_amp = True if (APEX_AVAILABLE and args.use_amp) else False
        self.opt_level = args.opt_level

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        cell_path_d = os.path.join(args.saved_arch_path, 'genotype_device.npy')
        cell_path_c = os.path.join(args.saved_arch_path, 'genotype_cloud.npy')
        network_path_space = os.path.join(args.saved_arch_path, 'network_path_space.npy')

        new_cell_arch_d = np.load(cell_path_d)
        new_cell_arch_c = np.load(cell_path_c)
        new_network_arch = np.load(network_path_space)
        new_network_arch = [1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2]
        # Define network
        model = new_cloud_Model(network_arch= new_network_arch,
                         cell_arch_d = new_cell_arch_d,
                         cell_arch_c = new_cell_arch_c,
                         num_classes=self.nclass,
                         device_num_layers=6)

        # TODO: look into these
        # TODO: ALSO look into different param groups as done int deeplab below
#        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
#                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
#
        train_params = [{'params': model.parameters(), 'lr': args.lr}]
        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator_device = Evaluator(self.nclass)
        self.evaluator_cloud = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader)) #TODO: use min_lr ?

        # Using cuda
        if self.use_amp and args.cuda:
            keep_batchnorm_fp32 = True if (self.opt_level == 'O2' or self.opt_level == 'O3') else None

            # fix for current pytorch version with opt_level 'O1'
            if self.opt_level == 'O1' and torch.__version__ < '1.3':
                for module in self.model.modules():
                    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
                        # Hack to fix BN fprop without affine transformation
                        if module.weight is None:
                            module.weight = torch.nn.Parameter(
                                torch.ones(module.running_var.shape, dtype=module.running_var.dtype,
                                           device=module.running_var.device), requires_grad=False)
                        if module.bias is None:
                            module.bias = torch.nn.Parameter(
                                torch.zeros(module.running_var.shape, dtype=module.running_var.dtype,
                                            device=module.running_var.device), requires_grad=False)

            # print(keep_batchnorm_fp32)
            self.model, [self.optimizer, self.architect_optimizer] = amp.initialize(
                self.model, [self.optimizer, self.architect_optimizer], opt_level=self.opt_level,
                keep_batchnorm_fp32=keep_batchnorm_fp32, loss_scale="dynamic")

            print('cuda finished')

        # Using data parallel
        if args.cuda and len(self.args.gpu_ids) >1:
            if self.opt_level == 'O2' or self.opt_level == 'O3':
                print('currently cannot run with nn.DataParallel and optimization level', self.opt_level)
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            print('training on multiple-GPUs')

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                self.model.load_state_dict(new_state_dict)

            else:
                if (torch.cuda.device_count() > 1 or args.load_parallel):
                    self.model.module.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint['state_dict'])

            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            device_output, cloud_output = self.model(image)
            
            device_loss = self.criterion(device_output, target)
            cloud_loss = self.criterion(cloud_output, target)
            loss = device_loss + cloud_loss
            if self.use_amp:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            self.optimizer.step()
            train_loss += loss.item()

            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            if i %50 == 0:
                self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % 100 == 0:
                output = (device_output + cloud_output)/2
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator_device.reset()
        self.evaluator_cloud.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                device_output, cloud_output = self.model(image)
            device_loss = self.criterion(device_output, target)
            cloud_loss = self.criterion(cloud_output, target)
            loss = device_loss + cloud_loss
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred_d = device_output.data.cpu().numpy()
            pred_c = cloud_output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred_d = np.argmax(pred_d, axis=1)
            pred_c = np.argmax(pred_c, axis=1)
            # Add batch sample into evaluator
            self.evaluator_device.add_batch(target, pred_d)
            self.evaluator_cloud.add_batch(target, pred_c)

        mIoU_d = self.evaluator_device.Mean_Intersection_over_Union()
        mIoU_c = self.evaluator_cloud.Mean_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/device/mIoU', mIoU_d, epoch)
        self.writer.add_scalar('val/cloud/mIoU', mIoU_c, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("device_mIoU:{}, cloud_mIoU: {}".format(mIoU_d, mIoU_c))
        print('Loss: %.3f' % test_loss)

        new_pred = (mIoU_d + mIoU_c)/2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Пример #22
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)
        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset)[0], args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if os.path.isfile(args.resume):
                checkpoint = torch.load(args.resume,map_location=torch.device('cpu'))
                args.start_epoch = checkpoint['epoch']
                if args.cuda:
                    self.model.module.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint['state_dict'])
                if not args.ft:
                    self.optimizer.load_state_dict(checkpoint['optimizer'])
                # self.best_pred = checkpoint['best_pred']-0.3
                print("=> loaded checkpoint '{}' (epoch {})"
                    .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target,weight = sample['image'], sample['label'],sample['weight']
            if self.args.cuda:
                image, target,weight= image.cuda(), target.cuda(),weight.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = 0
            for index in range(output.shape[0]):
                temp1 = output[index].unsqueeze(0)
                temp2 = target[index].unsqueeze(0)
                loss = loss + weight[index,0,0]*self.criterion(temp1,temp2)
            loss.backward() 
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
        # self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)



    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target= sample['image'], sample['label']#, sample['weight']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)


            # for channels in range(target.shape[0]):
            #     imagex = image[channels].cpu().numpy()
            #     imagex = np.transpose(imagex,(1,2,0))
            #     pre = pred[channels]
            #     targ = target[channels]

            #     plt.subplot(131)
            #     plt.imshow(imagex)

            #     plt.subplot(132)
            #     image1 = imagex.copy()
            #     for i in [0,1] :
            #         g = image1[:,:,i]
            #         g[pre>0.5] = 255
            #         image1[:,:,i] = g
            #     for i in [2]:
            #         g = image1[:,:,i]
            #         g[pre>0.5] = 0
            #         image1[:,:,i] = g
            #     plt.imshow(image1)

            #     plt.subplot(133)
            #     image2 = imagex.copy()
            #     for i in [0,1] :
            #         g = image2[:,:,i]
            #         g[targ>0.5] = 255
            #         image2[:,:,i] = g
            #     for i in [2]:
            #         g = image2[:,:,i]
            #         g[targ>0.5] = 0
            #         image2[:,:,i] = g
            #     plt.imshow(image2)

            #     plt.show()


            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        xy_mIoU = self.evaluator.xy_Mean_Intersection_over_Union()
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.test_batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print("min_mIoU{}".format(xy_mIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = xy_mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
Пример #23
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        # kwargs = {'num_workers': args.workers, 'pin_memory': True}
        kwargs = {'num_workers': 0, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        if args.nir:
            input_channels = 4
        else:
            input_channels = 3

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        in_channels=input_channels,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
                weight[1] = 4
                weight[2] = 2
                weight[0] = 1
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                # place_holder_target = target
                # place_holder_output = output
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def pred_single_image(self, path):
        self.model.eval()
        img_path = path
        lbl_path = os.path.join(
            os.path.split(os.path.split(path)[0])[0], 'lbl',
            os.path.split(path)[1])
        activations = collections.defaultdict(list)

        def save_activation(name, mod, input, output):
            activations[name].append(output.cpu())

        for name, m in self.model.named_modules():
            if type(m) == nn.ReLU:
                m.register_forward_hook(partial(save_activation, name))

        input = cv2.imread(path)
        label = cv2.imread(lbl_path)
        # bkg = cv2.createBackgroundSubtractorMOG2()
        # back = bkg.apply(input)
        # cv2.imshow('back', back)
        # cv2.waitKey()
        input = cv2.resize(input, (513, 513), interpolation=cv2.INTER_CUBIC)
        image = Image.open(img_path).convert('RGB')  # width x height x 3
        # _tmp = np.array(Image.open(lbl_path), dtype=np.uint8)
        _tmp = np.array(Image.open(img_path), dtype=np.uint8)
        _tmp[_tmp == 255] = 1
        _tmp[_tmp == 0] = 0
        _tmp[_tmp == 128] = 2
        _tmp = Image.fromarray(_tmp)

        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)

        composed_transforms = transforms.Compose([
            tr.FixedResize(size=513),
            tr.Normalize(mean=mean, std=std),
            tr.ToTensor()
        ])
        sample = {'image': image, 'label': _tmp}
        sample = composed_transforms(sample)

        image, target = sample['image'], sample['label']

        image = torch.unsqueeze(image, dim=0)
        if self.args.cuda:
            image, target = image.cuda(), target.cuda()
        with torch.no_grad():
            output = self.model(image)
            # output = output.data.cpu().numpy().squeeze(0).transpose([1, 2, 0])

            # output = np.argmax(output, axis=2) * 255
            output = output.data.cpu().numpy()
            prediction = np.argmax(output, axis=1)
            prediction = np.squeeze(prediction, axis=0)
            prediction[prediction == 1] = 255
            if np.any(prediction == 2):
                prediction[prediction == 2] = 128
            if np.any(prediction == 1):
                prediction[prediction == 1] = 255
            print(np.unique(prediction))

        see = Analysis(activations, label=1, path=self.saver.experiment_dir)
        see.backtrace(output)
        # for key in keys:
        #
        #     see.visualize_tensor(see.image)
        # see.save_tensor(see.image, self.saver.experiment_dir)

        cv2.imwrite(os.path.join(self.saver.experiment_dir, 'rgb.png'), input)
        cv2.imwrite(os.path.join(self.saver.experiment_dir, 'lbl.png'), label)
        cv2.imwrite(os.path.join(self.saver.experiment_dir, 'prediction.png'),
                    prediction)
        # pred = output.data.cpu().numpy()
        # target = target.cpu().numpy()
        # pred = np.argmax(pred, axis=1)
        # pred = np.reshape(pred, (513, 513))
        # # prediction = np.append(target, pred, axis=1)
        # prediction = pred
        #
        # rgb = np.zeros((prediction.shape[0], prediction.shape[1], 3))
        #
        # r = prediction.copy()
        # g = prediction.copy()
        # b = prediction.copy()
        #
        # g[g != 1] = 0
        # g[g == 1] = 255
        #
        # r[r != 2] = 0
        # r[r == 2] = 255
        # b = np.zeros(b.shape)
        #
        # rgb[:, :, 0] = b
        # rgb[:, :, 1] = g
        # rgb[:, :, 2] = r
        #
        # prediction = np.append(input, rgb.astype(np.uint8), axis=1)
        # result = np.append(input, prediction.astype(np.uint8), axis=1)
        # cv2.line(rgb, (513, 0), (513, 1020), (255, 255, 255), thickness=1)
        # cv2.line(rgb, (513, 0), (513, 1020), (255, 255, 255), thickness=1)
        # cv2.imwrite('/home/robot/git/pytorch-deeplab-xception/run/cropweed/deeplab-resnet/experiment_41/samples/synthetic_{}.png'.format(counter), prediction)
        # plt.imshow(see.weed_filter)
        # # cv2.waitKey()
        # plt.show()

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def testing(self):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.test_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            self.evaluator.add_batch(target, pred)
            # Add batch sample into evaluator
            prediction = np.append(target, pred, axis=2)
            print(pred.shape)
            input = image[0, 0:3, :, :].cpu().numpy().transpose([1, 2, 0])
            # cv2.imshow('figure', prediction)
            # cv2.waitKey()

        # Fast test during the testing
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print(
            '[INFO] Network performance measures on the test dataset are as follows: \n '
            'mIOU: {} \n FWIOU: {} \n Class accuracy: {} \n Pixel Accuracy: {}'
            .format(mIoU, FWIoU, Acc_class, Acc))

        self.evaluator.per_class_accuracy()

    def explain_image(self, path, counter):
        self.model.eval()
        img_path = path
        lbl_path = os.path.join(
            os.path.split(os.path.split(path)[0])[0], 'lbl',
            os.path.split(path)[1])
        image = Image.open(img_path).convert('RGB')  # width x height x 3
        _tmp = np.array(Image.open(lbl_path), dtype=np.uint8)
        _tmp[_tmp == 255] = 1
        _tmp[_tmp == 0] = 0
        _tmp[_tmp == 128] = 2
        _tmp = Image.fromarray(_tmp)

        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)

        composed_transforms = transforms.Compose([
            tr.FixedResize(size=513),
            tr.Normalize(mean=mean, std=std),
            tr.ToTensor()
        ])
        sample = {'image': image, 'label': _tmp}
        sample = composed_transforms(sample)

        image, target = sample['image'], sample['label']

        image = torch.unsqueeze(image, dim=0)
        # if self.args.cuda:
        #     image, target = image.cuda(), target.cuda()
        # with torch.no_grad():
        #     output = self.model(image)
        # inn_model = InnvestigateModel(self.model, lrp_exponent=1,
        #                               method="b-rule",
        #                               beta=0, epsilon=1e-6)
        #
        # inn_model.eval()
        # model_prediction, heatmap = inn_model.innvestigate(in_tensor=image)
        # model_prediction = np.argmax(model_prediction, axis=1)

        # def run_guided_backprop(net, image_tensor):
        #     return interpretation.guided_backprop(net, image_tensor, cuda=True, verbose=False, apply_softmax=False)
        #
        # def run_LRP(net, image_tensor):
        #     return inn_model.innvestigate(in_tensor=image_tensor, rel_for_class=1)
        print('hold')
Пример #24
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.writer = SummaryWriter(log_dir=self.saver.experiment_dir)

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.nclass = make_data_loader(
            args, **kwargs)

        model = None
        # Define network
        if self.args.dataset != 'Multi':
            model = ClassesNet(backbone=self.args.backbone,
                               num_classes=self.nclass,
                               pretrained=True)
            if self.args.dataset == 'Classes':
                print("Training ClassesNet")
            else:
                print("Training SpeciesNet")
        else:
            model = MultiNet(backbone=self.args.backbone,
                             num_classes=self.nclass,
                             pretrained=True)
            print("Training MultiNet")

        self.model = model

        train_params = [{'params': model.get_params()}]
        # Define Optimizer
        self.optimizer = torch.optim.Adam(train_params,
                                          self.args.learn_rate,
                                          weight_decay=args.weight_decay,
                                          amsgrad=args.nesterov)

        # Define Criterion
        self.criterion = nn.CrossEntropyLoss(size_average=True)

        # Define lr scheduler
        exp_lr_scheduler = lr_scheduler.StepLR(self.optimizer,
                                               step_size=1,
                                               gamma=0.1)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        print('[Epoch: %d, learning rate: %.6f, previous best = %.4f]' %
              (epoch, self.args.learn_rate, self.best_pred))
        train_loss = 0.0
        corrects_labels = 0
        correct_classes = 0
        correct_species = 0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)

        for i, sample in enumerate(tbar):
            self.optimizer.zero_grad()
            if self.args.dataset != 'Multi':
                image, target = sample['image'], sample['label']
                if self.args.cuda:
                    image, target = image.cuda(), target.cuda()
                output = self.model(image)
                loss = self.criterion(output, target)

                pred_label = output.data.cpu().numpy()
                target = target.cpu().numpy()
                pred_label = np.argmax(pred_label, axis=1)
                corrects_labels += np.sum(pred_label == target)

            else:
                image, target_classes, target_species = sample[
                    'image'], sample['classes_label'], sample['species_label']
                if self.args.cuda:
                    image, target_classes, target_species = image.cuda(
                    ), target_classes.cuda(), target_species.cuda()
                output_classes, output_species = self.model(image)
                classes_loss = self.criterion(output_classes, target_classes)
                species_loss = self.criterion(output_species, target_species)
                loss = classes_loss + species_loss

                pred_classes = output_classes.data.cpu().numpy()
                pred_species = output_species.data.cpu().numpy()
                target_classes = target_classes.data.cpu().numpy()
                target_species = target_species.data.cpu().numpy()
                pred_classes = np.argmax(pred_classes, axis=1)
                pred_species = np.argmax(pred_species, axis=1)

                tmp1 = pred_classes == target_classes
                tmp2 = target_species == pred_species
                correct_classes += np.sum(tmp1)  # 统计“纲”分类正确的数量
                correct_species += np.sum(tmp2)  # 统计“种”分类正确的数量
                corrects_labels += np.sum(tmp1
                                          & tmp2)  # 按位与,统计“纲”、“种”同时分类正确的数量

            loss.backward()
            self.optimizer.step()

            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)
            train_loss += loss.item()
            tbar.set_description('Train loss: %.5f' % (train_loss / (i + 1)))

        # Fast test during the training
        acc = corrects_labels / len(self.train_loader.dataset)
        self.writer.add_scalar('train/Acc', acc, epoch)
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        acc_classes, acc_species = 0.0, 0.0
        if self.args.dataset == 'Multi':
            acc_classes = correct_classes / len(self.train_loader.dataset)
            acc_species = correct_species / len(self.train_loader.dataset)
            self.writer.add_scalar('train/Acc_classes', acc_classes, epoch)
            self.writer.add_scalar('train/Acc_species', acc_species, epoch)

        print('train validation:')
        if self.args.dataset != 'Multi':
            print("Acc:{}".format(acc))
        else:
            print("Acc:{}, Acc_classes:{}, Acc_species:{}".format(
                acc, acc_classes, acc_species))
        print('Loss: %.5f' % train_loss)
        print('---------------------------------')

    def validation(self, epoch):
        test_loss = 0.0
        corrects_labels = 0
        correct_classes = 0
        correct_species = 0
        self.model.eval()
        tbar = tqdm(self.val_loader, desc='\r')
        num_img_val = len(self.val_loader)

        for i, sample in enumerate(tbar):

            if self.args.dataset != 'Multi':
                image, target = sample['image'], sample['label']
                if self.args.cuda:
                    image, target = image.cuda(), target.cuda()
                with torch.no_grad():
                    output = self.model(image)

                loss = self.criterion(output, target)

                pred_label = output.data.cpu().numpy()
                target = target.cpu().numpy()
                pred_label = np.argmax(pred_label, axis=1)
                corrects_labels += np.sum(pred_label == target)
            else:
                image, target_classes, target_species = sample[
                    'image'], sample['classes_label'], sample['species_label']
                if self.args.cuda:
                    image, target_classes, target_species = image.cuda(
                    ), target_classes.cuda(), target_species.cuda()
                with torch.no_grad():
                    output_classes, output_species = self.model(image)

                classes_loss = self.criterion(output_classes, target_classes)
                species_loss = self.criterion(output_species, target_species)
                loss = classes_loss + species_loss

                pred_classes = output_classes.data.cpu().numpy()
                pred_species = output_species.data.cpu().numpy()
                target_classes = target_classes.data.cpu().numpy()
                target_species = target_species.data.cpu().numpy()
                pred_classes = np.argmax(pred_classes, axis=1)
                pred_species = np.argmax(pred_species, axis=1)

                tmp1 = pred_classes == target_classes
                tmp2 = target_species == pred_species
                correct_classes += np.sum(tmp1)  # 统计“纲”分类正确的数量
                correct_species += np.sum(tmp2)  # 统计“种”分类正确的数量
                corrects_labels += np.sum(tmp1
                                          & tmp2)  # 按位与,统计“纲”、“种”同时分类正确的数量

            test_loss += loss.item()
            tbar.set_description('Test loss: %.5f' % (test_loss / (i + 1)))
            self.writer.add_scalar('val/total_loss_iter', loss.item(),
                                   i + num_img_val * epoch)

        # Fast test during the training
        acc = corrects_labels / len(self.val_loader.dataset)
        self.writer.add_scalar('val/Acc', acc, epoch)
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        acc_classes, acc_species = 0.0, 0.0
        if self.args.dataset == 'Multi':
            acc_classes = correct_classes / len(self.val_loader.dataset)
            acc_species = correct_species / len(self.val_loader.dataset)
            self.writer.add_scalar('val/Acc_classes', acc_classes, epoch)
            self.writer.add_scalar('val/Acc_species', acc_species, epoch)

        print('test validation:')
        if self.args.dataset != 'Multi':
            print("Acc:{}".format(acc))
        else:
            print("Acc:{}, Acc_classes:{}, Acc_species:{}".format(
                acc, acc_classes, acc_species))
        print('Loss: %.5f' % test_loss)
        print('====================================')

        new_pred = acc
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        if args.sync_bn == True:
            BN = SynchronizedBatchNorm2d
        else:
            BN = nn.BatchNorm2d
        ### deeplabV3 start ###
        self.backbone_model = MobileNetV2(output_stride = args.out_stride,
                            BatchNorm = BN)
        self.assp_model = ASPP(backbone = args.backbone,
                          output_stride = args.out_stride,
                          BatchNorm = BN)
        self.y_model = Decoder(num_classes = self.nclass,
                          backbone = args.backbone,
                          BatchNorm = BN)
        ### deeplabV3 end ###
        self.d_model = DomainClassifer(backbone = args.backbone,
                                  BatchNorm = BN)
        f_params = list(self.backbone_model.parameters()) + list(self.assp_model.parameters())
        y_params = list(self.y_model.parameters())
        d_params = list(self.d_model.parameters())

        # Define Optimizer
        if args.optimizer == 'SGD':
            self.task_optimizer = torch.optim.SGD(f_params+y_params, lr= args.lr,
                                             momentum=args.momentum,
                                             weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.d_optimizer = torch.optim.SGD(d_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.d_inv_optimizer = torch.optim.SGD(f_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.c_optimizer = torch.optim.SGD(f_params+y_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
        elif args.optimizer == 'Adam':
            self.task_optimizer = torch.optim.Adam(f_params + y_params, lr=args.lr)
            self.d_optimizer = torch.optim.Adam(d_params, lr=args.lr)
            self.d_inv_optimizer = torch.optim.Adam(f_params, lr=args.lr)
            self.c_optimizer = torch.optim.Adam(f_params+y_params, lr=args.lr)
        else:
            raise NotImplementedError

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = 'dataloders\\datasets\\'+args.dataset + '_classes_weights.npy'
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(self.train_loader, self.nclass, classes_weights_path, self.args.dataset)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.task_loss = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.domain_loss = DomainLosses(cuda=args.cuda).build_loss()
        self.ca_loss = ''

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.backbone_model = torch.nn.DataParallel(self.backbone_model, device_ids=self.args.gpu_ids)
            self.assp_model = torch.nn.DataParallel(self.assp_model, device_ids=self.args.gpu_ids)
            self.y_model = torch.nn.DataParallel(self.y_model, device_ids=self.args.gpu_ids)
            self.d_model = torch.nn.DataParallel(self.d_model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.backbone_model)
            patch_replication_callback(self.assp_model)
            patch_replication_callback(self.y_model)
            patch_replication_callback(self.d_model)
            self.backbone_model = self.backbone_model.cuda()
            self.assp_model = self.assp_model.cuda()
            self.y_model = self.y_model.cuda()
            self.d_model = self.d_model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.backbone_model.module.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.module.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.module.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.module.load_state_dict(checkpoint['d_model_state_dict'])
            else:
                self.backbone_model.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.load_state_dict(checkpoint['d_model_state_dict'])
            if not args.ft:
                self.task_optimizer.load_state_dict(checkpoint['task_optimizer'])
                self.d_optimizer.load_state_dict(checkpoint['d_optimizer'])
                self.d_inv_optimizer.load_state_dict(checkpoint['d_inv_optimizer'])
                self.c_optimizer.load_state_dict(checkpoint['c_optimizer'])
            if self.args.dataset == 'gtav':
                self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        train_task_loss = 0.0
        train_d_loss = 0.0
        train_d_inv_loss = 0.0
        self.backbone_model.train()
        self.assp_model.train()
        self.y_model.train()
        self.d_model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            if self.args.dataset == 'gtav':
                src_image,src_label = sample['image'], sample['label']
            else:
                src_image, src_label, tgt_image = sample['src_image'], sample['src_label'], sample['tgt_image']
            if self.args.cuda:
                if self.args.dataset != 'gtav':
                    src_image, src_label, tgt_image  = src_image.cuda(), src_label.cuda(), tgt_image.cuda()
                else:
                    src_image, src_label = src_image.cuda(), src_label.cuda()
            self.scheduler(self.task_optimizer, i, epoch, self.best_pred)
            self.scheduler(self.d_optimizer, i, epoch, self.best_pred)
            self.scheduler(self.d_inv_optimizer, i, epoch, self.best_pred)
            self.scheduler(self.c_optimizer, i, epoch, self.best_pred)
            self.task_optimizer.zero_grad()
            self.d_optimizer.zero_grad()
            self.d_inv_optimizer.zero_grad()
            self.c_optimizer.zero_grad()
            # source image feature
            src_high_feature_0, src_low_feature = self.backbone_model(src_image)
            src_high_feature = self.assp_model(src_high_feature_0)
            src_output = F.interpolate(self.y_model(src_high_feature, src_low_feature), src_image.size()[2:], \
                                       mode='bilinear', align_corners=True)

            src_d_pred = self.d_model(src_high_feature)
            task_loss = self.task_loss(src_output, src_label)

            if self.args.dataset != 'gtav':
                # target image feature
                tgt_high_feature_0, tgt_low_feature = self.backbone_model(tgt_image)
                tgt_high_feature = self.assp_model(tgt_high_feature_0)
                tgt_output = F.interpolate(self.y_model(tgt_high_feature, tgt_low_feature), tgt_image.size()[2:], \
                                           mode='bilinear', align_corners=True)
                tgt_d_pred = self.d_model(tgt_high_feature)

                d_loss,d_acc = self.domain_loss(src_d_pred, tgt_d_pred)
                d_inv_loss, _ = self.domain_loss(tgt_d_pred, src_d_pred)
                loss = task_loss + d_loss + d_inv_loss
                loss.backward()
                self.task_optimizer.step()
                self.d_optimizer.step()
                self.d_inv_optimizer.step()
            else:
                d_acc = 0
                d_loss = torch.tensor(0.0)
                d_inv_loss = torch.tensor(0.0)
                loss = task_loss
                loss.backward()
                self.task_optimizer.step()

            train_task_loss += task_loss.item()
            train_d_loss += d_loss.item()
            train_d_inv_loss += d_inv_loss.item()
            train_loss += task_loss.item() + d_loss.item() + d_inv_loss.item()

            tbar.set_description('Train loss: %.3f t_loss: %.3f d_loss: %.3f , d_inv_loss: %.3f  d_acc: %.2f' \
                                 % (train_loss / (i + 1),train_task_loss / (i + 1),\
                                    train_d_loss / (i + 1), train_d_inv_loss / (i + 1), d_acc*100))

            self.writer.add_scalar('train/task_loss_iter', task_loss.item(), i + num_img_tr * epoch)
            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                if self.args.dataset != 'gtav':
                    image = torch.cat([src_image,tgt_image],dim=0)
                    output = torch.cat([src_output,tgt_output],dim=0)
                else:
                    image = src_image
                    output = src_output
                self.summary.visualize_image(self.writer, self.args.dataset, image, src_label, output, global_step)


        self.writer.add_scalar('train/task_loss_epoch', train_task_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + src_image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'backbone_model_state_dict': self.backbone_model.module.state_dict(),
                'assp_model_state_dict': self.assp_model.module.state_dict(),
                'y_model_state_dict': self.y_model.module.state_dict(),
                'd_model_state_dict': self.d_model.module.state_dict(),
                'task_optimizer': self.task_optimizer.state_dict(),
                'd_optimizer': self.d_optimizer.state_dict(),
                'd_inv_optimizer': self.d_inv_optimizer.state_dict(),
                'c_optimizer': self.c_optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def validation(self, epoch):
        self.backbone_model.eval()
        self.assp_model.eval()
        self.y_model.eval()
        self.d_model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                high_feature, low_feature = self.backbone_model(image)
                high_feature = self.assp_model(high_feature)
                output = F.interpolate(self.y_model(high_feature, low_feature), image.size()[2:], \
                                           mode='bilinear', align_corners=True)
            task_loss = self.task_loss(output, target)
            test_loss += task_loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU,IoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU

        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'backbone_model_state_dict': self.backbone_model.module.state_dict(),
                'assp_model_state_dict': self.assp_model.module.state_dict(),
                'y_model_state_dict': self.y_model.module.state_dict(),
                'd_model_state_dict': self.d_model.module.state_dict(),
                'task_optimizer': self.task_optimizer.state_dict(),
                'd_optimizer': self.d_optimizer.state_dict(),
                'd_inv_optimizer': self.d_inv_optimizer.state_dict(),
                'c_optimizer': self.c_optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion

        self.criterion = SegmentationLosses(cuda=args.cuda)
        self.model, self.optimizer = model, optimizer
        self.contexts = TemporalContexts(history_len=5)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))

            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning or in validation/test mode
        if args.ft or args.mode == "val" or args.mode == "test":
            args.start_epoch = 0
            self.best_pred = 0.0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)

        for i, sample in enumerate(tbar):
            image, region_prop, target = sample['image'], sample['rp'], sample[
                'label']
            if self.args.cuda:
                image, region_prop, target = image.cuda(), region_prop.cuda(
                ), target.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image, region_prop)
            loss = self.criterion.CrossEntropyLoss(
                output,
                target,
                weight=torch.from_numpy(
                    calculate_weights_batch(sample,
                                            self.nclass).astype(np.float32)))
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            pred = output.clone().data.cpu()
            pred_softmax = F.softmax(pred, dim=1).numpy()
            pred = np.argmax(pred.numpy(), axis=1)

            # Plot prediction every 20th iter
            if i % (num_img_tr // 20) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.vis_grid(self.writer,
                                      self.args.dataset,
                                      image.data.cpu().numpy()[0],
                                      target.data.cpu().numpy()[0],
                                      pred[0],
                                      region_prop.data.cpu().numpy()[0],
                                      pred_softmax[0],
                                      global_step,
                                      split="Train")

        self.writer.add_scalar('train/total_loss_epoch',
                               train_loss / num_img_tr, epoch)
        print('Loss: {}'.format(train_loss / num_img_tr))

        if self.args.no_val or self.args.save_all:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                },
                is_best,
                filename='checkpoint_' + str(epoch + 1) + '_.pth.tar')

    def validation(self, epoch):

        if self.args.mode == "train" or self.args.mode == "val":
            loader = self.val_loader
        elif self.args.mode == "test":
            loader = self.test_loader

        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(loader, desc='\r')

        test_loss = 0.0
        idr_thresholds = [0.20, 0.30, 0.40, 0.50, 0.60, 0.65]

        num_itr = len(loader)

        for i, sample in enumerate(tbar):
            image, region_prop, target = sample['image'], sample['rp'], sample[
                'label']
            # orig_region_prop = region_prop.clone()
            # region_prop = self.contexts.temporal_prop(image.numpy(),region_prop.numpy())

            if self.args.cuda:
                image, region_prop, target = image.cuda(), region_prop.cuda(
                ), target.cuda()
            with torch.no_grad():
                output = self.model(image, region_prop)

            # loss = self.criterion.CrossEntropyLoss(output,target,weight=torch.from_numpy(calculate_weights_batch(sample,self.nclass).astype(np.float32)))
            # test_loss += loss.item()
            # tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))

            output = output.detach().data.cpu()
            pred_softmax = F.softmax(output, dim=1).numpy()
            pred = np.argmax(pred_softmax, axis=1)
            target = target.cpu().numpy()
            image = image.cpu().numpy()
            region_prop = region_prop.cpu().numpy()
            # orig_region_prop = orig_region_prop.numpy()

            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

            # Append buffer with original context(before temporal propagation)
            # self.contexts.append_buffer(image[0],orig_region_prop[0],pred[0])

            global_step = i + num_itr * epoch
            self.summary.vis_grid(self.writer,
                                  self.args.dataset,
                                  image[0],
                                  target[0],
                                  pred[0],
                                  region_prop[0],
                                  pred_softmax[0],
                                  global_step,
                                  split="Validation")

        # Fast test during the training
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        recall, precision = self.evaluator.pdr_metric(class_id=2)
        idr_avg = np.array([
            self.evaluator.get_idr(class_value=2, threshold=value)
            for value in idr_thresholds
        ])
        false_idr = self.evaluator.get_false_idr(class_value=2)
        instance_iou = self.evaluator.get_instance_iou(threshold=0.20,
                                                       class_value=2)

        # self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Recall/per_epoch', recall, epoch)
        self.writer.add_scalar('IDR/per_epoch(0.20)', idr_avg[0], epoch)
        self.writer.add_scalar('IDR/avg_epoch', np.mean(idr_avg), epoch)
        self.writer.add_scalar('False_IDR/epoch', false_idr, epoch)
        self.writer.add_scalar('Instance_IOU/epoch', instance_iou, epoch)
        self.writer.add_histogram(
            'Prediction_hist',
            self.evaluator.pred_labels[self.evaluator.gt_labels == 2], epoch)

        print('Validation:')
        # print('Loss: %.3f' % test_loss)
        # print('Recall/PDR:{}'.format(recall))
        print('IDR:{}'.format(idr_avg[0]))
        print('False Positive Rate: {}'.format(false_idr))
        print('Instance_IOU: {}'.format(instance_iou))

        if self.args.mode == "train":
            new_pred = mIoU
            if new_pred > self.best_pred:
                is_best = True
                self.best_pred = new_pred
                self.saver.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': self.model.module.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'best_pred': self.best_pred,
                    }, is_best)

        else:
            pass
Пример #27
0
def main():
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.enabled = True

    saver = Saver(args)
    # set log
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p',
                        filename=os.path.join(saver.experiment_dir, 'log.txt'),
                        filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger().addHandler(console)

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    saver.create_exp_dir(scripts_to_save=glob.glob('*.py') +
                         glob.glob('*.sh') + glob.glob('*.yml'))
    saver.save_experiment_config()
    summary = TensorboardSummary(saver.experiment_dir)
    writer = summary.create_summary()
    best_pred = 0

    logging.info(args)

    device = torch.device('cuda')
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)
    maml = Meta(args, criterion).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    logging.info(maml)
    logging.info('Total trainable tensors: {}'.format(num))

    # batch_size here means total episode number
    mini = MiniImagenet(args.data_path,
                        mode='train',
                        n_way=args.n_way,
                        k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batch_size=args.batch_size,
                        resize=args.img_size,
                        split=[0, args.train_portion])
    mini_valid = MiniImagenet(args.data_path,
                              mode='train',
                              n_way=args.n_way,
                              k_shot=args.k_spt,
                              k_query=args.k_qry,
                              batch_size=args.batch_size,
                              resize=args.img_size,
                              split=[args.train_portion, 1])
    mini_test = MiniImagenet(args.data_path,
                             mode='train',
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batch_size=args.test_batch_size,
                             resize=args.img_size,
                             split=[args.train_portion, 1])
    train_queue = DataLoader(mini,
                             args.meta_batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    valid_queue = DataLoader(mini_valid,
                             args.meta_batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    test_queue = DataLoader(mini_test,
                            args.meta_test_batch_size,
                            shuffle=True,
                            num_workers=args.num_workers,
                            pin_memory=True)
    architect = Architect(maml.model, args)

    for epoch in range(args.epoch):
        # fetch batch_size num of episode each time
        logging.info('--------- Epoch: {} ----------'.format(epoch))

        train_accs = meta_train(train_queue, valid_queue, maml, architect,
                                device, criterion, epoch, writer)
        logging.info('[Epoch: {}]\t Train acc: {}'.format(epoch, train_accs))
        valid_accs = meta_test(test_queue, maml, device, epoch, writer)
        logging.info('[Epoch: {}]\t Test acc: {}'.format(epoch, valid_accs))

        genotype = maml.model.genotype()
        logging.info('genotype = %s', genotype)

        # logging.info(F.softmax(maml.model.alphas_normal, dim=-1))
        logging.info(F.softmax(maml.model.alphas_reduce, dim=-1))

        # Save the best meta model.
        new_pred = valid_accs[-1]
        if new_pred > best_pred:
            is_best = True
            best_pred = new_pred
        else:
            is_best = False
        saver.save_checkpoint(
            {
                'epoch':
                epoch,
                'state_dict':
                maml.module.state_dict()
                if isinstance(maml, nn.DataParallel) else maml.state_dict(),
                'best_pred':
                best_pred,
            }, is_best)
Пример #28
0
class Trainer(object):
    def __init__(self, config):

        self.config = config
        self.best_pred = 0.0

        # Define Saver
        self.saver = Saver(config)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.config['training']['tensorboard']['log_dir'])
        self.writer = self.summary.create_summary()
        
        self.train_loader, self.val_loader, self.test_loader, self.nclass = initialize_data_loader(config)
        
        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=self.config['network']['backbone'],
                        output_stride=self.config['image']['out_stride'],
                        sync_bn=self.config['network']['sync_bn'],
                        freeze_bn=self.config['network']['freeze_bn'])

        train_params = [{'params': model.get_1x_lr_params(), 'lr': self.config['training']['lr']},
                        {'params': model.get_10x_lr_params(), 'lr': self.config['training']['lr'] * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=self.config['training']['momentum'],
                                    weight_decay=self.config['training']['weight_decay'], nesterov=self.config['training']['nesterov'])

        # Define Criterion
        # whether to use class balanced weights
        if self.config['training']['use_balanced_weights']:
            classes_weights_path = os.path.join(self.config['dataset']['base_path'], self.config['dataset']['dataset_name'] + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(self.config, self.config['dataset']['dataset_name'], self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion = SegmentationLosses(weight=weight, cuda=self.config['network']['use_cuda']).build_loss(mode=self.config['training']['loss_type'])
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(self.config['training']['lr_scheduler'], self.config['training']['lr'],
                                            self.config['training']['epochs'], len(self.train_loader))


        # Using cuda
        if self.config['network']['use_cuda']:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint

        if self.config['training']['weights_initialization']['use_pretrained_weights']:
            if not os.path.isfile(self.config['training']['weights_initialization']['restore_from']):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(self.config['training']['weights_initialization']['restore_from']))

            if self.config['network']['use_cuda']:
                checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'])
            else:
                checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'], map_location={'cuda:0': 'cpu'})

            self.config['training']['start_epoch'] = checkpoint['epoch']

            if self.config['network']['use_cuda']:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])

#            if not self.config['ft']:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(self.config['training']['weights_initialization']['restore_from'], checkpoint['epoch']))


    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.config['network']['use_cuda']:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.config['dataset']['dataset_name'], image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config['training']['batch_size'] + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        #save last checkpoint
        self.saver.save_checkpoint({
            'epoch': epoch + 1,
#            'state_dict': self.model.module.state_dict(),
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'best_pred': self.best_pred,
        }, is_best = False, filename='checkpoint_last.pth.tar')

        #if training on a subset reshuffle the data 
        if self.config['training']['train_on_subset']['enabled']:
            self.train_loader.dataset.shuffle_dataset()    


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.config['network']['use_cuda']:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Val loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.config['training']['batch_size'] + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
#                'state_dict': self.model.module.state_dict(),
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            },  is_best = True, filename='checkpoint_best.pth.tar')
Пример #29
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary

        # Define Dataloader
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        if DEBUG:
            print("get device: ",self.device)
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        # Define network

        model = DeepLab3d(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)
        # Define Criterion
        # whether to use class balanced weights

        if args.use_balanced_weights:
            classes_weights_path = os.path.join(ROOT_PATH, args.dataset+'_classes_weights.npy')

            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32)) ##########weight not cuda

        else:
            weight = None

        self.criterion = DiceCELoss()
        self.model, self.optimizer = model, optimizer

        

        # Define Evaluator

        self.evaluator = Evaluator(self.nclass)

        # Define lr scheduler

        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,

                                            args.epochs, len(self.train_loader))



        # Using cuda

        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):

        train_loss = 0.0
        dice_loss_count = 0.0
        ce_loss_count = 0.0
        num_count = 0

        #self.model.train()
        self.model.eval()

        tbar = tqdm(self.train_loader)

        num_img_tr = len(self.train_loader)

        for i, sample in enumerate(tbar):

            image, target = sample
            if DEBUG:
                print("image, target size feed in model,", image.size(), target.size())
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()

            output = self.model(image)
            if DEBUG:
                print(output.size())
            n,c,d,w,h = output.shape
            output2 = torch.tensor( (np.zeros( (n,c,d,w,h) ) ).astype(np.float32) )
            if(output.is_cuda==True):
                output2 = output2.to(self.device)
            for mk1 in range(0,n):
                for mk2 in range(0,c): #对于每个n, c进行正则化
                    output2[mk1,mk2,:,:,:] = ( output[mk1,mk2,:,:,:] - torch.min(output[mk1,mk2,:,:,:]) ) / ( torch.max( output[mk1,mk2,:,:,:] ) - torch.min(output[mk1,mk2,:,:,:]) )
                
            loss, dice_loss, ce_loss = self.criterion(output,output2, target,self.device)

            loss.backward()

            self.optimizer.step()

            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            dice_loss_count = dice_loss_count + dice_loss.item()
            ce_loss_count = ce_loss_count + ce_loss.item()
            num_count = num_count + 1

            # Show 10 * 3 inference results each epoch

            if i % (num_img_tr // 5) == 0:

                global_step = i + num_img_tr * epoch





        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))

        print('Loss: %.3f, dice loss: %.3f, ce loss: %.3f' % (train_loss, dice_loss_count/num_count, ce_loss_count/num_count))#maybe here is something wrong



        if self.args.no_val:

            # save checkpoint every epoch

            is_best = False

            self.saver.save_checkpoint({

                'epoch': epoch + 1,

                'state_dict': self.model.module.state_dict(),

                'optimizer': self.optimizer.state_dict(),

                'best_pred': self.best_pred,

            }, is_best)





    def validation(self, epoch):

        self.model.eval()

        self.evaluator.reset()

        tbar = tqdm(self.val_loader, desc='\r')

        test_loss = 0.0
        dice_loss = 0.0
        ce_loss = 0.0
        num_count = 0
        for i, sample in enumerate(tbar):
            image, target = sample
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output = self.model(image)
            n,c,d,w,h = output.shape
            output2 = torch.tensor( (np.zeros( (n,c,d,w,h) ) ).astype(np.float32) )
            if(output.is_cuda==True):
                output2 = output2.to(self.device)
            for mk1 in range(0,n):
                for mk2 in range(0,c): #对于每个n, c进行正则化
                    output2[mk1,mk2,:,:,:] = ( output[mk1,mk2,:,:,:] - torch.min(output[mk1,mk2,:,:,:]) ) / ( torch.max( output[mk1,mk2,:,:,:] ) - torch.min(output[mk1,mk2,:,:,:]) )
                

            loss, dice, ce = self.criterion(output, output2, target, self.device)
            test_loss += loss.item()
            dice_loss += dice.item()
            ce_loss += ce.item()
            num_count += 1
            tbar.set_description('Test loss: %.3f, dice loss: %.3f, ce loss: %.3f' % (test_loss / (i + 1), dice_loss / num_count, ce_loss / num_count))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
#            if self.args.cuda:
#                target, pred = torch.from_numpy(target).cuda(), torch.from_numpy(pred).cuda()
            if DEBUG:
                print("check gt_image shape, pred img shape ",target.shape, pred.shape)
            self.evaluator.add_batch(np.squeeze(target), np.squeeze(pred))



        # Fast test during the training

        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()

        print('Validation:')

        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))

        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))

        print('Loss: %.3f' % test_loss, dice_loss, ce_loss)



        new_pred = mIoU

        if new_pred > self.best_pred:

            is_best = True

            self.best_pred = new_pred

            self.saver.save_checkpoint({

                'epoch': epoch + 1,

                'state_dict': self.model.module.state_dict(),

                'optimizer': self.optimizer.state_dict(),

                'best_pred': self.best_pred,

            }, is_best)
            print("ltt save ckpt!")
Пример #30
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        print(self.nclass, args.backbone, args.out_stride, args.sync_bn,
              args.freeze_bn)
        #2 resnet 16 False False

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            #image, target = sample['image'], sample['label']
            image, target = sample['trace'], sample['label']

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                #self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['trace'], sample['label']
            #image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        if not args.inference and not args.eval:
            self.saver = Saver(args)
            self.saver.save_experiment_config()
        # Define Tensorboard Summary
            self.summary = TensorboardSummary(self.saver.experiment_dir)
            self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)
        '''
        model.cuda()
        summary(model, input_size=(3, 720, 1280))
        exit()
        '''

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
        # Palette for inferencing
        self.palette = np.asarray([ [0,0,0], [217,83,79], [91, 192, 222]], dtype=np.uint8)

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def id(self, pred, filename):
        save_dir = './prd'
        saveas = os.path.join(save_dir, filename)
        pred = pred[0]
        result = Image.fromarray(pred.astype('uint8'))
        result.save(saveas)

    def combined(self, pred, origin, filename):
        save_dir = './prd'
        saveas = os.path.join(save_dir, filename)
        origin = np.asarray(origin)
        origin2 = origin[0].swapaxes(0, 1).swapaxes(1, 2)
        origin2 = origin2 * np.array([0.197, 0.198, 0.201]) + np.array([0.279, 0.293, 0.290])
        origin2 = origin2 * 255
        pred = pred[0]
        img = np.array(self.palette[pred.squeeze()])
        result = np.array(np.zeros(img.shape))
        result[pred==0] = origin2[pred==0]
        result[pred!=0] = origin2[pred!=0] /2 + img[pred!=0] / 2
        result = Image.fromarray(result.astype('uint8'), 'RGB')
        result.save(saveas)



    def validation(self, epoch, inference=False):
        self.model.eval()
        self.evaluator.reset()
        if self.args.submit:
            tbar = tqdm(self.test_loader, desc='\r')
        else:
            tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        if self.args.submit:
            for i, sample in enumerate(tbar):
                image = sample['image']
                name = sample['name'][0]
                tbar.set_description('%s' % name)
                if self.args.cuda:
                    image = image.cuda()
                with torch.no_grad():
                    output = self.model(image)
                pred = output.data.cpu().numpy()
                pred = np.argmax(pred, axis=1)
                self.id(pred, name)
            print("All done clear, good luck!")
            return
        for i, sample in enumerate(tbar):
            if inference and not self.args.submit and i >= 100:
                break
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            target = target.cpu().numpy()
            image = image.cpu().numpy()
            pred = output.data.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            if inference:
                self.combined(pred, image, str(i) + '.png')
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        if not self.args.inference and not self.args.eval:
            self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
            self.writer.add_scalar('val/mIoU', mIoU, epoch)
            self.writer.add_scalar('val/Acc', Acc, epoch)
            self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
            self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        if not self.args.inference and not self.args.eval:
            new_pred = mIoU
            if new_pred > self.best_pred:
                is_best = True
                self.best_pred = new_pred
                self.saver.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)