Beispiel #1
0
    def validate(self):
        """
        Function to validate a training model on the val split.
        """
        logger.info("start validation....")
        val_loss = 0
        label_trues, label_preds = [], []

        # Evaluation
        for batch_idx, (data, target) in tqdm.tqdm(
                enumerate(self.val_loader),
                total=len(self.val_loader),
                desc='Validation iteration = {},epoch={}'.format(
                    self.iteration, self.epoch),
                leave=False):

            if self.cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data, volatile=True), Variable(target)

            score = self.model(data)

            loss = CrossEntropyLoss2d_Seg(score,
                                          target,
                                          size_average=self.size_average)

            if np.isnan(float(loss.data[0])):
                raise ValueError('loss is nan while validating')
            val_loss += float(loss.data[0]) / len(data)

            lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
            lbl_true = target.data.cpu().numpy()

            label_trues.append(lbl_true)
            label_preds.append(lbl_pred)

        # Computing the metrics
        acc, acc_cls, mean_iu, _ = torchfcn.utils.label_accuracy_score(
            label_trues, label_preds, self.n_class)
        val_loss /= len(self.val_loader)

        logger.info("iteration={},epoch={},validation mIoU = {}".format(
            self.iteration, self.epoch, mean_iu))

        is_best = mean_iu > self.best_mean_iu
        if is_best:
            self.best_mean_iu = mean_iu
        torch.save(
            {
                'epoch': self.epoch,
                'iteration': self.iteration,
                'arch': self.model.__class__.__name__,
                'optim_state_dict': self.optim.state_dict(),
                'model_state_dict': self.model.state_dict(),
                'best_mean_iu': self.best_mean_iu,
            }, osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar'))
        if is_best:
            shutil.copy(
                osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar'),
                osp.join(logger.get_logger_dir(), 'model_best.pth.tar'))
Beispiel #2
0
def do_test(logger, tdcnn_demo, dataloader_test):
    logger.info('do test')
    logger.info(args.test_nms)

    if torch.cuda.is_available():
        tdcnn_demo = tdcnn_demo.cuda()
        if isinstance(args.gpus, int):
            args.gpus = [args.gpus]
        tdcnn_demo = nn.parallel.DataParallel(tdcnn_demo, device_ids=args.gpus)

    state_dict = torch.load(
        os.path.join(logger.get_logger_dir(), "best_model.pth"))['model']
    logger.info("best_model.pth loaded!")
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if 'module' not in k:
            k = 'module.{}'.format(k)
        if 'modules_focal' in k:
            k = k.replace('modules_focal', '_fusion_modules')
        new_state_dict[k] = v
    tdcnn_demo.load_state_dict(new_state_dict)
    tdcnn_demo.eval()

    test_mAP, test_ap = test_net(tdcnn_demo,
                                 dataloader=dataloader_test,
                                 args=args,
                                 split='test',
                                 max_per_video=args.max_per_video,
                                 thresh=args.thresh)
    tdcnn_demo.train()
    logger.info("final test set result: {}".format((test_mAP, test_ap)))
    logger.info("Congrats~")
Beispiel #3
0
def do_python_eval(use_07=True):
    cachedir = os.path.join(Sim_ROOT, 'annotations_cache')
    aps = []
    # The PASCAL VOC metric changed in 2010
    use_07_metric = use_07
    log.l.info('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))

    for i, cls in enumerate(labelmap):
        filename = get_voc_results_file_template(set_type, cls)
        rec, prec, ap = voc_eval(filename,
                                 annopath,
                                 cls,
                                 cachedir,
                                 ovthresh=0.5,
                                 use_07_metric=use_07_metric)
        aps += [ap]
        log.l.info('AP for {} = {:.4f}'.format(cls, ap))
        with open(os.path.join(logger.get_logger_dir(), cls + '_pr.pkl'),
                  'wb') as f:
            pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
    log.l.info('Mean AP = {:.4f}'.format(np.mean(aps)))
    log.l.info('~~~~~~~~')
    log.l.info('Results:')
    for ap in aps:
        log.l.info('{:.3f}'.format(ap))
    log.l.info('{:.3f}'.format(np.mean(aps)))
    log.l.info('~~~~~~~~')
    log.l.info('')
    log.l.info(
        '--------------------------------------------------------------')
    log.l.info('Results computed with the **unofficial** Python eval code.')
    log.l.info(
        'Results should be very close to the official MATLAB eval code.')
    log.l.info(
        '--------------------------------------------------------------')
Beispiel #4
0
def test_net(net, dataset):
    """Test a Fast R-CNN network on an image database."""
    num_images = len(dataset)
    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(len(labelmap) + 1)]  #one background

    # timers
    _t = {'im_detect': Timer(), 'misc': Timer()}

    for i in tqdm(range(num_images), total=num_images):
        if i > 10 and is_debug == 1: break

        im, gt, h, w = dataset.pull_item(i)

        x = Variable(im.unsqueeze(0))
        if args.cuda:
            x = x.cuda()
        _t['im_detect'].tic()
        detections = net(x).data
        detect_time = _t['im_detect'].toc(average=False)

        # skip j = 0, because it's the background class
        for j in range(1, detections.size(1)):
            dets = detections[0, j, :]
            mask = dets[:, 0].gt(0.).expand(5, dets.size(0)).t()
            dets = torch.masked_select(dets, mask).view(-1, 5)
            if dets.dim() == 0:
                continue
            boxes = dets[:, 1:]
            boxes[:, 0] *= w
            boxes[:, 2] *= w
            boxes[:, 1] *= h
            boxes[:, 3] *= h
            scores = dets[:, 0].cpu().numpy()
            cls_dets = np.hstack((boxes.cpu().numpy(), scores[:, np.newaxis])) \
                .astype(np.float32, copy=False)
            all_boxes[j][i] = cls_dets

        log.l.info('im_detect: {:d}/{:d} {:.3f}s'.format(
            i + 1, num_images, detect_time))

    with open(os.path.join(logger.get_logger_dir(), 'detections.pkl'),
              'wb') as f:
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)

    log.l.info('Evaluating detections')
    evaluate_detections(all_boxes, dataset)
