Exemple #1
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()
    args = args_process(args)

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info("\n" + collect_env_info())
    logger.info(args)

    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    args.img_size = int(cfg['train_datasets']['search_size'])
    args.nms_threshold = float(cfg['train_datasets']['RPN_NMS'])
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(pretrain=True,
                       opts=args,
                       anchors=train_loader.dataset.anchors)
    else:
        exit()
    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)
    else:
        raise Exception("Pretrained weights must be loaded!")

    model = model.cuda()
    dist_model = torch.nn.DataParallel(model,
                                       list(range(
                                           torch.cuda.device_count()))).cuda()

    logger.info('model prepare done')

    logger = logging.getLogger('global')
    val_avg = AverageMeter()

    validation(val_loader, dist_model, cfg, val_avg)
Exemple #2
0
def train(train_loader, model, optimizer, lr_scheduler, epoch, cfg):
    global tb_index, best_acc, cur_lr, logger
    cur_lr = lr_scheduler.get_cur_lr()
    logger = logging.getLogger('global')
    avg = AverageMeter()
    model.train()
    model = model.cuda()
    end = time.time()

    def is_valid_number(x):
        return not (math.isnan(x) or math.isinf(x) or x > 1e4)

    num_per_epoch = len(train_loader.dataset) // args.epochs // args.batch
    start_epoch = epoch
    epoch = epoch
    for iter, input in enumerate(train_loader):

        if epoch != iter // num_per_epoch + start_epoch:  # next epoch
            epoch = iter // num_per_epoch + start_epoch

            if not os.path.exists(args.save_dir):  # makedir/save model
                os.makedirs(args.save_dir)

            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': args.arch,
                    'state_dict': model.module.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                    'anchor_cfg': cfg['anchors']
                }, False,
                os.path.join(args.save_dir, 'checkpoint_e%d.pth' % (epoch)),
                os.path.join(args.save_dir, 'best.pth'))

            if epoch == args.epochs:
                return

            if model.module.features.unfix(epoch / args.epochs):
                logger.info('unfix part model.')
                optimizer, lr_scheduler = build_opt_lr(model.module, cfg, args,
                                                       epoch)

            lr_scheduler.step(epoch)
            cur_lr = lr_scheduler.get_cur_lr()

            logger.info('epoch:{}'.format(epoch))

        tb_index = iter
        if iter % num_per_epoch == 0 and iter != 0:
            for idx, pg in enumerate(optimizer.param_groups):
                logger.info("epoch {} lr {}".format(epoch, pg['lr']))
                tb_writer.add_scalar('lr/group%d' % (idx + 1), pg['lr'],
                                     tb_index)

        data_time = time.time() - end
        avg.update(data_time=data_time)
        x = {
            'cfg': cfg,
            'template': torch.autograd.Variable(input[0]).cuda(),
            'search': torch.autograd.Variable(input[1]).cuda(),
            'label_cls': torch.autograd.Variable(input[2]).cuda(),
            'label_loc': torch.autograd.Variable(input[3]).cuda(),
            'label_loc_weight': torch.autograd.Variable(input[4]).cuda(),
            'label_mask': torch.autograd.Variable(input[6]).cuda(),
            'label_mask_weight': torch.autograd.Variable(input[7]).cuda(),
        }

        outputs = model(x)

        rpn_cls_loss, rpn_loc_loss, rpn_mask_loss = torch.mean(
            outputs['losses'][0]), torch.mean(
                outputs['losses'][1]), torch.mean(outputs['losses'][2])
        mask_iou_mean, mask_iou_at_5, mask_iou_at_7 = torch.mean(
            outputs['accuracy'][0]), torch.mean(
                outputs['accuracy'][1]), torch.mean(outputs['accuracy'][2])

        cls_weight, reg_weight, mask_weight = cfg['loss']['weight']

        loss = rpn_cls_loss * cls_weight + rpn_loc_loss * reg_weight + rpn_mask_loss * mask_weight

        optimizer.zero_grad()
        loss.backward()

        if cfg['clip']['split']:
            torch.nn.utils.clip_grad_norm_(model.module.features.parameters(),
                                           cfg['clip']['feature'])
            torch.nn.utils.clip_grad_norm_(model.module.rpn_model.parameters(),
                                           cfg['clip']['rpn'])
            torch.nn.utils.clip_grad_norm_(
                model.module.mask_model.parameters(), cfg['clip']['mask'])
        else:
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.clip)  # gradient clip

        if is_valid_number(loss.item()):
            optimizer.step()

        siammask_loss = loss.item()

        batch_time = time.time() - end

        avg.update(batch_time=batch_time,
                   rpn_cls_loss=rpn_cls_loss,
                   rpn_loc_loss=rpn_loc_loss,
                   rpn_mask_loss=rpn_mask_loss,
                   siammask_loss=siammask_loss,
                   mask_iou_mean=mask_iou_mean,
                   mask_iou_at_5=mask_iou_at_5,
                   mask_iou_at_7=mask_iou_at_7)

        tb_writer.add_scalar('loss/cls', rpn_cls_loss, tb_index)
        tb_writer.add_scalar('loss/loc', rpn_loc_loss, tb_index)
        tb_writer.add_scalar('loss/mask', rpn_mask_loss, tb_index)
        tb_writer.add_scalar('mask/mIoU', mask_iou_mean, tb_index)
        tb_writer.add_scalar('mask/[email protected]', mask_iou_at_5, tb_index)
        tb_writer.add_scalar('mask/[email protected]', mask_iou_at_7, tb_index)
        end = time.time()

        if (iter + 1) % args.print_freq == 0:
            logger.info(
                'Epoch: [{0}][{1}/{2}] lr: {lr:.6f}\t{batch_time:s}\t{data_time:s}'
                '\t{rpn_cls_loss:s}\t{rpn_loc_loss:s}\t{rpn_mask_loss:s}\t{siammask_loss:s}'
                '\t{mask_iou_mean:s}\t{mask_iou_at_5:s}\t{mask_iou_at_7:s}'.
                format(epoch + 1, (iter + 1) % num_per_epoch,
                       num_per_epoch,
                       lr=cur_lr,
                       batch_time=avg.batch_time,
                       data_time=avg.data_time,
                       rpn_cls_loss=avg.rpn_cls_loss,
                       rpn_loc_loss=avg.rpn_loc_loss,
                       rpn_mask_loss=avg.rpn_mask_loss,
                       siammask_loss=avg.siammask_loss,
                       mask_iou_mean=avg.mask_iou_mean,
                       mask_iou_at_5=avg.mask_iou_at_5,
                       mask_iou_at_7=avg.mask_iou_at_7))
            print_speed(iter + 1, avg.batch_time.avg,
                        args.epochs * num_per_epoch)
