def train(model: Hidden, device: torch.device,
          hidden_config: HiDDenConfiguration, train_options: TrainingOptions,
          this_run_folder: str, tb_logger):
    """
    Trains the HiDDeN model
    :param model: The model
    :param device: torch.device object, usually this is GPU (if avaliable), otherwise CPU.
    :param hidden_config: The network configuration
    :param train_options: The training settings
    :param this_run_folder: The parent folder for the current training run to store training artifacts/results/logs.
    :param tb_logger: TensorBoardLogger object which is a thin wrapper for TensorboardX logger.
                Pass None to disable TensorboardX logging
    :return:
    """

    train_data, val_data = utils.get_data_loaders(hidden_config, train_options)
    file_count = len(train_data.dataset)
    if file_count % train_options.batch_size == 0:
        steps_in_epoch = file_count // train_options.batch_size
    else:
        steps_in_epoch = file_count // train_options.batch_size + 1

    print_each = 10
    images_to_save = 8
    saved_images_size = (512, 512)

    for epoch in range(train_options.start_epoch,
                       train_options.number_of_epochs + 1):
        logging.info('\nStarting epoch {}/{}'.format(
            epoch, train_options.number_of_epochs))
        logging.info('Batch size = {}\nSteps in epoch = {}'.format(
            train_options.batch_size, steps_in_epoch))
        training_losses = defaultdict(AverageMeter)
        epoch_start = time.time()
        step = 1
        #train
        for image, _ in train_data:
            image = image.to(device)
            """
            message = torch.Tensor(np.random.choice([0, 1], (image.shape[0], hidden_config.message_length))).to(device)
            losses, _ = model.train_on_batch([image, message])
            print(losses)
            """
            #crop imgs
            imgs = cropImg(32, image)
            #iterate img
            bitwise_arr = []
            main_losses = None
            encoded_imgs = []
            for img in imgs:
                img = img.to(device)
                message = torch.Tensor(
                    np.random.choice(
                        [0, 1], (img.shape[0],
                                 hidden_config.message_length))).to(device)
                losses, (encoded_images, noised_images,
                         decoded_messages) = model.train_on_batch(
                             [img, message])
                encoded_imgs.append(
                    encoded_images[0][0].cpu().detach().numpy())
                main_losses = losses
                for name, loss in losses.items():
                    if (name == 'bitwise-error  '):
                        bitwise_arr.append(loss)
            Total = 0
            Vcount = 0
            V_average = 0
            H_average = 0
            for i in range(0, len(encoded_imgs) - 1):
                if ((i + 1) % 4 != 0):
                    img = encoded_imgs[i]
                    img_next = encoded_imgs[i + 1]
                    average_img = 0
                    average_img_next = 0
                    for j in range(0, 32):
                        for k in range(0, 10):
                            average_img = average_img + img[j][31 - k]
                            average_img_next = average_img_next + img_next[j][k]
                    average_blocking = np.abs(average_img -
                                              average_img_next) / 320
                    V_average = V_average + average_blocking
                    for j in range(0, 32):
                        distinct = np.abs(img[j][31] - img_next[j][0])
                        Total = Total + 1
                        if (distinct > 0.5):
                            Vcount = Vcount + 1
            V_average = V_average / 12
            Hcount = 0
            for i in range(0, len(encoded_imgs) - 4):
                img = encoded_imgs[i]
                img_next = encoded_imgs[i + 4]
                average_img = 0
                average_img_next = 0
                for j in range(0, 32):
                    for k in range(0, 10):
                        average_img = average_img + img[31 - k][j]
                        average_img_next = average_img_next + img_next[k][j]
                average_blocking = np.abs(average_img - average_img_next) / 320
                H_average = H_average + average_blocking
                for j in range(0, 32):
                    distinct = np.abs(img[31][j] - img_next[0][j])
                    Total = Total + 1
                    if (distinct > 0.5):
                        Hcount = Hcount + 1
            H_average = H_average / 12

            bitwise_arr = np.array(bitwise_arr)
            bitwise_avg = np.average(bitwise_arr)
            #blocking_loss = (Vcount+Hcount)/Total
            blocking_loss = (H_average + V_average) / 2

            for name, loss in main_losses.items():
                if (name == 'bitwise-error  '):
                    training_losses[name].update(bitwise_avg)
                else:
                    if (name == 'blocking_effect'):
                        training_losses[name].update(blocking_loss)
                    else:
                        training_losses[name].update(loss)

            if step % print_each == 0 or step == steps_in_epoch:
                logging.info('Epoch: {}/{} Step: {}/{}'.format(
                    epoch, train_options.number_of_epochs, step,
                    steps_in_epoch))
                utils.log_progress(training_losses)
                logging.info('-' * 40)
            step += 1

        train_duration = time.time() - epoch_start
        logging.info('Epoch {} training duration {:.2f} sec'.format(
            epoch, train_duration))
        logging.info('-' * 40)
        utils.write_losses(os.path.join(this_run_folder, 'train.csv'),
                           training_losses, epoch, train_duration)
        if tb_logger is not None:
            tb_logger.save_losses(training_losses, epoch)
            tb_logger.save_grads(epoch)
            tb_logger.save_tensors(epoch)

        first_iteration = True
        validation_losses = defaultdict(AverageMeter)
        logging.info('Running validation for epoch {}/{}'.format(
            epoch, train_options.number_of_epochs))

        #val
        for image, _ in val_data:
            image = image.to(device)
            #crop imgs
            imgs = cropImg(32, image)
            #iterate img
            bitwise_arr = []
            main_losses = None
            encoded_imgs = []
            blocking_imgs = []
            for img in imgs:
                img = img.to(device)
                message = torch.Tensor(
                    np.random.choice(
                        [0, 1], (img.shape[0],
                                 hidden_config.message_length))).to(device)
                losses, (encoded_images, noised_images,
                         decoded_messages) = model.validate_on_batch(
                             [img, message])
                encoded_imgs.append(encoded_images)
                blocking_imgs.append(
                    encoded_images[0][0].cpu().detach().numpy())
                main_losses = losses
                for name, loss in losses.items():
                    if (name == 'bitwise-error  '):
                        bitwise_arr.append(loss)

            Total = 0
            Vcount = 0
            V_average = 0
            H_average = 0
            for i in range(0, len(blocking_imgs) - 1):
                if ((i + 1) % 4 != 0):
                    img = blocking_imgs[i]
                    img_next = blocking_imgs[i + 1]
                    average_img = 0
                    average_img_next = 0
                    for j in range(0, 32):
                        for k in range(0, 10):
                            average_img = average_img + img[j][31 - k]
                            average_img_next = average_img_next + img_next[j][k]
                    average_blocking = np.abs(average_img -
                                              average_img_next) / 320
                    V_average = V_average + average_blocking
                    for j in range(0, 32):
                        distinct = np.abs(img[j][31] - img_next[j][0])
                        Total = Total + 1
                        if (distinct > 0.5):
                            Vcount = Vcount + 1
            V_average = V_average / 12
            Hcount = 0
            for i in range(0, len(blocking_imgs) - 4):
                img = blocking_imgs[i]
                img_next = blocking_imgs[i + 4]
                for j in range(0, 32):
                    for k in range(0, 10):
                        average_img = average_img + img[31 - k][j]
                        average_img_next = average_img_next + img_next[k][j]
                average_blocking = np.abs(average_img - average_img_next) / 320
                H_average = H_average + average_blocking
                for j in range(0, 32):
                    distinct = np.abs(img[31][j] - img_next[0][j])
                    Total = Total + 1
                    if (distinct > 0.5):
                        Hcount = Hcount + 1
            H_average = H_average / 12

            bitwise_arr = np.array(bitwise_arr)
            bitwise_avg = np.average(bitwise_arr)
            #blocking_loss = (Vcount+Hcount)/Total
            blocking_loss = (H_average + V_average) / 2
            for name, loss in main_losses.items():
                if (name == 'bitwise-error  '):
                    validation_losses[name].update(bitwise_avg)
                else:
                    if (name == 'blocking_effect'):
                        validation_losses[name].update(blocking_loss)
                    else:
                        validation_losses[name].update(loss)
            #concat image
            encoded_images = concatImgs(encoded_imgs)

            if first_iteration:
                if hidden_config.enable_fp16:
                    image = image.float()
                    encoded_images = encoded_images.float()
                utils.save_images(
                    image.cpu()[:images_to_save, :, :, :],
                    encoded_images[:images_to_save, :, :, :].cpu(),
                    epoch,
                    os.path.join(this_run_folder, 'images'),
                    resize_to=saved_images_size)
                first_iteration = False

        utils.log_progress(validation_losses)
        logging.info('-' * 40)
        utils.save_checkpoint(model, train_options.experiment_name, epoch,
                              os.path.join(this_run_folder, 'checkpoints'))
        utils.write_losses(os.path.join(this_run_folder, 'validation.csv'),
                           validation_losses, epoch,
                           time.time() - epoch_start)
