コード例 #1
0
ファイル: image_model.py プロジェクト: fabriquant/AutoFuzz
    def __init__(self, hparams, teacher_path=''):
        super().__init__()

        # addition: convert dict to namespace when necessary
        # hack:
        if isinstance(hparams, dict):
            import argparse
            args = argparse.Namespace()
            for k, v in hparams.items():
                setattr(args, k, v)
            hparams = args

        self.hparams = hparams
        self.to_heatmap = ToHeatmap(hparams.heatmap_radius)

        if teacher_path:
            # modifiction: add str
            self.teacher = MapModel.load_from_checkpoint(str(teacher_path))
            self.teacher.freeze()

        self.net = SegmentationModel(10,
                                     4,
                                     hack=hparams.hack,
                                     temperature=hparams.temperature)
        self.converter = Converter()
        self.controller = RawController(4)
コード例 #2
0
    def __init__(self, hparams):
        super().__init__()

        self.hparams = hparams
        self.net = SegmentationModel(4, 4)

        self.teacher = MapModel.load_from_checkpoint(pathlib.Path('/home/bradyzhou/code/carla_random/') / hparams.teacher_path)
        # self.teacher.eval()

        self.converter = Converter()
コード例 #3
0
    def __init__(self, hparams):
        super().__init__()

        # addition: convert dict to namespace when necessary
        # hack:
        if isinstance(hparams, dict):
            import argparse
            args = argparse.Namespace()
            for k,v in hparams.items():
                setattr(args, k, v)
            hparams = args

        self.hparams = hparams
        self.to_heatmap = ToHeatmap(hparams.heatmap_radius)
        self.net = SegmentationModel(10, 4, hack=hparams.hack, temperature=hparams.temperature)
        self.controller = RawController(4)
コード例 #4
0
def get_model(ie, args):
    if args.architecture_type == 'segmentation':
        return SegmentationModel(ie, args.model), SegmentationVisualizer(
            args.colors)
    if args.architecture_type == 'salient_object_detection':
        return SalientObjectDetectionModel(
            ie, args.model), SaliencyMapVisualizer()
