Example #1
0
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True


    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in currendt model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print (name)
        if name in saved_state_dict and param.size() == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))

    model.load_state_dict(new_params)


    model.train()


    model=nn.DataParallel(model)
    model.cuda()


    cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))

    model_D = nn.DataParallel(model_D)
    model_D.train()
    model_D.cuda()


    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)


    train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size,
                       scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                        batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = list(range(train_dataset_size))
            np.random.shuffle(train_ids)

        pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                        batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)


    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)


    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.module.optim_parameters(args),
                lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')


    # labels for adversarial training
    pred_label = 0
    gt_label = 1


    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):


            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv :
                try:
                    _, batch = trainloader_remain_iter.next()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.next()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda()


                pred = interp(model(images))
                pred_remain = pred.detach()

                mask1=F.softmax(pred,dim=1).data.cpu().numpy()

                id2 = np.argmax(mask1, axis=1)#10, 321, 321)


                D_out = interp(model_D(F.softmax(pred,dim=1)))
                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1)


                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool)

                loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, make_D_label(gt_label, ignore_mask_remain))
                loss_semi_adv = loss_semi_adv/args.iter_size

                #loss_semi_adv.backward()
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()[0]/args.lambda_semi_adv

                if args.lambda_semi <= 0 or i_iter < args.semi_start:
                    loss_semi_adv.backward()
                    loss_semi_value = 0
                else:
                    # produce ignore mask
                    semi_ignore_mask = (D_out_sigmoid < args.mask_T)
                    #print semi_ignore_mask.shape 10,321,321

                    map2 = np.zeros([pred.size()[0], id2.shape[1], id2.shape[2]])
                    for k in  range(pred.size()[0]):
                        for i in range(id2.shape[1]):
                            for j in range(id2.shape[2]):
                               map2[k][i][j] = mask1[k][id2[k][i][j]][i][j]


                    semi_ignore_mask = (map2 <  0.999999)
                    semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                    semi_gt[semi_ignore_mask] = 255

                    semi_ratio = 1.0 - float(semi_ignore_mask.sum())/semi_ignore_mask.size
                    print('semi ratio: {:.4f}'.format(semi_ratio))

                    if semi_ratio == 0.0:
                        loss_semi_value += 0
                    else:
                        semi_gt = torch.FloatTensor(semi_gt)

                        loss_semi = args.lambda_semi * loss_calc(pred, semi_gt)
                        loss_semi = loss_semi/args.iter_size
                        loss_semi_value += loss_semi.data.cpu().numpy()[0]/args.lambda_semi
                        loss_semi += loss_semi_adv
                        loss_semi.backward()

            else:
                loss_semi = None
                loss_semi_adv = None

            # train with source

            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _ = batch
            images = Variable(images).cuda()
            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))

            loss_seg = loss_calc(pred, labels)

            D_out = interp(model_D(F.softmax(pred,dim=1)))

            loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss/args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()[0]/args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()[0]/args.iter_size


            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            if args.D_remain:
                pred = torch.cat((pred, pred_remain), 0)
                ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0)


            D_out = interp(model_D(F.softmax(pred,dim=1)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]


            # train with gt
            # get gt labels
            try:
                _, batch = trainloader_gt_iter.next()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.next()

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda()
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]



        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'.format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value))

        if i_iter >= args.num_steps-1:
            print( 'save model ...')
            torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'.pth'))
            torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter!=0:
            print ('taking snapshot ...')
            torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'.pth'))
            torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'_D.pth'))

    end = timeit.default_timer()
    print(end-start,'seconds')
Example #2
0
def main():
    # LD ADD start
    from dataset.LiverDataset.liver_dataset import LiverDataset
    user_name = 'give'
    validation_interval = 800
    max_steps = 1000000000
    batch_size = 1
    n_neighboringslices = 5
    input_size = 400
    output_size = 400
    slice_type = 'axial'
    oversample = False
    # reset_counter = args.reset_counter
    label_of_interest = 1
    label_required = 0
    magic_number = 26.91
    max_slice_tries_val = 0
    max_slice_tries_train = 2
    fuse_labels = True
    apply_crop = False

    train_data_dir = "/home/" + user_name + "/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_2"
    test_data_dir = "/home/" + user_name + "/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_1"
    train_dataset = LiverDataset(data_dir=train_data_dir,
                                 slice_type=slice_type,
                                 n_neighboringslices=n_neighboringslices,
                                 input_size=input_size,
                                 oversample=oversample,
                                 label_of_interest=label_of_interest,
                                 label_required=label_required,
                                 max_slice_tries=max_slice_tries_train,
                                 fuse_labels=fuse_labels,
                                 apply_crop=apply_crop,
                                 interval=validation_interval,
                                 is_training=True,
                                 batch_size=batch_size,
                                 data_augmentation=False)
    val_dataset = LiverDataset(data_dir=test_data_dir,
                               slice_type=slice_type,
                               n_neighboringslices=n_neighboringslices,
                               input_size=input_size,
                               oversample=oversample,
                               label_of_interest=label_of_interest,
                               label_required=label_required,
                               max_slice_tries=max_slice_tries_val,
                               fuse_labels=fuse_labels,
                               apply_crop=apply_crop,
                               interval=validation_interval,
                               is_training=False,
                               batch_size=batch_size)
    # LD ADD end

    # LD build for summary
    # training_summary = tf.summary.FileWriter(os.path.join(SUMMARY_DIR, 'train'))
    # val_summary = tf.summary.FileWriter(os.path.join(SUMMARY_DIR, 'val'))
    # dice_placeholder = tf.placeholder(tf.float32, [], name='dice')
    # loss_placeholder = tf.placeholder(tf.float32, [], name='loss')
    # # image_placeholder = tf.placeholder(tf.float32, [400*2, 400*2], name='image')
    # # prediction_placeholder = tf.placeholder(tf.float32, [400*2, 400*2], name='prediction')
    # tf.summary.scalar('dice', dice_placeholder)
    # tf.summary.scalar('loss', loss_placeholder)
    # # tf.summary.image('image', image_placeholder, max_outputs=1)
    # # tf.summary.image('prediction', prediction_placeholder, max_outputs=1)
    # summary_op = tf.summary.merge_all()
    # config = tf.ConfigProto()
    # config.gpu_options.allow_growth = True
    # sess = tf.Session(config=config)

    perfix_name = 'Liver'
    h, w = map(int, args.input_size.split(','))
    # input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    model = Res_Deeplab(num_classes=args.num_classes,
                        input_channel=1,
                        slice_num=n_neighboringslices,
                        gpu_id=args.gpu)

    if RESTORE_FLAG:
        # load pretrained parameters
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        # only copy the params that exist in current model (caffe-like)
        new_params = model.state_dict().copy()
        for name, param in new_params.items():
            print(name)
            if name in saved_state_dict and param.size(
            ) == saved_state_dict[name].size():
                new_params[name].copy_(saved_state_dict[name])
                print('copy {}'.format(name))
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # LD delete
    '''
    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)
    '''

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # LD delete
    '''
    train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size,
                       scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                        batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = range(train_dataset_size)
            np.random.shuffle(train_ids)

        pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                        batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)


    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)
    '''

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # LD delete
    '''
    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99))
    optimizer_D.zero_grad()
    '''

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size, input_size), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size, input_size),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size, input_size), mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1
    loss_list = []

    for i_iter in range(iter_start, args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0
        num_prediction = 0
        num_ground_truth = 0
        num_intersection = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        # LD delete
        '''
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)
        '''
        for sub_i in range(args.iter_size):

            # train G
            # LD delete
            '''
            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            
            # do semi first
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv :
                try:
                    _, batch = trainloader_remain_iter.next()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.next()
                
                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda(args.gpu)
                
                
                pred = interp(model(images))
                pred_remain = pred.detach()

                D_out = interp(model_D(F.softmax(pred)))
                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1)

                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool)

                loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, make_D_label(gt_label, ignore_mask_remain))
                loss_semi_adv = loss_semi_adv/args.iter_size

                #loss_semi_adv.backward()
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()[0]/args.lambda_semi_adv

                if args.lambda_semi <= 0 or i_iter < args.semi_start:
                    loss_semi_adv.backward()
                    loss_semi_value = 0
                else:
                    # produce ignore mask
                    semi_ignore_mask = (D_out_sigmoid < args.mask_T)

                    semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                    semi_gt[semi_ignore_mask] = 255

                    semi_ratio = 1.0 - float(semi_ignore_mask.sum())/semi_ignore_mask.size
                    print('semi ratio: {:.4f}'.format(semi_ratio))

                    if semi_ratio == 0.0:
                        loss_semi_value += 0
                    else:
                        semi_gt = torch.FloatTensor(semi_gt)

                        loss_semi = args.lambda_semi * loss_calc(pred, semi_gt, args.gpu)
                        loss_semi = loss_semi/args.iter_size
                        loss_semi_value += loss_semi.data.cpu().numpy()[0]/args.lambda_semi
                        loss_semi += loss_semi_adv
                        loss_semi.backward()

            else:
                loss_semi = None
                loss_semi_adv = None
            '''

            # train with source

            # LD delete
            '''
            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            '''
            batch_image, batch_label = train_dataset.get_next_batch()
            batch_image = np.transpose(batch_image, axes=(0, 3, 1, 2))
            # batch_image = np.concatenate([batch_image, batch_image, batch_image], axis=1)
            # print('Shape: ', np.shape(batch_image))
            batch_image_torch = torch.Tensor(batch_image)
            images = Variable(batch_image_torch).cuda(args.gpu)

            # LD delete
            # ignore_mask = (labels.numpy() == 255)
            # print('image size is: ', images.size())
            pred = model(images)
            # print('pred shape is ', pred.size())
            pred = interp(pred)
            pred_ny = pred.data.cpu().numpy()
            pred_ny = np.transpose(pred_ny, axes=(0, 2, 3, 1))
            pred_label_ny = np.squeeze(np.argmax(pred_ny, axis=3))

            # prepare for dice
            # print('Shape of gt is: ', np.shape(batch_label))
            # print('Shape of pred is: ', np.shape(pred_ny))
            # print('Shape of pred_label is: ', np.shape(pred_label_ny))
            num_prediction += np.sum(np.asarray(pred_label_ny, np.uint8))
            num_ground_truth += np.sum(np.asarray(batch_label >= 1, np.uint8))
            num_intersection += np.sum(
                np.asarray(
                    np.logical_and(batch_label >= 1, pred_label_ny >= 1),
                    np.uint8))
            # num_intersection += np.sum(np.asarray(batch_label >= 1, np.uint8) == np.asarray(pred_label_ny, np.uint8))

            loss_seg = loss_calc(pred, batch_label, args.gpu)

            # LD delete
            '''
            D_out = interp(model_D(F.softmax(pred)))

            loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred
            '''
            loss = loss_seg
            # print('Loss is: ', loss)
            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            # print('Loss of numpy is: ', loss_seg.data.cpu().numpy())
            # print('Loss of numpy of zero is: ', loss_seg.data.cpu().numpy())
            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            loss_list.append(loss_seg_value)
            # loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()[0]/args.iter_size

            # train D
            # LD delete
            '''
            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            if args.D_remain:
                pred = torch.cat((pred, pred_remain), 0)
                ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0)

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]


            # train with gt
            # get gt labels
            try:
                _, batch = trainloader_gt_iter.next()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.next()

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]
            '''

        optimizer.step()
        # optimizer_D.step()
        dice = (2 * num_intersection + 1e-7) / (num_prediction +
                                                num_ground_truth + 1e-7)
        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}'.format(
            i_iter, args.num_steps, loss_seg_value))
        print(
            'dice: %.4f, num_prediction: %d, num_ground_truth: %d, num_intersection: %d'
            % (dice, num_prediction, num_ground_truth, num_intersection))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         perfix_name + str(args.num_steps) + '.pth'))
            # torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, perfix_name +str(args.num_steps)+'_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            # torch.save(model.state_dict(), osp.join(args.snapshot_dir, perfix_name + str(i_iter)+'.pth'))
            save_model(model, args.snapshot_dir, perfix_name, i_iter, 2)
            # torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, perfix_name +str(i_iter)+'_D.pth'))

        # if i_iter % UPDATE_TENSORBOARD_INTERVAL and i_iter != 0:
        #     # update tensorboard
        #     feed_dict = {
        #         dice_placeholder: dice,
        #         loss_placeholder: np.mean(loss_list)
        #     }
        #     summery_value = sess.run(summary_op, feed_dict)
        #     training_summary.add_summary(summery_value, i_iter)
        #     training_summary.flush()
        #
        #     # for validation
        #     val_num_prediction = 0
        #     val_num_ground_truth = 0
        #     val_num_intersection = 0
        #     loss_list = []
        #
        #     for _ in range(VAL_EXECUTE_TIMES):
        #         batch_image, batch_label = val_dataset.get_next_batch()
        #         batch_image = np.transpose(batch_image, axes=(0, 3, 1, 2))
        #         batch_image = np.concatenate([batch_image, batch_image, batch_image], axis=1)
        #         # print('Shape: ', np.shape(batch_image))
        #         batch_image_torch = torch.Tensor(batch_image)
        #         images = Variable(batch_image_torch).cuda(args.gpu)
        #
        #         # LD delete
        #         # ignore_mask = (labels.numpy() == 255)
        #         pred = interp(model(images))
        #         pred_ny = pred.data.cpu().numpy()
        #         pred_ny = np.transpose(pred_ny, axes=(0, 2, 3, 1))
        #         pred_label_ny = np.squeeze(np.argmax(pred_ny, axis=3))
        #         val_num_prediction += np.sum(np.asarray(pred_label_ny, np.uint8))
        #         val_num_ground_truth += np.sum(np.asarray(batch_label >= 1, np.uint8))
        #         val_num_intersection += np.sum(np.asarray(np.logical_and(batch_label >= 1, pred_label_ny >= 1), np.uint8))
        #
        #         loss_seg = loss_calc(pred, batch_label, args.gpu)
        #         loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
        #         loss_list.append(loss_seg)
        #     dice = (2 * val_num_intersection + 1e-7) / (val_num_prediction + val_num_ground_truth + 1e-7)
        #     feed_dict = {
        #         dice_placeholder: dice,
        #         loss_placeholder: np.mean(loss_list)
        #     }
        #     summery_value = sess.run(summary_op, feed_dict)
        #     val_summary.add_summary(summery_value, i_iter)
        #     val_summary.flush()
        #     loss_list = []

    training_summary.close()
    val_summary.close()
    end = timeit.default_timer()
    print(end - start, 'seconds')
