def __init__(self, faster_rcnn):
        super(FasterRCNNTrainer, self).__init__()

        self.faster_rcnn = faster_rcnn
        self.rpn_sigma = opt.rpn_sigma  #是在_faster_rcnn_loc_loss调用用来计算位置损失函数用到的超参数,
        self.roi_sigma = opt.roi_sigma

        self.anchor_target_creator = AnchorTargetCreator(
        )  #从上万个anchor中挑选256个来训练rpn,其中正样本不超过128
        self.proposal_target_creator = ProposalTargetCreator(
        )  #从rpn给的2000个框中挑出128个来训练roihead,其中正样本不超过32个

        self.loc_normalize_mean = faster_rcnn.loc_normalize_mean
        self.loc_normalize_std = faster_rcnn.loc_normalize_std

        self.optimizer = self.faster_rcnn.get_optimizer()
        #可视化
        self.vis = Visualizer(env=opt.env)

        #验证预测值和真实值的精度
        self.rpn_cm = ConfusionMeter(
            2)  #混淆矩阵,就是验证预测值与真实值精确度的矩阵ConfusionMeter(2)括号里的参数指的是类别数
        self.roi_cm = ConfusionMeter(opt.class_num + 1)
        self.meters = {k: AverageValueMeter()
                       for k in LossTuple._fields}  #验证平均loss
