Example #1
0
        print("[session %d][epoch %2d][iter %4d/%4d] loss: %.4f, lr: %.2e" \
                                % (args.session, epoch, step, iters_per_epoch, loss_temp, lr))
        print("\t\t\tfg/bg=(%d/%d), time cost: %f" % (fg_cnt, bg_cnt, end-start))
        print("\t\t\trpn_cls: %.4f, rpn_box: %.4f, rcnn_cls: %.4f, rcnn_box %.4f" \
                      % (loss_rpn_cls, loss_rpn_box, loss_rcnn_cls, loss_rcnn_box))
        if args.use_tfboard:
          info = {
            'loss': loss_temp,
            'loss_rpn_cls': loss_rpn_cls,
            'loss_rpn_box': loss_rpn_box,
            'loss_rcnn_cls': loss_rcnn_cls,
            'loss_rcnn_box': loss_rcnn_box,
            'learning_rate': lr
          }
          for tag, value in info.items():
            logger.scalar_summary(tag, value, (epoch-1)*iters_per_epoch+step)

        loss_temp = 0
        start = time.time()

    if (epoch % args.checkpoint_interval == 0) or (epoch == args.max_epochs):
      if args.mGPUs:
        save_name = os.path.join(output_dir, 'faster_rcnn_{}_{}_{}.pth'.format(args.session, epoch, step))
        save_checkpoint({
          'session': args.session,
          'epoch': epoch + 1,
          'model': fasterRCNN.module.state_dict(),
          'optimizer': optimizer.state_dict(),
          'pooling_mode': cfg.POOLING_MODE,
          'class_agnostic': args.class_agnostic,
        }, save_name)
Example #2
0
                print("[session %d][epoch %2d][iter %4d/%4d] loss: %.4f, lr: %.2e" \
                      % (args.session, epoch, step, iters_per_epoch, loss_temp, lr))
                print("\t\t\tfg/bg=(%d/%d), time cost: %f" % (fg_cnt, bg_cnt, end-start))
                print("\t\t\trpn_cls: %.4f, rpn_box: %.4f, rcnn_cls: %.4f, rcnn_box %.4f" \
                      % (loss_rpn_cls, loss_rpn_box, loss_rcnn_cls, loss_rcnn_box))
                if args.use_tfboard:
                    info = {
                        'loss': loss_temp,
                        'loss_rpn_cls': loss_rpn_cls,
                        'loss_rpn_box': loss_rpn_box,
                        'loss_rcnn_cls': loss_rcnn_cls,
                        'loss_rcnn_box': loss_rcnn_box
                    }
                    for tag, value in info.items():
                        logger.scalar_summary(tag, value, step)

                loss_temp = 0
                start = time.time()

        if args.mGPUs:
            save_name = os.path.join(output_dir, 'faster_rcnn_{}_{}_{}.pth'.format(args.session, epoch, step))
            save_checkpoint({
                'session': args.session,
                'epoch': epoch + 1,
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'pooling_mode': cfg.POOLING_MODE,
                'class_agnostic': args.class_agnostic,
            }, save_name)
        else:
