Beispiel #1
0
def main():
    # device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    parser = argparse.ArgumentParser(description='Training of HiDDeN nets')
    parser.add_argument('--hostname',
                        default=socket.gethostname(),
                        help='the  host name of the running server')
    # parser.add_argument('--size', '-s', default=128, type=int, help='The size of the images (images are square so this is height and width).')
    parser.add_argument('--data-dir',
                        '-d',
                        required=True,
                        type=str,
                        help='The directory where the data is stored.')
    parser.add_argument(
        '--runs_root',
        '-r',
        default=os.path.join('.', 'experiments'),
        type=str,
        help='The root folder where data about experiments are stored.')
    parser.add_argument('--batch-size',
                        '-b',
                        default=1,
                        type=int,
                        help='Validation batch size.')

    args = parser.parse_args()

    if args.hostname == 'ee898-System-Product-Name':
        args.data_dir = '/home/ee898/Desktop/chaoning/ImageNet'
        args.hostname = 'ee898'
    elif args.hostname == 'DL178':
        args.data_dir = '/media/user/SSD1TB-2/ImageNet'
    else:
        args.data_dir = '/workspace/data_local/imagenet_pytorch'
    assert args.data_dir

    print_each = 25

    completed_runs = [
        o for o in os.listdir(args.runs_root)
        if os.path.isdir(os.path.join(args.runs_root, o))
        and o != 'no-noise-defaults'
    ]

    print(completed_runs)

    write_csv_header = True
    current_run = args.runs_root
    print(f'Run folder: {current_run}')
    options_file = os.path.join(current_run, 'options-and-config.pickle')
    train_options, hidden_config, noise_config = utils.load_options(
        options_file)
    train_options.train_folder = os.path.join(args.data_dir, 'val')
    train_options.validation_folder = os.path.join(args.data_dir, 'val')
    train_options.batch_size = args.batch_size
    checkpoint, chpt_file_name = utils.load_last_checkpoint(
        os.path.join(current_run, 'checkpoints'))
    print(f'Loaded checkpoint from file {chpt_file_name}')

    noiser = Noiser(noise_config, device, 'jpeg')
    model = Hidden(hidden_config, device, noiser, tb_logger=None)
    utils.model_from_checkpoint(model, checkpoint)

    print('Model loaded successfully. Starting validation run...')
    _, val_data = utils.get_data_loaders(hidden_config, train_options)
    file_count = len(val_data.dataset)
    if file_count % train_options.batch_size == 0:
        steps_in_epoch = file_count // train_options.batch_size
    else:
        steps_in_epoch = file_count // train_options.batch_size + 1

    with torch.no_grad():
        noises = ['webp_10', 'webp_25', 'webp_50', 'webp_75', 'webp_90']
        for noise in noises:
            losses_accu = {}
            step = 0
            for image, _ in val_data:
                step += 1
                image = image.to(device)
                message = torch.Tensor(
                    np.random.choice(
                        [0, 1], (image.shape[0],
                                 hidden_config.message_length))).to(device)
                losses, (
                    encoded_images, noised_images,
                    decoded_messages) = model.validate_on_batch_specific_noise(
                        [image, message], noise=noise)
                if not losses_accu:  # dict is empty, initialize
                    for name in losses:
                        losses_accu[name] = AverageMeter()
                for name, loss in losses.items():
                    losses_accu[name].update(loss)
                if step % print_each == 0 or step == steps_in_epoch:
                    print(f'Step {step}/{steps_in_epoch}')
                    utils.print_progress(losses_accu)
                    print('-' * 40)

            # utils.print_progress(losses_accu)
            write_validation_loss(os.path.join(args.runs_root,
                                               'validation_run.csv'),
                                  losses_accu,
                                  noise,
                                  checkpoint['epoch'],
                                  write_header=write_csv_header)
            write_csv_header = False
