def val(model, dataloader, mse_loss, rate_loss, ps):
    if ps:
        ps.log('Validating ... ')
    else:
        print('Validating ... ')

    model.eval()

    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    mse_loss_meter.reset()
    if opt.use_imp:
        rate_display_meter.reset()
        rate_loss_meter.reset()
        total_loss_meter.reset()

    for idx, data in enumerate(dataloader):
        # ps.log('%.0f%%' % (idx*100.0/len(dataloader)))
        # pdb.set_trace()  # 32*7+16 = 240
        val_data = Variable(data, volatile=True)
        # val_mask = Variable(mask, volatile=True)
        # val_o_mask = Variable(o_mask, volatile=True)
        # pdb.set_trace()
        if opt.use_gpu:
            val_data = val_data.cuda(async=True)
            # val_mask = val_mask.cuda(async=True)
            # val_o_mask = val_o_mask.cuda(async=True)

        # reconstructed, imp_mask_sigmoid = model(val_data, val_mask, val_o_mask)
        if opt.use_imp:
            reconstructed, imp_mask_sigmoid = model(val_data)
        else:
            reconstructed = model(val_data)

        # batch_loss = mse_loss(reconstructed, val_data, val_o_mask)
        batch_loss = mse_loss(reconstructed, val_data)

        batch_caffe_loss = batch_loss / (2 * opt.batch_size)

        if opt.use_imp and rate_loss:
            rate_loss_value = rate_loss(imp_mask_sigmoid)
            total_loss = batch_caffe_loss + rate_loss_value

        mse_loss_meter.add(batch_caffe_loss.data[0])

        if opt.use_imp:
            rate_loss_meter.add(rate_loss_value.data[0])
            rate_display_meter.add(imp_mask_sigmoid.data.mean())
            total_loss_meter.add(total_loss.data[0])

    if opt.use_imp:
        return mse_loss_meter.value()[0], rate_loss_meter.value(
        )[0], total_loss_meter.value()[0], rate_display_meter.value()[0]
    else:
        return mse_loss_meter.value()[0]
Beispiel #2
0
def val(model, dataloader, mse_loss, rate_loss, ps):
    if ps:
        ps.log('validating ... ')
    else:
        print('run val ... ')
    model.eval()
    # avg_loss = 0
    # progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), ascii=True)
    # print(type(next(iter(dataloader))))
    # print(next(iter(dataloader)))
    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    mse_loss_meter.reset()
    if opt.use_imp:
        rate_display_meter.reset()
        rate_loss_meter.reset()
        total_loss_meter.reset()

    for idx, (data, _) in enumerate(dataloader):
        # ps.log('%.0f%%' % (idx*100.0/len(dataloader)))
        val_input = Variable(data, volatile=True)
        if opt.use_gpu:
            val_input = val_input.cuda(async=True)

        reconstructed, imp_mask_sigmoid = model(val_input)

        batch_loss = mse_loss(reconstructed, val_input)
        batch_caffe_loss = batch_loss / (2 * opt.batch_size)

        if opt.use_imp and rate_loss:
            rate_loss_display = imp_mask_sigmoid
            rate_loss_value = rate_loss(rate_loss_display)
            total_loss = batch_caffe_loss + rate_loss_value

        mse_loss_meter.add(batch_caffe_loss.data[0])

        if opt.use_imp:
            rate_loss_meter.add(rate_loss_value.data[0])
            rate_display_meter.add(rate_loss_display.data.mean())
            total_loss_meter.add(total_loss.data[0])
        # progress_bar.set_description('val_iter %d: loss = %.2f' % (idx+1, batch_caffe_loss.data[0]))
        # progress_bar.set_description('val_iter %d: loss = %.2f' % (idx+1, total_loss_meter.value()[0] if opt.use_imp else mse_loss_meter.value()[0]))

        # avg_loss += batch_caffe_loss.data[0]

    # avg_loss /= len(dataloader)
    # print('avg_loss =', avg_loss)
    # print('meter loss =', loss_meter.value[0])
    # print ('Total avg loss = {}'.format(avg_loss))
    if opt.use_imp:
        return mse_loss_meter.value()[0], rate_loss_meter.value(
        )[0], total_loss_meter.value()[0], rate_display_meter.value()[0]
    else:
        return mse_loss_meter.value()[0]
def val(compression_model,
        c_resnet_51,
        dataloader,
        class_loss,
        rate_loss=None,
        ps=None):
    if ps:
        ps.log('validating ... ')
    else:
        print('run val ... ')
    compression_model.eval()
    c_resnet_51.eval()

    class_loss_meter = AverageValueMeter()
    # if opt.use_imp:
    top5_acc_meter = AverageValueMeter()
    top1_acc_meter = AverageValueMeter()
    # total_loss_meter = AverageValueMeter()

    class_loss_meter.reset()
    # if opt.use_imp:
    top5_acc_meter.reset()
    top1_acc_meter.reset()
    # total_loss_meter.reset()

    for idx, (data, label) in enumerate(dataloader):
        #ps.log('%.0f%%' % (idx*100.0/len(dataloader)))
        val_input = Variable(data, volatile=True)
        label = Variable(label, volatile=True)

        if opt.use_gpu:
            val_input = val_input.cuda()
            label = label.cuda()

        # compressed_RGB = compression_model(val_input)

        compressed_feat = compression_model(val_input, need_decode=False)
        predicted = c_resnet_51(compressed_feat)

        val_class_loss = class_loss(predicted, label)
        class_loss_meter.add(val_class_loss.data[0])
        acc1, acc5 = accuracy(predicted.data, label.data, topk=(1, 5))

        top5_acc_meter.add(acc5[0])
        top1_acc_meter.add(acc1[0])

    return class_loss_meter.value()[0], top5_acc_meter.value(
    )[0], top1_acc_meter.value()[0]