Example #3
0
                print("\t\tTime Details: RPN: %.3f, Pre-RoI: %.3f, RoI: %.3f, Subnet: %.3f" \
                      % (time_measure[0], time_measure[1], time_measure[2], time_measure[3]))
                print("\t\trpn_cls: %.4f, rpn_box: %.4f, rcnn_cls: %.4f, rcnn_box %.4f" \
                      % (loss_rpn_cls, loss_rpn_box, loss_rcnn_cls, loss_rcnn_box))
                if args.use_tfboard:
                    info = {
                        'Total Loss': loss_temp,
                        'loss_rpn_cls': loss_rpn_cls,
                        'loss_rpn_box': loss_rpn_box,
                        'loss_rcnn_cls': loss_rcnn_cls,
                        'loss_rcnn_box': loss_rcnn_box,
                        'Learning Rate': lr,
                        'Time Cost': end - start
                    }
                    for tag, value in info.items():
                        logger.scalar_summary(
                            tag, value, step + ((epoch - 1) * iters_per_epoch))

                loss_temp = 0
                start = time.time()

                save_name = os.path.join(
                    output_dir,
                    '{}_{}_{}_{}.pth'.format(prefix, args.session, epoch,
                                             step))
                save_checkpoint(
                    {
                        'session': args.session,
                        'epoch': epoch,
                        'model': _RCNN.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'pooling_mode': cfg.POOLING_MODE,
Example #4
0
def train():
    # check cuda devices
    if not torch.cuda.is_available():
        assert RuntimeError(
            "Training can only be done by GPU. Please use --cuda to enable training."
        )
    if torch.cuda.is_available() and not args.cuda:
        assert RuntimeError(
            "You have a CUDA device, so you should probably run with --cuda")

    # init random seed
    np.random.seed(cfg.RNG_SEED)

    # init logger
    # TODO: RESUME LOGGER
    if args.use_tfboard:
        from model.utils.logger import Logger
        # Set the logger
        current_t = time.strftime("%Y_%m_%d") + "_" + time.strftime("%H:%M:%S")
        logger = Logger(
            os.path.join(
                '.', 'logs', current_t + "_" + args.frame + "_" +
                args.dataset + "_" + args.net))

    # init dataset
    imdb, roidb, ratio_list, ratio_index, cls_list = combined_roidb(
        args.imdb_name)
    train_size = len(roidb)
    print('{:d} roidb entries'.format(len(roidb)))
    sampler_batch = sampler(train_size, args.batch_size)
    iters_per_epoch = int(train_size / args.batch_size)
    if args.frame in {"fpn", "faster_rcnn", "efc_det"}:
        dataset = fasterrcnnbatchLoader(roidb, ratio_list, ratio_index, args.batch_size, \
                                        len(cls_list), training=True, cls_list=cls_list,
                                        augmentation=cfg.TRAIN.COMMON.AUGMENTATION)
    elif args.frame in {"ssd"}:
        dataset = ssdbatchLoader(roidb, ratio_list, ratio_index, args.batch_size, \
                                 len(cls_list), training=True, cls_list=cls_list,
                                 augmentation=cfg.TRAIN.COMMON.AUGMENTATION)
    elif args.frame in {"ssd_vmrn", "vam"}:
        dataset = svmrnbatchLoader(roidb, ratio_list, ratio_index, args.batch_size, \
                                   len(cls_list), training=True, cls_list=cls_list,
                                   augmentation=cfg.TRAIN.COMMON.AUGMENTATION)
    elif args.frame in {"faster_rcnn_vmrn"}:
        dataset = fvmrnbatchLoader(roidb, ratio_list, ratio_index, args.batch_size, \
                                   len(cls_list), training=True, cls_list=cls_list,
                                   augmentation=cfg.TRAIN.COMMON.AUGMENTATION)
    elif args.frame in {"fcgn"}:
        dataset = fcgnbatchLoader(roidb, ratio_list, ratio_index, args.batch_size, \
                                  len(cls_list), training=True, cls_list=cls_list,
                                  augmentation=cfg.TRAIN.COMMON.AUGMENTATION)
    elif args.frame in {"all_in_one"}:
        dataset = fallinonebatchLoader(roidb, ratio_list, ratio_index, args.batch_size, \
                                       len(cls_list), training=True, cls_list=cls_list,
                                       augmentation=cfg.TRAIN.COMMON.AUGMENTATION)
    elif args.frame in {"mgn"}:
        dataset = roignbatchLoader(roidb, ratio_list, ratio_index, args.batch_size, \
                                   len(cls_list), training=True, cls_list=cls_list,
                                   augmentation=cfg.TRAIN.COMMON.AUGMENTATION)
    else:
        raise RuntimeError
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             sampler=sampler_batch,
                                             num_workers=args.num_workers)

    args.iter_per_epoch = int(len(roidb) / args.batch_size)

    # init output directory for model saving
    output_dir = args.save_dir + "/" + args.dataset + "/" + args.net
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if args.vis:
        visualizer = dataViewer(cls_list)
        data_vis_dir = os.path.join(args.save_dir, args.dataset, 'data_vis',
                                    'train')
        if not os.path.exists(data_vis_dir):
            os.makedirs(data_vis_dir)
        id_number_to_name = {}
        for r in roidb:
            id_number_to_name[r["img_id"]] = r["image"]

    # init network
    Network, optimizer = init_network(args, len(cls_list))

    # init variables
    current_result, best_result, loss_temp, loss_rpn_cls, loss_rpn_box, loss_rcnn_cls, loss_rcnn_box, loss_rel_pred, \
    loss_grasp_box, loss_grasp_cls, fg_cnt, bg_cnt, fg_grasp_cnt, bg_grasp_cnt = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
    save_flag, rois, rpn_loss_cls, rpn_loss_box, rel_loss_cls, cls_prob, bbox_pred, rel_cls_prob, loss_bbox, loss_cls, \
    rois_label, grasp_cls_loss, grasp_bbox_loss, grasp_conf_label = \
        False, None, None, None, None, None, None, None, None, None, None, None, None, None

    # initialize step counter
    if args.resume:
        step = args.checkpoint
    else:
        step = 0

    for epoch in range(args.checkepoch, args.max_epochs + 1):
        # setting to train mode
        Network.train()

        start_epoch_time = time.time()
        start = time.time()

        data_iter = iter(dataloader)
        while (True):

            if step >= iters_per_epoch:
                break

            # get data batch
            data_batch = next(data_iter)
            if args.vis:
                for i in range(data_batch[0].size(0)):
                    data_list = [
                        data_batch[d][i] for d in range(len(data_batch))
                    ]
                    im_vis = vis_gt(data_list,
                                    visualizer,
                                    args.frame,
                                    train_mode=True)
                    # img_name = id_number_to_name[data_batch[1][i][4].item()].split("/")[-1]
                    img_name = str(int(data_batch[1][i][4].item())) + ".jpg"
                    # When using cv2.imwrite, channel order should be BGR
                    cv2.imwrite(os.path.join(data_vis_dir, img_name),
                                im_vis[:, :, ::-1])
            # ship to cuda
            if args.cuda:
                data_batch = makeCudaData(data_batch)

            # setting gradients to zeros
            Network.zero_grad()
            optimizer.zero_grad()

            # forward process
            if args.frame == 'faster_rcnn_vmrn':
                rois, cls_prob, bbox_pred, rel_cls_prob, rpn_loss_cls, rpn_loss_box, loss_cls, \
                loss_bbox, rel_loss_cls, reg_loss, rois_label = Network(data_batch)
                loss = (rpn_loss_cls + rpn_loss_box + loss_cls + loss_bbox +
                        reg_loss + rel_loss_cls).mean()
            elif args.frame == 'faster_rcnn' or args.frame == 'fpn':
                rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_box, loss_cls, loss_bbox, \
                rois_label = Network(data_batch)
                loss = (rpn_loss_cls + rpn_loss_box + loss_cls +
                        loss_bbox).mean()
            elif args.frame == 'fcgn':
                bbox_pred, cls_prob, loss_bbox, loss_cls, rois_label, rois = Network(
                    data_batch)
                loss = (loss_bbox + loss_cls).mean()
            elif args.frame == 'mgn':
                rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_box, loss_cls, loss_bbox, rois_label, grasp_loc, \
                grasp_prob, grasp_bbox_loss, grasp_cls_loss, grasp_conf_label, grasp_all_anchors = Network(data_batch)
                loss = rpn_loss_box.mean() + rpn_loss_cls.mean() + loss_cls.mean() + loss_bbox.mean() + \
                       cfg.MGN.OBJECT_GRASP_BALANCE * (grasp_bbox_loss.mean() + grasp_cls_loss.mean())
            elif args.frame == 'all_in_one':
                rois, cls_prob, bbox_pred, rel_cls_prob, rpn_loss_cls, rpn_loss_box, loss_cls, loss_bbox, rel_loss_cls, reg_loss, rois_label, \
                grasp_loc, grasp_prob, grasp_bbox_loss, grasp_cls_loss, grasp_conf_label, grasp_all_anchors = Network(
                    data_batch)
                loss = (rpn_loss_box + rpn_loss_cls + loss_cls + loss_bbox + rel_loss_cls + reg_loss \
                        + cfg.MGN.OBJECT_GRASP_BALANCE * grasp_bbox_loss + grasp_cls_loss).mean()

            elif args.frame in {'ssd', 'efc_det'}:
                bbox_pred, cls_prob, loss_bbox, loss_cls = Network(data_batch)
                loss = loss_bbox.mean() + loss_cls.mean()
            elif args.frame == 'ssd_vmrn' or args.frame == 'vam':
                bbox_pred, cls_prob, rel_result, loss_bbox, loss_cls, rel_loss_cls, reg_loss = Network(
                    data_batch)
                loss = (loss_cls + loss_bbox + rel_loss_cls + reg_loss).mean()
            loss_temp += loss.data.item()

            # backward process
            loss.backward()
            g_norm = gradient_norm(Network)
            if args.net == "vgg16":
                clip_gradient(Network, 10.)
            # print("Gradient norm:{:.3f}".format(g_norm))
            optimizer.step()
            step += 1

            # record training information
            if len(args.mGPUs) > 0:
                if rpn_loss_cls is not None and isinstance(
                        rpn_loss_cls, torch.Tensor):
                    loss_rpn_cls += rpn_loss_cls.mean().data[0].item()
                if rpn_loss_box is not None and isinstance(
                        rpn_loss_box, torch.Tensor):
                    loss_rpn_box += rpn_loss_box.mean().data[0].item()
                if loss_cls is not None and isinstance(loss_cls, torch.Tensor):
                    loss_rcnn_cls += loss_cls.mean().data[0].item()
                if loss_bbox is not None and isinstance(
                        loss_bbox, torch.Tensor):
                    loss_rcnn_box += loss_bbox.mean().data[0].item()
                if rel_loss_cls is not None and isinstance(
                        rel_loss_cls, torch.Tensor):
                    loss_rel_pred += rel_loss_cls.mean().data[0].item()
                if grasp_cls_loss is not None and isinstance(
                        grasp_cls_loss, torch.Tensor):
                    loss_grasp_cls += grasp_cls_loss.mean().data[0].item()
                if grasp_bbox_loss is not None and isinstance(
                        grasp_bbox_loss, torch.Tensor):
                    loss_grasp_box += grasp_bbox_loss.mean().data[0].item()
                if rois_label is not None and isinstance(
                        rois_label, torch.Tensor):
                    tempfg = torch.sum(rois_label.data.ne(0))
                    fg_cnt += tempfg
                    bg_cnt += (rois_label.data.numel() - tempfg)
                if grasp_conf_label is not None and isinstance(
                        grasp_conf_label, torch.Tensor):
                    tempfg = torch.sum(grasp_conf_label.data.ne(0))
                    fg_grasp_cnt += tempfg
                    bg_grasp_cnt += (grasp_conf_label.data.numel() - tempfg)
            else:
                if rpn_loss_cls is not None and isinstance(
                        rpn_loss_cls, torch.Tensor):
                    loss_rpn_cls += rpn_loss_cls.item()
                if rpn_loss_cls is not None and isinstance(
                        rpn_loss_cls, torch.Tensor):
                    loss_rpn_box += rpn_loss_box.item()
                if loss_cls is not None and isinstance(loss_cls, torch.Tensor):
                    loss_rcnn_cls += loss_cls.item()
                if loss_bbox is not None and isinstance(
                        loss_bbox, torch.Tensor):
                    loss_rcnn_box += loss_bbox.item()
                if rel_loss_cls is not None and isinstance(
                        rel_loss_cls, torch.Tensor):
                    loss_rel_pred += rel_loss_cls.item()
                if grasp_cls_loss is not None and isinstance(
                        grasp_cls_loss, torch.Tensor):
                    loss_grasp_cls += grasp_cls_loss.item()
                if grasp_bbox_loss is not None and isinstance(
                        grasp_bbox_loss, torch.Tensor):
                    loss_grasp_box += grasp_bbox_loss.item()
                if rois_label is not None and isinstance(
                        rois_label, torch.Tensor):
                    tempfg = torch.sum(rois_label.data.ne(0))
                    fg_cnt += tempfg
                    bg_cnt += (rois_label.data.numel() - tempfg)
                if grasp_conf_label is not None and isinstance(
                        grasp_conf_label, torch.Tensor):
                    tempfg = torch.sum(grasp_conf_label.data.ne(0))
                    fg_grasp_cnt += tempfg
                    bg_grasp_cnt += (grasp_conf_label.data.numel() - tempfg)

            if Network.iter_counter % args.disp_interval == 0:
                end = time.time()
                loss_temp /= args.disp_interval
                loss_rpn_cls /= args.disp_interval
                loss_rpn_box /= args.disp_interval
                loss_rcnn_cls /= args.disp_interval
                loss_rcnn_box /= args.disp_interval
                loss_rel_pred /= args.disp_interval
                loss_grasp_cls /= args.disp_interval
                loss_grasp_box /= args.disp_interval

                print("[session %d][epoch %2d][iter %4d/%4d] \n\t\t\tloss: %.4f, lr: %.2e" \
                      % (args.session, epoch, step, iters_per_epoch, loss_temp, optimizer.param_groups[0]['lr']))
                print('\t\t\ttime cost: %f' % (end - start, ))
                if rois_label is not None:
                    print("\t\t\tfg/bg=(%d/%d)" % (fg_cnt, bg_cnt))
                if grasp_conf_label is not None:
                    print("\t\t\tgrasp_fg/grasp_bg=(%d/%d)" %
                          (fg_grasp_cnt, bg_grasp_cnt))
                if rpn_loss_box is not None and rpn_loss_cls is not None:
                    print("\t\t\trpn_cls: %.4f\n\t\t\trpn_box: %.4f\n\t\t\trcnn_cls: %.4f\n\t\t\trcnn_box %.4f" \
                          % (loss_rpn_cls, loss_rpn_box, loss_rcnn_cls, loss_rcnn_box))
                else:
                    print("\t\t\trcnn_cls: %.4f\n\t\t\trcnn_box %.4f" \
                          % (loss_rcnn_cls, loss_rcnn_box))
                if rel_loss_cls is not None:
                    print("\t\t\trel_loss %.4f" \
                          % (loss_rel_pred,))
                if grasp_cls_loss is not None and grasp_bbox_loss is not None:
                    print("\t\t\tgrasp_cls: %.4f\n\t\t\tgrasp_box %.4f" \
                          % (loss_grasp_cls, loss_grasp_box))
                if args.use_tfboard:
                    info = {
                        'loss': loss_temp,
                        'loss_rcnn_cls': loss_rcnn_cls,
                        'loss_rcnn_box': loss_rcnn_box,
                    }
                    if rpn_loss_cls:
                        info['loss_rpn_cls'] = loss_rpn_cls
                    if rpn_loss_box:
                        info['loss_rpn_box'] = loss_rpn_box
                    if rel_loss_cls:
                        info['loss_rel_pred'] = loss_rel_pred
                    for tag, value in info.items():
                        logger.scalar_summary(tag, value, Network.iter_counter)

                loss_temp = 0.
                loss_rpn_cls = 0.
                loss_rpn_box = 0.
                loss_rcnn_cls = 0.
                loss_rcnn_box = 0.
                loss_rel_pred = 0.
                loss_grasp_box = 0.
                loss_grasp_cls = 0.
                fg_cnt = 0.
                bg_cnt = 0.
                fg_grasp_cnt = 0.
                bg_grasp_cnt = 0.
                start = time.time()

            # adjust learning rate
            if args.lr_decay_step == 0:
                # clr = lr / (1 + decay * n) -> lr_n / lr_n+1 = (1 + decay * (n+1)) / (1 + decay * n)
                decay = (1 + args.lr_decay_gamma * Network.iter_counter) / (
                    1 + args.lr_decay_gamma * (Network.iter_counter + 1))
                adjust_learning_rate(optimizer, decay)
            elif Network.iter_counter % (args.lr_decay_step) == 0:
                adjust_learning_rate(optimizer, args.lr_decay_gamma)

            # test and save
            if (Network.iter_counter -
                    1) % cfg.TRAIN.COMMON.SNAPSHOT_ITERS == 0:
                # test network and record results

                if cfg.TRAIN.COMMON.SNAPSHOT_AFTER_TEST:
                    Network.eval()
                    with torch.no_grad():
                        current_result = evalute_model(Network,
                                                       args.imdbval_name, args)
                    torch.cuda.empty_cache()
                    if args.use_tfboard:
                        for key in current_result.keys():
                            logger.scalar_summary(key, current_result[key],
                                                  Network.iter_counter)
                    Network.train()
                    if current_result["Main_Metric"] > best_result:
                        best_result = current_result["Main_Metric"]
                        save_flag = True
                else:
                    save_flag = True

                if save_flag:
                    save_name = os.path.join(
                        output_dir, args.frame +
                        '_{}_{}_{}.pth'.format(args.session, epoch, step))
                    save_checkpoint(
                        {
                            'session': args.session,
                            'model': Network.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'pooling_mode': cfg.RCNN_COMMON.POOLING_MODE,
                            'class_agnostic': args.class_agnostic,
                        }, save_name)
                    print('save model: {}'.format(save_name))
                    save_flag = False

        end_epoch_time = time.time()
        print("Epoch finished. Time costing: ",
              end_epoch_time - start_epoch_time, "s")
        step = 0  # reset step counter
Example #5
0
def train(dataset="kaggle_pna",
          train_ds="train",
          arch="couplenet",
          net="res152",
          start_epoch=1,
          max_epochs=20,
          disp_interval=100,
          save_dir="save",
          num_workers=4,
          cuda=True,
          large_scale=False,
          mGPUs=True,
          batch_size=4,
          class_agnostic=False,
          anchor_scales=4,
          optimizer="sgd",
          lr_decay_step=10,
          lr_decay_gamma=.1,
          session=1,
          resume=False,
          checksession=1,
          checkepoch=1,
          checkpoint=0,
          use_tfboard=False,
          flip_prob=0.0,
          scale=0.0,
          scale_prob=0.0,
          translate=0.0,
          translate_prob=0.0,
          angle=0.0,
          dist="cont",
          rotate_prob=0.0,
          shear_factor=0.0,
          shear_prob=0.0,
          rpn_loss_cls_wt=1,
          rpn_loss_box_wt=1,
          RCNN_loss_cls_wt=1,
          RCNN_loss_bbox_wt=1,
          **kwargs):
    print("Train Arguments: {}".format(locals()))

    # Import network definition
    if arch == 'rcnn':
        from model.faster_rcnn.resnet import resnet
    elif arch == 'rfcn':
        from model.rfcn.resnet_atrous import resnet
    elif arch == 'couplenet':
        from model.couplenet.resnet_atrous import resnet

    from roi_data_layer.pnaRoiBatchLoader import roibatchLoader
    from roi_data_layer.pna_roidb import combined_roidb

    print('Called with kwargs:')
    print(kwargs)

    # Set up logger
    if use_tfboard:
        from model.utils.logger import Logger
        # Set the logger
        logger = Logger('./logs')

    # Anchor settings: ANCHOR_SCALES: [8, 16, 32] or [4, 8, 16, 32]
    if anchor_scales == 3:
        scales = [8, 16, 32]
    elif anchor_scales == 4:
        scales = [4, 8, 16, 32]

    # Dataset related settings: MAX_NUM_GT_BOXES: 20, 30, 50
    if train_ds == "train":
        imdb_name = "pna_2018_train"
    elif train_ds == "trainval":
        imdb_name = "pna_2018_trainval"

    set_cfgs = [
        'ANCHOR_SCALES',
        str(scales), 'ANCHOR_RATIOS', '[0.5,1,2]', 'MAX_NUM_GT_BOXES', '30'
    ]

    import model
    model_repo_path = os.path.dirname(
        os.path.dirname(os.path.dirname(model.__file__)))

    cfg_file = "cfgs/{}_ls.yml".format(
        net) if large_scale else "cfgs/{}.yml".format(net)

    if cfg_file is not None:
        cfg_from_file(os.path.join(model_repo_path, cfg_file))
    if set_cfgs is not None:
        cfg_from_list(set_cfgs)

    train_kwargs = kwargs.pop("TRAIN", None)
    resnet_kwargs = kwargs.pop("RESNET", None)
    mobilenet_kwargs = kwargs.pop("MOBILENET", None)

    if train_kwargs is not None:
        for key, value in train_kwargs.items():
            cfg["TRAIN"][key] = value

    if resnet_kwargs is not None:
        for key, value in resnet_kwargs.items():
            cfg["RESNET"][key] = value

    if mobilenet_kwargs is not None:
        for key, value in mobilenet_kwargs.items():
            cfg["MOBILENET"][key] = value

    if kwargs is not None:
        for key, value in kwargs.items():
            cfg[key] = value

    print('Using config:')
    cfg.MODEL_DIR = os.path.abspath(cfg.MODEL_DIR)
    cfg.TRAIN_DATA_CLEAN_PATH = os.path.abspath(cfg.TRAIN_DATA_CLEAN_PATH)
    pprint.pprint(cfg)
    np.random.seed(cfg.RNG_SEED)
    print("LEARNING RATE: {}".format(cfg.TRAIN.LEARNING_RATE))

    # Warning to use cuda if available
    if torch.cuda.is_available() and not cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    # Train set
    # Note: Use validation set and disable the flipped to enable faster loading.
    cfg.TRAIN.USE_FLIPPED = True
    cfg.USE_GPU_NMS = cuda
    imdb, roidb, ratio_list, ratio_index = combined_roidb(imdb_name)
    train_size = len(roidb)

    print('{:d} roidb entries'.format(len(roidb)))

    # output_dir = os.path.join(save_dir, arch, net, dataset)
    output_dir = cfg.MODEL_DIR
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    sampler_batch = sampler(train_size, batch_size)

    dataset = roibatchLoader(roidb,
                             ratio_list,
                             ratio_index,
                             batch_size,
                             imdb.num_classes,
                             training=True)

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             sampler=sampler_batch,
                                             num_workers=num_workers)

    # Initilize the tensor holder
    im_data = torch.FloatTensor(1)
    im_info = torch.FloatTensor(1)
    num_boxes = torch.LongTensor(1)
    gt_boxes = torch.FloatTensor(1)

    # Copy tensors in CUDA memory
    if cuda:
        im_data = im_data.cuda()
        im_info = im_info.cuda()
        num_boxes = num_boxes.cuda()
        gt_boxes = gt_boxes.cuda()

    # Make variable
    im_data = Variable(im_data)
    im_info = Variable(im_info)
    num_boxes = Variable(num_boxes)
    gt_boxes = Variable(gt_boxes)

    if cuda:
        cfg.CUDA = True

    # Initilize the network:
    if net == 'vgg16':
        # model = vgg16(imdb.classes, pretrained=True, class_agnostic=args.class_agnostic)
        print("Pretrained model is not downloaded and network is not used")
    elif net == 'res18':
        model = resnet(imdb.classes,
                       18,
                       pretrained=False,
                       class_agnostic=class_agnostic)  # TODO: Check dim error
    elif net == 'res34':
        model = resnet(imdb.classes,
                       34,
                       pretrained=False,
                       class_agnostic=class_agnostic)  # TODO: Check dim error
    elif net == 'res50':
        model = resnet(imdb.classes,
                       50,
                       pretrained=False,
                       class_agnostic=class_agnostic)  # TODO: Check dim error
    elif net == 'res101':
        model = resnet(imdb.classes,
                       101,
                       pretrained=True,
                       class_agnostic=class_agnostic)
    elif net == 'res152':
        model = resnet(imdb.classes,
                       152,
                       pretrained=True,
                       class_agnostic=class_agnostic)
    else:
        print("network is not defined")
        pdb.set_trace()

    # Create network architecture
    model.create_architecture()

    # Update model parameters
    lr = cfg.TRAIN.LEARNING_RATE
    # tr_momentum = cfg.TRAIN.MOMENTUM
    # tr_momentum = args.momentum

    params = []
    for key, value in dict(model.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                params += [{'params': [value], 'lr': lr * (cfg.TRAIN.DOUBLE_BIAS + 1), \
                            'weight_decay': cfg.TRAIN.BIAS_DECAY and cfg.TRAIN.WEIGHT_DECAY or 0}]
            else:
                params += [{
                    'params': [value],
                    'lr': lr,
                    'weight_decay': cfg.TRAIN.WEIGHT_DECAY
                }]

    # Optimizer
    if optimizer == "adam":
        lr = lr * 0.1
        optimizer = torch.optim.Adam(params)

    elif optimizer == "sgd":
        optimizer = torch.optim.SGD(params, momentum=cfg.TRAIN.MOMENTUM)

    # Resume training
    if resume:
        load_name = os.path.join(
            output_dir, '{}_{}_{}_{}.pth'.format(arch, checksession,
                                                 checkepoch, checkpoint))
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name)
        session = checkpoint['session'] + 1
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr = optimizer.param_groups[0]['lr']
        if 'pooling_mode' in checkpoint.keys():
            cfg.POOLING_MODE = checkpoint['pooling_mode']
        print("loaded checkpoint %s" % (load_name))

    # Train on Multiple GPUS
    if mGPUs:
        model = nn.DataParallel(model)

    # Copy network to CUDA memroy
    if cuda:
        model.cuda()

    # Training loop
    iters_per_epoch = int(train_size / batch_size)

    sys.stdout.flush()

    for epoch in range(start_epoch, max_epochs + 1):
        # remove batch re-sizing for augmentation or adjust?
        dataset.resize_batch()

        # Set model to train mode
        model.train()
        loss_temp = 0
        start = time.time()

        # Update learning rate as per decay step
        if epoch % (lr_decay_step + 1) == 0:
            adjust_learning_rate(optimizer, lr_decay_gamma)
            lr *= lr_decay_gamma

        # Get batch data and train
        data_iter = iter(dataloader)
        for step in range(iters_per_epoch):
            sys.stdout.flush()
            data = next(data_iter)

            # Apply augmentations
            aug_img_tensors, aug_bbox_tensors = apply_augmentations(
                data[0],
                data[2],
                flip_prob=flip_prob,
                scale=scale,
                scale_prob=scale_prob,
                translate=translate,
                translate_prob=translate_prob,
                angle=angle,
                dist=dist,
                rotate_prob=rotate_prob,
                shear_factor=shear_factor,
                shear_prob=shear_prob)

            # im_data.data.resize_(data[0].size()).copy_(data[0])
            im_data.data.resize_(aug_img_tensors.size()).copy_(aug_img_tensors)
            im_info.data.resize_(data[1].size()).copy_(data[1])
            # gt_boxes.data.resize_(data[2].size()).copy_(data[2])
            gt_boxes.data.resize_(
                aug_bbox_tensors.size()).copy_(aug_bbox_tensors)
            num_boxes.data.resize_(data[3].size()).copy_(data[3])

            # Compute multi-task loss
            model.zero_grad()
            rois, cls_prob, bbox_pred, \
            rpn_loss_cls, rpn_loss_box, \
            RCNN_loss_cls, RCNN_loss_bbox, \
            rois_label = model(im_data, im_info, gt_boxes, num_boxes)

            loss = rpn_loss_cls_wt * rpn_loss_cls.mean() + rpn_loss_box_wt * rpn_loss_box.mean() + \
                   RCNN_loss_cls_wt * RCNN_loss_cls.mean() + RCNN_loss_bbox_wt * RCNN_loss_bbox.mean()
            loss_temp += loss.data[0]

            # Backward pass to compute gradients and update weights
            optimizer.zero_grad()
            loss.backward()
            if net == "vgg16":
                clip_gradient(model, 10.)
            optimizer.step()

            # Display training stats on terminal
            if step % disp_interval == 0:
                end = time.time()
                if step > 0:
                    loss_temp /= disp_interval

                if mGPUs:
                    batch_loss = loss.data[0]
                    loss_rpn_cls = rpn_loss_cls.mean().data[0]
                    loss_rpn_box = rpn_loss_box.mean().data[0]
                    loss_rcnn_cls = RCNN_loss_cls.mean().data[0]
                    loss_rcnn_box = RCNN_loss_bbox.mean().data[0]
                    fg_cnt = torch.sum(rois_label.data.ne(0))
                    bg_cnt = rois_label.data.numel() - fg_cnt
                else:
                    batch_loss = loss.data[0]
                    loss_rpn_cls = rpn_loss_cls.data[0]
                    loss_rpn_box = rpn_loss_box.data[0]
                    loss_rcnn_cls = RCNN_loss_cls.data[0]
                    loss_rcnn_box = RCNN_loss_bbox.data[0]
                    fg_cnt = torch.sum(rois_label.data.ne(0))
                    bg_cnt = rois_label.data.numel() - fg_cnt

                print("[session %d][epoch %2d][iter %4d/%4d] loss: %.4f, lr: %.2e" \
                      % (session, epoch, step, iters_per_epoch, loss_temp, lr))
                print("\t\t\tfg/bg=(%d/%d), time cost: %f" %
                      (fg_cnt, bg_cnt, end - start))
                print("\t\t\t batch_loss: %.4f, rpn_cls: %.4f, rpn_box: %.4f, rcnn_cls: %.4f, rcnn_box %.4f" \
                      % (batch_loss, loss_rpn_cls, loss_rpn_box, loss_rcnn_cls, loss_rcnn_box))
                if use_tfboard:
                    info = {
                        'loss': loss_temp,
                        'loss_rpn_cls': loss_rpn_cls,
                        'loss_rpn_box': loss_rpn_box,
                        'loss_rcnn_cls': loss_rcnn_cls,
                        'loss_rcnn_box': loss_rcnn_box
                    }
                    for tag, value in info.items():
                        logger.scalar_summary(tag, value, step)

                loss_temp = 0
                start = time.time()

                # Save model at checkpoints
        if mGPUs:
            save_name = os.path.join(
                output_dir, '{}_{}_{}_{}.pth'.format(arch, session, epoch,
                                                     step))
            save_checkpoint(
                {
                    'session': session,
                    'epoch': epoch + 1,
                    'model': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'pooling_mode': cfg.POOLING_MODE,
                    'class_agnostic': class_agnostic,
                }, save_name)
        else:
            save_name = os.path.join(
                output_dir, '{}_{}_{}_{}.pth'.format(arch, session, epoch,
                                                     step))
            save_checkpoint(
                {
                    'session': session,
                    'epoch': epoch + 1,
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'pooling_mode': cfg.POOLING_MODE,
                    'class_agnostic': class_agnostic,
                }, save_name)
        print('save model: {}'.format(save_name))

        end = time.time()
        delete_older_checkpoints(
            os.path.join(cfg.MODEL_DIR, "couplenet_{}_*.pth".format(i)))
        print("Run Time: ", end - start)