Ejemplo n.º 2
0
def train(model: Hidden, device: torch.device,
          hidden_config: HiDDenConfiguration, train_options: TrainingOptions,
          this_run_folder: str, tb_logger):
    """
    Trains the HiDDeN model
    :param model: The model
    :param device: torch.device object, usually this is GPU (if avaliable), otherwise CPU.
    :param hidden_config: The network configuration
    :param train_options: The training settings
    :param this_run_folder: The parent folder for the current training run to store training artifacts/results/logs.
    :param tb_logger: TensorBoardLogger object which is a thin wrapper for TensorboardX logger.
                Pass None to disable TensorboardX logging
    :return:
    """

    train_data, val_data = utils.get_data_loaders(hidden_config, train_options)

    images_to_save = 8
    saved_images_size = (512, 512)

    best_epoch = train_options.best_epoch
    best_cond = train_options.best_cond
    for epoch in range(train_options.start_epoch,
                       train_options.number_of_epochs + 1):
        logging.info(
            f'\nStarting epoch {epoch}/{train_options.number_of_epochs} [{best_epoch}]'
        )
        training_losses = defaultdict(functions.AverageMeter)
        epoch_start = time.time()
        for image, _ in tqdm(train_data, ncols=80):
            image = image.to(device)  #.squeeze(0)
            message = torch.Tensor(
                np.random.choice(
                    [0, 1],
                    (image.shape[0], hidden_config.message_length))).to(device)
            losses, _ = model.train_on_batch([image, message])

            for name, loss in losses.items():
                training_losses[name].update(loss)

        train_duration = time.time() - epoch_start
        logging.info('Epoch {} training duration {:.2f} sec'.format(
            epoch, train_duration))
        logging.info('-' * 40)
        utils.write_losses(os.path.join(this_run_folder, 'train.csv'),
                           training_losses, epoch, train_duration)
        if tb_logger is not None:
            tb_logger.save_losses('train_loss', training_losses, epoch)
            tb_logger.save_grads(epoch)
            tb_logger.save_tensors(epoch)
            tb_logger.writer.flush()

        validation_losses = defaultdict(functions.AverageMeter)
        logging.info('Running validation for epoch {}/{}'.format(
            epoch, train_options.number_of_epochs))
        val_image_patches = ()
        val_encoded_patches = ()
        val_noised_patches = ()
        for image, _ in tqdm(val_data, ncols=80):
            image = image.to(device)  #.squeeze(0)
            message = torch.Tensor(
                np.random.choice(
                    [0, 1],
                    (image.shape[0], hidden_config.message_length))).to(device)
            losses, (encoded_images, noised_images,
                     decoded_messages) = model.validate_on_batch(
                         [image, message])
            for name, loss in losses.items():
                validation_losses[name].update(loss)

            if hidden_config.enable_fp16:
                image = image.float()
                encoded_images = encoded_images.float()
            pick = np.random.randint(0, image.shape[0])
            val_image_patches += (F.interpolate(
                image[pick:pick + 1, :, :, :].cpu(),
                size=(hidden_config.W, hidden_config.H)), )
            val_encoded_patches += (F.interpolate(
                encoded_images[pick:pick + 1, :, :, :].cpu(),
                size=(hidden_config.W, hidden_config.H)), )
            val_noised_patches += (F.interpolate(
                noised_images[pick:pick + 1, :, :, :].cpu(),
                size=(hidden_config.W, hidden_config.H)), )

        if tb_logger is not None:
            tb_logger.save_losses('val_loss', validation_losses, epoch)
            tb_logger.writer.flush()

        val_image_patches = torch.stack(val_image_patches).squeeze(1)
        val_encoded_patches = torch.stack(val_encoded_patches).squeeze(1)
        val_noised_patches = torch.stack(val_noised_patches).squeeze(1)
        utils.save_images(val_image_patches[:images_to_save, :, :, :],
                          val_encoded_patches[:images_to_save, :, :, :],
                          val_noised_patches[:images_to_save, :, :, :],
                          epoch,
                          os.path.join(this_run_folder, 'images'),
                          resize_to=saved_images_size)

        curr_cond = validation_losses['encoder_mse'].avg + validation_losses[
            'bitwise-error'].avg
        if best_cond is None or curr_cond < best_cond:
            best_cond = curr_cond
            best_epoch = epoch

        utils.log_progress(validation_losses)
        logging.info('-' * 40)
        utils.save_checkpoint(model, train_options.experiment_name, epoch,
                              best_epoch, best_cond,
                              os.path.join(this_run_folder, 'checkpoints'))
        logging.info(
            f'Current best epoch = {best_epoch}, loss = {best_cond:.6f}')
        utils.write_losses(os.path.join(this_run_folder, 'validation.csv'),
                           validation_losses, epoch,
                           time.time() - epoch_start)
