Ejemplo n.º 1
0
def main_wflwe70():
    args = parse_args()
    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.cfg, 'test')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.determinstic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    config.defrost()
    config.MODEL.INIT_WEIGHTS = False
    config.freeze()
    model = models.get_face_alignment_net(config)

    dataset_type = get_dataset(config)
    dataset = dataset_type(config, is_train=True)

    for i in range(len(dataset)):
        # ipdb.set_trace()
        img, fname, meta = dataset[i]
        filename = osp.join('data/wflwe70/xximages', fname)
        if not osp.exists(osp.dirname(filename)):
            os.makedirs(osp.dirname(filename))
        scale = meta['scale']
        center = meta['center']
        tpts = meta['tpts']

        for spt in tpts:
            img = cv2.circle(img, (4 * spt[0], 4 * spt[1]),
                             1 + center[0] // 400, (255, 0, 0))
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        cv2.imwrite(filename, img)
Ejemplo n.º 2
0
def main_cofw():
    args = parse_args()
    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.cfg, 'test')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.determinstic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    config.defrost()
    config.MODEL.INIT_WEIGHTS = False
    config.freeze()
    model = models.get_face_alignment_net(config)

    ipdb.set_trace()
    dataset_type = get_dataset(config)
    dataset = dataset_type(config, is_train=True)

    fp = open('data/cofw/test.csv', 'w')
    for i in range(len(dataset)):
        # ipdb.set_trace()
        img, image_path, meta = dataset[i]
        fname = osp.join('data/cofw/test', osp.basename(image_path))
        fp.write('%s,1,128,128' % fname)
        tpts = meta['tpts']
        for j in range(tpts.shape[0]):
            fp.write(',%d,%d' % (tpts[j, 0], tpts[j, 1]))
        fp.write('\n')
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        cv2.imwrite(fname, img)
    fp.close()
Ejemplo n.º 3
0
def main():

    args = parse_args()

    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.cfg, 'test')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.determinstic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    config.defrost()
    config.MODEL.INIT_WEIGHTS = False
    config.freeze()
    model = models.get_face_alignment_net(config)

    gpus = list(config.GPUS)
    model = nn.DataParallel(model, device_ids=gpus).cuda()

    # load model
    state_dict = torch.load(args.model_file)
    if 'state_dict' in state_dict.keys():
        state_dict = state_dict['state_dict']
        model.load_state_dict(state_dict)
    else:
        model.module.load_state_dict(state_dict)

    dataset_type = get_dataset(config)

    test_loader = DataLoader(dataset=dataset_type(config, is_train=False),
                             batch_size=config.TEST.BATCH_SIZE_PER_GPU *
                             len(gpus),
                             shuffle=False,
                             num_workers=config.WORKERS,
                             pin_memory=config.PIN_MEMORY)

    nme, predictions = function.inference(config, test_loader, model)
    with open('../data/crabs/crabs_data_test.csv', 'r') as f:
        data = np.loadtxt(f, str, delimiter=",", skiprows=1)
    paths = data[:, 0]
    for index, path in enumerate(paths):
        img = cv2.imread("../data/crabs/images/{}".format(path))
        a = predictions[index]
        b = a.numpy()

        for index, px in enumerate(b):
            # print(tuple(px))
            cv2.circle(img, tuple(px), 1, (0, 0, 255), 3, 8, 0)

        # cv2.imwrite("/home/njtech/Jiannan/crabs/dataset/result_new/{}".format(path.split('/')[-1]), img)

        cv2.imshow("img", img)

        cv2.waitKey(1000) & 0xFF

    torch.save(predictions, os.path.join(final_output_dir, 'predictions.pth'))
Ejemplo n.º 4
0
def main():

    args = parse_args()

    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.cfg, 'test')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.determinstic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    config.defrost()
    config.MODEL.INIT_WEIGHTS = False
    config.freeze()
    model = models.get_face_alignment_net(config)
    if args.onnx_export:
        torch_out = torch.onnx._export(model,
                                       torch.rand(1, 3, config.IMAGE_SIZE),
                                       osp.join(final_output_dir,
                                                args.onnx_export),
                                       export_params=True)
        return

    gpus = list(config.GPUS)
    if gpus[0] > -1:
        model = nn.DataParallel(model, device_ids=gpus).cuda()

    # load model
    if gpus[0] > -1:
        state_dict = torch.load(args.model_file)
    else:
        state_dict = torch.load(args.model_file, map_location='cpu')
    if 'state_dict' in state_dict.keys():
        state_dict = state_dict['state_dict']
        model.load_state_dict(state_dict)
    else:
        if gpus[0] > -1:
            model.module.load_state_dict(state_dict)
        else:
            model.load_state_dict(state_dict)

    dataset_type = get_dataset(config)
    dataset = dataset_type(config, is_train=False)

    test_loader = DataLoader(dataset=dataset,
                             batch_size=config.TEST.BATCH_SIZE_PER_GPU *
                             len(gpus),
                             shuffle=False,
                             num_workers=config.WORKERS,
                             pin_memory=config.PIN_MEMORY)

    ipdb.set_trace()
    nme, predictions = function.inference(config, test_loader, model)

    torch.save(predictions, os.path.join(final_output_dir, 'predictions.pth'))
def main():

    args = parse_args()

    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.cfg, 'test')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.determinstic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    config.defrost()
    config.MODEL.INIT_WEIGHTS = False
    config.freeze()
    # model = models.get_face_alignment_net(config)
    model = eval('models.' + config.MODEL.NAME + '.get_face_alignment_net')(
        config, is_train=True)

    gpus = list(config.GPUS)
    model = nn.DataParallel(model, device_ids=gpus).cuda()

    # load model
    # state_dict = torch.load(args.model_file)
    # if 'state_dict' in state_dict.keys():
    #     state_dict = state_dict['state_dict']
    #     model.load_state_dict(state_dict)
    # else:
    #     model.module.load_state_dict(state_dict)

    if args.model_file:
        logger.info('=> loading model from {}'.format(args.model_file))
        # model.load_state_dict(torch.load(args.model_file), strict=False)

        model_state = torch.load(args.model_file)
        model.module.load_state_dict(model_state.state_dict())
    else:
        model_state_file = os.path.join(final_output_dir, 'final_state.pth')
        logger.info('=> loading model from {}'.format(model_state_file))
        model_state = torch.load(model_state_file)
        model.module.load_state_dict(model_state)

    dataset_type = get_dataset(config)

    test_loader = DataLoader(dataset=dataset_type(config, is_train=False),
                             batch_size=config.TEST.BATCH_SIZE_PER_GPU *
                             len(gpus),
                             shuffle=False,
                             num_workers=config.WORKERS,
                             pin_memory=config.PIN_MEMORY)

    nme, predictions = function.inference(config, test_loader, model)

    torch.save(predictions, os.path.join(final_output_dir, 'predictions.pth'))
