def train_main(v2, use_refine_model, use_ecc, use_modulator, use_bn, use_residual, vis_roi_features, no_visrepr, vis_loss_ratio, no_vis_loss,
               modulate_from_vis, max_sample_frame, lr, weight_decay, batch_size, output_dir, ex_name):
    random.seed(12345)
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)
    np.random.seed(12345)
    torch.backends.cudnn.deterministic = True

    output_dir = osp.join(output_dir, ex_name)
    log_file = osp.join(output_dir, 'epoch_log.txt')

    if not osp.exists(output_dir):
        os.makedirs(output_dir)

    with open(log_file, 'w') as f:
        f.write('[Experiment name]%s\n\n' % ex_name)
        f.write('[Parameters]\n')
        f.write('use_ecc=%r\nuse_modulator=%r\nuse_bn=%r\nuse_residual=%r\nvis_roi_features=%r\nno_visrepr=%r\nvis_loss_ratio=%f\nno_vis_loss=%r\nmodulate_from_vis=%r\nmax_sample_frame=%d\nlr=%f\nweight_decay=%f\nbatch_size=%d\n\n' % 
            (use_ecc, use_modulator, use_bn, use_residual, vis_roi_features, no_visrepr, vis_loss_ratio, no_vis_loss, modulate_from_vis, max_sample_frame, lr, weight_decay, batch_size))
        f.write('[Loss log]\n')

    with open('experiments/cfgs/tracktor.yaml', 'r') as f:
        tracker_config = yaml.safe_load(f)

    #################
    # Load Datasets #
    #################
    train_set = MOT17SimpleReIDWrapper('train', 0.8, 0.0, max_sample_frame, tracker_cfg=tracker_config)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=1, collate_fn=simple_reid_wrapper_collate)
    val_set = MOT17SimpleReIDWrapper('val', 0.8, 0.0, max_sample_frame, tracker_cfg=tracker_config)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=1, collate_fn=simple_reid_wrapper_collate)

    with open(osp.join(cfg.ROOT_DIR, 'output', 'precomputed_ecc_matrices_3.pkl'), 'rb') as f:
        ecc_dict = pickle.load(f)

    train_set.load_precomputed_ecc_warp_matrices(ecc_dict)
    val_set.load_precomputed_ecc_warp_matrices(ecc_dict)

    ########################
    # Initializing Modules #
    ########################
    obj_detect = FRCNN_FPN(num_classes=2)
    obj_detect.load_state_dict(torch.load(tracker_config['tracktor']['obj_detect_model'],
                               map_location=lambda storage, loc: storage))
    obj_detect.eval()
    obj_detect.cuda()

    if v2:
        motion_model = MotionModelSimpleReIDV2(use_modulator=use_modulator, use_bn=use_bn, use_residual=use_residual, 
                                               vis_roi_features=vis_roi_features, no_visrepr=no_visrepr, modulate_from_vis=modulate_from_vis)
    else:
        motion_model = MotionModelSimpleReID(use_modulator=use_modulator, use_bn=use_bn, use_residual=use_residual, 
                                             vis_roi_features=vis_roi_features, no_visrepr=no_visrepr, modulate_from_vis=modulate_from_vis)
    motion_model.train()
    motion_model.cuda()

    if use_refine_model:
        motion_model = RefineModel(motion_model)
        motion_model.train()
        motion_model.cuda()

    reid_network = resnet50(pretrained=False, output_dim=128)
    reid_network.load_state_dict(torch.load(tracker_config['tracktor']['reid_weights'],
                                 map_location=lambda storage, loc: storage))
    reid_network.eval()
    reid_network.cuda()

    optimizer = torch.optim.Adam(motion_model.parameters(), lr=lr, weight_decay=weight_decay)
    pred_loss_func = nn.SmoothL1Loss()
    vis_loss_func = nn.MSELoss()

    #######################
    # Training Parameters #
    #######################
    max_epochs = 100
    log_freq = 25

    train_pred_loss_epochs = []
    train_vis_loss_epochs = []
    val_pred_loss_epochs = []
    val_vis_loss_epochs = []
    lowest_val_loss = 9999999.9
    lowest_val_loss_epoch = -1
    last_save_epoch = 0
    save_epoch_freq = 5

    ############
    # Training #
    ############
    for epoch in range(max_epochs):
        n_iter = 0
        new_lowest_flag = False
        train_pred_loss_iters = []
        train_vis_loss_iters = []
        val_pred_loss_iters = []
        val_vis_loss_iters = []

        for data in train_loader:
            early_reid = get_batch_mean_early_reid(reid_network, data['early_reid_patches'])
            curr_reid = reid_network(data['curr_reid_patch'].cuda())

            conv_features, repr_features = get_features(obj_detect, data['curr_img'], data['curr_gt_app'])

            prev_loc = (data['prev_gt_warped'] if use_ecc else data['prev_gt']).cuda()
            curr_loc = (data['curr_gt_warped'] if use_ecc else data['curr_gt']).cuda()
            label_loc = data['label_gt'].cuda()
            curr_vis = data['curr_vis'].cuda()

            n_iter += 1
            optimizer.zero_grad()
            if use_refine_model:
                pred_loc_wh, vis = motion_model(obj_detect, data['label_img'], conv_features, repr_features, prev_loc, curr_loc,
                                                early_reid=early_reid, curr_reid=curr_reid)
                label_loc_wh = two_p_to_wh(label_loc)

                pred_loss = weighted_smooth_l1_loss(pred_loc_wh, label_loc_wh, curr_vis)
                vis_loss = vis_loss_func(vis, curr_vis)
            else:
                if v2:
                    pred_loc_wh, vis = motion_model(early_reid, curr_reid, repr_features, prev_loc, curr_loc)
                else:
                    pred_loc_wh, vis = motion_model(early_reid, curr_reid, conv_features, repr_features, prev_loc, curr_loc)
                label_loc_wh = two_p_to_wh(label_loc)

                pred_loss = pred_loss_func(pred_loc_wh, label_loc_wh)
                vis_loss = vis_loss_func(vis, curr_vis)
            
            if no_vis_loss:
                loss = pred_loss
            else:
                loss = pred_loss + vis_loss_ratio * vis_loss

            loss.backward()
            optimizer.step()

            train_pred_loss_iters.append(pred_loss.item())
            train_vis_loss_iters.append(vis_loss.item())
            if n_iter % log_freq == 0:
                print('[Train Iter %5d] train pred loss %.6f, vis loss %.6f ...' % 
                    (n_iter, np.mean(train_pred_loss_iters[n_iter-log_freq:n_iter]), np.mean(train_vis_loss_iters[n_iter-log_freq:n_iter])),
                    flush=True)

        mean_train_pred_loss = np.mean(train_pred_loss_iters)
        mean_train_vis_loss = np.mean(train_vis_loss_iters)
        train_pred_loss_epochs.append(mean_train_pred_loss)
        train_vis_loss_epochs.append(mean_train_vis_loss)
        print('Train epoch %4d end.' % (epoch + 1))

        motion_model.eval()

        with torch.no_grad():
            for data in val_loader:
                early_reid = get_batch_mean_early_reid(reid_network, data['early_reid_patches'])
                curr_reid = reid_network(data['curr_reid_patch'].cuda())

                conv_features, repr_features = get_features(obj_detect, data['curr_img'], data['curr_gt_app'])

                prev_loc = (data['prev_gt_warped'] if use_ecc else data['prev_gt']).cuda()
                curr_loc = (data['curr_gt_warped'] if use_ecc else data['curr_gt']).cuda()
                label_loc = data['label_gt'].cuda()
                curr_vis = data['curr_vis'].cuda()

                if use_refine_model:
                    pred_loc_wh, vis = motion_model(obj_detect, data['label_img'], conv_features, repr_features, prev_loc, curr_loc,
                                                    early_reid=early_reid, curr_reid=curr_reid)
                    label_loc_wh = two_p_to_wh(label_loc)

                    pred_loss = pred_loss_func(pred_loc_wh, label_loc_wh)
                    vis_loss = vis_loss_func(vis, curr_vis)
                else:
                    if v2:
                        pred_loc_wh, vis = motion_model(early_reid, curr_reid, repr_features, prev_loc, curr_loc)
                    else:
                        pred_loc_wh, vis = motion_model(early_reid, curr_reid, conv_features, repr_features, prev_loc, curr_loc)
                    label_loc_wh = two_p_to_wh(label_loc)

                    pred_loss = pred_loss_func(pred_loc_wh, label_loc_wh)
                    vis_loss = vis_loss_func(vis, curr_vis)

                val_pred_loss_iters.append(pred_loss.item())
                val_vis_loss_iters.append(vis_loss.item())

        mean_val_pred_loss = np.mean(val_pred_loss_iters)
        mean_val_vis_loss = np.mean(val_vis_loss_iters)
        val_pred_loss_epochs.append(mean_val_pred_loss)
        val_vis_loss_epochs.append(mean_val_vis_loss)

        print('[Epoch %4d] train pred loss %.6f, vis loss %.6f; val pred loss %.6f, vis loss %.6f' % 
            (epoch+1, mean_train_pred_loss, mean_train_vis_loss, mean_val_pred_loss, mean_val_vis_loss))

        motion_model.train()

        if mean_val_pred_loss < lowest_val_loss:
            lowest_val_loss, lowest_val_loss_epoch = mean_val_pred_loss, epoch + 1
            last_save_epoch = lowest_val_loss_epoch
            new_lowest_flag = True
            torch.save(motion_model.state_dict(), osp.join(output_dir, 'simple_reid_motion_model_epoch_%d.pth'%(epoch+1)))
        elif epoch + 1 - last_save_epoch == save_epoch_freq:
            last_save_epoch = epoch + 1
            torch.save(motion_model.state_dict(), osp.join(output_dir, 'simple_reid_motion_model_epoch_%d.pth'%(epoch+1)))

        with open(log_file, 'a') as f:
            f.write('Epoch %4d: train pred loss %.6f, vis loss %.6f; val pred loss %.6f, vis loss %.6f %s\n' % 
                (epoch+1, mean_train_pred_loss, mean_train_vis_loss, mean_val_pred_loss, mean_val_vis_loss, '*' if new_lowest_flag else ''))
