コード例 #1
0
def main():
    #Train settings
    wandb.init(project="vgg_triplet")
    global args, best_acc
    parser = argparse.ArgumentParser(
        description='VGG Triplet-Loss Speaker Embedding')
    parser.add_argument('--batch-size',
                        type=int,
                        default=32,
                        metavar='N',
                        help='input batch size for training')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=32,
                        metavar='N',
                        help='input batch size for testing')
    parser.add_argument('--epochs',
                        type=int,
                        default=50,
                        metavar='N',
                        help='number of epochs for training')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        metavar='LR',
                        help='learning rate')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='enables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=2,
                        metavar='S',
                        help='random seed')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=20,
        metavar='N',
        help='how many batches to wait before logging training score')
    parser.add_argument('--margin',
                        type=float,
                        default=2,
                        metavar='M',
                        help='margin for triplet loss (default: 0.2)')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        help='path to latest checkpoint (default: None)')
    parser.add_argument('--name',
                        default='TripletNet_RMSprop',
                        type=str,
                        help='name of experiment')
    parser.add_argument(
        '--base-path',
        type=str,
        default=
        '/home/lucas/PycharmProjects/Papers_with_code/data/AMI/amicorpus_individual/Extracted_Speech',
        help='string to triplets')
    parser.add_argument('--ap-file',
                        default='anchor_pairs.txt',
                        type=str,
                        help='name of file with anchor-positive pairs')
    parser.add_argument('--s-file',
                        default='trimmed_sample_list.txt',
                        type=str,
                        help='name of sample list')
    parser.add_argument(
        '--save-path',
        default=
        '/home/lucas/PycharmProjects/Papers_with_code/data/models/VGG_Triplet',
        type=str,
        help='path to save models to')
    parser.add_argument('--save', type=bool, default=True, help='save model?')
    parser.add_argument('--load',
                        type=bool,
                        default=False,
                        help='load model from latest checkpoint')
    args = parser.parse_args()

    wandb.run.name = args.name
    wandb.run.save()
    wandb.config.update(args)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    print(args.no_cuda)
    print(torch.cuda.is_available())
    if args.cuda:
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")
    torch.manual_seed(args.seed)
    kwargs = {'num_workers': 2, 'pin_memory': True} if args.cuda else {}

    #train_loader = torch.utils.data.DataLoader(TripletLoader(base_path=args.base_path,anchor_positive_pairs=args.ap_file,sample_list=args.s_file,train=True),
    #                                           batch_size=args.batch_size,shuffle=True,**kwargs)
    #test_loader = torch.utils.data.DataLoader(TripletLoader(base_path=args.base_path,anchor_positive_pairs=args.ap_file,sample_list=args.s_file,train=False),
    #                                          batch_size=args.test_batch_size,shuffle=True,**kwargs)

    #single_train_loader = torch.utils.data.DataLoader(Spectrogram_Loader(base_path=args.base_path, anchor_positive_pairs=args.ap_file, sample_list=args.s_file, train=True), batch_size=args.batch_size, shuffle=True, **kwargs)
    #single_test_loader = torch.utils.data.DataLoader(Spectrogram_Loader(base_path=args.base_path, anchor_positive_pairs=args.ap_file, sample_list=args.s_file, train=False), batch_size=args.test_batch_size, shuffle=True, **kwargs)

    train_time_loader = torch.utils.data.DataLoader(Triplet_Time_Loader(
        path=os.path.join(args.base_path, args.s_file), train=True),
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    **kwargs)
    test_time_loader = torch.utils.data.DataLoader(
        Triplet_Time_Loader(path=os.path.join(args.base_path, args.s_file),
                            train=False),
        batch_size=args.test_batch_size,
        shuffle=True,
        **kwargs)

    #global plotter
    #plotter = VisdomLinePlotter(env_name=args.name)

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 16, kernel_size=7)
            self.conv2 = nn.Conv2d(16, 16, kernel_size=7)
            self.bn_1 = nn.BatchNorm2d(16)
            self.conv3 = nn.Conv2d(16, 32, kernel_size=7)
            self.conv4 = nn.Conv2d(32, 32, kernel_size=7)
            self.bn_2 = nn.BatchNorm2d(32)
            self.conv2_drop = nn.Dropout2d(p=0.2)
            self.fc1 = nn.Linear(448, 256)
            self.fc2 = nn.Linear(256, 256)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = F.relu(F.max_pool2d(self.bn_1(self.conv2(x)), 7))
            x = F.relu(self.conv3(x))
            x = F.relu(
                F.max_pool2d(self.conv2_drop(self.bn_2(self.conv4(x))), 7))
            #print("SIZE  ",x.size())
            x = x.view(x.size(0), -1)
            x = F.relu(self.fc1(x))
            x = F.dropout(x, training=self.training)
            return self.fc2(x)

    model = VGGVox()
    if args.cuda:
        model.to(device)
    if args.load:
        model.load_state_dict(torch.load(args.save_path))
        print("Model loaded from state dict")
    #tnet = TripletNet(model)
    #if args.cuda:
    #    tnet.to(device)
    wandb.watch(model)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    #criterion = torch.nn.MarginRankingLoss(margin = args.margin)
    criterion = nn.TripletMarginLoss(margin=args.margin, p=2)

    #optimizer = optim.Adam(tnet.parameters(),lr=args.lr)
    #optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    optimizer = optim.RMSprop(model.parameters(),
                              lr=args.lr,
                              alpha=0.8,
                              momentum=args.momentum)
    #n_parameters = sum([p.data.nelement() for p in tnet.parameters()])
    #print('  + NUmber of params: {}'.format(n_parameters))

    for epoch in range(1, args.epochs + 1):
        start_time = time.time()
        train_batch(train_time_loader, model, optimizer, epoch)
        test_batch(test_time_loader, model, epoch)
        duration = time.time() - start_time
        print("Done training epoch {} in {:.4f}".format(epoch, duration))

    #for epoch in range(1, args.epochs + 1):
    #    test_batch(single_train_loader, model, epoch)

    if args.save:
        torch.save(model.state_dict(), args.save_path)
        print("Model Saved")
