def train(): torch.backends.cudnn.benchmark = True _, dataloader = create_dataloader(config.IMG_DIR + "/train", config.MESH_DIR + "/train", batch_size=config.BATCH_SIZE, used_layers=config.USED_LAYERS, img_size=config.IMAGE_SIZE, map_size=config.MAP_SIZE, augment=config.AUGMENT, workers=config.NUM_WORKERS, pin_memory=config.PIN_MEMORY, shuffle=True) in_channels = num_channels(config.USED_LAYERS) encoder = Encoder(in_channels=in_channels) decoder = Decoder(num_classes=config.NUM_CLASSES+1) encoder.apply(init_weights) decoder.apply(init_weights) encoder_solver = torch.optim.Adam(filter(lambda p: p.requires_grad, encoder.parameters()), lr=config.ENCODER_LEARNING_RATE, betas=config.BETAS) decoder_solver = torch.optim.Adam(decoder.parameters(), lr=config.DECODER_LEARNING_RATE, betas=config.BETAS) encoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(encoder_solver, milestones=config.ENCODER_LR_MILESTONES, gamma=config.GAMMA) decoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(decoder_solver, milestones=config.DECODER_LR_MILESTONES, gamma=config.GAMMA) encoder = encoder.to(config.DEVICE) decoder = decoder.to(config.DEVICE) loss_fn = LossFunction() init_epoch = 0 if config.CHECKPOINT_FILE and config.LOAD_MODEL: init_epoch, encoder, decoder = load_checkpoint(encoder, decoder, config.CHECKPOINT_FILE, config.DEVICE) output_dir = os.path.join(config.OUT_PATH, re.sub("[^0-9a-zA-Z]+", "-", dt.now().isoformat())) for epoch_idx in range(init_epoch, config.NUM_EPOCHS): encoder.train() decoder.train() train_one_epoch(encoder, decoder, dataloader, loss_fn, encoder_solver, decoder_solver, epoch_idx) encoder_lr_scheduler.step() decoder_lr_scheduler.step() if config.TEST: test(encoder, decoder) if config.SAVE_MODEL: save_checkpoint(epoch_idx, encoder, decoder, output_dir) if not config.TEST: test(encoder, decoder) if not config.SAVE_MODEL: save_checkpoint(config.NUM_EPOCHS - 1, encoder, decoder, output_dir)
def train_net(cfg): # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use torch.backends.cudnn.benchmark = True # Set up data augmentation IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W train_transforms = utils.data_transforms.Compose([ utils.data_transforms.RandomCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground( cfg.TRAIN.RANDOM_BG_COLOR_RANGE), utils.data_transforms.ColorJitter(cfg.TRAIN.BRIGHTNESS, cfg.TRAIN.CONTRAST, cfg.TRAIN.SATURATION), utils.data_transforms.RandomNoise(cfg.TRAIN.NOISE_STD), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.RandomFlip(), utils.data_transforms.RandomPermuteRGB(), utils.data_transforms.ToTensor(), ]) val_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) # Set up data loader train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TRAIN_DATASET](cfg) val_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TEST_DATASET](cfg) train_data_loader = torch.utils.data.DataLoader( dataset=train_dataset_loader.get_dataset( utils.data_loaders.DatasetType.TRAIN, cfg.CONST.N_VIEWS_RENDERING, train_transforms), batch_size=cfg.CONST.BATCH_SIZE, num_workers=cfg.TRAIN.NUM_WORKER, pin_memory=True, shuffle=True, drop_last=True) val_data_loader = torch.utils.data.DataLoader( dataset=val_dataset_loader.get_dataset( utils.data_loaders.DatasetType.VAL, cfg.CONST.N_VIEWS_RENDERING, val_transforms), batch_size=1, num_workers=1, pin_memory=True, shuffle=False) # Set up networks encoder = Encoder(cfg) decoder = Decoder(cfg) refiner = Refiner(cfg) merger = Merger(cfg) print('[DEBUG] %s Parameters in Encoder: %d.' % (dt.now(), utils.network_utils.count_parameters(encoder))) print('[DEBUG] %s Parameters in Decoder: %d.' % (dt.now(), utils.network_utils.count_parameters(decoder))) print('[DEBUG] %s Parameters in Refiner: %d.' % (dt.now(), utils.network_utils.count_parameters(refiner))) print('[DEBUG] %s Parameters in Merger: %d.' % (dt.now(), utils.network_utils.count_parameters(merger))) # Initialize weights of networks encoder.apply(utils.network_utils.init_weights) decoder.apply(utils.network_utils.init_weights) refiner.apply(utils.network_utils.init_weights) merger.apply(utils.network_utils.init_weights) # Set up solver if cfg.TRAIN.POLICY == 'adam': encoder_solver = torch.optim.Adam(filter(lambda p: p.requires_grad, encoder.parameters()), lr=cfg.TRAIN.ENCODER_LEARNING_RATE, betas=cfg.TRAIN.BETAS) decoder_solver = torch.optim.Adam(decoder.parameters(), lr=cfg.TRAIN.DECODER_LEARNING_RATE, betas=cfg.TRAIN.BETAS) refiner_solver = torch.optim.Adam(refiner.parameters(), lr=cfg.TRAIN.REFINER_LEARNING_RATE, betas=cfg.TRAIN.BETAS) merger_solver = torch.optim.Adam(merger.parameters(), lr=cfg.TRAIN.MERGER_LEARNING_RATE, betas=cfg.TRAIN.BETAS) elif cfg.TRAIN.POLICY == 'sgd': encoder_solver = torch.optim.SGD(filter(lambda p: p.requires_grad, encoder.parameters()), lr=cfg.TRAIN.ENCODER_LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM) decoder_solver = torch.optim.SGD(decoder.parameters(), lr=cfg.TRAIN.DECODER_LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM) refiner_solver = torch.optim.SGD(refiner.parameters(), lr=cfg.TRAIN.REFINER_LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM) merger_solver = torch.optim.SGD(merger.parameters(), lr=cfg.TRAIN.MERGER_LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM) else: raise Exception('[FATAL] %s Unknown optimizer %s.' % (dt.now(), cfg.TRAIN.POLICY)) # Set up learning rate scheduler to decay learning rates dynamically encoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( encoder_solver, milestones=cfg.TRAIN.ENCODER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA) decoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( decoder_solver, milestones=cfg.TRAIN.DECODER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA) refiner_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( refiner_solver, milestones=cfg.TRAIN.REFINER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA) merger_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( merger_solver, milestones=cfg.TRAIN.MERGER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA) if torch.cuda.is_available(): encoder = torch.nn.DataParallel(encoder).cuda() decoder = torch.nn.DataParallel(decoder).cuda() refiner = torch.nn.DataParallel(refiner).cuda() merger = torch.nn.DataParallel(merger).cuda() # Set up loss functions bce_loss = torch.nn.BCELoss() # Load pretrained model if exists init_epoch = 0 best_iou = -1 best_epoch = -1 if 'WEIGHTS' in cfg.CONST and cfg.TRAIN.RESUME_TRAIN: print('[INFO] %s Recovering from %s ...' % (dt.now(), cfg.CONST.WEIGHTS)) checkpoint = torch.load(cfg.CONST.WEIGHTS) init_epoch = checkpoint['epoch_idx'] best_iou = checkpoint['best_iou'] best_epoch = checkpoint['best_epoch'] encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) if cfg.NETWORK.USE_REFINER: refiner.load_state_dict(checkpoint['refiner_state_dict']) if cfg.NETWORK.USE_MERGER: merger.load_state_dict(checkpoint['merger_state_dict']) print('[INFO] %s Recover complete. Current epoch #%d, Best IoU = %.4f at epoch #%d.' \ % (dt.now(), init_epoch, best_iou, best_epoch)) # Summary writer for TensorBoard output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', dt.now().isoformat()) log_dir = output_dir % 'logs' ckpt_dir = output_dir % 'checkpoints' train_writer = SummaryWriter(os.path.join(log_dir, 'train')) val_writer = SummaryWriter(os.path.join(log_dir, 'test')) # Training loop for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHES): # Tick / tock epoch_start_time = time() # Batch average meterics batch_time = utils.network_utils.AverageMeter() data_time = utils.network_utils.AverageMeter() encoder_losses = utils.network_utils.AverageMeter() refiner_losses = utils.network_utils.AverageMeter() # Adjust learning rate encoder_lr_scheduler.step() decoder_lr_scheduler.step() refiner_lr_scheduler.step() merger_lr_scheduler.step() # switch models to training mode encoder.train() decoder.train() merger.train() refiner.train() batch_end_time = time() n_batches = len(train_data_loader) for batch_idx, (taxonomy_names, sample_names, rendering_images, ground_truth_volumes) in enumerate(train_data_loader): # Measure data time data_time.update(time() - batch_end_time) # Get data from data loader rendering_images = utils.network_utils.var_or_cuda( rendering_images) ground_truth_volumes = utils.network_utils.var_or_cuda( ground_truth_volumes) # Train the encoder, decoder, refiner, and merger image_features = encoder(rendering_images) raw_features, generated_volumes = decoder(image_features) if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER: generated_volumes = merger(raw_features, generated_volumes) else: generated_volumes = torch.mean(generated_volumes, dim=1) encoder_loss = bce_loss(generated_volumes, ground_truth_volumes) * 10 if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: generated_volumes = refiner(generated_volumes) refiner_loss = bce_loss(generated_volumes, ground_truth_volumes) * 10 else: refiner_loss = encoder_loss # Gradient decent encoder.zero_grad() decoder.zero_grad() refiner.zero_grad() merger.zero_grad() if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: encoder_loss.backward(retain_graph=True) refiner_loss.backward() else: encoder_loss.backward() encoder_solver.step() decoder_solver.step() refiner_solver.step() merger_solver.step() # Append loss to average metrics encoder_losses.update(encoder_loss.item()) refiner_losses.update(refiner_loss.item()) # Append loss to TensorBoard n_itr = epoch_idx * n_batches + batch_idx train_writer.add_scalar('EncoderDecoder/BatchLoss', encoder_loss.item(), n_itr) train_writer.add_scalar('Refiner/BatchLoss', refiner_loss.item(), n_itr) # Tick / tock batch_time.update(time() - batch_end_time) batch_end_time = time() print('[INFO] %s [Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' % \ (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, \ batch_time.val, data_time.val, encoder_loss.item(), refiner_loss.item())) # Append epoch loss to TensorBoard train_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg, epoch_idx + 1) train_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg, epoch_idx + 1) # Tick / tock epoch_end_time = time() print('[INFO] %s Epoch [%d/%d] EpochTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' % (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, epoch_end_time - epoch_start_time, \ encoder_losses.avg, refiner_losses.avg)) # Update Rendering Views if cfg.TRAIN.UPDATE_N_VIEWS_RENDERING: n_views_rendering = random.randint(1, cfg.CONST.N_VIEWS_RENDERING) train_data_loader.dataset.set_n_views_rendering(n_views_rendering) print('[INFO] %s Epoch [%d/%d] Update #RenderingViews to %d' % \ (dt.now(), epoch_idx + 2, cfg.TRAIN.NUM_EPOCHES, n_views_rendering)) # Validate the training models iou = test_net(cfg, epoch_idx + 1, output_dir, val_data_loader, val_writer, encoder, decoder, refiner, merger) # Save weights to file if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0: if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) utils.network_utils.save_checkpoints(cfg, \ os.path.join(ckpt_dir, 'ckpt-epoch-%04d.pth' % (epoch_idx + 1)), \ epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, \ refiner, refiner_solver, merger, merger_solver, best_iou, best_epoch) if iou > best_iou: if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) best_iou = iou best_epoch = epoch_idx + 1 utils.network_utils.save_checkpoints(cfg, \ os.path.join(ckpt_dir, 'best-ckpt.pth'), \ epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, \ refiner, refiner_solver, merger, merger_solver, best_iou, best_epoch) # Close SummaryWriter for TensorBoard train_writer.close() val_writer.close()
def train_net(cfg): # Set up data augmentation IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W train_transforms = utils.data_transforms.Compose([ utils.data_transforms.RandomCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground( cfg.TRAIN.RANDOM_BG_COLOR_RANGE), utils.data_transforms.ColorJitter(cfg.TRAIN.BRIGHTNESS, cfg.TRAIN.CONTRAST, cfg.TRAIN.SATURATION), utils.data_transforms.RandomNoise(cfg.TRAIN.NOISE_STD), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.RandomFlip(), utils.data_transforms.RandomPermuteRGB(), utils.data_transforms.ToTensor(), ]) val_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) # Set up data loader train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TRAIN_DATASET](cfg) val_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TEST_DATASET](cfg) train_data_loader = paddle.io.DataLoader( dataset=train_dataset_loader.get_dataset( utils.data_loaders.DatasetType.TRAIN, cfg.CONST.N_VIEWS_RENDERING, train_transforms), batch_size=cfg.CONST.BATCH_SIZE, #num_workers=0 , # cfg.TRAIN.NUM_WORKER>0时报错,因为dev/shm/太小 https://blog.csdn.net/ctypyb2002/article/details/107914643 #pin_memory=True, use_shared_memory=False, shuffle=True, drop_last=True) val_data_loader = paddle.io.DataLoader( dataset=val_dataset_loader.get_dataset( utils.data_loaders.DatasetType.VAL, cfg.CONST.N_VIEWS_RENDERING, val_transforms), batch_size=1, #num_workers=1, #pin_memory=True, shuffle=False) # Set up networks # paddle.Model prepare fit save encoder = Encoder(cfg) decoder = Decoder(cfg) merger = Merger(cfg) refiner = Refiner(cfg) print('[DEBUG] %s Parameters in Encoder: %d.' % (dt.now(), utils.network_utils.count_parameters(encoder))) print('[DEBUG] %s Parameters in Decoder: %d.' % (dt.now(), utils.network_utils.count_parameters(decoder))) print('[DEBUG] %s Parameters in Merger: %d.' % (dt.now(), utils.network_utils.count_parameters(merger))) print('[DEBUG] %s Parameters in Refiner: %d.' % (dt.now(), utils.network_utils.count_parameters(refiner))) # # Initialize weights of networks # paddle的参数化不同,参见API # encoder.apply(utils.network_utils.init_weights) # decoder.apply(utils.network_utils.init_weights) # merger.apply(utils.network_utils.init_weights) # Set up learning rate scheduler to decay learning rates dynamically encoder_lr_scheduler = paddle.optimizer.lr.MultiStepDecay( learning_rate=cfg.TRAIN.ENCODER_LEARNING_RATE, milestones=cfg.TRAIN.ENCODER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA, verbose=True) decoder_lr_scheduler = paddle.optimizer.lr.MultiStepDecay( learning_rate=cfg.TRAIN.DECODER_LEARNING_RATE, milestones=cfg.TRAIN.DECODER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA, verbose=True) merger_lr_scheduler = paddle.optimizer.lr.MultiStepDecay( learning_rate=cfg.TRAIN.MERGER_LEARNING_RATE, milestones=cfg.TRAIN.MERGER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA, verbose=True) refiner_lr_scheduler = paddle.optimizer.lr.MultiStepDecay( learning_rate=cfg.TRAIN.REFINER_LEARNING_RATE, milestones=cfg.TRAIN.REFINER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA, verbose=True) # Set up solver # if cfg.TRAIN.POLICY == 'adam': encoder_solver = paddle.optimizer.Adam(learning_rate=encoder_lr_scheduler, parameters=encoder.parameters()) decoder_solver = paddle.optimizer.Adam(learning_rate=decoder_lr_scheduler, parameters=decoder.parameters()) merger_solver = paddle.optimizer.Adam(learning_rate=merger_lr_scheduler, parameters=merger.parameters()) refiner_solver = paddle.optimizer.Adam(learning_rate=refiner_lr_scheduler, parameters=refiner.parameters()) # if torch.cuda.is_available(): # encoder = torch.nn.DataParallel(encoder).cuda() # decoder = torch.nn.DataParallel(decoder).cuda() # merger = torch.nn.DataParallel(merger).cuda() # Set up loss functions bce_loss = paddle.nn.BCELoss() # Load pretrained model if exists init_epoch = 0 best_iou = -1 best_epoch = -1 if 'WEIGHTS' in cfg.CONST and cfg.TRAIN.RESUME_TRAIN: print('[INFO] %s Recovering from %s ...' % (dt.now(), cfg.CONST.WEIGHTS)) # load encoder_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "encoder.pdparams")) encoder_solver_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "encoder_solver.pdopt")) encoder.set_state_dict(encoder_state_dict) encoder_solver.set_state_dict(encoder_solver_state_dict) decoder_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "decoder.pdparams")) decoder_solver_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "decoder_solver.pdopt")) decoder.set_state_dict(decoder_state_dict) decoder_solver.set_state_dict(decoder_solver_state_dict) if cfg.NETWORK.USE_MERGER: merger_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "merger.pdparams")) merger_solver_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "merger_solver.pdopt")) merger.set_state_dict(merger_state_dict) merger_solver.set_state_dict(merger_solver_state_dict) if cfg.NETWORK.USE_REFINER: refiner_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "refiner.pdparams")) refiner_solver_state_dict = paddle.load( os.path.join(cfg.CONST.WEIGHTS, "refiner_solver.pdopt")) refiner.set_state_dict(refiner_state_dict) refiner_solver.set_state_dict(refiner_solver_state_dict) print( '[INFO] %s Recover complete. Current epoch #%d, Best IoU = %.4f at epoch #%d.' % (dt.now(), init_epoch, best_iou, best_epoch)) # Summary writer for TensorBoard output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', dt.now().isoformat()) log_dir = output_dir % 'logs' ckpt_dir = output_dir % 'checkpoints' # train_writer = SummaryWriter() # val_writer = SummaryWriter(os.path.join(log_dir, 'test')) train_writer = LogWriter(os.path.join(log_dir, 'train')) val_writer = LogWriter(os.path.join(log_dir, 'val')) # Training loop for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHES): # Tick / tock epoch_start_time = time() # Batch average meterics batch_time = utils.network_utils.AverageMeter() data_time = utils.network_utils.AverageMeter() encoder_losses = utils.network_utils.AverageMeter() refiner_losses = utils.network_utils.AverageMeter() # # switch models to training mode encoder.train() decoder.train() merger.train() refiner.train() batch_end_time = time() n_batches = len(train_data_loader) # print("****debug: length of train data loder",n_batches) for batch_idx, (rendering_images, ground_truth_volumes) in enumerate( train_data_loader()): # # debug # if batch_idx>1: # break # Measure data time data_time.update(time() - batch_end_time) # print("****debug: batch_idx",batch_idx) # print(rendering_images.shape) # print(ground_truth_volumes.shape) # Get data from data loader rendering_images = utils.network_utils.var_or_cuda( rendering_images) ground_truth_volumes = utils.network_utils.var_or_cuda( ground_truth_volumes) # Train the encoder, decoder, and merger image_features = encoder(rendering_images) raw_features, generated_volumes = decoder(image_features) if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER: generated_volumes = merger(raw_features, generated_volumes) # else: # mergered_volumes = paddle.mean(generated_volumes, aixs=1) encoder_loss = bce_loss(generated_volumes, ground_truth_volumes) * 10 if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: generated_volumes = refiner(generated_volumes) refiner_loss = bce_loss(generated_volumes, ground_truth_volumes) * 10 # else: # refiner_loss = encoder_loss # Gradient decent encoder_solver.clear_grad() decoder_solver.clear_grad() merger_solver.clear_grad() refiner_solver.clear_grad() if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: encoder_loss.backward(retain_graph=True) refiner_loss.backward() # else: # encoder_loss.backward() encoder_solver.step() decoder_solver.step() merger_solver.step() refiner_solver.step() # Append loss to average metrics encoder_losses.update(encoder_loss.numpy()) refiner_losses.update(refiner_loss.numpy()) # Append loss to TensorBoard n_itr = epoch_idx * n_batches + batch_idx train_writer.add_scalar(tag='EncoderDecoder/BatchLoss', step=n_itr, value=encoder_loss.numpy()) train_writer.add_scalar('Refiner/BatchLoss', value=refiner_loss.numpy(), step=n_itr) # Tick / tock batch_time.update(time() - batch_end_time) batch_end_time = time() if (batch_idx % int(cfg.CONST.INFO_BATCH)) == 0: print( '[INFO] %s [Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' % (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, batch_time.val, data_time.val, encoder_loss.numpy(), refiner_loss.numpy())) # Append epoch loss to TensorBoard train_writer.add_scalar(tag='EncoderDecoder/EpochLoss', step=epoch_idx + 1, value=encoder_losses.avg) train_writer.add_scalar('Refiner/EpochLoss', value=refiner_losses.avg, step=epoch_idx + 1) # update scheduler each step encoder_lr_scheduler.step() decoder_lr_scheduler.step() merger_lr_scheduler.step() refiner_lr_scheduler.step() # Tick / tock epoch_end_time = time() print( '[INFO] %s Epoch [%d/%d] EpochTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' % (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, epoch_end_time - epoch_start_time, encoder_losses.avg, refiner_losses.avg)) # Update Rendering Views if cfg.TRAIN.UPDATE_N_VIEWS_RENDERING: n_views_rendering = random.randint(1, cfg.CONST.N_VIEWS_RENDERING) train_data_loader.dataset.set_n_views_rendering(n_views_rendering) print('[INFO] %s Epoch [%d/%d] Update #RenderingViews to %d' % (dt.now(), epoch_idx + 2, cfg.TRAIN.NUM_EPOCHES, n_views_rendering)) # Validate the training models iou = test_net(cfg, epoch_idx + 1, output_dir, val_data_loader, val_writer, encoder, decoder, merger, refiner) # Save weights to file if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0: if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) utils.network_utils.save_checkpoints( cfg, os.path.join(ckpt_dir, 'ckpt-epoch-%04d' % (epoch_idx + 1)), epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, merger, merger_solver, refiner, refiner_solver, best_iou, best_epoch) if iou > best_iou: if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) best_iou = iou best_epoch = epoch_idx + 1 utils.network_utils.save_checkpoints( cfg, os.path.join(ckpt_dir, 'best-ckpt'), epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, merger, merger_solver, refiner, refiner_solver, best_iou, best_epoch)
class Trainer: def __init__(self, config): self.config = config self.config.data.n_datasets = len(config.data.datasets) print("No of datasets used:", self.config.data.n_datasets) torch.manual_seed(config.env.seed) torch.cuda.manual_seed(config.env.seed) self.expPath = self.config.env.expPath self.logger = Logger("Training", "logs/training.log") self.data = [ DatasetSet(data_path, config.data.seq_len, config.data) for data_path in config.data.datasets ] self.losses_recon = [ LossMeter(f'recon {i}') for i in range(self.config.data.n_datasets) ] self.loss_d_right = LossMeter('d') self.loss_total = [ LossMeter(f'total {i}') for i in range(self.config.data.n_datasets) ] self.evals_recon = [ LossMeter(f'recon {i}') for i in range(self.config.data.n_datasets) ] self.eval_d_right = LossMeter('eval d') self.eval_total = [ LossMeter(f'eval total {i}') for i in range(self.config.data.n_datasets) ] self.encoder = Encoder(config.encoder) self.decoders = torch.nn.ModuleList([ Decoder(config.decoder) for _ in range(self.config.data.n_datasets) ]) self.classifier = DomainClassifier( config.domain_classifier, num_classes=self.config.data.n_datasets) states = None if config.env.checkpoint: checkpoint_args_path = os.path.dirname( config.env.checkpoint) + '/args.pth' checkpoint_args = torch.load(checkpoint_args_path) self.start_epoch = checkpoint_args[-1] + 1 states = [ torch.load(self.config.env.checkpoint + f'_{i}.pth') for i in range(self.config.data.n_datasets) ] self.encoder.load_state_dict(states[0]['encoder_state']) for i in range(self.config.data.n_datasets): self.decoders[i].load_state_dict(states[i]['decoder_state']) self.classifier.load_state_dict(states[0]['discriminator_state']) self.logger.info('Loaded checkpoint parameters') raise NotImplementedError else: self.start_epoch = 0 self.encoder = torch.nn.DataParallel(self.encoder).cuda() self.classifier = torch.nn.DataParallel(self.classifier).cuda() for i, decoder in enumerate(self.decoders): self.decoders[i] = torch.nn.DataParallel(decoder).cuda() self.model_optimizers = [ optim.Adam(chain(self.encoder.parameters(), decoder.parameters()), lr=config.data.lr) for decoder in self.decoders ] self.classifier_optimizer = optim.Adam(self.classifier.parameters(), lr=config.data.lr) if config.env.checkpoint and config.env.load_optimizer: for i in range(self.config.data.n_datasets): self.model_optimizers[i].load_state_dict( states[i]['model_optimizer_state']) self.classifier_optimizer.load_state_dict( states[0]['d_optimizer_state']) self.lr_managers = [] for i in range(self.config.data.n_datasets): self.lr_managers.append( torch.optim.lr_scheduler.ExponentialLR( self.model_optimizers[i], config.data.lr_decay)) self.lr_managers[i].last_epoch = self.start_epoch self.lr_managers[i].step() def eval_batch(self, x, x_aug, dset_num): x, x_aug = x.float(), x_aug.float() z = self.encoder(x) y = self.decoders[dset_num](x, z) z_logits = self.classifier(z) z_classification = torch.max(z_logits, dim=1)[1] z_accuracy = (z_classification == dset_num).float().mean() self.eval_d_right.add(z_accuracy.data.item()) # discriminator_right = F.cross_entropy(z_logits, dset_num).mean() discriminator_right = F.cross_entropy( z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean() recon_loss = cross_entropy_loss(y, x) self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean()) total_loss = discriminator_right.data.item() * self.config.domain_classifier.d_lambda + \ recon_loss.mean().data.item() self.eval_total[dset_num].add(total_loss) return total_loss def train_batch(self, x, x_aug, dset_num): x, x_aug = x.float(), x_aug.float() # Optimize D - classifier right z = self.encoder(x) z_logits = self.classifier(z) discriminator_right = F.cross_entropy( z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean() loss = discriminator_right * self.config.domain_classifier.d_lambda self.loss_d_right.add(loss.data.item()) self.classifier_optimizer.zero_grad() loss.backward() if self.config.domain_classifier.grad_clip is not None: clip_grad_value_(self.classifier.parameters(), self.config.domain_classifier.grad_clip) self.classifier_optimizer.step() # optimize G - reconstructs well, classifier wrong z = self.encoder(x_aug) y = self.decoders[dset_num](x, z) z_logits = self.classifier(z) discriminator_wrong = -F.cross_entropy( z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean() if not (-100 < discriminator_right.data.item() < 100): self.logger.debug(f'z_logits: {z_logits.detach().cpu().numpy()}') self.logger.debug(f'dset_num: {dset_num}') recon_loss = cross_entropy_loss(y, x) self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean()) loss = (recon_loss.mean() + self.config.domain_classifier.d_lambda * discriminator_wrong) self.model_optimizers[dset_num].zero_grad() loss.backward() if self.config.domain_classifier.grad_clip is not None: clip_grad_value_(self.encoder.parameters(), self.config.domain_classifier.grad_clip) clip_grad_value_(self.decoders[dset_num].parameters(), self.config.domain_classifier.grad_clip) self.model_optimizers[dset_num].step() self.loss_total[dset_num].add(loss.data.item()) return loss.data.item() def train_epoch(self, epoch): for meter in self.losses_recon: meter.reset() self.loss_d_right.reset() for i in range(len(self.loss_total)): self.loss_total[i].reset() self.encoder.train() self.classifier.train() for decoder in self.decoders: decoder.train() n_batches = self.config.data.epoch_len with tqdm(total=n_batches, desc='Train epoch %d' % epoch) as train_enum: for batch_num in range(n_batches): if self.config.data.short and batch_num == 3: break dset_num = batch_num % self.config.data.n_datasets x, x_aug = next(self.data[dset_num].train_iter) x = wrap_cuda(x) x_aug = wrap_cuda(x_aug) batch_loss = self.train_batch(x, x_aug, dset_num) train_enum.set_description( f'Train (loss: {batch_loss:.2f}) epoch {epoch}') train_enum.update() def evaluate_epoch(self, epoch): for meter in self.evals_recon: meter.reset() self.eval_d_right.reset() for i in range(len(self.eval_total)): self.eval_total[i].reset() self.encoder.eval() self.classifier.eval() for decoder in self.decoders: decoder.eval() n_batches = int(np.ceil(self.config.data.epoch_len / 10)) with tqdm(total=n_batches) as valid_enum, \ torch.no_grad(): for batch_num in range(n_batches): if self.config.data.short and batch_num == 10: break dset_num = batch_num % self.config.data.n_datasets x, x_aug = next(self.data[dset_num].valid_iter) x = wrap_cuda(x) x_aug = wrap_cuda(x_aug) batch_loss = self.eval_batch(x, x_aug, dset_num) valid_enum.set_description( f'Test (loss: {batch_loss:.2f}) epoch {epoch}') valid_enum.update() @staticmethod def format_losses(meters): losses = [meter.summarize_epoch() for meter in meters] return ', '.join('{:.4f}'.format(x) for x in losses) def train_losses(self): meters = [*self.losses_recon, self.loss_d_right] return self.format_losses(meters) def eval_losses(self): meters = [*self.evals_recon, self.eval_d_right] return self.format_losses(meters) def train(self): best_eval = [float('inf') for _ in range(self.config.data.n_datasets)] # Begin! for epoch in range(self.start_epoch, self.start_epoch + self.config.env.epochs): self.train_epoch(epoch) self.evaluate_epoch(epoch) self.logger.info(f'Epoch %s - Train loss: (%s), Test loss (%s)', epoch, self.train_losses(), self.eval_losses()) for i in range(len(self.lr_managers)): self.lr_managers[i].step() for dataset_id in range(self.config.data.n_datasets): val_loss = self.eval_total[dataset_id].summarize_epoch() if val_loss < best_eval[dataset_id]: self.save_model(f'bestmodel_{dataset_id}.pth', dataset_id) best_eval[dataset_id] = val_loss if not self.config.env.save_per_epoch: self.save_model(f'lastmodel_{dataset_id}.pth', dataset_id) else: self.save_model(f'lastmodel_{epoch}_rank_{dataset_id}.pth', dataset_id) torch.save([self.config, epoch], '%s/args.pth' % self.expPath) self.logger.debug('Ended epoch') def save_model(self, filename, decoder_id): save_path = self.expPath / filename torch.save( { 'encoder_state': self.encoder.module.state_dict(), 'decoder_state': self.decoders[decoder_id].module.state_dict(), 'discriminator_state': self.classifier.module.state_dict(), 'model_optimizer_state': self.model_optimizers[decoder_id].state_dict(), 'dataset': decoder_id, 'd_optimizer_state': self.classifier_optimizer.state_dict() }, save_path) self.logger.debug(f'Saved model to {save_path}')
device=device, seq_len=SEQ_LEN + 2).to(device) encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=args.LR) decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=args.LR) criterion = nn.CrossEntropyLoss( ) # 나중에 loss 계산 할때 패딩은 모두 없앨것이므로 ignore index를 설정하지 않는다 train_losses = [] validation_losses = [] for epoch in range(args.NUM_EPOCHS): train_loss = 0 validation_loss = 0 encoder.train() decoder.train() for idx, (img, caption_5, caption_lengths_5) in enumerate(train_loader): origin_img = img for i in range(args.NUM_CAPTIONS): encoder_optimizer.zero_grad() decoder_optimizer.zero_grad() img = origin_img.to(device) caption = caption_5[:, i, :].to(device) caption_lengths = caption_lengths_5[:, :, i].to(device) img = encoder(img) pred_length, captions, preds, coefs = decoder( img, caption, caption_lengths