Exemple #3
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()
    args = args_process(args)

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info("\n" + collect_env_info())
    logger.info(args)

    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    args.img_size = int(cfg['train_datasets']['search_size'])
    args.nms_threshold = float(cfg['train_datasets']['RPN_NMS'])
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(pretrain=True,
                       opts=args,
                       anchors=train_loader.dataset.anchors)
    else:
        exit()
    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    model = model.cuda()
    dist_model = torch.nn.DataParallel(model,
                                       list(range(
                                           torch.cuda.device_count()))).cuda()

    if args.resume and args.start_epoch != 0:
        model.features.unfix((args.start_epoch - 1) / args.epochs)

    optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(
            model, optimizer, args.resume)
        dist_model = torch.nn.DataParallel(
            model, list(range(torch.cuda.device_count()))).cuda()

    logger.info(lr_scheduler)

    logger.info('model prepare done')
    global cur_lr

    if not os.path.exists(args.save_dir):  # makedir/save model
        os.makedirs(args.save_dir)
    num_per_epoch = len(train_loader.dataset) // args.batch
    num_per_epoch_val = len(val_loader.dataset) // args.batch

    for epoch in range(args.start_epoch, args.epochs):
        lr_scheduler.step(epoch)
        cur_lr = lr_scheduler.get_cur_lr()
        logger = logging.getLogger('global')
        train_avg = AverageMeter()
        val_avg = AverageMeter()

        if dist_model.module.features.unfix(epoch / args.epochs):
            logger.info('unfix part model.')
            optimizer, lr_scheduler = build_opt_lr(dist_model.module, cfg,
                                                   args, epoch)

        train(train_loader, dist_model, optimizer, lr_scheduler, epoch, cfg,
              train_avg, num_per_epoch)

        if dist_model.module.features.unfix(epoch / args.epochs):
            logger.info('unfix part model.')
            optimizer, lr_scheduler = build_opt_lr(dist_model.module, cfg,
                                                   args, epoch)

        if (epoch + 1) % args.save_freq == 0:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': args.arch,
                    'state_dict': dist_model.module.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                    'anchor_cfg': cfg['anchors']
                }, False,
                os.path.join(args.save_dir, 'checkpoint_e%d.pth' % (epoch)),
                os.path.join(args.save_dir, 'best.pth'))

            validation(val_loader, dist_model, epoch, cfg, val_avg,
                       num_per_epoch_val)
