Beispiel #1
0
def validate(val_loader, model, criterion, device, is_test):
    losses = AverageMeter("Loss", ":.4e")
    top1 = AverageMeter("Acc@1", ":6.2f")
    top5 = AverageMeter("Acc@5", ":6.2f")
    prefix = "Test: " if is_test else "Validation: "
    progress = ProgressMeter(len(val_loader), [losses, top1, top5], prefix=prefix)

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for i, (images, target) in enumerate(val_loader):
            if torch.cuda.is_available():
                images = images.to(device)
                target = target.to(device)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0].item(), images.size(0))
            top5.update(acc5[0].item(), images.size(0))

            if i % 100 == 0:
                progress.display(i)

    return losses.avg, top1.avg, top5.avg
Beispiel #2
0
def train(train_loader, model, criterion, optimizer, epoch, args, cluster_result=None):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    acc_inst = AverageMeter('Acc@Inst', ':6.2f')
    acc_proto = AverageMeter('Acc@Proto', ':6.2f')

    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, acc_inst, acc_proto],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, index) in tqdm(enumerate(train_loader)):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images[0] = images[0].cuda(args.gpu, non_blocking=True)
            images[1] = images[1].cuda(args.gpu, non_blocking=True)

        # compute output
        output, target, output_proto, target_proto = model(im_q=images[0], im_k=images[1],
                                                           cluster_result=cluster_result, index=index)

        # InfoNCE loss
        loss = criterion(output, target)

        # ProtoNCE loss
        if output_proto is not None:
            loss_proto = 0
            for proto_out, proto_target in zip(output_proto, target_proto):
                loss_proto += criterion(proto_out, proto_target)
                accp = accuracy(proto_out, proto_target)[0]
                acc_proto.update(accp[0], images[0].size(0))

            # average loss across all sets of prototypes
            loss_proto /= len(args.num_cluster)
            loss += loss_proto

        losses.update(loss.item(), images[0].size(0))
        acc = accuracy(output, target)[0]
        acc_inst.update(acc[0], images[0].size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
def train(args, epoch, loader, model, optimizer, writer):
    model.train()
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    lr = AverageMeter('Lr', ':.3f')

    progress = ProgressMeter(
        len(loader),
        [lr, batch_time, losses],
        prefix='Epoch: [{}]'.format(epoch))
    
    end = time.time()

    for _iter, (images, targets) in enumerate(loader):
        images[0], images[1] = images[0].cuda(args.gpu, non_blocking=True), images[1].cuda(args.gpu, non_blocking=True)
        
        # swap the image
        yi, xj_moment = model(images[0], images[1])
        yj, xi_moment = model(images[1], images[0])

        if args.loss == 'pixpro':         
            base_A_matrix, moment_A_matrix = targets[0].cuda(args.gpu), targets[1].cuda(args.gpu)
            pixpro_loss = PixproLoss(args)
            overall_loss = pixpro_loss(yi, xj_moment, base_A_matrix) + pixpro_loss(yj, xi_moment, moment_A_matrix)
        
        elif args.loss == 'pixcontrast':
            base_A_matrix, moment_A_matrix = targets[0][0].cuda(args.gpu), targets[0][1].cuda(args.gpu)
            base_inter_mask, moment_inter_mask = targets[1][0].cuda(args.gpu), targets[1][1].cuda(args.gpu)

            pixcontrast_loss = PixContrastLoss(args)
            overall_loss = (pixcontrast_loss(yi, xj_moment, base_A_matrix, base_inter_mask) 
                            + pixcontrast_loss(yj, xi_moment, moment_A_matrix, moment_inter_mask)) / 2
        else:
            ValueError('HAVE TO SELECT PROPER LOSS TYPE')
        
        # if there is no intersection, skip the update
        if torch.max(base_A_matrix) < 1 and torch.max(moment_A_matrix) < 1:
            continue

        losses.update(overall_loss.item(), images[0].size(0))
        for param_group in optimizer.param_groups:
            cur_lr = param_group['lr']
        lr.update(cur_lr) 
        optimizer.zero_grad()
        overall_loss.backward()
        optimizer.step()
        
        batch_time.update(time.time() - end)
        end = time.time()

        if (_iter % args.print_freq == 0) and (args.gpu==0):
            progress.display(_iter)
            writer.add_scalar('Loss', overall_loss.item(), (epoch*len(loader))+_iter)
            writer.add_scalar('lr', cur_lr, (epoch*len(loader))+_iter)
Beispiel #4
0
def train_kd(train_loader, teacher, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target, idx) in enumerate(train_loader):
        # grid_img = torchvision.utils.make_grid(images)
        # imshow(grid_img)
        # time.sleep(100)

        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
        target = target.cuda(args.gpu, non_blocking=True)

        # compute output
        output = model(images)
        with torch.no_grad():
            o_teacher = teacher(images)
            #o_teacher = gaussian_noise(o_teacher, mean=0, stddev=0.5, alpha=0.4)
        # 0.1, 0.4 76.640
        # 0.3, 0.4 76.630
        # 0.5, 0.4 76.632
        # o_teacher = torch.from_numpy(o_teacher_label_train[idx]).cuda()

        loss = criterion(output, o_teacher, target)
        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.detach().item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Beispiel #5
0
def train(train_loader, model, criterion, optimizer, epoch, args,
          lr_scheduler):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader), [
        batch_time,
        data_time,
        losses,
        top1,
        top5,
    ],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
        if torch.cuda.is_available():
            target = target.cuda(args.gpu, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        lr_scheduler.step()

        if i % args.print_freq == 0:
            progress.display(i)
        if i % 1000 == 0:
            print('cur lr: ', lr_scheduler.get_lr()[0])
Beispiel #6
0
def train_prune(train_loader, model, criterion, optimizer, epoch, zero_weight,
                zero_grad, args):
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1)
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target, idx) in enumerate(train_loader):
        model.apply(zero_weight)
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
        target = target.cuda(args.gpu, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.detach().item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()

        model.apply(zero_grad)

        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Beispiel #7
0
    def train_epoch(self):
        self.model.train()
        self.epoch += 1

        # record training statistics
        avg_meters = {
            'loss': AverageMeter('Loss', ':.4e'),
            'acc': AverageMeter('Acc', ':6.2f'),
            'time': AverageMeter('Time', ':6.3f')
        }
        progress_meter = ProgressMeter(
            len(self.train_loader),
            avg_meters.values(),
            prefix="Epoch: [{}]".format(self.epoch)
        )

        # begin training from minibatches
        for ix, data in enumerate(self.train_loader):
            start_time = time.time()

            input_ids, attention_mask, labels = map(
                lambda x: x.to(args.device), data
            )
            logits = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            loss = self.criterion(logits, labels)
            acc = (logits.argmax(axis=1) == labels).float().mean().item()

            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                args.max_grad_norm
            )
            self.optimizer.step()
            self.scheduler.step()

            avg_meters['loss'].update(loss.item(), input_ids.size(0))
            avg_meters['acc'].update(acc * 100, input_ids.size(0))
            avg_meters['time'].update(time.time() - start_time)

            # log progress
            if (ix + 1) % args.log_interval == 0:
                progress_meter.display(ix + 1)

        progress_meter.display(len(self.train_loader))
