示例#1
0
    def __init__(self, model_path):
        self.UBR = UBR_VGG()
        self.UBR.create_architecture()
        print("loading checkpoint %s" % (model_path))
        checkpoint = torch.load(model_path)
        self.UBR.load_state_dict(checkpoint['model'])

        self.UBR.cuda()
        self.UBR.eval()
示例#2
0
文件: ubr_wrapper.py 项目: djkim3/UBR
    def __init__(self, model_path):
        print("loading checkpoint %s" % (model_path))
        checkpoint = torch.load(model_path)
        if checkpoint['net'] == 'UBR_VGG':
            self.UBR = UBR_VGG(None, False, True, True)
        elif checkpoint['net'] == 'UBR_RES':
            self.UBR = UBR_RES(None, 1, not args.fc)
        elif checkpoint['net'] == 'UBR_RES_FC2':
            self.UBR = UBR_RES_FC2(None, 1)
        elif checkpoint['net'] == 'UBR_RES_FC3':
            self.UBR = UBR_RES_FC3(None, 1)

        self.UBR.create_architecture()
        self.UBR.load_state_dict(checkpoint['model'])

        self.UBR.cuda()
        self.UBR.eval()
示例#3
0
class UBRWrapper:
    def __init__(self, model_path):
        self.UBR = UBR_VGG()
        self.UBR.create_architecture()
        print("loading checkpoint %s" % (model_path))
        checkpoint = torch.load(model_path)
        self.UBR.load_state_dict(checkpoint['model'])

        self.UBR.cuda()
        self.UBR.eval()

    # raw_img = h * y * 3 rgb image
    # bbox = n * 4 bounding boxes
    # return n * 4 refined boxes
    def query(self, raw_img, bbox):
        data, rois, im_scale = preprocess(raw_img, bbox)
        new_rois = torch.zeros((bbox.shape[0], 5))
        new_rois[:, 1:] = rois[:, :]
        rois = new_rois
        data = Variable(data.unsqueeze(0).cuda())
        rois = Variable(rois.cuda())
        bbox_pred = self.UBR(data, rois)
        refined_boxes = inverse_transform(rois[:, 1:].data.cpu(),
                                          bbox_pred.data.cpu())
        refined_boxes /= im_scale
        ret = np.zeros((refined_boxes.size(0), 4))
        ret[:, 0] = refined_boxes[:,
                                  0].clamp(min=0,
                                           max=raw_img.shape[1] - 1).numpy()
        ret[:, 1] = refined_boxes[:,
                                  1].clamp(min=0,
                                           max=raw_img.shape[0] - 1).numpy()
        ret[:, 2] = refined_boxes[:,
                                  2].clamp(min=0,
                                           max=raw_img.shape[1] - 1).numpy()
        ret[:, 3] = refined_boxes[:,
                                  3].clamp(min=0,
                                           max=raw_img.shape[0] - 1).numpy()

        return ret
