Exemplo n.º 1
0
def do_train(cfg, model, resume=False, val_set='firevysor_val'):
    model.train()
    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, min_lr=1e-6)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1, last_epoch=-1)
    metric = 0
    print_every = 50

    tensorboard_dir = osp.join(cfg.OUTPUT_DIR, 'tensorboard')
    checkpoint_dir = osp.join(cfg.OUTPUT_DIR, 'checkpoints')
    create_dir(tensorboard_dir)
    create_dir(checkpoint_dir)

    checkpointer = AdetCheckpointer(model,
                                    checkpoint_dir,
                                    optimizer=optimizer,
                                    scheduler=scheduler)
    start_iter = (checkpointer.resume_or_load(
        cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1)
    max_iter = cfg.SOLVER.MAX_ITER

    periodic_checkpointer = PeriodicCheckpointer(checkpointer,
                                                 cfg.SOLVER.CHECKPOINT_PERIOD,
                                                 max_iter=max_iter)

    writers = ([
        CommonMetricPrinter(max_iter),
        # JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
        TensorboardXWriter(tensorboard_dir),
    ] if comm.is_main_process() else [])
    data_loader = build_detection_train_loader(cfg)
    val_dataloader = build_detection_val_loader(cfg, val_set)

    logger.info("Starting training from iteration {}".format(start_iter))

    # [PHAT]: Create a log file
    log_file = open(cfg.MY_CUSTOM.LOG_FILE, 'w')

    best_loss = 1e6
    count_not_improve = 0
    train_size = 2177
    epoch_size = int(train_size / cfg.SOLVER.IMS_PER_BATCH)
    n_early_epoch = 10

    with EventStorage(start_iter) as storage:
        for data, iteration in zip(data_loader, range(start_iter, max_iter)):
            iteration = iteration + 1
            storage.step()

            loss_dict = model(data)
            losses = sum(loss for loss in loss_dict.values())

            assert torch.isfinite(losses).all(), loss_dict

            # Update loss dict
            loss_dict_reduced = {
                k: v.item()
                for k, v in comm.reduce_dict(loss_dict).items()
            }
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                storage.put_scalars(total_loss=losses_reduced,
                                    **loss_dict_reduced)

            # Early stopping
            if (iteration > start_iter) and ((iteration - start_iter) %
                                             epoch_size == 0):
                val_loss = do_val(cfg, model, val_dataloader)

                if val_loss >= best_loss:
                    count_not_improve += 1
                    # stop if models doesn't improve after <n_early_epoch> epoch
                    if count_not_improve == epoch_size * n_early_epoch:
                        break
                else:
                    count_not_improve = 0
                    best_loss = val_loss
                    periodic_checkpointer.save("best_model_early")

                # print(f"epoch {iteration//epoch_size}, val_loss: {val_loss}")
                log_file.write(
                    f"Epoch {(iteration-start_iter)//epoch_size}, val_loss: {val_loss}\n"
                )
                comm.synchronize()

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            lr = optimizer.param_groups[0]["lr"]
            storage.put_scalar("lr", lr, smoothing_hint=False)
            scheduler.step()

            if iteration - start_iter > 5 and (
                (iteration - start_iter) % print_every == 0
                    or iteration == max_iter):
                for writer in writers:
                    writer.write()

                # Write my log
                log_file.write(
                    f"[iter {iteration}, best_loss: {best_loss}] total_loss: {losses}, lr: {lr}\n"
                )

            periodic_checkpointer.step(iteration)

    log_file.close()
def do_train(cfg, model, resume=False):
    model.train()
    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)

    checkpointer = DetectionCheckpointer(model,
                                         cfg.OUTPUT_DIR,
                                         optimizer=optimizer,
                                         scheduler=scheduler)
    start_iter = (
        checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get(
            "iteration", -1) +
        1  #FIXME: does not continue from iteration # when resume=True
    )
    max_iter = cfg.SOLVER.MAX_ITER

    periodic_checkpointer = PeriodicCheckpointer(checkpointer,
                                                 cfg.SOLVER.CHECKPOINT_PERIOD,
                                                 max_iter=max_iter)

    writers = ([
        CommonMetricPrinter(max_iter),
        JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
        TensorboardXWriter(cfg.OUTPUT_DIR),
    ] if comm.is_main_process() else [])

    # init best monitor metric
    best_monitor_metric = None

    # init early stopping count
    es_count = 0

    # get train data loader
    data_loader = build_train_loader(cfg)
    logger.info("Starting training from iteration {}".format(start_iter))
    with EventStorage(start_iter) as storage:
        for data, iteration in zip(data_loader, range(start_iter, max_iter)):
            storage.step()

            _, losses, losses_reduced = get_loss(data, model)
            if comm.is_main_process():
                storage.put_scalars(total_loss=losses_reduced)

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            storage.put_scalar("lr",
                               optimizer.param_groups[0]["lr"],
                               smoothing_hint=False)
            scheduler.step()

            if (cfg.TEST.EVAL_PERIOD > 0
                    and (iteration + 1) % cfg.TEST.EVAL_PERIOD == 0
                    and iteration != max_iter - 1):
                results = do_test(cfg, model)
                storage.put_scalars(**results['metrics'])

                if cfg.EARLY_STOPPING.ENABLE:
                    curr = None
                    if cfg.EARLY_STOPPING.MONITOR in results['metrics'].keys():
                        curr = results['metrics'][cfg.EARLY_STOPPING.MONITOR]

                    if curr is None:
                        logger.warning(
                            "Early stopping enabled but cannot find metric: %s"
                            % cfg.EARLY_STOPPING.MONITOR)
                        logger.warning(
                            "Options for monitored metrics are: [%s]" %
                            ", ".join(map(str, results['metrics'].keys())))
                    elif best_monitor_metric is None:
                        best_monitor_metric = curr
                    elif get_es_result(cfg.EARLY_STOPPING.MODE, curr,
                                       best_monitor_metric):
                        best_monitor_metric = curr
                        es_count = 0
                        logger.info("Best metric %s improved to %0.4f" %
                                    (cfg.EARLY_STOPPING.MONITOR, curr))
                        # update best model
                        periodic_checkpointer.save(name="model_best",
                                                   **{**results['metrics']})
                        # save best metrics to a .txt file
                        with open(
                                os.path.join(cfg.OUTPUT_DIR,
                                             'best_metrics.txt'), 'w') as f:
                            json.dump(results['metrics'], f)
                    else:
                        logger.info(
                            "Early stopping metric %s did not improve, current %.04f, best %.04f"
                            % (cfg.EARLY_STOPPING.MONITOR, curr,
                               best_monitor_metric))
                        es_count += 1

                storage.put_scalar('val_loss', results['metrics']['val_loss'])

                comm.synchronize()

            if iteration - start_iter > 5 and ((iteration + 1) % 20 == 0
                                               or iteration == max_iter - 1):
                for writer in writers:
                    writer.write()
            periodic_checkpointer.step(iteration)

            if es_count >= cfg.EARLY_STOPPING.PATIENCE:
                logger.info(
                    "Early stopping triggered, metric %s has not improved for %s validation steps"
                    %
                    (cfg.EARLY_STOPPING.MONITOR, cfg.EARLY_STOPPING.PATIENCE))
                break