Пример #1
0
def build_callbacks(config):
    callback_list = []
    if config.TRAIN.CALLBACKS.LEARNING_RATE_MONITOR.ENABLE:
        callback_list.append(
            callbacks.LearningRateMonitor(
                logging_interval = config.TRAIN.CALLBACKS.LEARNING_RATE_MONITOR.LOGGING_INTERVAL
                )
            )
    if config.TRAIN.CALLBACKS.MODEL_CHECKPOINT.ENABLE:
        callback_list.append(
            callbacks.ModelCheckpoint(
                dirpath = config.OUTPUT,
                filename = config.TRAIN.CALLBACKS.MODEL_CHECKPOINT.FILE_NAME,
                monitor = config.TRAIN.CALLBACKS.MODEL_CHECKPOINT.MONITOR,
                save_top_k = config.TRAIN.CALLBACKS.MODEL_CHECKPOINT.SAVE_TOP_K,
                mode = config.TRAIN.CALLBACKS.MODEL_CHECKPOINT.MODE
            )
        )
    # if config.TRAIN.CALLBACKS.INTERVAL_STEP_VALIDATE.ENABLE:
    #     callback_list.append(
    #         IntervalStepValidate(config)
    #     )
    return callback_list

# Run validation on specified steps
# class IntervalStepValidate(Callback):
#     def __init__(self, config):
#         self.config = config
#         self.total_steps = config.TRAIN.STEPS
#         self.validation_interval = config.TRAIN.CALLBACKS.INTERVAL_STEP_VALIDATE.INTERVAL

#     def on_batch_end(self, trainer, pl_module):
#         if self.total_steps % self.validation_interval == 0:
#             trainer.validate_step()
Пример #2
0
def load_callbacks():
    callbacks = []
    callbacks.append(
        plc.EarlyStopping(monitor='val_acc',
                          mode='max',
                          patience=10,
                          min_delta=0.001))

    callbacks.append(
        plc.ModelCheckpoint(monitor='val_acc',
                            filename='best-{epoch:02d}-{val_acc:.3f}',
                            save_top_k=1,
                            mode='max',
                            save_last=True))

    if args.lr_scheduler:
        callbacks.append(plc.LearningRateMonitor(logging_interval='epoch'))
    return callbacks
Пример #3
0
def get_loggers_callbacks(args, model=None):

    try:
        # Setup logger(s) params
        csv_logger_params = dict(
            save_dir="./experiments",
            name=os.path.join(*args.save_dir.split("/")[1:-1]),
            version=args.save_dir.split("/")[-1],
        )
        wandb_logger_params = dict(
            log_model=False,
            name=os.path.join(*args.save_dir.split("/")[1:]),
            offline=args.debug,
            project="utime",
            save_dir=args.save_dir,
        )
        loggers = [
            pl_loggers.CSVLogger(**csv_logger_params),
            pl_loggers.WandbLogger(**wandb_logger_params),
        ]
        if model:
            loggers[-1].watch(model)

        # Setup callback(s) params
        checkpoint_monitor_params = dict(
            filepath=os.path.join(args.save_dir,
                                  "{epoch:03d}-{eval_loss:.2f}"),
            monitor=args.checkpoint_monitor,
            save_last=True,
            save_top_k=1,
        )
        earlystopping_parameters = dict(
            monitor=args.earlystopping_monitor,
            patience=args.earlystopping_patience,
        )
        callbacks = [
            pl_callbacks.ModelCheckpoint(**checkpoint_monitor_params),
            pl_callbacks.EarlyStopping(**earlystopping_parameters),
            pl_callbacks.LearningRateMonitor(),
        ]

        return loggers, callbacks
    except AttributeError:
        return None, None