示例#4
0
def train():
    args = parse_args()

    print('Called with args:')
    print(args)
    np.random.seed(4)
    torch.manual_seed(2017)
    torch.cuda.manual_seed(1086)

    output_dir = args.save_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    train_dataset = TDetDataset([args.dataset + '_train'],
                                training=True,
                                multi_scale=args.multiscale,
                                rotation=args.rotation,
                                pd=args.pd,
                                warping=args.warping,
                                prop_method=args.prop_method,
                                prop_min_scale=args.prop_min_scale,
                                prop_topk=args.prop_topk)
    val_dataset = TDetDataset([args.dataset + '_val'], training=False)
    tval_dataset = TDetDataset(['coco_voc_val'], training=False)

    lr = args.lr

    res_path = 'data/pretrained_model/resnet101_caffe.pth'
    vgg_path = 'data/pretrained_model/vgg16_caffe.pth'
    if args.net == 'UBR_VGG':
        UBR = UBR_VGG(vgg_path, not args.fc, not args.not_freeze,
                      args.no_dropout)
    elif args.net == 'UBR_RES':
        UBR = UBR_RES(res_path, 1, not args.fc)
    elif args.net == 'UBR_RES_FC2':
        UBR = UBR_RES_FC2(res_path, 1)
    elif args.net == 'UBR_RES_FC3':
        UBR = UBR_RES_FC3(res_path, 1)

    else:
        print("network is not defined")
        pdb.set_trace()

    UBR.create_architecture()

    params = []
    for key, value in dict(UBR.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                params += [{
                    'params': [value],
                    'lr': lr * 2,
                    'weight_decay': 0
                }]
            else:
                params += [{
                    'params': [value],
                    'lr': lr,
                    'weight_decay': 0 if args.no_wd else 0.0005
                }]

    optimizer = torch.optim.SGD(params, momentum=0.9)

    patience = 0
    last_optima = 999
    if args.resume:
        load_name = os.path.join(
            output_dir, '{}_{}_{}.pth'.format(args.net, args.checksession,
                                              args.checkepoch))
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name)
        assert args.net == checkpoint['net']
        args.start_epoch = checkpoint['epoch']
        UBR.load_state_dict(checkpoint['model'])
        if not args.no_optim:
            if 'patience' in checkpoint:
                patience = checkpoint['patience']
            if 'last_optima' in checkpoint:
                last_optima = checkpoint['last_optima']
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr = optimizer.param_groups[0]['lr']
        print("loaded checkpoint %s" % (load_name))

    log_file_name = os.path.join(
        output_dir, 'log_{}_{}.txt'.format(args.net, args.session))
    if args.resume:
        log_file = open(log_file_name, 'a')
    else:
        log_file = open(log_file_name, 'w')
    log_file.write(str(args))
    log_file.write('\n')

    UBR.cuda()

    if args.loss == 'smoothl1':
        criterion = UBR_SmoothL1Loss(args.iou_th)
    elif args.loss == 'iou':
        criterion = UBR_IoULoss(args.iou_th)

    if not args.use_prop:
        random_box_generator = NaturalUniformBoxGenerator(
            args.iou_th,
            pos_th=args.alpha,
            scale_min=1 - args.beta,
            scale_max=1 + args.beta)

    for epoch in range(args.start_epoch, args.max_epochs + 1):
        # setting to train mode
        UBR.train()
        loss_temp = 0

        effective_iteration = 0
        start = time.time()

        mean_boxes_per_iter = 0
        rand_perm = np.random.permutation(len(train_dataset))
        for step in range(1, len(train_dataset) + 1):
            index = rand_perm[step - 1]
            im_data, gt_boxes, box_labels, proposals, prop_scores, image_level_label, im_scale, raw_img, im_id, _ = train_dataset[
                index]

            data_height = im_data.size(1)
            data_width = im_data.size(2)
            im_data = Variable(im_data.unsqueeze(0).cuda())
            num_gt_box = gt_boxes.size(0)
            UBR.zero_grad()

            # generate random box from given gt box
            # the shape of rois is (n, 5), the first column is not used
            # so, rois[:, 1:5] is [xmin, ymin, xmax, ymax]
            num_per_base = 50
            if num_gt_box > 4:
                num_per_base = 200 // num_gt_box

            if args.use_prop:
                proposals = sample_pos_prop(proposals, gt_boxes, args.iou_th)
                if proposals is None:
                    # log_file.write('@@@@ no box @@@@\n')
                    # print('@@@@@ no box @@@@@')
                    continue

                rois = torch.zeros((proposals.size(0), 5))
                rois[:, 1:] = proposals
            else:
                rois = torch.zeros((num_per_base * num_gt_box, 5))
                cnt = 0
                for i in range(num_gt_box):
                    here = random_box_generator.get_rand_boxes(
                        gt_boxes[i, :], num_per_base, data_height, data_width)
                    if here is None:
                        continue
                    rois[cnt:cnt + here.size(0), :] = here
                    cnt += here.size(0)
                if cnt == 0:
                    log_file.write('@@@@ no box @@@@\n')
                    print('@@@@@ no box @@@@@')
                    continue
                rois = rois[:cnt, :]

            plt.imshow(raw_img)
            plt.show()
            continue

            mean_boxes_per_iter += rois.size(0)
            rois = Variable(rois.cuda())
            gt_boxes = Variable(gt_boxes.cuda())

            bbox_pred, shared_feat = UBR(im_data, rois)

            #refined_boxes = inverse_transform(rois[:, 1:].data, bbox_pred.data)
            #plt.imshow(raw_img)
            #draw_box(rois[:, 1:].data / im_scale)
            #draw_box(refined_boxes / im_scale, 'yellow')
            #draw_box(gt_boxes.data / im_scale, 'black')
            #plt.show()
            loss, num_selected_rois, num_rois, refined_rois = criterion(
                rois[:, 1:5], bbox_pred, gt_boxes)

            if loss is None:
                loss_temp = 1000000
                loss = Variable(torch.zeros(1).cuda())
                print('zero mached')

            loss = loss.mean()
            loss_temp += loss.data[0]

            # backward
            optimizer.zero_grad()

            loss.backward()
            clip_gradient([UBR], 10.0)

            optimizer.step()
            effective_iteration += 1

            if step % args.disp_interval == 0:
                end = time.time()
                loss_temp /= effective_iteration
                mean_boxes_per_iter /= effective_iteration

                print(
                    "[net %s][session %d][epoch %2d][iter %4d] loss: %.4f, lr: %.2e, time: %.1f, boxes: %.1f"
                    % (args.net, args.session, epoch, step, loss_temp, lr,
                       end - start, mean_boxes_per_iter))
                log_file.write(
                    "[net %s][session %d][epoch %2d][iter %4d] loss: %.4f, lr: %.2e, time: %.1f, boxes: %.1f\n"
                    % (args.net, args.session, epoch, step, loss_temp, lr,
                       end - start, mean_boxes_per_iter))
                loss_temp = 0
                effective_iteration = 0
                mean_boxes_per_iter = 0
                start = time.time()

            if math.isnan(loss_temp):
                print('@@@@@@@@@@@@@@nan@@@@@@@@@@@@@')
                log_file.write('@@@@@@@nan@@@@@@@@\n')
                return

        val_loss = validate(UBR,
                            None if args.use_prop else random_box_generator,
                            criterion, val_dataset, args)
        tval_loss = validate(UBR,
                             None if args.use_prop else random_box_generator,
                             criterion, tval_dataset, args)
        print('[net %s][session %d][epoch %2d] validation loss: %.4f' %
              (args.net, args.session, epoch, val_loss))
        log_file.write(
            '[net %s][session %d][epoch %2d] validation loss: %.4f\n' %
            (args.net, args.session, epoch, val_loss))
        print(
            '[net %s][session %d][epoch %2d] transfer validation loss: %.4f' %
            (args.net, args.session, epoch, tval_loss))
        log_file.write(
            '[net %s][session %d][epoch %2d] transfer validation loss: %.4f\n'
            % (args.net, args.session, epoch, tval_loss))

        log_file.flush()

        if args.auto_decay:
            if last_optima - val_loss < 0.001:
                patience += 1
            if last_optima > val_loss:
                last_optima = val_loss

            if patience >= args.decay_patience:
                adjust_learning_rate(optimizer, args.lr_decay_gamma)
                lr *= args.lr_decay_gamma
                patience = 0
        else:
            if epoch % args.lr_decay_step == 0:
                adjust_learning_rate(optimizer, args.lr_decay_gamma)
                lr *= args.lr_decay_gamma

        if epoch % args.save_interval == 0 or lr < 0.000005:
            save_name = os.path.join(
                output_dir, '{}_{}_{}.pth'.format(args.net, args.session,
                                                  epoch))
            checkpoint = dict()
            checkpoint['net'] = args.net
            checkpoint['session'] = args.session
            checkpoint['epoch'] = epoch + 1
            checkpoint['model'] = UBR.state_dict()
            checkpoint['optimizer'] = optimizer.state_dict()
            checkpoint['patience'] = patience
            checkpoint['last_optima'] = last_optima

            save_checkpoint(checkpoint, save_name)
            print('save model: {}'.format(save_name))

        if lr < 0.000005:
            break

    log_file.close()
