Exemple #1
0
def train_loop(hp, logger, writer):
    # make dataloader
    logger.info("Making train dataloader...")
    train_loader = create_dataloader(hp, DataloaderMode.train)
    logger.info("Making test dataloader...")
    test_loader = create_dataloader(hp, DataloaderMode.test)

    # init Model
    net_arch = Net_arch(hp)
    loss_f = torch.nn.MSELoss()
    model = Model(hp, net_arch, loss_f)

    if hp.load.resume_state_path is not None:
        model.load_training_state(logger)
    else:
        logger.info("Starting new training run.")

    try:
        for model.epoch in itertools.count(model.epoch + 1):
            if model.epoch > hp.train.num_iter:
                break
            train_model(hp, model, train_loader, writer, logger)
            if model.epoch % hp.log.chkpt_interval == 0:
                model.save_network(logger)
                model.save_training_state(logger)
            test_model(hp, model, test_loader, writer)
        logger.info("End of Train")
    except Exception as e:
        logger.info("Exiting due to exception: %s" % e)
        traceback.print_exc()
Exemple #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c',
                        '--config',
                        type=str,
                        required=True,
                        help="yaml file for config.")
    parser.add_argument('-p',
                        '--checkpoint_path',
                        type=str,
                        default=None,
                        help="path of checkpoint pt file for resuming")
    parser.add_argument(
        '-n',
        '--name',
        type=str,
        required=True,
        help="Name of the model. Used for both logging and saving chkpt.")
    args = parser.parse_args()

    hp = HParam(args.config)
    hp_str = yaml.dump(hp)
    args_str = yaml.dump(vars(args))

    pt_dir = os.path.join(hp.log.chkpt_dir, args.name)
    log_dir = os.path.join(hp.log.log_dir, args.name)
    os.makedirs(pt_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)

    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[
                            logging.FileHandler(
                                os.path.join(
                                    log_dir,
                                    '%s-%d.log' % (args.name, time.time()))),
                            logging.StreamHandler()
                        ])
    logger = logging.getLogger()

    logger.info('Config by yaml file')
    logger.info(hp_str)
    logger.info('Command Line Config')
    logger.info(args_str)

    if hp.data.train == '' or hp.data.test == '':
        logger.error("train or test data directory cannot be empty.")
        raise Exception("Please specify directories of data in %s" %
                        args.config)

    writer = Writer(hp, log_dir)
    train_loader = create_dataloader(hp, args, DataloaderMode.train)
    test_loader = create_dataloader(hp, args, DataloaderMode.test)

    train(args, pt_dir, train_loader, test_loader, writer, logger, hp, hp_str)
def train_loop(rank, hp, world_size=1):
    # reload hp
    hp = DotDict(hp)
    if hp.model.device.lower() == "cuda" and world_size != 0:
        setup(hp, rank, world_size)
    if rank != 0:
        logger = None
        writer = None
    else:
        # set logger
        logger = make_logger(hp)
        # set writer (tensorboard / wandb)
        writer = Writer(hp, hp.log.log_dir)
        hp_str = yaml.dump(hp.to_dict())
        logger.info("Config:")
        logger.info(hp_str)
        if hp.data.train_dir == "" or hp.data.test_dir == "":
            logger.error("train or test data directory cannot be empty.")
            raise Exception("Please specify directories of data")
        logger.info("Set up train process")

    if hp.model.device.lower() == "cuda" and world_size != 0:
        hp.model.device = rank
        torch.cuda.set_device(rank)
    else:
        hp.model.device = hp.model.device.lower()

    # make dataloader
    if logger is not None:
        logger.info("Making train dataloader...")
    train_loader = create_dataloader(hp, DataloaderMode.train, rank,
                                     world_size)
    if logger is not None:
        logger.info("Making test dataloader...")
    test_loader = create_dataloader(hp, DataloaderMode.test, rank, world_size)

    # init Model
    net_arch = Net_arch(hp)
    loss_f = torch.nn.MSELoss()
    model = Model(hp, net_arch, loss_f, rank, world_size)

    # load training state
    if hp.load.resume_state_path is not None:
        model.load_training_state(logger)
    else:
        if logger is not None:
            logger.info("Starting new training run.")

    try:
        epoch_step = 1 if hp.data.divide_dataset_per_gpu else world_size
        for model.epoch in itertools.count(model.epoch + 1, epoch_step):
            if model.epoch > hp.train.num_iter:
                break
            train_model(hp, model, train_loader, writer, logger)
            if model.epoch % hp.log.chkpt_interval == 0:
                model.save_network(logger)
                model.save_training_state(logger)
            test_model(hp, model, test_loader, writer)
        cleanup()
        if logger is not None:
            logger.info("End of Train")
    except Exception as e:
        if logger is not None:
            logger.info("Exiting due to exception: %s" % e)
        traceback.print_exc()
        cleanup()
