def __init__(self, config):

        # Config
        self.config = config

        self.start = 0  # Unless using pre-trained model

        # Create directories if not exist
        utils.make_folder(self.config.save_path)
        utils.make_folder(self.config.model_weights_path)
        utils.make_folder(self.config.sample_images_path)

        # Copy files
        utils.write_config_to_file(self.config, self.config.save_path)
        utils.copy_scripts(self.config.save_path)

        # Check for CUDA
        utils.check_for_CUDA(self)

        # Make dataloader
        self.dataloader, self.num_of_classes = utils.make_dataloader(
            self.config.batch_size_in_gpu, self.config.dataset,
            self.config.data_path, self.config.shuffle, self.config.drop_last,
            self.config.dataloader_args, self.config.resize,
            self.config.imsize, self.config.centercrop,
            self.config.centercrop_size)

        # Data iterator
        self.data_iter = iter(self.dataloader)

        # Build G and D
        self.build_models()

        if self.config.adv_loss == 'dcgan':
            self.criterion = nn.BCELoss()
示例#2
0
    def __init__(self, config):

        # Images data path & Output path
        self.dataset = config.dataset
        self.data_path = config.data_path
        self.save_path = os.path.join(config.save_path, config.name)

        # Training settings
        self.batch_size = config.batch_size
        self.total_step = config.total_step
        self.d_steps_per_iter = config.d_steps_per_iter
        self.g_steps_per_iter = config.g_steps_per_iter
        self.d_lr = config.d_lr
        self.g_lr = config.g_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.inst_noise_sigma = config.inst_noise_sigma
        self.inst_noise_sigma_iters = config.inst_noise_sigma_iters
        self.start = 0  # Unless using pre-trained model

        # Image transforms
        self.shuffle = config.shuffle
        self.drop_last = config.drop_last
        self.resize = config.resize
        self.imsize = config.imsize
        self.centercrop = config.centercrop
        self.centercrop_size = config.centercrop_size
        self.tanh_scale = config.tanh_scale
        self.normalize = config.normalize

        # Step size
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.save_n_images = config.save_n_images
        self.max_frames_per_gif = config.max_frames_per_gif

        # Pretrained model
        self.pretrained_model = config.pretrained_model

        # Misc
        self.manual_seed = config.manual_seed
        self.disable_cuda = config.disable_cuda
        self.parallel = config.parallel
        self.dataloader_args = config.dataloader_args

        # Output paths
        self.model_weights_path = os.path.join(self.save_path,
                                               config.model_weights_dir)
        self.sample_path = os.path.join(self.save_path, config.sample_dir)

        # Model hyper-parameters
        self.adv_loss = config.adv_loss
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.lambda_gp = config.lambda_gp

        # Model name
        self.name = config.name

        # Create directories if not exist
        utils.make_folder(self.save_path)
        utils.make_folder(self.model_weights_path)
        utils.make_folder(self.sample_path)

        # Copy files
        utils.write_config_to_file(config, self.save_path)
        utils.copy_scripts(self.save_path)

        # Check for CUDA
        utils.check_for_CUDA(self)

        # Make dataloader
        self.dataloader, self.num_of_classes = utils.make_dataloader(
            self.batch_size, self.dataset, self.data_path, self.shuffle,
            self.drop_last, self.dataloader_args, self.resize, self.imsize,
            self.centercrop, self.centercrop_size)

        # Data iterator
        self.data_iter = iter(self.dataloader)

        # Build G and D
        self.build_models()

        # Start with pretrained model (if it exists)
        if self.pretrained_model != '':
            utils.load_pretrained_model(self)

        if self.adv_loss == 'dcgan':
            self.criterion = nn.BCELoss()
示例#3
0
config = Config()
device = torch.device("cuda")

# tensorboard
writer = SummaryWriter(log_dir=os.path.join(config.path, "tb"))
writer.add_text('config', config.as_markdown(), 0)

