Пример #1
0
def vis_3d(cfg, genlen=50, cond_steps=10):
    indices = cfg.vis.indices
    print('\nLoading data...')
    dataset = get_dataset(cfg, cfg.val.mode)
    dataset = Subset(dataset, indices=list(indices))
    # dataloader = get_dataloader(cfg, cfg.val.mode)
    dataloader = DataLoader(dataset, batch_size=cfg.val.batch_size, num_workers=4)
    print('Data loaded.')
    
    print('Initializing model...')
    model = get_model(cfg)
    model = model.to(cfg.device)
    print('Model initialized.')
    
    checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name), max_num=cfg.train.max_ckpt)
    
    global_step = 0
    if cfg.resume:
        checkpoint = checkpointer.load(cfg.resume_ckpt, model, None)
        if checkpoint:
            start_epoch = checkpoint['epoch']
            global_step = checkpoint['global_step'] + 1
    if cfg.parallel:
        model = nn.DataParallel(model, device_ids=cfg.device_ids)
    vislogger = get_vislogger(cfg)
    
    resultdir = os.path.join(cfg.resultdir, cfg.exp_name)
    os.makedirs(resultdir, exist_ok=True)
    path = osp.join(resultdir, '3d.gif')
    vislogger.show_gif(model, dataset, indices, cfg.device, cond_steps, genlen, path, fps=7)
Пример #2
0
def show(cfg):
    assert cfg.resume, 'You must pass "resume True" to the if --task is "show"'

    print('Experiment name:', cfg.exp_name)
    print('Dataset:', cfg.dataset)
    print('Model name:', cfg.model)
    print('Resume:', cfg.resume)
    if cfg.resume:
        print('Checkpoint:',
              cfg.resume_ckpt if cfg.resume_ckpt else 'see below')
    print('Using device:', cfg.device)
    if 'cuda' in cfg.device:
        print('Using parallel:', cfg.parallel)
    if cfg.parallel:
        print('Device ids:', cfg.device_ids)

    print('Loading data')
    dataset = get_dataset(cfg, cfg.show.mode)
    model = get_model(cfg)
    model = model.to(cfg.device)
    checkpointer = Checkpointer(osp.join(cfg.checkpointdir, cfg.exp_name),
                                max_num=cfg.train.max_ckpt)
    vis_logger = get_vislogger(cfg)
    model.eval()

    use_cpu = 'cpu' in cfg.device
    if cfg.resume_ckpt:
        checkpoint = checkpointer.load(cfg.resume_ckpt,
                                       model,
                                       None,
                                       None,
                                       use_cpu=use_cpu)
    else:
        # Load last checkpoint
        checkpoint = checkpointer.load_last('',
                                            model,
                                            None,
                                            None,
                                            use_cpu=use_cpu)

    if cfg.parallel:
        assert 'cpu' not in cfg.device, 'You can not set "parallel" to True when you set "device" to cpu'
        model = nn.DataParallel(model, device_ids=cfg.device_ids)

    os.makedirs(cfg.demodir, exist_ok=True)
    img_path = osp.join(cfg.demodir, '{}.png'.format(cfg.exp_name))
    vis_logger.show_vis(model,
                        dataset,
                        cfg.show.indices,
                        img_path,
                        device=cfg.device)
    print('The result image has been saved to {}'.format(
        osp.abspath(img_path)))
def run_train(args):

    device = torch.device(args.device)

    model = build_model(args.model_name, args.num_classes, args.pretrained)
    model = model.to(device)
    # build checkpointer, optimizer, scheduler, logger
    optimizer = build_optimizer(args, model)
    scheduler = build_lr_scheduler(args, optimizer)
    checkpointer = Checkpointer(model, optimizer, scheduler, args.experiment,
                                args.checkpoint_period)
    logger = Logger(os.path.join(args.experiment, 'tf_log'))

    # data_load
    train_loader = CIFAR10_loader(args, is_train=True)
    test_loader = CIFAR10_loader(args, is_train=False)

    acc1, _ = inference(model, test_loader, logger, device, 0, args)
    checkpointer.best_acc = acc1
    for epoch in tqdm(range(0, args.max_epoch)):
        train_epoch(model, train_loader, optimizer,
                    len(train_loader) * epoch, checkpointer, device, logger)
        acc1, m_acc1 = inference(model, test_loader, logger, device, epoch + 1,
                                 args)
        if acc1 > checkpointer.best_acc:
            checkpointer.save("model_best")
            checkpointer.best_acc = acc1
        scheduler.step()

    checkpointer.save("model_last")