コード例 #2
0
def main():
    global args
    #wandb.login()
    wandb.init(project="vgg_triplet")
    config = wandb.config
    parser = argparse.ArgumentParser(
        description="VGGVox CNN with Spectrograms for Speaker Verification")
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=40,
                        metavar='N',
                        help='input batch size for training')
    parser.add_argument('--train-batch-size',
                        type=int,
                        default=40,
                        metavar='N',
                        help='input batch size for testing')
    parser.add_argument('--epochs',
                        type=int,
                        default=200,
                        metavar='N',
                        help='number of epochs')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        metavar='LR',
                        help='learning rate')
    parser.add_argument('--seed',
                        type=int,
                        default=1234,
                        metavar='S',
                        help='random seed')
    parser.add_argument('--margin',
                        type=float,
                        default=8,
                        metavar='M',
                        help='margin for triplet loss (default: 0.2)')
    parser.add_argument('--name',
                        default='VGG_Spectogram_Triplet',
                        type=str,
                        help='name of network')
    parser.add_argument(
        '--train-set',
        default=
        '/home/lucas/PycharmProjects/Data/pyannote/Extracted_Speech/trimmed_sample_list_train.txt',
        type=str,
        help='path to train samples')
    parser.add_argument(
        '--test-set',
        default=
        '/home/lucas/PycharmProjects/Data/pyannote/Extracted_Speech/trimmed_sample_list_test.txt',
        type=str,
        help='path to test samples')
    parser.add_argument(
        '--valid-set',
        default=
        '/home/lucas/PycharmProjects/Data/pyannote/Extracted_Speech/trimmed_sample_list_valid.txt',
        type=str,
        help='path to validation samples')
    parser.add_argument(
        '--model-path',
        default=
        '/home/lucas/PycharmProjects/MetricEmbeddingNet/models/VGG_Spectrogram_Triplet.pt',
        type=str,
        help='path to where models are saved/loaded')
    parser.add_argument('--save-model',
                        type=bool,
                        default=True,
                        help='save model?')
    parser.add_argument('--load-model',
                        type=bool,
                        default=False,
                        help='load model?')
    parser.add_argument('--melspectrogram',
                        type=bool,
                        default=False,
                        help='use melspectrogram?')
    args = parser.parse_args()

    wandb.config.update(args)

    torch.manual_seed(config.seed)
    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.backends.cudnn.deterministic = True
    device = torch.device('cuda:0')

    kwargs = {'num_workers': 6, 'pin_memory': True}
    train_loader = data.DataLoader(Spectrogram_Loader(filename=args.train_set,
                                                      mel=False),
                                   batch_size=config.train_batch_size,
                                   shuffle=True,
                                   **kwargs)
    test_loader = data.DataLoader(Spectrogram_Loader(filename=args.test_set,
                                                     mel=False),
                                  batch_size=config.test_batch_size,
                                  shuffle=True,
                                  **kwargs)
    valid_loader = data.DataLoader(Spectrogram_Loader(filename=args.valid_set,
                                                      mel=False),
                                   batch_size=config.test_batch_size,
                                   shuffle=True,
                                   **kwargs)

    model = VGGVox()
    model.to(device)

    if args.load_model:
        try:
            model.load_state_dict(torch.load(args.model_path))
        except:
            print("Could not load model {} not found".format(args.model_path))
            #nn.init.xavier_uniform(model.parameters())

    #optimizer = optim.Adam(model.parameters(), lr = config.lr)
    optimizer = optim.RMSprop(model.parameters(),
                              lr=0.001,
                              alpha=0.8,
                              momentum=0.5)
    #optimizer = optim.Adam(model.parameters(), lr=0.0001, eps=1e-3, amsgrad=True)
    wandb.watch(model)

    for epoch in range(1, config.epochs + 1):
        start_time = time.time()
        train_loss, train_acc = train(train_loader=train_loader,
                                      model=model,
                                      optimizer=optimizer,
                                      device=device,
                                      epoch=epoch)
        test_loss, test_acc = test(data_loader=test_loader,
                                   model=model,
                                   device=device,
                                   epoch=epoch)
        #valid_loss, valid_acc = test(data_loader=valid_loader, model=model, device=device, epoch=epoch)
        print('Finished epoch {} in {:.2f} '.format(
            epoch, (time.time() - start_time)))
        wandb.log({
            'Train Loss': train_loss,
            'Train Accuracy': train_acc,
            'Test Loss': test_loss,
            'Test Accuracy': test_acc
        })  # 'Validation Loss': valid_loss, 'Validation Accuracy': valid_acc})
        if config.save_model and (epoch % 20 == 0):
            torch.save(model.state_dict(), config.model_path)
            print("Model saved after {} epochs".format(epoch))
            plot_validation_set(valid_loader=valid_loader,
                                model=model,
                                epoch=epoch)
            print("Validation plot saved")