Beispiel #8
0
def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            if torch.cuda.is_available():
                target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)


        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg
Beispiel #9
0
def train(train_loader, model, criterion, optimizer, epoch, device, print_freq):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        
        images = images.to(device)
        target = target.to(device)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            progress.display(i)
Beispiel #10
0
def train(train_loader, model, criterion, optimizer, epoch, cfg, logger):
    curr_lr = optimizer.param_groups[0]["lr"]
    progress = ProgressMeter(
        len(train_loader),
        [logger.time, logger.loss, logger.acc1, logger.acc5],
        prefix="Epoch: [{}/{}]\t"
        "LR: {}\t".format(epoch, cfg.epochs, curr_lr),
    )

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time

        if cfg.gpu is not None:
            images = images.cuda(cfg.gpu, non_blocking=True)
        target = target.cuda(cfg.gpu, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        logger.time.update(time.time() - end)
        logger.loss.update(loss.item(), images.size(0))
        logger.acc1.update(acc1[0].item(), images.size(0))
        logger.acc5.update(acc5[0].item(), images.size(0))
        logger.save(batch=i, epoch=epoch)

        end = time.time()

        if i % cfg.print_freq == 0:
            progress.display(i)
def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter("Time", ":6.3f")
    losses = AverageMeter("Loss", ":.4e")
    top1 = AverageMeter("Acc@1", ":6.2f")
    top5 = AverageMeter("Acc@5", ":6.2f")
    progress = ProgressMeter(len(val_loader), [batch_time, losses, top1, top5],
                             prefix="Test: ")

    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            # blur images
            if args.blur_val:
                images = GaussianBlurAll(images, args.sigma)
            if torch.cuda.is_available():
                images = images.cuda(args.gpu, non_blocking=True)
                target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        print(" * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}".format(top1=top1,
                                                                    top5=top5))

    return losses.avg, top1.avg, top5.avg
Beispiel #12
0
def validate(val_loader, model, criterion, cfg, epoch, logger):

    progress = ProgressMeter(
        len(val_loader),
        [logger.time, logger.loss, logger.acc1, logger.acc5],
        prefix="Test: ",
    )

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            if cfg.gpu is not None:
                images = images.cuda(cfg.gpu, non_blocking=True)
            target = target.cuda(cfg.gpu, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            logger.time.update(time.time() - end)
            logger.loss.update(loss.item(), images.size(0))
            logger.acc1.update(acc1[0].item(), images.size(0))
            logger.acc5.update(acc5[0].item(), images.size(0))
            logger.time.update(time.time() - end)
            end = time.time()
            if i % cfg.print_freq == 0:
                progress.display(i)
        logger.save(batch=i, epoch=epoch)

        # measure elapsed time

        # TODO: this should also be done with the ProgressMeter
        print(" * Acc@1 {acc1:.3f} Acc@5 {acc5:.3f}".format(
            acc1=logger.acc1.avg, acc5=logger.acc5.avg))

    return logger.acc1.avg
Beispiel #13
0
def validate_kd(val_loader, teacher, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(val_loader), [batch_time, losses, top1, top5],
                             prefix='Test: ')

    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target, idx) in enumerate(val_loader):

            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            output = model(images)
            o_teacher = teacher(images)
            #o_teacher = torch.from_numpy(o_teacher_label_val[idx]).cuda()

            loss = criterion(output, o_teacher, target)

            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1,
                                                                    top5=top5))

    return top1.avg
