コード例 #1
0
    def val():
        """ Run model through validation dataset.
        """
        loss_meter = AverageMeter()
        acc_meter = AccuracyMeter()
        pbar = tqdm(total=len(val_loader))
        student.languageModel.replace_embeddings(val_vocab)
        student.eval()

        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                loss, alphas, logits, labels, batch_size, all_labels, _ = compute_loss(
                    batch)
                loss += compute_self_att_loss(alphas)
                acc_meter.update(logits,
                                 all_labels,
                                 cuda=args.cuda,
                                 vision=True)
                loss_meter.update(loss.data.item(), batch_size)
                pbar.update()
        pbar.close()
        print('====> Validation set loss: {:.4f}'.format(loss_meter.avg))
        print('Training Objective ({}) Validation Accuracies:'.format(
            kwargs['train_obj']))
        acc_meter.print()
        return loss_meter.avg, acc_meter
コード例 #2
0
    def val():
        """ Run model through validation dataset.
        """
        loss_meter = AverageMeter()
        acc_meter = AccuracyMeter()
        pbar = tqdm(total=len(val_loader))
        student.eval()

        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                loss, alphas, logits, batch_size = compute_loss(batch)
                loss += compute_self_att_loss(alphas)

                # Update progress
                acc_meter.update(logits,
                                 batch_size,
                                 refGame=True,
                                 cuda=args.cuda)
                loss_meter.update(loss.data.item(), batch_size)
                pbar.update()
        pbar.close()
        print('====> Validation set loss: {:.4f}'.format(loss_meter.avg))
        print('Validation Accuracies:')
        acc_meter.print(True)
        return loss_meter.avg, acc_meter
コード例 #3
0
    def val():
        """ Run model through validation dataset.
        """
        loss_meter = AverageMeter()
        acc_meter = AccuracyMeter()
        pbar = tqdm(total=len(val_loader))
        val_loader.init_epoch()
        student.eval()

        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                loss, alphas, logits = compute_loss(batch)
                loss += compute_self_att_loss(alphas)
                acc_meter.update(logits, batch)
                loss_meter.update(loss.data.item(), batch.batch_size)
                pbar.update()
        pbar.close()
        print('====> Validation set loss: {:.4f}'.format(loss_meter.avg))
        print('Training Objective ({}) Validation Accuracies:'.format(kwargs['train_obj']))
        acc_meter.print()
        return loss_meter.avg, acc_meter
コード例 #4
0
    def train(epoch=-1, backprop=True):
        """ Train model for a single epoch.
        """
        # Data loading & progress visualization
        loss_meter = AverageMeter()
        acc_meter = AccuracyMeter()
        pbar = tqdm(total=len(train_loader))

        student.languageModel.replace_embeddings(train_vocab)
        if backprop:
            student.train()
        else:
            student.eval()

        for batch_idx, batch in enumerate(train_loader):
            optimizer.zero_grad()
            loss, alphas, logits, labels, batch_size, all_labels, _ = compute_loss(
                batch)
            loss += compute_self_att_loss(alphas)

            if backprop:
                loss.backward()
                optimizer.step()

            # Update progress
            acc_meter.update(logits, all_labels, cuda=args.cuda, vision=True)
            loss_meter.update(loss.data.item(), batch_size)
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * batch_size, len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss_meter.avg))
            pbar.update()
        pbar.close()
        print('====> Epoch: {} Average loss: {:.4f}'.format(
            epoch, loss_meter.avg))
        print('Training Objective ({}) Train Accuracies:'.format(
            kwargs['train_obj']))
        acc_meter.print()
        return loss_meter.avg
コード例 #5
0
    def train(epoch=-1, backprop=True):
        """ Train model for a single epoch.
        """
        # Data loading & progress visualization
        loss_meter = AverageMeter()
        acc_meter = AccuracyMeter()
        pbar = tqdm(total=len(train_loader))
        train_loader.init_epoch()
        
        if backprop:
            student.train()
        else:
            student.eval()

        for batch_idx, batch in enumerate(train_loader):
            optimizer.zero_grad()
            batch_size = batch.batch_size
            loss, alphas, logits = compute_loss(batch)
            loss += compute_self_att_loss(alphas)

            if backprop:
                loss.backward()
                optimizer.step()

            # Update progress
            acc_meter.update(logits, batch, refGame=True, cuda=args.cuda)
            loss_meter.update(loss.data.item(), batch_size)
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * batch.batch_size, len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss_meter.avg))
            pbar.update()
        pbar.close()
        print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, loss_meter.avg))
        print('Train Accuracies:')
        acc_meter.print(True)
        return loss_meter.avg