Beispiel #2
0
    def __init__(self, faster_rcnn):
        # 继承父模块的初始化
        super(FasterRCNNTrainer, self).__init__()

        self.faster_rcnn = faster_rcnn
        self.rpn_sigma = opt.rpn_sigma
        self.roi_sigma = opt.roi_sigma  # 超参:在_faster_rcnn_loc_loss调用用来计算位置损失函数

        # target creator create gt_bbox gt_label etc as training targets.
        # 用于从20000个候选anchor中产生256个anchor进行二分类和位置回归,也就是为rpn网络产生的预测位置和预测类别提供真正的ground_truth标准
        self.anchor_target_creator = AnchorTargetCreator()
        # AnchorTargetCreator和ProposalTargetCreator是为了生成训练的目标(或称ground truth),只在训练阶段用到,ProposalCreator是RPN为Fast R-CNN生成RoIs,在训练和测试阶段都会用到。所以测试阶段直接输进来300个RoIs,而训练阶段会有AnchorTargetCreator的再次干预。
        self.proposal_target_creator = ProposalTargetCreator()
        # (0., 0., 0., 0.)
        self.loc_normalize_mean = faster_rcnn.loc_normalize_mean
        # (0.1, 0.1, 0.2, 0.2)
        self.loc_normalize_std = faster_rcnn.loc_normalize_std

        self.optimizer = self.faster_rcnn.get_optimizer()  # SGD
        # visdom wrapper
        self.vis = Visualizer(env=opt.env)  # 可视化工具

        # indicators for training status
        # 混淆矩阵,验证预测值和真实值精确度,2为类别数
        self.rpn_cm = ConfusionMeter(2)
        #
        self.roi_cm = ConfusionMeter(21)
        self.meters = {k: AverageValueMeter()
                       for k in LossTuple._fields}  # average loss
    def __init__(self, faster_rcnn):
        # 继承父模块的初始化
        super(FasterRCNNTrainer, self).__init__()

        self.faster_rcnn = faster_rcnn
        # 下面2个参数是在_faster_rcnn_loc_loss调用用来计算位置损失函数用到的超参数
        self.rpn_sigma = opt.rpn_sigma
        self.roi_sigma = opt.roi_sigma

        # target creator create gt_bbox gt_label etc as training targets.
        # 用于从20000个候选anchor中产生256个anchor进行二分类和位置回归,也就是
        # 为rpn网络产生的预测位置和预测类别提供真正的ground_truth标准
        self.anchor_target_creator = AnchorTargetCreator()
        # AnchorTargetCreator和ProposalTargetCreator是为了生成训练的目标
        # (或称ground truth),只在训练阶段用到,ProposalCreator是RPN为Fast
        #  R-CNN生成RoIs,在训练和测试阶段都会用到。所以测试阶段直接输进来300
        # 个RoIs,而训练阶段会有AnchorTargetCreator的再次干预
        self.proposal_target_creator = ProposalTargetCreator()
        # (0., 0., 0., 0.)
        self.loc_normalize_mean = faster_rcnn.loc_normalize_mean
        # (0.1, 0.1, 0.2, 0.2)
        self.loc_normalize_std = faster_rcnn.loc_normalize_std
        # SGD
        self.optimizer = self.faster_rcnn.get_optimizer()
        # 可视化,vis_tool.py
        self.vis = Visualizer(env=opt.env)

        # 混淆矩阵,就是验证预测值与真实值精确度的矩阵ConfusionMeter
        # (2)括号里的参数指的是类别数
        self.rpn_cm = ConfusionMeter(2)
        # roi的类别有21种(20个object类+1个background)
        self.roi_cm = ConfusionMeter(21)
        # 平均损失
        self.meters = {k: AverageValueMeter()
                       for k in LossTuple._fields}  # average loss
    def __init__(self,
                 faster_rcnn,
                 attacker=None,
                 layer_idx=None,
                 attack_mode=False):
        super(BRFasterRcnnTrainer, self).__init__()

        self.faster_rcnn = faster_rcnn
        self.attacker = attacker
        self.layer_idx = layer_idx
        self.rpn_sigma = opt.rpn_sigma
        self.roi_sigma = opt.roi_sigma
        self.attack_mode = attack_mode

        self.anchor_target_creator = AnchorTargetCreator()
        self.proposal_target_creator = ProposalTargetCreator()

        self.loc_normalize_mean = faster_rcnn.loc_normalize_mean
        self.loc_normalize_std = faster_rcnn.loc_normalize_std

        self.optimizer = self.faster_rcnn.get_optimizer()

        self.vis = Visualizer(env=opt.env)

        self.rpn_cm = ConfusionMeter(2)
        self.roi_cm = ConfusionMeter(21)
        self.meters = {k: AverageValueMeter() for k in LossTuple._fields}
        self.BR_meters = {k: AverageValueMeter() for k in LossTupleBR._fields}
    def __init__(self, faster_rcnn):
        super(FasterRCNNTrainer, self).__init__()
        #传入的是FasterRCNNVGG16模型,继承了FasterRCNN模型,而参数根据说明 是FasterRCNN模型
        #即初始化的是FasterRCNN模型
        #FasterRCNN模型是父类   FasterRCNNVGG16模型是子类
        self.faster_rcnn = faster_rcnn
        #sigma for l1_smooth_loss
        self.rpn_sigma = opt.rpn_sigma
        self.roi_sigma = opt.roi_sigma

        # target creator create gt_bbox gt_label etc as training targets.
        #目标框creator 目标是产生 真实的bbox 类别标签等
        #将真实的bbox分配给锚点
        self.anchor_target_creator = AnchorTargetCreator()
        self.proposal_target_creator = ProposalTargetCreator()
        #得到faster网络权重,均值 和方差
        self.loc_normalize_mean = faster_rcnn.loc_normalize_mean
        self.loc_normalize_std = faster_rcnn.loc_normalize_std

        #得到faster网络的优化器
        self.optimizer = self.faster_rcnn.get_optimizer()
        # visdom wrapper
        self.vis = Visualizer(env=opt.env)

        # indicators for training status
        #训练状态指标  两个混淆矩阵 2×2(前景后景)   21×21(20类+背景)
        self.rpn_cm = ConfusionMeter(2)
        self.roi_cm = ConfusionMeter(21)
        self.meters = {k: AverageValueMeter()
                       for k in LossTuple._fields}  # average loss 平均损失