Ejemplo n.º 6
0
    def eval_text(self):
        self.set_mode('eval')

        tst_loader = DataLoader(get_dataset('iam_word',
                                            self.opt.training.dset_split),
                                batch_size=self.opt.test.nrow,
                                shuffle=True,
                                collate_fn=self.collect_fn)

        def get_space_index(text):
            idxs = []
            for i, ch in enumerate(text):
                if ch == ' ':
                    idxs.append(i)
            return idxs

        with torch.no_grad():
            nrow = self.opt.test.nrow
            while True:
                text = input('input text: ')
                if len(text) == 0:
                    break

                batch = next(iter(tst_loader))
                imgs, img_lens, lbs, lb_lens, wids = batch
                real_imgs, real_img_lens = imgs.to(self.device), img_lens.to(
                    self.device)
                fake_lbs = self.label_converter.encode(text)
                fake_lbs = torch.LongTensor(fake_lbs)
                fake_lb_lens = torch.IntTensor([len(text)])

                fake_lbs = fake_lbs.repeat(nrow, 1).to(self.device)
                fake_lb_lens = fake_lb_lens.repeat(nrow, ).to(self.device)
                enc_styles = self.models.E(real_imgs, real_img_lens,
                                           self.models.W.cnn_backbone)
                noises = torch.randn((nrow, self.noise_dim)).to(self.device)
                enc_styles = torch.cat([noises, enc_styles], dim=-1)

                real_imgs = (1 - real_imgs).squeeze().cpu().numpy() * 127
                gen_imgs = self.models.G(enc_styles, fake_lbs, fake_lb_lens)
                gen_imgs = (1 - gen_imgs).squeeze().cpu().numpy() * 127
                space_indexs = get_space_index(text)
                for idx in space_indexs:
                    gen_imgs[:, :, idx * 16:(idx + 1) * 16] = 255

                plt.figure()

                for i in range(nrow):
                    plt.subplot(nrow * 2, 1, i * 2 + 1)
                    plt.imshow(real_imgs[i], cmap='gray')
                    plt.axis('off')
                    plt.subplot(nrow * 2, 1, i * 2 + 2)
                    plt.imshow(gen_imgs[i], cmap='gray')
                    plt.axis('off')
                plt.tight_layout()
                plt.show()
def main():

    args = parse_args()

    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.cfg, 'test')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.determinstic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    config.defrost()
    config.MODEL.INIT_WEIGHTS = False
    config.freeze()
    model = models.get_face_alignment_net(config)

    gpus = list(config.GPUS)
    model = nn.DataParallel(model, device_ids=gpus).cuda()

    # load model
    #state_dict = torch.load(args.model_file)
    #if 'state_dict' in state_dict.keys():
    #    state_dict = state_dict['state_dict']
    #    model.load_state_dict(state_dict)
    #else:
    #    model.module.load_state_dict(state_dict)

    model = torch.load(args.model_file)
    model.eval()
    model = nn.DataParallel(model, device_ids=gpus).cuda()

    dataset_type = get_dataset(config)

    test_loader = DataLoader(dataset=dataset_type(config, is_train=False),
                             batch_size=config.TEST.BATCH_SIZE_PER_GPU *
                             len(gpus),
                             shuffle=False,
                             num_workers=config.WORKERS,
                             pin_memory=config.PIN_MEMORY)

    nme, predictions = function.inference(config, test_loader, model)

    import cv2
    img = cv2.imread('data/wflw/images/my3.jpg')
    print(predictions, predictions.shape)
    for item in predictions[0]:
        cv2.circle(img, (item[0], item[1]), 3, (0, 0, 255), -1)
    cv2.imwrite('out.png', img)

    torch.save(predictions, os.path.join(final_output_dir, 'predictions.pth'))
Ejemplo n.º 8
0
 def validate(self, guided=True):
     self.set_mode('eval')
     dset_name = self.opt.valid.dset_name if self.opt.valid.dset_name \
                 else self.opt.dataset
     dset = get_dataset(dset_name, self.opt.valid.dset_split)
     dloader = DataLoader(dset,
                          collate_fn=self.collect_fn,
                          batch_size=self.opt.valid.batch_size,
                          shuffle=False,
                          num_workers=4)
     # style images are resized
     source_dloader = DataLoader(get_dataset(
         self.opt.valid.dset_name.strip('_org'), self.opt.valid.dset_split),
                                 collate_fn=self.collect_fn,
                                 batch_size=self.opt.valid.batch_size,
                                 shuffle=False,
                                 num_workers=4)
     generator = self.image_generator(source_dloader, guided)
     fid_kid = calculate_kid_fid(self.opt.valid, dloader, generator,
                                 self.max_valid_image_width, self.device)
     return fid_kid
Ejemplo n.º 9
0
def main():

    args = parse_args()

    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.cfg, 'test')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.determinstic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    config.defrost()
    config.MODEL.INIT_WEIGHTS = False
    config.freeze()
    model = models.get_face_alignment_net(config)

    gpus = list(config.GPUS)
    if torch.cuda.is_available():
        model = nn.DataParallel(model, device_ids=gpus).cuda()

        # load model
        state_dict = torch.load(args.model_file)
    else:
        # model = nn.DataParallel(model)
        state_dict = torch.load(args.model_file,
                                map_location=lambda storage, loc: storage)

    if 'state_dict' in state_dict.keys():
        state_dict = state_dict['state_dict']
        model.load_state_dict(state_dict)
    else:
        try:
            model.module.load_state_dict(state_dict.state_dict())
        except AttributeError:
            state_dict  ## remove first seven
            model.load_state_dict(state_dict)

    dataset_type = get_dataset(config)

    test_loader = DataLoader(dataset=dataset_type(config, is_train=False),
                             batch_size=config.TEST.BATCH_SIZE_PER_GPU *
                             len(gpus),
                             shuffle=False,
                             num_workers=config.WORKERS,
                             pin_memory=config.PIN_MEMORY)

    #nme, predictions = function.inference(config, test_loader, model, args.model_file) #### testing
    function.test(config, test_loader, model, args.model_file)  #### testing
Ejemplo n.º 10
0
    def eval_style(self):
        self.set_mode('eval')

        tst_loader = DataLoader(get_dataset('iam_word',
                                            self.opt.training.dset_split),
                                batch_size=self.opt.test.nrow,
                                shuffle=True,
                                collate_fn=self.collect_fn)

        with torch.no_grad():
            nrow, ncol = self.opt.test.nrow, 2
            while True:
                text = input('input text: ')
                if len(text) == 0:
                    break

                texts = text.split(' ')
                ncol = len(texts)
                batch = next(iter(tst_loader))
                imgs, img_lens, lbs, lb_lens, wids = batch
                real_imgs, real_img_lens = imgs.to(self.device), img_lens.to(
                    self.device)
                if len(texts) == 1:
                    fake_lbs = self.label_converter.encode(texts)
                    fake_lbs = torch.LongTensor(fake_lbs)
                    fake_lb_lens = torch.IntTensor([len(texts[0])])
                else:
                    fake_lbs, fake_lb_lens = self.label_converter.encode(texts)

                fake_lbs = fake_lbs.repeat(nrow, 1).to(self.device)
                fake_lb_lens = fake_lb_lens.repeat(nrow, ).to(self.device)
                enc_styles = self.models.E(real_imgs, real_img_lens, self.models.W.cnn_backbone).unsqueeze(1).\
                                repeat(1, ncol, 1).view(nrow * ncol, self.opt.EncModel.style_dim)
                noises = torch.randn((nrow, self.noise_dim)).unsqueeze(1).\
                                repeat(1, ncol, 1).view(nrow * ncol, self.noise_dim).to(self.device)
                enc_styles = torch.cat([noises, enc_styles], dim=-1)

                gen_imgs = self.models.G(enc_styles, fake_lbs, fake_lb_lens)
                gen_imgs = (1 - gen_imgs).squeeze().cpu().numpy() * 127
                real_imgs = (1 - real_imgs).squeeze().cpu().numpy() * 127
                plt.figure()
                for i in range(nrow):
                    plt.subplot(nrow, 1 + ncol, i * (1 + ncol) + 1)
                    plt.imshow(real_imgs[i], cmap='gray')
                    plt.axis('off')
                    for j in range(ncol):
                        plt.subplot(nrow, 1 + ncol, i * (1 + ncol) + 2 + j)
                        plt.imshow(gen_imgs[i * ncol + j], cmap='gray')
                        plt.axis('off')
                plt.tight_layout()
                plt.show()