def train(model: Hidden,
          device: torch.device,
          hidden_config: HiDDenConfiguration,
          train_options: TrainingOptions,
          this_run_folder: str,
          tb_logger):
    """
    Trains the HiDDeN model
    :param model: The model
    :param device: torch.device object, usually this is GPU (if avaliable), otherwise CPU.
    :param hidden_config: The network configuration
    :param train_options: The training settings
    :param this_run_folder: The parent folder for the current training run to store training artifacts/results/logs.
    :param tb_logger: TensorBoardLogger object which is a thin wrapper for TensorboardX logger.
                Pass None to disable TensorboardX logging
    :return:
    """

    train_data, val_data = utils.get_data_loaders(hidden_config, train_options)
    block_size = hidden_config.block_size
    block_number = int(hidden_config.H/hidden_config.block_size)
    val_folder = train_options.validation_folder
    loss_type = train_options.loss_mode
    m_length = hidden_config.message_length
    alpha = train_options.alpha
    img_names = listdir(val_folder+"/valid_class")
    img_names.sort()
    out_folder = train_options.output_folder
    default = train_options.default
    beta = train_options.beta
    crop_width = int(beta*block_size)
    file_count = len(train_data.dataset)
    if file_count % train_options.batch_size == 0:
        steps_in_epoch = file_count // train_options.batch_size
    else:
        steps_in_epoch = file_count // train_options.batch_size + 1

    print_each = 10
    images_to_save = 8
    saved_images_size = (512, 512)
    icount = 0
    plot_block = []

    for epoch in range(train_options.start_epoch, train_options.number_of_epochs + 1):
        logging.info('\nStarting epoch {}/{}'.format(epoch, train_options.number_of_epochs))
        logging.info('Batch size = {}\nSteps in epoch = {}'.format(train_options.batch_size, steps_in_epoch))
        training_losses = defaultdict(AverageMeter)
        epoch_start = time.time()
        step = 1
        #train
        for image, _ in train_data:
            image = image.to(device)
            #crop imgs into blocks
            imgs, modified_imgs, entropies = cropImg(block_size,image,crop_width,alpha)
            bitwise_arr=[]
            main_losses = None
            encoded_imgs = []
            batch = 0 
            for img, modified_img, entropy in zip(imgs,modified_imgs, entropies):
                img=img.to(device)
                modified_img = modified_img.to(device)
                entropy = entropy.to(device)
                
                message = torch.Tensor(np.random.choice([0, 1], (img.shape[0], m_length))).to(device)
                losses, (encoded_images, noised_images, decoded_messages) = \
                    model.train_on_batch([img, message, modified_img, entropy,loss_type])
                encoded_imgs.append(encoded_images)
                batch = encoded_images.shape[0]
                #get loss in the last block
                if main_losses is None:
                    main_losses = losses
                    for k in losses:
                        main_losses[k] = losses[k]/len(imgs)
                else:
                    for k in main_losses:
                        main_losses[k] += losses[k]/len(imgs)

            #blocking effect loss calculation
            blocking_loss = blocking_value(encoded_imgs,batch,block_size,block_number)
          
            #update bitwise training loss
            for name, loss in main_losses.items():
                if(default == False  and name == 'blocking_effect'):
                    training_losses[name].update(blocking_loss)
                else:
                    training_losses[name].update(loss) 
            #statistic
            if step % print_each == 0 or step == steps_in_epoch:
                logging.info(
                    'Epoch: {}/{} Step: {}/{}'.format(epoch, train_options.number_of_epochs, step, steps_in_epoch))
                utils.log_progress(training_losses)
                logging.info('-' * 40)
            step += 1

        train_duration = time.time() - epoch_start
        logging.info('Epoch {} training duration {:.2f} sec'.format(epoch, train_duration))
        logging.info('-' * 40)
        utils.write_losses(os.path.join(this_run_folder, 'train.csv'), training_losses, epoch, train_duration)
        if tb_logger is not None:
            tb_logger.save_losses(training_losses, epoch)
            tb_logger.save_grads(epoch)
            tb_logger.save_tensors(epoch)

        first_iteration = True
        validation_losses = defaultdict(AverageMeter)
        logging.info('Running validation for epoch {}/{}'.format(epoch, train_options.number_of_epochs))

        #validation
        ep_blocking = 0
        ep_total = 0
     
        for image, _ in val_data:
            image = image.to(device)
            #crop imgs
            imgs, modified_imgs, entropies = cropImg(block_size,image,crop_width,alpha)
            bitwise_arr=[]
            main_losses = None
            encoded_imgs = []
            batch = 0
          
            for img, modified_img, entropy in zip(imgs,modified_imgs, entropies):
                img=img.to(device)
                modified_img = modified_img.to(device)
                entropy = entropy.to(device)
                
                message = torch.Tensor(np.random.choice([0, 1], (img.shape[0], m_length))).to(device)
                losses, (encoded_images, noised_images, decoded_messages) = \
                    model.train_on_batch([img, message, modified_img, entropy,loss_type])
                encoded_imgs.append(encoded_images)
                batch = encoded_images.shape[0]
                #get loss in the last block
                if main_losses is None:
                    main_losses = losses
                    for k in losses:
                        main_losses[k] = losses[k]/len(imgs)
                else:
                    for k in main_losses:
                        main_losses[k] += losses[k]/len(imgs)
                
            #blocking value for plotting
            blocking_loss = blocking_value(encoded_imgs,batch,block_size,block_number)
            ep_blocking = ep_blocking+ blocking_loss
            ep_total = ep_total+1

            for name, loss in main_losses.items():
                if(default == False  and name == 'blocking_effect'):
                    validation_losses[name].update(blocking_loss)
                else:
                    validation_losses[name].update(loss) 
            #concat image
            encoded_images = concatImgs(encoded_imgs,block_number)
            #save_image(encoded_images,"enc_img"+str(epoch)+".png")
            #save_image(image,"original_img"+str(epoch)+".png")
            if first_iteration:
                if hidden_config.enable_fp16:
                    image = image.float()
                    encoded_images = encoded_images.float()
                utils.save_images(image.cpu()[:images_to_save, :, :, :],
                                  encoded_images[:images_to_save, :, :, :].cpu(),
                                  epoch,
                                  os.path.join(this_run_folder, 'images'), resize_to=saved_images_size)
                first_iteration = False
            #save validation in the last epoch
            if(epoch == train_options.number_of_epochs):
                if(train_options.ats):
                    for i in range(0,batch):
                        image = encoded_images[i].cpu()
                        image = (image + 1) / 2
                        f_dst = out_folder+"/"+img_names[icount]
                        save_image(image,f_dst)
                        icount = icount+1
        #append block effect for plotting
        plot_block.append(ep_blocking/ep_total)
    
        utils.log_progress(validation_losses)
        logging.info('-' * 40)
        utils.save_checkpoint(model, train_options.experiment_name, epoch, os.path.join(this_run_folder, 'checkpoints'))
        utils.write_losses(os.path.join(this_run_folder, 'validation.csv'), validation_losses, epoch,
                           time.time() - epoch_start)