Example #3
0
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    cudnn.enabled = True
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # initialize parameters
    num_steps = args.num_steps
    batch_size = args.batch_size
    lr = args.lr
    save_cp = args.save_cp
    img_scale = args.scale
    val_percent = args.val / 100

    # data input
    dataset = BasicDataset(IMG_DIRECTORY, MASK_DIRECTORY, img_scale)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    tcga_dataset = UnlabeledDataset(TCGA_DIRECTORY)
    n_unlabeled = len(tcga_dataset)

    # create network
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    #logger.addHandler(logging.StreamHandler())
    logging.info('Using device %s' % str(device))
    logging.info('Network %s' % args.mod)
    logging.info('''Starting training:
            Num_steps:          %.2f
            Batch size:      %.2f
            Learning rate:   %.4f_transform
            Training size:   %.0f
            Validation size: %.0f
            Unlabeled size:  %.0f
            Checkpoints:     %s
            Device:          %s
            Scale:           %.2f
        ''' % (num_steps, batch_size, lr, n_train, n_val, n_unlabeled,
               str(save_cp), str(device.type), img_scale))
    if args.mod == 'unet':
        net = UNet(n_channels=3, n_classes=NUM_CLASSES)
        print('channels = %d , classes = %d' % (net.n_channels, net.n_classes))
    elif args.mod == 'modified_unet':
        net = modified_UNet(n_channels=3, n_classes=NUM_CLASSES)
        print('channels = %d , classes = %d' % (net.n_channels, net.n_classes))
    elif args.mod == 'deeplabv3':
        net = DeepLabV3(nclass=NUM_CLASSES, pretrained_base=False)
        print('channels = 3 , classes = %d' % net.nclass)
    elif args.mod == 'deeplabv3plus':
        net = DeepLabV3Plus(nclass=NUM_CLASSES, pretrained_base=False)
        print('channels = 3 , classes = %d' % net.nclass)
    elif args.mod == 'nestedunet':
        net = NestedUNet(nclass=NUM_CLASSES, deep_supervision=False)
        print('channels = 3 , classes = %d' % net.nlass)
    elif args.mod == 'inception3':
        net = Inception3(n_classes=4,
                         inception_blocks=None,
                         init_weights=True,
                         bilinear=True)
        print('channels = 3 , classes = %d' % net.n_classes)

    net.to(device=device)
    net.train()

    cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    if args.semi_train is None:
        train_loader = DataLoader(train,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=8,
                                  pin_memory=True)
        val_loader = DataLoader(val,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=8,
                                pin_memory=True)
    else:
        #read unlabeled data and labeled data
        train_loader = DataLoader(train,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=4,
                                  pin_memory=True)
        val_loader = DataLoader(val,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=4,
                                pin_memory=True)

        trainloader_remain = DataLoader(tcga_dataset,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=4,
                                        pin_memory=True)
        #trainloader_gt = data.DataLoader(train_gt_dataset,
        #batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(train_loader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    #optimizer = optim.SGD(net.optim_parameters(args),
    #lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay)
    optimizer = optim.Adam(net.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    optimizer.zero_grad()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           10000,
                                                           eta_min=1e-6,
                                                           last_epoch=-1)

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    #optimizer_D = optim.SGD(model_D.parameters(), lr=args.learning_rate_D, momentum=args.momentum,weight_decay=args.weight_decay)
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    '''
    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    '''

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):

        best_acc = 0
        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        #adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False
            for param in net.parameters():
                param.requires_grad = True

            # do semi first
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0
                ) and i_iter >= args.semi_start_adv:
                try:
                    _, batch = trainloader_remain_iter.__next__()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.__next__()

                # only access to img
                images = batch['image']
                images = images.type(torch.FloatTensor)
                images = Variable(images).cuda()

                pred = net(images)
                pred_remain = pred.detach()

                D_out = interp(model_D(F.softmax(pred, dim=1)))
                D_out_sigmoid = torch.sigmoid(
                    D_out).data.cpu().numpy().squeeze(axis=1)
                #D_out_sigmoid = torch.sigmoid(D_out).data.cpu().numpy()

                #ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool)

                targetr = Variable(torch.ones(D_out.shape))
                targetr = Variable(torch.FloatTensor(targetr)).cuda()
                loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, targetr)
                loss_semi_adv = loss_semi_adv / args.iter_size

                #loss_semi_adv.backward()
                #loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()[0]/args.lambda_semi_adv
                loss_semi_adv_value += loss_semi_adv.cpu().detach().numpy(
                ).item() / args.lambda_semi_adv

                if args.lambda_semi <= 0 or i_iter < args.semi_start:
                    loss_semi_adv.backward()
                    loss_semi_value = 0
                else:
                    # produce ignore mask
                    semi_ignore_mask = (D_out_sigmoid < args.mask_T)

                    semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                    semi_gt[semi_ignore_mask] = 255

                    semi_ratio = 1.0 - float(
                        semi_ignore_mask.sum()) / semi_ignore_mask.size
                    print('semi ratio: {:.4f}'.format(semi_ratio))

                    if semi_ratio == 0.0:
                        loss_semi_value += 0
                    else:
                        semi_gt = torch.FloatTensor(semi_gt)

                        loss_semi = args.lambda_semi * loss_calc(pred, semi_gt)
                        loss_semi = loss_semi / args.iter_size
                        loss_semi_value += loss_semi.cpu().detach().numpy(
                        ).item() / args.lambda_semi
                        loss_semi += loss_semi_adv
                        loss_semi.backward()

            else:
                loss_semi = None
                loss_semi_adv = None

            # train with source

            try:
                _, batch = trainloader_iter.__next__()
            except:
                trainloader_iter = enumerate(train_loader)
                _, batch = trainloader_iter.__next__()

            images = batch['image']
            labels = batch['mask']
            images = images.to(device=device, dtype=torch.float32)
            labels = labels.to(device=device, dtype=torch.long)
            labels = labels.squeeze(1)
            ignore_mask = (labels.cpu().numpy() == 255)
            #pred = interp(net(images))

            pred = net(images)
            criterion = nn.CrossEntropyLoss()
            loss_seg = criterion(pred, labels)
            #loss_seg = loss_calc(pred, labels)

            D_out = interp(model_D(F.softmax(pred, dim=1)))

            targetr = Variable(torch.ones(D_out.shape))
            targetr = Variable(torch.FloatTensor(targetr)).cuda()
            #loss_adv_pred = bce_loss(D_out, targetr)

            if i_iter > args.semi_start_adv:
                loss_adv_pred = bce_loss(D_out, targetr)
                loss = loss_seg + args.lambda_adv_pred * loss_adv_pred
                loss_adv_pred_value += loss_adv_pred.cpu().detach().numpy(
                ).item() / args.iter_size
            else:
                loss = loss_seg

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            optimizer.step()
            loss_seg_value += loss_seg.cpu().detach().numpy().item(
            ) / args.iter_size
            #loss_adv_pred_value += loss_adv_pred.cpu().detach().numpy().item()/args.iter_size

            # train D

            # bring back requires_grad
            if i_iter > args.semi_start_adv and i_iter % 3 == 0:
                for param in net.parameters():
                    param.requires_grad = False
                for param in model_D.parameters():
                    param.requires_grad = True

            # train with pred
                pred = pred.detach()

                if args.D_remain:
                    pred = torch.cat((pred, pred_remain), 0)
                #ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0)

                D_out = interp(model_D(F.softmax(pred, dim=1)))
                #targetf = Variable(torch.zeros(D_out.shape))
                targetf = 0.1 * np.random.rand(D_out.shape[0], D_out.shape[1],
                                               D_out.shape[2], D_out.shape[3])
                targetf = Variable(torch.FloatTensor(targetf)).cuda()
                loss_D = bce_loss(D_out, targetf)
                loss_D = loss_D / args.iter_size / 2
                loss_D.backward()
                loss_D_value += loss_D.data.cpu().detach().numpy().item()

                # train with gt
                # get gt labels
                try:
                    _, batch = trainloader_iter.__next__()
                except:
                    trainloader_iter = enumerate(train_loader)
                    _, batch = trainloader_iter.__next__()

                labels_gt = batch['mask']
                D_gt_v = Variable(one_hot(labels_gt)).cuda()
                ignore_mask_gt = (labels_gt.numpy() == 255).squeeze(axis=1)

                D_out = interp(model_D(D_gt_v))
                #targetr = Variable(torch.ones(D_out.shape))
                targetr = 0.1 * np.random.rand(D_out.shape[0], D_out.shape[1],
                                               D_out.shape[2],
                                               D_out.shape[3]) + 0.9
                targetr = Variable(torch.FloatTensor(targetr)).cuda()
                loss_D = bce_loss(D_out, targetr)
                loss_D = loss_D / args.iter_size / 2
                loss_D.backward()
                optimizer_D.step()
                loss_D_value += loss_D.cpu().detach().numpy().item()
        scheduler.step()

        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value))
        '''
        if i_iter >= args.num_steps-1:
            print 'save model ...'
            torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+str(args.num_steps)+'.pth'))
            torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+str(args.num_steps)+'_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter!=0:
            print 'taking snapshot ...'
            torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+str(i_iter)+'.pth'))
            torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+str(i_iter)+'_D.pth'))
        '''
        # save checkpoints
        if save_cp and (i_iter % 1000) == 0 and (i_iter != 0):
            try:
                os.mkdir(DIR_CHECKPOINTS)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(),
                       DIR_CHECKPOINTS + 'i_iter_%d.pth' % (i_iter + 1))
            logging.info('Checkpoint %d saved !' % (i_iter + 1))

        if (i_iter % 1000 == 0) and (i_iter != 0):
            val_score, accuracy, dice_avr, dice_panck, dice_nuclei, dice_lcell = eval_net(
                net, val_loader, device, n_val)
            logging.info('Validation cross entropy: {}'.format(val_score))
            if accuracy > best_acc:
                best_acc = accuracy
            result_file = open('result.txt', 'a', encoding='utf-8')
            result_file.write('best_acc = ' + str(best_acc) + '\n' +
                              'iter = ' + str(i_iter) + '\n')
            result_file.close
