Exemple #1
0
    def train(self, dataset: BaseADDataset, net: CVDDNet):
        logger = logging.getLogger()

        # Set device for network
        net = net.to(self.device)

        # Get number of attention heads
        n_attention_heads = net.n_attention_heads

        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size,
                                          num_workers=self.n_jobs_dataloader)

        # Initialize context vectors
        net.c.data = torch.from_numpy(
            initialize_context_vectors(
                net, train_loader, self.device)[np.newaxis, :]).to(self.device)

        # Set parameters and optimizer (Adam optimizer for now)
        parameters = filter(lambda p: p.requires_grad, net.parameters())
        optimizer = optim.Adam(parameters,
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Training
        logger.info('Starting training...')
        start_time = time.time()
        net.train()
        alpha_i = 0
        for epoch in range(self.n_epochs):

            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' %
                            float(scheduler.get_lr()[0]))

            if epoch in self.alpha_milestones:
                net.alpha = float(self.alphas[alpha_i])
                logger.info('  Temperature alpha scheduler: new alpha is %g' %
                            net.alpha)
                alpha_i += 1

            epoch_loss = 0.0
            n_batches = 0
            att_matrix = np.zeros((n_attention_heads, n_attention_heads))
            dists_per_head = ()
            epoch_start_time = time.time()
            for data in train_loader:
                _, text_batch, _, _ = data

                text_batch = text_batch.to(self.device)
                # text_batch.shape = (sentence_length, batch_size)

                # Zero the network parameter gradients
                optimizer.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize

                # forward pass
                cosine_dists, context_weights, A = net(text_batch)
                # The cosine_dists are the scores that are used at test time to compute AUC (for context_dist_mean)

                scores = context_weights * cosine_dists
                # scores.shape = (batch_size, n_attention_heads)
                # A.shape = (batch_size, n_attention_heads, sentence_length)

                # get orthogonality penalty: P = (CCT - I)
                I = torch.eye(n_attention_heads).to(self.device)
                CCT = net.c @ net.c.transpose(1, 2)
                P = torch.mean((CCT.squeeze() - I)**2)

                # compute loss
                loss_P = self.lambda_p * P
                loss_emp = torch.mean(torch.sum(scores, dim=1))
                loss = loss_emp + loss_P

                # Get scores
                dists_per_head += (cosine_dists.cpu().data.numpy(), )

                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    net.parameters(),
                    0.5)  # clip gradient norms in [-0.5, 0.5]
                optimizer.step()

                # Get attention matrix
                AAT = A @ A.transpose(1, 2)
                att_matrix += torch.mean(AAT, 0).cpu().data.numpy()

                epoch_loss += loss.item()
                n_batches += 1

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info(
                f'| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s '
                f'| Train Loss: {epoch_loss / n_batches:.6f} |')

            # Save distances per attention head and attention matrix
            self.train_dists = np.concatenate(dists_per_head)
            self.train_att_matrix = att_matrix / n_batches
            self.train_att_matrix = self.train_att_matrix.tolist()

        self.train_time = time.time() - start_time

        # Get context vectors
        self.c = np.squeeze(net.c.cpu().data.numpy())
        self.c = self.c.tolist()

        # Get top words per context
        self.train_top_words = get_top_words_per_context(
            dataset.train_set, dataset.encoder, net, train_loader, self.device)

        # Log results
        logger.info('Training Time: {:.3f}s'.format(self.train_time))
        logger.info('Finished training.')

        # now train our unsuprisk
        if self.withUnsuprisk:

            self.coslin = [
                cosinelin.Cosinelin(self.inputsize, 1)
                for i in range(glob.nheads)
            ]
            for i in range(glob.nheads):
                self.coslin[i].setWeightsFromCentroids(
                    net.c.squeeze()[i].detach().numpy())

            dataset.train_set = Subset(dataset.train_set0, dataset.alltrain)
            for i, row in enumerate(dataset.train_set):
                row['index'] = i
            tmptrloader, _ = dataset.loaders(
                batch_size=self.batch_size, num_workers=self.n_jobs_dataloader)

            trainbatches = []
            for data in tmptrloader:
                idx, text_batch, label_batch, _ = data
                trainbatches.append((idx, text_batch, label_batch))
            # split all batches into a list of singletons
            ltext, llab = [], []
            for idx, text_batch, label_batch in trainbatches:
                # text_batch.tolist() gives a list of list of floats (python)
                # in text batch, we have (uttlen, batch)
                ltext += text_batch.transpose(0, 1).tolist()
                llab += label_batch.tolist()
            print("lens %d %d" % (len(ltext), len(llab)))
            posclassidxs = [i for i in range(len(ltext)) if llab[i] == 0]
            negclassidxs = [i for i in range(len(ltext)) if llab[i] == 1]
            print("n0 n1", len(posclassidxs), len(negclassidxs))

            # add (eventually) some outliers in the training set
            random.shuffle(negclassidxs)
            ncorr = min(len(negclassidxs),
                        int(glob.traincorr * float(len(posclassidxs))))
            trainidx = posclassidxs + negclassidxs[0:ncorr]
            print("trainidx", trainidx)

            # and prepare a dev set
            try:
                with open("devdata." + str(glob.nc) + "." + str(glob.nheads),
                          "rb") as devfile:
                    devidx = pickle.load(devfile)
            except:
                random.shuffle(negclassidxs)
                ncorr = min(len(negclassidxs),
                            int(glob.devcorr * float(len(posclassidxs))))
                print("devncorr " + str(ncorr), glob.traincorr,
                      len(negclassidxs), len(posclassidxs))
                devidx = posclassidxs + negclassidxs[0:ncorr]
                with open("devdata." + str(glob.nc) + "." + str(glob.nheads),
                          "wb") as devfile:
                    pickle.dump(devidx, devfile, pickle.HIGHEST_PROTOCOL)
            ndev1 = sum([llab[i] for i in devidx])
            ndev0 = len(devidx) - ndev1
            print("DEV %d %d" % (ndev0, ndev1))

            if True:
                # also prepare the test set for our unsuprisk
                _, tstloader = dataset.loaders(
                    batch_size=self.batch_size,
                    num_workers=self.n_jobs_dataloader)
                testbatches = []
                for data in tstloader:
                    idx, text_batch, label_batch, _ = data
                    testbatches.append((idx, text_batch, label_batch))
                testtext = []
                for idx, text_batch, label_batch in testbatches:
                    testtext += text_batch.transpose(0, 1).tolist()

            # test the original CVDD model on the test set:
            self.test(dataset, net)
            self.testtrain(dataset, net)
            print("detson from there on, we start unsuprisk training")

            n_attention_heads = net.n_attention_heads
            logger.info('Starting training unsup...')
            n_batches = 0
            net.eval()
            uttembeds, scores = [], []
            with torch.no_grad():
                for data in [ltext[i] for i in trainidx]:
                    data = torch.tensor(data)
                    text_batch = data.unsqueeze(0).to(self.device)
                    # we want (utt len, batch)
                    text_batch = text_batch.transpose(0, 1)

                    # forward pass
                    cosine_dists, context_weights, A = net(text_batch)
                    uttembeds.append(net.M.detach().numpy())
                    tsc = torch.mean(cosine_dists, dim=1)
                    scores.append(tsc.item())
                    n_batches += 1

            print("text_batch", text_batch.size())
            print("cosine_dists", cosine_dists.size())
            xscores = torch.tensor(scores).to(self.device)
            print("xscores", xscores.size())
            # xscores = (batch,)
            print(xscores[0:3])
            uttemb = torch.tensor(uttembeds).to(self.device)
            print("uttembb", uttemb.size())
            # uttemb = (batch, 1, 3, 300)

            if False:
                # save training + test corpus
                with open("traindata." + str(glob.nc) + "." + str(glob.nheads),
                          "wb") as devfile:
                    pickle.dump(uttemb.detach().cpu().numpy(), devfile,
                                pickle.HIGHEST_PROTOCOL)
                with open("coslin." + str(glob.nc) + "." + str(glob.nheads),
                          "ab") as devfile:
                    pickle.dump(glob.nheads, devfile, pickle.HIGHEST_PROTOCOL)

            lossfct = UnsupRisk(glob.p0, self.device)
            detparms = [l.parameters() for l in self.coslin]
            optimizerRisk = optim.Adam(chain.from_iterable(detparms),
                                       lr=glob.lr,
                                       weight_decay=self.weight_decay)
            print("starting unsup epochs %d" % (glob.nep, ))
            for epoch in range(glob.nep):
                self.unsupepoch = epoch
                optimizerRisk.zero_grad()
                for l in self.coslin:
                    l.train()
                # uttemb contains utt embeddings (fullbatch,nheads,hidden_size)
                cospred = []
                cosmean = torch.zeros((uttemb.size(0), ))
                assert net.c.size(0) == 1  # why is there this 1 dim ?
                for i in range(glob.nheads):
                    mm = uttemb.transpose(0, 2)[i].squeeze()
                    conehead = self.coslin[i](mm)
                    cospred.append(conehead)
                    cosmean += conehead
                cosmean /= float(glob.nheads)

                loss = lossfct(cosmean)
                if not (float('-inf') < float(loss.item()) < float('inf')):
                    print("WARNING %f at unsup epoch %d" %
                          (loss.item(), epoch))
                    # nan or inf
                    continue
                loss.backward()
                optimizerRisk.step()

                if False:
                    # save the weights at every epoch
                    detparms = [l.parameters() for l in self.coslin]
                    with open(
                            "coslin." + str(glob.nc) + "." + str(glob.nheads),
                            "ab") as devfile:
                        for pr in detparms:
                            for p in pr:
                                pickle.dump(p.detach().cpu().numpy(), devfile,
                                            pickle.HIGHEST_PROTOCOL)

                if True:
                    # compute unsup loss on DEV
                    uttembeds = []
                    with torch.no_grad():
                        for data in [ltext[i] for i in devidx]:
                            data = torch.tensor(data)
                            text_batch = data.unsqueeze(0).to(self.device)
                            text_batch = text_batch.transpose(0, 1)
                            cosine_dists, context_weights, A = net(text_batch)
                            selfattemb = net.M.detach().numpy()
                            noise = 0.0 * (np.random.rand(*selfattemb.shape) -
                                           0.5)
                            noise2 = np.multiply(selfattemb, noise)
                            selfattemb += noise2
                            uttembeds.append(selfattemb)
                    uttemb = torch.tensor(uttembeds).to(self.device)
                    for l in self.coslin:
                        l.eval()
                    cospred = []
                    cosmean = torch.zeros((uttemb.size(0), ))
                    for i in range(glob.nheads):
                        mm = uttemb.transpose(0, 2)[i].squeeze()
                        conehead = self.coslin[i](mm)
                        cospred.append(conehead)
                        cosmean += conehead
                    cosmean /= float(glob.nheads)
                    devloss = lossfct(cosmean)
                    if not (float('-inf') < float(devloss.item()) <
                            float('inf')):
                        print("WARNING %f at unsup epoch DEV %d" %
                              (devloss.item(), epoch))
                        # nan or inf

                if True:
                    # compute unsup loss on TEST
                    uttembeds = []
                    with torch.no_grad():
                        for data in testtext:
                            data = torch.tensor(data)
                            text_batch = data.unsqueeze(0).to(self.device)
                            text_batch = text_batch.transpose(0, 1)
                            cosine_dists, context_weights, A = net(text_batch)
                            uttembeds.append(net.M.detach().numpy())
                    uttemb = torch.tensor(uttembeds).to(self.device)
                    for l in self.coslin:
                        l.eval()
                    cospred = []
                    cosmean = torch.zeros((uttemb.size(0), ))
                    for i in range(glob.nheads):
                        mm = uttemb.transpose(0, 2)[i].squeeze()
                        conehead = self.coslin[i](mm)
                        cospred.append(conehead)
                        cosmean += conehead
                    cosmean /= float(glob.nheads)
                    testloss = lossfct(cosmean)
                    if not (float('-inf') < float(testloss.item()) <
                            float('inf')):
                        print("WARNING %f at unsup epoch TEST %d" %
                              (testloss.item(), epoch))
                        # nan or inf

                print(
                    "unsuprisk epoch %d trainloss %f devloss %f testloss %f" %
                    (epoch, loss.item(), devloss.item(), testloss.item()))
                self.test(dataset, net)
                self.testtrain(dataset, net)

        return net
