def main(): # device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') parser = argparse.ArgumentParser(description='Training of HiDDeN nets') parser.add_argument('--hostname', default=socket.gethostname(), help='the host name of the running server') # parser.add_argument('--size', '-s', default=128, type=int, help='The size of the images (images are square so this is height and width).') parser.add_argument('--data-dir', '-d', required=True, type=str, help='The directory where the data is stored.') parser.add_argument( '--runs_root', '-r', default=os.path.join('.', 'experiments'), type=str, help='The root folder where data about experiments are stored.') parser.add_argument('--batch-size', '-b', default=1, type=int, help='Validation batch size.') args = parser.parse_args() if args.hostname == 'ee898-System-Product-Name': args.data_dir = '/home/ee898/Desktop/chaoning/ImageNet' args.hostname = 'ee898' elif args.hostname == 'DL178': args.data_dir = '/media/user/SSD1TB-2/ImageNet' else: args.data_dir = '/workspace/data_local/imagenet_pytorch' assert args.data_dir print_each = 25 completed_runs = [ o for o in os.listdir(args.runs_root) if os.path.isdir(os.path.join(args.runs_root, o)) and o != 'no-noise-defaults' ] print(completed_runs) write_csv_header = True current_run = args.runs_root print(f'Run folder: {current_run}') options_file = os.path.join(current_run, 'options-and-config.pickle') train_options, hidden_config, noise_config = utils.load_options( options_file) train_options.train_folder = os.path.join(args.data_dir, 'val') train_options.validation_folder = os.path.join(args.data_dir, 'val') train_options.batch_size = args.batch_size checkpoint, chpt_file_name = utils.load_last_checkpoint( os.path.join(current_run, 'checkpoints')) print(f'Loaded checkpoint from file {chpt_file_name}') noiser = Noiser(noise_config, device, 'jpeg') model = Hidden(hidden_config, device, noiser, tb_logger=None) utils.model_from_checkpoint(model, checkpoint) print('Model loaded successfully. Starting validation run...') _, val_data = utils.get_data_loaders(hidden_config, train_options) file_count = len(val_data.dataset) if file_count % train_options.batch_size == 0: steps_in_epoch = file_count // train_options.batch_size else: steps_in_epoch = file_count // train_options.batch_size + 1 with torch.no_grad(): noises = ['webp_10', 'webp_25', 'webp_50', 'webp_75', 'webp_90'] for noise in noises: losses_accu = {} step = 0 for image, _ in val_data: step += 1 image = image.to(device) message = torch.Tensor( np.random.choice( [0, 1], (image.shape[0], hidden_config.message_length))).to(device) losses, ( encoded_images, noised_images, decoded_messages) = model.validate_on_batch_specific_noise( [image, message], noise=noise) if not losses_accu: # dict is empty, initialize for name in losses: losses_accu[name] = AverageMeter() for name, loss in losses.items(): losses_accu[name].update(loss) if step % print_each == 0 or step == steps_in_epoch: print(f'Step {step}/{steps_in_epoch}') utils.print_progress(losses_accu) print('-' * 40) # utils.print_progress(losses_accu) write_validation_loss(os.path.join(args.runs_root, 'validation_run.csv'), losses_accu, noise, checkpoint['epoch'], write_header=write_csv_header) write_csv_header = False
def train_own_noise(model: Hidden, device: torch.device, hidden_config: HiDDenConfiguration, train_options: TrainingOptions, this_run_folder: str, tb_logger, noise): """ Trains the HiDDeN model :param model: The model :param device: torch.device object, usually this is GPU (if avaliable), otherwise CPU. :param hidden_config: The network configuration :param train_options: The training settings :param this_run_folder: The parent folder for the current training run to store training artifacts/results/logs. :param tb_logger: TensorBoardLogger object which is a thin wrapper for TensorboardX logger. Pass None to disable TensorboardX logging :return: """ train_data, val_data = utils.get_data_loaders(hidden_config, train_options) file_count = len(train_data.dataset) if file_count % train_options.batch_size == 0: steps_in_epoch = file_count // train_options.batch_size else: steps_in_epoch = file_count // train_options.batch_size + 1 steps_in_epoch = 313 print_each = 10 images_to_save = 8 saved_images_size = ( 512, 512) # for qualitative check purpose to use a larger size for epoch in range(train_options.start_epoch, train_options.number_of_epochs + 1): logging.info('\nStarting epoch {}/{}'.format( epoch, train_options.number_of_epochs)) logging.info('Batch size = {}\nSteps in epoch = {}'.format( train_options.batch_size, steps_in_epoch)) training_losses = defaultdict(AverageMeter) if train_options.video_dataset: random.shuffle(train_data.dataset) epoch_start = time.time() step = 1 for image, _ in train_data: image = image.to(device) message = torch.Tensor( np.random.choice( [0, 1], (image.shape[0], hidden_config.message_length))).to(device) losses, _ = model.train_on_batch([image, message]) for name, loss in losses.items(): training_losses[name].update(loss) if step % print_each == 0 or step == steps_in_epoch: #import pdb; pdb.set_trace() logging.info('Epoch: {}/{} Step: {}/{}'.format( epoch, train_options.number_of_epochs, step, steps_in_epoch)) utils.log_progress(training_losses) logging.info('-' * 40) step += 1 if step == steps_in_epoch: break train_duration = time.time() - epoch_start logging.info('Epoch {} training duration {:.2f} sec'.format( epoch, train_duration)) logging.info('-' * 40) utils.write_losses(os.path.join(this_run_folder, 'train.csv'), training_losses, epoch, train_duration) if tb_logger is not None: tb_logger.save_losses(training_losses, epoch) tb_logger.save_grads(epoch) tb_logger.save_tensors(epoch) first_iteration = True validation_losses = defaultdict(AverageMeter) logging.info('Running validation for epoch {}/{} for noise {}'.format( epoch, train_options.number_of_epochs, noise)) step = 1 for image, _ in val_data: image = image.to(device) message = torch.Tensor( np.random.choice( [0, 1], (image.shape[0], hidden_config.message_length))).to(device) losses, ( encoded_images, noised_images, decoded_messages) = model.validate_on_batch_specific_noise( [image, message], noise=noise) for name, loss in losses.items(): validation_losses[name].update(loss) if first_iteration: if hidden_config.enable_fp16: image = image.float() encoded_images = encoded_images.float() utils.save_images( image.cpu()[:images_to_save, :, :, :], encoded_images[:images_to_save, :, :, :].cpu(), epoch, os.path.join(this_run_folder, 'images'), resize_to=saved_images_size) first_iteration = False step += 1 if step == steps_in_epoch // 10: break utils.log_progress(validation_losses) logging.info('-' * 40) utils.save_checkpoint(model, train_options.experiment_name, epoch, os.path.join(this_run_folder, 'checkpoints')) utils.write_losses( os.path.join(this_run_folder, 'validation_' + noise + '.csv'), validation_losses, epoch, time.time() - epoch_start)