def build_dataset(fields, data_type, src_path, tgt_path, src_dir=None, second_data_type=None, second_src_path=None, src_seq_length=0, tgt_seq_length=0, src_seq_length_trunc=0, tgt_seq_length_trunc=0, dynamic_dict=True, sample_rate=0, window_size=0, window_stride=0, window=None, normalize_audio=True, use_filter_pred=True, file_to_tensor_fn=None): use_second_modality = second_data_type is not None if use_second_modality: assert data_type == 'text' # Only implemented for primary input type text # Second data type should not be text. One could simply append his secondary text # to the primary input. assert second_data_type != 'text', 'second_data_type cannot be text.' assert second_src_path is not None and src_dir is not None, \ 'If second_data_type is set, second_src_path as well as src_dir needs to be present' # Build src/tgt examples iterator from corpus files, also extract # number of features. src_examples_iter, num_src_feats = \ _make_examples_nfeats_tpl(data_type, src_path, src_dir, src_seq_length_trunc, sample_rate, window_size, window_stride, window, normalize_audio, file_to_tensor_fn=file_to_tensor_fn) if use_second_modality: src2_examples_iter, num_src2_feats = \ _make_examples_nfeats_tpl(second_data_type, second_src_path, src_dir, src_seq_length_trunc, sample_rate, window_size, window_stride, window, normalize_audio, side='src2', file_to_tensor_fn=file_to_tensor_fn) # For all data types, the tgt side corpus is in form of text. tgt_examples_iter, num_tgt_feats = \ TextDataset.make_text_examples_nfeats_tpl( tgt_path, tgt_seq_length_trunc, "tgt") if use_second_modality: dataset = MultiModalDataset(fields, src_examples_iter, src2_examples_iter, second_data_type, tgt_examples_iter, num_src_feats, num_src2_feats, num_tgt_feats, src_seq_length=src_seq_length, tgt_seq_length=tgt_seq_length, use_filter_pred=use_filter_pred) elif data_type == 'text': dataset = TextDataset(fields, src_examples_iter, tgt_examples_iter, num_src_feats, num_tgt_feats, src_seq_length=src_seq_length, tgt_seq_length=tgt_seq_length, dynamic_dict=dynamic_dict, use_filter_pred=use_filter_pred) elif data_type == 'img': dataset = ImageDataset(fields, src_examples_iter, tgt_examples_iter, num_src_feats, num_tgt_feats, tgt_seq_length=tgt_seq_length, use_filter_pred=use_filter_pred) elif data_type == 'audio': dataset = AudioDataset(fields, src_examples_iter, tgt_examples_iter, num_src_feats, num_tgt_feats, tgt_seq_length=tgt_seq_length, sample_rate=sample_rate, window_size=window_size, window_stride=window_stride, window=window, normalize_audio=normalize_audio, use_filter_pred=use_filter_pred) return dataset
def main(): # Todo: Load checkpoint if we resume from a previous training. if opt.train_from: # opt.train_from defaults 'False'. print('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] opt.start_epoch = checkpoint['epoch'] + 1 else: checkpoint = None model_opt = opt train_dataset = lazily_load_dataset('train') ex_generator = next(train_dataset) # # {'indices': 0, # # 'src': None, # will not be used. should be removed when preparing data. # # 'src_audio': FloatTensor, # # 'src_path': wav path, # will not be used. should be removed when preparing data. # # 'src_text': tuple, # # 'tgt': tuple # # } # For debug. # ex=ex_generator[0] # getattr(ex,'src_audio',None) # getattr(ex,'src_text',None) # getattr(ex,'tgt',None) pass # load vocab vocabs = torch.load(opt.data + '.vocab.pt') # 'src_text', 'tgt' vocabs = dict(vocabs) pass # get fields, we attempt to use dict to store fields for different encoders(source data). text_fields = TextDataset.get_fields( 0, 0) # Here we set number of src_features and tgt_features to 0. # Actually, we can use these features, but it need more modifications. audio_fields = AudioDataset.get_fields(0, 0) # fields['src_text'] = fields['src'] # Copy key from 'src' to 'src_text'. for assigning the field for text type input. # the field for audio type input will not be made, i.e., fields['src_audio']=audio_fields['src']. # Because it will not be used next. for k, v in vocabs.items(): v.stoi = defaultdict(lambda: 0, v.stoi) if k == 'src_text': text_fields['src'].vocab = v else: text_fields['tgt'].vocab = v audio_fields['tgt'].vocab = v text_fields = dict([(k, f) for (k, f) in text_fields.items() if k in ex_generator[0].__dict__ ]) # 'indices', 'src', 'src_text', 'tgt' audio_fields = dict([(k, f) for (k, f) in audio_fields.items() if k in ex_generator[0].__dict__]) print(' * vocabulary size. text source = %d; target = %d' % (len(text_fields['src'].vocab), len(text_fields['tgt'].vocab))) print(' * vocabulary size. audio target = %d' % len(audio_fields['tgt'].vocab)) fields_dict = {'text': text_fields, 'audio': audio_fields} pass # Build model. model = build_multiencoder_model( model_opt, opt, fields_dict) # TODO: support using 'checkpoint'. tally_parameters(model) check_save_model_path() # Build optimizer. optim = build_optim(model) # TODO: support using 'checkpoint'. # Do training. train_model(model, fields_dict, optim, data_type='multi', model_opt=model_opt)