Ejemplo n.º 4
0
def train_own_noise(model: Hidden, device: torch.device,
                    hidden_config: HiDDenConfiguration,
                    train_options: TrainingOptions, this_run_folder: str,
                    tb_logger, noise):
    """
    Trains the HiDDeN model
    :param model: The model
    :param device: torch.device object, usually this is GPU (if avaliable), otherwise CPU.
    :param hidden_config: The network configuration
    :param train_options: The training settings
    :param this_run_folder: The parent folder for the current training run to store training artifacts/results/logs.
    :param tb_logger: TensorBoardLogger object which is a thin wrapper for TensorboardX logger.
                Pass None to disable TensorboardX logging
    :return:
    """

    train_data, val_data = utils.get_data_loaders(hidden_config, train_options)
    file_count = len(train_data.dataset)
    if file_count % train_options.batch_size == 0:
        steps_in_epoch = file_count // train_options.batch_size
    else:
        steps_in_epoch = file_count // train_options.batch_size + 1
    steps_in_epoch = 313

    print_each = 10
    images_to_save = 8
    saved_images_size = (
        512, 512)  # for qualitative check purpose to use a larger size

    for epoch in range(train_options.start_epoch,
                       train_options.number_of_epochs + 1):
        logging.info('\nStarting epoch {}/{}'.format(
            epoch, train_options.number_of_epochs))
        logging.info('Batch size = {}\nSteps in epoch = {}'.format(
            train_options.batch_size, steps_in_epoch))
        training_losses = defaultdict(AverageMeter)

        if train_options.video_dataset:
            random.shuffle(train_data.dataset)

        epoch_start = time.time()
        step = 1
        for image, _ in train_data:
            image = image.to(device)
            message = torch.Tensor(
                np.random.choice(
                    [0, 1],
                    (image.shape[0], hidden_config.message_length))).to(device)
            losses, _ = model.train_on_batch([image, message])

            for name, loss in losses.items():
                training_losses[name].update(loss)
            if step % print_each == 0 or step == steps_in_epoch:
                #import pdb; pdb.set_trace()
                logging.info('Epoch: {}/{} Step: {}/{}'.format(
                    epoch, train_options.number_of_epochs, step,
                    steps_in_epoch))
                utils.log_progress(training_losses)
                logging.info('-' * 40)
            step += 1
            if step == steps_in_epoch:
                break

        train_duration = time.time() - epoch_start
        logging.info('Epoch {} training duration {:.2f} sec'.format(
            epoch, train_duration))
        logging.info('-' * 40)
        utils.write_losses(os.path.join(this_run_folder, 'train.csv'),
                           training_losses, epoch, train_duration)
        if tb_logger is not None:
            tb_logger.save_losses(training_losses, epoch)
            tb_logger.save_grads(epoch)
            tb_logger.save_tensors(epoch)

        first_iteration = True
        validation_losses = defaultdict(AverageMeter)
        logging.info('Running validation for epoch {}/{} for noise {}'.format(
            epoch, train_options.number_of_epochs, noise))
        step = 1
        for image, _ in val_data:
            image = image.to(device)
            message = torch.Tensor(
                np.random.choice(
                    [0, 1],
                    (image.shape[0], hidden_config.message_length))).to(device)
            losses, (
                encoded_images, noised_images,
                decoded_messages) = model.validate_on_batch_specific_noise(
                    [image, message], noise=noise)
            for name, loss in losses.items():
                validation_losses[name].update(loss)
            if first_iteration:
                if hidden_config.enable_fp16:
                    image = image.float()
                    encoded_images = encoded_images.float()
                utils.save_images(
                    image.cpu()[:images_to_save, :, :, :],
                    encoded_images[:images_to_save, :, :, :].cpu(),
                    epoch,
                    os.path.join(this_run_folder, 'images'),
                    resize_to=saved_images_size)
                first_iteration = False
            step += 1
            if step == steps_in_epoch // 10:
                break

        utils.log_progress(validation_losses)
        logging.info('-' * 40)
        utils.save_checkpoint(model, train_options.experiment_name, epoch,
                              os.path.join(this_run_folder, 'checkpoints'))
        utils.write_losses(
            os.path.join(this_run_folder, 'validation_' + noise + '.csv'),
            validation_losses, epoch,
            time.time() - epoch_start)
