Example #1
0
class Trainer(object):
    def __init__(self, mode):
        # Define Saver
        self.saver = Saver(opt, mode)
        self.logger = self.saver.logger

        # Visualize
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Dataset dataloader
        self.train_dataset, self.train_loader = make_data_loader(opt)
        self.nbatch_train = len(self.train_loader)
        self.val_dataset, self.val_loader = make_data_loader(opt, mode="val")
        self.nbatch_val = len(self.val_loader)

        # Model
        if opt.sync_bn is None and len(opt.gpu_id) > 1:
            opt.sync_bn = True
        else:
            opt.sync_bn = False
        # model = DeepLab(opt)
        # model = CSRNet()
        model = CRGNet(opt)
        model_info(model, self.logger)
        self.model = model.to(opt.device)

        # Loss
        if opt.use_balanced_weights:
            classes_weights_file = osp.join(opt.root_dir, 'train_classes_weights.npy')
            if os.path.isfile(classes_weights_file):
                weight = np.load(classes_weights_file)
            else:
                weight = calculate_weigths_labels(
                    self.train_loader, opt.root_dir)
            print(weight)
            opt.loss['weight'] = weight
        self.loss = build_loss(opt.loss)

        # Define Evaluator
        self.evaluator = Evaluator()  # use region to eval: class_num is 2

        # Resuming Checkpoint
        self.best_pred = 0.0
        self.start_epoch = 0
        if opt.resume:
            if os.path.isfile(opt.pre):
                print("=> loading checkpoint '{}'".format(opt.pre))
                checkpoint = torch.load(opt.pre)
                self.start_epoch = checkpoint['epoch']
                self.best_pred = checkpoint['best_pred']
                self.model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(opt.pre, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(opt.pre))

        if len(opt.gpu_id) > 1:
            print("Using multiple gpu")
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=opt.gpu_id)

        # Define Optimizer
        # train_params = [{'params': model.get_1x_lr_params(), 'lr': opt.lr},
        #                 {'params': model.get_10x_lr_params(), 'lr': opt.lr * 10}]
        # self.optimizer = torch.optim.SGD(train_params,
        #                                  momentum=opt.momentum,
        #                                  weight_decay=opt.decay)
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=opt.lr,
                                         momentum=opt.momentum,
                                         weight_decay=opt.decay)

        # Define lr scheduler
        # self.scheduler = LR_Scheduler(mode=opt.lr_scheduler,
        #                               base_lr=opt.lr,
        #                               num_epochs=opt.epochs,
        #                               iters_per_epoch=self.nbatch_train,
        #                               lr_step=140)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer,
            milestones=[round(opt.epochs * x) for x in opt.steps],
            gamma=opt.gamma)

        # Time
        self.loss_hist = collections.deque(maxlen=500)
        self.timer = Timer(opt.epochs, self.nbatch_train, self.nbatch_val)
        self.step_time = collections.deque(maxlen=opt.print_freq)

    def train(self, epoch):
        self.model.train()
        if opt.freeze_bn:
            self.model.module.freeze_bn() if len(opt.gpu_id) > 1 \
                else self.model.freeze_bn()
        last_time = time.time()
        epoch_loss = []
        for iter_num, sample in enumerate(self.train_loader):
            # if iter_num >= 100: break
            try:
                imgs = sample["image"].to(opt.device)
                labels = sample["label"].to(opt.device)

                output = self.model(imgs)

                loss = self.loss(output, labels)
                loss.backward()
                # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 3)
                self.loss_hist.append(float(loss))
                epoch_loss.append(float(loss.cpu().item()))

                self.optimizer.step()
                self.optimizer.zero_grad()
                # self.scheduler(self.optimizer, iter_num, epoch)

                # Visualize
                global_step = iter_num + self.nbatch_train * epoch + 1
                self.writer.add_scalar('train/loss', loss.cpu().item(), global_step) 
                batch_time = time.time() - last_time
                last_time = time.time()
                eta = self.timer.eta(global_step, batch_time)
                self.step_time.append(batch_time)
                if global_step % opt.print_freq == 0:
                    printline = ('Epoch: [{}][{}/{}] '
                                 'lr: {:1.5f}, '  # 10x:{:1.5f}), '
                                 'eta: {}, time: {:1.1f}, '
                                 'Loss: {:1.4f} '.format(
                                    epoch, iter_num+1, self.nbatch_train,
                                    self.optimizer.param_groups[0]['lr'],
                                    # self.optimizer.param_groups[1]['lr'],
                                    eta, np.sum(self.step_time),
                                    np.mean(self.loss_hist)))
                    self.logger.info(printline)

                del loss

            except Exception as e:
                print(e)
                continue

        self.scheduler.step()

    def validate(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        test_loss = 0.0
        with torch.no_grad():
            tbar = tqdm(self.val_loader, desc='\r')
            for i, sample in enumerate(tbar):
                # if i > 3: break
                imgs = sample['image'].to(opt.device)
                labels = sample['label'].to(opt.device)
                path = sample["path"]

                output = self.model(imgs)

                loss = self.loss(output, labels)
                test_loss += loss.item()
                tbar.set_description('Test loss: %.4f' % (test_loss / (i + 1)))

                # Visualize
                global_step = i + self.nbatch_val * epoch + 1
                if global_step % opt.plot_every == 0:
                    # pred = output.data.cpu().numpy()
                    if output.shape[1] > 1:
                        pred = torch.argmax(output, dim=1)
                    else:
                        pred = torch.clamp(output, min=0)
                    self.summary.visualize_image(self.writer,
                                                 opt.dataset,
                                                 imgs,
                                                 labels,
                                                 pred,
                                                 global_step)

                # metrics
                pred = output.data.cpu().numpy()
                target = labels.cpu().numpy() > 0
                if pred.shape[1] > 1:
                    pred = np.argmax(pred, axis=1)
                pred = (pred > opt.region_thd).reshape(target.shape)
                self.evaluator.add_batch(target, pred, path, opt.dataset)

            # Fast test during the training
            Acc = self.evaluator.Pixel_Accuracy()
            Acc_class = self.evaluator.Pixel_Accuracy_Class()
            mIoU = self.evaluator.Mean_Intersection_over_Union()
            FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
            RRecall = self.evaluator.Region_Recall()
            RNum = self.evaluator.Region_Num()
            mean_loss = test_loss / self.nbatch_val
            result = 2 / (1 / mIoU + 1 / RRecall)
            titles = ["mean_loss", "mIoU", "Acc", "Acc_class", "fwIoU", "RRecall", "RNum", "Result"]
            values = [mean_loss, mIoU, Acc, Acc_class, FWIoU, RRecall, RNum, result]
            for title, value in zip(titles, values):
                self.writer.add_scalar('val/'+title, value, epoch)

            printline = ("Val, mean_loss: {:.4f}, mIoU: {:.4f}, "
                         "Acc: {:.4f}, Acc_class: {:.4f}, fwIoU: {:.4f}, "
                         "RRecall: {:.4f}, RNum: {:.1f}]").format(
                            *values[:-1])
            self.logger.info(printline)

        return result
Example #2
0
    def __init__(self, mode):
        # Define Saver
        self.saver = Saver(opt, mode)
        self.logger = self.saver.logger

        # Visualize
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Dataset dataloader
        self.train_dataset, self.train_loader = make_data_loader(opt)
        self.nbatch_train = len(self.train_loader)
        self.val_dataset, self.val_loader = make_data_loader(opt, mode="val")
        self.nbatch_val = len(self.val_loader)

        # Model
        if opt.sync_bn is None and len(opt.gpu_id) > 1:
            opt.sync_bn = True
        else:
            opt.sync_bn = False
        # model = DeepLab(opt)
        # model = CSRNet()
        model = CRGNet(opt)
        model_info(model, self.logger)
        self.model = model.to(opt.device)

        # Loss
        if opt.use_balanced_weights:
            classes_weights_file = osp.join(opt.root_dir, 'train_classes_weights.npy')
            if os.path.isfile(classes_weights_file):
                weight = np.load(classes_weights_file)
            else:
                weight = calculate_weigths_labels(
                    self.train_loader, opt.root_dir)
            print(weight)
            opt.loss['weight'] = weight
        self.loss = build_loss(opt.loss)

        # Define Evaluator
        self.evaluator = Evaluator()  # use region to eval: class_num is 2

        # Resuming Checkpoint
        self.best_pred = 0.0
        self.start_epoch = 0
        if opt.resume:
            if os.path.isfile(opt.pre):
                print("=> loading checkpoint '{}'".format(opt.pre))
                checkpoint = torch.load(opt.pre)
                self.start_epoch = checkpoint['epoch']
                self.best_pred = checkpoint['best_pred']
                self.model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(opt.pre, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(opt.pre))

        if len(opt.gpu_id) > 1:
            print("Using multiple gpu")
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=opt.gpu_id)

        # Define Optimizer
        # train_params = [{'params': model.get_1x_lr_params(), 'lr': opt.lr},
        #                 {'params': model.get_10x_lr_params(), 'lr': opt.lr * 10}]
        # self.optimizer = torch.optim.SGD(train_params,
        #                                  momentum=opt.momentum,
        #                                  weight_decay=opt.decay)
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=opt.lr,
                                         momentum=opt.momentum,
                                         weight_decay=opt.decay)

        # Define lr scheduler
        # self.scheduler = LR_Scheduler(mode=opt.lr_scheduler,
        #                               base_lr=opt.lr,
        #                               num_epochs=opt.epochs,
        #                               iters_per_epoch=self.nbatch_train,
        #                               lr_step=140)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer,
            milestones=[round(opt.epochs * x) for x in opt.steps],
            gamma=opt.gamma)

        # Time
        self.loss_hist = collections.deque(maxlen=500)
        self.timer = Timer(opt.epochs, self.nbatch_train, self.nbatch_val)
        self.step_time = collections.deque(maxlen=opt.print_freq)