Exemple #4
0
def train_loop(rank, cfg):
    logger = get_logger(cfg, os.path.basename(__file__))
    if cfg.device == "cuda" and cfg.dist.gpus != 0:
        cfg.device = rank
        # turn off background generator when distributed run is on
        cfg.data.use_background_generator = False
        setup(cfg, rank)
        torch.cuda.set_device(cfg.device)

    # setup writer
    if is_logging_process():
        # set log/checkpoint dir
        os.makedirs(cfg.log.chkpt_dir, exist_ok=True)
        # set writer (tensorboard / wandb)
        writer = Writer(cfg, "tensorboard")
        cfg_str = OmegaConf.to_yaml(cfg)
        logger.info("Config:\n" + cfg_str)
        if cfg.data.train_dir == "" or cfg.data.test_dir == "":
            logger.error("train or test data directory cannot be empty.")
            raise Exception("Please specify directories of data")
        logger.info("Set up train process")
        logger.info("BackgroundGenerator is turned off when Distributed running is on")

        # download MNIST dataset before making dataloader
        # TODO: This is example code. You should change this part as you need
        _ = torchvision.datasets.MNIST(
            root=hydra.utils.to_absolute_path("dataset/meta"),
            train=True,
            transform=torchvision.transforms.ToTensor(),
            download=True,
        )
        _ = torchvision.datasets.MNIST(
            root=hydra.utils.to_absolute_path("dataset/meta"),
            train=False,
            transform=torchvision.transforms.ToTensor(),
            download=True,
        )
    # Sync dist processes (because of download MNIST Dataset)
    if cfg.dist.gpus != 0:
        dist.barrier()

    # make dataloader
    if is_logging_process():
        logger.info("Making train dataloader...")
    train_loader = create_dataloader(cfg, DataloaderMode.train, rank)
    if is_logging_process():
        logger.info("Making test dataloader...")
    test_loader = create_dataloader(cfg, DataloaderMode.test, rank)

    # init Model
    net_arch = Net_arch(cfg)
    loss_f = torch.nn.CrossEntropyLoss()
    model = Model(cfg, net_arch, loss_f, rank)

    # load training state / network checkpoint
    if cfg.load.resume_state_path is not None:
        model.load_training_state()
    elif cfg.load.network_chkpt_path is not None:
        model.load_network()
    else:
        if is_logging_process():
            logger.info("Starting new training run.")

    try:
        if cfg.dist.gpus == 0 or cfg.data.divide_dataset_per_gpu:
            epoch_step = 1
        else:
            epoch_step = cfg.dist.gpus
        for model.epoch in itertools.count(model.epoch + 1, epoch_step):
            if model.epoch > cfg.num_epoch:
                break
            train_model(cfg, model, train_loader, writer)
            if model.epoch % cfg.log.chkpt_interval == 0:
                model.save_network()
                model.save_training_state()
            test_model(cfg, model, test_loader, writer)
        if is_logging_process():
            logger.info("End of Train")
    except Exception as e:
        if is_logging_process():
            logger.error(traceback.format_exc())
        else:
            traceback.print_exc()
    finally:
        if cfg.dist.gpus != 0:
            cleanup()
