Ejemplo n.º 1
0
def main(opts):
    # Set parameters
    p = OrderedDict()  # Parameters to include in report
    p['trainBatch'] = opts.batch  # Training batch size
    testBatch = 1  # Testing batch size
    useTest = True  # See evolution of the test set when training
    nTestInterval = opts.testInterval  # Run on test set every nTestInterval epochs
    snapshot = 1  # Store a model every snapshot epochs
    p['nAveGrad'] = 1  # Average the gradient of several iterations
    p['lr'] = opts.lr  # Learning rate
    p['wd'] = 5e-4  # Weight decay
    p['momentum'] = 0.9  # Momentum
    p['epoch_size'] = opts.step  # How many epochs to change learning rate
    p['num_workers'] = opts.numworker
    model_path = opts.pretrainedModel
    backbone = 'xception'  # Use xception or resnet as feature extractor
    nEpochs = opts.epochs

    max_id = 0
    save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
    exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
    runs = glob.glob(os.path.join(save_dir_root, 'run', 'run_*'))
    for r in runs:
        run_id = int(r.split('_')[-1])
        if run_id >= max_id:
            max_id = run_id + 1
    # run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
    save_dir = os.path.join(save_dir_root, 'run', 'run_' + str(max_id))

    # Device
    if (opts.device == "gpu"):
        use_cuda = torch.cuda.is_available()
        if (use_cuda == True):
            device = torch.device("cuda")
            #torch.cuda.set_device(args.gpu_ids[0])
            print("実行デバイス :", device)
            print("GPU名 :", torch.cuda.get_device_name(device))
            print("torch.cuda.current_device() =", torch.cuda.current_device())
        else:
            print("can't using gpu.")
            device = torch.device("cpu")
            print("実行デバイス :", device)
    else:
        device = torch.device("cpu")
        print("実行デバイス :", device)

    # Network definition
    if backbone == 'xception':
        net_ = deeplab_xception_universal.deeplab_xception_end2end_3d(
            n_classes=20,
            os=16,
            hidden_layers=opts.hidden_layers,
            source_classes=7,
            middle_classes=18,
        )
    elif backbone == 'resnet':
        # net_ = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
        raise NotImplementedError
    else:
        raise NotImplementedError

    modelName = 'deeplabv3plus-' + backbone + '-voc' + datetime.now().strftime(
        '%b%d_%H-%M-%S')
    criterion = ut.cross_entropy2d

    if gpu_id >= 0:
        # torch.cuda.set_device(device=gpu_id)
        net_.to(device)

    # net load weights
    if not model_path == '':
        x = torch.load(model_path)
        net_.load_state_dict_new(x)
        print('load pretrainedModel.')
    else:
        print('no pretrainedModel.')

    if not opts.loadmodel == '':
        x = torch.load(opts.loadmodel)
        net_.load_source_model(x)
        print('load model:', opts.loadmodel)
    else:
        print('no trained model load !!!!!!!!')

    print(net_)

    log_dir = os.path.join(
        save_dir, 'models',
        datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    writer = SummaryWriter(log_dir=log_dir)
    writer.add_text('load model', opts.loadmodel, 1)
    writer.add_text('setting', sys.argv[0], 1)

    # Use the following optimizer
    optimizer = optim.SGD(net_.parameters(),
                          lr=p['lr'],
                          momentum=p['momentum'],
                          weight_decay=p['wd'])

    composed_transforms_tr = transforms.Compose([
        tr.RandomSized_new(opts.image_size),
        tr.Normalize_xception_tf(),
        tr.ToTensor_()
    ])

    composed_transforms_ts = transforms.Compose(
        [tr.Normalize_xception_tf(),
         tr.ToTensor_()])

    composed_transforms_ts_flip = transforms.Compose(
        [tr.HorizontalFlip(),
         tr.Normalize_xception_tf(),
         tr.ToTensor_()])

    #all_train = cihp_pascal_atr.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True)
    #voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)
    #voc_val_flip = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)
    all_train = cihp_pascal_atr.VOCSegmentation(
        cihp_dir="./data/datasets/CIHP_4w",
        split='train',
        transform=composed_transforms_tr,
        flip=True)
    #voc_val = pascal.VOCSegmentation(base_dir="./data/datasets/pascal", split='val', transform=composed_transforms_ts)
    #voc_val_flip = pascal.VOCSegmentation(base_dir="./data/datasets/pascal", split='val', transform=composed_transforms_ts_flip)

    num_cihp, num_pascal, num_atr = all_train.get_class_num()
    ss = sam.Sampler_uni(num_cihp, num_pascal, num_atr, opts.batch)
    # balance datasets based pascal
    ss_balanced = sam.Sampler_uni(num_cihp,
                                  num_pascal,
                                  num_atr,
                                  opts.batch,
                                  balance_id=1)

    trainloader = DataLoader(all_train,
                             batch_size=p['trainBatch'],
                             shuffle=False,
                             num_workers=p['num_workers'],
                             sampler=ss,
                             drop_last=True)
    trainloader_balanced = DataLoader(all_train,
                                      batch_size=p['trainBatch'],
                                      shuffle=False,
                                      num_workers=p['num_workers'],
                                      sampler=ss_balanced,
                                      drop_last=True)
    #testloader = DataLoader(voc_val, batch_size=testBatch, shuffle=False, num_workers=p['num_workers'])
    #testloader_flip = DataLoader(voc_val_flip, batch_size=testBatch, shuffle=False, num_workers=p['num_workers'])

    num_img_tr = len(trainloader)
    num_img_balanced = len(trainloader_balanced)
    #num_img_ts = len(testloader)
    num_img_ts = 0
    running_loss_tr = 0.0
    running_loss_tr_atr = 0.0
    running_loss_ts = 0.0
    aveGrad = 0
    global_step = 0
    print("Training Network")
    net = torch.nn.DataParallel(net_)

    id_list = torch.LongTensor(range(opts.batch))
    pascal_iter = int(num_img_tr // opts.batch)

    # Get graphs
    train_graph, test_graph = get_graphs(opts, device)
    adj1, adj2, adj3, adj4, adj5, adj6 = train_graph
    adj1_test, adj2_test, adj3_test, adj4_test, adj5_test, adj6_test = test_graph

    # Main Training and Testing Loop
    for epoch in range(resume_epoch, int(1.5 * nEpochs)):
        start_time = timeit.default_timer()

        if epoch % p['epoch_size'] == p['epoch_size'] - 1 and epoch < nEpochs:
            lr_ = ut.lr_poly(p['lr'], epoch, nEpochs, 0.9)
            optimizer = optim.SGD(net_.parameters(),
                                  lr=lr_,
                                  momentum=p['momentum'],
                                  weight_decay=p['wd'])
            print('(poly lr policy) learning rate: ', lr_)
            writer.add_scalar('data/lr_', lr_, epoch)
        elif epoch % p['epoch_size'] == p['epoch_size'] - 1 and epoch > nEpochs:
            lr_ = ut.lr_poly(p['lr'], epoch - nEpochs, int(0.5 * nEpochs), 0.9)
            optimizer = optim.SGD(net_.parameters(),
                                  lr=lr_,
                                  momentum=p['momentum'],
                                  weight_decay=p['wd'])
            print('(poly lr policy) learning rate: ', lr_)
            writer.add_scalar('data/lr_', lr_, epoch)

        net_.train()
        if epoch < nEpochs:
            for ii, sample_batched in enumerate(trainloader):
                inputs, labels = sample_batched['image'], sample_batched[
                    'label']
                dataset_lbl = sample_batched['pascal'][0].item()
                # Forward-Backward of the mini-batch
                inputs, labels = Variable(inputs,
                                          requires_grad=True), Variable(labels)
                global_step += 1

                if gpu_id >= 0:
                    inputs, labels = inputs.to(device), labels.to(device)

                if dataset_lbl == 0:
                    # 0 is cihp -- target
                    #print( "inputs.shape : ", inputs.shape )    # torch.Size([batch, 3, 512, 512])
                    #print( "adj1.shape : ", adj1.shape )        # torch.Size([1, 1, 20, 20])
                    #print( "adj2.shape : ", adj2.shape )        # torch.Size([1, 1, 7, 7])
                    #print( "adj3.shape : ", adj3.shape )        # torch.Size([1, 1, 20, 20])
                    _, outputs, _ = net.forward(
                        None,
                        input_target=inputs,
                        input_middle=None,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )

                    #print( "outputs.shape : ", outputs.shape )  # torch.Size([2, 20, 512, 512])

                elif dataset_lbl == 1:
                    # pascal is source
                    outputs, _, _ = net.forward(
                        inputs,
                        input_target=None,
                        input_middle=None,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )
                else:
                    # atr
                    _, _, outputs = net.forward(
                        None,
                        input_target=None,
                        input_middle=inputs,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )
                # print(sample_batched['pascal'])
                # print(outputs.size(),)
                # print(labels)
                loss = criterion(outputs, labels, batch_average=True)
                running_loss_tr += loss.item()

                # Print stuff
                if ii % num_img_tr == (num_img_tr - 1):
                    running_loss_tr = running_loss_tr / num_img_tr
                    writer.add_scalar('data/total_loss_epoch', running_loss_tr,
                                      epoch)
                    print('[Epoch: %d, numImages: %5d]' % (epoch, epoch))
                    print('Loss: %f' % running_loss_tr)
                    running_loss_tr = 0
                    stop_time = timeit.default_timer()
                    print("Execution time: " + str(stop_time - start_time) +
                          "\n")

                # Backward the averaged gradient
                loss /= p['nAveGrad']
                loss.backward()
                aveGrad += 1

                # Update the weights once in p['nAveGrad'] forward passes
                if aveGrad % p['nAveGrad'] == 0:
                    writer.add_scalar('data/total_loss_iter', loss.item(),
                                      global_step)
                    if dataset_lbl == 0:
                        writer.add_scalar('data/total_loss_iter_cihp',
                                          loss.item(), global_step)
                    if dataset_lbl == 1:
                        writer.add_scalar('data/total_loss_iter_pascal',
                                          loss.item(), global_step)
                    if dataset_lbl == 2:
                        writer.add_scalar('data/total_loss_iter_atr',
                                          loss.item(), global_step)
                    optimizer.step()
                    optimizer.zero_grad()
                    # optimizer_gcn.step()
                    # optimizer_gcn.zero_grad()
                    aveGrad = 0

                # Show 10 * 3 images results each epoch
                if ii % (num_img_tr // 10) == 0:
                    #                if ii % (num_img_tr // 4000) == 0:
                    grid_image = make_grid(inputs[:3].clone().cpu().data,
                                           3,
                                           normalize=True)
                    writer.add_image('Image', grid_image, global_step)
                    grid_image = make_grid(ut.decode_seg_map_sequence(
                        torch.max(outputs[:3], 1)[1].detach().cpu().numpy()),
                                           3,
                                           normalize=False,
                                           range=(0, 255))
                    writer.add_image('Predicted label', grid_image,
                                     global_step)
                    grid_image = make_grid(ut.decode_seg_map_sequence(
                        torch.squeeze(labels[:3], 1).detach().cpu().numpy()),
                                           3,
                                           normalize=False,
                                           range=(0, 255))
                    writer.add_image('Groundtruth label', grid_image,
                                     global_step)

                print('step {} | loss is {}'.format(ii,
                                                    loss.cpu().item()),
                      flush=True)
        else:
            # Balanced the number of datasets
            for ii, sample_batched in enumerate(trainloader_balanced):
                inputs, labels = sample_batched['image'], sample_batched[
                    'label']
                dataset_lbl = sample_batched['pascal'][0].item()
                # Forward-Backward of the mini-batch
                inputs, labels = Variable(inputs,
                                          requires_grad=True), Variable(labels)
                global_step += 1

                if gpu_id >= 0:
                    inputs, labels = inputs.to(device), labels.to(device)

                if dataset_lbl == 0:
                    # 0 is cihp -- target
                    _, outputs, _ = net.forward(
                        None,
                        input_target=inputs,
                        input_middle=None,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )
                elif dataset_lbl == 1:
                    # pascal is source
                    outputs, _, _ = net.forward(
                        inputs,
                        input_target=None,
                        input_middle=None,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )
                else:
                    # atr
                    _, _, outputs = net.forward(
                        None,
                        input_target=None,
                        input_middle=inputs,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )
                # print(sample_batched['pascal'])
                # print(outputs.size(),)
                # print(labels)
                loss = criterion(outputs, labels, batch_average=True)
                running_loss_tr += loss.item()

                # Print stuff
                if ii % num_img_balanced == (num_img_balanced - 1):
                    running_loss_tr = running_loss_tr / num_img_balanced
                    writer.add_scalar('data/total_loss_epoch', running_loss_tr,
                                      epoch)
                    print('[Epoch: %d, numImages: %5d]' % (epoch, epoch))
                    print('Loss: %f' % running_loss_tr)
                    running_loss_tr = 0
                    stop_time = timeit.default_timer()
                    print("Execution time: " + str(stop_time - start_time) +
                          "\n")

                # Backward the averaged gradient
                loss /= p['nAveGrad']
                loss.backward()
                aveGrad += 1

                # Update the weights once in p['nAveGrad'] forward passes
                if aveGrad % p['nAveGrad'] == 0:
                    writer.add_scalar('data/total_loss_iter', loss.item(),
                                      global_step)
                    if dataset_lbl == 0:
                        writer.add_scalar('data/total_loss_iter_cihp',
                                          loss.item(), global_step)
                    if dataset_lbl == 1:
                        writer.add_scalar('data/total_loss_iter_pascal',
                                          loss.item(), global_step)
                    if dataset_lbl == 2:
                        writer.add_scalar('data/total_loss_iter_atr',
                                          loss.item(), global_step)
                    optimizer.step()
                    optimizer.zero_grad()

                    aveGrad = 0

                # Show 10 * 3 images results each epoch
                if ii % (num_img_balanced // 10) == 0:
                    grid_image = make_grid(inputs[:3].clone().cpu().data,
                                           3,
                                           normalize=True)
                    writer.add_image('Image', grid_image, global_step)
                    grid_image = make_grid(ut.decode_seg_map_sequence(
                        torch.max(outputs[:3], 1)[1].detach().cpu().numpy()),
                                           3,
                                           normalize=False,
                                           range=(0, 255))
                    writer.add_image('Predicted label', grid_image,
                                     global_step)
                    grid_image = make_grid(ut.decode_seg_map_sequence(
                        torch.squeeze(labels[:3], 1).detach().cpu().numpy()),
                                           3,
                                           normalize=False,
                                           range=(0, 255))
                    writer.add_image('Groundtruth label', grid_image,
                                     global_step)

                print('loss is ', loss.cpu().item(), flush=True)

        # Save the model
        if (epoch % snapshot) == snapshot - 1:
            torch.save(
                net_.state_dict(),
                os.path.join(save_dir, 'models',
                             modelName + '_epoch-' + str(epoch) + '.pth'))
            print("Save model at {}\n".format(
                os.path.join(save_dir, 'models',
                             modelName + '_epoch-' + str(epoch) + '.pth')))

        # One testing epoch
        """
def main(opts):
    p = OrderedDict()  # Parameters to include in report
    p['trainBatch'] = opts.batch  # Training batch size
    testBatch = 1  # Testing batch size
    useTest = True  # See evolution of the test set when training
    nTestInterval = opts.testInterval  # Run on test set every nTestInterval epochs
    snapshot = 1  # Store a model every snapshot epochs
    p['nAveGrad'] = 1  # Average the gradient of several iterations
    p['lr'] = opts.lr  # Learning rate
    p['lrFtr'] = 1e-5
    p['lraspp'] = 1e-5
    p['lrpro'] = 1e-5
    p['lrdecoder'] = 1e-5
    p['lrother'] = 1e-5
    p['wd'] = 5e-4  # Weight decay
    p['momentum'] = 0.9  # Momentum
    p['epoch_size'] = opts.step  # How many epochs to change learning rate
    p['num_workers'] = opts.numworker
    model_path = opts.pretrainedModel
    backbone = 'xception'  # Use xception or resnet as feature extractor,
    nEpochs = opts.epochs

    max_id = 0
    save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
    exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
    runs = glob.glob(os.path.join(save_dir_root, 'run_cihp', 'run_*'))
    for r in runs:
        run_id = int(r.split('_')[-1])
        if run_id >= max_id:
            max_id = run_id + 1
    save_dir = os.path.join(save_dir_root, 'run_cihp', 'run_' + str(max_id))

    # Device
    if (opts.device == "gpu"):
        use_cuda = torch.cuda.is_available()
        if (use_cuda == True):
            device = torch.device("cuda")
            #torch.cuda.set_device(args.gpu_ids[0])
            print("実行デバイス :", device)
            print("GPU名 :", torch.cuda.get_device_name(device))
            print("torch.cuda.current_device() =", torch.cuda.current_device())
        else:
            print("can't using gpu.")
            device = torch.device("cpu")
            print("実行デバイス :", device)
    else:
        device = torch.device("cpu")
        print("実行デバイス :", device)

    # Network definition
    if backbone == 'xception':
        net_ = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(
            n_classes=opts.classes,
            os=16,
            hidden_layers=opts.hidden_layers,
            source_classes=7,
        )
    elif backbone == 'resnet':
        # net_ = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
        raise NotImplementedError
    else:
        raise NotImplementedError

    modelName = 'deeplabv3plus-' + backbone + '-voc' + datetime.now().strftime(
        '%b%d_%H-%M-%S')
    criterion = util.cross_entropy2d

    if gpu_id >= 0:
        # torch.cuda.set_device(device=gpu_id)
        #net_.cuda()
        net_.to(device)

    # net load weights
    if not model_path == '':
        x = torch.load(model_path)
        net_.load_state_dict_new(x)
        print('load pretrainedModel:', model_path)
    else:
        print('no pretrainedModel.')
    if not opts.loadmodel == '':
        x = torch.load(opts.loadmodel)
        net_.load_source_model(x)
        print('load model:', opts.loadmodel)
    else:
        print('no model load !!!!!!!!')

    log_dir = os.path.join(
        save_dir, 'models',
        datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    writer = SummaryWriter(log_dir=log_dir)
    writer.add_text('load model', opts.loadmodel, 1)
    writer.add_text('setting', sys.argv[0], 1)

    if opts.freezeBN:
        net_.freeze_bn()

    print(net_)

    # Use the following optimizer
    optimizer = optim.SGD(net_.parameters(),
                          lr=p['lr'],
                          momentum=p['momentum'],
                          weight_decay=p['wd'])

    composed_transforms_tr = transforms.Compose([
        tr.RandomSized_new(opts.image_size),
        tr.Normalize_xception_tf(),
        tr.ToTensor_()
    ])

    composed_transforms_ts = transforms.Compose(
        [tr.Normalize_xception_tf(),
         tr.ToTensor_()])

    composed_transforms_ts_flip = transforms.Compose(
        [tr.HorizontalFlip(),
         tr.Normalize_xception_tf(),
         tr.ToTensor_()])

    #voc_train = cihp.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True)
    #voc_val = cihp.VOCSegmentation(split='val', transform=composed_transforms_ts)
    #voc_val_flip = cihp.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)
    voc_train = cihp.VOCSegmentation(base_dir="../../data/datasets/CIHP_4w",
                                     split='train',
                                     transform=composed_transforms_tr,
                                     flip=True)
    voc_val = cihp.VOCSegmentation(base_dir="../../data/datasets/CIHP_4w",
                                   split='val',
                                   transform=composed_transforms_ts)
    voc_val_flip = cihp.VOCSegmentation(base_dir="../../data/datasets/CIHP_4w",
                                        split='val',
                                        transform=composed_transforms_ts_flip)

    trainloader = DataLoader(voc_train,
                             batch_size=p['trainBatch'],
                             shuffle=True,
                             num_workers=p['num_workers'],
                             drop_last=True)
    testloader = DataLoader(voc_val,
                            batch_size=testBatch,
                            shuffle=False,
                            num_workers=p['num_workers'])
    testloader_flip = DataLoader(voc_val_flip,
                                 batch_size=testBatch,
                                 shuffle=False,
                                 num_workers=p['num_workers'])

    num_img_tr = len(trainloader)
    num_img_ts = len(testloader)
    running_loss_tr = 0.0
    running_loss_ts = 0.0
    aveGrad = 0
    global_step = 0
    print("Training Network")

    print("num_img_tr : ", num_img_tr)

    net = torch.nn.DataParallel(net_)
    train_graph, test_graph = get_graphs(opts, device)
    adj1, adj2, adj3 = train_graph

    # Main Training and Testing Loop
    for epoch in range(resume_epoch, nEpochs):
        start_time = timeit.default_timer()

        if epoch % p['epoch_size'] == p['epoch_size'] - 1:
            lr_ = util.lr_poly(p['lr'], epoch, nEpochs, 0.9)
            optimizer = optim.SGD(net_.parameters(),
                                  lr=lr_,
                                  momentum=p['momentum'],
                                  weight_decay=p['wd'])
            writer.add_scalar('data/lr_', lr_, epoch)
            print('(poly lr policy) learning rate: ', lr_)

        net.train()
        for ii, sample_batched in enumerate(trainloader):

            inputs, labels = sample_batched['image'], sample_batched['label']
            # Forward-Backward of the mini-batch
            inputs, labels = Variable(inputs,
                                      requires_grad=True), Variable(labels)
            global_step += inputs.data.shape[0]

            if gpu_id >= 0:
                #inputs, labels = inputs.cuda(), labels.cuda()
                inputs, labels = inputs.to(device), labels.to(device)

            #print( "inputs.shape : ", inputs.shape )    # torch.Size([batch, 3, 512, 512])
            #print( "adj1.shape : ", adj1.shape )        # torch.Size([8, 1, 20, 20])
            #print( "adj2.shape : ", adj2.shape )        # torch.Size([8, 1, 20, 7])
            #print( "adj3.shape : ", adj3.shape )        # torch.Size([8, 1, 7, 7])
            outputs = net.forward(inputs, adj1, adj3, adj2)
            #print( "outputs.shape : ", outputs.shape )  # torch.Size([2, 20, 512, 512])

            loss = criterion(outputs, labels, batch_average=True)
            running_loss_tr += loss.item()

            # Print stuff
            if ii % num_img_tr == (num_img_tr - 1):
                running_loss_tr = running_loss_tr / num_img_tr
                writer.add_scalar('data/total_loss_epoch', running_loss_tr,
                                  epoch)
                print('[Epoch: %d, numImages: %5d]' %
                      (epoch, ii * p['trainBatch'] + inputs.data.shape[0]))
                print('Loss: %f' % running_loss_tr)
                running_loss_tr = 0
                stop_time = timeit.default_timer()
                print("Execution time: " + str(stop_time - start_time) + "\n")

            # Backward the averaged gradient
            loss /= p['nAveGrad']
            loss.backward()
            aveGrad += 1

            # Update the weights once in p['nAveGrad'] forward passes
            if aveGrad % p['nAveGrad'] == 0:
                writer.add_scalar('data/total_loss_iter', loss.item(),
                                  ii + num_img_tr * epoch)
                optimizer.step()
                optimizer.zero_grad()
                aveGrad = 0

            # Show 10 * 3 images results each epoch
            if ii % (num_img_tr // 10) == 0:
                #            if ii % (num_img_tr // 4000) == 0:
                grid_image = make_grid(inputs[:3].clone().cpu().data,
                                       3,
                                       normalize=True)
                writer.add_image('Image', grid_image, global_step)
                grid_image = make_grid(util.decode_seg_map_sequence(
                    torch.max(outputs[:3], 1)[1].detach().cpu().numpy()),
                                       3,
                                       normalize=False,
                                       range=(0, 255))
                writer.add_image('Predicted label', grid_image, global_step)
                grid_image = make_grid(util.decode_seg_map_sequence(
                    torch.squeeze(labels[:3], 1).detach().cpu().numpy()),
                                       3,
                                       normalize=False,
                                       range=(0, 255))
                writer.add_image('Groundtruth label', grid_image, global_step)
            print('loss is ', loss.cpu().item(), flush=True)

        # Save the model
        if (epoch % snapshot) == snapshot - 1:
            torch.save(
                net_.state_dict(),
                os.path.join(save_dir, 'models',
                             modelName + '_epoch-' + str(epoch) + '.pth'))
            print("Save model at {}\n".format(
                os.path.join(save_dir, 'models',
                             modelName + '_epoch-' + str(epoch) + '.pth')))

        torch.cuda.empty_cache()

        # One testing epoch
        if useTest and epoch % nTestInterval == (nTestInterval - 1):
            val_cihp(net_,
                     testloader=testloader,
                     testloader_flip=testloader_flip,
                     test_graph=test_graph,
                     epoch=epoch,
                     writer=writer,
                     criterion=criterion,
                     classes=opts.classes,
                     device=device)
        torch.cuda.empty_cache()
Ejemplo n.º 3
0
def main(opts):
    p = OrderedDict()  # Parameters to include in report
    p["trainBatch"] = opts.batch  # Training batch size
    testBatch = 1  # Testing batch size
    useTest = True  # See evolution of the test set when training
    nTestInterval = opts.testInterval  # Run on test set every nTestInterval epochs
    snapshot = 1  # Store a model every snapshot epochs
    p["nAveGrad"] = 1  # Average the gradient of several iterations
    p["lr"] = opts.lr  # Learning rate
    p["lrFtr"] = 1e-5
    p["lraspp"] = 1e-5
    p["lrpro"] = 1e-5
    p["lrdecoder"] = 1e-5
    p["lrother"] = 1e-5
    p["wd"] = 5e-4  # Weight decay
    p["momentum"] = 0.9  # Momentum
    p["epoch_size"] = opts.step  # How many epochs to change learning rate
    p["num_workers"] = opts.numworker
    model_path = opts.pretrainedModel
    backbone = "xception"  # Use xception or resnet as feature extractor,
    nEpochs = opts.epochs

    max_id = 0
    save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
    exp_name = os.path.dirname(os.path.abspath(__file__)).split("/")[-1]
    runs = glob.glob(os.path.join(save_dir_root, "run_cihp", "run_*"))
    for r in runs:
        run_id = int(r.split("_")[-1])
        if run_id >= max_id:
            max_id = run_id + 1
    save_dir = os.path.join(save_dir_root, "run_cihp", "run_" + str(max_id))

    # Network definition
    if backbone == "xception":
        net_ = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(
            n_classes=opts.classes,
            os=16,
            hidden_layers=opts.hidden_layers,
            source_classes=7,
        )
    elif backbone == "resnet":
        # net_ = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
        raise NotImplementedError
    else:
        raise NotImplementedError

    modelName = ("deeplabv3plus-" + backbone + "-voc" +
                 datetime.now().strftime("%b%d_%H-%M-%S"))
    criterion = util.cross_entropy2d

    if gpu_id >= 0:
        # torch.cuda.set_device(device=gpu_id)
        net_.cuda()

    # net load weights
    if not model_path == "":
        x = torch.load(model_path)
        net_.load_state_dict_new(x)
        print("load pretrainedModel:", model_path)
    else:
        print("no pretrainedModel.")
    if not opts.loadmodel == "":
        x = torch.load(opts.loadmodel)
        net_.load_source_model(x)
        print("load model:", opts.loadmodel)
    else:
        print("no model load !!!!!!!!")

    log_dir = os.path.join(
        save_dir,
        "models",
        datetime.now().strftime("%b%d_%H-%M-%S") + "_" + socket.gethostname(),
    )
    writer = SummaryWriter(log_dir=log_dir)
    writer.add_text("load model", opts.loadmodel, 1)
    writer.add_text("setting", sys.argv[0], 1)

    if opts.freezeBN:
        net_.freeze_bn()

    # Use the following optimizer
    optimizer = optim.SGD(net_.parameters(),
                          lr=p["lr"],
                          momentum=p["momentum"],
                          weight_decay=p["wd"])

    composed_transforms_tr = transforms.Compose(
        [tr.RandomSized_new(512),
         tr.Normalize_xception_tf(),
         tr.ToTensor_()])

    composed_transforms_ts = transforms.Compose(
        [tr.Normalize_xception_tf(),
         tr.ToTensor_()])

    composed_transforms_ts_flip = transforms.Compose(
        [tr.HorizontalFlip(),
         tr.Normalize_xception_tf(),
         tr.ToTensor_()])

    voc_train = cihp.VOCSegmentation(split="train",
                                     transform=composed_transforms_tr,
                                     flip=True)
    voc_val = cihp.VOCSegmentation(split="val",
                                   transform=composed_transforms_ts)
    voc_val_flip = cihp.VOCSegmentation(split="val",
                                        transform=composed_transforms_ts_flip)

    trainloader = DataLoader(
        voc_train,
        batch_size=p["trainBatch"],
        shuffle=True,
        num_workers=p["num_workers"],
        drop_last=True,
    )
    testloader = DataLoader(voc_val,
                            batch_size=testBatch,
                            shuffle=False,
                            num_workers=p["num_workers"])
    testloader_flip = DataLoader(voc_val_flip,
                                 batch_size=testBatch,
                                 shuffle=False,
                                 num_workers=p["num_workers"])

    num_img_tr = len(trainloader)
    num_img_ts = len(testloader)
    running_loss_tr = 0.0
    running_loss_ts = 0.0
    aveGrad = 0
    global_step = 0
    print("Training Network")

    net = torch.nn.DataParallel(net_)
    train_graph, test_graph = get_graphs(opts)
    adj1, adj2, adj3 = train_graph

    # Main Training and Testing Loop
    for epoch in range(resume_epoch, nEpochs):
        start_time = timeit.default_timer()

        if epoch % p["epoch_size"] == p["epoch_size"] - 1:
            lr_ = util.lr_poly(p["lr"], epoch, nEpochs, 0.9)
            optimizer = optim.SGD(net_.parameters(),
                                  lr=lr_,
                                  momentum=p["momentum"],
                                  weight_decay=p["wd"])
            writer.add_scalar("data/lr_", lr_, epoch)
            print("(poly lr policy) learning rate: ", lr_)

        net.train()
        for ii, sample_batched in enumerate(trainloader):

            inputs, labels = sample_batched["image"], sample_batched["label"]
            # Forward-Backward of the mini-batch
            inputs, labels = Variable(inputs,
                                      requires_grad=True), Variable(labels)
            global_step += inputs.data.shape[0]

            if gpu_id >= 0:
                inputs, labels = inputs.cuda(), labels.cuda()

            outputs = net.forward(inputs, adj1, adj3, adj2)

            loss = criterion(outputs, labels, batch_average=True)
            running_loss_tr += loss.item()

            # Print stuff
            if ii % num_img_tr == (num_img_tr - 1):
                running_loss_tr = running_loss_tr / num_img_tr
                writer.add_scalar("data/total_loss_epoch", running_loss_tr,
                                  epoch)
                print("[Epoch: %d, numImages: %5d]" %
                      (epoch, ii * p["trainBatch"] + inputs.data.shape[0]))
                print("Loss: %f" % running_loss_tr)
                running_loss_tr = 0
                stop_time = timeit.default_timer()
                print("Execution time: " + str(stop_time - start_time) + "\n")

            # Backward the averaged gradient
            loss /= p["nAveGrad"]
            loss.backward()
            aveGrad += 1

            # Update the weights once in p['nAveGrad'] forward passes
            if aveGrad % p["nAveGrad"] == 0:
                writer.add_scalar("data/total_loss_iter", loss.item(),
                                  ii + num_img_tr * epoch)
                optimizer.step()
                optimizer.zero_grad()
                aveGrad = 0

            # Show 10 * 3 images results each epoch
            if ii % (num_img_tr // 10) == 0:
                grid_image = make_grid(inputs[:3].clone().cpu().data,
                                       3,
                                       normalize=True)
                writer.add_image("Image", grid_image, global_step)
                grid_image = make_grid(
                    util.decode_seg_map_sequence(
                        torch.max(outputs[:3], 1)[1].detach().cpu().numpy()),
                    3,
                    normalize=False,
                    range=(0, 255),
                )
                writer.add_image("Predicted label", grid_image, global_step)
                grid_image = make_grid(
                    util.decode_seg_map_sequence(
                        torch.squeeze(labels[:3], 1).detach().cpu().numpy()),
                    3,
                    normalize=False,
                    range=(0, 255),
                )
                writer.add_image("Groundtruth label", grid_image, global_step)
            print("loss is ", loss.cpu().item(), flush=True)

        # Save the model
        if (epoch % snapshot) == snapshot - 1:
            torch.save(
                net_.state_dict(),
                os.path.join(save_dir, "models",
                             modelName + "_epoch-" + str(epoch) + ".pth"),
            )
            print("Save model at {}\n".format(
                os.path.join(save_dir, "models",
                             modelName + "_epoch-" + str(epoch) + ".pth")))

        torch.cuda.empty_cache()

        # One testing epoch
        if useTest and epoch % nTestInterval == (nTestInterval - 1):
            val_cihp(
                net_,
                testloader=testloader,
                testloader_flip=testloader_flip,
                test_graph=test_graph,
                epoch=epoch,
                writer=writer,
                criterion=criterion,
                classes=opts.classes,
            )
        torch.cuda.empty_cache()
Ejemplo n.º 4
0
        return _img, _target, type_lbl

    def __str__(self):
        return 'datasets(split=' + str(self.split) + ')'


if __name__ == '__main__':
    from dataloaders import custom_transforms as tr
    from dataloaders.utils import decode_segmap
    from torch.utils.data import DataLoader
    from torchvision import transforms
    import matplotlib.pyplot as plt

    composed_transforms_tr = transforms.Compose([
        # tr.RandomHorizontalFlip(),
        tr.RandomSized_new(512),
        tr.RandomRotate(15),
        tr.ToTensor_()
    ])

    voc_train = VOCSegmentation(split='train',
                                transform=composed_transforms_tr)

    dataloader = DataLoader(voc_train,
                            batch_size=5,
                            shuffle=True,
                            num_workers=1)

    for ii, sample in enumerate(dataloader):
        if ii > 10:
            break
Ejemplo n.º 5
0
def main(opts):

	# Some of the settings are not used
	p = OrderedDict()  # Parameters to include in report
	p['trainBatch'] = opts.batch  # Training batch size
	testBatch = 1  # Testing batch size
	useTest = True  # See evolution of the test set when training
	nTestInterval = opts.testInterval  # Run on test set every nTestInterval epochs
	snapshot = 1  # Store a model every snapshot epochs
	p['nAveGrad'] = 1  # Average the gradient of several iterations
	p['lr'] = opts.lr  # Learning rate
	p['lrFtr'] = 1e-5
	p['lraspp'] = 1e-5
	p['lrpro'] = 1e-5
	p['lrdecoder'] = 1e-5
	p['lrother'] = 1e-5
	p['wd'] = 5e-4  # Weight decay
	p['momentum'] = 0.9  # Momentum
	p['epoch_size'] = opts.step  # How many epochs to change learning rate
	p['num_workers'] = opts.numworker
	backbone = 'xception'  # Use xception or resnet as feature extractor,
	nEpochs = opts.epochs

	resume_epoch = opts.resume_epoch

	max_id = 0
	save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
	exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
	runs = glob.glob(os.path.join(save_dir_root, 'run_cihp', 'run_*'))
	for r in runs:
		run_id = int(r.split('_')[-1])
		if run_id >= max_id:
			max_id = run_id + 1
	save_dir = os.path.join(save_dir_root, 'run_cihp', 'run_' + str(max_id))

	print(save_dir)

	# Network definition
	net_ = grapy_net.GrapyMutualLearning(os=16, hidden_layers=opts.hidden_graph_layers)

	modelName = 'deeplabv3plus-' + backbone + '-voc' + datetime.now().strftime('%b%d_%H-%M-%S')
	criterion = util.cross_entropy2d

	log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
	writer = SummaryWriter(log_dir=log_dir)
	writer.add_text('load model', opts.loadmodel, 1)
	writer.add_text('setting', sys.argv[0], 1)

	# Use the following optimizer
	optimizer = optim.SGD(net_.parameters(), lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd'])

	composed_transforms_tr = transforms.Compose([
		tr.RandomSized_new(512),
		tr.Normalize_xception_tf(),
		tr.ToTensor_()])

	composed_transforms_ts = transforms.Compose([
		tr.Normalize_xception_tf(),
		tr.ToTensor_()])

	composed_transforms_ts_flip = transforms.Compose([
		tr.HorizontalFlip(),
		tr.Normalize_xception_tf(),
		tr.ToTensor_()])

	if opts.train_mode == 'cihp_pascal_atr':
		all_train = cihp_pascal_atr.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True)
		num_cihp, num_pascal, num_atr = all_train.get_class_num()

		voc_val = atr.VOCSegmentation(split='val', transform=composed_transforms_ts)
		voc_val_flip = atr.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)

		ss = sam.Sampler_uni(num_cihp, num_pascal, num_atr, opts.batch)

		trainloader = DataLoader(all_train, batch_size=p['trainBatch'], shuffle=False, num_workers=18, sampler=ss, drop_last=True)

	elif opts.train_mode == 'cihp_pascal_atr_1_1_1':
		all_train = cihp_pascal_atr.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True)
		num_cihp, num_pascal, num_atr = all_train.get_class_num()

		voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)
		voc_val_flip = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)

		ss_uni = sam.Sampler_uni(num_cihp, num_pascal, num_atr, opts.batch, balance_id=1)

		trainloader = DataLoader(all_train, batch_size=p['trainBatch'], shuffle=False, num_workers=1, sampler=ss_uni, drop_last=True)

	elif opts.train_mode == 'cihp':
		voc_train = cihp.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True)
		voc_val = cihp.VOCSegmentation(split='val', transform=composed_transforms_ts)
		voc_val_flip = cihp.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)

		trainloader = DataLoader(voc_train, batch_size=p['trainBatch'], shuffle=True, num_workers=18, drop_last=True)

	elif opts.train_mode == 'pascal':

		# here we train without flip but test with flip
		voc_train = pascal_flip.VOCSegmentation(split='train', transform=composed_transforms_tr)
		voc_val = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts)
		voc_val_flip = pascal.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)

		trainloader = DataLoader(voc_train, batch_size=p['trainBatch'], shuffle=True, num_workers=18, drop_last=True)

	elif opts.train_mode == 'atr':

		# here we train without flip but test with flip
		voc_train = atr.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True)
		voc_val = atr.VOCSegmentation(split='val', transform=composed_transforms_ts)
		voc_val_flip = atr.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)

		trainloader = DataLoader(voc_train, batch_size=p['trainBatch'], shuffle=True, num_workers=18, drop_last=True)

	else:
		raise NotImplementedError

	if not opts.loadmodel == '':
		x = torch.load(opts.loadmodel)
		net_.load_state_dict_new(x, strict=False)
		print('load model:', opts.loadmodel)
	else:
		print('no model load !!!!!!!!')

	if not opts.resume_model == '':
		x = torch.load(opts.resume_model)
		net_.load_state_dict(x)
		print('resume model:', opts.resume_model)

	else:
		print('we are not resuming from any model')

	# We only validate on pascal dataset to save time
	testloader = DataLoader(voc_val, batch_size=testBatch, shuffle=False, num_workers=3)
	testloader_flip = DataLoader(voc_val_flip, batch_size=testBatch, shuffle=False, num_workers=3)

	num_img_tr = len(trainloader)
	num_img_ts = len(testloader)

	# Set the category relations
	c1, c2, p1, p2, a1, a2 = [[0], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]],\
							 [[0], [1, 2, 4, 13], [5, 6, 7, 10, 11, 12], [3, 14, 15], [8, 9, 16, 17, 18, 19]], \
							 [[0], [1, 2, 3, 4, 5, 6]], [[0], [1], [2], [3, 4], [5, 6]], [[0], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]],\
							 [[0], [1, 2, 3, 11], [4, 5, 7, 8, 16, 17], [14, 15], [6, 9, 10, 12, 13]]

	net_.set_category_list(c1, c2, p1, p2, a1, a2)
	if gpu_id >= 0:
		# torch.cuda.set_device(device=gpu_id)
		net_.cuda()

	running_loss_tr = 0.0
	running_loss_ts = 0.0

	running_loss_tr_main = 0.0
	running_loss_tr_aux = 0.0
	aveGrad = 0
	global_step = 0
	miou = 0
	cur_miou = 0
	print("Training Network")

	net = torch.nn.DataParallel(net_)

	# Main Training and Testing Loop
	for epoch in range(resume_epoch, nEpochs):
		start_time = timeit.default_timer()

		if opts.poly:
			if epoch % p['epoch_size'] == p['epoch_size'] - 1:
				lr_ = util.lr_poly(p['lr'], epoch, nEpochs, 0.9)
				optimizer = optim.SGD(net_.parameters(), lr=lr_, momentum=p['momentum'], weight_decay=p['wd'])
				writer.add_scalar('data/lr_', lr_, epoch)
				print('(poly lr policy) learning rate: ', lr_)

		net.train()
		for ii, sample_batched in enumerate(trainloader):

			inputs, labels = sample_batched['image'], sample_batched['label']
			# Forward-Backward of the mini-batch
			inputs, labels = Variable(inputs, requires_grad=True), Variable(labels)
			global_step += inputs.data.shape[0]

			if gpu_id >= 0:
				inputs, labels = inputs.cuda(), labels.cuda()

			if opts.train_mode == 'cihp_pascal_atr' or opts.train_mode == 'cihp_pascal_atr_1_1_1':
				num_dataset_lbl = sample_batched['pascal'][0].item()

			elif opts.train_mode == 'cihp':
				num_dataset_lbl = 0

			elif opts.train_mode == 'pascal':
				num_dataset_lbl = 1

			else:
				num_dataset_lbl = 2

			outputs, outputs_aux = net.forward((inputs, num_dataset_lbl))

			# print(inputs.shape, labels.shape, outputs.shape, outputs_aux.shape)

			loss_main = criterion(outputs, labels, batch_average=True)
			loss_aux = criterion(outputs_aux, labels, batch_average=True)

			loss = opts.beta_main * loss_main + opts.beta_aux * loss_aux

			running_loss_tr_main += loss_main.item()
			running_loss_tr_aux += loss_aux.item()
			running_loss_tr += loss.item()

			# Print stuff
			if ii % num_img_tr == (num_img_tr - 1):
				running_loss_tr = running_loss_tr / num_img_tr
				running_loss_tr_aux = running_loss_tr_aux / num_img_tr
				running_loss_tr_main = running_loss_tr_main / num_img_tr

				writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch)

				writer.add_scalars('data/scalar_group', {'loss': running_loss_tr_main,
														 'loss_aux': running_loss_tr_aux}, epoch)

				print('[Epoch: %d, numImages: %5d]' % (epoch, ii * p['trainBatch'] + inputs.data.shape[0]))
				print('Loss: %f' % running_loss_tr)
				running_loss_tr = 0
				stop_time = timeit.default_timer()
				print("Execution time: " + str(stop_time - start_time) + "\n")

			# Backward the averaged gradient
			loss /= p['nAveGrad']
			loss.backward()
			aveGrad += 1

			# Update the weights once in p['nAveGrad'] forward passes
			if aveGrad % p['nAveGrad'] == 0:
				writer.add_scalar('data/total_loss_iter', loss.item(), ii + num_img_tr * epoch)

				if num_dataset_lbl == 0:
					writer.add_scalar('data/total_loss_iter_cihp', loss.item(), global_step)
				if num_dataset_lbl == 1:
					writer.add_scalar('data/total_loss_iter_pascal', loss.item(), global_step)
				if num_dataset_lbl == 2:
					writer.add_scalar('data/total_loss_iter_atr', loss.item(), global_step)

				optimizer.step()
				optimizer.zero_grad()
				aveGrad = 0

			# Show 10 * 3 images results each
			# print(ii, (num_img_tr * 10), (ii % (num_img_tr * 10) == 0))
			if ii % (num_img_tr * 10) == 0:
				grid_image = make_grid(inputs[:3].clone().cpu().data, 3, normalize=True)
				writer.add_image('Image', grid_image, global_step)
				grid_image = make_grid(
					util.decode_seg_map_sequence(torch.max(outputs[:3], 1)[1].detach().cpu().numpy()), 3,
					normalize=False,
					range=(0, 255))
				writer.add_image('Predicted label', grid_image, global_step)
				grid_image = make_grid(
					util.decode_seg_map_sequence(torch.squeeze(labels[:3], 1).detach().cpu().numpy()), 3,
					normalize=False, range=(0, 255))
				writer.add_image('Groundtruth label', grid_image, global_step)
			print('loss is ', loss.cpu().item(), flush=True)

		# Save the model
		# One testing epoch
		if useTest and epoch % nTestInterval == (nTestInterval - 1):

			cur_miou = validation(net_, testloader=testloader, testloader_flip=testloader_flip, classes=opts.classes,
								epoch=epoch, writer=writer, criterion=criterion, dataset=opts.train_mode)

		torch.cuda.empty_cache()

		if (epoch % snapshot) == snapshot - 1:

			torch.save(net_.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch' + '_current' + '.pth'))
			print("Save model at {}\n".format(
				os.path.join(save_dir, 'models', modelName + str(epoch) + '_epoch-' + str(epoch) + '.pth as our current model')))

			if cur_miou > miou:
				miou = cur_miou
				torch.save(net_.state_dict(), os.path.join(save_dir, 'models', modelName + '_best' + '.pth'))
				print("Save model at {}\n".format(
					os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth as our best model')))

		torch.cuda.empty_cache()
Ejemplo n.º 6
0
def main(opts):
    # Set parameters
    p = OrderedDict()  # Parameters to include in report
    p["trainBatch"] = opts.batch  # Training batch size
    testBatch = 1  # Testing batch size
    useTest = True  # See evolution of the test set when training
    nTestInterval = opts.testInterval  # Run on test set every nTestInterval epochs
    snapshot = 1  # Store a model every snapshot epochs
    p["nAveGrad"] = 1  # Average the gradient of several iterations
    p["lr"] = opts.lr  # Learning rate
    p["wd"] = 5e-4  # Weight decay
    p["momentum"] = 0.9  # Momentum
    p["epoch_size"] = opts.step  # How many epochs to change learning rate
    p["num_workers"] = opts.numworker
    model_path = opts.pretrainedModel
    backbone = "xception"  # Use xception or resnet as feature extractor
    nEpochs = opts.epochs

    max_id = 0
    save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
    exp_name = os.path.dirname(os.path.abspath(__file__)).split("/")[-1]
    runs = glob.glob(os.path.join(save_dir_root, "run", "run_*"))
    for r in runs:
        run_id = int(r.split("_")[-1])
        if run_id >= max_id:
            max_id = run_id + 1
    # run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
    save_dir = os.path.join(save_dir_root, "run", "run_" + str(max_id))

    # Network definition
    if backbone == "xception":
        net_ = deeplab_xception_universal.deeplab_xception_end2end_3d(
            n_classes=20,
            os=16,
            hidden_layers=opts.hidden_layers,
            source_classes=7,
            middle_classes=18,
        )
    elif backbone == "resnet":
        # net_ = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
        raise NotImplementedError
    else:
        raise NotImplementedError

    modelName = (
        "deeplabv3plus-" + backbone + "-voc" + datetime.now().strftime("%b%d_%H-%M-%S")
    )
    criterion = ut.cross_entropy2d

    if gpu_id >= 0:
        # torch.cuda.set_device(device=gpu_id)
        net_.cuda()

    # net load weights
    if not model_path == "":
        x = torch.load(model_path)
        net_.load_state_dict_new(x)
        print("load pretrainedModel.")
    else:
        print("no pretrainedModel.")

    if not opts.loadmodel == "":
        x = torch.load(opts.loadmodel)
        net_.load_source_model(x)
        print("load model:", opts.loadmodel)
    else:
        print("no trained model load !!!!!!!!")

    log_dir = os.path.join(
        save_dir,
        "models",
        datetime.now().strftime("%b%d_%H-%M-%S") + "_" + socket.gethostname(),
    )
    writer = SummaryWriter(log_dir=log_dir)
    writer.add_text("load model", opts.loadmodel, 1)
    writer.add_text("setting", sys.argv[0], 1)

    # Use the following optimizer
    optimizer = optim.SGD(
        net_.parameters(), lr=p["lr"], momentum=p["momentum"], weight_decay=p["wd"]
    )

    composed_transforms_tr = transforms.Compose(
        [tr.RandomSized_new(512), tr.Normalize_xception_tf(), tr.ToTensor_()]
    )

    composed_transforms_ts = transforms.Compose(
        [tr.Normalize_xception_tf(), tr.ToTensor_()]
    )

    composed_transforms_ts_flip = transforms.Compose(
        [tr.HorizontalFlip(), tr.Normalize_xception_tf(), tr.ToTensor_()]
    )

    all_train = cihp_pascal_atr.VOCSegmentation(
        split="train", transform=composed_transforms_tr, flip=True
    )
    voc_val = pascal.VOCSegmentation(split="val", transform=composed_transforms_ts)
    voc_val_flip = pascal.VOCSegmentation(
        split="val", transform=composed_transforms_ts_flip
    )

    num_cihp, num_pascal, num_atr = all_train.get_class_num()
    ss = sam.Sampler_uni(num_cihp, num_pascal, num_atr, opts.batch)
    # balance datasets based pascal
    ss_balanced = sam.Sampler_uni(
        num_cihp, num_pascal, num_atr, opts.batch, balance_id=1
    )

    trainloader = DataLoader(
        all_train,
        batch_size=p["trainBatch"],
        shuffle=False,
        num_workers=p["num_workers"],
        sampler=ss,
        drop_last=True,
    )
    trainloader_balanced = DataLoader(
        all_train,
        batch_size=p["trainBatch"],
        shuffle=False,
        num_workers=p["num_workers"],
        sampler=ss_balanced,
        drop_last=True,
    )
    testloader = DataLoader(
        voc_val, batch_size=testBatch, shuffle=False, num_workers=p["num_workers"]
    )
    testloader_flip = DataLoader(
        voc_val_flip, batch_size=testBatch, shuffle=False, num_workers=p["num_workers"]
    )

    num_img_tr = len(trainloader)
    num_img_balanced = len(trainloader_balanced)
    num_img_ts = len(testloader)
    running_loss_tr = 0.0
    running_loss_tr_atr = 0.0
    running_loss_ts = 0.0
    aveGrad = 0
    global_step = 0
    print("Training Network")
    net = torch.nn.DataParallel(net_)

    id_list = torch.LongTensor(range(opts.batch))
    pascal_iter = int(num_img_tr // opts.batch)

    # Get graphs
    train_graph, test_graph = get_graphs(opts)
    adj1, adj2, adj3, adj4, adj5, adj6 = train_graph
    adj1_test, adj2_test, adj3_test, adj4_test, adj5_test, adj6_test = test_graph

    # Main Training and Testing Loop
    for epoch in range(resume_epoch, int(1.5 * nEpochs)):
        start_time = timeit.default_timer()

        if epoch % p["epoch_size"] == p["epoch_size"] - 1 and epoch < nEpochs:
            lr_ = ut.lr_poly(p["lr"], epoch, nEpochs, 0.9)
            optimizer = optim.SGD(
                net_.parameters(), lr=lr_, momentum=p["momentum"], weight_decay=p["wd"]
            )
            print("(poly lr policy) learning rate: ", lr_)
            writer.add_scalar("data/lr_", lr_, epoch)
        elif epoch % p["epoch_size"] == p["epoch_size"] - 1 and epoch > nEpochs:
            lr_ = ut.lr_poly(p["lr"], epoch - nEpochs, int(0.5 * nEpochs), 0.9)
            optimizer = optim.SGD(
                net_.parameters(), lr=lr_, momentum=p["momentum"], weight_decay=p["wd"]
            )
            print("(poly lr policy) learning rate: ", lr_)
            writer.add_scalar("data/lr_", lr_, epoch)

        net_.train()
        if epoch < nEpochs:
            for ii, sample_batched in enumerate(trainloader):
                inputs, labels = sample_batched["image"], sample_batched["label"]
                dataset_lbl = sample_batched["pascal"][0].item()
                # Forward-Backward of the mini-batch
                inputs, labels = Variable(inputs, requires_grad=True), Variable(labels)
                global_step += 1

                if gpu_id >= 0:
                    inputs, labels = inputs.cuda(), labels.cuda()

                if dataset_lbl == 0:
                    # 0 is cihp -- target
                    _, outputs, _ = net.forward(
                        None,
                        input_target=inputs,
                        input_middle=None,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )
                elif dataset_lbl == 1:
                    # pascal is source
                    outputs, _, _ = net.forward(
                        inputs,
                        input_target=None,
                        input_middle=None,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )
                else:
                    # atr
                    _, _, outputs = net.forward(
                        None,
                        input_target=None,
                        input_middle=inputs,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )
                # print(sample_batched['pascal'])
                # print(outputs.size(),)
                # print(labels)
                loss = criterion(outputs, labels, batch_average=True)
                running_loss_tr += loss.item()

                # Print stuff
                if ii % num_img_tr == (num_img_tr - 1):
                    running_loss_tr = running_loss_tr / num_img_tr
                    writer.add_scalar("data/total_loss_epoch", running_loss_tr, epoch)
                    print("[Epoch: %d, numImages: %5d]" % (epoch, epoch))
                    print("Loss: %f" % running_loss_tr)
                    running_loss_tr = 0
                    stop_time = timeit.default_timer()
                    print("Execution time: " + str(stop_time - start_time) + "\n")

                # Backward the averaged gradient
                loss /= p["nAveGrad"]
                loss.backward()
                aveGrad += 1

                # Update the weights once in p['nAveGrad'] forward passes
                if aveGrad % p["nAveGrad"] == 0:
                    writer.add_scalar("data/total_loss_iter", loss.item(), global_step)
                    if dataset_lbl == 0:
                        writer.add_scalar(
                            "data/total_loss_iter_cihp", loss.item(), global_step
                        )
                    if dataset_lbl == 1:
                        writer.add_scalar(
                            "data/total_loss_iter_pascal", loss.item(), global_step
                        )
                    if dataset_lbl == 2:
                        writer.add_scalar(
                            "data/total_loss_iter_atr", loss.item(), global_step
                        )
                    optimizer.step()
                    optimizer.zero_grad()
                    # optimizer_gcn.step()
                    # optimizer_gcn.zero_grad()
                    aveGrad = 0

                # Show 10 * 3 images results each epoch
                if ii % (num_img_tr // 10) == 0:
                    grid_image = make_grid(
                        inputs[:3].clone().cpu().data, 3, normalize=True
                    )
                    writer.add_image("Image", grid_image, global_step)
                    grid_image = make_grid(
                        ut.decode_seg_map_sequence(
                            torch.max(outputs[:3], 1)[1].detach().cpu().numpy()
                        ),
                        3,
                        normalize=False,
                        range=(0, 255),
                    )
                    writer.add_image("Predicted label", grid_image, global_step)
                    grid_image = make_grid(
                        ut.decode_seg_map_sequence(
                            torch.squeeze(labels[:3], 1).detach().cpu().numpy()
                        ),
                        3,
                        normalize=False,
                        range=(0, 255),
                    )
                    writer.add_image("Groundtruth label", grid_image, global_step)

                print("loss is ", loss.cpu().item(), flush=True)
        else:
            # Balanced the number of datasets
            for ii, sample_batched in enumerate(trainloader_balanced):
                inputs, labels = sample_batched["image"], sample_batched["label"]
                dataset_lbl = sample_batched["pascal"][0].item()
                # Forward-Backward of the mini-batch
                inputs, labels = Variable(inputs, requires_grad=True), Variable(labels)
                global_step += 1

                if gpu_id >= 0:
                    inputs, labels = inputs.cuda(), labels.cuda()

                if dataset_lbl == 0:
                    # 0 is cihp -- target
                    _, outputs, _ = net.forward(
                        None,
                        input_target=inputs,
                        input_middle=None,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )
                elif dataset_lbl == 1:
                    # pascal is source
                    outputs, _, _ = net.forward(
                        inputs,
                        input_target=None,
                        input_middle=None,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )
                else:
                    # atr
                    _, _, outputs = net.forward(
                        None,
                        input_target=None,
                        input_middle=inputs,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )
                # print(sample_batched['pascal'])
                # print(outputs.size(),)
                # print(labels)
                loss = criterion(outputs, labels, batch_average=True)
                running_loss_tr += loss.item()

                # Print stuff
                if ii % num_img_balanced == (num_img_balanced - 1):
                    running_loss_tr = running_loss_tr / num_img_balanced
                    writer.add_scalar("data/total_loss_epoch", running_loss_tr, epoch)
                    print("[Epoch: %d, numImages: %5d]" % (epoch, epoch))
                    print("Loss: %f" % running_loss_tr)
                    running_loss_tr = 0
                    stop_time = timeit.default_timer()
                    print("Execution time: " + str(stop_time - start_time) + "\n")

                # Backward the averaged gradient
                loss /= p["nAveGrad"]
                loss.backward()
                aveGrad += 1

                # Update the weights once in p['nAveGrad'] forward passes
                if aveGrad % p["nAveGrad"] == 0:
                    writer.add_scalar("data/total_loss_iter", loss.item(), global_step)
                    if dataset_lbl == 0:
                        writer.add_scalar(
                            "data/total_loss_iter_cihp", loss.item(), global_step
                        )
                    if dataset_lbl == 1:
                        writer.add_scalar(
                            "data/total_loss_iter_pascal", loss.item(), global_step
                        )
                    if dataset_lbl == 2:
                        writer.add_scalar(
                            "data/total_loss_iter_atr", loss.item(), global_step
                        )
                    optimizer.step()
                    optimizer.zero_grad()

                    aveGrad = 0

                # Show 10 * 3 images results each epoch
                if ii % (num_img_balanced // 10) == 0:
                    grid_image = make_grid(
                        inputs[:3].clone().cpu().data, 3, normalize=True
                    )
                    writer.add_image("Image", grid_image, global_step)
                    grid_image = make_grid(
                        ut.decode_seg_map_sequence(
                            torch.max(outputs[:3], 1)[1].detach().cpu().numpy()
                        ),
                        3,
                        normalize=False,
                        range=(0, 255),
                    )
                    writer.add_image("Predicted label", grid_image, global_step)
                    grid_image = make_grid(
                        ut.decode_seg_map_sequence(
                            torch.squeeze(labels[:3], 1).detach().cpu().numpy()
                        ),
                        3,
                        normalize=False,
                        range=(0, 255),
                    )
                    writer.add_image("Groundtruth label", grid_image, global_step)

                print("loss is ", loss.cpu().item(), flush=True)

        # Save the model
        if (epoch % snapshot) == snapshot - 1:
            torch.save(
                net_.state_dict(),
                os.path.join(
                    save_dir, "models", modelName + "_epoch-" + str(epoch) + ".pth"
                ),
            )
            print(
                "Save model at {}\n".format(
                    os.path.join(
                        save_dir, "models", modelName + "_epoch-" + str(epoch) + ".pth"
                    )
                )
            )

        # One testing epoch
        if useTest and epoch % nTestInterval == (nTestInterval - 1):
            val_pascal(
                net_=net_,
                testloader=testloader,
                testloader_flip=testloader_flip,
                test_graph=test_graph,
                criterion=criterion,
                epoch=epoch,
                writer=writer,
            )