Beispiel #1
0
def train_on_data(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = misc_utils.AverageMeter()
    data_time = misc_utils.AverageMeter()
    losses = misc_utils.AverageMeter()
    top1 = misc_utils.AverageMeter()
    top5 = misc_utils.AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (img0, img1, target, _) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpus is not None:
            img0 = img0.cuda(args.gpus, non_blocking=True)
            img1 = img1.cuda(args.gpus, non_blocking=True)
        target = target.cuda(args.gpus, non_blocking=True)

        # compute output
        output = model(img0, img1)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1 = misc_utils.accuracy(output, target)
        acc5 = acc1.copy()
        losses.update(loss.item(), img0.size(0))
        top1.update(acc1[0].cpu().numpy()[0], img0.size(0))
        top5.update(acc5[0].cpu().numpy()[0], img0.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # printing the accuracy at certain intervals
        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch,
                      i,
                      len(train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))
        if i * len(img0) > args.train_samples:
            break
    return [epoch, batch_time.avg, losses.avg, top1.avg, top5.avg]
Beispiel #2
0
def validate_on_data(val_loader, model, criterion, args):
    batch_time = misc_utils.AverageMeter()
    losses = misc_utils.AverageMeter()
    top1 = misc_utils.AverageMeter()
    top5 = misc_utils.AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (input_image, target, _) in enumerate(val_loader):
            if args.dataset == 'natural':
                input_image = input_image[0]
                target = target[0]
            if args.gpus is not None:
                input_image = input_image.cuda(args.gpus, non_blocking=True)
            target = target.cuda(args.gpus, non_blocking=True)

            # compute output
            output = model(input_image)
            output = output[1]
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1 = misc_utils.accuracy(output, target)
            acc5 = acc1.copy()
            losses.update(loss.item(), input_image.size(0))
            top1.update(acc1[0].cpu().numpy()[0], input_image.size(0))
            top5.update(acc5[0].cpu().numpy()[0], input_image.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # printing the accuracy at certain intervals
            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                          i,
                          len(val_loader),
                          batch_time=batch_time,
                          loss=losses,
                          top1=top1,
                          top5=top5))
            if i * len(input_image) > args.val_samples:
                break
        # printing the accuracy of the epoch
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1,
                                                                    top5=top5))

    return [batch_time.avg, losses.avg, top1.avg, top5.avg]