Exemple #5
0
def train_loop(rank, hp, world_size=0):
    if hp.model.device == "cuda" and world_size != 0:
        hp.model.device = rank
        # turn off background generator when distributed run is on
        hp.data.use_background_generator = False
        setup(hp, rank, world_size)
        torch.cuda.set_device(hp.model.device)

    # setup logger / writer
    if rank != 0:
        logger = None
        writer = None
    else:
        # set logger
        logger = make_logger(hp)
        # set writer (tensorboard / wandb)
        writer = Writer(hp, os.path.join(hp.log.log_dir, "tensorboard"))
        hp_str = yaml.dump(hp.to_dict())
        logger.info("Config:")
        logger.info(hp_str)
        if hp.data.train_dir == "" or hp.data.test_dir == "":
            logger.error("train or test data directory cannot be empty.")
            raise Exception("Please specify directories of data")
        logger.info("Set up train process")
        logger.info(
            "BackgroundGenerator is turned off when Distributed running is on")

        # download MNIST dataset before making dataloader
        # TODO: This is example code. You should change this part as you need
        _ = torchvision.datasets.MNIST(
            root="dataset/meta",
            train=True,
            transform=torchvision.transforms.ToTensor(),
            download=True,
        )
        _ = torchvision.datasets.MNIST(
            root="dataset/meta",
            train=False,
            transform=torchvision.transforms.ToTensor(),
            download=True,
        )
    # Sync dist processes (because of download MNIST Dataset)
    if world_size != 0:
        dist.barrier()

    # make dataloader
    if logger is not None:
        logger.info("Making train dataloader...")
    train_loader = create_dataloader(hp, DataloaderMode.train, rank,
                                     world_size)
    if logger is not None:
        logger.info("Making test dataloader...")
    test_loader = create_dataloader(hp, DataloaderMode.test, rank, world_size)

    # init Model
    net_arch = Net_arch(hp)
    loss_f = torch.nn.CrossEntropyLoss()
    model = Model(hp, net_arch, loss_f, rank, world_size)

    # load training state / network checkpoint
    if hp.load.resume_state_path is not None:
        model.load_training_state(logger)
    elif hp.load.network_chkpt_path is not None:
        model.load_network(logger=logger)
    else:
        if logger is not None:
            logger.info("Starting new training run.")

    try:
        if world_size == 0 or hp.data.divide_dataset_per_gpu:
            epoch_step = 1
        else:
            epoch_step = world_size
        for model.epoch in itertools.count(model.epoch + 1, epoch_step):
            if model.epoch > hp.train.num_epoch:
                break
            train_model(hp, model, train_loader, writer, logger)
            if model.epoch % hp.log.chkpt_interval == 0:
                model.save_network(logger)
                model.save_training_state(logger)
            test_model(hp, model, test_loader, writer, logger)
        if logger is not None:
            logger.info("End of Train")
    except Exception as e:
        if logger is not None:
            logger.error(traceback.format_exc())
        else:
            traceback.print_exc()
    finally:
        if world_size != 0:
            cleanup()
    logger = logging.getLogger()

    if hp.data.train == '' or hp.data.val == '':
        logger.error("hp.data.train, hp.data.val cannot be empty")
        raise Exception("Please specify directories of train data.")

    if hp.model.graph0 == '' or hp.model.graph1 == '' or hp.model.graph2 == '':
        logger.error("hp.model.graph0, graph1, graph2 cannot be empty")
        raise Exception("Please specify random DAG architecture.")

    graphs = [
        read_graph(hp.model.graph0),
        read_graph(hp.model.graph1),
        read_graph(hp.model.graph2),
    ]

    writer = MyWriter(log_dir)

    trainset = create_dataloader(hp, args, True)
    valset = create_dataloader(hp, args, False)
    train(out_dir,
          chkpt_path,
          trainset,
          valset,
          writer,
          logger,
          hp,
          hp_str,
          graphs,
          in_channels=3)