Example #3
0
class Trainer(object):
    def __init__(self, mode):
        # Define Saver
        self.saver = Saver(opt, mode)
        self.logger = self.saver.logger

        # visualize
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Dataset dataloader
        self.train_dataset, self.train_loader = make_data_loader(opt)
        self.nbatch_train = len(self.train_loader)
        self.val_dataset, self.val_loader = make_data_loader(opt, mode="val")
        self.nbatch_val = len(self.val_loader)

        # model
        if opt.sync_bn is None and len(opt.gpu_id) > 1:
            opt.sync_bn = True
        else:
            opt.sync_bn = False
        model = CRG2Net(opt)
        self.model = model.to(opt.device)

        # Loss
        if opt.use_balanced_weights:
            classes_weights_file = osp.join(opt.root_dir, 'train_classes_weights.npy')
            if os.path.isfile(classes_weights_file):
                weight = np.load(classes_weights_file)
            else:
                weight = calculate_weigths_labels(
                    self.train_loader, opt.root_dir)
            print(weight)
            opt.loss_region['weight'] = weight
        self.loss_region = build_loss(opt.loss_region)
        self.loss_density = build_loss(opt.loss_density)

        # Define Evaluator
        self.evaluator = Evaluator(dataset=opt.dataset)  # use region to eval: class_num is 2

        # Resuming Checkpoint
        self.best_pred = 0.0
        self.start_epoch = 0
        if opt.resume:
            if os.path.isfile(opt.pre):
                print("=> loading checkpoint '{}'".format(opt.pre))
                checkpoint = torch.load(opt.pre)
                self.start_epoch = checkpoint['epoch']
                self.best_pred = checkpoint['best_pred']
                self.model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(opt.pre, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(opt.pre))

        if len(opt.gpu_id) > 1:
            self.logger.info("Using multiple gpu")
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=opt.gpu_id)

        # Define Optimizer and Lr Scheduler
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=opt.lr,
                                         momentum=opt.momentum,
                                         weight_decay=opt.decay)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer,
            milestones=[round(opt.epochs * x) for x in opt.steps],
            gamma=opt.gamma)

        # Time
        self.loss_hist = collections.deque(maxlen=500)
        self.timer = Timer(opt.epochs, self.nbatch_train, self.nbatch_val)
        self.step_time = collections.deque(maxlen=opt.print_freq)

    def train(self, epoch):
        self.model.train()
        if opt.freeze_bn:
            self.model.module.freeze_bn() if len(opt.gpu_id) > 1 \
                else self.model.freeze_bn()
        last_time = time.time()
        epoch_loss = []
        for iter_num, sample in enumerate(self.train_loader):
            # if iter_num >= 0: break
            try:
                imgs = sample["image"].to(opt.device)
                density_gt = sample["label"].to(opt.device)
                region_gt = (sample["label"] > 0).float().to(opt.device)

                region_pred, density_pred = self.model(imgs)

                region_loss = self.loss_region(region_pred, region_gt)
                density_loss = self.loss_density(density_pred, density_gt)
                loss = region_loss + density_loss
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
                self.loss_hist.append(float(loss))
                epoch_loss.append(float(loss.cpu().item()))

                self.optimizer.step()
                self.optimizer.zero_grad()
                # self.scheduler(self.optimizer, iter_num, epoch)

                # Visualize
                global_step = iter_num + self.nbatch_train * epoch + 1
                self.writer.add_scalar('train/loss', loss.cpu().item(), global_step) 
                batch_time = time.time() - last_time
                last_time = time.time()
                eta = self.timer.eta(global_step, batch_time)
                self.step_time.append(batch_time)
                if global_step % opt.print_freq == 0:
                    printline = ('Epoch: [{}][{}/{}] '
                                 'lr: {:1.5f}, '  # 10x:{:1.5f}), '
                                 'eta: {}, time: {:1.1f}, '
                                 'region loss: {:1.4f}, '
                                 'density loss: {:1.4f}, '
                                 'loss: {:1.4f}').format(
                                    epoch, iter_num+1, self.nbatch_train,
                                    self.optimizer.param_groups[0]['lr'],
                                    # self.optimizer.param_groups[1]['lr'],
                                    eta, np.sum(self.step_time),
                                    region_loss, density_loss,
                                    np.mean(self.loss_hist))
                    self.logger.info(printline)

                del loss, region_loss, density_loss

            except Exception as e:
                print(e)
                continue

        self.scheduler.step()

    def validate(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        SMAE = 0
        with torch.no_grad():
            tbar = tqdm(self.val_loader, desc='\r')
            for i, sample in enumerate(tbar):
                # if i > 3: break
                imgs = sample['image'].to(opt.device)
                density_gt = sample["label"].to(opt.device)
                region_gt = (sample["label"] > 0).float()
                path = sample["path"]

                region_pred, density_pred = self.model(imgs)

                # Visualize
                global_step = i + self.nbatch_val * epoch + 1
                if global_step % opt.plot_every == 0:
                    # pred = output.data.cpu().numpy()
                    pred = torch.argmax(region_pred, dim=1)
                    self.summary.visualize_image(self.writer,
                                                 opt.dataset,
                                                 imgs,
                                                 density_gt,
                                                 pred,
                                                 global_step)

                # metrics
                target = region_gt.numpy()
                pred = region_pred.data.cpu().numpy()
                pred = np.argmax(pred, axis=1).reshape(target.shape)
                self.evaluator.add_batch(target, pred, path)
                density_pred = density_pred.clamp(min=0.00018) * region_pred.argmax(1, keepdim=True)
                SMAE += (density_gt.sum() - density_pred.sum()).abs().item()

            # Fast test during the training
            MAE = SMAE / (len(self.val_dataset) * opt.norm_cfg['para'])
            Acc = self.evaluator.Pixel_Accuracy()
            Acc_class = self.evaluator.Pixel_Accuracy_Class()
            mIoU = self.evaluator.Mean_Intersection_over_Union()
            FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
            RRecall = self.evaluator.Region_Recall()
            RNum = self.evaluator.Region_Num()
            result = 2 / (1 / mIoU + 1 / RRecall)
            titles = ["mIoU", "MAE", "Acc", "Acc_class", "fwIoU", "RRecall", "RNum", "Result"]
            values = [mIoU, MAE, Acc, Acc_class, FWIoU, RRecall, RNum, result]
            for title, value in zip(titles, values):
                self.writer.add_scalar('val/'+title, value, epoch)

            printline = ("Val: mIoU: {:.4f}, MAE: {:.4f}, "
                         "Acc: {:.4f}, Acc_class: {:.4f}, fwIoU: {:.4f}, "
                         "RRecall: {:.4f}, RNum: {:.1f}, Result: {:.4f}]").format(
                            *values)
            self.logger.info(printline)

        return result
Example #4
0
parser.add_argument('--validate', action='store_true',
                    help='validate')

# checking point
parser.add_argument('--resume', type=str, default=None,
                    help='put the path to resuming file if needed')
parser.add_argument('--checkname', type=str, default='VesselNN_Unsupervised',
                    help='set the checkpoint name')
args = parser.parse_args()

# Define Saver
saver = Saver(args)
saver.save_experiment_config()

# Define Tensorboard Summary
summary = TensorboardSummary(saver.experiment_dir)
writer = summary.create_summary()

# Data
dataset = Directory_Image_Train(images_path=args.train_images_path,
                                labels_path=args.train_labels_path,
                                data_shape=(32, 128, 128),
                                lables_shape=(32, 128, 128),
                                range_norm=args.range_norm)
dataloader = DataLoader(dataset, batch_size=torch.cuda.device_count() * args.batch_size, shuffle=True, num_workers=2)

# Data - validation
dataset_val = Single_Image_Eval(image_path=args.val_image_path,
                                label_path=args.val_label_path,
                                data_shape=(32, 128, 128),
                                lables_shape=(32, 128, 128),
Example #5
0
parser.add_argument('--resume',
                    type=str,
                    default=None,
                    help='put the path to resuming file if needed')
parser.add_argument('--checkname',
                    type=str,
                    default='VesselNN_supervised',
                    help='set the checkpoint name')
args = parser.parse_args()

# Define Saver
saver = Saver(args)
saver.save_experiment_config()

# Define Tensorboard Summary
summary = TensorboardSummary(saver.experiment_dir)
writer = summary.create_summary()

# Data
dataset = Directory_Image_Train(images_path=args.train_images_path,
                                labels_path=args.train_labels_path,
                                max_iter=20000,
                                range_norm=args.range_norm,
                                data_shape=(5, 85, 85),
                                lables_shape=(1, 1, 1))
dataloader = DataLoader(dataset,
                        batch_size=torch.cuda.device_count() * args.batch_size,
                        shuffle=True,
                        num_workers=2)

# Data - validation
Example #6
0
    def __init__(self, mode):
        # Define Saver
        self.saver = Saver(opt, mode)
        self.logger = self.saver.logger

        # visualize
        self.summary = TensorboardSummary(self.saver.experiment_dir, opt)
        self.writer = self.summary.writer

        # Define Dataloader
        # train dataset
        self.train_dataset, self.train_loader = make_data_loader(opt, train=True)
        self.nbatch_train = len(self.train_loader)
        self.num_classes = self.train_dataset.num_classes

        # val dataset
        self.val_dataset, self.val_loader = make_data_loader(opt, train=False)
        self.nbatch_val = len(self.val_loader)

        # Define Network
        # initilize the network here.
        self.model = Model(opt, self.num_classes)
        self.model = self.model.to(opt.device)

        # Detection post process(NMS...)
        self.post_pro = PostProcess(**opt.nms)

        # Define Optimizer
        if opt.adam:
            self.optimizer = optim.Adam(self.model.parameters(), lr=opt.lr)
        else:
            self.optimizer = optim.SGD(self.model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.decay)

        # Apex
        if opt.use_apex:
            self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1')

        # Resuming Checkpoint
        self.best_pred = 0.0
        self.start_epoch = 0
        if opt.resume:
            if os.path.isfile(opt.pre):
                print("=> loading checkpoint '{}'".format(opt.pre))
                checkpoint = torch.load(opt.pre)
                self.start_epoch = checkpoint['epoch'] + 1
                self.best_pred = checkpoint['best_pred']
                self.model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(opt.pre, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(opt.pre))

        # Define lr scherduler
        # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        #     self.optimizer, patience=3, verbose=True)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer,
            milestones=[round(opt.epochs * x) for x in opt.steps],
            gamma=opt.gamma)
        self.scheduler.last_epoch = self.start_epoch - 1

        # Using mul gpu
        if len(opt.gpu_id) > 1:
            self.logger.info("Using multiple gpu")
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=opt.gpu_id)

        # metrics
        if opt.eval_type == 'cocoeval':
            self.eval = COCO_eval(self.val_dataset.coco)
        else:
            self.eval = VOC_eval(self.num_classes)

        self.loss_hist = collections.deque(maxlen=500)
        self.timer = Timer(opt.epochs, self.nbatch_train, self.nbatch_val)
        self.step_time = collections.deque(maxlen=opt.print_freq)