Пример #4
0
def run_train(args):

    device = torch.device(args.device)
    # build student
    student = build_model(args.student, args.num_classes, args.pretrained)
    student = student.to(device)
    # build teachers
    teachers = build_teachers(args, device)
    # build checkpointer, optimizer, scheduler, logger
    optimizer = build_optimizer(args, student)
    scheduler = build_lr_scheduler(args, optimizer)
    checkpointer = Checkpointer(student, optimizer, scheduler, args.experiment, args.checkpoint_period)
    logger = Logger(os.path.join(args.experiment, 'tf_log'))

    # objective function to train student
    loss_fn = loss_fn_kd

    # data_load
    train_loader = CIFAR10_loader(args, is_train=True)
    test_loader = CIFAR10_loader(args, is_train=False)

    acc1, m_acc1 = inference(student, test_loader, logger, device, 0, args)
    checkpointer.best_acc = acc1
    for epoch in tqdm(range(0, args.max_epoch)):
        do_train(student, teachers, loss_fn, train_loader, optimizer, checkpointer, device, logger, epoch)
        acc1, m_acc1 = inference(student, test_loader, logger, device, epoch+1, args)
        if acc1 > checkpointer.best_acc:
            checkpointer.save("model_best")
            checkpointer.best_acc = acc1
        scheduler.step()
    
    checkpointer.save("model_last")
Пример #5
0
def eval_maze(cfg, cond_steps=5):
    print('\nLoading data...')
    assert cfg.val.mode == 'test', 'Please set cfg.val.mode to "test"'
    dataset = get_dataset(cfg, cfg.val.mode)
    dataloader = get_dataloader(cfg, cfg.val.mode)
    print('Data loaded.')

    print('Initializing model...')
    model = get_model(cfg)
    model = model.to(cfg.device)
    print('Model initialized.')

    checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name),
                                max_num=cfg.train.max_ckpt)

    global_step = 0
    if cfg.resume:
        checkpoint = checkpointer.load(cfg.resume_ckpt, model, None)
        if checkpoint:
            start_epoch = checkpoint['epoch']
            global_step = checkpoint['global_step'] + 1
    if cfg.parallel:
        model = nn.DataParallel(model, device_ids=cfg.device_ids)

    evaluator = get_evaluator(cfg)

    evaldir = os.path.join(cfg.evaldir, cfg.exp_name)
    os.makedirs(evaldir, exist_ok=True)
    start = time.perf_counter()
    model.eval()

    evaluator.evaluate(model, dataloader, cond_steps, cfg.device, evaldir,
                       cfg.exp_name, cfg.resume_ckpt)
    file_name = 'maze-{}.json'.format(cfg.exp_name)
    jsonpath = os.path.join(evaldir, file_name)
    with open(jsonpath) as f:
        metrics = json.load(f)
    num_mean = metrics['num_mean']
    f, ax = plt.subplots()
    ax: plt.Axes
    ax.plot(num_mean)
    ax.set_ylim(0, 3.5)
    ax.set_xlabel('Time step')
    ax.set_ylabel('#Agents')
    ax.set_title(cfg.exp_name)
    plt.savefig(os.path.join(evaldir, 'plot_maze.png'))