コード例 #6
0
    def run_epoch(task, concept_loader, ref_loader, train, epoch, backprop):
        """ Run model through 1 epoch.
        """
        if train:
            epoch_type = 'Training'
        else:
            epoch_type = 'Validation'

        # Loss & Accuracy Meters
        concept_loss_meter = AverageMeter()
        concept_acc_meter = AccuracyMeter()
        ref_loss_meter = AverageMeter()
        ref_acc_meter = AccuracyMeter()
        loss_meter = AverageMeter()

        # Data loader init + visualization
        if task == 'ref':
            pbar = tqdm(total=len(ref_loader))
        elif task == 'concept':
            pbar = tqdm(total=len(concept_loader))
        else:
            pbar = tqdm(total=len(ref_loader) + len(concept_loader))

        # Set model mode
        if backprop:
            student.train()
        else:
            student.eval()

        # Train model on two objectives
        ref_batch_idx = -1
        ref_iterator = iter(ref_loader)
        ref_batches_exist = True
        for concept_batch_idx, concept_batch in enumerate(concept_loader):
            try:
                ref_batch = next(ref_iterator)
                ref_batch_idx += 1
            except StopIteration:
                ref_batches_exist = False
                if task == 'ref':
                    break

            # Compute loss + backprop
            optimizer.zero_grad()
            if task == 'ref':
                ref_loss, ref_alphas, ref_logits, ref_batch_size = compute_ref_loss(ref_batch)
                loss = ref_loss + compute_self_att_loss(ref_alphas)
            elif task == 'concept':
                concept_loss, concept_alphas, concept_logits, concept_labels, concept_batch_size, concept_all_labels, _ = compute_concept_loss(concept_batch) 
                loss = concept_loss + compute_self_att_loss(concept_alphas)
            else:
                ref_loss, ref_alphas, ref_logits, ref_batch_size = compute_ref_loss(ref_batch)
                loss, concept_alphas, concept_logits, concept_labels, concept_batch_size, concept_all_labels, _ = compute_concept_loss(concept_batch) 
                loss = concept_loss * kwargs['concept_loss_weight'] + ref_loss * (1.0 - kwargs['concept_loss_weight'])
                loss += compute_self_att_loss(concept_alphas) + compute_self_att_loss(ref_alphas)

            if backprop:
                loss.backward()
                optimizer.step()

            # Update concept learning progress
            if task == 'concept' or task == 'multi':
                concept_acc_meter.update(concept_logits, concept_batch, refGame=False, cuda=kwargs['cuda'])
                concept_loss_meter.update(concept_loss.data.item(), concept_batch_size)
                if concept_batch_idx % args.log_interval == 0:
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tConcept Loss: {:.6f}'.format(
                        epoch, concept_batch_idx * concept_batch_size, len(concept_loader),
                        100. * concept_batch_idx / len(concept_loader), concept_loss_meter.avg))
                pbar.update()

            # Reference game learning progress
            if ref_batches_exist and (task == 'ref' or task == 'multi'):
                ref_acc_meter.update(ref_logits, ref_batch, refGame=True, cuda=kwargs['cuda'])
                ref_loss_meter.update(ref_loss.data.item(), ref_batch_size)
                if ref_batch_idx % args.log_interval == 0:
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tReference Loss: {:.6f}'.format(
                        epoch, ref_batch_idx * ref_batch_size, len(ref_loader),
                        100. * ref_batch_idx / len(ref_loader), ref_loss_meter.avg))
                pbar.update()

            # Joint learning progress\
            if task == 'multi':
                if ref_batches_exist:
                    loss_meter.update(loss.data.item(), ref_batch_size + concept_batch_size)
                else:
                    loss_meter.update(loss.data.item(), concept_batch_size)
                if (ref_batch_idx + concept_batch_idx) % args.log_interval == 0:
                    print('Train Epoch: {}\tJoint Loss: {:.6f}'.format(
                        epoch, loss_meter.avg))

        
        # Train model on reference objective, if there outstanding batches
        if task == 'ref' or task == 'multi':
            while (ref_batches_exist):
                try:
                    ref_batch = next(ref_iterator)
                    ref_batch_idx += 1
                except StopIteration:
                    break

                # Compute loss + backprop
                optimizer.zero_grad()
                ref_loss, ref_alphas, ref_logits, ref_batch_size = compute_ref_loss(ref_batch)
                loss += ref_loss * (1.0 - kwargs['concept_loss_weight']) + compute_self_att_loss(ref_alphas)
                if backprop:
                    loss.backward()
                    optimizer.step()

                # Reference game learning progress
                ref_acc_meter.update(ref_logits, ref_batch, refGame=True, cuda=kwargs['cuda'])
                ref_loss_meter.update(ref_loss.data.item(), ref_batch_size)
                if ref_batch_idx % args.log_interval == 0:
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tReference Loss: {:.6f}'.format(
                        epoch, ref_batch_idx * ref_batch_size, len(ref_loader),
                        100. * ref_batch_idx / len(ref_loader), ref_loss_meter.avg))
                pbar.update()
        
        pbar.close()
        if task == 'concept' or task == 'multi':
            print('====> Epoch: {} Average Concept loss: {:.4f}'.format(epoch, concept_loss_meter.avg))
        if task == 'ref' or task == 'multi':
            print('====> Epoch: {} Average Reference loss: {:.4f}'.format(epoch, ref_loss_meter.avg))
        if task == 'multi':
            print('====> Epoch: {} Average Joint loss: {:.4f}'.format(epoch, loss_meter.avg))
        if task != 'ref':
            print('Concept Learning Objective ({}) Concept {} Accuracies:'.format(epoch_type, kwargs['concept_train_obj']))
            concept_acc_meter.print()
        if task != 'concept':
            print('Reference {} Accuracies:'.format(epoch_type))
            ref_acc_meter.print(ground_truth_only=True)
        return loss_meter.avg, concept_loss_meter.avg, ref_loss_meter.avg, concept_acc_meter, ref_acc_meter