Exemplo n.º 1
0
def main():
    args = get_args()
    logdir = 'log/{}-emb{}-{}layers-{}resblk-lr{}-wd{}-maxlen{}-alpha10-margin{}'\
             '{}class-{}sample-{}selector'\
             .format(args.name, 
                     args.embedding_size,
                     args.layers,
                     args.resblk,
                     args.lr, 
                     args.wd, 
                     args.maxlen,
                     args.margin,
                     args.n_classes,
                     args.n_samples,
                     args.selection)
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    resblock = []
    for i in range(args.layers):
        resblock.append(args.resblk)

    if args.train:
        logger = Logger(logdir)
        if not os.path.exists(args.trainfeature):
            os.mkdir(args.trainfeature)
            extractFeature(args.training-dataset, args.trainfeature)
        trainset = DeepSpkDataset(args.trainfeature, args.maxlen)
        pre_loader = DataLoader(trainset, batch_size = 128, shuffle = True, num_workers = 8)
        train_batch_sampler = BalancedBatchSampler(trainset.train_labels, 
                                                   n_classes = args.n_classes, 
                                                   n_samples = args.n_samples)
        kwargs = {'num_workers' : 1, 'pin_memory' : True}
        online_train_loader = torch.utils.data.DataLoader(trainset, 
                                                          batch_sampler=train_batch_sampler,
                                                          **kwargs) 
        margin = args.margin
        
        embedding_net = EmbeddingNet(resblock,  
                                     embedding_size = args.embedding_size,
                                     layers = args.layers)
        model = DeepSpeaker(embedding_net, trainset.get_num_class())
        device = torch.device('cuda:0')
        model.to(device) # 要在初始化optimizer之前把model转换到GPU上,这样初始化optimizer的时候也是在GPU上
        optimizer = optim.SGD(model.embedding_net.parameters(), 
                              lr = args.lr, 
                              momentum = 0.99,
                              weight_decay = args.wd)
        start_epoch = 0
        if args.resume:
            if os.path.isfile(args.resume):
                print('=> loading checkpoint {}'.format(args.resume))
                checkpoint = torch.load(args.resume)
                start_epoch = checkpoint['epoch']
                model.embedding_net.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
            else:
                print('=> no checkpoint found at {}'.format(args.resume))

        pretrain_epoch = args.pretrain_epoch
        
        if args.selection == 'randomhard':
            selector = RandomNegativeTripletSelector(margin)
        if args.selection == 'hardest':
            selector = HardestNegativeTripletSelector(margin)
        if args.selection == 'semihard':
            selector = SemihardNegativeTripletSelector(margin)
        if args.selection == 'all':
            print('warning : select all triplet may take very long time')
            selector = AllTripletSelector()

        loss_fn = OnlineTripletLoss(margin, selector)   
        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size = args.lr_adjust_step,
                                        gamma = args.lr_decay,
                                        last_epoch = -1) 
        n_epochs = args.n_epochs
        log_interval = 50
        fit(online_train_loader,
            pre_loader,
            model,
            loss_fn,
            optimizer,
            scheduler,
            pretrain_epoch,
            n_epochs,
            True,
            device,
            log_interval,
            log_dir = logdir,
            eval_path = args.evalfeature,
            logger = logger,
            metrics = [AverageNonzeroTripletsMetric()],
            evaluatee = args.eval,
            start_epoch = start_epoch)
    else:
        if not os.path.exists(args.testfeature):
            os.mkdir(args.testfeature)
            extractFeature(args.test-dataset, args.testfeature)
        model = EmbeddingNet(resblock,  
                             embedding_size = args.embedding_size,
                             layers = args.layers)
        model.cpu()
        if args.model:
            if os.path.isfile(args.model):
                print('=> loading checkpoint {}'.format(args.model))
                checkpoint = torch.load(args.model)
                model.load_state_dict(checkpoint['state_dict'])
            else:
                print('=> no checkpoint found at {}'.format(args.model))
        thres = np.loadtxt(logdir + '/thres.txt')
        acc = np.loadtxt(logdir + '/acc.txt')
        idx = np.argmax(acc)
        best_thres = thres[idx]
        predict(model, args.testfeature, best_thres)
        n_sampE, IE_dim = X_trainE.shape

        h_dim = hdm
        Z_in = h_dim
        marg = mrg
        lrE = lre
        epoch = epch

        costtr = []
        auctr = []
        costts = []
        aucts = []

        triplet_selector = RandomNegativeTripletSelector(marg)
        triplet_selector2 = AllTripletSelector()

        class AEE(nn.Module):
            def __init__(self):
                super(AEE, self).__init__()
                self.EnE = torch.nn.Sequential(
                    nn.Linear(IE_dim, h_dim),
                    nn.BatchNorm1d(h_dim),
                    nn.ReLU(),
                    nn.Dropout())
            def forward(self, x):
                output = self.EnE(x)
                return output  

        class OnlineTriplet(nn.Module):
            def __init__(self, marg, triplet_selector):
for param in model.classifier.parameters():
    param.requires_grad = True
model_trained_mobilenet = model
print("All trainable parameters of model are")
for name, param in model_trained_mobilenet.named_parameters():
    if param.requires_grad:
        print(name, param.shape)
contact_model = multi_task_model_classification(model_trained_mobilenet)
contact_less_model = multi_task_model_classification(model_trained_mobilenet)

# In[22]:

if cuda:
    contact_less_model.cuda()
    contact_model.cuda()
AllTripletSelector1 = AllTripletSelector()
loss_fn = OnlineTripletLoss(margin, AllTripletSelector1)
lr = 1e-3
contact_optimizer = optim.Adam(contact_model.parameters(),
                               lr=lr,
                               weight_decay=1e-4)
contact_less_optimizer = optim.Adam(contact_less_model.parameters(),
                                    lr=lr,
                                    weight_decay=1e-4)
contact_scheduler = lr_scheduler.StepLR(contact_optimizer,
                                        8,
                                        gamma=0.1,
                                        last_epoch=-1)
contact_less_scheduler = lr_scheduler.StepLR(contact_less_optimizer,
                                             8,
                                             gamma=0.1,