Ejemplo n.º 5
0
def train(model: Hidden, device: torch.device,
          hidden_config: HiDDenConfiguration, train_options: TrainingOptions,
          this_run_folder: str, tb_logger, vocab):
    """
    Trains the HiDDeN model
    :param model: The model
    :param device: torch.device object, usually this is GPU (if avaliable), otherwise CPU.
    :param hidden_config: The network configuration
    :param train_options: The training settings
    :param this_run_folder: The parent folder for the current training run to store training artifacts/results/logs.
    :param tb_logger: TensorBoardLogger object which is a thin wrapper for TensorboardX logger.
                Pass None to disable TensorboardX logging
    :return:
    """

    train_data, val_data = utils.get_data_loaders(hidden_config, train_options,
                                                  vocab)
    file_count = len(train_data.dataset)
    if file_count % train_options.batch_size == 0:
        steps_in_epoch = file_count // train_options.batch_size
    else:
        steps_in_epoch = file_count // train_options.batch_size + 1

    print_each = 10
    images_to_save = 8
    saved_images_size = (512, 512)

    for epoch in range(train_options.start_epoch,
                       train_options.number_of_epochs + 1):
        logging.info('\nStarting epoch {}/{}'.format(
            epoch, train_options.number_of_epochs))
        logging.info('Batch size = {}\nSteps in epoch = {}'.format(
            train_options.batch_size, steps_in_epoch))
        training_losses = defaultdict(AverageMeter)
        epoch_start = time.time()
        step = 1
        for image, ekeys, dkeys, caption, length in train_data:
            image, caption, ekeys, dkeys = image.to(device), caption.to(
                device), ekeys.to(device), dkeys.to(device)

            losses, _ = model.train_on_batch(
                [image, ekeys, dkeys, caption, length])

            for name, loss in losses.items():
                training_losses[name].update(loss)
            if step % print_each == 0 or step == steps_in_epoch:
                logging.info('Epoch: {}/{} Step: {}/{}'.format(
                    epoch, train_options.number_of_epochs, step,
                    steps_in_epoch))
                utils.log_progress(training_losses)
                logging.info('-' * 40)
            step += 1

        train_duration = time.time() - epoch_start
        logging.info('Epoch {} training duration {:.2f} sec'.format(
            epoch, train_duration))
        logging.info('-' * 40)
        utils.write_losses(os.path.join(this_run_folder, 'train.csv'),
                           training_losses, epoch, train_duration)
        if tb_logger is not None:
            tb_logger.save_losses(training_losses, epoch)
            tb_logger.save_grads(epoch)
            tb_logger.save_tensors(epoch)

        first_iteration = True
        validation_losses = defaultdict(AverageMeter)
        logging.info('Running validation for epoch {}/{}'.format(
            epoch, train_options.number_of_epochs))
        for image, ekeys, dkeys, caption, length in val_data:
            image, caption, ekeys, dkeys = image.to(device), caption.to(
                device), ekeys.to(device), dkeys.to(device)

            losses, (encoded_images, noised_images, decoded_messages, predicted_sents) = \
                model.validate_on_batch([image, ekeys, dkeys, caption, length])

            #print(predicted)
            #exit()
            predicted_sents = predicted_sents.cpu().numpy()
            for i in range(train_options.batch_size):
                try:
                    #print(''.join([vocab.idx2word[int(w)] + ' ' for w in predicted.cpu().numpy()[i::train_options.batch_size]][1:length[i]-1]))
                    print("".join([
                        vocab.idx2word[int(idx)] + ' '
                        for idx in predicted_sents[i]
                    ]))
                    break
                except IndexError:
                    print(f'{i}th batch does not have enough length.')

            for name, loss in losses.items():
                validation_losses[name].update(loss)
            if first_iteration:
                if hidden_config.enable_fp16:
                    image = image.float()
                    encoded_images = encoded_images.float()
                utils.save_images(
                    image.cpu()[:images_to_save, :, :, :],
                    encoded_images[:images_to_save, :, :, :].cpu(),
                    epoch,
                    os.path.join(this_run_folder, 'images'),
                    resize_to=saved_images_size)
                first_iteration = False

        utils.log_progress(validation_losses)
        logging.info('-' * 40)
        utils.save_checkpoint(model, train_options.experiment_name, epoch,
                              os.path.join(this_run_folder, 'checkpoints'))
        utils.write_losses(os.path.join(this_run_folder, 'validation.csv'),
                           validation_losses, epoch,
                           time.time() - epoch_start)
