예제 #1
0
def validate(val_loader, model, encoder_learn):
    batch_time = metrics.AverageMeter()
    psnr = metrics.AverageMeter()

    # switch to evaluate mode
    model.cuda()
    model.eval()

    # binarize weights
    if encoder_learn:
        model.module.measurements.binarization()

    end = time.time()
    for i, (video_frames, pad_frame_size,
            patch_shape) in enumerate(val_loader):
        video_input = video_frames.cuda()
        print(val_loader.dataset.videos[i])

        # compute output
        model.module.pad_frame_size = pad_frame_size.numpy()
        model.module.patch_shape = patch_shape.numpy()
        reconstructed_video, y = model(video_input)

        # original video
        reconstructed_video = reconstructed_video.cpu().data.numpy()
        original_video = video_input.cpu().data.numpy()

        # measure accuracy and record loss
        psnr_video = metrics.psnr_accuracy(reconstructed_video, original_video)
        psnr.update(psnr_video, video_frames.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        logging.info('Test: [{0}/{1}]\t'
                     'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                     'PSNR {psnr.val:.3f} ({psnr.avg:.3f})'.format(
                         i + 1,
                         len(val_loader),
                         batch_time=batch_time,
                         psnr=psnr))

    # restore real-valued weights
    if encoder_learn:
        model.module.measurements.restore()

    print(' * PSNR {psnr.avg:.3f}'.format(psnr=psnr))

    return psnr.avg
예제 #2
0
def train(loader, models, optimizer, criterion, writer, epoch):
    batch_time = metrics.AverageMeter()
    data_time = metrics.AverageMeter()
    losses_R = metrics.AverageMeter()

    netE, netD = models

    netE.train()
    netD.train()

    total_iter = len(loader)
    end = time.time()
    for i, (inputs, _) in enumerate(loader):
        inputs = inputs.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        batch_size = inputs.size(0)
        # Reconstruction
        optimizer.zero_grad()

        outputs = netD(netE(inputs))[0]
        lossR = criterion(outputs, inputs)

        lossR.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # measure accuracy and record loss
        losses_R.update(lossR.item(), batch_size)

        # global_step = (epoch * total_iter) + i + 1
        # writer.add_scalar('train/loss', losses.val, global_step)

        if i % 10 == 0:
            print('Epoch {0} [{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss_R {lossR.val:.4f} ({lossR.avg:.4f})\t'.format(
                      epoch + 1,
                      i + 1,
                      total_iter,
                      batch_time=batch_time,
                      data_time=data_time,
                      lossR=losses_R))