Пример #6
0
def eval(cfg):
    assert cfg.resume
    assert cfg.eval.checkpoint in ['best', 'last']
    assert cfg.eval.metric in ['ap_dot5', 'ap_avg']

    print('Experiment name:', cfg.exp_name)
    print('Dataset:', cfg.dataset)
    print('Model name:', cfg.model)
    print('Resume:', cfg.resume)
    if cfg.resume:
        print('Checkpoint:',
              cfg.resume_ckpt if cfg.resume_ckpt else 'see below')
    print('Using device:', cfg.device)
    if 'cuda' in cfg.device:
        print('Using parallel:', cfg.parallel)
    if cfg.parallel:
        print('Device ids:', cfg.device_ids)

    print('Loading data')
    testset = get_dataset(cfg, 'test')
    model = get_model(cfg)
    model = model.to(cfg.device)
    checkpointer = Checkpointer(osp.join(cfg.checkpointdir, cfg.exp_name),
                                max_num=cfg.train.max_ckpt)
    evaluator = get_evaluator(cfg)
    model.eval()

    use_cpu = 'cpu' in cfg.device
    if cfg.resume_ckpt:
        checkpoint = checkpointer.load(cfg.resume_ckpt, model, None, None,
                                       use_cpu)
    elif cfg.eval.checkpoint == 'last':
        checkpoint = checkpointer.load_last('', model, None, None, use_cpu)
    elif cfg.eval.checkpoint == 'best':
        checkpoint = checkpointer.load_best(cfg.eval.metric, model, None, None,
                                            use_cpu)
    if cfg.parallel:
        assert 'cpu' not in cfg.device
        model = nn.DataParallel(model, device_ids=cfg.device_ids)

    evaldir = osp.join(cfg.evaldir, cfg.exp_name)
    info = {'exp_name': cfg.exp_name}
    evaluator.test_eval(model, testset, testset.bb_path, cfg.device, evaldir,
                        info)
