def get_fields(data_type, n_src_features, n_tgt_features): """ Args: data_type: type of the source input. Options are [text|img|audio]. n_src_features: the number of source features to create `torchtext.data.Field` for. n_tgt_features: the number of target features to create `torchtext.data.Field` for. Returns: A dictionary whose keys are strings and whose values are the corresponding Field objects. """ if data_type == 'text': return TextDataset.get_fields(n_src_features, n_tgt_features) elif data_type == 'img': return ImageDataset.get_fields(n_src_features, n_tgt_features) elif data_type == 'audio': return AudioDataset.get_fields(n_src_features, n_tgt_features) elif data_type == 'gcn': return GCNDataset.get_fields(n_src_features, n_tgt_features)
def main(): # Todo: Load checkpoint if we resume from a previous training. if opt.train_from: # opt.train_from defaults 'False'. print('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] opt.start_epoch = checkpoint['epoch'] + 1 else: checkpoint = None model_opt = opt train_dataset = lazily_load_dataset('train') ex_generator = next(train_dataset) # # {'indices': 0, # # 'src': None, # will not be used. should be removed when preparing data. # # 'src_audio': FloatTensor, # # 'src_path': wav path, # will not be used. should be removed when preparing data. # # 'src_text': tuple, # # 'tgt': tuple # # } # For debug. # ex=ex_generator[0] # getattr(ex,'src_audio',None) # getattr(ex,'src_text',None) # getattr(ex,'tgt',None) pass # load vocab vocabs = torch.load(opt.data + '.vocab.pt') # 'src_text', 'tgt' vocabs = dict(vocabs) pass # get fields, we attempt to use dict to store fields for different encoders(source data). text_fields = TextDataset.get_fields( 0, 0) # Here we set number of src_features and tgt_features to 0. # Actually, we can use these features, but it need more modifications. audio_fields = AudioDataset.get_fields(0, 0) # fields['src_text'] = fields['src'] # Copy key from 'src' to 'src_text'. for assigning the field for text type input. # the field for audio type input will not be made, i.e., fields['src_audio']=audio_fields['src']. # Because it will not be used next. for k, v in vocabs.items(): v.stoi = defaultdict(lambda: 0, v.stoi) if k == 'src_text': text_fields['src'].vocab = v else: text_fields['tgt'].vocab = v audio_fields['tgt'].vocab = v text_fields = dict([(k, f) for (k, f) in text_fields.items() if k in ex_generator[0].__dict__ ]) # 'indices', 'src', 'src_text', 'tgt' audio_fields = dict([(k, f) for (k, f) in audio_fields.items() if k in ex_generator[0].__dict__]) print(' * vocabulary size. text source = %d; target = %d' % (len(text_fields['src'].vocab), len(text_fields['tgt'].vocab))) print(' * vocabulary size. audio target = %d' % len(audio_fields['tgt'].vocab)) fields_dict = {'text': text_fields, 'audio': audio_fields} pass # Build model. model = build_multiencoder_model( model_opt, opt, fields_dict) # TODO: support using 'checkpoint'. tally_parameters(model) check_save_model_path() # Build optimizer. optim = build_optim(model) # TODO: support using 'checkpoint'. # Do training. train_model(model, fields_dict, optim, data_type='multi', model_opt=model_opt)