Ejemplo n.º 6
0
def train(model: Hidden, device: torch.device,
          hidden_config: HiDDenConfiguration, train_options: TrainingOptions,
          this_run_folder: str, tb_logger):
    """
    Trains the HiDDeN model
    :param model: The model
    :param device: torch.device object, usually this is GPU (if avaliable), otherwise CPU.
    :param hidden_config: The network configuration
    :param train_options: The training settings
    :param this_run_folder: The parent folder for the current training run to store training artifacts/results/logs.
    :param tb_logger: TensorBoardLogger object which is a thin wrapper for TensorboardX logger.
                Pass None to disable TensorboardX logging
    :return:
    """

    train_data, val_data = utils.get_data_loaders(hidden_config, train_options)
    file_count = len(train_data.dataset)
    if file_count % train_options.batch_size == 0:
        steps_in_epoch = file_count // train_options.batch_size
    else:
        steps_in_epoch = file_count // train_options.batch_size + 1

    print_each = 10
    images_to_save = 8
    saved_images_size = (512, 512)

    for epoch in range(train_options.start_epoch,
                       train_options.number_of_epochs + 1):
        print('\nStarting epoch {}/{}'.format(epoch,
                                              train_options.number_of_epochs))
        print('Batch size = {}\nSteps in epoch = {}'.format(
            train_options.batch_size, steps_in_epoch))
        losses_accu = {}
        epoch_start = time.time()
        step = 1
        for image, _ in train_data:
            image = image.to(device)
            message = torch.Tensor(
                np.random.choice(
                    [0, 1],
                    (image.shape[0], hidden_config.message_length))).to(device)
            losses, _ = model.train_on_batch([image, message])
            if not losses_accu:  # dict is empty, initialize
                for name in losses:
                    losses_accu[name] = []

            for name, loss in losses.items():
                losses_accu[name].append(loss)
            if step % print_each == 0 or step == steps_in_epoch:
                print('Epoch: {}/{} Step: {}/{}'.format(
                    epoch, train_options.number_of_epochs, step,
                    steps_in_epoch))
                utils.print_progress(losses_accu)
                print('-' * 40)
            step += 1

        train_duration = time.time() - epoch_start
        print('Epoch {} training duration {:.2f} sec'.format(
            epoch, train_duration))
        print('-' * 40)
        utils.write_losses(os.path.join(this_run_folder, 'train.csv'),
                           losses_accu, epoch, train_duration)
        if tb_logger is not None:
            tb_logger.save_losses(losses_accu, epoch)
            tb_logger.save_grads(epoch)
            tb_logger.save_tensors(epoch)

        first_iteration = True

        print('Running validation for epoch {}/{}'.format(
            epoch, train_options.number_of_epochs))
        for image, _ in val_data:
            image = image.to(device)
            message = torch.Tensor(
                np.random.choice(
                    [0, 1],
                    (image.shape[0], hidden_config.message_length))).to(device)
            losses, (encoded_images, noised_images,
                     decoded_messages) = model.validate_on_batch(
                         [image, message])
            if not losses_accu:  # dict is empty, initialize
                for name in losses:
                    losses_accu[name] = []
            for name, loss in losses.items():
                losses_accu[name].append(loss)
            if first_iteration:
                utils.save_images(
                    image.cpu()[:images_to_save, :, :, :],
                    encoded_images[:images_to_save, :, :, :].cpu(),
                    epoch,
                    os.path.join(this_run_folder, 'images'),
                    resize_to=saved_images_size)
                first_iteration = False

        utils.print_progress(losses_accu)
        print('-' * 40)
        utils.save_checkpoint(model, epoch, losses_accu,
                              os.path.join(this_run_folder, 'checkpoints'))
        utils.write_losses(os.path.join(this_run_folder, 'validation.csv'),
                           losses_accu, epoch,
                           time.time() - epoch_start)
