Exemplo n.º 1
0
def do_test(cfg, model, dataloader, logger, task, load_ckpt):
    if load_ckpt is not None:
        checkpointer = Checkpointer(model,
                                    save_dir=cfg.OUTPUT_DIR,
                                    logger=logger,
                                    monitor_unit='episode')
        checkpointer.load_checkpoint(load_ckpt)

    val_metrics = model.metric_evaluator

    model.eval()
    num_images = 0
    meters = MetricLogger(delimiter="  ")

    logger.info('Start testing...')
    start_testing_time = time.time()
    end = time.time()
    for iteration, data in enumerate(dataloader):
        data_time = time.time() - end
        inputs, labels = torch.cat(
            data['images']).to(device), data['labels'].to(device)
        logits = model(inputs)
        val_metrics.accumulated_update(logits, labels)
        num_images += logits.shape[0]

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)
        eta_seconds = meters.time.global_avg * (len(dataloader) - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
        if iteration % 50 == 0 and iteration > 0:
            logger.info('eta: {}, iter: {}/{}'.format(eta_string, iteration,
                                                      len(dataloader)))

    val_metrics.gather_results()
    logger.info('num of images: {}'.format(num_images))
    logger.info('{} top1 acc: {:.4f}'.format(
        task, val_metrics.accumulated_topk_corrects['top1_acc']))

    total = time.time() - start_testing_time
    total_time_str = str(datetime.timedelta(seconds=total))

    logger.info("Total testing time: {}".format(total_time_str))
    return val_metrics
Exemplo n.º 2
0
def do_test(cfg, model, dataloader, logger, task, load_ckpt):
    checkpointer = Checkpointer(model, save_dir=cfg.OUTPUT_DIR, logger=logger)
    checkpointer.load_checkpoint(load_ckpt)
    val_metrics = TransferNetMetrics(cfg)

    model.eval()
    num_images = 0
    logger.info('Start testing...')
    for iteration, data in enumerate(dataloader):
        inputs, targets = data['image'].to(device), data['label_articleType'].to(device)
        cls_scores = model(inputs, targets)
        val_metrics.accumulated_update(cls_scores, targets)
        num_images += len(targets)

    val_metrics.gather_results()
    logger.info('num of images: {}'.format(num_images))
    logger.info('{} top1/5 acc: {:.4f}/{:.4f}, mean class acc: {:.4f}'.format(task,
                                                                              val_metrics.accumulated_topk_corrects[
                                                                                  'top1_acc'],
                                                                              val_metrics.accumulated_topk_corrects[
                                                                                  'top5_acc'],
                                                                              val_metrics.mean_class_accuracy))
    return val_metrics
Exemplo n.º 3
0
def do_train(cfg, model, train_dataloader, logger, load_ckpt=None):
    # define optimizer
    if cfg.TRAIN.OPTIMIZER == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=cfg.TRAIN.LR_BASE,
                              momentum=cfg.TRAIN.MOMENTUM,
                              weight_decay=cfg.TRAIN.WEIGHT_DECAY)
    else:
        raise NotImplementedError

    # define learning rate scheduler
    lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                    cfg.TRAIN.LR_DECAY)
    checkpointer = Checkpointer(model,
                                optimizer,
                                lr_scheduler,
                                cfg.OUTPUT_DIR,
                                logger,
                                monitor_unit='episode')

    training_args = {}
    # training_args['iteration'] = 1
    training_args['episode'] = 1
    if load_ckpt:
        checkpointer.load_checkpoint(load_ckpt, strict=False)

    if checkpointer.has_checkpoint():
        extra_checkpoint_data = checkpointer.load()
        training_args.update(extra_checkpoint_data)

    start_episode = training_args['episode']
    episode = training_args['episode']

    meters = MetricLogger(delimiter="  ")
    end = time.time()
    start_training_time = time.time()

    model.train()
    break_while = False
    while not break_while:
        for inner_iter, data in enumerate(train_dataloader):
            training_args['episode'] = episode
            data_time = time.time() - end
            # targets = torch.cat(data['labels']).to(device)
            inputs = torch.cat(data['images']).to(device)
            logits = model(inputs)
            losses = model.loss_evaluator(logits)
            metrics = model.metric_evaluator(logits)

            total_loss = sum(loss for loss in losses.values())
            meters.update(loss=total_loss, **losses, **metrics)

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            batch_time = time.time() - end
            end = time.time()
            meters.update(time=batch_time, data=data_time)

            if inner_iter % cfg.TRAIN.PRINT_PERIOD == 0:
                eta_seconds = meters.time.global_avg * (cfg.TRAIN.MAX_EPISODE -
                                                        episode)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

                logger.info(
                    meters.delimiter.join([
                        "eta: {eta}",
                        "episode: {ep}/{max_ep}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.0f}",
                    ]).format(
                        eta=eta_string,
                        ep=episode,
                        max_ep=cfg.TRAIN.MAX_EPISODE,
                        iter=inner_iter,
                        max_iter=len(train_dataloader),
                        meters=str(meters),
                        lr=optimizer.param_groups[-1]["lr"],
                        memory=(torch.cuda.max_memory_allocated() / 1024.0 /
                                1024.0) if torch.cuda.is_available() else 0.,
                    ))

            if episode % cfg.TRAIN.LR_DECAY_EPISODE == 0:
                logger.info("lr decayed to {:.4f}".format(
                    optimizer.param_groups[-1]["lr"]))
                lr_scheduler.step()

            if episode == cfg.TRAIN.MAX_EPISODE:
                break_while = True
                checkpointer.save("model_{:06d}".format(episode),
                                  **training_args)
                break

            if episode % cfg.TRAIN.SAVE_CKPT_EPISODE == 0:
                checkpointer.save("model_{:06d}".format(episode),
                                  **training_args)

            episode += 1

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))

    logger.info("Total training time: {} ({:.4f} s / epoch)".format(
        total_time_str, total_training_time /
        (episode - start_episode if episode > start_episode else 1)))