Beispiel #14
0
def train(
    train_loader, models, optimizers, criterion, epoch, device, method_name, **kwargs,
):
    loss_meters = []
    top1_meters = []
    top5_meters = []
    inds_updates = [[] for _ in range(len(models))]

    show_logs = []
    for i in range(len(models)):
        loss_meter = AverageMeter(f"Loss{i}", ":.4e")
        top1_meter = AverageMeter(f"Acc{i}@1", ":6.2f")
        top5_meter = AverageMeter(f"Acc{i}@5", ":6.2f")
        loss_meters.append(loss_meter)
        top1_meters.append(top1_meter)
        top5_meters.append(top5_meter)
        show_logs += [loss_meter, top1_meter, top5_meter]
    progress = ProgressMeter(
        len(train_loader), show_logs, prefix="Epoch: [{}]".format(epoch),
    )

    # switch to train mode
    for i in range(len(models)):
        models[i].train()

    for i, (images, target, indexes) in enumerate(train_loader):
        if torch.cuda.is_available():
            images = images.to(device)
            target = target.to(device)

        outputs = []
        for m in range(len(models)):
            output = models[m](images)
            outputs.append(output)

        # calculate loss and selected index
        if method_name in ["ours", "ftl", "greedy", "precision"]:
            ind = indexes.cpu().numpy()
            losses, ind_updates = loss_general(outputs, target, criterion)
        elif method_name == "f-correction":
            losses, ind_updates = loss_forward(outputs, target, kwargs["P"])
        elif method_name == "decouple":
            losses, ind_updates = loss_decouple(outputs, target, criterion)
        elif method_name == "co-teaching":
            losses, ind_updates = loss_coteaching(
                outputs, target, kwargs["rate_schedule"][epoch]
            )
        elif method_name == "co-teaching+":
            ind = indexes.cpu().numpy().transpose()
            if epoch < kwargs["init_epoch"]:
                losses, ind_updates = loss_coteaching(
                    outputs, target, kwargs["rate_schedule"][epoch]
                )
            else:
                losses, ind_updates = loss_coteaching_plus(
                    outputs, target, kwargs["rate_schedule"][epoch], ind, epoch * i,
                )
        elif method_name == "jocor":
            losses, ind_updates = loss_jocor(
                outputs, target, kwargs["rate_schedule"][epoch], kwargs["co_lambda"]
            )
        else:
            losses, ind_updates = loss_general(outputs, target, criterion)

        if None in losses or any(~torch.isfinite(torch.tensor(losses))):
            continue

        # compute gradient and do BP
        for loss, optimizer in zip(losses, optimizers):
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # measure accuracy and record loss
        for m in range(len(models)):
            acc1, acc5 = accuracy(outputs[m], target, topk=(1, 5))

            top1_meters[m].update(acc1[0].item(), images.size(0))
            top5_meters[m].update(acc5[0].item(), images.size(0))
            if len(ind_updates[m]) > 0:
                loss_meters[m].update(losses[m].item(), len(ind_updates[m]))
                inds_updates[m] += indexes[ind_updates[m]].numpy().tolist()
            else:
                loss_meters[m].update(losses[m].item(), images.size(0))

        if i % 100 == 0:
            progress.display(i)

    loss_avgs = [loss_meter.avg for loss_meter in loss_meters]
    top1_avgs = [top1_meter.avg for top1_meter in top1_meters]
    top5_avgs = [top5_meter.avg for top5_meter in top1_meters]

    return loss_avgs, top1_avgs, top5_avgs, inds_updates
Beispiel #15
0
def ssl(
    model,
    device,
    dataloader,
    criterion,
    optimizer,
    lr_scheduler=None,
    epoch=0,
    args=None,
):
    print(
        " ->->->->->->->->->-> One epoch with self-supervised training <-<-<-<-<-<-<-<-<-<-"
    )

    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.4f")
    progress = ProgressMeter(
        len(dataloader),
        [batch_time, data_time, losses],
        prefix="Epoch: [{}]".format(epoch),
    )

    model.train()
    end = time.time()

    for i, data in enumerate(dataloader):
        images, target = data[0], data[1].to(device)
        images = torch.cat([images[0], images[1]], dim=0).to(device)
        bsz = target.shape[0]

        # basic properties of training
        if i == 0:
            print(
                images.shape,
                target.shape,
                f"Batch_size from args: {args.batch_size}",
                "lr: {:.5f}".format(optimizer.param_groups[0]["lr"]),
            )
            print("Pixel range for training images : [{}, {}]".format(
                torch.min(images).data.cpu().numpy(),
                torch.max(images).data.cpu().numpy(),
            ))

        features = model(images)
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
        if args.training_mode == "SupCon":
            loss = criterion(features, target)
        elif args.training_mode == "SimCLR":
            loss = criterion(features)
        else:
            raise ValueError("training mode not supported")

        losses.update(loss.item(), bsz)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if lr_scheduler:
            lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Beispiel #16