def validation(epoch, log_interval, test_dataloader, model, loss, writer,
               device):
    """Validate on test dataset.

    Current validation is only for loss, pos|neg_distance.
    In future, we will add more validation like MAP5|10|50|100. 
    (maybe in another file.)

    Args:
        log_interval:
            How many time will the logger log once.
        test_dataloader:
            It should not be none! A Triplet dataloader to validate data.
        model:
            The model that used to test on dataset.
        loss: 
            Loss metric.
        writer:
            Tensorboard writer
        device: 
            Device that model compute on

    Return:
        epoch avrage value:
            triplet_loss, pos_dists, neg_dists
    
    """
    logger.info(
        "\n------------------------- Start validation -------------------------\n"
    )
    # epoch average meter
    avg_test = AverageMeter()

    # get test batch count
    current_test_batch = 0
    total_test_batch = len(test_dataloader)

    # check dataloader is not None
    assert test_dataloader is not None, "test_dataloader should not be None."

    for batch_idx, batch_sample in enumerate(test_dataloader):
        # Skip last iteration to avoid the problem of having different number of tensors while calculating
        # averages (sizes of tensors must be the same for pairwise distance calculation)
        if batch_idx + 1 == len(test_dataloader):
            continue

        # switch to evaluation mode.
        for param in model.parameters():
            param.requires_grad = False
        model.eval()

        # start time counting
        batch_start_time_test = time.time()

        # Forward pass - compute embeddings
        anc_imgs = batch_sample['anchor_img']
        pos_imgs = batch_sample['pos_img']
        neg_imgs = batch_sample['neg_img']

        pos_cls = batch_sample['pos_cls']
        neg_cls = batch_sample['neg_cls']

        # move to device
        anc_imgs = anc_imgs.to(device)
        pos_imgs = pos_imgs.to(device)
        neg_imgs = neg_imgs.to(device)
        pos_cls = pos_cls.to(device)
        neg_cls = neg_cls.to(device)

        # forward
        output = model.forward_triplet(anc_imgs, pos_imgs, neg_imgs)

        # get output
        anc_emb = output['anchor_map']
        pos_emb = output['pos_map']
        neg_emb = output['neg_map']

        pos_dists = torch.mean(output['dist_pos'])
        neg_dists = torch.mean(output['dist_neg'])

        # loss compute
        loss_value = loss(anc_emb, pos_emb, neg_emb)

        # batch time & batch count
        current_test_batch += 1
        batch_time = time.time() - batch_start_time_test

        # update avg
        avg_test.update(time=batch_time,
                        triplet_loss=loss_value,
                        pos_dists=pos_dists,
                        neg_dists=neg_dists)
        if current_test_batch % log_interval == 0:
            print_speed(current_test_batch, batch_time, total_test_batch,
                        "global")
            logger.info(
                "\n current global average information:\n batch_time {0:.5f} | triplet_loss: {1:.5f} | pos_dists: {2:.5f} | neg_dists: {3:.5f} \n"
                .format(avg_test.time.avg, avg_test.triplet_loss.avg,
                        avg_test.pos_dists.avg, avg_test.neg_dists.avg))
    else:
        writer.add_scalar("Validate/Loss/train",
                          avg_test.triplet_loss.avg,
                          global_step=epoch)
        writer.add_scalar("Validate/Other/pos_dists",
                          avg_test.pos_dists.avg,
                          global_step=epoch)
        writer.add_scalar("Validate/Other/neg_dists",
                          avg_test.neg_dists.avg,
                          global_step=epoch)

    return avg_test.triplet_loss.avg, avg_test.pos_dists.avg, avg_test.neg_dists.avg
