예제 #1
0
파일: train.py 프로젝트: cs50victor/riri
def main(args):  # pylint: disable=redefined-outer-name
    # pylint: disable=global-variable-undefined
    global meta_data_train
    global meta_data_eval

    ap = AudioProcessor(**c.audio)
    model = SpeakerEncoder(input_dim=40,
                           proj_dim=128,
                           lstm_dim=384,
                           num_lstm_layers=3)
    optimizer = RAdam(model.parameters(), lr=c.lr)
    criterion = GE2ELoss(loss_method='softmax')

    if args.restore_path:
        checkpoint = torch.load(args.restore_path)
        try:
            # TODO: fix optimizer init, model.cuda() needs to be called before
            # optimizer restore
            # optimizer.load_state_dict(checkpoint['optimizer'])
            if c.reinit_layers:
                raise RuntimeError
            model.load_state_dict(checkpoint['model'])
        except KeyError:
            print(" > Partial model initialization.")
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint, c)
            model.load_state_dict(model_dict)
            del model_dict
        for group in optimizer.param_groups:
            group['lr'] = c.lr
        print(" > Model restored from step %d" % checkpoint['step'],
              flush=True)
        args.restore_step = checkpoint['step']
    else:
        args.restore_step = 0

    if use_cuda:
        model = model.cuda()
        criterion.cuda()

    if c.lr_decay:
        scheduler = NoamLR(optimizer,
                           warmup_steps=c.warmup_steps,
                           last_epoch=args.restore_step - 1)
    else:
        scheduler = None

    num_params = count_parameters(model)
    print("\n > Model has {} parameters".format(num_params), flush=True)

    # pylint: disable=redefined-outer-name
    meta_data_train, meta_data_eval = load_meta_data(c.datasets)

    global_step = args.restore_step
    train_loss, global_step = train(model, criterion, optimizer, scheduler, ap,
                                    global_step)
예제 #2
0
 def test_in_out(self):
     # check random input
     dummy_input = T.rand(4, 5, 64)  # num_speaker x num_utterance x dim
     loss = GE2ELoss(loss_method="softmax")
     output = loss.forward(dummy_input)
     assert output.item() >= 0.0
     # check all zeros
     dummy_input = T.ones(4, 5, 64)  # num_speaker x num_utterance x dim
     loss = GE2ELoss(loss_method="softmax")
     output = loss.forward(dummy_input)
     # check speaker loss with orthogonal d-vectors
     dummy_input = T.empty(3, 64)
     dummy_input = T.nn.init.orthogonal(dummy_input)
     dummy_input = T.cat([
         dummy_input[0].repeat(5, 1, 1).transpose(0, 1),
         dummy_input[1].repeat(5, 1, 1).transpose(0, 1),
         dummy_input[2].repeat(5, 1, 1).transpose(0, 1),
     ])  # num_speaker x num_utterance x dim
     loss = GE2ELoss(loss_method="softmax")
     output = loss.forward(dummy_input)
     assert output.item() < 0.005