コード例 #1
0
ファイル: eval_attdet.py プロジェクト: deneb2016/TDET
def show():
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    load_name = os.path.join(args.save_dir, 'tdet',
                             '{}.pth'.format(args.model_name))
    print("loading checkpoint %s" % (load_name))
    checkpoint = torch.load(load_name)
    if checkpoint['net'] == 'ATT_DET':
        model = AttentiveDetGlobal(None, 20 if args.target_only else 80,
                                   args.hidden_dim, args.im_size)
    else:
        raise Exception('network is not defined')
    model.load_state_dict(checkpoint['model'])
    model.to(device)
    model.eval()
    print("loaded checkpoint %s" % (load_name))

    test_dataset = TDETDataset(['voc07_test'],
                               args.data_dir,
                               args.prop_method,
                               num_classes=20,
                               prop_min_scale=args.prop_min_scale,
                               prop_topk=args.num_prop)

    for data_idx in range(len(test_dataset)):
        batch = test_dataset.get_data(data_idx, False, args.im_size, True)
        img = cv2.resize(batch['raw_img'],
                         None,
                         None,
                         fx=batch['im_scale'][0],
                         fy=batch['im_scale'][1],
                         interpolation=cv2.INTER_LINEAR)
        im_data = batch['im_data'].unsqueeze(0).to(device)
        gt_labels = batch['gt_labels']
        pos_cls = [i for i in range(80) if i in gt_labels]
        pos_cls = torch.tensor(pos_cls, dtype=torch.long, device=device)

        if len(pos_cls) < 2:
            continue
        print(pos_cls)
        score, loss, attention, densified_attention = model(im_data, pos_cls)
        for i, cls in enumerate(pos_cls):
            print(VOC_CLASSES[cls])
            print(score[i])
            print(score[i, cls])
            maxi = attention[i].max()
            print(maxi)
            #attention[i][attention[i] < maxi] = 0
            plt.imshow(img)
            plt.show()
            plt.imshow(attention[i].cpu().detach().numpy())

            plt.show()
            plt.imshow(densified_attention[i].cpu().detach().numpy())

            plt.show()
コード例 #2
0
ファイル: eval_dc_vgg16.py プロジェクト: deneb2016/TDET
def show():
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    load_name = os.path.join(args.save_dir, 'tdet',
                             '{}.pth'.format(args.model_name))
    print("loading checkpoint %s" % (load_name))
    checkpoint = torch.load(load_name)
    if checkpoint['net'] == 'DC_VGG16':
        model = DC_VGG16_CLS(
            os.path.join(args.data_dir, 'pretrained_model/vgg16_caffe.pth'),
            20, 1, args.specific_from, args.specific_to)
    else:
        raise Exception('network is not defined')
    model.load_state_dict(checkpoint['model'])
    model.cuda()
    model.eval()
    print("loaded checkpoint %s" % (load_name))
    print(torch.cuda.device_count())

    test_dataset = TDETDataset(['voc07_test'],
                               args.data_dir,
                               args.prop_method,
                               num_classes=20,
                               prop_min_scale=args.prop_min_scale,
                               prop_topk=args.num_prop)

    for data_idx in range(len(test_dataset)):
        batch = test_dataset.get_data(data_idx, False, 600)
        im_data = batch['im_data'].unsqueeze(0).to(device)
        gt_labels = batch['gt_labels']
        pos_cls = [i for i in range(80) if i in gt_labels]
        print(pos_cls)

        for cls in range(14, 15):
            if cls not in pos_cls:
                continue
            score, spe_feat = model(im_data, [cls])
            C, H, W = spe_feat.size(1), spe_feat.size(2), spe_feat.size(3)
            print(cls, score)
            print(spe_feat.detach().view(C, -1).mean(1).topk(5))
            activation = torch.mean(spe_feat.detach().view(C, H, W), 0)
            plt.imshow(activation)
            plt.show()
            plt.imshow(batch['raw_img'])
            plt.show()