Beispiel #3
0
def train(epoch, model, train_loader, optimizer, cuda, log_interval, save_path,
          args, writer, pos_net, neg_nets):
    losses_neg = misc.AverageMeter()
    losses_pos = misc.AverageMeter()
    top1_neg = misc.AverageMeter()
    top1_pos = misc.AverageMeter()

    model.train()
    loss_dict = model.latest_losses()
    losses = {k + '_train': 0 for k, v in loss_dict.items()}
    epoch_losses = {k + '_train': 0 for k, v in loss_dict.items()}
    start_time = time.time()
    batch_idx, data = None, None
    for batch_idx, loader_data in enumerate(train_loader):
        if args.dataset == 'coco':
            data = []
            for batch_data in loader_data:
                current_image = batch_data['image'][[2, 1, 0], :, :].clone()
                current_image = current_image.unsqueeze(0)
                current_image = current_image.type('torch.FloatTensor')
                current_image = nn.functional.interpolate(
                    current_image, (224, 224))
                current_image /= 255
                current_image[0] = functional.normalize(
                    current_image[0], args.mean, args.std, False)
                data.append(current_image)
            data = torch.cat(data, dim=0)
            max_len = len(train_loader.dataset)
        else:
            data = loader_data[0]
            target = loader_data[1]
            max_len = len(train_loader)
            target = target.cuda()

        data = data.cuda()
        optimizer.zero_grad()
        outputs = model(data)

        if args.dataset == 'coco':
            recon_imgs = preprocessing.inv_normalise_tensor(
                outputs[0], args.mean, args.std)
            for im_ind, batch_data in enumerate(loader_data):
                org_size = batch_data['image'].shape
                current_image = recon_imgs[im_ind].squeeze()[[2, 1, 0], :, :]
                current_image = current_image.unsqueeze(0)
                current_image = nn.functional.interpolate(
                    current_image, (org_size[1], org_size[2]))
                current_image *= 255
                current_image = current_image.type(torch.uint8)
                current_image = current_image.squeeze()
                batch_data['image'] = current_image
            with EventStorage(0) as storage:
                output_neg = neg_net(loader_data)
                loss_neg = sum(loss for loss in output_neg.values())
                losses_neg.update(loss_neg, data.size(0))
                output_pos = pos_net(loader_data)
                loss_pos = sum(loss for loss in output_pos.values())
                losses_pos.update(loss_pos, data.size(0))
        else:
            recon_imgs = preprocessing.inv_normalise_tensor(
                outputs[0], args.mean, args.std)
            recon_imgs = preprocessing.normalise_tensor(
                recon_imgs, args.imagenet_mean, args.imagenet_std)
            current_loss_negs = 0
            for neg_net in neg_nets:
                output_neg = neg_net(recon_imgs)
                loss_neg = args.criterion_neg(output_neg, target)
                acc1_neg, acc5_pos = misc.accuracy(output_neg,
                                                   target,
                                                   topk=(1, 5))
                losses_neg.update(loss_neg.item(), data.size(0))
                top1_neg.update(acc1_neg[0], data.size(0))
                current_loss_negs += loss_neg

            output_pos = pos_net(recon_imgs)
            loss_pos = args.criterion_pos(output_pos, target)
            acc1_pos, acc5_pos = misc.accuracy(output_pos, target, topk=(1, 5))
            losses_pos.update(loss_pos.item(), data.size(0))
            top1_pos.update(acc1_pos[0], data.size(0))

        loss = model.loss_function(data, *outputs) + (
            (loss_pos + current_loss_negs) / current_loss_negs)
        loss.backward()
        optimizer.step()
        latest_losses = model.latest_losses()
        for key in latest_losses:
            losses[key + '_train'] += float(latest_losses[key])
            epoch_losses[key + '_train'] += float(latest_losses[key])

        if batch_idx % log_interval == 0:
            for key in latest_losses:
                losses[key + '_train'] /= log_interval
            loss_string = ' '.join(
                ['{}: {:.6f}'.format(k, v) for k, v in losses.items()])
            logging.info('Train Epoch: {epoch} [{batch:5d}/{total_batch} '
                         '({percent:2d}%)]   time: {time:3.2f}   {loss}'
                         ' Lp: {loss_pos:.3f} Ap: {acc_pos:.3f}'
                         ' Ln: {loss_neg:.3f} An: {acc_neg:.3f}'.format(
                             epoch=epoch,
                             batch=batch_idx * len(data),
                             total_batch=max_len * len(data),
                             percent=int(100. * batch_idx / max_len),
                             time=time.time() - start_time,
                             loss=loss_string,
                             loss_pos=losses_pos.avg,
                             acc_pos=top1_pos.avg,
                             loss_neg=losses_neg.avg,
                             acc_neg=top1_neg.avg))
            start_time = time.time()
            for key in latest_losses:
                losses[key + '_train'] = 0
        if batch_idx in [18, 180, 1650, max_len - 1]:
            save_reconstructed_images(data, epoch, outputs[0], save_path,
                                      'reconstruction_train%.5d' % batch_idx)
            write_images(data, outputs, writer, 'train', args.mean, args.std)

        if args.dataset in [
                'imagenet', 'coco', 'custom'
        ] and batch_idx * len(data) > args.max_epoch_samples:
            break

    for key in epoch_losses:
        if args.dataset != 'imagenet':
            epoch_losses[key] /= (max_len / data.shape[0])
        else:
            epoch_losses[key] /= (len(train_loader.dataset) /
                                  train_loader.batch_size)
    loss_string = '\t'.join(
        ['{}: {:.6f}'.format(k, v) for k, v in epoch_losses.items()])
    logging.info('====> Epoch: {} {}'.format(epoch, loss_string))
    # writer.add_histogram('dict frequency', outputs[3], bins=range(args.k + 1))
    # model.print_atom_hist(outputs[3])
    return epoch_losses
