Esempio n. 1
0
    def test_from_name_file_model(self):
        # test that loading works even if they differ by a prefix
        for trained_model, fresh_model in [
            (self.create_model(), self.create_model()),
            (nn.DataParallel(self.create_model()), self.create_model()),
            (self.create_model(), nn.DataParallel(self.create_model())),
            (
                nn.DataParallel(self.create_model()),
                nn.DataParallel(self.create_model()),
            ),
        ]:
            with TemporaryDirectory() as f:
                checkpointer = Checkpointer(
                    trained_model, save_dir=f, save_to_disk=True
                )
                checkpointer.save("checkpoint_file")

                # on different folders
                with TemporaryDirectory() as g:
                    fresh_checkpointer = Checkpointer(fresh_model, save_dir=g)
                    self.assertFalse(fresh_checkpointer.has_checkpoint())
                    self.assertEqual(fresh_checkpointer.get_checkpoint_file(), "")
                    _ = fresh_checkpointer.load(os.path.join(f, "checkpoint_file.pth"))

            for trained_p, loaded_p in zip(
                trained_model.parameters(), fresh_model.parameters()
            ):
                # different tensor references
                self.assertFalse(id(trained_p) == id(loaded_p))
                # same content
                self.assertTrue(trained_p.equal(loaded_p))
Esempio n. 2
0
def train(cfg, output_dir='', output_dir_merge='', output_dir_refine=''):
    logger = logging.getLogger('shaper.train')

    # build model
    set_random_seed(cfg.RNG_SEED)

    model_merge = nn.DataParallel(PointNetCls(in_channels=3,
                                              out_channels=128)).cuda()

    # build optimizer
    cfg['SCHEDULER']['StepLR']['step_size'] = 150
    cfg['SCHEDULER']['MAX_EPOCH'] = 20000
    optimizer_embed = build_optimizer(cfg, model_merge)

    # build lr scheduler
    scheduler_embed = build_scheduler(cfg, optimizer_embed)
    checkpointer_embed = Checkpointer(model_merge,
                                      optimizer=optimizer_embed,
                                      scheduler=scheduler_embed,
                                      save_dir=output_dir_merge,
                                      logger=logger)
    checkpoint_data_embed = checkpointer_embed.load(
        cfg.MODEL.WEIGHT,
        resume=cfg.AUTO_RESUME,
        resume_states=cfg.RESUME_STATES)

    ckpt_period = cfg.TRAIN.CHECKPOINT_PERIOD

    # build data loader
    # Reset the random seed again in case the initialization of models changes the random state.
    set_random_seed(cfg.RNG_SEED)

    # build tensorboard logger (optionally by comment)
    tensorboard_logger = TensorboardLogger(output_dir_merge)

    # train
    max_epoch = cfg.SCHEDULER.MAX_EPOCH
    start_epoch = checkpoint_data_embed.get('epoch', 0)
    best_metric_name = 'best_{}'.format(cfg.TRAIN.VAL_METRIC)
    best_metric = checkpoint_data_embed.get(best_metric_name, None)
    logger.info('Start training from epoch {}'.format(start_epoch))
    for epoch in range(start_epoch, max_epoch):
        cur_epoch = epoch + 1
        scheduler_embed.step()
        start_time = time.time()
        train_meters = train_one_epoch(
            model_merge,
            cur_epoch,
            optimizer_embed=optimizer_embed,
            output_dir_merge=output_dir_merge,
            max_grad_norm=cfg.OPTIMIZER.MAX_GRAD_NORM,
            freezer=None,
            log_period=cfg.TRAIN.LOG_PERIOD,
        )
        epoch_time = time.time() - start_time
        logger.info('Epoch[{}]-Train {}  total_time: {:.2f}s'.format(
            cur_epoch, train_meters.summary_str, epoch_time))

        tensorboard_logger.add_scalars(train_meters.meters,
                                       cur_epoch,
                                       prefix='train')

        # checkpoint
        if (ckpt_period > 0
                and cur_epoch % ckpt_period == 0) or cur_epoch == max_epoch:
            checkpoint_data_embed['epoch'] = cur_epoch
            checkpoint_data_embed[best_metric_name] = best_metric
            checkpointer_embed.save('model_{:03d}'.format(cur_epoch),
                                    **checkpoint_data_embed)

    return model
