def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined global meta_data_train global meta_data_eval ap = AudioProcessor(**c.audio) model = SpeakerEncoder(input_dim=40, proj_dim=128, lstm_dim=384, num_lstm_layers=3) optimizer = RAdam(model.parameters(), lr=c.lr) criterion = GE2ELoss(loss_method='softmax') if args.restore_path: checkpoint = torch.load(args.restore_path) try: # TODO: fix optimizer init, model.cuda() needs to be called before # optimizer restore # optimizer.load_state_dict(checkpoint['optimizer']) if c.reinit_layers: raise RuntimeError model.load_state_dict(checkpoint['model']) except KeyError: print(" > Partial model initialization.") model_dict = model.state_dict() model_dict = set_init_dict(model_dict, checkpoint, c) model.load_state_dict(model_dict) del model_dict for group in optimizer.param_groups: group['lr'] = c.lr print(" > Model restored from step %d" % checkpoint['step'], flush=True) args.restore_step = checkpoint['step'] else: args.restore_step = 0 if use_cuda: model = model.cuda() criterion.cuda() if c.lr_decay: scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) else: scheduler = None num_params = count_parameters(model) print("\n > Model has {} parameters".format(num_params), flush=True) # pylint: disable=redefined-outer-name meta_data_train, meta_data_eval = load_meta_data(c.datasets) global_step = args.restore_step train_loss, global_step = train(model, criterion, optimizer, scheduler, ap, global_step)
def test_in_out(self): # check random input dummy_input = T.rand(4, 5, 64) # num_speaker x num_utterance x dim loss = GE2ELoss(loss_method="softmax") output = loss.forward(dummy_input) assert output.item() >= 0.0 # check all zeros dummy_input = T.ones(4, 5, 64) # num_speaker x num_utterance x dim loss = GE2ELoss(loss_method="softmax") output = loss.forward(dummy_input) # check speaker loss with orthogonal d-vectors dummy_input = T.empty(3, 64) dummy_input = T.nn.init.orthogonal(dummy_input) dummy_input = T.cat([ dummy_input[0].repeat(5, 1, 1).transpose(0, 1), dummy_input[1].repeat(5, 1, 1).transpose(0, 1), dummy_input[2].repeat(5, 1, 1).transpose(0, 1), ]) # num_speaker x num_utterance x dim loss = GE2ELoss(loss_method="softmax") output = loss.forward(dummy_input) assert output.item() < 0.005