예제 #3
0
def main():
    global args
    args = parser.parse_args()

    # massage args
    block_opts = []
    block_opts = args.block_opts
    block_opts.append(args.block_overlap)

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # create model
    print("=> using pre-trained model '{}'".format(args.arch))
    model = models.__dict__[args.arch](
        block_opts, pretrained=args.pretrained_net, mask_path=None, mean=args.mean, std=args.std,
        noise=args.noise, K=args.layers_k)
    model = torch.nn.DataParallel(model, device_ids=[args.gpu_id]).cuda()

    # switch to evaluate mode
    model.eval()
    cudnn.benchmark = True

    # Data loading code
    testdir = os.path.join(args.data)

    test_loader = torch.utils.data.DataLoader(
        datasets.videocs.VideoCS(testdir, block_opts, transforms.Compose([
            transforms.ToTensor(),
        ])),
        batch_size=1, shuffle=False,
        num_workers=0, pin_memory=True)

    batch_time = metrics.AverageMeter()
    psnr = metrics.AverageMeter()

    # binarize weights
    model_weights = model.module.measurements.weight.data
    if ((model_weights == 0) | (model_weights == 1)).all() == False:
        model.module.measurements.binarization()

    end = time.time()
    for i, (video_frames, pad_frame_size, patch_shape) in enumerate(test_loader):
        video_input = Variable(video_frames.cuda(async=True), volatile=True)
        print(test_loader.dataset.videos[i])

        # compute output
        model.module.pad_frame_size = pad_frame_size.numpy()
        model.module.patch_shape = patch_shape.numpy()
        reconstructed_video, y = model(video_input)

        # original video
        reconstructed_video = reconstructed_video.cpu().data.numpy()
        original_video = video_input.cpu().data.numpy()

        # measure accuracy and record loss
        psnr_video = metrics.psnr_accuracy(reconstructed_video, original_video)
        psnr.update(psnr_video, video_frames.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        print('Test: [{0}/{1}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'PSNR {psnr.val:.3f} ({psnr.avg:.3f})'.format(
                  i + 1, len(test_loader), batch_time=batch_time,
                  psnr=psnr))

        if args.save_videos is not None:
            save_path = os.path.join(args.save_videos, args.save_format)
            if not os.path.exists(save_path):
                os.makedirs(save_path)

            y_repeat = torch.zeros(
                *y.size()).unsqueeze(2).repeat(1, 1, args.block_opts[0], 1, 1)
            for j in range(y.size(1)):
                y_repeat[:, j, :, :, :] = y[:, j, :, :].repeat(
                    1, args.block_opts[0], 1, 1).data
            y_repeat = y_repeat.numpy()

            original_video = np.reshape(
                original_video, (original_video.shape[0] * original_video.shape[1] * original_video.shape[2], original_video.shape[3], original_video.shape[4]))
            reconstructed_video = np.reshape(reconstructed_video, (reconstructed_video.shape[0] * reconstructed_video.shape[1] *
                                                                   reconstructed_video.shape[2], reconstructed_video.shape[3], reconstructed_video.shape[4])) / np.max(reconstructed_video)
            y_repeat = np.reshape(y_repeat, (y_repeat.shape[0] * y_repeat.shape[1] *
                                             y_repeat.shape[2], y_repeat.shape[3], y_repeat.shape[4])) / np.max(y_repeat)

            write_video(save_path, test_loader.dataset.videos[i], np.dstack(
                (original_video, y_repeat, reconstructed_video)), args.save_format)

    print(' * PSNR {psnr.avg:.3f}'.format(psnr=psnr))