0
def validate(val_loader, model, classifier, criterion, config, logger, block=None):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.3f')
    top1 = AverageMeter('Acc@1', ':6.3f')
    top5 = AverageMeter('Acc@5', ':6.3f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Eval: ')

    # switch to evaluate mode
    model.eval()
    if config.dataset == 'places':
        block.eval()
    classifier.eval()
    class_num = torch.zeros(config.num_classes).cuda()
    correct = torch.zeros(config.num_classes).cuda()

    confidence = np.array([])
    pred_class = np.array([])
    true_class = np.array([])

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            if config.gpu is not None:
                images = images.cuda(config.gpu, non_blocking=True)
            if torch.cuda.is_available():
                target = target.cuda(config.gpu, non_blocking=True)

            # compute output
            feat = model(images)
            if config.dataset == 'places':
                feat = block(feat)
            output = classifier(feat)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            _, predicted = output.max(1)
            target_one_hot = F.one_hot(target, config.num_classes)
            predict_one_hot = F.one_hot(predicted, config.num_classes)
            class_num = class_num + target_one_hot.sum(dim=0).to(torch.float)
            correct = correct + (target_one_hot + predict_one_hot == 2).sum(dim=0).to(torch.float)

            prob = torch.softmax(output, dim=1)
            confidence_part, pred_class_part = torch.max(prob, dim=1)
            confidence = np.append(confidence, confidence_part.cpu().numpy())
            pred_class = np.append(pred_class, pred_class_part.cpu().numpy())
            true_class = np.append(true_class, target.cpu().numpy())

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % config.print_freq == 0:
                progress.display(i, logger)

        acc_classes = correct / class_num
        head_acc = acc_classes[config.head_class_idx[0]:config.head_class_idx[1]].mean() * 100

        med_acc = acc_classes[config.med_class_idx[0]:config.med_class_idx[1]].mean() * 100
        tail_acc = acc_classes[config.tail_class_idx[0]:config.tail_class_idx[1]].mean() * 100
        logger.info('* Acc@1 {top1.avg:.3f}% Acc@5 {top5.avg:.3f}% HAcc {head_acc:.3f}% MAcc {med_acc:.3f}% TAcc {tail_acc:.3f}%.'.format(top1=top1, top5=top5, head_acc=head_acc, med_acc=med_acc, tail_acc=tail_acc))

        cal = calibration(true_class, pred_class, confidence, num_bins=15)
        logger.info('* ECE   {ece:.3f}%.'.format(ece=cal['expected_calibration_error'] * 100))

    return top1.avg, cal['expected_calibration_error'] * 100