def train_own_noise(model: Hidden, device: torch.device,
                    hidden_config: HiDDenConfiguration,
                    train_options: TrainingOptions, this_run_folder: str,
                    tb_logger, noise):
    """
    Trains the HiDDeN model
    :param model: The model
    :param device: torch.device object, usually this is GPU (if avaliable), otherwise CPU.
    :param hidden_config: The network configuration
    :param train_options: The training settings
    :param this_run_folder: The parent folder for the current training run to store training artifacts/results/logs.
    :param tb_logger: TensorBoardLogger object which is a thin wrapper for TensorboardX logger.
                Pass None to disable TensorboardX logging
    :return:
    """

    train_data, val_data = utils.get_data_loaders(hidden_config, train_options)
    file_count = len(train_data.dataset)
    if file_count % train_options.batch_size == 0:
        steps_in_epoch = file_count // train_options.batch_size
    else:
        steps_in_epoch = file_count // train_options.batch_size + 1
    steps_in_epoch = 313

    print_each = 10
    images_to_save = 8
    saved_images_size = (
        512, 512)  # for qualitative check purpose to use a larger size

    for epoch in range(train_options.start_epoch,
                       train_options.number_of_epochs + 1):
        logging.info('\nStarting epoch {}/{}'.format(
            epoch, train_options.number_of_epochs))
        logging.info('Batch size = {}\nSteps in epoch = {}'.format(
            train_options.batch_size, steps_in_epoch))
        training_losses = defaultdict(AverageMeter)

        if train_options.video_dataset:
            random.shuffle(train_data.dataset)

        epoch_start = time.time()
        step = 1
        for image, _ in train_data:
            image = image.to(device)
            message = torch.Tensor(
                np.random.choice(
                    [0, 1],
                    (image.shape[0], hidden_config.message_length))).to(device)
            losses, _ = model.train_on_batch([image, message])

            for name, loss in losses.items():
                training_losses[name].update(loss)
            if step % print_each == 0 or step == steps_in_epoch:
                #import pdb; pdb.set_trace()
                logging.info('Epoch: {}/{} Step: {}/{}'.format(
                    epoch, train_options.number_of_epochs, step,
                    steps_in_epoch))
                utils.log_progress(training_losses)
                logging.info('-' * 40)
            step += 1
            if step == steps_in_epoch:
                break

        train_duration = time.time() - epoch_start
        logging.info('Epoch {} training duration {:.2f} sec'.format(
            epoch, train_duration))
        logging.info('-' * 40)
        utils.write_losses(os.path.join(this_run_folder, 'train.csv'),
                           training_losses, epoch, train_duration)
        if tb_logger is not None:
            tb_logger.save_losses(training_losses, epoch)
            tb_logger.save_grads(epoch)
            tb_logger.save_tensors(epoch)

        first_iteration = True
        validation_losses = defaultdict(AverageMeter)
        logging.info('Running validation for epoch {}/{} for noise {}'.format(
            epoch, train_options.number_of_epochs, noise))
        step = 1
        for image, _ in val_data:
            image = image.to(device)
            message = torch.Tensor(
                np.random.choice(
                    [0, 1],
                    (image.shape[0], hidden_config.message_length))).to(device)
            losses, (
                encoded_images, noised_images,
                decoded_messages) = model.validate_on_batch_specific_noise(
                    [image, message], noise=noise)
            for name, loss in losses.items():
                validation_losses[name].update(loss)
            if first_iteration:
                if hidden_config.enable_fp16:
                    image = image.float()
                    encoded_images = encoded_images.float()
                utils.save_images(
                    image.cpu()[:images_to_save, :, :, :],
                    encoded_images[:images_to_save, :, :, :].cpu(),
                    epoch,
                    os.path.join(this_run_folder, 'images'),
                    resize_to=saved_images_size)
                first_iteration = False
            step += 1
            if step == steps_in_epoch // 10:
                break

        utils.log_progress(validation_losses)
        logging.info('-' * 40)
        utils.save_checkpoint(model, train_options.experiment_name, epoch,
                              os.path.join(this_run_folder, 'checkpoints'))
        utils.write_losses(
            os.path.join(this_run_folder, 'validation_' + noise + '.csv'),
            validation_losses, epoch,
            time.time() - epoch_start)