# logger
logger = utils.get_logger(
    os.path.join(config.path, "{}.log".format(config.name)))
logger.info("Run options = {}".format(sys.argv))
config.print_params(logger.info)

# copy scripts
utils.copy_scripts("*.py", config.path)


def main():
    logger.info("Logger is set - training start")

    # set gpu device id
    logger.info("Set GPU device {}".format(config.gpu))
    torch.cuda.set_device(config.gpu)

    # set seed
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.benchmark = True
示例#4
0
def main():
    config = Config()
    device = torch.device("cuda")

    # tensorboard
    writer = SummaryWriter(log_dir=os.path.join(config.path, "tb"))
    writer.add_text('config', config.as_markdown(), 0)

    # logger
    logger = utils.get_logger(
        os.path.join(config.path, "{}.log".format(config.name)))
    logger.info("Run options = {}".format(sys.argv))
    config.print_params(logger.info)

    # copy scripts
    utils.copy_scripts("*.py", config.path)

    logger.info("Logger is set - training start")

    # set gpu device id
    logger.info("Set GPU device {}".format(config.gpu))
    torch.cuda.set_device(config.gpu)

    # set seed
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed_all(config.seed)

    torch.backends.cudnn.benchmark = True

    # get dataset
    train_data, valid_data, data_shape = get_dataset(config.data,
                                                     config.data_path,
                                                     config.aug_lv)

    # build model
    criterion = nn.CrossEntropyLoss().to(device)
    model = FractalNet(data_shape,
                       config.columns,
                       config.init_channels,
                       p_ldrop=config.p_ldrop,
                       dropout_probs=config.dropout_probs,
                       gdrop_ratio=config.gdrop_ratio,
                       gap=config.gap,
                       init=config.init,
                       pad_type=config.pad,
                       doubling=config.doubling,
                       dropout_pos=config.dropout_pos,
                       consist_gdrop=config.consist_gdrop)
    model = model.to(device)

    # model size
    m_params = utils.param_size(model)
    logger.info("Models:\n{}".format(model))
    logger.info("Model size (# of params) = {:.3f} M".format(m_params))

    # weights optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                config.lr,
                                momentum=config.momentum)

    # setup data loader
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=config.workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_data,
                                               batch_size=config.batch_size,
                                               shuffle=False,
                                               num_workers=config.workers,
                                               pin_memory=True)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, config.lr_milestone)

    best_top1 = 0.
    # training loop
    for epoch in range(config.epochs):
        # training
        train(train_loader, model, optimizer, criterion, epoch, config, writer,
              logger, device)
        lr_scheduler.step()

        # validation
        cur_step = (epoch + 1) * len(train_loader)
        top1 = validate(valid_loader,
                        model,
                        criterion,
                        epoch,
                        cur_step,
                        config,
                        writer,
                        logger,
                        device,
                        deepest=False)

        # save
        if best_top1 < top1:
            best_top1 = top1
            is_best = True
        else:
            is_best = False
        utils.save_checkpoint(model.state_dict(), config.path, is_best)

        print("")

    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
示例#5
0
    # Seed
    torch.manual_seed(args.seed)

    # IMAGES DATALOADER
    train_loader, valid_loader = utils.make_dataloader(args)

    print(args)

    # OUT PATH
    if not os.path.exists(args.out_path):
        print("Making", args.out_path)
        os.makedirs(args.out_path)

    # Copy all scripts
    utils.copy_scripts(args.out_path)

    # Save all args
    utils.write_config_to_file(args, args.out_path)

    # MODEL

    torch.manual_seed(args.seed)
    if args.pth != '':
        pth_dir_name = os.path.dirname(args.pth)
        full_model_pth = os.path.join(pth_dir_name, 'model.pth')
        if os.path.exists(full_model_pth):
            print("Loading", full_model_pth)
            model = torch.load(full_model_pth)
            print("Loading pretrained state_dict", args.pth)
            model.load_state_dict(torch.load(args.pth))