def main(args): logger = setup_logger("Listen_to_look, classification", args.checkpoint_path, True) logger.debug(args) writer = None if args.visualization: writer = setup_tbx(args.checkpoint_path, True) if writer is not None: logger.info("Allowed Tensorboard writer") # create model builder = ModelBuilder() net_classifier = builder.build_classifierNet(args.embedding_size, args.num_classes).cuda() net_imageAudioClassify = builder.build_imageAudioClassifierNet( net_classifier, args).cuda() model = builder.build_audioPreviewLSTM(net_classifier, args) # define loss function (criterion) and optimizer criterion = {} criterion['CrossEntropyLoss'] = nn.CrossEntropyLoss().cuda() cudnn.benchmark = True checkpointer = Checkpointer(model) if args.pretrained_model is not None: if not os.path.isfile(args.pretrained_model): list_of_models = glob.glob( os.path.join(args.pretrained_model, "*.pth")) args.pretrained_model = max(list_of_models, key=os.path.getctime) logger.debug("Loading model only at: {}".format(args.pretrained_model)) checkpointer.load_model_only(f=args.pretrained_model) model = torch.nn.parallel.DataParallel(model).cuda() # DATA LOADING val_ds, val_collate = create_validation_dataset(args, logger=logger) val_loader = torch.utils.data.DataLoader(val_ds, batch_size=args.batch_size, num_workers=args.decode_threads, collate_fn=val_collate) video_mean_ap, video_acc, loss_avg, gtprediction_mean_ap = validate( args, 117, val_loader, model, criterion, val_ds=val_ds) print("Testing Summary for checkpoint: {}\n" "video accuracy/mAP/gt mAP: {} \t {} \t {}\n".format( args.pretrained_model, video_acc * 100, video_mean_ap * 100, gtprediction_mean_ap * 100))
def main(): #load test arguments opt = TestOptions().parse() opt.device = torch.device("cuda") # Network Builders builder = ModelBuilder() net_lipreading = builder.build_lipreadingnet( config_path=opt.lipreading_config_path, weights=opt.weights_lipreadingnet, extract_feats=opt.lipreading_extract_feature) #if identity feature dim is not 512, for resnet reduce dimension to this feature dim if opt.identity_feature_dim != 512: opt.with_fc = True else: opt.with_fc = False net_facial_attributes = builder.build_facial( pool_type=opt.visual_pool, fc_out=opt.identity_feature_dim, with_fc=opt.with_fc, weights=opt.weights_facial) net_unet = builder.build_unet( ngf=opt.unet_ngf, input_nc=opt.unet_input_nc, output_nc=opt.unet_output_nc, audioVisual_feature_dim=opt.audioVisual_feature_dim, identity_feature_dim=opt.identity_feature_dim, weights=opt.weights_unet) net_vocal_attributes = builder.build_vocal(pool_type=opt.audio_pool, input_channel=2, with_fc=opt.with_fc, fc_out=opt.identity_feature_dim, weights=opt.weights_vocal) nets = (net_lipreading, net_facial_attributes, net_unet, net_vocal_attributes) print(nets) # construct our audio-visual model model = AudioVisualModel(nets, opt) model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) model.to(opt.device) model.eval() mtcnn = MTCNN(keep_all=True, device=opt.device) lipreading_preprocessing_func = get_preprocessing_pipelines()['test'] normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) vision_transform_list = [transforms.ToTensor()] if opt.normalization: vision_transform_list.append(normalize) vision_transform = transforms.Compose(vision_transform_list) # load data mouthroi_1 = load_mouthroi(opt.mouthroi1_path) mouthroi_2 = load_mouthroi(opt.mouthroi2_path) _, audio1 = wavfile.read(opt.audio1_path) _, audio2 = wavfile.read(opt.audio2_path) _, audio_offscreen = wavfile.read(opt.offscreen_audio_path) audio1 = audio1 / 32768 audio2 = audio2 / 32768 audio_offscreen = audio_offscreen / 32768 #make sure the two audios are of the same length and then mix them audio_length = min(min(len(audio1), len(audio2)), len(audio_offscreen)) audio1 = clip_audio(audio1[:audio_length]) audio2 = clip_audio(audio2[:audio_length]) audio_offscreen = clip_audio(audio_offscreen[:audio_length]) audio_mix = (audio1 + audio2 + audio_offscreen * opt.noise_weight) / 3 if opt.reliable_face: best_score_1 = 0 best_score_2 = 0 for i in range(10): frame_1 = load_frame(opt.video1_path) frame_2 = load_frame(opt.video2_path) boxes, scores = mtcnn.detect(frame_1) if scores[0] > best_score_1: best_frame_1 = frame_1 boxes, scores = mtcnn.detect(frame_2) if scores[0] > best_score_2: best_frame_2 = frame_2 frames_1 = vision_transform(best_frame_1).squeeze().unsqueeze(0) frames_2 = vision_transform(best_frame_2).squeeze().unsqueeze(0) else: frame_1_list = [] frame_2_list = [] for i in range(opt.number_of_identity_frames): frame_1 = load_frame(opt.video1_path) frame_2 = load_frame(opt.video2_path) frame_1 = vision_transform(frame_1) frame_2 = vision_transform(frame_2) frame_1_list.append(frame_1) frame_2_list.append(frame_2) frames_1 = torch.stack(frame_1_list).squeeze().unsqueeze(0) frames_2 = torch.stack(frame_2_list).squeeze().unsqueeze(0) #perform separation over the whole audio using a sliding window approach overlap_count = np.zeros((audio_length)) sep_audio1 = np.zeros((audio_length)) sep_audio2 = np.zeros((audio_length)) sliding_window_start = 0 data = {} avged_sep_audio1 = np.zeros((audio_length)) avged_sep_audio2 = np.zeros((audio_length)) samples_per_window = int(opt.audio_length * opt.audio_sampling_rate) while sliding_window_start + samples_per_window < audio_length: sliding_window_end = sliding_window_start + samples_per_window #get audio spectrogram segment1_audio = audio1[sliding_window_start:sliding_window_end] segment2_audio = audio2[sliding_window_start:sliding_window_end] segment_offscreen = audio_offscreen[ sliding_window_start:sliding_window_end] if opt.audio_normalization: normalizer1, segment1_audio = audio_normalize(segment1_audio) normalizer2, segment2_audio = audio_normalize(segment2_audio) _, segment_offscreen = audio_normalize(segment_offscreen) else: normalizer1 = 1 normalizer2 = 1 audio_segment = (segment1_audio + segment2_audio + segment_offscreen * opt.noise_weight) / 3 audio_mix_spec = generate_spectrogram_complex(audio_segment, opt.window_size, opt.hop_size, opt.n_fft) audio_spec_1 = generate_spectrogram_complex(segment1_audio, opt.window_size, opt.hop_size, opt.n_fft) audio_spec_2 = generate_spectrogram_complex(segment2_audio, opt.window_size, opt.hop_size, opt.n_fft) #get mouthroi frame_index_start = int( round(sliding_window_start / opt.audio_sampling_rate * 25)) segment1_mouthroi = mouthroi_1[frame_index_start:( frame_index_start + opt.num_frames), :, :] segment2_mouthroi = mouthroi_2[frame_index_start:( frame_index_start + opt.num_frames), :, :] #transform mouthrois segment1_mouthroi = lipreading_preprocessing_func(segment1_mouthroi) segment2_mouthroi = lipreading_preprocessing_func(segment2_mouthroi) data['audio_spec_mix1'] = torch.FloatTensor(audio_mix_spec).unsqueeze( 0) data['mouthroi_A1'] = torch.FloatTensor(segment1_mouthroi).unsqueeze( 0).unsqueeze(0) data['mouthroi_B'] = torch.FloatTensor(segment2_mouthroi).unsqueeze( 0).unsqueeze(0) data['audio_spec_A1'] = torch.FloatTensor(audio_spec_1).unsqueeze(0) data['audio_spec_B'] = torch.FloatTensor(audio_spec_2).unsqueeze(0) data['frame_A'] = frames_1 data['frame_B'] = frames_2 #don't care below data['frame_A'] = frames_1 data['mouthroi_A2'] = torch.FloatTensor(segment1_mouthroi).unsqueeze( 0).unsqueeze(0) data['audio_spec_A2'] = torch.FloatTensor(audio_spec_1).unsqueeze(0) data['audio_spec_mix2'] = torch.FloatTensor(audio_mix_spec).unsqueeze( 0) outputs = model.forward(data) reconstructed_signal_1, reconstructed_signal_2 = get_separated_audio( outputs, data, opt) reconstructed_signal_1 = reconstructed_signal_1 * normalizer1 reconstructed_signal_2 = reconstructed_signal_2 * normalizer2 sep_audio1[sliding_window_start:sliding_window_end] = sep_audio1[ sliding_window_start:sliding_window_end] + reconstructed_signal_1 sep_audio2[sliding_window_start:sliding_window_end] = sep_audio2[ sliding_window_start:sliding_window_end] + reconstructed_signal_2 #update overlap count overlap_count[sliding_window_start:sliding_window_end] = overlap_count[ sliding_window_start:sliding_window_end] + 1 sliding_window_start = sliding_window_start + int( opt.hop_length * opt.audio_sampling_rate) #deal with the last segment #get audio spectrogram segment1_audio = audio1[-samples_per_window:] segment2_audio = audio2[-samples_per_window:] segment_offscreen = audio_offscreen[-samples_per_window:] if opt.audio_normalization: normalizer1, segment1_audio = audio_normalize(segment1_audio) normalizer2, segment2_audio = audio_normalize(segment2_audio) else: normalizer1 = 1 normalizer2 = 1 audio_segment = (segment1_audio + segment2_audio + segment_offscreen * opt.noise_weight) / 3 audio_mix_spec = generate_spectrogram_complex(audio_segment, opt.window_size, opt.hop_size, opt.n_fft) #get mouthroi frame_index_start = int( round((len(audio1) - samples_per_window) / opt.audio_sampling_rate * 25)) - 1 segment1_mouthroi = mouthroi_1[frame_index_start:(frame_index_start + opt.num_frames), :, :] segment2_mouthroi = mouthroi_2[frame_index_start:(frame_index_start + opt.num_frames), :, :] #transform mouthrois segment1_mouthroi = lipreading_preprocessing_func(segment1_mouthroi) segment2_mouthroi = lipreading_preprocessing_func(segment2_mouthroi) audio_spec_1 = generate_spectrogram_complex(segment1_audio, opt.window_size, opt.hop_size, opt.n_fft) audio_spec_2 = generate_spectrogram_complex(segment2_audio, opt.window_size, opt.hop_size, opt.n_fft) data['audio_spec_mix1'] = torch.FloatTensor(audio_mix_spec).unsqueeze(0) data['mouthroi_A1'] = torch.FloatTensor(segment1_mouthroi).unsqueeze( 0).unsqueeze(0) data['mouthroi_B'] = torch.FloatTensor(segment2_mouthroi).unsqueeze( 0).unsqueeze(0) data['audio_spec_A1'] = torch.FloatTensor(audio_spec_1).unsqueeze(0) data['audio_spec_B'] = torch.FloatTensor(audio_spec_2).unsqueeze(0) data['frame_A'] = frames_1 data['frame_B'] = frames_2 #don't care below data['frame_A'] = frames_1 data['mouthroi_A2'] = torch.FloatTensor(segment1_mouthroi).unsqueeze( 0).unsqueeze(0) data['audio_spec_A2'] = torch.FloatTensor(audio_spec_1).unsqueeze(0) data['audio_spec_mix2'] = torch.FloatTensor(audio_mix_spec).unsqueeze(0) outputs = model.forward(data) reconstructed_signal_1, reconstructed_signal_2 = get_separated_audio( outputs, data, opt) reconstructed_signal_1 = reconstructed_signal_1 * normalizer1 reconstructed_signal_2 = reconstructed_signal_2 * normalizer2 sep_audio1[-samples_per_window:] = sep_audio1[ -samples_per_window:] + reconstructed_signal_1 sep_audio2[-samples_per_window:] = sep_audio2[ -samples_per_window:] + reconstructed_signal_2 #update overlap count overlap_count[ -samples_per_window:] = overlap_count[-samples_per_window:] + 1 #divide the aggregated predicted audio by the overlap count avged_sep_audio1 = avged_sep_audio1 + clip_audio( np.divide(sep_audio1, overlap_count)) avged_sep_audio2 = avged_sep_audio2 + clip_audio( np.divide(sep_audio2, overlap_count)) #output original and separated audios parts1 = opt.video1_path.split('/') parts2 = opt.video2_path.split('/') video1_name = parts1[-3] + '_' + parts1[-2] + '_' + parts1[-1][:-4] video2_name = parts2[-3] + '_' + parts2[-2] + '_' + parts2[-1][:-4] output_dir = os.path.join(opt.output_dir_root, video1_name + 'VS' + video2_name) if not os.path.isdir(output_dir): os.mkdir(output_dir) librosa.output.write_wav(os.path.join(output_dir, 'audio1.wav'), audio1, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio2.wav'), audio2, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio_offscreen.wav'), audio_offscreen, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio_mixed.wav'), audio_mix, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio1_separated.wav'), avged_sep_audio1, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio2_separated.wav'), avged_sep_audio2, opt.audio_sampling_rate)
def main(): #load test arguments opt = TestOptions().parse() opt.device = torch.device("cuda") # Network Builders builder = ModelBuilder() net_visual = builder.build_visual(pool_type=opt.visual_pool, weights=opt.weights_visual) net_unet = builder.build_unet(unet_num_layers=opt.unet_num_layers, ngf=opt.unet_ngf, input_nc=opt.unet_input_nc, output_nc=opt.unet_output_nc, weights=opt.weights_unet) if opt.with_additional_scene_image: opt.number_of_classes = opt.number_of_classes + 1 net_classifier = builder.build_classifier( pool_type=opt.classifier_pool, num_of_classes=opt.number_of_classes, input_channel=opt.unet_output_nc, weights=opt.weights_classifier) nets = (net_visual, net_unet, net_classifier) # construct our audio-visual model model = AudioVisualModel(nets, opt) model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) model.to(opt.device) model.eval() #load the two audios audio1_path = os.path.join(opt.data_path, 'audio_11025', opt.video1_name + '.wav') audio1, _ = librosa.load(audio1_path, sr=opt.audio_sampling_rate) audio2_path = os.path.join(opt.data_path, 'audio_11025', opt.video2_name + '.wav') audio2, _ = librosa.load(audio2_path, sr=opt.audio_sampling_rate) #make sure the two audios are of the same length and then mix them audio_length = min(len(audio1), len(audio2)) audio1 = clip_audio(audio1[:audio_length]) audio2 = clip_audio(audio2[:audio_length]) audio_mix = (audio1 + audio2) / 2.0 #define the transformation to perform on visual frames vision_transform_list = [ transforms.Resize((224, 224)), transforms.ToTensor() ] if opt.subtract_mean: vision_transform_list.append( transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) vision_transform = transforms.Compose(vision_transform_list) #load the object regions of the highest confidence score for both videos detectionResult1 = np.load( os.path.join(opt.data_path, 'detection_results', opt.video1_name + '.npy')) detectionResult2 = np.load( os.path.join(opt.data_path, 'detection_results', opt.video2_name + '.npy')) avged_sep_audio1 = np.zeros((audio_length)) avged_sep_audio2 = np.zeros((audio_length)) for i in range(opt.num_of_object_detections_to_use): det_box1 = detectionResult1[np.argmax( detectionResult1[:, 2] ), :] #get the box of the highest confidence score det_box2 = detectionResult2[np.argmax( detectionResult2[:, 2] ), :] #get the box of the highest confidence score detectionResult1[np.argmax(detectionResult1[:, 2]), 2] = 0 # set to 0 after using it detectionResult2[np.argmax(detectionResult2[:, 2]), 2] = 0 # set to 0 after using it frame_path1 = os.path.join(opt.data_path, 'frame', opt.video1_name, "%06d.png" % det_box1[0]) frame_path2 = os.path.join(opt.data_path, 'frame', opt.video2_name, "%06d.png" % det_box2[0]) detection1 = Image.open(frame_path1).convert('RGB').crop( (det_box1[-4], det_box1[-3], det_box1[-2], det_box1[-1])) detection2 = Image.open(frame_path2).convert('RGB').crop( (det_box2[-4], det_box2[-3], det_box2[-2], det_box2[-1])) #perform separation over the whole audio using a sliding window approach overlap_count = np.zeros((audio_length)) sep_audio1 = np.zeros((audio_length)) sep_audio2 = np.zeros((audio_length)) sliding_window_start = 0 data = {} samples_per_window = opt.audio_window while sliding_window_start + samples_per_window < audio_length: sliding_window_end = sliding_window_start + samples_per_window audio_segment = audio_mix[sliding_window_start:sliding_window_end] audio_mix_mags, audio_mix_phases = generate_spectrogram_magphase( audio_segment, opt.stft_frame, opt.stft_hop) data['audio_mix_mags'] = torch.FloatTensor( audio_mix_mags).unsqueeze(0) data['audio_mix_phases'] = torch.FloatTensor( audio_mix_phases).unsqueeze(0) data['real_audio_mags'] = data[ 'audio_mix_mags'] #dont' care for testing data['audio_mags'] = data[ 'audio_mix_mags'] #dont' care for testing #separate for video 1 data['visuals'] = vision_transform(detection1).unsqueeze(0) data['labels'] = torch.FloatTensor(np.ones( (1, 1))) #don't care for testing data['vids'] = torch.FloatTensor(np.ones( (1, 1))) #don't care for testing outputs = model.forward(data) reconstructed_signal = get_separated_audio(outputs, data, opt) sep_audio1[sliding_window_start:sliding_window_end] = sep_audio1[ sliding_window_start:sliding_window_end] + reconstructed_signal #separate for video 2 data['visuals'] = vision_transform(detection2).unsqueeze(0) #data['label'] = torch.LongTensor([0]) #don't care for testing outputs = model.forward(data) reconstructed_signal = get_separated_audio(outputs, data, opt) sep_audio2[sliding_window_start:sliding_window_end] = sep_audio2[ sliding_window_start:sliding_window_end] + reconstructed_signal #update overlap count overlap_count[ sliding_window_start:sliding_window_end] = overlap_count[ sliding_window_start:sliding_window_end] + 1 sliding_window_start = sliding_window_start + int( opt.hop_size * opt.audio_sampling_rate) #deal with the last segment audio_segment = audio_mix[-samples_per_window:] audio_mix_mags, audio_mix_phases = generate_spectrogram_magphase( audio_segment, opt.stft_frame, opt.stft_hop) data['audio_mix_mags'] = torch.FloatTensor(audio_mix_mags).unsqueeze(0) data['audio_mix_phases'] = torch.FloatTensor( audio_mix_phases).unsqueeze(0) data['real_audio_mags'] = data[ 'audio_mix_mags'] #dont' care for testing data['audio_mags'] = data['audio_mix_mags'] #dont' care for testing #separate for video 1 data['visuals'] = vision_transform(detection1).unsqueeze(0) data['labels'] = torch.FloatTensor(np.ones( (1, 1))) #don't care for testing data['vids'] = torch.FloatTensor(np.ones( (1, 1))) #don't care for testing outputs = model.forward(data) reconstructed_signal = get_separated_audio(outputs, data, opt) sep_audio1[-samples_per_window:] = sep_audio1[ -samples_per_window:] + reconstructed_signal #separate for video 2 data['visuals'] = vision_transform(detection2).unsqueeze(0) outputs = model.forward(data) reconstructed_signal = get_separated_audio(outputs, data, opt) sep_audio2[-samples_per_window:] = sep_audio2[ -samples_per_window:] + reconstructed_signal #update overlap count overlap_count[ -samples_per_window:] = overlap_count[-samples_per_window:] + 1 #divide the aggregated predicted audio by the overlap count avged_sep_audio1 = avged_sep_audio1 + clip_audio( np.divide(sep_audio1, overlap_count) * 2) avged_sep_audio2 = avged_sep_audio2 + clip_audio( np.divide(sep_audio2, overlap_count) * 2) separation1 = avged_sep_audio1 / opt.num_of_object_detections_to_use separation2 = avged_sep_audio2 / opt.num_of_object_detections_to_use #output original and separated audios output_dir = os.path.join(opt.output_dir_root, opt.video1_name + 'VS' + opt.video2_name) if not os.path.isdir(output_dir): os.mkdir(output_dir) librosa.output.write_wav(os.path.join(output_dir, 'audio1.wav'), audio1, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio2.wav'), audio2, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio_mixed.wav'), audio_mix, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio1_separated.wav'), separation1, opt.audio_sampling_rate) librosa.output.write_wav(os.path.join(output_dir, 'audio2_separated.wav'), separation2, opt.audio_sampling_rate) #save the two detections detection1.save(os.path.join(output_dir, 'audio1.png')) detection2.save(os.path.join(output_dir, 'audio2.png')) #save the spectrograms & masks if opt.visualize_spectrogram: import matplotlib.pyplot as plt plt.switch_backend('agg') plt.ioff() audio1_mag = generate_spectrogram_magphase(audio1, opt.stft_frame, opt.stft_hop, with_phase=False) audio2_mag = generate_spectrogram_magphase(audio2, opt.stft_frame, opt.stft_hop, with_phase=False) audio_mix_mag = generate_spectrogram_magphase(audio_mix, opt.stft_frame, opt.stft_hop, with_phase=False) separation1_mag = generate_spectrogram_magphase(separation1, opt.stft_frame, opt.stft_hop, with_phase=False) separation2_mag = generate_spectrogram_magphase(separation2, opt.stft_frame, opt.stft_hop, with_phase=False) utils.visualizeSpectrogram(audio1_mag[0, :, :], os.path.join(output_dir, 'audio1_spec.png')) utils.visualizeSpectrogram(audio2_mag[0, :, :], os.path.join(output_dir, 'audio2_spec.png')) utils.visualizeSpectrogram( audio_mix_mag[0, :, :], os.path.join(output_dir, 'audio_mixed_spec.png')) utils.visualizeSpectrogram( separation1_mag[0, :, :], os.path.join(output_dir, 'separation1_spec.png')) utils.visualizeSpectrogram( separation2_mag[0, :, :], os.path.join(output_dir, 'separation2_spec.png'))
#temperally set to val to load val data opt.mode = 'val' data_loader_val = CreateDataLoader(opt) dataset_val = data_loader_val.load_data() dataset_size_val = len(data_loader_val) print('#validation images = %d' % dataset_size_val) opt.mode = 'train' #set it back if opt.tensorboard: from tensorboardX import SummaryWriter writer = SummaryWriter(comment=opt.name) else: writer = None # Network Builders builder = ModelBuilder() net_visual = builder.build_visual(pool_type=opt.visual_pool, fc_out=512, weights=opt.weights_visual) net_unet = builder.build_unet(unet_num_layers=opt.unet_num_layers, ngf=opt.unet_ngf, input_nc=opt.unet_input_nc, output_nc=opt.unet_output_nc, weights=opt.weights_unet) net_classifier = builder.build_classifier(pool_type=opt.classifier_pool, num_of_classes=opt.number_of_classes, input_channel=opt.unet_output_nc, weights=opt.weights_classifier) nets = (net_visual, net_unet, net_classifier) # construct our audio-visual model
validation_opt = copy.copy(opt) validation_opt.mode = 'val' validation_opt.enable_data_augmentation = False data_loader_val = CreateDataLoader(validation_opt) dataset_val = data_loader_val.load_data() dataset_size_val = len(data_loader_val) print('#validation clips = %d' % dataset_size_val) if opt.tensorboard: from tensorboardX import SummaryWriter writer = SummaryWriter(comment=opt.name) else: writer = None # network builders builder = ModelBuilder() model, nets = builder.get_model(opt) if opt.use_visual_info: net_visual, net_audio = nets else: net_audio = nets[0] if len(opt.gpu_ids) > 0: model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) model.to(opt.device) # set up optimizer optimizer = create_optimizer(nets, opt) # set up loss function loss_criterion = create_loss_criterion(opt) #if len(opt.gpu_ids) > 0: # loss_criterion.cuda(opt.gpu_ids[0])
'guzheng': [], 'piano': [], 'pipa': [], 'saxophone': [], 'trumpet': [], 'tuba': [], 'ukulele': [], 'violin': [], 'xylophone': [] } for sample in samples_list: ins = get_ins_name(sample) instruments[ins].append(get_clip_name(sample)) # Network Builders builder = ModelBuilder() net_visual = builder.build_visual(pool_type=opt.visual_pool, weights=opt.weights_visual) net_unet = builder.build_unet(unet_num_layers=opt.unet_num_layers, ngf=opt.unet_ngf, input_nc=opt.unet_input_nc, output_nc=opt.unet_output_nc, weights=opt.weights_unet) # if opt.with_additional_scene_image: # opt.number_of_classes = opt.number_of_classes + 1 net_classifier = builder.build_classifier( pool_type=opt.classifier_pool, num_of_classes=opt.number_of_classes + 1, input_channel=opt.unet_output_nc, weights=opt.weights_classifier) net_refine = builder.build_refine(opt=opt,
def inference(opt): # network builders builder = ModelBuilder() model, _ = builder.get_model(opt) #model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) model.to(opt.device) model.eval() with open(opt.split_file, 'r') as fd: split = json.load(fd) audio_names = split[opt.split_subset] if len(audio_names) == 0: raise Exception("Split subset has no audios") #construct data loader data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() for audio_name in tqdm(audio_names): curr_audio_path = os.path.join(opt.input_audio_path, audio_name) #load the audio to perform separation audio, _ = librosa.load(curr_audio_path, sr=opt.audio_sampling_rate, mono=False) #perform spatialization over the whole audio using a sliding window approach overlap_count = np.zeros( (audio.shape )) #count the number of times a data point is calculated binaural_audio = np.zeros((audio.shape)) #perform spatialization over the whole spectrogram in a siliding-window fashion sliding_window_start = 0 data = {} samples_per_window = int(opt.audio_length * opt.audio_sampling_rate) ended = False while not ended: if sliding_window_start + samples_per_window >= audio.shape[-1]: sliding_window_start = audio.shape[-1] - samples_per_window ended = True sliding_window_end = sliding_window_start + samples_per_window data = dataset.dataset.__getitem__( curr_audio_path, audio, audio_start_time=sliding_window_start / opt.audio_sampling_rate, audio_end_time=sliding_window_end / opt.audio_sampling_rate, audio_start=sliding_window_start, audio_end=sliding_window_end) normalizer = data['normalizer'] del data['normalizer'] for k in data.keys(): if str(type(data[k])) == "<class 'torch.Tensor'>": data[k] = data[k].unsqueeze(0).to(opt.device) with torch.no_grad(): output = model.forward(data) prediction = output['binaural_output'] #ISTFT to convert back to audio if opt.model == "audioVisual": prediction = prediction[0, :, :, :].data[:].cpu().numpy() audio_segment_channel1 = audio[ 0, sliding_window_start:sliding_window_end] / normalizer audio_segment_channel2 = audio[ 1, sliding_window_start:sliding_window_end] / normalizer audio_segment_mix = audio_segment_channel1 + audio_segment_channel2 reconstructed_stft_diff = prediction[0, :, :] + ( 1j * prediction[1, :, :]) reconstructed_signal_diff = librosa.istft( reconstructed_stft_diff, hop_length=160, win_length=400, center=True, length=samples_per_window) reconstructed_signal_left = (audio_segment_mix + reconstructed_signal_diff) / 2 reconstructed_signal_right = (audio_segment_mix - reconstructed_signal_diff) / 2 reconstructed_binaural = np.concatenate( (np.expand_dims(reconstructed_signal_left, axis=0), np.expand_dims(reconstructed_signal_right, axis=0)), axis=0) * normalizer else: reconstructed_binaural = prediction.cpu().numpy() binaural_audio[:, sliding_window_start: sliding_window_end] = binaural_audio[:, sliding_window_start: sliding_window_end] + reconstructed_binaural overlap_count[:, sliding_window_start: sliding_window_end] = overlap_count[:, sliding_window_start: sliding_window_end] + 1 if opt.model == "audioVisual": sliding_window_start = sliding_window_start + int( opt.hop_size * opt.audio_sampling_rate) else: sliding_window_start = sliding_window_end #divide aggregated predicted audio by their corresponding counts predicted_binaural_audio = np.divide(binaural_audio, overlap_count) #check output directory if not os.path.isdir(opt.output_dir_root): os.mkdir(opt.output_dir_root) curr_output_dir_root = os.path.join(opt.output_dir_root, audio_name) if not os.path.isdir(curr_output_dir_root): os.mkdir(curr_output_dir_root) mixed_mono = (audio[0, :] + audio[1, :]) / 2 wavfile.write( os.path.join(curr_output_dir_root, 'predicted_binaural.wav'), opt.audio_sampling_rate, predicted_binaural_audio.T) wavfile.write(os.path.join(curr_output_dir_root, 'mixed_mono.wav'), opt.audio_sampling_rate, mixed_mono.T) wavfile.write(os.path.join(curr_output_dir_root, 'input_binaural.wav'), opt.audio_sampling_rate, audio.T)
def main(): #load test arguments opt = TestRealOptions().parse() opt.device = torch.device("cuda") # Network Builders builder = ModelBuilder() net_lipreading = builder.build_lipreadingnet( config_path=opt.lipreading_config_path, weights=opt.weights_lipreadingnet, extract_feats=opt.lipreading_extract_feature) #if identity feature dim is not 512, for resnet reduce dimension to this feature dim if opt.identity_feature_dim != 512: opt.with_fc = True else: opt.with_fc = False net_facial_attributes = builder.build_facial( pool_type=opt.visual_pool, fc_out = opt.identity_feature_dim, with_fc=opt.with_fc, weights=opt.weights_facial) net_unet = builder.build_unet( ngf=opt.unet_ngf, input_nc=opt.unet_input_nc, output_nc=opt.unet_output_nc, audioVisual_feature_dim=opt.audioVisual_feature_dim, identity_feature_dim=opt.identity_feature_dim, weights=opt.weights_unet) net_vocal_attributes = builder.build_vocal( pool_type=opt.audio_pool, input_channel=2, with_fc=opt.with_fc, fc_out = opt.identity_feature_dim, weights=opt.weights_vocal) nets = (net_lipreading, net_facial_attributes, net_unet, net_vocal_attributes) print(nets) # construct our audio-visual model model = AudioVisualModel(nets, opt) model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) model.to(opt.device) model.eval() mtcnn = MTCNN(keep_all=True, device=opt.device) lipreading_preprocessing_func = get_preprocessing_pipelines()['test'] normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) vision_transform_list = [transforms.ToTensor()] if opt.normalization: vision_transform_list.append(normalize) vision_transform = transforms.Compose(vision_transform_list) for speaker_index in range(opt.number_of_speakers): mouthroi_path = os.path.join(opt.mouthroi_root, 'speaker' + str(speaker_index+1) + '.npz') facetrack_path = os.path.join(opt.facetrack_root, 'speaker' + str(speaker_index+1) + '.mp4') #load data mouthroi = load_mouthroi(mouthroi_path) sr, audio = wavfile.read(opt.audio_path) print("sampling rate of audio: ", sr) if len((audio.shape)) == 2: audio = np.mean(audio, axis=1) #convert to mono if stereo audio = audio / 32768 audio = audio / 2.0 audio = clip_audio(audio) audio_length = len(audio) if opt.reliable_face: best_score = 0 for i in range(10): frame = load_frame(facetrack_path) boxes, scores = mtcnn.detect(frame) if scores[0] > best_score: best_frame = frame frames = vision_transform(best_frame).squeeze().unsqueeze(0).cuda() else: frame_list = [] for i in range(opt.number_of_identity_frames): frame = load_frame(facetrack_path) frame = vision_transform(frame) frame_list.append(frame) frame = torch.stack(frame_list).squeeze().unsqueeze(0).cuda() sep_audio = np.zeros((audio_length)) #perform separation over the whole audio using a sliding window approach sliding_window_start = 0 overlap_count = np.zeros((audio_length)) sep_audio = np.zeros((audio_length)) avged_sep_audio = np.zeros((audio_length)) samples_per_window = int(opt.audio_length * opt.audio_sampling_rate) while sliding_window_start + samples_per_window < audio_length: sliding_window_end = sliding_window_start + samples_per_window #get audio spectrogram segment_audio = audio[sliding_window_start:sliding_window_end] if opt.audio_normalization: normalizer, segment_audio = audio_normalize(segment_audio, desired_rms=0.07) else: normalizer = 1 audio_spec = generate_spectrogram_complex(segment_audio, opt.window_size, opt.hop_size, opt.n_fft) audio_spec = torch.FloatTensor(audio_spec).unsqueeze(0).cuda() #get mouthroi frame_index_start = int(round(sliding_window_start / opt.audio_sampling_rate * 25)) segment_mouthroi = mouthroi[frame_index_start:(frame_index_start + opt.num_frames), :, :] segment_mouthroi = lipreading_preprocessing_func(segment_mouthroi) segment_mouthroi = torch.FloatTensor(segment_mouthroi).unsqueeze(0).unsqueeze(0).cuda() reconstructed_signal = get_separated_audio(net_lipreading, net_facial_attributes, net_unet, audio_spec, segment_mouthroi, frames, opt) reconstructed_signal = reconstructed_signal * normalizer sep_audio[sliding_window_start:sliding_window_end] = sep_audio[sliding_window_start:sliding_window_end] + reconstructed_signal #update overlap count overlap_count[sliding_window_start:sliding_window_end] = overlap_count[sliding_window_start:sliding_window_end] + 1 sliding_window_start = sliding_window_start + int(opt.hop_length * opt.audio_sampling_rate) #deal with the last segment segment_audio = audio[-samples_per_window:] if opt.audio_normalization: normalizer, segment_audio = audio_normalize(segment_audio, desired_rms=0.07) else: normalizer = 1 audio_spec = generate_spectrogram_complex(segment_audio, opt.window_size, opt.hop_size, opt.n_fft) audio_spec = torch.FloatTensor(audio_spec).unsqueeze(0).cuda() #get mouthroi frame_index_start = int(round((len(audio) - samples_per_window) / opt.audio_sampling_rate * 25)) - 1 segment_mouthroi = mouthroi[-opt.num_frames:, :, :] segment_mouthroi = lipreading_preprocessing_func(segment_mouthroi) segment_mouthroi = torch.FloatTensor(segment_mouthroi).unsqueeze(0).unsqueeze(0).cuda() reconstructed_signal = get_separated_audio(net_lipreading, net_facial_attributes, net_unet, audio_spec, segment_mouthroi, frames, opt) reconstructed_signal = reconstructed_signal * normalizer sep_audio[-samples_per_window:] = sep_audio[-samples_per_window:] + reconstructed_signal #update overlap count overlap_count[-samples_per_window:] = overlap_count[-samples_per_window:] + 1 #divide the aggregated predicted audio by the overlap count avged_sep_audio = clip_audio(np.divide(sep_audio, overlap_count)) #output separated audios if not os.path.isdir(opt.output_dir_root): os.mkdir(opt.output_dir_root) librosa.output.write_wav(os.path.join(opt.output_dir_root, 'speaker' + str(speaker_index+1) + '.wav'), avged_sep_audio, opt.audio_sampling_rate)
#temperally set to val to load val data opt.mode = 'val' data_loader_val = CreateDataLoader(opt) dataset_val = data_loader_val.load_data() dataset_size_val = len(data_loader_val) print('#validation images = %d' % dataset_size_val) opt.mode = 'train' #set it back if opt.tensorboard: from tensorboardX import SummaryWriter writer = SummaryWriter(comment=opt.name) else: writer = None # Network Builders builder = ModelBuilder() net_lipreading = builder.build_lipreadingnet( config_path=opt.lipreading_config_path, weights=opt.weights_lipreadingnet, extract_feats=opt.lipreading_extract_feature) #if identity feature dim is not 512, for resnet reduce dimension to this feature dim if opt.identity_feature_dim != 512: opt.with_fc = True else: opt.with_fc = False net_facial_attribtes = builder.build_facial( pool_type=opt.visual_pool, fc_out = opt.identity_feature_dim, with_fc=opt.with_fc, weights=opt.weights_facial) net_unet = builder.build_unet(
def main(): #load test arguments opt = TestOptions().parse() opt.device = torch.device("cuda") # network builders builder = ModelBuilder() net_visual = builder.build_visual(weights=opt.weights_visual) net_audio = builder.build_audio(ngf=opt.unet_ngf, input_nc=opt.unet_input_nc, output_nc=opt.unet_output_nc, weights=opt.weights_audio) nets = (net_visual, net_audio) # construct our audio-visual model model = AudioVisualModel(nets, opt) model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) model.to(opt.device) model.eval() #load the audio to perform separation audio, audio_rate = librosa.load(opt.input_audio_path, sr=opt.audio_sampling_rate, mono=False) audio_channel1 = audio[0, :] audio_channel2 = audio[1, :] #define the transformation to perform on visual frames vision_transform_list = [ transforms.Resize((224, 448)), transforms.ToTensor() ] vision_transform_list.append( transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) vision_transform = transforms.Compose(vision_transform_list) #perform spatialization over the whole audio using a sliding window approach overlap_count = np.zeros( (audio.shape)) #count the number of times a data point is calculated binaural_audio = np.zeros((audio.shape)) #perform spatialization over the whole spectrogram in a siliding-window fashion sliding_window_start = 0 data = {} samples_per_window = int(opt.audio_length * opt.audio_sampling_rate) while sliding_window_start + samples_per_window < audio.shape[-1]: sliding_window_end = sliding_window_start + samples_per_window normalizer, audio_segment = audio_normalize( audio[:, sliding_window_start:sliding_window_end]) audio_segment_channel1 = audio_segment[0, :] audio_segment_channel2 = audio_segment[1, :] audio_segment_mix = audio_segment_channel1 + audio_segment_channel2 data['audio_diff_spec'] = torch.FloatTensor( generate_spectrogram(audio_segment_channel1 - audio_segment_channel2)).unsqueeze( 0) #unsqueeze to add a batch dimension data['audio_mix_spec'] = torch.FloatTensor( generate_spectrogram(audio_segment_channel1 + audio_segment_channel2)).unsqueeze( 0) #unsqueeze to add a batch dimension #get the frame index for current window frame_index = int( round((((sliding_window_start + samples_per_window / 2.0) / audio.shape[-1]) * opt.input_audio_length + 0.05) * 10)) image = Image.open( os.path.join(opt.video_frame_path, str(frame_index).zfill(6) + '.jpg')).convert('RGB') #image = image.transpose(Image.FLIP_LEFT_RIGHT) frame = vision_transform(image).unsqueeze( 0) #unsqueeze to add a batch dimension data['frame'] = frame output = model.forward(data) predicted_spectrogram = output['binaural_spectrogram'][ 0, :, :, :].data[:].cpu().numpy() #ISTFT to convert back to audio reconstructed_stft_diff = predicted_spectrogram[0, :, :] + ( 1j * predicted_spectrogram[1, :, :]) reconstructed_signal_diff = librosa.istft(reconstructed_stft_diff, hop_length=160, win_length=400, center=True, length=samples_per_window) reconstructed_signal_left = (audio_segment_mix + reconstructed_signal_diff) / 2 reconstructed_signal_right = (audio_segment_mix - reconstructed_signal_diff) / 2 reconstructed_binaural = np.concatenate( (np.expand_dims(reconstructed_signal_left, axis=0), np.expand_dims(reconstructed_signal_right, axis=0)), axis=0) * normalizer binaural_audio[:, sliding_window_start: sliding_window_end] = binaural_audio[:, sliding_window_start: sliding_window_end] + reconstructed_binaural overlap_count[:, sliding_window_start: sliding_window_end] = overlap_count[:, sliding_window_start: sliding_window_end] + 1 sliding_window_start = sliding_window_start + int( opt.hop_size * opt.audio_sampling_rate) #deal with the last segment normalizer, audio_segment = audio_normalize(audio[:, -samples_per_window:]) audio_segment_channel1 = audio_segment[0, :] audio_segment_channel2 = audio_segment[1, :] data['audio_diff_spec'] = torch.FloatTensor( generate_spectrogram(audio_segment_channel1 - audio_segment_channel2)).unsqueeze( 0) #unsqueeze to add a batch dimension data['audio_mix_spec'] = torch.FloatTensor( generate_spectrogram(audio_segment_channel1 + audio_segment_channel2)).unsqueeze( 0) #unsqueeze to add a batch dimension #get the frame index for last window frame_index = int( round(((opt.input_audio_length - opt.audio_length / 2.0) + 0.05) * 10)) image = Image.open( os.path.join(opt.video_frame_path, str(frame_index).zfill(6) + '.jpg')).convert('RGB') #image = image.transpose(Image.FLIP_LEFT_RIGHT) frame = vision_transform(image).unsqueeze( 0) #unsqueeze to add a batch dimension data['frame'] = frame output = model.forward(data) predicted_spectrogram = output['binaural_spectrogram'][ 0, :, :, :].data[:].cpu().numpy() #ISTFT to convert back to audio reconstructed_stft_diff = predicted_spectrogram[0, :, :] + ( 1j * predicted_spectrogram[1, :, :]) reconstructed_signal_diff = librosa.istft(reconstructed_stft_diff, hop_length=160, win_length=400, center=True, length=samples_per_window) reconstructed_signal_left = (audio_segment_mix + reconstructed_signal_diff) / 2 reconstructed_signal_right = (audio_segment_mix - reconstructed_signal_diff) / 2 reconstructed_binaural = np.concatenate( (np.expand_dims(reconstructed_signal_left, axis=0), np.expand_dims(reconstructed_signal_right, axis=0)), axis=0) * normalizer #add the spatialized audio to reconstructed_binaural binaural_audio[:, -samples_per_window:] = binaural_audio[:, -samples_per_window:] + reconstructed_binaural overlap_count[:, -samples_per_window:] = overlap_count[:, -samples_per_window:] + 1 #divide aggregated predicted audio by their corresponding counts predicted_binaural_audio = np.divide(binaural_audio, overlap_count) #check output directory if not os.path.isdir(opt.output_dir_root): os.mkdir(opt.output_dir_root) mixed_mono = (audio_channel1 + audio_channel2) / 2 librosa.output.write_wav( os.path.join(opt.output_dir_root, 'predicted_binaural.wav'), predicted_binaural_audio, opt.audio_sampling_rate) librosa.output.write_wav( os.path.join(opt.output_dir_root, 'mixed_mono.wav'), mixed_mono, opt.audio_sampling_rate) librosa.output.write_wav( os.path.join(opt.output_dir_root, 'input_binaural.wav'), audio, opt.audio_sampling_rate)
import os import torch import numpy as np from options.test_options import TestOptions import torchvision.transforms as transforms from models.models import ModelBuilder from models.audioVisual_model import AudioVisualModel from data_loader.custom_dataset_data_loader import CustomDatasetDataLoader from util.util import compute_errors from models import criterion loss_criterion = criterion.LogDepthLoss() opt = TestOptions().parse() opt.device = torch.device("cuda") builder = ModelBuilder() net_audiodepth = builder.build_audiodepth( opt.audio_shape, weights=os.path.join(opt.checkpoints_dir, 'audiodepth_' + opt.dataset + '.pth')) net_rgbdepth = builder.build_rgbdepth( weights=os.path.join(opt.checkpoints_dir, 'rgbdepth_' + opt.dataset + '.pth')) net_attention = builder.build_attention( weights=os.path.join(opt.checkpoints_dir, 'attention_' + opt.dataset + '.pth')) net_material = builder.build_material_property( weights=os.path.join(opt.checkpoints_dir, 'material_' + opt.dataset + '.pth')) nets = (net_rgbdepth, net_audiodepth, net_attention, net_material)
#temperally set to val to load val data opt.mode = 'val' data_loader_val = CreateDataLoader(opt) dataset_val = data_loader_val.load_data() dataset_size_val = len(data_loader_val) print('#validation clips = %d' % dataset_size_val) opt.mode = 'train' #set it back if opt.tensorboard: #from tensorboardX import SummaryWriter writer = SummaryWriter(comment=opt.name) else: writer = None # network builders builder = ModelBuilder() net_visual = builder.build_visual(weights=opt.weights_visual) net_audio = builder.build_audio( ngf=opt.unet_ngf, input_nc=opt.unet_input_nc, output_nc=opt.unet_output_nc, weights=opt.weights_audio) nets = (net_visual, net_audio) # construct our audio-visual model model = AudioVisualModel(nets, opt) # 多卡并行计算 model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) model.to(opt.device) print("model is created")
def main(args): os.makedirs(args.checkpoint_path, exist_ok=True) # Setup logging system logger = setup_logger( "Listen_to_look, audio_preview classification single modality", args.checkpoint_path, True) logger.debug(args) # Epoch logging epoch_log = setup_logger("Listen_to_look: results", args.checkpoint_path, True, logname="epoch.log") epoch_log.info("epoch,loss,acc,lr") writer = None if args.visualization: writer = setup_tbx(args.checkpoint_path, True) if writer is not None: logger.info("Allowed Tensorboard writer") # Define the model builder = ModelBuilder() net_classifier = builder.build_classifierNet(args.embedding_size, args.num_classes).cuda() net_imageAudioClassify = builder.build_imageAudioClassifierNet( net_classifier, args).cuda() model = builder.build_audioPreviewLSTM(net_classifier, args) model = model.cuda() # DATA LOADING train_ds, train_collate = create_training_dataset(args, logger=logger) val_ds, val_collate = create_validation_dataset(args, logger=logger) train_loader = torch.utils.data.DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.decode_threads, collate_fn=train_collate) val_loader = torch.utils.data.DataLoader(val_ds, batch_size=args.batch_size, num_workers=4, collate_fn=val_collate) args.iters_per_epoch = len(train_loader) args.warmup_iters = args.warmup_epochs * args.iters_per_epoch args.milestones = [args.iters_per_epoch * m for m in args.milestones] # define loss function (criterion) and optimizer criterion = {} criterion['CrossEntropyLoss'] = nn.CrossEntropyLoss().cuda() if args.freeze_imageAudioNet: param_groups = [{ 'params': model.queryfeature_mlp.parameters(), 'lr': args.lr }, { 'params': model.prediction_fc.parameters(), 'lr': args.lr }, { 'params': model.key_conv1x1.parameters(), 'lr': args.lr }, { 'params': model.rnn.parameters(), 'lr': args.lr }, { 'params': net_classifier.parameters(), 'lr': args.lr }] optimizer = torch.optim.SGD(param_groups, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=1) else: optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=1) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.milestones) # make optimizer scheduler if args.scheduler: scheduler = default_lr_scheduler(optimizer, args.milestones, args.warmup_iters) cudnn.benchmark = True # setting up the checkpointing system write_here = True checkpointer = Checkpointer(model, optimizer, save_dir=args.checkpoint_path, save_to_disk=write_here, scheduler=scheduler, logger=logger) if args.pretrained_model is not None: logger.debug("Loading model only at: {}".format(args.pretrained_model)) checkpointer.load_model_only(f=args.pretrained_model) if checkpointer.has_checkpoint(): # call load checkpoint logger.debug("Loading last checkpoint") checkpointer.load() model = torch.nn.parallel.DataParallel(model).cuda() logger.debug(model) # Log all info if writer: writer.add_text("namespace", repr(args)) writer.add_text("model", str(model)) # # TRAINING # logger.debug("Entering the training loop") for epoch in range(args.start_epoch, args.epochs): # train for one epoch train_accuracy, train_loss = train_epoch(args, epoch, train_loader, model, criterion, optimizer, scheduler, logger, epoch_logger=epoch_log, checkpointer=checkpointer, writer=writer) test_map, test_accuracy, test_loss, _ = validate( args, epoch, val_loader, model, criterion, epoch_logger=epoch_log, writer=writer) if writer is not None: writer.add_scalars('training_curves/accuracies', { 'train': train_accuracy, 'val': test_accuracy }, epoch) writer.add_scalars('training_curves/loss', { 'train': train_loss, 'val': test_loss }, epoch)