예제 #4
0
def validate(args, valid_loader, model, epoch=0, criterion=False, cur_step=0):
    print(
        '-------------------validation_start at epoch {}---------------------'.
        format(epoch))
    top1 = metrics.AverageMeter()
    top5 = metrics.AverageMeter()
    top10 = metrics.AverageMeter()
    losses = metrics.AverageMeter()

    model.eval()
    model.to(device)
    with torch.no_grad():
        for step, (X, y) in enumerate(valid_loader):
            X, y = X.to(device, non_blocking=True), y.to(device,
                                                         non_blocking=True)
            N = X.size(0)

            if args.distributed:
                if N < int(args.batch_size // world_size):
                    continue
            else:
                if N < args.batch_size:  # skip the last batch
                    continue

            logits = model(X)

            if not args.PCB:
                _, preds = torch.max(logits.data, 1)
                loss = criterion(logits, y)
            else:
                part = {}
                sm = nn.Softmax(dim=1)
                num_part = 6
                for i in range(num_part):
                    part[i] = logits[i]

                score = sm(part[0]) + sm(part[1]) + sm(part[2]) + sm(
                    part[3]) + sm(part[4]) + sm(part[5])
                _, preds = torch.max(score.data, 1)

                loss = criterion(part[0], y)
                for i in range(num_part - 1):
                    loss += criterion(part[i + 1], y)

            if args.PCB:
                prec1, prec5, prec10 = metrics.accuracy(score,
                                                        y,
                                                        topk=(1, 5, 10))
            else:
                prec1, prec5, prec10 = metrics.accuracy(logits,
                                                        y,
                                                        topk=(1, 5, 10))

            if args.distributed:
                dist.simple_sync.allreducemean_list(
                    [loss, prec1, prec5, prec10])

            losses.update(loss.item(), N)
            top1.update(prec1.item(), N)
            top5.update(prec5.item(), N)
            top10.update(prec10.item(), N)

            if args.distributed:
                if rank == 0:
                    if step % args.print_freq == 0 or step == len(
                            valid_loader) - 1:
                        logger.info(
                            "Valid: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                            "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".
                            format(epoch + 1,
                                   args.epochs,
                                   step,
                                   len(valid_loader) - 1,
                                   losses=losses,
                                   top1=top1,
                                   top5=top5))

            else:
                if step % args.print_freq == 0 or step == len(
                        valid_loader) - 1:
                    logger.info(
                        "Valid: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                        "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                            epoch + 1,
                            args.epochs,
                            step,
                            len(valid_loader) - 1,
                            losses=losses,
                            top1=top1,
                            top5=top5))

    if args.distributed:
        if rank == 0:
            writer.add_scalar('val/loss', losses.avg, cur_step)
            writer.add_scalar('val/top1', top1.avg, cur_step)
            writer.add_scalar('val/top5', top5.avg, cur_step)
            writer.add_scalar('val/top10', top10.avg, cur_step)

            logger.info(
                "Valid: [{:2d}/{}] Final Prec@1 {:.4%}, Prec@5 {:.4%}, Prec@10 {:.4%}"
                .format(epoch + 1, args.epochs, top1.avg, top5.avg, top10.avg))

    else:
        writer.add_scalar('val/loss', losses.avg, cur_step)
        writer.add_scalar('val/top1', top1.avg, cur_step)
        writer.add_scalar('val/top5', top5.avg, cur_step)
        writer.add_scalar('val/top10', top10.avg, cur_step)

        logger.info(
            "Valid: [{:2d}/{}] Final Prec@1 {:.4%}, Prec@5 {:.4%}, Prec@10 {:.4%}"
            .format(epoch + 1, args.epochs, top1.avg, top5.avg, top10.avg))

    return top1.avg