Ejemplo n.º 7
0
def train(args):
    cuda = torch.cuda.is_available()

    # Set seed for reproducibility.
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    data_dir = os.path.expanduser(args.data_dir)
    corpus = Corpus(data_dir,
                    headers=args.no_headers,
                    lower=args.lower,
                    chars=args.use_chars)
    train_data = batchify(corpus.train, args.batch_size)
    val_data = batchify(corpus.valid, args.batch_size)
    test_data = batchify(corpus.test, args.batch_size)
    if cuda:
        train_data, val_data, test_data = train_data.cuda(), val_data.cuda(
        ), test_data.cuda()

    # Logging
    print_args(args)
    print('Using cuda: {}'.format(cuda))
    print('Size of training set: {:,} tokens'.format(np.prod(
        train_data.size())))
    print('Size of validation set: {:,} tokens'.format(np.prod(
        val_data.size())))
    print('Size of test set: {:,} tokens'.format(np.prod(test_data.size())))
    print('Vocabulary size: {:,}'.format(corpus.vocab_size))
    print('Example data:')
    for k in range(100, 107):
        x = [corpus.dictionary.i2w[i] for i in train_data[k:args.order + k, 0]]
        y = [corpus.dictionary.i2w[train_data[k + args.order, 0]]]
        print(x, y)

    # Initialize model
    if args.resume:
        print(f'Resume training with model {args.checkpoint}...')
        with open(args.checkpoint, 'rb') as f:
            model = torch.load(f)
        model_data_checks(model, corpus, args)
    else:
        hidden_dims = list_hidden_dims(args.hidden_dims)
        model = NeuralNgram(order=args.order,
                            emb_dim=args.emb_dim,
                            vocab_size=corpus.vocab_size,
                            hidden_dims=hidden_dims,
                            dropout=args.dropout)
        if args.use_glove:
            print('Loading GloVe vectors...')
            model.load_glove(args.glove_dir, i2w=corpus.dictionary.i2w)
        if args.tied:
            print('Tying weights...')
            model.tie_weights()
        if cuda:
            model.cuda()

    parameters = [param for param in model.parameters() if param.requires_grad]
    optimizer = torch.optim.SGD(parameters, lr=args.lr)
    if args.softmax_approx:
        if args.unigram_proposal:
            unigram = normalize(Counter(corpus.train.numpy()))
            unigram = np.array([unigram[i] for i in range(len(unigram))])
        else:
            unigram = None
        criterion = ApproximateLoss(vocab_size=len(corpus.dictionary),
                                    method='importance',
                                    unigram=unigram)
    else:
        criterion = nn.CrossEntropyLoss()

    # Training
    print('Training...')
    losses = dict(train=[], val=[])

    lr = args.lr
    best_val_loss = None

    num_steps = train_data.size(0) - args.order - 1
    batch_order = np.arange(num_steps)

    t0 = time.time()
    try:
        for epoch in range(1, args.epochs + 1):
            model.train()
            epoch_start_time = time.time()
            np.random.shuffle(batch_order)
            for step in range(1, num_steps + 1):
                idx = batch_order[step - 1]
                x, y = get_batch(train_data, idx, args.order)

                # Forward pass
                logits = model(x)
                loss = criterion(logits, y)

                if args.debug:
                    # Debugging softmax approximation.
                    xe = nn.CrossEntropyLoss()
                    true_loss = xe(logits, y)
                    print(
                        'approx {:>3.2f}, true {:>3.2f}, diff {:>3.4f}'.format(
                            loss.data[0], true_loss.data[0],
                            true_loss.data[0] - loss.data[0]))

                # Update parameters
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Save loss.
                losses['train'].append(loss.cpu().data[0])

                if step % args.print_every == 0:
                    avg_loss = sum(
                        losses['train'][-args.print_every:]) / args.print_every
                    t1 = time.time()
                    steps_per_second = args.print_every / (t1 - t0)
                    print(
                        '| epoch {} | step {}/{} | loss {:.4f} | lr {:.3f} | '
                        'ngrams/sec {:.1f} | eta {}h{}m{}s'.format(
                            epoch, step, num_steps, avg_loss, lr,
                            steps_per_second * args.batch_size,
                            *clock_time(
                                (num_steps - step) / steps_per_second)))
                    t0 = time.time()

                if step % args.save_every == 0:
                    modelpath = os.path.join(args.save_dir,
                                             f'{args.name}.latest.pt')
                    with open(modelpath, 'wb') as f:
                        torch.save(model, f)

            print('Evaluating on validation set...')
            val_loss = evaluate(val_data, model, criterion)
            losses['val'].append(val_loss)
            print('-' * 89)
            print(
                '| end of epoch {:3d} | time {:5.2f}s | valid loss {:5.2f} | valid ppl {:8.2f}'
                .format(epoch, (time.time() - epoch_start_time), val_loss,
                        np.exp(val_loss)))
            print('-' * 89)

            if not best_val_loss or val_loss < best_val_loss:
                modelpath = os.path.join(args.save_dir, f'{args.name}.best.pt')
                with open(modelpath, 'wb') as f:
                    torch.save(model, f)
                best_val_loss = val_loss
            else:
                # Anneal the learning rate if no improvement has been seen in the validation dataset.
                lr /= 4.0
                optimizer = torch.optim.SGD(parameters, lr=lr)
    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early')

    write_losses(losses['train'], args.log_dir, name='train-losses')
    write_losses(losses['val'], args.log_dir, name='val-losses')

    print('Evaluating on test set...')
    test_loss = evaluate(test_data, model, criterion)
    print('=' * 89)
    print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
        test_loss, np.exp(test_loss)))
    print('=' * 89)