Beispiel #5
0
def train(args):

    logger.auto_set_dir()
    os.environ['CUDA_VISIBLE_DEVICES'] = '3'

    # Setup Augmentations
    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path,
                           is_transform=True,
                           img_size=(args.img_rows, args.img_cols),
                           epoch_scale=4,
                           augmentations=data_aug,
                           img_norm=args.img_norm)
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='val',
                           img_size=(args.img_rows, args.img_cols),
                           img_norm=args.img_norm)

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=8)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup Model
    from model_zoo.deeplabv1 import VGG16_LargeFoV
    model = VGG16_LargeFoV(class_num=n_classes,
                           image_size=[args.img_cols, args.img_rows],
                           pretrained=True)

    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    model.cuda()

    # Check if model has custom optimizer / loss
    if hasattr(model, 'optimizer'):
        logger.warn("don't have customzed optimizer, use default setting!")
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.l_rate,
                                    momentum=0.99,
                                    weight_decay=5e-4)

    optimizer_summary(optimizer)
    if args.resume is not None:
        if os.path.isfile(args.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            logger.info("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logger.info("No checkpoint found at '{}'".format(args.resume))

    best_iou = -100.0
    for epoch in tqdm(range(args.n_epoch), total=args.n_epoch):
        model.train()
        for i, (images, labels) in tqdm(enumerate(trainloader),
                                        total=len(trainloader),
                                        desc="training epoch {}/{}".format(
                                            epoch, args.n_epoch)):
            cur_iter = i + epoch * len(trainloader)
            cur_lr = adjust_learning_rate(optimizer,
                                          args.l_rate,
                                          cur_iter,
                                          args.n_epoch * len(trainloader),
                                          power=0.9)
            #if i > 10:break

            images = Variable(images.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()
            outputs = model(images)
            #print(np.unique(outputs.data[0].cpu().numpy()))
            loss = CrossEntropyLoss2d_Seg(input=outputs,
                                          target=labels,
                                          class_num=n_classes)

            loss.backward()
            optimizer.step()

            if (i + 1) % 100 == 0:
                logger.info("Epoch [%d/%d] Loss: %.4f, lr: %.7f" %
                            (epoch + 1, args.n_epoch, loss.data[0], cur_lr))

        model.eval()
        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader),
                                                    total=len(valloader),
                                                    desc="validation"):
            images_val = Variable(images_val.cuda(), volatile=True)
            labels_val = Variable(labels_val.cuda(), volatile=True)

            outputs = model(images_val)
            pred = outputs.data.max(1)[1].cpu().numpy()
            gt = labels_val.data.cpu().numpy()
            running_metrics.update(gt, pred)

        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            logger.info("{}: {}".format(k, v))
        running_metrics.reset()

        if score['Mean IoU : \t'] >= best_iou:
            best_iou = score['Mean IoU : \t']
            state = {
                'epoch': epoch + 1,
                'mIoU': best_iou,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(state,
                       os.path.join(logger.get_logger_dir(), "best_model.pkl"))
Beispiel #6
0
                  'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f}), best_knn_acc={best_knn_acc}'.format(
                epoch, args.epochs, batch_idx, len(trainloader), batch_time=batch_time, data_time=data_time, train_loss=train_loss,best_knn_acc=best_acc))
            wandb_logging(
                d=dict(loss1e4=loss.item() * 1e4, group0_lr=optimizer.state_dict()['param_groups'][0]['lr']),
                step=pytorchgo_args.get_args().step,
                use_wandb=pytorchgo_args.get_args().wandb,
                prefix="training epoch {}/{}: ".format(epoch, pytorchgo_args.get_args().epochs))
    return selflabels


pytorchgo_args.get_args().step = 0
for epoch in range(start_epoch, start_epoch + args.epochs):
    if args.debug and epoch>=3:break
    selflabels = train(epoch, selflabels)
    feature_return_switch(model, True)
    logger.warning(logger.get_logger_dir())
    logger.warning("doing KNN evaluation.")
    acc = kNN(model, trainloader, testloader, K=10, sigma=0.1, dim=knn_dim)
    logger.warning("finish KNN evaluation.")
    feature_return_switch(model, False)
    if acc > best_acc:
        logger.info('get better result, saving..')
        state = {
            'net': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
            'opt': optimizer.state_dict(),
            'L': selflabels,
        }
        if not os.path.isdir(args.exp):
            os.mkdir(args.exp)
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    from pytorchgo.utils.pytorch_utils import set_gpu
    set_gpu(args.gpu)

    # Create network
    if args.model == 'DeepLab':
        logger.info("adopting Deeplabv2 base model..")
        model = Res_Deeplab(num_classes=args.num_classes, multi_scale=False)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

        optimizer = optim.SGD(model.optim_parameters(args),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.model == "FCN8S":
        logger.info("adopting FCN8S base model..")
        from pytorchgo.model.MyFCN8s import MyFCN8s
        model = MyFCN8s(n_class=NUM_CLASSES)
        vgg16 = torchfcn.models.VGG16(pretrained=True)
        model.copy_params_from_vgg16(vgg16)

        optimizer = optim.SGD(model.parameters(),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    else:
        raise ValueError

    model.train()
    model.cuda()

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda()

    model_D2.train()
    model_D2.cuda()

    if SOURCE_DATA == "GTA5":
        trainloader = data.DataLoader(GTA5DataSet(
            args.data_dir,
            args.data_list,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size,
            scale=args.random_scale,
            mirror=args.random_mirror,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
        trainloader_iter = enumerate(trainloader)
    elif SOURCE_DATA == "SYNTHIA":
        trainloader = data.DataLoader(SynthiaDataSet(
            args.data_dir,
            args.data_list,
            LABEL_LIST_PATH,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size,
            scale=args.random_scale,
            mirror=args.random_mirror,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
        trainloader_iter = enumerate(trainloader)
    else:
        raise ValueError

    targetloader = data.DataLoader(cityscapesDataSet(
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    best_mIoU = 0

    model_summary([model, model_D1, model_D2])
    optimizer_summary([optimizer, optimizer_D1, optimizer_D2])

    for i_iter in tqdm(range(args.num_steps_stop),
                       total=args.num_steps_stop,
                       desc="training"):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        lr = adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        lr_D1 = adjust_learning_rate_D(optimizer_D1, i_iter)
        lr_D2 = adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            ######################### train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.next()
            images, labels, _, _ = batch
            images = Variable(images).cuda()

            pred2 = model(images)
            pred2 = interp(pred2)

            loss_seg2 = loss_calc(pred2, labels)
            loss = loss_seg2

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value2 += loss_seg2.data.cpu().numpy()[0] / args.iter_size

            # train with target

            _, batch = targetloader_iter.next()
            images, _, _, _ = batch
            images = Variable(images).cuda()

            pred_target2 = model(images)
            pred_target2 = interp_target(pred_target2)

            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda())

            loss = args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            )[0] / args.iter_size

            ################################## train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred2 = pred2.detach()
            D_out2 = model_D2(F.softmax(pred2))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda())

            loss_D2 = loss_D2 / args.iter_size / 2
            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()[0]

            # train with target
            pred_target2 = pred_target2.detach()

            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda())

            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if i_iter % 100 == 0:
            logger.info(
                'iter = {}/{},loss_seg1 = {:.3f} loss_seg2 = {:.3f} loss_adv1 = {:.3f}, loss_adv2 = {:.3f} loss_D1 = {:.3f} loss_D2 = {:.3f}, lr={:.7f}, lr_D={:.7f}, best miou16= {:.5f}'
                .format(i_iter, args.num_steps_stop, loss_seg_value1,
                        loss_seg_value2, loss_adv_target_value1,
                        loss_adv_target_value2, loss_D_value1, loss_D_value2,
                        lr, lr_D1, best_mIoU))

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            logger.info("saving snapshot.....")
            cur_miou16 = proceed_test(model, input_size)
            is_best = True if best_mIoU < cur_miou16 else False
            if is_best:
                best_mIoU = cur_miou16
            torch.save(
                {
                    'iteration': i_iter,
                    'optim_state_dict': optimizer.state_dict(),
                    'optim_D1_state_dict': optimizer_D1.state_dict(),
                    'optim_D2_state_dict': optimizer_D2.state_dict(),
                    'model_state_dict': model.state_dict(),
                    'model_D1_state_dict': model_D1.state_dict(),
                    'model_D2_state_dict': model_D2.state_dict(),
                    'best_mean_iu': cur_miou16,
                }, osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar'))
            if is_best:
                import shutil
                shutil.copy(
                    osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar'),
                    osp.join(logger.get_logger_dir(), 'model_best.pth.tar'))

        if i_iter >= args.num_steps_stop - 1:
            break
Beispiel #8
0
def train():
    train_loader = generator.Generator(args.dataset_root,
                                       args,
                                       partition='train',
                                       dataset=args.dataset)
    logger.info('Batch size: ' + str(args.batch_size))

    #Try to load models
    enc_nn = models.load_model('enc_nn', args)
    metric_nn = models.load_model('metric_nn', args)

    if enc_nn is None or metric_nn is None:
        enc_nn, metric_nn = models.create_models(args=args)
    softmax_module = models.SoftmaxModule()

    if args.cuda:
        enc_nn.cuda()
        metric_nn.cuda()

    logger.info(str(enc_nn))
    logger.info(str(metric_nn))

    weight_decay = 0
    if args.dataset == 'mini_imagenet':
        logger.info('Weight decay ' + str(1e-6))
        weight_decay = 1e-6
    opt_enc_nn = optim.Adam(enc_nn.parameters(),
                            lr=args.lr,
                            weight_decay=weight_decay)
    opt_metric_nn = optim.Adam(metric_nn.parameters(),
                               lr=args.lr,
                               weight_decay=weight_decay)

    model_summary([enc_nn, metric_nn])
    optimizer_summary([opt_enc_nn, opt_metric_nn])
    enc_nn.train()
    metric_nn.train()
    counter = 0
    total_loss = 0
    val_acc, val_acc_aux = 0, 0
    test_acc = 0
    for batch_idx in range(args.iterations):

        ####################
        # Train
        ####################
        data = train_loader.get_task_batch(
            batch_size=args.batch_size,
            n_way=args.train_N_way,
            unlabeled_extra=args.unlabeled_extra,
            num_shots=args.train_N_shots,
            cuda=args.cuda,
            variable=True)
        [
            batch_x, label_x, _, _, batches_xi, labels_yi, oracles_yi,
            hidden_labels
        ] = data

        opt_enc_nn.zero_grad()
        opt_metric_nn.zero_grad()

        loss_d_metric = train_batch(model=[enc_nn, metric_nn, softmax_module],
                                    data=[
                                        batch_x, label_x, batches_xi,
                                        labels_yi, oracles_yi, hidden_labels
                                    ])

        opt_enc_nn.step()
        opt_metric_nn.step()

        adjust_learning_rate(optimizers=[opt_enc_nn, opt_metric_nn],
                             lr=args.lr,
                             iter=batch_idx)

        ####################
        # Display
        ####################
        counter += 1
        total_loss += loss_d_metric.data[0]
        if batch_idx % args.log_interval == 0:
            display_str = 'Train Iter: {}'.format(batch_idx)
            display_str += '\tLoss_d_metric: {:.6f}'.format(total_loss /
                                                            counter)
            logger.info(display_str)
            counter = 0
            total_loss = 0

        ####################
        # Test
        ####################
        if (batch_idx + 1) % args.test_interval == 0 or batch_idx == 20:
            if batch_idx == 20:
                test_samples = 100
            else:
                test_samples = 3000
            if args.dataset == 'mini_imagenet':
                val_acc_aux = test.test_one_shot(
                    args,
                    model=[enc_nn, metric_nn, softmax_module],
                    test_samples=test_samples * 5,
                    partition='val')
            test_acc_aux = test.test_one_shot(
                args,
                model=[enc_nn, metric_nn, softmax_module],
                test_samples=test_samples * 5,
                partition='test')
            test.test_one_shot(args,
                               model=[enc_nn, metric_nn, softmax_module],
                               test_samples=test_samples,
                               partition='train')
            enc_nn.train()
            metric_nn.train()

            if val_acc_aux is not None and val_acc_aux >= val_acc:
                test_acc = test_acc_aux
                val_acc = val_acc_aux

            if args.dataset == 'mini_imagenet':
                logger.info("Best test accuracy {:.4f} \n".format(test_acc))

        ####################
        # Save model
        ####################
        if (batch_idx + 1) % args.save_interval == 0:
            logger.info("saving model...")
            torch.save(enc_nn,
                       os.path.join(logger.get_logger_dir(), 'enc_nn.t7'))
            torch.save(metric_nn,
                       os.path.join(logger.get_logger_dir(), 'metric_nn.t7'))

    # Test after training
    test.test_one_shot(args,
                       model=[enc_nn, metric_nn, softmax_module],
                       test_samples=args.test_samples)
Beispiel #9
0
def train():
    if args.dataset == 'COCO':
        if args.dataset_root == VOC_ROOT:
            if not os.path.exists(COCO_ROOT):
                parser.error('Must specify dataset_root if specifying dataset')
            logger.info("WARNING: Using default COCO dataset_root because " +
                        "--dataset_root was not specified.")
            args.dataset_root = COCO_ROOT
        cfg = coco
        dataset = COCODetection(root=args.dataset_root,
                                transform=SSDAugmentation(
                                    cfg['min_dim'], MEANS))
    elif args.dataset == 'VOC':
        if args.dataset_root == COCO_ROOT:
            parser.error('Must specify dataset if specifying dataset_root')
        cfg = voc
        dataset = VOCDetection(root=args.dataset_root,
                               transform=SSDAugmentation(
                                   cfg['min_dim'], MEANS))

    if args.visdom:
        import visdom
        viz = visdom.Visdom()

    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net

    if args.cuda:
        #net = torch.nn.DataParallel(ssd_net), if only one gpu, just comment it!!
        cudnn.benchmark = True

    if args.resume:
        logger.info('Resuming training, loading {}...'.format(args.resume))
        ssd_net.load_weights(args.resume)
    else:
        vgg_weights = torch.load("weights/" + args.basenet)
        logger.info('Loading base network...')
        ssd_net.vgg.load_state_dict(vgg_weights)

    if args.cuda:
        net = net.cuda()

    if not args.resume:
        logger.info('Initializing weights...')
        # initialize newly added layers' weights with xavier method
        ssd_net.extras.apply(weights_init)
        ssd_net.loc.apply(weights_init)
        ssd_net.conf.apply(weights_init)

    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    criterion = MultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5,
                             False, args.cuda)

    net.train()
    # loss counters
    loc_loss = 0
    conf_loss = 0
    epoch = 0
    logger.info('Loading the dataset...')

    epoch_size = len(dataset) // args.batch_size
    logger.info('Training SSD on:', dataset.name)
    logger.info('Using the specified args:')
    logger.info(args)

    step_index = 0

    if args.visdom:
        vis_title = 'SSD.PyTorch on ' + dataset.name
        vis_legend = ['Loc Loss', 'Conf Loss', 'Total Loss']
        iter_plot = create_vis_plot('Iteration', 'Loss', vis_title, vis_legend)
        epoch_plot = create_vis_plot('Epoch', 'Loss', vis_title, vis_legend)

    data_loader = data.DataLoader(dataset,
                                  args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True,
                                  collate_fn=detection_collate,
                                  pin_memory=True)

    # create batch iterator
    batch_iterator = iter(data_loader)
    for iteration in tqdm(range(args.start_iter, cfg['max_iter'])):
        if args.visdom and iteration != 0 and (iteration % epoch_size == 0):
            update_vis_plot(epoch, loc_loss, conf_loss, epoch_plot, None,
                            'append', epoch_size)
            # reset epoch loss counters
            loc_loss = 0
            conf_loss = 0
            epoch += 1

        if iteration in cfg['lr_steps']:
            step_index += 1
            adjust_learning_rate(optimizer, args.gamma, step_index)

        # load train data
        #images, targets = next(batch_iterator)
        #https://github.com/amdegroot/ssd.pytorch/issues/140
        try:
            images, targets = next(batch_iterator)
        except StopIteration:
            batch_iterator = iter(data_loader)
            images, targets = next(batch_iterator)

        if args.cuda:
            images = Variable(images.cuda())
            targets = [Variable(ann.cuda(), volatile=True) for ann in targets]
        else:
            images = Variable(images)
            targets = [Variable(ann, volatile=True) for ann in targets]
        # forward
        t0 = time.time()
        out = net(images)
        # backprop
        optimizer.zero_grad()
        loss_l, loss_c = criterion(out, targets)
        loss = loss_l + loss_c
        loss.backward()
        optimizer.step()
        t1 = time.time()
        loc_loss += loss_l.data[0]
        conf_loss += loss_c.data[0]

        if iteration % 10 == 0:
            logger.info('timer: {} sec.'.format(t1 - t0))
            logger.info('iter {}/{} || Loss: {} ||'.format(
                repr(iteration), cfg['max_iter'], loss.data[0]))

        if args.visdom:
            update_vis_plot(iteration, loss_l.data[0], loss_c.data[0],
                            iter_plot, epoch_plot, 'append')

        if iteration != 0 and iteration % 5000 == 0:
            logger.info('Saving state, iter: {}'.format(iteration))
            torch.save(
                ssd_net.state_dict(),
                os.path.join(
                    logger.get_logger_dir(),
                    'ssd300_{}_{}.pth'.format(args.dataset, repr(iteration))))
    torch.save(ssd_net.state_dict(),
               os.path.join(logger.get_logger_dir(), args.dataset + '.pth'))
Beispiel #10
0
                              pytorchgo_args.get_args().epochs))
            #optimizer_summary(optimizer)

    cpu_prototype = model.prototype_N2K.detach().cpu().numpy()
    return cpu_prototype


optimizer_summary(optimizer)
model_summary(model)

pytorchgo_args.get_args().step = 0
for epoch in range(start_epoch, start_epoch + args.epochs):
    if args.debug and epoch >= 2: break
    prototype = train(epoch)
    feature_return_switch(model, True)
    logger.warning(logger.get_logger_dir())
    logger.warning("doing KNN evaluation.")
    acc = kNN(model, trainloader, testloader, K=10, sigma=0.1, dim=knn_dim)
    logger.warning("finish KNN evaluation.")
    feature_return_switch(model, False)
    if acc > best_acc:
        logger.info('get better result, saving..')
        state = {
            'net': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
            'opt': optimizer.state_dict(),
            'prototype': prototype,
        }
        torch.save(state, os.path.join(logger.get_logger_dir(),
                                       'best_ckpt.t7'))
Beispiel #11
0
def test_net(tdcnn_demo, dataloader, args, split, max_per_video=0, thresh=0):
    np.random.seed(cfg.RNG_SEED)
    total_video_num = len(dataloader) * args.batch_size

    all_twins = [[[] for _ in range(total_video_num)]
                 for _ in range(args.num_classes)
                 ]  # class_num,video_num,proposal_num
    tdcnn_demo.eval()
    empty_array = np.transpose(np.array([[], [], []]), (1, 0))

    for data_idx, (support_data, video_data, gt_twins, num_gt,
                   video_info) in tqdm(enumerate(dataloader),
                                       desc="evaluation"):
        if is_debug and data_idx > fast_eval_samples:
            break

        video_data = video_data.cuda()
        for i in range(args.shot):
            support_data[i] = support_data[i].cuda()
        gt_twins = gt_twins.cuda()
        batch_size = video_data.shape[0]
        rois, cls_prob, twin_pred = tdcnn_demo(
            video_data, gt_twins, support_data
        )  ##torch.Size([1, 300, 3]),torch.Size([1, 300, 2]),torch.Size([1, 300, 4])
        scores_all = cls_prob.data
        twins = rois.data[:, :, 1:3]

        if cfg.TEST.TWIN_REG:  # True
            # Apply bounding-twin regression deltas
            twin_deltas = twin_pred.data
            if cfg.TRAIN.TWIN_NORMALIZE_TARGETS_PRECOMPUTED:  # True
                # Optionally normalize targets by a precomputed mean and stdev
                twin_deltas = twin_deltas.view(-1, 2) * torch.FloatTensor(
                    cfg.TRAIN.TWIN_NORMALIZE_STDS
                ).type_as(twin_deltas) + torch.FloatTensor(
                    cfg.TRAIN.TWIN_NORMALIZE_MEANS).type_as(twin_deltas)
                twin_deltas = twin_deltas.view(
                    batch_size, -1,
                    2 * args.num_classes)  # torch.Size([1, 300, 4])

            pred_twins_all = twin_transform_inv(
                twins, twin_deltas, batch_size)  # torch.Size([1, 300, 4])
            pred_twins_all = clip_twins(pred_twins_all, cfg.TRAIN.LENGTH[0],
                                        batch_size)  # torch.Size([1, 300, 4])
        else:
            # Simply repeat the twins, once for each class
            pred_twins_all = np.tile(twins, (1, scores_all.shape[1]))

        for b in range(batch_size):
            if is_debug:
                logger.info(video_info)
            scores = scores_all[b]  # scores.squeeze()
            pred_twins = pred_twins_all[b]  # .squeeze()

            # skip j = 0, because it's the background class
            for j in range(1, args.num_classes):
                inds = torch.nonzero(scores[:, j] > thresh).view(-1)
                # if there is det
                if inds.numel() > 0:
                    cls_scores = scores[:, j][inds]
                    _, order = torch.sort(cls_scores, 0, True)
                    cls_twins = pred_twins[inds][:, j * 2:(j + 1) * 2]

                    cls_dets = torch.cat((cls_twins, cls_scores.unsqueeze(1)),
                                         1)
                    # cls_dets = torch.cat((cls_twins, cls_scores), 1)
                    cls_dets = cls_dets[order]
                    keep = nms_cpu(cls_dets.cpu(), args.test_nms)
                    if (len(keep) > 0):
                        if is_debug:
                            print("after nms, keep {}".format(len(keep)))
                        cls_dets = cls_dets[keep.view(-1).long()]
                    else:
                        print(
                            "warning, after nms, none of the rois is kept!!!")
                    all_twins[j][data_idx * batch_size +
                                 b] = cls_dets.cpu().numpy()
                else:
                    all_twins[j][data_idx * batch_size + b] = empty_array

            # Limit to max_per_video detections *over all classes*, useless code here, default max_per_video = 0
            if max_per_video > 0:
                video_scores = np.hstack([
                    all_twins[j][data_idx * batch_size + b][:, -1]
                    for j in range(1, args.num_classes)
                ])
                if len(video_scores) > max_per_video:
                    video_thresh = np.sort(video_scores)[-max_per_video]
                    for j in range(1, args.num_classes):
                        keep = np.where(
                            all_twins[j][data_idx * batch_size +
                                         b][:, -1] >= video_thresh)[0]
                        all_twins[j][data_idx * batch_size +
                                     b] = all_twins[j][data_idx * batch_size +
                                                       b][keep, :]

            # logger.info('im_detect: {:d}/{:d}'.format(i * batch_size + b + 1, len(dataloader)))

    pred = dict()
    pred['external_data'] = ''
    pred['version'] = ''
    pred['results'] = dict()
    for i_video in tqdm(range(total_video_num),
                        desc="generating prediction json.."):
        if is_debug and i_video > fast_eval_samples * batch_size - 2:
            break
        item_pre = []
        for j_roi in range(
                0, len(all_twins[1][i_video])
        ):  # binary class problem, here we only consider class_num=1, ignoring background class
            _d = dict()
            _d['score'] = all_twins[1][i_video][j_roi][2].item()
            _d['label'] = 'c1'
            _d['segment'] = [
                all_twins[1][i_video][j_roi][0].item(),
                all_twins[1][i_video][j_roi][1].item()
            ]
            item_pre.append(_d)
        pred['results']["query_%05d" % i_video] = item_pre

    predict_filename = os.path.join(logger.get_logger_dir(),
                                    '{}_pred.json'.format(split))
    ground_truth_filename = os.path.join('preprocess/{}'.format(args.dataset),
                                         '{}_gt.json'.format(split))

    with open(predict_filename, 'w') as f:
        json.dump(pred, f)
        logger.info('dump pred.json complete..')

    sys.path.insert(0, "evaluation")
    from eval_detection import ANETdetection

    anet_detection = ANETdetection(ground_truth_filename,
                                   predict_filename,
                                   subset="test",
                                   tiou_thresholds=tiou_thresholds,
                                   verbose=True,
                                   check_status=False)
    anet_detection.evaluate()
    ap = anet_detection.mAP
    mAP = ap[0]
    return mAP, ap
Beispiel #12
0
def train():

    logger.info("current cuda device: {}".format(torch.cuda.current_device()))

    few_shot_net = build_ssd(args.dim, num_classes)
    support_net = build_ssd_support(args.dim, num_classes)

    vgg16_state_dict = torch.load(args.basenet)
    new_params = {}
    for index, i in enumerate(vgg16_state_dict):
        #if index >= 20:
        #    continue
        new_params[i] = vgg16_state_dict[i]
        logger.info(
            "recovering weight for student model(loading vgg16 weight): {}".
            format(i))
    support_net.support_vgg.load_state_dict(new_params)

    logger.info('Loading base network...')
    few_shot_net.query_vgg.load_state_dict(torch.load(args.basenet))

    few_shot_net = few_shot_net.cuda()
    support_net = support_net.cuda()

    def xavier(param):
        init.xavier_uniform(param)

    def weights_init(m):
        if isinstance(m, nn.Conv2d):
            xavier(m.weight.data)
            m.bias.data.zero_()

    logger.info('Initializing weights...')
    # initialize newly added layers' weights with xavier method
    few_shot_net.extras.apply(weights_init)
    few_shot_net.loc.apply(weights_init)
    few_shot_net.conf.apply(weights_init)

    optimizer = optim.SGD(list(few_shot_net.parameters()) +
                          list(support_net.parameters()),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    criterion = MultiBoxLoss(num_classes,
                             size=args.dim,
                             overlap_thresh=0.5,
                             prior_for_matching=True,
                             bkg_label=0,
                             neg_mining=True,
                             neg_pos=3,
                             neg_overlap=0.5,
                             encode_target=False,
                             use_gpu=True)

    few_shot_net.train()
    support_net.train()
    best_val_result = 0
    logger.info('Loading Dataset...')

    dataset = FSLDataset(params=train_setting,
                         image_size=(args.dim, args.dim),
                         query_image_augs=SSDAugmentation(args.dim),
                         ds_name=args.dataset)

    epoch_size = len(dataset) // args.batch_size

    data_loader = data.DataLoader(dataset,
                                  args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True,
                                  pin_memory=True,
                                  collate_fn=detection_collate)
    batch_iterator = iter(data_loader)

    lr = args.lr
    for iteration in tqdm(range(args.iterations + 1),
                          total=args.iterations,
                          desc="training {}".format(logger.get_logger_dir())):
        if (not batch_iterator) or (iteration % epoch_size == 0):
            # create batch iterator
            batch_iterator = iter(data_loader)

        if iteration in stepvalues:
            lr = adjust_learning_rate(optimizer, lr, iteration)

        # load train data
        first_images, images, targets, metadata = next(batch_iterator)
        #embed()

        first_images = Variable(first_images.cuda())
        images = Variable(images.cuda())
        targets = [Variable(anno.cuda(), volatile=True) for anno in targets]

        fusion_support = support_net(first_images)
        out = few_shot_net(fusion_support, images, is_train=True)
        # backprop
        optimizer.zero_grad()
        loss_l, loss_c = criterion(out, targets)
        loss = loss_l + loss_c
        loss.backward()
        optimizer.step()

        if iteration % log_per_iter == 0 and iteration > 0:
            logger.info(
                '''LR: {}\t Iter: {}\t Loss_l: {:.5f}\t Loss_c: {:.5f}\t Loss_total: {:.5f}\t best_result: {:.5f}'''
                .format(lr, iteration, loss_l.data[0], loss_c.data[0],
                        loss.data[0], best_val_result))

        if iteration % save_per_iter == 0 and iteration > 0:
            few_shot_net.eval()
            support_net.eval()
            cur_eval_result = do_eval(
                few_shot_net,
                support_net=support_net,
                test_setting=val_setting,
                base_dir=logger.get_logger_dir(),
            )
            few_shot_net.train()
            support_net.train()

            is_best = True if cur_eval_result > best_val_result else False
            if is_best:
                best_val_result = cur_eval_result
                torch.save(
                    {
                        'iteration': iteration,
                        'optim_state_dict': optimizer.state_dict(),
                        'support_model_state_dict': support_net.state_dict(),
                        'query_model_state_dict': few_shot_net.state_dict(),
                        'best_mean_iu': best_val_result,
                    }, os.path.join(logger.get_logger_dir(), 'cherry.pth'))
            else:
                logger.info("current snapshot is not good enough, skip~~")

            logger.info('current iter: {} current_result: {:.5f}'.format(
                iteration, cur_eval_result))

    few_shot_net.eval()
    support_net.eval()
    test_result = do_eval(few_shot_net,
                          support_net,
                          test_setting=test_setting,
                          base_dir=logger.get_logger_dir())
    logger.info("test result={:.5f}, best validation result={:.5f}".format(
        test_result, best_val_result))
    logger.info("Congrats~")
Beispiel #13
0
def main():
    logger.auto_set_dir()
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--dataroot',
        default='/home/hutao/lab/pytorchgo/example/LSD-seg/data',
        help='Path to source dataset')
    parser.add_argument('--batchSize',
                        type=int,
                        default=1,
                        help='input batch size')
    parser.add_argument('--num_iters',
                        type=int,
                        default=100000,
                        help='Number of training iterations')
    parser.add_argument('--optimizer',
                        type=str,
                        default='Adam',
                        help='Optimizer to use | SGD, Adam')
    parser.add_argument('--lr',
                        type=float,
                        default=1.0e-5,
                        help='learning rate')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.99,
                        help='Momentum for SGD')
    parser.add_argument('--beta1',
                        type=float,
                        default=0.9,
                        help='beta1 for adam. default=0.5')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=0.0005,
                        help='Weight decay')
    parser.add_argument(
        '--interval_validate',
        type=int,
        default=500,
        help=
        'Period for validation. Model is validated every interval_validate iterations'
    )
    parser.add_argument(
        '--resume',
        default='',
        help=
        "path to the current checkpoint for resuming training. Do not specify if model has to be trained from scratch"
    )
    parser.add_argument('--method',
                        default='LSD',
                        help="Method to use for training | LSD, sourceonly")
    parser.add_argument('--l1_weight', type=float, default=1, help='L1 weight')
    parser.add_argument('--adv_weight',
                        type=float,
                        default=0.1,
                        help='Adv_weight')
    parser.add_argument('--c_weight', type=float, default=0.1, help='C_weight')
    parser.add_argument('--gpu', type=int, default=0)
    args = parser.parse_args()
    print(args)

    gpu = args.gpu
    out = logger.get_logger_dir()
    resume = args.resume
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
    cuda = torch.cuda.is_available()
    torch.manual_seed(1337)
    if cuda:
        torch.cuda.manual_seed(1337)

    # Defining data loaders

    image_size = [640, 320]
    kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
    train_loader = torch.utils.data.DataLoader(torchfcn.datasets.SYNTHIA(
        'SYNTHIA',
        args.dataroot,
        split='train',
        transform=True,
        image_size=image_size),
                                               batch_size=args.batchSize,
                                               shuffle=True,
                                               **kwargs)
    val_loader = torch.utils.data.DataLoader(torchfcn.datasets.SYNTHIA(
        'SYNTHIA',
        args.dataroot,
        split='val',
        transform=True,
        image_size=image_size),
                                             batch_size=args.batchSize,
                                             shuffle=False,
                                             **kwargs)
    target_loader = torch.utils.data.DataLoader(torchfcn.datasets.CityScapes(
        'cityscapes',
        args.dataroot,
        split='train',
        transform=True,
        image_size=image_size),
                                                batch_size=args.batchSize,
                                                shuffle=True)

    # Defining models

    start_epoch = 0
    start_iteration = 0
    if args.method == 'sourceonly':
        model = torchfcn.models.FCN8s_sourceonly(n_class=class_num)
    elif args.method == 'LSD':
        model = torchfcn.models.FCN8s_LSD(n_class=class_num)
        netG = torchfcn.models._netG()
        netD = torchfcn.models._netD()
        netD.apply(weights_init)
        netG.apply(weights_init)
    else:
        raise ValueError('method argument can be either sourceonly or LSD')

    if resume:
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        start_epoch = checkpoint['epoch']
        start_iteration = checkpoint['iteration']
    else:
        vgg16 = torchfcn.models.VGG16(pretrained=True)
        model.copy_params_from_vgg16(vgg16)
    if cuda:
        model = model.cuda()
        if args.method == 'LSD':
            netD = netD.cuda()
            netG = netG.cuda()

    # Defining optimizer

    if args.optimizer == 'SGD':
        optim = torch.optim.SGD([
            {
                'params': get_parameters(model, bias=False)
            },
            {
                'params': get_parameters(model, bias=True),
                'lr': args.lr * 2,
                'weight_decay': args.weight_decay
            },
        ],
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    elif args.optimizer == 'Adam':
        optim = torch.optim.Adam([
            {
                'params': get_parameters(model, bias=False)
            },
            {
                'params': get_parameters(model, bias=True),
                'lr': args.lr * 2
            },
        ],
                                 lr=args.lr,
                                 betas=(args.beta1, 0.999))
    else:
        raise ValueError('Invalid optmizer argument. Has to be SGD or Adam')

    if args.method == 'LSD':
        optimD = torch.optim.Adam(netD.parameters(),
                                  lr=0.0001,
                                  betas=(0.7, 0.999))
        optimG = torch.optim.Adam(netG.parameters(),
                                  lr=0.0001,
                                  betas=(0.7, 0.999))

    if resume:
        optim.load_state_dict(checkpoint['optim_state_dict'])

    # Defining trainer object, and start training
    if args.method == 'sourceonly':
        trainer = torchfcn.Trainer_sourceonly(
            cuda=cuda,
            model=model,
            optimizer=optim,
            train_loader=train_loader,
            target_loader=target_loader,
            val_loader=val_loader,
            out=out,
            max_iter=args.num_iters,
            interval_validate=args.interval_validate)
        trainer.epoch = start_epoch
        trainer.iteration = start_iteration
        trainer.train()
    elif args.method == 'LSD':
        trainer = Trainer_LSD(cuda=cuda,
                              model=model,
                              netD=netD,
                              netG=netG,
                              optimizer=optim,
                              optimizerD=optimD,
                              optimizerG=optimG,
                              train_loader=train_loader,
                              target_loader=target_loader,
                              l1_weight=args.l1_weight,
                              adv_weight=args.adv_weight,
                              c_weight=args.c_weight,
                              val_loader=val_loader,
                              out=out,
                              max_iter=args.num_iters,
                              interval_validate=args.interval_validate,
                              image_size=image_size)
        trainer.epoch = start_epoch
        trainer.iteration = start_iteration
        trainer.train()
Beispiel #14
0
        tdcnn_demo.train()  #recover for training mode

        logger.info("current result: {},{}".format(test_mAP, test_ap))
        if test_mAP > args.best_result:
            logger.info(
                "current result {} better than {}, save best_model.".format(
                    test_mAP, args.best_result))
            args.best_result = test_mAP
            save_checkpoint(
                {
                    'model':
                    tdcnn_demo.module.state_dict()
                    if len(args.gpus) > 1 else tdcnn_demo.state_dict(),
                    'best':
                    args.best_result,
                }, os.path.join(logger.get_logger_dir(), 'best_model.pth'))

    # reload the best weight, do final testing
    state_dict = torch.load(
        os.path.join(logger.get_logger_dir(), 'best_model.pth'))['model']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if 'module' not in k:
            k = 'module.{}'.format(k)
        new_state_dict[k] = v
    tdcnn_demo.load_state_dict(new_state_dict)
    tdcnn_demo.eval()

    test_mAP, test_ap = test_net(tdcnn_demo,
                                 dataloader=dataloader_test,
Beispiel #15
0
def do_eval(few_shot_net,
            support_net,
            test_setting,
            base_dir=logger.get_logger_dir()):
    tmp_eval = os.path.join(base_dir, "eval_tmp")
    ground_truth_dir = os.path.join(tmp_eval, "ground_truth")
    predicted_dir = os.path.join(tmp_eval, "detection")
    vis_dir = os.path.join(tmp_eval, "vis")
    mAP_result_dir = os.path.join(tmp_eval, "mAP_result")

    def create_dirs(dir_name):
        global tmp_eval, ground_truth_dir, predicted_dir, vis_dir, mAP_result_dir

        tmp_eval = os.path.join(base_dir, dir_name)
        ground_truth_dir = os.path.join(tmp_eval, "ground_truth")
        predicted_dir = os.path.join(tmp_eval, "detection")
        vis_dir = os.path.join(tmp_eval, "vis")
        mAP_result_dir = os.path.join(tmp_eval, "mAP_result")

        if os.path.isdir(tmp_eval):
            import shutil
            shutil.rmtree(tmp_eval, ignore_errors=True)
        os.makedirs(tmp_eval)
        os.makedirs(ground_truth_dir)
        os.makedirs(predicted_dir)
        os.makedirs(vis_dir)
        os.mkdir(mAP_result_dir)
        return (tmp_eval, ground_truth_dir, predicted_dir, vis_dir,
                mAP_result_dir)

    tmp_eval, ground_truth_dir, predicted_dir, vis_dir, mAP_result_dir = create_dirs(
        dir_name="eval_tmp")

    dataset = FSLDataset(params=test_setting,
                         image_size=(args.dim, args.dim),
                         ds_name=args.dataset)
    num_images = len(dataset)

    data_loader = data.DataLoader(dataset,
                                  batch_size=1,
                                  num_workers=args.num_workers,
                                  shuffle=False,
                                  pin_memory=True,
                                  collate_fn=detection_collate)

    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)] for _ in range(num_classes)]
    w = image_size
    h = image_size

    for i, batch in tqdm(enumerate(data_loader),
                         total=len(data_loader),
                         desc="online {}".format(test_setting['image_sets'])):

        with open(os.path.join(ground_truth_dir, "{}.txt".format(i)),
                  "w") as f_gt:
            with open(os.path.join(predicted_dir, "{}.txt".format(i)),
                      "w") as f_predict:
                # if i > 500:break
                first_images, images, targets, metadata = batch
                class_name = metadata[0]['class_name']

                first_images = Variable(first_images.cuda())
                x = Variable(images.cuda())

                query_origin_img = metadata[0][
                    'cl_query_image']  #np.transpose(images.numpy()[0], (1, 2, 0))#W*H*C
                gt_bboxes = targets[0].numpy()
                for _ in range(gt_bboxes.shape[0]):
                    gt_bboxes[_, 0] *= w
                    gt_bboxes[_, 2] *= w
                    gt_bboxes[_, 1] *= h
                    gt_bboxes[_, 3] *= h
                    f_gt.write("targetobject {} {} {} {}\n".format(
                        int(gt_bboxes[_, 0]), int(gt_bboxes[_, 1]),
                        int(gt_bboxes[_, 2]), int(gt_bboxes[_, 3])))

                fusion_support = support_net(first_images)
                detections = few_shot_net(fusion_support, x,
                                          is_train=False).data

                vis_flag = 0
                # skip j = 0, because it's the background class
                for j in range(1, detections.size(1)):
                    dets = detections[0, j, :]
                    mask = dets[:, 0].gt(0).expand(
                        5,
                        dets.size(0)).t()  # greater than 0 will be visualized!
                    dets = torch.masked_select(dets, mask).view(-1, 5)
                    torch_dets = dets.clone()
                    if dets.dim() == 0:
                        continue
                    boxes = dets[:, 1:].cpu().numpy()
                    boxes[:, 0] = (boxes[:, 0] * w)
                    boxes[:, 1] = (boxes[:, 1] * h)
                    boxes[:, 2] = (boxes[:, 2] * w)
                    boxes[:, 3] = (boxes[:, 3] * h)

                    boxes[:, 0][boxes[:, 0] < 0] = 0
                    boxes[:, 1][boxes[:, 1] < 0] = 0
                    boxes[:, 2][boxes[:, 2] > image_size] = image_size
                    boxes[:, 3][boxes[:, 3] > image_size] = image_size

                    scores = dets[:, 0].cpu().numpy()
                    cls_dets = np.hstack(
                        (boxes, scores[:, np.newaxis])).astype(np.float32,
                                                               copy=False)
                    all_boxes[j][i] = cls_dets

                    for _ in range(cls_dets.shape[0]):
                        f_predict.write("targetobject {} {} {} {} {}\n".format(
                            cls_dets[_, 4], int(cls_dets[_, 0]),
                            int(cls_dets[_, 1]), int(cls_dets[_, 2]),
                            int(cls_dets[_, 3])))

    from cl_utils.mAP_lib.pascalvoc_interactive import get_mAP
    cwd = os.getcwd()
    mAP = get_mAP(
        os.path.join(os.path.dirname(os.path.realpath(__file__)), tmp_eval),
        "ground_truth", "detection", "mAP_result")
    os.chdir(cwd)
    return mAP