예제 #5
0
def train(args,
          train_loader,
          valid_loader,
          model,
          woptimizer,
          lr_scheduler,
          epoch=0,
          criterion=False):
    print('-------------------training_start at epoch {}---------------------'.
          format(epoch))
    top1 = metrics.AverageMeter()
    top5 = metrics.AverageMeter()
    top10 = metrics.AverageMeter()
    losses = metrics.AverageMeter()

    cur_step = epoch * len(train_loader)

    lr_scheduler.step()
    lr = lr_scheduler.get_lr()[0]

    if args.distributed:
        if rank == 0:
            writer.add_scalar('train/lr', lr, cur_step)
    else:
        writer.add_scalar('train/lr', lr, cur_step)

    model.train()

    running_loss = 0.0
    running_corrects = 0.0
    step = 0

    for samples, labels in train_loader:
        step = step + 1
        now_batch_size, c, h, w = samples.shape
        if now_batch_size < args.batch_size:  # skip the last batch
            continue

        if use_gpu:
            #samples = Variable(samples.cuda().detach())
            #labels = Variable(labels.cuda().detach())
            samples, labels = samples.to(device), labels.to(device)
        else:
            samples, labels = Variable(samples), Variable(labels)

        model.to(device)
        woptimizer.zero_grad()
        logits = model(samples)

        if not args.PCB:
            _, preds = torch.max(logits.data, 1)
            loss = criterion(logits, labels)
        else:
            part = {}
            sm = nn.Softmax(dim=1)
            num_part = 6
            for i in range(num_part):
                part[i] = logits[i]

            score = sm(part[0]) + sm(part[1]) + sm(part[2]) + sm(part[3]) + sm(
                part[4]) + sm(part[5])
            _, preds = torch.max(score.data, 1)

            loss = criterion(part[0], labels)
            for i in range(num_part - 1):
                loss += criterion(part[i + 1], labels)

        if epoch < args.warm_epoch and args.warm_up:
            warm_iteration = round(
                len(train_loader) /
                args.batch_size) * args.warm_epoch  # first 5 epoch
            warm_up = min(1.0, warm_up + 0.9 / warm_iteration)
            loss *= warm_up

        if args.fp16:  # we use optimier to backward loss
            with amp.scale_loss(loss, woptimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if args.w_grad_clip != False:
            nn.utils.clip_grad_norm_(model.weights(), args.w_grad_clip)

        if args.distributed:
            dist.simple_sync.sync_grad_sum(model)

        woptimizer.step()

        if args.distributed:
            dist.simple_sync.sync_bn_stat(model)

        if args.PCB:
            prec1, prec5, prec10 = metrics.accuracy(score,
                                                    labels,
                                                    topk=(1, 5, 10))
        else:
            prec1, prec5, prec10 = metrics.accuracy(logits,
                                                    labels,
                                                    topk=(1, 5, 10))

        if args.distributed:
            dist.simple_sync.allreducemean_list([loss, prec1, prec5, prec10])

        losses.update(loss.item(), samples.size(0))
        top1.update(prec1.item(), samples.size(0))
        top5.update(prec5.item(), samples.size(0))
        top10.update(prec10.item(), samples.size(0))

        running_loss += loss.item() * now_batch_size

        #y_loss['train'].append(losses)
        #y_err['train'].append(1.0-top1)

        if args.distributed:
            if rank == 0:
                if step % args.print_freq == 0 or step == len(
                        train_loader) - 1:
                    logger.info(
                        "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                        "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                            epoch + 1,
                            args.epochs,
                            step,
                            len(train_loader) - 1,
                            losses=losses,
                            top1=top1,
                            top5=top5))

                writer.add_scalar('train/loss', loss.item(), cur_step)
                writer.add_scalar('train/top1', prec1.item(), cur_step)
                writer.add_scalar('train/top5', prec5.item(), cur_step)
        else:
            if step % args.print_freq == 0 or step == len(train_loader) - 1:
                logger.info(
                    "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                    "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                        epoch + 1,
                        args.epochs,
                        step,
                        len(train_loader) - 1,
                        losses=losses,
                        top1=top1,
                        top5=top5))

            writer.add_scalar('train/loss', loss.item(), cur_step)
            writer.add_scalar('train/top1', prec1.item(), cur_step)
            writer.add_scalar('train/top5', prec5.item(), cur_step)
            writer.add_scalar('train/top10', prec10.item(), cur_step)

        cur_step += 1
    if args.distributed:
        if rank == 0:
            logger.info("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(
                epoch + 1, args.epochs, top1.avg))
    else:
        logger.info("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(
            epoch + 1, args.epochs, top1.avg))

    if args.distributed:
        if rank == 0:
            if epoch % args.forcesave == 0:
                save_network(args, model, epoch, top1)
    else:
        if epoch % args.forcesave == 0:
            save_network(args, model, epoch, top1)