Ejemplo n.º 11
0
def main():

    args = parse_args()

    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.cfg, 'test')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.determinstic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    config.defrost()
    config.MODEL.INIT_WEIGHTS = False
    config.freeze()

    model = torchvision.models.resnet101(pretrained=config.MODEL.PRETRAINED,
                                         progress=True)
    num_ftrs = model.fc.in_features
    model.fc = torch.nn.Linear(num_ftrs, config.MODEL.OUTPUT_SIZE[0])

    gpus = list(config.GPUS)
    model = nn.DataParallel(model, device_ids=gpus).cuda()

    # load model
    state_dict = torch.load(args.model_file)
    if 'state_dict' in state_dict.keys():
        state_dict = state_dict['state_dict']
        model.load_state_dict(state_dict)
    else:
        model.module.load_state_dict(state_dict)

    dataset_type = get_dataset(config)

    test_loader = DataLoader(dataset=dataset_type(config, is_train=False),
                             batch_size=config.TEST.BATCH_SIZE_PER_GPU *
                             len(gpus),
                             shuffle=False,
                             num_workers=config.WORKERS,
                             pin_memory=config.PIN_MEMORY)

    predictions = function.inference(config, test_loader, model)

    torch.save(predictions, os.path.join(final_output_dir, 'predictions.pth'))
Ejemplo n.º 12
0
def main_300w():
    args = parse_args()
    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.cfg, 'test')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.determinstic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    config.defrost()
    config.MODEL.INIT_WEIGHTS = False
    config.freeze()
    model = models.get_face_alignment_net(config)

    dataset_type = get_dataset(config)
    dataset = dataset_type(config, is_train=False)

    fp = open('data/300w/face_landmarks70_300w_test.csv', 'w')
    for i in range(len(dataset)):
        # ipdb.set_trace()
        img, fname, meta = dataset[i]
        filename = osp.join('data/300w/xximages', fname)
        if not osp.exists(osp.dirname(filename)):
            os.makedirs(osp.dirname(filename))
        scale = meta['scale']
        center = meta['center']
        tpts = meta['tpts']

        selpts = []
        for j in range(0, 68):
            selpts.append(tpts[j])
        selpts.append(tpts[36:42].mean(0))
        selpts.append(tpts[42:48].mean(0))

        fp.write('%s,%.2f,%.1f,%.1f' % (fname, scale, center[0], center[1]))
        for spt in selpts:
            img = cv2.circle(img, (spt[0], spt[1]), 1 + center[0] // 400,
                             (255, 0, 0))
            fp.write(',%f,%f' % (spt[0], spt[1]))
        fp.write('\n')
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        cv2.imwrite(filename, img)
    fp.close()
Ejemplo n.º 13
0
def main(hparams):
    results_dir = get_results_directory(hparams.output_dir)
    writer = SummaryWriter(log_dir=str(results_dir))

    ds = get_dataset(hparams.dataset, root=hparams.data_root)
    input_size, num_classes, train_dataset, test_dataset = ds

    hparams.seed = set_seed(hparams.seed)

    if hparams.n_inducing_points is None:
        hparams.n_inducing_points = num_classes

    print(f"Training with {hparams}")
    hparams.save(results_dir / "hparams.json")

    if hparams.ard:
        # Hardcoded to WRN output size
        ard = 640
    else:
        ard = None

    feature_extractor = WideResNet(
        spectral_normalization=hparams.spectral_normalization,
        dropout_rate=hparams.dropout_rate,
        coeff=hparams.coeff,
        n_power_iterations=hparams.n_power_iterations,
        batchnorm_momentum=hparams.batchnorm_momentum,
    )

    initial_inducing_points, initial_lengthscale = initial_values_for_GP(
        train_dataset, feature_extractor, hparams.n_inducing_points
    )

    gp = GP(
        num_outputs=num_classes,
        initial_lengthscale=initial_lengthscale,
        initial_inducing_points=initial_inducing_points,
        separate_inducing_points=hparams.separate_inducing_points,
        kernel=hparams.kernel,
        ard=ard,
        lengthscale_prior=hparams.lengthscale_prior,
    )

    model = DKL_GP(feature_extractor, gp)
    model = model.cuda()

    likelihood = SoftmaxLikelihood(num_classes=num_classes, mixing_weights=False)
    likelihood = likelihood.cuda()

    elbo_fn = VariationalELBO(likelihood, gp, num_data=len(train_dataset))

    parameters = [
        {"params": feature_extractor.parameters(), "lr": hparams.learning_rate},
        {"params": gp.parameters(), "lr": hparams.learning_rate},
        {"params": likelihood.parameters(), "lr": hparams.learning_rate},
    ]

    optimizer = torch.optim.SGD(
        parameters, momentum=0.9, weight_decay=hparams.weight_decay
    )

    milestones = [60, 120, 160]

    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=milestones, gamma=0.2
    )

    def step(engine, batch):
        model.train()
        likelihood.train()

        optimizer.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        y_pred = model(x)
        elbo = -elbo_fn(y_pred, y)

        elbo.backward()
        optimizer.step()

        return elbo.item()

    def eval_step(engine, batch):
        model.eval()
        likelihood.eval()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        with torch.no_grad():
            y_pred = model(x)

        return y_pred, y

    trainer = Engine(step)
    evaluator = Engine(eval_step)

    metric = Average()
    metric.attach(trainer, "elbo")

    def output_transform(output):
        y_pred, y = output

        # Sample softmax values independently for classification at test time
        y_pred = y_pred.to_data_independent_dist()

        # The mean here is over likelihood samples
        y_pred = likelihood(y_pred).probs.mean(0)

        return y_pred, y

    metric = Accuracy(output_transform=output_transform)
    metric.attach(evaluator, "accuracy")

    metric = Loss(lambda y_pred, y: -elbo_fn(y_pred, y))
    metric.attach(evaluator, "elbo")

    kwargs = {"num_workers": 4, "pin_memory": True}

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=hparams.batch_size,
        shuffle=True,
        drop_last=True,
        **kwargs,
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=512, shuffle=False, **kwargs
    )

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_results(trainer):
        metrics = trainer.state.metrics
        elbo = metrics["elbo"]

        print(f"Train - Epoch: {trainer.state.epoch} ELBO: {elbo:.2f} ")
        writer.add_scalar("Likelihood/train", elbo, trainer.state.epoch)

        if hparams.spectral_normalization:
            for name, layer in model.feature_extractor.named_modules():
                if isinstance(layer, torch.nn.Conv2d):
                    writer.add_scalar(
                        f"sigma/{name}", layer.weight_sigma, trainer.state.epoch
                    )

        if not hparams.ard:
            # Otherwise it's too much to submit to tensorboard
            length_scales = model.gp.covar_module.base_kernel.lengthscale.squeeze()
            for i in range(length_scales.shape[0]):
                writer.add_scalar(
                    f"length_scale/{i}", length_scales[i], trainer.state.epoch
                )

        if trainer.state.epoch > 150 and trainer.state.epoch % 5 == 0:
            _, auroc, aupr = get_ood_metrics(
                hparams.dataset, "SVHN", model, likelihood, hparams.data_root
            )
            print(f"OoD Metrics - AUROC: {auroc}, AUPR: {aupr}")
            writer.add_scalar("OoD/auroc", auroc, trainer.state.epoch)
            writer.add_scalar("OoD/auprc", aupr, trainer.state.epoch)

        evaluator.run(test_loader)
        metrics = evaluator.state.metrics
        acc = metrics["accuracy"]
        elbo = metrics["elbo"]

        print(
            f"Test - Epoch: {trainer.state.epoch} "
            f"Acc: {acc:.4f} "
            f"ELBO: {elbo:.2f} "
        )

        writer.add_scalar("Likelihood/test", elbo, trainer.state.epoch)
        writer.add_scalar("Accuracy/test", acc, trainer.state.epoch)

        scheduler.step()

    pbar = ProgressBar(dynamic_ncols=True)
    pbar.attach(trainer)

    trainer.run(train_loader, max_epochs=200)

    # Done training - time to evaluate
    results = {}

    evaluator.run(train_loader)
    train_acc = evaluator.state.metrics["accuracy"]
    train_elbo = evaluator.state.metrics["elbo"]
    results["train_accuracy"] = train_acc
    results["train_elbo"] = train_elbo

    evaluator.run(test_loader)
    test_acc = evaluator.state.metrics["accuracy"]
    test_elbo = evaluator.state.metrics["elbo"]
    results["test_accuracy"] = test_acc
    results["test_elbo"] = test_elbo

    _, auroc, aupr = get_ood_metrics(
        hparams.dataset, "SVHN", model, likelihood, hparams.data_root
    )
    results["auroc_ood_svhn"] = auroc
    results["aupr_ood_svhn"] = aupr

    print(f"Test - Accuracy {results['test_accuracy']:.4f}")

    results_json = json.dumps(results, indent=4, sort_keys=True)
    (results_dir / "results.json").write_text(results_json)

    torch.save(model.state_dict(), results_dir / "model.pt")
    torch.save(likelihood.state_dict(), results_dir / "likelihood.pt")

    writer.close()