def _train_val(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = report_utils.AverageMeter()
    data_time = report_utils.AverageMeter()
    losses = report_utils.AverageMeter()
    top1 = report_utils.AverageMeter()

    is_train = optimizer is not None

    if is_train:
        train_test_str = 'Train'
        model.train()
        num_samples = args.train_samples
    else:
        train_test_str = 'Test'
        model.eval()
        num_samples = args.val_samples

    end = time.time()
    with torch.set_grad_enabled(is_train):
        for i, (kinematic, intensity, mass_dist, response,
                trial_name) in enumerate(train_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            kinematic = kinematic.cuda(args.gpu, non_blocking=True)
            if args.out_type == 'intensity':
                intensity = intensity.cuda(args.gpu, non_blocking=True)
                response = response.cuda(args.gpu, non_blocking=True)
            else:
                intensity = None
                response = mass_dist.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(kinematic, intensity)
            loss = criterion(output, response)

            # measure accuracy and record loss
            acc1 = report_utils.accuracy(output, response)
            losses.update(loss.item(), kinematic.size(0))
            top1.update(acc1[0].cpu().numpy()[0], kinematic.size(0))

            if is_train:
                # compute gradient and do SGD step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # printing the accuracy at certain intervals
            if i % args.print_freq == 0:
                print('%s: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                          epoch,
                          i,
                          len(train_loader),
                          batch_time=batch_time,
                          data_time=data_time,
                          loss=losses,
                          top1=top1) % train_test_str)
            if num_samples is not None and i * len(kinematic) > num_samples:
                break
        if not is_train:
            # printing the accuracy of the epoch
            print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
    return [epoch, batch_time.avg, losses.avg, top1.avg]
Beispiel #5
0
def train_on_data(train_loader, model, criterion, optimizer, epoch, args):
    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    losses_obj = AverageMeter()
    top1_obj = AverageMeter()
    top5_obj = AverageMeter()
    losses_mun = AverageMeter()
    top1_mun = AverageMeter()
    top5_mun = AverageMeter()
    losses_ill = AverageMeter()
    top1_ill = AverageMeter()
    top5_ill = AverageMeter()

    if args.top_k is None:
        topks = (1, )
    else:
        topks = (1, args.top_k)

    # switch to train mode
    model.train()

    mean, std = model_utils.get_preprocessing_function(args.colour_space,
                                                       args.vision_type)
    normalise_inverse = cv2_transforms.NormalizeInverse(mean, std)
    normalise_back = transforms.Normalize(mean=mean, std=std)

    end = time.time()
    for i, (input_image, targets) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpus is not None:
            input_image = input_image.cuda(args.gpus, non_blocking=True)
        targets = targets.cuda(args.gpus, non_blocking=True)

        # compute output
        out_obj, out_mun, out_ill = model(input_image)

        if out_obj is None:
            loss_obj = 0
        else:
            loss_obj = criterion(out_obj, targets[:, 0])
            acc1_obj, acc5_obj = accuracy(out_obj, targets[:, 0], topk=topks)
            losses_obj.update(loss_obj.item(), input_image.size(0))
            top1_obj.update(acc1_obj[0], input_image.size(0))
            top5_obj.update(acc5_obj[0], input_image.size(0))
        if out_mun is None:
            loss_mun = 0
        else:
            loss_mun = criterion(out_mun, targets[:, 1])
            acc1_mun, acc5_mun = accuracy(out_mun, targets[:, 1], topk=topks)
            losses_mun.update(loss_mun.item(), input_image.size(0))
            top1_mun.update(acc1_mun[0], input_image.size(0))
            top5_mun.update(acc5_mun[0], input_image.size(0))
        if out_ill is None:
            loss_ill = 0
        else:
            loss_ill = criterion(out_ill, targets[:, 2])
            acc1_ill, acc5_ill = accuracy(out_ill, targets[:, 2], topk=topks)
            losses_ill.update(loss_ill.item(), input_image.size(0))
            top1_ill.update(acc1_ill[0], input_image.size(0))
            top5_ill.update(acc5_ill[0], input_image.size(0))

        loss = loss_obj + loss_mun + loss_ill
        losses.update(loss.item(), input_image.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if out_mun is None and args.ill_colour is not None:
            input_image2 = correct_image(normalise_inverse, normalise_back,
                                         input_image, out_ill, args.ill_colour)
            out_obj2, out_mun2, _ = model(input_image2)
            loss_mun2 = 0
            loss_obj2 = 0
            if out_mun2 is not None:
                loss_mun2 = criterion(out_mun2, targets[:, 1])
            if out_obj2 is not None:
                loss_obj2 = criterion(out_obj2, targets[:, 0])
            loss2 = loss_obj2 + loss_mun2
            optimizer.zero_grad()
            loss2.backward()
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # printing the accuracy at certain intervals
        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.2f} ({batch_time.avg:.2f})\t'
                  'Data {data_time.val:.2f} ({data_time.avg:.2f})\t'
                  'Loss {loss.val:.2f} ({loss.avg:.2f})\t'
                  'LO {obj_loss.val:.2f} ({obj_loss.avg:.2f})\t'
                  'LM {mun_loss.val:.2f} ({mun_loss.avg:.2f})\t'
                  'LI {ill_loss.val:.2f} ({ill_loss.avg:.2f})\t'
                  'Ao {obj_acc.val:.2f} ({obj_acc.avg:.2f})\t'
                  'AM {mun_acc.val:.2f} ({mun_acc.avg:.2f})\t'
                  'AI {ill_acc.val:.2f} ({ill_acc.avg:.2f})'.format(
                      epoch,
                      i,
                      len(train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      obj_loss=losses_obj,
                      mun_loss=losses_mun,
                      ill_loss=losses_ill,
                      obj_acc=top1_obj,
                      mun_acc=top1_mun,
                      ill_acc=top1_ill))
    return [
        epoch, batch_time.avg, losses.avg, losses_obj.avg, losses_mun.avg,
        losses_ill.avg, top1_obj.avg, top1_mun.avg, top1_ill.avg
    ]
Beispiel #6
0
def validate_on_data(val_loader, model, criterion, args):
    losses = AverageMeter()
    batch_time = AverageMeter()

    losses_obj = AverageMeter()
    top1_obj = AverageMeter()
    top5_obj = AverageMeter()
    losses_mun = AverageMeter()
    top1_mun = AverageMeter()
    top5_mun = AverageMeter()
    losses_ill = AverageMeter()
    top1_ill = AverageMeter()
    top5_ill = AverageMeter()

    if args.top_k is None:
        topks = (1, )
    else:
        topks = (1, args.top_k)

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (input_image, targets) in enumerate(val_loader):
            if args.gpus is not None:
                input_image = input_image.cuda(args.gpus, non_blocking=True)
            targets = targets.cuda(args.gpus, non_blocking=True)

            # compute output
            out_obj, out_mun, out_ill = model(input_image)
            if out_obj is None:
                loss_obj = 0
            else:
                loss_obj = criterion(out_obj, targets[:, 0])
                acc1_obj, acc5_obj = accuracy(out_obj,
                                              targets[:, 0],
                                              topk=topks)
                losses_obj.update(loss_obj.item(), input_image.size(0))
                top1_obj.update(acc1_obj[0], input_image.size(0))
                top5_obj.update(acc5_obj[0], input_image.size(0))
            if out_mun is None:
                loss_mun = 0
            else:
                loss_mun = criterion(out_mun, targets[:, 1])
                acc1_mun, acc5_mun = accuracy(out_mun,
                                              targets[:, 1],
                                              topk=topks)
                losses_mun.update(loss_mun.item(), input_image.size(0))
                top1_mun.update(acc1_mun[0], input_image.size(0))
                top5_mun.update(acc5_mun[0], input_image.size(0))
            if out_ill is None:
                loss_ill = 0
            else:
                loss_ill = criterion(out_ill, targets[:, 2])
                acc1_ill, acc5_ill = accuracy(out_ill,
                                              targets[:, 2],
                                              topk=topks)
                losses_ill.update(loss_ill.item(), input_image.size(0))
                top1_ill.update(acc1_ill[0], input_image.size(0))
                top5_ill.update(acc5_ill[0], input_image.size(0))

            loss = loss_obj + loss_mun + loss_ill
            losses.update(loss.item(), input_image.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # printing the accuracy at certain intervals
            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.2f} ({batch_time.avg:.2f})\t'
                      'Loss {loss.val:.2f} ({loss.avg:.2f})\t'
                      'LO {obj_loss.val:.2f} ({obj_loss.avg:.2f})\t'
                      'LM {mun_loss.val:.2f} ({mun_loss.avg:.2f})\t'
                      'LI {ill_loss.val:.2f} ({ill_loss.avg:.2f})\t'
                      'Ao {obj_acc.val:.2f} ({obj_acc.avg:.2f})\t'
                      'AM {mun_acc.val:.2f} ({mun_acc.avg:.2f})\t'
                      'AI {ill_acc.val:.2f} ({ill_acc.avg:.2f})'.format(
                          i,
                          len(val_loader),
                          batch_time=batch_time,
                          loss=losses,
                          obj_loss=losses_obj,
                          mun_loss=losses_mun,
                          ill_loss=losses_ill,
                          obj_acc=top1_obj,
                          mun_acc=top1_mun,
                          ill_acc=top1_ill))
        # printing the accuracy of the epoch
        print(' * AccObj {obj_acc.avg:.2f} AccMun {mun_acc.avg:.2f}'
              ' AccIll {ill_acc.avg:.2f}'.format(obj_acc=top1_obj,
                                                 mun_acc=top1_mun,
                                                 ill_acc=top1_ill))

    return [
        batch_time.avg, losses.avg, losses_obj.avg, losses_mun.avg,
        losses_ill.avg, top1_obj.avg, top1_mun.avg, top1_ill.avg
    ]