Beispiel #6
0
    def __init__(self, faster_rcnn):
        super(FasterRCNNTrainer, self).__init__()

        self.faster_rcnn = faster_rcnn
        self.rpn_sigma = opt.rpn_sigma
        self.roi_sigma = opt.roi_sigma
        self.rpn_pen = opt.rpn_pen
        self.roi_pen = opt.roi_pen

        # target creator create gt_bbox gt_label etc as training targets.
        # FLAG: add params
        # Initail best: pos 0.2, neg 0.1
        self.anchor_target_creator = AnchorTargetCreator(pos_ratio=0.5,
                                                         pos_iou_thresh=0.7,
                                                         neg_iou_thresh=0.3)
        # Initial best: pos 0.2, neg 0.2
        self.proposal_target_creator = ProposalTargetCreator(pos_ratio=0.5,
                                                             pos_iou_thresh=0.5,
                                                             neg_iou_thresh_hi=0.5)

        self.loc_normalize_mean = faster_rcnn.loc_normalize_mean
        self.loc_normalize_std = faster_rcnn.loc_normalize_std

        self.optimizer = self.faster_rcnn.get_optimizer()
        # visdom wrapper
        self.vis = Visualizer(env=opt.env)

        # indicators for training status
        self.rpn_cm = ConfusionMeter(2)
        self.roi_cm = ConfusionMeter(4)
        self.meters = {k: AverageValueMeter() for k in LossTuple._fields}  # average loss
    def __init__(self, faster_rcnn, logger):
        super(FasterRCNNTrainer, self).__init__()

        self.logger = logger

        self.faster_rcnn = faster_rcnn
        self.rpn_sigma = opt.rpn_sigma
        self.roi_sigma = opt.roi_sigma
        self.losses = AverageVal()

        # target creator create gt_bbox gt_label etc as training targets.
        self.anchor_target_creator = AnchorTargetCreator()
        self.proposal_target_creator = ProposalTargetCreator()

        self.loc_normalize_mean = faster_rcnn.loc_normalize_mean
        self.loc_normalize_std = faster_rcnn.loc_normalize_std

        self.optimizer = self.faster_rcnn.get_optimizer()
        # visdom wrapper
        self.vis = Visualizer(env=opt.env)

        # indicators for training status
        self.rpn_cm = ConfusionMeter(2)
        self.roi_cm = ConfusionMeter(21)
        self.meters = {k: AverageValueMeter()
                       for k in LossTuple._fields}  # average loss
Beispiel #8
0
    def __init__(self, faster_rcnn):
        super(FasterRCNNTrainer, self).__init__()

        self.faster_rcnn = faster_rcnn
        #在faster_rcnn_loc_losss中调用,用来计算位置损失函数时用到的超参
        self.rpn_sigma = opt.rpn_sigma
        self.roi_sigma = opt.roi_sigma

        # target creator create gt_bbox gt_label etc as training targets.
        #用于从20000个候选anchor中产生256个anchor进行二分类和位置回归,用于rpn的训练
        self.anchor_target_creator = AnchorTargetCreator()
        #从2000个筛选出的ROIS中再次选出128个ROIs用于ROIhead训练
        self.proposal_target_creator = ProposalTargetCreator()
        #定义位置信息的均值方差。因为送入网络训练的位置信息需全部归一化处理
        self.loc_normalize_mean = faster_rcnn.loc_normalize_mean
        self.loc_normalize_std = faster_rcnn.loc_normalize_std

        self.optimizer = self.faster_rcnn.get_optimizer()
        # visdom wrapper
        self.vis = Visualizer(env=opt.env)

        # indicators for training status
        self.rpn_cm = ConfusionMeter(2)
        self.roi_cm = ConfusionMeter(21)
        self.meters = {k: AverageValueMeter()
                       for k in LossTuple._fields}  # average loss
Beispiel #9
0
    def __init__(self, faster_rcnn):
        super(FasterRCNNTrainer, self).__init__()

        self.faster_rcnn = faster_rcnn
        self.rpn_sigma = opt.rpn_sigma
        self.roi_sigma = opt.roi_sigma

        # target creator create gt_bbox gt_label etc as training targets. 
        self.anchor_target_creator = AnchorTargetCreator()
        self.proposal_target_creator = ProposalTargetCreator()

        self.loc_normalize_mean = faster_rcnn.loc_normalize_mean
        self.loc_normalize_std = faster_rcnn.loc_normalize_std

        self.optimizer = self.faster_rcnn.get_optimizer()
        # visdom wrapper
        self.vis = Visualizer(env=opt.env)
Beispiel #10
0
    def __init__(self, siam_reid):
        super(Trainer, self).__init__()

        self.siam_reid = siam_reid
        self.rpn_sigma = opt.rpn_sigma
        self.roi_sigma = opt.roi_sigma

        self.anchor_target_creator = AnchorTargetCreator()
        self.proposal_filter = ProposalFilter()
        self.loc_normalize_mean = (0., 0., 0., 0.)
        self.loc_normalize_std = (0.1, 0.1, 0.2, 0.2)

        self.optimizer = self.siam_reid.get_optimizer()
        self.vis = Visualizer(env=opt.env)

        # self.rpn_cm = ConfusionMeter(2)
        # self.roi_cm = ConfusionMeter(2)
        self.meters = {k: AverageValueMeter() for k in LossTuple._fields}
