def vis_3d(cfg, genlen=50, cond_steps=10): indices = cfg.vis.indices print('\nLoading data...') dataset = get_dataset(cfg, cfg.val.mode) dataset = Subset(dataset, indices=list(indices)) # dataloader = get_dataloader(cfg, cfg.val.mode) dataloader = DataLoader(dataset, batch_size=cfg.val.batch_size, num_workers=4) print('Data loaded.') print('Initializing model...') model = get_model(cfg) model = model.to(cfg.device) print('Model initialized.') checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name), max_num=cfg.train.max_ckpt) global_step = 0 if cfg.resume: checkpoint = checkpointer.load(cfg.resume_ckpt, model, None) if checkpoint: start_epoch = checkpoint['epoch'] global_step = checkpoint['global_step'] + 1 if cfg.parallel: model = nn.DataParallel(model, device_ids=cfg.device_ids) vislogger = get_vislogger(cfg) resultdir = os.path.join(cfg.resultdir, cfg.exp_name) os.makedirs(resultdir, exist_ok=True) path = osp.join(resultdir, '3d.gif') vislogger.show_gif(model, dataset, indices, cfg.device, cond_steps, genlen, path, fps=7)
def show(cfg): assert cfg.resume, 'You must pass "resume True" to the if --task is "show"' print('Experiment name:', cfg.exp_name) print('Dataset:', cfg.dataset) print('Model name:', cfg.model) print('Resume:', cfg.resume) if cfg.resume: print('Checkpoint:', cfg.resume_ckpt if cfg.resume_ckpt else 'see below') print('Using device:', cfg.device) if 'cuda' in cfg.device: print('Using parallel:', cfg.parallel) if cfg.parallel: print('Device ids:', cfg.device_ids) print('Loading data') dataset = get_dataset(cfg, cfg.show.mode) model = get_model(cfg) model = model.to(cfg.device) checkpointer = Checkpointer(osp.join(cfg.checkpointdir, cfg.exp_name), max_num=cfg.train.max_ckpt) vis_logger = get_vislogger(cfg) model.eval() use_cpu = 'cpu' in cfg.device if cfg.resume_ckpt: checkpoint = checkpointer.load(cfg.resume_ckpt, model, None, None, use_cpu=use_cpu) else: # Load last checkpoint checkpoint = checkpointer.load_last('', model, None, None, use_cpu=use_cpu) if cfg.parallel: assert 'cpu' not in cfg.device, 'You can not set "parallel" to True when you set "device" to cpu' model = nn.DataParallel(model, device_ids=cfg.device_ids) os.makedirs(cfg.demodir, exist_ok=True) img_path = osp.join(cfg.demodir, '{}.png'.format(cfg.exp_name)) vis_logger.show_vis(model, dataset, cfg.show.indices, img_path, device=cfg.device) print('The result image has been saved to {}'.format( osp.abspath(img_path)))
def run_train(args): device = torch.device(args.device) model = build_model(args.model_name, args.num_classes, args.pretrained) model = model.to(device) # build checkpointer, optimizer, scheduler, logger optimizer = build_optimizer(args, model) scheduler = build_lr_scheduler(args, optimizer) checkpointer = Checkpointer(model, optimizer, scheduler, args.experiment, args.checkpoint_period) logger = Logger(os.path.join(args.experiment, 'tf_log')) # data_load train_loader = CIFAR10_loader(args, is_train=True) test_loader = CIFAR10_loader(args, is_train=False) acc1, _ = inference(model, test_loader, logger, device, 0, args) checkpointer.best_acc = acc1 for epoch in tqdm(range(0, args.max_epoch)): train_epoch(model, train_loader, optimizer, len(train_loader) * epoch, checkpointer, device, logger) acc1, m_acc1 = inference(model, test_loader, logger, device, epoch + 1, args) if acc1 > checkpointer.best_acc: checkpointer.save("model_best") checkpointer.best_acc = acc1 scheduler.step() checkpointer.save("model_last")
def run_train(args): device = torch.device(args.device) # build student student = build_model(args.student, args.num_classes, args.pretrained) student = student.to(device) # build teachers teachers = build_teachers(args, device) # build checkpointer, optimizer, scheduler, logger optimizer = build_optimizer(args, student) scheduler = build_lr_scheduler(args, optimizer) checkpointer = Checkpointer(student, optimizer, scheduler, args.experiment, args.checkpoint_period) logger = Logger(os.path.join(args.experiment, 'tf_log')) # objective function to train student loss_fn = loss_fn_kd # data_load train_loader = CIFAR10_loader(args, is_train=True) test_loader = CIFAR10_loader(args, is_train=False) acc1, m_acc1 = inference(student, test_loader, logger, device, 0, args) checkpointer.best_acc = acc1 for epoch in tqdm(range(0, args.max_epoch)): do_train(student, teachers, loss_fn, train_loader, optimizer, checkpointer, device, logger, epoch) acc1, m_acc1 = inference(student, test_loader, logger, device, epoch+1, args) if acc1 > checkpointer.best_acc: checkpointer.save("model_best") checkpointer.best_acc = acc1 scheduler.step() checkpointer.save("model_last")
def eval_maze(cfg, cond_steps=5): print('\nLoading data...') assert cfg.val.mode == 'test', 'Please set cfg.val.mode to "test"' dataset = get_dataset(cfg, cfg.val.mode) dataloader = get_dataloader(cfg, cfg.val.mode) print('Data loaded.') print('Initializing model...') model = get_model(cfg) model = model.to(cfg.device) print('Model initialized.') checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name), max_num=cfg.train.max_ckpt) global_step = 0 if cfg.resume: checkpoint = checkpointer.load(cfg.resume_ckpt, model, None) if checkpoint: start_epoch = checkpoint['epoch'] global_step = checkpoint['global_step'] + 1 if cfg.parallel: model = nn.DataParallel(model, device_ids=cfg.device_ids) evaluator = get_evaluator(cfg) evaldir = os.path.join(cfg.evaldir, cfg.exp_name) os.makedirs(evaldir, exist_ok=True) start = time.perf_counter() model.eval() evaluator.evaluate(model, dataloader, cond_steps, cfg.device, evaldir, cfg.exp_name, cfg.resume_ckpt) file_name = 'maze-{}.json'.format(cfg.exp_name) jsonpath = os.path.join(evaldir, file_name) with open(jsonpath) as f: metrics = json.load(f) num_mean = metrics['num_mean'] f, ax = plt.subplots() ax: plt.Axes ax.plot(num_mean) ax.set_ylim(0, 3.5) ax.set_xlabel('Time step') ax.set_ylabel('#Agents') ax.set_title(cfg.exp_name) plt.savefig(os.path.join(evaldir, 'plot_maze.png'))
def eval(cfg): assert cfg.resume assert cfg.eval.checkpoint in ['best', 'last'] assert cfg.eval.metric in ['ap_dot5', 'ap_avg'] print('Experiment name:', cfg.exp_name) print('Dataset:', cfg.dataset) print('Model name:', cfg.model) print('Resume:', cfg.resume) if cfg.resume: print('Checkpoint:', cfg.resume_ckpt if cfg.resume_ckpt else 'see below') print('Using device:', cfg.device) if 'cuda' in cfg.device: print('Using parallel:', cfg.parallel) if cfg.parallel: print('Device ids:', cfg.device_ids) print('Loading data') testset = get_dataset(cfg, 'test') model = get_model(cfg) model = model.to(cfg.device) checkpointer = Checkpointer(osp.join(cfg.checkpointdir, cfg.exp_name), max_num=cfg.train.max_ckpt) evaluator = get_evaluator(cfg) model.eval() use_cpu = 'cpu' in cfg.device if cfg.resume_ckpt: checkpoint = checkpointer.load(cfg.resume_ckpt, model, None, None, use_cpu) elif cfg.eval.checkpoint == 'last': checkpoint = checkpointer.load_last('', model, None, None, use_cpu) elif cfg.eval.checkpoint == 'best': checkpoint = checkpointer.load_best(cfg.eval.metric, model, None, None, use_cpu) if cfg.parallel: assert 'cpu' not in cfg.device model = nn.DataParallel(model, device_ids=cfg.device_ids) evaldir = osp.join(cfg.evaldir, cfg.exp_name) info = {'exp_name': cfg.exp_name} evaluator.test_eval(model, testset, testset.bb_path, cfg.device, evaldir, info)
def train(cfg): print('Experiment name:', cfg.exp_name) print('Dataset:', cfg.dataset) print('Model name:', cfg.model) print('Resume:', cfg.resume) if cfg.resume: print('Checkpoint:', cfg.resume_ckpt if cfg.resume_ckpt else 'last checkpoint') print('Using device:', cfg.device) if 'cuda' in cfg.device: print('Using parallel:', cfg.parallel) if cfg.parallel: print('Device ids:', cfg.device_ids) print('Loading data') if cfg.exp_name == 'table': data_set = np.load('{}/train/all_set_train.npy'.format( cfg.dataset_roots.TABLE)) data_size = len(data_set) else: trainloader = get_dataloader(cfg, 'train') data_size = len(trainloader) if cfg.train.eval_on: valset = get_dataset(cfg, 'val') # valloader = get_dataloader(cfg, 'val') evaluator = get_evaluator(cfg) model = get_model(cfg) model = model.to(cfg.device) checkpointer = Checkpointer(osp.join(cfg.checkpointdir, cfg.exp_name), max_num=cfg.train.max_ckpt) model.train() optimizer_fg, optimizer_bg = get_optimizers(cfg, model) start_epoch = 0 start_iter = 0 global_step = 0 if cfg.resume: checkpoint = checkpointer.load_last(cfg.resume_ckpt, model, optimizer_fg, optimizer_bg) if checkpoint: start_epoch = checkpoint['epoch'] global_step = checkpoint['global_step'] + 1 if cfg.parallel: model = nn.DataParallel(model, device_ids=cfg.device_ids) writer = SummaryWriter(log_dir=os.path.join(cfg.logdir, cfg.exp_name), flush_secs=30, purge_step=global_step) vis_logger = get_vislogger(cfg) metric_logger = MetricLogger() print('Start training') end_flag = False for epoch in range(start_epoch, cfg.train.max_epochs): if end_flag: break if cfg.exp_name == 'table': # creates indexes and shuffles them. So it can acces the data idx_set = np.arange(data_size) np.random.shuffle(idx_set) idx_set = np.split(idx_set, len(idx_set) / cfg.train.batch_size) data_to_enumerate = idx_set else: trainloader = get_dataloader(cfg, 'train') data_to_enumerate = trainloader data_size = len(trainloader) start = time.perf_counter() for i, enumerated_data in enumerate(data_to_enumerate): end = time.perf_counter() data_time = end - start start = end model.train() if cfg.exp_name == 'table': data_i = data_set[enumerated_data] data_i = torch.from_numpy(data_i).float().to(cfg.device) data_i /= 255 data_i = data_i.permute([0, 3, 1, 2]) imgs = data_i else: imgs = enumerated_data imgs = imgs.to(cfg.device) loss, log = model(imgs, global_step) # In case of using DataParallel loss = loss.mean() optimizer_fg.zero_grad() optimizer_bg.zero_grad() loss.backward() if cfg.train.clip_norm: clip_grad_norm_(model.parameters(), cfg.train.clip_norm) optimizer_fg.step() # if cfg.train.stop_bg == -1 or global_step < cfg.train.stop_bg: optimizer_bg.step() end = time.perf_counter() batch_time = end - start metric_logger.update(data_time=data_time) metric_logger.update(batch_time=batch_time) metric_logger.update(loss=loss.item()) if (global_step) % cfg.train.print_every == 0: start = time.perf_counter() log.update({ 'loss': metric_logger['loss'].median, }) vis_logger.train_vis(writer, log, global_step, 'train') end = time.perf_counter() print( 'exp: {}, epoch: {}, iter: {}/{}, global_step: {}, loss: {:.2f}, batch time: {:.4f}s, data time: {:.4f}s, log time: {:.4f}s' .format(cfg.exp_name, epoch + 1, i + 1, data_size, global_step, metric_logger['loss'].median, metric_logger['batch_time'].avg, metric_logger['data_time'].avg, end - start)) if (global_step) % cfg.train.create_image_every == 0: vis_logger.test_create_image( log, '../output/{}_img_{}.png'.format(cfg.dataset, global_step)) if (global_step) % cfg.train.save_every == 0: start = time.perf_counter() checkpointer.save_last(model, optimizer_fg, optimizer_bg, epoch, global_step) print('Saving checkpoint takes {:.4f}s.'.format( time.perf_counter() - start)) if (global_step) % cfg.train.eval_every == 0 and cfg.train.eval_on: pass '''print('Validating...') start = time.perf_counter() checkpoint = [model, optimizer_fg, optimizer_bg, epoch, global_step] if cfg.exp_name == 'table': evaluator.train_eval(model, None, None, writer, global_step, cfg.device, checkpoint, checkpointer) else: evaluator.train_eval(model, valset, valset.bb_path, writer, global_step, cfg.device, checkpoint, checkpointer) print('Validation takes {:.4f}s.'.format(time.perf_counter() - start))''' start = time.perf_counter() global_step += 1 if global_step > cfg.train.max_steps: end_flag = True break
device = torch.device(cfg.device) model = AIR().to(device) optimizer = optim.Adam([{ 'params': model.air_modules.parameters(), 'lr': cfg.train.model_lr }, { 'params': model.baseline_modules.parameters(), 'lr': cfg.train.baseline_lr }]) # checkpoint start_epoch = 0 checkpoint_path = os.path.join(cfg.checkpointdir, cfg.exp_name) if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) checkpointer = Checkpointer(path=checkpoint_path) if cfg.resume: start_epoch = checkpointer.load(model, optimizer) # tensorboard writer = SummaryWriter(logdir=os.path.join(cfg.logdir, cfg.exp_name)) # presence prior annealing prior_scheduler = PriorScheduler(cfg.anneal.initial, cfg.anneal.final, cfg.anneal.total_steps, cfg.anneal.interval, model, device) weight_scheduler = WeightScheduler(0.0, 1.0, 40000, 500, model, device) print('Start training') with autograd.detect_anomaly(): for epoch in range(start_epoch, cfg.train.max_epochs):
def train(cfg): print('Experiment name:', cfg.exp_name) print('Dataset:', cfg.dataset) print('Model name:', cfg.model) print('Resume:', cfg.resume) if cfg.resume: print('Checkpoint:', cfg.resume_ckpt if cfg.resume_ckpt else 'last checkpoint') print('Using device:', cfg.device) if 'cuda' in cfg.device: print('Using parallel:', cfg.parallel) if cfg.parallel: print('Device ids:', cfg.device_ids) print('Loading data') trainloader = get_dataloader(cfg, 'train') if cfg.train.eval_on: valset = get_dataset(cfg, 'val') # valloader = get_dataloader(cfg, 'val') evaluator = get_evaluator(cfg) model = get_model(cfg) model = model.to(cfg.device) checkpointer = Checkpointer(osp.join(cfg.checkpointdir, cfg.exp_name), max_num=cfg.train.max_ckpt) model.train() optimizer_fg, optimizer_bg = get_optimizers(cfg, model) start_epoch = 0 start_iter = 0 global_step = 0 if cfg.resume: checkpoint = checkpointer.load_last(cfg.resume_ckpt, model, optimizer_fg, optimizer_bg) if checkpoint: start_epoch = checkpoint['epoch'] global_step = checkpoint['global_step'] + 1 if cfg.parallel: model = nn.DataParallel(model, device_ids=cfg.device_ids) writer = SummaryWriter(log_dir=os.path.join(cfg.logdir, cfg.exp_name), flush_secs=30, purge_step=global_step) vis_logger = get_vislogger(cfg) metric_logger = MetricLogger() print('Start training') end_flag = False for epoch in range(start_epoch, cfg.train.max_epochs): if end_flag: break start = time.perf_counter() for i, data in enumerate(trainloader): end = time.perf_counter() data_time = end - start start = end model.train() imgs = data imgs = imgs.to(cfg.device) loss, log = model(imgs, global_step) # In case of using DataParallel loss = loss.mean() optimizer_fg.zero_grad() optimizer_bg.zero_grad() loss.backward() if cfg.train.clip_norm: clip_grad_norm_(model.parameters(), cfg.train.clip_norm) optimizer_fg.step() # if cfg.train.stop_bg == -1 or global_step < cfg.train.stop_bg: optimizer_bg.step() end = time.perf_counter() batch_time = end - start metric_logger.update(data_time=data_time) metric_logger.update(batch_time=batch_time) metric_logger.update(loss=loss.item()) if (global_step) % cfg.train.print_every == 0: start = time.perf_counter() log.update({ 'loss': metric_logger['loss'].median, }) vis_logger.train_vis(writer, log, global_step, 'train') end = time.perf_counter() print( 'exp: {}, epoch: {}, iter: {}/{}, global_step: {}, loss: {:.2f}, batch time: {:.4f}s, data time: {:.4f}s, log time: {:.4f}s'.format( cfg.exp_name, epoch + 1, i + 1, len(trainloader), global_step, metric_logger['loss'].median, metric_logger['batch_time'].avg, metric_logger['data_time'].avg, end - start)) if (global_step) % cfg.train.save_every == 0: start = time.perf_counter() checkpointer.save_last(model, optimizer_fg, optimizer_bg, epoch, global_step) print('Saving checkpoint takes {:.4f}s.'.format(time.perf_counter() - start)) if (global_step) % cfg.train.eval_every == 0 and cfg.train.eval_on: print('Validating...') start = time.perf_counter() checkpoint = [model, optimizer_fg, optimizer_bg, epoch, global_step] evaluator.train_eval(model, valset, valset.bb_path, writer, global_step, cfg.device, checkpoint, checkpointer) print('Validation takes {:.4f}s.'.format(time.perf_counter() - start)) start = time.perf_counter() global_step += 1 if global_step > cfg.train.max_steps: end_flag = True break
def train(cfg): torch.manual_seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) torch.cuda.manual_seed(cfg.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Some info print('Experiment name:', cfg.exp_name) print('Model name:', cfg.model) print('Dataset:', cfg.dataset) print('Resume:', cfg.resume) if cfg.resume: print('Checkpoint:', cfg.resume_ckpt if cfg.resume_ckpt else 'last checkpoint') print('Using device:', cfg.device) if 'cuda' in cfg.device: print('Using parallel:', cfg.parallel) if cfg.parallel: print('Device ids:', cfg.device_ids) print('\nLoading data...') trainloader = get_dataloader(cfg, 'train') if cfg.val.ison or cfg.vis.ison: valset = get_dataset(cfg, 'val') valloader = get_dataloader(cfg, 'val') print('Data loaded.') print('Initializing model...') model = get_model(cfg) model = model.to(cfg.device) print('Model initialized.') model.train() optimizer = get_optimizer(cfg, model) # Checkpointer will print information. checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name), max_num=cfg.train.max_ckpt) start_epoch = 0 start_iter = 0 global_step = 0 if cfg.resume: checkpoint = checkpointer.load(cfg.resume_ckpt, model, optimizer) if checkpoint: start_epoch = checkpoint['epoch'] global_step = checkpoint['global_step'] + 1 if cfg.parallel: model = nn.DataParallel(model, device_ids=cfg.device_ids) writer = SummaryWriter(log_dir=os.path.join(cfg.logdir, cfg.exp_name), purge_step=global_step, flush_secs=30) metric_logger = MetricLogger() vis_logger = get_vislogger(cfg) evaluator = get_evaluator(cfg) print('Start training') end_flag = False for epoch in range(start_epoch, cfg.train.max_epochs): if end_flag: break start = time.perf_counter() for i, data in enumerate(trainloader): end = time.perf_counter() data_time = end - start start = end imgs, *_ = [d.to(cfg.device) for d in data] model.train() loss, log = model(imgs, global_step) # If you are using DataParallel loss = loss.mean() optimizer.zero_grad() loss.backward() if cfg.train.clip_norm: clip_grad_norm_(model.parameters(), cfg.train.clip_norm) optimizer.step() end = time.perf_counter() batch_time = end - start metric_logger.update(data_time=data_time) metric_logger.update(batch_time=batch_time) metric_logger.update(loss=loss.item()) if (global_step + 1) % cfg.train.print_every == 0: start = time.perf_counter() log.update(loss=metric_logger['loss'].median) vis_logger.model_log_vis(writer, log, global_step + 1) end = time.perf_counter() device_text = cfg.device_ids if cfg.parallel else cfg.device print( 'exp: {}, device: {}, epoch: {}, iter: {}/{}, global_step: {}, loss: {:.2f}, batch time: {:.4f}s, data time: {:.4f}s, log time: {:.4f}s' .format(cfg.exp_name, device_text, epoch + 1, i + 1, len(trainloader), global_step + 1, metric_logger['loss'].median, metric_logger['batch_time'].avg, metric_logger['data_time'].avg, end - start)) if (global_step + 1) % cfg.train.save_every == 0: start = time.perf_counter() checkpointer.save(model, optimizer, epoch, global_step) print('Saving checkpoint takes {:.4f}s.'.format( time.perf_counter() - start)) if (global_step + 1) % cfg.vis.vis_every == 0 and cfg.vis.ison: print('Doing visualization...') start = time.perf_counter() vis_logger.train_vis(model, valset, writer, global_step, cfg.vis.indices, cfg.device, cond_steps=cfg.vis.cond_steps, fg_sample=cfg.vis.fg_sample, bg_sample=cfg.vis.bg_sample, num_gen=cfg.vis.num_gen) print( 'Visualization takes {:.4f}s.'.format(time.perf_counter() - start)) if (global_step + 1) % cfg.val.val_every == 0 and cfg.val.ison: print('Doing evaluation...') start = time.perf_counter() evaluator.train_eval( evaluator, os.path.join(cfg.evaldir, cfg.exp_name), cfg.val.metrics, cfg.val.eval_types, cfg.val.intervals, cfg.val.cond_steps, model, valset, valloader, cfg.device, writer, global_step, [model, optimizer, epoch, global_step], checkpointer) print('Evaluation takes {:.4f}s.'.format(time.perf_counter() - start)) start = time.perf_counter() global_step += 1 if global_step >= cfg.train.max_steps: end_flag = True break
def vis_maze(cfg, genlen=50, num_gen=4, cond_steps=5): assert cfg.val.mode == 'val' indices = cfg.vis.indices print('\nLoading data...') dataset = get_dataset(cfg, cfg.val.mode) dataset = Subset(dataset, indices=list(indices)) # dataloader = get_dataloader(cfg, cfg.val.mode) dataloader = DataLoader(dataset, batch_size=cfg.val.batch_size, num_workers=4) print('Data loaded.') print('Initializing model...') model = get_model(cfg) model = model.to(cfg.device) print('Model initialized.') checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name), max_num=cfg.train.max_ckpt) global_step = 0 if cfg.resume: checkpoint = checkpointer.load(cfg.resume_ckpt, model, None) if checkpoint: start_epoch = checkpoint['epoch'] global_step = checkpoint['global_step'] + 1 if cfg.parallel: model = nn.DataParallel(model, device_ids=cfg.device_ids) print("Maze...") start = time.perf_counter() model.eval() results = {} model_fn = lambda model, imgs: model.generate( imgs, cond_steps=cond_steps, fg_sample=True, bg_sample=False) seqs_all = [] for i, data in enumerate(tqdm(dataloader)): data = [d.to(cfg.device) for d in data] # (B, T, C, H, W), (B, T, O, 2), (B, T, O) imgs, *_ = data # (B, T, C, H, W) imgs = imgs[:, :genlen] B, T, C, H, W = imgs.size() seqs = [list() for i in range(B)] for j in range(num_gen): log = model_fn(model, imgs) log = AttrDict(log) # (B, T, C, H, W) recon = log.recon for b in range(B): seqs[b].append((recon[b].permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)) # (N, G, T, 3, H, W) seqs_all.extend(seqs) frames = maze_vis(seqs_all) resultdir = os.path.join(cfg.resultdir, cfg.exp_name) os.makedirs(resultdir, exist_ok=True) path = os.path.join(resultdir, 'maze.gif') make_gif(frames, path)
def eval_balls(cfg): print('\nLoading data...') assert cfg.val.mode == 'test', 'Please set cfg.val.mode to "test"' dataset = get_dataset(cfg, cfg.val.mode) dataloader = get_dataloader(cfg, cfg.val.mode) print('Data loaded.') print('Initializing model...') model = get_model(cfg) model = model.to(cfg.device) print('Model initialized.') checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name), max_num=cfg.train.max_ckpt) global_step = 0 if cfg.resume: checkpoint = checkpointer.load(cfg.resume_ckpt, model, None) if checkpoint: start_epoch = checkpoint['epoch'] global_step = checkpoint['global_step'] + 1 if cfg.parallel: model = nn.DataParallel(model, device_ids=cfg.device_ids) evaluator = get_evaluator(cfg) ### evaldir = os.path.join(cfg.evaldir, cfg.exp_name) os.makedirs(evaldir, exist_ok=True) print("Evaluating...") start = time.perf_counter() model.eval() results = {} for eval_type in cfg.val.eval_types: if eval_type == 'tracking': model_fn = lambda model, imgs: model.track(imgs, discovery_dropout=0) elif eval_type == 'generation': model_fn = lambda model, imgs: model.generate(imgs, cond_steps=cfg.val. cond_steps, fg_sample=False, bg_sample=False) print(f'Evaluating {eval_type}...') skip = cfg.val.cond_steps if eval_type == 'generation' else 0 (iou_summary, euclidean_summary, med_summary) = evaluator.evaluate(eval_type, model, model_fn, skip, dataset, dataloader, evaldir, cfg.device, cfg.val.metrics) # print('iou_summary: {}'.format(iou_summary)) # print('euclidean_summary: {}'.format(euclidean_summary)) # print('med_summary: {}'.format(med_summary)) results[eval_type] = [iou_summary, euclidean_summary, med_summary] for eval_type in cfg.val.eval_types: evaluator.dump_to_json(*results[eval_type], evaldir, 'ours', cfg.dataset.lower(), eval_type, cfg.run_num, cfg.resume_ckpt, cfg.exp_name) print('Evaluation takes {}s.'.format(time.perf_counter() - start)) # Plot figure if 'generation' in cfg.val.eval_types and 'med' in cfg.val.metrics: med_list = results['generation'][-1]['meds_over_time'] assert len(med_list) == 90 steps = np.arange(10, 100) f, ax = plt.subplots() ax: plt.Axes ax.plot(steps, med_list) ax.set_xlabel('Time step') ax.set_ylim(0.0, 0.6) ax.set_ylabel('Position error') ax.set_title(cfg.exp_name) plt.savefig(os.path.join(evaldir, 'plot_balls.png')) print('Plot saved to', os.path.join(evaldir, 'plot_balls.png')) print('MED summed over the first 10 prediction steps: ', sum(med_list[:10]))