Example #7
0
class Trainer(object):
    def __init__(self, mode):
        # Define Saver
        self.saver = Saver(opt, mode)
        self.logger = self.saver.logger

        # visualize
        self.summary = TensorboardSummary(self.saver.experiment_dir, opt)
        self.writer = self.summary.writer

        # Define Dataloader
        # train dataset
        self.train_dataset, self.train_loader = make_data_loader(opt, train=True)
        self.nbatch_train = len(self.train_loader)
        self.num_classes = self.train_dataset.num_classes

        # val dataset
        self.val_dataset, self.val_loader = make_data_loader(opt, train=False)
        self.nbatch_val = len(self.val_loader)

        # Define Network
        # initilize the network here.
        self.model = Model(opt, self.num_classes)
        self.model = self.model.to(opt.device)

        # Detection post process(NMS...)
        self.post_pro = PostProcess(**opt.nms)

        # Define Optimizer
        if opt.adam:
            self.optimizer = optim.Adam(self.model.parameters(), lr=opt.lr)
        else:
            self.optimizer = optim.SGD(self.model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.decay)

        # Apex
        if opt.use_apex:
            self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1')

        # Resuming Checkpoint
        self.best_pred = 0.0
        self.start_epoch = 0
        if opt.resume:
            if os.path.isfile(opt.pre):
                print("=> loading checkpoint '{}'".format(opt.pre))
                checkpoint = torch.load(opt.pre)
                self.start_epoch = checkpoint['epoch'] + 1
                self.best_pred = checkpoint['best_pred']
                self.model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(opt.pre, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(opt.pre))

        # Define lr scherduler
        # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        #     self.optimizer, patience=3, verbose=True)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer,
            milestones=[round(opt.epochs * x) for x in opt.steps],
            gamma=opt.gamma)
        self.scheduler.last_epoch = self.start_epoch - 1

        # Using mul gpu
        if len(opt.gpu_id) > 1:
            self.logger.info("Using multiple gpu")
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=opt.gpu_id)

        # metrics
        if opt.eval_type == 'cocoeval':
            self.eval = COCO_eval(self.val_dataset.coco)
        else:
            self.eval = VOC_eval(self.num_classes)

        self.loss_hist = collections.deque(maxlen=500)
        self.timer = Timer(opt.epochs, self.nbatch_train, self.nbatch_val)
        self.step_time = collections.deque(maxlen=opt.print_freq)

    def training(self, epoch):
        self.model.train()
        epoch_loss = []
        last_time = time.time()
        for iter_num, data in enumerate(self.train_loader):
            # if iter_num >= 0: break
            try:
                self.optimizer.zero_grad()
                inputs = data['img'].to(opt.device)
                targets = data['annot'].to(opt.device)

                losses = self.model(inputs, targets)
                loss, log_vars = parse_losses(losses)

                if bool(loss == 0):
                    continue
                if opt.use_apex:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), opt.grad_clip)
                self.optimizer.step()
                self.loss_hist.append(float(loss.cpu().item()))
                epoch_loss.append(float(loss.cpu().item()))

                # visualize
                global_step = iter_num + self.nbatch_train * epoch + 1
                loss_logs = ""
                for _key, _value in log_vars.items():
                    loss_logs += "{}: {:.4f}  ".format(_key, _value)
                    self.writer.add_scalar('train/{}'.format(_key),
                                           _value,
                                           global_step)

                batch_time = time.time() - last_time
                last_time = time.time()
                eta = self.timer.eta(global_step, batch_time)
                self.step_time.append(batch_time)
                if global_step % opt.print_freq == 0:
                    printline = ("Epoch: [{}][{}/{}]  "
                                 "lr: {}  eta: {}  time: {:1.1f}  "
                                 "{}"
                                 "Running loss: {:1.5f}").format(
                                    epoch, iter_num + 1, self.nbatch_train,
                                    self.optimizer.param_groups[0]['lr'],
                                    eta, np.sum(self.step_time),
                                    loss_logs,
                                    np.mean(self.loss_hist))
                    self.logger.info(printline)

            except Exception as e:
                print(e)
                continue

        # self.scheduler.step(np.mean(epoch_loss))
        self.scheduler.step()

    def validate(self, epoch):
        self.model.eval()
        # torch.backends.cudnn.benchmark = False
        # self.model.apply(uninplace_relu)
        # start collecting results
        with torch.no_grad():
            for ii, data in enumerate(self.val_loader):
                # if ii > 0: break
                scale = data['scale']
                index = data['index']
                inputs = data['img'].to(opt.device)
                targets = data['annot']

                # run network
                scores, labels, boxes = self.model(inputs)

                scores_bt, labels_bt, boxes_bt = self.post_pro(
                    scores, labels, boxes, inputs.shape[-2:])

                outputs = []
                for k in range(len(boxes_bt)):
                    outputs.append(torch.cat((
                        boxes_bt[k].clone(),
                        labels_bt[k].clone().unsqueeze(1).float(),
                        scores_bt[k].clone().unsqueeze(1)),
                        dim=1))

                # visualize
                global_step = ii + self.nbatch_val * epoch
                if global_step % opt.plot_every == 0:
                    self.summary.visualize_image(
                        inputs, targets, outputs,
                        self.val_dataset.labels,
                        global_step)

                # eval
                if opt.eval_type == "voceval":
                    self.eval.statistics(outputs, targets, iou_thresh=0.5)

                elif opt.eval_type == "cocoeval":
                    self.eval.statistics(outputs, scale, index)

                print('{}/{}'.format(ii, len(self.val_loader)), end='\r')

            if opt.eval_type == "voceval":
                stats, ap_class = self.eval.metric()
                for key, value in stats.items():
                    self.writer.add_scalar('val/{}'.format(key), value.mean(), epoch)
                self.saver.save_voc_eval_result(stats, ap_class, self.val_dataset.labels)
                return stats['AP']

            elif opt.eval_type == "cocoeval":
                stats = self.eval.metirc()
                self.saver.save_coco_eval_result(stats)
                self.writer.add_scalar('val/mAP', stats[0], epoch)
                return stats[0]

            else:
                raise NotImplementedError
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(args.logdir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        dltrain = DLDataset('trainval', "./data/pascal_voc_seg/tfrecord/")
        dlval = DLDataset('val', "./data/pascal_voc_seg/tfrecord/")
        # dltrain = DLDataset('trainval', "./data/pascal_voc_seg/VOCdevkit/VOC2012/")
        # dlval = DLDataset('val', "./data/pascal_voc_seg/VOCdevkit/VOC2012/")
        self.train_loader = DataLoader(dltrain,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.workers,
                                       pin_memory=True)
        self.val_loader = DataLoader(dlval,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True)

        # Define network
        model = Deeplab()

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = nn.CrossEntropyLoss(ignore_index=255).cuda()
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(21)
        # Define lr scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer=optimizer)

        # Using cuda
        # if args.cuda:
        # self.model = torch.nn.DataParallel(self.model)
        self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(args.logdir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        dltrain = DLDataset('trainval', "./data/pascal_voc_seg/tfrecord/")
        dlval = DLDataset('val', "./data/pascal_voc_seg/tfrecord/")
        # dltrain = DLDataset('trainval', "./data/pascal_voc_seg/VOCdevkit/VOC2012/")
        # dlval = DLDataset('val', "./data/pascal_voc_seg/VOCdevkit/VOC2012/")
        self.train_loader = DataLoader(dltrain,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.workers,
                                       pin_memory=True)
        self.val_loader = DataLoader(dlval,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True)

        # Define network
        model = Deeplab()

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = nn.CrossEntropyLoss(ignore_index=255).cuda()
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(21)
        # Define lr scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer=optimizer)

        # Using cuda
        # if args.cuda:
        # self.model = torch.nn.DataParallel(self.model)
        self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, (image, target) in enumerate(tbar):
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            self.optimizer.zero_grad()

            output = self.model(image)
            loss = self.criterion(output, target.long())
            loss.backward()
            self.optimizer.step()

            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            # if i % (num_img_tr // 10) == 0:
            if i % 10 == 0:
                global_step = i + num_img_tr * epoch
                self.summary.visualize_image(self.writer, self.args.dataset,
                                             image, target, output,
                                             global_step)
        self.scheduler.step(train_loss)
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)

    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, (image, target) in enumerate(tbar):

            if self.args.cuda:
                image, target = image.cuda(), target.cuda()

            with torch.no_grad():
                output = self.model(image)

            loss = self.criterion(output, target.long())
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)