def main():

    args = parse_args()

    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.determinstic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    gpus = list(config.GPUS)

    dataset_type = get_dataset(config)
    train_data = dataset_type(config, is_train=True)
    train_loader = DataLoader(dataset=train_data,
                              batch_size=config.TRAIN.BATCH_SIZE_PER_GPU *
                              len(gpus),
                              shuffle=config.TRAIN.SHUFFLE,
                              num_workers=config.WORKERS,
                              pin_memory=config.PIN_MEMORY)

    val_loader = DataLoader(dataset=dataset_type(config, is_train=False),
                            batch_size=config.TEST.BATCH_SIZE_PER_GPU *
                            len(gpus),
                            shuffle=False,
                            num_workers=config.WORKERS,
                            pin_memory=config.PIN_MEMORY)

    # config.MODEL.NUM_JOINTS = train_data.get_num_points()
    model = models.get_face_alignment_net(config)

    # copy model files
    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    model = nn.DataParallel(model, device_ids=gpus).cuda()

    # loss
    criterion = torch.nn.MSELoss(size_average=True).cuda()

    optimizer = utils.get_optimizer(config, model)

    best_nme = 100
    last_epoch = config.TRAIN.BEGIN_EPOCH

    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
            last_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       config.TRAIN.LR_STEP,
                                                       config.TRAIN.LR_FACTOR,
                                                       last_epoch - 1)

    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir, 'latest.pth')
        if os.path.islink(model_state_file):
            checkpoint = torch.load(model_state_file)
            last_epoch = checkpoint['epoch']
            best_nme = checkpoint['best_nme']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no checkpoint found")

    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
        lr_scheduler.step()

        function.train(config, train_loader, model, criterion, optimizer,
                       epoch, writer_dict)

        # evaluate
        nme, predictions = function.validate(config, val_loader, model,
                                             criterion, epoch, writer_dict)

        is_best = nme < best_nme
        best_nme = min(nme, best_nme)

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        print("best:", is_best)
        utils.save_checkpoint(
            {
                "state_dict": model,
                "epoch": epoch + 1,
                "best_nme": best_nme,
                "optimizer": optimizer.state_dict(),
            }, predictions, is_best, final_output_dir,
            'checkpoint_{}.pth'.format(epoch))

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        'saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
Ejemplo n.º 15
0
    def train(self):
        self.info()

        def KLloss(mu, logvar):
            return torch.mean(
                -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp(), dim=1),
                dim=0)

        opt = self.opt
        self.z = prepare_z_dist(opt.training.batch_size,
                                opt.GenModel.style_dim,
                                self.device,
                                seed=self.opt.seed)
        self.y = prepare_y_dist(opt.training.batch_size,
                                len(self.lexicon),
                                self.device,
                                seed=self.opt.seed)

        self.eval_z = prepare_z_dist(opt.training.eval_batch_size,
                                     opt.GenModel.style_dim,
                                     self.device,
                                     seed=self.opt.seed)
        self.eval_y = prepare_y_dist(opt.training.eval_batch_size,
                                     len(self.lexicon),
                                     self.device,
                                     seed=self.opt.seed)

        self.train_loader = DataLoader(get_dataset(opt.dataset,
                                                   opt.training.dset_split),
                                       batch_size=opt.training.batch_size,
                                       shuffle=True,
                                       collate_fn=self.collect_fn,
                                       num_workers=4,
                                       drop_last=True)

        self.tst_loader = DataLoader(
            get_dataset(opt.dataset, opt.valid.dset_split),
            batch_size=opt.training.eval_batch_size // 2,
            shuffle=True,
            collate_fn=self.collect_fn)

        self.tst_loader2 = DataLoader(
            get_dataset(opt.dataset, opt.training.dset_split),
            batch_size=opt.training.eval_batch_size // 2,
            shuffle=True,
            collate_fn=self.collect_fn,
            num_workers=4)

        self.optimizers = Munch(
            G=torch.optim.Adam(chain(self.models.G.parameters(),
                                     self.models.E.parameters()),
                               lr=opt.training.lr,
                               betas=(opt.training.adam_b1,
                                      opt.training.adam_b2)),
            D=torch.optim.Adam(chain(self.models.D.parameters(),
                                     self.models.R.parameters(),
                                     self.models.W.parameters()),
                               lr=opt.training.lr,
                               betas=(opt.training.adam_b1,
                                      opt.training.adam_b2)),
        )

        self.lr_schedulers = Munch(G=get_scheduler(self.optimizers.G,
                                                   opt.training),
                                   D=get_scheduler(self.optimizers.D,
                                                   opt.training))

        self.averager_meters = AverageMeterManager([
            'adv_loss', 'fake_disc_loss', 'real_disc_loss', 'info_loss',
            'fake_ctc_loss', 'real_ctc_loss', 'fake_wid_loss', 'real_wid_loss',
            'kl_loss', 'gp_ctc', 'gp_info', 'gp_wid'
        ])
        device = self.device

        ctc_len_scale = 8
        best_kid = np.inf
        iter_count = 0
        for epoch in range(1, self.opt.training.epochs):
            for i, (imgs, img_lens, lbs, lb_lens,
                    wids) in enumerate(self.train_loader):
                #############################
                # Prepare inputs & Network Forward
                #############################
                self.set_mode('train')
                real_imgs, real_img_lens, real_wids = imgs.to(
                    device), img_lens.to(device), wids.to(device)
                real_lbs, real_lb_lens = lbs.to(device), lb_lens.to(device)

                #############################
                # Optimizing Recognizer & Writer Identifier & Discriminator
                #############################
                self.optimizers.D.zero_grad()
                set_requires_grad([self.models.G, self.models.E], False)
                set_requires_grad(
                    [self.models.R, self.models.D, self.models.W], True)

                ### Compute CTC loss for real samples###
                real_ctc = self.models.R(real_imgs)
                real_ctc_lens = real_img_lens // ctc_len_scale
                real_ctc_loss = self.ctc_loss(real_ctc, real_lbs,
                                              real_ctc_lens, real_lb_lens)
                self.averager_meters.update('real_ctc_loss',
                                            real_ctc_loss.item())

                real_wid_logits = self.models.W(real_imgs, real_img_lens)
                real_wid_loss = self.classify_loss(real_wid_logits, real_wids)
                self.averager_meters.update('real_wid_loss',
                                            real_wid_loss.item())

                with torch.no_grad():
                    self.y.sample_()
                    sampled_words = idx_to_words(
                        self.y, self.lexicon,
                        self.opt.training.capitalize_ratio)
                    fake_lbs, fake_lb_lens = self.label_converter.encode(
                        sampled_words)
                    fake_lbs, fake_lb_lens = fake_lbs.to(
                        device).detach(), fake_lb_lens.to(device).detach()

                    self.z.sample_()
                    fake_imgs = self.models.G(self.z, fake_lbs, fake_lb_lens)

                    enc_styles, _, _ = self.models.E(
                        real_imgs,
                        real_img_lens,
                        self.models.W.cnn_backbone,
                        vae_mode=True)
                    noises = torch.randn(
                        (real_imgs.size(0), self.opt.GenModel.style_dim -
                         self.opt.EncModel.style_dim)).float().to(device)
                    enc_z = torch.cat([noises, enc_styles], dim=-1)
                    style_imgs = self.models.G(enc_z, fake_lbs, fake_lb_lens)

                    cat_fake_imgs = torch.cat([fake_imgs, style_imgs], dim=0)
                    cat_fake_lb_lens = fake_lb_lens.repeat(2, ).detach()
                    cat_fake_img_lens = cat_fake_lb_lens * self.opt.char_width

                ### Compute discriminative loss for real & fake samples ###
                fake_disc = self.models.D(cat_fake_imgs.detach(),
                                          cat_fake_img_lens, cat_fake_lb_lens)
                fake_disc_loss = torch.mean(F.relu(1.0 + fake_disc))

                real_disc = self.models.D(real_imgs, real_img_lens,
                                          real_lb_lens)
                real_disc_loss = torch.mean(F.relu(1.0 - real_disc))

                disc_loss = real_disc_loss + fake_disc_loss
                self.averager_meters.update('real_disc_loss',
                                            real_disc_loss.item())
                self.averager_meters.update('fake_disc_loss',
                                            fake_disc_loss.item())

                (real_ctc_loss + disc_loss + real_wid_loss).backward()
                self.optimizers.D.step()

                #############################
                # Optimizing Generator
                #############################
                if iter_count % self.opt.training.num_critic_train == 0:
                    self.optimizers.G.zero_grad()
                    set_requires_grad(
                        [self.models.D, self.models.R, self.models.W], False)
                    set_requires_grad([self.models.G, self.models.E], True)

                    ##########################
                    # Prepare Fake Inputs
                    ##########################
                    self.y.sample_()
                    sampled_words = idx_to_words(
                        self.y, self.lexicon,
                        self.opt.training.capitalize_ratio)
                    fake_lbs, fake_lb_lens = self.label_converter.encode(
                        sampled_words)
                    fake_lbs, fake_lb_lens = fake_lbs.to(
                        device).detach(), fake_lb_lens.to(device).detach()
                    fake_img_lens = fake_lb_lens * self.opt.char_width

                    self.z.sample_()
                    fake_imgs = self.models.G(self.z, fake_lbs, fake_lb_lens)

                    enc_styles, enc_mu, enc_logvar = self.models.E(
                        real_imgs,
                        real_img_lens,
                        self.models.W.cnn_backbone,
                        vae_mode=True)
                    noises = torch.randn(
                        (real_imgs.size(0), self.opt.GenModel.style_dim -
                         self.opt.EncModel.style_dim)).float().to(device)
                    enc_z = torch.cat([noises, enc_styles], dim=-1)
                    style_imgs = self.models.G(enc_z, fake_lbs, fake_lb_lens)
                    style_img_lens = fake_lb_lens * self.opt.char_width

                    ### Concatenating all generated images in a batch ###
                    cat_fake_imgs = torch.cat([fake_imgs, style_imgs], dim=0)
                    cat_fake_lbs = fake_lbs.repeat(2, 1).detach()
                    cat_fake_lb_lens = fake_lb_lens.repeat(2, ).detach()
                    cat_fake_img_lens = cat_fake_lb_lens * self.opt.char_width

                    ###################################################
                    # Calculating G Losses
                    ####################################################
                    ### deal with fake samples ###
                    ### Compute Adversarial loss ###
                    cat_fake_disc = self.models.D(cat_fake_imgs,
                                                  cat_fake_img_lens,
                                                  cat_fake_lb_lens)
                    adv_loss = -torch.mean(cat_fake_disc)

                    ### CTC Auxiliary loss ###
                    cat_fake_ctc = self.models.R(cat_fake_imgs)
                    cat_fake_ctc_lens = cat_fake_img_lens // ctc_len_scale
                    fake_ctc_loss = self.ctc_loss(cat_fake_ctc, cat_fake_lbs,
                                                  cat_fake_ctc_lens,
                                                  cat_fake_lb_lens)

                    ### Latent Style Reconstruction ###
                    styles = self.models.E(fake_imgs, fake_img_lens,
                                           self.models.W.cnn_backbone)
                    info_loss = torch.mean(
                        torch.abs(
                            styles -
                            self.z[:, -self.opt.EncModel.style_dim:].detach()))

                    ### Writer Identify Loss ###
                    recn_wid_logits = self.models.W(style_imgs, style_img_lens)
                    fake_wid_loss = self.classify_loss(recn_wid_logits,
                                                       real_wids)

                    ### KL-Divergence Loss ###
                    kl_loss = KLloss(enc_mu, enc_logvar)

                    ### Gradient balance ###
                    grad_fake_adv = torch.autograd.grad(adv_loss,
                                                        cat_fake_imgs,
                                                        create_graph=True,
                                                        retain_graph=True)[0]
                    grad_fake_OCR = torch.autograd.grad(fake_ctc_loss,
                                                        cat_fake_ctc,
                                                        create_graph=True,
                                                        retain_graph=True)[0]
                    grad_fake_info = torch.autograd.grad(info_loss,
                                                         fake_imgs,
                                                         create_graph=True,
                                                         retain_graph=True)[0]
                    grad_fake_wid = torch.autograd.grad(fake_wid_loss,
                                                        recn_wid_logits,
                                                        create_graph=True,
                                                        retain_graph=True)[0]

                    std_grad_adv = torch.std(grad_fake_adv)
                    gp_ctc = torch.div(
                        std_grad_adv,
                        torch.std(grad_fake_OCR) + 1e-8).detach() + 1
                    gp_info = torch.div(
                        std_grad_adv,
                        torch.std(grad_fake_info) + 1e-8).detach() + 1
                    gp_wid = torch.div(
                        std_grad_adv,
                        torch.std(grad_fake_wid) + 1e-8).detach() + 1
                    self.averager_meters.update('gp_ctc', gp_ctc.item())
                    self.averager_meters.update('gp_info', gp_info.item())
                    self.averager_meters.update('gp_wid', gp_wid.item())

                    g_loss = 2 * adv_loss + \
                             gp_ctc * fake_ctc_loss + \
                             gp_info * info_loss + \
                             gp_wid * fake_wid_loss + \
                             self.opt.training.lambda_kl * kl_loss
                    g_loss.backward()
                    self.averager_meters.update('adv_loss', adv_loss.item())
                    self.averager_meters.update('fake_ctc_loss',
                                                fake_ctc_loss.item())
                    self.averager_meters.update('info_loss', info_loss.item())
                    self.averager_meters.update('fake_wid_loss',
                                                fake_wid_loss.item())
                    self.averager_meters.update('kl_loss', kl_loss.item())
                    self.optimizers.G.step()

                if iter_count % self.opt.training.print_iter_val == 0:
                    meter_vals = self.averager_meters.eval_all()
                    self.averager_meters.reset_all()
                    info = "[%3d|%3d]-[%4d|%4d] G:%.4f D-fake:%.4f D-real:%.4f " \
                           "CTC-fake:%.4f CTC-real:%.4f Wid-fake:%.4f Wid-real:%.4f " \
                           "Recn-z:%.4f Kl:%.4f" \
                           % (epoch, self.opt.training.epochs,
                              iter_count % len(self.train_loader), len(self.train_loader),
                              meter_vals['adv_loss'],
                              meter_vals['fake_disc_loss'], meter_vals['real_disc_loss'],
                              meter_vals['fake_ctc_loss'], meter_vals['real_ctc_loss'],
                              meter_vals['fake_wid_loss'], meter_vals['real_wid_loss'],
                              meter_vals['info_loss'], meter_vals['kl_loss'])
                    self.print(info)

                    if self.writer:
                        for key, val in meter_vals.items():
                            self.writer.add_scalar('loss/%s' % key, val,
                                                   iter_count + 1)

                if (iter_count + 1) % self.opt.training.sample_iter_val == 0:
                    if not (self.logger and self.writer):
                        self.create_logger()

                    sample_root = os.path.join(self.log_root,
                                               self.opt.training.sample_dir)
                    if not os.path.exists(sample_root):
                        os.makedirs(sample_root)
                    self.sample_images(iter_count + 1)

                iter_count += 1

            if epoch:
                ckpt_root = os.path.join(self.log_root,
                                         self.opt.training.ckpt_dir)
                if not os.path.exists(ckpt_root):
                    os.makedirs(ckpt_root)

                self.save('last', epoch)
                if epoch >= self.opt.training.start_save_epoch_val and \
                        epoch % self.opt.training.save_epoch_val == 0:
                    self.print('Calculate FID_KID')
                    scores = self.validate()
                    fid, kid = scores['FID'], scores['KID']
                    self.print('FID:{} KID:{}'.format(fid, kid))

                    if kid < best_kid:
                        best_kid = kid
                        self.save('best', epoch, KID=kid, FID=fid)
                    if self.writer:
                        self.writer.add_scalar('valid/FID', fid, epoch)
                        self.writer.add_scalar('valid/KID', kid, epoch)

            for scheduler in self.lr_schedulers.values():
                scheduler.step(epoch)