コード例 #3
0
ファイル: eval.py プロジェクト: deneb2016/TDET
def eval():
    print('Called with args:')
    print(args)

    np.random.seed(3)
    torch.manual_seed(4)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(5)
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    eval_kit = voc_eval_kit('test', '2007',
                            os.path.join(args.data_dir, 'VOCdevkit2007'))

    test_dataset = TDETDataset(['voc07_test'],
                               args.data_dir,
                               args.prop_method,
                               num_classes=20,
                               prop_min_scale=args.prop_min_scale,
                               prop_topk=args.num_prop)

    load_name = os.path.join(args.save_dir, 'tdet',
                             '{}.pth'.format(args.model_name))
    print("loading checkpoint %s" % (load_name))
    checkpoint = torch.load(load_name)
    if checkpoint['net'] == 'TDET_VGG16':
        model = TDET_VGG16(None,
                           20,
                           pooling_method=checkpoint['pooling_method'],
                           cls_specific_det=checkpoint['cls_specific'] if
                           checkpoint['cls_specific'] is not False else 'no',
                           share_level=checkpoint['share_level'],
                           det_softmax=checkpoint['det_softmax']
                           if 'det_softmax' in checkpoint else 'no',
                           det_choice=checkpoint['det_choice']
                           if 'det_choice' in checkpoint else 1)
    else:
        raise Exception('network is not defined')
    model.load_state_dict(checkpoint['model'])
    print("loaded checkpoint %s" % (load_name))

    model.to(device)
    model.eval()

    start = time.time()

    num_images = len(test_dataset)
    # heuristic: keep an average of 40 detections per class per images prior
    # to NMS
    max_per_set = 40 * num_images
    # heuristic: keep at most 100 detection per class per image prior to NMS
    max_per_image = 100
    # detection thresold for each class (this is adaptively set based on the
    # max_per_set constraint)
    thresh = -np.inf * np.ones(20)
    # thresh = 0.1 * np.ones(imdb.num_classes)
    # top_scores will hold one minheap of scores per class (used to enforce
    # the max_per_set constraint)
    top_scores = [[] for _ in range(20)]
    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)] for _ in range(20)]
    cls_scores_set = [[[] for _ in range(num_images)] for _ in range(20)]
    det_scores_set = [[[] for _ in range(num_images)] for _ in range(20)]

    for index in range(len(test_dataset)):
        scores = 0
        c_scores = 0
        d_scores = 0
        N = 0
        if args.multiscale:
            comb = itertools.product([False, True], [480, 576, 688, 864, 1200])
        else:
            comb = itertools.product([False], [688])
        for h_flip, im_size in comb:
            test_batch = test_dataset.get_data(index, h_flip, im_size)

            im_data = test_batch['im_data'].unsqueeze(0).to(device)
            proposals = test_batch['proposals'].to(device)

            local_scores, local_cls_scores, local_det_scores = model(
                im_data, proposals)
            local_scores = local_scores.detach().cpu().numpy()
            local_cls_scores = local_cls_scores.detach().cpu().numpy()
            local_det_scores = local_det_scores.detach().cpu().numpy()
            scores = scores + local_scores
            c_scores = c_scores + local_cls_scores
            d_scores = d_scores + local_det_scores
            N += 1
        scores = 100 * scores / N
        c_scores = 10 * c_scores / N
        d_scores = 10 * d_scores / N

        boxes = test_dataset.get_raw_proposal(index)

        for cls in range(20):
            inds = np.where((scores[:, cls] > thresh[cls]))[0]
            cls_scores = scores[inds, cls]
            cls_c_scores = c_scores[inds, cls]
            if checkpoint['cls_specific']:
                cls_d_scores = d_scores[inds, cls]
            else:
                cls_d_scores = d_scores[inds, 0]
            cls_boxes = boxes[inds].copy()

            top_inds = np.argsort(-cls_scores)[:max_per_image]
            cls_scores = cls_scores[top_inds]
            cls_c_scores = cls_c_scores[top_inds]
            cls_d_scores = cls_d_scores[top_inds]
            cls_boxes = cls_boxes[top_inds, :]

            # if cls_scores[0] > 10:
            #     print(cls)
            #     plt.imshow(test_batch['raw_img'])
            #     draw_box(cls_boxes[0:10, :])
            #     draw_box(test_batch['gt_boxes'] / test_batch['im_scale'], 'black')
            #     plt.show()

            # push new scores onto the minheap
            for val in cls_scores:
                heapq.heappush(top_scores[cls], val)
            # if we've collected more than the max number of detection,
            # then pop items off the minheap and update the class threshold
            if len(top_scores[cls]) > max_per_set:
                while len(top_scores[cls]) > max_per_set:
                    heapq.heappop(top_scores[cls])
                thresh[cls] = top_scores[cls][0]

            all_boxes[cls][index] = np.hstack(
                (cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32,
                                                               copy=False)
            cls_scores_set[cls][index] = cls_c_scores
            det_scores_set[cls][index] = cls_d_scores

        if index % 100 == 99:
            print('%d images complete, elapsed time:%.1f' %
                  (index + 1, time.time() - start))

    for j in range(20):
        for i in range(len(test_dataset)):
            inds = np.where(all_boxes[j][i][:, -1] > thresh[j])[0]
            all_boxes[j][i] = all_boxes[j][i][inds, :]
            cls_scores_set[j][i] = cls_scores_set[j][i][inds]
            det_scores_set[j][i] = det_scores_set[j][i][inds]

    if args.multiscale:
        save_name = os.path.join(args.save_dir, 'detection_result',
                                 '{}_multiscale.pkl'.format(args.model_name))
    else:
        save_name = os.path.join(args.save_dir, 'detection_result',
                                 '{}.pkl'.format(args.model_name))
    pickle.dump(
        {
            'all_boxes': all_boxes,
            'cls': cls_scores_set,
            'det': det_scores_set
        }, open(save_name, 'wb'))

    print('Detection Complete, elapsed time: %.1f', time.time() - start)

    for cls in range(20):
        for index in range(len(test_dataset)):
            dets = all_boxes[cls][index]
            if dets == []:
                continue
            keep = nms(dets, 0.3)
            all_boxes[cls][index] = dets[keep, :].copy()
    print('NMS complete, elapsed time: %.1f', time.time() - start)

    eval_kit.evaluate_detections(all_boxes)
コード例 #4
0
ファイル: train_dc_det.py プロジェクト: deneb2016/TDET
def train():
    args = parse_args()
    print('Called with args:')
    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    output_dir = args.save_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if args.target_only:
        source_train_dataset = TDETDataset(['voc07_trainval'],
                                           args.data_dir,
                                           args.prop_method,
                                           num_classes=20,
                                           prop_min_scale=args.prop_min_scale,
                                           prop_topk=args.num_prop)
    else:
        source_train_dataset = TDETDataset(
            ['coco60_train2014', 'coco60_val2014'],
            args.data_dir,
            args.prop_method,
            num_classes=60,
            prop_min_scale=args.prop_min_scale,
            prop_topk=args.num_prop)
    target_val_dataset = TDETDataset(['voc07_test'],
                                     args.data_dir,
                                     args.prop_method,
                                     num_classes=20,
                                     prop_min_scale=args.prop_min_scale,
                                     prop_topk=args.num_prop)

    lr = args.lr

    if args.net == 'DC_VGG16_DET':
        base_model = DC_VGG16_CLS(None, 20 if args.target_only else 80, 3, 4)
        checkpoint = torch.load(args.pretrained_base_path)
        base_model.load_state_dict(checkpoint['model'])
        del checkpoint
        model = DC_VGG16_DET(base_model, args.pooling_method)

    optimizer = model.get_optimizer(args.lr)

    log_file_name = os.path.join(
        output_dir, 'log_{}_{}.txt'.format(args.net, args.session))
    log_file = open(log_file_name, 'w')
    log_file.write(str(args))
    log_file.write('\n')

    model.to(device)
    model.train()
    source_loss_sum = 0
    source_pos_prop_sum = 0
    source_neg_prop_sum = 0
    start = time.time()
    optimizer.zero_grad()
    for step in range(args.start_iter, args.max_iter + 1):
        if step % len(source_train_dataset) == 1:
            source_rand_perm = np.random.permutation(len(source_train_dataset))

        source_index = source_rand_perm[step % len(source_train_dataset)]

        source_batch = source_train_dataset.get_data(
            source_index,
            h_flip=np.random.rand() > 0.5,
            target_im_size=np.random.choice([480, 576, 688, 864, 1200]))

        source_im_data = source_batch['im_data'].unsqueeze(0).to(device)
        source_proposals = source_batch['proposals']
        source_gt_boxes = source_batch['gt_boxes']
        if args.target_only:
            source_gt_labels = source_batch['gt_labels']
        else:
            source_gt_labels = source_batch['gt_labels'] + 20
        source_pos_cls = [i for i in range(80) if i in source_gt_labels]

        source_loss = 0
        for cls in np.random.choice(source_pos_cls, 2):
            indices = np.where(source_gt_labels.numpy() == cls)[0]
            here_gt_boxes = source_gt_boxes[indices]
            here_proposals, here_labels, _, pos_cnt, neg_cnt = sample_proposals(
                here_gt_boxes, source_proposals, args.bs // 2, args.pos_ratio)
            # plt.imshow(source_batch['raw_img'])
            # draw_box(here_proposals[:pos_cnt] / source_batch['im_scale'], 'black')
            # draw_box(here_proposals[pos_cnt:] / source_batch['im_scale'], 'yellow')
            # plt.show()
            here_proposals = here_proposals.to(device)
            here_labels = here_labels.to(device)
            here_loss = model(source_im_data, cls, here_proposals, here_labels)
            source_loss = source_loss + here_loss

            source_pos_prop_sum += pos_cnt
            source_neg_prop_sum += neg_cnt

        source_loss = source_loss / 2

        source_loss_sum += source_loss.item()
        source_loss.backward()

        clip_gradient(model, 10.0)
        optimizer.step()
        optimizer.zero_grad()

        if step % args.disp_interval == 0:
            end = time.time()
            source_loss_sum /= args.disp_interval
            source_pos_prop_sum /= args.disp_interval
            source_neg_prop_sum /= args.disp_interval
            log_message = "[%s][session %d][iter %4d] loss: %.4f, pos_prop: %.1f, neg_prop: %.1f, lr: %.2e, time: %.1f" % \
                          (args.net, args.session, step, source_loss_sum, source_pos_prop_sum, source_neg_prop_sum, lr, end - start)
            print(log_message)
            log_file.write(log_message + '\n')
            log_file.flush()
            source_loss_sum = 0
            source_pos_prop_sum = 0
            source_neg_prop_sum = 0
            start = time.time()

        if step in (args.max_iter * 4 // 7, args.max_iter * 6 // 7):
            adjust_learning_rate(optimizer, 0.1)
            lr *= 0.1

        if step % args.save_interval == 0 or step == args.max_iter:
            validate(model, target_val_dataset, args, device)
            save_name = os.path.join(
                output_dir, '{}_{}_{}.pth'.format(args.net, args.session,
                                                  step))
            checkpoint = dict()
            checkpoint['net'] = args.net
            checkpoint['session'] = args.session
            checkpoint['pooling_method'] = args.pooling_method
            checkpoint['iterations'] = step
            checkpoint['model'] = model.state_dict()

            save_checkpoint(checkpoint, save_name)
            print('save model: {}'.format(save_name))

    log_file.close()
コード例 #5
0
def train():
    args = parse_args()
    print('Called with args:')
    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    output_dir = args.save_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    source_train_dataset = TDETDataset(['coco60_train2014', 'coco60_val2014'],
                                       args.data_dir,
                                       args.prop_method,
                                       num_classes=60,
                                       prop_min_scale=args.prop_min_scale,
                                       prop_topk=args.num_prop)
    target_train_dataset = TDETDataset(['voc07_trainval'],
                                       args.data_dir,
                                       args.prop_method,
                                       num_classes=20,
                                       prop_min_scale=args.prop_min_scale,
                                       prop_topk=args.num_prop)

    lr = args.lr

    if args.net == 'NEW_TDET':
        model = NEW_TDET(os.path.join(args.data_dir,
                                      'pretrained_model/vgg16_caffe.pth'),
                         20,
                         pooling_method=args.pooling_method,
                         share_level=args.share_level,
                         mil_topk=args.mil_topk)
    else:
        raise Exception('network is not defined')

    optimizer = model.get_optimizer(args.lr)

    if args.resume:
        load_name = os.path.join(
            output_dir, '{}_{}_{}.pth'.format(args.net, args.checksession,
                                              args.checkiter))
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name)
        assert args.net == checkpoint['net']
        args.start_iter = checkpoint['iterations'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr = optimizer.param_groups[0]['lr']
        print("loaded checkpoint %s" % (load_name))

    log_file_name = os.path.join(
        output_dir, 'log_{}_{}.txt'.format(args.net, args.session))
    if args.resume:
        log_file = open(log_file_name, 'a')
    else:
        log_file = open(log_file_name, 'w')
    log_file.write(str(args))
    log_file.write('\n')

    model.to(device)
    model.train()
    source_loss_sum = 0
    target_loss_sum = 0
    source_pos_prop_sum = 0
    source_neg_prop_sum = 0
    target_prop_sum = 0
    start = time.time()
    for step in range(args.start_iter, args.max_iter + 1):
        if step % len(source_train_dataset) == 1:
            source_rand_perm = np.random.permutation(len(source_train_dataset))
        if step % len(target_train_dataset) == 1:
            target_rand_perm = np.random.permutation(len(target_train_dataset))

        source_index = source_rand_perm[step % len(source_train_dataset)]
        target_index = target_rand_perm[step % len(target_train_dataset)]

        source_batch = source_train_dataset.get_data(
            source_index,
            h_flip=np.random.rand() > 0.5,
            target_im_size=np.random.choice([480, 576, 688, 864, 1200]))
        target_batch = target_train_dataset.get_data(
            target_index,
            h_flip=np.random.rand() > 0.5,
            target_im_size=np.random.choice([480, 576, 688, 864, 1200]))

        source_im_data = source_batch['im_data'].unsqueeze(0).to(device)
        source_proposals = source_batch['proposals']
        source_gt_boxes = source_batch['gt_boxes']
        source_proposals, source_labels, _, pos_cnt, neg_cnt = sample_proposals(
            source_gt_boxes, source_proposals, args.bs, args.pos_ratio)
        source_proposals = source_proposals.to(device)
        source_gt_boxes = source_gt_boxes.to(device)
        source_labels = source_labels.to(device)

        target_im_data = target_batch['im_data'].unsqueeze(0).to(device)
        target_proposals = target_batch['proposals'].to(device)
        target_image_level_label = target_batch['image_level_label'].to(device)

        optimizer.zero_grad()

        # source forward & backward
        _, source_loss = model.forward_det(source_im_data, source_proposals,
                                           source_labels)
        source_loss_sum += source_loss.item()
        source_loss = source_loss * (1 - args.alpha)
        source_loss.backward()

        # target forward & backward
        if args.cam_like:
            _, target_loss = model.forward_cls_camlike(
                target_im_data, target_proposals, target_image_level_label)
        else:
            _, target_loss = model.forward_cls(target_im_data,
                                               target_proposals,
                                               target_image_level_label)
        target_loss_sum += target_loss.item()
        target_loss = target_loss * args.alpha
        target_loss.backward()

        clip_gradient(model, 10.0)
        optimizer.step()
        source_pos_prop_sum += pos_cnt
        source_neg_prop_sum += neg_cnt
        target_prop_sum += target_proposals.size(0)

        if step % args.disp_interval == 0:
            end = time.time()
            loss_sum = source_loss_sum * (
                1 - args.alpha) + target_loss_sum * args.alpha
            loss_sum /= args.disp_interval
            source_loss_sum /= args.disp_interval
            target_loss_sum /= args.disp_interval
            source_pos_prop_sum /= args.disp_interval
            source_neg_prop_sum /= args.disp_interval
            target_prop_sum /= args.disp_interval
            log_message = "[%s][session %d][iter %4d] loss: %.4f, src_loss: %.4f, tar_loss: %.4f, pos_prop: %.1f, neg_prop: %.1f, tar_prop: %.1f, lr: %.2e, time: %.1f" % \
                          (args.net, args.session, step, loss_sum, source_loss_sum, target_loss_sum, source_pos_prop_sum, source_neg_prop_sum, target_prop_sum, lr, end - start)
            print(log_message)
            log_file.write(log_message + '\n')
            log_file.flush()
            source_loss_sum = 0
            target_loss_sum = 0
            source_pos_prop_sum = 0
            source_neg_prop_sum = 0
            target_prop_sum = 0
            start = time.time()

        if step in (args.max_iter * 4 // 7, args.max_iter * 6 // 7):
            adjust_learning_rate(optimizer, 0.1)
            lr *= 0.1

        if step % args.save_interval == 0 or step == args.max_iter:
            save_name = os.path.join(
                output_dir, '{}_{}_{}.pth'.format(args.net, args.session,
                                                  step))
            checkpoint = dict()
            checkpoint['net'] = args.net
            checkpoint['session'] = args.session
            checkpoint['pooling_method'] = args.pooling_method
            checkpoint['share_level'] = args.share_level
            checkpoint['iterations'] = step
            checkpoint['model'] = model.state_dict()
            checkpoint['optimizer'] = optimizer.state_dict()

            save_checkpoint(checkpoint, save_name)
            print('save model: {}'.format(save_name))

    log_file.close()
コード例 #6
0
load_name = os.path.join(args.save_dir, 'tdet', '{}.pth'.format(args.model_name))
print("loading checkpoint %s" % (load_name))
checkpoint = torch.load(load_name)
if checkpoint['net'] == 'NEW_TDET':
    model = NEW_TDET(None, 20, pooling_method=checkpoint['pooling_method'], share_level=checkpoint['share_level'])
else:
    raise Exception('network is not defined')
model.load_state_dict(checkpoint['model'])
print("loaded checkpoint %s" % (load_name))

model.to(device)
model.eval()

cls = 10
for index in range(len(test_dataset)):
    batch = test_dataset.get_data(index, False, 640)
    if batch['image_level_label'][cls] == 0:
        continue
    im_data = batch['im_data'].unsqueeze(0).to(device)

    # proposals = [
    #     [100, 100, 200, 200]
    # ]
    # prop_tensor = torch.tensor(proposals, dtype=torch.float, device=device)
    # prop_tensor = prop_tensor * batch['im_scale']

    prop_tensor = batch['proposals'].to(device)
    _, c_scores, d_scores = model(im_data, prop_tensor)
    c_scores = c_scores.detach().cpu().numpy()
    d_scores = d_scores.detach().cpu().numpy()
コード例 #7
0
ファイル: train_camdet.py プロジェクト: deneb2016/TDET
def train():
    args = parse_args()
    print('Called with args:')
    print(args)
    assert args.bs % 2 == 0

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    print(device)
    output_dir = args.save_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    target_only = args.target_only
    source_train_dataset = TDETDataset(['coco60_train2014', 'coco60_val2014'],
                                       args.data_dir,
                                       'eb',
                                       num_classes=60)
    target_train_dataset = TDETDataset(['voc07_trainval'],
                                       args.data_dir,
                                       'eb',
                                       num_classes=20)

    lr = args.lr

    if args.net == 'CAM_DET':
        model = CamDet(
            os.path.join(args.data_dir, 'pretrained_model/vgg16_caffe.pth')
            if not args.resume else None, 20 if target_only else 80,
            args.hidden_dim)
    else:
        raise Exception('network is not defined')

    optimizer = model.get_optimizer(args.lr)

    if args.resume:
        load_name = os.path.join(
            output_dir, '{}_{}_{}.pth'.format(args.net, args.checksession,
                                              args.checkiter))
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name)
        assert args.net == checkpoint['net']
        args.start_iter = checkpoint['iterations'] + 1
        model.load_state_dict(checkpoint['model'])
        print("loaded checkpoint %s" % (load_name))
        del checkpoint

    log_file_name = os.path.join(
        output_dir, 'log_{}_{}.txt'.format(args.net, args.session))
    if args.resume:
        log_file = open(log_file_name, 'a')
    else:
        log_file = open(log_file_name, 'w')
    log_file.write(str(args))
    log_file.write('\n')

    model.to(device)
    model.train()
    source_loss_sum = 0
    target_loss_sum = 0
    total_loss_sum = 0
    start = time.time()
    source_rand_perm = None
    target_rand_perm = None
    for step in range(args.start_iter, args.max_iter + 1):
        if source_rand_perm is None or step % len(source_train_dataset) == 1:
            source_rand_perm = np.random.permutation(len(source_train_dataset))
        if target_rand_perm is None or step % len(target_train_dataset) == 1:
            target_rand_perm = np.random.permutation(len(target_train_dataset))

        source_index = source_rand_perm[step % len(source_train_dataset)]
        target_index = target_rand_perm[step % len(target_train_dataset)]

        optimizer.zero_grad()
        if not target_only:
            source_batch = source_train_dataset.get_data(
                source_index,
                h_flip=np.random.rand() > 0.5,
                target_im_size=np.random.choice([480, 576, 688, 864, 1200]))

            source_im_data = source_batch['im_data'].unsqueeze(0).to(device)
            source_gt_labels = source_batch['gt_labels'] + 20
            source_pos_cls = [i for i in range(80) if i in source_gt_labels]
            source_pos_cls = torch.tensor(np.random.choice(
                source_pos_cls,
                min(args.bs, len(source_pos_cls)),
                replace=False),
                                          dtype=torch.long,
                                          device=device)

            source_loss, _, _ = model(source_im_data, source_pos_cls)
            source_loss_sum += source_loss.item()

        target_batch = target_train_dataset.get_data(
            target_index,
            h_flip=np.random.rand() > 0.5,
            target_im_size=np.random.choice([480, 576, 688, 864, 1200]))

        target_im_data = target_batch['im_data'].unsqueeze(0).to(device)
        target_gt_labels = target_batch['gt_labels']
        target_pos_cls = [i for i in range(80) if i in target_gt_labels]
        target_pos_cls = torch.tensor(np.random.choice(
            target_pos_cls, min(args.bs, len(target_pos_cls)), replace=False),
                                      dtype=torch.long,
                                      device=device)

        target_loss, _, _, _ = model(target_im_data, target_pos_cls)
        target_loss_sum += target_loss.item()
        if args.target_only:
            total_loss = target_loss
        else:
            total_loss = (source_loss + target_loss) * 0.5
        total_loss.backward()
        total_loss_sum += total_loss.item()
        clip_gradient(model, 10.0)
        optimizer.step()

        if step % args.disp_interval == 0:
            end = time.time()
            total_loss_sum /= args.disp_interval
            source_loss_sum /= args.disp_interval
            target_loss_sum /= args.disp_interval
            log_message = "[%s][session %d][iter %4d] loss: %.8f, src_loss: %.8f, tar_loss: %.8f, lr: %.2e, time: %.1f" % \
                          (args.net, args.session, step, total_loss_sum, source_loss_sum, target_loss_sum, lr, end - start)
            print(log_message)
            log_file.write(log_message + '\n')
            log_file.flush()
            total_loss_sum = 0
            source_loss_sum = 0
            target_loss_sum = 0
            start = time.time()

        if step in (args.max_iter * 4 // 7, args.max_iter * 6 // 7):
            adjust_learning_rate(optimizer, 0.1)
            lr *= 0.1

        if step % args.save_interval == 0 or step == args.max_iter:
            save_name = os.path.join(
                output_dir, '{}_{}_{}.pth'.format(args.net, args.session,
                                                  step))
            checkpoint = dict()
            checkpoint['net'] = args.net
            checkpoint['session'] = args.session
            checkpoint['iterations'] = step
            checkpoint['model'] = model.state_dict()

            save_checkpoint(checkpoint, save_name)
            print('save model: {}'.format(save_name))

    log_file.close()
コード例 #8
0
from datasets.tdet_dataset import TDETDataset
from matplotlib import pyplot as plt
import numpy as np
from utils.box_utils import all_pair_iou


def draw_box(boxes, col=None):
    for j, (xmin, ymin, xmax, ymax) in enumerate(boxes):
        if col is None:
            c = np.random.rand(3)
        else:
            c = col
        plt.hlines(ymin, xmin, xmax, colors=c, lw=2)
        plt.hlines(ymax, xmin, xmax, colors=c, lw=2)
        plt.vlines(xmin, ymin, ymax, colors=c, lw=2)
        plt.vlines(xmax, ymin, ymax, colors=c, lw=2)


dataset = TDETDataset(dataset_names=['coco60_val'], data_dir='../data', prop_method='mcg', prop_min_size=0, prop_topk=3000, num_classes=60)
tot = 0.0
det = 0.0
for i in range(len(dataset)):
    here = dataset.get_data(i)

    iou = all_pair_iou(here['gt_boxes'], here['proposals'])
    det += iou.max(1)[0].gt(0.8).sum().item()
    tot += iou.size(0)
    recall = det / tot
    if i % 100 == 99:
        print('%d: %f, %f, %.3f' % (i + 1, det, tot, recall))