def train(**kwargs):
    opt.parse(kwargs)
    # log file
    ps = PlotSaver("FrozenCNN_ResNet50_RGB_" +
                   time.strftime("%m_%d_%H:%M:%S") + ".log.txt")

    # step1: Model
    compression_model = getattr(models, opt.model)(
        use_imp=opt.use_imp,
        model_name="CWCNN_limu_ImageNet_imp_r={r}_γ={w}_for_resnet50".format(
            r=opt.rate_loss_threshold, w=opt.rate_loss_weight)
        if opt.use_imp else None)

    compression_model.load(None, opt.compression_model_ckpt)
    compression_model.eval()

    # if use_original_RGB:
    #     resnet_50 = resnet50()    # Official ResNet
    # else:
    #     resnet_50 = ResNet50()    # My ResNet
    c_resnet_51 = cResNet51()
    if opt.use_gpu:
        # compression_model.cuda()
        # resnet_50.cuda()
        compression_model = multiple_gpu_process(compression_model)
        c_resnet_51 = multiple_gpu_process(c_resnet_51)

    # freeze the compression network
    for param in compression_model.parameters():
        # print (param.requires_grad)
        param.requires_grad = False

    cudnn.benchmark = True

    # pdb.set_trace()

    # step2: Data
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_data_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])
    val_data_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(), normalize
    ])

    train_data = datasets.ImageFolder(opt.train_data_root,
                                      train_data_transforms)
    val_data = datasets.ImageFolder(opt.val_data_root, val_data_transforms)
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers,
                                  pin_memory=True)
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers,
                                pin_memory=True)

    # step3: criterion and optimizer

    class_loss = t.nn.CrossEntropyLoss()
    lr = opt.lr
    # optimizer = t.optim.Adam(resnet_50.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=opt.weight_decay)
    optimizer = t.optim.SGD(c_resnet_51.parameters(),
                            lr=lr,
                            momentum=opt.momentum,
                            weight_decay=opt.weight_decay)
    start_epoch = 0

    if opt.resume:
        start_epoch = c_resnet_51.module.load(
            None if opt.finetune else optimizer, opt.resume, opt.finetune)

        if opt.finetune:
            print('Finetune from model checkpoint file', opt.resume)
        else:
            print('Resume training from checkpoint file', opt.resume)
            print('Continue training at epoch %d.' % start_epoch)

    # step4: meters
    class_loss_meter = AverageValueMeter()

    class_acc_top5_meter = AverageValueMeter()
    class_acc_top1_meter = AverageValueMeter()

    # class_loss_meter = AverageMeter()
    # class_acc_top5_meter = AverageMeter()
    # class_acc_top1_meter = AverageMeter()

    # ps init

    ps.new_plot('train class loss',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_CE_loss")
    ps.new_plot('val class loss', 1, xlabel="epoch", ylabel="val_CE_loss")

    ps.new_plot('train top_5 acc',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_top5_acc")
    ps.new_plot('train top_1 acc',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_top1_acc")

    ps.new_plot('val top_5 acc', 1, xlabel="iteration", ylabel="val_top_5_acc")
    ps.new_plot('val top_1 acc', 1, xlabel="iteration", ylabel="val_top_1_acc")

    for epoch in range(start_epoch + 1, opt.max_epoch + 1):
        # per epoch avg loss meter
        class_loss_meter.reset()

        class_acc_top1_meter.reset()
        class_acc_top5_meter.reset()

        # cur_epoch_loss refresh every epoch

        ps.new_plot("cur epoch train class loss",
                    opt.print_freq,
                    xlabel="iteration in cur epoch",
                    ylabel="cur_train_CE_loss")

        c_resnet_51.train()

        for idx, (data, label) in enumerate(train_dataloader):
            ipt = Variable(data)
            label = Variable(label)

            if opt.use_gpu:
                ipt = ipt.cuda()
                label = label.cuda()

            optimizer.zero_grad()
            # if not use_original_RGB:
            # compressed_RGB = compression_model(ipt)
            # else:
            # compressed_RGB = ipt
            # We just wanna compressed features, not to decode this.
            compressed_feat = compression_model(ipt, need_decode=False)
            # print ('RGB', compressed_RGB.requires_grad)
            predicted = c_resnet_51(compressed_feat)

            class_loss_ = class_loss(predicted, label)

            class_loss_.backward()
            optimizer.step()

            class_loss_meter.add(class_loss_.data[0])
            # class_loss_meter.update(class_loss_.data[0], ipt.size(0))

            acc1, acc5 = accuracy(predicted.data, label.data, topk=(1, 5))
            # pdb.set_trace()

            class_acc_top1_meter.add(acc1[0])
            class_acc_top5_meter.add(acc5[0])
            # class_acc_top1_meter.update(acc1[0], ipt.size(0))
            # class_acc_top5_meter.update(acc5[0], ipt.size(0))

            if idx % opt.print_freq == opt.print_freq - 1:
                ps.add_point(
                    'train class loss',
                    class_loss_meter.value()[0]
                    if opt.print_smooth else class_loss_.data[0])
                ps.add_point(
                    'cur epoch train class loss',
                    class_loss_meter.value()[0]
                    if opt.print_smooth else class_loss_.data[0])
                ps.add_point(
                    'train top_5 acc',
                    class_acc_top5_meter.value()[0]
                    if opt.print_smooth else acc5[0])
                ps.add_point(
                    'train top_1 acc',
                    class_acc_top1_meter.value()[0]
                    if opt.print_smooth else acc1[0])

                ps.log(
                    'Epoch %d/%d, Iter %d/%d, class loss = %.4f, top 5 acc = %.2f %%, top 1 acc  = %.2f %%, lr = %.8f'
                    % (epoch, opt.max_epoch, idx, len(train_dataloader),
                       class_loss_meter.value()[0],
                       class_acc_top5_meter.value()[0],
                       class_acc_top1_meter.value()[0], lr))
                # 进入debug模式
                if os.path.exists(opt.debug_file):
                    pdb.set_trace()

        if use_data_parallel:
            c_resnet_51.module.save(optimizer, epoch)

        # plot before val can ease me
        ps.make_plot(
            'train class loss'
        )  # all epoch share a same img, so give ""(default) to epoch
        ps.make_plot('cur epoch train class loss', epoch)
        ps.make_plot("train top_5 acc")
        ps.make_plot("train top_1 acc")

        val_class_loss, val_top5_acc, val_top1_acc = val(
            compression_model, c_resnet_51, val_dataloader, class_loss, None,
            ps)

        ps.add_point('val class loss', val_class_loss)
        ps.add_point('val top_5 acc', val_top5_acc)
        ps.add_point('val top_1 acc', val_top1_acc)

        ps.make_plot('val class loss')
        ps.make_plot('val top_5 acc')
        ps.make_plot('val top_1 acc')

        ps.log(
            'Epoch:{epoch}, lr:{lr}, train_class_loss: {train_class_loss}, train_top5_acc: {train_top5_acc} %, train_top1_acc: {train_top1_acc} %, \
val_class_loss: {val_class_loss}, val_top5_acc: {val_top5_acc} %, val_top1_acc: {val_top1_acc} %'
            .format(epoch=epoch,
                    lr=lr,
                    train_class_loss=class_loss_meter.value()[0],
                    train_top5_acc=class_acc_top5_meter.value()[0],
                    train_top1_acc=class_acc_top1_meter.value()[0],
                    val_class_loss=val_class_loss,
                    val_top5_acc=val_top5_acc,
                    val_top1_acc=val_top1_acc))

        # adjust lr
        if epoch in opt.lr_decay_step_list:
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
Beispiel #5
0
def train(**kwargs):
    opt.parse(kwargs)
    # log file
    ps = PlotSaver("Context_Baseline_p32_no_imp_plain_1_" +
                   time.strftime("%m_%d_%H:%M:%S") + ".log.txt")

    # step1: Model
    model = getattr(models, opt.model)(
        use_imp=opt.use_imp,
        model_name="Context_Baseline_p32_imp_r={r}_gama={w}_plain".format(
            r=opt.rate_loss_threshold, w=opt.rate_loss_weight)
        if opt.use_imp else "ContextBaseNoImpP32_Plain")

    if opt.use_gpu:
        model = multiple_gpu_process(model)

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    data_transforms = transforms.Compose([transforms.ToTensor(), normalize])
    train_data = datasets.ImageFolder(opt.train_data_root, data_transforms)
    val_data = datasets.ImageFolder(opt.val_data_root, data_transforms)
    # opt.batch_size  --> 1
    train_dataloader = DataLoader(train_data,
                                  1,
                                  shuffle=True,
                                  num_workers=opt.num_workers,
                                  pin_memory=True)
    val_dataloader = DataLoader(val_data,
                                1,
                                shuffle=False,
                                num_workers=opt.num_workers,
                                pin_memory=True)

    # step3: criterion and optimizer

    mse_loss = t.nn.MSELoss(size_average=False)

    if opt.use_imp:
        rate_loss = LimuRateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)

    lr = opt.lr

    optimizer = t.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

    start_epoch = 0

    if opt.resume:
        if use_data_parallel:
            start_epoch = model.module.load(
                None if opt.finetune else optimizer, opt.resume, opt.finetune)
        else:
            start_epoch = model.load(None if opt.finetune else optimizer,
                                     opt.resume, opt.finetune)

        if opt.finetune:
            print('Finetune from model checkpoint file', opt.resume)
        else:
            print('Resume training from checkpoint file', opt.resume)
            print('Continue training at epoch %d.' % start_epoch)

    # step4: meters
    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    previous_loss = 1e100
    tolerant_now = 0
    same_lr_epoch = 0

    # ps init

    ps.new_plot('train mse loss',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_mse_loss")
    ps.new_plot('val mse loss', 1, xlabel="epoch", ylabel="val_mse_loss")
    if opt.use_imp:
        ps.new_plot('train rate value',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_value")
        ps.new_plot('train rate loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_loss")
        ps.new_plot('train total loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_total_loss")
        ps.new_plot('val rate value',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_value")
        ps.new_plot('val rate loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_loss")
        ps.new_plot('val total loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_total_loss")

    for epoch in range(start_epoch + 1, opt.max_epoch + 1):

        same_lr_epoch += 1
        # per epoch avg loss meter
        mse_loss_meter.reset()
        if opt.use_imp:
            rate_display_meter.reset()
            rate_loss_meter.reset()
            total_loss_meter.reset()
        else:
            total_loss_meter = mse_loss_meter
        # cur_epoch_loss refresh every epoch

        ps.new_plot("cur epoch train mse loss",
                    opt.print_freq,
                    xlabel="iteration in cur epoch",
                    ylabel="train_mse_loss")

        # Init val
        if (epoch == start_epoch + 1) and opt.init_val:
            print('Init validation ... ')
            if opt.use_imp:
                mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                    model, val_dataloader, mse_loss, rate_loss, ps)
            else:
                mse_val_loss = val(model, val_dataloader, mse_loss, None, ps)

            ps.add_point('val mse loss', mse_val_loss)
            if opt.use_imp:
                ps.add_point('val rate value', rate_val_display)
                ps.add_point('val rate loss', rate_val_loss)
                ps.add_point('val total loss', total_val_loss)

            # make plot
            ps.make_plot('val mse loss')

            if opt.use_imp:
                ps.make_plot('val rate value')
                ps.make_plot('val rate loss')
                ps.make_plot('val total loss')

            # log sth.
            if opt.use_imp:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                    .format(epoch=epoch,
                            lr=lr,
                            val_mse_loss=mse_val_loss,
                            val_rate_loss=rate_val_loss,
                            val_total_loss=total_val_loss,
                            val_rate_display=rate_val_display))
            else:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss:{val_mse_loss}'
                    .format(
                        epoch=epoch,
                        lr=lr,
                        # train_mse_loss = mse_loss_meter.value()[0],
                        val_mse_loss=mse_val_loss))

        model.train()

        def to_patches(x, patch_size):
            print(type(x))
            _, h, w = x.size()
            print('original h,w', h, w)
            dw = (patch_size - w % patch_size) if w % patch_size else 0
            dh = (patch_size - h % patch_size) if h % patch_size else 0

            x = F.pad(x, (dw // 2, dw - dw // 2, dh // 2, dh - dh // 2))

            _, h, w = x.size()
            print(h, w)
            pdb.set_trace()

            tv.utils.save_image(x.data, 'test_images/padded_original_img.png')

            num_patch_x = w // patch_size
            num_patch_y = h // patch_size
            print(num_patch_x, num_patch_y)
            patches = []
            for i in range(num_patch_y):
                for j in range(num_patch_x):
                    patch = x[:, i * patch_size:(i + 1) * patch_size,
                              j * patch_size:(j + 1) * patch_size]
                    patches.append(patch.contiguous())
                if (j + 1) * patch_size < w:
                    extra_patch = x[:, i * patch_size:(i + 1) * patch_size,
                                    (j + 1) * patch_size:w]
                    extra_patch

            return patches

        # _ is corresponding Label, compression doesn't use it.
        for idx, (data, _) in enumerate(train_dataloader):

            if idx == 0:
                print('skip idx =', idx)
                continue
            # ipt = Variable(data[0])
            ipt = data[0]
            # if opt.use_gpu:
            #     # if not use_data_parallel:  # because ipt is also target, so we still need to transfer it to CUDA
            #         ipt = ipt.cuda(async = True)

            # ipt is a full image, so I need to split it into crops
            # so set batch_size = 1 for simplicity at first

            # pdb.set_trace()

            tv.utils.save_image(ipt, "test_imgs/original.png")
            patches = to_patches(ipt, opt.patch_size)

            for (i, p) in enumerate(patches):
                tv.utils.save_image(p, "test_imgs/%s.png" % i)

            pdb.set_trace()

            optimizer.zero_grad()
            # reconstructed, imp_mask_sigmoid = model(ipt)
            reconstructed = model(ipt)

            # print ('imp_mask_height', model.imp_mask_height)
            # pdb.set_trace()

            # print ('type recons', type(reconstructed.data))
            loss = mse_loss(reconstructed, ipt)
            caffe_loss = loss / (2 * opt.batch_size)

            if opt.use_imp:
                rate_loss_display = imp_mask_sigmoid
                rate_loss_ = rate_loss(rate_loss_display)

                total_loss = caffe_loss + rate_loss_
            else:
                total_loss = caffe_loss

            # 1.
            total_loss.backward()
            # caffe_loss.backward()
            optimizer.step()

            mse_loss_meter.add(caffe_loss.data[0])

            if opt.use_imp:
                rate_loss_meter.add(rate_loss_.data[0])
                rate_display_meter.add(rate_loss_display.data.mean())
                total_loss_meter.add(total_loss.data[0])

            if idx % opt.print_freq == opt.print_freq - 1:
                ps.add_point(
                    'train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                ps.add_point(
                    'cur epoch train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                # print (rate_loss_display.data.mean())
                if opt.use_imp:
                    ps.add_point(
                        'train rate value',
                        rate_display_meter.value()[0]
                        if opt.print_smooth else rate_loss_display.data.mean())
                    ps.add_point(
                        'train rate loss',
                        rate_loss_meter.value()[0]
                        if opt.print_smooth else rate_loss_.data[0])
                    ps.add_point(
                        'train total loss',
                        total_loss_meter.value()[0]
                        if opt.print_smooth else total_loss.data[0])
                # pdb.set_trace()
                # progress_bar.set_description('epoch %d/%d, loss = %.2f' % (epoch, opt.max_epoch, total_loss.data[0]))

                #  2.
                # ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' % (epoch, opt.max_epoch, idx, len(train_dataloader), total_loss_meter.value()[0], lr))
                # ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' % (epoch, opt.max_epoch, idx, len(train_dataloader), mse_loss_meter.value()[0], lr))

                # ps.log('loss = %f' % caffe_loss.data[0])
                # print(total_loss.data[0])
                # input('waiting......')

                if not opt.use_imp:
                    ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' %
                           (epoch, opt.max_epoch, idx, len(train_dataloader),
                            total_loss_meter.value()[0], lr))
                else:
                    ps.log(
                        'Epoch %d/%d, Iter %d/%d, loss = %.2f, mse_loss = %.2f, rate_loss = %.2f, rate_display = %.2f, lr = %.8f'
                        %
                        (epoch, opt.max_epoch, idx, len(train_dataloader),
                         total_loss_meter.value()[0],
                         mse_loss_meter.value()[0], rate_loss_meter.value()[0],
                         rate_display_meter.value()[0], lr))
                # 进入debug模式
                if os.path.exists(opt.debug_file):
                    # import pdb
                    pdb.set_trace()

        if use_data_parallel:
            # print (type(model.module))
            # print (model)
            # print (type(model))
            model.module.save(optimizer, epoch)
        else:
            model.save(optimizer, epoch)

        # print ('case error', total_loss.data[0])
        # print ('smoothed error', total_loss_meter.value()[0])

        # plot before val can ease me
        ps.make_plot('train mse loss'
                     )  # all epoch share a same img, so give "" to epoch
        ps.make_plot('cur epoch train mse loss', epoch)
        if opt.use_imp:
            ps.make_plot("train rate value")
            ps.make_plot("train rate loss")
            ps.make_plot("train total loss")

        # val
        if opt.use_imp:
            mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                model, val_dataloader, mse_loss, rate_loss, ps)
        else:
            mse_val_loss = val(model, val_dataloader, mse_loss, None, ps)

        ps.add_point('val mse loss', mse_val_loss)
        if opt.use_imp:
            ps.add_point('val rate value', rate_val_display)
            ps.add_point('val rate loss', rate_val_loss)
            ps.add_point('val total loss', total_val_loss)

        # make plot
        # ps.make_plot('train mse loss', "")   # all epoch share a same img, so give "" to epoch
        # ps.make_plot('cur epoch train mse loss',epoch)
        ps.make_plot('val mse loss')

        if opt.use_imp:
            # ps.make_plot("train rate value","")
            # ps.make_plot("train rate loss","")
            # ps.make_plot("train total loss","")
            ps.make_plot('val rate value')
            ps.make_plot('val rate loss')
            ps.make_plot('val total loss')

        # log sth.
        if opt.use_imp:
            ps.log(
                'Epoch:{epoch}, lr:{lr}, train_mse_loss: {train_mse_loss}, train_rate_loss: {train_rate_loss}, train_total_loss: {train_total_loss}, train_rate_display: {train_rate_display} \n\
val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                .format(epoch=epoch,
                        lr=lr,
                        train_mse_loss=mse_loss_meter.value()[0],
                        train_rate_loss=rate_loss_meter.value()[0],
                        train_total_loss=total_loss_meter.value()[0],
                        train_rate_display=rate_display_meter.value()[0],
                        val_mse_loss=mse_val_loss,
                        val_rate_loss=rate_val_loss,
                        val_total_loss=total_val_loss,
                        val_rate_display=rate_val_display))
        else:
            ps.log(
                'Epoch:{epoch}, lr:{lr}, train_mse_loss:{train_mse_loss}, val_mse_loss:{val_mse_loss}'
                .format(epoch=epoch,
                        lr=lr,
                        train_mse_loss=mse_loss_meter.value()[0],
                        val_mse_loss=mse_val_loss))

        # Adaptive adjust lr
        # 每个lr,如果有opt.tolerant_max次比上次的val_loss还高,
        # update learning rate
        # if loss_meter.value()[0] > previous_loss:
        if opt.use_early_adjust:
            if total_loss_meter.value()[0] > previous_loss:
                tolerant_now += 1
                if tolerant_now == opt.tolerant_max:
                    tolerant_now = 0
                    same_lr_epoch = 0
                    lr = lr * opt.lr_decay
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Due to early stop anneal lr to', lr, 'at epoch',
                          epoch)
                    ps.log('Due to early stop anneal lr to %.10f at epoch %d' %
                           (lr, epoch))

            else:
                tolerant_now -= 1

        # if same_lr_epoch and same_lr_epoch % opt.lr_anneal_epochs == 0:
        #     same_lr_epoch = 0
        #     tolerant_now = 0
        #     lr = lr * opt.lr_decay
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr
        #     print ('Due to full epochs anneal lr to',lr,'at epoch',epoch)
        #     ps.log ('Due to full epochs anneal lr to %.10f at epoch %d' % (lr, epoch))

        if opt.use_file_decay_lr and os.path.exists(opt.lr_decay_file):
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        # previous_loss = total_loss_meter.value()[0] if opt.use_imp else mse_loss_meter.value()[0]
        previous_loss = total_loss_meter.value()[0]
def val(model, dataloader, mse_loss, rate_loss, ps, epoch, show_inf_imgs=True):

    revert_transforms = transforms.Compose([
        transforms.Normalize((-1, -1, -1), (2, 2, 2)),
        transforms.Lambda(lambda tensor: t.clamp(tensor, 0.0, 1.0)),
    ])

    if ps:
        ps.log('validating ... ')
    else:
        print('run val ... ')
    model.eval()
    # avg_loss = 0
    # progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), ascii=True)

    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    mse_loss_meter.reset()
    if opt.use_imp:
        rate_loss_meter.reset()
        rate_display_meter.reset()
        total_loss_meter.reset()

    for idx, (data, _) in enumerate(dataloader):
        val_input = Variable(data, volatile=True)
        if opt.use_gpu:
            val_input = val_input.cuda(async=True)

        reconstructed = model(val_input)

        batch_loss = mse_loss(reconstructed, val_input)
        batch_caffe_loss = batch_loss / (2 * opt.batch_size)

        if opt.use_imp and rate_loss:
            rate_loss_display = imp_mask_sigmoid
            rate_loss_value = rate_loss(rate_loss_display)
            total_loss = batch_caffe_loss + rate_loss_value

        mse_loss_meter.add(batch_caffe_loss.data[0])

        if opt.use_imp:
            rate_loss_meter.add(rate_loss_value.data[0])
            rate_display_meter.add(rate_loss_display.data.mean())
            total_loss_meter.add(total_loss.data[0])

        if show_inf_imgs:
            if idx > 5:
                continue
            reconstructed_imgs = revert_transforms(reconstructed.data[0])
            # pdb.set_trace()
            # print ('save imgs ...')

            dir_path = os.path.join(opt.save_inf_imgs_path, 'epoch_%d' % epoch)
            if not os.path.exists(dir_path):
                # print ('mkdir', dir_path)
                os.makedirs(dir_path)
            tv.utils.save_image(
                reconstructed_imgs,
                os.path.join(
                    dir_path, 'test_p{ps}_epoch{epoch}_idx{idx}.png'.format(
                        ps=opt.patch_size, epoch=epoch, idx=idx)))
            tv.utils.save_image(
                revert_transforms(val_input.data[0]),
                os.path.join(dir_path, 'origin_{idx}.png'.format(idx=idx)))

    if opt.use_imp:
        return mse_loss_meter.value()[0], rate_loss_meter.value(
        )[0], total_loss_meter.value()[0], rate_display_meter.value()[0]
    else:
        return mse_loss_meter.value()[0]
def train(**kwargs):
    opt.parse(kwargs)
    # log file
    # tpse: test patch size effect
    ps = PlotSaver(opt.exp_id[4:] + "_" + time.strftime("%m_%d__%H:%M:%S") +
                   ".log.txt")

    # step1: Model
    model = getattr(models, opt.model)(
        use_imp=opt.use_imp,
        model_name="TPSE_p{ps}_imp_r={r}_gama={w}".format(
            ps=opt.patch_size,
            r=opt.rate_loss_threshold,
            w=opt.rate_loss_weight)
        if opt.use_imp else "TPSE_p{ps}_no_imp".format(ps=opt.patch_size))

    if opt.use_gpu:
        model = multiple_gpu_process(model)

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    train_data_transforms = transforms.Compose([
        transforms.RandomCrop(opt.patch_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])

    val_data_transforms = transforms.Compose(
        [transforms.CenterCrop(256),
         transforms.ToTensor(), normalize])

    train_data = datasets.ImageFolder(opt.train_data_root,
                                      train_data_transforms)
    val_data = datasets.ImageFolder(opt.val_data_root, val_data_transforms)

    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers,
                                  pin_memory=True)
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers,
                                pin_memory=True)

    # step3: criterion and optimizer

    mse_loss = t.nn.MSELoss(size_average=False)
    if opt.use_imp:
        rate_loss = LimuRateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)
    lr = opt.lr
    optimizer = t.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

    start_epoch = 0

    if opt.resume:
        if use_data_parallel:
            start_epoch = model.module.load(
                None if opt.finetune else optimizer, opt.resume, opt.finetune)
        else:
            start_epoch = model.load(None if opt.finetune else optimizer,
                                     opt.resume, opt.finetune)

        if opt.finetune:
            print('Finetune from model checkpoint file', opt.resume)
        else:
            print('Resume training from checkpoint file', opt.resume)
            print('Continue training at epoch %d.' % start_epoch)

    # step4: meters
    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    previous_loss = 1e100
    tolerant_now = 0
    same_lr_epoch = 0

    # ps init

    ps.new_plot('train mse loss',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_mse_loss")
    ps.new_plot('val mse loss', 1, xlabel="epoch", ylabel="val_mse_loss")
    if opt.use_imp:
        ps.new_plot('train rate value',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_value")
        ps.new_plot('train rate loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_loss")
        ps.new_plot('train total loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_total_loss")
        ps.new_plot('val rate value',
                    1,
                    xlabel="epoch",
                    ylabel="val_rate_value")
        ps.new_plot('val rate loss', 1, xlabel="epoch", ylabel="val_rate_loss")
        ps.new_plot('val total loss',
                    1,
                    xlabel="epoch",
                    ylabel="val_total_loss")

    for epoch in range(start_epoch + 1, opt.max_epoch + 1):

        same_lr_epoch += 1
        # per epoch avg loss meter
        mse_loss_meter.reset()
        if opt.use_imp:
            rate_display_meter.reset()
            rate_loss_meter.reset()
            total_loss_meter.reset()
        else:
            total_loss_meter = mse_loss_meter

        ps.new_plot("cur epoch train mse loss",
                    opt.print_freq,
                    xlabel="iteration in cur epoch",
                    ylabel="train_mse_loss")

        # Init val
        if (epoch == start_epoch + 1) and opt.init_val:
            print('Init validation ... ')
            if opt.use_imp:
                mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                    model, val_dataloader, mse_loss, rate_loss, ps, -1, True)
            else:
                mse_val_loss = val(model, val_dataloader, mse_loss, None, ps,
                                   -1, True)

            ps.add_point('val mse loss', mse_val_loss)
            if opt.use_imp:
                ps.add_point('val rate value', rate_val_display)
                ps.add_point('val rate loss', rate_val_loss)
                ps.add_point('val total loss', total_val_loss)

            # make plot
            ps.make_plot('val mse loss')

            if opt.use_imp:
                ps.make_plot('val rate value')
                ps.make_plot('val rate loss')
                ps.make_plot('val total loss')

            # log sth.
            if opt.use_imp:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                    .format(epoch=epoch,
                            lr=lr,
                            val_mse_loss=mse_val_loss,
                            val_rate_loss=rate_val_loss,
                            val_total_loss=total_val_loss,
                            val_rate_display=rate_val_display))
            else:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss:{val_mse_loss}'
                    .format(epoch=epoch, lr=lr, val_mse_loss=mse_val_loss))

        # 开始训练
        model.train()

        # _ is corresponding Label, compression doesn't use it.
        for idx, (data, _) in enumerate(train_dataloader):

            ipt = Variable(data)

            if opt.use_gpu:
                ipt = ipt.cuda(async=True)

            optimizer.zero_grad()

            reconstructed = model(ipt)

            loss = mse_loss(reconstructed, ipt)
            caffe_loss = loss / (2 * opt.batch_size)

            # 相对于8x8的SIZE
            caffe_loss = caffe_loss / (
                (opt.patch_size // 8)**2
            )  # 8 16 32 ... (size = 4 * size')   * / 8 = 1 2 4 ...   log2(*/8) = 0, 1, 2 ...  4^(log2(*/8)) = (*/8)^2

            if opt.use_imp:
                rate_loss_display = imp_mask_sigmoid
                rate_loss_ = rate_loss(rate_loss_display)

                total_loss = caffe_loss + rate_loss_
            else:
                total_loss = caffe_loss

            total_loss.backward()

            optimizer.step()

            mse_loss_meter.add(caffe_loss.data[0])

            if opt.use_imp:
                rate_loss_meter.add(rate_loss_.data[0])
                rate_display_meter.add(rate_loss_display.data.mean())
                total_loss_meter.add(total_loss.data[0])

            if idx % opt.print_freq == opt.print_freq - 1:
                ps.add_point(
                    'train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                ps.add_point(
                    'cur epoch train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                # print (rate_loss_display.data.mean())
                if opt.use_imp:
                    ps.add_point(
                        'train rate value',
                        rate_display_meter.value()[0]
                        if opt.print_smooth else rate_loss_display.data.mean())
                    ps.add_point(
                        'train rate loss',
                        rate_loss_meter.value()[0]
                        if opt.print_smooth else rate_loss_.data[0])
                    ps.add_point(
                        'train total loss',
                        total_loss_meter.value()[0]
                        if opt.print_smooth else total_loss.data[0])

                if not opt.use_imp:
                    ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' %
                           (epoch, opt.max_epoch, idx, len(train_dataloader),
                            total_loss_meter.value()[0], lr))
                else:
                    ps.log(
                        'Epoch %d/%d, Iter %d/%d, loss = %.2f, mse_loss = %.2f, rate_loss = %.2f, rate_display = %.2f, lr = %.8f'
                        %
                        (epoch, opt.max_epoch, idx, len(train_dataloader),
                         total_loss_meter.value()[0],
                         mse_loss_meter.value()[0], rate_loss_meter.value()[0],
                         rate_display_meter.value()[0], lr))
                # 进入debug模式
                if os.path.exists(opt.debug_file):
                    # import pdb
                    pdb.set_trace()

        if use_data_parallel:
            model.module.save(optimizer, epoch)
        else:
            model.save(optimizer, epoch)

        # plot before val can ease me
        ps.make_plot('train mse loss'
                     )  # all epoch share a same img, so give "" to epoch
        ps.make_plot('cur epoch train mse loss', epoch)
        if opt.use_imp:
            ps.make_plot("train rate value")
            ps.make_plot("train rate loss")
            ps.make_plot("train total loss")

        # val
        if opt.use_imp:
            mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                model, val_dataloader, mse_loss, rate_loss, ps, epoch,
                epoch % opt.show_inf_imgs_T == 0)
        else:
            mse_val_loss = val(model, val_dataloader, mse_loss, None, ps,
                               epoch, epoch % opt.show_inf_imgs_T == 0)

        ps.add_point('val mse loss', mse_val_loss)
        if opt.use_imp:
            ps.add_point('val rate value', rate_val_display)
            ps.add_point('val rate loss', rate_val_loss)
            ps.add_point('val total loss', total_val_loss)

        ps.make_plot('val mse loss')

        if opt.use_imp:
            ps.make_plot('val rate value')
            ps.make_plot('val rate loss')
            ps.make_plot('val total loss')

        # log sth.
        if opt.use_imp:
            ps.log(
                'Epoch:{epoch}, lr:{lr}, train_mse_loss: {train_mse_loss}, train_rate_loss: {train_rate_loss}, train_total_loss: {train_total_loss}, train_rate_display: {train_rate_display} \n\
val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                .format(epoch=epoch,
                        lr=lr,
                        train_mse_loss=mse_loss_meter.value()[0],
                        train_rate_loss=rate_loss_meter.value()[0],
                        train_total_loss=total_loss_meter.value()[0],
                        train_rate_display=rate_display_meter.value()[0],
                        val_mse_loss=mse_val_loss,
                        val_rate_loss=rate_val_loss,
                        val_total_loss=total_val_loss,
                        val_rate_display=rate_val_display))
        else:
            ps.log(
                'Epoch:{epoch}, lr:{lr}, train_mse_loss:{train_mse_loss}, val_mse_loss:{val_mse_loss}'
                .format(epoch=epoch,
                        lr=lr,
                        train_mse_loss=mse_loss_meter.value()[0],
                        val_mse_loss=mse_val_loss))

        # Adaptive adjust lr
        # 每个lr,如果有opt.tolerant_max次比上次的val_loss还高,
        # update learning rate
        # if loss_meter.value()[0] > previous_loss:
        if opt.use_early_adjust:
            if total_loss_meter.value()[0] > previous_loss:
                tolerant_now += 1
                if tolerant_now == opt.tolerant_max:
                    tolerant_now = 0
                    same_lr_epoch = 0
                    lr = lr * opt.lr_decay
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Due to early stop anneal lr to', lr, 'at epoch',
                          epoch)
                    ps.log('Due to early stop anneal lr to %.10f at epoch %d' %
                           (lr, epoch))

            else:
                tolerant_now -= 1

        if same_lr_epoch and same_lr_epoch % opt.lr_anneal_epochs == 0:
            same_lr_epoch = 0
            tolerant_now = 0
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            print('Due to full epochs anneal lr to', lr, 'at epoch', epoch)
            ps.log('Due to full epochs anneal lr to %.10f at epoch %d' %
                   (lr, epoch))

        if opt.use_file_decay_lr and os.path.exists(opt.lr_decay_file):
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        previous_loss = total_loss_meter.value()[0]
def train(**kwargs):
    opt.parse(kwargs)
    # vis = Visualizer(opt.env)
    # log file
    ps = PlotSaver("Train_ImageNet12_With_ImpMap_" +
                   time.strftime("%m_%d_%H:%M:%S") + ".log.txt")

    # step1: Model
    model = getattr(models, opt.model)(
        use_imp=opt.use_imp,
        model_name="CWCNN_limu_ImageNet_imp_r={r}_γ={w}".format(
            r=opt.rate_loss_threshold, w=opt.rate_loss_weight)
        if opt.use_imp else None)
    # if opt.use_imp else "test_pytorch")
    if opt.use_gpu:
        # model = multiple_gpu_process(model)
        model.cuda()


#    pdb.set_trace()

    cudnn.benchmark = True

    # step2: Data
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_data_transforms = transforms.Compose([
        transforms.Resize(256),
        #transforms.Scale(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])
    val_data_transforms = transforms.Compose([
        transforms.Resize(256),
        #transforms.Scale(256),
        transforms.CenterCrop(224),
        # transforms.TenCrop(224),
        # transforms.Lambda(lambda crops: t.stack(([normalize(transforms.ToTensor()(crop)) for crop in crops]))),
        transforms.ToTensor(),
        normalize
    ])
    # train_data = ImageNet_200k(opt.train_data_root, train=True, transforms=data_transforms)
    # val_data = ImageNet_200k(opt.val_data_root, train = False, transforms=data_transforms)
    train_data = datasets.ImageFolder(opt.train_data_root,
                                      train_data_transforms)
    val_data = datasets.ImageFolder(opt.val_data_root, val_data_transforms)
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers,
                                  pin_memory=True)
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers,
                                pin_memory=True)

    # step3: criterion and optimizer

    mse_loss = t.nn.MSELoss(size_average=False)

    if opt.use_imp:
        # rate_loss = RateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)
        rate_loss = LimuRateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)

    lr = opt.lr

    optimizer = t.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

    start_epoch = 0

    if opt.resume:
        if hasattr(model, 'module'):
            start_epoch = model.module.load(
                None if opt.finetune else optimizer, opt.resume, opt.finetune)
        else:
            start_epoch = model.load(None if opt.finetune else optimizer,
                                     opt.resume, opt.finetune)

        if opt.finetune:
            print('Finetune from model checkpoint file', opt.resume)
        else:
            print('Resume training from checkpoint file', opt.resume)
            print('Continue training at epoch %d.' % start_epoch)

    # step4: meters
    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    previous_loss = 1e100
    tolerant_now = 0
    same_lr_epoch = 0

    # ps init

    ps.new_plot('train mse loss',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_mse_loss")
    ps.new_plot('val mse loss', 1, xlabel="epoch", ylabel="val_mse_loss")
    if opt.use_imp:
        ps.new_plot('train rate value',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_value")
        ps.new_plot('train rate loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_loss")
        ps.new_plot('train total loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_total_loss")
        ps.new_plot('val rate value',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_value")
        ps.new_plot('val rate loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_loss")
        ps.new_plot('val total loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_total_loss")

    for epoch in range(start_epoch + 1, opt.max_epoch + 1):
        same_lr_epoch += 1
        # per epoch avg loss meter
        mse_loss_meter.reset()
        if opt.use_imp:
            rate_display_meter.reset()
            rate_loss_meter.reset()
            total_loss_meter.reset()
        else:
            total_loss_meter = mse_loss_meter
        # cur_epoch_loss refresh every epoch

        ps.new_plot("cur epoch train mse loss",
                    opt.print_freq,
                    xlabel="iteration in cur epoch",
                    ylabel="train_mse_loss")

        model.train()

        # _ is corresponding Label, compression doesn't use it.
        for idx, (data, _) in enumerate(train_dataloader):
            ipt = Variable(data)

            if opt.use_gpu:
                ipt = ipt.cuda()

            optimizer.zero_grad()  # f**k it! Don't forget to clear grad!
            reconstructed = model(ipt)

            # print ('reconstructed tensor size :', reconstructed.size())
            loss = mse_loss(reconstructed, ipt)
            caffe_loss = loss / (2 * opt.batch_size)

            if opt.use_imp:
                # print ('use data_parallel?',use_data_parallel)
                # pdb.set_trace()
                rate_loss_display = (model.module if use_data_parallel else
                                     model).imp_mask_sigmoid
                rate_loss_ = rate_loss(rate_loss_display)
                total_loss = caffe_loss + rate_loss_
            else:
                total_loss = caffe_loss

            total_loss.backward()

            optimizer.step()

            mse_loss_meter.add(caffe_loss.data[0])

            if opt.use_imp:
                rate_loss_meter.add(rate_loss_.data[0])
                rate_display_meter.add(rate_loss_display.data.mean())
                total_loss_meter.add(total_loss.data[0])

            if idx % opt.print_freq == opt.print_freq - 1:
                ps.add_point(
                    'train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                ps.add_point(
                    'cur epoch train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                if opt.use_imp:
                    ps.add_point(
                        'train rate value',
                        rate_display_meter.value()[0]
                        if opt.print_smooth else rate_loss_display.data.mean())
                    ps.add_point(
                        'train rate loss',
                        rate_loss_meter.value()[0]
                        if opt.print_smooth else rate_loss_.data[0])
                    ps.add_point(
                        'train total loss',
                        total_loss_meter.value()[0]
                        if opt.print_smooth else total_loss.data[0])
                if not opt.use_imp:
                    ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' %
                           (epoch, opt.max_epoch, idx, len(train_dataloader),
                            total_loss_meter.value()[0], lr))
                else:
                    ps.log(
                        'Epoch %d/%d, Iter %d/%d, loss = %.2f, mse_loss = %.2f, rate_loss = %.2f, rate_display = %.2f, lr = %.8f'
                        %
                        (epoch, opt.max_epoch, idx, len(train_dataloader),
                         total_loss_meter.value()[0],
                         mse_loss_meter.value()[0], rate_loss_meter.value()[0],
                         rate_display_meter.value()[0], lr))
                # 进入debug模式
                if os.path.exists(opt.debug_file):
                    pdb.set_trace()

        # data parallel
        # if hasattr(model, 'module'):
        if use_data_parallel:
            model.module.save(optimizer, epoch)
        else:
            model.save(optimizer, epoch)

        # plot before val can ease me
        ps.make_plot(
            'train mse loss'
        )  # all epoch share a same img, so give ""(default) to epoch
        ps.make_plot('cur epoch train mse loss', epoch)
        if opt.use_imp:
            ps.make_plot("train rate value")
            ps.make_plot("train rate loss")
            ps.make_plot("train total loss")

        # val
        if opt.use_imp:
            mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                model, val_dataloader, mse_loss, rate_loss, ps)
        else:
            mse_val_loss = val(model, val_dataloader, mse_loss, None, ps)

        ps.add_point('val mse loss', mse_val_loss)
        if opt.use_imp:
            ps.add_point('val rate value', rate_val_display)
            ps.add_point('val rate loss', rate_val_loss)
            ps.add_point('val total loss', total_val_loss)

        ps.make_plot('val mse loss')

        if opt.use_imp:
            ps.make_plot('val rate value')
            ps.make_plot('val rate loss')
            ps.make_plot('val total loss')

        # log sth.
        if opt.use_imp:
            ps.log(
                'Epoch:{epoch}, lr:{lr}, train_mse_loss: {train_mse_loss}, train_rate_loss: {train_rate_loss}, train_total_loss: {train_total_loss}, train_rate_display: {train_rate_display} \n\
val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                .format(epoch=epoch,
                        lr=lr,
                        train_mse_loss=mse_loss_meter.value()[0],
                        train_rate_loss=rate_loss_meter.value()[0],
                        train_total_loss=total_loss_meter.value()[0],
                        train_rate_display=rate_display_meter.value()[0],
                        val_mse_loss=mse_val_loss,
                        val_rate_loss=rate_val_loss,
                        val_total_loss=total_val_loss,
                        val_rate_display=rate_val_display))
        else:
            ps.log(
                'Epoch:{epoch}, lr:{lr}, train_mse_loss:{train_mse_loss}, val_mse_loss:{val_mse_loss}'
                .format(epoch=epoch,
                        lr=lr,
                        train_mse_loss=mse_loss_meter.value()[0],
                        val_mse_loss=mse_val_loss))

        # Adaptive adjust lr
        # 每个lr,如果有opt.tolerant_max次比上次的val_loss还高,
        if opt.use_early_adjust:
            if total_loss_meter.value()[0] > previous_loss:
                tolerant_now += 1
                if tolerant_now == opt.tolerant_max:
                    tolerant_now = 0
                    same_lr_epoch = 0
                    lr = lr * opt.lr_decay
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Anneal lr to', lr, 'at epoch', epoch,
                          'due to early stop.')
                    ps.log(
                        'Anneal lr to %.10f at epoch %d due to early stop.' %
                        (lr, epoch))

            else:
                tolerant_now -= 1

        if same_lr_epoch and same_lr_epoch % opt.lr_anneal_epochs == 0:
            same_lr_epoch = 0
            tolerant_now = 0
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            print('Anneal lr to', lr, 'at epoch', epoch, 'due to full epochs.')
            ps.log('Anneal lr to %.10f at epoch %d due to full epochs.' %
                   (lr, epoch))

        previous_loss = total_loss_meter.value()[0]
def train(**kwargs):
    opt.parse(kwargs)
    # log file
    logfile_name = "Cmpr_with_YOLOv2_" + opt.exp_desc + time.strftime(
        "_%m_%d_%H:%M:%S") + ".log.txt"
    ps = PlotSaver(logfile_name)

    # step1: Model
    model = getattr(models, opt.model)(
        use_imp=opt.use_imp,
        n=opt.feat_num,
        input_4_ch=opt.input_4_ch,
        model_name="Cmpr_yolo_imp_" + opt.exp_desc + "_r={r}_gama={w}".format(
            r=opt.rate_loss_threshold, w=opt.rate_loss_weight)
        if opt.use_imp else "Cmpr_yolo_no_imp_" + opt.exp_desc)
    # pdb.set_trace()
    if opt.use_gpu:
        model = multiple_gpu_process(model)

    cudnn.benchmark = True

    # step2: Data
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    train_data_transforms = transforms.Compose([
        # transforms.RandomHorizontalFlip(),  TODO: try to reimplement by myself to simultaneous operate on label and data
        transforms.ToTensor(),
        normalize
    ])
    val_data_transforms = transforms.Compose(
        [transforms.ToTensor(), normalize])
    train_data = ImageCropWithBBoxMaskDataset(
        opt.train_data_list,
        train_data_transforms,
        contrastive_degree=opt.contrastive_degree,
        mse_bbox_weight=opt.input_original_bbox_weight)
    val_data = ImageCropWithBBoxMaskDataset(
        opt.val_data_list,
        val_data_transforms,
        contrastive_degree=opt.contrastive_degree,
        mse_bbox_weight=opt.input_original_bbox_weight)
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers,
                                  pin_memory=True)
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers,
                                pin_memory=True)

    # step3: criterion and optimizer

    mse_loss = t.nn.MSELoss(size_average=False)

    if opt.use_imp:
        # TODO: new rate loss
        rate_loss = RateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)
        # rate_loss = LimuRateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)

    def weighted_mse_loss(input, target, weight):
        # weight[weight!=opt.mse_bbox_weight] = 1
        # weight[weight==opt.mse_bbox_weight] = opt.mse_bbox_weight
        # print('max val', weight.max())
        # return mse_loss(input, target)
        # weight_clone = weight.clone()
        # weight_clone[weight_clone == opt.input_original_bbox_weight] = 0
        # return t.sum(weight_clone * (input - target) ** 2)
        weight_clone = t.ones_like(weight)
        weight_clone[weight ==
                     opt.input_original_bbox_inner] = opt.mse_bbox_weight
        return t.sum(weight_clone * (input - target)**2)

    def yolo_rate_loss(imp_map, mask_r):
        return rate_loss(imp_map)
        # V2 contrastive_degree must be 0!
        # return YoloRateLossV2(mask_r, opt.rate_loss_threshold, opt.rate_loss_weight)(imp_map)

    lr = opt.lr
    optimizer = t.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

    start_epoch = 0
    decay_file_create_time = -1  # 为了避免同一个文件反复衰减学习率, 所以判断修改时间

    if opt.resume:
        if use_data_parallel:
            start_epoch = model.module.load(
                None if opt.finetune else optimizer, opt.resume, opt.finetune)
        else:
            start_epoch = model.load(None if opt.finetune else optimizer,
                                     opt.resume, opt.finetune)

        if opt.finetune:
            print('Finetune from model checkpoint file', opt.resume)
        else:
            print('Resume training from checkpoint file', opt.resume)
            print('Continue training at epoch %d.' % start_epoch)

    # step4: meters
    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    previous_loss = 1e100
    tolerant_now = 0
    same_lr_epoch = 0

    # ps init

    ps.new_plot('train mse loss',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_mse_loss")
    ps.new_plot('val mse loss', 1, xlabel="epoch", ylabel="val_mse_loss")
    if opt.use_imp:
        ps.new_plot('train rate value',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_value")
        ps.new_plot('train rate loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_loss")
        ps.new_plot('train total loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_total_loss")
        ps.new_plot('val rate value',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_value")
        ps.new_plot('val rate loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_loss")
        ps.new_plot('val total loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_total_loss")

    for epoch in range(start_epoch + 1, opt.max_epoch + 1):

        same_lr_epoch += 1
        # per epoch avg loss meter
        mse_loss_meter.reset()
        if opt.use_imp:
            rate_display_meter.reset()
            rate_loss_meter.reset()
            total_loss_meter.reset()
        else:
            total_loss_meter = mse_loss_meter
        # cur_epoch_loss refresh every epoch
        # vis.refresh_plot('cur epoch train mse loss')
        ps.new_plot("cur epoch train mse loss",
                    opt.print_freq,
                    xlabel="iteration in cur epoch",
                    ylabel="train_mse_loss")
        # progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), ascii=True)
        # progress_bar.set_description('epoch %d/%d, loss = 0.00' % (epoch, opt.max_epoch))

        # Init val
        if (epoch == start_epoch + 1) and opt.init_val:
            print('Init validation ... ')
            if opt.use_imp:
                mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                    model, val_dataloader, weighted_mse_loss, yolo_rate_loss,
                    ps)
            else:
                mse_val_loss = val(model, val_dataloader, weighted_mse_loss,
                                   None, ps)

            ps.add_point('val mse loss', mse_val_loss)
            if opt.use_imp:
                ps.add_point('val rate value', rate_val_display)
                ps.add_point('val rate loss', rate_val_loss)
                ps.add_point('val total loss', total_val_loss)

            ps.make_plot('val mse loss')

            if opt.use_imp:
                ps.make_plot('val rate value')
                ps.make_plot('val rate loss')
                ps.make_plot('val total loss')

            # log sth.
            if opt.use_imp:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                    .format(epoch=epoch,
                            lr=lr,
                            val_mse_loss=mse_val_loss,
                            val_rate_loss=rate_val_loss,
                            val_total_loss=total_val_loss,
                            val_rate_display=rate_val_display))
            else:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss:{val_mse_loss}'
                    .format(epoch=epoch, lr=lr, val_mse_loss=mse_val_loss))

        if opt.only_init_val:
            print('Only Init Val Over!')
            return

        model.train()

        if epoch == start_epoch + 1:
            print('Start training, please inspect log file %s!' % logfile_name)
        # mask is the detection bounding box mask
        for idx, (data, mask, o_mask) in enumerate(train_dataloader):

            # pdb.set_trace()

            data = Variable(data)
            mask = Variable(mask)
            o_mask = Variable(o_mask, requires_grad=False)

            if opt.use_gpu:
                data = data.cuda(async=True)
                mask = mask.cuda(async=True)
                o_mask = o_mask.cuda(async=True)

            # pdb.set_trace()

            optimizer.zero_grad()
            reconstructed, imp_mask_sigmoid = model(data, mask, o_mask)

            # print ('imp_mask_height', model.imp_mask_height)
            # pdb.set_trace()

            # print ('type recons', type(reconstructed.data))

            loss = weighted_mse_loss(reconstructed, data, o_mask)
            # loss = mse_loss(reconstructed, data)
            caffe_loss = loss / (2 * opt.batch_size)

            if opt.use_imp:
                rate_loss_display = imp_mask_sigmoid
                # rate_loss_ =  rate_loss(rate_loss_display)
                rate_loss_ = yolo_rate_loss(rate_loss_display, mask)
                total_loss = caffe_loss + rate_loss_
            else:
                total_loss = caffe_loss

            total_loss.backward()
            optimizer.step()

            mse_loss_meter.add(caffe_loss.data[0])

            if opt.use_imp:
                rate_loss_meter.add(rate_loss_.data[0])
                rate_display_meter.add(rate_loss_display.data.mean())
                total_loss_meter.add(total_loss.data[0])

            if idx % opt.print_freq == opt.print_freq - 1:
                ps.add_point(
                    'train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                ps.add_point(
                    'cur epoch train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                if opt.use_imp:
                    ps.add_point(
                        'train rate value',
                        rate_display_meter.value()[0]
                        if opt.print_smooth else rate_loss_display.data.mean())
                    ps.add_point(
                        'train rate loss',
                        rate_loss_meter.value()[0]
                        if opt.print_smooth else rate_loss_.data[0])
                    ps.add_point(
                        'train total loss',
                        total_loss_meter.value()[0]
                        if opt.print_smooth else total_loss.data[0])

                if not opt.use_imp:
                    ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' %
                           (epoch, opt.max_epoch, idx, len(train_dataloader),
                            total_loss_meter.value()[0], lr))
                else:
                    ps.log(
                        'Epoch %d/%d, Iter %d/%d, loss = %.2f, mse_loss = %.2f, rate_loss = %.2f, rate_display = %.2f, lr = %.8f'
                        %
                        (epoch, opt.max_epoch, idx, len(train_dataloader),
                         total_loss_meter.value()[0],
                         mse_loss_meter.value()[0], rate_loss_meter.value()[0],
                         rate_display_meter.value()[0], lr))

                # 进入debug模式
                if os.path.exists(opt.debug_file):
                    pdb.set_trace()

        if epoch % opt.save_interval == 0:
            print('save checkpoint file of epoch %d.' % epoch)
            if use_data_parallel:
                model.module.save(optimizer, epoch)
            else:
                model.save(optimizer, epoch)

        ps.make_plot('train mse loss')
        ps.make_plot('cur epoch train mse loss', epoch)
        if opt.use_imp:
            ps.make_plot("train rate value")
            ps.make_plot("train rate loss")
            ps.make_plot("train total loss")

        if epoch % opt.eval_interval == 0:

            print('Validating ...')
            # val
            if opt.use_imp:
                mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                    model, val_dataloader, weighted_mse_loss, yolo_rate_loss,
                    ps)
            else:
                mse_val_loss = val(model, val_dataloader, weighted_mse_loss,
                                   None, ps)

            ps.add_point('val mse loss', mse_val_loss)
            if opt.use_imp:
                ps.add_point('val rate value', rate_val_display)
                ps.add_point('val rate loss', rate_val_loss)
                ps.add_point('val total loss', total_val_loss)

            ps.make_plot('val mse loss')

            if opt.use_imp:
                ps.make_plot('val rate value')
                ps.make_plot('val rate loss')
                ps.make_plot('val total loss')

            # log sth.
            if opt.use_imp:
                ps.log(
                    'Epoch:{epoch}, lr:{lr}, train_mse_loss: {train_mse_loss}, train_rate_loss: {train_rate_loss}, train_total_loss: {train_total_loss}, train_rate_display: {train_rate_display} \n\
    val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                    .format(epoch=epoch,
                            lr=lr,
                            train_mse_loss=mse_loss_meter.value()[0],
                            train_rate_loss=rate_loss_meter.value()[0],
                            train_total_loss=total_loss_meter.value()[0],
                            train_rate_display=rate_display_meter.value()[0],
                            val_mse_loss=mse_val_loss,
                            val_rate_loss=rate_val_loss,
                            val_total_loss=total_val_loss,
                            val_rate_display=rate_val_display))
            else:
                ps.log(
                    'Epoch:{epoch}, lr:{lr}, train_mse_loss:{train_mse_loss}, val_mse_loss:{val_mse_loss}'
                    .format(epoch=epoch,
                            lr=lr,
                            train_mse_loss=mse_loss_meter.value()[0],
                            val_mse_loss=mse_val_loss))

        # Adaptive adjust lr
        # 每个lr,如果有opt.tolerant_max次比上次的val_loss还高,
        # update learning rate
        # if loss_meter.value()[0] > previous_loss:
        if opt.use_early_adjust:
            if total_loss_meter.value()[0] > previous_loss:
                tolerant_now += 1
                if tolerant_now == opt.tolerant_max:
                    tolerant_now = 0
                    same_lr_epoch = 0
                    lr = lr * opt.lr_decay
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Due to early stop anneal lr to %.10f at epoch %d' %
                          (lr, epoch))
                    ps.log('Due to early stop anneal lr to %.10f at epoch %d' %
                           (lr, epoch))

            else:
                tolerant_now -= 1

        if epoch % opt.lr_anneal_epochs == 0:
            # if same_lr_epoch and same_lr_epoch % opt.lr_anneal_epochs == 0:
            same_lr_epoch = 0
            tolerant_now = 0
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            print('Anneal lr to %.10f at epoch %d due to full epochs.' %
                  (lr, epoch))
            ps.log('Anneal lr to %.10f at epoch %d due to full epochs.' %
                   (lr, epoch))

        if opt.use_file_decay_lr and os.path.exists(opt.lr_decay_file):
            cur_mtime = os.path.getmtime(opt.lr_decay_file)
            if cur_mtime > decay_file_create_time:
                decay_file_create_time = cur_mtime
                lr = lr * opt.lr_decay
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
                print(
                    'Anneal lr to %.10f at epoch %d due to decay-file indicator.'
                    % (lr, epoch))
                ps.log(
                    'Anneal lr to %.10f at epoch %d due to decay-file indicator.'
                    % (lr, epoch))

        previous_loss = total_loss_meter.value()[0]
def test(model, dataloader, test_batch_size, mse_loss, rate_loss):

    model.eval()
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))

    if opt.save_test_img:
        if not os.path.exists(opt.test_imgs_save_path):
            os.makedirs(opt.test_imgs_save_path)
            print('Makedir %s for save test images!' % opt.test_imgs_save_path)

    eval_loss = True  # mse_loss, rate_loss, rate_disp
    eval_on_RGB = True  # RGB_mse, RGB_psnr

    revert_transforms = transforms.Compose([
        transforms.Normalize((-1, -1, -1), (2, 2, 2)),
        # transforms.Normalize([-0.485/0.229, -0.456/0.224, -0.406/0.225], [1/0.229, 1/0.224, 1/0.225]),
        transforms.Lambda(lambda tensor: t.clamp(tensor, 0.0, 1.0)),
        transforms.ToPILImage()
    ])

    mse = lambda x, y: np.mean(np.square(y - x))
    psnr = lambda x, y: 10 * math.log10(255.**2 / mse(x, y))

    mse_meter = AverageValueMeter()
    psnr_meter = AverageValueMeter()
    rate_disp_meter = AverageValueMeter()
    caffe_loss_meter = AverageValueMeter()
    rate_loss_meter = AverageValueMeter()

    if opt.use_imp:
        total_loss_meter = AverageValueMeter()
    else:
        total_loss_meter = caffe_loss_meter

    for idx, data in progress_bar:
        test_data = Variable(data, volatile=True)
        # test_mask = Variable(mask, volatile=True)
        # pdb.set_trace()
        # mask[mask==1] = 0
        # mask[mask==opt.contrastive_degree] = 1
        # print ('type.mask', type(mask))

        # o_mask_as_weight = o_mask.clone()
        # pdb.set_trace()
        # bbox_inner = (o_mask == opt.mse_bbox_weight)
        # bbox_outer = (o_mask == 1)
        # o_mask[bbox_inner] = 1
        # o_mask[bbox_outer] = opt.mse_bbox_weight
        # print (o_mask)
        # pdb.set_trace()
        # o_mask[...] = 1
        # test_o_mask = Variable(o_mask, volatile=True)
        # test_o_mask_as_weight = Variable(o_mask_as_weight, volatile=True)

        # pdb.set_trace()
        if opt.use_gpu:
            test_data = test_data.cuda(async=True)
            # test_mask = test_mask.cuda(async=True)
            # test_o_mask = test_o_mask.cuda(async=True)
            # test_o_mask_as_weight = test_o_mask_as_weight.cuda(async=True)
            # o_mask = o_mask.cuda(async=True)

        # pdb.set_trace()
        if opt.use_imp:
            reconstructed, imp_mask_sigmoid = model(test_data)
        else:
            reconstructed = model(test_data)

        # only save the 1th image of batch
        img_origin = revert_transforms(test_data.data.cpu()[0])
        img_reconstructed = revert_transforms(reconstructed.data.cpu()[0])

        if opt.save_test_img:
            if opt.use_imp:
                imp_map = transforms.ToPILImage()(
                    imp_mask_sigmoid.data.cpu()[0])
                imp_map = imp_map.resize(
                    (imp_map.size[0] * 8, imp_map.size[1] * 8))
                imp_map.save(
                    os.path.join(opt.test_imgs_save_path, "%d_imp.png" % idx))
            img_origin.save(
                os.path.join(opt.test_imgs_save_path, "%d_origin.png" % idx))
            img_reconstructed.save(
                os.path.join(opt.test_imgs_save_path, "%d_reconst.png" % idx))

        if eval_loss:
            mse_loss_v = mse_loss(reconstructed, test_data)

            caffe_loss_v = mse_loss_v / (2 * test_batch_size)

            caffe_loss_meter.add(caffe_loss_v.data[0])
            if opt.use_imp:
                rate_disp_meter.add(imp_mask_sigmoid.data.mean())

                assert rate_loss is not None
                rate_loss_v = rate_loss(imp_mask_sigmoid)
                rate_loss_meter.add(rate_loss_v.data[0])

                total_loss_v = caffe_loss_v + rate_loss_v
                total_loss_meter.add(total_loss_v.data[0])

        if eval_on_RGB:
            origin_arr = np.array(img_origin)
            reconst_arr = np.array(img_reconstructed)
            RGB_mse_v = mse(origin_arr, reconst_arr)
            RGB_psnr_v = psnr(origin_arr, reconst_arr)
            mse_meter.add(RGB_mse_v)
            psnr_meter.add(RGB_psnr_v)

    if eval_loss:
        print(
            'avg_mse_loss = {m_l}, avg_rate_loss = {r_l}, avg_rate_disp = {r_d}, avg_tot_loss = {t_l}'
            .format(m_l=caffe_loss_meter.value()[0],
                    r_l=rate_loss_meter.value()[0],
                    r_d=rate_disp_meter.value()[0],
                    t_l=total_loss_meter.value()[0]))
    if eval_on_RGB:
        print('RGB avg mse = {mse}, RGB avg psnr = {psnr}'.format(
            mse=mse_meter.value()[0], psnr=psnr_meter.value()[0]))
def train(**kwargs):
    opt.parse(kwargs)
    # vis = Visualizer(opt.env)
    # log file
    ps = PlotSaver("CLIC_finetune_With_TestSet_0.13_" +
                   time.strftime("%m_%d_%H:%M:%S") + ".log.txt")

    # step1: Model

    # CWCNN_limu_ImageNet_imp_
    model = getattr(models, opt.model)(
        use_imp=opt.use_imp,
        model_name="CLIC_imp_ft_testSet_2_r={r}_gama={w}".format(
            r=opt.rate_loss_threshold, w=opt.rate_loss_weight)
        if opt.use_imp else None)
    # if opt.use_imp else "test_pytorch")
    if opt.use_gpu:
        model = multiple_gpu_process(model)
        # model.cuda()  # because I want model.imp_mask_sigmoid

    # freeze the decoder network, and finetune the enc + imp with train + val set
    if opt.freeze_decoder:
        for param in model.module.decoder.parameters():
            # print (param.requires_grad)
            param.requires_grad = False

    cudnn.benchmark = True

    # step2: Data
    # normalize = transforms.Normalize(
    #                 mean=[0.485, 0.456, 0.406],
    #                 std=[0.229, 0.224, 0.225]
    #                 )

    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    # train_data_transforms = transforms.Compose(
    #     [
    #         transforms.Resize(256),
    #         # transforms.RandomCrop(128),
    #         # transforms.Resize((128,128)),
    #         # transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor(),
    #         # transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    #         normalize
    #     ]
    # )
    # val_data_transforms = transforms.Compose(
    #     [
    #         # transforms.Resize(256),
    #         # transforms.CenterCrop(128),
    #         # transforms.Resize((128,128)),
    #         # transforms.TenCrop(224),
    #         # transforms.Lambda(lambda crops: t.stack(([normalize(transforms.ToTensor()(crop)) for crop in crops]))),
    #         transforms.ToTensor(),
    #         normalize
    #     ]
    # )

    train_data_transforms = transforms.Compose([
        transforms.Resize(256),
        #transforms.Scale(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])
    val_data_transforms = transforms.Compose([
        transforms.Resize(256),
        #transforms.Scale(256),
        transforms.CenterCrop(224),
        # transforms.TenCrop(224),
        # transforms.Lambda(lambda crops: t.stack(([normalize(transforms.ToTensor()(crop)) for crop in crops]))),
        transforms.ToTensor(),
        normalize
    ])
    # train_data = ImageNet_200k(opt.train_data_root, train=True, transforms=data_transforms)
    # val_data = ImageNet_200k(opt.val_data_root, train = False, transforms=data_transforms)
    train_data = datasets.ImageFolder(opt.train_data_root,
                                      train_data_transforms)
    val_data = datasets.ImageFolder(opt.val_data_root, val_data_transforms)
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers,
                                  pin_memory=True)
    # train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True, num_workers=opt.num_workers)

    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers,
                                pin_memory=True)
    # val_dataloader = DataLoader(val_data, opt.batch_size, shuffle=False, num_workers=opt.num_workers)

    # step3: criterion and optimizer

    mse_loss = t.nn.MSELoss(size_average=False)

    if opt.use_imp:
        # rate_loss = RateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)
        rate_loss = LimuRateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)

    lr = opt.lr

    # print ('model.parameters():')
    # print (model.parameters())

    # optimizer = t.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))
    optimizer = t.optim.Adam(filter(lambda p: p.requires_grad,
                                    model.parameters()),
                             lr=lr,
                             betas=(0.9, 0.999))

    start_epoch = 0

    if opt.resume:
        if use_data_parallel:
            start_epoch = model.module.load(
                None if opt.finetune else optimizer, opt.resume, opt.finetune)
        else:
            start_epoch = model.load(None if opt.finetune else optimizer,
                                     opt.resume, opt.finetune)

        if opt.finetune:
            print('Finetune from model checkpoint file', opt.resume)
        else:
            print('Resume training from checkpoint file', opt.resume)
            print('Continue training at epoch %d.' % start_epoch)

    # step4: meters
    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    previous_loss = 1e100
    tolerant_now = 0
    same_lr_epoch = 0

    # ps init

    ps.new_plot('train mse loss',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_mse_loss")
    ps.new_plot('val mse loss', 1, xlabel="epoch", ylabel="val_mse_loss")
    if opt.use_imp:
        ps.new_plot('train rate value',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_value")
        ps.new_plot('train rate loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_loss")
        ps.new_plot('train total loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_total_loss")
        ps.new_plot('val rate value',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_value")
        ps.new_plot('val rate loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_loss")
        ps.new_plot('val total loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_total_loss")

    for epoch in range(start_epoch + 1, opt.max_epoch + 1):

        same_lr_epoch += 1
        # per epoch avg loss meter
        mse_loss_meter.reset()
        if opt.use_imp:
            rate_display_meter.reset()
            rate_loss_meter.reset()
            total_loss_meter.reset()
        else:
            total_loss_meter = mse_loss_meter
        # cur_epoch_loss refresh every epoch
        # vis.refresh_plot('cur epoch train mse loss')
        ps.new_plot("cur epoch train mse loss",
                    opt.print_freq,
                    xlabel="iteration in cur epoch",
                    ylabel="train_mse_loss")
        # progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), ascii=True)
        # progress_bar.set_description('epoch %d/%d, loss = 0.00' % (epoch, opt.max_epoch))

        # Init val
        if (epoch == start_epoch + 1) and opt.init_val:
            print('Init validation ... ')
            if opt.use_imp:
                mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                    model, val_dataloader, mse_loss, rate_loss, ps)
            else:
                mse_val_loss = val(model, val_dataloader, mse_loss, None, ps)

            ps.add_point('val mse loss', mse_val_loss)
            if opt.use_imp:
                ps.add_point('val rate value', rate_val_display)
                ps.add_point('val rate loss', rate_val_loss)
                ps.add_point('val total loss', total_val_loss)

            # make plot
            # ps.make_plot('train mse loss', "")   # all epoch share a same img, so give "" to epoch
            # ps.make_plot('cur epoch train mse loss',epoch)
            ps.make_plot('val mse loss')

            if opt.use_imp:
                # ps.make_plot("train rate value","")
                # ps.make_plot("train rate loss","")
                # ps.make_plot("train total loss","")
                ps.make_plot('val rate value')
                ps.make_plot('val rate loss')
                ps.make_plot('val total loss')

            # log sth.
            if opt.use_imp:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                    .format(epoch=epoch,
                            lr=lr,
                            val_mse_loss=mse_val_loss,
                            val_rate_loss=rate_val_loss,
                            val_total_loss=total_val_loss,
                            val_rate_display=rate_val_display))
            else:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss:{val_mse_loss}'
                    .format(
                        epoch=epoch,
                        lr=lr,
                        # train_mse_loss = mse_loss_meter.value()[0],
                        val_mse_loss=mse_val_loss))

        model.train()

        # mse_val_loss = val(model, val_dataloader, mse_loss, None, ps)

        # _ is corresponding Label, compression doesn't use it.
        for idx, (data, _) in enumerate(train_dataloader):
            ipt = Variable(data)

            if opt.use_gpu:
                # if not use_data_parallel:  # because ipt is also target, so we still need to transfer it to CUDA
                ipt = ipt.cuda(async=True)

            optimizer.zero_grad()
            reconstructed, imp_mask_sigmoid = model(ipt)

            # print ('imp_mask_height', model.imp_mask_height)
            # pdb.set_trace()

            # print ('type recons', type(reconstructed.data))
            loss = mse_loss(reconstructed, ipt)
            caffe_loss = loss / (2 * opt.batch_size)

            if opt.use_imp:
                # rate_loss_display = model.imp_mask_sigmoid
                # rate_loss_display = (model.module if use_data_parallel else model).imp_mask_sigmoid
                rate_loss_display = imp_mask_sigmoid
                # rate_loss_display = model.imp_mask
                # print ('type of display', type(rate_loss_display.data))
                rate_loss_ = rate_loss(rate_loss_display)
                # print (
                #     'type of rate_loss_value',
                #     type(rate_loss_value.data)
                # )
                total_loss = caffe_loss + rate_loss_
            else:
                total_loss = caffe_loss

            # 1.
            total_loss.backward()
            # caffe_loss.backward()
            optimizer.step()

            mse_loss_meter.add(caffe_loss.data[0])

            if opt.use_imp:
                rate_loss_meter.add(rate_loss_.data[0])
                rate_display_meter.add(rate_loss_display.data.mean())
                total_loss_meter.add(total_loss.data[0])

            if idx % opt.print_freq == opt.print_freq - 1:
                ps.add_point(
                    'train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                ps.add_point(
                    'cur epoch train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                # print (rate_loss_display.data.mean())
                if opt.use_imp:
                    ps.add_point(
                        'train rate value',
                        rate_display_meter.value()[0]
                        if opt.print_smooth else rate_loss_display.data.mean())
                    ps.add_point(
                        'train rate loss',
                        rate_loss_meter.value()[0]
                        if opt.print_smooth else rate_loss_.data[0])
                    ps.add_point(
                        'train total loss',
                        total_loss_meter.value()[0]
                        if opt.print_smooth else total_loss.data[0])
                # pdb.set_trace()
                # progress_bar.set_description('epoch %d/%d, loss = %.2f' % (epoch, opt.max_epoch, total_loss.data[0]))

                #  2.
                # ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' % (epoch, opt.max_epoch, idx, len(train_dataloader), total_loss_meter.value()[0], lr))
                # ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' % (epoch, opt.max_epoch, idx, len(train_dataloader), mse_loss_meter.value()[0], lr))

                # ps.log('loss = %f' % caffe_loss.data[0])
                # print(total_loss.data[0])
                # input('waiting......')

                if not opt.use_imp:
                    ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' %
                           (epoch, opt.max_epoch, idx, len(train_dataloader),
                            total_loss_meter.value()[0], lr))
                else:
                    ps.log(
                        'Epoch %d/%d, Iter %d/%d, loss = %.2f, mse_loss = %.2f, rate_loss = %.2f, rate_display = %.2f, lr = %.8f'
                        %
                        (epoch, opt.max_epoch, idx, len(train_dataloader),
                         total_loss_meter.value()[0],
                         mse_loss_meter.value()[0], rate_loss_meter.value()[0],
                         rate_display_meter.value()[0], lr))
                # 进入debug模式
                if os.path.exists(opt.debug_file):
                    # import pdb
                    pdb.set_trace()

        if use_data_parallel:
            # print (type(model.module))
            # print (model)
            # print (type(model))
            model.module.save(optimizer, epoch)
        else:
            model.save(optimizer, epoch)

        # print ('case error', total_loss.data[0])
        # print ('smoothed error', total_loss_meter.value()[0])

        # plot before val can ease me
        ps.make_plot('train mse loss'
                     )  # all epoch share a same img, so give "" to epoch
        ps.make_plot('cur epoch train mse loss', epoch)
        if opt.use_imp:
            ps.make_plot("train rate value")
            ps.make_plot("train rate loss")
            ps.make_plot("train total loss")

        # val
        if opt.use_imp:
            mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                model, val_dataloader, mse_loss, rate_loss, ps)
        else:
            mse_val_loss = val(model, val_dataloader, mse_loss, None, ps)

        ps.add_point('val mse loss', mse_val_loss)
        if opt.use_imp:
            ps.add_point('val rate value', rate_val_display)
            ps.add_point('val rate loss', rate_val_loss)
            ps.add_point('val total loss', total_val_loss)

        # make plot
        # ps.make_plot('train mse loss', "")   # all epoch share a same img, so give "" to epoch
        # ps.make_plot('cur epoch train mse loss',epoch)
        ps.make_plot('val mse loss')

        if opt.use_imp:
            # ps.make_plot("train rate value","")
            # ps.make_plot("train rate loss","")
            # ps.make_plot("train total loss","")
            ps.make_plot('val rate value')
            ps.make_plot('val rate loss')
            ps.make_plot('val total loss')

        # log sth.
        if opt.use_imp:
            ps.log(
                'Epoch:{epoch}, lr:{lr}, train_mse_loss: {train_mse_loss}, train_rate_loss: {train_rate_loss}, train_total_loss: {train_total_loss}, train_rate_display: {train_rate_display} \n\
val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                .format(epoch=epoch,
                        lr=lr,
                        train_mse_loss=mse_loss_meter.value()[0],
                        train_rate_loss=rate_loss_meter.value()[0],
                        train_total_loss=total_loss_meter.value()[0],
                        train_rate_display=rate_display_meter.value()[0],
                        val_mse_loss=mse_val_loss,
                        val_rate_loss=rate_val_loss,
                        val_total_loss=total_val_loss,
                        val_rate_display=rate_val_display))
        else:
            ps.log(
                'Epoch:{epoch}, lr:{lr}, train_mse_loss:{train_mse_loss}, val_mse_loss:{val_mse_loss}'
                .format(epoch=epoch,
                        lr=lr,
                        train_mse_loss=mse_loss_meter.value()[0],
                        val_mse_loss=mse_val_loss))

        # Adaptive adjust lr
        # 每个lr,如果有opt.tolerant_max次比上次的val_loss还高,
        # update learning rate
        # if loss_meter.value()[0] > previous_loss:
        if opt.use_early_adjust:
            if total_loss_meter.value()[0] > previous_loss:
                tolerant_now += 1
                if tolerant_now == opt.tolerant_max:
                    tolerant_now = 0
                    same_lr_epoch = 0
                    lr = lr * opt.lr_decay
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Due to early stop anneal lr to', lr, 'at epoch',
                          epoch)
                    ps.log('Due to early stop anneal lr to %.10f at epoch %d' %
                           (lr, epoch))

            else:
                tolerant_now -= 1

        # if same_lr_epoch and same_lr_epoch % opt.lr_anneal_epochs == 0:
        #     same_lr_epoch = 0
        #     tolerant_now = 0
        #     lr = lr * opt.lr_decay
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr
        #     print ('Due to full epochs anneal lr to',lr,'at epoch',epoch)
        #     ps.log ('Due to full epochs anneal lr to %.10f at epoch %d' % (lr, epoch))

        if opt.use_file_decay_lr and os.path.exists(opt.lr_decay_file):
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        # previous_loss = total_loss_meter.value()[0] if opt.use_imp else mse_loss_meter.value()[0]
        previous_loss = total_loss_meter.value()[0]
def train(**kwargs):
    global batch_model_id
    opt.parse(kwargs)
    if opt.use_batch_process:
        max_bp_times = len(opt.exp_ids)
        if batch_model_id >= max_bp_times:
            print('Batch Processing Done!')
            return
        else:
            print('Cur Batch Processing ID is %d/%d.' %
                  (batch_model_id + 1, max_bp_times))
            opt.r = opt.r_s[batch_model_id]
            opt.exp_id = opt.exp_ids[batch_model_id]
            opt.exp_desc = opt.exp_desc_LUT[opt.exp_id]
            opt.plot_path = "plot/plot_%d" % opt.exp_id
            print('Cur Model(exp_%d) r = %f, desc = %s. ' %
                  (opt.exp_id, opt.r, opt.exp_desc))
    opt.make_new_dirs()
    # log file
    EvalVal = opt.only_init_val and opt.init_val and not opt.test_test
    EvalTest = opt.only_init_val and opt.init_val and opt.test_test
    EvalSuffix = ""
    if EvalVal:
        EvalSuffix = "_val"
    if EvalTest:
        EvalSuffix = "_test"
    logfile_name = opt.exp_desc + time.strftime(
        "_%m_%d_%H:%M:%S") + EvalSuffix + ".log.txt"

    ps = PlotSaver(logfile_name, log_to_stdout=opt.log_to_stdout)

    # step1: Model
    # model = getattr(models, opt.model)(use_imp = opt.use_imp, n = opt.feat_num, model_name=opt.exp_desc + ("_r={r}_gm={w}".format(
    #                                                             r=opt.rate_loss_threshold,
    #                                                             w=opt.rate_loss_weight)
    #                                                           if opt.use_imp else "_no_imp"))

    model = getattr(models, opt.model)(use_imp=opt.use_imp,
                                       n=opt.feat_num,
                                       model_name=opt.exp_desc)
    # print (model)
    # pdb.set_trace()
    global use_data_parallel
    if opt.use_gpu:
        model, use_data_parallel = multiple_gpu_process(model)

    # real use gpu or cpu
    opt.use_gpu = opt.use_gpu and use_data_parallel >= 0
    cudnn.benchmark = True

    # step2: Data
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    train_data_transforms = transforms.Compose([
        # transforms.RandomHorizontalFlip(),  TODO: try to reimplement by myself to simultaneous operate on label and data
        transforms.RandomCrop(128),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])
    val_data_transforms = transforms.Compose(
        [transforms.CenterCrop(128),
         transforms.ToTensor(), normalize])

    caffe_data_transforms = transforms.Compose([transforms.CenterCrop(128)])
    # transforms ||  data
    train_data = ImageFilelist(
        flist=opt.train_data_list,
        transform=train_data_transforms,
    )

    val_data = ImageFilelist(
        flist=opt.val_data_list,
        prefix=opt.val_data_prefix,
        transform=val_data_transforms,
    )

    val_data_caffe = ImageFilelist(flist=opt.val_data_list,
                                   transform=caffe_data_transforms)

    test_data_caffe = ImageFilelist(flist=opt.test_data_list,
                                    transform=caffe_data_transforms)

    if opt.make_caffe_data:
        save_caffe_data(test_data_caffe)
        print('Make caffe dataset over!')
        return

    # train_data = ImageCropWithBBoxMaskDataset(
    #     opt.train_data_list,
    #     train_data_transforms,
    #     contrastive_degree = opt.contrastive_degree,
    #     mse_bbox_weight = opt.input_original_bbox_weight
    # )
    # val_data = ImageCropWithBBoxMaskDataset(
    #     opt.val_data_list,
    #     val_data_transforms,
    #     contrastive_degree = opt.contrastive_degree,
    #     mse_bbox_weight = opt.input_original_bbox_weight
    # )
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers,
                                  pin_memory=True)
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers,
                                pin_memory=True)

    # step3: criterion and optimizer

    mse_loss = t.nn.MSELoss(size_average=False)

    if opt.use_imp:
        # TODO: new rate loss
        rate_loss = RateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)
        # rate_loss = LimuRateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)

    def weighted_mse_loss(input, target, weight):
        # weight[weight!=opt.mse_bbox_weight] = 1
        # weight[weight==opt.mse_bbox_weight] = opt.mse_bbox_weight
        # print('max val', weight.max())
        # return mse_loss(input, target)
        # weight_clone = weight.clone()
        # weight_clone[weight_clone == opt.input_original_bbox_weight] = 0
        # return t.sum(weight_clone * (input - target) ** 2)
        weight_clone = t.ones_like(weight)
        weight_clone[weight ==
                     opt.input_original_bbox_inner] = opt.mse_bbox_weight
        return t.sum(weight_clone * (input - target)**2)

    def yolo_rate_loss(imp_map, mask_r):
        return rate_loss(imp_map)
        # V2 contrastive_degree must be 0!
        # return YoloRateLossV2(mask_r, opt.rate_loss_threshold, opt.rate_loss_weight)(imp_map)

    start_epoch = 0
    decay_file_create_time = -1  # 为了避免同一个文件反复衰减学习率, 所以判断修改时间

    previous_loss = 1e100
    tolerant_now = 0
    same_lr_epoch = 0
    lr = opt.lr
    optimizer = t.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

    if opt.resume:
        start_epoch = (model.module if use_data_parallel == 1 else model).load(
            None if opt.finetune else optimizer, opt.resume, opt.finetune)

        if opt.finetune:
            print('Finetune from model checkpoint file', opt.resume)
        else:
            print('Resume training from checkpoint file', opt.resume)
            print('Continue training from epoch %d.' % start_epoch)
            same_lr_epoch = start_epoch % opt.lr_anneal_epochs
            decay_times = start_epoch // opt.lr_anneal_epochs
            lr = opt.lr * (opt.lr_decay**decay_times)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            print('Decay lr %d times, now lr is %e.' % (decay_times, lr))

    # step4: meters
    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    # ps init

    ps.new_plot('train mse loss',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_mse_loss")
    ps.new_plot('val mse loss', 1, xlabel="epoch", ylabel="val_mse_loss")
    if opt.use_imp:
        ps.new_plot('train rate value',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_value")
        ps.new_plot('train rate loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_loss")
        ps.new_plot('train total loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_total_loss")
        ps.new_plot('val rate value',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_value")
        ps.new_plot('val rate loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_loss")
        ps.new_plot('val total loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_total_loss")

    # 如果测试时是600,max_epoch也是600
    if opt.only_init_val and opt.max_epoch <= start_epoch:
        opt.max_epoch = start_epoch + 2

    for epoch in range(start_epoch + 1, opt.max_epoch + 1):

        same_lr_epoch += 1
        # per epoch avg loss meter
        mse_loss_meter.reset()
        if opt.use_imp:
            rate_display_meter.reset()
            rate_loss_meter.reset()
            total_loss_meter.reset()
        else:
            total_loss_meter = mse_loss_meter
        # cur_epoch_loss refresh every epoch
        # vis.refresh_plot('cur epoch train mse loss')
        ps.new_plot("cur epoch train mse loss",
                    opt.print_freq,
                    xlabel="iteration in cur epoch",
                    ylabel="train_mse_loss")
        # progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), ascii=True)
        # progress_bar.set_description('epoch %d/%d, loss = 0.00' % (epoch, opt.max_epoch))

        # Init val
        if (epoch == start_epoch + 1) and opt.init_val:
            print('Init validation ... ')
            if opt.use_imp:
                mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                    model, val_dataloader, mse_loss, rate_loss, ps)
            else:
                mse_val_loss = val(model, val_dataloader, mse_loss, None, ps)

            ps.add_point('val mse loss', mse_val_loss)
            if opt.use_imp:
                ps.add_point('val rate value', rate_val_display)
                ps.add_point('val rate loss', rate_val_loss)
                ps.add_point('val total loss', total_val_loss)

            ps.make_plot('val mse loss')

            if opt.use_imp:
                ps.make_plot('val rate value')
                ps.make_plot('val rate loss')
                ps.make_plot('val total loss')

            # log sth.
            if opt.use_imp:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                    .format(epoch=epoch,
                            lr=lr,
                            val_mse_loss=mse_val_loss,
                            val_rate_loss=rate_val_loss,
                            val_total_loss=total_val_loss,
                            val_rate_display=rate_val_display))
            else:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss: {val_mse_loss}'
                    .format(epoch=epoch, lr=lr, val_mse_loss=mse_val_loss))

        if opt.only_init_val:
            print('Only Init Val Over!')
            return

        model.train()

        # if epoch == start_epoch + 1:
        #     print ('Start training, please inspect log file %s!' % logfile_name)
        # mask is the detection bounding box mask
        for idx, data in enumerate(train_dataloader):

            # pdb.set_trace()

            data = Variable(data)
            # mask = Variable(mask)
            # o_mask = Variable(o_mask, requires_grad=False)

            if opt.use_gpu:
                data = data.cuda(async=True)
                # mask = mask.cuda(async = True)
                # o_mask = o_mask.cuda(async = True)

            # pdb.set_trace()

            optimizer.zero_grad()
            if opt.use_imp:
                reconstructed, imp_mask_sigmoid = model(data)
            else:
                reconstructed = model(data)

            # print ('imp_mask_height', model.imp_mask_height)
            # pdb.set_trace()

            # print ('type recons', type(reconstructed.data))

            loss = mse_loss(reconstructed, data)
            # loss = mse_loss(reconstructed, data)
            caffe_loss = loss / (2 * opt.batch_size)

            if opt.use_imp:
                rate_loss_ = rate_loss(imp_mask_sigmoid)
                # rate_loss_ = yolo_rate_loss(imp_mask_sigmoid, mask)
                total_loss = caffe_loss + rate_loss_
            else:
                total_loss = caffe_loss

            total_loss.backward()
            optimizer.step()

            mse_loss_meter.add(caffe_loss.data[0])

            if opt.use_imp:
                rate_loss_meter.add(rate_loss_.data[0])
                rate_display_meter.add(imp_mask_sigmoid.data.mean())
                total_loss_meter.add(total_loss.data[0])

            if idx % opt.print_freq == opt.print_freq - 1:
                ps.add_point(
                    'train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                ps.add_point(
                    'cur epoch train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                if opt.use_imp:
                    ps.add_point(
                        'train rate value',
                        rate_display_meter.value()[0]
                        if opt.print_smooth else imp_mask_sigmoid.data.mean())
                    ps.add_point(
                        'train rate loss',
                        rate_loss_meter.value()[0]
                        if opt.print_smooth else rate_loss_.data[0])
                    ps.add_point(
                        'train total loss',
                        total_loss_meter.value()[0]
                        if opt.print_smooth else total_loss.data[0])

                if not opt.use_imp:
                    ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' %
                           (epoch, opt.max_epoch, idx, len(train_dataloader),
                            total_loss_meter.value()[0], lr))
                else:
                    ps.log(
                        'Epoch %d/%d, Iter %d/%d, loss = %.2f, mse_loss = %.2f, rate_loss = %.2f, rate_display = %.2f, lr = %.8f'
                        %
                        (epoch, opt.max_epoch, idx, len(train_dataloader),
                         total_loss_meter.value()[0],
                         mse_loss_meter.value()[0], rate_loss_meter.value()[0],
                         rate_display_meter.value()[0], lr))

                # 进入debug模式
                if os.path.exists(opt.debug_file):
                    pdb.set_trace()

        if epoch % opt.save_interval == 0:
            print('Save checkpoint file of epoch %d.' % epoch)
            if use_data_parallel == 1:
                model.module.save(optimizer, epoch)
            else:
                model.save(optimizer, epoch)

        ps.make_plot('train mse loss')
        ps.make_plot('cur epoch train mse loss', epoch)
        if opt.use_imp:
            ps.make_plot("train rate value")
            ps.make_plot("train rate loss")
            ps.make_plot("train total loss")

        if epoch % opt.eval_interval == 0:
            print('Validating ...')
            # val
            if opt.use_imp:
                mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                    model, val_dataloader, mse_loss, rate_loss, ps)
            else:
                mse_val_loss = val(model, val_dataloader, mse_loss, None, ps)

            ps.add_point('val mse loss', mse_val_loss)
            if opt.use_imp:
                ps.add_point('val rate value', rate_val_display)
                ps.add_point('val rate loss', rate_val_loss)
                ps.add_point('val total loss', total_val_loss)

            ps.make_plot('val mse loss')

            if opt.use_imp:
                ps.make_plot('val rate value')
                ps.make_plot('val rate loss')
                ps.make_plot('val total loss')

            # log sth.
            if opt.use_imp:
                ps.log(
                    'Epoch:{epoch}, lr:{lr}, train_mse_loss: {train_mse_loss}, train_rate_loss: {train_rate_loss}, train_total_loss: {train_total_loss}, train_rate_display: {train_rate_display} \n\
    val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                    .format(epoch=epoch,
                            lr=lr,
                            train_mse_loss=mse_loss_meter.value()[0],
                            train_rate_loss=rate_loss_meter.value()[0],
                            train_total_loss=total_loss_meter.value()[0],
                            train_rate_display=rate_display_meter.value()[0],
                            val_mse_loss=mse_val_loss,
                            val_rate_loss=rate_val_loss,
                            val_total_loss=total_val_loss,
                            val_rate_display=rate_val_display))
            else:
                ps.log(
                    'Epoch:{epoch}, lr:{lr}, train_mse_loss:{train_mse_loss}, val_mse_loss:{val_mse_loss}'
                    .format(epoch=epoch,
                            lr=lr,
                            train_mse_loss=mse_loss_meter.value()[0],
                            val_mse_loss=mse_val_loss))

        # Adaptive adjust lr
        # 每个lr,如果有opt.tolerant_max次比上次的val_loss还高,
        # update learning rate
        # if loss_meter.value()[0] > previous_loss:
        if opt.use_early_adjust:
            if total_loss_meter.value()[0] > previous_loss:
                tolerant_now += 1
                if tolerant_now == opt.tolerant_max:
                    tolerant_now = 0
                    same_lr_epoch = 0
                    lr = lr * opt.lr_decay
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Due to early stop anneal lr to %.10f at epoch %d' %
                          (lr, epoch))
                    ps.log('Due to early stop anneal lr to %.10f at epoch %d' %
                           (lr, epoch))

            else:
                # tolerant_now -= 1
                tolerant_now = 0

        if epoch % opt.lr_anneal_epochs == 0:
            # if same_lr_epoch and same_lr_epoch % opt.lr_anneal_epochs == 0:
            same_lr_epoch = 0
            tolerant_now = 0
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            print('Anneal lr to %.10f at epoch %d due to full epochs.' %
                  (lr, epoch))
            ps.log('Anneal lr to %.10f at epoch %d due to full epochs.' %
                   (lr, epoch))

        if opt.use_file_decay_lr and os.path.exists(opt.lr_decay_file):
            cur_mtime = os.path.getmtime(opt.lr_decay_file)
            if cur_mtime > decay_file_create_time:
                decay_file_create_time = cur_mtime
                lr = lr * opt.lr_decay
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
                print(
                    'Anneal lr to %.10f at epoch %d due to decay-file indicator.'
                    % (lr, epoch))
                ps.log(
                    'Anneal lr to %.10f at epoch %d due to decay-file indicator.'
                    % (lr, epoch))

        previous_loss = total_loss_meter.value()[0]
    if opt.use_batch_process:
        batch_model_id += 1
        train(kwargs)