Beispiel #11
0
    def __init__(self, faster_rcnn):
        super(FasterRCNNTrainer, self).__init__()

        self.faster_rcnn = faster_rcnn
        self.rpn_sigma = opt.rpn_sigma
        self.roi_sigma = opt.roi_sigma

        # target creator create gt_bbox gt_label etc as training targets. 
        self.anchor_target_creator = AnchorTargetCreator()#用于从20000个候选anchor中产生256个anchor进行二分类和位置回归,也就是为rpn网络产生的预测位置和预测类别提供真正的ground_truth标准
        self.proposal_target_creator = ProposalTargetCreator()#AnchorTargetCreator和ProposalTargetCreator是为了生成训练的目标

        self.loc_normalize_mean = faster_rcnn.loc_normalize_mean#
        self.loc_normalize_std = faster_rcnn.loc_normalize_std

        self.optimizer = self.faster_rcnn.get_optimizer()#SGD
        # visdom wrapper
        self.vis = Visualizer(env=opt.env)

        # indicators for training status
        self.rpn_cm = ConfusionMeter(2)#构造一个用于多类分类问题的混淆矩阵
        self.roi_cm = ConfusionMeter(21)
        self.meters = {k: AverageValueMeter() for k in LossTuple._fields}  # average loss
    def __init__(self, faster_rcnn):
        super(FasterRCNNTrainer, self).__init__()

        self.faster_rcnn = faster_rcnn #训练的网络
        self.rpn_sigma = opt.rpn_sigma #计算rpn_loc_loss的超参数
        self.roi_sigma = opt.roi_sigma #计算roi_loc_loss的超参数

        # target creator create gt_bbox gt_label etc as training targets. 
        self.anchor_target_creator = AnchorTargetCreator() #rpn处使用的,实例化anchor和gt匹配函数
        self.proposal_target_creator = ProposalTargetCreator() #roi处使用的,实例化提取roi正负样本函数

        self.loc_normalize_mean = faster_rcnn.loc_normalize_mean #用于坐标归一化的mean
        self.loc_normalize_std = faster_rcnn.loc_normalize_std #用于坐标归一化的std

        self.optimizer = self.faster_rcnn.get_optimizer() #实例化得到optimizer函数
        # visdom wrapper
        self.vis = Visualizer(env=opt.env) #可视化内容,(跳过)

        # indicators for training status
        self.rpn_cm = ConfusionMeter(2) #可视化内容,(跳过)
        self.roi_cm = ConfusionMeter(21) #可视化内容,(跳过)
        self.meters = {k: AverageValueMeter() for k in LossTuple._fields}  # average loss,可视化内容,(跳过)
Beispiel #13
0
def train(**kwargs):
    opt._parse(kwargs)

    device = t.device('cuda') if opt.use_gpu else t.device('cpu')
    vis = Visualizer(opt.env)

    # Data loading
    transfroms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x * 255)
    ])
    dataset = tv.datasets.ImageFolder(opt.data_root, transfroms)
    dataloader = data.DataLoader(dataset, opt.batch_size)

    # style transformer network
    transformer = TransformerNet()
    if opt.model_path:
        transformer.load_state_dict(
            t.load(opt.model_path, map_location=lambda _s, _: _s))
    transformer.to(device)

    # Vgg16 for Perceptual Loss
    vgg = Vgg16().eval()
    vgg.to(device)
    for param in vgg.parameters():
        param.requires_grad = False

    # Optimizer: use Adam
    optimizer = t.optim.Adam(transformer.parameters(), opt.lr)

    # Get style image
    style = utils.get_style_data(opt.style_path)
    vis.img('style', (style.data[0] * 0.225 + 0.45).clamp(min=0, max=1))
    style = style.to(device)

    # print("style.shape: ", style.shape)

    # gram matrix for style image
    with t.no_grad():
        features_style = vgg(style)
        gram_style = [utils.gram_matrix(y) for y in features_style]

    # Loss meter
    style_meter = tnt.meter.AverageValueMeter()
    content_meter = tnt.meter.AverageValueMeter()

    for epoch in range(opt.epoches):
        content_meter.reset()
        style_meter.reset()

        for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)):

            # Train
            optimizer.zero_grad()
            x = x.to(device)
            y = transformer(x)
            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)
            features_y = vgg(y)
            features_x = vgg(x)

            # content loss
            content_loss = opt.content_weight * F.mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            # style loss
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gram_y = utils.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            # Loss smooth for visualization
            content_meter.add(content_loss.item())
            style_meter.add(style_loss.item())

            if (ii + 1) % opt.plot_every == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                # visualization
                vis.plot('content_loss', content_meter.value()[0])
                vis.plot('style_loss', style_meter.value()[0])
                # denorm input/output, since we have applied (utils.normalize_batch)
                vis.img('output',
                        (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))
                vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0,
                                                                        max=1))

        # save checkpoint
        vis.save([opt.env])
        t.save(transformer.state_dict(), 'checkpoints/%s_style.pth' % epoch)
