Ejemplo n.º 1
0
def main(args):
    """Main train and evaluation function.

    Parameters
    ----------
    args: argparse.Namespace
        Arguments
    """
    formatter = logging.Formatter('%(asctime)s %(levelname)s - %(funcName)s: %(message)s',
                                  "%H:%M:%S")
    logger = logging.getLogger(__name__)
    logger.setLevel(args.log_level.upper())
    stream = logging.StreamHandler()
    stream.setLevel(args.log_level.upper())
    stream.setFormatter(formatter)
    logger.addHandler(stream)

    set_seed(args.seed)
    device = get_device(is_gpu=not args.no_cuda)
    exp_dir = os.path.join(RES_DIR, args.name)
    feature_dir = os.path.join(exp_dir, 'training_features')
    logger.info("Root directory for saving and loading experiments: {}".format(exp_dir))

    if not args.is_eval_only:
        create_safe_directory(feature_dir, logger=logger)

        # Setting number of epochs to 1, as we need to extract features
        args.epochs = 1
        args.batch_size = 1

        # PREPARES DATA
        data_loader = get_dataloaders(args.dataset,
                                       batch_size=args.batch_size,
                                       logger=logger, test=False)
        logger.info("Train {} with {} samples".format(args.dataset, len(data_loader.dataset)))

        # PREPARES MODEL
        args.img_size = get_img_size(args.dataset)  # stores for metadata
        model = load_model(exp_dir, filename='model.pt')
        logger.info('Num parameters in model: {}'.format(get_n_param(model)))

        # Extract Features

        model = model.to(device)  # make sure trainer and viz on same device
        fe = FeatureExtractor(model,
                          save_dir=exp_dir,
                          is_progress_bar=not args.no_progress_bar)
        fe(data_loader,
                epochs=args.epochs,
                checkpoint_every=args.checkpoint_every, feature_dir=feature_dir)

        # SAVE MODEL AND EXPERIMENT INFORMATION
        # save_model(trainer.model, exp_dir, metadata=vars(args))
        print('Done.')
Ejemplo n.º 2
0
def main(args):

    set_seed(args.seed)
    device = torch.device(
        'cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu')
    exp_dir = os.path.join(RES_DIR, args.name)
    print("save and load experiments at : {}".format(exp_dir))

    if not args.is_eval_only:  #train

        create_directory(exp_dir)

        # PREPARES TRAINING DATA
        train_loader = get_dataloaders(args.dataset,
                                       batch_size=args.batch_size)

        ##############
        # PREPARES MODEL
        args.img_size = get_img_size(args.dataset)  # stores for metadata
        cs = [1, 64, 128, 1024]
        model = VLAE(args, args.latent_dim, cs)

        #TRAINS
        optimizer = optim.Adam(model.parameters(), lr=args.lr)

        model = model.to(device)  # make sure trainer and viz on same device

        gif_visualizer = GifTraversalsTraining(model, args.dataset, exp_dir)

        reg_coeff = [args.reg_coeff0, args.reg_coeff1, args.reg_coeff2]

        trainer = Trainer(model,
                          optimizer,
                          reg_coeff,
                          device=device,
                          save_dir=exp_dir,
                          is_progress_bar=not args.no_progress_bar,
                          gif_visualizer=gif_visualizer)

        trainer(args,
                train_loader,
                epochs=args.epochs,
                checkpoint_every=args.checkpoint_every)

        #SAVE MODEL AND EXPERIMENT INFORMATION

        save_model(trainer.model, exp_dir, metadata=vars(args))
        print("Model has been saved")