Exemple #5
0
def train(train_loader, model, optimizer, lr_scheduler, epoch, cfg):
    global tb_index, best_acc, cur_lr, logger
    cur_lr = lr_scheduler.get_cur_lr()
    logger = logging.getLogger('global')
    avg = AverageMeter()
    model.train()
    # model.module.features.eval()
    # model.module.rpn_model.eval()
    # model.module.features.apply(BNtoFixed)
    # model.module.rpn_model.apply(BNtoFixed)
    #
    # model.module.mask_model.train()
    # model.module.refine_model.train()
    model = model.cuda()
    end = time.time()

    def is_valid_number(x):
        return not (math.isnan(x) or math.isinf(x) or x > 1e4)

    num_per_epoch = len(train_loader.dataset) // args.epochs // args.batch
    start_epoch = epoch
    epoch = epoch
    with torch.no_grad():
        for iter, input in enumerate(train_loader):
            if iter > 100:
                break

            if epoch != iter // num_per_epoch + start_epoch:  # next epoch
                epoch = iter // num_per_epoch + start_epoch

                if epoch == args.epochs:
                    return

                if model.module.features.unfix(epoch / args.epochs):
                    logger.info('unfix part model.')
                    optimizer, lr_scheduler = build_opt_lr(
                        model.module, cfg, args, epoch)

                lr_scheduler.step(epoch)
                cur_lr = lr_scheduler.get_cur_lr()

                logger.info('epoch:{}'.format(epoch))

            tb_index = iter
            if iter % num_per_epoch == 0 and iter != 0:
                for idx, pg in enumerate(optimizer.param_groups):
                    logger.info("epoch {} lr {}".format(epoch, pg['lr']))
                    tb_writer.add_scalar('lr/group%d' % (idx + 1), pg['lr'],
                                         tb_index)

            data_time = time.time() - end
            avg.update(data_time=data_time)
            x_rpn = {
                'cfg': cfg,
                'template': torch.autograd.Variable(input[0]).cuda(),
                'search': torch.autograd.Variable(input[1]).cuda(),
                'label_cls': torch.autograd.Variable(input[2]).cuda(),
                'label_loc': torch.autograd.Variable(input[3]).cuda(),
                'label_loc_weight': torch.autograd.Variable(input[4]).cuda(),
                'label_mask': torch.autograd.Variable(input[6]).cuda()
            }
            x_kp = input[7]
            x_kp = {
                x: torch.autograd.Variable(y).cuda()
                for x, y in x_kp.items()
            }
            x_rpn['anchors'] = train_loader.dataset.anchors.all_anchors[0]

            outputs = model(x_rpn, x_kp)
            roi_box = outputs['predict'][-1]
            pred_kp = outputs['predict'][2]['hm_hp']
            batch_img = x_rpn['search'].expand(x_kp['hm_hp'].size(0), -1, -1,
                                               -1)
            gt_img, pred_img = save_gt_pred_heatmaps(
                batch_img, x_kp['hm_hp'], pred_kp,
                'test_imgs/test_{}.jpg'.format(iter))
            # rpn_pred_cls, rpn_pred_loc = outputs['predict'][:2]
            # rpn_pred_cls = outputs['predict'][-1]
            # anchors = train_loader.dataset.anchors.all_anchors[0]
            #
            # normalized_boxes = proposal_layer([rpn_pred_cls, rpn_pred_loc], anchors, config=cfg)
            # print('rpn_pred_cls: ', rpn_pred_cls.shape)

            rpn_cls_loss, rpn_loc_loss, kp_losses = torch.mean(outputs['losses'][0]),\
                                                        torch.mean(outputs['losses'][1]),\
                                                        outputs['losses'][3]
            kp_loss = torch.mean(kp_losses['loss'])
            kp_hp_loss = torch.mean(kp_losses['hp_loss'])
            kp_hm_hp_loss = torch.mean(kp_losses['hm_hp_loss'])
            kp_hp_offset_loss = torch.mean(kp_losses['hp_offset_loss'])

            # mask_iou_mean, mask_iou_at_5, mask_iou_at_7 = torch.mean(outputs['accuracy'][0]), torch.mean(outputs['accuracy'][1]), torch.mean(outputs['accuracy'][2])

            cls_weight, reg_weight, kp_weight = cfg['loss']['weight']

            loss = rpn_cls_loss * cls_weight + rpn_loc_loss * reg_weight + kp_loss * kp_weight

            optimizer.zero_grad()
            loss.backward()

            if cfg['clip']['split']:
                torch.nn.utils.clip_grad_norm_(
                    model.module.features.parameters(), cfg['clip']['feature'])
                torch.nn.utils.clip_grad_norm_(
                    model.module.rpn_model.parameters(), cfg['clip']['rpn'])
                torch.nn.utils.clip_grad_norm_(
                    model.module.mask_model.parameters(), cfg['clip']['mask'])
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.clip)  # gradient clip

            if is_valid_number(loss.item()):
                optimizer.step()

            siammask_loss = loss.item()

            batch_time = time.time() - end

            avg.update(batch_time=batch_time,
                       rpn_cls_loss=rpn_cls_loss,
                       rpn_loc_loss=rpn_loc_loss,
                       kp_hp_loss=kp_hp_loss,
                       kp_hm_hp_loss=kp_hm_hp_loss,
                       kp_hp_offset_loss=kp_hp_offset_loss,
                       kp_loss=kp_loss,
                       siammask_loss=siammask_loss)
            # mask_iou_mean=mask_iou_mean, mask_iou_at_5=mask_iou_at_5, mask_iou_at_7=mask_iou_at_7)

            tb_writer.add_scalar('loss/cls', rpn_cls_loss, tb_index)
            tb_writer.add_scalar('loss/loc', rpn_loc_loss, tb_index)
            tb_writer.add_scalar('loss/kp_hp_loss', kp_hp_loss, tb_index)
            tb_writer.add_scalar('loss/kp_hm_hp_loss', kp_hm_hp_loss, tb_index)
            tb_writer.add_scalar('loss/kp_hp_offset_loss', kp_hp_offset_loss,
                                 tb_index)
            # tb_writer.add_scalar('loss/kp', kp_loss, tb_index)
            end = time.time()

            if (iter + 1) % args.print_freq == 0:
                logger.info(
                    'Epoch: [{0}][{1}/{2}] lr: {lr:.6f}\t{batch_time:s}\t{data_time:s}'
                    '\t{rpn_cls_loss:s}\t{rpn_loc_loss:s}'
                    '\t{kp_hp_loss:s}\t{kp_hm_hp_loss:s}\t{kp_hp_offset_loss:s}'
                    '\t{kp_loss:s}\t{siammask_loss:s}'.format(
                        epoch + 1,
                        (iter + 1) % num_per_epoch,
                        num_per_epoch,
                        lr=cur_lr,
                        batch_time=avg.batch_time,
                        data_time=avg.data_time,
                        rpn_cls_loss=avg.rpn_cls_loss,
                        rpn_loc_loss=avg.rpn_loc_loss,
                        kp_hp_loss=avg.kp_hp_loss,
                        kp_hm_hp_loss=avg.kp_hm_hp_loss,
                        kp_hp_offset_loss=avg.kp_hp_offset_loss,
                        kp_loss=avg.kp_loss,
                        siammask_loss=avg.siammask_loss,
                    ))
                # mask_iou_mean=avg.mask_iou_mean,
                # mask_iou_at_5=avg.mask_iou_at_5,mask_iou_at_7=avg.mask_iou_at_7))
                print_speed(iter + 1, avg.batch_time.avg,
                            args.epochs * num_per_epoch)