Beispiel #14
0
                    help='env name for visdom')
parser.add_argument('--save',
                    default='./work',
                    type=str,
                    help='directory for saving')
parser.add_argument('--log',
                    default='log_combine.txt',
                    type=str,
                    help='log_file')
args = parser.parse_args()
print(args)

logger = logging.getLogger('__name__')
logger.setLevel(logging.DEBUG)

vis = Visualizer(env=args.vis_env)
# ------------------------------------------------------------------------------------------------

gpu = [int(i) for i in args.gpu.split(',')]
torch.cuda.set_device(gpu[0])
gpu_num = len(gpu)
logger.info('there are %d gpus are used' % gpu_num)

filepath = 'dataset/MOT/dataset.json'
train_dataset = MOT(file=filepath,
                    train=True,
                    rangex=args.rangex,
                    max_num=args.sample_num)
val_dataset = MOT(file=filepath,
                  train=False,
                  rangex=args.rangex,
Beispiel #15
0
    model = UNet(input_channels=1, nclasses=1)
    if opt.is_train:
        # split all data to train and validation, set split = True
        train_loader, val_loader = get_train_val_loader(
            opt.root_dir,
            batch_size=opt.batch_size,
            val_ratio=0.15,
            shuffle=True,
            num_workers=4,
            pin_memory=False)

        optimizer = optim.Adam(model.parameters(),
                               lr=opt.learning_rate,
                               weight_decay=opt.weight_decay)
        criterion = nn.BCELoss()
        vis = Visualizer(env=opt.env)

        if opt.is_cuda:
            model.cuda()
            criterion.cuda()
            if opt.n_gpu > 1:
                model = nn.DataParallel(model)

        run(model, train_loader, val_loader, criterion, vis)
    else:
        if opt.is_cuda:
            model.cuda()
            if opt.n_gpu > 1:
                model = nn.DataParallel(model)
        test_loader = get_test_loader(batch_size=20,
                                      shuffle=True,
Beispiel #16
0
def main():
    # print the parameters:
    opt._parse()

    # data augmentation
    train_aug = JointCompose([
        JointRandomFlip(),
        JointRandomRotation(),
        JointRandomGaussianNoise(8),
        JointRandomSubTractGaussianNoise(8),
        JointRandomBrightness([-0.3, 0.3]),
        JointRandomIntensityChange([-0.3, 0.3]),
    ])

    # load the dataloader
    prefix = 'WeaklySupervised'
    train_loader = GetDatasetLoader(opt,
                                    prefix,
                                    phase='train',
                                    augument=train_aug)
    val_loader = GetDatasetLoader(opt, prefix, phase='val')

    # get the network
    model = GetModel(opt)

    # training from the beginning or transfer learning.
    # if transfer learning: opt.load_state = True
    # if opt.load_state:
    #     ##### fine tuning ####
    #     ## replace your model
    #     parameters_name = 'XXXXXXXXXXX'
    #     model_CKPT = torch.load(parameters_name)
    #     model.load_state_dict(model_CKPT['state_dict'])

    optimizer = GetOptimizer(opt, model)
    scheduler = GetScheduler(opt, optimizer)

    # get the loss function
    criterion = DiceLossPlusCrossEntrophy()

    # inspect training process
    vis_tool = Visualizer(env='VoxResNet_3D')
    log = WriteLog(opt)
    log.write(
        'epoch |train_loss |train_dice |train_recall |valid_loss |valid dice |valid_recall |time          \n'
    )
    log.write('------------------------------------------------------------\n')

    # record the training value
    record_value = np.array([0, 0])

    # begin to train:
    for epoch_num in range(opt.train_epoch):
        start_time = time.time()
        scheduler.step()

        avg_loss, train_dice, train_recall, best_value = TrainOneEpoch(
            train_loader, model, optimizer, criterion, epoch_num, vis_tool,
            record_value)
        record_value = best_value

        # the information
        run_time = time.time() - start_time
        print('Train Epoch{} run time is:{:.3f}m and {:.3f}s'.format(
            epoch_num, run_time // 60, run_time % 60))
        print('Loss:{:.3f}  Recall:{:.3f}  Dice:{:.3f}'.format(
            avg_loss, train_recall, train_dice))

        # save the parameters
        save_name1 = (opt.save_parameters_name + '_epoch_{}').format(epoch_num)
        state = {
            'epoch': epoch_num + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        save_parameters(state, save_name1)

        # using val_run for validation
        if opt.val_run:
            val_avgloss, val_dice, val_recall = val_model(
                opt, val_loader, model, criterion, vis_tool)
            print('Test Loss:{:.3f}  Recall:{:.3f}  Dice:{:.3f}'.format(
                val_avgloss, val_recall, val_dice))

        log.write('%d |%0.3f |%0.3f |%0.3f |%0.3f |%0.3f |%0.3f |%0.3f \n' %
                  (epoch_num, avg_loss, train_dice, train_recall, val_avgloss,
                   val_dice, val_recall, run_time))
        log.write('\n')
Beispiel #17
0
def main():
    ##########################################################
    # 1. generate random cropped samples from the selected images for training
    if opt.random_sample_flag == 1:
        Generate_Random_Samples(opt)

    # 2. load the paramter and print them
    opt._parse()

    # load the dataloader
    train_loader = GetDatasetLoader(opt, phase='train')
    val_loader = GetDatasetLoader(opt, phase='val')

    # get the net work
    model = GetModel(opt)

    # if opt.load_state:
    #     ##### fine tune the result
    #     parameters_name = 'checkpoints/'+opt.dataset_prefix+'_epoch_100.ckpt'
    #     model_CKPT = torch.load(parameters_name)
    #     model.load_state_dict(model_CKPT['state_dict'])

    optimizer = GetOptimizer(opt, model)
    scheduler = GetScheduler(opt, optimizer)

    # get the loss function
    criterion = DiceLossPlusCrossEntrophy()

    vis_tool = Visualizer(env='VoxResNet_3D')
    log_path = 'result'
    log_name = os.path.join(log_path, 'log_{}.txt'.format(opt.dataset_prefix))
    log = WriteLog(log_path, log_name, opt)
    log.write('selected dataset \n')
    log.write('weighted and dice: 0.5:1 \n')
    log.write(
        'epoch |train_loss |train_dice |train_recall |valid_loss |valid dice |valid_recall |time          \n'
    )
    log.write('------------------------------------------------------------\n')

    # begin to train:
    for epoch_num in range(opt.train_epoch):
        # for epoch_num in np.arange(1,30):
        start_time = time.time()
        scheduler.step()

        avg_loss,train_dice,train_recall=\
            TrainOneEpoch(train_loader, model, optimizer, criterion, epoch_num, vis_tool,opt.dataset_prefix)

        # the information
        run_time = time.time() - start_time
        print('Train Epoch{} run time is:{:.3f}m and {:.3f}s'.format(
            epoch_num, run_time // 60, run_time % 60))
        print('Loss:{:.3f}  Recall:{:.3f}  Dice:{:.3f}'.format(
            avg_loss, train_recall, train_dice))

        if epoch_num >= 0:
            save_name1 = (opt.dataset_prefix + '_epoch_{}').format(epoch_num)
            state = {
                'epoch': epoch_num + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }

            save_parameters(state, save_name1)

        # judge whether to test or not
        if opt.val_run:
            val_avgloss, val_dice, val_recall = val_model(
                opt, val_loader, model, criterion, vis_tool)
            print('Test Loss:{:.3f}  Recall:{:.3f}  Dice:{:.3f}'.format(
                val_avgloss, val_recall, val_dice))

        log.write('%d |%0.3f |%0.3f |%0.3f |%0.3f |%0.3f |%0.3f |%0.3f \n' %
                  (epoch_num, avg_loss, train_dice, train_recall, val_avgloss,
                   val_dice, val_recall, run_time))
        log.write('\n')
Beispiel #18
0
def train():
    opt._parse()
    vis_tool = Visualizer(env=opt.env)

    print('load data')
    train_dataset = Datatrain(opt.rootpath, mode="train/")
    val_dataset = Dataval(opt.rootpath, mode="val/")

    trainer = Ftrainer(opt, image_size=opt.image_size)
    if opt.load_G:
        trainer.load_G(opt.load_G)
    print('model G construct completed')

    if opt.load_F:
        trainer.load_F(opt.load_F)
        print('model F construct completed')

    best_map = 0.0
    for epoch in range(opt.epoch):
        trainer.train()
        train_dataloader = data_.DataLoader(train_dataset,
                                            batch_size=opt.train_batch_size,
                                            shuffle=True,
                                            num_workers=opt.num_workers)
        val_dataloader = data_.DataLoader(val_dataset,
                                          batch_size=opt.test_batch_size,
                                          num_workers=opt.num_workers,
                                          shuffle=False)

        # test_model(test_dataloader, trainer, epoch, ifsave=True, test_num=opt.test_num)
        for ii, (img, oriimg, mask) in tqdm(enumerate(train_dataloader),
                                            total=len(train_dataloader)):
            loss, loss1, loss2 = trainer.train_onebatch(img, oriimg, mask)
            if ii % 20 == 0:
                trainer.eval()
                vis_tool.plot("totalloss", loss.detach().cpu().numpy())
                vis_tool.plot("loss_r", loss1.detach().cpu().numpy())
                vis_tool.plot("loss_t", loss2.detach().cpu().numpy())
                snr, output, edg, edg2 = trainer(img[0:2, :, :, :],
                                                 oriimg[0:2, :, :, :],
                                                 mask[0:2, :, :, :],
                                                 vis=True)
                vis_tool.plot("snr_train", snr)
                input = img[0][0].numpy()
                input = (input * 255).astype(np.uint8)
                vis_tool.img("input", input)
                label = oriimg[0][0].numpy()
                label = (label * 255).astype(np.uint8)
                vis_tool.img("label", label)
                snr = round(snr, 2)
                vis_pic(output, snr, vis_tool)
                vis_tool.img("predict_segm", edg[0])
                vis_tool.img("ori_segm", edg2[0])
                trainer.train()

        ifsave = False
        if (epoch + 1) % 10 == 0:
            ifsave = True
        eval_result = test_model(val_dataloader,
                                 trainer,
                                 epoch,
                                 ifsave=ifsave,
                                 test_num=opt.test_num)
        print('eval_loss: ', eval_result)

        vis_tool.plot("SNR_val", eval_result["SNR"])
        if epoch > 100 and eval_result["SNR"] > best_map:
            best_map = eval_result["SNR"]
            best_path = trainer.save_F(best_map=best_map)
            print("save to %s !" % best_path)
Beispiel #19
0
    # log
    parser.add_argument('--env_name',
                        dest='env_name',
                        help='name of visdom environment',
                        default='HopeNet',
                        type=str)
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    # os.environ['CUDA_VISIBLE_DEVICES'] =  args.gpu
    device = torch.device('cuda:{}'.format(args.gpu))
    # logger
    vis = Visualizer(env=args.env_name)
    # dataset
    trainset = graph_dataset(subsets=('zara01', 'eth', 'hotel', 'univ'))
    validset = graph_dataset(subsets=('zara02', ))
    train_dataloader = DataLoader(trainset,
                                  batch_size=args.bs,
                                  shuffle=True,
                                  collate_fn=trainset.collate_fn,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    valid_dataloader = DataLoader(validset,
                                  batch_size=args.bs,
                                  shuffle=False,
                                  collate_fn=validset.collate_fn,
                                  num_workers=args.num_workers,
                                  pin_memory=True)