Example #4
0
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=16,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=16,
                                         pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = np.arange(train_dataset_size)
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler_all = data.sampler.SubsetRandomSampler(train_ids)
        train_gt_sampler_all = data.sampler.SubsetRandomSampler(train_ids)
        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader_all = data.DataLoader(train_dataset,
                                          batch_size=args.batch_size,
                                          sampler=train_sampler_all,
                                          num_workers=16,
                                          pin_memory=True)
        trainloader_gt_all = data.DataLoader(train_gt_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_gt_sampler_all,
                                             num_workers=16,
                                             pin_memory=True)
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=16,
                                      pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=16,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=16,
                                         pin_memory=True)

        trainloader_remain_iter = iter(trainloader_remain)

    trainloader_all_iter = iter(trainloader_all)
    trainloader_iter = iter(trainloader)
    trainloader_gt_iter = iter(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    #y_real_, y_fake_ = Variable(torch.ones(args.batch_size, 1).cuda()), Variable(torch.zeros(args.batch_size, 1).cuda())

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_fm_value = 0
        loss_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # train with source

            try:
                batch = next(trainloader_iter)
            except:
                trainloader_iter = iter(trainloader)
                batch = next(trainloader_iter)

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            #ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))

            loss_seg = loss_calc(pred, labels, args.gpu)
            loss_seg.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size

            if i_iter >= args.adv_start:

                #fm loss calc
                try:
                    batch = next(trainloader_all_iter)
                except:
                    trainloader_iter = iter(trainloader_all)
                    batch = next(trainloader_all_iter)

                images, labels, _, _ = batch
                images = Variable(images).cuda(args.gpu)
                #ignore_mask = (labels.numpy() == 255)
                pred = interp(model(images))

                _, D_out_y_pred = model_D(F.softmax(pred))

                trainloader_gt_iter = iter(trainloader_gt)
                batch = next(trainloader_gt_iter)

                _, labels_gt, _, _ = batch
                D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
                #ignore_mask_gt = (labels_gt.numpy() == 255)

                _, D_out_y_gt = model_D(D_gt_v)

                fm_loss = torch.mean(
                    torch.abs(
                        torch.mean(D_out_y_gt, 0) -
                        torch.mean(D_out_y_pred, 0)))

                loss = loss_seg + args.lambda_fm * fm_loss

                # proper normalization
                fm_loss.backward()
                #loss_seg_value += loss_seg.data.cpu().numpy()[0]/args.iter_size
                loss_fm_value += fm_loss.data.cpu().numpy()[0] / args.iter_size
                loss_value += loss.data.cpu().numpy()[0] / args.iter_size

                # train D

                # bring back requires_grad
                for param in model_D.parameters():
                    param.requires_grad = True

                # train with pred
                pred = pred.detach()

                D_out_z, _ = model_D(F.softmax(pred))
                y_fake_ = Variable(torch.zeros(D_out_z.size(0), 1).cuda())
                loss_D_fake = criterion(D_out_z, y_fake_)

                # train with gt
                # get gt labels
                _, labels_gt, _, _ = batch
                D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
                #ignore_mask_gt = (labels_gt.numpy() == 255)

                D_out_z_gt, _ = model_D(D_gt_v)
                #D_out = interp(D_out_x)

                y_real_ = Variable(torch.ones(D_out_z_gt.size(0), 1).cuda())

                loss_D_real = criterion(D_out_z_gt, y_real_)
                loss_D = loss_D_fake + loss_D_real
                loss_D.backward()
                loss_D_value += loss_D.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_D = {3:.3f}'.
              format(i_iter, args.num_steps, loss_seg_value, loss_D_value))
        print('fm_loss: ', loss_fm_value, ' g_loss: ', loss_value)

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Example #5
0
def main():
    # LD ADD start
    from dataset.LiverDataset.liver_dataset import LiverDataset
    user_name = 'give'
    validation_interval = 800
    max_steps = 1000000000
    batch_size = 1
    n_neighboringslices = 5
    input_size = 400
    output_size = 400
    slice_type = 'axial'
    oversample = False
    # reset_counter = args.reset_counter
    label_of_interest = 1
    label_required = 0
    magic_number = 26.91
    max_slice_tries_val = 0
    max_slice_tries_train = 2
    fuse_labels = True
    apply_crop = False

    train_data_dir = "/home/" + user_name + "/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_2"
    test_data_dir = "/home/" + user_name + "/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_1"
    train_dataset = LiverDataset(data_dir=train_data_dir,
                                 slice_type=slice_type,
                                 n_neighboringslices=n_neighboringslices,
                                 input_size=input_size,
                                 oversample=oversample,
                                 label_of_interest=label_of_interest,
                                 label_required=label_required,
                                 max_slice_tries=max_slice_tries_train,
                                 fuse_labels=fuse_labels,
                                 apply_crop=apply_crop,
                                 interval=validation_interval,
                                 is_training=True,
                                 batch_size=batch_size,
                                 data_augmentation=True)
    val_dataset = LiverDataset(data_dir=test_data_dir,
                               slice_type=slice_type,
                               n_neighboringslices=n_neighboringslices,
                               input_size=input_size,
                               oversample=oversample,
                               label_of_interest=label_of_interest,
                               label_required=label_required,
                               max_slice_tries=max_slice_tries_val,
                               fuse_labels=fuse_labels,
                               apply_crop=apply_crop,
                               interval=validation_interval,
                               is_training=False,
                               batch_size=batch_size)
    # LD ADD end

    # LD build for summary
    # training_summary = tf.summary.FileWriter(os.path.join(SUMMARY_DIR, 'train'))
    # val_summary = tf.summary.FileWriter(os.path.join(SUMMARY_DIR, 'val'))
    # dice_placeholder = tf.placeholder(tf.float32, [], name='dice')
    # loss_placeholder = tf.placeholder(tf.float32, [], name='loss')
    # # image_placeholder = tf.placeholder(tf.float32, [400*2, 400*2], name='image')
    # # prediction_placeholder = tf.placeholder(tf.float32, [400*2, 400*2], name='prediction')
    # tf.summary.scalar('dice', dice_placeholder)
    # tf.summary.scalar('loss', loss_placeholder)
    # # tf.summary.image('image', image_placeholder, max_outputs=1)
    # # tf.summary.image('prediction', prediction_placeholder, max_outputs=1)
    # summary_op = tf.summary.merge_all()
    # config = tf.ConfigProto()
    # config.gpu_options.allow_growth = True
    # sess = tf.Session(config=config)

    perfix_name = 'Liver'
    h, w = map(int, args.input_size.split(','))
    # input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    model = Res_Deeplab(num_classes=args.num_classes,
                        slice_num=n_neighboringslices,
                        gpu_id=0)
    if RESTORE_FROM is not None:
        # load pretrained parameters
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        # only copy the params that exist in current model (caffe-like)
        new_params = model.state_dict().copy()
        for name, param in new_params.items():
            print(name)
            if name in saved_state_dict and param.size(
            ) == saved_state_dict[name].size():
                new_params[name].copy_(saved_state_dict[name])
                print('copy {}'.format(name))
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # LD delete
    '''
    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)
    '''

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # LD delete
    '''
    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99))
    optimizer_D.zero_grad()
    '''

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size, input_size), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size, input_size),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size, input_size), mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1
    loss_list = []

    from dataset.LiverDataset.medicalImage import preprocessing_agumentation, read_image_file
    image_path = '/home/give/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_1/volume-0.nii'
    gt_path = '/home/give/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_1/segmentation-0.nii'
    image = read_image_file(image_path)
    gt_image = read_image_file(gt_path)

    original_image = np.copy(image)
    processed_image = preprocessing_agumentation(original_image, input_size)

    for slice_idx in range(processed_image.shape[2]):
        # print('%d / %d ' % (slice_idx, processed_image.shape[2]))
        for j in range(n_neighboringslices):
            cur_idx = slice_idx - half_num_slice + j
            if cur_idx < 0:
                cur_idx = 0
            if cur_idx >= processed_image.shape[2]:
                cur_idx = processed_image.shape[2] - 1

            image_input[0, :, :, j] = processed_image[:, :, cur_idx]
    batch_image, batch_label = train_dataset.get_next_batch()
    batch_image = np.transpose(batch_image, axes=(0, 3, 1, 2))
    # batch_image = np.concatenate([batch_image, batch_image, batch_image], axis=1)
    # print('Batch_images: ', np.shape(batch_image))

    batch_image_torch = torch.Tensor(batch_image)
    images = Variable(batch_image_torch).cuda(args.gpu)

    # LD delete
    # ignore_mask = (labels.numpy() == 255)
    pred = interp(model(images))
    pred_ny = pred.data.cpu().numpy()
    pred_ny = np.transpose(pred_ny, axes=(0, 2, 3, 1))
    pred_label_ny = np.squeeze(np.argmax(pred_ny, axis=3))

    # prepare for dice
    # print('Shape of gt is: ', np.shape(batch_label))
    # print('Shape of pred is: ', np.shape(pred_ny))
    # print('Shape of pred_label is: ', np.shape(pred_label_ny))
    num_prediction += np.sum(np.asarray(pred_label_ny, np.uint8))
    num_ground_truth += np.sum(np.asarray(batch_label >= 1, np.uint8))
    num_intersection += np.sum(
        np.asarray(np.logical_and(batch_label >= 1, pred_label_ny >= 1),
                   np.uint8))
    # num_intersection += np.sum(np.asarray(batch_label >= 1, np.uint8) == np.asarray(pred_label_ny, np.uint8))

    loss_seg = loss_calc(pred, batch_label, args.gpu)

    # LD delete
    '''
    D_out = interp(model_D(F.softmax(pred)))

    loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask))

    loss = loss_seg + args.lambda_adv_pred * loss_adv_pred
    '''
    loss = loss_seg
    # print('Loss is: ', loss)
    # proper normalization
    loss = loss / args.iter_size
    loss.backward()
    # print('Loss of numpy is: ', loss_seg.data.cpu().numpy())
    # print('Loss of numpy of zero is: ', loss_seg.data.cpu().numpy())
    loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
    loss_list.append(loss_seg_value)
    # loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()[0]/args.iter_size

    # train D
    # LD delete
    '''
    # bring back requires_grad
    for param in model_D.parameters():
        param.requires_grad = True

    # train with pred
    pred = pred.detach()

    if args.D_remain:
        pred = torch.cat((pred, pred_remain), 0)
        ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0)

    D_out = interp(model_D(F.softmax(pred)))
    loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
    loss_D = loss_D/args.iter_size/2
    loss_D.backward()
    loss_D_value += loss_D.data.cpu().numpy()[0]


    # train with gt
    # get gt labels
    try:
        _, batch = trainloader_gt_iter.next()
    except:
        trainloader_gt_iter = enumerate(trainloader_gt)
        _, batch = trainloader_gt_iter.next()

    _, labels_gt, _, _ = batch
    D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
    ignore_mask_gt = (labels_gt.numpy() == 255)

    D_out = interp(model_D(D_gt_v))
    loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
    loss_D = loss_D/args.iter_size/2
    loss_D.backward()
    loss_D_value += loss_D.data.cpu().numpy()[0]
    '''

    optimizer.step()
    # optimizer_D.step()
    dice = (2 * num_intersection + 1e-7) / (num_prediction + num_ground_truth +
                                            1e-7)
    print('exp = {}'.format(args.snapshot_dir))
    print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}'.format(
        i_iter, args.num_steps, loss_seg_value))
    print(
        'dice: %.4f, num_prediction: %d, num_ground_truth: %d, num_intersection: %d'
        % (dice, num_prediction, num_ground_truth, num_intersection))

    # if i_iter % UPDATE_TENSORBOARD_INTERVAL and i_iter != 0:
    #     # update tensorboard
    #     feed_dict = {
    #         dice_placeholder: dice,
    #         loss_placeholder: np.mean(loss_list)
    #     }
    #     summery_value = sess.run(summary_op, feed_dict)
    #     training_summary.add_summary(summery_value, i_iter)
    #     training_summary.flush()
    #
    #     # for validation
    #     val_num_prediction = 0
    #     val_num_ground_truth = 0
    #     val_num_intersection = 0
    #     loss_list = []
    #
    #     for _ in range(VAL_EXECUTE_TIMES):
    #         batch_image, batch_label = val_dataset.get_next_batch()
    #         batch_image = np.transpose(batch_image, axes=(0, 3, 1, 2))
    #         batch_image = np.concatenate([batch_image, batch_image, batch_image], axis=1)
    #         # print('Shape: ', np.shape(batch_image))
    #         batch_image_torch = torch.Tensor(batch_image)
    #         images = Variable(batch_image_torch).cuda(args.gpu)
    #
    #         # LD delete
    #         # ignore_mask = (labels.numpy() == 255)
    #         pred = interp(model(images))
    #         pred_ny = pred.data.cpu().numpy()
    #         pred_ny = np.transpose(pred_ny, axes=(0, 2, 3, 1))
    #         pred_label_ny = np.squeeze(np.argmax(pred_ny, axis=3))
    #         val_num_prediction += np.sum(np.asarray(pred_label_ny, np.uint8))
    #         val_num_ground_truth += np.sum(np.asarray(batch_label >= 1, np.uint8))
    #         val_num_intersection += np.sum(np.asarray(np.logical_and(batch_label >= 1, pred_label_ny >= 1), np.uint8))
    #
    #         loss_seg = loss_calc(pred, batch_label, args.gpu)
    #         loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
    #         loss_list.append(loss_seg)
    #     dice = (2 * val_num_intersection + 1e-7) / (val_num_prediction + val_num_ground_truth + 1e-7)
    #     feed_dict = {
    #         dice_placeholder: dice,
    #         loss_placeholder: np.mean(loss_list)
    #     }
    #     summery_value = sess.run(summary_op, feed_dict)
    #     val_summary.add_summary(summery_value, i_iter)
    #     val_summary.flush()
    #     loss_list = []

    training_summary.close()
    val_summary.close()
    end = timeit.default_timer()
    print(end - start, 'seconds')