Exemplo n.º 4
0
def do_train(cfg, model, train_dataloader, val_dataloader, logger, load_ckpt=None):
    # define optimizer
    if cfg.TRAIN.OPTIMIZER == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=cfg.TRAIN.LR_BASE,
                              momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY)
    else:
        raise NotImplementedError

    # define learning rate scheduler
    lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, cfg.TRAIN.LR_DECAY)
    checkpointer = Checkpointer(model, optimizer, lr_scheduler, cfg.OUTPUT_DIR, logger)

    training_args = {}
    # training_args['iteration'] = 1
    training_args['epoch'] = 1
    training_args['val_best'] = 0.
    if load_ckpt:
        checkpointer.load_checkpoint(load_ckpt, strict=False)

    if checkpointer.has_checkpoint():
        extra_checkpoint_data = checkpointer.load()
        training_args.update(extra_checkpoint_data)

    # start_iter = training_args['iteration']
    start_epoch = training_args['epoch']
    checkpointer.current_val_best = training_args['val_best']

    meters = MetricLogger(delimiter="  ")
    end = time.time()
    start_training_time = time.time()

    for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH + 1):
        training_args['epoch'] = epoch
        model.train()
        for inner_iter, data in enumerate(train_dataloader):
            # training_args['iteration'] = iteration
            # logger.info('inner_iter: {}, label: {}'.format(inner_iter, data['label_articleType'], len(data['label_articleType'])))
            data_time = time.time() - end
            inputs, targets = data['image'].to(device), data['label_articleType'].to(device)
            cls_scores = model(inputs, targets)
            losses = model.loss_evaluator(cls_scores, targets)
            metrics = model.metric_evaluator(cls_scores, targets)

            total_loss = sum(loss for loss in losses.values())
            meters.update(loss=total_loss, **losses, **metrics)

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            batch_time = time.time() - end
            end = time.time()
            meters.update(time=batch_time, data=data_time)

            if inner_iter % cfg.TRAIN.PRINT_PERIOD == 0:
                eta_seconds = meters.time.global_avg * (len(train_dataloader) * cfg.TRAIN.MAX_EPOCH -
                                                        (epoch - 1) * len(train_dataloader) - inner_iter)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

                logger.info(
                    meters.delimiter.join(
                        [
                            "eta: {eta}",
                            "epoch: {ep}/{max_ep} (iter: {iter}/{max_iter})",
                            "{meters}",
                            "lr: {lr:.6f}",
                            "max mem: {memory:.0f}",
                        ]
                    ).format(
                        eta=eta_string,
                        ep=epoch,
                        max_ep=cfg.TRAIN.MAX_EPOCH,
                        iter=inner_iter,
                        max_iter=len(train_dataloader),
                        meters=str(meters),
                        lr=optimizer.param_groups[-1]["lr"],
                        memory=(
                                    torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else 0.,
                    )
                )

        if epoch % cfg.TRAIN.VAL_EPOCH == 0:
            logger.info('start evaluating at epoch {}'.format(epoch))
            val_metrics = do_eval(cfg, model, val_dataloader, logger, 'validation')
            if val_metrics.mean_class_accuracy > checkpointer.current_val_best:
                checkpointer.current_val_best = val_metrics.mean_class_accuracy
                training_args['val_best'] = checkpointer.current_val_best
                checkpointer.save("model_{:04d}_val_{:.4f}".format(epoch, checkpointer.current_val_best),
                                  **training_args)
                checkpointer.patience = 0
            else:
                checkpointer.patience += 1

            logger.info('current patience: {}/{}'.format(checkpointer.patience, cfg.TRAIN.PATIENCE))

        if epoch == cfg.TRAIN.MAX_EPOCH or epoch % cfg.TRAIN.SAVE_CKPT_EPOCH == 0 or checkpointer.patience == cfg.TRAIN.PATIENCE:
            checkpointer.save("model_{:04d}".format(epoch), **training_args)

        if checkpointer.patience == cfg.TRAIN.PATIENCE:
            logger.info('Max patience triggered. Early terminate training')
            break

        if epoch % cfg.TRAIN.LR_DECAY_EPOCH == 0:
            logger.info("lr decayed to {:.4f}".format(optimizer.param_groups[-1]["lr"]))
            lr_scheduler.step()

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))

    logger.info(
        "Total training time: {} ({:.4f} s / epoch)".format(
            total_time_str, total_training_time / (epoch - start_epoch if epoch > start_epoch else 1)
        )
    )