コード例 #1
0
def train(args):
    iters, vocab = get_iterator(args)

    model = get_model(args, vocab)
    loss_fn = get_loss(args, vocab)
    optimizer = get_optimizer(args, model)

    trainer = get_trainer(args, model, loss_fn, optimizer)
    metrics = get_metrics(args, vocab)
    evaluator = get_evaluator(args, model, loss_fn, metrics)

    logger = get_logger(args)
    @trainer.on(Events.STARTED)
    def on_training_started(engine):
        print("Begin Training")

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_iter_results(engine):
        log_results(logger, 'train/iter', engine.state, engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate_epoch(engine):
        log_results(logger, 'train/epoch', engine.state, engine.state.epoch)
        state = evaluate_once(evaluator, iterator=iters['val'])
        log_results(logger, 'valid/epoch', state, engine.state.epoch)

    trainer.run(iters['train'], max_epochs=args.max_epochs)
コード例 #2
0
ファイル: train.py プロジェクト: sjoerdapp/vtt_challenge_2019
def train(args):
    args, model, iters, vocab, ckpt_available = get_model_ckpt(args)

    if ckpt_available:
        print("loaded checkpoint {}".format(args.ckpt_name))
    loss_fn = get_loss(args, vocab)
    optimizer = get_optimizer(args, model)

    trainer = get_trainer(args, model, loss_fn, optimizer)
    metrics = get_metrics(args, vocab)
    evaluator = get_evaluator(args, model, loss_fn, metrics)

    logger = get_logger(args)
    @trainer.on(Events.STARTED)
    def on_training_started(engine):
        print("Begin Training")

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_iter_results(engine):
        log_results(logger, 'train/iter', engine.state, engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate_epoch(engine):
        log_results(logger, 'train/epoch', engine.state, engine.state.epoch)
        state = evaluate_once(evaluator, iterator=iters['val'])
        log_results(logger, 'valid/epoch', state, engine.state.epoch)
        save_ckpt(args, engine.state.epoch, engine.state.metrics['loss'], model, vocab)

    trainer.run(iters['train'], max_epochs=args.max_epochs)
コード例 #3
0
def train(args):
    args, model, iters, vocab, ckpt_available = get_model_ckpt(args)

    if ckpt_available:
        print("loaded checkpoint {}".format(args.ckpt_name))
    loss_fn = get_loss(args, vocab)
    optimizer = get_optimizer(args, model)

    pretrainer = get_pretrainer(args, model, loss_fn, optimizer)
    trainer = get_trainer(args, model, loss_fn, optimizer)

    metrics = get_metrics(args, vocab)
    evaluator = get_evaluator(args, model, loss_fn, metrics)

    logger = get_logger(args)

    @pretrainer.on(Events.STARTED)
    def on_training_started(engine):
        print("Begin Pretraining")

    @pretrainer.on(Events.ITERATION_COMPLETED)
    def log_iter_results(engine):
        log_results(logger, 'pretrain/iter', engine.state, engine.state.iteration)

    @pretrainer.on(Events.EPOCH_COMPLETED)
    def evaluate_epoch(engine):
        log_results(logger, 'pretrain/epoch', engine.state, engine.state.epoch)

    """
    @pretrainer.on(Events.COMPLETED)
    def unfreeze_language_model(engine):
        for param in model.module.language_model.base_model.parameters():
            param.requires_grad = True
    """

    @trainer.on(Events.STARTED)
    def on_training_started(engine):
        print("Begin Training")

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_iter_results(engine):
        log_results(logger, 'train/iter', engine.state, engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate_epoch(engine):
        log_results(logger, 'train/epoch', engine.state, engine.state.epoch)
        state = evaluate_once(evaluator, iterator=iters['val'])
        log_results(logger, 'valid/epoch', state, engine.state.epoch)
        log_results_cmd('valid/epoch', state, engine.state.epoch)
        save_ckpt(args, engine.state.epoch, engine.state.metrics['loss'], model, vocab)
        evaluate_by_logic_level(args, model, iterator=iters['val'])

    if args.pretrain_epochs > 0:
        pretrainer.run(iters['pretrain'], max_epochs=args.pretrain_epochs) 
    trainer.run(iters['train'], max_epochs=args.max_epochs)
コード例 #4
0
ファイル: eval_maze.py プロジェクト: zhixuan-lin/G-SWM
def eval_maze(cfg, cond_steps=5):
    print('\nLoading data...')
    assert cfg.val.mode == 'test', 'Please set cfg.val.mode to "test"'
    dataset = get_dataset(cfg, cfg.val.mode)
    dataloader = get_dataloader(cfg, cfg.val.mode)
    print('Data loaded.')

    print('Initializing model...')
    model = get_model(cfg)
    model = model.to(cfg.device)
    print('Model initialized.')

    checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name),
                                max_num=cfg.train.max_ckpt)

    global_step = 0
    if cfg.resume:
        checkpoint = checkpointer.load(cfg.resume_ckpt, model, None)
        if checkpoint:
            start_epoch = checkpoint['epoch']
            global_step = checkpoint['global_step'] + 1
    if cfg.parallel:
        model = nn.DataParallel(model, device_ids=cfg.device_ids)

    evaluator = get_evaluator(cfg)

    evaldir = os.path.join(cfg.evaldir, cfg.exp_name)
    os.makedirs(evaldir, exist_ok=True)
    start = time.perf_counter()
    model.eval()

    evaluator.evaluate(model, dataloader, cond_steps, cfg.device, evaldir,
                       cfg.exp_name, cfg.resume_ckpt)
    file_name = 'maze-{}.json'.format(cfg.exp_name)
    jsonpath = os.path.join(evaldir, file_name)
    with open(jsonpath) as f:
        metrics = json.load(f)
    num_mean = metrics['num_mean']
    f, ax = plt.subplots()
    ax: plt.Axes
    ax.plot(num_mean)
    ax.set_ylim(0, 3.5)
    ax.set_xlabel('Time step')
    ax.set_ylabel('#Agents')
    ax.set_title(cfg.exp_name)
    plt.savefig(os.path.join(evaldir, 'plot_maze.png'))
コード例 #5
0
def train(args):
    args, model, iters, ckpt_available = get_model_ckpt(args)

    if ckpt_available:
        print("loaded checkpoint {}".format(args.ckpt_name))
    loss_fn = get_loss(args)
    optimizer = get_optimizer(args, model)

    trainer = get_trainer(args, model, loss_fn, optimizer)

    metrics = get_metrics(args)
    evaluator = get_evaluator(args, model, loss_fn, metrics)

    logger = get_logger(args)


    trainer.run(iters['train']), max_epochs=args.max_epochs)