Example #6
0
def train(log_file, arch, dataset, batch_size, iter_size, num_workers,
          partial_data, partial_data_size, partial_id, ignore_label, crop_size,
          eval_crop_size, is_training, learning_rate, learning_rate_d,
          supervised, lambda_adv_pred, lambda_semi, lambda_semi_adv, mask_t,
          semi_start, semi_start_adv, d_remain, momentum, not_restore_last,
          num_steps, power, random_mirror, random_scale, random_seed,
          restore_from, restore_from_d, eval_every, save_snapshot_every,
          snapshot_dir, weight_decay, device):
    settings = locals().copy()

    import cv2
    import torch
    import torch.nn as nn
    from torch.utils import data, model_zoo
    import numpy as np
    import pickle
    import torch.optim as optim
    import torch.nn.functional as F
    import scipy.misc
    import sys
    import os
    import os.path as osp
    import pickle

    from model.deeplab import Res_Deeplab
    from model.unet import unet_resnet50
    from model.deeplabv3 import resnet101_deeplabv3
    from model.discriminator import FCDiscriminator
    from utils.loss import CrossEntropy2d, BCEWithLogitsLoss2d
    from utils.evaluation import EvaluatorIoU
    from dataset.voc_dataset import VOCDataSet
    import logger

    torch_device = torch.device(device)

    import time

    if log_file != '' and log_file != 'none':
        if os.path.exists(log_file):
            print('Log file {} already exists; exiting...'.format(log_file))
            return

    with logger.LogFile(log_file if log_file != 'none' else None):
        if dataset == 'pascal_aug':
            ds = VOCDataSet(augmented_pascal=True)
        elif dataset == 'pascal':
            ds = VOCDataSet(augmented_pascal=False)
        else:
            print('Dataset {} not yet supported'.format(dataset))
            return

        print('Command: {}'.format(sys.argv[0]))
        print('Arguments: {}'.format(' '.join(sys.argv[1:])))
        print('Settings: {}'.format(', '.join([
            '{}={}'.format(k, settings[k])
            for k in sorted(list(settings.keys()))
        ])))

        print('Loaded data')

        def loss_calc(pred, label):
            """
            This function returns cross entropy loss for semantic segmentation
            """
            # out shape batch_size x channels x h x w -> batch_size x channels x h x w
            # label shape h x w x 1 x batch_size  -> batch_size x 1 x h x w
            label = label.long().to(torch_device)
            criterion = CrossEntropy2d()

            return criterion(pred, label)

        def lr_poly(base_lr, iter, max_iter, power):
            return base_lr * ((1 - float(iter) / max_iter)**(power))

        def adjust_learning_rate(optimizer, i_iter):
            lr = lr_poly(learning_rate, i_iter, num_steps, power)
            optimizer.param_groups[0]['lr'] = lr
            if len(optimizer.param_groups) > 1:
                optimizer.param_groups[1]['lr'] = lr * 10

        def adjust_learning_rate_D(optimizer, i_iter):
            lr = lr_poly(learning_rate_d, i_iter, num_steps, power)
            optimizer.param_groups[0]['lr'] = lr
            if len(optimizer.param_groups) > 1:
                optimizer.param_groups[1]['lr'] = lr * 10

        def one_hot(label):
            label = label.numpy()
            one_hot = np.zeros((label.shape[0], ds.num_classes, label.shape[1],
                                label.shape[2]),
                               dtype=label.dtype)
            for i in range(ds.num_classes):
                one_hot[:, i, ...] = (label == i)
            #handle ignore labels
            return torch.tensor(one_hot,
                                dtype=torch.float,
                                device=torch_device)

        def make_D_label(label, ignore_mask):
            ignore_mask = np.expand_dims(ignore_mask, axis=1)
            D_label = np.ones(ignore_mask.shape) * label
            D_label[ignore_mask] = ignore_label
            D_label = torch.tensor(D_label,
                                   dtype=torch.float,
                                   device=torch_device)

            return D_label

        h, w = map(int, eval_crop_size.split(','))
        eval_crop_size = (h, w)

        h, w = map(int, crop_size.split(','))
        crop_size = (h, w)

        # create network
        if arch == 'deeplab2':
            model = Res_Deeplab(num_classes=ds.num_classes)
        elif arch == 'unet_resnet50':
            model = unet_resnet50(num_classes=ds.num_classes)
        elif arch == 'resnet101_deeplabv3':
            model = resnet101_deeplabv3(num_classes=ds.num_classes)
        else:
            print('Architecture {} not supported'.format(arch))
            return

        # load pretrained parameters
        if restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(restore_from)
        else:
            saved_state_dict = torch.load(restore_from)

        # only copy the params that exist in current model (caffe-like)
        new_params = model.state_dict().copy()
        for name, param in new_params.items():
            if name in saved_state_dict and param.size(
            ) == saved_state_dict[name].size():
                new_params[name].copy_(saved_state_dict[name])
        model.load_state_dict(new_params)

        model.train()
        model = model.to(torch_device)

        # init D
        model_D = FCDiscriminator(num_classes=ds.num_classes)
        if restore_from_d is not None:
            model_D.load_state_dict(torch.load(restore_from_d))
        model_D.train()
        model_D = model_D.to(torch_device)

        print('Built model')

        if snapshot_dir is not None:
            if not os.path.exists(snapshot_dir):
                os.makedirs(snapshot_dir)

        ds_train_xy = ds.train_xy(crop_size=crop_size,
                                  scale=random_scale,
                                  mirror=random_mirror,
                                  range01=model.RANGE01,
                                  mean=model.MEAN,
                                  std=model.STD)
        ds_train_y = ds.train_y(crop_size=crop_size,
                                scale=random_scale,
                                mirror=random_mirror,
                                range01=model.RANGE01,
                                mean=model.MEAN,
                                std=model.STD)
        ds_val_xy = ds.val_xy(crop_size=eval_crop_size,
                              scale=False,
                              mirror=False,
                              range01=model.RANGE01,
                              mean=model.MEAN,
                              std=model.STD)

        train_dataset_size = len(ds_train_xy)

        if partial_data_size != -1:
            if partial_data_size > partial_data_size:
                print('partial-data-size > |train|: exiting')
                return

        if partial_data == 1.0 and (partial_data_size == -1 or
                                    partial_data_size == train_dataset_size):
            trainloader = data.DataLoader(ds_train_xy,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=5,
                                          pin_memory=True)

            trainloader_gt = data.DataLoader(ds_train_y,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=5,
                                             pin_memory=True)

            trainloader_remain = None
            print('|train|={}'.format(train_dataset_size))
            print('|val|={}'.format(len(ds_val_xy)))
        else:
            #sample partial data
            if partial_data_size != -1:
                partial_size = partial_data_size
            else:
                partial_size = int(partial_data * train_dataset_size)

            if partial_id is not None:
                train_ids = pickle.load(open(partial_id))
                print('loading train ids from {}'.format(partial_id))
            else:
                rng = np.random.RandomState(random_seed)
                train_ids = list(rng.permutation(train_dataset_size))

            if snapshot_dir is not None:
                pickle.dump(train_ids,
                            open(osp.join(snapshot_dir, 'train_id.pkl'), 'wb'))

            print('|train supervised|={}'.format(partial_size))
            print('|train unsupervised|={}'.format(train_dataset_size -
                                                   partial_size))
            print('|val|={}'.format(len(ds_val_xy)))

            print('supervised={}'.format(list(train_ids[:partial_size])))

            train_sampler = data.sampler.SubsetRandomSampler(
                train_ids[:partial_size])
            train_remain_sampler = data.sampler.SubsetRandomSampler(
                train_ids[partial_size:])
            train_gt_sampler = data.sampler.SubsetRandomSampler(
                train_ids[:partial_size])

            trainloader = data.DataLoader(ds_train_xy,
                                          batch_size=batch_size,
                                          sampler=train_sampler,
                                          num_workers=3,
                                          pin_memory=True)
            trainloader_remain = data.DataLoader(ds_train_xy,
                                                 batch_size=batch_size,
                                                 sampler=train_remain_sampler,
                                                 num_workers=3,
                                                 pin_memory=True)
            trainloader_gt = data.DataLoader(ds_train_y,
                                             batch_size=batch_size,
                                             sampler=train_gt_sampler,
                                             num_workers=3,
                                             pin_memory=True)

            trainloader_remain_iter = enumerate(trainloader_remain)

        testloader = data.DataLoader(ds_val_xy,
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)

        print('Data loaders ready')

        trainloader_iter = enumerate(trainloader)
        trainloader_gt_iter = enumerate(trainloader_gt)

        # implement model.optim_parameters(args) to handle different models' lr setting

        # optimizer for segmentation network
        optimizer = optim.SGD(model.optim_parameters(learning_rate),
                              lr=learning_rate,
                              momentum=momentum,
                              weight_decay=weight_decay)
        optimizer.zero_grad()

        # optimizer for discriminator network
        optimizer_D = optim.Adam(model_D.parameters(),
                                 lr=learning_rate_d,
                                 betas=(0.9, 0.99))
        optimizer_D.zero_grad()

        # loss/ bilinear upsampling
        bce_loss = BCEWithLogitsLoss2d()

        print('Built optimizer')

        # labels for adversarial training
        pred_label = 0
        gt_label = 1

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_mask_accum = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        t1 = time.time()

        print('Training for {} steps...'.format(num_steps))
        for i_iter in range(num_steps + 1):

            model.train()
            model.freeze_batchnorm()

            optimizer.zero_grad()
            adjust_learning_rate(optimizer, i_iter)
            optimizer_D.zero_grad()
            adjust_learning_rate_D(optimizer_D, i_iter)

            for sub_i in range(iter_size):

                # train G

                if not supervised:
                    # don't accumulate grads in D
                    for param in model_D.parameters():
                        param.requires_grad = False

                # do semi first
                if not supervised and (lambda_semi > 0 or lambda_semi_adv > 0 ) and i_iter >= semi_start_adv and \
                        trainloader_remain is not None:
                    try:
                        _, batch = next(trainloader_remain_iter)
                    except:
                        trainloader_remain_iter = enumerate(trainloader_remain)
                        _, batch = next(trainloader_remain_iter)

                    # only access to img
                    images, _, _, _ = batch
                    images = images.float().to(torch_device)

                    pred = model(images)
                    pred_remain = pred.detach()

                    D_out = model_D(F.softmax(pred, dim=1))
                    D_out_sigmoid = F.sigmoid(
                        D_out).data.cpu().numpy().squeeze(axis=1)

                    ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(
                        np.bool)

                    loss_semi_adv = lambda_semi_adv * bce_loss(
                        D_out, make_D_label(gt_label, ignore_mask_remain))
                    loss_semi_adv = loss_semi_adv / iter_size

                    #loss_semi_adv.backward()
                    loss_semi_adv_value += float(
                        loss_semi_adv) / lambda_semi_adv

                    if lambda_semi <= 0 or i_iter < semi_start:
                        loss_semi_adv.backward()
                        loss_semi_value = 0
                    else:
                        # produce ignore mask
                        semi_ignore_mask = (D_out_sigmoid < mask_t)

                        semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                        semi_gt[semi_ignore_mask] = ignore_label

                        semi_ratio = 1.0 - float(
                            semi_ignore_mask.sum()) / semi_ignore_mask.size

                        loss_semi_mask_accum += float(semi_ratio)

                        if semi_ratio == 0.0:
                            loss_semi_value += 0
                        else:
                            semi_gt = torch.FloatTensor(semi_gt)

                            loss_semi = lambda_semi * loss_calc(pred, semi_gt)
                            loss_semi = loss_semi / iter_size
                            loss_semi_value += float(loss_semi) / lambda_semi
                            loss_semi += loss_semi_adv
                            loss_semi.backward()

                else:
                    loss_semi = None
                    loss_semi_adv = None

                # train with source

                try:
                    _, batch = next(trainloader_iter)
                except:
                    trainloader_iter = enumerate(trainloader)
                    _, batch = next(trainloader_iter)

                images, labels, _, _ = batch
                images = images.float().to(torch_device)
                ignore_mask = (labels.numpy() == ignore_label)
                pred = model(images)

                loss_seg = loss_calc(pred, labels)

                if supervised:
                    loss = loss_seg
                else:
                    D_out = model_D(F.softmax(pred, dim=1))

                    loss_adv_pred = bce_loss(
                        D_out, make_D_label(gt_label, ignore_mask))

                    loss = loss_seg + lambda_adv_pred * loss_adv_pred
                    loss_adv_pred_value += float(loss_adv_pred) / iter_size

                # proper normalization
                loss = loss / iter_size
                loss.backward()
                loss_seg_value += float(loss_seg) / iter_size

                if not supervised:
                    # train D

                    # bring back requires_grad
                    for param in model_D.parameters():
                        param.requires_grad = True

                    # train with pred
                    pred = pred.detach()

                    if d_remain:
                        pred = torch.cat((pred, pred_remain), 0)
                        ignore_mask = np.concatenate(
                            (ignore_mask, ignore_mask_remain), axis=0)

                    D_out = model_D(F.softmax(pred, dim=1))
                    loss_D = bce_loss(D_out,
                                      make_D_label(pred_label, ignore_mask))
                    loss_D = loss_D / iter_size / 2
                    loss_D.backward()
                    loss_D_value += float(loss_D)

                    # train with gt
                    # get gt labels
                    try:
                        _, batch = next(trainloader_gt_iter)
                    except:
                        trainloader_gt_iter = enumerate(trainloader_gt)
                        _, batch = next(trainloader_gt_iter)

                    _, labels_gt, _, _ = batch
                    D_gt_v = one_hot(labels_gt)
                    ignore_mask_gt = (labels_gt.numpy() == ignore_label)

                    D_out = model_D(D_gt_v)
                    loss_D = bce_loss(D_out,
                                      make_D_label(gt_label, ignore_mask_gt))
                    loss_D = loss_D / iter_size / 2
                    loss_D.backward()
                    loss_D_value += float(loss_D)

            optimizer.step()
            optimizer_D.step()

            sys.stdout.write('.')
            sys.stdout.flush()

            if i_iter % eval_every == 0 and i_iter != 0:
                model.eval()
                with torch.no_grad():
                    evaluator = EvaluatorIoU(ds.num_classes)
                    for index, batch in enumerate(testloader):
                        image, label, size, name = batch
                        size = size[0].numpy()
                        image = image.float().to(torch_device)
                        output = model(image)
                        output = output.cpu().data[0].numpy()

                        output = output[:, :size[0], :size[1]]
                        gt = np.asarray(label[0].numpy()[:size[0], :size[1]],
                                        dtype=np.int)

                        output = output.transpose(1, 2, 0)
                        output = np.asarray(np.argmax(output, axis=2),
                                            dtype=np.int)

                        evaluator.sample(gt, output, ignore_value=ignore_label)

                        sys.stdout.write('+')
                        sys.stdout.flush()

                per_class_iou = evaluator.score()
                mean_iou = per_class_iou.mean()

                loss_seg_value /= eval_every
                loss_adv_pred_value /= eval_every
                loss_D_value /= eval_every
                loss_semi_mask_accum /= eval_every
                loss_semi_value /= eval_every
                loss_semi_adv_value /= eval_every

                sys.stdout.write('\n')

                t2 = time.time()

                print(
                    'iter = {:8d}/{:8d}, took {:.3f}s, loss_seg = {:.6f}, loss_adv_p = {:.6f}, loss_D = {:.6f}, loss_semi_mask_rate = {:.3%} loss_semi = {:.6f}, loss_semi_adv = {:.3f}'
                    .format(i_iter, num_steps, t2 - t1, loss_seg_value,
                            loss_adv_pred_value, loss_D_value,
                            loss_semi_mask_accum, loss_semi_value,
                            loss_semi_adv_value))

                for i, (class_name,
                        iou) in enumerate(zip(ds.class_names, per_class_iou)):
                    print('class {:2d} {:12} IU {:.2f}'.format(
                        i, class_name, iou))

                print('meanIOU: ' + str(mean_iou) + '\n')

                loss_seg_value = 0
                loss_adv_pred_value = 0
                loss_D_value = 0
                loss_semi_value = 0
                loss_semi_mask_accum = 0
                loss_semi_adv_value = 0

                t1 = t2

            if snapshot_dir is not None and i_iter % save_snapshot_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                torch.save(
                    model.state_dict(),
                    osp.join(snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
                torch.save(
                    model_D.state_dict(),
                    osp.join(snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

        if snapshot_dir is not None:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(snapshot_dir, 'VOC_' + str(num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(snapshot_dir, 'VOC_' + str(num_steps) + '_D.pth'))
Example #7
0
def main():
    # LD ADD start
    from dataset.LiverDataset.liver_dataset import LiverDataset
    user_name = 'give'
    validation_interval = 800
    max_steps = 1000000000
    batch_size = 1
    n_neighboringslices = 5
    input_size = 400
    output_size = 400
    slice_type = 'axial'
    oversample = False
    # reset_counter = args.reset_counter
    label_of_interest = 1
    label_required = 0
    magic_number = 26.91
    max_slice_tries_val = 0
    max_slice_tries_train = 2
    fuse_labels = True
    apply_crop = False

    train_data_dir = "/home/" + user_name + "/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_2"
    test_data_dir = "/home/" + user_name + "/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_1"
    train_dataset = LiverDataset(data_dir=train_data_dir,
                                 slice_type=slice_type,
                                 n_neighboringslices=n_neighboringslices,
                                 input_size=input_size,
                                 oversample=oversample,
                                 label_of_interest=label_of_interest,
                                 label_required=label_required,
                                 max_slice_tries=max_slice_tries_train,
                                 fuse_labels=fuse_labels,
                                 apply_crop=apply_crop,
                                 interval=validation_interval,
                                 is_training=True,
                                 batch_size=batch_size,
                                 data_augmentation=True)
    val_dataset = LiverDataset(data_dir=test_data_dir,
                               slice_type=slice_type,
                               n_neighboringslices=n_neighboringslices,
                               input_size=input_size,
                               oversample=oversample,
                               label_of_interest=label_of_interest,
                               label_required=label_required,
                               max_slice_tries=max_slice_tries_val,
                               fuse_labels=fuse_labels,
                               apply_crop=apply_crop,
                               interval=validation_interval,
                               is_training=False,
                               batch_size=batch_size)
    # LD ADD end

    # LD build for summary
    training_summary = tf.summary.FileWriter(os.path.join(
        SUMMARY_DIR, 'train'))
    val_summary = tf.summary.FileWriter(os.path.join(SUMMARY_DIR, 'val'))
    dice_placeholder = tf.placeholder(tf.float32, [], name='dice')
    loss_placeholder = tf.placeholder(tf.float32, [], name='loss')
    tf.summary.scalar('dice', dice_placeholder)
    tf.summary.scalar('loss', loss_placeholder)
    summary_op = tf.summary.merge_all()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    perfix_name = 'Liver'
    h, w = map(int, args.input_size.split(','))
    # input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    model = Res_Deeplab(num_classes=args.num_classes,
                        slice_num=n_neighboringslices,
                        gpu_id=0)
    if RESTORE_FROM is not None:
        # load pretrained parameters
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        # only copy the params that exist in current model (caffe-like)
        new_params = model.state_dict().copy()
        for name, param in new_params.items():
            print(name)
            if name in saved_state_dict and param.size(
            ) == saved_state_dict[name].size():
                new_params[name].copy_(saved_state_dict[name])
                print('copy {}'.format(name))
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # LD delete
    '''
    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)
    '''

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # LD delete
    '''
    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99))
    optimizer_D.zero_grad()
    '''

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size, input_size), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size, input_size),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size, input_size), mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1
    loss_list = []
    dice_list = []
    for i_iter in range(iter_start, args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0
        num_prediction = 0
        num_ground_truth = 0
        num_intersection = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        # model.train(True)
        # LD delete
        '''
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)
        '''
        for sub_i in range(args.iter_size):

            # LD delete
            '''
            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            '''
            batch_image, batch_label = train_dataset.get_next_batch()
            batch_image = np.transpose(batch_image, axes=(0, 3, 1, 2))
            # batch_image = np.concatenate([batch_image, batch_image, batch_image], axis=1)
            # print('Batch_images: ', np.shape(batch_image))

            batch_image_torch = torch.Tensor(batch_image)
            images = (batch_image_torch).cuda(args.gpu)

            # LD delete
            # ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))
            pred_ny = pred.data.cpu().numpy()
            pred_ny = np.transpose(pred_ny, axes=(0, 2, 3, 1))
            pred_label_ny = np.squeeze(np.argmax(pred_ny, axis=3))

            cur_prediction = np.sum(np.asarray(pred_label_ny, np.uint8))
            cur_grount_truth = np.sum(np.asarray(batch_label >= 1, np.uint8))
            cur_intersection = np.sum(
                np.asarray(
                    np.logical_and(batch_label >= 1, pred_label_ny >= 1),
                    np.uint8))
            cur_dice = (2 * cur_intersection +
                        1e-7) / (cur_prediction + cur_grount_truth + 1e-7)
            dice_list.append(cur_dice)

            num_prediction += np.sum(np.asarray(pred_label_ny, np.uint8))
            num_ground_truth += np.sum(np.asarray(batch_label >= 1, np.uint8))
            num_intersection += np.sum(
                np.asarray(
                    np.logical_and(batch_label >= 1, pred_label_ny >= 1),
                    np.uint8))
            # num_intersection += np.sum(np.asarray(batch_label >= 1, np.uint8) == np.asarray(pred_label_ny, np.uint8))

            loss_seg = loss_calc(pred, batch_label, args.gpu)

            # LD delete
            '''
            D_out = interp(model_D(F.softmax(pred)))

            loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred
            '''
            loss = loss_seg
            # print('Loss is: ', loss)
            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            # print('Loss of numpy is: ', loss_seg.data.cpu().numpy())
            # print('Loss of numpy of zero is: ', loss_seg.data.cpu().numpy())
            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            loss_list.append(loss_seg_value)
            # loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()[0]/args.iter_size

        optimizer.step()
        # optimizer_D.step()
        dice = (2 * num_intersection + 1e-7) / (num_prediction +
                                                num_ground_truth + 1e-7)
        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}'.format(
            i_iter, args.num_steps, loss_seg_value))
        print(
            'dice: %.4f, num_prediction: %d, num_ground_truth: %d, num_intersection: %d'
            % (dice, num_prediction, num_ground_truth, num_intersection))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         perfix_name + str(args.num_steps) + '.pth'))
            # torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, perfix_name +str(args.num_steps)+'_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            # torch.save(model.state_dict(), osp.join(args.snapshot_dir, perfix_name + str(i_iter)+'.pth'))
            save_model(model, args.snapshot_dir, perfix_name, i_iter, 2)
            # torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, perfix_name +str(i_iter)+'_D.pth'))

        # if i_iter % UPDATE_TENSORBOARD_INTERVAL and i_iter != 0:
        # update tensorboard
        feed_dict = {
            dice_placeholder: dice,
            loss_placeholder: np.mean(loss_list)
        }
        summery_value = sess.run(summary_op, feed_dict)
        training_summary.add_summary(summery_value, i_iter)
        training_summary.flush()
        loss_list = []
        dice_list = []
        # for validation
        # val_num_prediction = 0
        # val_num_ground_truth = 0
        # val_num_intersection = 0
        # loss_list = []
        # # model.train(False)
        # for idx in range(VAL_EXECUTE_TIMES):
        #     print(idx)
        #     batch_image, batch_label = val_dataset.get_next_batch()
        #     batch_image = np.transpose(batch_image, axes=(0, 3, 1, 2))
        #     # batch_image = np.concatenate([batch_image, batch_image, batch_image], axis=1)
        #     # print('Shape: ', np.shape(batch_image))
        #     batch_image_torch = torch.Tensor(batch_image)
        #     images = Variable(batch_image_torch).cuda(args.gpu)
        #
        #     # LD delete
        #     # ignore_mask = (labels.numpy() == 255)
        #     pred = interp(model(images))
        #     pred_ny = pred.data.cpu().numpy()
        #     pred_ny = np.transpose(pred_ny, axes=(0, 2, 3, 1))
        #     pred_label_ny = np.squeeze(np.argmax(pred_ny, axis=3))
        #     val_num_prediction += np.sum(np.asarray(pred_label_ny, np.uint8))
        #     val_num_ground_truth += np.sum(np.asarray(batch_label >= 1, np.uint8))
        #     val_num_intersection += np.sum(np.asarray(np.logical_and(batch_label >= 1, pred_label_ny >= 1), np.uint8))
        #
        #     loss_seg = loss_calc(pred, batch_label, args.gpu)
        #     loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
        #     loss_list.append(loss_seg)
        # dice = (2 * val_num_intersection + 1e-7) / (val_num_prediction + val_num_ground_truth + 1e-7)
        # feed_dict = {
        #     dice_placeholder: dice,
        #     loss_placeholder: np.mean(loss_list)
        # }
        # print('validation: dice:%.4f, loss: %.4f' % (dice, np.mean(loss_list)))
        # summery_value = sess.run(summary_op, feed_dict)
        # val_summary.add_summary(summery_value, i_iter)
        # val_summary.flush()
        # loss_list = []
        # print('\n')
    training_summary.close()
    val_summary.close()
    end = timeit.default_timer()
    print(end - start, 'seconds')
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    model = DeeplabMulti(num_classes=args.num_classes)
    #model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from, map_location='cuda:0')

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)
    #summary(model,(3,7,7))

    cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)
    #summary(model_D, (21,321,321))
    #quit()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = cityscapesDataSet(max_iters=args.num_steps *
                                      args.iter_size * args.batch_size,
                                      scale=args.random_scale)
    train_dataset_size = len(train_dataset)
    train_gt_dataset = cityscapesDataSet(max_iters=args.num_steps *
                                         args.iter_size * args.batch_size,
                                         scale=args.random_scale)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        # sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = list(range(train_dataset_size))
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.batch_size,
                                      shuffle=False,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             sampler=train_remain_sampler,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         sampler=train_gt_sampler,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers,
                                         pin_memory=True)
        trainloader_remain_iter = enumerate(trainloader)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0
        loss_laplacian = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0
                ) and i_iter >= args.semi_start_adv:
                try:
                    _, batch = trainloader_remain_iter.__next__()
                except:
                    trainloader_remain_iter = enumerate(trainloader)
                    _, batch = trainloader_remain_iter.__next__()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda(args.gpu)

                try:
                    pred = interp(model(images))
                except RuntimeError as exception:
                    if "out of memory" in str(exception):
                        print("WARNING: out of memory")
                        if hasattr(torch.cuda, 'empty_cache'):
                            torch.cuda.empty_cache()
                    else:
                        raise exception

                pred_remain = pred.detach()

                D_out = interp(model_D(F.softmax(pred)))
                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(
                    axis=1)

                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(
                    np.bool)

                loss_semi_adv = args.lambda_semi_adv * bce_loss(
                    D_out, make_D_label(gt_label, ignore_mask_remain))
                loss_semi_adv = loss_semi_adv / args.iter_size

                #loss_semi_adv.backward()
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy(
                ) / args.lambda_semi_adv

                if args.lambda_semi <= 0 or i_iter < args.semi_start:
                    loss_semi_adv.backward()
                    loss_semi_value = 0
                else:
                    # produce ignore mask
                    semi_ignore_mask = (D_out_sigmoid < args.mask_T)

                    semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                    semi_gt[semi_ignore_mask] = 255

                    semi_ratio = 1.0 - float(
                        semi_ignore_mask.sum()) / semi_ignore_mask.size
                    print('semi ratio: {:.4f}'.format(semi_ratio))

                    if semi_ratio == 0.0:
                        loss_semi_value += 0
                    else:
                        semi_gt = torch.FloatTensor(semi_gt)

                        loss_semi = args.lambda_semi * loss_calc(
                            pred, semi_gt, args.gpu)
                        loss_semi = loss_semi / args.iter_size
                        loss_semi_value += loss_semi.data.cpu().numpy(
                        ) / args.lambda_semi
                        loss_semi += loss_semi_adv
                        loss_semi.backward()

            else:
                loss_semi = None
                loss_semi_adv = None

            # train with source

            try:
                _, batch = trainloader_iter.__next__()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            ignore_mask = (labels.numpy() == 255)

            try:
                pred = interp(model(images))
            except RuntimeError as exception:
                if "out of memory" in str(exception):
                    print("WARNING: out of memory")
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                else:
                    raise exception
            for i in range(1):
                imagess = torch.zeros(1280, 720).cuda()
                for j in range(19):
                    try:
                        imagess += pred[i, j, :, :].reshape(1280, 720)
                    except IndexError:
                        pass
                try:
                    label = labels[i, :, :].reshape(1280, 720).cuda()
                except IndexError:
                    pass
                imagess = torch.from_numpy(
                    cv2.Laplacian(imagess.cpu().detach().numpy(), -1)).cuda()
                labell = torch.from_numpy(
                    cv2.Laplacian(label.cpu().detach().numpy(), -1)).cuda()
                imagess = imagess.reshape(1, 1, 1280, 720)
                labell = labell.reshape(1, 1, 1280, 720)
                l = bce_loss(imagess, labell)

            loss_laplacian = l
            loss_seg = loss_calc(pred, labels, args.gpu)

            D_out = interp(model_D(F.softmax(pred)))

            loss_adv_pred = bce_loss(D_out,
                                     make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred - loss_laplacian

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            if args.D_remain:
                pred = torch.cat((pred, pred_remain), 0)
                ignore_mask = np.concatenate((ignore_mask, ignore_mask_remain),
                                             axis=0)

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

            # train with gt
            # get gt labels
            try:
                _, batch = trainloader_gt_iter.__next__()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.__next__()

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}, loss_laplacian = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value, loss_laplacian))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'CITY_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'CITY_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'CITY_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'CITY_' + str(i_iter) + '_D.pth'))
            #torch.cuda.empty_cache()

    end = timeit.default_timer()
    print(end - start, 'seconds')