Пример #4
0
def main():
    args = parse_args()

    if args.debug or not args.non_deterministic:
        np.random.seed(1)
        torch.manual_seed(1)
        torch.cuda.manual_seed(1)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

        # torch.set_deterministic(True) # grid_sampler_2d_backward_cuda does not have a deterministic implementation

    if args.debug:
        torch.autograd.set_detect_anomaly(True)

    dataloader_args = EasyDict(
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0 if args.debug else args.data_workers)
    if args.dataset == 'mnist':
        args.num_classes = 10
        args.im_channels = 1
        args.image_size = (40, 40)

        from torchvision.datasets import MNIST

        t = transforms.Compose([
            transforms.RandomCrop(size=(40, 40), pad_if_needed=True),
            transforms.ToTensor(),
            # norm_1c
        ])
        train_dataloader = DataLoader(
            MNIST(data_path / 'mnist', train=True, transform=t, download=True),
            **dataloader_args)
        val_dataloader = DataLoader(
            MNIST(data_path / 'mnist', train=False, transform=t,
                  download=True), **dataloader_args)
    elif args.dataset == 'usps':
        args.num_classes = 10
        args.im_channels = 1
        args.image_size = (40, 40)

        from torchvision.datasets import USPS

        t = transforms.Compose([
            transforms.RandomCrop(size=(40, 40), pad_if_needed=True),
            transforms.ToTensor(),
            # norm_1c
        ])
        train_dataloader = DataLoader(
            USPS(data_path / 'usps', train=True, transform=t, download=True),
            **dataloader_args)
        val_dataloader = DataLoader(
            USPS(data_path / 'usps', train=False, transform=t, download=True),
            **dataloader_args)
    elif args.dataset == 'constellation':

        data_gen = create_constellation(
            batch_size=args.batch_size,
            shuffle_corners=True,
            gaussian_noise=.0,
            drop_prob=0.5,
            which_patterns=[[0], [1], [0]],
            rotation_percent=180 / 360.,
            max_scale=3.,
            min_scale=3.,
            use_scale_schedule=False,
            schedule_steps=0,
        )

        train_dataloader = DataLoader(data_gen, **dataloader_args)
        val_dataloader = DataLoader(data_gen, **dataloader_args)

    elif args.dataset == 'cifar10':
        args.num_classes = 10
        args.im_channels = 3
        args.image_size = (32, 32)

        from torchvision.datasets import CIFAR10

        t = transforms.Compose([transforms.ToTensor()])
        train_dataloader = DataLoader(
            CIFAR10(data_path / 'cifar10',
                    train=True,
                    transform=t,
                    download=True), **dataloader_args)
        val_dataloader = DataLoader(
            CIFAR10(data_path / 'cifar10',
                    train=False,
                    transform=t,
                    download=True), **dataloader_args)
    elif args.dataset == 'svhn':
        args.num_classes = 10
        args.im_channels = 3
        args.image_size = (32, 32)

        from torchvision.datasets import SVHN

        t = transforms.Compose([transforms.ToTensor()])
        train_dataloader = DataLoader(
            SVHN(data_path / 'svhn', split='train', transform=t,
                 download=True), **dataloader_args)
        val_dataloader = DataLoader(
            SVHN(data_path / 'svhn', split='test', transform=t, download=True),
            **dataloader_args)
    else:
        raise NotImplementedError()

    logger = WandbLogger(project=args.log.project,
                         name=args.log.run_name,
                         entity=args.log.team,
                         config=args,
                         offline=not args.log.upload)

    if args.model == 'ccae':
        from scae.modules.attention import SetTransformer
        from scae.modules.capsule import CapsuleLayer
        from scae.models.ccae import CCAE

        encoder = SetTransformer(2)
        decoder = CapsuleLayer(input_dims=32,
                               n_caps=3,
                               n_caps_dims=2,
                               n_votes=4,
                               n_caps_params=32,
                               n_hiddens=128,
                               learn_vote_scale=True,
                               deformations=True,
                               noise_type='uniform',
                               noise_scale=4.,
                               similarity_transform=False)

        model = CCAE(encoder, decoder, args)

        # logger.watch(encoder._encoder, log='all', log_freq=args.log_frequency)
        # logger.watch(decoder, log='all', log_freq=args.log_frequency)
    elif args.model == 'pcae':
        from scae.modules.part_capsule_ae import CapsuleImageEncoder, TemplateImageDecoder
        from scae.models.pcae import PCAE

        encoder = CapsuleImageEncoder(args)
        decoder = TemplateImageDecoder(args)
        model = PCAE(encoder, decoder, args)

        logger.watch(encoder._encoder, log='all', log_freq=args.log.frequency)
        logger.watch(decoder, log='all', log_freq=args.log.frequency)
    elif args.model == 'ocae':
        from scae.modules.object_capsule_ae import SetTransformer, ImageCapsule
        from scae.models.ocae import OCAE

        encoder = SetTransformer()
        decoder = ImageCapsule()
        model = OCAE(encoder, decoder, args)

        #  TODO: after ccae
    else:
        raise NotImplementedError()

    # Execute Experiment
    lr_logger = cb.LearningRateMonitor(logging_interval='step')
    best_checkpointer = cb.ModelCheckpoint(save_top_k=1,
                                           monitor='val_rec_ll',
                                           filepath=logger.experiment.dir)
    last_checkpointer = cb.ModelCheckpoint(save_last=True,
                                           filepath=logger.experiment.dir)
    trainer = pl.Trainer(
        max_epochs=args.num_epochs,
        logger=logger,
        callbacks=[lr_logger, best_checkpointer, last_checkpointer])
    trainer.fit(model, train_dataloader, val_dataloader)