Exemple #6
0
num_frame = cfg.model['input_num']
# print freq
print_freq = cfg.train['print_freq']

# 初始化logger
global_logger = init_log('global', level=logging.INFO)
add_file_handler("global",
                 os.path.join(os.getcwd(), 'logs',
                              '{}.log'.format(experiment_name)),
                 level=logging.DEBUG)

# 打印cfg信息
cfg.log_dict()

# 初始化avrager
avg = AverageMeter()

# cuda
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

# 准备数据集
train_set = MovingMNIST(root='./data/mnist',
                        train=True,
                        download=True,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                        ]),
                        target_transform=transforms.Compose([
                            transforms.ToTensor(),
Exemple #7
0
def train(train_loader, model, optimizer, lr_scheduler, epoch, cfg):
    global tb_index, best_acc, cur_lr, logger
    cur_lr = lr_scheduler.get_cur_lr()
    logger = logging.getLogger('global')
    avg = AverageMeter()
    model.train()
    model = model.cuda()
    end = time.time()

    def is_valid_number(x):
        return not (math.isnan(x) or math.isinf(x) or x > 1e4)

    num_per_epoch = len(train_loader.dataset) // args.epochs // args.batch
    start_epoch = epoch
    epoch = epoch
    for iter, input in enumerate(train_loader):

        if epoch != iter // num_per_epoch + start_epoch:  # next epoch
            epoch = iter // num_per_epoch + start_epoch

            if not os.path.exists(args.save_dir):  # makedir/save model
                os.makedirs(args.save_dir)

            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': args.arch,
                    'state_dict': model.module.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                    'anchor_cfg': cfg['anchors']
                }, False,
                os.path.join(args.save_dir, 'checkpoint_e%d.pth' % (epoch)),
                os.path.join(args.save_dir, 'best.pth'))

            if epoch == args.epochs:
                return

            if model.module.features.unfix(epoch / args.epochs):
                logger.info('unfix part model.')
                optimizer, lr_scheduler = build_opt_lr(model.module, cfg, args,
                                                       epoch)

            lr_scheduler.step(epoch)
            cur_lr = lr_scheduler.get_cur_lr()

            logger.info('epoch:{}'.format(epoch))

        tb_index = iter
        if iter % num_per_epoch == 0 and iter != 0:
            for idx, pg in enumerate(optimizer.param_groups):
                logger.info("epoch {} lr {}".format(epoch, pg['lr']))
                tb_writer.add_scalar('lr/group%d' % (idx + 1), pg['lr'],
                                     tb_index)

        data_time = time.time() - end
        avg.update(data_time=data_time)
        x = {
            'cfg': cfg,
            'template': torch.autograd.Variable(input[0]).cuda(),
            'search': torch.autograd.Variable(input[1]).cuda(),
            'label_cls': torch.autograd.Variable(input[2]).cuda(),
            'label_loc': torch.autograd.Variable(input[3]).cuda(),
            'label_loc_weight': torch.autograd.Variable(input[4]).cuda(),
            'label_mask': torch.autograd.Variable(input[6]).cuda(),
            'label_kp_weight': torch.autograd.Variable(input[7]).cuda(),
            'label_mask_weight': torch.autograd.Variable(input[8]).cuda(),
        }

        outputs = model(x)
        # print(x['search'].shape)
        pred_mask = outputs['predict'][2]
        pred_mask = select_pred_heatmap(
            pred_mask,
            x['label_mask_weight'])  #is rpn_pred_mask (bs, 17, 127, 127)

        true_search = select_gt_img(x['search'], x['label_mask_weight'])
        if true_search.shape:
            save_batch_heatmaps(true_search,
                                pred_mask,
                                vis_outpath + '{}.jpg'.format(iter),
                                normalize=True)

        # pred_mask = pred_mask.cpu(.sh).detach().numpy()
        # true_search = true_search.cpu().detach().numpy()

        # print("pose_mask", pred_mask.shape)
        # pose_heat = np.transpose(pred_mask[0,:,:,:],(1,2,0))   #shape (127,127,17)

        # plt.figure(num='image', figsize=(128,128))
        #
        # plt.subplot(1, 2, 1)
        # plt.title('origin image')
        # plt.imshow(np.transpose(true_search[0,:,:,:], (1,2,0)))
        #
        # plt.subplot(1, 2, 2)
        # plt.title('heatmap')
        # pose_map = np.zeros((127,127), np.float32)
        # for i in range(pred_mask.shape[1]):
        #     pose_map += pose_heat[:,:,i]
        # plt.imshow(pose_map)
        # plt.axis('off')
        #
        #
        # plt.show()

        # 可视化: 把17个map都投影到一张黑色图片上

        rpn_cls_loss, rpn_loc_loss, rpn_mask_loss = torch.mean(outputs['losses'][0]),\
                                                    torch.mean(outputs['losses'][1]),\
                                                    torch.mean(outputs['losses'][2])

        # mask_iou_mean, mask_iou_at_5, mask_iou_at_7 = torch.mean(outputs['accuracy'][0]), torch.mean(outputs['accuracy'][1]), torch.mean(outputs['accuracy'][2])

        cls_weight, reg_weight, mask_weight = cfg['loss']['weight']

        loss = rpn_cls_loss * cls_weight + rpn_loc_loss * reg_weight + rpn_mask_loss * mask_weight

        optimizer.zero_grad()
        loss.backward()

        if cfg['clip']['split']:
            torch.nn.utils.clip_grad_norm_(model.module.features.parameters(),
                                           cfg['clip']['feature'])
            torch.nn.utils.clip_grad_norm_(model.module.rpn_model.parameters(),
                                           cfg['clip']['rpn'])
            torch.nn.utils.clip_grad_norm_(
                model.module.mask_model.parameters(), cfg['clip']['mask'])
        else:
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.clip)  # gradient clip

        if is_valid_number(loss.item()):
            optimizer.step()

        siammask_loss = loss.item()

        batch_time = time.time() - end

        avg.update(batch_time=batch_time,
                   rpn_cls_loss=rpn_cls_loss,
                   rpn_loc_loss=rpn_loc_loss,
                   rpn_mask_loss=rpn_mask_loss * mask_weight,
                   siammask_loss=siammask_loss)
        # mask_iou_mean=mask_iou_mean, mask_iou_at_5=mask_iou_at_5, mask_iou_at_7=mask_iou_at_7)

        tb_writer.add_scalar('loss/cls', rpn_cls_loss, tb_index)
        tb_writer.add_scalar('loss/loc', rpn_loc_loss, tb_index)
        tb_writer.add_scalar('loss/mask', rpn_mask_loss * mask_weight,
                             tb_index)
        # tb_writer.add_scalar('mask/mIoU', mask_iou_mean, tb_index)
        # tb_writer.add_scalar('mask/[email protected]', mask_iou_at_5, tb_index)
        # tb_writer.add_scalar('mask/[email protected]', mask_iou_at_7, tb_index)
        end = time.time()

        if (iter + 1) % args.print_freq == 0:
            logger.info(
                'Epoch: [{0}][{1}/{2}] lr: {lr:.6f}\t{batch_time:s}\t{data_time:s}'
                '\t{rpn_cls_loss:s}\t{rpn_loc_loss:s}\t{rpn_mask_loss:s}\t{siammask_loss:s}'
                .format(
                    epoch + 1,
                    (iter + 1) % num_per_epoch,
                    num_per_epoch,
                    lr=cur_lr,
                    batch_time=avg.batch_time,
                    data_time=avg.data_time,
                    rpn_cls_loss=avg.rpn_cls_loss,
                    rpn_loc_loss=avg.rpn_loc_loss,
                    rpn_mask_loss=avg.rpn_mask_loss,
                    siammask_loss=avg.siammask_loss,
                ))
            # mask_iou_mean=avg.mask_iou_mean,
            # mask_iou_at_5=avg.mask_iou_at_5,mask_iou_at_7=avg.mask_iou_at_7))
            print_speed(iter + 1, avg.batch_time.avg,
                        args.epochs * num_per_epoch)