def validate(args, valid_loader, model, epoch=0, cur_step=0):
    print(
        '-------------------validation_start at epoch {}---------------------'.
        format(epoch))
    top1 = metrics.AverageMeter()
    top5 = metrics.AverageMeter()
    top10 = metrics.AverageMeter()
    losses = metrics.AverageMeter()

    model.eval()
    model.to(device)
    with torch.no_grad():
        for step, (X, y) in enumerate(valid_loader):
            X, y = X.to(device, non_blocking=True), y.to(device,
                                                         non_blocking=True)
            N = X.size(0)

            ### 必须加分布式判断,否则validation跳过一直为真。
            if args.distributed:
                if N < int(args.batch_size // world_size):
                    continue
            else:
                if N < args.batch_size:  # skip the last batch
                    continue

            logits = model(X)
            loss = model.criterion(logits, y)

            prec1, prec5, prec10 = metrics.accuracy(logits, y, topk=(1, 5, 10))

            if args.distributed:
                dist.simple_sync.allreducemean_list(
                    [loss, prec1, prec5, prec10])

            losses.update(loss.item(), N)
            top1.update(prec1.item(), N)
            top5.update(prec5.item(), N)
            top10.update(prec10.item(), N)

            if args.distributed:
                if rank == 0:
                    if step % args.print_freq == 0 or step == len(
                            valid_loader) - 1:
                        logger.info(
                            "Valid: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                            "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".
                            format(epoch + 1,
                                   args.epochs,
                                   step,
                                   len(valid_loader) - 1,
                                   losses=losses,
                                   top1=top1,
                                   top5=top5))

            else:
                if step % args.print_freq == 0 or step == len(
                        valid_loader) - 1:
                    logger.info(
                        "Valid: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                        "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                            epoch + 1,
                            args.epochs,
                            step,
                            len(valid_loader) - 1,
                            losses=losses,
                            top1=top1,
                            top5=top5))

    if args.distributed:
        if rank == 0:
            writer.add_scalar('val/loss', losses.avg, cur_step)
            writer.add_scalar('val/top1', top1.avg, cur_step)
            writer.add_scalar('val/top5', top5.avg, cur_step)
            writer.add_scalar('val/top10', top10.avg, cur_step)

            logger.info(
                "Valid: [{:2d}/{}] Final Prec@1 {:.4%}, Prec@5 {:.4%}, Prec@10 {:.4%}"
                .format(epoch + 1, args.epochs, top1.avg, top5.avg, top10.avg))
    else:
        writer.add_scalar('val/loss', losses.avg, cur_step)
        writer.add_scalar('val/top1', top1.avg, cur_step)
        writer.add_scalar('val/top5', top5.avg, cur_step)
        writer.add_scalar('val/top10', top10.avg, cur_step)

        logger.info(
            "Valid: [{:2d}/{}] Final Prec@1 {:.4%}, Prec@5 {:.4%}, Prec@10 {:.4%}"
            .format(epoch + 1, args.epochs, top1.avg, top5.avg, top10.avg))

    return top1.avg