Ejemplo n.º 3
0
def main(args):
    """Main train and evaluation function.

    Parameters
    ----------
    args: argparse.Namespace
        Arguments
    """
    formatter = logging.Formatter(
        '%(asctime)s %(levelname)s - %(funcName)s: %(message)s', "%H:%M:%S")
    logger = logging.getLogger(__name__)
    logger.setLevel(args.log_level.upper())
    stream = logging.StreamHandler()
    stream.setLevel(args.log_level.upper())
    stream.setFormatter(formatter)
    logger.addHandler(stream)

    set_seed(args.seed)
    device = get_device(is_gpu=not args.no_cuda)
    exp_dir = os.path.join(RES_DIR, args.name)
    logger.info("Root directory for saving and loading experiments: {}".format(
        exp_dir))

    if not args.is_eval_only:

        create_safe_directory(exp_dir, logger=logger)

        if args.loss == "factor":
            logger.info(
                "FactorVae needs 2 batches per iteration. To replicate this behavior while being consistent, we double the batch size and the the number of epochs."
            )
            args.batch_size *= 2
            args.epochs *= 2

        # PREPARES DATA
        train_loader = get_dataloaders(args.dataset,
                                       batch_size=args.batch_size,
                                       logger=logger)
        logger.info("Train {} with {} samples".format(
            args.dataset, len(train_loader.dataset)))

        # PREPARES MODEL
        args.img_size = get_img_size(args.dataset)  # stores for metadata
        model = init_specific_model(args.model_type, args.img_size,
                                    args.latent_dim)
        logger.info('Num parameters in model: {}'.format(get_n_param(model)))

        # TRAINS
        optimizer = optim.Adam(model.parameters(), lr=args.lr)

        model = model.to(device)  # make sure trainer and viz on same device
        gif_visualizer = GifTraversalsTraining(model, args.dataset, exp_dir)
        loss_f = get_loss_f(args.loss,
                            n_data=len(train_loader.dataset),
                            device=device,
                            **vars(args))
        trainer = Trainer(model,
                          optimizer,
                          loss_f,
                          device=device,
                          logger=logger,
                          save_dir=exp_dir,
                          is_progress_bar=not args.no_progress_bar,
                          gif_visualizer=gif_visualizer)
        trainer(
            train_loader,
            epochs=args.epochs,
            checkpoint_every=args.checkpoint_every,
        )

        # SAVE MODEL AND EXPERIMENT INFORMATION
        save_model(trainer.model, exp_dir, metadata=vars(args))

    if args.is_metrics or not args.no_test:
        model = load_model(exp_dir, is_gpu=not args.no_cuda)
        metadata = load_metadata(exp_dir)
        # TO-DO: currently uses train datatset

        test_loader = get_dataloaders(metadata["dataset"],
                                      batch_size=args.eval_batchsize,
                                      shuffle=False,
                                      logger=logger)
        loss_f = get_loss_f(args.loss,
                            n_data=len(test_loader.dataset),
                            device=device,
                            **vars(args))

        use_wandb = False
        if use_wandb:
            loss = args.loss
            wandb.init(project="atmlbetavae", config={"VAE_loss": args.loss})
            if loss == "betaH":
                beta = loss_f.beta
                wandb.config["Beta"] = beta
        evaluator = Evaluator(model,
                              loss_f,
                              device=device,
                              logger=logger,
                              save_dir=exp_dir,
                              is_progress_bar=not args.no_progress_bar,
                              use_wandb=use_wandb)

        evaluator(test_loader,
                  is_metrics=args.is_metrics,
                  is_losses=not args.no_test)