def train_main(max_previous_frame, use_ecc, use_modulator, use_bn,
               vis_loss_ratio, no_vis_loss, lr, weight_decay, batch_size,
               output_dir, pretrain_vis_path, ex_name):
    random.seed(12345)
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)
    np.random.seed(12345)
    torch.backends.cudnn.deterministic = True

    output_dir = osp.join(output_dir, ex_name)
    log_file = osp.join(output_dir, 'epoch_log.txt')

    if not osp.exists(output_dir):
        os.makedirs(output_dir)

    with open(log_file, 'w') as f:
        f.write('[Experiment name]%s\n\n' % ex_name)
        f.write('[Parameters]\n')
        f.write(
            'max_previous_frame=%d\nuse_ecc=%r\nuse_modulator=%r\nvis_loss_ratio=%f\nno_vis_loss=%r\nlr=%f\nweight_decay=%f\nbatch_size=%d\n\n'
            % (max_previous_frame, use_ecc, use_modulator, vis_loss_ratio,
               no_vis_loss, lr, weight_decay, batch_size))
        f.write('[Loss log]\n')

    with open('experiments/cfgs/tracktor.yaml', 'r') as f:
        tracker_config = yaml.safe_load(f)

    #################
    # Load Datasets #
    #################
    train_set = MOT17TracksWrapper('train',
                                   0.8,
                                   0.0,
                                   input_track_len=max_previous_frame + 1,
                                   max_sample_frame=max_previous_frame,
                                   get_data_mode='sample' +
                                   (',ecc' if use_ecc else ''),
                                   tracker_cfg=tracker_config)
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=1,
                              collate_fn=tracks_wrapper_collate)
    val_set = MOT17TracksWrapper('val',
                                 0.8,
                                 0.1,
                                 input_track_len=max_previous_frame + 1,
                                 max_sample_frame=max_previous_frame,
                                 get_data_mode='sample' +
                                 (',ecc' if use_ecc else ''),
                                 tracker_cfg=tracker_config)
    val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=1,
                            collate_fn=tracks_wrapper_collate)

    with open(
            osp.join(cfg.ROOT_DIR, 'output', 'precomputed_ecc_matrices_3.pkl'),
            'rb') as f:
        ecc_dict = pickle.load(f)

    train_set.load_precomputed_ecc_warp_matrices(ecc_dict)
    val_set.load_precomputed_ecc_warp_matrices(ecc_dict)

    ########################
    # Initializing Modules #
    ########################

    motion_model = BackboneMotionModel(tracker_config=tracker_config,
                                       vis_conv_only=False,
                                       use_modulator=use_modulator,
                                       use_bn=use_bn)
    # motion_model.load_vis_pretrained(pretrain_vis_path)
    motion_model.train()

    def set_bn_eval(m):
        classname = m.__class__.__name__
        if classname.find('BatchNorm') != -1:
            m.eval()

    motion_model.backbone.apply(set_bn_eval)
    motion_model.cuda()

    optimizer = torch.optim.Adam(motion_model.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
    pred_loss_func = nn.SmoothL1Loss()
    vis_loss_func = nn.MSELoss()

    #######################
    # Training Parameters #
    #######################

    max_epochs = 100
    log_freq = 25

    train_pred_loss_epochs = []
    train_vis_loss_epochs = []
    val_pred_loss_epochs = []
    val_vis_loss_epochs = []
    lowest_val_loss = 9999999.9
    lowest_val_loss_epoch = -1

    ############
    # Training #
    ############

    for epoch in range(max_epochs):
        n_iter = 0
        train_pred_loss_iters = []
        train_vis_loss_iters = []
        val_pred_loss_iters = []
        val_vis_loss_iters = []

        for data, label in train_loader:
            images = data['curr_img']
            images = [img.cuda().squeeze(0) for img in images]

            # jitter target for getting roi features
            im_w = torch.tensor([img.size()[-1] for img in data['curr_img']],
                                dtype=data['curr_gt'].dtype)
            im_h = torch.tensor([img.size()[-2] for img in data['curr_img']],
                                dtype=data['curr_gt'].dtype)
            jittered_curr_gt = bbox_jitter(data['curr_gt'].clone(), im_w, im_h)

            target = jittered_curr_gt.cuda()
            target = [{"boxes": bbox.unsqueeze(0)} for bbox in target]

            prev_loc = (data['prev_gt_warped']
                        if use_ecc else data['prev_gt']).cuda()
            curr_loc = (data['curr_gt_warped']
                        if use_ecc else data['curr_gt']).cuda()
            label_loc = label['label_gt'].cuda()
            curr_vis = data['curr_vis'].cuda()
            label_loc_wh = two_p_to_wh(label_loc)

            n_iter += 1
            optimizer.zero_grad()
            pred_loc_wh, vis = motion_model(images, target, prev_loc, curr_loc)

            pred_loss = pred_loss_func(pred_loc_wh, label_loc_wh)
            vis_loss = vis_loss_func(vis, curr_vis)
            if no_vis_loss:
                loss = pred_loss
            else:
                loss = pred_loss + vis_loss_ratio * vis_loss

            loss.backward()
            optimizer.step()

            train_pred_loss_iters.append(pred_loss.item())
            train_vis_loss_iters.append(vis_loss.item())
            if n_iter % log_freq == 0:
                print(
                    '[Train Iter %5d] train pred loss %.6f, vis loss %.6f ...'
                    %
                    (n_iter,
                     np.mean(train_pred_loss_iters[n_iter - log_freq:n_iter]),
                     np.mean(train_vis_loss_iters[n_iter - log_freq:n_iter])),
                    flush=True)

        mean_train_pred_loss = np.mean(train_pred_loss_iters)
        mean_train_vis_loss = np.mean(train_vis_loss_iters)
        train_pred_loss_epochs.append(mean_train_pred_loss)
        train_vis_loss_epochs.append(mean_train_vis_loss)
        print('Train epoch %4d end.' % (epoch + 1))

        motion_model.eval()

        with torch.no_grad():
            for data, label in val_loader:
                images = data['curr_img']
                images = [img.cuda().squeeze(0) for img in images]

                # do not jitter for validation
                target = data['curr_gt'].cuda()
                target = [{"boxes": bbox.unsqueeze(0)} for bbox in target]

                prev_loc = (data['prev_gt_warped']
                            if use_ecc else data['prev_gt']).cuda()
                curr_loc = (data['curr_gt_warped']
                            if use_ecc else data['curr_gt']).cuda()
                label_loc = label['label_gt'].cuda()
                curr_vis = data['curr_vis'].cuda()
                label_loc_wh = two_p_to_wh(label_loc)

                pred_loc_wh, vis = motion_model(images, target, prev_loc,
                                                curr_loc)

                pred_loss = pred_loss_func(pred_loc_wh, label_loc_wh)
                vis_loss = vis_loss_func(vis, curr_vis)

                val_pred_loss_iters.append(pred_loss.item())
                val_vis_loss_iters.append(vis_loss.item())

        mean_val_pred_loss = np.mean(val_pred_loss_iters)
        mean_val_vis_loss = np.mean(val_vis_loss_iters)
        val_pred_loss_epochs.append(mean_val_pred_loss)
        val_vis_loss_epochs.append(mean_val_vis_loss)

        print(
            '[Epoch %4d] train pred loss %.6f, vis loss %.6f; val pred loss %.6f, vis loss %.6f'
            % (epoch + 1, mean_train_pred_loss, mean_train_vis_loss,
               mean_val_pred_loss, mean_val_vis_loss))
        with open(log_file, 'a') as f:
            f.write(
                'Epoch %4d: train pred loss %.6f, vis loss %.6f; val pred loss %.6f, vis loss %.6f\n'
                % (epoch + 1, mean_train_pred_loss, mean_train_vis_loss,
                   mean_val_pred_loss, mean_val_vis_loss))

        motion_model.train()
        if mean_val_pred_loss < lowest_val_loss:
            lowest_val_loss, lowest_val_loss_epoch = mean_val_pred_loss, epoch + 1
            torch.save(
                motion_model.state_dict(),
                osp.join(output_dir,
                         'motion_model_epoch_%d.pth' % (epoch + 1)))
        motion_model.backbone.apply(set_bn_eval)
示例#3
0
def test_motion_model(val_loader, tracker_config, motion_model):
    obj_detect = FRCNN_FPN(num_classes=2)
    obj_detect.load_state_dict(
        torch.load(tracker_config['tracktor']['obj_detect_model'],
                   map_location=lambda storage, loc: storage))
    obj_detect.eval()
    obj_detect.cuda()

    reid_network = resnet50(pretrained=False, output_dim=128)
    reid_network.load_state_dict(
        torch.load(tracker_config['tracktor']['reid_weights'],
                   map_location=lambda storage, loc: storage))
    reid_network.eval()
    reid_network.cuda()

    pred_loss_func = nn.SmoothL1Loss()

    loss_iters = []
    low_vis_loss_sum = 0.0
    low_vis_num = 0
    high_vis_loss_sum = 0.0
    high_vis_num = 0
    total_iters = len(val_loader)
    n_iters = 0

    with torch.no_grad():
        for data in val_loader:
            n_iters += 1

            early_reid = get_batch_mean_early_reid(reid_network,
                                                   data['early_reid_patches'])
            curr_reid = reid_network(data['curr_reid_patch'].cuda())
            conv_features, repr_features = get_features(
                obj_detect, data['curr_img'], data['curr_gt_app'])

            prev_loc = data['prev_gt_warped'].cuda()
            curr_loc = data['curr_gt_warped'].cuda()
            label_loc = data['label_gt'].cuda()
            curr_vis = data['curr_vis'].cuda()

            pred_loc_wh, vis = motion_model(early_reid, curr_reid,
                                            conv_features, repr_features,
                                            prev_loc, curr_loc)
            label_loc_wh = two_p_to_wh(label_loc)

            pred_loss = pred_loss_func(pred_loc_wh, label_loc_wh)
            loss_iters.append(pred_loss.item())

            low_vis_ind = curr_vis < 0.3
            if low_vis_ind.any():
                low_vis_pred_loss = pred_loss_func(pred_loc_wh[low_vis_ind],
                                                   label_loc_wh[low_vis_ind])
                low_vis_loss_sum += (low_vis_pred_loss *
                                     torch.sum(low_vis_ind)).item()
                low_vis_num += torch.sum(low_vis_ind).item()

            high_vis_ind = curr_vis > 0.7
            if high_vis_ind.any():
                high_vis_pred_loss = pred_loss_func(pred_loc_wh[high_vis_ind],
                                                    label_loc_wh[high_vis_ind])
                high_vis_loss_sum += (high_vis_pred_loss *
                                      torch.sum(high_vis_ind)).item()
                high_vis_num += torch.sum(high_vis_ind).item()

            if n_iters % 50 == 0:
                print('Iter %5d/%5d finished.' % (n_iters, total_iters),
                      flush=True)

    mean_loss = np.mean(loss_iters)
    mean_low_vis_loss = low_vis_loss_sum / low_vis_num
    mean_high_vis_loss = high_vis_loss_sum / high_vis_num

    print('All finished! Loss %.6f, low vis loss %.6f, high vis loss %.6f.' %
          (mean_loss, mean_low_vis_loss, mean_high_vis_loss))
示例#4
0
def test_tracktor_motion(val_loader, tracker_config, bbox_regression=True):
    obj_detect = FRCNN_FPN(num_classes=2)
    obj_detect.load_state_dict(
        torch.load(tracker_config['tracktor']['obj_detect_model'],
                   map_location=lambda storage, loc: storage))
    obj_detect.eval()
    obj_detect.cuda()

    pred_loss_func = nn.SmoothL1Loss()

    loss_iters = []
    low_vis_loss_sum = 0.0
    low_vis_num = 0
    high_vis_loss_sum = 0.0
    high_vis_num = 0
    total_iters = len(val_loader)
    n_iters = 0

    print(total_iters)

    with torch.no_grad():
        for data in val_loader:
            n_iters += 1

            prev_loc = data['prev_gt_warped'].cuda()
            curr_loc = data['curr_gt_warped'].cuda()
            label_loc = data['label_gt'].cuda()
            curr_vis = data['curr_vis'].cuda()

            pred_loc = curr_loc.clone()

            last_motion = curr_loc - prev_loc
            pred_loc += last_motion

            if bbox_regression:
                obj_detect.load_image(data['label_img'][0])
                pred_loc, _ = obj_detect.predict_boxes(pred_loc)

            label_loc_wh = two_p_to_wh(label_loc)
            pred_loc_wh = two_p_to_wh(pred_loc)

            pred_loss = pred_loss_func(pred_loc_wh, label_loc_wh)
            loss_iters.append(pred_loss.item())

            low_vis_ind = curr_vis < 0.3
            if low_vis_ind.any():
                low_vis_pred_loss = pred_loss_func(pred_loc_wh[low_vis_ind],
                                                   label_loc_wh[low_vis_ind])
                low_vis_loss_sum += (low_vis_pred_loss *
                                     torch.sum(low_vis_ind)).item()
                low_vis_num += torch.sum(low_vis_ind).item()

            high_vis_ind = curr_vis > 0.7
            if high_vis_ind.any():
                high_vis_pred_loss = pred_loss_func(pred_loc_wh[high_vis_ind],
                                                    label_loc_wh[high_vis_ind])
                high_vis_loss_sum += (high_vis_pred_loss *
                                      torch.sum(high_vis_ind)).item()
                high_vis_num += torch.sum(high_vis_ind).item()

            if n_iters % 500 == 0:
                print('Iter %5d/%5d finished.' % (n_iters, total_iters),
                      flush=True)

    mean_loss = np.mean(loss_iters)
    mean_low_vis_loss = low_vis_loss_sum / low_vis_num
    mean_high_vis_loss = high_vis_loss_sum / high_vis_num

    print('All finished! Loss %.6f, low vis loss %.6f, high vis loss %.6f.' %
          (mean_loss, mean_low_vis_loss, mean_high_vis_loss))
    def motion(self, img, new_img=None):
        """
        Apply neural motion model to all active tracks. Apply last_v to all inactive tracks. 
        Input:
            -img: (3, h, w) LAST frame
            -new_img: (1, 3, h, w) this frame. Mandatory if using RefineModel.
        """
        last_pos_1 = [t.last_pos[-2] for t in self.tracks]
        last_pos_2 = [t.last_pos[-1] for t in self.tracks]  # same as t.pos
        last_pos_1 = torch.cat(last_pos_1, 0)
        last_pos_2 = torch.cat(last_pos_2, 0)

        pos = self.get_pos()
        if self.do_align:
            # use the unwarped pos
            pos_app = self.get_unwarped_pos()
        else:
            pos_app = pos

        if self.use_backbone_model:
            img = [img.cuda()]
            pos_app = clip_boxes_to_image(pos_app, img.shape[-2:])

            target = [{"boxes": pos_app}]

            pred_motion = self.motion_model(img,
                                            target,
                                            last_pos_1,
                                            last_pos_2,
                                            output_motion=True)
        elif self.use_reid_motion_model:
            historical_reid_features = [
                torch.cat(list(t.features), 0) for t in self.tracks
            ]

            pos_app = clip_boxes_to_image(pos_app, img.shape[-2:])
            curr_reid_features = self.reid_network.test_rois(
                img.unsqueeze(0), pos_app)

            # note that the current frame has not been loaded yet, so we are still using the last frame
            conv_features, repr_features = self.get_pooled_features(pos_app)

            pred_motion = self.motion_model(historical_reid_features,
                                            curr_reid_features,
                                            conv_features,
                                            repr_features,
                                            last_pos_1,
                                            last_pos_2,
                                            output_motion=True)
        elif self.use_simple_reid_model or self.use_simple_reid_v2_model or self.use_refine_model or self.use_v3_model:
            early_reid_features = torch.stack([
                torch.mean(torch.cat(t.early_features, 0), 0)
                for t in self.tracks
            ], 0)

            pos_app = clip_boxes_to_image(pos_app, img.shape[-2:])
            curr_reid_features = self.reid_network.test_rois(
                img.unsqueeze(0), pos_app)

            conv_features, repr_features = self.get_pooled_features(pos_app)

            if self.use_simple_reid_model or self.use_v3_model:
                pred_motion = self.motion_model(early_reid_features,
                                                curr_reid_features,
                                                conv_features,
                                                repr_features,
                                                last_pos_1,
                                                last_pos_2,
                                                output_motion=True)
            elif self.use_simple_reid_v2_model:
                pred_motion = self.motion_model(early_reid_features,
                                                curr_reid_features,
                                                repr_features,
                                                last_pos_1,
                                                last_pos_2,
                                                output_motion=True)
            else:
                # RefineModel
                pred_motion = self.motion_model(self.obj_detect, [new_img],
                                                conv_features,
                                                repr_features,
                                                last_pos_1,
                                                last_pos_2,
                                                early_reid=early_reid_features,
                                                curr_reid=curr_reid_features,
                                                output_motion=True)

        else:
            # MotionModel/MotionModelV2
            pos_app = clip_boxes_to_image(pos_app, img.shape[-2:])
            conv_features, repr_features = self.get_pooled_features(pos_app)

            pred_motion = self.motion_model(conv_features,
                                            repr_features,
                                            last_pos_1,
                                            last_pos_2,
                                            output_motion=True)

        pos_wh = two_p_to_wh(pos)
        pred_pos_wh = decode_motion(pred_motion, pos_wh)
        pred_pos = wh_to_two_p(pred_pos_wh)

        for i in range(len(self.tracks)):
            self.tracks[i].last_v = pred_pos[i].unsqueeze(
                0) - self.tracks[i].pos
            self.tracks[i].pos = pred_pos[i].unsqueeze(0)

        if self.do_reid:
            for t in self.inactive_tracks:
                if t.last_v.nelement() > 0:
                    t.pos = t.pos + t.last_v
def train_main(use_ecc, use_modulator, use_bn, use_residual, use_reid_distance, no_visrepr, vis_loss_ratio, no_vis_loss, motion_noise,
               lr, weight_decay, batch_size, output_dir, ex_name):
    random.seed(12345)
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)
    np.random.seed(12345)
    torch.backends.cudnn.deterministic = True

    output_dir = osp.join(output_dir, ex_name)
    log_file = osp.join(output_dir, 'epoch_log.txt')

    if not osp.exists(output_dir):
        os.makedirs(output_dir)

    with open(log_file, 'w') as f:
        f.write('[Experiment name]%s\n\n' % ex_name)
        f.write('[Parameters]\n')
        f.write('use_ecc=%r\nuse_modulator=%r\nuse_bn=%r\nuse_residual=%r\nuse_reid_distance=%r\nno_visrepr=%r\nvis_loss_ratio=%f\nno_vis_loss=%r\nmotion_noise=%f\nlr=%f\nweight_decay=%f\nbatch_size=%d\n\n' % 
            (use_ecc, use_modulator, use_bn, use_residual, use_reid_distance, no_visrepr, vis_loss_ratio, no_vis_loss, motion_noise, lr, weight_decay, batch_size))
        f.write('[Loss log]\n')

    with open('experiments/cfgs/tracktor.yaml', 'r') as f:
        tracker_config = yaml.safe_load(f)

    #################
    # Load Datasets #
    #################
    train_set = MOT17ClipsWrapper('train', 0.8, 0.0, clip_len=10, motion_noise=motion_noise, train_jitter=True, ecc=True, tracker_cfg=tracker_config)
    train_loader = DataLoader(train_set, batch_size=1, shuffle=True, num_workers=1, collate_fn=clips_wrapper_collate)
    val_set = MOT17ClipsWrapper('val', 0.8, 0.0, clip_len=10, motion_noise=motion_noise, train_jitter=True, ecc=True, tracker_cfg=tracker_config)
    val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1, collate_fn=clips_wrapper_collate)

    with open(osp.join(cfg.ROOT_DIR, 'output', 'precomputed_ecc_matrices_3.pkl'), 'rb') as f:
        ecc_dict = pickle.load(f)

    train_set.load_precomputed_ecc_warp_matrices(ecc_dict)
    val_set.load_precomputed_ecc_warp_matrices(ecc_dict)

    ########################
    # Initializing Modules #
    ########################
    obj_detect = FRCNN_FPN(num_classes=2)
    obj_detect.load_state_dict(torch.load(tracker_config['tracktor']['obj_detect_model'],
                               map_location=lambda storage, loc: storage))
    obj_detect.eval()
    obj_detect.cuda()

    motion_model = MotionModelReID(use_modulator=use_modulator, use_bn=use_bn, use_residual=use_residual, 
                                   use_reid_distance=use_reid_distance, no_visrepr=no_visrepr)

    motion_model.train()
    motion_model.cuda()

    reid_network = resnet50(pretrained=False, output_dim=128)
    reid_network.load_state_dict(torch.load(tracker_config['tracktor']['reid_weights'],
                                 map_location=lambda storage, loc: storage))
    reid_network.eval()
    reid_network.cuda()


    optimizer = torch.optim.Adam(motion_model.parameters(), lr=lr, weight_decay=weight_decay)
    pred_loss_func = nn.SmoothL1Loss()
    vis_loss_func = nn.MSELoss()

    #######################
    # Training Parameters #
    #######################

    # usage: historical_reid, curr_reid, roi_pool_output, representation_feature, prev_loc, curr_loc, curr_vis, label_loc
    batch_manager = BatchForgerManager([
        BatchForgerList(batch_size),
        BatchForger(batch_size, (motion_model.reid_dim,)),
        BatchForger(batch_size, (motion_model.roi_output_dim, motion_model.pool_size, motion_model.pool_size)),
        BatchForger(batch_size, (motion_model.representation_dim,)),
        BatchForger(batch_size, (4,)),
        BatchForger(batch_size, (4,)),
        BatchForger(batch_size, ()),
        BatchForger(batch_size, (4,))
    ])

    max_epochs = 100
    log_freq = 25

    train_pred_loss_epochs = []
    train_vis_loss_epochs = []
    val_pred_loss_epochs = []
    val_vis_loss_epochs = []
    lowest_val_loss = 9999999.9
    lowest_val_loss_epoch = -1

    ############
    # Training #
    ############

    for epoch in range(max_epochs):
        n_iter = 0
        train_pred_loss_iters = []
        train_vis_loss_iters = []
        val_pred_loss_iters = []
        val_vis_loss_iters = []

        for data in train_loader:
            historical_reid = get_batch_reid_features(reid_network, data['imgs'], data['historical'])
            curr_reid = get_curr_reid_features(reid_network, data['imgs'], data['curr_frame_offset'], data['curr_gt_app'])
            conv_features, repr_features = get_features(obj_detect, data['imgs'], data['curr_frame_offset'], data['curr_gt_app'])
            prev_loc = (data['prev_gt_warped'] if use_ecc else data['prev_gt']).cuda()
            curr_loc = (data['curr_gt_warped'] if use_ecc else data['curr_gt']).cuda()
            curr_vis = data['curr_vis'].cuda()
            label_loc = data['label_gt'].cuda()

            batch_manager.feed((historical_reid, curr_reid, conv_features, repr_features, prev_loc, curr_loc, curr_vis, label_loc))

            while batch_manager.has_one_batch():
                n_iter += 1
                historical_reid, curr_reid, conv_features, repr_features, prev_loc, curr_loc, curr_vis, label_loc = \
                    batch_manager.dump()

                optimizer.zero_grad()

                pred_loc_wh, vis = motion_model(historical_reid, curr_reid, conv_features, repr_features, prev_loc, curr_loc)
                label_loc_wh = two_p_to_wh(label_loc)

                pred_loss = pred_loss_func(pred_loc_wh, label_loc_wh)
                vis_loss = vis_loss_func(vis, curr_vis)
                if no_vis_loss:
                    loss = pred_loss
                else:
                    loss = pred_loss + vis_loss_ratio * vis_loss

                loss.backward()
                optimizer.step()

                train_pred_loss_iters.append(pred_loss.item())
                train_vis_loss_iters.append(vis_loss.item())
                if n_iter % log_freq == 0:
                    print('[Train Iter %5d] train pred loss %.6f, vis loss %.6f ...' % 
                        (n_iter, np.mean(train_pred_loss_iters[n_iter-log_freq:n_iter]), 
                         np.mean(train_vis_loss_iters[n_iter-log_freq:n_iter])), flush=True)

        mean_train_pred_loss = np.mean(train_pred_loss_iters)
        mean_train_vis_loss = np.mean(train_vis_loss_iters)
        train_pred_loss_epochs.append(mean_train_pred_loss)
        train_vis_loss_epochs.append(mean_train_vis_loss)
        print('Train epoch %4d end.' % (epoch + 1))

        batch_manager.reset()
        motion_model.eval()

        with torch.no_grad():
            for data in val_loader:
                historical_reid = get_batch_reid_features(reid_network, data['imgs'], data['historical'])
                curr_reid = get_curr_reid_features(reid_network, data['imgs'], data['curr_frame_offset'], data['curr_gt_app'])
                conv_features, repr_features = get_features(obj_detect, data['imgs'], data['curr_frame_offset'], data['curr_gt_app'])
                prev_loc = (data['prev_gt_warped'] if use_ecc else data['prev_gt']).cuda()
                curr_loc = (data['curr_gt_warped'] if use_ecc else data['curr_gt']).cuda()
                curr_vis = data['curr_vis'].cuda()
                label_loc = data['label_gt'].cuda()

                batch_manager.feed((historical_reid, curr_reid, conv_features, repr_features, prev_loc, curr_loc, curr_vis, label_loc))

                while batch_manager.has_one_batch():
                    historical_reid, curr_reid, conv_features, repr_features, prev_loc, curr_loc, curr_vis, label_loc = \
                        batch_manager.dump()

                    pred_loc_wh, vis = motion_model(historical_reid, curr_reid, conv_features, repr_features, prev_loc, curr_loc)
                    label_loc_wh = two_p_to_wh(label_loc)

                    pred_loss = pred_loss_func(pred_loc_wh, label_loc_wh)
                    vis_loss = vis_loss_func(vis, curr_vis)

                    val_pred_loss_iters.append(pred_loss.item())
                    val_vis_loss_iters.append(vis_loss.item())

        mean_val_pred_loss = np.mean(val_pred_loss_iters)
        mean_val_vis_loss = np.mean(val_vis_loss_iters)
        val_pred_loss_epochs.append(mean_val_pred_loss)
        val_vis_loss_epochs.append(mean_val_vis_loss)

        print('[Epoch %4d] train pred loss %.6f, vis loss %.6f; val pred loss %.6f, vis loss %.6f' % 
            (epoch+1, mean_train_pred_loss, mean_train_vis_loss, mean_val_pred_loss, mean_val_vis_loss), flush=True)
        with open(log_file, 'a') as f:
            f.write('Epoch %4d: train pred loss %.6f, vis loss %.6f; val pred loss %.6f, vis loss %.6f\n' % 
                (epoch+1, mean_train_pred_loss, mean_train_vis_loss, mean_val_pred_loss, mean_val_vis_loss))

        batch_manager.reset()
        motion_model.train()

        if mean_val_pred_loss < lowest_val_loss:
            lowest_val_loss, lowest_val_loss_epoch = mean_val_pred_loss, epoch + 1
            torch.save(motion_model.state_dict(), osp.join(output_dir, 'reid_motion_model_epoch_%d.pth'%(epoch+1)))