def main():
    #
    args = parse_args()
    #
    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.cfg, 'test')
    #
    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))
    #
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.determinstic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    #
    config.defrost()
    config.MODEL.INIT_WEIGHTS = False
    config.freeze()
    model = models.get_face_alignment_net(config)
    #
    gpus = list(config.GPUS)
    #
    # # load model
    state_dict = torch.load(args.model_file)
    if 'state_dict' in state_dict.keys():
        state_dict = state_dict['state_dict']
        model.load_state_dict(state_dict)
    else:
        model.module.load_state_dict(state_dict)
    model = nn.DataParallel(model, device_ids=gpus).cuda()
    #
    dataset_type = get_dataset(config)

    test_loader = DataLoader(dataset=dataset_type(config, is_train=False),
                             batch_size=config.TEST.BATCH_SIZE_PER_GPU *
                             len(gpus),
                             shuffle=False,
                             num_workers=config.WORKERS,
                             pin_memory=config.PIN_MEMORY)

    nme, predictions = function.inference(config, test_loader, model)
    torch.save(predictions, os.path.join(final_output_dir, 'predictions.pth'))
    target = test_loader.dataset.load_all_pts()
    pred = 16 * predictions
    l = len(pred)
    res = 0.0
    res_tmp = [0.0 for i in range(config.MODEL.NUM_JOINTS)]

    res_tmp = np.array(res_tmp)
    res_temp_x = target - pred
    res_temp_x = res_temp_x[:, :, 0]
    res_temp_y = target - pred
    res_temp_y = res_temp_y[:, :, 1]

    # csv_file_test_x = pd.DataFrame(np.transpose(np.array(pred[:, :, 0])), columns=test_loader.dataset.annotation_files)
    # csv_file_test_y = pd.DataFrame(np.transpose(np.array(pred[:, :, 1])), columns=test_loader.dataset.annotation_files)
    # csv_file_target_x = pd.DataFrame(np.transpose(np.array(target[:, :, 0])), columns=test_loader.dataset.annotation_files)
    # csv_file_target_y = pd.DataFrame(np.transpose(np.array(target[:, :, 1])), columns=test_loader.dataset.annotation_files)

    for i in range(l):
        trans = np.sqrt(
            pow(target[i][0][0] - target[i][1][0], 2) +
            pow(target[i][0][1] - target[i][1][1], 2)) / 30.0
        res_temp_x[i] = res_temp_x[i] / trans
        res_temp_y[i] = res_temp_y[i] / trans
        for j in range(len(target[i])):
            dist = np.sqrt(
                np.power((target[i][j][0] - pred[i][j][0]), 2) +
                np.power((target[i][j][1] - pred[i][j][1]), 2)) / trans
            res += dist
            res_tmp[j] += dist
    res_t = np.sqrt(res_temp_x * res_temp_x + res_temp_y * res_temp_y)
    # pd.DataFrame(data=res_temp_x.data.value).to_csv('res_x')
    # pd.DataFrame(data=res_temp_y.data.value).to_csv('res_y')
    # pd.DataFrame(data=res_t.data.value).to_csv('res_t')
    res_tmp /= np.float(len(pred))
    print(res_tmp)
    print(np.mean(res_tmp))
    res /= (len(pred) * len(pred[0]))
    print(res)