Exemple #2
0
    def train(self, dataset: BaseADDataset, net: CVDDNet):
        logger = logging.getLogger()

        # Set device for network
        net = net.to(self.device)

        # Get number of attention heads
        n_attention_heads = net.n_attention_heads

        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size,
                                          num_workers=self.n_jobs_dataloader)

        # Initialize context vectors
        net.c.data = torch.from_numpy(
            initialize_context_vectors(net, train_loader,
                                       self.device)[np.newaxis, :])

        # Set parameters and optimizer (Adam optimizer for now)
        parameters = filter(lambda p: p.requires_grad, net.parameters())
        optimizer = optim.Adam(parameters,
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Training
        logger.info('Starting training...')
        start_time = time.time()
        net.train()
        alpha_i = 0
        for epoch in range(self.n_epochs):

            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' %
                            float(scheduler.get_lr()[0]))

            if epoch in self.alpha_milestones:
                net.alpha = float(self.alphas[alpha_i])
                logger.info('  Temperature alpha scheduler: new alpha is %g' %
                            net.alpha)
                alpha_i += 1

            epoch_loss = 0.0
            n_batches = 0
            att_matrix = np.zeros((n_attention_heads, n_attention_heads))
            dists_per_head = ()
            epoch_start_time = time.time()
            for data in train_loader:
                _, text_batch, _, _ = data
                text_batch = text_batch.to(self.device)
                # text_batch.shape = (sentence_length, batch_size)

                # Zero the network parameter gradients
                optimizer.zero_grad()

                # Update network parameters via backpropagation: forward + backward + optimize

                # forward pass
                cosine_dists, context_weights, A = net(text_batch)
                scores = context_weights * cosine_dists
                # scores.shape = (batch_size, n_attention_heads)
                # A.shape = (batch_size, n_attention_heads, sentence_length)

                # get orthogonality penalty: P = (CCT - I)
                I = torch.eye(n_attention_heads).to(self.device)
                CCT = net.c @ net.c.transpose(1, 2)
                P = torch.mean((CCT.squeeze() - I)**2)

                # compute loss
                loss_P = self.lambda_p * P
                loss_emp = torch.mean(torch.sum(scores, dim=1))
                loss = loss_emp + loss_P

                # Get scores
                dists_per_head += (cosine_dists.cpu().data.numpy(), )

                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    net.parameters(),
                    0.5)  # clip gradient norms in [-0.5, 0.5]
                optimizer.step()

                # Get attention matrix
                AAT = A @ A.transpose(1, 2)
                att_matrix += torch.mean(AAT, 0).cpu().data.numpy()

                epoch_loss += loss.item()
                n_batches += 1

            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            logger.info(
                f'| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s '
                f'| Train Loss: {epoch_loss / n_batches:.6f} |')

            # Save distances per attention head and attention matrix
            self.train_dists = np.concatenate(dists_per_head)
            self.train_att_matrix = att_matrix / n_batches
            self.train_att_matrix = self.train_att_matrix.tolist()

        self.train_time = time.time() - start_time

        # Get context vectors
        self.c = np.squeeze(net.c.cpu().data.numpy())
        self.c = self.c.tolist()

        # Get top words per context
        self.train_top_words = get_top_words_per_context(
            dataset.train_set, dataset.encoder, net, train_loader, self.device)

        # Log results
        logger.info('Training Time: {:.3f}s'.format(self.train_time))
        logger.info('Finished training.')

        return net