Beispiel #17
0
def train(train_loader, model, classifier, criterion, optimizer, epoch, config, logger, block=None):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.3f')
    top1 = AverageMeter('Acc@1', ':6.3f')
    top5 = AverageMeter('Acc@5', ':6.3f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    if config.dataset == 'places':
        model.eval()
        block.train()
    else:
        model.train()
    classifier.train()

    training_data_num = len(train_loader.dataset)
    end_steps = int(training_data_num / train_loader.batch_size)

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        if i > end_steps:
            break

        # measure data loading time
        data_time.update(time.time() - end)

        if torch.cuda.is_available():
            images = images.cuda(config.gpu, non_blocking=True)
            target = target.cuda(config.gpu, non_blocking=True)

        if config.mixup is True:
            images, targets_a, targets_b, lam = mixup_data(images, target, alpha=config.alpha)
            if config.dataset == 'places':
                with torch.no_grad():
                    feat_a = model(images)
                feat = block(feat_a.detach())
                output = classifier(feat)
            else:
                feat = model(images)
                output = classifier(feat)
            loss = mixup_criterion(criterion, output, targets_a, targets_b, lam)
        else:
            if config.dataset == 'places':
                with torch.no_grad():
                    feat_a = model(images)
                feat = block(feat_a.detach())
                output = classifier(feat)
            else:
                feat = model(images)
                output = classifier(feat)

            loss = criterion(output, target)

        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % config.print_freq == 0:
            progress.display(i, logger)
Beispiel #18
0
def adv(model, device, dataloader, criterion, optimizer, lr_scheduler=None, epoch=0, args=None):
    if args.local_rank == 0:
        print(" ->->->->->->->->->-> One epoch with Adversarial (Trades) training <-<-<-<-<-<-<-<-<-<-")
        
    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.4f")
    top1 = AverageMeter("Acc_1", ":6.2f")
    top5 = AverageMeter("Acc_5", ":6.2f")
    top1_adv = AverageMeter("Acc_1_adv", ":6.2f")
    top5_adv = AverageMeter("Acc_5_adv", ":6.2f")
    progress = ProgressMeter(
        len(dataloader),
        [batch_time, data_time, losses, top1, top5, top1_adv, top5_adv],
        prefix="Epoch: [{}]".format(epoch),
    )

    model.train()
    end = time.time()
    
    for i, data in enumerate(dataloader):
        images, target = data[0].to(device), data[1].to(device)

        # basic properties of training
        if i == 0 and args.local_rank == 0:
            print(
                images.shape,
                target.shape,
                f"Batch_size from args: {args.batch_size}",
                "lr: {:.5f}".format(optimizer.param_groups[0]["lr"]),
            )
            print(
                "Pixel range for training images : [{}, {}]".format(
                    torch.min(images).data.cpu().numpy(),
                    torch.max(images).data.cpu().numpy(),
                )
            )

        # calculate robust loss
        loss, logits, logits_adv = trades_loss(
            model=model,
            x_natural=images,
            y=target,
            device=device,
            optimizer=optimizer,
            step_size=args.step_size,
            epsilon=args.epsilon,
            perturb_steps=args.num_steps,
            beta=args.beta,
            clip_min=args.clip_min,
            clip_max=args.clip_max,
            distance=args.distance,
        )

        # measure accuracy and record loss
        acc1, acc5 = accuracy(logits, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))
        acc1_adv, acc5_adv = accuracy(logits_adv, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1_adv.update(acc1_adv[0], images.size(0))
        top5_adv.update(acc5_adv[0], images.size(0))

        optimizer.zero_grad()
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()
        if lr_scheduler:
            lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0 and args.local_rank == 0:
            progress.display(i)
    result = {"top1": top1.avg, "top5":  top5.avg, "top1_adv": top1_adv.avg, "top5_adv": top5_adv.avg}
    return result
Beispiel #19
0
def baseline(model, device, dataloader, criterion, optimizer, lr_scheduler=None, epoch=0, args=None):
    if args.local_rank == 0:
        print(" ->->->->->->->->->-> One epoch with Baseline natural training <-<-<-<-<-<-<-<-<-<-")
    
    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.4f")
    top1 = AverageMeter("Acc_1", ":6.2f")
    top5 = AverageMeter("Acc_5", ":6.2f")
    progress = ProgressMeter(
        len(dataloader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch),
    )

    model.train()
    end = time.time()
    
    for i, data in enumerate(dataloader):
        images, target = data[0].to(device), data[1].to(device)
            
        # basic properties of training
        if i == 0 and args.local_rank == 0:
            print(
                images.shape,
                target.shape,
                f"Batch_size from args: {args.batch_size}",
                "lr: {:.5f}".format(optimizer.param_groups[0]["lr"]),
            )
            print(
                "Pixel range for training images : [{}, {}]".format(
                    torch.min(images).data.cpu().numpy(),
                    torch.max(images).data.cpu().numpy(),
                )
            )

        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        optimizer.zero_grad()
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()
        if lr_scheduler:
            lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0 and args.local_rank == 0:
            progress.display(i)
    result = {"top1": top1.avg, "top5":  top5.avg}
    return result
Beispiel #20
0
def train(loader: DataLoader, model: nn.Module, criterion: Callable,
          optimizer: Optimizer, scheduler: object, scaler: GradScaler,
          device: torch.device, loop_steps: int, step: int):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    # path_losses = AverageMeter('Path Loss', ':.4e')
    # mlm_losses = AverageMeter('MLM Loss', ':.4e')
    accuracies = AverageMeter('Accuracy', ':6.2f')
    # progress = ProgressMeter(
    #     loop_steps,
    #     [batch_time, data_time, losses, path_losses, mlm_losses, accuracies],
    #     prefix="Epoch: [{}]".format(step))
    progress = ProgressMeter(loop_steps,
                             [batch_time, data_time, losses, accuracies],
                             prefix="Epoch: [{}]".format(step))

    model.train()

    end = time.time()
    for i, (path, attention_mask, targets, _) in enumerate(loader):
        data_time.update(time.time() - end)

        non_blocking = device.type != 'cpu'
        path = path.to(device, non_blocking=non_blocking)
        attention_mask = path.to(device, non_blocking=non_blocking)
        targets = targets.to(device, non_blocking=non_blocking)

        #
        # mask_indices = []
        # for input_ids_batch, attention_mask_batch in zip(path, attention_masks):
        #     sequence_lengths_batch = attention_mask_batch.shape[1] - attention_mask_batch.flip(1).argmax(1)
        #     attention_mask_batch = attention_mask_batch.detach().clone()
        #     # remove masking for sequence length
        #     for batch_idx in range(input_ids_batch.shape[0]):
        #         attention_mask_batch[batch_idx, sequence_lengths_batch[batch_idx]:] = 1
        #     mask_indices_batch = (attention_mask_batch == 0).nonzero(as_tuple=True)
        #     mask_indices.append(mask_indices_batch)
        #

        with autocast():
            # logits, mask_encodings = model(input_ids=path, attention_mask=attention_mask)
            output = model(input_ids=path,
                           attention_mask=attention_mask,
                           labels=targets)
            logits = output.logits
            loss = output.loss.mean()
            # # mlm_loss = torch.zeros(1).to(device, non_blocking=non_blocking)
            # for mask_encoding_batch, input_ids_batch, mask_indices_batch in zip(mask_encodings, path, mask_indices):
            #     if len(mask_indices_batch[0]) == 0:
            #         continue
            #     mask_targets = input_ids_batch[mask_indices_batch]
            #     mask_logits = mlm_head(mask_encoding_batch)
            #     # gradually take mean
            #     mlm_loss += criterion(mask_logits, mask_targets) * (1. / len(path))
            # loss = path_loss + MLM_COEF * mlm_loss
            accuracy = (logits.argmax(1) == targets).float().mean().item()

        losses.update(loss.item(), targets.shape[0])
        # path_losses.update(path_loss.item(), targets.shape[0])
        # mlm_losses.update(mlm_loss.item(), targets.shape[0])
        accuracies.update(accuracy, targets.shape[0])

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                       CLIPPING_GRADIENT_NORM)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        batch_time.update(time.time() - end)
        end = time.time()

        if i % (loop_steps // 50) == 0:
            progress.display(i)

        if i == loop_steps - 1:
            break
    return losses.avg
def train(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.4e")
    top1 = AverageMeter("Acc@1", ":6.2f")
    top5 = AverageMeter("Acc@5", ":6.2f")
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch + 1),
    )

    # switch to train mode
    model.train()

    # blur settings
    if args.mode == "normal":
        args.sigma = 0  # no blur
    elif args.mode == "multi-steps-cbt":
        args.sigma = adjust_multi_steps_cbt(
            args.init_sigma, epoch, args.cbt_rate,
            every=5)  # sigma decay every 5 epoch
    elif args.mode == "multi-steps":
        args.sigma = adjust_multi_steps(epoch)
    elif args.mode == "single-step":
        if epoch >= args.epochs // 2:
            args.sigma = 0
    elif args.mode == "fixed-single-step":
        if epoch >= args.epochs // 2:
            args.sigma = 0  # no blur
            # fix parameters of 1st Conv layer
            model.features[0].weight.requires_grad = False
            model.features[0].bias.requires_grad = False
    elif args.mode == "reversed-single-step":
        if epoch < args.epochs // 2:
            args.sigma = 0
        else:
            args.sigma = args.reverse_sigma

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        # blur images (no blur when args.sigma = 0)
        if args.mode == "mix":
            half1, half2 = images.chunk(2)
            # blur first half images
            half1 = GaussianBlurAll(half1, args.sigma)
            images = torch.cat((half1, half2))
        elif args.mode == "random-mix":
            half1, half2 = images.chunk(2)
            # blur first half images
            half1 = RandomGaussianBlurAll(half1, args.min_sigma,
                                          args.max_sigma)
            images = torch.cat((half1, half2))
        else:
            images = GaussianBlurAll(images, args.sigma)

        if torch.cuda.is_available():
            images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

        # compute outputs
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

    return losses.avg, top1.avg, top5.avg