def main():
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu
    np.random.seed(args.random_seed)

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)
    model.train()
    model.cuda(args.gpu)

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # load dataset
    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = np.arange(train_dataset_size)
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        # labeled data
        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=3,
                                      pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=3,
                                         pin_memory=True)

        # unlabeled data
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=3,
                                             pin_memory=True)
        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)
    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):
        loss_seg_value = 0
        loss_adv_pred_value = 0

        loss_D_value = 0
        loss_D_ul_value = 0

        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # creating 2nd discriminator as a copy of the 1st one
        if i_iter == args.discr_split:
            model_D_ul = FCDiscriminator(num_classes=args.num_classes)
            model_D_ul.load_state_dict(net_D.state_dict())
            model_D_ul.train()
            model_D_ul.cuda(args.gpu)

            optimizer_D_ul = optim.Adam(model_D_ul.parameters(),
                                        lr=args.learning_rate_D,
                                        betas=(0.9, 0.99))

        # start training 2nd discriminator after specified number of steps
        if i_iter >= args.discr_split:
            optimizer_D_ul.zero_grad()
            adjust_learning_rate_D(optimizer_D_ul, i_iter)

        for sub_i in range(args.iter_size):
            # train Segmentation

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # don't accumulate grads in D_ul, in case split has already been made
            if i_iter >= args.discr_split:
                for param in model_D_ul.parameters():
                    param.requires_grad = False

            # do semi-supervised training first
            if args.lambda_semi_adv > 0 and i_iter >= args.semi_start_adv:
                try:
                    _, batch = trainloader_remain_iter.next()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.next()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda(args.gpu)

                pred = interp(model(images))
                pred_remain = pred.detach()

                # choose discriminator depending on the iteration
                if i_iter >= args.discr_split:
                    D_out = interp(model_D_ul(F.softmax(pred)))
                else:
                    D_out = interp(model_D(F.softmax(pred)))

                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(
                    axis=1)
                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(
                    np.bool)

                # adversarial loss
                loss_semi_adv = args.lambda_semi_adv * bce_loss(
                    D_out, make_D_label(gt_label, ignore_mask_remain,
                                        args.gpu))
                loss_semi_adv = loss_semi_adv / args.iter_size

                # true loss value without multiplier
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy(
                ) / args.lambda_semi_adv
                loss_semi_adv.backward()

            else:

                loss_semi = None
                loss_semi_adv = None

            # train with labeled images
            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            ignore_mask = (labels.numpy() == 255)

            pred = interp(model(images))
            D_out = interp(model_D(F.softmax(pred)))

            # computing loss
            loss_seg = loss_calc(pred, labels, args.gpu)
            loss_adv_pred = bce_loss(
                D_out, make_D_label(gt_label, ignore_mask, args.gpu))
            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            ) / args.iter_size

            # train D and D_ul

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            if i_iter >= args.discr_split:
                for param in model_D_ul.parameters():
                    param.requires_grad = True

            # train D with pred
            pred = pred.detach()

            # before split, traing D with both labeled and unlabeled
            if args.D_remain and i_iter < args.discr_split and (
                    args.lambda_semi > 0 or args.lambda_semi_adv > 0):
                pred = torch.cat((pred, pred_remain), 0)
                ignore_mask = np.concatenate((ignore_mask, ignore_mask_remain),
                                             axis=0)

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out,
                              make_D_label(pred_label, ignore_mask, args.gpu))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

            # train D_ul with pred on unlabeled
            if i_iter >= args.discr_split and (args.lambda_semi > 0
                                               or args.lambda_semi_adv > 0):
                D_ul_out = interp(model_D_ul(F.softmax(pred_remain)))
                loss_D_ul = bce_loss(
                    D_ul_out,
                    make_D_label(pred_label, ignore_mask_remain, args.gpu))
                loss_D_ul = loss_D_ul / args.iter_size / 2
                loss_D_ul.backward()
                loss_D_ul_value += loss_D_ul.data.cpu().numpy()

            # get gt labels
            try:
                _, batch = trainloader_gt_iter.next()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.next()

            images_gt, labels_gt, _, _ = batch
            images_gt = Variable(images_gt).cuda(args.gpu)
            with torch.no_grad():
                pred_l = interp(model(images_gt))

            # train D with gt
            D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out,
                              make_D_label(gt_label, ignore_mask_gt, args.gpu))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

            # train D_ul with pseudo_gt (gt are substituted for pred)
            if i_iter >= args.discr_split:
                D_ul_out = interp(model_D_ul(F.softmax(pred_l)))
                loss_D_ul = bce_loss(
                    D_ul_out, make_D_label(gt_label, ignore_mask_gt, args.gpu))
                loss_D_ul = loss_D_ul / args.iter_size / 2
                loss_D_ul.backward()
                loss_D_ul_value += loss_D_ul.data.cpu().numpy()

        optimizer.step()
        optimizer_D.step()
        if i_iter >= args.discr_split:
            optimizer_D_ul.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_D_ul={5:.3f}, loss_semi = {6:.3f}, loss_semi_adv = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_D_ul_value,
                    loss_semi_value, loss_semi_adv_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                net.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_' +
                    str(args.random_seed) + '.pth'))
            torch.save(
                net_D.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_' +
                    str(args.random_seed) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                net.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' + str(i_iter) + '_' +
                    str(args.random_seed) + '.pth'))
            torch.save(
                net_D.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' + str(i_iter) + '_' +
                    str(args.random_seed) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Example #10
0
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    model = Res_Deeplab(num_classes=args.num_classes, mode=args.mode)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print name
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    bce_loss = BCEWithLogitsLoss2d()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = VOCClsDataSet(args.data_dir,
                                  args.data_list,
                                  crop_size=input_size,
                                  scale=True,
                                  mirror=True,
                                  mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = range(train_dataset_size)
            np.random.seed(args.seed)
            np.random.shuffle(train_ids)
            #print(train_ids)

        pickle.dump(
            train_ids,
            open(
                osp.join(args.snapshot_dir,
                         'train_id_seed_' + str(args.seed) + '_.pkl'), 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=3,
                                      pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    for i_iter in range(args.num_steps):
        loss_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        for sub_i in range(args.iter_size):
            # train with source

            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, cls_label, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            cls_pred = model(images)

            cls_label = Variable(torch.FloatTensor(cls_label)).cuda(args.gpu)
            loss = bce_loss(torch.unsqueeze(torch.unsqueeze(cls_pred, 2), 3),
                            cls_label)

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_value += loss.data.cpu().numpy() / args.iter_size

        optimizer.step()

        print('iter = {0:8d}/{1:8d}, loss = {2:.3f}'.format(
            i_iter, args.num_steps, loss_value))

        if i_iter >= args.num_steps - 1:
            print 'save model ...'
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_classifier_VOCcls_pd_' +
                    str(args.partial_data) + '_seed_' + str(args.seed) + '_' +
                    str(args.num_steps) + '.pth'))

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print 'taking snapshot ...'
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'VOC_classifier_VOCcls_pd_' + str(args.partial_data) +
                    '_seed_' + str(args.seed) + '_' + str(i_iter) + '.pth'))

    end = timeit.default_timer()
    print end - start, 'seconds'