Ejemplo n.º 17
0
def main():

    args = parse_args()

    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn.benchmark = config.CUDNN.BENCHMARK
    # cudnn.determinstic = config.CUDNN.DETERMINISTIC
    # cudnn.enabled = config.CUDNN.ENABLED



    # if isinstance(config.TRAIN.LR_STEP, list):
    #     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
    #         optimizer, config.TRAIN.LR_STEP,
    #         # config.TRAIN.LR_FACTOR, last_epoch-1
    #         config.TRAIN.LR_FACTOR, 0
    #     )
    # else:
    #     lr_scheduler = torch.optim.lr_scheduler.StepLR(
    #         optimizer, config.TRAIN.LR_STEP,
    #         # config.TRAIN.LR_FACTOR, last_epoch-1
    #     config.TRAIN.LR_FACTOR, 0
    #     )
    dataset_type = get_dataset(config)
    train_dataset = dataset_type(config,
                             is_train=True)

    # train_dataset[0]
    # return 0

    train_loader = DataLoader(
        dataset=dataset_type(config,
                             is_train=True),
        # batch_size=config.TRAIN.BATCH_SIZE_PER_GPU*len(gpus),
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS)



    # val_loader = DataLoader(
    #     dataset=dataset_type(config,
    #                          is_train=True),
    #     # batch_size=config.TEST.BATCH_SIZE_PER_GPU*len(gpus),
    #     batch_size=config.TEST.BATCH_SIZE_PER_GPU,
    #     shuffle=False,
    #     num_workers=config.WORKERS,
    #     # pin_memory=config.PIN_MEMORY
    # )

    model = models.get_face_alignment_net(config)

    # copy model files
    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    gpus = list(config.GPUS)
    # model = nn.DataParallel(model, device_ids=gpus).cuda()
    model.to("cuda")
    # loss
    criterion = torch.nn.MSELoss(size_average=True).cuda()
    # criterion = fnn.mse_loss
    # criterion = WingLoss()
    # criterion = Loss_weighted()

    optimizer = utils.get_optimizer(config, model)
    best_nme = 100
    last_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir,
                                        'latest.pth')

        if os.path.isfile(model_state_file):
            with open(model_state_file, "rb") as fp:
                state_dict = torch.load(fp)
                model.load_state_dict(state_dict)
                last_epoch = 1
            # checkpoint = torch.load(model_state_file)
            # last_epoch = checkpoint['epoch']
            # best_nme = checkpoint['best_nme']
            # model.load_state_dict(checkpoint['state_dict'])
            # optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint (epoch {})"
                  .format(last_epoch))
        else:
            print("=> no checkpoint found")

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
        lr_scheduler.step()

        function.train(config, train_loader, model, criterion,
                       optimizer, epoch, writer_dict)

        # evaluate
        nme = 0
        # nme, predictions = function.validate(config, val_loader, model,
        #                                    criterion, epoch, writer_dict)

        is_best = True
        # is_best = nme < best_nme
        best_nme = min(nme, best_nme)


        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        print("best:", is_best)
        torch.save(model.state_dict(), os.path.join(final_output_dir, 'mse_relu_checkpoint_{}.pth'.format(epoch)))

        # utils.save_checkpoint(
        #     {"state_dict": model,
        #      "epoch": epoch + 1,
        #      "best_nme": best_nme,
        #      "optimizer": optimizer.state_dict(),
        #      }, predictions, is_best, final_output_dir, 'checkpoint_{}.pth'.format(epoch))

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state.pth')
    logger.info('saving final model state to {}'.format(
        final_model_state_file))
    torch.save(model.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
Ejemplo n.º 18
0
def main():

    args = parse_args()

    # set logger and dir
    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.experiment_name, 'train')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # set cudnn
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.determinstic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    # 目前仅支持单gpu,todo:增加多gpu支持
    # set model and loss and criterion
    model = models.get_face_alignment_net(config)
    model = model.cuda(config.GPUS[0])
    criterion = torch.nn.MSELoss(size_average=True).cuda(config.GPUS[0])
    # criterion = AdaptiveWingLoss()
    optimizer = utils.get_optimizer(config, model)

    # get dataset
    dataset_type = get_dataset(config)

    # get dataloader
    train_loader = DataLoader(dataset=dataset_type(config, is_train=True),
                              batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
                              shuffle=config.TRAIN.SHUFFLE,
                              num_workers=config.WORKERS,
                              pin_memory=config.PIN_MEMORY)

    val_loader = DataLoader(dataset=dataset_type(config, is_train=False),
                            batch_size=config.TEST.BATCH_SIZE_PER_GPU,
                            shuffle=False,
                            num_workers=config.WORKERS,
                            pin_memory=config.PIN_MEMORY)

    # set lr_scheduler
    last_epoch = config.TRAIN.BEGIN_EPOCH
    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
            last_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       config.TRAIN.LR_STEP,
                                                       config.TRAIN.LR_FACTOR,
                                                       last_epoch - 1)

    # set training writer
    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # set training resume function
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir, 'latest.pth')
        if os.path.islink(model_state_file):
            checkpoint = torch.load(model_state_file)
            last_epoch = checkpoint['epoch']
            best_nme = checkpoint['best_nme']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no checkpoint found")

    # starting training
    best_nme = 10000
    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):

        lr_scheduler.step()

        # traing
        function.train(config, train_loader, model, criterion, optimizer,
                       epoch, writer_dict)

        # evaluating
        nme, predictions = function.validate(config, val_loader, model,
                                             criterion, epoch, writer_dict)

        # saving
        is_best = nme < best_nme
        best_nme = min(nme, best_nme)

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        print("best:", is_best)
        utils.save_checkpoint(
            {
                "state_dict": model,
                "epoch": epoch + 1,
                "best_nme": best_nme,
                "optimizer": optimizer.state_dict(),
            }, predictions, is_best, final_output_dir,
            'checkpoint_{}.pth'.format(epoch))

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        'saving final model state to {}'.format(final_model_state_file))
    torch.save(model.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
Ejemplo n.º 19
0
def main_wflw():
    args = parse_args()
    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.cfg, 'test')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.determinstic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    config.defrost()
    config.MODEL.INIT_WEIGHTS = False
    config.freeze()
    model = models.get_face_alignment_net(config)

    dataset_type = get_dataset(config)
    dataset = dataset_type(config, is_train=True)

    fp = open('data/wflw/face_landmarks70_wflw_train.csv', 'w')
    for i in range(len(dataset)):
        img, image_path, meta = dataset[i]
        fold, name = image_path.split('/')[-2], image_path.split('/')[-1]
        folder = osp.join('data/wflw/xximages', fold)
        if not osp.exists(folder):
            os.makedirs(folder)
        fname = osp.join(folder, name)
        scale = meta['scale']
        center = meta['center']
        tpts = meta['tpts']

        selpts = []
        for j in range(0, 33, 2):
            selpts.append(tpts[j])
        # eyebow
        selpts.append(tpts[33])
        selpts.append((tpts[34] + tpts[41]) / 2)
        selpts.append((tpts[35] + tpts[40]) / 2)
        selpts.append((tpts[36] + tpts[39]) / 2)
        selpts.append((tpts[37] + tpts[38]) / 2)
        selpts.append((tpts[42] + tpts[50]) / 2)
        selpts.append((tpts[43] + tpts[49]) / 2)
        selpts.append((tpts[44] + tpts[48]) / 2)
        selpts.append((tpts[45] + tpts[47]) / 2)
        selpts.append(tpts[46])
        # nose
        for j in range(51, 60):
            selpts.append(tpts[j])
        # eye
        selpts.append(tpts[60])
        selpts.append((tpts[61] + tpts[62]) / 2)
        selpts.append(tpts[63])
        selpts.append(tpts[64])
        selpts.append(tpts[65])
        selpts.append((tpts[66] + tpts[67]) / 2)
        selpts.append(tpts[68])
        selpts.append(tpts[69])
        selpts.append((tpts[70] + tpts[71]) / 2)
        selpts.append(tpts[72])
        selpts.append((tpts[73] + tpts[74]) / 2)
        selpts.append(tpts[75])
        for j in range(76, 98):
            selpts.append(tpts[j])

        fp.write('%s,%.2f,%.1f,%.1f' %
                 (osp.join(fold, name), scale, center[0], center[1]))
        for spt in selpts:
            cv2.circle(img, (spt[0], spt[1]), 1, (0, 0, 255))
            fp.write(',%f,%f' % (spt[0], spt[1]))
        fp.write('\n')
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        cv2.imwrite(fname, img)
    fp.close()