Пример #7
0
def train(cfg):

    print('Experiment name:', cfg.exp_name)
    print('Dataset:', cfg.dataset)
    print('Model name:', cfg.model)
    print('Resume:', cfg.resume)
    if cfg.resume:
        print('Checkpoint:',
              cfg.resume_ckpt if cfg.resume_ckpt else 'last checkpoint')
    print('Using device:', cfg.device)
    if 'cuda' in cfg.device:
        print('Using parallel:', cfg.parallel)
    if cfg.parallel:
        print('Device ids:', cfg.device_ids)

    print('Loading data')

    if cfg.exp_name == 'table':
        data_set = np.load('{}/train/all_set_train.npy'.format(
            cfg.dataset_roots.TABLE))
        data_size = len(data_set)
    else:
        trainloader = get_dataloader(cfg, 'train')
        data_size = len(trainloader)
    if cfg.train.eval_on:
        valset = get_dataset(cfg, 'val')
        # valloader = get_dataloader(cfg, 'val')
        evaluator = get_evaluator(cfg)
    model = get_model(cfg)
    model = model.to(cfg.device)
    checkpointer = Checkpointer(osp.join(cfg.checkpointdir, cfg.exp_name),
                                max_num=cfg.train.max_ckpt)
    model.train()

    optimizer_fg, optimizer_bg = get_optimizers(cfg, model)

    start_epoch = 0
    start_iter = 0
    global_step = 0
    if cfg.resume:
        checkpoint = checkpointer.load_last(cfg.resume_ckpt, model,
                                            optimizer_fg, optimizer_bg)
        if checkpoint:
            start_epoch = checkpoint['epoch']
            global_step = checkpoint['global_step'] + 1
    if cfg.parallel:
        model = nn.DataParallel(model, device_ids=cfg.device_ids)

    writer = SummaryWriter(log_dir=os.path.join(cfg.logdir, cfg.exp_name),
                           flush_secs=30,
                           purge_step=global_step)
    vis_logger = get_vislogger(cfg)
    metric_logger = MetricLogger()

    print('Start training')
    end_flag = False
    for epoch in range(start_epoch, cfg.train.max_epochs):
        if end_flag:
            break
        if cfg.exp_name == 'table':
            # creates indexes and shuffles them. So it can acces the data
            idx_set = np.arange(data_size)
            np.random.shuffle(idx_set)
            idx_set = np.split(idx_set, len(idx_set) / cfg.train.batch_size)
            data_to_enumerate = idx_set
        else:
            trainloader = get_dataloader(cfg, 'train')
            data_to_enumerate = trainloader
            data_size = len(trainloader)

        start = time.perf_counter()
        for i, enumerated_data in enumerate(data_to_enumerate):

            end = time.perf_counter()
            data_time = end - start
            start = end

            model.train()
            if cfg.exp_name == 'table':
                data_i = data_set[enumerated_data]
                data_i = torch.from_numpy(data_i).float().to(cfg.device)
                data_i /= 255
                data_i = data_i.permute([0, 3, 1, 2])
                imgs = data_i
            else:

                imgs = enumerated_data
                imgs = imgs.to(cfg.device)

            loss, log = model(imgs, global_step)
            # In case of using DataParallel
            loss = loss.mean()
            optimizer_fg.zero_grad()
            optimizer_bg.zero_grad()
            loss.backward()
            if cfg.train.clip_norm:
                clip_grad_norm_(model.parameters(), cfg.train.clip_norm)

            optimizer_fg.step()

            # if cfg.train.stop_bg == -1 or global_step < cfg.train.stop_bg:
            optimizer_bg.step()

            end = time.perf_counter()
            batch_time = end - start

            metric_logger.update(data_time=data_time)
            metric_logger.update(batch_time=batch_time)
            metric_logger.update(loss=loss.item())

            if (global_step) % cfg.train.print_every == 0:
                start = time.perf_counter()
                log.update({
                    'loss': metric_logger['loss'].median,
                })
                vis_logger.train_vis(writer, log, global_step, 'train')
                end = time.perf_counter()

                print(
                    'exp: {}, epoch: {}, iter: {}/{}, global_step: {}, loss: {:.2f}, batch time: {:.4f}s, data time: {:.4f}s, log time: {:.4f}s'
                    .format(cfg.exp_name, epoch + 1, i + 1, data_size,
                            global_step, metric_logger['loss'].median,
                            metric_logger['batch_time'].avg,
                            metric_logger['data_time'].avg, end - start))
            if (global_step) % cfg.train.create_image_every == 0:
                vis_logger.test_create_image(
                    log,
                    '../output/{}_img_{}.png'.format(cfg.dataset, global_step))
            if (global_step) % cfg.train.save_every == 0:
                start = time.perf_counter()
                checkpointer.save_last(model, optimizer_fg, optimizer_bg,
                                       epoch, global_step)
                print('Saving checkpoint takes {:.4f}s.'.format(
                    time.perf_counter() - start))

            if (global_step) % cfg.train.eval_every == 0 and cfg.train.eval_on:
                pass
                '''print('Validating...')
                start = time.perf_counter()
                checkpoint = [model, optimizer_fg, optimizer_bg, epoch, global_step]
                if cfg.exp_name == 'table':
                    evaluator.train_eval(model, None, None, writer, global_step, cfg.device, checkpoint, checkpointer)
                else:
                    evaluator.train_eval(model, valset, valset.bb_path, writer, global_step, cfg.device, checkpoint, checkpointer)

                print('Validation takes {:.4f}s.'.format(time.perf_counter() - start))'''

            start = time.perf_counter()
            global_step += 1
            if global_step > cfg.train.max_steps:
                end_flag = True
                break
