def main_train(validation_dataset): logger.critical('Building the model.') model = desc.make_model(args) if args.use_gpu: model.cuda() # Use the customized data parallel if applicable. if args.gpu_parallel: from jactorch.parallel import JacDataParallel # from jactorch.parallel import UserScatteredJacDataParallel as JacDataParallel model = JacDataParallel(model, device_ids=args.gpus).cuda() # Disable the cudnn benchmark. cudnn.benchmark = False trainer = TrainerEnv(model, None) if args.load: if trainer.load_weights(args.load): logger.critical( 'Loaded weights from pretrained model: "{}".'.format( args.load)) from jacinle.utils.meter import GroupMeters meters = GroupMeters() logger.critical('Building the data loader.') validation_dataloader = validation_dataset.make_dataloader( args.batch_size, shuffle=False, drop_last=False, nr_workers=args.data_workers) meters.reset() model.eval() if not os.path.isdir(args.output_attr_path): os.makedirs(args.output_attr_path) validate_attribute(model, validation_dataloader, meters, args.setname, logger, args.output_attr_path) logger.critical( meters.format_simple(args.setname, {k: v for k, v in meters.avg.items() if v != 0}, compressed=False)) return meters
def train(self, data_loader, nr_epochs, verbose=True, meters=None, early_stop=None, print_interval=1): if meters is None: meters = GroupMeters() for epoch in range(1, 1 + nr_epochs): meters.reset() self.train_epoch(data_loader, meters=meters) if verbose and epoch % print_interval == 0: caption = 'Epoch: {}:'.format(epoch) logger.info(meters.format_simple(caption)) if early_stop is not None: flag = early_stop(self._model) if flag: break