示例#5
0
    np.random.seed(10)

    output_dir = args.save_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    dataset = VOCDetection('./data/VOCdevkit2007', [('2007', 'test')])
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=args.num_workers, shuffle=False)

    load_name = os.path.abspath(args.model_path)
    print("loading checkpoint %s" % (load_name))
    checkpoint = torch.load(load_name)

    # initilize the network here.
    if checkpoint['net'] == 'UBR_VGG':
        UBR = UBR_VGG()
    else:
        print("network is not defined")
        pdb.set_trace()

    UBR.create_architecture()
    UBR.load_state_dict(checkpoint['model'])
    print("loaded checkpoint %s" % (load_name))

    output_file_name = os.path.join(output_dir, 'eval_{}_{}_{}.txt'.format(checkpoint['net'], checkpoint['session'], checkpoint['epoch'] - 1))
    output_file = open(output_file_name, 'w')
    output_file.write(str(args))
    output_file.write('\n')

    UBR.cuda()
    UBR.eval()
示例#6
0
    output_dir = args.save_dir + "/" + args.net
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    dataset = COCODataset(
        './data/coco/annotations/instances_train2014_subtract_voc.json',
        './data/coco/images/train2014/',
        training=True)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             num_workers=args.num_workers,
                                             shuffle=True)

    # initilize the network here.
    if args.net == 'vgg16':
        UBR = UBR_VGG()
    else:
        print("network is not defined")
        pdb.set_trace()

    UBR.create_architecture()

    lr = args.lr

    params = []
    for key, value in dict(UBR.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                params += [{
                    'params': [value],
                    'lr': lr * 2,
示例#7
0
def extract_feature():
    args = parse_args()

    print('Called with args:')
    print(args)
    np.random.seed(3)
    torch.manual_seed(2016)
    torch.cuda.manual_seed(1085)

    output_dir = args.save_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    args.train_anno = './data/coco/annotations/coco60_train_21413_61353.json'
    args.val_anno = './data/coco/annotations/coco60_val_900_2575.json'
    args.tval_anno = './data/coco/annotations/voc20_val_740_2844.json'

    train_dataset = COCODataset(args.train_anno,
                                args.train_images,
                                training=True,
                                multi_scale=False)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=1,
        num_workers=args.num_workers,
        shuffle=True)
    val_dataset = COCODataset(args.val_anno,
                              args.val_images,
                              training=False,
                              multi_scale=False)
    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=1,
                                                 num_workers=args.num_workers,
                                                 shuffle=False)
    tval_dataset = COCODataset(args.tval_anno,
                               args.val_images,
                               training=False,
                               multi_scale=False)
    tval_dataloader = torch.utils.data.DataLoader(tval_dataset,
                                                  batch_size=1,
                                                  num_workers=args.num_workers,
                                                  shuffle=False)

    load_name = os.path.abspath(args.model_path)
    print("loading checkpoint %s" % (load_name))
    checkpoint = torch.load(load_name)

    # initilize the network here.
    if checkpoint['net'] == 'UBR_VGG':
        UBR = UBR_VGG()
    else:
        print("network is not defined")
        pdb.set_trace()

    UBR.create_architecture()
    UBR.load_state_dict(checkpoint['model'])
    print("loaded checkpoint %s" % (load_name))

    UBR.cuda()
    UBR.eval()

    random_box_generator = UniformBoxGenerator(0.5)
    extracted_features = []
    feature_labels = []

    data_iter = iter(tval_dataloader)
    for step in range(1, len(tval_dataset) + 1):
        im_data, gt_boxes, gt_labels, data_height, data_width, im_scale, raw_img, im_id = next(
            data_iter)
        raw_img = raw_img.squeeze().numpy()
        gt_labels = gt_labels[0, :]
        gt_boxes = gt_boxes[0, :, :]
        data_height = data_height[0]
        data_width = data_width[0]
        im_scale = im_scale[0]
        im_id = im_id[0]
        im_data = Variable(im_data.cuda())
        num_gt_box = gt_boxes.size(0)
        UBR.zero_grad()

        # generate random box from given gt box
        # the shape of rois is (n, 5), the first column is not used
        # so, rois[:, 1:5] is [xmin, ymin, xmax, ymax]
        num_per_base = 1

        rois = torch.zeros((num_per_base * num_gt_box, 5))
        cnt = 0
        cnt_per_base = []
        for i in range(num_gt_box):
            here = random_box_generator.get_rand_boxes(gt_boxes[i, :],
                                                       num_per_base,
                                                       data_height, data_width)
            if here is None:
                continue
            rois[cnt:cnt + here.size(0), :] = here
            cnt += here.size(0)
            cnt_per_base.append(here.size(0))
        if cnt == 0:
            print('@@@@@ no box @@@@@')
            continue
        rois = rois[:cnt, :]
        rois = Variable(rois.cuda())

        bbox_pred, shared_feat = UBR(im_data, rois)
        shared_feat = shared_feat.view(rois.size(0), -1)

        begin_idx = 0
        for i, c in enumerate(cnt_per_base):
            label = gt_labels[i]
            for j in range(c):
                feat = shared_feat[begin_idx + j, :].data.cpu().numpy()
                extracted_features.append(feat)
                feature_labels.append(label)
            begin_idx += c

        print(step)
        #refined_boxes = inverse_transform(rois[:, 1:].data, bbox_pred.data)
        # plt.imshow(raw_img)
        # draw_box(rois[:, 1:].data / im_scale)
        #draw_box(refined_boxes / im_scale, 'yellow')
        # draw_box(gt_boxes.data / im_scale, 'black')
        # plt.show()

    pickle.dump({
        'label': feature_labels,
        'feature': extracted_features
    }, open('cal_tval_pooled_features', 'wb'))
示例#8
0
def train():
    args = parse_args()

    print('Called with args:')
    print(args)
    np.random.seed(4)
    torch.manual_seed(2017)
    torch.cuda.manual_seed(1086)

    output_dir = args.save_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    source_train_dataset = TDetDataset(['coco60_train'], training=False)
    target_train_dataset = TDetDataset(['voc07_trainval'], training=False)
    val_dataset = TDetDataset(['coco60_val'], training=False)
    tval_dataset = TDetDataset(['coco_voc_val'], training=False)

    lr = args.lr

    if args.net == 'UBR_VGG':
        UBR = UBR_VGG(None, not args.fc, not args.not_freeze, args.no_dropout)
    elif args.net == 'UBR_AUG':
        UBR = UBR_AUG(None, no_dropout=args.no_dropout)
    elif args.net == 'UBR_TANH0':
        UBR = UBR_TANH(0, None, not args.fc, not args.not_freeze, args.no_dropout)
    elif args.net == 'UBR_TANH1':
        UBR = UBR_TANH(1, None, not args.fc, not args.not_freeze, args.no_dropout)
    elif args.net == 'UBR_TANH2':
        UBR = UBR_TANH(2, None, not args.fc, not args.not_freeze, args.no_dropout)

    else:
        print("network is not defined")
        pdb.set_trace()

    D = BoxDiscriminator(args.dim)

    UBR.create_architecture()

    paramsG = []
    for key, value in dict(UBR.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                paramsG += [{'params': [value], 'lr': lr * 2, 'weight_decay': 0}]
            else:
                paramsG += [{'params': [value], 'lr': lr, 'weight_decay': 0}]

    paramsD = []
    for key, value in dict(D.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                paramsD += [{'params': [value], 'lr': lr * 2, 'weight_decay': 0}]
            else:
                paramsD += [{'params': [value], 'lr': lr, 'weight_decay': 0}]

    if args.optim == 'sgd':
        optimizerG = torch.optim.SGD(paramsG, momentum=0.9)
        optimizerD = torch.optim.SGD(paramsD, momentum=0.9)
    elif args.optim == 'adam':
        optimizerG = torch.optim.Adam(paramsG, lr=lr, betas=(0.5, 0.9))
        optimizerD = torch.optim.Adam(paramsD, lr=lr, betas=(0.5, 0.9))

    load_name = args.pretrained_model
    print("loading checkpoint %s" % (load_name))
    checkpoint = torch.load(load_name)
    UBR.load_state_dict(checkpoint['model'])
    #optimizer.load_state_dict(checkpoint['optimizer'])
    print("loaded checkpoint %s" % (load_name))

    log_file_name = os.path.join(output_dir, 'log_{}_{}.txt'.format(args.net, args.session))
    log_file = open(log_file_name, 'w')
    log_file.write(str(args))
    log_file.write('\n')


    UBR.cuda()
    D.cuda()

    # setting to train mode
    UBR.train()
    D.train()

    lossG_temp = 0
    lossD_real_temp = 0
    lossD_fake_temp = 0
    lossD_temp = 0
    iou_loss_temp = 0
    effective_iteration = 0
    start = time.time()

    if args.loss == 'smoothl1':
        criterion = UBR_SmoothL1Loss(args.iou_th)
    elif args.loss == 'iou':
        criterion = UBR_IoULoss(args.iou_th)

    random_box_generator = NaturalUniformBoxGenerator(args.iou_th)

    for step in range(1, args.max_iter + 1):
        src_idx = np.random.choice(len(source_train_dataset))
        tar_idx = np.random.choice(len(target_train_dataset))
        src_im_data, src_gt_boxes, _, _, src_im_scale, src_raw_img, src_im_id, _ = source_train_dataset[src_idx]
        tar_im_data, tar_gt_boxes, _, _, tar_im_scale, tar_raw_img, tar_im_id, _ = target_train_dataset[tar_idx]

        # generate random box from given gt box
        # the shape of rois is (n, 5), the first column is not used
        # so, rois[:, 1:5] is [xmin, ymin, xmax, ymax]
        num_src_gt = src_gt_boxes.size(0)
        num_per_base = 60 // num_src_gt
        src_rois = torch.zeros((num_per_base * num_src_gt, 5))
        cnt = 0
        for i in range(num_src_gt):
            here = random_box_generator.get_rand_boxes(src_gt_boxes[i, :], num_per_base, src_im_data.size(1), src_im_data.size(2))
            if here is None:
                continue
            src_rois[cnt:cnt + here.size(0), :] = here
            cnt += here.size(0)
        if cnt == 0:
            log_file.write('@@@@ no box @@@@\n')
            print('@@@@@ no box @@@@@')
            continue
        src_rois = src_rois[:cnt, :]
        src_rois = Variable(src_rois.cuda())

        num_tar_gt = tar_gt_boxes.size(0)
        num_per_base = 60 // num_tar_gt
        tar_rois = torch.zeros((num_per_base * num_tar_gt, 5))
        cnt = 0
        for i in range(num_tar_gt):
            here = random_box_generator.get_rand_boxes(tar_gt_boxes[i, :], num_per_base, tar_im_data.size(1),
                                                       tar_im_data.size(2))
            if here is None:
                continue
            tar_rois[cnt:cnt + here.size(0), :] = here
            cnt += here.size(0)
        if cnt == 0:
            log_file.write('@@@@ no box @@@@\n')
            print('@@@@@ no box @@@@@')
            continue
        tar_rois = tar_rois[:cnt, :]
        tar_rois = Variable(tar_rois.cuda())

        ##############################################################################################
        # train D with real
        optimizerD.zero_grad()
        src_im_data = Variable(src_im_data.unsqueeze(0).cuda())
        src_feat = UBR.get_tanh_feat(src_im_data, src_rois)
        output_real = D(src_feat.detach())
        label_real = Variable(torch.ones(output_real.size()).cuda())
        loss_real = F.binary_cross_entropy_with_logits(output_real, label_real) * args.alpha
        loss_real.backward()

        # train D with fake
        tar_im_data = Variable(tar_im_data.unsqueeze(0).cuda())
        tar_feat = UBR.get_tanh_feat(tar_im_data, tar_rois)
        output_fake = D(tar_feat.detach())
        label_fake = Variable(torch.zeros(output_fake.size()).cuda())
        loss_fake = F.binary_cross_entropy_with_logits(output_fake, label_fake) * args.alpha
        loss_fake.backward()

        lossD_real_temp += loss_real.data[0]
        lossD_fake_temp += loss_fake.data[0]
        lossD = loss_real + loss_fake
        clip_gradient([D], 10.0)
        optimizerD.step()
        #############################################################################################

        # train G
        optimizerG.zero_grad()
        output = D(tar_feat)
        label_real = Variable(torch.ones(output.size()).cuda())
        lossG = F.binary_cross_entropy_with_logits(output, label_real)
        #lossG.backward()

        # train UBR
        bbox_pred = UBR.forward_with_tanh_feat(src_feat)
        src_gt_boxes = Variable(src_gt_boxes.cuda())
        iou_loss, num_selected_rois, num_rois, refined_rois = criterion(src_rois[:, 1:5], bbox_pred, src_gt_boxes)
        iou_loss = iou_loss.mean()
        #iou_loss.backward()

        loss = lossG * args.alpha + iou_loss
        loss.backward()
        clip_gradient([UBR], 10.0)
        optimizerG.step()
        ##############################################################################################

        effective_iteration += 1
        lossG_temp += lossG.data[0]
        lossD_temp += lossD.data[0]
        iou_loss_temp += iou_loss.data[0]

        if step % args.disp_interval == 0:
            end = time.time()
            lossG_temp /= effective_iteration
            lossD_temp /= effective_iteration
            lossD_fake_temp /= effective_iteration
            lossD_real_temp /= effective_iteration
            iou_loss_temp /= effective_iteration

            print("[net %s][session %d][iter %4d] iou_loss: %.4f, lossG: %.4f, lossD: %.4f, lr: %.2e, time: %.1f" %
                  (args.net, args.session, step, iou_loss_temp, lossG_temp, lossD_temp,  lr,  end - start))
            log_file.write("[net %s][session %d][iter %4d] iou_loss: %.4f, lossG: %.4f, lossD: %.4f, lr: %.2e, time: %.1f\n" %
                           (args.net, args.session, step, iou_loss_temp, lossG_temp, lossD_temp, lr,  end - start))

            #print('%f %f' % (lossD_real_temp, lossD_fake_temp))
            effective_iteration = 0
            lossG_temp = 0
            lossD_temp = 0
            lossD_real_temp = 0
            lossD_fake_temp = 0
            iou_loss_temp = 0
            start = time.time()

        if step % args.val_interval == 0:
            val_loss = validate(UBR, random_box_generator, criterion, val_dataset)
            tval_loss = validate(UBR, random_box_generator, criterion, tval_dataset)
            print('[net %s][session %d][step %2d] validation loss: %.4f' % (args.net, args.session, step, val_loss))
            log_file.write('[net %s][session %d][step %2d] validation loss: %.4f\n' % (args.net, args.session, step, val_loss))
            print('[net %s][session %d][step %2d] transfer validation loss: %.4f' % (args.net, args.session, step, tval_loss))
            log_file.write('[net %s][session %d][step %2d] transfer validation loss: %.4f\n' % (args.net, args.session, step, tval_loss))

            log_file.flush()

    log_file.close()
示例#9
0
def train():
    args = parse_args()

    print('Called with args:')
    print(args)
    np.random.seed(4)
    torch.manual_seed(2017)
    torch.cuda.manual_seed(1086)

    output_dir = args.save_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    source_train_dataset = TDetDataset(['coco60_train'], training=False)
    target_train_dataset = TDetDataset(['voc07_trainval'], training=False)
    val_dataset = TDetDataset(['coco60_val'], training=False)
    tval_dataset = TDetDataset(['coco_voc_val'], training=False)

    lr = args.lr

    source_model = UBR_VGG(None, not args.fc, not args.not_freeze,
                           args.no_dropout)
    target_model = UBR_VGG(None, not args.fc, not args.not_freeze,
                           args.no_dropout)
    D = Discriminator(512)

    source_model.create_architecture()
    target_model.create_architecture()

    paramsG = []
    for key, value in dict(target_model.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                paramsG += [{
                    'params': [value],
                    'lr': lr * 2,
                    'weight_decay': 0
                }]
            else:
                paramsG += [{
                    'params': [value],
                    'lr': lr,
                    'weight_decay': 0.0005
                }]

    paramsD = []
    for key, value in dict(D.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                paramsD += [{
                    'params': [value],
                    'lr': lr * 2,
                    'weight_decay': 0
                }]
            else:
                paramsD += [{
                    'params': [value],
                    'lr': lr,
                    'weight_decay': 0.0005
                }]

    optimizerG = torch.optim.SGD(paramsG, momentum=0.9)
    optimizerD = torch.optim.SGD(paramsD, momentum=0.9)

    load_name = args.pretrained_model
    print("loading checkpoint %s" % (load_name))
    checkpoint = torch.load(load_name)
    assert checkpoint['net'] == 'UBR_VGG'
    source_model.load_state_dict(checkpoint['model'])
    target_model.load_state_dict(checkpoint['model'])
    print("loaded checkpoint %s" % (load_name))

    log_file_name = os.path.join(
        output_dir, 'log_{}_{}.txt'.format(args.net, args.session))
    log_file = open(log_file_name, 'w')
    log_file.write(str(args))
    log_file.write('\n')

    source_model.cuda()
    target_model.cuda()
    D.cuda()

    # setting to train mode
    target_model.train()
    source_model.eval()
    D.train()

    lossG_temp = 0
    lossD_real_temp = 0
    lossD_fake_temp = 0
    lossD_temp = 0
    effective_iteration = 0
    start = time.time()

    if args.loss == 'smoothl1':
        criterion = UBR_SmoothL1Loss(args.iou_th)
    elif args.loss == 'iou':
        criterion = UBR_IoULoss(args.iou_th)

    random_box_generator = NaturalUniformBoxGenerator(args.iou_th)

    for step in range(1, args.max_iter + 1):
        src_idx = np.random.choice(len(source_train_dataset))
        tar_idx = np.random.choice(len(target_train_dataset))
        src_im_data, src_gt_boxes, _, _, src_im_scale, src_raw_img, src_im_id, _ = source_train_dataset[
            src_idx]
        tar_im_data, tar_gt_boxes, _, _, tar_im_scale, tar_raw_img, tar_im_id, _ = target_train_dataset[
            tar_idx]

        # generate random box from given gt box
        # the shape of rois is (n, 5), the first column is not used
        # so, rois[:, 1:5] is [xmin, ymin, xmax, ymax]
        num_src_gt = src_gt_boxes.size(0)
        num_per_base = 60 / num_src_gt
        rois = torch.zeros((num_per_base * num_sr, 5))
        cnt = 0
        for i in range(num_gt_box):
            here = random_box_generator.get_rand_boxes(gt_boxes[i, :],
                                                       num_per_base,
                                                       data_height, data_width)
            if here is None:
                continue
            rois[cnt:cnt + here.size(0), :] = here
            cnt += here.size(0)
        if cnt == 0:
            log_file.write('@@@@ no box @@@@\n')
            print('@@@@@ no box @@@@@')
            continue
        rois = rois[:cnt, :]
        mean_boxes_per_iter += rois.size(0)
        rois = Variable(rois.cuda())

        ##############################################################################################
        # train D with real
        optimizerD.zero_grad()
        src_im_data = Variable(src_im_data.unsqueeze(0).cuda())
        src_feat = F.adaptive_avg_pool2d(
            source_model.get_conv_feat(src_im_data), 1).view(1, 512)
        if args.tanh:
            src_feat = F.tanh(src_feat)
        label_real = Variable(torch.FloatTensor([[1]]).cuda())
        output_real = D(src_feat.detach())
        loss_real = F.binary_cross_entropy_with_logits(output_real, label_real)
        loss_real.backward()

        # train D with fake
        tar_im_data = Variable(tar_im_data.unsqueeze(0).cuda())
        tar_feat = F.adaptive_avg_pool2d(
            target_model.get_conv_feat(tar_im_data), 1).view(1, 512)
        if args.tanh:
            tar_feat = F.tanh(tar_feat)
        label_fake = Variable(torch.FloatTensor([[0]]).cuda())
        output_fake = D(tar_feat.detach())
        loss_fake = F.binary_cross_entropy_with_logits(output_fake, label_fake)
        loss_fake.backward()

        lossD_real_temp += loss_real.data[0]
        lossD_fake_temp += loss_fake.data[0]
        lossD = loss_real + loss_fake
        clip_gradient([D], 10.0)
        optimizerD.step()
        #############################################################################################

        # train G
        optimizerG.zero_grad()
        label_real = Variable(torch.FloatTensor([[1]]).cuda())
        output = D(tar_feat)
        lossG = F.binary_cross_entropy_with_logits(output, label_real)
        lossG.backward()
        clip_gradient([target_model], 10.0)
        optimizerG.step()
        ##############################################################################################

        effective_iteration += 1
        lossG_temp += lossG.data[0]
        lossD_temp += lossD.data[0]

        if step % args.disp_interval == 0:
            end = time.time()
            lossG_temp /= effective_iteration
            lossD_temp /= effective_iteration
            lossD_fake_temp /= effective_iteration
            lossD_real_temp /= effective_iteration
            print(
                "[net %s][session %d][iter %4d] lossG: %.4f, lossD: %.4f, lr: %.2e, time: %.1f"
                % (args.net, args.session, step, lossG_temp, lossD_temp, lr,
                   end - start))
            log_file.write(
                "[net %s][session %d][iter %4d] lossG: %.4f, lossD: %.4f, lr: %.2e, time: %.1f\n"
                % (args.net, args.session, step, lossG_temp, lossD_temp, lr,
                   end - start))

            #print('%f %f' % (lossD_real_temp, lossD_fake_temp))
            effective_iteration = 0
            lossG_temp = 0
            lossD_temp = 0
            lossD_real_temp = 0
            lossD_fake_temp = 0
            start = time.time()

        if step % args.val_interval == 0:
            val_loss = validate(target_model, random_box_generator, criterion,
                                val_dataset)
            tval_loss = validate(target_model, random_box_generator, criterion,
                                 tval_dataset)
            print('[net %s][session %d][step %2d] validation loss: %.4f' %
                  (args.net, args.session, step, val_loss))
            log_file.write(
                '[net %s][session %d][step %2d] validation loss: %.4f\n' %
                (args.net, args.session, step, val_loss))
            print(
                '[net %s][session %d][step %2d] transfer validation loss: %.4f'
                % (args.net, args.session, step, tval_loss))
            log_file.write(
                '[net %s][session %d][step %2d] transfer validation loss: %.4f\n'
                % (args.net, args.session, step, tval_loss))

            log_file.flush()

    log_file.close()