def train_net(cfg, model_type): if model_type == Pix2VoxTypes.Pix2Vox_A or model_type == Pix2VoxTypes.Pix2Vox_Plus_Plus_A: use_refiner = True else: use_refiner = False # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use torch.backends.cudnn.benchmark = True # Set up data augmentation IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W train_transforms = utils.data_transforms.Compose([ utils.data_transforms.RandomCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground(cfg.TRAIN.RANDOM_BG_COLOR_RANGE), utils.data_transforms.ColorJitter(cfg.TRAIN.BRIGHTNESS, cfg.TRAIN.CONTRAST, cfg.TRAIN.SATURATION), utils.data_transforms.RandomNoise(cfg.TRAIN.NOISE_STD), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.RandomFlip(), utils.data_transforms.RandomPermuteRGB(), utils.data_transforms.ToTensor(), ]) val_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) # Set up data loader train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TRAIN_DATASET](cfg) val_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg) train_data_loader = torch.utils.data.DataLoader(dataset=train_dataset_loader.get_dataset( utils.data_loaders.DatasetType.TRAIN, cfg.CONST.N_VIEWS_RENDERING, train_transforms), batch_size=cfg.CONST.BATCH_SIZE, num_workers=cfg.CONST.NUM_WORKER, pin_memory=True, shuffle=True, drop_last=True) val_data_loader = torch.utils.data.DataLoader(dataset=val_dataset_loader.get_dataset( utils.data_loaders.DatasetType.VAL, cfg.CONST.N_VIEWS_RENDERING, val_transforms), batch_size=1, num_workers=cfg.CONST.NUM_WORKER, pin_memory=True, shuffle=False) # Set up networks encoder = Encoder(cfg, model_type) decoder = Decoder(cfg, model_type) if use_refiner: refiner = Refiner(cfg) merger = Merger(cfg, model_type) logging.debug('Parameters in Encoder: %d.' % (utils.helpers.count_parameters(encoder))) logging.debug('Parameters in Decoder: %d.' % (utils.helpers.count_parameters(decoder))) if use_refiner: logging.debug('Parameters in Refiner: %d.' % (utils.helpers.count_parameters(refiner))) logging.debug('Parameters in Merger: %d.' % (utils.helpers.count_parameters(merger))) # Initialize weights of networks encoder.apply(utils.helpers.init_weights) decoder.apply(utils.helpers.init_weights) if use_refiner: refiner.apply(utils.helpers.init_weights) merger.apply(utils.helpers.init_weights) # Set up solver if cfg.TRAIN.POLICY == 'adam': encoder_solver = torch.optim.Adam(filter(lambda p: p.requires_grad, encoder.parameters()), lr=cfg.TRAIN.ENCODER_LEARNING_RATE, betas=cfg.TRAIN.BETAS) decoder_solver = torch.optim.Adam(decoder.parameters(), lr=cfg.TRAIN.DECODER_LEARNING_RATE, betas=cfg.TRAIN.BETAS) if use_refiner: refiner_solver = torch.optim.Adam(refiner.parameters(), lr=cfg.TRAIN.REFINER_LEARNING_RATE, betas=cfg.TRAIN.BETAS) merger_solver = torch.optim.Adam(merger.parameters(), lr=cfg.TRAIN.MERGER_LEARNING_RATE, betas=cfg.TRAIN.BETAS) elif cfg.TRAIN.POLICY == 'sgd': encoder_solver = torch.optim.SGD(filter(lambda p: p.requires_grad, encoder.parameters()), lr=cfg.TRAIN.ENCODER_LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM) decoder_solver = torch.optim.SGD(decoder.parameters(), lr=cfg.TRAIN.DECODER_LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM) if use_refiner: refiner_solver = torch.optim.SGD(refiner.parameters(), lr=cfg.TRAIN.REFINER_LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM) merger_solver = torch.optim.SGD(merger.parameters(), lr=cfg.TRAIN.MERGER_LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM) else: raise Exception('[FATAL] %s Unknown optimizer %s.' % (dt.now(), cfg.TRAIN.POLICY)) # Set up learning rate scheduler to decay learning rates dynamically encoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(encoder_solver, milestones=cfg.TRAIN.ENCODER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA) decoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(decoder_solver, milestones=cfg.TRAIN.DECODER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA) if use_refiner: refiner_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(refiner_solver, milestones=cfg.TRAIN.REFINER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA) merger_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(merger_solver, milestones=cfg.TRAIN.MERGER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA) if torch.cuda.is_available(): encoder = torch.nn.DataParallel(encoder).cuda() decoder = torch.nn.DataParallel(decoder).cuda() if use_refiner: refiner = torch.nn.DataParallel(refiner).cuda() merger = torch.nn.DataParallel(merger).cuda() # Set up loss functions bce_loss = torch.nn.BCELoss() # Load pretrained model if exists init_epoch = 0 best_iou = -1 best_epoch = -1 if 'WEIGHTS' in cfg.CONST and cfg.TRAIN.RESUME_TRAIN: logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS)) checkpoint = torch.load(cfg.CONST.WEIGHTS) init_epoch = checkpoint['epoch_idx'] best_iou = checkpoint['best_iou'] best_epoch = checkpoint['best_epoch'] encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) if use_refiner: refiner.load_state_dict(checkpoint['refiner_state_dict']) if cfg.NETWORK.USE_MERGER: merger.load_state_dict(checkpoint['merger_state_dict']) logging.info('Recover complete. Current epoch #%d, Best IoU = %.4f at epoch #%d.' % (init_epoch, best_iou, best_epoch)) # Summary writer for TensorBoard output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s') cfg.DIR.LOGS = output_dir % f'logs_{model_type}_{cfg.DATASET.TRAIN_DATASET}_{cfg.CONST.SHAPENET_RATIO}' cfg.DIR.CHECKPOINTS = output_dir % f'checkpoints_{model_type}_{cfg.DATASET.TRAIN_DATASET}_{cfg.CONST.SHAPENET_RATIO}' train_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'train')) val_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'test')) # Training loop for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHS): # Tick / tock epoch_start_time = time() # Batch average meterics batch_time = AverageMeter() data_time = AverageMeter() encoder_losses = AverageMeter() if use_refiner: refiner_losses = AverageMeter() # switch models to training mode encoder.train() decoder.train() merger.train() if use_refiner: refiner.train() batch_end_time = time() n_batches = len(train_data_loader) for batch_idx, (taxonomy_names, sample_names, rendering_images, ground_truth_volumes) in enumerate(train_data_loader): # Measure data time data_time.update(time() - batch_end_time) # Get data from data loader rendering_images = utils.helpers.var_or_cuda(rendering_images) ground_truth_volumes = utils.helpers.var_or_cuda(ground_truth_volumes) # Train the encoder, decoder, refiner, and merger image_features = encoder(rendering_images) raw_features, generated_volumes = decoder(image_features) if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER: generated_volumes = merger(raw_features, generated_volumes) else: generated_volumes = torch.mean(generated_volumes, dim=1) encoder_loss = bce_loss(generated_volumes, ground_truth_volumes) * 10 if use_refiner and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: generated_volumes = refiner(generated_volumes) refiner_loss = bce_loss(generated_volumes, ground_truth_volumes) * 10 else: refiner_loss = encoder_loss # Gradient decent encoder.zero_grad() decoder.zero_grad() if use_refiner: refiner.zero_grad() merger.zero_grad() if use_refiner and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: encoder_loss.backward(retain_graph=True) refiner_loss.backward() else: encoder_loss.backward() encoder_solver.step() decoder_solver.step() if use_refiner: refiner_solver.step() merger_solver.step() # Append loss to average metrics encoder_losses.update(encoder_loss.item()) if use_refiner: refiner_losses.update(refiner_loss.item()) # Append loss to TensorBoard n_itr = epoch_idx * n_batches + batch_idx train_writer.add_scalar('EncoderDecoder/BatchLoss', encoder_loss.item(), n_itr) if use_refiner: train_writer.add_scalar('Refiner/BatchLoss', refiner_loss.item(), n_itr) # Tick / tock batch_time.update(time() - batch_end_time) batch_end_time = time() logging.info( '[Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' % (epoch_idx + 1, cfg.TRAIN.NUM_EPOCHS, batch_idx + 1, n_batches, batch_time.val, data_time.val, encoder_loss.item(), refiner_loss.item())) # Adjust learning rate encoder_lr_scheduler.step() decoder_lr_scheduler.step() if use_refiner: refiner_lr_scheduler.step() merger_lr_scheduler.step() # Append epoch loss to TensorBoard train_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg, epoch_idx + 1) if use_refiner: train_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg, epoch_idx + 1) # Tick / tock epoch_end_time = time() if use_refiner: logging.info('[Epoch %d/%d] EpochTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' % (epoch_idx + 1, cfg.TRAIN.NUM_EPOCHS, epoch_end_time - epoch_start_time, encoder_losses.avg, refiner_losses.avg)) else: logging.info('[Epoch %d/%d] EpochTime = %.3f (s) EDLoss = %.4f' % (epoch_idx + 1, cfg.TRAIN.NUM_EPOCHS, epoch_end_time - epoch_start_time, encoder_losses.avg)) # Update Rendering Views if cfg.TRAIN.UPDATE_N_VIEWS_RENDERING: n_views_rendering = random.randint(1, cfg.CONST.N_VIEWS_RENDERING) train_data_loader.dataset.set_n_views_rendering(n_views_rendering) logging.info('Epoch [%d/%d] Update #RenderingViews to %d' % (epoch_idx + 2, cfg.TRAIN.NUM_EPOCHS, n_views_rendering)) # Validate the training models if use_refiner: iou = test_net(cfg, model_type, DatasetType.VAL, epoch_idx + 1, val_data_loader, val_writer, encoder, decoder, refiner, merger) else: iou = test_net(cfg, model_type, DatasetType.VAL, epoch_idx + 1, val_data_loader, val_writer, encoder, decoder, refiner=None, merger=merger) # Save weights to file if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0 or iou > best_iou: file_name = 'checkpoint-epoch-%03d.pth' % (epoch_idx + 1) if iou > best_iou: best_iou = iou best_epoch = epoch_idx file_name = 'checkpoint-best.pth' output_path = os.path.join(cfg.DIR.CHECKPOINTS, file_name) if not os.path.exists(cfg.DIR.CHECKPOINTS): os.makedirs(cfg.DIR.CHECKPOINTS) checkpoint = { 'epoch_idx': epoch_idx, 'best_iou': best_iou, 'best_epoch': best_epoch, 'encoder_state_dict': encoder.state_dict(), 'decoder_state_dict': decoder.state_dict(), } if use_refiner: checkpoint['refiner_state_dict'] = refiner.state_dict() if cfg.NETWORK.USE_MERGER: checkpoint['merger_state_dict'] = merger.state_dict() torch.save(checkpoint, output_path) logging.info('Saved checkpoint to %s ...' % output_path) # Close SummaryWriter for TensorBoard train_writer.close() val_writer.close()
img, caption, caption_lengths ) # preds => [BATCH_SIZE * (현재 배치에서 유효한 가장 긴 길이) * VOCAB_SIZE] target = captions[:, 1:] # 이전의 start를 뺀 것과 같다. 단 길이별로 정렬되어있음 [BATCH_SIZE, SEQ_LEN(53)] preds = nn.utils.rnn.pack_padded_sequence(preds, pred_length, batch_first=True) target = nn.utils.rnn.pack_padded_sequence( target, pred_length, batch_first=True) loss = criterion(preds.data, target.data) loss = loss + args.ATTENTION_COEF * ( (1.0 - coefs.sum(dim=1))**2).mean() # regularization validation_loss += loss validation_losses.append(validation_loss) logger.write("val_loss : {}\n".format( validation_loss / (len(validation_loader) * args.NUM_CAPTIONS))) if epoch % 5 == 0: torch.save( encoder.state_dict(), os.path.join(args.MODEL_SAVE_PATH, "{}".format(epoch) + "encoder" + args.MODEL_NAME)) torch.save( decoder.state_dict(), os.path.join(args.MODEL_SAVE_PATH, "{}".format(epoch) + "decoder" + args.MODEL_NAME))