def train(args):

    logger.auto_set_dir()
    from pytorchgo.utils.pytorch_utils import set_gpu
    set_gpu(args.gpu)

    # Setup Dataloader
    from pytorchgo.augmentation.segmentation import SubtractMeans, PIL2NP, RGB2BGR, PIL_Scale, Value255to0, ToLabel
    from torchvision.transforms import Compose, Normalize, ToTensor
    img_transform = Compose([  # notice the order!!!
        PIL_Scale(train_img_shape, Image.BILINEAR),
        PIL2NP(),
        RGB2BGR(),
        SubtractMeans(),
        ToTensor(),
    ])

    label_transform = Compose([
        PIL_Scale(train_img_shape, Image.NEAREST),
        PIL2NP(),
        Value255to0(),
        ToLabel()
    ])

    val_img_transform = Compose([
        PIL_Scale(train_img_shape, Image.BILINEAR),
        PIL2NP(),
        RGB2BGR(),
        SubtractMeans(),
        ToTensor(),
    ])
    val_label_transform = Compose([
        PIL_Scale(train_img_shape, Image.NEAREST),
        PIL2NP(),
        ToLabel(),
        # notice here, training, validation size difference, this is very tricky.
    ])

    from pytorchgo.dataloader.pascal_voc_loader import pascalVOCLoader as common_voc_loader
    train_loader = common_voc_loader(split="train_aug",
                                     epoch_scale=1,
                                     img_transform=img_transform,
                                     label_transform=label_transform)
    validation_loader = common_voc_loader(split='val',
                                          img_transform=val_img_transform,
                                          label_transform=val_label_transform)

    n_classes = train_loader.n_classes
    trainloader = data.DataLoader(train_loader,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True)

    valloader = data.DataLoader(validation_loader,
                                batch_size=args.batch_size,
                                num_workers=8)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup Model
    from pytorchgo.model.deeplabv1 import VGG16_LargeFoV
    from pytorchgo.model.deeplab_resnet import Res_Deeplab

    model = Res_Deeplab(NoLabels=n_classes, pretrained=True, output_all=False)

    from pytorchgo.utils.pytorch_utils import model_summary, optimizer_summary
    model_summary(model)

    def get_validation_miou(model):
        model.eval()
        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader),
                                                    total=len(valloader),
                                                    desc="validation"):
            if i_val > 5 and is_debug == 1: break
            if i_val > 200 and is_debug == 2: break

            #img_large = torch.Tensor(np.zeros((1, 3, 513, 513)))
            #img_large[:, :, :images_val.shape[2], :images_val.shape[3]] = images_val

            output = model(Variable(images_val, volatile=True).cuda())
            output = output
            pred = output.data.max(1)[1].cpu().numpy()
            #pred = output[:, :images_val.shape[2], :images_val.shape[3]]

            gt = labels_val.numpy()

            running_metrics.update(gt, pred)

        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            logger.info("{}: {}".format(k, v))
        running_metrics.reset()
        return score['Mean IoU : \t']

    model.cuda()

    # Check if model has custom optimizer / loss
    if hasattr(model, 'optimizer'):
        logger.warn("don't have customzed optimizer, use default setting!")
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.optimizer_params(args.l_rate),
                                    lr=args.l_rate,
                                    momentum=0.99,
                                    weight_decay=5e-4)

    optimizer_summary(optimizer)
    if args.resume is not None:
        if os.path.isfile(args.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            logger.info("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logger.info("No checkpoint found at '{}'".format(args.resume))

    best_iou = 0
    logger.info('start!!')
    for epoch in tqdm(range(args.n_epoch), total=args.n_epoch):
        model.train()
        for i, (images, labels) in tqdm(enumerate(trainloader),
                                        total=len(trainloader),
                                        desc="training epoch {}/{}".format(
                                            epoch, args.n_epoch)):
            if i > 10 and is_debug == 1: break

            if i > 200 and is_debug == 2: break

            cur_iter = i + epoch * len(trainloader)
            cur_lr = adjust_learning_rate(optimizer,
                                          args.l_rate,
                                          cur_iter,
                                          args.n_epoch * len(trainloader),
                                          power=0.9)

            images = Variable(images.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()
            outputs = model(images)  # use fusion score
            loss = CrossEntropyLoss2d_Seg(input=outputs,
                                          target=labels,
                                          class_num=n_classes)

            #for i in range(len(outputs) - 1):
            #for i in range(1):
            #    loss = loss + CrossEntropyLoss2d_Seg(input=outputs[i], target=labels, class_num=n_classes)

            loss.backward()
            optimizer.step()

            if (i + 1) % 100 == 0:
                logger.info(
                    "Epoch [%d/%d] Loss: %.4f, lr: %.7f, best mIoU: %.7f" %
                    (epoch + 1, args.n_epoch, loss.data[0], cur_lr, best_iou))

        cur_miou = get_validation_miou(model)
        if cur_miou >= best_iou:
            best_iou = cur_miou
            state = {
                'epoch': epoch + 1,
                'mIoU': best_iou,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(state,
                       os.path.join(logger.get_logger_dir(), "best_model.pth"))
Ejemplo n.º 2
0
def train(args):

    logger.auto_set_dir()
    os.environ['CUDA_VISIBLE_DEVICES'] = '3'

    # Setup Augmentations
    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path,
                           is_transform=True,
                           img_size=(args.img_rows, args.img_cols),
                           epoch_scale=4,
                           augmentations=data_aug,
                           img_norm=args.img_norm)
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='val',
                           img_size=(args.img_rows, args.img_cols),
                           img_norm=args.img_norm)

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=8)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup Model
    from model_zoo.deeplabv1 import VGG16_LargeFoV
    model = VGG16_LargeFoV(class_num=n_classes,
                           image_size=[args.img_cols, args.img_rows],
                           pretrained=True)

    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    model.cuda()

    # Check if model has custom optimizer / loss
    if hasattr(model, 'optimizer'):
        logger.warn("don't have customzed optimizer, use default setting!")
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.l_rate,
                                    momentum=0.99,
                                    weight_decay=5e-4)

    optimizer_summary(optimizer)
    if args.resume is not None:
        if os.path.isfile(args.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            logger.info("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logger.info("No checkpoint found at '{}'".format(args.resume))

    best_iou = -100.0
    for epoch in tqdm(range(args.n_epoch), total=args.n_epoch):
        model.train()
        for i, (images, labels) in tqdm(enumerate(trainloader),
                                        total=len(trainloader),
                                        desc="training epoch {}/{}".format(
                                            epoch, args.n_epoch)):
            cur_iter = i + epoch * len(trainloader)
            cur_lr = adjust_learning_rate(optimizer,
                                          args.l_rate,
                                          cur_iter,
                                          args.n_epoch * len(trainloader),
                                          power=0.9)
            #if i > 10:break

            images = Variable(images.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()
            outputs = model(images)
            #print(np.unique(outputs.data[0].cpu().numpy()))
            loss = CrossEntropyLoss2d_Seg(input=outputs,
                                          target=labels,
                                          class_num=n_classes)

            loss.backward()
            optimizer.step()

            if (i + 1) % 100 == 0:
                logger.info("Epoch [%d/%d] Loss: %.4f, lr: %.7f" %
                            (epoch + 1, args.n_epoch, loss.data[0], cur_lr))

        model.eval()
        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader),
                                                    total=len(valloader),
                                                    desc="validation"):
            images_val = Variable(images_val.cuda(), volatile=True)
            labels_val = Variable(labels_val.cuda(), volatile=True)

            outputs = model(images_val)
            pred = outputs.data.max(1)[1].cpu().numpy()
            gt = labels_val.data.cpu().numpy()
            running_metrics.update(gt, pred)

        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            logger.info("{}: {}".format(k, v))
        running_metrics.reset()

        if score['Mean IoU : \t'] >= best_iou:
            best_iou = score['Mean IoU : \t']
            state = {
                'epoch': epoch + 1,
                'mIoU': best_iou,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(state,
                       os.path.join(logger.get_logger_dir(), "best_model.pkl"))
        tgt_imgs = Variable(target[0])

        if torch.cuda.is_available():
            src_imgs, src_lbls, tgt_imgs = src_imgs.cuda(), src_lbls.cuda(), tgt_imgs.cuda()

        # update generator and classifiers by source samples
        optimizer_g.zero_grad()
        optimizer_f.zero_grad()
        loss = 0
        d_loss = 0
        outputs = model_g(src_imgs)

        outputs1 = model_f1(outputs)
        outputs2 = model_f2(outputs)

        c_loss = CrossEntropyLoss2d_Seg(outputs1, src_lbls, class_num=args.n_class)
        c_loss += CrossEntropyLoss2d_Seg(outputs2, src_lbls,  class_num=args.n_class)
        c_loss.backward(retain_graph=True)
        ####################
        lambd = 1.0
        model_f1.set_lambda(lambd)
        model_f2.set_lambda(lambd)
        outputs = model_g(tgt_imgs)
        outputs1 = model_f1(outputs, reverse=True)
        outputs2 = model_f2(outputs, reverse=True)
        loss = - criterion_d(outputs1, outputs2)
        loss.backward()
        optimizer_f.step()
        optimizer_g.step()

        d_loss = -loss.data[0]