cfg.POOLING_MODE = checkpoint['pooling_mode']
        print("loaded checkpoint %s" % (load_name))

    if args.use_det:
        iters_per_epoch = 2*int(train_size / args.batch_size)
    else:
        iters_per_epoch = int(train_size / args.batch_size)

    for epoch in range(args.start_epoch, args.max_epochs + 1):
        
        model.train()
        loss_temp = 0
        start = time.time()

        if epoch % (args.lr_decay_step + 1) == 0:
            adjust_learning_rate(optimizer, args.lr_decay_gamma)
            lr *= args.lr_decay_gamma

        if args.use_det:
            vid_data_iter = iter(vid_dataloader)
            det_data_iter = iter(det_dataloader)
            
        else:
            data_iter = iter(dataloader) 
        for step in range(iters_per_epoch):
            
            if args.use_det:
        # Alternate training with samples from VID and DET
                if step%2!=0:
                    data = next(vid_data_iter)
                else:
Beispiel #2
0
def train():
    args = parse_args()

    print('Called with args:')
    print(args)
    np.random.seed(4)
    torch.manual_seed(2017)
    torch.cuda.manual_seed(1086)

    output_dir = args.save_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    train_dataset = TDetDataset([args.dataset + '_train'],
                                training=True,
                                multi_scale=args.multiscale,
                                rotation=args.rotation,
                                pd=args.pd,
                                warping=args.warping,
                                prop_method=args.prop_method,
                                prop_min_scale=args.prop_min_scale,
                                prop_topk=args.prop_topk)
    val_dataset = TDetDataset([args.dataset + '_val'], training=False)
    tval_dataset = TDetDataset(['coco_voc_val'], training=False)

    lr = args.lr

    res_path = 'data/pretrained_model/resnet101_caffe.pth'
    vgg_path = 'data/pretrained_model/vgg16_caffe.pth'
    if args.net == 'UBR_VGG':
        UBR = UBR_VGG(vgg_path, not args.fc, not args.not_freeze,
                      args.no_dropout)
    elif args.net == 'UBR_RES':
        UBR = UBR_RES(res_path, 1, not args.fc)
    elif args.net == 'UBR_RES_FC2':
        UBR = UBR_RES_FC2(res_path, 1)
    elif args.net == 'UBR_RES_FC3':
        UBR = UBR_RES_FC3(res_path, 1)

    else:
        print("network is not defined")
        pdb.set_trace()

    UBR.create_architecture()

    params = []
    for key, value in dict(UBR.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                params += [{
                    'params': [value],
                    'lr': lr * 2,
                    'weight_decay': 0
                }]
            else:
                params += [{
                    'params': [value],
                    'lr': lr,
                    'weight_decay': 0 if args.no_wd else 0.0005
                }]

    optimizer = torch.optim.SGD(params, momentum=0.9)

    patience = 0
    last_optima = 999
    if args.resume:
        load_name = os.path.join(
            output_dir, '{}_{}_{}.pth'.format(args.net, args.checksession,
                                              args.checkepoch))
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name)
        assert args.net == checkpoint['net']
        args.start_epoch = checkpoint['epoch']
        UBR.load_state_dict(checkpoint['model'])
        if not args.no_optim:
            if 'patience' in checkpoint:
                patience = checkpoint['patience']
            if 'last_optima' in checkpoint:
                last_optima = checkpoint['last_optima']
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr = optimizer.param_groups[0]['lr']
        print("loaded checkpoint %s" % (load_name))

    log_file_name = os.path.join(
        output_dir, 'log_{}_{}.txt'.format(args.net, args.session))
    if args.resume:
        log_file = open(log_file_name, 'a')
    else:
        log_file = open(log_file_name, 'w')
    log_file.write(str(args))
    log_file.write('\n')

    UBR.cuda()

    if args.loss == 'smoothl1':
        criterion = UBR_SmoothL1Loss(args.iou_th)
    elif args.loss == 'iou':
        criterion = UBR_IoULoss(args.iou_th)

    if not args.use_prop:
        random_box_generator = NaturalUniformBoxGenerator(
            args.iou_th,
            pos_th=args.alpha,
            scale_min=1 - args.beta,
            scale_max=1 + args.beta)

    for epoch in range(args.start_epoch, args.max_epochs + 1):
        # setting to train mode
        UBR.train()
        loss_temp = 0

        effective_iteration = 0
        start = time.time()

        mean_boxes_per_iter = 0
        rand_perm = np.random.permutation(len(train_dataset))
        for step in range(1, len(train_dataset) + 1):
            index = rand_perm[step - 1]
            im_data, gt_boxes, box_labels, proposals, prop_scores, image_level_label, im_scale, raw_img, im_id, _ = train_dataset[
                index]

            data_height = im_data.size(1)
            data_width = im_data.size(2)
            im_data = Variable(im_data.unsqueeze(0).cuda())
            num_gt_box = gt_boxes.size(0)
            UBR.zero_grad()

            # generate random box from given gt box
            # the shape of rois is (n, 5), the first column is not used
            # so, rois[:, 1:5] is [xmin, ymin, xmax, ymax]
            num_per_base = 50
            if num_gt_box > 4:
                num_per_base = 200 // num_gt_box

            if args.use_prop:
                proposals = sample_pos_prop(proposals, gt_boxes, args.iou_th)
                if proposals is None:
                    # log_file.write('@@@@ no box @@@@\n')
                    # print('@@@@@ no box @@@@@')
                    continue

                rois = torch.zeros((proposals.size(0), 5))
                rois[:, 1:] = proposals
            else:
                rois = torch.zeros((num_per_base * num_gt_box, 5))
                cnt = 0
                for i in range(num_gt_box):
                    here = random_box_generator.get_rand_boxes(
                        gt_boxes[i, :], num_per_base, data_height, data_width)
                    if here is None:
                        continue
                    rois[cnt:cnt + here.size(0), :] = here
                    cnt += here.size(0)
                if cnt == 0:
                    log_file.write('@@@@ no box @@@@\n')
                    print('@@@@@ no box @@@@@')
                    continue
                rois = rois[:cnt, :]

            plt.imshow(raw_img)
            plt.show()
            continue

            mean_boxes_per_iter += rois.size(0)
            rois = Variable(rois.cuda())
            gt_boxes = Variable(gt_boxes.cuda())

            bbox_pred, shared_feat = UBR(im_data, rois)

            #refined_boxes = inverse_transform(rois[:, 1:].data, bbox_pred.data)
            #plt.imshow(raw_img)
            #draw_box(rois[:, 1:].data / im_scale)
            #draw_box(refined_boxes / im_scale, 'yellow')
            #draw_box(gt_boxes.data / im_scale, 'black')
            #plt.show()
            loss, num_selected_rois, num_rois, refined_rois = criterion(
                rois[:, 1:5], bbox_pred, gt_boxes)

            if loss is None:
                loss_temp = 1000000
                loss = Variable(torch.zeros(1).cuda())
                print('zero mached')

            loss = loss.mean()
            loss_temp += loss.data[0]

            # backward
            optimizer.zero_grad()

            loss.backward()
            clip_gradient([UBR], 10.0)

            optimizer.step()
            effective_iteration += 1

            if step % args.disp_interval == 0:
                end = time.time()
                loss_temp /= effective_iteration
                mean_boxes_per_iter /= effective_iteration

                print(
                    "[net %s][session %d][epoch %2d][iter %4d] loss: %.4f, lr: %.2e, time: %.1f, boxes: %.1f"
                    % (args.net, args.session, epoch, step, loss_temp, lr,
                       end - start, mean_boxes_per_iter))
                log_file.write(
                    "[net %s][session %d][epoch %2d][iter %4d] loss: %.4f, lr: %.2e, time: %.1f, boxes: %.1f\n"
                    % (args.net, args.session, epoch, step, loss_temp, lr,
                       end - start, mean_boxes_per_iter))
                loss_temp = 0
                effective_iteration = 0
                mean_boxes_per_iter = 0
                start = time.time()

            if math.isnan(loss_temp):
                print('@@@@@@@@@@@@@@nan@@@@@@@@@@@@@')
                log_file.write('@@@@@@@nan@@@@@@@@\n')
                return

        val_loss = validate(UBR,
                            None if args.use_prop else random_box_generator,
                            criterion, val_dataset, args)
        tval_loss = validate(UBR,
                             None if args.use_prop else random_box_generator,
                             criterion, tval_dataset, args)
        print('[net %s][session %d][epoch %2d] validation loss: %.4f' %
              (args.net, args.session, epoch, val_loss))
        log_file.write(
            '[net %s][session %d][epoch %2d] validation loss: %.4f\n' %
            (args.net, args.session, epoch, val_loss))
        print(
            '[net %s][session %d][epoch %2d] transfer validation loss: %.4f' %
            (args.net, args.session, epoch, tval_loss))
        log_file.write(
            '[net %s][session %d][epoch %2d] transfer validation loss: %.4f\n'
            % (args.net, args.session, epoch, tval_loss))

        log_file.flush()

        if args.auto_decay:
            if last_optima - val_loss < 0.001:
                patience += 1
            if last_optima > val_loss:
                last_optima = val_loss

            if patience >= args.decay_patience:
                adjust_learning_rate(optimizer, args.lr_decay_gamma)
                lr *= args.lr_decay_gamma
                patience = 0
        else:
            if epoch % args.lr_decay_step == 0:
                adjust_learning_rate(optimizer, args.lr_decay_gamma)
                lr *= args.lr_decay_gamma

        if epoch % args.save_interval == 0 or lr < 0.000005:
            save_name = os.path.join(
                output_dir, '{}_{}_{}.pth'.format(args.net, args.session,
                                                  epoch))
            checkpoint = dict()
            checkpoint['net'] = args.net
            checkpoint['session'] = args.session
            checkpoint['epoch'] = epoch + 1
            checkpoint['model'] = UBR.state_dict()
            checkpoint['optimizer'] = optimizer.state_dict()
            checkpoint['patience'] = patience
            checkpoint['last_optima'] = last_optima

            save_checkpoint(checkpoint, save_name)
            print('save model: {}'.format(save_name))

        if lr < 0.000005:
            break

    log_file.close()