def plot(self, model_name, metric_name, metric_value): if model_name not in self._writers: self._writers[model_name] = TensorboardLogger( os.path.join(self._log_dir, model_name)) if model_name not in self._steps: self._steps[model_name] = Counter() self._writers[model_name].log_value( metric_name, metric_value, step=self._steps[model_name][metric_name]) self._steps[model_name][metric_name] += 1 serialize(self._steps_path, self._steps)
def __init__(self, config: Configuration): self.config = config self.tensorboard_logger = TensorboardLogger( self.config.output_directory) self.buffer = PrioritizedBuffer( capacity=config.replay_capacity, epsilon=config.replay_min_priority, alpha=config.replay_prioritization_factor, max_priority=config.replay_max_priority, ) self.beta = config.replay_importance_weight self.stats = LearnerStatistics(self.config, self.tensorboard_logger, self.buffer) learner_address = config.learner_ip_address + ":" + config.starting_port self._connect_sockets(learner_address)
def __init__(self, opt): self.exp_name = opt['name'] self.use_tb_logger = opt['use_tb_logger'] self.opt = opt['logger'] self.log_dir = opt['path']['log'] # loss log file self.loss_log_path = os.path.join(self.log_dir, 'loss_log.txt') with open(self.loss_log_path, "a") as log_file: log_file.write('=============== Time: ' + get_timestamp() + ' =============\n') log_file.write('================ Training Losses ================\n') # val results log file self.val_log_path = os.path.join(self.log_dir, 'val_log.txt') with open(self.val_log_path, "a") as log_file: log_file.write('================ Time: ' + get_timestamp() + ' ===============\n') log_file.write('================ Validation Results ================\n') if self.use_tb_logger and 'debug' not in self.exp_name: from tensorboard_logger import Logger as TensorboardLogger self.tb_logger = TensorboardLogger('../tb_logger/' + self.exp_name)
def main(): args = get_args() config = Configuration("./apex/config.json") tensorboard_logger = TensorboardLogger(config.output_directory, args.actor_index) actor = Actor(config, args.actor_index, args.starting_port, tensorboard_logger) enemy_agents = [] for _ in range(config.snakes - 1): enemy_agents.append(EnemyActor(actor)) env = BattlesnakeEnvironment( config, enemy_agents=enemy_agents, output_directory=f"{config.output_directory}/actor-{args.actor_index}", actor_idx=args.actor_index, tensorboard_logger=tensorboard_logger, ) wait_for_initial_parameters(actor) while True: state = env.reset() terminal = False while not terminal: if env.stats.steps > config.random_initial_steps: action, greedy = actor.act(state) else: action = np.random.choice(3) greedy = False next_state, reward, terminal = env.step(action) actor.observe( Observation(state, action, reward, next_state, config.discount_factor, greedy)) state = next_state if env.stats.episodes % config.parameter_update_interval == 0: actor.update_parameters() if env.stats.episodes % config.report_interval == 0: env.stats.report() if env.stats.episodes % (config.render_interval) == 0: env.render()
def __init__(self, opt, tb_logger_suffix=''): self.exp_name = opt['name'] self.use_tb_logger = opt['use_tb_logger'] self.opt = opt['logger'] self.log_dir = opt['path']['log'] if not os.path.isdir(self.log_dir): os.mkdir(self.log_dir) # loss log file self.loss_log_path = os.path.join(self.log_dir, 'loss_log.txt') with open(self.loss_log_path, 'a') as log_file: log_file.write('=============== Time: ' + get_timestamp() + ' =============\n') log_file.write( '================ Training Losses ================\n') # val results log file self.val_log_path = os.path.join(self.log_dir, 'val_log.txt') with open(self.val_log_path, 'a') as log_file: log_file.write('================ Time: ' + get_timestamp() + ' ===============\n') log_file.write( '================ Validation Results ================\n') if self.use_tb_logger: # and 'debug' not in self.exp_name: from tensorboard_logger import Logger as TensorboardLogger logger_dir_num = 0 tb_logger_dir = self.log_dir.replace('experiments', 'logs') if not os.path.isdir(tb_logger_dir): os.mkdir(tb_logger_dir) existing_dirs = sorted([ dir.split('_')[0] for dir in os.listdir(tb_logger_dir) if os.path.isdir(os.path.join(tb_logger_dir, dir)) ], key=lambda x: int(x.split('_')[0])) if len(existing_dirs) > 0: logger_dir_num = int(existing_dirs[-1]) + 1 self.tb_logger = TensorboardLogger( os.path.join(tb_logger_dir, str(logger_dir_num) + tb_logger_suffix))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--name', '-n', type=str) args = parser.parse_args() experiment_name = args.name HYPARAMS = load_json( './hyparams/nec_hyparams.json')[experiment_name]['hyparams'] logger.debug('experiment_name: {} hyparams: {}'.format( experiment_name, HYPARAMS)) # make checkpoint path experiment_logdir = 'experiments/{}'.format(experiment_name) if not os.path.exists(experiment_logdir): os.makedirs(experiment_logdir) # write to tensorboard tensorboard_logdir = '{}/tensorboard'.format(experiment_logdir) if not os.path.exists(tensorboard_logdir): os.mkdir(tensorboard_logdir) writer = TensorboardLogger(logdir=tensorboard_logdir) env = gym.make('CartPole-v0') agent = NECAgent(input_dim=env.observation_space.shape[0], encode_dim=32, hidden_dim=64, output_dim=env.action_space.n, capacity=HYPARAMS['capacity'], buffer_size=HYPARAMS['buffer_size'], epsilon_start=HYPARAMS['epsilon_start'], epsilon_end=HYPARAMS['epsilon_end'], decay_factor=HYPARAMS['decay_factor'], lr=HYPARAMS['lr'], p=HYPARAMS['p'], similarity_threshold=HYPARAMS['similarity_threshold'], alpha=HYPARAMS['alpha']) global_steps = 0 for episode in range(HYPARAMS['episodes']): state = env.reset() counter = 0 while True: n_steps_q = 0 start_state = state # N-steps Q estimate for step in range(HYPARAMS['horizon']): state_tensor = torch.from_numpy(state).float().unsqueeze(0) action_tensor, value_tensor, encoded_state_tensor = agent.epsilon_greedy_infer( state_tensor) if step == 0: start_action = action_tensor.item() start_encoded_state = encoded_state_tensor # env.render() if global_steps > HYPARAMS['warmup_steps']: action = action_tensor.item() agent.epsilon_decay() else: action = env.action_space.sample() logger.debug( 'episode: {} global_steps: {} value: {} action: {} state: {} epsilon: {}' .format(episode, global_steps, value_tensor.item(), action, state, agent.epsilon)) next_state, reward, done, info = env.step(action) counter += 1 global_steps += 1 writer.log_training_v2(global_steps, { 'train/value': value_tensor.item(), }) n_steps_q += (HYPARAMS['gamma']**step) * reward if done: break state = next_state n_steps_q += (HYPARAMS['gamma']**HYPARAMS['horizon'] ) * agent.get_target_n_steps_q().item() writer.log_training_v2(global_steps, { 'sampled/n_steps_q': n_steps_q, }) logger.debug('sample n_steps_q: {}'.format(n_steps_q)) # append to ReplayBuffer and DND agent.remember_to_replay_buffer(start_state, start_action, n_steps_q) agent.remember_to_dnd(start_encoded_state, start_action, n_steps_q) if global_steps / HYPARAMS['horizon'] > HYPARAMS['batch_size']: agent.replay(batch_size=HYPARAMS['batch_size']) if done: # update dnd writer.log_episode(episode + 1, counter) logger.info('episode done! episode: {} score: {}'.format( episode, counter)) logger.debug('dnd[0] len: {}'.format(len(agent.dnd_list[0]))) logger.debug('dnd[1] len: {}'.format(len(agent.dnd_list[1]))) break
def main(): parser = argparse.ArgumentParser(description='Chainer example: MNIST') parser.add_argument('--batchsize', '-b', type=int, default=100, help='Number of images in each mini-batch') parser.add_argument('--epoch', '-e', type=int, default=20, help='Number of sweeps over the dataset to train') parser.add_argument('--frequency', '-f', type=int, default=-1, help='Frequency of taking a snapshot') parser.add_argument('--gpu', '-g', type=int, default=-1, help='GPU ID (negative value indicates CPU)') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--resume', '-r', default='', help='Resume the training from snapshot') parser.add_argument('--unit', '-u', type=int, default=1000, help='Number of units') parser.add_argument('--noplot', dest='plot', action='store_false', help='Disable PlotReport extension') parser.add_argument('--log-dir', default=None, help='directory to output TensorBoard event file (default: runs/<DATETIME>)') args = parser.parse_args() print('GPU: {}'.format(args.gpu)) print('# unit: {}'.format(args.unit)) print('# Minibatch-size: {}'.format(args.batchsize)) print('# epoch: {}'.format(args.epoch)) print('') if args.log_dir is None: args.log_dir = os.path.join('runs', datetime.now().strftime('%b%d_%H-%M-%S')) writer = SummaryWriter(log_dir=args.log_dir) # Set up a neural network to train # Classifier reports softmax cross entropy loss and accuracy at every # iteration, which will be used by the PrintReport extension below. model = L.Classifier(MLP(args.unit, 10)) if args.gpu >= 0: # Make a specified GPU current chainer.cuda.get_device_from_id(args.gpu).use() model.to_gpu() # Copy the model to the GPU # Setup an optimizer optimizer = chainer.optimizers.Adam() optimizer.setup(model) # Load the MNIST dataset train, test = chainer.datasets.get_mnist() train_iter = chainer.iterators.SerialIterator(train, args.batchsize) test_iter = chainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) # Set up a trainer updater = training.StandardUpdater( train_iter, optimizer, device=args.gpu) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) # Evaluate the model with the test dataset for each epoch trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu)) # Dump a computational graph from 'loss' variable at the first iteration # The "main" refers to the target link of the "main" optimizer. trainer.extend(extensions.dump_graph('main/loss')) # Take a snapshot for each specified epoch frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport()) # Save two plot images to the result dir if args.plot and extensions.PlotReport.available(): trainer.extend( extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png')) trainer.extend( extensions.PlotReport( ['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png')) # Print selected entries of the log to stdout # Here "main" refers to the target link of the "main" optimizer again, and # "validation" refers to the default name of the Evaluator extension. # Entries other than 'epoch' are reported by the Classifier link, called by # either the updater or the evaluator. trainer.extend(extensions.PrintReport( ['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) # Print a progress bar to stdout trainer.extend(extensions.ProgressBar()) # Write training log to TensorBoard log file trainer.extend(TensorboardLogger(writer)) if args.resume: # Resume from a snapshot chainer.serializers.load_npz(args.resume, trainer) # Run the training trainer.run()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--name', '-n', required=True, type=str, help='name of experiment') parser.add_argument('--render', action='store_true', help='render gym') args = parser.parse_args() experiment_name = args.name is_render = args.render hyparams = load_json( './hyparams/dqn_hyparams.json')[experiment_name]['hyparams'] # make checkpoint path experiment_logdir = 'experiments/{}'.format(experiment_name) if not os.path.exists(experiment_logdir): os.makedirs(experiment_logdir) # hyparameters lr = hyparams['lr'] buffer_size = hyparams['buffer_size'] gamma = hyparams['gamma'] epsilon_start = hyparams['epsilon_start'] epsilon_end = hyparams['epsilon_end'] decay_factor = hyparams['decay_factor'] batch_size = hyparams['batch_size'] replay_freq = hyparams['replay_freq'] target_update_freq = hyparams['target_update_freq'] episodes = hyparams['episodes'] warmup_steps = hyparams['warmup_steps'] # max_steps = 1e10 logger.debug('experiment_name: {} hyparams: {}'.format( experiment_name, hyparams)) # write to tensorboard tensorboard_logdir = '{}/tensorboard'.format(experiment_logdir) if not os.path.exists(tensorboard_logdir): os.mkdir(tensorboard_logdir) writer = TensorboardLogger(logdir=tensorboard_logdir) env = gym.make('CartPole-v0') env.reset() # logger.debug('observation_space.shape: {}'.format(env.observation_space.shape)) agent = DQNAgent(buffer_size, writer=writer, input_dim=env.observation_space.shape[0], output_dim=env.action_space.n, gamma=gamma, epsilon_start=epsilon_start, epsilon_end=epsilon_end, decay_factor=decay_factor) state, _, _, _ = env.step( env.action_space.sample()) # take a random action to start with writer.add_graph( agent.policy_network, torch.tensor([state], dtype=torch.float32)) # add model graph to tensorboard # state, reward, done, info = env.step(env.action_space.sample()) # take a random action to start with # for i in range(50): # agent.remember(state, reward, env.action_space.sample(), state, False) # for i in range(50): # agent.remember(state, reward, env.action_space.sample(), state, True) # loss = agent.replay(batch_size=5) global_steps = 0 for episode in range(episodes): score = 0.0 total_loss = 0.0 env.reset() logger.debug('env.reset() episode {} starts!'.format(episode)) # update target_network if episode % target_update_freq == 0: # 1. test replay_bufer # logger.debug('step: {} number of samples in bufer: {} sample: {}'.format(step, len(agent.replay_buffer), agent.replay_buffer.get_batch(2))) agent.update_target_network() for step in count(): if is_render: env.render() action_tensor, value_tensor = agent.epsilon_greedy_infer( torch.tensor([state], dtype=torch.float32)) target_value_tensor = agent.evaluate_state( torch.tensor([state], dtype=torch.float32)) # temp: for debug next_state, reward, done, info = env.step( action_tensor.item()) # take a random action # action = env.action_space.sample() # next_state, reward, done, info = env.step(action) # take a random action # logger.debug('episode: {} state: {} reward: {} action: {} next_state: {} done: {}'.format(episode, state, reward, action, next_state, done)) agent.remember(state, reward, action_tensor.item(), next_state, done) # 2. test QNetwork # logger.debug('state_tensor: {} action_tensor: {} value_tensor: {}'.format(state_tensor, action_tensor, value_tensor)) # logger.debug('state_tensor: {} action: {} value: {}'.format(state_tensor, action_tensor.item(), value_tensor.item())) # print('state: {} reward: {} action_tensor.item(): {} next_state: {} done: {}'.format(state, reward, action_tensor.item(), next_state, done)) score += reward # experience replay if global_steps > max( batch_size, warmup_steps) and global_steps % replay_freq == 0: loss = agent.replay(batch_size) total_loss += loss logger.debug( 'episode: {} done: {} global_steps: {} loss: {}'.format( episode, done, global_steps, loss)) writer.log_training(global_steps, loss, agent.lr, value_tensor.item(), target_value_tensor.item()) writer.add_scalar('epsilon', agent.epsilon, global_steps) # FIXME # if global_steps > max(batch_size, warmup_steps) and global_steps % 1000: # writer.log_linear_weights(global_steps, 'encoder.0.weight', agent.policy_network.get_weights()['encoder.0.weight']) agent.epsilon_decay() state = next_state # update state manually global_steps += 1 if done: logger.info('episode done! episode: {} score: {}'.format( episode, score)) writer.log_episode(episode, score, total_loss / (step + 1)) # save checkpoints if global_steps > max(batch_size, warmup_steps) and episode % 100 == 0: agent.save_checkpoint(experiment_logdir) break # logger.debug('state_tensor: {} action_tensor: {} value_tensor: {}'.format(state_tensor, action_tensor, value_tensor)) # logger.debug('output: {} state_tensor: {} state: {}'.format(output, state_tensor, state)) # agent.remember(state, reward, action, next_state, done) env.close()
def train_generator(self, current_loop_num): BaseLayer.set_model_parameter_requires_grad_all(self.generator, True) BaseLayer.set_model_parameter_requires_grad_all( self.discriminator, False) # train generator # TensorboardLogger.print_parameter(generator) for index in range(0, self.opt.generator_train_num): train_z = self.Tensor( np.random.normal(loc=0, scale=1, size=(self.opt.batch_size, self.opt.latent_dim))) fake_imgs, fake_dlatents_out = self.generator(train_z) fake_validity = self.discriminator(fake_imgs) prob_fake = F.sigmoid(fake_validity).mean() TensorboardLogger.write_scalar('prob_fake/generator', prob_fake) # print('{} prob_fake(generator): {}'.format(index, prob_fake)) g_loss = self.generator_loss(fake_validity) self.optimizer_g.zero_grad() g_loss.backward() self.optimizer_g.step() run_g_reg = current_loop_num % self.opt.g_reg_interval == 0 if run_g_reg: # generatorの正則化処理 g_reg_maxcount = 4 if 4 < self.opt.generator_train_num else self.opt.generator_train_num for _ in range(0, g_reg_maxcount): z = self.Tensor( np.random.normal(loc=0, scale=1, size=(self.opt.batch_size, self.opt.latent_dim))) pl_fake_imgs, pl_fake_dlatents_out = self.generator(z) g_reg, pl_lenght = self.generator_loss_path_reg( pl_fake_imgs, pl_fake_dlatents_out) self.optimizer_g.zero_grad() g_reg.backward() self.optimizer_g.step() TensorboardLogger.write_scalar('loss/g_reg', g_reg) TensorboardLogger.write_scalar('loss/path_length', pl_lenght) TensorboardLogger.write_scalar( 'loss/pl_mean_var', self.generator_loss_path_reg.pl_mean_var.mean()) # 推論用のgeneratorに指数移動平均を行った重みを適用する Generator.apply_decay_parameters(self.generator, self.generator_predict, decay=self.decay) fake_imgs_predict, fake_dlatents_out_predict = self.generator_predict( train_z) fake_predict_validity = self.discriminator(fake_imgs_predict) prob_fake_predict = F.sigmoid(fake_predict_validity).mean() TensorboardLogger.write_scalar('prob_fake_predict/generator', prob_fake_predict) # print('prob_fake_predict(generator): {}'.format(prob_fake_predict)) Generator.apply_decay_parameters(self.generator_predict, self.generator, decay=self.opt.reverse_decay) if current_loop_num % self.opt.save_metrics_interval == 0: TensorboardLogger.write_scalar('score/g_score', fake_validity.mean()) TensorboardLogger.write_scalar('loss/g_loss', g_loss) TensorboardLogger.write_histogram('generator/fake_imgs', fake_imgs) TensorboardLogger.write_histogram('generator/fake_dlatents_out', fake_dlatents_out) TensorboardLogger.write_histogram('generator/fake_imgs_predict', fake_imgs_predict) TensorboardLogger.write_histogram( 'generator/fake_dlatents_out_predict', fake_dlatents_out_predict) if current_loop_num % self.opt.save_images_tensorboard_interval == 0: # for index in range(fake_imgs.shape[0]): # img = adjust_dynamic_range(fake_imgs[index].to('cpu').detach().numpy(), drange_in=[-1, 1], drange_out=[0, 255]) # TensorboardLogger.write_image('images/fake/{}'.format(index), img) for index in range(fake_imgs_predict.shape[0]): img = adjust_dynamic_range( fake_imgs_predict[index].to('cpu').detach().numpy(), drange_in=[-1, 1], drange_out=[0, 255]) TensorboardLogger.write_image( 'images/fake_predict/{}'.format(index), img) if current_loop_num % self.opt.save_images_interval == 0: # 生成した画像を保存する if not os.path.isdir(self.opt.results): os.makedirs(self.opt.results, exist_ok=True) # fake_imgs_val, fake_dlatents_out_val = generator(val_z) # save_image_grid( # # fake_imgs_val.to('cpu').detach().numpy(), # fake_imgs.to('cpu').detach().numpy(), # os.path.join(self.opt.results, '{}_fake.png'.format(TensorboardLogger.global_step)), # batch_size=self.opt.batch_size, # drange=[-1, 1]) # fake_imgs_predict_val, fake_dlatents_out_predict_val = generator_predict(val_z) save_image_grid(fake_imgs_predict.to('cpu').detach().numpy(), os.path.join( self.opt.results, '{}_fake_predict.png'.format( TensorboardLogger.global_step)), batch_size=self.opt.batch_size, drange=[-1, 1]) return g_loss
def calculate_fid_score(self): fid_score = self.fid.get_score() TensorboardLogger.write_scalar('score/fid', fid_score)
def train_discriminator(self, current_loop_num): BaseLayer.set_model_parameter_requires_grad_all(self.generator, False) BaseLayer.set_model_parameter_requires_grad_all( self.discriminator, True) # train discriminator for index in range(0, self.opt.discriminator_train_num): data_iterator = self.dataloader.__iter__() imgs = data_iterator.next() # imgs = TranformDynamicRange.fade_lod(x=imgs, lod=0.0) # imgs = TranformDynamicRange.upscale_lod(x=imgs, lod=0.0) real_imgs = Variable(imgs.type(self.Tensor), requires_grad=False) z = self.Tensor( np.random.normal(loc=0, scale=1, size=(self.opt.batch_size, self.opt.latent_dim))) fake_imgs, fake_dlatents_out = self.generator(z) real_validity = self.discriminator(real_imgs) prob_real = F.sigmoid(real_validity).mean() TensorboardLogger.write_scalar('prob_real/discriminator', prob_real) # print('{} prob_real(discriminator): {}'.format(index, prob_real)) fake_validity = self.discriminator(fake_imgs) prob_fake = F.sigmoid(fake_validity).mean() TensorboardLogger.write_scalar('prob_fake/discriminator', prob_fake) # print('{} prob_fake(discriminator): {}'.format(index, prob_fake)) d_loss = self.discriminator_loss(fake_validity, real_validity) self.optimizer_d.zero_grad() d_loss.backward() self.optimizer_d.step() run_d_reg = current_loop_num % self.opt.d_reg_interval == 0 if run_d_reg: d_reg_maxcount = 4 if 4 < self.opt.discriminator_train_num else self.opt.discriminator_train_num for index in range(0, d_reg_maxcount): # discriminatorの正則化処理 # z = self.Tensor(np.random.normal(loc=0, scale=1, size=(self.opt.batch_size, self.opt.latent_dim))) # fake_imgs, fake_dlatents_out = self.generator(z) # fake_validity = self.discriminator(fake_imgs) real_imgs.requires_grad = True real_validity = self.discriminator(real_imgs) d_reg = self.discriminator_loss_r1(real_validity, real_imgs) self.optimizer_d.zero_grad() d_reg.backward() self.optimizer_d.step() TensorboardLogger.writer.add_scalar( '{}/reg/d_reg'.format(TensorboardLogger.now), d_reg, TensorboardLogger.global_step) if current_loop_num % self.opt.save_metrics_interval == 0: TensorboardLogger.write_scalar('score/d_score', real_validity.mean()) TensorboardLogger.write_scalar('loss/d_loss', d_loss) TensorboardLogger.write_histogram('real_imgs', real_imgs) return d_loss
def train_lstm( model: LstmModel, criterion: torch.nn.modules.loss, optimizer: torch.optim, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, device: str, verbose: bool, n_epochs: int, kwargs_writer: Dict[str, str] = None, ) -> NoReturn: """Short summary. Parameters ---------- model : LstmModel Description of parameter `model`. criterion : torch.nn.modules.loss Description of parameter `criterion`. optimizer : torch.optim Description of parameter `optimizer`. train_loader : torch.utils.data.dataloader.DataLoader Description of parameter `train_loader`. val_loader : torch.utils.data.dataloader.DataLoader Description of parameter `val_loader`. device : str Description of parameter `device`. verbose : bool Description of parameter `verbose`. n_epochs : int Description of parameter `n_epochs`. kwargs_writer : Dict[str, str] Description of parameter `kwargs_writer`. Returns ------- NoReturn Description of returned object. """ model = model.to(device) dict_loader = {"fit": train_loader, "val": val_loader} writer = TensorboardLogger(kwargs_writer) global_step_fit = 0 glob_step_val = 0 for epoch in tqdm_notebook(range(1, n_epochs + 1)): # monitor training loss fit_loss = 0.0 ################### # train the model # ################### epoch_losses = {} for phase in ["fit", "val"]: for chunk in dict_loader[phase]: data = chunk["data"].to(device) target = chunk["target"].to(device) total_loss = 0 optimizer.zero_grad() with torch.set_grad_enabled(phase == "fit"): outputs = model(data) loss = criterion(outputs, target) if phase == "fit": loss.backward() optimizer.step() writer.add( fit_loss=loss.item(), val_loss=None, model_for_gradient=None, step=global_step_fit, ) global_step_fit += 1 else: writer.add( fit_loss=None, val_loss=loss.item(), model_for_gradient=None, step=glob_step_val, ) glob_step_val += 1 total_loss += loss.item() * data.size(0) epoch_losses.update({ f"{phase} loss": total_loss / len(dict_loader[phase].dataset) }) # if phase == "fit": # writer.add(fit_loss=None, val_loss=None, model_for_gradient=model) # print avg training statistics if verbose: print( f'Fit loss: {epoch_losses["fit loss"]:.4f} and Val loss: {epoch_losses["val loss"]:.4f}' )