コード例 #5
0
ファイル: image_model.py プロジェクト: fabriquant/AutoFuzz
class ImageModel(pl.LightningModule):
    def __init__(self, hparams, teacher_path=''):
        super().__init__()

        # addition: convert dict to namespace when necessary
        # hack:
        if isinstance(hparams, dict):
            import argparse
            args = argparse.Namespace()
            for k, v in hparams.items():
                setattr(args, k, v)
            hparams = args

        self.hparams = hparams
        self.to_heatmap = ToHeatmap(hparams.heatmap_radius)

        if teacher_path:
            # modifiction: add str
            self.teacher = MapModel.load_from_checkpoint(str(teacher_path))
            self.teacher.freeze()

        self.net = SegmentationModel(10,
                                     4,
                                     hack=hparams.hack,
                                     temperature=hparams.temperature)
        self.converter = Converter()
        self.controller = RawController(4)

    def forward(self, img, target):
        target_cam = self.converter.map_to_cam(target)
        target_heatmap_cam = self.to_heatmap(target, img)[:, None]
        out = self.net(torch.cat((img, target_heatmap_cam), 1))
        # print(out, target)
        return out, (target_cam, target_heatmap_cam)

    @torch.no_grad()
    def _get_labels(self, topdown, target):
        out, (target_heatmap, ) = self.teacher.forward(topdown,
                                                       target,
                                                       debug=True)
        control = self.teacher.controller(out)

        return out, control, (target_heatmap, )

    def training_step(self, batch, batch_nb):
        img, topdown, points, target, actions, meta = batch

        # Ground truth command.
        lbl_map, ctrl_map, (target_heatmap, ) = self._get_labels(
            topdown, target)
        lbl_cam = self.converter.map_to_cam((lbl_map + 1) / 2 * 256)
        lbl_cam[..., 0] = (lbl_cam[..., 0] / 256) * 2 - 1
        lbl_cam[..., 1] = (lbl_cam[..., 1] / 144) * 2 - 1

        out, (target_cam, target_heatmap_cam) = self.forward(img, target)

        alpha = torch.rand(out.shape[0], out.shape[1], 1).type_as(out)
        between = alpha * out + (1 - alpha) * lbl_cam
        out_ctrl = self.controller(between)

        point_loss = torch.nn.functional.l1_loss(out,
                                                 lbl_cam,
                                                 reduction='none').mean((1, 2))
        ctrl_loss_raw = torch.nn.functional.l1_loss(out_ctrl,
                                                    ctrl_map,
                                                    reduction='none')
        ctrl_loss = ctrl_loss_raw.mean(1)
        steer_loss = ctrl_loss_raw[:, 0]
        speed_loss = ctrl_loss_raw[:, 1]

        loss_gt = (point_loss + self.hparams.command_coefficient * ctrl_loss)
        loss_gt_mean = loss_gt.mean()

        # Random command.
        indices = np.random.choice(RANDOM_POINTS.shape[0], topdown.shape[0])
        target_aug = torch.from_numpy(RANDOM_POINTS[indices]).type_as(img)

        lbl_map_aug, ctrl_map_aug, (target_heatmap_aug, ) = self._get_labels(
            topdown, target_aug)
        lbl_cam_aug = self.converter.map_to_cam((lbl_map_aug + 1) / 2 * 256)
        lbl_cam_aug[..., 0] = (lbl_cam_aug[..., 0] / 256) * 2 - 1
        lbl_cam_aug[..., 1] = (lbl_cam_aug[..., 1] / 144) * 2 - 1

        out_aug, (target_cam_aug,
                  target_heatmap_cam_aug) = self.forward(img, target_aug)

        alpha = torch.rand(out.shape[0], out.shape[1], 1).type_as(out)
        between_aug = alpha * out_aug + (1 - alpha) * lbl_cam_aug
        out_ctrl_aug = self.controller(between_aug)

        point_loss_aug = torch.nn.functional.l1_loss(out_aug,
                                                     lbl_cam_aug,
                                                     reduction='none').mean(
                                                         (1, 2))
        ctrl_loss_aug_raw = torch.nn.functional.l1_loss(out_ctrl_aug,
                                                        ctrl_map_aug,
                                                        reduction='none')
        ctrl_loss_aug = ctrl_loss_aug_raw.mean(1)
        steer_loss_aug = ctrl_loss_aug_raw[:, 0]
        speed_loss_aug = ctrl_loss_aug_raw[:, 1]

        loss_aug = (point_loss_aug +
                    self.hparams.command_coefficient * ctrl_loss_aug)
        loss_aug_mean = loss_aug.mean()

        loss = loss_gt_mean + loss_aug_mean
        metrics = {
            'train_loss': loss.item(),
            'train_point': point_loss.mean().item(),
            'train_ctrl': ctrl_loss.mean().item(),
            'train_steer': steer_loss.mean().item(),
            'train_speed': speed_loss.mean().item(),
            'train_point_aug': point_loss_aug.mean().item(),
            'train_ctrl_aug': ctrl_loss_aug.mean().item(),
            'train_steer_aug': steer_loss_aug.mean().item(),
            'train_speed_aug': speed_loss_aug.mean().item(),
        }

        if batch_nb % 250 == 0:
            metrics['train_image'] = viz(batch, out, out_ctrl, target_cam,
                                         lbl_cam, lbl_map, ctrl_map,
                                         point_loss, ctrl_loss)
            metrics['train_image_aug'] = viz(batch, out_aug, out_ctrl_aug,
                                             target_cam_aug, lbl_cam_aug,
                                             lbl_map_aug, ctrl_map_aug,
                                             point_loss_aug, ctrl_loss_aug)

        self.logger.log_metrics(metrics, self.global_step)

        return {'loss': loss}

    def validation_step(self, batch, batch_nb):
        img, topdown, points, target, actions, meta = batch

        # Ground truth command.
        lbl_map, ctrl_map, (target_heatmap, ) = self._get_labels(
            topdown, target)
        lbl_cam = self.converter.map_to_cam((lbl_map + 1) / 2 * 256)
        lbl_cam[..., 0] = (lbl_cam[..., 0] / 256) * 2 - 1
        lbl_cam[..., 1] = (lbl_cam[..., 1] / 144) * 2 - 1

        out, (target_cam, target_heatmap_cam) = self.forward(img, target)
        out_ctrl = self.controller(out)
        out_ctrl_gt = self.controller(lbl_cam)

        point_loss = torch.nn.functional.l1_loss(out,
                                                 lbl_cam,
                                                 reduction='none').mean((1, 2))
        ctrl_loss_raw = torch.nn.functional.l1_loss(out_ctrl,
                                                    ctrl_map,
                                                    reduction='none')
        ctrl_loss = ctrl_loss_raw.mean(1)
        steer_loss = ctrl_loss_raw[:, 0]
        speed_loss = ctrl_loss_raw[:, 1]

        ctrl_loss_gt_raw = torch.nn.functional.l1_loss(out_ctrl_gt,
                                                       ctrl_map,
                                                       reduction='none')
        ctrl_loss_gt = ctrl_loss_gt_raw.mean(1)
        steer_loss_gt = ctrl_loss_gt_raw[:, 0]
        speed_loss_gt = ctrl_loss_gt_raw[:, 1]

        loss_gt = (point_loss + self.hparams.command_coefficient * ctrl_loss)
        loss_gt_mean = loss_gt.mean()

        # Random command.
        indices = np.random.choice(RANDOM_POINTS.shape[0], topdown.shape[0])
        target_aug = torch.from_numpy(RANDOM_POINTS[indices]).type_as(img)

        lbl_map_aug, ctrl_map_aug, (target_heatmap_aug, ) = self._get_labels(
            topdown, target_aug)
        lbl_cam_aug = self.converter.map_to_cam((lbl_map_aug + 1) / 2 * 256)
        lbl_cam_aug[..., 0] = (lbl_cam_aug[..., 0] / 256) * 2 - 1
        lbl_cam_aug[..., 1] = (lbl_cam_aug[..., 1] / 144) * 2 - 1
        out_aug, (target_cam_aug,
                  target_heatmap_cam_aug) = self.forward(img, target_aug)
        out_ctrl_aug = self.controller(out_aug)
        out_ctrl_gt_aug = self.controller(lbl_cam_aug)

        point_loss_aug = torch.nn.functional.l1_loss(out_aug,
                                                     lbl_cam_aug,
                                                     reduction='none').mean(
                                                         (1, 2))

        ctrl_loss_aug_raw = torch.nn.functional.l1_loss(out_ctrl_aug,
                                                        ctrl_map_aug,
                                                        reduction='none')
        ctrl_loss_aug = ctrl_loss_aug_raw.mean(1)
        steer_loss_aug = ctrl_loss_aug_raw[:, 0]
        speed_loss_aug = ctrl_loss_aug_raw[:, 1]

        ctrl_loss_gt_aug_raw = torch.nn.functional.l1_loss(out_ctrl_gt_aug,
                                                           ctrl_map_aug,
                                                           reduction='none')
        ctrl_loss_gt_aug = ctrl_loss_gt_aug_raw.mean(1)
        steer_loss_gt_aug = ctrl_loss_gt_aug_raw[:, 0]
        speed_loss_gt_aug = ctrl_loss_gt_aug_raw[:, 1]

        loss_gt_aug = (point_loss_aug +
                       self.hparams.command_coefficient * ctrl_loss_aug)
        loss_gt_aug_mean = loss_gt_aug.mean()

        if batch_nb == 0:
            self.logger.log_metrics(
                {
                    'val_image':
                    viz(batch, out, out_ctrl, target_cam, lbl_cam, lbl_map,
                        ctrl_map, point_loss, ctrl_loss),
                    'val_image_aug':
                    viz(batch, out_aug, out_ctrl_aug, target_cam_aug,
                        lbl_cam_aug, lbl_map_aug, ctrl_map_aug, point_loss_aug,
                        ctrl_loss_aug)
                }, self.global_step)

        return {
            'val_loss': (loss_gt_mean + loss_gt_aug_mean).item(),
            'val_point': point_loss.mean().item(),
            'val_ctrl': ctrl_loss.mean().item(),
            'val_steer': steer_loss.mean().item(),
            'val_speed': speed_loss.mean().item(),
            'val_ctrl_gt': ctrl_loss_gt.mean().item(),
            'val_steer_gt': steer_loss_gt.mean().item(),
            'val_speed_gt': speed_loss_gt.mean().item(),
            'val_point_aug': point_loss_aug.mean().item(),
            'val_ctrl_aug': ctrl_loss_aug.mean().item(),
            'val_steer_aug': steer_loss_aug.mean().item(),
            'val_speed_aug': speed_loss_aug.mean().item(),
            'val_ctrl_gt_aug': ctrl_loss_gt_aug.mean().item(),
            'val_steer_gt_aug': steer_loss_gt_aug.mean().item(),
            'val_speed_gt_aug': speed_loss_gt_aug.mean().item(),
        }

    def validation_epoch_end(self, outputs):
        results = dict()

        for output in outputs:
            for key in output:
                if key not in results:
                    results[key] = list()

                results[key].append(output[key])

        summary = {key: np.mean(val) for key, val in results.items()}
        self.logger.log_metrics(summary, self.global_step)

        return summary

    def configure_optimizers(self):
        optim = torch.optim.Adam(list(self.net.parameters()) +
                                 list(self.controller.parameters()),
                                 lr=self.hparams.lr,
                                 weight_decay=self.hparams.weight_decay)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,
                                                               mode='min',
                                                               factor=0.5,
                                                               patience=5,
                                                               min_lr=1e-6,
                                                               verbose=True)

        return [optim], [scheduler]

    def train_dataloader(self):
        return get_dataset(self.hparams.dataset_dir,
                           True,
                           self.hparams.batch_size,
                           sample_by=self.hparams.sample_by)

    def val_dataloader(self):
        return get_dataset(self.hparams.dataset_dir,
                           False,
                           self.hparams.batch_size,
                           sample_by=self.hparams.sample_by)

    def state_dict(self):
        return {
            k: v
            for k, v in super().state_dict().items() if 'teacher' not in k
        }

    def load_state_dict(self, state_dict):
        errors = super().load_state_dict(state_dict, strict=False)

        print(errors)
