def main(): logger.critical('Loading the word embedding.') vocab, word_embeddings = load_word_embedding(args.vse) logger.critical('Building up the model.') model = CompletionModel(word_embeddings) if args.use_gpu: model.cuda() # Disable the cudnn benchmark. model.eval() cudnn.benchmark = False logger.critical('Loading the dataset.') dev_dataset = CompletionDataset(vocab, pjoin(args.data_dir, args.dev_img), pjoin(args.data_dir, args.dev_cap), mode=args.mode) test_dataset = CompletionDataset(vocab, pjoin(args.data_dir, args.test_img), pjoin(args.data_dir, args.test_cap), mode=args.mode) logger.critical('Building up the data loader.') dev_dataloader = make_dataloader(dev_dataset, num_workers=args.data_workers, batch_size=64, shuffle=False, drop_last=False, pin_memory=True) test_dataloader = make_dataloader(test_dataset, num_workers=args.data_workers, batch_size=64, shuffle=False, drop_last=False, pin_memory=True) for epoch_id in range(1, 11): load_weights(model, pjoin(args.load, 'epoch_{}.pth'.format(epoch_id))) for loader in [dev_dataloader, test_dataloader]: meters = GroupMeters() end = time.time() with tqdm_pbar(total=len(loader), leave=False) as pbar: for i, data in enumerate(loader): feed_dict = data feed_dict = mark_volatile(feed_dict) if args.use_gpu: feed_dict = async_copy_to(feed_dict, 0) data_time = time.time() - end; end = time.time() output_dict = model(feed_dict) output_dict = as_numpy(output_dict) gpu_time = time.time() - end; end = time.time() meters.update({k: float(v) for k, v in output_dict.items() if k.startswith('top')}, n=len(feed_dict['image'])) meters.update({'time/data': data_time, 'time/gpu': gpu_time}) pbar.set_description(format_meters('sentid={}'.format(i), meters.val, '{}={:.4f}', ', ')) pbar.update() end = time.time() print(epoch_id, sorted(meters.avg.items()))
def main(): args.dump_dir = ensure_path( osp.join('dumps', args.dataset_name, args.desc_name, args.expr)) args.ckpt_dir = ensure_path(osp.join(args.dump_dir, 'checkpoints')) args.meta_dir = ensure_path(osp.join(args.dump_dir, 'meta')) args.vis_dir = osp.join(args.dump_dir, 'vis', args.run_name) initialize_dataset(args.dataset) build_dataset = get_dataset_builder(args.dataset) dataset = build_dataset(args, configs, args.data_image_root, args.data_scenes_json, args.data_questions_json) dataset_split = int(len(dataset) * args.data_split) if args.data_split <= 1 else int( args.data_split) train_dataset, validation_dataset = dataset.split_trainval(dataset_split) logger.critical('Building the model.') model = desc.make_model(args, train_dataset.unwrapped.vocab) if args.use_gpu: model.cuda() # Use the customized data parallel if applicable. if args.gpu_parallel: from jactorch.parallel import JacDataParallel # from jactorch.parallel import UserScatteredJacDataParallel as JacDataParallel model = JacDataParallel(model, device_ids=args.gpus).cuda() # Disable the cudnn benchmark. cudnn.benchmark = False if args.load: from jactorch.io import load_weights if load_weights(model, args.load): logger.critical( 'Loaded weights from pretrained model: "{}".'.format( args.load)) from jacinle.utils.meter import GroupMeters meters = GroupMeters() if args.embed: from IPython import embed embed() logger.critical('Building the data loader.') validation_dataloader = validation_dataset.make_dataloader( args.batch_size, shuffle=True, drop_last=False, nr_workers=args.data_workers) model.eval() validate_epoch(0, model, validation_dataloader, meters) logger.critical( meters.format_simple('Validation', {k: v for k, v in meters.avg.items() if v != 0}, compressed=False)) return meters
def load_weights(self, filename, **kwargs): return load_weights(self._model, filename, **kwargs)
def main(): # directories if not args.debug: args.dump_dir = ensure_path( osp.join('dumps', args.series_name, args.desc_name)) args.ckpt_dir = ensure_path(osp.join(args.dump_dir, 'checkpoints')) args.meta_dir = ensure_path(osp.join(args.dump_dir, 'meta')) args.meta_file = osp.join(args.meta_dir, args.run_name + '.json') args.log_file = osp.join(args.meta_dir, args.run_name + '.log') args.meter_file = osp.join(args.meta_dir, args.run_name + '.meter.json') if not args.debug: logger.critical('Writing logs to file: "{}".'.format(args.log_file)) set_output_file(args.log_file) logger.critical('Writing metainfo to file: "{}".'.format( args.meta_file)) with open(args.meta_file, 'w') as f: f.write(dump_metainfo(args=args.__dict__, configs=configs)) else: if args.use_tb: logger.warning( 'Disabling the tensorboard in the debug mode.'.format( args.meta_file)) args.use_tb = False # TODO(Jiayuan Mao @ 04/23): load the dataset. logger.critical('Loading the dataset.') validation_dataset = None # configs.validate_dataset_compatibility(train_dataset) # TODO(Jiayuan Mao @ 04/23): build the model. logger.critical('Building the model.') model = desc.make_model(args) if args.use_gpu: model.cuda() # Use the customized data parallel if applicable. if args.gpu_parallel: from jactorch.parallel import JacDataParallel # from jactorch.parallel import UserScatteredJacDataParallel as JacDataParallel model = JacDataParallel(model, device_ids=args.gpus).cuda() # TODO(Jiayuan Mao @ 04/23): disable the cudnn benchmark. # Disable the cudnn benchmark. cudnn.benchmark = False if load_weights(model, args.load): logger.critical('Loaded weights from pretrained model: "{}".'.format( args.load)) if args.use_tb: from jactorch.train.tb import TBLogger, TBGroupMeters tb_logger = TBLogger(args.tb_dir) meters = TBGroupMeters(tb_logger) logger.critical('Writing tensorboard logs to: "{}".'.format( args.tb_dir)) else: from jacinle.utils.meter import GroupMeters meters = GroupMeters() if not args.debug: logger.critical('Writing meter logs to file: "{}".'.format( args.meter_file)) if args.embed: from IPython import embed embed() # TODO(Jiayuan Mao @ 04/23): make the data loader. logger.critical('Building the data loader.') validation_dataloader = validation_dataset.make_dataloader( args.batch_size, shuffle=False, drop_last=False, nr_workers=args.data_workers) model.eval() validate_epoch(model, validation_dataloader, meters) if not args.debug: meters.dump(args.meter_file) logger.critical(meters.format_simple('Test', compressed=False))
def load_weights(self, filename): return load_weights(self._model, filename)