Ejemplo n.º 20
0
def main():

    args = parse_args()

    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.cfg, 'test')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn.benchmark = config.CUDNN.BENCHMARK
    # cudnn.determinstic = config.CUDNN.DETERMINISTIC
    # cudnn.enabled = config.CUDNN.ENABLED

    config.defrost()
    config.MODEL.INIT_WEIGHTS = False
    config.freeze()
    model = models.get_face_alignment_net(config)

    gpus = list(config.GPUS)
    # model = nn.DataParallel(model, device_ids=gpus).cuda()
    model.to("cuda")
    # print(model)
    # load model
    # state_dict = torch.load(args.model_file)
    # print(state_dict)
    # model = torch.load(args.model_file)
    with open(args.model_file, "rb") as fp:
        state_dict = torch.load(fp)
        model.load_state_dict(state_dict)
    # model.load_state_dict(state_dict['state_dict'])
    # if 'state_dict' in state_dict.keys():
    #     state_dict = state_dict['state_dict']
    #     # print(state_dict)
    #     model.load_state_dict(state_dict)
    # else:
    #     model.module.load_state_dict(state_dict)

    dataset_type = get_dataset(config)

    test_loader = DataLoader(
        dataset=dataset_type(config,
                             is_train=False),
        batch_size=config.TEST.BATCH_SIZE_PER_GPU*len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY
    )

    predictions  = function.inference(config, test_loader, model)
    # print("len(predictions)", len(predictions))
    # print(predictions[0])
    df_predictions = []
    for pred in predictions:
        row = dict()
        row['file_name'] = pred[0]
        for id_point in range(194):
            row[f'Point_M{id_point}_X'] = int(pred[1][id_point])
            row[f'Point_M{id_point}_Y'] = int(pred[2][id_point])
        df_predictions.append(row)
    df_predictions = pd.DataFrame(df_predictions)
    # print(predictions_meta[0])
    df_predictions.to_csv('pred_test.csv', index=False)