Пример #8
0
    device = torch.device(cfg.device)
    model = AIR().to(device)
    optimizer = optim.Adam([{
        'params': model.air_modules.parameters(),
        'lr': cfg.train.model_lr
    }, {
        'params': model.baseline_modules.parameters(),
        'lr': cfg.train.baseline_lr
    }])

    # checkpoint
    start_epoch = 0
    checkpoint_path = os.path.join(cfg.checkpointdir, cfg.exp_name)
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    checkpointer = Checkpointer(path=checkpoint_path)
    if cfg.resume:
        start_epoch = checkpointer.load(model, optimizer)

    # tensorboard
    writer = SummaryWriter(logdir=os.path.join(cfg.logdir, cfg.exp_name))

    # presence prior annealing
    prior_scheduler = PriorScheduler(cfg.anneal.initial, cfg.anneal.final,
                                     cfg.anneal.total_steps,
                                     cfg.anneal.interval, model, device)
    weight_scheduler = WeightScheduler(0.0, 1.0, 40000, 500, model, device)

    print('Start training')
    with autograd.detect_anomaly():
        for epoch in range(start_epoch, cfg.train.max_epochs):
Пример #9
0
def train(cfg):
    
    print('Experiment name:', cfg.exp_name)
    print('Dataset:', cfg.dataset)
    print('Model name:', cfg.model)
    print('Resume:', cfg.resume)
    if cfg.resume:
        print('Checkpoint:', cfg.resume_ckpt if cfg.resume_ckpt else 'last checkpoint')
    print('Using device:', cfg.device)
    if 'cuda' in cfg.device:
        print('Using parallel:', cfg.parallel)
    if cfg.parallel:
        print('Device ids:', cfg.device_ids)
    
    print('Loading data')

    trainloader = get_dataloader(cfg, 'train')
    if cfg.train.eval_on:
        valset = get_dataset(cfg, 'val')
        # valloader = get_dataloader(cfg, 'val')
        evaluator = get_evaluator(cfg)
    model = get_model(cfg)
    model = model.to(cfg.device)
    checkpointer = Checkpointer(osp.join(cfg.checkpointdir, cfg.exp_name), max_num=cfg.train.max_ckpt)
    model.train()

    optimizer_fg, optimizer_bg = get_optimizers(cfg, model)

    start_epoch = 0
    start_iter = 0
    global_step = 0
    if cfg.resume:
        checkpoint = checkpointer.load_last(cfg.resume_ckpt, model, optimizer_fg, optimizer_bg)
        if checkpoint:
            start_epoch = checkpoint['epoch']
            global_step = checkpoint['global_step'] + 1
    if cfg.parallel:
        model = nn.DataParallel(model, device_ids=cfg.device_ids)
    
    writer = SummaryWriter(log_dir=os.path.join(cfg.logdir, cfg.exp_name), flush_secs=30, purge_step=global_step)
    vis_logger = get_vislogger(cfg)
    metric_logger =  MetricLogger()

    print('Start training')
    end_flag = False
    for epoch in range(start_epoch, cfg.train.max_epochs):
        if end_flag:
            break
    
        start = time.perf_counter()
        for i, data in enumerate(trainloader):
        
            end = time.perf_counter()
            data_time = end - start
            start = end
        
            model.train()
            imgs = data
            imgs = imgs.to(cfg.device)
            loss, log = model(imgs, global_step)
            # In case of using DataParallel
            loss = loss.mean()
            optimizer_fg.zero_grad()
            optimizer_bg.zero_grad()
            loss.backward()
            if cfg.train.clip_norm:
                clip_grad_norm_(model.parameters(), cfg.train.clip_norm)
        
            optimizer_fg.step()
        
            # if cfg.train.stop_bg == -1 or global_step < cfg.train.stop_bg:
            optimizer_bg.step()
        
            end = time.perf_counter()
            batch_time = end - start
        
            metric_logger.update(data_time=data_time)
            metric_logger.update(batch_time=batch_time)
            metric_logger.update(loss=loss.item())
        
            if (global_step) % cfg.train.print_every == 0:
                start = time.perf_counter()
                log.update({
                    'loss': metric_logger['loss'].median,
                })
                vis_logger.train_vis(writer, log, global_step, 'train')
                end = time.perf_counter()
            
                print(
                    'exp: {}, epoch: {}, iter: {}/{}, global_step: {}, loss: {:.2f}, batch time: {:.4f}s, data time: {:.4f}s, log time: {:.4f}s'.format(
                        cfg.exp_name, epoch + 1, i + 1, len(trainloader), global_step, metric_logger['loss'].median,
                        metric_logger['batch_time'].avg, metric_logger['data_time'].avg, end - start))
                
            if (global_step) % cfg.train.save_every == 0:
                start = time.perf_counter()
                checkpointer.save_last(model, optimizer_fg, optimizer_bg, epoch, global_step)
                print('Saving checkpoint takes {:.4f}s.'.format(time.perf_counter() - start))
        
            if (global_step) % cfg.train.eval_every == 0 and cfg.train.eval_on:
                print('Validating...')
                start = time.perf_counter()
                checkpoint = [model, optimizer_fg, optimizer_bg, epoch, global_step]
                evaluator.train_eval(model, valset, valset.bb_path, writer, global_step, cfg.device, checkpoint, checkpointer)
                print('Validation takes {:.4f}s.'.format(time.perf_counter() - start))
        
            start = time.perf_counter()
            global_step += 1
            if global_step > cfg.train.max_steps:
                end_flag = True
                break
