def train(args, model, train_loader, test_loader, device, logger, optimizers): start_time = time.time() test_loader_iter = iter(test_loader) current_D_steps, train_generator = 0, True best_loss, best_test_loss, mean_epoch_loss = np.inf, np.inf, np.inf train_writer = SummaryWriter(os.path.join(args.tensorboard_runs, 'train')) test_writer = SummaryWriter(os.path.join(args.tensorboard_runs, 'test')) storage, storage_test = model.storage_train, model.storage_test amortization_opt, hyperlatent_likelihood_opt = optimizers['amort'], optimizers['hyper'] if model.use_discriminator is True: disc_opt = optimizers['disc'] for epoch in trange(args.n_epochs, desc='Epoch'): epoch_loss, epoch_test_loss = [], [] epoch_start_time = time.time() if epoch > 0: ckpt_path = utils.save_model(model, optimizers, mean_epoch_loss, epoch, device, args=args, logger=logger) model.train() for idx, (data, bpp) in enumerate(tqdm(train_loader, desc='Train'), 0): data = data.to(device, dtype=torch.float) try: if model.use_discriminator is True: # Train D for D_steps, then G, using distinct batches losses = model(data, train_generator=train_generator) compression_loss = losses['compression'] disc_loss = losses['disc'] if train_generator is True: optimize_compression_loss(compression_loss, amortization_opt, hyperlatent_likelihood_opt) train_generator = False else: optimize_loss(disc_loss, disc_opt) current_D_steps += 1 if current_D_steps == args.discriminator_steps: current_D_steps = 0 train_generator = True continue else: # Rate, distortion, perceptual only losses = model(data, train_generator=True) compression_loss = losses['compression'] optimize_compression_loss(compression_loss, amortization_opt, hyperlatent_likelihood_opt) except KeyboardInterrupt: # Note: saving not guaranteed! if model.step_counter > args.log_interval+1: logger.warning('Exiting, saving ...') ckpt_path = utils.save_model(model, optimizers, mean_epoch_loss, epoch, device, args=args, logger=logger) return model, ckpt_path else: return model, None if model.step_counter % args.log_interval == 1: epoch_loss.append(compression_loss.item()) mean_epoch_loss = np.mean(epoch_loss) best_loss = utils.log(model, storage, epoch, idx, mean_epoch_loss, compression_loss.item(), best_loss, start_time, epoch_start_time, batch_size=data.shape[0], avg_bpp=bpp.mean().item(), logger=logger, writer=train_writer) try: test_data, test_bpp = test_loader_iter.next() except StopIteration: test_loader_iter = iter(test_loader) test_data, test_bpp = test_loader_iter.next() best_test_loss, epoch_test_loss = test(args, model, epoch, idx, data, test_data, test_bpp, device, epoch_test_loss, storage_test, best_test_loss, start_time, epoch_start_time, logger, train_writer, test_writer) with open(os.path.join(args.storage_save, 'storage_{}_tmp.pkl'.format(args.name)), 'wb') as handle: pickle.dump(storage, handle, protocol=pickle.HIGHEST_PROTOCOL) model.train() # LR scheduling utils.update_lr(args, amortization_opt, model.step_counter, logger) utils.update_lr(args, hyperlatent_likelihood_opt, model.step_counter, logger) if model.use_discriminator is True: utils.update_lr(args, disc_opt, model.step_counter, logger) if model.step_counter > args.n_steps: logger.info('Reached step limit [args.n_steps = {}]'.format(args.n_steps)) break if (idx % args.save_interval == 1) and (idx > args.save_interval): ckpt_path = utils.save_model(model, optimizers, mean_epoch_loss, epoch, device, args=args, logger=logger) # End epoch mean_epoch_loss = np.mean(epoch_loss) mean_epoch_test_loss = np.mean(epoch_test_loss) logger.info('===>> Epoch {} | Mean train loss: {:.3f} | Mean test loss: {:.3f}'.format(epoch, mean_epoch_loss, mean_epoch_test_loss)) if model.step_counter > args.n_steps: break with open(os.path.join(args.storage_save, 'storage_{}_{:%Y_%m_%d_%H:%M:%S}.pkl'.format(args.name, datetime.datetime.now())), 'wb') as handle: pickle.dump(storage, handle, protocol=pickle.HIGHEST_PROTOCOL) ckpt_path = utils.save_model(model, optimizers, mean_epoch_loss, epoch, device, args=args, logger=logger) args.ckpt = ckpt_path logger.info("Training complete. Time elapsed: {:.3f} s. Number of steps: {}".format((time.time()-start_time), model.step_counter)) return model, ckpt_path
def train(args, model, train_loader, test_loader, device, logger, optimizers, bpp): start_time = time.time() test_loader_iter = iter(test_loader) current_D_steps, train_generator = 0, True best_loss, best_test_loss, mean_epoch_loss = np.inf, np.inf, np.inf train_writer = SummaryWriter(os.path.join(args.tensorboard_runs, 'train')) test_writer = SummaryWriter(os.path.join(args.tensorboard_runs, 'test')) storage, storage_test = model.storage_train, model.storage_test Ntrain = len(train_loader) classi_loss_total_train, classi_acc_total_train = torch.Tensor(Ntrain), torch.Tensor(Ntrain) classi_opt, amortization_opt, hyperlatent_likelihood_opt = optimizers['classi'], optimizers['amort'], optimizers['hyper'] #end_of_epoch_metrics(args, model, train_loader, device, logger) if model.use_discriminator is True: disc_opt = optimizers['disc'] best_mean_test_classi_loss_total = 10000000000 for epoch in trange(args.n_epochs, desc='Epoch'): epoch_loss, epoch_test_loss = [], [] epoch_start_time = time.time() if epoch > 0: ckpt_path = utils.save_model(model, optimizers, mean_epoch_loss, epoch, device, args=args, logger=logger) model.train() test_index = 0 test_acc_total= 0 mean_test_acc_total = 0 test_classi_loss_total = 0 for idx, (data, y) in enumerate(tqdm(train_loader, desc='Train'), 0): #if idx == 10: # break data = data.to(device, dtype=torch.float) y = y.to(device) try: if model.use_classiOnly is True: losses = model(data, y, train_generator=False) classi_loss = losses['classi'] classi_acc = losses['classi_acc'] compression_loss = losses['compression'] optimize_loss(classi_loss, classi_opt) classi_loss_total_train[idx] = classi_loss.data classi_acc_total_train[idx] = classi_acc.data model.step_counter += 1 else: if model.use_discriminator is True: # Train D for D_steps, then G, using distinct batches losses = model(data, y, train_generator=train_generator) compression_loss = losses['compression'] disc_loss = losses['disc'] if train_generator is True: optimize_compression_loss(compression_loss, amortization_opt, hyperlatent_likelihood_opt) train_generator = False else: optimize_loss(disc_loss, disc_opt) current_D_steps += 1 if current_D_steps == args.discriminator_steps: current_D_steps = 0 train_generator = True continue else: # Rate, distortion, perceptual only losses = model(data, y, train_generator=True) compression_loss = losses['compression'] optimize_compression_loss(compression_loss, amortization_opt, hyperlatent_likelihood_opt) except KeyboardInterrupt: # Note: saving not guaranteed! if model.step_counter > args.log_interval+1: logger.warning('Exiting, saving ...') ckpt_path = utils.save_model(model, optimizers, mean_epoch_loss, epoch, device, args=args, logger=logger) return model, ckpt_path else: return model, None if model.step_counter % args.log_interval == 0: epoch_loss.append(compression_loss.item()) mean_epoch_loss = np.mean(epoch_loss) #best_loss = utils.log(model, storage, epoch, idx, mean_epoch_loss, compression_loss.item(), # best_loss, start_time, epoch_start_time, batch_size=data.shape[0], # avg_bpp=bpp, logger=logger, writer=train_writer) try: test_data, ytest = test_loader_iter.next() except StopIteration: test_loader_iter = iter(test_loader) test_data, ytest = test_loader_iter.next() ytest = ytest.to(device) best_test_loss, epoch_test_loss, mean_test_acc, mean_test_classi_loss = test(args, model, epoch, idx, data, y, test_data, ytest, device, epoch_test_loss, storage_test, best_test_loss, start_time, epoch_start_time, logger, train_writer, test_writer) test_index = test_index + 1 test_classi_loss_total = test_classi_loss_total + mean_test_classi_loss mean_test_classi_loss_total = test_classi_loss_total/test_index test_acc_total = test_acc_total + mean_test_acc mean_test_acc_total = test_acc_total/test_index with open(os.path.join(args.storage_save, 'storage_{}_tmp.pkl'.format(args.name)), 'wb') as handle: pickle.dump(storage, handle, protocol=pickle.HIGHEST_PROTOCOL) model.train() if model.step_counter > args.n_steps: logger.info('Reached step limit [args.n_steps = {}]'.format(args.n_steps)) break # LR scheduling if model.use_classiOnly is True: utils.update_lr(args, classi_opt, model.step_counter, logger) utils.update_lr(args, amortization_opt, model.step_counter, logger) utils.update_lr(args, hyperlatent_likelihood_opt, model.step_counter, logger) if model.use_discriminator is True: utils.update_lr(args, disc_opt, model.step_counter, logger) if mean_test_classi_loss_total < best_mean_test_classi_loss_total: logger.info(f'Classi_loss decreased to : {mean_test_classi_loss:.3f}. Saving Model') best_mean_test_classi_loss_total = mean_test_classi_loss_total ckpt_path = utils.save_model(model, optimizers, mean_epoch_loss, epoch, device, args=args, logger=logger) # End epoch mean_epoch_loss = np.mean(epoch_loss) mean_epoch_test_loss = np.mean(epoch_test_loss) logger.info('===>> Epoch {} | Mean train loss: {:.3f} | Mean test loss: {:.3f} | Mean test classi acc: {:.3f}'.format(epoch, mean_epoch_loss, mean_epoch_test_loss, mean_test_acc_total)) logger.info(f'ClassiLossTrain: mean={classi_loss_total_train.mean(dim=0):.3f}, std={classi_loss_total_train.std(dim=0):.3f}') logger.info(f'ClassiAccTrain: mean={classi_acc_total_train.mean(dim=0):.3f}, std={classi_acc_total_train.std(dim=0):.3f}') #end_of_epoch_metrics(args, model, train_loader, device, logger) #end_of_epoch_metrics(args, model, test_loader, device, logger) if model.step_counter > args.n_steps: break with open(os.path.join(args.storage_save, 'storage_{}_{:%Y_%m_%d_%H:%M:%S}.pkl'.format(args.name, datetime.datetime.now())), 'wb') as handle: pickle.dump(storage, handle, protocol=pickle.HIGHEST_PROTOCOL) ckpt_path = utils.save_model(model, optimizers, mean_epoch_loss, epoch, device, args=args, logger=logger) args.ckpt = ckpt_path logger.info("Training complete. Time elapsed: {:.3f} s. Number of steps: {}".format((time.time()-start_time), model.step_counter)) return model, ckpt_path