Example #11
0
def main():
    models = {
        'resnet101':
        lambda: PSPNet(n_classes=21,
                       sizes=(1, 2, 3, 6),
                       psp_size=2048,
                       deep_features_size=1024,
                       backend='resnet101')
    }

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True

    # create network
    model = models['resnet101']()

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        #saved_state_dict = torch.load(args.restore_from)
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model = nn.DataParallel(model)
    model.cuda()

    cudnn.benchmark = True

    # init D
    # model_D = FCDiscriminator(num_classes=args.num_classes)
    # if args.restore_from_D is not None:
    #     model_D.load_state_dict(torch.load(args.restore_from_D))
    #
    # model_D = nn.DataParallel(model_D)
    # model_D.train()
    # model_D.cuda()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data == 0:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    # optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99))
    # optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    seg_criterion = NLL2d().cuda()
    cls_criterion = nn.BCEWithLogitsLoss(weight=None)

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        # optimizer_D.zero_grad()
        # adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # train with source

            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _, y_cls = batch
            labels = Variable(labels.long()).cuda()

            y_cls = Variable(y_cls.float()).cuda()

            images = Variable(images).cuda()
            #ignore_mask = (labels.numpy() == 255)
            out, out_cls = model(images)

            seg_loss, cls_loss = seg_criterion(out, labels), cls_criterion(
                out_cls, y_cls)

            loss = seg_loss + cls_loss

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += seg_loss.data.cpu().numpy()[0] / args.iter_size

        optimizer.step()
        # optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' +
                    os.path.abspath(__file__).split('/')[-1].split('.')[0] +
                    '_' + str(args.num_steps) + '.pth'))
            #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' +
                    os.path.abspath(__file__).split('/')[-1].split('.')[0] +
                    '_' + str(i_iter) + '.pth'))
            #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Example #12