Пример #10
0
def train(cfg):
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    torch.cuda.manual_seed(cfg.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Some info
    print('Experiment name:', cfg.exp_name)
    print('Model name:', cfg.model)
    print('Dataset:', cfg.dataset)
    print('Resume:', cfg.resume)
    if cfg.resume:
        print('Checkpoint:',
              cfg.resume_ckpt if cfg.resume_ckpt else 'last checkpoint')
    print('Using device:', cfg.device)
    if 'cuda' in cfg.device:
        print('Using parallel:', cfg.parallel)
    if cfg.parallel:
        print('Device ids:', cfg.device_ids)

    print('\nLoading data...')

    trainloader = get_dataloader(cfg, 'train')
    if cfg.val.ison or cfg.vis.ison:
        valset = get_dataset(cfg, 'val')
        valloader = get_dataloader(cfg, 'val')
    print('Data loaded.')

    print('Initializing model...')
    model = get_model(cfg)
    model = model.to(cfg.device)
    print('Model initialized.')
    model.train()

    optimizer = get_optimizer(cfg, model)

    # Checkpointer will print information.
    checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name),
                                max_num=cfg.train.max_ckpt)

    start_epoch = 0
    start_iter = 0
    global_step = 0
    if cfg.resume:
        checkpoint = checkpointer.load(cfg.resume_ckpt, model, optimizer)
        if checkpoint:
            start_epoch = checkpoint['epoch']
            global_step = checkpoint['global_step'] + 1
    if cfg.parallel:
        model = nn.DataParallel(model, device_ids=cfg.device_ids)

    writer = SummaryWriter(log_dir=os.path.join(cfg.logdir, cfg.exp_name),
                           purge_step=global_step,
                           flush_secs=30)
    metric_logger = MetricLogger()
    vis_logger = get_vislogger(cfg)
    evaluator = get_evaluator(cfg)

    print('Start training')
    end_flag = False
    for epoch in range(start_epoch, cfg.train.max_epochs):
        if end_flag: break
        start = time.perf_counter()
        for i, data in enumerate(trainloader):
            end = time.perf_counter()
            data_time = end - start
            start = end

            imgs, *_ = [d.to(cfg.device) for d in data]
            model.train()
            loss, log = model(imgs, global_step)
            # If you are using DataParallel
            loss = loss.mean()
            optimizer.zero_grad()
            loss.backward()
            if cfg.train.clip_norm:
                clip_grad_norm_(model.parameters(), cfg.train.clip_norm)
            optimizer.step()

            end = time.perf_counter()
            batch_time = end - start

            metric_logger.update(data_time=data_time)
            metric_logger.update(batch_time=batch_time)
            metric_logger.update(loss=loss.item())

            if (global_step + 1) % cfg.train.print_every == 0:
                start = time.perf_counter()
                log.update(loss=metric_logger['loss'].median)
                vis_logger.model_log_vis(writer, log, global_step + 1)
                end = time.perf_counter()
                device_text = cfg.device_ids if cfg.parallel else cfg.device
                print(
                    'exp: {}, device: {}, epoch: {}, iter: {}/{}, global_step: {}, loss: {:.2f}, batch time: {:.4f}s, data time: {:.4f}s, log time: {:.4f}s'
                    .format(cfg.exp_name, device_text, epoch + 1, i + 1,
                            len(trainloader), global_step + 1,
                            metric_logger['loss'].median,
                            metric_logger['batch_time'].avg,
                            metric_logger['data_time'].avg, end - start))

            if (global_step + 1) % cfg.train.save_every == 0:
                start = time.perf_counter()
                checkpointer.save(model, optimizer, epoch, global_step)
                print('Saving checkpoint takes {:.4f}s.'.format(
                    time.perf_counter() - start))

            if (global_step + 1) % cfg.vis.vis_every == 0 and cfg.vis.ison:
                print('Doing visualization...')
                start = time.perf_counter()
                vis_logger.train_vis(model,
                                     valset,
                                     writer,
                                     global_step,
                                     cfg.vis.indices,
                                     cfg.device,
                                     cond_steps=cfg.vis.cond_steps,
                                     fg_sample=cfg.vis.fg_sample,
                                     bg_sample=cfg.vis.bg_sample,
                                     num_gen=cfg.vis.num_gen)
                print(
                    'Visualization takes {:.4f}s.'.format(time.perf_counter() -
                                                          start))

            if (global_step + 1) % cfg.val.val_every == 0 and cfg.val.ison:
                print('Doing evaluation...')
                start = time.perf_counter()
                evaluator.train_eval(
                    evaluator, os.path.join(cfg.evaldir,
                                            cfg.exp_name), cfg.val.metrics,
                    cfg.val.eval_types, cfg.val.intervals, cfg.val.cond_steps,
                    model, valset, valloader, cfg.device, writer, global_step,
                    [model, optimizer, epoch, global_step], checkpointer)
                print('Evaluation takes {:.4f}s.'.format(time.perf_counter() -
                                                         start))

            start = time.perf_counter()
            global_step += 1
            if global_step >= cfg.train.max_steps:
                end_flag = True
                break