Beispiel #22
0
        ## eval
        eval_acc = model.evaluation(testloader)

        ## log
        for i, loss in enumerate(loss_list):
            log_list[i + 1].update(loss.avg)
            writer.add_scalar(loss.name, loss.avg, epoch)
        lr_name = 'lr'
        for i, opt in enumerate(model.optimizer.optimizers):
            writer.add_scalar(lr_name, opt.param_groups[0]['lr'], epoch)
            lr_name = f'lr_{i+2}'
        log_list[-1].update(eval_acc)
        writer.add_scalar(log_list[-1].name, eval_acc, epoch)

        logger.log(progress.display(epoch), consol=False)

        ## save
        state = {
            'args': args,
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': model.optimizer.state_dict()
        }

        if max_acc < eval_acc:
            max_acc = eval_acc
            filename = os.path.join(args.save_folder,
                                    'checkpoint_best.pth.tar')
            logger.log('#' * 20 + 'Save Best Model' + '#' * 20)
            torch.save(state, filename)
Beispiel #23
0
def validate(val_loader, model, epsilon, args):
    batch_time = AverageMeter('Time', ':6.3f')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(len(val_loader), [batch_time, top1],
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
    std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
    preprocessing = (mean, std)
    fmodel = PyTorchModel(model,
                          bounds=(0, 1),
                          num_classes=1000,
                          preprocessing=preprocessing)

    clean_labels = np.zeros(len(val_loader))
    target_labels = np.zeros(len(val_loader))
    clean_pred_labels = np.zeros(len(val_loader))
    adv_pred_labels = np.zeros(len(val_loader))

    end = time.time()

    # Batch processing is not supported in in foolbox 1.8, so we feed images one by one. Note that we are using a batch
    # size of 2, which means we consider every other image (due to computational costs)
    for i, (images, target) in enumerate(val_loader):

        image = images.cpu().numpy()[0]
        clean_label = target.cpu().numpy()[0]

        target_label = np.random.choice(
            np.setdiff1d(np.arange(1000), clean_label))
        attack = RandomStartProjectedGradientDescentAttack(
            model=fmodel,
            criterion=TargetClass(target_label),
            distance=Linfinity)
        adversarial = attack(image,
                             clean_label,
                             binary_search=False,
                             epsilon=epsilon,
                             stepsize=2. / 255,
                             iterations=args.pgd_steps,
                             random_start=True)

        if np.any(adversarial == None):
            # Non-adversarial
            adversarial = image
            target_label = clean_label

        adv_pred_labels[i] = np.argmax(fmodel.predictions(adversarial))
        clean_labels[i] = clean_label
        target_labels[i] = target_label
        clean_pred_labels[i] = np.argmax(fmodel.predictions(image))

        print('Iter, Clean, Clean_pred, Adv, Adv_pred: ', i, clean_label,
              clean_pred_labels[i], target_label, adv_pred_labels[i])

        # measure accuracy and update average
        acc1 = 100. * np.mean(clean_label == adv_pred_labels[i])
        top1.update(acc1, 1)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

    print('* Acc@1 {top1.avg:.3f} '.format(top1=top1))

    return top1.avg
Beispiel #24
0
            acc1, acc5 = accuracy(sim_t_2_i,
                                  torch.arange(image.size(0)).cuda(),
                                  topk=(1, 5))
            losses.update(loss.item(), image.size(0))
            top1_acc.update(acc1[0], image.size(0))
            top5_acc.update(acc5[0], image.size(0))

            loss.backward()
            optimizer.step()

            scheduler.step()
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % cfg.TRAIN.PRINT_FREQ == 0:
                progress.display(global_step % (len(trainloader) * 30))

    if epoch % 8 == 1:
        checkpoint_file = args.name + "/checkpoint_%d.pth" % epoch
        torch.save(
            {
                "epoch": epoch,
                "global_step": global_step,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict()
            }, checkpoint_file)
    if top1_acc.avg > best_top1:
        best_top1 = top1_acc.avg
        checkpoint_file = args.name + "/checkpoint_best.pth"
        torch.save(
            {
def validate(loader: DataLoader, model: nn.Module, criterion: Callable,
             num_classes: int, num_super_classes: int, maf: torch.FloatTensor,
             args: ArgumentParser) -> torch.FloatTensor:

    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('MLM Loss', ':.4e')
    accuracies = AverageMeter('Acc', ':.4e')
    accuracy_deltas = AverageMeter('Acc Delta', ':.4e')
    progress = ProgressMeter(len(loader),
                             [batch_time, losses, accuracies, accuracy_deltas],
                             prefix="Test: ")

    model.eval()

    device = get_device(args)
    with torch.no_grad():
        end = time.time()
        for i, (genotypes, labels, super_labels) in enumerate(loader):

            ### Mask for Masked Language Modeling
            mask_num = int((i % 9 + 1) / 10 * genotypes.shape[1])
            mask_scores = torch.rand(genotypes.shape[1])
            mask_indices = mask_scores.argsort(descending=True)[:mask_num]
            masked_genotypes = genotypes[:, mask_indices].reshape(-1)
            targets = (masked_genotypes == 1).float().clone().detach()
            genotypes[:, mask_indices] = 0
            maf_vector = maf[labels[0]]

            genotypes = genotypes.to(device)
            masked_genotypes = masked_genotypes.to(device)
            targets = targets.to(device)
            labels = labels.to(device)
            super_labels = super_labels.to(device)
            maf_vector = maf_vector.to(device)

            logits = model(genotypes, labels, super_labels)
            logits = logits[:, mask_indices].reshape(-1)

            # add weight to nonzero maf snps
            weights = torch.ones_like(logits)
            weight_coefficients = (maf_vector[mask_indices] > 0).repeat(
                genotypes.shape[0]).float() * (args.minor_coefficient - 1) + 1
            weights *= weight_coefficients

            loss = criterion(logits, targets, weight=weights, reduction='mean')

            accuracy = (masked_genotypes * logits.sign()).mean() / 2 + .5
            baseline_accuracy = (
                masked_genotypes *
                (maf_vector[mask_indices].repeat(genotypes.shape[0]) -
                 .5000001).sign()).mean() / 2 + .5
            accuracy_delta = accuracy - baseline_accuracy

            losses.update(loss.item(), genotypes.shape[0])
            accuracies.update(accuracy.item(), genotypes.shape[0])
            accuracy_deltas.update(accuracy_delta.item(), genotypes.shape[0])
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)
        progress.display(i)
    return losses.avg
Beispiel #26
0
    def train_epoch(self):
        self.models.train()
        self.epoch += 1

        # record training statistics
        avg_meters = {
            'loss_rec': AverageMeter('Loss Rec', ':.4e'),
            'loss_adv': AverageMeter('Loss Adv', ':.4e'),
            'loss_disc': AverageMeter('Loss Disc', ':.4e'),
            'time': AverageMeter('Time', ':6.3f')
        }
        progress_meter = ProgressMeter(len(self.train_loaders[0]),
                                       avg_meters.values(),
                                       prefix="Epoch: [{}]".format(self.epoch))

        # begin training from minibatches
        for ix, (data_0, data_1) in enumerate(zip(*self.train_loaders)):
            start_time = time.time()

            # load text and labels
            src_0, src_len_0, labels_0 = data_0
            src_0, labels_0 = src_0.to(args.device), labels_0.to(args.device)
            src_1, src_len_1, labels_1 = data_1
            src_1, labels_1 = src_1.to(args.device), labels_1.to(args.device)

            # encode
            encoder = self.models['encoder']
            z_0 = encoder(labels_0, src_0, src_len_0)  # (batch_size, dim_z)
            z_1 = encoder(labels_1, src_1, src_len_1)

            # recon & transfer
            generator = self.models['generator']
            inputs_0 = (z_0, labels_0, src_0)
            h_ori_seq_0, pred_ori_0 = generator(*inputs_0, src_len_0, False)
            h_trans_seq_0_to_1, _ = generator(*inputs_0, src_len_1, True)

            inputs_1 = (z_1, labels_1, src_1)
            h_ori_seq_1, pred_ori_1 = generator(*inputs_1, src_len_1, False)
            h_trans_seq_1_to_0, _ = generator(*inputs_1, src_len_0, True)

            # discriminate real and transfer
            disc_0, disc_1 = self.models['disc_0'], self.models['disc_1']
            d_0_real = disc_0(h_ori_seq_0.detach())  # detached
            d_0_fake = disc_0(h_trans_seq_1_to_0.detach())
            d_1_real = disc_1(h_ori_seq_1.detach())
            d_1_fake = disc_1(h_trans_seq_0_to_1.detach())

            # discriminator loss
            loss_disc = (loss_fn(args.gan_type)(d_0_real, self.ones) +
                         loss_fn(args.gan_type)(d_0_fake, self.zeros) +
                         loss_fn(args.gan_type)(d_1_real, self.ones) +
                         loss_fn(args.gan_type)(d_1_fake, self.zeros))
            # gradient penalty
            if args.gan_type == 'wgan-gp':
                loss_disc += args.gp_weight * gradient_penalty(
                    h_ori_seq_0,  # real data for 0
                    h_trans_seq_1_to_0,  # fake data for 0
                    disc_0)
                loss_disc += args.gp_weight * gradient_penalty(
                    h_ori_seq_1,  # real data for 1
                    h_trans_seq_0_to_1,  # fake data for 1
                    disc_1)
            avg_meters['loss_disc'].update(loss_disc.item(), src_0.size(0))

            self.disc_optimizer.zero_grad()
            loss_disc.backward()
            self.disc_optimizer.step()

            # reconstruction loss
            loss_rec = (
                F.cross_entropy(  # Recon 0 -> 0
                    pred_ori_0.view(-1, pred_ori_0.size(-1)),
                    src_0[1:].view(-1),
                    ignore_index=bert_tokenizer.pad_token_id,
                    reduction='sum') + F.cross_entropy(  # Recon 1 -> 1
                        pred_ori_1.view(-1, pred_ori_1.size(-1)),
                        src_1[1:].view(-1),
                        ignore_index=bert_tokenizer.pad_token_id,
                        reduction='sum')) / (
                            2.0 * args.batch_size
                        )  # match scale with the orginal paper
            avg_meters['loss_rec'].update(loss_rec.item(), src_0.size(0))

            # generator loss
            d_0_fake = disc_0(h_trans_seq_1_to_0)  # not detached
            d_1_fake = disc_1(h_trans_seq_0_to_1)
            loss_adv = (loss_fn(args.gan_type, disc=False)
                        (d_0_fake, self.ones) +
                        loss_fn(args.gan_type, disc=False)(d_1_fake, self.ones)
                        ) / 2.0  # match scale with the original paper
            avg_meters['loss_adv'].update(loss_adv.item(), src_0.size(0))

            # XXX: threshold for training stability
            if (not args.two_stage):
                if (args.threshold is not None and loss_disc < args.threshold):
                    loss = loss_rec + args.rho * loss_adv
                else:
                    loss = loss_rec
            else:  # two_stage training
                if (args.second_stage_num > args.epochs - self.epoch):
                    # last second_stage; flow loss_adv gradients
                    loss = loss_rec + args.rho * loss_adv
                else:
                    loss = loss_rec
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            avg_meters['time'].update(time.time() - start_time)

            # log progress
            if (ix + 1) % args.log_interval == 0:
                progress_meter.display(ix + 1)

        progress_meter.display(len(self.train_loaders[0]))
Beispiel #27
0
def evaluate(model, valloader, epoch, cfg, index=2):
    global best_top1_eval
    print("Test::::")
    model.eval()
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1_acc = AverageMeter('Acc@1', ':6.2f')
    top5_acc = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(valloader), [batch_time, data_time, losses, top1_acc, top5_acc],
        prefix="Test Epoch: [{}]".format(epoch))
    end = time.time()
    with torch.no_grad():
        for batch_idx, batch in enumerate(valloader):
            if cfg.DATA.USE_MOTION:
                image, text, bk, id_car = batch
            else:
                image, text, id_car = batch
            tokens = tokenizer.batch_encode_plus(text,
                                                 padding='longest',
                                                 return_tensors='pt')
            data_time.update(time.time() - end)
            if cfg.DATA.USE_MOTION:
                pairs, logit_scale, cls_logits = model(
                    tokens['input_ids'].cuda(),
                    tokens['attention_mask'].cuda(), image.cuda(), bk.cuda())
            else:
                pairs, logit_scale, cls_logits = model(
                    tokens['input_ids'].cuda(),
                    tokens['attention_mask'].cuda(), image.cuda())
            logit_scale = logit_scale.mean().exp()
            loss = 0

            # for visual_embeds,lang_embeds in pairs:
            visual_embeds, lang_embeds = pairs[index]
            sim_i_2_t = torch.matmul(torch.mul(logit_scale, visual_embeds),
                                     torch.t(lang_embeds))
            sim_t_2_i = sim_i_2_t.t()
            loss_t_2_i = F.cross_entropy(sim_t_2_i,
                                         torch.arange(image.size(0)).cuda())
            loss_i_2_t = F.cross_entropy(sim_i_2_t,
                                         torch.arange(image.size(0)).cuda())
            loss += (loss_t_2_i + loss_i_2_t) / 2

            acc1, acc5 = accuracy(sim_t_2_i,
                                  torch.arange(image.size(0)).cuda(),
                                  topk=(1, 5))
            losses.update(loss.item(), image.size(0))
            top1_acc.update(acc1[0], image.size(0))
            top5_acc.update(acc5[0], image.size(0))
            batch_time.update(time.time() - end)
            end = time.time()

            progress.display(batch_idx)
    if top1_acc.avg > best_top1_eval:
        best_top1_eval = top1_acc.avg
        checkpoint_file = args.name + "/checkpoint_best_eval.pth"
        torch.save(
            {
                "epoch": epoch,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict()
            }, checkpoint_file)