Пример #5
0
def main(args):
    pl.seed_everything(RANDOM_STATE)

    exp_dir = pathlib.Path(args.exp_dir)
    exp_dir.mkdir(exist_ok=True, parents=True)

    train_params = TrainParams()

    train_tr, valid_tr = train_transform(
        train_params.img_size_in_batch), valid_transform(
            train_params.img_size_in_batch)

    ignore_train_images_list = None

    if args.ignore_images is not None:
        ignore_train_images_list = load_ignore_images(args.ignore_images)

    if train_params.train_size == 1:
        datamodule = FullLandmarkDataModule(
            path_to_dir=args.data_dir,
            annot_file=args.annot_file,
            ignore_train_images=ignore_train_images_list,
            train_batch_size=args.train_batch_size,
            val_batch_size=args.valid_batch_size,
            train_num_workers=args.train_num_workers,
            valid_num_workers=args.valid_num_workers,
            random_state=RANDOM_STATE,
            train_transforms=train_tr,
            precompute_data=args.precompute_data)
    else:
        datamodule = TrainTestLandmarkDataModule(
            path_to_dir=args.data_dir,
            annot_file=args.annot_file,
            ignore_train_images=ignore_train_images_list,
            train_batch_size=args.train_batch_size,
            val_batch_size=args.valid_batch_size,
            train_num_workers=args.train_num_workers,
            valid_num_workers=args.valid_num_workers,
            random_state=RANDOM_STATE,
            train_size=train_params.train_size,
            precompute_data=args.precompute_data,
            train_transforms=train_tr,
            val_transforms=valid_tr)

    model = get_model(train_params.num_landmarks, train_params.dropout_prob,
                      train_params.train_backbone)

    opt_params = OptimizerParams()
    scheduler_params = SchedulerPrams()

    target_metric_name = "MSE loss"

    train_module = ModelTrain(
        model=model,
        optimizer_params=opt_params,
        scheduler_params=scheduler_params,
        train_backbone_after_epoch=train_params.train_full_model_after_epoch,
        target_metric_name=target_metric_name,
        save_img_every_train_batch=100)

    checkpoint_dir = exp_dir / "checkpoint"
    checkpoint_dir.mkdir(exist_ok=True, parents=True)

    checkpoint_callback = callbacks.ModelCheckpoint(
        monitor=target_metric_name + '_epoch',
        dirpath=checkpoint_dir,
        filename=f"{{epoch}}-{{{target_metric_name}:.4f}}",
        verbose=True,
        save_last=True,
        save_top_k=2,
        mode="min",
        save_weights_only=False)

    lr_monitor = callbacks.LearningRateMonitor(logging_interval='step')

    log_dir = exp_dir / "logs"
    log_dir.mkdir(exist_ok=True, parents=True)

    logger = TensorBoardLogger(str(log_dir))

    gpus = -1 if torch.cuda.is_available() else None

    if gpus is None:
        logging.getLogger().warning(
            "GPU is not available. Try train on CPU. It may will bew very slow"
        )

    trainer = pl.Trainer(
        amp_backend="native",
        auto_scale_batch_size="binsearch",
        gpus=gpus,
        logger=logger,
        auto_select_gpus=True,
        benchmark=True,
        check_val_every_n_epoch=train_params.check_val_every_n_epoch,
        flush_logs_every_n_steps=train_params.flush_logs_every_n_steps,
        default_root_dir=str(exp_dir),
        deterministic=False,
        fast_dev_run=args.fast_dev_run,
        progress_bar_refresh_rate=10,
        precision=train_params.precision,
        max_epochs=train_params.max_epochs,
        callbacks=[checkpoint_callback, lr_monitor])

    trainer.fit(train_module, datamodule=datamodule)