Пример #11
0
def vis_maze(cfg, genlen=50, num_gen=4, cond_steps=5):
    assert cfg.val.mode == 'val'
    indices = cfg.vis.indices
    print('\nLoading data...')
    dataset = get_dataset(cfg, cfg.val.mode)
    dataset = Subset(dataset, indices=list(indices))
    # dataloader = get_dataloader(cfg, cfg.val.mode)
    dataloader = DataLoader(dataset,
                            batch_size=cfg.val.batch_size,
                            num_workers=4)
    print('Data loaded.')

    print('Initializing model...')
    model = get_model(cfg)
    model = model.to(cfg.device)
    print('Model initialized.')

    checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name),
                                max_num=cfg.train.max_ckpt)

    global_step = 0
    if cfg.resume:
        checkpoint = checkpointer.load(cfg.resume_ckpt, model, None)
        if checkpoint:
            start_epoch = checkpoint['epoch']
            global_step = checkpoint['global_step'] + 1
    if cfg.parallel:
        model = nn.DataParallel(model, device_ids=cfg.device_ids)

    print("Maze...")
    start = time.perf_counter()
    model.eval()
    results = {}

    model_fn = lambda model, imgs: model.generate(
        imgs, cond_steps=cond_steps, fg_sample=True, bg_sample=False)
    seqs_all = []
    for i, data in enumerate(tqdm(dataloader)):
        data = [d.to(cfg.device) for d in data]
        # (B, T, C, H, W), (B, T, O, 2), (B, T, O)
        imgs, *_ = data
        # (B, T, C, H, W)
        imgs = imgs[:, :genlen]
        B, T, C, H, W = imgs.size()

        seqs = [list() for i in range(B)]
        for j in range(num_gen):
            log = model_fn(model, imgs)
            log = AttrDict(log)
            # (B, T, C, H, W)
            recon = log.recon
            for b in range(B):
                seqs[b].append((recon[b].permute(0, 2, 3, 1).cpu().numpy() *
                                255).astype(np.uint8))

        # (N, G, T, 3, H, W)
        seqs_all.extend(seqs)
    frames = maze_vis(seqs_all)
    resultdir = os.path.join(cfg.resultdir, cfg.exp_name)
    os.makedirs(resultdir, exist_ok=True)
    path = os.path.join(resultdir, 'maze.gif')
    make_gif(frames, path)