コード例 #6
0
def train(cfg):
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    torch.cuda.manual_seed(cfg.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Some info
    print('Experiment name:', cfg.exp_name)
    print('Model name:', cfg.model)
    print('Dataset:', cfg.dataset)
    print('Resume:', cfg.resume)
    if cfg.resume:
        print('Checkpoint:',
              cfg.resume_ckpt if cfg.resume_ckpt else 'last checkpoint')
    print('Using device:', cfg.device)
    if 'cuda' in cfg.device:
        print('Using parallel:', cfg.parallel)
    if cfg.parallel:
        print('Device ids:', cfg.device_ids)

    print('\nLoading data...')

    trainloader = get_dataloader(cfg, 'train')
    if cfg.val.ison or cfg.vis.ison:
        valset = get_dataset(cfg, 'val')
        valloader = get_dataloader(cfg, 'val')
    print('Data loaded.')

    print('Initializing model...')
    model = get_model(cfg)
    model = model.to(cfg.device)
    print('Model initialized.')
    model.train()

    optimizer = get_optimizer(cfg, model)

    # Checkpointer will print information.
    checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name),
                                max_num=cfg.train.max_ckpt)

    start_epoch = 0
    start_iter = 0
    global_step = 0
    if cfg.resume:
        checkpoint = checkpointer.load(cfg.resume_ckpt, model, optimizer)
        if checkpoint:
            start_epoch = checkpoint['epoch']
            global_step = checkpoint['global_step'] + 1
    if cfg.parallel:
        model = nn.DataParallel(model, device_ids=cfg.device_ids)

    writer = SummaryWriter(log_dir=os.path.join(cfg.logdir, cfg.exp_name),
                           purge_step=global_step,
                           flush_secs=30)
    metric_logger = MetricLogger()
    vis_logger = get_vislogger(cfg)
    evaluator = get_evaluator(cfg)

    print('Start training')
    end_flag = False
    for epoch in range(start_epoch, cfg.train.max_epochs):
        if end_flag: break
        start = time.perf_counter()
        for i, data in enumerate(trainloader):
            end = time.perf_counter()
            data_time = end - start
            start = end

            imgs, *_ = [d.to(cfg.device) for d in data]
            model.train()
            loss, log = model(imgs, global_step)
            # If you are using DataParallel
            loss = loss.mean()
            optimizer.zero_grad()
            loss.backward()
            if cfg.train.clip_norm:
                clip_grad_norm_(model.parameters(), cfg.train.clip_norm)
            optimizer.step()

            end = time.perf_counter()
            batch_time = end - start

            metric_logger.update(data_time=data_time)
            metric_logger.update(batch_time=batch_time)
            metric_logger.update(loss=loss.item())

            if (global_step + 1) % cfg.train.print_every == 0:
                start = time.perf_counter()
                log.update(loss=metric_logger['loss'].median)
                vis_logger.model_log_vis(writer, log, global_step + 1)
                end = time.perf_counter()
                device_text = cfg.device_ids if cfg.parallel else cfg.device
                print(
                    'exp: {}, device: {}, epoch: {}, iter: {}/{}, global_step: {}, loss: {:.2f}, batch time: {:.4f}s, data time: {:.4f}s, log time: {:.4f}s'
                    .format(cfg.exp_name, device_text, epoch + 1, i + 1,
                            len(trainloader), global_step + 1,
                            metric_logger['loss'].median,
                            metric_logger['batch_time'].avg,
                            metric_logger['data_time'].avg, end - start))

            if (global_step + 1) % cfg.train.save_every == 0:
                start = time.perf_counter()
                checkpointer.save(model, optimizer, epoch, global_step)
                print('Saving checkpoint takes {:.4f}s.'.format(
                    time.perf_counter() - start))

            if (global_step + 1) % cfg.vis.vis_every == 0 and cfg.vis.ison:
                print('Doing visualization...')
                start = time.perf_counter()
                vis_logger.train_vis(model,
                                     valset,
                                     writer,
                                     global_step,
                                     cfg.vis.indices,
                                     cfg.device,
                                     cond_steps=cfg.vis.cond_steps,
                                     fg_sample=cfg.vis.fg_sample,
                                     bg_sample=cfg.vis.bg_sample,
                                     num_gen=cfg.vis.num_gen)
                print(
                    'Visualization takes {:.4f}s.'.format(time.perf_counter() -
                                                          start))

            if (global_step + 1) % cfg.val.val_every == 0 and cfg.val.ison:
                print('Doing evaluation...')
                start = time.perf_counter()
                evaluator.train_eval(
                    evaluator, os.path.join(cfg.evaldir,
                                            cfg.exp_name), cfg.val.metrics,
                    cfg.val.eval_types, cfg.val.intervals, cfg.val.cond_steps,
                    model, valset, valloader, cfg.device, writer, global_step,
                    [model, optimizer, epoch, global_step], checkpointer)
                print('Evaluation takes {:.4f}s.'.format(time.perf_counter() -
                                                         start))

            start = time.perf_counter()
            global_step += 1
            if global_step >= cfg.train.max_steps:
                end_flag = True
                break
