예제 #1
0
def train(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    # progress = ProgressMeter(
    # 	len(train_loader),
    # 	[batch_time, data_time, losses, top1, top5],
    # 	prefix="Epoch: [{}]".format(epoch))

    # stliu: design new pregress
    epoch_time = AverageMeter('Epoch Time', ':6.3f')
    progress = ProgressMeter(len(train_loader),
                             [epoch_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()

    for i, (images, _) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images[0] = images[0].cuda(args.gpu, non_blocking=True)
            images[1] = images[1].cuda(args.gpu, non_blocking=True)

        # compute output
        output, target = model(im_q=images[0], im_k=images[1])
        loss = criterion(output, target)

        # acc1/acc5 are (K+1)-way contrast classifier accuracy
        # measure accuracy and record lossa
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images[0].size(0))
        top1.update(acc1[0], images[0].size(0))
        top5.update(acc5[0], images[0].size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        epoch_time.update(batch_time.avg * len(train_loader))
        end = time.time()

        if (i + 1) % args.print_freq == 0:  # stliu: change i to i+1
            progress.display(i)

    return losses.avg
예제 #2
0
def train(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    param_groups = optimizer.param_groups[0]
    curr_lr = param_groups["lr"]

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images = images.to(args.device, non_blocking=True)
        if torch.cuda.is_available():
            target = target.to(args.device, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            epoch_msg = progress.get_message(i)
            epoch_msg += ("\tLr  {:.4f}".format(curr_lr))
            print(epoch_msg)

        if i % args.log_freq == 0:
            args.log_file.write(epoch_msg + "\n")
예제 #3
0
def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')

    progress = ProgressMeter(len(val_loader), [batch_time, losses, top1, top5],
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):

            if args.gpu is not None:
                images = images.to(args.device, non_blocking=True)
            if torch.cuda.is_available():
                target = target.to(args.device, non_blocking=True)

            # compute outputs
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                epoch_msg = progress.get_message(i)
                print(epoch_msg)

        # TODO: this should also be done with the ProgressMeter
        # print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
        #       .format(top1=top1, top5=top5))

        epoch_msg = '----------- Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} -----------'.format(
            top1=top1, top5=top5)

        print(epoch_msg)

        args.log_file.write(epoch_msg + "\n")

    return top1.avg
예제 #4
0
def train_finetune(train_loader, model, criterion, optimizer, epoch, args,
                   scaler):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    end = time.time()

    is_main = not args.multiprocessing_distributed or (
        args.multiprocessing_distributed
        and args.rank % torch.cuda.device_count() == 0)

    for i, (images, target) in enumerate(train_loader):
        optimizer.zero_grad()

        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
        if torch.cuda.is_available():
            target = target.cuda(args.gpu, non_blocking=True)

        with autocast():
            # compute output
            output = model(images)
            loss = criterion(output, target)

        scaler.scale(loss).backward()

        # 去掉剪枝梯度
        for m in model.modules():
            if isinstance(m, nn.Conv2d):
                weight_copy = m.weight.data.abs().clone()
                mask = weight_copy.gt(0).float().cuda()
                m.weight.grad.data.mul_(mask)

        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
예제 #5
0
def validate(net, epoch, data_loader, args):

    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4f')
    top1 = AverageMeter('Accuracy', ':4.2f')
    progress = ProgressMeter(
        len(data_loader),
        [batch_time, data_time, losses, top1],
        prefix="Epoch (Valid LR {:6.4f}): [{}] ".format(0, epoch))

    net.eval()

    with torch.no_grad():
        tic = time.time()
        for batch_idx, (data, target) in enumerate(data_loader):
        
            data, target = data.to(args.device, non_blocking=True), target.to(args.device, non_blocking=True)

            data_time.update(time.time() - tic)

            output = net(data)
            loss = F.cross_entropy(output, target)

            acc = accuracy(output, target)
            losses.update(loss.item(), data.size(0))
            top1.update(acc[0].item(), data.size(0))

            batch_time.update(time.time() - tic)
            tic = time.time()
            
            if (batch_idx+1) % args.disp_iter == 0 or (batch_idx+1) == len(data_loader):
                epoch_msg = progress.get_message(batch_idx+1)
                print(epoch_msg)

                args.log_file.write(epoch_msg + "\n")
    
        print('-------- Mean Accuracy {top1.avg:.3f} --------'.format(top1=top1))

    return top1.avg
예제 #6
0
def train(net, optimizer, epoch, data_loader, args):

    learning_rate = optimizer.param_groups[0]["lr"]

    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4f')
    top1 = AverageMeter('Accuracy', ':4.2f')
    progress = ProgressMeter(
        len(data_loader),
        [batch_time, data_time, losses, top1],
        prefix="Epoch (Train LR {:6.4f}): [{}] ".format(learning_rate, epoch))

    net.train()

    tic = time.time()
    for batch_idx, (data, target) in enumerate(data_loader):
     
        data, target = data.to(args.device, non_blocking=True), target.to(args.device, non_blocking=True)

        data_time.update(time.time() - tic)

        optimizer.zero_grad()
        output = net(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        acc = accuracy(output, target)
        losses.update(loss.item(), data.size(0))
        top1.update(acc[0].item(), data.size(0))

        batch_time.update(time.time() - tic)
        tic = time.time()
        
        if (batch_idx+1) % args.disp_iter == 0 or (batch_idx+1) == len(data_loader):
            epoch_msg = progress.get_message(batch_idx+1)
            print(epoch_msg)

            args.log_file.write(epoch_msg + "\n")
def train(model, dataset, optimizer, criterion, epoch, args, data_start_index):
    model.train()
    if data_start_index == 0:
        dataset.shuffle('train', seed=epoch + args.seed)
    if args.epoch_max_len is not None:
        data_end_index = min(data_start_index + args.epoch_max_len,
                             len(dataset.splits['train']))
        loader = dataset.loader('train',
                                num_workers=args.num_workers,
                                indices=list(
                                    range(data_start_index, data_end_index)))
        data_start_index = data_end_index if data_end_index < len(
            dataset.splits['train']) else 0
    else:
        loader = dataset.loader('train', num_workers=args.num_workers)
    loss_meter = AverageMeter('loss', ':6.4f')
    total_length = len(loader)
    progress = ProgressMeter(total_length, [loss_meter], prefix='Training: ')
    for batch_num, batch in enumerate(tqdm(loader, total=len(loader))):
        batch = [tensor.to(args.device) for tensor in batch]
        inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch
        if args.task not in ['formality', 'iambic']:
            if not args.debug and len(
                    inputs) != args.batch_size:  # it'll screw up the bias...?
                continue
        scores = model(inputs,
                       lengths,
                       future_words,
                       log_probs,
                       syllables_to_go,
                       future_word_num_syllables,
                       rhyme_group_index,
                       run_classifier=True)
        if args.task == 'formality':  # we're learning for all positions at once. scores are batch x seq
            expanded_labels = classification_targets.unsqueeze(1).expand(
                -1, scores.shape[1])  # batch x seq
            length_mask = pad_mask(lengths).permute(1, 0)  # batch x seq
            loss = criterion(
                scores.flatten()[length_mask.flatten() == 1],
                expanded_labels.flatten().float()[length_mask.flatten() == 1])
        elif args.task in ['iambic', 'newline']:
            use_indices = classification_targets.flatten() != -1
            loss = criterion(
                scores.flatten()[use_indices],
                classification_targets.flatten().float()[use_indices])
        else:  # topic, rhyme
            loss = criterion(scores.flatten(), labels.flatten().float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_meter.update(loss.detach(), len(labels))
        if batch_num % args.train_print_freq == 0:
            progress.display(batch_num)
    progress.display(total_length)
    return data_start_index
예제 #8
0
def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(val_loader), [batch_time, losses, top1, top5],
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    is_main = not args.multiprocessing_distributed or (
        args.multiprocessing_distributed
        and args.rank % torch.cuda.device_count() == 0)

    prune_by_mask(model)

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            if torch.cuda.is_available():
                target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # if i % args.print_freq == 0:
            #     progress.display(i)

        if is_main:
            # TODO: this should also be done with the ProgressMeter
            print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(
                top1=top1, top5=top5))

    return top1.avg
def validate(model, dataset, criterion, epoch, args):
    model.eval()
    random.seed(0)
    loader = dataset.loader('val', num_workers=args.num_workers)
    loss_meter = AverageMeter('loss', ':6.4f')
    total_length = len(loader)
    progress = ProgressMeter(total_length, [loss_meter], prefix='Validation: ')
    with torch.no_grad():
        for batch_num, batch in enumerate(tqdm(loader, total=len(loader))):
            batch = [tensor.to(args.device) for tensor in batch]
            inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch
            if args.task not in ['formality', 'iambic']:  # topic predictor
                if not args.debug and len(inputs) != args.batch_size:
                    continue
            scores = model(inputs,
                           lengths,
                           future_words,
                           log_probs,
                           syllables_to_go,
                           future_word_num_syllables,
                           rhyme_group_index,
                           run_classifier=True)
            if args.task == 'formality':  # we're learning for all positions at once. scores are batch x seq
                expanded_labels = classification_targets.unsqueeze(1).expand(
                    -1, scores.shape[1])  # batch x seq
                length_mask = pad_mask(lengths).permute(1, 0)  # batch x seq
                loss = criterion(
                    scores.flatten()[length_mask.flatten() == 1],
                    expanded_labels.flatten().float()[length_mask.flatten() ==
                                                      1])
            elif args.task == 'intent':  # we're learning for all positions at once. scores are batch x seq x 4
                expanded_labels = classification_targets.unsqueeze(1).expand(
                    -1, scores.shape[1], -1)  # batch x seq x 4
                expanded_labels = expanded_labels.contiguous().view(
                    -1, 4)  # batch*seq x 4
                scores = scores.contiguous().view(-1, 4)
                loss = criterion(scores, expanded_labels.float())
            elif args.task in ['iambic', 'newline']:
                use_indices = classification_targets.flatten() != -1
                loss = criterion(
                    scores.flatten()[use_indices],
                    classification_targets.flatten().float()[use_indices])
            else:  # topic, rhyme
                loss = criterion(scores.flatten(), labels.flatten().float())
            loss_meter.update(loss.detach(), len(labels))
            if batch_num % args.train_print_freq == 0:
                progress.display(batch_num)
    progress.display(total_length)
    return loss_meter.avg
예제 #10
0
def train(train_loader, model, criterion, optimizer, epoch, args, scaler):
    global iterations

    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    end = time.time()

    is_main = not args.multiprocessing_distributed or (
        args.multiprocessing_distributed
        and args.rank % torch.cuda.device_count() == 0)

    for i, (images, target) in enumerate(train_loader):
        iterations += 1

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
        if torch.cuda.is_available():
            target = target.cuda(args.gpu, non_blocking=True)

        dense_model = copy.deepcopy(model.state_dict())

        if iterations % 16 == 0 and (epoch + 1) <= args.milestones[1]:
            # 更新mask
            target_sparsity = args.prune_rate - args.prune_rate * \
                (1 - iterations / (args.milestones[1] * len(train_loader)))**3
            update_mask(model, target_sparsity)
        prune_by_mask(model)

        optimizer.zero_grad()
        # measure data loading time
        data_time.update(time.time() - end)

        with autocast():
            # compute output
            output = model(images)
            loss = criterion(output, target)

        scaler.scale(loss).backward()

        # 将模型恢复完整
        for k, m in model.named_modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.copy_(dense_model[k + '.weight'])

        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
예제 #11
0
def train(model, dataset, optimizer, criterion, epoch, args, data_start_index):
    model.train()
    if args.iw:
        loader = DataLoader(dataset, batch_size=args.batch_size // 2)
        loss_meter = AverageMeter('loss', ':6.4f')
        acc_meter = AverageMeter('acc', ':6.4f')
        total_length = len(loader)
        progress = ProgressMeter(total_length, [loss_meter, acc_meter],
                                 prefix='Training: ')

        for batch_num, ((x_source, source_lengths),
                        (x_target, target_lengths)) in enumerate(
                            tqdm(iter(loader), total=len(loader),
                                 leave=False)):
            x = torch.cat([x_source, x_target]).to(args.device).long()
            y_source = torch.cat([
                torch.cat([
                    torch.zeros(source_lengths[i]),
                    -torch.ones(100 - source_lengths[i])
                ]).unsqueeze(0) for i in range(source_lengths.shape[0])
            ],
                                 dim=0)
            y_target = torch.cat([
                torch.cat([
                    torch.ones(target_lengths[i]),
                    -torch.ones(100 - target_lengths[i])
                ]).unsqueeze(0) for i in range(target_lengths.shape[0])
            ],
                                 dim=0)
            y = torch.cat([y_source, y_target]).float().to(args.device)

            lengths = torch.cat([source_lengths,
                                 target_lengths]).squeeze().to(args.device)
            scores = model(x, lengths)
            use_indices = y.flatten() != -1
            loss = criterion(scores.flatten()[use_indices],
                             y.flatten().float()[use_indices])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            preds = torch.round(torch.sigmoid(scores.flatten()[use_indices]))
            accs = sum(preds == y.flatten()
                       [use_indices]) / y.flatten()[use_indices].shape[0]
            acc_meter.update(accs.detach(), y.shape[0])
            loss_meter.update(loss.detach(), y.shape[0])
            if batch_num % args.train_print_freq == 0:
                progress.display(batch_num)
        progress.display(total_length)
    else:
        if data_start_index == 0:
            dataset.shuffle('train', seed=epoch + args.seed)
        if args.epoch_max_len is not None:
            data_end_index = min(data_start_index + args.epoch_max_len,
                                 len(dataset.splits['train']))
            loader = dataset.loader('train',
                                    num_workers=args.num_workers,
                                    indices=list(
                                        range(data_start_index,
                                              data_end_index)))
            data_start_index = data_end_index if data_end_index < len(
                dataset.splits['train']) else 0
        else:
            loader = dataset.loader('train', num_workers=args.num_workers)
        loss_meter = AverageMeter('loss', ':6.4f')
        total_length = len(loader)
        progress = ProgressMeter(total_length, [loss_meter],
                                 prefix='Training: ')
        for batch_num, batch in enumerate(tqdm(loader, total=len(loader))):
            batch = [tensor.to(args.device) for tensor in batch]
            inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch
            if args.task not in ['formality', 'iambic']:
                if not args.debug and len(
                        inputs
                ) != args.batch_size:  # it'll screw up the bias...?
                    continue
            scores = model(inputs,
                           lengths,
                           future_words,
                           log_probs,
                           syllables_to_go,
                           future_word_num_syllables,
                           rhyme_group_index,
                           run_classifier=True)
            if args.task == 'formality':  # we're learning for all positions at once. scores are batch x seq
                expanded_labels = classification_targets.unsqueeze(1).expand(
                    -1, scores.shape[1])  # batch x seq
                length_mask = pad_mask(lengths).permute(1, 0)  # batch x seq
                loss = criterion(
                    scores.flatten()[length_mask.flatten() == 1],
                    expanded_labels.flatten().float()[length_mask.flatten() ==
                                                      1])
            elif args.task in ['iambic', 'newline']:
                use_indices = classification_targets.flatten() != -1
                loss = criterion(
                    scores.flatten()[use_indices],
                    classification_targets.flatten().float()[use_indices])
            else:  # topic, rhyme
                loss = criterion(scores.flatten(), labels.flatten().float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_meter.update(loss.detach(), len(labels))
            if batch_num % args.train_print_freq == 0:
                progress.display(batch_num)
        progress.display(total_length)
    return data_start_index
예제 #12
0
def validate(model, dataset, criterion, epoch, args):
    model.eval()
    random.seed(0)
    if args.iw:
        loader = DataLoader(SplitDataset(args, split='test'),
                            batch_size=args.batch_size // 2)
        loss_meter = AverageMeter('loss', ':6.4f')
        acc_meter = AverageMeter('acc', ':6.4f')
        total_length = len(loader)
        progress = ProgressMeter(total_length, [loss_meter, acc_meter],
                                 prefix='Validation: ')
        with torch.no_grad():
            for batch_num, ((x_source, source_lengths),
                            (x_target, target_lengths)) in enumerate(
                                tqdm(iter(loader),
                                     total=len(loader),
                                     leave=False)):
                x = torch.cat([x_source, x_target]).to(args.device).long()
                y_source = torch.cat([
                    torch.cat([
                        torch.zeros(source_lengths[i]),
                        -torch.ones(100 - source_lengths[i])
                    ]).unsqueeze(0) for i in range(source_lengths.shape[0])
                ],
                                     dim=0)
                y_target = torch.cat([
                    torch.cat([
                        torch.ones(target_lengths[i]),
                        -torch.ones(100 - target_lengths[i])
                    ]).unsqueeze(0) for i in range(target_lengths.shape[0])
                ],
                                     dim=0)
                y = torch.cat([y_source, y_target]).float().to(args.device)

                lengths = torch.cat([source_lengths,
                                     target_lengths]).squeeze().to(args.device)
                scores = model(x, lengths)
                use_indices = y.flatten() != -1
                loss = criterion(scores.flatten()[use_indices],
                                 y.flatten().float()[use_indices])
                preds = torch.round(
                    torch.sigmoid(scores.flatten()[use_indices]))

                accs = sum(preds == y.flatten()
                           [use_indices]) / y.flatten()[use_indices].shape[0]
                acc_meter.update(accs.detach(), y.shape[0])
                loss_meter.update(loss.detach(), y.shape[0])
                if batch_num % args.train_print_freq == 0:
                    progress.display(batch_num)
        progress.display(total_length)
        return loss_meter.avg

    else:
        loader = dataset.loader('val', num_workers=args.num_workers)
        loss_meter = AverageMeter('loss', ':6.4f')
        total_length = len(loader)
        progress = ProgressMeter(total_length, [loss_meter],
                                 prefix='Validation: ')
        with torch.no_grad():
            for batch_num, batch in enumerate(tqdm(loader, total=len(loader))):
                batch = [tensor.to(args.device) for tensor in batch]
                inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch
                if args.task not in ['formality', 'iambic']:  # topic predictor
                    if not args.debug and len(inputs) != args.batch_size:
                        continue
                scores = model(inputs,
                               lengths,
                               future_words,
                               log_probs,
                               syllables_to_go,
                               future_word_num_syllables,
                               rhyme_group_index,
                               run_classifier=True)
                if args.task == 'formality':  # we're learning for all positions at once. scores are batch x seq
                    expanded_labels = classification_targets.unsqueeze(
                        1).expand(-1, scores.shape[1])  # batch x seq
                    length_mask = pad_mask(lengths).permute(1,
                                                            0)  # batch x seq
                    loss = criterion(
                        scores.flatten()[length_mask.flatten() == 1],
                        expanded_labels.flatten().float()[
                            length_mask.flatten() == 1])
                elif args.task in ['iambic', 'newline']:
                    use_indices = classification_targets.flatten() != -1
                    loss = criterion(
                        scores.flatten()[use_indices],
                        classification_targets.flatten().float()[use_indices])
                else:  # topic, rhyme
                    loss = criterion(scores.flatten(),
                                     labels.flatten().float())
                loss_meter.update(loss.detach(), len(labels))
                if batch_num % args.train_print_freq == 0:
                    progress.display(batch_num)
        progress.display(total_length)
        return loss_meter.avg