def train(args,
          train_loader,
          valid_loader,
          model,
          architect,
          w_optim,
          alpha_optim,
          lr_scheduler,
          epoch=0):
    print('-------------------training_start at epoch {}---------------------'.
          format(epoch))
    top1 = metrics.AverageMeter()
    top5 = metrics.AverageMeter()
    top10 = metrics.AverageMeter()
    losses = metrics.AverageMeter()

    cur_step = epoch * len(train_loader)

    lr_scheduler.step()
    lr = lr_scheduler.get_lr()[0]

    if args.distributed:
        if rank == 0:
            writer.add_scalar('train/lr', lr, cur_step)
    else:
        writer.add_scalar('train/lr', lr, cur_step)

    model.train()

    running_loss = 0.0
    running_corrects = 0.0
    #step = 0
    model.to(device)

    for step, ((trn_X, trn_y),
               (val_X, val_y)) in enumerate(zip(train_loader, valid_loader)):
        #step = step+1
        now_batch_size, c, h, w = trn_X.shape
        trn_X, trn_y = trn_X.to(device,
                                non_blocking=True), trn_y.to(device,
                                                             non_blocking=True)
        val_X, val_y = val_X.to(device,
                                non_blocking=True), val_y.to(device,
                                                             non_blocking=True)

        if args.distributed:
            if now_batch_size < int(args.batch_size // world_size):
                continue
        else:
            if now_batch_size < args.batch_size:  # skip the last batch
                continue

        alpha_optim.zero_grad()
        architect.unrolled_backward(trn_X, trn_y, val_X, val_y, lr, w_optim)
        alpha_optim.step()

        w_optim.zero_grad()
        logits = model(trn_X)
        loss = model.criterion(logits, trn_y)
        loss.backward()

        # gradient clipping\
        if args.w_grad_clip != False:
            nn.utils.clip_grad_norm_(model.weights(), args.w_grad_clip)

        if args.distributed:
            if args.sync_grad_sum:
                dist.sync_grad_sum(model)
            else:
                dist.sync_grad_mean(model)

        w_optim.step()

        if args.distributed:
            dist.sync_bn_stat(model)

        prec1, prec5, prec10 = metrics.accuracy(logits, trn_y, topk=(1, 5, 10))

        if args.distributed:
            dist.simple_sync.allreducemean_list([loss, prec1, prec5, prec10])

        losses.update(loss.item(), now_batch_size)
        top1.update(prec1.item(), now_batch_size)
        top5.update(prec5.item(), now_batch_size)
        top10.update(prec10.item(), now_batch_size)

        #running_loss += loss.item() * now_batch_size

        #y_loss['train'].append(losses)
        #y_err['train'].append(1.0-top1)

        if args.distributed:
            if rank == 0:
                if step % args.print_freq == 0 or step == len(
                        train_loader) - 1:
                    logger.info(
                        "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                        "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                            epoch + 1,
                            args.epochs,
                            step,
                            len(train_loader) - 1,
                            losses=losses,
                            top1=top1,
                            top5=top5))

                writer.add_scalar('train/loss', loss.item(), cur_step)
                writer.add_scalar('train/top1', prec1.item(), cur_step)
                writer.add_scalar('train/top5', prec5.item(), cur_step)
                writer.add_scalar('train/top10', prec10.item(), cur_step)
        else:
            if step % args.print_freq == 0 or step == len(train_loader) - 1:
                logger.info(
                    "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                    "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                        epoch + 1,
                        args.epochs,
                        step,
                        len(train_loader) - 1,
                        losses=losses,
                        top1=top1,
                        top5=top5))

            writer.add_scalar('train/loss', loss.item(), cur_step)
            writer.add_scalar('train/top1', prec1.item(), cur_step)
            writer.add_scalar('train/top5', prec5.item(), cur_step)
            writer.add_scalar('train/top10', prec10.item(), cur_step)

        cur_step += 1
    if args.distributed:
        if rank == 0:
            logger.info("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(
                epoch + 1, args.epochs, top1.avg))
    else:
        logger.info("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(
            epoch + 1, args.epochs, top1.avg))

    if epoch % args.forcesave == 0:
        save_network(args, model, epoch, top1)
예제 #8
0
    def _reset_metrics(self):

        self.total_loss = metrics.AverageMeter()
예제 #9
0
def train(train_loader, model, optimizer, epoch, mseloss, encoder_learn,
          gradient_clip):
    batch_time = metrics.AverageMeter()
    data_time = metrics.AverageMeter()
    losses = metrics.AverageMeter()
    psnr = metrics.AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (video_blocks, pad_block_size,
            block_shape) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        target = video_blocks.cuda()
        input_var = Variable(video_blocks.cuda())
        target_var = Variable(target)

        # compute output
        model.module.pad_frame_size = pad_block_size.numpy()
        model.module.patch_shape = block_shape.numpy()

        if encoder_learn:
            model.module.measurements.binarization()

        output, y = model(input_var)
        loss = mseloss.compute_loss(output, target_var)
        # record loss
        losses.update(loss.item(), video_blocks.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()

        if encoder_learn:
            # restore real-valued weights
            model.module.measurements.restore()
            nn.utils.clip_grad_norm_(model.module.parameters(), gradient_clip)
        else:
            nn.utils.clip_grad_norm_(model.module.reconstruction.parameters(),
                                     gradient_clip)

        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            logging.info('Epoch: [{0}][{1}/{2}]\t'
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                         'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                         'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                             epoch,
                             i,
                             len(train_loader),
                             batch_time=batch_time,
                             data_time=data_time,
                             loss=losses))
    return losses.avg