Example #6
0
                                       thickness=-1)
                        action_classes = list(set(action_classes))
                        for j in range(len(action_classes)):
                            cv2.putText(
                                img,
                                imdb._action_names[int(action_classes[j])],
                                (0, img.shape[0] - 30 * (j) - 10),
                                cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
                        # for roi in rois:
                        #     cv2.rectangle(img, (int(roi[1]), int(roi[2])), (int(roi[3]), int(roi[4])), (255,0,0), thickness=2)
                        img_bbox.append(img)
                    info['image_kp_bbox'] = np.array(img_bbox)

                    for tag, value in info.items():
                        if "loss" in tag:
                            logger.scalar_summary(
                                tag, value, step + epoch * iters_per_epoch)
                        elif "image" in tag:
                            logger.image_summary(
                                tag, value, step + epoch * iters_per_epoch)

                loss_temp = 0
                loss_act_temp = 0
                loss_obj_temp = 0
                start = time.time()

        if args.mGPUs:
            save_name = os.path.join(
                output_dir,
                "faster_rcnn_{}_{}_{}.pth".format(args.session, epoch, step))
            save_checkpoint(
                {
Example #7
0
def bld_train(args, ann_path=None, step=0):

    # print('Train from annotaion {}'.format(ann_path))
    # print('Called with args:')
    # print(args)

    if args.use_tfboard:
        from model.utils.logger import Logger
        # Set the logger
        logger = Logger(
            os.path.join('./.logs', args.active_method,
                         "/activestep" + str(step)))

    if args.dataset == "pascal_voc":
        args.imdb_name = "voc_2007_trainval"
        args.imdbval_name = "voc_2007_test"
        args.set_cfgs = [
            'ANCHOR_SCALES', '[8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]',
            'MAX_NUM_GT_BOXES', '20'
        ]
    elif args.dataset == "pascal_voc_0712":
        args.imdb_name = "voc_2007_trainval+voc_2012_trainval"
        args.imdbval_name = "voc_2007_test"
        args.set_cfgs = [
            'ANCHOR_SCALES', '[8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]',
            'MAX_NUM_GT_BOXES', '20'
        ]
    elif args.dataset == "coco":
        args.imdb_name = "coco_2014_train"
        args.imdbval_name = "coco_2014_minival"
        args.set_cfgs = [
            'ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]',
            'MAX_NUM_GT_BOXES', '50'
        ]
    elif args.dataset == "imagenet":
        args.imdb_name = "imagenet_train"
        args.imdbval_name = "imagenet_val"
        args.set_cfgs = [
            'ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]',
            'MAX_NUM_GT_BOXES', '30'
        ]
    elif args.dataset == "vg":
        # train sizes: train, smalltrain, minitrain
        # train scale: ['150-50-20', '150-50-50', '500-150-80', '750-250-150', '1750-700-450', '1600-400-20']
        args.imdb_name = "vg_150-50-50_minitrain"
        args.imdbval_name = "vg_150-50-50_minival"
        args.set_cfgs = [
            'ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]',
            'MAX_NUM_GT_BOXES', '50'
        ]
    elif args.dataset == "voc_coco":
        args.imdb_name = "voc_coco_2007_train+voc_coco_2007_val"
        args.imdbval_name = "voc_coco_2007_test"
        args.set_cfgs = [
            'ANCHOR_SCALES', '[8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]',
            'MAX_NUM_GT_BOXES', '20'
        ]
    else:
        raise NotImplementedError

    args.cfg_file = "cfgs/{}_ls.yml".format(
        args.net) if args.large_scale else "cfgs/{}.yml".format(args.net)

    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    # print('Using config:')
    # pprint.pprint(cfg)
    # np.random.seed(cfg.RNG_SEED)

    # torch.backends.cudnn.benchmark = True
    if torch.cuda.is_available() and not args.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    # train set = source set + target set
    # -- Note: Use validation set and disable the flipped to enable faster loading.
    cfg.TRAIN.USE_FLIPPED = True
    cfg.USE_GPU_NMS = args.cuda
    # source train set, fully labeled
    #ann_path_source = os.path.join(ann_path, 'voc_coco_2007_train_f.json')
    #ann_path_target = os.path.join(ann_path, 'voc_coco_2007_train_l.json')
    imdb, roidb, ratio_list, ratio_index = combined_roidb(
        args.imdb_name, ann_path=os.path.join(ann_path, 'source'))
    imdb_tg, roidb_tg, ratio_list_tg, ratio_index_tg = combined_roidb(
        args.imdb_name, ann_path=os.path.join(ann_path, 'target'))

    print('{:d} roidb entries for source set'.format(len(roidb)))
    print('{:d} roidb entries for target set'.format(len(roidb_tg)))

    output_dir = args.save_dir + "/" + args.net + "/" + args.dataset + "/" + args.active_method + "/activestep" + str(
        step)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    sampler_batch_tg = None  # do not sample target set

    bs_tg = 4
    dataset_tg = roibatchLoader(roidb_tg, ratio_list_tg, ratio_index_tg, bs_tg, \
                             imdb_tg.num_classes, training=True)

    assert imdb.num_classes == imdb_tg.num_classes

    dataloader_tg = torch.utils.data.DataLoader(dataset_tg,
                                                batch_size=bs_tg,
                                                sampler=sampler_batch_tg,
                                                num_workers=args.num_workers,
                                                worker_init_fn=_rand_fn())

    # initilize the tensor holder here.
    im_data = torch.FloatTensor(1)
    im_info = torch.FloatTensor(1)
    num_boxes = torch.LongTensor(1)
    gt_boxes = torch.FloatTensor(1)
    image_label = torch.FloatTensor(1)
    confidence = torch.FloatTensor(1)

    # ship to cuda
    if args.cuda:
        im_data = im_data.cuda()
        im_info = im_info.cuda()
        num_boxes = num_boxes.cuda()
        gt_boxes = gt_boxes.cuda()
        image_label = image_label.cuda()
        confidence = confidence.cuda()

    # make variable
    im_data = Variable(im_data)
    im_info = Variable(im_info)
    num_boxes = Variable(num_boxes)
    gt_boxes = Variable(gt_boxes)
    image_label = Variable(image_label)
    confidence = Variable(confidence)

    if args.cuda:
        cfg.CUDA = True

    # initialize the network here.
    if args.net == 'vgg16':
        fasterRCNN = vgg16(imdb.classes,
                           pretrained=True,
                           class_agnostic=args.class_agnostic)
    elif args.net == 'res101':
        fasterRCNN = resnet(imdb.classes,
                            101,
                            pretrained=True,
                            class_agnostic=args.class_agnostic)
    elif args.net == 'res50':
        fasterRCNN = resnet(imdb.classes,
                            50,
                            pretrained=True,
                            class_agnostic=args.class_agnostic)
    elif args.net == 'res152':
        fasterRCNN = resnet(imdb.classes,
                            152,
                            pretrained=True,
                            class_agnostic=args.class_agnostic)
    else:
        print("network is not defined")
        raise NotImplementedError

    # initialize the expectation network.
    if args.net == 'vgg16':
        fasterRCNN_val = vgg16(imdb.classes,
                               pretrained=True,
                               class_agnostic=args.class_agnostic)
    elif args.net == 'res101':
        fasterRCNN_val = resnet(imdb.classes,
                                101,
                                pretrained=True,
                                class_agnostic=args.class_agnostic)
    elif args.net == 'res50':
        fasterRCNN_val = resnet(imdb.classes,
                                50,
                                pretrained=True,
                                class_agnostic=args.class_agnostic)
    elif args.net == 'res152':
        fasterRCNN_val = resnet(imdb.classes,
                                152,
                                pretrained=True,
                                class_agnostic=args.class_agnostic)
    else:
        print("network is not defined")
        raise NotImplementedError

    fasterRCNN.create_architecture()
    fasterRCNN_val.create_architecture()

    # lr = cfg.TRAIN.LEARNING_RATE
    lr = args.lr
    # tr_momentum = cfg.TRAIN.MOMENTUM
    # tr_momentum = args.momentum

    params = []
    for key, value in dict(fasterRCNN.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                params += [{'params': [value], 'lr': lr * (cfg.TRAIN.DOUBLE_BIAS + 1), \
                            'weight_decay': cfg.TRAIN.BIAS_DECAY and cfg.TRAIN.WEIGHT_DECAY or 0}]
            else:
                params += [{
                    'params': [value],
                    'lr': lr,
                    'weight_decay': cfg.TRAIN.WEIGHT_DECAY
                }]

    if args.optimizer == "adam":
        lr = lr * 0.1
        optimizer = torch.optim.Adam(params)
    elif args.optimizer == "sgd":
        optimizer = torch.optim.SGD(params, momentum=cfg.TRAIN.MOMENTUM)
    else:
        raise NotImplementedError

    if args.resume:
        load_name = os.path.join(
            output_dir,
            'faster_rcnn_{}_{}_{}.pth'.format(args.checksession,
                                              args.checkepoch,
                                              args.checkpoint))
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name)
        args.session = checkpoint['session']
        args.start_epoch = checkpoint['epoch']
        fasterRCNN.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr = optimizer.param_groups[0]['lr']
        if 'pooling_mode' in checkpoint.keys():
            cfg.POOLING_MODE = checkpoint['pooling_mode']
        print("loaded checkpoint %s" % (load_name))

    # expectation model
    print("load checkpoint for expectation model: %s" % args.model_path)
    checkpoint = torch.load(args.model_path)
    fasterRCNN_val.load_state_dict(checkpoint['model'])
    if 'pooling_mode' in checkpoint.keys():
        cfg.POOLING_MODE = checkpoint['pooling_mode']

    fasterRCNN_val = fasterRCNN_val
    fasterRCNN_val.eval()

    if args.mGPUs:
        fasterRCNN = nn.DataParallel(fasterRCNN)
        #fasterRCNN_val = nn.DataParallel(fasterRCNN_val)

    if args.cuda:
        fasterRCNN.cuda()
        fasterRCNN_val.cuda()

    # Evaluation
    # data_iter = iter(dataloader_tg)
    # for target_k in range( int(train_size_tg / args.batch_size)):
    fname = "noisy_annotations.pkl"
    if not os.path.isfile(fname):
        for batch_k, data in enumerate(dataloader_tg):
            im_data.data.resize_(data[0].size()).copy_(data[0])
            im_info.data.resize_(data[1].size()).copy_(data[1])
            gt_boxes.data.resize_(data[2].size()).copy_(data[2])
            num_boxes.data.resize_(data[3].size()).copy_(data[3])
            image_label.data.resize_(data[4].size()).copy_(data[4])
            b_size = len(im_data)
            # expactation pass
            rois, cls_prob, bbox_pred, \
            _, _, _, _, _ = fasterRCNN_val(im_data, im_info, gt_boxes, num_boxes)
            scores = cls_prob.data
            boxes = rois.data[:, :, 1:5]
            if cfg.TRAIN.BBOX_REG:
                # Apply bounding-box regression deltas
                box_deltas = bbox_pred.data
                if cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED:
                    # Optionally normalize targets by a precomputed mean and stdev
                    if args.class_agnostic:
                        box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() \
                                     + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()
                        box_deltas = box_deltas.view(b_size, -1, 4)
                    else:
                        box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() \
                                     + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()
                        # print('DEBUG: Size of box_deltas is {}'.format(box_deltas.size()) )
                        box_deltas = box_deltas.view(b_size, -1,
                                                     4 * len(imdb.classes))

                pred_boxes = bbox_transform_inv(boxes, box_deltas, 1)
                pred_boxes = clip_boxes(pred_boxes, im_info.data, 1)
            else:
                # Simply repeat the boxes, once for each class
                pred_boxes = np.tile(boxes, (1, scores.shape[1]))

            # TODO: data distalliation
            # Choose the confident samples
            for b_idx in range(b_size):
                # fill one confidence
                # confidence.data[b_idx, :] = 1 - (gt_boxes.data[b_idx, :, 4] == 0)
                # resize prediction
                pred_boxes[b_idx] /= data[1][b_idx][2]
                for j in xrange(1, imdb.num_classes):
                    if image_label.data[b_idx, j] != 1:
                        continue  # next if no image label

                    # filtering box outside of the image
                    not_keep = (pred_boxes[b_idx][:, j * 4] == pred_boxes[b_idx][:, j * 4 + 2]) | \
                               (pred_boxes[b_idx][:, j * 4 + 1] == pred_boxes[b_idx][:, j * 4 + 3])
                    keep = torch.nonzero(not_keep == 0).view(-1)
                    # decease the number of pgts
                    thresh = 0.5
                    while torch.nonzero(
                            scores[b_idx, :,
                                   j][keep] > thresh).view(-1).numel() <= 0:
                        thresh = thresh * 0.5
                    inds = torch.nonzero(
                        scores[b_idx, :, j][keep] > thresh).view(-1)

                    # if there is no det, error
                    if inds.numel() <= 0:
                        print('Warning!!!!!!! It should not appear!!')
                        continue

                    # find missing ID
                    missing_list = np.where(gt_boxes.data[b_idx, :, 4] == 0)[0]
                    if (len(missing_list) == 0): continue
                    missing_id = missing_list[0]
                    cls_scores = scores[b_idx, :, j][keep][inds]
                    cls_boxes = pred_boxes[b_idx][keep][inds][:, j *
                                                              4:(j + 1) * 4]
                    cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)),
                                         1)
                    keep = nms(cls_dets, 0.2)  # Magic number ????
                    keep = keep.view(-1).tolist()
                    sys.stdout.write(
                        'from {} predictions choose-> min({},4) as pseudo label  \r'
                        .format(len(cls_scores), len(keep)))
                    sys.stdout.flush()
                    _, order = torch.sort(cls_scores[keep], 0, True)
                    if len(keep) == 0: continue

                    max_keep = 4
                    for pgt_k in range(max_keep):
                        if len(order) <= pgt_k: break
                        if missing_id + pgt_k >= 20: break
                        gt_boxes.data[b_idx, missing_id +
                                      pgt_k, :4] = cls_boxes[keep][order[
                                          len(order) - 1 - pgt_k]]
                        gt_boxes.data[b_idx, missing_id + pgt_k,
                                      4] = j  # class
                        #confidence[b_idx, missing_id + pgt_k] = cls_scores[keep][order[len(order) - 1 - pgt_k]]
                        num_boxes[b_idx] = num_boxes[b_idx] + 1
                sample = roidb_tg[dataset_tg.ratio_index[batch_k * bs_tg +
                                                         b_idx]]
                pgt_boxes = np.array([
                    gt_boxes[b_idx, x, :4].cpu().data.numpy()
                    for x in range(int(num_boxes[b_idx]))
                ])
                pgt_classes = np.array([
                    gt_boxes[b_idx, x, 4].cpu().data[0]
                    for x in range(int(num_boxes[b_idx]))
                ])
                sample["boxes"] = pgt_boxes
                sample["gt_classes"] = pgt_classes
                # DEBUG
                assert np.array_equal(sample["label"],image_label[b_idx].cpu().data.numpy()), \
                    "Image labels are not equal! {} vs {}".format(sample["label"],image_label[b_idx].cpu().data.numpy())

        #with open(fname, 'w') as f:
        # pickle.dump(roidb_tg, f)
    else:
        pass
        # with open(fname) as f:  # Python 3: open(..., 'rb')
        # roidb_tg = pickle.load(f)

    print("-- Optimization Stage --")
    # Optimization
    print("######################################################l")

    roidb.extend(roidb_tg)  # merge two datasets
    print('before filtering, there are %d images...' % (len(roidb)))
    i = 0
    while i < len(roidb):
        if True:
            if len(roidb[i]['boxes']) == 0:
                del roidb[i]
                i -= 1
        else:
            if len(roidb[i]['boxes']) == 0:
                del roidb[i]
                i -= 1
        i += 1

    print('after filtering, there are %d images...' % (len(roidb)))
    from roi_data_layer.roidb import rank_roidb_ratio
    ratio_list, ratio_index = rank_roidb_ratio(roidb)
    train_size = len(roidb)
    sampler_batch = sampler(train_size, args.batch_size)
    dataset = roibatchLoader(roidb, ratio_list, ratio_index, args.batch_size, \
                             imdb.num_classes, training=True)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             sampler=sampler_batch,
                                             num_workers=args.num_workers,
                                             worker_init_fn=_rand_fn())
    iters_per_epoch = int(train_size / args.batch_size)
    print("Training set size is {}".format(train_size))
    for epoch in range(args.start_epoch, args.max_epochs + 1):
        fasterRCNN.train()

        loss_temp = 0
        start = time.time()
        epoch_start = start

        # adjust learning rate
        if epoch % (args.lr_decay_step + 1) == 0:
            adjust_learning_rate(optimizer, args.lr_decay_gamma)
            lr *= args.lr_decay_gamma

        # one step
        data_iter = iter(dataloader)
        for step in range(iters_per_epoch):
            data = next(data_iter)
            im_data.data.resize_(data[0].size()).copy_(data[0])
            im_info.data.resize_(data[1].size()).copy_(data[1])
            gt_boxes.data.resize_(data[2].size()).copy_(data[2])
            num_boxes.data.resize_(data[3].size()).copy_(data[3])
            image_label.data.resize_(data[4].size()).copy_(data[4])

            #gt_boxes.data = \
            #    torch.cat((gt_boxes.data, torch.zeros(gt_boxes.size(0), gt_boxes.size(1), 1).cuda()), dim=2)
            conf_data = torch.zeros(gt_boxes.size(0), gt_boxes.size(1)).cuda()
            confidence.data.resize_(conf_data.size()).copy_(conf_data)

            fasterRCNN.zero_grad()

            # rois_label = fasterRCNN(im_data, im_info, gt_boxes, num_boxes, confidence)
            rois, cls_prob, bbox_pred, \
            rpn_loss_cls, rpn_loss_box, \
            RCNN_loss_cls, RCNN_loss_bbox, \
            rois_label = fasterRCNN(im_data, im_info, gt_boxes, num_boxes)
            # rois_label = fasterRCNN(im_data, im_info, gt_boxes, num_boxes, confidence)

            loss = rpn_loss_cls.mean() + rpn_loss_box.mean() \
                   + RCNN_loss_cls.mean() + RCNN_loss_bbox.mean()
            loss_temp += loss.data[0]

            # backward
            optimizer.zero_grad()
            loss.backward()
            if args.net == "vgg16":
                clip_gradient(fasterRCNN, 10.)
            optimizer.step()

            if step % args.disp_interval == 0:
                end = time.time()
                if step > 0:
                    loss_temp /= args.disp_interval

                if args.mGPUs:
                    loss_rpn_cls = rpn_loss_cls.mean().data[0]
                    loss_rpn_box = rpn_loss_box.mean().data[0]
                    loss_rcnn_cls = RCNN_loss_cls.mean().data[0]
                    loss_rcnn_box = RCNN_loss_bbox.mean().data[0]
                    fg_cnt = torch.sum(rois_label.data.ne(0))
                    bg_cnt = rois_label.data.numel() - fg_cnt
                else:
                    loss_rpn_cls = rpn_loss_cls.data[0]
                    loss_rpn_box = rpn_loss_box.data[0]
                    loss_rcnn_cls = RCNN_loss_cls.data[0]
                    loss_rcnn_box = RCNN_loss_bbox.data[0]
                    fg_cnt = torch.sum(rois_label.data.ne(0))
                    bg_cnt = rois_label.data.numel() - fg_cnt

                print("[session %d][epoch %2d][iter %4d/%4d] loss: %.4f, lr: %.2e" \
                      % (args.session, epoch, step, iters_per_epoch, loss_temp, lr))
                print("\t\t\tfg/bg=(%d/%d), time cost: %f" %
                      (fg_cnt, bg_cnt, end - start))
                print("\t\t\trpn_cls: %.4f, rpn_box: %.4f, rcnn_cls: %.4f, rcnn_box %.4f" \
                      % (loss_rpn_cls, loss_rpn_box, loss_rcnn_cls, loss_rcnn_box))
                if args.use_tfboard:
                    info = {
                        'loss': loss_temp,
                        'loss_rpn_cls': loss_rpn_cls,
                        'loss_rpn_box': loss_rpn_box,
                        'loss_rcnn_cls': loss_rcnn_cls,
                        'loss_rcnn_box': loss_rcnn_box
                    }
                    for tag, value in info.items():
                        logger.scalar_summary(tag, value, step)

                    images = []
                    for k in range(args.batch_size):
                        image = draw_bounding_boxes(
                            im_data[k].data.cpu().numpy(),
                            gt_boxes[k].data.cpu().numpy(),
                            im_info[k].data.cpu().numpy(),
                            num_boxes[k].data.cpu().numpy())
                        images.append(image)
                    logger.image_summary("Train epoch %2d, iter %4d/%4d" % (epoch, step, iters_per_epoch), \
                                          images, step)
                loss_temp = 0
                start = time.time()
                if False:
                    break

        if args.mGPUs:
            save_name = os.path.join(
                output_dir,
                'faster_rcnn_{}_{}_{}.pth'.format(args.session, epoch, step))
            save_checkpoint(
                {
                    'session': args.session,
                    'epoch': epoch + 1,
                    'model': fasterRCNN.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'pooling_mode': cfg.POOLING_MODE,
                    'class_agnostic': args.class_agnostic,
                }, save_name)
        else:
            save_name = os.path.join(
                output_dir,
                'faster_rcnn_{}_{}_{}.pth'.format(args.session, epoch, step))
            save_checkpoint(
                {
                    'session': args.session,
                    'epoch': epoch + 1,
                    'model': fasterRCNN.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'pooling_mode': cfg.POOLING_MODE,
                    'class_agnostic': args.class_agnostic,
                }, save_name)
        print('save model: {}'.format(save_name))

        epoch_end = time.time()
        print('Epoch time cost: {}'.format(epoch_end - epoch_start))

    print('finished!')