コード例 #7
0
def eval_balls(cfg):
    print('\nLoading data...')
    assert cfg.val.mode == 'test', 'Please set cfg.val.mode to "test"'
    dataset = get_dataset(cfg, cfg.val.mode)
    dataloader = get_dataloader(cfg, cfg.val.mode)
    print('Data loaded.')

    print('Initializing model...')
    model = get_model(cfg)
    model = model.to(cfg.device)
    print('Model initialized.')

    checkpointer = Checkpointer(os.path.join(cfg.checkpointdir, cfg.exp_name),
                                max_num=cfg.train.max_ckpt)

    global_step = 0
    if cfg.resume:
        checkpoint = checkpointer.load(cfg.resume_ckpt, model, None)
        if checkpoint:
            start_epoch = checkpoint['epoch']
            global_step = checkpoint['global_step'] + 1
    if cfg.parallel:
        model = nn.DataParallel(model, device_ids=cfg.device_ids)

    evaluator = get_evaluator(cfg)
    ###

    evaldir = os.path.join(cfg.evaldir, cfg.exp_name)
    os.makedirs(evaldir, exist_ok=True)
    print("Evaluating...")
    start = time.perf_counter()
    model.eval()
    results = {}
    for eval_type in cfg.val.eval_types:
        if eval_type == 'tracking':
            model_fn = lambda model, imgs: model.track(imgs,
                                                       discovery_dropout=0)
        elif eval_type == 'generation':
            model_fn = lambda model, imgs: model.generate(imgs,
                                                          cond_steps=cfg.val.
                                                          cond_steps,
                                                          fg_sample=False,
                                                          bg_sample=False)

        print(f'Evaluating {eval_type}...')
        skip = cfg.val.cond_steps if eval_type == 'generation' else 0
        (iou_summary, euclidean_summary,
         med_summary) = evaluator.evaluate(eval_type, model, model_fn, skip,
                                           dataset, dataloader, evaldir,
                                           cfg.device, cfg.val.metrics)
        # print('iou_summary: {}'.format(iou_summary))
        # print('euclidean_summary: {}'.format(euclidean_summary))
        # print('med_summary: {}'.format(med_summary))

        results[eval_type] = [iou_summary, euclidean_summary, med_summary]

    for eval_type in cfg.val.eval_types:
        evaluator.dump_to_json(*results[eval_type], evaldir, 'ours',
                               cfg.dataset.lower(), eval_type, cfg.run_num,
                               cfg.resume_ckpt, cfg.exp_name)
    print('Evaluation takes {}s.'.format(time.perf_counter() - start))

    # Plot figure
    if 'generation' in cfg.val.eval_types and 'med' in cfg.val.metrics:
        med_list = results['generation'][-1]['meds_over_time']
        assert len(med_list) == 90
        steps = np.arange(10, 100)
        f, ax = plt.subplots()
        ax: plt.Axes
        ax.plot(steps, med_list)
        ax.set_xlabel('Time step')
        ax.set_ylim(0.0, 0.6)
        ax.set_ylabel('Position error')
        ax.set_title(cfg.exp_name)
        plt.savefig(os.path.join(evaldir, 'plot_balls.png'))
        print('Plot saved to', os.path.join(evaldir, 'plot_balls.png'))
        print('MED summed over the first 10 prediction steps: ',
              sum(med_list[:10]))