Пример #12
0
def eval_balls(cfg):
    print('\nLoading data...')
    assert cfg.val.mode == 'test', 'Please set cfg.val.mode to "test"'
    dataset = get_dataset(cfg, cfg.val.mode)
    dataloader = get_dataloader(cfg, cfg.val.mode)
    print('Data loaded.')

    print('Initializing model...')
    model = get_model(cfg)
    model = model.to(cfg.device)
    print('Model initialized.')

    checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name),
                                max_num=cfg.train.max_ckpt)

    global_step = 0
    if cfg.resume:
        checkpoint = checkpointer.load(cfg.resume_ckpt, model, None)
        if checkpoint:
            start_epoch = checkpoint['epoch']
            global_step = checkpoint['global_step'] + 1
    if cfg.parallel:
        model = nn.DataParallel(model, device_ids=cfg.device_ids)

    evaluator = get_evaluator(cfg)
    ###

    evaldir = os.path.join(cfg.evaldir, cfg.exp_name)
    os.makedirs(evaldir, exist_ok=True)
    print("Evaluating...")
    start = time.perf_counter()
    model.eval()
    results = {}
    for eval_type in cfg.val.eval_types:
        if eval_type == 'tracking':
            model_fn = lambda model, imgs: model.track(imgs,
                                                       discovery_dropout=0)
        elif eval_type == 'generation':
            model_fn = lambda model, imgs: model.generate(imgs,
                                                          cond_steps=cfg.val.
                                                          cond_steps,
                                                          fg_sample=False,
                                                          bg_sample=False)

        print(f'Evaluating {eval_type}...')
        skip = cfg.val.cond_steps if eval_type == 'generation' else 0
        (iou_summary, euclidean_summary,
         med_summary) = evaluator.evaluate(eval_type, model, model_fn, skip,
                                           dataset, dataloader, evaldir,
                                           cfg.device, cfg.val.metrics)
        # print('iou_summary: {}'.format(iou_summary))
        # print('euclidean_summary: {}'.format(euclidean_summary))
        # print('med_summary: {}'.format(med_summary))

        results[eval_type] = [iou_summary, euclidean_summary, med_summary]

    for eval_type in cfg.val.eval_types:
        evaluator.dump_to_json(*results[eval_type], evaldir, 'ours',
                               cfg.dataset.lower(), eval_type, cfg.run_num,
                               cfg.resume_ckpt, cfg.exp_name)
    print('Evaluation takes {}s.'.format(time.perf_counter() - start))

    # Plot figure
    if 'generation' in cfg.val.eval_types and 'med' in cfg.val.metrics:
        med_list = results['generation'][-1]['meds_over_time']
        assert len(med_list) == 90
        steps = np.arange(10, 100)
        f, ax = plt.subplots()
        ax: plt.Axes
        ax.plot(steps, med_list)
        ax.set_xlabel('Time step')
        ax.set_ylim(0.0, 0.6)
        ax.set_ylabel('Position error')
        ax.set_title(cfg.exp_name)
        plt.savefig(os.path.join(evaldir, 'plot_balls.png'))
        print('Plot saved to', os.path.join(evaldir, 'plot_balls.png'))
        print('MED summed over the first 10 prediction steps: ',
              sum(med_list[:10]))