Ejemplo n.º 4
0
def main(args: argparse.Namespace):
    """Main train and evaluation function."""
    formatter = logging.Formatter(
        '%(asctime)s %(levelname)s - %(funcName)s: %(message)s', "%H:%M:%S")
    logger = logging.getLogger(__name__)
    logger.setLevel("INFO")
    stream = logging.StreamHandler()
    stream.setLevel("INFO")
    stream.setFormatter(formatter)
    logger.addHandler(stream)

    set_seed(args.seed)
    device = get_device(is_gpu=not args.no_cuda)
    exp_dir = os.path.join(RES_DIR, args.name)
    logger.info(
        f"Root directory for saving and loading experiments: {exp_dir}")

    if not args.is_eval_only:

        create_safe_directory(exp_dir, logger=logger)

        if args.loss == "factor":
            logger.info(
                "FactorVae needs 2 batches per iteration." +
                "To replicate this behavior, double batch size and epochs.")
            args.batch_size *= 2
            args.epochs *= 2

        # PREPARES DATA
        train_loader = get_dataloaders(args.dataset,
                                       noise=args.noise,
                                       batch_size=args.batch_size,
                                       logger=logger)
        logger.info(
            f"Train {args.dataset} with {len(train_loader.dataset)} samples")

        # PREPARES MODEL
        args.img_size = get_img_size(args.dataset)  # stores for metadata
        model = VAE(args.img_size, args.latent_dim)
        logger.info(f'Num parameters in model: {get_n_param(model)}')

        # TRAINS
        optimizer = optim.Adam(model.parameters(), lr=args.lr)

        model = model.to(device)
        gif_visualizer = GifTraversalsTraining(model, args.dataset, exp_dir)
        loss_f = get_loss_f(args.loss,
                            n_data=len(train_loader.dataset),
                            device=device,
                            **vars(args))

        if args.loss in ['tdGJS', 'tGJS']:
            loss_optimizer = optim.Adam(loss_f.parameters(), lr=args.lr)
        else:
            loss_optimizer = None
        print(loss_optimizer)
        trainer = Trainer(model,
                          optimizer,
                          loss_f,
                          device=device,
                          logger=logger,
                          save_dir=exp_dir,
                          is_progress_bar=not args.no_progress_bar,
                          gif_visualizer=gif_visualizer,
                          loss_optimizer=loss_optimizer,
                          denoise=args.noise is not None)
        trainer(
            train_loader,
            epochs=args.epochs,
            checkpoint_every=args.checkpoint_every,
        )

        # SAVE MODEL AND EXPERIMENT INFORMATION
        save_model(trainer.model, exp_dir, metadata=vars(args))

    # Eval
    model = load_model(exp_dir, is_gpu=not args.no_cuda)
    metadata = load_metadata(exp_dir)

    test_loader = get_dataloaders(metadata["dataset"],
                                  noise=args.noise,
                                  train=False,
                                  batch_size=128,
                                  logger=logger)
    loss_f = get_loss_f(args.loss,
                        n_data=len(test_loader.dataset),
                        device=device,
                        **vars(args))
    evaluator = Evaluator(model,
                          loss_f,
                          device=device,
                          is_metrics=args.is_metrics,
                          is_train=False,
                          logger=logger,
                          save_dir=exp_dir,
                          is_progress_bar=not args.no_progress_bar,
                          denoise=args.noise is not None)
    evaluator(test_loader)

    # Train set also
    test_loader = get_dataloaders(metadata["dataset"],
                                  train=True,
                                  batch_size=128,
                                  logger=logger)
    loss_f = get_loss_f(args.loss,
                        n_data=len(test_loader.dataset),
                        device=device,
                        **vars(args))
    evaluator = Evaluator(model,
                          loss_f,
                          device=device,
                          is_metrics=args.is_metrics,
                          is_train=True,
                          logger=logger,
                          save_dir=exp_dir,
                          is_progress_bar=not args.no_progress_bar)
    evaluator(test_loader)
Ejemplo n.º 5
0
            create_safe_directory(exp_dir, logger=logger)

            if args.loss == "factor":
                logger.info("FactorVae needs 2 batches per iteration. To replicate this behavior while being consistent, we double the batch size and the the number of epochs.")
                args.batch_size *= 2
                args.epochs *= 2                                                                                                        

            # PREPARES DATA
            train_loader = get_dataloaders(args.dataset,
                                        batch_size=args.batch_size,
                                        logger=logger)
            logger.info("Train {} with {} samples".format(args.dataset, len(train_loader.dataset)))

            # PREPARES MODEL
            args.img_size = get_img_size(args.dataset)  # stores for metadata
            model = init_specific_model(args.model_type, args.img_size, args.latent_dim)
           
            logger.info('Num parameters in model: {}'.format(get_n_param(model)))
           
            
            # TRAINS
            optimizer = optim.Adam(model.parameters(), lr=args.lr)

            model = model.to(device)  # make sure trainer and viz on same device
            gif_visualizer = GifTraversalsTraining(model, args.dataset, exp_dir)
            loss_f = get_loss_f(args.loss,
                                n_data=len(train_loader.dataset),
                                device=device,
                                **vars(args))
            wandb.watch(model, optimizer, log="parameters", log_freq=1000)