Exemple #8
0
    # Start Training loop
    end_epoch = start_epoch + train_epochs
    logger.info(
        "\nTraining using triplet loss starting for {} epochs, current epoch: {}, target epoch: {};\n"
        .format(train_epochs, start_epoch, end_epoch))

    # for progress printing
    total_batch = train_epochs * len(train_dataloader)
    current_batch = 0
    batch_time = 0

    for epoch in range(start_epoch, end_epoch):
        # init avg meter
        # avg.update(time=1.1, accuracy=.99)
        avg = AverageMeter()

        # start training epoch
        logger.info(
            "\n------------------------- Start Training {} Epoch -------------------------\n"
            .format(epoch + 1))
        for batch_idx, batch_sample in enumerate(train_dataloader):
            # Skip last iteration to avoid the problem of having different number of tensors while calculating
            # averages (sizes of tensors must be the same for pairwise distance calculation)
            if batch_idx + 1 == len(train_dataloader):
                continue

            # switch to train mode
            for param in model.parameters():
                param.requires_grad = True
            model.train()
Exemple #9
0
def test_model(model, test_dataloader, log_interval, device):
    """Test and Return the feature vector of all sample in dataset with its index.

    Args:
        cfg: (dict) 
            config file of the test precedure.
        model: (nn.module)
            loaded model
        test_dataloader: (torch.Dataloader)
            It should not be none! A non-triplet dataloader to validate data.
            It's sample protocal is:
                {
                    "img": target image,
                    "cls": target class, 
                    "other": other information,
                        {
                            "index" : index,
                        }
                }
        writer: (tensorboard writer)
        device: cuda or cpu

    Return:
        a list of dict:[
            {
                "cls": class label of the sample,
                "feature": feature vectuer of the result,
                "other": other information,
                {
                    "index": index of the sample in the dataset,
                }
            },
            ...,
            {
                "cls": class label of the sample,
                "feature": feature vectuer of the result,
                "other": other information,
                {
                    "index": index of the sample in the dataset,
                }
            }] 
    """
    logger.info("\n------------------------- Start Forwarding Dataset -------------------------\n")
    
    # epoch average meter
    avg_test = AverageMeter()

    # get test batch count
    current_test_batch = 0
    total_test_batch = len(test_dataloader)

    # to return list
    out_sample_list = []
    for batch_idx, batch_sample in enumerate(test_dataloader):
        # Skip last iteration to avoid the problem of having different number of tensors while calculating
        # averages (sizes of tensors must be the same for pairwise distance calculation)
        if batch_idx + 1 == len(test_dataloader):
            continue
        batch_size = test_dataloader.batch_size

        # switch to evaluation mode.
        for param in model.parameters():
            param.requires_grad = False
        model.eval()

        # start time counting
        batch_start_time_test = time.time()

        # Forward pass - compute embeddings
        imgs = batch_sample["img"]
        cls = batch_sample["cls"]
        indexs = batch_sample["other"]["index"]

        imgs = imgs.to(device)

        out_put = model(imgs)

        out_put.to("cpu")

        for i in range(batch_size):
            out_dict = {
                "cls": cls[i],
                "feature": out_put[i],
                "other": {
                    "index" : indexs[i]
                    },
            }
            out_sample_list.append(out_dict)
        
        # batch time & batch count
        current_test_batch += 1
        batch_time = time.time() - batch_start_time_test

        if current_test_batch % log_interval == 0:            
                print_speed(current_test_batch, batch_time, total_test_batch, "global")
    else:
        logger.info("\n------------------------- End Forwarding Dataset -------------------------\n")

    return out_sample_list