def train(args):

    logger.auto_set_dir()
    from pytorchgo.utils.pytorch_utils import set_gpu
    set_gpu(args.gpu)

    # Setup Dataloader
    from pytorchgo.augmentation.segmentation import SubtractMeans, PIL2NP, RGB2BGR, PIL_Scale, Value255to0, ToLabel
    from torchvision.transforms import Compose, Normalize, ToTensor
    img_transform = Compose([  # notice the order!!!
        PIL_Scale(train_img_shape, Image.BILINEAR),
        PIL2NP(),
        RGB2BGR(),
        SubtractMeans(),
        ToTensor(),
    ])

    label_transform = Compose([
        PIL_Scale(train_img_shape, Image.NEAREST),
        PIL2NP(),
        Value255to0(),
        ToLabel()
    ])

    val_img_transform = Compose([
        PIL_Scale(train_img_shape, Image.BILINEAR),
        PIL2NP(),
        RGB2BGR(),
        SubtractMeans(),
        ToTensor(),
    ])
    val_label_transform = Compose([
        PIL_Scale(train_img_shape, Image.NEAREST),
        PIL2NP(),
        ToLabel(),
        # notice here, training, validation size difference, this is very tricky.
    ])

    from pytorchgo.dataloader.pascal_voc_loader import pascalVOCLoader as common_voc_loader
    train_loader = common_voc_loader(split="train_aug",
                                     epoch_scale=1,
                                     img_transform=img_transform,
                                     label_transform=label_transform)
    validation_loader = common_voc_loader(split='val',
                                          img_transform=val_img_transform,
                                          label_transform=val_label_transform)

    n_classes = train_loader.n_classes
    trainloader = data.DataLoader(train_loader,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True)

    valloader = data.DataLoader(validation_loader,
                                batch_size=args.batch_size,
                                num_workers=8)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup Model
    from pytorchgo.model.deeplabv1 import VGG16_LargeFoV
    from pytorchgo.model.deeplab_resnet import Res_Deeplab

    model = Res_Deeplab(NoLabels=n_classes, pretrained=True, output_all=False)

    from pytorchgo.utils.pytorch_utils import model_summary, optimizer_summary
    model_summary(model)

    def get_validation_miou(model):
        model.eval()
        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader),
                                                    total=len(valloader),
                                                    desc="validation"):
            if i_val > 5 and is_debug == 1: break
            if i_val > 200 and is_debug == 2: break

            #img_large = torch.Tensor(np.zeros((1, 3, 513, 513)))
            #img_large[:, :, :images_val.shape[2], :images_val.shape[3]] = images_val

            output = model(Variable(images_val, volatile=True).cuda())
            output = output
            pred = output.data.max(1)[1].cpu().numpy()
            #pred = output[:, :images_val.shape[2], :images_val.shape[3]]

            gt = labels_val.numpy()

            running_metrics.update(gt, pred)

        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            logger.info("{}: {}".format(k, v))
        running_metrics.reset()
        return score['Mean IoU : \t']

    model.cuda()

    # Check if model has custom optimizer / loss
    if hasattr(model, 'optimizer'):
        logger.warn("don't have customzed optimizer, use default setting!")
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.optimizer_params(args.l_rate),
                                    lr=args.l_rate,
                                    momentum=0.99,
                                    weight_decay=5e-4)

    optimizer_summary(optimizer)
    if args.resume is not None:
        if os.path.isfile(args.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            logger.info("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logger.info("No checkpoint found at '{}'".format(args.resume))

    best_iou = 0
    logger.info('start!!')
    for epoch in tqdm(range(args.n_epoch), total=args.n_epoch):
        model.train()
        for i, (images, labels) in tqdm(enumerate(trainloader),
                                        total=len(trainloader),
                                        desc="training epoch {}/{}".format(
                                            epoch, args.n_epoch)):
            if i > 10 and is_debug == 1: break

            if i > 200 and is_debug == 2: break

            cur_iter = i + epoch * len(trainloader)
            cur_lr = adjust_learning_rate(optimizer,
                                          args.l_rate,
                                          cur_iter,
                                          args.n_epoch * len(trainloader),
                                          power=0.9)

            images = Variable(images.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()
            outputs = model(images)  # use fusion score
            loss = CrossEntropyLoss2d_Seg(input=outputs,
                                          target=labels,
                                          class_num=n_classes)

            #for i in range(len(outputs) - 1):
            #for i in range(1):
            #    loss = loss + CrossEntropyLoss2d_Seg(input=outputs[i], target=labels, class_num=n_classes)

            loss.backward()
            optimizer.step()

            if (i + 1) % 100 == 0:
                logger.info(
                    "Epoch [%d/%d] Loss: %.4f, lr: %.7f, best mIoU: %.7f" %
                    (epoch + 1, args.n_epoch, loss.data[0], cur_lr, best_iou))

        cur_miou = get_validation_miou(model)
        if cur_miou >= best_iou:
            best_iou = cur_miou
            state = {
                'epoch': epoch + 1,
                'mIoU': best_iou,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(state,
                       os.path.join(logger.get_logger_dir(), "best_model.pth"))
Beispiel #17
0
def train():
    net.train()
    logger.info('Loading Dataset...')

    dataset = SimDetection(transform=SSDAugmentation(args.dim, means))

    epoch_size = len(dataset) // args.batch_size
    logger.info('Training SSD on {}'.format(dataset.name))
    logger.info("epoch size: {}".format(epoch_size))
    step_index = 0
    batch_iterator = None
    data_loader = data.DataLoader(dataset,
                                  args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True,
                                  collate_fn=detection_collate,
                                  pin_memory=True)

    lr = args.lr
    epoch = 0
    for iteration in tqdm(range(start_iter, args.iterations),
                          desc="epoch {}/{} training".format(
                              epoch, args.iterations // epoch_size)):
        if (not batch_iterator) or (iteration % epoch_size == 0):
            # create batch iterator
            batch_iterator = iter(data_loader)
            epoch += 1
        if iteration in stepvalues:
            step_index += 1
            old_lr = lr
            lr = adjust_learning_rate(optimizer, args.gamma, step_index)
            logger.info("iter {}, change lr from {:.8f} to {:.8f}".format(
                iteration, old_lr, lr))

        images, targets = next(batch_iterator)
        if args.cuda:
            images = Variable(images.cuda())
            targets = [
                Variable(anno.cuda(), volatile=True) for anno in targets
            ]
        else:
            images = Variable(images)
            targets = [Variable(anno, volatile=True) for anno in targets]
        # forward
        t0 = time.time()
        out = net(images)
        # backprop
        optimizer.zero_grad()
        loss_l, loss_c = criterion(out, targets)
        loss = loss_l + loss_c
        loss.backward()
        optimizer.step()
        t1 = time.time()
        if iteration % 10 == 0:
            logger.info('''
                Timer: {:.5f} sec.\t LR: {:.7f}.\t Iter: {}.\t Loss_l: {:.5f}.\t Loss_c: {:.5f}. Loss: {:.5f}
                '''.format((t1 - t0), lr, iteration, loss_l.data[0],
                           loss_c.data[0], loss.data[0]))

        if iteration % validate_per == 0 and iteration > 0:
            logger.info('Saving state, iter={}'.format(iteration))
            torch.save(
                ssd_net.state_dict(),
                os.path.join(logger.get_logger_dir(),
                             'ssd-{}.pth'.format(repr(iteration))))
    torch.save(
        ssd_net.state_dict(),
        os.path.join(logger.get_logger_dir(), 'ssd_{}.pth'.format(iteration)))
    logger.info("Congratulations..")
                                             is_data_parallel=args.is_data_parallel)
    optimizer_g = get_optimizer(model_g.parameters(), lr=args.lr, momentum=args.momentum, opt=args.opt,
                                weight_decay=args.weight_decay)
    optimizer_f = get_optimizer(list(model_f1.parameters()) + list(model_f2.parameters()), opt=args.opt,
                                lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
if args.uses_one_classifier:
    logger.warn ("f1 and f2 are same!")
    model_f2 = model_f1

mode = "%s-%s2%s-%s_%sch" % (args.src_dataset, args.src_split, args.tgt_dataset, args.tgt_split, args.input_ch)
if args.net in ["fcn", "psp"]:
    model_name = "%s-%s-%s-res%s" % (args.method, args.savename, args.net, args.res)
else:
    model_name = "%s-%s-%s" % (args.method, args.savename, args.net)

outdir = os.path.join(logger.get_logger_dir(), mode)

# Create Model Dir
pth_dir = os.path.join(outdir, "pth")
mkdir_if_not_exist(pth_dir)

# Create Model Dir and  Set TF-Logger
tflog_dir = os.path.join(outdir, "tflog", model_name)
mkdir_if_not_exist(tflog_dir)
configure(tflog_dir, flush_secs=5)

# Save param dic
if resume_flg:
    json_fn = os.path.join(args.outdir, "param-%s_resume.json" % model_name)
else:
    json_fn = os.path.join(outdir, "param-%s.json" % model_name)
Beispiel #19
0
def main():
    logger.auto_set_dir()

    global args, best_prec1

    import argparse
    parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks")
    parser.add_argument('--dataset', type=str,default="something", choices=['something', 'jester', 'moments'])
    parser.add_argument('--modality', type=str, default="RGB", choices=['RGB', 'Flow'])
    parser.add_argument('--train_list', type=str, default="")
    parser.add_argument('--val_list', type=str, default="")
    parser.add_argument('--root_path', type=str, default="")
    parser.add_argument('--store_name', type=str, default="")
    # ========================= Model Configs ==========================
    parser.add_argument('--arch', type=str, default="BNInception")
    parser.add_argument('--num_segments', type=int, default=3)
    parser.add_argument('--consensus_type', type=str, default='avg')
    parser.add_argument('--k', type=int, default=3)

    parser.add_argument('--dropout', '--do', default=0.8, type=float,
                        metavar='DO', help='dropout ratio (default: 0.5)')
    parser.add_argument('--loss_type', type=str, default="nll",
                        choices=['nll'])
    parser.add_argument('--img_feature_dim', default=256, type=int, help="the feature dimension for each frame")

    # ========================= Learning Configs ==========================
    parser.add_argument('--epochs', default=120, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-b', '--batch_size', default=128, type=int,
                        metavar='N', help='mini-batch size (default: 256)')
    parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
                        metavar='LR', help='initial learning rate')
    parser.add_argument('--lr_steps', default=[50, 100], type=float, nargs="+",
                        metavar='LRSteps', help='epochs to decay learning rate by 10')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
                        metavar='W', help='weight decay (default: 5e-4)')
    parser.add_argument('--clip-gradient', '--gd', default=20, type=float,
                        metavar='W', help='gradient norm clipping (default: disabled)')
    parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true")

    # ========================= Monitor Configs ==========================
    parser.add_argument('--print-freq', '-p', default=20, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--eval-freq', '-ef', default=5, type=int,
                        metavar='N', help='evaluation frequency (default: 5)')

    # ========================= Runtime Configs ==========================
    parser.add_argument('-j', '--workers', default=30, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--snapshot_pref', type=str, default="")
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--gpu', type=str, default='4')
    parser.add_argument('--flow_prefix', default="", type=str)
    parser.add_argument('--root_log', type=str, default='log')
    parser.add_argument('--root_model', type=str, default='model')
    parser.add_argument('--root_output', type=str, default='output')

    args = parser.parse_args()

    args.consensus_type = "TRN"
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    device_ids = [int(id) for id in args.gpu.split(',')]
    assert len(device_ids) >1, "TRN must run with GPU_num > 1"

    args.root_log = logger.get_logger_dir()
    args.root_model = logger.get_logger_dir()
    args.root_output = logger.get_logger_dir()

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(args.dataset, args.modality)
    num_class = len(categories)


    args.store_name = '_'.join(['TRN', args.dataset, args.modality, args.arch, args.consensus_type, 'segment%d'% args.num_segments])
    print('storing name: ' + args.store_name)

    model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)#TODO, , device_ids=[int(id) for id in args.gpu.split(',')]

    if torch.cuda.is_available():
       model.cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   transform=torchvision.transforms.Compose([
                       train_augmentation,
                       Stack(roll=(args.arch in ['BNInception','InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception','InceptionV3'])),
                       normalize,
                   ])),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   random_shift=False,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=(args.arch in ['BNInception','InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception','InceptionV3'])),
                       normalize,
                   ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        logger.info('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, (epoch + 1) * len(train_loader), log_training)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)