コード例 #6
0
class MapModel(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()

        # addition: convert dict to namespace when necessary
        # hack:
        if isinstance(hparams, dict):
            import argparse
            args = argparse.Namespace()
            for k,v in hparams.items():
                setattr(args, k, v)
            hparams = args

        self.hparams = hparams
        self.to_heatmap = ToHeatmap(hparams.heatmap_radius)
        self.net = SegmentationModel(10, 4, hack=hparams.hack, temperature=hparams.temperature)
        self.controller = RawController(4)

    def forward(self, topdown, target, debug=False):
        target_heatmap = self.to_heatmap(target, topdown)[:, None]
        out = self.net(torch.cat((topdown, target_heatmap), 1))

        if not debug:
            return out

        return out, (target_heatmap,)

    def training_step(self, batch, batch_nb):
        img, topdown, points, target, actions, meta = batch
        out, (target_heatmap,) = self.forward(topdown, target, debug=True)

        alpha = torch.rand(out.shape).type_as(out)
        between = alpha * out + (1-alpha) * points
        out_cmd = self.controller(between)

        loss_point = torch.nn.functional.l1_loss(out, points, reduction='none').mean((1, 2))
        loss_cmd_raw = torch.nn.functional.l1_loss(out_cmd, actions, reduction='none')

        loss_cmd = loss_cmd_raw.mean(1)
        loss = (loss_point + self.hparams.command_coefficient * loss_cmd).mean()

        metrics = {
                'point_loss': loss_point.mean().item(),
                'cmd_loss': loss_cmd.mean().item(),
                'loss_steer': loss_cmd_raw[:, 0].mean().item(),
                'loss_speed': loss_cmd_raw[:, 1].mean().item()
                }

        if batch_nb % 250 == 0:
            metrics['train_image'] = visualize(batch, out, between, out_cmd, loss_point, loss_cmd, target_heatmap)

        self.logger.log_metrics(metrics, self.global_step)
        print(type(loss))
        return {'loss': loss}

    def validation_step(self, batch, batch_nb):
        img, topdown, points, target, actions, meta = batch
        out, (target_heatmap,) = self.forward(topdown, target, debug=True)

        alpha = 0.0
        between = alpha * out + (1-alpha) * points
        out_cmd = self.controller(between)
        out_cmd_pred = self.controller(out)

        loss_point = torch.nn.functional.l1_loss(out, points, reduction='none').mean((1, 2))
        loss_cmd_raw = torch.nn.functional.l1_loss(out_cmd, actions, reduction='none')
        loss_cmd_pred_raw = torch.nn.functional.l1_loss(out_cmd_pred, actions, reduction='none')

        loss_cmd = loss_cmd_raw.mean(1)
        loss = (loss_point + self.hparams.command_coefficient * loss_cmd).mean()

        if batch_nb == 0:
            self.logger.log_metrics({
                'val_image': visualize(batch, out, between, out_cmd, loss_point, loss_cmd, target_heatmap)
                }, self.global_step)

        return {
                'val_loss': loss.item(),
                'val_point_loss': loss_point.mean().item(),

                'val_cmd_loss': loss_cmd_raw.mean(1).mean().item(),
                'val_steer_loss': loss_cmd_raw[:, 0].mean().item(),
                'val_speed_loss': loss_cmd_raw[:, 1].mean().item(),

                'val_cmd_pred_loss': loss_cmd_pred_raw.mean(1).mean().item(),
                'val_steer_pred_loss': loss_cmd_pred_raw[:, 0].mean().item(),
                'val_speed_pred_loss': loss_cmd_pred_raw[:, 1].mean().item(),
                }

    def validation_epoch_end(self, batch_metrics):
        results = dict()

        for metrics in batch_metrics:
            for key in metrics:
                if key not in results:
                    results[key] = list()

                results[key].append(metrics[key])

        summary = {key: np.mean(val) for key, val in results.items()}
        self.logger.log_metrics(summary, self.global_step)

        return summary

    def configure_optimizers(self):
        optim = torch.optim.Adam(
                list(self.net.parameters()) + list(self.controller.parameters()),
                lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optim, mode='min', factor=0.5, patience=5, min_lr=1e-6,
                verbose=True)

        return [optim], [scheduler]

    def train_dataloader(self):
        return get_dataset(self.hparams.dataset_dir, True, self.hparams.batch_size, sample_by=self.hparams.sample_by)

    def val_dataloader(self):
        return get_dataset(self.hparams.dataset_dir, False, self.hparams.batch_size, sample_by=self.hparams.sample_by)
コード例 #7
0
def main():
    metrics = PerformanceMetrics()
    args = build_argparser().parse_args()

    log.info('Initializing Inference Engine...')
    ie = IECore()

    plugin_config = get_plugin_configs(args.device, args.num_streams,
                                       args.num_threads)

    log.info('Loading network...')

    model = SegmentationModel(ie, args.model)

    pipeline = AsyncPipeline(ie,
                             model,
                             plugin_config,
                             device=args.device,
                             max_num_requests=args.num_infer_requests)

    cap = open_images_capture(args.input, args.loop)

    next_frame_id = 0
    next_frame_id_to_show = 0

    log.info('Starting inference...')
    print(
        "To close the application, press 'CTRL+C' here or switch to the output window and press ESC key"
    )

    visualizer = Visualizer(args.colors)
    presenter = None
    video_writer = cv2.VideoWriter()

    while True:
        if pipeline.is_ready():
            # Get new image/frame
            start_time = perf_counter()
            frame = cap.read()
            if frame is None:
                if next_frame_id == 0:
                    raise ValueError("Can't read an image from the input")
                break
            if next_frame_id == 0:
                presenter = monitors.Presenter(
                    args.utilization_monitors, 55,
                    (round(frame.shape[1] / 4), round(frame.shape[0] / 8)))
                if args.output and not video_writer.open(
                        args.output, cv2.VideoWriter_fourcc(*'MJPG'),
                        cap.fps(), (frame.shape[1], frame.shape[0])):
                    raise RuntimeError("Can't open video writer")
            # Submit for inference
            pipeline.submit_data(frame, next_frame_id, {
                'frame': frame,
                'start_time': start_time
            })
            next_frame_id += 1
        else:
            # Wait for empty request
            pipeline.await_any()

        if pipeline.callback_exceptions:
            raise pipeline.callback_exceptions[0]
        # Process all completed requests
        results = pipeline.get_result(next_frame_id_to_show)
        if results:
            objects, frame_meta = results
            frame = frame_meta['frame']
            start_time = frame_meta['start_time']

            frame = visualizer.overlay_masks(frame, objects)
            presenter.drawGraphs(frame)
            metrics.update(start_time, frame)

            if video_writer.isOpened() and (
                    args.output_limit <= 0
                    or next_frame_id_to_show <= args.output_limit - 1):
                video_writer.write(frame)

            if not args.no_show:
                cv2.imshow('Segmentation Results', frame)
                key = cv2.waitKey(1)
                if key == 27 or key == 'q' or key == 'Q':
                    break
                presenter.handleKey(key)
            next_frame_id_to_show += 1

    pipeline.await_all()
    # Process completed requests
    while pipeline.has_completed_request():
        results = pipeline.get_result(next_frame_id_to_show)
        if results:
            objects, frame_meta = results
            frame = frame_meta['frame']
            start_time = frame_meta['start_time']

            frame = visualizer.overlay_masks(frame, objects)
            presenter.drawGraphs(frame)
            metrics.update(start_time, frame)

            if video_writer.isOpened() and (
                    args.output_limit <= 0
                    or next_frame_id_to_show <= args.output_limit - 1):
                video_writer.write(frame)

            if not args.no_show:
                cv2.imshow('Segmentation Results', frame)
                key = cv2.waitKey(1)
            next_frame_id_to_show += 1
        else:
            break

    metrics.print_total()
    print(presenter.reportMeans())
コード例 #8
0
class ImageModel(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()

        self.hparams = hparams
        self.net = SegmentationModel(4, 4)

        self.teacher = MapModel.load_from_checkpoint(pathlib.Path('/home/bradyzhou/code/carla_random/') / hparams.teacher_path)
        # self.teacher.eval()

        self.converter = Converter()

    def forward(self, x, *args, **kwargs):
        return self.net(x, *args, **kwargs)

    @torch.no_grad()
    def _get_labels(self, batch):
        img, topdown, points, heatmap, heatmap_img, meta = batch
        out = self.teacher.forward(torch.cat([topdown, heatmap], 1))

        return out

    def training_step(self, batch, batch_nb):
        img, topdown, points, heatmap, heatmap_img, meta = batch
        labels_map = self._get_labels(batch)

        labels_cam = self.converter.map_to_cam((labels_map + 1) / 2 * 256)
        labels_cam[..., 0] = (labels_cam[..., 0] / 256) * 2 - 1
        labels_cam[..., 1] = (labels_cam[..., 1] / 144) * 2 - 1

        out = self.forward(torch.cat([img, heatmap_img], 1))

        loss = torch.nn.functional.l1_loss(out, labels_cam, reduction='none').mean((1, 2))
        loss_mean = loss.mean()

        metrics = {'train_loss': loss_mean.item()}

        if batch_nb % 250 == 0:
            metrics['train_image'] = visualize(batch, out, labels_cam, labels_map, loss)

        self.logger.log_metrics(metrics, self.global_step)

        return {'loss': loss_mean}

    def validation_step(self, batch, batch_nb):
        img, topdown, points, heatmap, heatmap_img, meta = batch
        labels_map = self._get_labels(batch)

        labels_cam = self.converter.map_to_cam((labels_map + 1) / 2 * 256)
        labels_cam[..., 0] = (labels_cam[..., 0] / 256) * 2 - 1
        labels_cam[..., 1] = (labels_cam[..., 1] / 144) * 2 - 1

        out = self.forward(torch.cat([img, heatmap_img], 1))

        loss = torch.nn.functional.l1_loss(out, points, reduction='none').mean((1, 2))
        loss_mean = loss.mean()

        if batch_nb == 0:
            self.logger.log_metrics({
                'val_image': visualize(batch, out, labels_cam, labels_map, loss)
                }, self.global_step)

        return {'val_loss': loss_mean.item()}

    def validation_epoch_end(self, outputs):
        results = {'val_loss': list()}

        for output in outputs:
            for key in results:
                results[key].append(output[key])

        summary = {key: np.mean(val) for key, val in results.items()}
        self.logger.log_metrics(summary, self.global_step)

        return summary

    def configure_optimizers(self):
        return torch.optim.Adam(
                self.net.parameters(),
                lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)

    def train_dataloader(self):
        return get_dataset(self.hparams.dataset_dir, True, self.hparams.batch_size)

    def val_dataloader(self):
        return get_dataset(self.hparams.dataset_dir, False, self.hparams.batch_size)