コード例 #1
0
def train(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    # progress = ProgressMeter(
    # 	len(train_loader),
    # 	[batch_time, data_time, losses, top1, top5],
    # 	prefix="Epoch: [{}]".format(epoch))

    # stliu: design new pregress
    epoch_time = AverageMeter('Epoch Time', ':6.3f')
    progress = ProgressMeter(len(train_loader),
                             [epoch_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()

    for i, (images, _) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images[0] = images[0].cuda(args.gpu, non_blocking=True)
            images[1] = images[1].cuda(args.gpu, non_blocking=True)

        # compute output
        output, target = model(im_q=images[0], im_k=images[1])
        loss = criterion(output, target)

        # acc1/acc5 are (K+1)-way contrast classifier accuracy
        # measure accuracy and record lossa
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images[0].size(0))
        top1.update(acc1[0], images[0].size(0))
        top5.update(acc5[0], images[0].size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        epoch_time.update(batch_time.avg * len(train_loader))
        end = time.time()

        if (i + 1) % args.print_freq == 0:  # stliu: change i to i+1
            progress.display(i)

    return losses.avg
コード例 #2
0
def train(model, dataset, optimizer, criterion, epoch, args, data_start_index):
    model.train()
    if data_start_index == 0:
        dataset.shuffle('train', seed=epoch + args.seed)
    if args.epoch_max_len is not None:
        data_end_index = min(data_start_index + args.epoch_max_len,
                             len(dataset.splits['train']))
        loader = dataset.loader('train',
                                num_workers=args.num_workers,
                                indices=list(
                                    range(data_start_index, data_end_index)))
        data_start_index = data_end_index if data_end_index < len(
            dataset.splits['train']) else 0
    else:
        loader = dataset.loader('train', num_workers=args.num_workers)
    loss_meter = AverageMeter('loss', ':6.4f')
    total_length = len(loader)
    progress = ProgressMeter(total_length, [loss_meter], prefix='Training: ')
    for batch_num, batch in enumerate(tqdm(loader, total=len(loader))):
        batch = [tensor.to(args.device) for tensor in batch]
        inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch
        if args.task not in ['formality', 'iambic']:
            if not args.debug and len(
                    inputs) != args.batch_size:  # it'll screw up the bias...?
                continue
        scores = model(inputs,
                       lengths,
                       future_words,
                       log_probs,
                       syllables_to_go,
                       future_word_num_syllables,
                       rhyme_group_index,
                       run_classifier=True)
        if args.task == 'formality':  # we're learning for all positions at once. scores are batch x seq
            expanded_labels = classification_targets.unsqueeze(1).expand(
                -1, scores.shape[1])  # batch x seq
            length_mask = pad_mask(lengths).permute(1, 0)  # batch x seq
            loss = criterion(
                scores.flatten()[length_mask.flatten() == 1],
                expanded_labels.flatten().float()[length_mask.flatten() == 1])
        elif args.task in ['iambic', 'newline']:
            use_indices = classification_targets.flatten() != -1
            loss = criterion(
                scores.flatten()[use_indices],
                classification_targets.flatten().float()[use_indices])
        else:  # topic, rhyme
            loss = criterion(scores.flatten(), labels.flatten().float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_meter.update(loss.detach(), len(labels))
        if batch_num % args.train_print_freq == 0:
            progress.display(batch_num)
    progress.display(total_length)
    return data_start_index
コード例 #3
0
def validate(model, dataset, criterion, epoch, args):
    model.eval()
    random.seed(0)
    loader = dataset.loader('val', num_workers=args.num_workers)
    loss_meter = AverageMeter('loss', ':6.4f')
    total_length = len(loader)
    progress = ProgressMeter(total_length, [loss_meter], prefix='Validation: ')
    with torch.no_grad():
        for batch_num, batch in enumerate(tqdm(loader, total=len(loader))):
            batch = [tensor.to(args.device) for tensor in batch]
            inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch
            if args.task not in ['formality', 'iambic']:  # topic predictor
                if not args.debug and len(inputs) != args.batch_size:
                    continue
            scores = model(inputs,
                           lengths,
                           future_words,
                           log_probs,
                           syllables_to_go,
                           future_word_num_syllables,
                           rhyme_group_index,
                           run_classifier=True)
            if args.task == 'formality':  # we're learning for all positions at once. scores are batch x seq
                expanded_labels = classification_targets.unsqueeze(1).expand(
                    -1, scores.shape[1])  # batch x seq
                length_mask = pad_mask(lengths).permute(1, 0)  # batch x seq
                loss = criterion(
                    scores.flatten()[length_mask.flatten() == 1],
                    expanded_labels.flatten().float()[length_mask.flatten() ==
                                                      1])
            elif args.task == 'intent':  # we're learning for all positions at once. scores are batch x seq x 4
                expanded_labels = classification_targets.unsqueeze(1).expand(
                    -1, scores.shape[1], -1)  # batch x seq x 4
                expanded_labels = expanded_labels.contiguous().view(
                    -1, 4)  # batch*seq x 4
                scores = scores.contiguous().view(-1, 4)
                loss = criterion(scores, expanded_labels.float())
            elif args.task in ['iambic', 'newline']:
                use_indices = classification_targets.flatten() != -1
                loss = criterion(
                    scores.flatten()[use_indices],
                    classification_targets.flatten().float()[use_indices])
            else:  # topic, rhyme
                loss = criterion(scores.flatten(), labels.flatten().float())
            loss_meter.update(loss.detach(), len(labels))
            if batch_num % args.train_print_freq == 0:
                progress.display(batch_num)
    progress.display(total_length)
    return loss_meter.avg
コード例 #4
0
def train(model, dataset, optimizer, criterion, epoch, args, data_start_index):
    model.train()
    if args.iw:
        loader = DataLoader(dataset, batch_size=args.batch_size // 2)
        loss_meter = AverageMeter('loss', ':6.4f')
        acc_meter = AverageMeter('acc', ':6.4f')
        total_length = len(loader)
        progress = ProgressMeter(total_length, [loss_meter, acc_meter],
                                 prefix='Training: ')

        for batch_num, ((x_source, source_lengths),
                        (x_target, target_lengths)) in enumerate(
                            tqdm(iter(loader), total=len(loader),
                                 leave=False)):
            x = torch.cat([x_source, x_target]).to(args.device).long()
            y_source = torch.cat([
                torch.cat([
                    torch.zeros(source_lengths[i]),
                    -torch.ones(100 - source_lengths[i])
                ]).unsqueeze(0) for i in range(source_lengths.shape[0])
            ],
                                 dim=0)
            y_target = torch.cat([
                torch.cat([
                    torch.ones(target_lengths[i]),
                    -torch.ones(100 - target_lengths[i])
                ]).unsqueeze(0) for i in range(target_lengths.shape[0])
            ],
                                 dim=0)
            y = torch.cat([y_source, y_target]).float().to(args.device)

            lengths = torch.cat([source_lengths,
                                 target_lengths]).squeeze().to(args.device)
            scores = model(x, lengths)
            use_indices = y.flatten() != -1
            loss = criterion(scores.flatten()[use_indices],
                             y.flatten().float()[use_indices])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            preds = torch.round(torch.sigmoid(scores.flatten()[use_indices]))
            accs = sum(preds == y.flatten()
                       [use_indices]) / y.flatten()[use_indices].shape[0]
            acc_meter.update(accs.detach(), y.shape[0])
            loss_meter.update(loss.detach(), y.shape[0])
            if batch_num % args.train_print_freq == 0:
                progress.display(batch_num)
        progress.display(total_length)
    else:
        if data_start_index == 0:
            dataset.shuffle('train', seed=epoch + args.seed)
        if args.epoch_max_len is not None:
            data_end_index = min(data_start_index + args.epoch_max_len,
                                 len(dataset.splits['train']))
            loader = dataset.loader('train',
                                    num_workers=args.num_workers,
                                    indices=list(
                                        range(data_start_index,
                                              data_end_index)))
            data_start_index = data_end_index if data_end_index < len(
                dataset.splits['train']) else 0
        else:
            loader = dataset.loader('train', num_workers=args.num_workers)
        loss_meter = AverageMeter('loss', ':6.4f')
        total_length = len(loader)
        progress = ProgressMeter(total_length, [loss_meter],
                                 prefix='Training: ')
        for batch_num, batch in enumerate(tqdm(loader, total=len(loader))):
            batch = [tensor.to(args.device) for tensor in batch]
            inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch
            if args.task not in ['formality', 'iambic']:
                if not args.debug and len(
                        inputs
                ) != args.batch_size:  # it'll screw up the bias...?
                    continue
            scores = model(inputs,
                           lengths,
                           future_words,
                           log_probs,
                           syllables_to_go,
                           future_word_num_syllables,
                           rhyme_group_index,
                           run_classifier=True)
            if args.task == 'formality':  # we're learning for all positions at once. scores are batch x seq
                expanded_labels = classification_targets.unsqueeze(1).expand(
                    -1, scores.shape[1])  # batch x seq
                length_mask = pad_mask(lengths).permute(1, 0)  # batch x seq
                loss = criterion(
                    scores.flatten()[length_mask.flatten() == 1],
                    expanded_labels.flatten().float()[length_mask.flatten() ==
                                                      1])
            elif args.task in ['iambic', 'newline']:
                use_indices = classification_targets.flatten() != -1
                loss = criterion(
                    scores.flatten()[use_indices],
                    classification_targets.flatten().float()[use_indices])
            else:  # topic, rhyme
                loss = criterion(scores.flatten(), labels.flatten().float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_meter.update(loss.detach(), len(labels))
            if batch_num % args.train_print_freq == 0:
                progress.display(batch_num)
        progress.display(total_length)
    return data_start_index
コード例 #5
0
def validate(model, dataset, criterion, epoch, args):
    model.eval()
    random.seed(0)
    if args.iw:
        loader = DataLoader(SplitDataset(args, split='test'),
                            batch_size=args.batch_size // 2)
        loss_meter = AverageMeter('loss', ':6.4f')
        acc_meter = AverageMeter('acc', ':6.4f')
        total_length = len(loader)
        progress = ProgressMeter(total_length, [loss_meter, acc_meter],
                                 prefix='Validation: ')
        with torch.no_grad():
            for batch_num, ((x_source, source_lengths),
                            (x_target, target_lengths)) in enumerate(
                                tqdm(iter(loader),
                                     total=len(loader),
                                     leave=False)):
                x = torch.cat([x_source, x_target]).to(args.device).long()
                y_source = torch.cat([
                    torch.cat([
                        torch.zeros(source_lengths[i]),
                        -torch.ones(100 - source_lengths[i])
                    ]).unsqueeze(0) for i in range(source_lengths.shape[0])
                ],
                                     dim=0)
                y_target = torch.cat([
                    torch.cat([
                        torch.ones(target_lengths[i]),
                        -torch.ones(100 - target_lengths[i])
                    ]).unsqueeze(0) for i in range(target_lengths.shape[0])
                ],
                                     dim=0)
                y = torch.cat([y_source, y_target]).float().to(args.device)

                lengths = torch.cat([source_lengths,
                                     target_lengths]).squeeze().to(args.device)
                scores = model(x, lengths)
                use_indices = y.flatten() != -1
                loss = criterion(scores.flatten()[use_indices],
                                 y.flatten().float()[use_indices])
                preds = torch.round(
                    torch.sigmoid(scores.flatten()[use_indices]))

                accs = sum(preds == y.flatten()
                           [use_indices]) / y.flatten()[use_indices].shape[0]
                acc_meter.update(accs.detach(), y.shape[0])
                loss_meter.update(loss.detach(), y.shape[0])
                if batch_num % args.train_print_freq == 0:
                    progress.display(batch_num)
        progress.display(total_length)
        return loss_meter.avg

    else:
        loader = dataset.loader('val', num_workers=args.num_workers)
        loss_meter = AverageMeter('loss', ':6.4f')
        total_length = len(loader)
        progress = ProgressMeter(total_length, [loss_meter],
                                 prefix='Validation: ')
        with torch.no_grad():
            for batch_num, batch in enumerate(tqdm(loader, total=len(loader))):
                batch = [tensor.to(args.device) for tensor in batch]
                inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch
                if args.task not in ['formality', 'iambic']:  # topic predictor
                    if not args.debug and len(inputs) != args.batch_size:
                        continue
                scores = model(inputs,
                               lengths,
                               future_words,
                               log_probs,
                               syllables_to_go,
                               future_word_num_syllables,
                               rhyme_group_index,
                               run_classifier=True)
                if args.task == 'formality':  # we're learning for all positions at once. scores are batch x seq
                    expanded_labels = classification_targets.unsqueeze(
                        1).expand(-1, scores.shape[1])  # batch x seq
                    length_mask = pad_mask(lengths).permute(1,
                                                            0)  # batch x seq
                    loss = criterion(
                        scores.flatten()[length_mask.flatten() == 1],
                        expanded_labels.flatten().float()[
                            length_mask.flatten() == 1])
                elif args.task in ['iambic', 'newline']:
                    use_indices = classification_targets.flatten() != -1
                    loss = criterion(
                        scores.flatten()[use_indices],
                        classification_targets.flatten().float()[use_indices])
                else:  # topic, rhyme
                    loss = criterion(scores.flatten(),
                                     labels.flatten().float())
                loss_meter.update(loss.detach(), len(labels))
                if batch_num % args.train_print_freq == 0:
                    progress.display(batch_num)
        progress.display(total_length)
        return loss_meter.avg