Ejemplo n.º 21
0
def main():

    args = parse_args()

    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.determinstic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    gpus = list(config.GPUS)

    dataset_type = get_dataset(config)
    train_data = dataset_type(config, split="train")
    train_loader = DataLoader(dataset=train_data,
                              batch_size=config.TRAIN.BATCH_SIZE_PER_GPU *
                              len(gpus),
                              shuffle=config.TRAIN.SHUFFLE,
                              num_workers=config.WORKERS,
                              pin_memory=config.PIN_MEMORY)

    val_data = dataset_type(config, split="valid")
    val_loader = DataLoader(dataset=val_data,
                            batch_size=config.TEST.BATCH_SIZE_PER_GPU *
                            len(gpus),
                            shuffle=False,
                            num_workers=config.WORKERS,
                            pin_memory=config.PIN_MEMORY)

    # config.MODEL.NUM_JOINTS = train_data.get_num_points()
    model = models.get_face_alignment_net(config)

    # copy model files
    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    model = nn.DataParallel(model, device_ids=gpus).cuda()

    # loss
    criterion = torch.nn.MSELoss(size_average=True).cuda()
    optimizer = utils.get_optimizer(config, model)

    best_nme = 100
    last_epoch = config.TRAIN.BEGIN_EPOCH

    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
            last_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       config.TRAIN.LR_STEP,
                                                       config.TRAIN.LR_FACTOR,
                                                       last_epoch - 1)

    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir, 'final.pth')
        if os.path.islink(model_state_file):
            checkpoint = torch.load(model_state_file)
            last_epoch = checkpoint['epoch']
            best_nme = checkpoint['best_nme']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no checkpoint found")
    loss = []

    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
        losses, diff = function.train(config, train_loader, model, criterion,
                                      optimizer, epoch, writer_dict)
        loss.append(losses)
        lr_scheduler.step()

        np.save(
            os.path.join(final_output_dir, "train_diff@epoch{}".format(epoch)),
            diff)

        # evaluate
        nme, predictions, diff = function.validate(config, val_loader, model,
                                                   criterion, epoch,
                                                   writer_dict)

        np.save(
            os.path.join(final_output_dir, "valid_diff@epoch{}".format(epoch)),
            diff)

        is_best = nme < best_nme
        best_nme = min(nme, best_nme)

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        print("best:", is_best)
        utils.save_checkpoint(
            {
                "state_dict": model,
                "epoch": epoch + 1,
                "best_nme": best_nme,
                "optimizer": optimizer.state_dict(),
            }, predictions, is_best, final_output_dir,
            'checkpoint_{}.pth'.format(epoch))
        if is_best:
            for i in range(len(predictions)):
                afile = val_data.annotation_files[i]
                new_afile = '{}.{}.txt'.format(
                    afile,
                    os.path.basename(args.cfg).split('.')[0])
                with open(new_afile, 'wt') as f:
                    pts = predictions[i].cpu().numpy()
                    for j in range(len(pts)):
                        f.write("{},{}\n".format(
                            pts[j][1] / val_data.factor[1],
                            pts[j][0] / val_data.factor[0]))

    pd.DataFrame(data=loss).to_csv('loss2.csv')
    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        'saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()