示例#7
0
def train_main(oracle_training, no_visrepr, max_previous_frame, use_ecc,
               use_modulator, use_bn, use_residual, vis_loss_ratio,
               no_vis_loss, lr, weight_decay, batch_size, output_dir,
               pretrain_vis_path, ex_name):
    random.seed(12345)
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)
    np.random.seed(12345)
    torch.backends.cudnn.deterministic = True

    output_dir = osp.join(output_dir, ex_name)
    log_file = osp.join(output_dir, 'epoch_log.txt')

    if not osp.exists(output_dir):
        os.makedirs(output_dir)

    with open(log_file, 'w') as f:
        f.write('[Experiment name]%s\n\n' % ex_name)
        f.write('[Parameters]\n')
        f.write(
            'oracle_training=%r\nno_visrepr=%r\nmax_previous_frame=%d\nuse_ecc=%r\nuse_modulator=%r\nuse_bn=%r\nuse_residual=%r\nvis_loss_ratio=%f\nno_vis_loss=%r\nlr=%f\nweight_decay=%f\nbatch_size=%d\n\n'
            % (oracle_training, no_visrepr, max_previous_frame, use_ecc,
               use_modulator, use_bn, use_residual, vis_loss_ratio,
               no_vis_loss, lr, weight_decay, batch_size))
        f.write('[Loss log]\n')

    with open('experiments/cfgs/tracktor.yaml', 'r') as f:
        tracker_config = yaml.safe_load(f)

    #################
    # Load Datasets #
    #################
    train_set = MOT17TracksWrapper('train',
                                   0.8,
                                   0.0,
                                   input_track_len=max_previous_frame + 1,
                                   max_sample_frame=max_previous_frame,
                                   get_data_mode='sample' +
                                   (',ecc' if use_ecc else ''),
                                   tracker_cfg=tracker_config)
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=1,
                              collate_fn=tracks_wrapper_collate)
    val_set = MOT17TracksWrapper('val',
                                 0.8,
                                 0.1,
                                 input_track_len=max_previous_frame + 1,
                                 max_sample_frame=max_previous_frame,
                                 get_data_mode='sample' +
                                 (',ecc' if use_ecc else ''),
                                 tracker_cfg=tracker_config)
    val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=1,
                            collate_fn=tracks_wrapper_collate)

    with open(
            osp.join(cfg.ROOT_DIR, 'output', 'precomputed_ecc_matrices_3.pkl'),
            'rb') as f:
        ecc_dict = pickle.load(f)

    train_set.load_precomputed_ecc_warp_matrices(ecc_dict)
    val_set.load_precomputed_ecc_warp_matrices(ecc_dict)

    ########################
    # Initializing Modules #
    ########################
    obj_detect = FRCNN_FPN(num_classes=2)
    obj_detect.load_state_dict(
        torch.load(tracker_config['tracktor']['obj_detect_model'],
                   map_location=lambda storage, loc: storage))
    obj_detect.eval()
    obj_detect.cuda()

    if oracle_training:
        motion_model = VisOracleMotionModel(vis_conv_only=False,
                                            use_modulator=use_modulator)
    else:
        if no_visrepr:
            motion_model = MotionModelNoVisRepr(vis_conv_only=False,
                                                use_modulator=use_modulator,
                                                use_bn=use_bn)
        else:
            motion_model = MotionModelV2(vis_conv_only=False,
                                         use_modulator=use_modulator,
                                         use_bn=use_bn,
                                         use_residual=use_residual)
    # motion_model.load_vis_pretrained(pretrain_vis_path)

    motion_model.train()
    motion_model.cuda()

    optimizer = torch.optim.Adam(motion_model.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
    pred_loss_func = nn.SmoothL1Loss()
    vis_loss_func = nn.MSELoss()

    #######################
    # Training Parameters #
    #######################

    max_epochs = 100
    log_freq = 25

    train_pred_loss_epochs = []
    train_vis_loss_epochs = []
    val_pred_loss_epochs = []
    val_vis_loss_epochs = []
    lowest_val_loss = 9999999.9
    lowest_val_loss_epoch = -1

    ############
    # Training #
    ############

    for epoch in range(max_epochs):
        n_iter = 0
        train_pred_loss_iters = []
        train_vis_loss_iters = []
        val_pred_loss_iters = []
        val_vis_loss_iters = []

        for data, label in train_loader:
            # jitter bboxs for getting roi features
            im_w = torch.tensor([img.size()[-1] for img in data['curr_img']],
                                dtype=data['curr_gt'].dtype)
            im_h = torch.tensor([img.size()[-2] for img in data['curr_img']],
                                dtype=data['curr_gt'].dtype)
            jittered_curr_gt = bbox_jitter(data['curr_gt'].clone(), im_w, im_h)

            conv_features, repr_features = get_features(
                obj_detect, data['curr_img'], jittered_curr_gt)

            # for motion calculation, we still use the unjittered bboxs
            prev_loc = (data['prev_gt_warped']
                        if use_ecc else data['prev_gt']).cuda()
            curr_loc = (data['curr_gt_warped']
                        if use_ecc else data['curr_gt']).cuda()
            label_loc = label['label_gt'].cuda()
            curr_vis = data['curr_vis'].cuda()

            n_iter += 1
            # TODO the output bbox should be (x,y,w,h)?
            optimizer.zero_grad()
            if oracle_training:
                pred_loc_wh, vis = motion_model(conv_features, repr_features,
                                                prev_loc, curr_loc,
                                                curr_vis.unsqueeze(-1))
            else:
                pred_loc_wh, vis = motion_model(conv_features, repr_features,
                                                prev_loc, curr_loc)
            label_loc_wh = two_p_to_wh(label_loc)

            pred_loss = pred_loss_func(pred_loc_wh, label_loc_wh)
            vis_loss = vis_loss_func(vis, curr_vis)
            if no_vis_loss:
                loss = pred_loss
            else:
                loss = pred_loss + vis_loss_ratio * vis_loss

            loss.backward()
            optimizer.step()

            train_pred_loss_iters.append(pred_loss.item())
            train_vis_loss_iters.append(vis_loss.item())
            if n_iter % log_freq == 0:
                print(
                    '[Train Iter %5d] train pred loss %.6f, vis loss %.6f ...'
                    %
                    (n_iter,
                     np.mean(train_pred_loss_iters[n_iter - log_freq:n_iter]),
                     np.mean(train_vis_loss_iters[n_iter - log_freq:n_iter])),
                    flush=True)

        mean_train_pred_loss = np.mean(train_pred_loss_iters)
        mean_train_vis_loss = np.mean(train_vis_loss_iters)
        train_pred_loss_epochs.append(mean_train_pred_loss)
        train_vis_loss_epochs.append(mean_train_vis_loss)
        print('Train epoch %4d end.' % (epoch + 1))

        motion_model.eval()

        with torch.no_grad():
            for data, label in val_loader:
                # do not jitter for validation
                conv_features, repr_features = get_features(
                    obj_detect, data['curr_img'], data['curr_gt'])

                prev_loc = (data['prev_gt_warped']
                            if use_ecc else data['prev_gt']).cuda()
                curr_loc = (data['curr_gt_warped']
                            if use_ecc else data['curr_gt']).cuda()
                label_loc = label['label_gt'].cuda()
                curr_vis = data['curr_vis'].cuda()

                if oracle_training:
                    pred_loc_wh, vis = motion_model(conv_features,
                                                    repr_features, prev_loc,
                                                    curr_loc,
                                                    curr_vis.unsqueeze(-1))
                else:
                    pred_loc_wh, vis = motion_model(conv_features,
                                                    repr_features, prev_loc,
                                                    curr_loc)
                label_loc_wh = two_p_to_wh(label_loc)

                pred_loss = pred_loss_func(pred_loc_wh, label_loc_wh)
                vis_loss = vis_loss_func(vis, curr_vis)

                val_pred_loss_iters.append(pred_loss.item())
                val_vis_loss_iters.append(vis_loss.item())

        mean_val_pred_loss = np.mean(val_pred_loss_iters)
        mean_val_vis_loss = np.mean(val_vis_loss_iters)
        val_pred_loss_epochs.append(mean_val_pred_loss)
        val_vis_loss_epochs.append(mean_val_vis_loss)

        print(
            '[Epoch %4d] train pred loss %.6f, vis loss %.6f; val pred loss %.6f, vis loss %.6f'
            % (epoch + 1, mean_train_pred_loss, mean_train_vis_loss,
               mean_val_pred_loss, mean_val_vis_loss))
        with open(log_file, 'a') as f:
            f.write(
                'Epoch %4d: train pred loss %.6f, vis loss %.6f; val pred loss %.6f, vis loss %.6f\n'
                % (epoch + 1, mean_train_pred_loss, mean_train_vis_loss,
                   mean_val_pred_loss, mean_val_vis_loss))

        motion_model.train()
        if mean_val_pred_loss < lowest_val_loss:
            lowest_val_loss, lowest_val_loss_epoch = mean_val_pred_loss, epoch + 1
            torch.save(
                motion_model.state_dict(),
                osp.join(output_dir,
                         'motion_model_epoch_%d.pth' % (epoch + 1)))