示例#1
0
def train(args):

    logger.auto_set_dir()
    os.environ['CUDA_VISIBLE_DEVICES'] = '3'

    # Setup Augmentations
    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path,
                           is_transform=True,
                           img_size=(args.img_rows, args.img_cols),
                           epoch_scale=4,
                           augmentations=data_aug,
                           img_norm=args.img_norm)
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='val',
                           img_size=(args.img_rows, args.img_cols),
                           img_norm=args.img_norm)

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=8)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup Model
    from model_zoo.deeplabv1 import VGG16_LargeFoV
    model = VGG16_LargeFoV(class_num=n_classes,
                           image_size=[args.img_cols, args.img_rows],
                           pretrained=True)

    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    model.cuda()

    # Check if model has custom optimizer / loss
    if hasattr(model, 'optimizer'):
        logger.warn("don't have customzed optimizer, use default setting!")
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.l_rate,
                                    momentum=0.99,
                                    weight_decay=5e-4)

    optimizer_summary(optimizer)
    if args.resume is not None:
        if os.path.isfile(args.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            logger.info("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logger.info("No checkpoint found at '{}'".format(args.resume))

    best_iou = -100.0
    for epoch in tqdm(range(args.n_epoch), total=args.n_epoch):
        model.train()
        for i, (images, labels) in tqdm(enumerate(trainloader),
                                        total=len(trainloader),
                                        desc="training epoch {}/{}".format(
                                            epoch, args.n_epoch)):
            cur_iter = i + epoch * len(trainloader)
            cur_lr = adjust_learning_rate(optimizer,
                                          args.l_rate,
                                          cur_iter,
                                          args.n_epoch * len(trainloader),
                                          power=0.9)
            #if i > 10:break

            images = Variable(images.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()
            outputs = model(images)
            #print(np.unique(outputs.data[0].cpu().numpy()))
            loss = CrossEntropyLoss2d_Seg(input=outputs,
                                          target=labels,
                                          class_num=n_classes)

            loss.backward()
            optimizer.step()

            if (i + 1) % 100 == 0:
                logger.info("Epoch [%d/%d] Loss: %.4f, lr: %.7f" %
                            (epoch + 1, args.n_epoch, loss.data[0], cur_lr))

        model.eval()
        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader),
                                                    total=len(valloader),
                                                    desc="validation"):
            images_val = Variable(images_val.cuda(), volatile=True)
            labels_val = Variable(labels_val.cuda(), volatile=True)

            outputs = model(images_val)
            pred = outputs.data.max(1)[1].cpu().numpy()
            gt = labels_val.data.cpu().numpy()
            running_metrics.update(gt, pred)

        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            logger.info("{}: {}".format(k, v))
        running_metrics.reset()

        if score['Mean IoU : \t'] >= best_iou:
            best_iou = score['Mean IoU : \t']
            state = {
                'epoch': epoch + 1,
                'mIoU': best_iou,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(state,
                       os.path.join(logger.get_logger_dir(), "best_model.pkl"))
示例#2
0
def main():
    logger.auto_set_dir()
    global args
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataroot',
                        default='/home/hutao/lab/pytorchgo/example/ROAD/data',
                        help='Path to source dataset')
    parser.add_argument('--batchSize',
                        type=int,
                        default=1,
                        help='input batch size')
    parser.add_argument('--max_epoch',
                        type=int,
                        default=max_epoch,
                        help='Number of training iterations')
    parser.add_argument('--optimizer',
                        type=str,
                        default='Adam',
                        help='Optimizer to use | SGD, Adam')
    parser.add_argument('--lr',
                        type=float,
                        default=base_lr,
                        help='learning rate')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.99,
                        help='Momentum for SGD')
    parser.add_argument('--beta1',
                        type=float,
                        default=0.9,
                        help='beta1 for adam. default=0.5')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=0.0005,
                        help='Weight decay')
    parser.add_argument('--model', type=str, default='vgg16')
    parser.add_argument('--gpu', type=int, default=1)

    args = parser.parse_args()
    print(args)

    gpu = args.gpu

    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
    cuda = torch.cuda.is_available()
    torch.manual_seed(1337)
    if cuda:
        logger.info("random seed 1337")
        torch.cuda.manual_seed(1337)

    # Defining data loaders

    kwargs = {
        'num_workers': 4,
        'pin_memory': True,
        'drop_last': True
    } if cuda else {}
    train_loader = torch.utils.data.DataLoader(torchfcn.datasets.SYNTHIA(
        'SYNTHIA',
        args.dataroot,
        split='train',
        transform=True,
        image_size=image_size),
                                               batch_size=args.batchSize,
                                               shuffle=True,
                                               **kwargs)

    val_loader = torch.utils.data.DataLoader(torchfcn.datasets.CityScapes(
        'cityscapes',
        args.dataroot,
        split='val',
        transform=True,
        image_size=image_size),
                                             batch_size=1,
                                             shuffle=False)

    target_loader = torch.utils.data.DataLoader(torchfcn.datasets.CityScapes(
        'cityscapes',
        args.dataroot,
        split='train',
        transform=True,
        image_size=image_size),
                                                batch_size=args.batchSize,
                                                shuffle=True)

    if cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    if args.model == "vgg16":
        model = origin_model = torchfcn.models.Seg_model(n_class=class_num)
        vgg16 = torchfcn.models.VGG16(pretrained=True)
        model.copy_params_from_vgg16(vgg16)

        model_fix = torchfcn.models.Seg_model(n_class=class_num)
        model_fix.copy_params_from_vgg16(vgg16)
        for param in model_fix.parameters():
            param.requires_grad = False

    elif args.model == "deeplabv2":  # TODO may have problem!
        model = origin_model = torchfcn.models.Res_Deeplab(
            num_classes=class_num, image_size=image_size)
        saved_state_dict = model_zoo.load_url(Deeplabv2_restore_from)
        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not class_num == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

        model_fix = torchfcn.models.Res_Deeplab(num_classes=class_num,
                                                image_size=image_size)
        model_fix.load_state_dict(new_params)
    else:
        raise ValueError("only support vgg16, deeplabv2!")

    netD = torchfcn.models.Domain_classifer(reverse=True)
    netD.apply(weights_init)

    model_summary([model, netD])

    if cuda:
        model = model.cuda()
        netD = netD.cuda()

    # Defining optimizer

    if args.optimizer == 'SGD':
        raise ValueError("SGD is not prepared well..")
        optim = torch.optim.SGD([
            {
                'params': get_parameters(model, bias=False)
            },
            {
                'params': get_parameters(model, bias=True),
                'lr': args.lr * 2,
                'weight_decay': args.weight_decay
            },
        ],
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    elif args.optimizer == 'Adam':
        if args.model == "vgg16":
            optim = torch.optim.Adam([
                {
                    'params': get_parameters(model, bias=False),
                    'weight_decay': args.weight_decay
                },
                {
                    'params': get_parameters(model, bias=True),
                    'lr': args.lr * 2,
                    'weight_decay': args.weight_decay
                },
            ],
                                     lr=args.lr,
                                     betas=(args.beta1, 0.999))
        elif args.model == "deeplabv2":
            optim = torch.optim.Adam(origin_model.optim_parameters(args.lr),
                                     lr=args.lr,
                                     betas=(args.beta1, 0.999),
                                     weight_decay=args.weight_decay)
        else:
            raise
    else:
        raise ValueError('Invalid optmizer argument. Has to be SGD or Adam')

    optimD = torch.optim.Adam(netD.parameters(),
                              lr=dis_lr,
                              weight_decay=args.weight_decay,
                              betas=(0.7, 0.999))

    optimizer_summary([optim, optimD])

    trainer = MyTrainer_ROAD(cuda=cuda,
                             model=model,
                             model_fix=model_fix,
                             netD=netD,
                             optimizer=optim,
                             optimizerD=optimD,
                             train_loader=train_loader,
                             target_loader=target_loader,
                             val_loader=val_loader,
                             batch_size=args.batchSize,
                             image_size=image_size,
                             loss_print_interval=LOSS_PRINT_INTERVAL)
    trainer.epoch = 0
    trainer.iteration = 0
    trainer.train()
示例#3
0
def train():
    train_loader = generator.Generator(args.dataset_root,
                                       args,
                                       partition='train',
                                       dataset=args.dataset)
    logger.info('Batch size: ' + str(args.batch_size))

    #Try to load models
    enc_nn = models.load_model('enc_nn', args)
    metric_nn = models.load_model('metric_nn', args)

    if enc_nn is None or metric_nn is None:
        enc_nn, metric_nn = models.create_models(args=args)
    softmax_module = models.SoftmaxModule()

    if args.cuda:
        enc_nn.cuda()
        metric_nn.cuda()

    logger.info(str(enc_nn))
    logger.info(str(metric_nn))

    weight_decay = 0
    if args.dataset == 'mini_imagenet':
        logger.info('Weight decay ' + str(1e-6))
        weight_decay = 1e-6
    opt_enc_nn = optim.Adam(enc_nn.parameters(),
                            lr=args.lr,
                            weight_decay=weight_decay)
    opt_metric_nn = optim.Adam(metric_nn.parameters(),
                               lr=args.lr,
                               weight_decay=weight_decay)

    model_summary([enc_nn, metric_nn])
    optimizer_summary([opt_enc_nn, opt_metric_nn])
    enc_nn.train()
    metric_nn.train()
    counter = 0
    total_loss = 0
    val_acc, val_acc_aux = 0, 0
    test_acc = 0
    for batch_idx in range(args.iterations):

        ####################
        # Train
        ####################
        data = train_loader.get_task_batch(
            batch_size=args.batch_size,
            n_way=args.train_N_way,
            unlabeled_extra=args.unlabeled_extra,
            num_shots=args.train_N_shots,
            cuda=args.cuda,
            variable=True)
        [
            batch_x, label_x, _, _, batches_xi, labels_yi, oracles_yi,
            hidden_labels
        ] = data

        opt_enc_nn.zero_grad()
        opt_metric_nn.zero_grad()

        loss_d_metric = train_batch(model=[enc_nn, metric_nn, softmax_module],
                                    data=[
                                        batch_x, label_x, batches_xi,
                                        labels_yi, oracles_yi, hidden_labels
                                    ])

        opt_enc_nn.step()
        opt_metric_nn.step()

        adjust_learning_rate(optimizers=[opt_enc_nn, opt_metric_nn],
                             lr=args.lr,
                             iter=batch_idx)

        ####################
        # Display
        ####################
        counter += 1
        total_loss += loss_d_metric.data[0]
        if batch_idx % args.log_interval == 0:
            display_str = 'Train Iter: {}'.format(batch_idx)
            display_str += '\tLoss_d_metric: {:.6f}'.format(total_loss /
                                                            counter)
            logger.info(display_str)
            counter = 0
            total_loss = 0

        ####################
        # Test
        ####################
        if (batch_idx + 1) % args.test_interval == 0 or batch_idx == 20:
            if batch_idx == 20:
                test_samples = 100
            else:
                test_samples = 3000
            if args.dataset == 'mini_imagenet':
                val_acc_aux = test.test_one_shot(
                    args,
                    model=[enc_nn, metric_nn, softmax_module],
                    test_samples=test_samples * 5,
                    partition='val')
            test_acc_aux = test.test_one_shot(
                args,
                model=[enc_nn, metric_nn, softmax_module],
                test_samples=test_samples * 5,
                partition='test')
            test.test_one_shot(args,
                               model=[enc_nn, metric_nn, softmax_module],
                               test_samples=test_samples,
                               partition='train')
            enc_nn.train()
            metric_nn.train()

            if val_acc_aux is not None and val_acc_aux >= val_acc:
                test_acc = test_acc_aux
                val_acc = val_acc_aux

            if args.dataset == 'mini_imagenet':
                logger.info("Best test accuracy {:.4f} \n".format(test_acc))

        ####################
        # Save model
        ####################
        if (batch_idx + 1) % args.save_interval == 0:
            logger.info("saving model...")
            torch.save(enc_nn,
                       os.path.join(logger.get_logger_dir(), 'enc_nn.t7'))
            torch.save(metric_nn,
                       os.path.join(logger.get_logger_dir(), 'metric_nn.t7'))

    # Test after training
    test.test_one_shot(args,
                       model=[enc_nn, metric_nn, softmax_module],
                       test_samples=args.test_samples)
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    from pytorchgo.utils.pytorch_utils import set_gpu
    set_gpu(args.gpu)

    # Create network
    if args.model == 'DeepLab':
        logger.info("adopting Deeplabv2 base model..")
        model = Res_Deeplab(num_classes=args.num_classes, multi_scale=False)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

        optimizer = optim.SGD(model.optim_parameters(args),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.model == "FCN8S":
        logger.info("adopting FCN8S base model..")
        from pytorchgo.model.MyFCN8s import MyFCN8s
        model = MyFCN8s(n_class=NUM_CLASSES)
        vgg16 = torchfcn.models.VGG16(pretrained=True)
        model.copy_params_from_vgg16(vgg16)

        optimizer = optim.SGD(model.parameters(),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    else:
        raise ValueError

    model.train()
    model.cuda()

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda()

    model_D2.train()
    model_D2.cuda()

    if SOURCE_DATA == "GTA5":
        trainloader = data.DataLoader(GTA5DataSet(
            args.data_dir,
            args.data_list,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size,
            scale=args.random_scale,
            mirror=args.random_mirror,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
        trainloader_iter = enumerate(trainloader)
    elif SOURCE_DATA == "SYNTHIA":
        trainloader = data.DataLoader(SynthiaDataSet(
            args.data_dir,
            args.data_list,
            LABEL_LIST_PATH,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size,
            scale=args.random_scale,
            mirror=args.random_mirror,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
        trainloader_iter = enumerate(trainloader)
    else:
        raise ValueError

    targetloader = data.DataLoader(cityscapesDataSet(
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    best_mIoU = 0

    model_summary([model, model_D1, model_D2])
    optimizer_summary([optimizer, optimizer_D1, optimizer_D2])

    for i_iter in tqdm(range(args.num_steps_stop),
                       total=args.num_steps_stop,
                       desc="training"):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        lr = adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        lr_D1 = adjust_learning_rate_D(optimizer_D1, i_iter)
        lr_D2 = adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            ######################### train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.next()
            images, labels, _, _ = batch
            images = Variable(images).cuda()

            pred2 = model(images)
            pred2 = interp(pred2)

            loss_seg2 = loss_calc(pred2, labels)
            loss = loss_seg2

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value2 += loss_seg2.data.cpu().numpy()[0] / args.iter_size

            # train with target

            _, batch = targetloader_iter.next()
            images, _, _, _ = batch
            images = Variable(images).cuda()

            pred_target2 = model(images)
            pred_target2 = interp_target(pred_target2)

            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda())

            loss = args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            )[0] / args.iter_size

            ################################## train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred2 = pred2.detach()
            D_out2 = model_D2(F.softmax(pred2))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda())

            loss_D2 = loss_D2 / args.iter_size / 2
            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()[0]

            # train with target
            pred_target2 = pred_target2.detach()

            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda())

            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if i_iter % 100 == 0:
            logger.info(
                'iter = {}/{},loss_seg1 = {:.3f} loss_seg2 = {:.3f} loss_adv1 = {:.3f}, loss_adv2 = {:.3f} loss_D1 = {:.3f} loss_D2 = {:.3f}, lr={:.7f}, lr_D={:.7f}, best miou16= {:.5f}'
                .format(i_iter, args.num_steps_stop, loss_seg_value1,
                        loss_seg_value2, loss_adv_target_value1,
                        loss_adv_target_value2, loss_D_value1, loss_D_value2,
                        lr, lr_D1, best_mIoU))

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            logger.info("saving snapshot.....")
            cur_miou16 = proceed_test(model, input_size)
            is_best = True if best_mIoU < cur_miou16 else False
            if is_best:
                best_mIoU = cur_miou16
            torch.save(
                {
                    'iteration': i_iter,
                    'optim_state_dict': optimizer.state_dict(),
                    'optim_D1_state_dict': optimizer_D1.state_dict(),
                    'optim_D2_state_dict': optimizer_D2.state_dict(),
                    'model_state_dict': model.state_dict(),
                    'model_D1_state_dict': model_D1.state_dict(),
                    'model_D2_state_dict': model_D2.state_dict(),
                    'best_mean_iu': cur_miou16,
                }, osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar'))
            if is_best:
                import shutil
                shutil.copy(
                    osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar'),
                    osp.join(logger.get_logger_dir(), 'model_best.pth.tar'))

        if i_iter >= args.num_steps_stop - 1:
            break
示例#5
0
                loss1e4=loss.item() * 1e4,
                group0_lr=optimizer.state_dict()['param_groups'][0]['lr'],
                sk_err=err,
                sk_time_sec=tim_sec),
                          step=pytorchgo_args.get_args().step,
                          use_wandb=pytorchgo_args.get_args().wandb,
                          prefix="training epoch {}/{}: ".format(
                              epoch,
                              pytorchgo_args.get_args().epochs))
            #optimizer_summary(optimizer)

    cpu_prototype = model.prototype_N2K.detach().cpu().numpy()
    return cpu_prototype


optimizer_summary(optimizer)
model_summary(model)

pytorchgo_args.get_args().step = 0
for epoch in range(start_epoch, start_epoch + args.epochs):
    if args.debug and epoch >= 2: break
    prototype = train(epoch)
    feature_return_switch(model, True)
    logger.warning(logger.get_logger_dir())
    logger.warning("doing KNN evaluation.")
    acc = kNN(model, trainloader, testloader, K=10, sigma=0.1, dim=knn_dim)
    logger.warning("finish KNN evaluation.")
    feature_return_switch(model, False)
    if acc > best_acc:
        logger.info('get better result, saving..')
        state = {
def train(args):

    logger.auto_set_dir()
    from pytorchgo.utils.pytorch_utils import set_gpu
    set_gpu(args.gpu)

    # Setup Dataloader
    from pytorchgo.augmentation.segmentation import SubtractMeans, PIL2NP, RGB2BGR, PIL_Scale, Value255to0, ToLabel
    from torchvision.transforms import Compose, Normalize, ToTensor
    img_transform = Compose([  # notice the order!!!
        PIL_Scale(train_img_shape, Image.BILINEAR),
        PIL2NP(),
        RGB2BGR(),
        SubtractMeans(),
        ToTensor(),
    ])

    label_transform = Compose([
        PIL_Scale(train_img_shape, Image.NEAREST),
        PIL2NP(),
        Value255to0(),
        ToLabel()
    ])

    val_img_transform = Compose([
        PIL_Scale(train_img_shape, Image.BILINEAR),
        PIL2NP(),
        RGB2BGR(),
        SubtractMeans(),
        ToTensor(),
    ])
    val_label_transform = Compose([
        PIL_Scale(train_img_shape, Image.NEAREST),
        PIL2NP(),
        ToLabel(),
        # notice here, training, validation size difference, this is very tricky.
    ])

    from pytorchgo.dataloader.pascal_voc_loader import pascalVOCLoader as common_voc_loader
    train_loader = common_voc_loader(split="train_aug",
                                     epoch_scale=1,
                                     img_transform=img_transform,
                                     label_transform=label_transform)
    validation_loader = common_voc_loader(split='val',
                                          img_transform=val_img_transform,
                                          label_transform=val_label_transform)

    n_classes = train_loader.n_classes
    trainloader = data.DataLoader(train_loader,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True)

    valloader = data.DataLoader(validation_loader,
                                batch_size=args.batch_size,
                                num_workers=8)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup Model
    from pytorchgo.model.deeplabv1 import VGG16_LargeFoV
    from pytorchgo.model.deeplab_resnet import Res_Deeplab

    model = Res_Deeplab(NoLabels=n_classes, pretrained=True, output_all=False)

    from pytorchgo.utils.pytorch_utils import model_summary, optimizer_summary
    model_summary(model)

    def get_validation_miou(model):
        model.eval()
        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader),
                                                    total=len(valloader),
                                                    desc="validation"):
            if i_val > 5 and is_debug == 1: break
            if i_val > 200 and is_debug == 2: break

            #img_large = torch.Tensor(np.zeros((1, 3, 513, 513)))
            #img_large[:, :, :images_val.shape[2], :images_val.shape[3]] = images_val

            output = model(Variable(images_val, volatile=True).cuda())
            output = output
            pred = output.data.max(1)[1].cpu().numpy()
            #pred = output[:, :images_val.shape[2], :images_val.shape[3]]

            gt = labels_val.numpy()

            running_metrics.update(gt, pred)

        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            logger.info("{}: {}".format(k, v))
        running_metrics.reset()
        return score['Mean IoU : \t']

    model.cuda()

    # Check if model has custom optimizer / loss
    if hasattr(model, 'optimizer'):
        logger.warn("don't have customzed optimizer, use default setting!")
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.optimizer_params(args.l_rate),
                                    lr=args.l_rate,
                                    momentum=0.99,
                                    weight_decay=5e-4)

    optimizer_summary(optimizer)
    if args.resume is not None:
        if os.path.isfile(args.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            logger.info("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logger.info("No checkpoint found at '{}'".format(args.resume))

    best_iou = 0
    logger.info('start!!')
    for epoch in tqdm(range(args.n_epoch), total=args.n_epoch):
        model.train()
        for i, (images, labels) in tqdm(enumerate(trainloader),
                                        total=len(trainloader),
                                        desc="training epoch {}/{}".format(
                                            epoch, args.n_epoch)):
            if i > 10 and is_debug == 1: break

            if i > 200 and is_debug == 2: break

            cur_iter = i + epoch * len(trainloader)
            cur_lr = adjust_learning_rate(optimizer,
                                          args.l_rate,
                                          cur_iter,
                                          args.n_epoch * len(trainloader),
                                          power=0.9)

            images = Variable(images.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()
            outputs = model(images)  # use fusion score
            loss = CrossEntropyLoss2d_Seg(input=outputs,
                                          target=labels,
                                          class_num=n_classes)

            #for i in range(len(outputs) - 1):
            #for i in range(1):
            #    loss = loss + CrossEntropyLoss2d_Seg(input=outputs[i], target=labels, class_num=n_classes)

            loss.backward()
            optimizer.step()

            if (i + 1) % 100 == 0:
                logger.info(
                    "Epoch [%d/%d] Loss: %.4f, lr: %.7f, best mIoU: %.7f" %
                    (epoch + 1, args.n_epoch, loss.data[0], cur_lr, best_iou))

        cur_miou = get_validation_miou(model)
        if cur_miou >= best_iou:
            best_iou = cur_miou
            state = {
                'epoch': epoch + 1,
                'mIoU': best_iou,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(state,
                       os.path.join(logger.get_logger_dir(), "best_model.pth"))
示例#7
0
                                             method=args.method, uses_one_classifier=args.uses_one_classifier,
                                             is_data_parallel=args.is_data_parallel)
    optimizer_g = get_optimizer(model_g.parameters(), lr=args.lr, momentum=args.momentum, opt=args.opt,
                                weight_decay=args.weight_decay)
    optimizer_f = get_optimizer(list(model_f1.parameters()) + list(model_f2.parameters()), opt=args.opt,
                                lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
if args.uses_one_classifier:
    logger.warn ("f1 and f2 are same!")
    model_f2 = model_f1



from pytorchgo.utils.pytorch_utils import model_summary,optimizer_summary

model_summary([model_g, model_f1, model_f2])
optimizer_summary([optimizer_g,optimizer_f])

mode = "%s-%s2%s-%s_%sch" % (args.src_dataset, args.src_split, args.tgt_dataset, args.tgt_split, args.input_ch)
if args.net in ["fcn", "psp"]:
    model_name = "%s-%s-%s-res%s" % (args.method, args.savename, args.net, args.res)
else:
    model_name = "%s-%s-%s" % (args.method, args.savename, args.net)

outdir = os.path.join(logger.get_logger_dir(), mode)

# Create Model Dir
pth_dir = os.path.join(outdir, "pth")
mkdir_if_not_exist(pth_dir)

# Create Model Dir and  Set TF-Logger
tflog_dir = os.path.join(outdir, "tflog", model_name)