Exemple #10
0
cfg = Configs(args.cfg)
# board的路径
board_path = cfg.meta["board_path"]
experiment_path = cfg.meta["experiment_path"]
arch = cfg.meta["arch"]
# 训练时候的一些参数
batch_size = cfg.train['batch_size']
epoches = cfg.train['epoches']
lr = cfg.train['lr']
# 初始化未来帧的数量
input_num = cfg.model['input_num']
# print freq
print_freq = cfg.train['print_freq']

# 初始化avrager
avg = AverageMeter()

# cuda
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

# 准备数据集
train_set = MovingMNIST(root='./data/mnist', train=True, download=True,
                        transform=transforms.Compose([transforms.ToTensor(),]),
                        target_transform=transforms.Compose([transforms.ToTensor(),]))
test_set = MovingMNIST(root='./data/mnist', train=False, download=True,
                        transform=transforms.Compose([transforms.ToTensor(),]),
                        target_transform=transforms.Compose([transforms.ToTensor(),]))

# 建立dataloader
Exemple #11
0
        optimizer.load_state_dict(checkpoint['optimizer'])
        model.load_state_dict(checkpoint['model'])
        lr = optimizer.param_groups[0]['lr']
        print("loaded checkpoint {}".format(load_name))
    model.to(device)

    # loss
    trip_loss = TripletLoss(margin=1)
    for epoch in range(args.start_epoch, args.max_epochs):
        # train
        model.train()
        start = time.time()
        if epoch % (args.lr_decay_epoch + 1) == 0:
            adjust_learning_rate(optimizer, args.lr_decay_gamma)
            lr *= args.lr_decay_gamma
        avg = AverageMeter()
        print('Training ...')
        for step, value in tqdm(enumerate(train_dataloader)):
            end_time = time.time()
            batch_loss = 0
            for i in range(len(value[0])):
                visual_graph = value[0][i].to(device)
                points_graph = value[1][i].to(device)
                visual_feats = model(visual_graph)
                points_feats = model(points_graph)
                global_feats = torch.cat((visual_feats, points_feats),
                                         dim=0).contiguous()
                labels = torch.cat((visual_graph.y, points_graph.y),
                                   dim=0).contiguous()
                iloss, _, _, _, _, _ = global_loss(trip_loss, global_feats,
                                                   labels)