def main_train(validation_dataset):
    logger.critical('Building the model.')
    model = desc.make_model(args)

    if args.use_gpu:
        model.cuda()
        # Use the customized data parallel if applicable.
        if args.gpu_parallel:
            from jactorch.parallel import JacDataParallel
            # from jactorch.parallel import UserScatteredJacDataParallel as JacDataParallel
            model = JacDataParallel(model, device_ids=args.gpus).cuda()
        # Disable the cudnn benchmark.
        cudnn.benchmark = False

    trainer = TrainerEnv(model, None)

    if args.load:
        if trainer.load_weights(args.load):
            logger.critical(
                'Loaded weights from pretrained model: "{}".'.format(
                    args.load))

        from jacinle.utils.meter import GroupMeters
        meters = GroupMeters()

    logger.critical('Building the data loader.')
    validation_dataloader = validation_dataset.make_dataloader(
        args.batch_size,
        shuffle=False,
        drop_last=False,
        nr_workers=args.data_workers)

    meters.reset()
    model.eval()

    if not os.path.isdir(args.output_attr_path):
        os.makedirs(args.output_attr_path)
    validate_attribute(model, validation_dataloader, meters, args.setname,
                       logger, args.output_attr_path)
    logger.critical(
        meters.format_simple(args.setname,
                             {k: v
                              for k, v in meters.avg.items() if v != 0},
                             compressed=False))
    return meters
Esempio n. 2
0
    def train(self,
              data_loader,
              nr_epochs,
              verbose=True,
              meters=None,
              early_stop=None,
              print_interval=1):
        if meters is None:
            meters = GroupMeters()

        for epoch in range(1, 1 + nr_epochs):
            meters.reset()
            self.train_epoch(data_loader, meters=meters)
            if verbose and epoch % print_interval == 0:
                caption = 'Epoch: {}:'.format(epoch)
                logger.info(meters.format_simple(caption))
            if early_stop is not None:
                flag = early_stop(self._model)
                if flag:
                    break