0
def main():
    # 将参数的input_size 映射到整数,并赋值,从字符串转换到整数二元组
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = False
    gpu = args.gpu

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    # 确保模型中参数的格式与要加载的参数相同
    # 返回一个字典,保存着module的所有状态(state);parameters和persistent buffers都会包含在字典中,字典的key就是parameter和buffer的 names。
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        # print (name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            # print('copy {}'.format(name))
    model.load_state_dict(new_params)

    # 设置为训练模式
    model.train()
    cudnn.benchmark = True

    model.cuda(gpu)

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)
    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        # sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = list(range(train_dataset_size))  # ?
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'),
                         'wb'))  # 写入文件

        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=3,
                                      pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=3,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=3,
                                         pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear')  # ???

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):
        print("Iter:", i_iter)
        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if args.lambda_semi > 0 and i_iter >= args.semi_start:
                try:
                    _, batch = next(trainloader_remain_iter)
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = next(trainloader_remain_iter)

                # only access to img
                images, _, _, _ = batch

                images = Variable(images).cuda(gpu)
                # images = Variable(images).cpu()

                pred = interp(model(images))
                D_out = interp(model_D(F.softmax(pred)))

                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(
                    axis=1)

                # produce ignore mask
                semi_ignore_mask = (D_out_sigmoid < args.mask_T)

                semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                semi_gt[semi_ignore_mask] = 255

                semi_ratio = 1.0 - float(
                    semi_ignore_mask.sum()) / semi_ignore_mask.size
                print('semi ratio: {:.4f}'.format(semi_ratio))

                if semi_ratio == 0.0:
                    loss_semi_value += 0
                else:
                    semi_gt = torch.FloatTensor(semi_gt)

                    loss_semi = args.lambda_semi * loss_calc(
                        pred, semi_gt, args.gpu)
                    loss_semi = loss_semi / args.iter_size
                    loss_semi.backward()
                    loss_semi_value += loss_semi.data.cpu().numpy(
                    )[0] / args.lambda_semi
            else:
                loss_semi = None

            # train with source

            try:
                _, batch = next(trainloader_iter)
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = next(trainloader_iter)

            images, labels, _, _ = batch

            images = Variable(images).cuda(gpu)
            # images = Variable(images).cpu()

            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))

            loss_seg = loss_calc(pred, labels, args.gpu)

            D_out = interp(model_D(F.softmax(pred)))

            loss_adv_pred = bce_loss(D_out,
                                     make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            )[0] / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]

            # train with gt
            # get gt labels
            try:
                _, batch = next(trainloader_gt_iter)
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = next(trainloader_gt_iter)

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            # D_gt_v = Variable(one_hot(labels_gt)).cpu()
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Example #13
0
def main():
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu
    np.random.seed(args.random_seed)

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # load dataset
    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = np.arange(train_dataset_size)
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        # labeled data
        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=3,
                                      pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=3,
                                         pin_memory=True)

        # unlabeled data
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=3,
                                             pin_memory=True)
        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # loss/bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    for i_iter in range(args.num_steps):
        loss_seg_value = 0
        loss_unlabeled_seg_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        for sub_i in range(args.iter_size):
            # train Segmentation
            # train with labeled images
            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            pred = interp(model(images))

            # computing loss
            loss_seg = loss_calc(pred, labels, args.gpu)

            # proper normalization
            loss = loss_seg / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size

            # train with unlabeled
            if args.lambda_semi > 0 and i_iter >= args.semi_start:
                try:
                    _, batch = trainloader_remain_iter.next()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.next()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda(args.gpu)

                pred = interp(model(images))
                semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                semi_gt = torch.FloatTensor(semi_gt)

                loss_unlabeled_seg = args.lambda_semi * loss_calc(
                    pred, semi_gt, args.gpu)
                loss_unlabeled_seg = loss_unlabeled_seg / args.iter_size
                loss_unlabeled_seg.backward()

                loss_unlabeled_seg_value += loss_unlabeled_seg.data.cpu(
                ).numpy() / args.lambda_semi

            else:
                if args.lambda_semi > 0 and i_iter < args.semi_start:
                    loss_unlabeled_seg_value = 0

                else:
                    loss_unlabeled_seg_value = None

        optimizer.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_unlabeled_seg = {3:.3f} '
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_unlabeled_seg_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_' +
                    str(args.lambda_semi) + '_' + str(args.random_seed) +
                    '.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'VOC_' + str(i_iter) + '_' + str(args.lambda_semi) + '_' +
                    str(args.random_seed) + '.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Example #14
