Example #1
0
            fasterRCNN.zero_grad()
            rois, cls_prob, bbox_pred, \
                rpn_loss_cls, rpn_loss_box, \
                RCNN_loss_cls, RCNN_loss_bbox, \
                rois_label = fasterRCNN(im_data, im_info, gt_boxes, num_boxes)

            loss = rpn_loss_cls.mean() + rpn_loss_box.mean() \
                + RCNN_loss_cls.mean() + RCNN_loss_bbox.mean()
            loss_temp += loss.item()

            # backward
            optimizer.zero_grad()
            loss.backward()
            if args.net == "vgg16":
                clip_gradient(fasterRCNN, 10.)
            optimizer.step()

            if step % args.disp_interval == 0:
                end = time.time()
                if step > 0:
                    loss_temp /= (args.disp_interval + 1)

                if args.mGPUs:
                    loss_rpn_cls = rpn_loss_cls.mean().item()
                    loss_rpn_box = rpn_loss_box.mean().item()
                    loss_rcnn_cls = RCNN_loss_cls.mean().item()
                    loss_rcnn_box = RCNN_loss_bbox.mean().item()
                    fg_cnt = torch.sum(rois_label.data.ne(0))
                    bg_cnt = rois_label.data.numel() - fg_cnt
                else:
Example #2
0
def train():
    args = parse_args()
    print('Called with args:')
    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    output_dir = args.save_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    source_train_dataset = TDETDataset(['coco60_train2014', 'coco60_val2014'],
                                       args.data_dir,
                                       args.prop_method,
                                       num_classes=60,
                                       prop_min_scale=args.prop_min_scale,
                                       prop_topk=args.num_prop)
    target_train_dataset = TDETDataset(['voc07_trainval'],
                                       args.data_dir,
                                       args.prop_method,
                                       num_classes=20,
                                       prop_min_scale=args.prop_min_scale,
                                       prop_topk=args.num_prop)

    lr = args.lr

    if args.net == 'NEW_TDET':
        model = NEW_TDET(os.path.join(args.data_dir,
                                      'pretrained_model/vgg16_caffe.pth'),
                         20,
                         pooling_method=args.pooling_method,
                         share_level=args.share_level,
                         mil_topk=args.mil_topk)
    else:
        raise Exception('network is not defined')

    optimizer = model.get_optimizer(args.lr)

    if args.resume:
        load_name = os.path.join(
            output_dir, '{}_{}_{}.pth'.format(args.net, args.checksession,
                                              args.checkiter))
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name)
        assert args.net == checkpoint['net']
        args.start_iter = checkpoint['iterations'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr = optimizer.param_groups[0]['lr']
        print("loaded checkpoint %s" % (load_name))

    log_file_name = os.path.join(
        output_dir, 'log_{}_{}.txt'.format(args.net, args.session))
    if args.resume:
        log_file = open(log_file_name, 'a')
    else:
        log_file = open(log_file_name, 'w')
    log_file.write(str(args))
    log_file.write('\n')

    model.to(device)
    model.train()
    source_loss_sum = 0
    target_loss_sum = 0
    source_pos_prop_sum = 0
    source_neg_prop_sum = 0
    target_prop_sum = 0
    start = time.time()
    for step in range(args.start_iter, args.max_iter + 1):
        if step % len(source_train_dataset) == 1:
            source_rand_perm = np.random.permutation(len(source_train_dataset))
        if step % len(target_train_dataset) == 1:
            target_rand_perm = np.random.permutation(len(target_train_dataset))

        source_index = source_rand_perm[step % len(source_train_dataset)]
        target_index = target_rand_perm[step % len(target_train_dataset)]

        source_batch = source_train_dataset.get_data(
            source_index,
            h_flip=np.random.rand() > 0.5,
            target_im_size=np.random.choice([480, 576, 688, 864, 1200]))
        target_batch = target_train_dataset.get_data(
            target_index,
            h_flip=np.random.rand() > 0.5,
            target_im_size=np.random.choice([480, 576, 688, 864, 1200]))

        source_im_data = source_batch['im_data'].unsqueeze(0).to(device)
        source_proposals = source_batch['proposals']
        source_gt_boxes = source_batch['gt_boxes']
        source_proposals, source_labels, _, pos_cnt, neg_cnt = sample_proposals(
            source_gt_boxes, source_proposals, args.bs, args.pos_ratio)
        source_proposals = source_proposals.to(device)
        source_gt_boxes = source_gt_boxes.to(device)
        source_labels = source_labels.to(device)

        target_im_data = target_batch['im_data'].unsqueeze(0).to(device)
        target_proposals = target_batch['proposals'].to(device)
        target_image_level_label = target_batch['image_level_label'].to(device)

        optimizer.zero_grad()

        # source forward & backward
        _, source_loss = model.forward_det(source_im_data, source_proposals,
                                           source_labels)
        source_loss_sum += source_loss.item()
        source_loss = source_loss * (1 - args.alpha)
        source_loss.backward()

        # target forward & backward
        if args.cam_like:
            _, target_loss = model.forward_cls_camlike(
                target_im_data, target_proposals, target_image_level_label)
        else:
            _, target_loss = model.forward_cls(target_im_data,
                                               target_proposals,
                                               target_image_level_label)
        target_loss_sum += target_loss.item()
        target_loss = target_loss * args.alpha
        target_loss.backward()

        clip_gradient(model, 10.0)
        optimizer.step()
        source_pos_prop_sum += pos_cnt
        source_neg_prop_sum += neg_cnt
        target_prop_sum += target_proposals.size(0)

        if step % args.disp_interval == 0:
            end = time.time()
            loss_sum = source_loss_sum * (
                1 - args.alpha) + target_loss_sum * args.alpha
            loss_sum /= args.disp_interval
            source_loss_sum /= args.disp_interval
            target_loss_sum /= args.disp_interval
            source_pos_prop_sum /= args.disp_interval
            source_neg_prop_sum /= args.disp_interval
            target_prop_sum /= args.disp_interval
            log_message = "[%s][session %d][iter %4d] loss: %.4f, src_loss: %.4f, tar_loss: %.4f, pos_prop: %.1f, neg_prop: %.1f, tar_prop: %.1f, lr: %.2e, time: %.1f" % \
                          (args.net, args.session, step, loss_sum, source_loss_sum, target_loss_sum, source_pos_prop_sum, source_neg_prop_sum, target_prop_sum, lr, end - start)
            print(log_message)
            log_file.write(log_message + '\n')
            log_file.flush()
            source_loss_sum = 0
            target_loss_sum = 0
            source_pos_prop_sum = 0
            source_neg_prop_sum = 0
            target_prop_sum = 0
            start = time.time()

        if step in (args.max_iter * 4 // 7, args.max_iter * 6 // 7):
            adjust_learning_rate(optimizer, 0.1)
            lr *= 0.1

        if step % args.save_interval == 0 or step == args.max_iter:
            save_name = os.path.join(
                output_dir, '{}_{}_{}.pth'.format(args.net, args.session,
                                                  step))
            checkpoint = dict()
            checkpoint['net'] = args.net
            checkpoint['session'] = args.session
            checkpoint['pooling_method'] = args.pooling_method
            checkpoint['share_level'] = args.share_level
            checkpoint['iterations'] = step
            checkpoint['model'] = model.state_dict()
            checkpoint['optimizer'] = optimizer.state_dict()

            save_checkpoint(checkpoint, save_name)
            print('save model: {}'.format(save_name))

    log_file.close()
Example #3
0
def train():
    args = parse_args()
    print('Called with args:')
    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    output_dir = args.save_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if args.target_only:
        source_train_dataset = TDETDataset(['voc07_trainval'],
                                           args.data_dir,
                                           args.prop_method,
                                           num_classes=20,
                                           prop_min_scale=args.prop_min_scale,
                                           prop_topk=args.num_prop)
    else:
        source_train_dataset = TDETDataset(
            ['coco60_train2014', 'coco60_val2014'],
            args.data_dir,
            args.prop_method,
            num_classes=60,
            prop_min_scale=args.prop_min_scale,
            prop_topk=args.num_prop)
    target_val_dataset = TDETDataset(['voc07_test'],
                                     args.data_dir,
                                     args.prop_method,
                                     num_classes=20,
                                     prop_min_scale=args.prop_min_scale,
                                     prop_topk=args.num_prop)

    lr = args.lr

    if args.net == 'DC_VGG16_DET':
        base_model = DC_VGG16_CLS(None, 20 if args.target_only else 80, 3, 4)
        checkpoint = torch.load(args.pretrained_base_path)
        base_model.load_state_dict(checkpoint['model'])
        del checkpoint
        model = DC_VGG16_DET(base_model, args.pooling_method)

    optimizer = model.get_optimizer(args.lr)

    log_file_name = os.path.join(
        output_dir, 'log_{}_{}.txt'.format(args.net, args.session))
    log_file = open(log_file_name, 'w')
    log_file.write(str(args))
    log_file.write('\n')

    model.to(device)
    model.train()
    source_loss_sum = 0
    source_pos_prop_sum = 0
    source_neg_prop_sum = 0
    start = time.time()
    optimizer.zero_grad()
    for step in range(args.start_iter, args.max_iter + 1):
        if step % len(source_train_dataset) == 1:
            source_rand_perm = np.random.permutation(len(source_train_dataset))

        source_index = source_rand_perm[step % len(source_train_dataset)]

        source_batch = source_train_dataset.get_data(
            source_index,
            h_flip=np.random.rand() > 0.5,
            target_im_size=np.random.choice([480, 576, 688, 864, 1200]))

        source_im_data = source_batch['im_data'].unsqueeze(0).to(device)
        source_proposals = source_batch['proposals']
        source_gt_boxes = source_batch['gt_boxes']
        if args.target_only:
            source_gt_labels = source_batch['gt_labels']
        else:
            source_gt_labels = source_batch['gt_labels'] + 20
        source_pos_cls = [i for i in range(80) if i in source_gt_labels]

        source_loss = 0
        for cls in np.random.choice(source_pos_cls, 2):
            indices = np.where(source_gt_labels.numpy() == cls)[0]
            here_gt_boxes = source_gt_boxes[indices]
            here_proposals, here_labels, _, pos_cnt, neg_cnt = sample_proposals(
                here_gt_boxes, source_proposals, args.bs // 2, args.pos_ratio)
            # plt.imshow(source_batch['raw_img'])
            # draw_box(here_proposals[:pos_cnt] / source_batch['im_scale'], 'black')
            # draw_box(here_proposals[pos_cnt:] / source_batch['im_scale'], 'yellow')
            # plt.show()
            here_proposals = here_proposals.to(device)
            here_labels = here_labels.to(device)
            here_loss = model(source_im_data, cls, here_proposals, here_labels)
            source_loss = source_loss + here_loss

            source_pos_prop_sum += pos_cnt
            source_neg_prop_sum += neg_cnt

        source_loss = source_loss / 2

        source_loss_sum += source_loss.item()
        source_loss.backward()

        clip_gradient(model, 10.0)
        optimizer.step()
        optimizer.zero_grad()

        if step % args.disp_interval == 0:
            end = time.time()
            source_loss_sum /= args.disp_interval
            source_pos_prop_sum /= args.disp_interval
            source_neg_prop_sum /= args.disp_interval
            log_message = "[%s][session %d][iter %4d] loss: %.4f, pos_prop: %.1f, neg_prop: %.1f, lr: %.2e, time: %.1f" % \
                          (args.net, args.session, step, source_loss_sum, source_pos_prop_sum, source_neg_prop_sum, lr, end - start)
            print(log_message)
            log_file.write(log_message + '\n')
            log_file.flush()
            source_loss_sum = 0
            source_pos_prop_sum = 0
            source_neg_prop_sum = 0
            start = time.time()

        if step in (args.max_iter * 4 // 7, args.max_iter * 6 // 7):
            adjust_learning_rate(optimizer, 0.1)
            lr *= 0.1

        if step % args.save_interval == 0 or step == args.max_iter:
            validate(model, target_val_dataset, args, device)
            save_name = os.path.join(
                output_dir, '{}_{}_{}.pth'.format(args.net, args.session,
                                                  step))
            checkpoint = dict()
            checkpoint['net'] = args.net
            checkpoint['session'] = args.session
            checkpoint['pooling_method'] = args.pooling_method
            checkpoint['iterations'] = step
            checkpoint['model'] = model.state_dict()

            save_checkpoint(checkpoint, save_name)
            print('save model: {}'.format(save_name))

    log_file.close()
Example #4
0
def train():
    args = parse_args()
    print('Called with args:')
    print(args)
    assert args.bs % 2 == 0

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    print(device)
    output_dir = args.save_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    target_only = args.target_only
    source_train_dataset = TDETDataset(['coco60_train2014', 'coco60_val2014'],
                                       args.data_dir,
                                       'eb',
                                       num_classes=60)
    target_train_dataset = TDETDataset(['voc07_trainval'],
                                       args.data_dir,
                                       'eb',
                                       num_classes=20)

    lr = args.lr

    if args.net == 'CAM_DET':
        model = CamDet(
            os.path.join(args.data_dir, 'pretrained_model/vgg16_caffe.pth')
            if not args.resume else None, 20 if target_only else 80,
            args.hidden_dim)
    else:
        raise Exception('network is not defined')

    optimizer = model.get_optimizer(args.lr)

    if args.resume:
        load_name = os.path.join(
            output_dir, '{}_{}_{}.pth'.format(args.net, args.checksession,
                                              args.checkiter))
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name)
        assert args.net == checkpoint['net']
        args.start_iter = checkpoint['iterations'] + 1
        model.load_state_dict(checkpoint['model'])
        print("loaded checkpoint %s" % (load_name))
        del checkpoint

    log_file_name = os.path.join(
        output_dir, 'log_{}_{}.txt'.format(args.net, args.session))
    if args.resume:
        log_file = open(log_file_name, 'a')
    else:
        log_file = open(log_file_name, 'w')
    log_file.write(str(args))
    log_file.write('\n')

    model.to(device)
    model.train()
    source_loss_sum = 0
    target_loss_sum = 0
    total_loss_sum = 0
    start = time.time()
    source_rand_perm = None
    target_rand_perm = None
    for step in range(args.start_iter, args.max_iter + 1):
        if source_rand_perm is None or step % len(source_train_dataset) == 1:
            source_rand_perm = np.random.permutation(len(source_train_dataset))
        if target_rand_perm is None or step % len(target_train_dataset) == 1:
            target_rand_perm = np.random.permutation(len(target_train_dataset))

        source_index = source_rand_perm[step % len(source_train_dataset)]
        target_index = target_rand_perm[step % len(target_train_dataset)]

        optimizer.zero_grad()
        if not target_only:
            source_batch = source_train_dataset.get_data(
                source_index,
                h_flip=np.random.rand() > 0.5,
                target_im_size=np.random.choice([480, 576, 688, 864, 1200]))

            source_im_data = source_batch['im_data'].unsqueeze(0).to(device)
            source_gt_labels = source_batch['gt_labels'] + 20
            source_pos_cls = [i for i in range(80) if i in source_gt_labels]
            source_pos_cls = torch.tensor(np.random.choice(
                source_pos_cls,
                min(args.bs, len(source_pos_cls)),
                replace=False),
                                          dtype=torch.long,
                                          device=device)

            source_loss, _, _ = model(source_im_data, source_pos_cls)
            source_loss_sum += source_loss.item()

        target_batch = target_train_dataset.get_data(
            target_index,
            h_flip=np.random.rand() > 0.5,
            target_im_size=np.random.choice([480, 576, 688, 864, 1200]))

        target_im_data = target_batch['im_data'].unsqueeze(0).to(device)
        target_gt_labels = target_batch['gt_labels']
        target_pos_cls = [i for i in range(80) if i in target_gt_labels]
        target_pos_cls = torch.tensor(np.random.choice(
            target_pos_cls, min(args.bs, len(target_pos_cls)), replace=False),
                                      dtype=torch.long,
                                      device=device)

        target_loss, _, _, _ = model(target_im_data, target_pos_cls)
        target_loss_sum += target_loss.item()
        if args.target_only:
            total_loss = target_loss
        else:
            total_loss = (source_loss + target_loss) * 0.5
        total_loss.backward()
        total_loss_sum += total_loss.item()
        clip_gradient(model, 10.0)
        optimizer.step()

        if step % args.disp_interval == 0:
            end = time.time()
            total_loss_sum /= args.disp_interval
            source_loss_sum /= args.disp_interval
            target_loss_sum /= args.disp_interval
            log_message = "[%s][session %d][iter %4d] loss: %.8f, src_loss: %.8f, tar_loss: %.8f, lr: %.2e, time: %.1f" % \
                          (args.net, args.session, step, total_loss_sum, source_loss_sum, target_loss_sum, lr, end - start)
            print(log_message)
            log_file.write(log_message + '\n')
            log_file.flush()
            total_loss_sum = 0
            source_loss_sum = 0
            target_loss_sum = 0
            start = time.time()

        if step in (args.max_iter * 4 // 7, args.max_iter * 6 // 7):
            adjust_learning_rate(optimizer, 0.1)
            lr *= 0.1

        if step % args.save_interval == 0 or step == args.max_iter:
            save_name = os.path.join(
                output_dir, '{}_{}_{}.pth'.format(args.net, args.session,
                                                  step))
            checkpoint = dict()
            checkpoint['net'] = args.net
            checkpoint['session'] = args.session
            checkpoint['iterations'] = step
            checkpoint['model'] = model.state_dict()

            save_checkpoint(checkpoint, save_name)
            print('save model: {}'.format(save_name))

    log_file.close()