def evaluate_hand_draw_net(cfg): # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use torch.backends.cudnn.benchmark = True IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W eval_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) # Set up networks encoder = Encoder(cfg) decoder = Decoder(cfg) azi_classes, ele_classes = int(360 / cfg.CONST.BIN_SIZE), int( 180 / cfg.CONST.BIN_SIZE) view_estimater = ViewEstimater(cfg, azi_classes=azi_classes, ele_classes=ele_classes) if torch.cuda.is_available(): encoder = torch.nn.DataParallel(encoder).cuda() decoder = torch.nn.DataParallel(decoder).cuda() view_estimater = torch.nn.DataParallel(view_estimater).cuda() # Load weight # Load weight for encoder, decoder print('[INFO] %s Loading reconstruction weights from %s ...' % (dt.now(), cfg.EVALUATE_HAND_DRAW.RECONSTRUCTION_WEIGHTS)) rec_checkpoint = torch.load(cfg.EVALUATE_HAND_DRAW.RECONSTRUCTION_WEIGHTS) encoder.load_state_dict(rec_checkpoint['encoder_state_dict']) decoder.load_state_dict(rec_checkpoint['decoder_state_dict']) print('[INFO] Best reconstruction result at epoch %d ...' % rec_checkpoint['epoch_idx']) # Load weight for view estimater print('[INFO] %s Loading view estimation weights from %s ...' % (dt.now(), cfg.EVALUATE_HAND_DRAW.VIEW_ESTIMATION_WEIGHTS)) view_checkpoint = torch.load( cfg.EVALUATE_HAND_DRAW.VIEW_ESTIMATION_WEIGHTS) view_estimater.load_state_dict( view_checkpoint['view_estimator_state_dict']) print('[INFO] Best view estimation result at epoch %d ...' % view_checkpoint['epoch_idx']) for img_path in os.listdir(cfg.EVALUATE_HAND_DRAW.INPUT_IMAGE_FOLDER): eval_id = int(img_path[:-4]) input_img_path = os.path.join( cfg.EVALUATE_HAND_DRAW.INPUT_IMAGE_FOLDER, img_path) print(input_img_path) evaluate_hand_draw_img(cfg, encoder, decoder, view_estimater, input_img_path, eval_transforms, eval_id)
def run(ckpt_fpath): checkpoint = torch.load(ckpt_fpath) """ Load Config """ config = dict_to_cls(checkpoint['config']) """ Build Data Loader """ if config.corpus == "MSVD": corpus = MSVD(config) elif config.corpus == "MSR-VTT": corpus = MSRVTT(config) train_iter, val_iter, test_iter, vocab = \ corpus.train_data_loader, corpus.val_data_loader, corpus.test_data_loader, corpus.vocab print( '#vocabs: {} ({}), #words: {} ({}). Trim words which appear less than {} times.' .format(vocab.n_vocabs, vocab.n_vocabs_untrimmed, vocab.n_words, vocab.n_words_untrimmed, config.loader.min_count)) """ Build Models """ decoder = Decoder(rnn_type=config.decoder.rnn_type, num_layers=config.decoder.rnn_num_layers, num_directions=config.decoder.rnn_num_directions, feat_size=config.feat.size, feat_len=config.loader.frame_sample_len, embedding_size=config.vocab.embedding_size, hidden_size=config.decoder.rnn_hidden_size, attn_size=config.decoder.rnn_attn_size, output_size=vocab.n_vocabs, rnn_dropout=config.decoder.rnn_dropout) decoder.load_state_dict(checkpoint['decoder']) model = CaptionGenerator(decoder, config.loader.max_caption_len, vocab) model = model.cuda() """ Train Set """ """ train_vid2pred = get_predicted_captions(train_iter, model, model.vocab, beam_width=5, beam_alpha=0.) train_vid2GTs = get_groundtruth_captions(train_iter, model.vocab) train_scores = score(train_vid2pred, train_vid2GTs) print("[TRAIN] {}".format(train_scores)) """ """ Validation Set """ """ val_vid2pred = get_predicted_captions(val_iter, model, model.vocab, beam_width=5, beam_alpha=0.) val_vid2GTs = get_groundtruth_captions(val_iter, model.vocab) val_scores = score(val_vid2pred, val_vid2GTs) print("[VAL] scores: {}".format(val_scores)) """ """ Test Set """ test_vid2pred = get_predicted_captions(test_iter, model, model.vocab, beam_width=5, beam_alpha=0.) test_vid2GTs = get_groundtruth_captions(test_iter, model.vocab) test_scores = score(test_vid2pred, test_vid2GTs) print("[TEST] {}".format(test_scores)) test_save_fpath = os.path.join(C.result_dpath, "{}_{}.csv".format(config.corpus, 'test')) save_result(test_vid2pred, test_vid2GTs, test_save_fpath)
class Visualization_demo(): def __init__(self, cfg, output_dir): self.encoder = Encoder(cfg) self.decoder = Decoder(cfg) self.refiner = Refiner(cfg) self.merger = Merger(cfg) checkpoint = torch.load(cfg.CHECKPOINT) encoder_state_dict = clean_state_dict(checkpoint['encoder_state_dict']) self.encoder.load_state_dict(encoder_state_dict) decoder_state_dict = clean_state_dict(checkpoint['decoder_state_dict']) self.decoder.load_state_dict(decoder_state_dict) if cfg.NETWORK.USE_REFINER: refiner_state_dict = clean_state_dict( checkpoint['refiner_state_dict']) self.refiner.load_state_dict(refiner_state_dict) if cfg.NETWORK.USE_MERGER: merger_state_dict = clean_state_dict( checkpoint['merger_state_dict']) self.merger.load_state_dict(merger_state_dict) if not os.path.exists(output_dir): os.makedirs(output_dir) self.output_dir = output_dir def run_on_images(self, imgs, sid, mid, iid, sampled_idx): dir1 = os.path.join(output_dir, str(sid), str(mid)) if not os.path.exists(dir1): os.makedirs(dir1) deprocess = imagenet_deprocess(rescale_image=False) image_features = self.encoder(imgs) raw_features, generated_volume = self.decoder(image_features) generated_volume = self.merger(raw_features, generated_volume) generated_volume = self.refiner(generated_volume) mesh = cubify(generated_volume, 0.3) # mesh = voxel_to_world(meshes) save_mesh = os.path.join(dir1, "%s_%s.obj" % (iid, sampled_idx)) verts, faces = mesh.get_mesh_verts_faces(0) save_obj(save_mesh, verts, faces) generated_volume = generated_volume.squeeze() img = image_to_numpy(deprocess(imgs[0][0])) save_img = os.path.join(dir1, "%02d.png" % (iid)) # cv2.imwrite(save_img, img[:, :, ::-1]) cv2.imwrite(save_img, img) img1 = image_to_numpy(deprocess(imgs[0][1])) save_img1 = os.path.join(dir1, "%02d.png" % (sampled_idx)) cv2.imwrite(save_img1, img1) # cv2.imwrite(save_img1, img1[:, :, ::-1]) get_volume_views(generated_volume, dir1, iid, sampled_idx)
def build_model(C, vocab): decoder = Decoder(rnn_type=C.decoder.rnn_type, num_layers=C.decoder.rnn_num_layers, num_directions=C.decoder.rnn_num_directions, feat_size=C.feat.size, feat_len=C.loader.frame_sample_len, embedding_size=C.vocab.embedding_size, hidden_size=C.decoder.rnn_hidden_size, attn_size=C.decoder.rnn_attn_size, output_size=vocab.n_vocabs, rnn_dropout=C.decoder.rnn_dropout) if C.pretrained_decoder_fpath is not None: decoder.load_state_dict( torch.load(C.pretrained_decoder_fpath)['decoder']) print("Pretrained decoder is loaded from {}".format( C.pretrained_decoder_fpath)) #全局和局部重构器 if C.reconstructor is None: reconstructor = None elif C.reconstructor.type == 'global': reconstructor = GlobalReconstructor( rnn_type=C.reconstructor.rnn_type, num_layers=C.reconstructor.rnn_num_layers, num_directions=C.reconstructor.rnn_num_directions, decoder_size=C.decoder.rnn_hidden_size, hidden_size=C.reconstructor.rnn_hidden_size, rnn_dropout=C.reconstructor.rnn_dropout) else: reconstructor = LocalReconstructor( rnn_type=C.reconstructor.rnn_type, num_layers=C.reconstructor.rnn_num_layers, num_directions=C.reconstructor.rnn_num_directions, decoder_size=C.decoder.rnn_hidden_size, hidden_size=C.reconstructor.rnn_hidden_size, attn_size=C.reconstructor.rnn_attn_size, rnn_dropout=C.reconstructor.rnn_dropout) if C.pretrained_reconstructor_fpath is not None: reconstructor.load_state_dict( torch.load(C.pretrained_reconstructor_fpath)['reconstructor']) print("Pretrained reconstructor is loaded from {}".format( C.pretrained_reconstructor_fpath)) model = CaptionGenerator(decoder, reconstructor, C.loader.max_caption_len, vocab) model.cuda() return model
class Quantitative_analysis_demo(): def __init__(self, cfg, output_dir): self.encoder = Encoder(cfg) self.decoder = Decoder(cfg) self.refiner = Refiner(cfg) self.merger = Merger(cfg) # self.thresh = cfg.VOXEL_THRESH self.th = cfg.TEST.VOXEL_THRESH checkpoint = torch.load(cfg.CHECKPOINT) encoder_state_dict = clean_state_dict(checkpoint['encoder_state_dict']) self.encoder.load_state_dict(encoder_state_dict) decoder_state_dict = clean_state_dict(checkpoint['decoder_state_dict']) self.decoder.load_state_dict(decoder_state_dict) if cfg.NETWORK.USE_REFINER: refiner_state_dict = clean_state_dict( checkpoint['refiner_state_dict']) self.refiner.load_state_dict(refiner_state_dict) if cfg.NETWORK.USE_MERGER: merger_state_dict = clean_state_dict( checkpoint['merger_state_dict']) self.merger.load_state_dict(merger_state_dict) self.output_dir = output_dir def calculate_iou(self, imgs, GT_voxels, sid, mid, iid): dir1 = os.path.join(self.output_dir, str(sid), str(mid)) if not os.path.exists(dir1): os.makedirs(dir1) image_features = self.encoder(imgs) raw_features, generated_volume = self.decoder(image_features) generated_volume = self.merger(raw_features, generated_volume) generated_volume = self.refiner(generated_volume) generated_volume = generated_volume.squeeze() sample_iou = [] for th in self.th: _volume = torch.ge(generated_volume, th).float() intersection = torch.sum(_volume.mul(GT_voxels)).float() union = torch.sum(torch.ge(_volume.add(GT_voxels), 1)).float() sample_iou.append((intersection / union).item()) return sample_iou
def build_model(vocab): decoder = Decoder(rnn_type=C.decoder.rnn_type, num_layers=C.decoder.rnn_num_layers, num_directions=C.decoder.rnn_num_directions, feat_size=C.feat.size, feat_len=C.loader.frame_sample_len, embedding_size=C.vocab.embedding_size, hidden_size=C.decoder.rnn_hidden_size, attn_size=C.decoder.rnn_attn_size, output_size=vocab.n_vocabs, rnn_dropout=C.decoder.rnn_dropout) if C.pretrained_decoder_fpath is not None: decoder.load_state_dict( torch.load(C.pretrained_decoder_fpath)['decoder']) print("Pretrained decoder is loaded from {}".format( C.pretrained_decoder_fpath)) model = CaptionGenerator(decoder, C.loader.max_caption_len, vocab) model.cuda() return model
def main(): checkpoint = torch.load(C.model_fpath) TC = MockConfig() TC_dict = dict(checkpoint['config'].__dict__) for key, val in TC_dict.items(): setattr(TC, key, val) TC.build_train_data_loader = False TC.build_val_data_loader = False TC.build_test_data_loader = True TC.build_score_data_loader = True TC.test_video_fpath = C.test_video_fpath TC.test_caption_fpath = C.test_caption_fpath MSVD = _MSVD(TC) vocab = MSVD.vocab score_data_loader = MSVD.score_data_loader decoder = Decoder( model_name=TC.decoder_model, n_layers=TC.decoder_n_layers, encoder_size=TC.encoder_output_size, embedding_size=TC.embedding_size, embedding_scale=TC.embedding_scale, hidden_size=TC.decoder_hidden_size, attn_size=TC.decoder_attn_size, output_size=vocab.n_vocabs, embedding_dropout=TC.embedding_dropout, dropout=TC.decoder_dropout, out_dropout=TC.decoder_out_dropout, ) decoder = decoder.to(C.device) decoder.load_state_dict(checkpoint['dec']) decoder.eval() scores = evaluate(TC, MSVD, score_data_loader, decoder, ("beam", 5)) print(scores)
def test_net(cfg, epoch_idx=-1, output_dir=None, test_data_loader=None, \ test_writer=None, encoder=None, decoder=None, refiner=None, merger=None): # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use torch.backends.cudnn.benchmark = True # Load taxonomies of dataset taxonomies = [] with open(cfg.DATASETS[cfg.DATASET.TEST_DATASET.upper()].TAXONOMY_FILE_PATH, encoding='utf-8') as file: taxonomies = json.loads(file.read()) taxonomies = {t['taxonomy_id']: t for t in taxonomies} # Set up data loader if test_data_loader is None: # Set up data augmentation IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W test_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg) test_data_loader = torch.utils.data.DataLoader( dataset=dataset_loader.get_dataset(utils.data_loaders.DatasetType.TEST, cfg.CONST.N_VIEWS_RENDERING, test_transforms), batch_size=1, num_workers=1, pin_memory=True, shuffle=False) # Set up networks if decoder is None or encoder is None: encoder = Encoder(cfg) decoder = Decoder(cfg) refiner = Refiner(cfg) merger = Merger(cfg) if torch.cuda.is_available(): encoder = torch.nn.DataParallel(encoder).cuda() decoder = torch.nn.DataParallel(decoder).cuda() refiner = torch.nn.DataParallel(refiner).cuda() merger = torch.nn.DataParallel(merger).cuda() print('[INFO] %s Loading weights from %s ...' % (dt.now(), cfg.CONST.WEIGHTS)) checkpoint = torch.load(cfg.CONST.WEIGHTS) epoch_idx = checkpoint['epoch_idx'] encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) if cfg.NETWORK.USE_REFINER: refiner.load_state_dict(checkpoint['refiner_state_dict']) if cfg.NETWORK.USE_MERGER: merger.load_state_dict(checkpoint['merger_state_dict']) # Set up loss functions bce_loss = torch.nn.BCELoss() # Testing loop n_samples = len(test_data_loader) test_iou = dict() encoder_losses = utils.network_utils.AverageMeter() refiner_losses = utils.network_utils.AverageMeter() # Switch models to evaluation mode encoder.eval() decoder.eval() refiner.eval() merger.eval() for sample_idx, (taxonomy_id, sample_name, rendering_images, ground_truth_volume) in enumerate(test_data_loader): taxonomy_id = taxonomy_id[0] if isinstance(taxonomy_id[0], str) else taxonomy_id[0].item() sample_name = sample_name[0] with torch.no_grad(): # Get data from data loader rendering_images = utils.network_utils.var_or_cuda(rendering_images) ground_truth_volume = utils.network_utils.var_or_cuda(ground_truth_volume) # Test the encoder, decoder, refiner and merger image_features = encoder(rendering_images) raw_features, generated_volume = decoder(image_features) if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER: generated_volume = merger(raw_features, generated_volume) else: generated_volume = torch.mean(generated_volume, dim=1) encoder_loss = bce_loss(generated_volume, ground_truth_volume) * 10 if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: generated_volume = refiner(generated_volume) refiner_loss = bce_loss(generated_volume, ground_truth_volume) * 10 else: refiner_loss = encoder_loss print("vox shape {}".format(generated_volume.shape)) # Append loss and accuracy to average metrics encoder_losses.update(encoder_loss.item()) refiner_losses.update(refiner_loss.item()) # IoU per sample sample_iou = [] for th in cfg.TEST.VOXEL_THRESH: _volume = torch.ge(generated_volume, th).float() intersection = torch.sum(_volume.mul(ground_truth_volume)).float() union = torch.sum(torch.ge(_volume.add(ground_truth_volume), 1)).float() sample_iou.append((intersection / union).item()) # IoU per taxonomy if not taxonomy_id in test_iou: test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []} test_iou[taxonomy_id]['n_samples'] += 1 test_iou[taxonomy_id]['iou'].append(sample_iou) # Append generated volumes to TensorBoard if output_dir and sample_idx < 3: img_dir = output_dir % 'images' # Volume Visualization gv = generated_volume.cpu().numpy() rendering_views = utils.binvox_visualization.get_volume_views(gv, os.path.join(img_dir, 'test'), epoch_idx) if not test_writer is None: test_writer.add_image('Test Sample#%02d/Volume Reconstructed' % sample_idx, rendering_views, epoch_idx) gtv = ground_truth_volume.cpu().numpy() rendering_views = utils.binvox_visualization.get_volume_views(gtv, os.path.join(img_dir, 'test'), epoch_idx) if not test_writer is None: test_writer.add_image('Test Sample#%02d/Volume GroundTruth' % sample_idx, rendering_views, epoch_idx) # Print sample loss and IoU print('[INFO] %s Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f RLoss = %.4f IoU = %s' % \ (dt.now(), sample_idx + 1, n_samples, taxonomy_id, sample_name, encoder_loss.item(), \ refiner_loss.item(), ['%.4f' % si for si in sample_iou])) # Output testing results mean_iou = [] for taxonomy_id in test_iou: test_iou[taxonomy_id]['iou'] = np.mean(test_iou[taxonomy_id]['iou'], axis=0) mean_iou.append(test_iou[taxonomy_id]['iou'] * test_iou[taxonomy_id]['n_samples']) mean_iou = np.sum(mean_iou, axis=0) / n_samples # Print header print('============================ TEST RESULTS ============================') print('Taxonomy', end='\t') print('#Sample', end='\t') print('Baseline', end='\t') for th in cfg.TEST.VOXEL_THRESH: print('t=%.2f' % th, end='\t') print() # Print body for taxonomy_id in test_iou: print('%s' % taxonomies[taxonomy_id]['taxonomy_name'].ljust(8), end='\t') print('%d' % test_iou[taxonomy_id]['n_samples'], end='\t') if 'baseline' in taxonomies[taxonomy_id]: print('%.4f' % taxonomies[taxonomy_id]['baseline']['%d-view' % cfg.CONST.N_VIEWS_RENDERING], end='\t\t') else: print('N/a', end='\t\t') for ti in test_iou[taxonomy_id]['iou']: print('%.4f' % ti, end='\t') print() # Print mean IoU for each threshold print('Overall ', end='\t\t\t\t') for mi in mean_iou: print('%.4f' % mi, end='\t') print('\n') # Add testing results to TensorBoard max_iou = np.max(mean_iou) if not test_writer is None: test_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg, epoch_idx) test_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg, epoch_idx) test_writer.add_scalar('Refiner/IoU', max_iou, epoch_idx) return max_iou
def test_single_img_net(cfg): encoder = Encoder(cfg) decoder = Decoder(cfg) refiner = Refiner(cfg) merger = Merger(cfg) print('[INFO] %s Loading weights from %s ...' % (dt.now(), cfg.CONST.WEIGHTS)) checkpoint = torch.load(cfg.CONST.WEIGHTS, map_location=torch.device('cpu')) fix_checkpoint = {} fix_checkpoint['encoder_state_dict'] = OrderedDict( (k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items()) fix_checkpoint['decoder_state_dict'] = OrderedDict( (k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items()) fix_checkpoint['refiner_state_dict'] = OrderedDict( (k.split('module.')[1:][0], v) for k, v in checkpoint['refiner_state_dict'].items()) fix_checkpoint['merger_state_dict'] = OrderedDict( (k.split('module.')[1:][0], v) for k, v in checkpoint['merger_state_dict'].items()) epoch_idx = checkpoint['epoch_idx'] encoder.load_state_dict(fix_checkpoint['encoder_state_dict']) decoder.load_state_dict(fix_checkpoint['decoder_state_dict']) if cfg.NETWORK.USE_REFINER: print('Use refiner') refiner.load_state_dict(fix_checkpoint['refiner_state_dict']) if cfg.NETWORK.USE_MERGER: print('Use merger') merger.load_state_dict(fix_checkpoint['merger_state_dict']) encoder.eval() decoder.eval() refiner.eval() merger.eval() img1_path = '/media/caig/FECA2C89CA2C406F/dataset/ShapeNetRendering_copy/03001627/1a74a83fa6d24b3cacd67ce2c72c02e/rendering/00.png' img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype( np.float32) / 255. sample = np.array([img1_np]) IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W test_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) rendering_images = test_transforms(rendering_images=sample) rendering_images = rendering_images.unsqueeze(0) with torch.no_grad(): image_features = encoder(rendering_images) raw_features, generated_volume = decoder(image_features) if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER: generated_volume = merger(raw_features, generated_volume) else: generated_volume = torch.mean(generated_volume, dim=1) if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: generated_volume = refiner(generated_volume) generated_volume = generated_volume.squeeze(0) img_dir = '/media/caig/FECA2C89CA2C406F/sketch3D/sketch3D/test_output' gv = generated_volume.cpu().numpy() gv_new = np.swapaxes(gv, 2, 1) rendering_views = utils.binvox_visualization.get_volume_views( gv_new, os.path.join(img_dir), epoch_idx)
def main(): #parsing the arguments args, _ = parse_arguments() #setup logging #output_dir = Path('/content/drive/My Drive/image-captioning/output') output_dir = Path(args.output_directory) output_dir.mkdir(parents=True, exist_ok=True) logfile_path = Path(output_dir / "output.log") setup_logging(logfile=logfile_path) #setup and read config.ini #config_file = Path('/content/drive/My Drive/image-captioning/config.ini') config_file = Path('../config.ini') reading_config(config_file) #tensorboard tensorboard_logfile = Path(output_dir / 'tensorboard') tensorboard_writer = SummaryWriter(tensorboard_logfile) #load dataset #dataset_dir = Path('/content/drive/My Drive/Flickr8k_Dataset') dataset_dir = Path(args.dataset) images_path = Path(dataset_dir / Config.get("images_dir")) captions_path = Path(dataset_dir / Config.get("captions_dir")) training_loader, validation_loader, testing_loader = data_loaders( images_path, captions_path) #load the model (encoder, decoder, optimizer) embed_size = Config.get("encoder_embed_size") hidden_size = Config.get("decoder_hidden_size") batch_size = Config.get("training_batch_size") epochs = Config.get("epochs") feature_extraction = Config.get("feature_extraction") raw_captions = read_captions(captions_path) id_to_word, word_to_id = dictionary(raw_captions, threshold=5) vocab_size = len(id_to_word) encoder = Encoder(embed_size, feature_extraction) decoder = Decoder(embed_size, hidden_size, vocab_size, batch_size) #load pretrained embeddings #pretrained_emb_dir = Path('/content/drive/My Drive/word2vec') pretrained_emb_dir = Path(args.pretrained_embeddings) pretrained_emb_file = Path(pretrained_emb_dir / Config.get("pretrained_emb_path")) pretrained_embeddings = load_pretrained_embeddings(pretrained_emb_file, id_to_word) #load the optimizer learning_rate = Config.get("learning_rate") optimizer = adam_optimizer(encoder, decoder, learning_rate) #loss funciton criterion = cross_entropy #load checkpoint checkpoint_file = Path(output_dir / Config.get("checkpoint_file")) checkpoint_captioning = load_checkpoint(checkpoint_file) #using available device(gpu/cpu) encoder = encoder.to(Config.get("device")) decoder = decoder.to(Config.get("device")) pretrained_embeddings = pretrained_embeddings.to(Config.get("device")) start_epoch = 1 if checkpoint_captioning is not None: start_epoch = checkpoint_captioning['epoch'] + 1 encoder.load_state_dict(checkpoint_captioning['encoder']) decoder.load_state_dict(checkpoint_captioning['decoder']) optimizer.load_state_dict(checkpoint_captioning['optimizer']) logger.info( 'Initialized encoder, decoder and optimizer from loaded checkpoint' ) del checkpoint_captioning #image captioning model model = ImageCaptioning(encoder, decoder, optimizer, criterion, training_loader, validation_loader, testing_loader, pretrained_embeddings, output_dir, tensorboard_writer) #training and testing the model if args.training: validate_every = Config.get("validate_every") model.train(epochs, validate_every, start_epoch) elif args.testing: images_path = Path(images_path / Config.get("images_dir")) model.testing(id_to_word, images_path)
def test_single_img(cfg): encoder = Encoder(cfg) decoder = Decoder(cfg) refiner = Refiner(cfg) merger = Merger(cfg) cfg.CONST.WEIGHTS = 'D:/Pix2Vox/Pix2Vox/pretrained/Pix2Vox-A-ShapeNet.pth' checkpoint = torch.load(cfg.CONST.WEIGHTS, map_location=torch.device('cpu')) fix_checkpoint = {} fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items()) fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items()) fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['refiner_state_dict'].items()) fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['merger_state_dict'].items()) epoch_idx = checkpoint['epoch_idx'] encoder.load_state_dict(fix_checkpoint['encoder_state_dict']) decoder.load_state_dict(fix_checkpoint['decoder_state_dict']) if cfg.NETWORK.USE_REFINER: print('Use refiner') refiner.load_state_dict(fix_checkpoint['refiner_state_dict']) if cfg.NETWORK.USE_MERGER: print('Use merger') merger.load_state_dict(fix_checkpoint['merger_state_dict']) encoder.eval() decoder.eval() refiner.eval() merger.eval() img1_path = 'D:/Pix2Vox/Pix2Vox/rand/minecraft.png' img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. sample = np.array([img1_np]) IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W test_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) rendering_images = test_transforms(rendering_images=sample) rendering_images = rendering_images.unsqueeze(0) with torch.no_grad(): image_features = encoder(rendering_images) raw_features, generated_volume = decoder(image_features) if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER: generated_volume = merger(raw_features, generated_volume) else: generated_volume = torch.mean(generated_volume, dim=1) if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: generated_volume = refiner(generated_volume) generated_volume = generated_volume.squeeze(0) img_dir = 'D:/Pix2Vox/Pix2Vox/output' gv = generated_volume.cpu().numpy() gv_new = np.swapaxes(gv, 2, 1) print(gv_new) rendering_views = utils.binvox_visualization.get_volume_views(gv_new, os.path.join(img_dir), epoch_idx)
vocab = Vocabulary("./captions.json", args.NUM_CAPTIONS, num_fre=args.NUM_FRE) VOCAB_SIZE = vocab.num_words SEQ_LEN = vocab.max_sentence_len encoder = Encoder(args.ENCODER_OUTPUT_SIZE) decoder = Decoder(embed_size=args.EMBED_SIZE, hidden_size=args.HIDDEN_SIZE, attention_size=args.ATTENTION_SIZE, vocab_size=VOCAB_SIZE, encoder_size=2048, device=device, seq_len=SEQ_LEN + 2) encoder.load_state_dict(torch.load(args.ENCODER_MODEL_LOAD_PATH)) decoder.load_state_dict(torch.load(args.DECODER_MODEL_LOAD_PATH)) encoder.to(device) decoder.to(device) encoder.eval() decoder.eval() result_json = {"images": []} for path in IMG_PATH: img_name = path.split("/")[-1] img = Image.open(path) img = transform(img).unsqueeze(0).to( device) # [BATCH_SIZE(1) * CHANNEL * INPUT_SIZE * INPUT_SIZE] num_sentence = args.NUM_TOP_PROB top_prev_prob = torch.zeros((num_sentence, 1)).to(device) words = torch.Tensor([vocab.SOS_token]).long().expand(
def test_net(cfg, model_type, dataset_type, results_file_name, epoch_idx=-1, test_data_loader=None, test_writer=None, encoder=None, decoder=None, refiner=None, merger=None, save_results_to_file=False, show_voxels=False, path_to_times_csv=None): if model_type == Pix2VoxTypes.Pix2Vox_A or model_type == Pix2VoxTypes.Pix2Vox_Plus_Plus_A: use_refiner = True else: use_refiner = False # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use torch.backends.cudnn.benchmark = True # Set up data loader if test_data_loader is None: # Set up data augmentation IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W test_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground( cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TEST_DATASET](cfg) test_data_loader = torch.utils.data.DataLoader( dataset=dataset_loader.get_dataset(dataset_type, cfg.CONST.N_VIEWS_RENDERING, test_transforms), batch_size=1, num_workers=cfg.CONST.NUM_WORKER, pin_memory=True, shuffle=False) # Set up networks if decoder is None or encoder is None: encoder = Encoder(cfg, model_type) decoder = Decoder(cfg, model_type) if use_refiner: refiner = Refiner(cfg) merger = Merger(cfg, model_type) if torch.cuda.is_available(): encoder = torch.nn.DataParallel(encoder).cuda() decoder = torch.nn.DataParallel(decoder).cuda() if use_refiner: refiner = torch.nn.DataParallel(refiner).cuda() merger = torch.nn.DataParallel(merger).cuda() logging.info('Loading weights from %s ...' % (cfg.CONST.WEIGHTS)) checkpoint = torch.load(cfg.CONST.WEIGHTS) epoch_idx = checkpoint['epoch_idx'] encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) if use_refiner: refiner.load_state_dict(checkpoint['refiner_state_dict']) if cfg.NETWORK.USE_MERGER: merger.load_state_dict(checkpoint['merger_state_dict']) # Set up loss functions bce_loss = torch.nn.BCELoss() # Testing loop n_samples = len(test_data_loader) test_iou = dict() encoder_losses = AverageMeter() if use_refiner: refiner_losses = AverageMeter() # Switch models to evaluation mode encoder.eval() decoder.eval() if use_refiner: refiner.eval() merger.eval() samples_names = [] edlosses = [] rlosses = [] ious_dict = {} for iou_threshold in cfg.TEST.VOXEL_THRESH: ious_dict[iou_threshold] = [] if path_to_times_csv is not None: n_view_list = [] times_list = [] for sample_idx, (taxonomy_id, sample_name, rendering_images, ground_truth_volume) in enumerate(test_data_loader): taxonomy_id = taxonomy_id[0] if isinstance( taxonomy_id[0], str) else taxonomy_id[0].item() sample_name = sample_name[0] with torch.no_grad(): # Get data from data loader rendering_images = utils.helpers.var_or_cuda(rendering_images) ground_truth_volume = utils.helpers.var_or_cuda( ground_truth_volume) if path_to_times_csv is not None: start_time = time.time() # Test the encoder, decoder, refiner and merger image_features = encoder(rendering_images) raw_features, generated_volume = decoder(image_features) if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER: generated_volume = merger(raw_features, generated_volume) else: generated_volume = torch.mean(generated_volume, dim=1) encoder_loss = bce_loss(generated_volume, ground_truth_volume) * 10 if use_refiner and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: generated_volume = refiner(generated_volume) refiner_loss = bce_loss(generated_volume, ground_truth_volume) * 10 else: refiner_loss = encoder_loss if path_to_times_csv is not None: end_time = time.time() n_view_list.append(rendering_images.size()[1]) times_list.append(end_time - start_time) # Append loss and accuracy to average metrics encoder_losses.update(encoder_loss.item()) if use_refiner: refiner_losses.update(refiner_loss.item()) # IoU per sample sample_iou = [] for th in cfg.TEST.VOXEL_THRESH: _volume = torch.ge(generated_volume, th).float() intersection = torch.sum( _volume.mul(ground_truth_volume)).float() union = torch.sum(torch.ge(_volume.add(ground_truth_volume), 1)).float() sample_iou.append((intersection / union).item()) ious_dict[th].append((intersection / union).item()) # IoU per taxonomy if taxonomy_id not in test_iou: test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []} test_iou[taxonomy_id]['n_samples'] += 1 test_iou[taxonomy_id]['iou'].append(sample_iou) # Append generated volumes to TensorBoard if show_voxels: with open("model.binvox", "wb") as f: v = br.Voxels( torch.ge(generated_volume, 0.2).float().cpu().numpy()[0], (32, 32, 32), (0, 0, 0), 1, "xyz") v.write(f) subprocess.run([VIEWVOX_EXE, "model.binvox"]) with open("model.binvox", "wb") as f: v = br.Voxels(ground_truth_volume.cpu().numpy()[0], (32, 32, 32), (0, 0, 0), 1, "xyz") v.write(f) subprocess.run([VIEWVOX_EXE, "model.binvox"]) # Print sample loss and IoU logging.info( 'Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f RLoss = %.4f IoU = %s' % (sample_idx + 1, n_samples, taxonomy_id, sample_name, encoder_loss.item(), refiner_loss.item(), ['%.4f' % si for si in sample_iou])) samples_names.append(sample_name) edlosses.append(encoder_loss.item()) if use_refiner: rlosses.append(refiner_loss.item()) if save_results_to_file: save_test_results_to_csv(samples_names, edlosses, rlosses, ious_dict, path_to_csv=results_file_name) if path_to_times_csv is not None: save_times_to_csv(times_list, n_view_list, path_to_csv=path_to_times_csv) # Output testing results mean_iou = [] for taxonomy_id in test_iou: test_iou[taxonomy_id]['iou'] = np.mean(test_iou[taxonomy_id]['iou'], axis=0) mean_iou.append(test_iou[taxonomy_id]['iou'] * test_iou[taxonomy_id]['n_samples']) mean_iou = np.sum(mean_iou, axis=0) / n_samples # Print header print( '============================ TEST RESULTS ============================' ) print('Taxonomy', end='\t') print('#Sample', end='\t') print('Baseline', end='\t') for th in cfg.TEST.VOXEL_THRESH: print('t=%.2f' % th, end='\t') print() # Print mean IoU for each threshold print('Overall ', end='\t\t\t\t') for mi in mean_iou: print('%.4f' % mi, end='\t') print('\n') # Add testing results to TensorBoard max_iou = np.max(mean_iou) if test_writer is not None: test_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg, epoch_idx) if use_refiner: test_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg, epoch_idx) test_writer.add_scalar('Refiner/IoU', max_iou, epoch_idx) return max_iou
def train(): opt = parse_opt() train_mode = opt.train_mode idx2word = json.load(open(opt.idx2word, 'r')) captions = json.load(open(opt.captions, 'r')) # 模型 decoder = Decoder(idx2word, opt.settings) decoder.to(opt.device) lr = opt.learning_rate optimizer, xe_criterion = decoder.get_optim_and_crit(lr) if opt.resume: print("====> loading checkpoint '{}'".format(opt.resume)) chkpoint = torch.load(opt.resume, map_location=lambda s, l: s) assert opt.settings == chkpoint['settings'], \ 'opt.settings and resume model settings are different' assert idx2word == chkpoint['idx2word'], \ 'idx2word and resume model idx2word are different' decoder.load_state_dict(chkpoint['model']) if chkpoint['train_mode'] == train_mode: optimizer.load_state_dict(chkpoint['optimizer']) lr = optimizer.param_groups[0]['lr'] print("====> loaded checkpoint '{}', epoch: {}, train_mode: {}".format( opt.resume, chkpoint['epoch'], chkpoint['train_mode'])) elif train_mode == 'rl': raise Exception('"rl" mode need resume model') print('====> process image captions begin') word2idx = {} for i, w in enumerate(idx2word): word2idx[w] = i captions_id = {} for split, caps in captions.items(): print('convert %s captions to index' % split) captions_id[split] = {} for fn, seqs in tqdm.tqdm(caps.items(), ncols=100): tmp = [] for seq in seqs: tmp.append( [decoder.sos_id] + [word2idx.get(w, None) or word2idx['<UNK>'] for w in seq] + [decoder.eos_id]) captions_id[split][fn] = tmp captions = captions_id print('====> process image captions end') train_data = get_dataloader(opt.img_feats, captions['train'], decoder.pad_id, opt.max_seq_len, opt.batch_size, opt.num_workers) val_data = get_dataloader(opt.img_feats, captions['val'], decoder.pad_id, opt.max_seq_len, opt.batch_size, opt.num_workers, shuffle=False) test_captions = {} for fn in captions['test']: test_captions[fn] = [[]] test_data = get_dataloader(opt.img_feats, test_captions, decoder.pad_id, opt.max_seq_len, opt.batch_size, opt.num_workers, shuffle=False) if train_mode == 'rl': rl_criterion = RewardCriterion() ciderd_scorer = get_ciderd_scorer(captions, decoder.sos_id, decoder.eos_id) def forward(data, training=True, ss_prob=0.0): decoder.train(training) loss_val = 0.0 reward_val = 0.0 for fns, fc_feats, (caps_tensor, lengths), ground_truth in tqdm.tqdm(data, ncols=100): fc_feats = fc_feats.to(opt.device) caps_tensor = caps_tensor.to(opt.device) if training and train_mode == 'rl': sample_captions, sample_logprobs, seq_masks = decoder( fc_feats, sample_max=0, max_seq_len=opt.max_seq_len, mode=train_mode) decoder.eval() with torch.no_grad(): greedy_captions, _, _ = decoder( fc_feats, sample_max=1, max_seq_len=opt.max_seq_len, mode=train_mode) decoder.train(training) reward = get_self_critical_reward(sample_captions, greedy_captions, fns, ground_truth, decoder.sos_id, decoder.eos_id, ciderd_scorer) loss = rl_criterion( sample_logprobs, seq_masks, torch.from_numpy(reward).float().to(opt.device)) reward_val += float(np.mean(reward[:, 0])) else: pred = decoder(fc_feats, caps_tensor, ss_prob=ss_prob) loss = xe_criterion(pred, caps_tensor[:, 1:], lengths) loss_val += float(loss) if training: optimizer.zero_grad() loss.backward() clip_gradient(optimizer, opt.grad_clip) optimizer.step() return loss_val / len(data), reward_val / len(data) checkpoint_dir = os.path.join(opt.checkpoint, train_mode) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) result_dir = os.path.join(opt.result, train_mode) if not os.path.exists(result_dir): os.makedirs(result_dir) previous_loss = None for epoch in range(opt.max_epochs): print('--------------------epoch: %d' % epoch) ss_prob = 0.0 if epoch > opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start ) // opt.scheduled_sampling_increase_every ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) train_loss, train_reward = forward(train_data, ss_prob=ss_prob) with torch.no_grad(): val_loss, _ = forward(val_data, training=False) if train_mode == 'xe' and previous_loss is not None and val_loss >= previous_loss: lr = lr * 0.5 for param_group in optimizer.param_groups: param_group['lr'] = lr previous_loss = val_loss if epoch in [0, 5, 10, 15, 20, 25, 29, 30, 35, 39, 40, 45, 49]: # test results = [] for fns, fc_feats, _, _ in tqdm.tqdm(test_data, ncols=100): fc_feats = fc_feats.to(opt.device) for i, fn in enumerate(fns): fc_feat = fc_feats[i] with torch.no_grad(): rest, _ = decoder.sample(fc_feat, beam_size=opt.beam_size, max_seq_len=opt.max_seq_len) results.append({'image_id': fn, 'caption': rest[0]}) json.dump( results, open(os.path.join(result_dir, 'result_%d.json' % epoch), 'w')) chkpoint = { 'epoch': epoch, 'model': decoder.state_dict(), 'optimizer': optimizer.state_dict(), 'settings': opt.settings, 'idx2word': idx2word, 'train_mode': train_mode, } checkpoint_path = os.path.join( checkpoint_dir, 'model_%d_%.4f_%s.pth' % (epoch, val_loss, time.strftime('%m%d-%H%M'))) torch.save(chkpoint, checkpoint_path) print('train_loss: %.4f, train_reward: %.4f, val_loss: %.4f' % (train_loss, train_reward, val_loss))
def run(ckpt_fpath): checkpoint = torch.load(ckpt_fpath) """ Load Config """ config = dict_to_cls(checkpoint['config']) """ Build Data Loader """ if config.corpus == "MSVD": corpus = MSVD(config) elif config.corpus == "MSR-VTT": corpus = MSRVTT(config) train_iter, val_iter, test_iter, vocab = \ corpus.train_data_loader, corpus.val_data_loader, corpus.test_data_loader, corpus.vocab print( '#vocabs: {} ({}), #words: {} ({}). Trim words which appear less than {} times.' .format(vocab.n_vocabs, vocab.n_vocabs_untrimmed, vocab.n_words, vocab.n_words_untrimmed, config.loader.min_count)) """ Build Models """ decoder = Decoder(rnn_type=config.decoder.rnn_type, num_layers=config.decoder.rnn_num_layers, num_directions=config.decoder.rnn_num_directions, feat_size=config.feat.size, feat_len=config.loader.frame_sample_len, embedding_size=config.vocab.embedding_size, hidden_size=config.decoder.rnn_hidden_size, attn_size=config.decoder.rnn_attn_size, output_size=vocab.n_vocabs, rnn_dropout=config.decoder.rnn_dropout) decoder.load_state_dict(checkpoint['decoder']) if config.reconstructor.type == 'global': reconstructor = GlobalReconstructor( rnn_type=config.reconstructor.rnn_type, num_layers=config.reconstructor.rnn_num_layers, num_directions=config.reconstructor.rnn_num_directions, decoder_size=config.decoder.rnn_hidden_size, hidden_size=config.reconstructor.rnn_hidden_size, rnn_dropout=config.reconstructor.rnn_dropout) else: reconstructor = LocalReconstructor( rnn_type=config.reconstructor.rnn_type, num_layers=config.reconstructor.rnn_num_layers, num_directions=config.reconstructor.rnn_num_directions, decoder_size=config.decoder.rnn_hidden_size, hidden_size=config.reconstructor.rnn_hidden_size, attn_size=config.reconstructor.rnn_attn_size, rnn_dropout=config.reconstructor.rnn_dropout) reconstructor.load_state_dict(checkpoint['reconstructor']) model = CaptionGenerator(decoder, reconstructor, config.loader.max_caption_len, vocab) model = model.cuda() ''' """ Train Set """ train_scores, train_refs, train_hypos = score(model, train_iter, vocab) save_result(train_refs, train_hypos, C.result_dpath, config.corpus, 'train') print("[TRAIN] {}".format(train_scores)) """ Validation Set """ val_scores, val_refs, val_hypos = score(model, val_iter, vocab) save_result(val_refs, val_hypos, C.result_dpath, config.corpus, 'val') print("[VAL] scores: {}".format(val_scores)) ''' """ Test Set """ test_scores, test_refs, test_hypos = score(model, test_iter, vocab) save_result(test_refs, test_hypos, C.result_dpath, config.corpus, 'test') print("[TEST] {}".format(test_scores))
encoder.load_weights() encoder.eval() for p in encoder.parameters(): p.requires_grad = False """ Load Facial Decoder """ batchSize = 110 net = Decoder(batchSize) checkpoint = torch.load("./weights/decoder-iter-4449.pt", map_location=torch.device('cpu')) net.load_state_dict(checkpoint['net_state_dict']) net.eval() """ Load Voice Encoder """ x = Speaker() net2 = VoiceEncoder(1) checkpoint = torch.load("./weights/voice-encoder-epoch-16.pt", map_location=torch.device('cpu')) net2.load_state_dict(checkpoint['net_state_dict']) net2.eval() """
def test_net(cfg, epoch_idx=-1, output_dir=None, test_data_loader=None, \ test_writer=None, encoder=None, decoder=None, refiner=None, merger=None): # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use torch.backends.cudnn.benchmark = True # Load taxonomies of dataset taxonomies = [] with open( cfg.DATASETS[cfg.DATASET.TEST_DATASET.upper()].TAXONOMY_FILE_PATH, encoding='utf-8') as file: taxonomies = json.loads(file.read()) taxonomies = {t['taxonomy_id']: t for t in taxonomies} # Set up data loader if test_data_loader is None: # Set up data augmentation IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W test_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground( cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TEST_DATASET](cfg) test_data_loader = torch.utils.data.DataLoader( dataset=dataset_loader.get_dataset( utils.data_loaders.DatasetType.TEST, cfg.CONST.N_VIEWS_RENDERING, test_transforms), batch_size=1, num_workers=1, pin_memory=True, shuffle=False) # Set up networks if decoder is None or encoder is None: encoder = Encoder(cfg) decoder = Decoder(cfg) refiner = Refiner(cfg) merger = Merger(cfg) if torch.cuda.is_available(): encoder = torch.nn.DataParallel(encoder).cuda() decoder = torch.nn.DataParallel(decoder).cuda() refiner = torch.nn.DataParallel(refiner).cuda() merger = torch.nn.DataParallel(merger).cuda() print('[INFO] %s Loading weights from %s ...' % (dt.now(), cfg.CONST.WEIGHTS)) if torch.cuda.is_available(): checkpoint = torch.load(cfg.CONST.WEIGHTS) else: map_location = torch.device('cpu') checkpoint = torch.load(cfg.CONST.WEIGHTS, map_location=map_location) epoch_idx = checkpoint['epoch_idx'] print('Epoch ID of the current model is {}'.format(epoch_idx)) encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) if cfg.NETWORK.USE_REFINER: refiner.load_state_dict(checkpoint['refiner_state_dict']) if cfg.NETWORK.USE_MERGER: merger.load_state_dict(checkpoint['merger_state_dict']) # Set up loss functions bce_loss = torch.nn.BCELoss() # Testing loop n_samples = len(test_data_loader) test_iou = dict() encoder_losses = utils.network_utils.AverageMeter() refiner_losses = utils.network_utils.AverageMeter() # Switch models to evaluation mode encoder.eval() decoder.eval() refiner.eval() merger.eval() print("test data loader type is {}".format(type(test_data_loader))) for sample_idx, (taxonomy_id, sample_name, rendering_images) in enumerate(test_data_loader): taxonomy_id = taxonomy_id[0] if isinstance( taxonomy_id[0], str) else taxonomy_id[0].item() sample_name = sample_name[0] print("sample IDx {}".format(sample_idx)) print("taxonomy id {}".format(taxonomy_id)) with torch.no_grad(): # Get data from data loader rendering_images = utils.network_utils.var_or_cuda( rendering_images) print("Shape of the loaded images {}".format( rendering_images.shape)) # Test the encoder, decoder, refiner and merger image_features = encoder(rendering_images) raw_features, generated_volume = decoder(image_features) if cfg.NETWORK.USE_MERGER: generated_volume = merger(raw_features, generated_volume) else: generated_volume = torch.mean(generated_volume, dim=1) if cfg.NETWORK.USE_REFINER: generated_volume = refiner(generated_volume) print("vox shape {}".format(generated_volume.shape)) gv = generated_volume.cpu().numpy() rendering_views = utils.binvox_visualization.get_volume_views( gv, os.path.join('./LargeDatasets/inference_images/', 'inference'), sample_idx) print("gv shape is {}".format(gv.shape)) return gv, rendering_images
def main(config): print('Starting') checkpoints = config.checkpoint.parent.glob(config.checkpoint.name + '_*.pth') checkpoints = [c for c in checkpoints if extract_id(c) in config.decoders] assert len(checkpoints) >= 1, "No checkpoints found." model_config = torch.load(config.checkpoint.parent / 'args.pth')[0] encoder = Encoder(model_config.encoder) encoder.load_state_dict(torch.load(checkpoints[0])['encoder_state']) encoder.eval() encoder = encoder.cuda() generators = [] generator_ids = [] for checkpoint in checkpoints: decoder = Decoder(model_config.decoder) decoder.load_state_dict(torch.load(checkpoint)['decoder_state']) decoder.eval() decoder = decoder.cuda() generator = SampleGenerator(decoder, config.batch_size, wav_freq=config.rate) generators.append(generator) generator_ids.append(extract_id(checkpoint)) xs = [] assert config.out_dir is not None if len(config.sample_dir) == 1 and config.sample_dir[0].is_dir(): top = config.sample_dir[0] file_paths = list(top.glob('**/*.wav')) + list(top.glob('**/*.h5')) else: file_paths = config.sample_dir print("File paths to be used:", file_paths) for file_path in file_paths: if file_path.suffix == '.wav': data, rate = librosa.load(file_path, sr=config.rate) data = helper_functions.mu_law(data) elif file_path.suffix == '.h5': data = helper_functions.mu_law( h5py.File(file_path, 'r')['wav'][:] / (2**15)) if data.shape[-1] % config.rate != 0: data = data[:-(data.shape[-1] % config.rate)] assert data.shape[-1] % config.rate == 0 print(data.shape) else: raise Exception(f'Unsupported filetype {file_path}') if config.sample_len: data = data[:config.sample_len] else: config.sample_len = len(data) xs.append(torch.tensor(data).unsqueeze(0).float().cuda()) xs = torch.stack(xs).contiguous() print(f'xs size: {xs.size()}') def save(x, decoder_idx, filepath): wav = helper_functions.inv_mu_law(x.cpu().numpy()) print(f'X size: {x.shape}') print(f'X min: {x.min()}, max: {x.max()}') save_audio(wav.squeeze(), config.out_dir / str(decoder_idx) / filepath.with_suffix('.wav').name, rate=config.rate) yy = {} with torch.no_grad(): zz = [] for xs_batch in torch.split(xs, config.batch_size): zz += [encoder(xs_batch)] zz = torch.cat(zz, dim=0) for i, generator_id in enumerate(generator_ids): yy[generator_id] = [] generator = generators[i] for zz_batch in torch.split(zz, config.batch_size): print("Batch shape:", zz_batch.shape) splits = torch.split(zz_batch, config.split_size, -1) audio_data = [] generator.reset() for cond in tqdm.tqdm(splits): audio_data += [generator.generate(cond).cpu()] audio_data = torch.cat(audio_data, -1) yy[generator_id] += [audio_data] yy[generator_id] = torch.cat(yy[generator_id], dim=0) for sample_result, filepath in zip(yy[generator_id], file_paths): save(sample_result, generator_id, filepath) del generator
def main(_run, _config, _log): for source_file, _ in _run.experiment_info['sources']: os.makedirs( os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) device = torch.device(f"cuda:{_config['gpu_id']}") _log.info('###### Create model ######') resize_dim = _config['input_size'] encoded_h = int(resize_dim[0] / 2**_config['n_pool']) encoded_w = int(resize_dim[1] / 2**_config['n_pool']) s_encoder = SupportEncoder(_config['path']['init_path'], device) #.to(device) q_encoder = QueryEncoder(_config['path']['init_path'], device) #.to(device) decoder = Decoder(input_res=(encoded_h, encoded_w), output_res=resize_dim).to(device) checkpoint = torch.load(_config['snapshot'], map_location='cpu') s_encoder.load_state_dict(checkpoint['s_encoder']) q_encoder.load_state_dict(checkpoint['q_encoder']) decoder.load_state_dict(checkpoint['decoder']) # initializer.eval() # encoder.eval() # convlstmcell.eval() # decoder.eval() _log.info('###### Load data ######') data_name = _config['dataset'] make_data = meta_data max_label = 1 tr_dataset, val_dataset, ts_dataset = make_data(_config) testloader = DataLoader( dataset=ts_dataset, batch_size=1, shuffle=False, # num_workers=_config['n_work'], pin_memory=False, # True drop_last=False) # all_samples = test_loader_Spleen() # all_samples = test_loader_Prostate() if _config['record']: _log.info('###### define tensorboard writer #####') board_name = f'board/test_{_config["board"]}_{date()}' writer = SummaryWriter(board_name) _log.info('###### Testing begins ######') # metric = Metric(max_label=max_label, n_runs=_config['n_runs']) img_cnt = 0 # length = len(all_samples) length = len(testloader) img_lists = [] pred_lists = [] label_lists = [] saves = {} for subj_idx in range(len(ts_dataset.get_cnts())): saves[subj_idx] = [] with torch.no_grad(): loss_valid = 0 batch_i = 0 # use only 1 batch size for testing for i, sample_test in enumerate( testloader): # even for upward, down for downward subj_idx, idx = ts_dataset.get_test_subj_idx(i) img_list = [] pred_list = [] label_list = [] preds = [] fnames = sample_test['q_fname'] s_x = sample_test['s_x'].to(device) # [B, slice_num, 1, 256, 256] s_y = sample_test['s_y'].to(device) # [B, slice_num, 1, 256, 256] q_x = sample_test['q_x'].to(device) # [B, slice_num, 1, 256, 256] q_y = sample_test['q_y'].to(device) # [B, slice_num, 1, 256, 256] s_xi = s_x[:, :, 0, :, :, :] # [B, Support, 1, 256, 256] s_yi = s_y[:, :, 0, :, :, :] for s_idx in range(_config["n_shot"]): s_x_merge = s_xi.view(s_xi.size(0) * s_xi.size(1), 1, 256, 256) s_y_merge = s_yi.view(s_yi.size(0) * s_yi.size(1), 1, 256, 256) s_xi_encode_merge, _ = s_encoder(s_x_merge, s_y_merge) # [B*S, 512, w, h] s_xi_encode = s_xi_encode_merge.view(s_yi.size(0), s_yi.size(1), 512, encoded_w, encoded_h) s_xi_encode_avg = torch.mean(s_xi_encode, dim=1) # s_xi_encode, _ = s_encoder(s_xi, s_yi) # [B, 512, w, h] q_xi = q_x[:, 0, :, :, :] q_yi = q_y[:, 0, :, :, :] q_xi_encode, q_ft_list = q_encoder(q_xi) sq_xi = torch.cat((s_xi_encode_avg, q_xi_encode), dim=1) yhati = decoder(sq_xi, q_ft_list) # [B, 1, 256, 256] preds.append(yhati.round()) img_list.append(q_xi[batch_i].cpu().numpy()) pred_list.append(yhati[batch_i].round().cpu().numpy()) label_list.append(q_yi[batch_i].cpu().numpy()) saves[subj_idx].append( [subj_idx, idx, img_list, pred_list, label_list, fnames]) print(f"test, iter:{i}/{length} - {subj_idx}/{idx} \t\t", end='\r') img_lists.append(img_list) pred_lists.append(pred_list) label_lists.append(label_list) print("start computing dice similarities ... total ", len(saves)) dice_similarities = [] for subj_idx in range(len(saves)): imgs, preds, labels = [], [], [] save_subj = saves[subj_idx] for i in range(len(save_subj)): # print(len(save_subj), len(save_subj)-q_slice_n+1, q_slice_n, i) subj_idx, idx, img_list, pred_list, label_list, fnames = save_subj[ i] # print(subj_idx, idx, is_reverse, len(img_list)) # print(i, is_reverse, is_reverse_next, is_flip) for j in range(len(img_list)): imgs.append(img_list[j]) preds.append(pred_list[j]) labels.append(label_list[j]) # pdb.set_trace() img_arr = np.concatenate(imgs, axis=0) pred_arr = np.concatenate(preds, axis=0) label_arr = np.concatenate(labels, axis=0) # print(ts_dataset.slice_cnts[subj_idx] , len(imgs)) # pdb.set_trace() dice = np.sum([label_arr * pred_arr ]) * 2.0 / (np.sum(pred_arr) + np.sum(label_arr)) dice_similarities.append(dice) print(f"computing dice scores {subj_idx}/{10}", end='\n') if _config['record']: frames = [] for frame_id in range(0, len(save_subj)): frames += overlay_color(torch.tensor(imgs[frame_id]), torch.tensor(preds[frame_id]), torch.tensor(labels[frame_id]), scale=_config['scale']) visual = make_grid(frames, normalize=True, nrow=5) writer.add_image(f"test/{subj_idx}", visual, i) writer.add_scalar(f'dice_score/{i}', dice) if _config['save_sample']: ## only for internal test (BCV - MICCAI2015) sup_idx = _config['s_idx'] target = _config['target'] save_name = _config['save_name'] dirs = ["gt", "pred", "input"] save_dir = f"/user/home2/soopil/tmp/PANet/MICCAI2015/sample/fss1000_organ{target}_sup{sup_idx}_{save_name}" for dir in dirs: try: os.makedirs(os.path.join(save_dir, dir)) except: pass subj_name = fnames[0][0].split("/")[-2] if target == 14: src_dir = "/user/home2/soopil/Datasets/MICCAI2015challenge/Cervix/RawData/Training/img" orig_fname = f"{src_dir}/{subj_name}-Image.nii.gz" pass else: src_dir = "/user/home2/soopil/Datasets/MICCAI2015challenge/Abdomen/RawData/Training/img" orig_fname = f"{src_dir}/img{subj_name}.nii.gz" itk = sitk.ReadImage(orig_fname) orig_spacing = itk.GetSpacing() label_arr = label_arr * 2.0 # label_arr = np.concatenate([np.zeros([1,256,256]), label_arr,np.zeros([1,256,256])]) # pred_arr = np.concatenate([np.zeros([1,256,256]), pred_arr,np.zeros([1,256,256])]) # img_arr = np.concatenate([np.zeros([1,256,256]), img_arr,np.zeros([1,256,256])]) # pdb.set_trace() itk = sitk.GetImageFromArray(label_arr) itk.SetSpacing(orig_spacing) sitk.WriteImage(itk, f"{save_dir}/gt/{subj_idx}.nii.gz") itk = sitk.GetImageFromArray(pred_arr.astype(float)) itk.SetSpacing(orig_spacing) sitk.WriteImage(itk, f"{save_dir}/pred/{subj_idx}.nii.gz") itk = sitk.GetImageFromArray(img_arr.astype(float)) itk.SetSpacing(orig_spacing) sitk.WriteImage(itk, f"{save_dir}/input/{subj_idx}.nii.gz") print(f"test result \n n : {len(dice_similarities)}, mean dice score : \ {np.mean(dice_similarities)} \n dice similarities : {dice_similarities}") if _config['record']: writer.add_scalar(f'dice_score/mean', np.mean(dice_similarities))
def train_net(cfg): # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use torch.backends.cudnn.benchmark = True # Set up data augmentation IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W train_transforms = utils.data_transforms.Compose([ utils.data_transforms.RandomCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground( cfg.TRAIN.RANDOM_BG_COLOR_RANGE), utils.data_transforms.ColorJitter(cfg.TRAIN.BRIGHTNESS, cfg.TRAIN.CONTRAST, cfg.TRAIN.SATURATION), utils.data_transforms.RandomNoise(cfg.TRAIN.NOISE_STD), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.RandomFlip(), utils.data_transforms.RandomPermuteRGB(), utils.data_transforms.ToTensor(), ]) val_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) # Set up data loader train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TRAIN_DATASET](cfg) val_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TEST_DATASET](cfg) train_data_loader = torch.utils.data.DataLoader( dataset=train_dataset_loader.get_dataset( utils.data_loaders.DatasetType.TRAIN, cfg.CONST.N_VIEWS_RENDERING, train_transforms), batch_size=cfg.CONST.BATCH_SIZE, num_workers=cfg.TRAIN.NUM_WORKER, pin_memory=True, shuffle=True, drop_last=True) val_data_loader = torch.utils.data.DataLoader( dataset=val_dataset_loader.get_dataset( utils.data_loaders.DatasetType.VAL, cfg.CONST.N_VIEWS_RENDERING, val_transforms), batch_size=1, num_workers=1, pin_memory=True, shuffle=False) # Set up networks encoder = Encoder(cfg) decoder = Decoder(cfg) refiner = Refiner(cfg) merger = Merger(cfg) print('[DEBUG] %s Parameters in Encoder: %d.' % (dt.now(), utils.network_utils.count_parameters(encoder))) print('[DEBUG] %s Parameters in Decoder: %d.' % (dt.now(), utils.network_utils.count_parameters(decoder))) print('[DEBUG] %s Parameters in Refiner: %d.' % (dt.now(), utils.network_utils.count_parameters(refiner))) print('[DEBUG] %s Parameters in Merger: %d.' % (dt.now(), utils.network_utils.count_parameters(merger))) # Initialize weights of networks encoder.apply(utils.network_utils.init_weights) decoder.apply(utils.network_utils.init_weights) refiner.apply(utils.network_utils.init_weights) merger.apply(utils.network_utils.init_weights) # Set up solver if cfg.TRAIN.POLICY == 'adam': encoder_solver = torch.optim.Adam(filter(lambda p: p.requires_grad, encoder.parameters()), lr=cfg.TRAIN.ENCODER_LEARNING_RATE, betas=cfg.TRAIN.BETAS) decoder_solver = torch.optim.Adam(decoder.parameters(), lr=cfg.TRAIN.DECODER_LEARNING_RATE, betas=cfg.TRAIN.BETAS) refiner_solver = torch.optim.Adam(refiner.parameters(), lr=cfg.TRAIN.REFINER_LEARNING_RATE, betas=cfg.TRAIN.BETAS) merger_solver = torch.optim.Adam(merger.parameters(), lr=cfg.TRAIN.MERGER_LEARNING_RATE, betas=cfg.TRAIN.BETAS) elif cfg.TRAIN.POLICY == 'sgd': encoder_solver = torch.optim.SGD(filter(lambda p: p.requires_grad, encoder.parameters()), lr=cfg.TRAIN.ENCODER_LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM) decoder_solver = torch.optim.SGD(decoder.parameters(), lr=cfg.TRAIN.DECODER_LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM) refiner_solver = torch.optim.SGD(refiner.parameters(), lr=cfg.TRAIN.REFINER_LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM) merger_solver = torch.optim.SGD(merger.parameters(), lr=cfg.TRAIN.MERGER_LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM) else: raise Exception('[FATAL] %s Unknown optimizer %s.' % (dt.now(), cfg.TRAIN.POLICY)) # Set up learning rate scheduler to decay learning rates dynamically encoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( encoder_solver, milestones=cfg.TRAIN.ENCODER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA) decoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( decoder_solver, milestones=cfg.TRAIN.DECODER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA) refiner_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( refiner_solver, milestones=cfg.TRAIN.REFINER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA) merger_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( merger_solver, milestones=cfg.TRAIN.MERGER_LR_MILESTONES, gamma=cfg.TRAIN.GAMMA) if torch.cuda.is_available(): encoder = torch.nn.DataParallel(encoder).cuda() decoder = torch.nn.DataParallel(decoder).cuda() refiner = torch.nn.DataParallel(refiner).cuda() merger = torch.nn.DataParallel(merger).cuda() # Set up loss functions bce_loss = torch.nn.BCELoss() # Load pretrained model if exists init_epoch = 0 best_iou = -1 best_epoch = -1 if 'WEIGHTS' in cfg.CONST and cfg.TRAIN.RESUME_TRAIN: print('[INFO] %s Recovering from %s ...' % (dt.now(), cfg.CONST.WEIGHTS)) checkpoint = torch.load(cfg.CONST.WEIGHTS) init_epoch = checkpoint['epoch_idx'] best_iou = checkpoint['best_iou'] best_epoch = checkpoint['best_epoch'] encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) if cfg.NETWORK.USE_REFINER: refiner.load_state_dict(checkpoint['refiner_state_dict']) if cfg.NETWORK.USE_MERGER: merger.load_state_dict(checkpoint['merger_state_dict']) print('[INFO] %s Recover complete. Current epoch #%d, Best IoU = %.4f at epoch #%d.' \ % (dt.now(), init_epoch, best_iou, best_epoch)) # Summary writer for TensorBoard output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', dt.now().isoformat()) log_dir = output_dir % 'logs' ckpt_dir = output_dir % 'checkpoints' train_writer = SummaryWriter(os.path.join(log_dir, 'train')) val_writer = SummaryWriter(os.path.join(log_dir, 'test')) # Training loop for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHES): # Tick / tock epoch_start_time = time() # Batch average meterics batch_time = utils.network_utils.AverageMeter() data_time = utils.network_utils.AverageMeter() encoder_losses = utils.network_utils.AverageMeter() refiner_losses = utils.network_utils.AverageMeter() # Adjust learning rate encoder_lr_scheduler.step() decoder_lr_scheduler.step() refiner_lr_scheduler.step() merger_lr_scheduler.step() # switch models to training mode encoder.train() decoder.train() merger.train() refiner.train() batch_end_time = time() n_batches = len(train_data_loader) for batch_idx, (taxonomy_names, sample_names, rendering_images, ground_truth_volumes) in enumerate(train_data_loader): # Measure data time data_time.update(time() - batch_end_time) # Get data from data loader rendering_images = utils.network_utils.var_or_cuda( rendering_images) ground_truth_volumes = utils.network_utils.var_or_cuda( ground_truth_volumes) # Train the encoder, decoder, refiner, and merger image_features = encoder(rendering_images) raw_features, generated_volumes = decoder(image_features) if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER: generated_volumes = merger(raw_features, generated_volumes) else: generated_volumes = torch.mean(generated_volumes, dim=1) encoder_loss = bce_loss(generated_volumes, ground_truth_volumes) * 10 if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: generated_volumes = refiner(generated_volumes) refiner_loss = bce_loss(generated_volumes, ground_truth_volumes) * 10 else: refiner_loss = encoder_loss # Gradient decent encoder.zero_grad() decoder.zero_grad() refiner.zero_grad() merger.zero_grad() if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: encoder_loss.backward(retain_graph=True) refiner_loss.backward() else: encoder_loss.backward() encoder_solver.step() decoder_solver.step() refiner_solver.step() merger_solver.step() # Append loss to average metrics encoder_losses.update(encoder_loss.item()) refiner_losses.update(refiner_loss.item()) # Append loss to TensorBoard n_itr = epoch_idx * n_batches + batch_idx train_writer.add_scalar('EncoderDecoder/BatchLoss', encoder_loss.item(), n_itr) train_writer.add_scalar('Refiner/BatchLoss', refiner_loss.item(), n_itr) # Tick / tock batch_time.update(time() - batch_end_time) batch_end_time = time() print('[INFO] %s [Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' % \ (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, \ batch_time.val, data_time.val, encoder_loss.item(), refiner_loss.item())) # Append epoch loss to TensorBoard train_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg, epoch_idx + 1) train_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg, epoch_idx + 1) # Tick / tock epoch_end_time = time() print('[INFO] %s Epoch [%d/%d] EpochTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' % (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, epoch_end_time - epoch_start_time, \ encoder_losses.avg, refiner_losses.avg)) # Update Rendering Views if cfg.TRAIN.UPDATE_N_VIEWS_RENDERING: n_views_rendering = random.randint(1, cfg.CONST.N_VIEWS_RENDERING) train_data_loader.dataset.set_n_views_rendering(n_views_rendering) print('[INFO] %s Epoch [%d/%d] Update #RenderingViews to %d' % \ (dt.now(), epoch_idx + 2, cfg.TRAIN.NUM_EPOCHES, n_views_rendering)) # Validate the training models iou = test_net(cfg, epoch_idx + 1, output_dir, val_data_loader, val_writer, encoder, decoder, refiner, merger) # Save weights to file if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0: if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) utils.network_utils.save_checkpoints(cfg, \ os.path.join(ckpt_dir, 'ckpt-epoch-%04d.pth' % (epoch_idx + 1)), \ epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, \ refiner, refiner_solver, merger, merger_solver, best_iou, best_epoch) if iou > best_iou: if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) best_iou = iou best_epoch = epoch_idx + 1 utils.network_utils.save_checkpoints(cfg, \ os.path.join(ckpt_dir, 'best-ckpt.pth'), \ epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, \ refiner, refiner_solver, merger, merger_solver, best_iou, best_epoch) # Close SummaryWriter for TensorBoard train_writer.close() val_writer.close()
def main(_run, _config, _log): for source_file, _ in _run.experiment_info['sources']: os.makedirs( os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) device = torch.device(f"cuda:{_config['gpu_id']}") _log.info('###### Create model ######') resize_dim = _config['input_size'] encoded_h = int(resize_dim[0] / 2**_config['n_pool']) encoded_w = int(resize_dim[1] / 2**_config['n_pool']) s_encoder = SupportEncoder(_config['path']['init_path'], device) #.to(device) q_encoder = QueryEncoder(_config['path']['init_path'], device) #.to(device) decoder = Decoder(input_res=(encoded_h, encoded_w), output_res=resize_dim).to(device) checkpoint = torch.load(_config['snapshot'], map_location='cpu') s_encoder.load_state_dict(checkpoint['s_encoder']) q_encoder.load_state_dict(checkpoint['q_encoder']) decoder.load_state_dict(checkpoint['decoder']) # initializer.eval() # encoder.eval() # convlstmcell.eval() # decoder.eval() _log.info('###### Load data ######') data_name = _config['dataset'] make_data = meta_data max_label = 1 tr_dataset, val_dataset, ts_dataset = make_data(_config) testloader = DataLoader( dataset=ts_dataset, batch_size=1, shuffle=False, # num_workers=_config['n_work'], pin_memory=False, # True drop_last=False) _log.info('###### Testing begins ######') # metric = Metric(max_label=max_label, n_runs=_config['n_runs']) img_cnt = 0 # length = len(all_samples) length = len(testloader) img_lists = [] pred_lists = [] label_lists = [] saves = {} for subj_idx in range(len(ts_dataset.get_cnts())): saves[subj_idx] = [] with torch.no_grad(): loss_valid = 0 batch_i = 0 # use only 1 batch size for testing for i, sample_test in enumerate( testloader): # even for upward, down for downward subj_idx, idx = ts_dataset.get_test_subj_idx(i) img_list = [] pred_list = [] label_list = [] preds = [] s_x = sample_test['s_x'].to(device) # [B, slice_num, 1, 256, 256] s_y = sample_test['s_y'].to(device) # [B, slice_num, 1, 256, 256] q_x = sample_test['q_x'].to(device) # [B, slice_num, 1, 256, 256] q_y = sample_test['q_y'].to(device) # [B, slice_num, 1, 256, 256] s_fname = sample_test['s_fname'] q_fname = sample_test['q_fname'] s_xi = s_x[:, 0, :, :, :] #[B, 1, 256, 256] s_yi = s_y[:, 0, :, :, :] s_xi_encode, _ = s_encoder(s_xi, s_yi) #[B, 512, w, h] q_xi = q_x[:, 0, :, :, :] q_yi = q_y[:, 0, :, :, :] q_xi_encode, q_ft_list = q_encoder(q_xi) sq_xi = torch.cat((s_xi_encode, q_xi_encode), dim=1) yhati = decoder(sq_xi, q_ft_list) # [B, 1, 256, 256] preds.append(yhati.round()) img_list.append(q_xi[batch_i].cpu().numpy()) pred_list.append(yhati[batch_i].round().cpu().numpy()) label_list.append(q_yi[batch_i].cpu().numpy()) saves[subj_idx].append( [subj_idx, idx, img_list, pred_list, label_list]) print(f"test, iter:{i}/{length} - {subj_idx}/{idx} \t\t", end='\r') img_lists.append(img_list) pred_lists.append(pred_list) label_lists.append(label_list) q_fname_split = q_fname[0][0].split("/") q_fname_split[-6] = "Training_2d_2_pred" try_mkdirs("/".join(q_fname_split[:-1])) o_q_fname = "/".join(q_fname_split) np.save(o_q_fname, yhati.round().cpu().numpy()) # print(q_fname[0][0]) # print(o_q_fname) try_mkdirs("figure") print("start computing dice similarities ... total ", len(saves)) for subj_idx in range(len(saves)): save_subj = saves[subj_idx] dices = [] for slice_idx in range(len(save_subj)): subj_idx, idx, img_list, pred_list, label_list = save_subj[ slice_idx] for j in range(len(img_list)): dice = np.sum([label_list[j] * pred_list[j]]) * 2.0 / ( np.sum(pred_list[j]) + np.sum(label_list[j])) dices.append(dice) plt.clf() plt.bar([k for k in range(len(dices))], dices) plt.savefig(f"figure/bar_{_config['target']}_{subj_idx}.png")
from models.encoder import Encoder opt = parse_opt() assert opt.test_model, 'please input test_model' assert opt.image_file, 'please input image_file' encoder = Encoder(opt.resnet101_file) encoder.to(opt.device) encoder.eval() img = skimage.io.imread(opt.image_file) with torch.no_grad(): img = encoder.preprocess(img) img = img.to(opt.device) fc_feat, att_feat = encoder(img) print("====> loading checkpoint '{}'".format(opt.test_model)) chkpoint = torch.load(opt.test_model, map_location=lambda s, l: s) decoder = Decoder(chkpoint['idx2word'], chkpoint['settings']) decoder.load_state_dict(chkpoint['model']) print("====> loaded checkpoint '{}', epoch: {}, train_mode: {}".format( opt.test_model, chkpoint['epoch'], chkpoint['train_mode'])) decoder.to(opt.device) decoder.eval() rest, _ = decoder.sample(fc_feat, att_feat, beam_size=opt.beam_size, max_seq_len=opt.max_seq_len) print('generate captions:\n' + '\n'.join(rest))
def test_img(cfg): encoder = Encoder(cfg) decoder = Decoder(cfg) refiner = Refiner(cfg) merger = Merger(cfg) cfg.CONST.WEIGHTS = '/Users/pranavpomalapally/Downloads/new-Pix2Vox-A-ShapeNet.pth' checkpoint = torch.load(cfg.CONST.WEIGHTS, map_location=torch.device('cpu')) print() # fix_checkpoint = {} # fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items()) # fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items()) # fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['refiner_state_dict'].items()) # fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['merger_state_dict'].items()) # fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['encoder_state_dict'].items()) # fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['decoder_state_dict'].items()) # fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['refiner_state_dict'].items()) # fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[0], v) for k, v in checkpoint['merger_state_dict'].items()) epoch_idx = checkpoint['epoch_idx'] # encoder.load_state_dict(fix_checkpoint['encoder_state_dict']) # decoder.load_state_dict(fix_checkpoint['decoder_state_dict']) encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) # if cfg.NETWORK.USE_REFINER: # print('Use refiner') # refiner.load_state_dict(fix_checkpoint['refiner_state_dict']) print('Use refiner') refiner.load_state_dict(checkpoint['refiner_state_dict']) if cfg.NETWORK.USE_MERGER: print('Use merger') # merger.load_state_dict(fix_checkpoint['merger_state_dict']) merger.load_state_dict(checkpoint['merger_state_dict']) encoder.eval() decoder.eval() refiner.eval() merger.eval() #img1_path = '/Users/pranavpomalapally/Downloads/ShapeNetRendering/02691156/1a04e3eab45ca15dd86060f189eb133/rendering/00.png' img1_path = '/Users/pranavpomalapally/Downloads/09 copy.png' img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype( np.float32) / 255. sample = np.array([img1_np]) IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W test_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) rendering_images = test_transforms(rendering_images=sample) rendering_images = rendering_images.unsqueeze(0) with torch.no_grad(): image_features = encoder(rendering_images) raw_features, generated_volume = decoder(image_features) if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER: generated_volume = merger(raw_features, generated_volume) else: generated_volume = torch.mean(generated_volume, dim=1) # if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER: # generated_volume = refiner(generated_volume) generated_volume = refiner(generated_volume) generated_volume = generated_volume.squeeze(0) img_dir = '/Users/pranavpomalapally/Downloads/outputs' # gv = generated_volume.cpu().numpy() gv = generated_volume.cpu().detach().numpy() gv_new = np.swapaxes(gv, 2, 1) os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' rendering_views = utils.binvox_visualization.get_volume_views( gv_new, img_dir, epoch_idx)