0
def main():
    h, w = list(map(int, args.input_size.split(',')))   # 321, 321
    input_size = (h, w)

    cudnn.enabled = True

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)   # num_classes = 21

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in list(new_params.items()):
        # print(name)
        if name in saved_state_dict and param.size() == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            # print(('copy {}'.format(name)))
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)   # num_classes = 21,全卷积判别模型
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))

    model_D.train()
    model_D.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    if not os.path.exists('logs/'):
        os.makedirs('logs/')
    now_time = datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
    log_file = 'logs/' + now_time + '.txt'
    file = open(log_file, 'w')      # 保存loss

    train_dataset = VOCDataSet(args.data_dir, args.data_list, args.label_list, crop_size=input_size,
                               scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, args.label_list, crop_size=input_size,
                                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    if args.partial_data is None:   # 使用全部数据
        trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,  # batch_size = 10
                                      num_workers=5, pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True,
                                         num_workers=5, pin_memory=True)
    else:
        # sample partial data 部分数据
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print(('loading train ids from {}'.format(args.partial_id)))
        else:
            train_ids = list(range(train_dataset_size))
            np.random.shuffle(train_ids)

        pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))    # 将train_ids写入train_id.pkl

        train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,    # 数据集中采样输入
                        batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                        batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    # interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1
    best_loss = 1
    best_epoch = 0

    for i_iter in range(args.num_steps):    # num_steps = 20000
        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):     # iter_size = 1
            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0) and i_iter >= args.semi_start_adv:
                try:
                    _, batch = next(trainloader_remain_iter)
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = next(trainloader_remain_iter)

                # only access to img 无标签数据
                images, _, _, _, _ = batch
                images = Variable(images).cuda(args.gpu)

                pred = interp(model(images))
                pred_remain = pred.detach()     # 返回一个新的Variable,不具有grade

                D_out = interp(model_D(F.softmax(pred)))
                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1)

                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool)

                loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, make_D_label(gt_label, ignore_mask_remain))
                loss_semi_adv = loss_semi_adv/args.iter_size

                # loss_semi_adv.backward()

                # print('bug,', loss_semi_adv.data.cpu().numpy())

                # loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()[0]/args.lambda_semi_adv
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()/args.lambda_semi_adv

                if args.lambda_semi <= 0 or i_iter < args.semi_start:
                    loss_semi_adv.backward()
                    loss_semi_value = 0
                else:
                    # produce ignore mask
                    semi_ignore_mask = (D_out_sigmoid < args.mask_T)    # mask_T = 0.2,阈值

                    semi_gt = pred.data.cpu().numpy().argmax(axis=1)    # 返回维度为1上的最大值的下标
                    semi_gt[semi_ignore_mask] = 255

                    semi_ratio = 1.0 - float(semi_ignore_mask.sum())/semi_ignore_mask.size  # 被忽略的点占的比重
                    print(('semi ratio: {:.4f}'.format(semi_ratio)))

                    if semi_ratio == 0.0:
                        loss_semi_value += 0
                    else:
                        semi_gt = torch.FloatTensor(semi_gt)

                        loss_semi = args.lambda_semi * loss_calc(pred, semi_gt, args.gpu)
                        loss_semi = loss_semi/args.iter_size
                        # loss_semi_value += loss_semi.data.cpu().numpy()[0]/args.lambda_semi
                        loss_semi_value += loss_semi.data.cpu().numpy()/args.lambda_semi
                        loss_semi += loss_semi_adv
                        loss_semi.backward()

            else:
                loss_semi = None
                loss_semi_adv = None

            # train with source

            try:
                _, batch = next(trainloader_iter)
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = next(trainloader_iter)

            images, labels, _, _, _ = batch    # 有标签数据
            images = Variable(images).cuda(args.gpu)
            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))    # interp上采样

            loss_seg = loss_calc(pred, labels, args.gpu)    # 语义分割的cross entropy loss
            # loss_seg_NLL = loss_NLL(pred, labels, args.gpu)     # 语义分割的NLLLoss

            D_out = interp(model_D(F.softmax(pred)))    # 得到判别模型输出的判别图

            loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss/args.iter_size
            loss.backward()

            # loss_seg_value += loss_seg.data.cpu().numpy()[0]/args.iter_size
            # loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()[0]/args.iter_size

            loss_seg_value += loss_seg.data.cpu().numpy()/args.iter_size
            # loss_seg_value += loss_seg_NLL.data.cpu().numpy()/args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()/args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            if args.D_remain:
                pred = torch.cat((pred, pred_remain), 0)
                ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis=0)

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            # loss_D_value += loss_D.data.cpu().numpy()[0]
            loss_D_value += loss_D.data.cpu().numpy()

            # train with gt
            # get gt labels
            try:
                _, batch = next(trainloader_gt_iter)
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = next(trainloader_gt_iter)

            _, labels_gt, _, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)    # 每个类别一张label图,batch * class * h * w
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))     # ground_truth输入判别模型
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            # loss_D_value += loss_D.data.cpu().numpy()[0]
            loss_D_value += loss_D.data.cpu().numpy()

        optimizer.step()
        optimizer_D.step()

        print(('exp = {}'.format(args.snapshot_dir)))
        print(('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, '
               'loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'.format(i_iter, args.num_steps, loss_seg_value,
                                                                     loss_adv_pred_value, loss_D_value, loss_semi_value,
                                                                     loss_semi_adv_value)))

        file.write('{0} {1} {2} {3} {4}\n'.format(loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value,
                                                loss_semi_adv_value))

        if loss_seg_value < best_loss:      # 保存最优模型,删除次优模型
            # print('loss:', loss_seg_value, 'best:', best_loss)
            torch.save(model.state_dict(), osp.join(args.snapshot_dir,
                                                    'VOC_epoch_{0}_seg_loss_{1}.pth'.format(i_iter+1, loss_seg_value)))
            torch.save(model_D.state_dict(), osp.join(args.snapshot_dir,
                                                      'VOC_epoch_{0}_seg_loss_{1}_D.pth'.format(i_iter+1, loss_seg_value)))
            delete_models(best_epoch + 1, best_loss)
            best_loss = loss_seg_value
            best_epoch = i_iter

        if i_iter >= args.num_steps-1:  # num_step = 20000
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'VOC_'+str(args.num_steps)+'.pth'))
            torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_'+str(args.num_steps)+'_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:  # save_pred_every = 5000
            print('taking snapshot ...')
            torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+str(i_iter)+'.pth'))
            torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+str(i_iter)+'_D.pth'))

    end = timeit.default_timer()
    print(end-start, 'seconds')
    file.close()
Example #15
0
def main():

    # parse input size
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    # cudnn.enabled = True
    # gpu = args.gpu

    # create segmentation network
    model = DeepLab(num_classes=args.num_classes)

    # load pretrained parameters
    # if args.restore_from[:4] == 'http' :
    #     saved_state_dict = model_zoo.load_url(args.restore_from)
    # else:
    #     saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    # new_params = model.state_dict().copy()
    # for name, param in new_params.items():
    #     if name in saved_state_dict and param.size() == saved_state_dict[name].size():
    #         new_params[name].copy_(saved_state_dict[name])
    # model.load_state_dict(new_params)

    model.train()
    model.cpu()
    # model.cuda(args.gpu)
    # cudnn.benchmark = True

    # create discriminator network
    model_D = Discriminator(num_classes=args.num_classes)
    # if args.restore_from_D is not None:
    #     model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cpu()
    # model_D.cuda(args.gpu)

    # MILESTONE 1
    print("Printing MODELS ...")
    print(model)
    print(model_D)

    # Create directory to save snapshots of the model
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # Load train data and ground truth labels
    # train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size,
    #                 scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)
    # train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size,
    #                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    # trainloader = data.DataLoader(train_dataset,
    #                 batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=False)
    # trainloader_gt = data.DataLoader(train_gt_dataset,
    #                 batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=False)

    train_dataset = MyCustomDataset()
    train_gt_dataset = MyCustomDataset()

    trainloader = data.DataLoader(train_dataset, batch_size=5, shuffle=True)
    trainloader_gt = data.DataLoader(train_gt_dataset,
                                     batch_size=5,
                                     shuffle=True)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # MILESTONE 2
    print("Printing Loaders")
    print(trainloader_iter)
    print(trainloader_gt_iter)

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # MILESTONE 3
    print("Printing OPTIMIZERS ...")
    print(optimizer)
    print(optimizer_D)

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            # if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv :
            #     try:
            #         _, batch = next(trainloader_remain_iter)
            #     except:
            #         trainloader_remain_iter = enumerate(trainloader_remain)
            #         _, batch = next(trainloader_remain_iter)

            #     # only access to img
            #     images, _, _, _ = batch
            #     images = Variable(images).cuda(args.gpu)

            #     pred = interp(model(images))
            #     pred_remain = pred.detach()

            #     D_out = interp(model_D(F.softmax(pred)))
            #     D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1)

            #     ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool)

            #     loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, make_D_label(gt_label, ignore_mask_remain))
            #     loss_semi_adv = loss_semi_adv/args.iter_size

            #     #loss_semi_adv.backward()
            #     loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()/args.lambda_semi_adv

            #     if args.lambda_semi <= 0 or i_iter < args.semi_start:
            #         loss_semi_adv.backward()
            #         loss_semi_value = 0
            #     else:
            #         # produce ignore mask
            #         semi_ignore_mask = (D_out_sigmoid < args.mask_T)

            #         semi_gt = pred.data.cpu().numpy().argmax(axis=1)
            #         semi_gt[semi_ignore_mask] = 255

            #         semi_ratio = 1.0 - float(semi_ignore_mask.sum())/semi_ignore_mask.size
            #         print('semi ratio: {:.4f}'.format(semi_ratio))

            #         if semi_ratio == 0.0:
            #             loss_semi_value += 0
            #         else:
            #             semi_gt = torch.FloatTensor(semi_gt)

            #             loss_semi = args.lambda_semi * loss_calc(pred, semi_gt, args.gpu)
            #             loss_semi = loss_semi/args.iter_size
            #             loss_semi_value += loss_semi.data.cpu().numpy()/args.lambda_semi
            #             loss_semi += loss_semi_adv
            #             loss_semi.backward()

            # else:
            #     loss_semi = None
            #     loss_semi_adv = None

            # train with source

            try:
                _, batch = next(trainloader_iter)
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = next(trainloader_iter)

            images, labels, _, _ = batch
            images = Variable(images).cpu()
            # images = Variable(images).cuda(args.gpu)
            ignore_mask = (labels.numpy() == 255)

            # segmentation prediction
            pred = interp(model(images))
            # (spatial multi-class) cross entropy loss
            loss_seg = loss_calc(pred, labels)
            # loss_seg = loss_calc(pred, labels, args.gpu)

            # discriminator prediction
            D_out = interp(model_D(F.softmax(pred)))
            # adversarial loss
            loss_adv_pred = bce_loss(D_out,
                                     make_D_label(gt_label, ignore_mask))

            # multi-task loss
            # lambda_adv - weight for minimizing loss
            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # loss normalization
            loss = loss / args.iter_size

            # back propagation
            loss.backward()

            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            # if args.D_remain:
            #     pred = torch.cat((pred, pred_remain), 0)
            #     ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0)

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

            # train with gt
            # get gt labels
            try:
                _, batch = next(trainloader_gt_iter)
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = next(trainloader_gt_iter)

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cpu()
            # D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')