def backward_G_B(self): self.loss_G_adversarial_B = loss.adversarial_loss_generator( self.fakeBpred, self.outputApred, method='L2', loss_weight_config=self.loss_weight_config) self.loss_G_reconstruction_B = loss.reconstruction_loss( self.outputB, self.realB, method='L1', loss_weight_config=self.loss_weight_config) self.loss_G_mask_B = loss.mask_loss( self.maskB, threshold=self.loss_config['mask_threshold'], method='L1', loss_weight_config=self.loss_weight_config) self.loss_G_B = self.loss_G_adversarial_B + self.loss_G_reconstruction_B + self.loss_G_mask_B if self.loss_config['pl_on']: self.loss_G_perceptual_B = loss.perceptual_loss( self.realB, self.fakeB, self.vggface, self.vggface_for_pl, method='L2', loss_weight_config=self.loss_weight_config) self.loss_G_B += self.loss_G_perceptual_B if self.loss_config['edgeloss_on']: self.loss_G_edge_B = loss.edge_loss( self.outputB, self.realB, self.mask_eye_B, method='L1', loss_weight_config=self.loss_weight_config) self.loss_G_B += self.loss_G_edge_B if self.loss_config['eyeloss_on']: self.loss_G_eye_B = loss.eye_loss( self.outputB, self.realB, self.mask_eye_B, method='L1', loss_weight_config=self.loss_weight_config) self.loss_G_B += self.loss_G_eye_B self.loss_G_B.backward(retain_graph=True)
def backward_G_A(self): self.loss_G_adversarial_A = loss.adversarial_loss_generator( self.fakeApred, self.outputApred, method='L2', loss_weight_config=self.loss_weight_config) self.loss_G_reconstruction_A = loss.reconstruction_loss( self.outputA, self.realA, method='L1', loss_weight_config=self.loss_weight_config) self.loss_G_mask_A = loss.mask_loss( self.maskA, threshold=self.loss_config['mask_threshold'], method='L1', loss_weight_config=self.loss_weight_config) self.loss_G_A = self.loss_G_adversarial_A + self.loss_G_reconstruction_A + self.loss_G_mask_A if self.loss_config['pl_on']: self.loss_G_perceptual_A = loss.perceptual_loss( self.realA, self.fakeA, self.vggface, self.vggface_for_pl, method='L2', loss_weight_config=self.loss_weight_config) self.loss_G_A += self.loss_G_perceptual_A if self.loss_config['edgeloss_on']: self.loss_G_edge_A = loss.edge_loss( self.outputA, self.realA, self.mask_eye_A, method='L1', loss_weight_config=self.loss_weight_config) self.loss_G_A += self.loss_G_edge_A if self.loss_config['eyeloss_on']: self.loss_G_eye_A = loss.eye_loss( self.outputA, self.realA, self.mask_eye_A, method='L1', loss_weight_config=self.loss_weight_config) self.loss_G_A += self.loss_G_eye_A self.loss_G_A.backward(retain_graph=True)
def forward(self, content, style, alpha=1.0): style_feats = self.encode_with_intermediate(style) cont_feats = self.encode_with_intermediate(content) hidden_cont_feats = self.feature_pyramid(cont_feats[-3:]) hidden_style_feats = self.feature_pyramid(style_feats[-3:]) cs, cs_feats = self.pair_inference(cont_feats, style_feats, hidden_cont_feats, hidden_style_feats) if not self.training: return cs # perceptual loss_c = loss.perceptual_loss(cs_feats[-3:], cont_feats[-3:]) # Style Loss loss_s = loss.adain_style_loss(cs_feats, style_feats) result = (cs, loss_c, loss_s) if self.use_iden: cc, cc_feats = self.pair_inference(cont_feats, cont_feats, hidden_cont_feats, hidden_cont_feats, True) ss, ss_feats = self.pair_inference(style_feats, style_feats, hidden_style_feats, hidden_style_feats, True) loss_i = loss.identity_loss(cc, cc_feats, content, cont_feats, ss, ss_feats, style, style_feats, 50) result += (loss_i, ) else: result += (0, ) if self.use_cx: loss_cx = loss.contextual_loss(cs_feats, style_feats) result += (loss_cx, ) else: result += (0, ) result += (loss.total_variation(cs), ) return result
def train(cfg): # Set device if gpu is available device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Build network net = ImageTransformationNet().to(device) # Setup optimizer optimizer = optim.Adam(net.parameters()) # Load state if resuming training if cfg['resume']: checkpoint = torch.load(cfg['resume']) net.load_state_dict(checkpoint['net_state_dict']) optimizer.load_state_dict(checkpoint['opt_state_dict']) # Get starting epoch and batch (expects weight file in form EPOCH_<>_BATCH_<>.pt) parts = cfg['resume'].split('_') first_epoch = int(checkpoint['epoch']) first_batch = int(parts[-1].split('.')[0]) # Setup dataloader train_data = tqdm(build_data_loader(cfg), initial=first_batch) else: # Setup dataloader train_data = tqdm(build_data_loader(cfg)) # Set first epoch and batch first_epoch = 1 first_batch = 0 # Fetch style image and style grams style_im = load_image(cfg['style_image'], cfg) style_grams = get_style_grams(style_im, cfg) # Setup log file if specified log_dir = Path('logs') log_dir.mkdir(parents=True, exist_ok=True) if cfg['log_file'] and not cfg['resume']: today = datetime.datetime.today().strftime('%m/%d/%Y') header = f'Feed-Forward Style Transfer Training Log - {today}' with open(cfg['log_file'], 'w+') as file: file.write(header + '\n\n') # Setup log CSV if specified if cfg['csv_log_file'] and not cfg['resume']: utils.setup_csv(cfg) for epoch in range(first_epoch, cfg['epochs'] + 1): # Keep track of per epoch loss content_loss = 0 style_loss = 0 total_var_loss = 0 train_loss = 0 num_batches = 0 # Setup first batch to start enumerate at proper place if epoch == first_epoch: start = first_batch else: start = 0 for i, batch in enumerate(train_data, start=start): batch = batch.to(device) # Put batch through network batch_styled = net(batch) # Get vgg activations for styled and unstyled batch features = vgg_activations(batch_styled) content_features = vgg_activations(batch) # Get loss c_loss, s_loss = perceptual_loss(features=features, content_features=content_features, style_grams=style_grams, cfg=cfg) tv_loss = total_variation_loss(batch_styled, cfg) total_loss = c_loss + s_loss + tv_loss # Backpropogate total_loss.backward() # Do one step of optimization optimizer.step() # Clear gradients before next batch optimizer.zero_grad() # Update summary statistics with torch.no_grad(): content_loss += c_loss.item() style_loss += s_loss.item() total_var_loss += tv_loss.item() train_loss += total_loss.item() num_batches += 1 # Update progress bar avg_loss = round(train_loss / num_batches, 2) avg_c_loss = round(content_loss / num_batches, 2) avg_s_loss = round(style_loss / num_batches, 1) avg_tv_loss = round(total_var_loss / num_batches, 3) train_data.set_description( f'C - {avg_c_loss} | S - {avg_s_loss} | TV - {avg_tv_loss} | Total - {avg_loss}' ) train_data.refresh() # Create progress image if specified if cfg['image_checkpoint'] and ((i + 1) % cfg['image_checkpoint'] == 0): save_path = str( Path( cfg['image_checkpoint_dir'], f'EPOCH_{str(epoch).zfill(3)}_BATCH_{str(i+1).zfill(5)}.png' )) utils.make_checkpoint_image(cfg, net, save_path) # Save weights if specified if cfg['save_checkpoint'] and ((i + 1) % cfg['save_checkpoint'] == 0): save_path = str( Path( cfg['save_checkpoint_dir'], f'EPOCH_{str(epoch).zfill(3)}_BATCH_{str(i+1).zfill(5)}.pth' )) checkpoint = { 'epoch': epoch, 'net_state_dict': net.state_dict(), 'opt_state_dict': optimizer.state_dict(), 'loss': avg_loss } torch.save(checkpoint, save_path) # Write progress row to CSV if cfg['csv_checkpoint'] and ((i + 1) % cfg['csv_checkpoint'] == 0): row = [ epoch, i + 1, avg_c_loss, avg_s_loss, avg_tv_loss, avg_loss ] utils.write_progress_row(cfg, row) # Write loss at end of each epoch if cfg['log_file']: avg_loss = round(train_loss / num_batches, 4) line = f'EPOCH {epoch} | Loss - {avg_loss}' with open(cfg['log_file'], 'a') as file: file.write(line + '\n') # Save network if specified if cfg['epoch_save_checkpoint'] and ( epoch % cfg['epoch_save_checkpoint'] == 0): save_path = str( Path(cfg['save_checkpoint_dir'], f'EPOCH_{str(epoch).zfill(3)}.pth')) checkpoint = { 'epoch': epoch, 'net_state_dict': net.state_dict(), 'opt_state_dict': optimizer.state_dict(), 'loss': round(train_loss / num_batches, 4) } torch.save(checkpoint, save_path)
def train(args): # get context ctx = get_extension_context(args.context) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank device_id = mpi_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) config = read_yaml(args.config) if args.info: config.monitor_params.info = args.info if comm.size == 1: comm = None else: # disable outputs from logger except its rank = 0 if comm.rank > 0: import logging logger.setLevel(logging.ERROR) test = False train_params = config.train_params dataset_params = config.dataset_params model_params = config.model_params loss_flags = get_loss_flags(train_params) start_epoch = 0 rng = np.random.RandomState(device_id) data_iterator = frame_data_iterator( root_dir=dataset_params.root_dir, frame_shape=dataset_params.frame_shape, id_sampling=dataset_params.id_sampling, is_train=True, random_seed=rng, augmentation_params=dataset_params.augmentation_params, batch_size=train_params['batch_size'], shuffle=True, with_memory_cache=False, with_file_cache=False) if n_devices > 1: data_iterator = data_iterator.slice(rng=rng, num_of_slices=comm.size, slice_pos=comm.rank) # workaround not to use memory cache data_iterator._data_source._on_memory = False logger.info("Disabled on memory data cache.") bs, h, w, c = [train_params.batch_size] + dataset_params.frame_shape source = nn.Variable((bs, c, h, w)) driving = nn.Variable((bs, c, h, w)) with nn.parameter_scope("kp_detector"): # kp_X = {"value": Variable((bs, 10, 2)), "jacobian": Variable((bs, 10, 2, 2))} kp_source = detect_keypoint(source, **model_params.kp_detector_params, **model_params.common_params, test=test, comm=comm) persistent_all(kp_source) kp_driving = detect_keypoint(driving, **model_params.kp_detector_params, **model_params.common_params, test=test, comm=comm) persistent_all(kp_driving) with nn.parameter_scope("generator"): generated = occlusion_aware_generator(source, kp_source=kp_source, kp_driving=kp_driving, **model_params.generator_params, **model_params.common_params, test=test, comm=comm) # generated is a dictionary containing; # 'mask': Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25 # 'sparse_deformed': Variable((bs, num_kp + 1, num_channel, h/4, w/4)) # 'occlusion_map': Variable((bs, 1, h/4, w/4)) # 'deformed': Variable((bs, c, h, w)) # 'prediction': Variable((bs, c, h, w)) Only this is fed to discriminator. generated["prediction"].persistent = True pyramide_real = get_image_pyramid(driving, train_params.scales, generated["prediction"].shape[1]) persistent_all(pyramide_real) pyramide_fake = get_image_pyramid(generated['prediction'], train_params.scales, generated["prediction"].shape[1]) persistent_all(pyramide_fake) total_loss_G = None # dammy. defined temporarily loss_var_dict = {} # perceptual loss using VGG19 (always applied) if loss_flags.use_perceptual_loss: logger.info("Use Perceptual Loss.") scales = train_params.scales weights = train_params.loss_weights.perceptual vgg_param_path = train_params.vgg_param_path percep_loss = perceptual_loss(pyramide_real, pyramide_fake, scales, weights, vgg_param_path) percep_loss.persistent = True loss_var_dict['perceptual_loss'] = percep_loss total_loss_G = percep_loss # (LS)GAN loss and feature matching loss if loss_flags.use_gan_loss: logger.info("Use GAN Loss.") with nn.parameter_scope("discriminator"): discriminator_maps_generated = multiscale_discriminator( pyramide_fake, kp=unlink_all(kp_driving), **model_params.discriminator_params, **model_params.common_params, test=test, comm=comm) discriminator_maps_real = multiscale_discriminator( pyramide_real, kp=unlink_all(kp_driving), **model_params.discriminator_params, **model_params.common_params, test=test, comm=comm) for v in discriminator_maps_generated["feature_maps_1"]: v.persistent = True discriminator_maps_generated["prediction_map_1"].persistent = True for v in discriminator_maps_real["feature_maps_1"]: v.persistent = True discriminator_maps_real["prediction_map_1"].persistent = True for i, scale in enumerate(model_params.discriminator_params.scales): key = f'prediction_map_{scale}'.replace('.', '-') lsgan_loss_weight = train_params.loss_weights.generator_gan # LSGAN loss for Generator if i == 0: gan_loss_gen = lsgan_loss(discriminator_maps_generated[key], lsgan_loss_weight) else: gan_loss_gen += lsgan_loss(discriminator_maps_generated[key], lsgan_loss_weight) # LSGAN loss for Discriminator if i == 0: gan_loss_dis = lsgan_loss(discriminator_maps_real[key], lsgan_loss_weight, discriminator_maps_generated[key]) else: gan_loss_dis += lsgan_loss(discriminator_maps_real[key], lsgan_loss_weight, discriminator_maps_generated[key]) gan_loss_dis.persistent = True loss_var_dict['gan_loss_dis'] = gan_loss_dis total_loss_D = gan_loss_dis total_loss_D.persistent = True gan_loss_gen.persistent = True loss_var_dict['gan_loss_gen'] = gan_loss_gen total_loss_G += gan_loss_gen if loss_flags.use_feature_matching_loss: logger.info("Use Feature Matching Loss.") fm_weights = train_params.loss_weights.feature_matching fm_loss = feature_matching_loss(discriminator_maps_real, discriminator_maps_generated, model_params, fm_weights) fm_loss.persistent = True loss_var_dict['feature_matching_loss'] = fm_loss total_loss_G += fm_loss # transform loss if loss_flags.use_equivariance_value_loss or loss_flags.use_equivariance_jacobian_loss: transform = Transform(bs, **config.train_params.transform_params) transformed_frame = transform.transform_frame(driving) with nn.parameter_scope("kp_detector"): transformed_kp = detect_keypoint(transformed_frame, **model_params.kp_detector_params, **model_params.common_params, test=test, comm=comm) persistent_all(transformed_kp) # Value loss part if loss_flags.use_equivariance_value_loss: logger.info("Use Equivariance Value Loss.") warped_kp_value = transform.warp_coordinates( transformed_kp['value']) eq_value_weight = train_params.loss_weights.equivariance_value eq_value_loss = equivariance_value_loss(kp_driving['value'], warped_kp_value, eq_value_weight) eq_value_loss.persistent = True loss_var_dict['equivariance_value_loss'] = eq_value_loss total_loss_G += eq_value_loss # jacobian loss part if loss_flags.use_equivariance_jacobian_loss: logger.info("Use Equivariance Jacobian Loss.") arithmetic_jacobian = transform.jacobian(transformed_kp['value']) eq_jac_weight = train_params.loss_weights.equivariance_jacobian eq_jac_loss = equivariance_jacobian_loss( kp_driving['jacobian'], arithmetic_jacobian, transformed_kp['jacobian'], eq_jac_weight) eq_jac_loss.persistent = True loss_var_dict['equivariance_jacobian_loss'] = eq_jac_loss total_loss_G += eq_jac_loss assert total_loss_G is not None total_loss_G.persistent = True loss_var_dict['total_loss_gen'] = total_loss_G # -------------------- Create Monitors -------------------- monitors_gen, monitors_dis, monitor_time, monitor_vis, log_dir = get_monitors( config, loss_flags, loss_var_dict) if device_id == 0: # Dump training info .yaml _ = shutil.copy(args.config, log_dir) # copy the config yaml training_info_yaml = os.path.join(log_dir, "training_info.yaml") os.rename(os.path.join(log_dir, os.path.basename(args.config)), training_info_yaml) # then add additional information with open(training_info_yaml, "a", encoding="utf-8") as f: f.write(f"\nlog_dir: {log_dir}\nsaved_parameter: None") # -------------------- Solver Setup -------------------- solvers = setup_solvers(train_params) solver_generator = solvers["generator"] solver_discriminator = solvers["discriminator"] solver_kp_detector = solvers["kp_detector"] # max epochs num_epochs = train_params['num_epochs'] # iteration per epoch num_iter_per_epoch = data_iterator.size // bs # will be increased by num_repeat if 'num_repeats' in train_params or train_params['num_repeats'] != 1: num_iter_per_epoch *= config.train_params.num_repeats # modify learning rate if current epoch exceeds the number defined in lr_decay_at_epochs = train_params['epoch_milestones'] # ex. [60, 90] gamma = 0.1 # decay rate # -------------------- For finetuning --------------------- if args.ft_params: assert os.path.isfile(args.ft_params) logger.info(f"load {args.ft_params} for finetuning.") nn.load_parameters(args.ft_params) start_epoch = int( os.path.splitext(os.path.basename( args.ft_params))[0].split("epoch_")[1]) # set solver's state for name, solver in solvers.items(): saved_states = os.path.join( os.path.dirname(args.ft_params), f"state_{name}_at_epoch_{start_epoch}.h5") solver.load_states(saved_states) start_epoch += 1 logger.info(f"Resuming from epoch {start_epoch}.") logger.info( f"Start training. Total epoch: {num_epochs - start_epoch}, {num_iter_per_epoch * n_devices} iter/epoch." ) for e in range(start_epoch, num_epochs): logger.info(f"Epoch: {e} / {num_epochs}.") data_iterator._reset() # rewind the iterator at the beginning # learning rate scheduler if e in lr_decay_at_epochs: logger.info("Learning rate decayed.") learning_rate_decay(solvers, gamma=gamma) for i in range(num_iter_per_epoch): _driving, _source = data_iterator.next() source.d = _source driving.d = _driving # update generator and keypoint detector total_loss_G.forward() if device_id == 0: monitors_gen.add((e * num_iter_per_epoch + i) * n_devices) solver_generator.zero_grad() solver_kp_detector.zero_grad() callback = None if n_devices > 1: params = [x.grad for x in solver_generator.get_parameters().values()] + \ [x.grad for x in solver_kp_detector.get_parameters().values()] callback = comm.all_reduce_callback(params, 2 << 20) total_loss_G.backward(clear_buffer=True, communicator_callbacks=callback) solver_generator.update() solver_kp_detector.update() if loss_flags.use_gan_loss: # update discriminator total_loss_D.forward(clear_no_need_grad=True) if device_id == 0: monitors_dis.add((e * num_iter_per_epoch + i) * n_devices) solver_discriminator.zero_grad() callback = None if n_devices > 1: params = [ x.grad for x in solver_discriminator.get_parameters().values() ] callback = comm.all_reduce_callback(params, 2 << 20) total_loss_D.backward(clear_buffer=True, communicator_callbacks=callback) solver_discriminator.update() if device_id == 0: monitor_time.add((e * num_iter_per_epoch + i) * n_devices) if device_id == 0 and ( (e * num_iter_per_epoch + i) * n_devices) % config.monitor_params.visualize_freq == 0: images_to_visualize = [ source.d, driving.d, generated["prediction"].d ] visuals = combine_images(images_to_visualize) monitor_vis.add((e * num_iter_per_epoch + i) * n_devices, visuals) if device_id == 0: if e % train_params.checkpoint_freq == 0 or e == num_epochs - 1: save_parameters(e, log_dir, solvers) return
def main(): global params, best_iou, num_iter, tb_writer, logger, logger_results best_iou = 0 params = Params() params.save_params('{:s}/params.txt'.format(params.paths['save_dir'])) tb_writer = SummaryWriter('{:s}/tb_logs'.format(params.paths['save_dir'])) os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( str(x) for x in params.train['gpu']) # set up logger logger, logger_results = setup_logging(params) # ----- create model ----- # model_name = params.model['name'] if model_name == 'ResUNet34': model = ResUNet34(params.model['out_c'], fixed_feature=params.model['fix_params']) elif params.model['name'] == 'UNet': model = UNet(3, params.model['out_c']) else: raise NotImplementedError() logger.info('Model: {:s}'.format(model_name)) # if not params.train['checkpoint']: # logger.info(model) model = nn.DataParallel(model) model = model.cuda() global vgg_model logger.info('=> Using VGG16 for perceptual loss...') vgg_model = vgg16_feat() vgg_model = nn.DataParallel(vgg_model).cuda() cudnn.benchmark = True # ----- define optimizer ----- # optimizer = torch.optim.Adam(model.parameters(), params.train['lr'], betas=(0.9, 0.99), weight_decay=params.train['weight_decay']) # ----- get pixel weights and define criterion ----- # if not params.train['weight_map']: criterion = torch.nn.NLLLoss().cuda() else: logger.info('=> Using weight maps...') criterion = torch.nn.NLLLoss(reduction='none').cuda() if params.train['beta'] > 0: logger.info('=> Using perceptual loss...') global criterion_perceptual criterion_perceptual = perceptual_loss() data_transforms = { 'train': get_transforms(params.transform['train']), 'val': get_transforms(params.transform['val']) } # ----- load data ----- # dsets = {} for x in ['train', 'val']: img_dir = '{:s}/{:s}'.format(params.paths['img_dir'], x) target_dir = '{:s}/{:s}'.format(params.paths['label_dir'], x) if params.train['weight_map']: weight_map_dir = '{:s}/{:s}'.format(params.paths['weight_map_dir'], x) dir_list = [img_dir, weight_map_dir, target_dir] postfix = ['weight.png', 'label_with_contours.png'] num_channels = [3, 1, 3] else: dir_list = [img_dir, target_dir] postfix = ['label_with_contours.png'] num_channels = [3, 3] dsets[x] = DataFolder(dir_list, postfix, num_channels, data_transforms[x]) train_loader = DataLoader(dsets['train'], batch_size=params.train['batch_size'], shuffle=True, num_workers=params.train['workers']) val_loader = DataLoader(dsets['val'], batch_size=params.train['val_batch_size'], shuffle=False, num_workers=params.train['workers']) # ----- optionally load from a checkpoint for validation or resuming training ----- # if params.train['checkpoint']: if os.path.isfile(params.train['checkpoint']): logger.info("=> loading checkpoint '{}'".format( params.train['checkpoint'])) checkpoint = torch.load(params.train['checkpoint']) params.train['start_epoch'] = checkpoint['epoch'] best_iou = checkpoint['best_iou'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) logger.info("=> loaded checkpoint '{}' (epoch {})".format( params.train['checkpoint'], checkpoint['epoch'])) else: logger.info("=> no checkpoint found at '{}'".format( params.train['checkpoint'])) # ----- training and validation ----- # num_iter = params.train['num_epochs'] * len(train_loader) # print training parameters logger.info("=> Initial learning rate: {:g}".format(params.train['lr'])) logger.info("=> Batch size: {:d}".format(params.train['batch_size'])) # logger.info("=> Number of training iterations: {:d}".format(num_iter)) logger.info("=> Training epochs: {:d}".format(params.train['num_epochs'])) logger.info("=> beta: {:.1f}".format(params.train['beta'])) for epoch in range(params.train['start_epoch'], params.train['num_epochs']): # train for one epoch or len(train_loader) iterations logger.info('Epoch: [{:d}/{:d}]'.format(epoch + 1, params.train['num_epochs'])) train_results = train(train_loader, model, optimizer, criterion, epoch) train_loss, train_loss_ce, train_loss_var, train_iou_nuclei, train_iou = train_results # evaluate on validation set with torch.no_grad(): val_results = validate(val_loader, model, criterion) val_loss, val_loss_ce, val_loss_var, val_iou_nuclei, val_iou = val_results # check if it is the best accuracy combined_iou = (val_iou_nuclei + val_iou) / 2 is_best = combined_iou > best_iou best_iou = max(combined_iou, best_iou) cp_flag = (epoch + 1) % params.train['checkpoint_freq'] == 0 save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_iou': best_iou, 'optimizer': optimizer.state_dict(), }, epoch, is_best, params.paths['save_dir'], cp_flag) # save the training results to txt files logger_results.info( '{:d}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}' .format(epoch + 1, train_loss, train_loss_ce, train_loss_var, train_iou_nuclei, train_iou, val_loss, val_iou_nuclei, val_iou)) # tensorboard logs tb_writer.add_scalars( 'epoch_losses', { 'train_loss': train_loss, 'train_loss_ce': train_loss_ce, 'train_loss_var': train_loss_var, 'val_loss': val_loss }, epoch) tb_writer.add_scalars( 'epoch_accuracies', { 'train_iou_nuclei': train_iou_nuclei, 'train_iou': train_iou, 'val_iou_nuclei': val_iou_nuclei, 'val_iou': val_iou }, epoch) tb_writer.close()