Esempio n. 3
0
def train(cfg, output_dir=''):
    logger = logging.getLogger('shaper.train')

    # build model
    set_random_seed(cfg.RNG_SEED)
    model, loss_fn, metric = build_model(cfg)
    logger.info('Build model:\n{}'.format(str(model)))
    model = nn.DataParallel(model).cuda()
    # model = model.cuda()

    # build optimizer
    optimizer = build_optimizer(cfg, model)

    # build lr scheduler
    scheduler = build_scheduler(cfg, optimizer)

    # build checkpointer
    # Note that checkpointer will load state_dict of model, optimizer and scheduler.
    checkpointer = Checkpointer(model,
                                optimizer=optimizer,
                                scheduler=scheduler,
                                save_dir=output_dir,
                                logger=logger)
    checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT,
                                        resume=cfg.AUTO_RESUME,
                                        resume_states=cfg.RESUME_STATES)
    ckpt_period = cfg.TRAIN.CHECKPOINT_PERIOD

    # build freezer
    if cfg.TRAIN.FROZEN_PATTERNS:
        freezer = Freezer(model, cfg.TRAIN.FROZEN_PATTERNS)
        freezer.freeze(verbose=True)  # sanity check
    else:
        freezer = None

    # build data loader
    # Reset the random seed again in case the initialization of models changes the random state.
    set_random_seed(cfg.RNG_SEED)
    train_dataloader = build_dataloader(cfg, mode='train')
    val_period = cfg.TRAIN.VAL_PERIOD
    val_dataloader = build_dataloader(cfg,
                                      mode='val') if val_period > 0 else None

    # build tensorboard logger (optionally by comment)
    tensorboard_logger = TensorboardLogger(output_dir)

    # train
    max_epoch = cfg.SCHEDULER.MAX_EPOCH
    start_epoch = checkpoint_data.get('epoch', 0)
    best_metric_name = 'best_{}'.format(cfg.TRAIN.VAL_METRIC)
    best_metric = checkpoint_data.get(best_metric_name, None)
    logger.info('Start training from epoch {}'.format(start_epoch))
    for epoch in range(start_epoch, max_epoch):
        cur_epoch = epoch + 1
        scheduler.step()
        start_time = time.time()
        train_meters = train_one_epoch(
            model,
            loss_fn,
            metric,
            train_dataloader,
            optimizer=optimizer,
            max_grad_norm=cfg.OPTIMIZER.MAX_GRAD_NORM,
            freezer=freezer,
            log_period=cfg.TRAIN.LOG_PERIOD,
        )
        epoch_time = time.time() - start_time
        logger.info('Epoch[{}]-Train {}  total_time: {:.2f}s'.format(
            cur_epoch, train_meters.summary_str, epoch_time))

        tensorboard_logger.add_scalars(train_meters.meters,
                                       cur_epoch,
                                       prefix='train')

        # checkpoint
        if (ckpt_period > 0
                and cur_epoch % ckpt_period == 0) or cur_epoch == max_epoch:
            checkpoint_data['epoch'] = cur_epoch
            checkpoint_data[best_metric_name] = best_metric
            checkpointer.save('model_{:03d}'.format(cur_epoch),
                              **checkpoint_data)

        # validate
        if val_period > 0 and (cur_epoch % val_period == 0
                               or cur_epoch == max_epoch):
            start_time = time.time()
            val_meters = validate(
                model,
                loss_fn,
                metric,
                val_dataloader,
                log_period=cfg.TEST.LOG_PERIOD,
            )
            epoch_time = time.time() - start_time
            logger.info('Epoch[{}]-Val {}  total_time: {:.2f}s'.format(
                cur_epoch, val_meters.summary_str, epoch_time))

            tensorboard_logger.add_scalars(val_meters.meters,
                                           cur_epoch,
                                           prefix='val')

            # best validation
            if cfg.TRAIN.VAL_METRIC in val_meters.meters:
                cur_metric = val_meters.meters[cfg.TRAIN.VAL_METRIC].global_avg
                if best_metric is None or cur_metric > best_metric:
                    best_metric = cur_metric
                    checkpoint_data['epoch'] = cur_epoch
                    checkpoint_data[best_metric_name] = best_metric
                    checkpointer.save('model_best', **checkpoint_data)

    logger.info('Best val-{} = {}'.format(cfg.TRAIN.VAL_METRIC, best_metric))
    return model